In text generation, the way we choose the succeeding character is crucial. The most common way (greedy sampling) leads to repetitive characters that does not produce a coherent language. This is why we use a different approach called stochastic sampling. This adds a degree of randomness to the prediction probability distribution.
Use this code to re-weight the prediction probability distribution and sample a character index:
def sample(preds, temperature=1.0):
preds = np.asarray(preds).astype('float64')
preds = np.log(preds) / temperature
exp_preds = np.exp(preds)
preds = exp_preds / np.sum(exp_preds)
probas = np.random.multinomial(1, preds, 1)
return np.argmax(probas)
Now, we iterate training and text generation beginning with 30 training epochs and then fitting the model for 1 iteration. Random selection of seed text followed by conversion into the one hot encoding format and perform predictions of 100 characters and append the newly generated character to the seed text in each iteration.
After each epoch, generation is performed utilizing a different temperature from a range of values. This makes it possible to see and understand the evolution of the generated text at model convergence, and the consequences of temperature in the sampling strategy.
for epoch in range(1, 30):
print('epoch', epoch)
# Fit the model for 1 epoch
model.fit(x, y, batch_size=128, epochs=1, callbacks=callbacks_list)
# Select a text seed randomly
start_index = random.randint(0, len(text) - maxlen - 1)
generated_text = text[start_index: start_index + maxlen]
print('---Seeded text: "' + generated_text + '"')
for temperature in [0.2, 0.5, 1.0, 1.2]:
print('------ Selected temperature:', temperature)
sys.stdout.write(generated_text)
# We generate 100 characters
for i in range(100):
sampled = np.zeros((1, maxlen, len(chars)))
for t, char in enumerate(generated_text):
sampled[0, t, char_indices[char]] = 1.
preds = model.predict(sampled, verbose=0)[0]
next_index = sample(preds, temperature)
next_char = chars[next_index]
generated_text += next_char
generated_text = generated_text[1:]
sys.stdout.write(next_char)
sys.stdout.flush()
print()