Inference

Once the model is ready we can now use it to make the predictions. We will start by defining all the parameters. While building inference we need to provide some seed text as we did in the previous model. Along with that, we will also provide the path of the vocab file and the output file in which we will store the generated lyrics. Also, we will provide the length of text that we need to generate:

import argparse
import codecs
from modules.Model import *
from modules.Preprocessing import *
from collections import deque

def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, required=True)
parser.add_argument('--vocabulary_file', type=str, required=True)
parser.add_argument('--output_file', type=str, required=True)

parser.add_argument('--seed', type=str, default="Yeah, oho ")
parser.add_argument('--sample_length', type=int, default=1500)
parser.add_argument('--log_frequency', type=int, default=100)

Next, we load the model by providing the name of model which we used in the training step above, and we restore the vocabulary from the file:

    model = Model(model_name)
model.restore()
classifier = model.get_classifier()

vocabulary = Preprocessing()
vocabulary.retrieve(vocabulary_file)

We will be using the stack methods to store the generated characters append the stack and then use the same stack to feed into the model in the interactive fashion:

# Preparing the raw input data 
for char in seed:
if char not in vocabulary.vocabulary:
print char,"is not in vocabulary file"
char = u' '
stack.append(char)
sample_file.write(char)

# Restoring the models and making inferences
with tf.Session() as sess:
tf.global_variables_initializer().run()

saver = tf.train.Saver(tf.global_variables())
ckpt = tf.train.get_checkpoint_state(model_name)

if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)

for i in range(0, sample_length):
vector = []
for char in stack:
vector.append(vocabulary.binary_vocabulary[char])
vector = np.array([vector])
prediction = sess.run(classifier, feed_dict={model.x: vector})
predicted_char = vocabulary.char_lookup[np.argmax(prediction)]

stack.popleft()
stack.append(predicted_char)
sample_file.write(predicted_char)

if i % log_frequency == 0:
print "Progress: {}%".format((i * 100) / sample_length)

sample_file.close()
print "Sample saved in {}".format(output_file)
..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset
3.149.213.44