Style loss

The style loss is calculated across multiple layers. Style loss is the MSE of the gram matrix generated for each feature map. The gram matrix represents the correlation value of its features. Let's understand how gram matrix works by using the following diagram and a code implementation.

The following table shows the output of a feature map of dimension [2, 3, 3, 3], having the column attributes Batch_size, Channels, and Values:

To calculate the gram matrix, we flatten all the values per channel and then find correlation by multiplying with its transpose, as shown in the following table:

All we did is flatten all the values, with respect to each channel, to a single vector or tensor. The following code implements this:

class GramMatrix(nn.Module):

def forward(self,input):
b,c,h,w = input.size()
features = input.view(b,c,h*w)
gram_matrix = torch.bmm(features,features.transpose(1,2))
gram_matrix.div_(h*w)
return gram_matrix

We implement the GramMatrix as another PyTorch module with a forward function so that we can use it like a PyTorch layer. We are extracting the different dimensions from the input image in this line:

b,c,h,w = input.size()

Here, b represents batch, c represents filters or channels, h represents height, and w represents width. In the next step, we will use the following code to keep the batch and channel dimensions intact and flatten all the values along the height and width dimension, as shown in the preceding figure:

features = input.view(b,c,h*w)

The gram matrix is calculated by multiplying the flattening values along with its transposed vector. We can do it by using the PyTorch batch matrix multiplication function, provided as torch.bmm(), as shown in the following code:

gram_matrix = torch.bmm(features,features.transpose(1,2))

We finish normalizing the values of the gram matrix by dividing it by the number of elements. This prevents a particular feature map with a lot of values dominating the score. Once GramMatrix is calculated, it becomes simple to calculate style loss, which is implemented in this code:

class StyleLoss(nn.Module):

def forward(self,inputs,targets):
out = nn.MSELoss()(GramMatrix()(inputs),targets)
return (out)

The StyleLoss is implemented as another PyTorch layer. It calculates the MSE between the input GramMatrix values and the style image GramMatrix values.

..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset
18.189.171.153