Complete source

In this example, we will load the weights from a MNIST convolutional autoencoder example. We will restore the weights of the encoder part only, freeze the CONV layers, and train the FC layers to perform digits classification:

import tensorflow as tf 
import numpy as np 
import os 
from models import CAE_CNN_Encoder SAVE_FOLDER='/tmp/cae_cnn_transfer' from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) model = CAE_CNN_Encoder(latent_size = 20) model_in = model.input model_out = model.output labels_in = model.labels
# Get all convs weights list_convs = [v for v in tf.global_variables() if "conv" in v.name]
# Get fc1 and logits list_fc_layers = [v for v in tf.global_variables() if "fc" in v.name or "logits" in v.name]

# Define the saver object to load only the conv variables saver_load_autoencoder = tf.train.Saver(var_list=list_convs)
# Define saver object to save all the variables during training saver = tf.train.Saver()

# Define loss for classification
with tf.name_scope("LOSS"): loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=model_out, labels=labels_in)) correct_prediction = tf.equal(tf.argmax(model_out,1), tf.argmax(labels_in,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) # Solver configuration with tf.name_scope("Solver"): train_step = tf.train.AdamOptimizer(1e-4).minimize(loss, var_list=list_fc_layers)

# Initialize variables init = tf.global_variables_initializer() # Avoid allocating the whole memory gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.200) sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) sess.run(init)

# Restore only the CONV weights (From AutoEncoder) saver_load_autoencoder.restore(sess, "/tmp/cae_cnn/model.ckpt-34")

# Add some tensors to observe on tensorboad
tf.summary.image("input_image", model.image_in, 4) tf.summary.scalar("loss", loss) merged_summary = tf.summary.merge_all() writer = tf.summary.FileWriter(SAVE_FOLDER) writer.add_graph(sess.graph)

#####Train######
num_epoch = 200 batch_size = 10 for epoch in range(num_epoch): for i in range(int(mnist.train.num_examples / batch_size)): # Get batch of 50 images batch = mnist.train.next_batch(batch_size) # Dump summary if i % 5000 == 0: # Other summaries s = sess.run(merged_summary, feed_dict={model_in:batch[0], labels_in:batch[1]}) writer.add_summary(s,i) # Train actually here (Also get loss value) _, val_loss, t_acc = sess.run((train_step, loss, accuracy), feed_dict={model_in:batch[0], labels_in:batch[1]}) print('Epoch: %d/%d loss:%d' % (epoch, num_epoch, val_loss)) print('Save model:', epoch) saver.save(sess, os.path.join(SAVE_FOLDER, "model.ckpt"), epoch)
..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset
3.21.106.7