batch_first in PyTorch LSTM

Issue

I am new in this field, so I still don’t understand about the batch_first in PyTorch LSTM. I tried the code that someone has referred to me, and it works on my train data when the batch_first = False, it produces the same output for Official LSTM and Manual LSTM. However, when I change the batch_first = True, it not produces the same value anymore, while I need to change the batch_first to True, because my dataset shape is tensor of (Batch, Sequences, Input size). Which part of the Manual LSTM needs to be changed to produces the same output as the Official LSTM produces when batch_first = True? Here is the code snippet:

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

train_x = torch.tensor([[[0.14285755], [0], [0.04761982], [0.04761982], [0.04761982],
  [0.04761982], [0.04761982], [0.09523869], [0.09523869], [0.09523869], 
  [0.09523869], [0.09523869], [0.04761982], [0.04761982], [0.04761982],
  [0.04761982], [0.09523869], [0.        ], [0.        ], [0.        ],
  [0.        ], [0.09523869], [0.09523869], [0.09523869], [0.09523869],
  [0.09523869], [0.09523869], [0.09523869],[0.14285755], [0.14285755]]], 
  requires_grad=True)

seed = 23
torch.manual_seed(seed)
np.random.seed(seed)

pytorch_lstm = torch.nn.LSTM(1, 1, bidirectional=False, num_layers=1, batch_first=True)
weights = torch.randn(pytorch_lstm.weight_ih_l0.shape,dtype = torch.float)
pytorch_lstm.weight_ih_l0 = torch.nn.Parameter(weights)
# Set bias to Zero
pytorch_lstm.bias_ih_l0 = torch.nn.Parameter(torch.zeros(pytorch_lstm.bias_ih_l0.shape))
pytorch_lstm.weight_hh_l0 = torch.nn.Parameter(torch.ones(pytorch_lstm.weight_hh_l0.shape))
# Set bias to Zero
pytorch_lstm.bias_hh_l0 = torch.nn.Parameter(torch.zeros(pytorch_lstm.bias_ih_l0.shape))
pytorch_lstm_out = pytorch_lstm(train_x)

batch_size=1

# Manual Calculation
W_ii, W_if, W_ig, W_io = pytorch_lstm.weight_ih_l0.split(1, dim=0)
b_ii, b_if, b_ig, b_io = pytorch_lstm.bias_ih_l0.split(1, dim=0)

W_hi, W_hf, W_hg, W_ho = pytorch_lstm.weight_hh_l0.split(1, dim=0)
b_hi, b_hf, b_hg, b_ho = pytorch_lstm.bias_hh_l0.split(1, dim=0)

prev_h = torch.zeros((batchsize,1))
prev_c = torch.zeros((batchsize,1))

i_t = torch.sigmoid(F.linear(train_x, W_ii, b_ii) + F.linear(prev_h, W_hi, b_hi))
f_t = torch.sigmoid(F.linear(train_x, W_if, b_if) + F.linear(prev_h, W_hf, b_hf))
g_t = torch.tanh(F.linear(train_x, W_ig, b_ig) + F.linear(prev_h, W_hg, b_hg))
o_t = torch.sigmoid(F.linear(train_x, W_io, b_io) + F.linear(prev_h, W_ho, b_ho))
c_t = f_t * prev_c + i_t * g_t
h_t = o_t * torch.tanh(c_t)

print('nn.LSTM output {}, manual output {}'.format(pytorch_lstm_out[0], h_t))
print('nn.LSTM hidden {}, manual hidden {}'.format(pytorch_lstm_out[1][0], h_t))
print('nn.LSTM state {}, manual state {}'.format(pytorch_lstm_out[1][1], c_t))

Solution

You have to iterate through each sequence element at a time and take the computed hidden and cell states as input in the next time step…

h_t = torch.zeros((batch_size,1))
c_t = torch.zeros((batch_size,1))

hidden_seq = []

for t in range(30):
  x_t = train_x[:, t, :]
  i_t = torch.sigmoid(F.linear(x_t, W_ii, b_ii) + F.linear(h_t, W_hi, b_hi))
  f_t = torch.sigmoid(F.linear(x_t, W_if, b_if) + F.linear(h_t, W_hf, b_hf))
  g_t = torch.tanh(F.linear(x_t, W_ig, b_ig) + F.linear(h_t, W_hg, b_hg))
  o_t = torch.sigmoid(F.linear(x_t, W_io, b_io) + F.linear(h_t, W_ho, b_ho))
  c_t = f_t * c_t + i_t * g_t
  h_t = o_t * torch.tanh(c_t)
  hidden_seq.append(h_t.unsqueeze(0))

hidden_seq = torch.cat(hidden_seq, dim=0)
hidden_seq = hidden_seq.transpose(0, 1).contiguous()

print('nn.LSTM output {}, manual output {}'.format(pytorch_lstm_out[0], hidden_seq))

Answered By – Alexander Riedel

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published