Matching networks in TensorFlow

Now, we will see how to build a matching network in TensorFlow step by step. We will see the final code at the end.

First, we import the libraries:

import tensorflow as tf
slim = tf.contrib.slim
rnn = tf.contrib.rnn

Now, we define a class called Matching_network, where we define our network:

class Matching_network():

We define the __init__ method, where we initialize all of the variables:


def __init__(self, lr, n_way, k_shot, batch_size=32):

#placeholder for support set
self.support_set_image = tf.placeholder(tf.float32, [None, n_way * k_shot, 28, 28, 1])
self.support_set_label = tf.placeholder(tf.int32, [None, n_way * k_shot, ])

#placeholder for query set
self.query_image = tf.placeholder(tf.float32, [None, 28, 28, 1])
self.query_label = tf.placeholder(tf.int32, [None, ])

Let's say our support set and query set have images. Before feeding this raw image to the embedding function, first, we will extract the features from the image using a convolutional network and then we feed the extracted features of the support set and query set to the embedding functions of g and f respectively.

So, we will define a function called image_encoder, which is used for encoding features from the image. We use a four-layered convolutional network with a max pooling operation as our image encoder:


def image_encoder(self, image):

with slim.arg_scope([slim.conv2d], num_outputs=64, kernel_size=3, normalizer_fn=slim.batch_norm):
#conv1
net = slim.conv2d(image)
net = slim.max_pool2d(net, [2, 2])

#conv2
net = slim.conv2d(net)
net = slim.max_pool2d(net, [2, 2])

#conv3
net = slim.conv2d(net)
net = slim.max_pool2d(net, [2, 2])

#conv4
net = slim.conv2d(net)
net = slim.max_pool2d(net, [2, 2])

return tf.reshape(net, [-1, 1 * 1 * 64])

Now we define our embedding functions; we have already seen how the embedding functions, f and g, are defined in the Embedding function section. So, we can define them directly as follows:

#embedding function for extracting support set embeddings
def g(self, x_i):

forward_cell = rnn.BasicLSTMCell(32)
backward_cell = rnn.BasicLSTMCell(32)
outputs, state_forward, state_backward = rnn.static_bidirectional_rnn(forward_cell, backward_cell, x_i, dtype=tf.float32)

return tf.add(tf.stack(x_i), tf.stack(outputs))


#embedding function for extracting query set embeddings
def f(self, XHat, g_embedding):
cell = rnn.BasicLSTMCell(64)
prev_state = cell.zero_state(self.batch_size, tf.float32)

for step in xrange(self.processing_steps):
output, state = cell(XHat, prev_state)

h_k = tf.add(output, XHat)

content_based_attention = tf.nn.softmax(tf.multiply(prev_state[1], g_embedding))

r_k = tf.reduce_sum(tf.multiply(content_based_attention, g_embedding), axis=0)

prev_state = rnn.LSTMStateTuple(state[0], tf.add(h_k, r_k))

return output

Now, we define a function called cosine_similarity for learning the cosine similarity between support set and query set embeddings:

    def cosine_similarity(self, target, support_set):
target_normed = target
sup_similarity = []
for i in tf.unstack(support_set):
i_normed = tf.nn.l2_normalize(i, 1)
similarity = tf.matmul(tf.expand_dims(target_normed, 1), tf.expand_dims(i_normed, 2))
sup_similarity.append(similarity)

return tf.squeeze(tf.stack(sup_similarity, axis=1))

Finally, we use a function called train to perform our training operation—let's see this step by step:

 def train(self, support_set_image, support_set_label, query_image):  

First, we encode the features of support set images using our image encoder:

    support_set_image_encoded = [self.image_encoder(i) for i in tf.unstack(support_set_image, axis=1)]

Then, we will also encode the features of query set images using the image encoder:

    query_image_encoded = self.image_encoder(query_image)

Next, we will learn the embeddings of our support set using our embedding function, :

     g_embedding = self.g(support_set_image_encoded) 

Similarly, we will also learn the embeddings of our query set using our embedding function, f:

    f_embedding = self.f(query_image_encoded, g_embedding) 

Now, we calculate cosine_similarity between both of these embeddings:

    embeddings_similarity = self.cosine_similarity(f_embedding, g_embedding) 

Then, we perform softmax attention over this similarity:

    attention = tf.nn.softmax(embeddings_similarity)

We predict a query set label by multiplying our attention matrix with one-hot encoded support set labels:

    y_hat = tf.matmul(tf.expand_dims(attention, 1), tf.one_hot(support_set_label, self.n_way))

