Building relation networks using TensorFlow

The relation function is pretty simple, right? We will understand relation networks better by implementing one in TensorFlow.

You can also check the code available as a Jupyter Notebook with an explanation here: https://github.com/sudharsan13296/Hands-On-Meta-Learning-With-Python/blob/master/04.%20Relation%20and%20Matching%20Networks%20Using%20Tensorflow/4.5%20Building%20Relation%20Network%20Using%20Tensorflow.ipynb.

First, we import all of the required libraries:

import tensorflow as tf
import numpy as np

We will randomly generate our data points. Let's say we have two classes in our dataset; we will randomly generate some 1,000 data points for each of these classes:

classA = np.random.rand(1000,18)
ClassB = np.random.rand(1000,18)

We create our dataset by combining both of these classes:

data = np.vstack([classA, ClassB])

Now, we set the labels; we assign the 1 label for classA and the 0 label for classB:

label = np.vstack([np.ones((len(classA),1)),np.zeros((len(ClassB),1))])

So, our dataset will have 2,000 records:

data.shape
(2000, 18)

Now, we will define the placeholders for our support and query sets:

xi = tf.placeholder(tf.float32, [None, 9])
xj = tf.placeholder(tf.float32, [None, 9])

Define the placeholder for the y label, as follows:

y = tf.placeholder(tf.float32, [None, 1]) 

Now, we will define our embedding function that will learn the embeddings of the support and query sets. We will use a normal feedforward network as our embedding function:

def embedding_function(x):

weights = tf.Variable(tf.truncated_normal([9,1]))
bias = tf.Variable(tf.truncated_normal([1]))

a = (tf.nn.xw_plus_b(x,weights,bias))
embeddings = tf.nn.relu(a)

return embeddings

We compute the embeddings for the support set:

f_xi = embedding_function(xi)

We compute the embeddings for the query set:

f_xj = embedding_function(xj)

Now that we have calculated the embeddings and have the feature vectors, we combine both the support set and query set feature vectors:

Z = tf.concat([f_xi,f_xj],axis=1)

We define our relation function as three-layered neural network with ReLU activations:

def relation_function(x):
w1 = tf.Variable(tf.truncated_normal([2,3]))
b1 = tf.Variable(tf.truncated_normal([3]))

w2 = tf.Variable(tf.truncated_normal([3,5]))
b2 = tf.Variable(tf.truncated_normal([5]))

w3 = tf.Variable(tf.truncated_normal([5,1]))
b3 = tf.Variable(tf.truncated_normal([1]))

#layer1
z1 = (tf.nn.xw_plus_b(x,w1,b1))
a1 = tf.nn.relu(z1)

#layer2
z2 = tf.nn.xw_plus_b(a1,w2,b2)
a2 = tf.nn.relu(z2)

#layer3
z3 = tf.nn.xw_plus_b(z2,w3,b3)

#output
y = tf.nn.sigmoid(z3)

return y

We now pass the concatenated feature vectors of the support and query sets to the relation function and get the relation scores:

relation_scores = relation_function(Z)

We compute loss_function as MSE, which is squared_difference between relation_scores and the actual y value:

loss_function = tf.reduce_mean(tf.squared_difference(relation_scores,y))

We can minimize the loss using AdamOptimizer:

optimizer = tf.train.AdamOptimizer(0.1)
train = optimizer.minimize(loss_function)

Now, let's start our TensorFlow session:

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

Now, we randomly sample data points for our support set, xi, and query set, xj, and train the network:

for episode in range(1000):
_, loss_value = sess.run([train, loss_function],
feed_dict={xi:data[:,0:9]+np.random.randn(*np.shape(data[:,0:9]))*0.05,
xj:data[:,9:]+np.random.randn(*np.shape(data[:,9:]))*0.05,
y:label})
if episode % 100 == 0:
print("Episode {}: loss {:.3f} ".format(episode, loss_value))

We can see the output as follows:

Episode 0: loss 0.495 
Episode 100: loss 0.250 
Episode 200: loss 0.250 
Episode 300: loss 0.250 
Episode 400: loss 0.250 
Episode 500: loss 0.250 
Episode 600: loss 0.250 
Episode 700: loss 0.250 
Episode 800: loss 0.250 
Episode 900: loss 0.250 
..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset
18.189.178.237