Creating the network 

Let's look at the code and then walk through the code. You may be surprised at how similar the code looks:

class IMDBRnn(nn.Module):

def __init__(self,vocab,hidden_size,n_cat,bs=1,nl=2):
super().__init__()
self.hidden_size = hidden_size
self.bs = bs
self.nl = nl
self.e = nn.Embedding(n_vocab,hidden_size)
self.rnn = nn.LSTM(hidden_size,hidden_size,nl)
self.fc2 = nn.Linear(hidden_size,n_cat)
self.softmax = nn.LogSoftmax(dim=-1)

def forward(self,inp):
bs = inp.size()[1]
if bs != self.bs:
self.bs = bs
e_out = self.e(inp)
h0 = c0 = Variable(e_out.data.new(*(self.nl,self.bs,self.hidden_size)).zero_())
rnn_o,_ = self.rnn(e_out,(h0,c0))
rnn_o = rnn_o[-1]
fc = F.dropout(self.fc2(rnn_o),p=0.8)
return self.softmax(fc)

The init method creates an embedding layer of the size of the vocabulary and hidden_size. It also creates an LSTM and a linear layer. The last layer is a LogSoftmax layer for converting the results from the linear layer to probabilities.

In the forward function, we pass the input data of size [200, 32], which gets passed through the embedding layer and each token in the batch gets replaced by embedding and the size turns to [200, 32, 100], where 100 is the embedding dimensions. The LSTM layer takes the output of the embedding layer along with two hidden variables. The hidden variables should be of the same type of the embeddings output, and their size should be [num_layers, batch_size, hidden_size]. The LSTM processes the data in a sequence and generates the output of the shape [Sequence_length, batch_size, hidden_size], where each sequence index represents the output of that sequence. In this case, we just take the output of the last sequence, which is of shape [batch_size, hidden_dim], and pass it on to a linear layer to map it to the output categories. Since the model tends to overfit, add a dropout layer. You can play with the dropout probabilities.

..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset
3.15.137.59