To save variables from the tensor flow session for future use, you can use the Saver() function. Let's start by creating a saver variable right after the writer variable:
writer = tf.summary.FileWriter(log_location, session.graph)
saver = tf.train.Saver(max_to_keep=5)
Then, in the training loop, we will add the following code to save the model after every model_saving_step:
if step % model_saving_step == 0 or step == num_steps + 1: path = saver.save(session, os.path.join(log_location,
"model.ckpt"), global_step=step) logmanager.logger.info('Model saved in file: %s' % path)
After that, whenever we want to restore the model using the saved model, we can easily create a new Saver() instance and use the restore function as follows:
checkpoint_path = tf.train.latest_checkpoint(log_location) restorer = tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) restorer.restore(sess, checkpoint_path)
In the preceding code, we use the tf.train.latest_checkpoint so that TensorFlow will automatically choose the latest model checkpoint. Then, we create a new Saver instance named restore. Finally, we can use the restore function to load the saved model to the session graph:
restorer.restore(sess, checkpoint_path)
You should note that we must restore after we run the tf.global_variables_initializer. Otherwise, the loaded variables will be overridden by the initializer.