Next, we get probabilities:

    probabilities = tf.squeeze(y_hat)  

We select the index that has the highest probability as a class of the query image:

    predictions = tf.argmax(self.logits, 1)

Finally, we define our loss function; we use softmax cross-entropy as our loss function:

    loss_function = tf.losses.sparse_softmax_cross_entropy(label, self.probabilities)

We minimize our loss function using AdamOptimizer:

    tf.train.AdamOptimizer(self.lr).minimize(self.loss_op)

Now, we will see the final code of our matching network as a whole:


class Matching_network():

#initialize all the variables
def __init__(self, lr, n_way, k_shot, batch_size=32):

#placeholder for support set
self.support_set_image = tf.placeholder(tf.float32, [None, n_way * k_shot, 28, 28, 1])
self.support_set_label = tf.placeholder(tf.int32, [None, n_way * k_shot, ])

#placeholder for query set
self.query_image = tf.placeholder(tf.float32, [None, 28, 28, 1])
self.query_label = tf.placeholder(tf.int32, [None, ])

#encoder function for extracting features from the image
def image_encoder(self, image):

with slim.arg_scope([slim.conv2d], num_outputs=64, kernel_size=3, normalizer_fn=slim.batch_norm):
#conv1
net = slim.conv2d(image)
net = slim.max_pool2d(net, [2, 2])

#conv2
net = slim.conv2d(net)
net = slim.max_pool2d(net, [2, 2])

#conv3
net = slim.conv2d(net)
net = slim.max_pool2d(net, [2, 2])

#conv4
net = slim.conv2d(net)
net = slim.max_pool2d(net, [2, 2])

return tf.reshape(net, [-1, 1 * 1 * 64])


#embedding function for extracting support set embeddings
def g(self, x_i):

forward_cell = rnn.BasicLSTMCell(32)
backward_cell = rnn.BasicLSTMCell(32)
outputs, state_forward, state_backward = rnn.static_bidirectional_rnn(forward_cell, backward_cell, x_i, dtype=tf.float32)

return tf.add(tf.stack(x_i), tf.stack(outputs))


#embedding function for extracting query set embeddings
def f(self, XHat, g_embedding):
cell = rnn.BasicLSTMCell(64)
prev_state = cell.zero_state(self.batch_size, tf.float32)

for step in xrange(self.processing_steps):
output, state = cell(XHat, prev_state)

h_k = tf.add(output, XHat)

content_based_attention = tf.nn.softmax(tf.multiply(prev_state[1], g_embedding))

r_k = tf.reduce_sum(tf.multiply(content_based_attention, g_embedding), axis=0)

prev_state = rnn.LSTMStateTuple(state[0], tf.add(h_k, r_k))

return output

#cosine similarity function for calculating cosine similarity between support set and query set embeddings
def cosine_similarity(self, target, support_set):
target_normed = target
sup_similarity = []
for i in tf.unstack(support_set):
i_normed = tf.nn.l2_normalize(i, 1)
similarity = tf.matmul(tf.expand_dims(target_normed, 1), tf.expand_dims(i_normed, 2))
sup_similarity.append(similarity)

return tf.squeeze(tf.stack(sup_similarity, axis=1))


def train(self, support_set_image, support_set_label, query_image):

#encode the features of query set images using our image encoder
query_image_encoded = self.image_encoder(query_image)

#encode the features of support set images using our image encoder
support_set_image_encoded = [self.image_encoder(i) for i in tf.unstack(support_set_image, axis=1)]

#generate support set embeddings using our embedding function g
g_embedding = self.g(support_set_image_encoded)

#generate query set embeddings using our embedding function f
f_embedding = self.f(query_image_encoded, g_embedding)

#calculate the cosine similarity between both of these embeddings
embeddings_similarity = self.cosine_similarity(f_embedding, g_embedding)

#perform attention over the embedding similarity
attention = tf.nn.softmax(embeddings_similarity)

#now predict query set label by multiplying attention matrix with one hot encoded support set labels
y_hat = tf.matmul(tf.expand_dims(attention, 1), tf.one_hot(support_set_label, self.n_way))

#get the probabilities
probabilities = tf.squeeze(y_hat)

#select the index which has the highest probability as a class of query image
predictions = tf.argmax(self.probabilities, 1)

#we use softmax cross entropy loss as our loss function
loss_function = tf.losses.sparse_softmax_cross_entropy(label, self.probabilities)

#we minimize the loss using adam optimizer
tf.train.AdamOptimizer(self.lr).minimize(self.loss_op)
..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset
3.145.73.207