Training the GAN

So now that we have a generator, a discriminator, and our loss function, all that is left is to train! We will give a sketch idea of how to do this in TensorFlow, because there is nothing fancy in this part; it is just piecing together the stuff from the previous section, along with loading and feeding MNIST images, as we did earlier.

First, set up two solvers: one for the discriminator and one for the generator. A smaller value of beta1 for the AdamOptimizer is used as it has been shown to help GAN train to converge:

discriminator_solver = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.5)
generator_solver = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.5)

Next, create a random noise vector; this can be done with tf.random_uniform. This is fed to the generator network to create a batch of generated images:

z = tf.random_uniform(maxval=1,minval=-1,shape=[batch_size, dim]) 
generator_sample = generator(z)

Then, we feed a batch of real images and our batch of generated samples to the discriminator. We use variable scope here to reuse our model variables and ensure that a second graph isn't created:

 with tf.variable_scope("") as scope:   
logits_real = discriminator(x)
# We want to re-use the discriminator weights.
scope.reuse_variables()
logits_fake = discriminator(generator_sample )

We separate the weights belonging to the discriminator and the generator, as we need to update them separately:

discriminator_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator')
generator_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'generator')

Finally, we calculate our losses and send them to our optimizers with the relevant weights to update:


discriminator_loss, generator_loss = gan_loss(logits_real, logits_fake)

# Training steps.
discriminator_train_step = discriminator_solver.minimize(discriminator_loss, var_list=discriminator_vars )
generator_train_step = generator_solver.minimize(generator_loss , var_list=generator_vars )

These are the main steps to training a GAN. All that is left is to create a training loop, iterating over batches of data. If you do this, you should be able to feed in any random noise vector, like we did in training, and generate an image.

 

As you can see in the following diagram, the images created are starting to resemble MNIST digits:

..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset
18.223.108.119