Matching networks

Matching networks are yet another simple and efficient one-shot learning algorithm published by Google's DeepMind team. It can even produce labels for the unobserved class in the dataset.

Let's say we have a support set, S, containing K examples as . When given a query point (a new unseen example), , the matching network predicts the class of by comparing it with the support set.

We can define this as , where is the parameterized neural network, is the predicted class for the query point, , and is the support set. will return the probability of belonging to each of the classes in the dataset. Then, we select the class of as the one that has the highest probability. But how does this work exactly? How is this probability computed? Let's us see that now.

The output, , for the query point, , can be predicted as follows:

Let's decipher this equation. and are the input and labels of the support set. is the query input— the input to which we want to predict the label. is the attention mechanism between and . But how do we perform attention? Here we use a simple attention mechanism, which is the softmax function over the cosine distance between and —that is, .

We can't calculate the cosine distance between the raw input, and , directly. So, first, we will learn their embeddings and calculate the cosine distance between the embeddings. We use two different embeddings,  and , for learning the embeddings of the query input, and support set input, , respectively. We will see how exactly these two embedding functions,  and , learn the embeddings in the upcoming section.

So, we can rewrite our attention equation as follows:

We can rewrite the previous equation as follows:

So, after calculating the attention matrix, , we multiply our attention matrix with the support set labels, . But how can we multiply the support set labels with our attention matrix? First, we convert our support set labels into one-hot encoded values and then multiply them with our attention matrix and, as a result, we get the probability of  belonging to each of the classes in the support set. Then, we apply argmax and select as the one that has a maximum probability value.

Are you still not clear about matching networks? Look at the following diagram; as you can see, we have three classes in our support set, {lion, elephant and dog}, and we have a new query image, . First, we feed the support set to embedding function, , and the query image to embedding function, , and learn their embeddings and calculate the cosine distance between them; then, we apply softmax attention over this cosine distance. Then, we multiply our attention matrix with the one-hot encoded support set labels and get the probabilities, and then we select as the one that has the highest probability. As you can see in the following diagram, the query set image is an elephant, and we have a high probability at the index 1, so we predict the class of as 1 (elephant):

