Training the seq2seq model

Now, let's train the model.  We will need some helper functions for padding of the sentence and to calculate the accuracy of the model.

def pad_sentence_batch(sentence_batch, pad_int):
padded_seqs = []
seq_lens = []
max_sentence_len = 50
for sentence in sentence_batch:
padded_seqs.append(sentence + [pad_int] * (max_sentence_len - len(sentence)))
seq_lens.append(50)
return padded_seqs, seq_lens

def check_accuracy(logits, Y):
acc = 0
for i in range(logits.shape[0]):
internal_acc = 0
for k in range(len(Y[i])):
if Y[i][k] == logits[i][k]:
internal_acc += 1
acc += (internal_acc / len(Y[i]))
return acc / logits.shape[0]

We initialize our model and iterate the session for the defined number of epochs.

tf.reset_default_graph()
sess = tf.InteractiveSession()
model = Chatbot(size_layer, num_layers, embedded_size, vocabulary_size_from + 4,
vocabulary_size_to + 4, learning_rate, batch_size)
sess.run(tf.global_variables_initializer())

for i in range(epoch):
total_loss, total_accuracy = 0, 0
for k in range(0, (len(text_from) // batch_size) * batch_size, batch_size):
batch_x, seq_x = pad_sentence_batch(X[k: k+batch_size], PAD)
batch_y, seq_y = pad_sentence_batch(Y[k: k+batch_size], PAD)
predicted, loss, _ = sess.run([tf.argmax(model.logits,2), model.cost, model.optimizer],
feed_dict={model.X:batch_x,
model.Y:batch_y,
model.X_seq_len:seq_x,
model.Y_seq_len:seq_y})
total_loss += loss
total_accuracy += check_accuracy(predicted,batch_y)
total_loss /= (len(text_from) // batch_size)
total_accuracy /= (len(text_from) // batch_size)
print('epoch: %d, avg loss: %f, avg accuracy: %f'%(i+1, total_loss, total_accuracy))

OUTPUT: epoch: 47, avg loss: 0.682934, avg accuracy: 0.000000 epoch: 48, avg loss: 0.680367, avg accuracy: 0.000000 epoch: 49, avg loss: 0.677882, avg accuracy: 0.000000 epoch: 50, avg loss: 0.678484, avg accuracy: 0.000000
.
.
.
epoch: 1133, avg loss: 0.000464, avg accuracy: 1.000000
epoch: 1134, avg loss: 0.000462, avg accuracy: 1.000000
epoch: 1135, avg loss: 0.000460, avg accuracy: 1.000000
epoch: 1136, avg loss: 0.000457, avg accuracy: 1.000000
..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset
18.223.170.223