Putting the losses together

In TensorFlow, we can implement the whole GAN loss, as shown in the following code. As input, we take the output of the discriminator for a batch of fake images from the generator and a batch of real images from our dataset:

def gan_loss(logits_real, logits_fake):    
# Target label vectors for generator and discriminator losses.
true_labels = tf.ones_like(logits_real)
fake_labels = tf.zeros_like(logits_fake)
# DISCRIMINATOR loss has 2 parts: how well it classifies real images and how well it
# classifies fake images.
real_image_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_real, labels=true_labels)
fake_image_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=fake_labels)

# Combine and average losses over the batch
discriminator_loss = tf.reduce_mean(real_image_loss + fake_image_loss)

# GENERATOR is trying to make the discriminator output 1 for all its images.
# So we use our target label vector of ones for computing generator loss.
generator_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=true_labels)

# Average generator loss over the batch.
generator_loss = tf.reduce_mean(G_loss)

return discriminator_loss , generator_loss

You probably noted that it is impossible to maximize both discriminator loss and the generator loss at the same time. This is the beauty of the GAN, as when it trains, the model will hopefully reach some equilibrium, where the generator is having to produce really good quality images in order to fool the discriminator.

TensorFlow only allows its optimizers to minimize and not maximize. As a result, we actually take the negative of the loss functions described earlier, which means we go from maximizing them to minimizing them. We don't have to do anything extra though, as tf.nn.sigmoid_cross_entropy_with_logits() takes care of this for us.
..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset
52.15.135.175