Training Loop

Now, we are at the most important part of the code. The part where all the functions we previously defined will be used. 

Note - Batch size used here is 128

  • step 1 - Load the generator by calling the function img_generator()
  • step2 - Load the discriminator by calling the function img_discriminator() and compile it with 'binary cross-entropy loss' and optimizer as 'optimizer_d' which we have defined under the hyperparameters section.
  • step3 -  Feed the generator and the discriminator to the function dcgan() and compile it with 'binary cross-entropy loss' and optimizer as 'optimizer_g' which we have defined under the hyperparameters section.
  • step4 - Create a new batch of original images and masked images. Generate new fake images by feeding the batch of masked images to the generator.
  • step5 - Concatenate the original and generated images so that the 1st 128 images are all original and the next 128 images are all fake. It is important that you do not shuffle the data here, else it will be hard to train.  Label the generated images as 0 and original images as 0.9 instead of 1. This is one-sided label smoothing on the original images. The reason for using label smoothing is to make the network resilient to adversarial examples. And, its called one-sided because we are smoothing labels only for the real images.
  • step6 - Set discriminator.trainable to true to enable training of the discriminator and feed this set of 256 images and their corresponding labels to the discriminator for classification. 
  • step7 - Now, Set discriminator.trainable to False and feed a new batch of 128 masked images labeled as 1 to the gan(dcgan) for classification. It is important to set discriminator.trainable to False to make sure the discriminator is not getting trained while training the generator.
  • step8 - repeat step 4 thru 7 for desired number epochs.

 

We have placed the plot_generated_images_combined() function and the generated_images_plot() function such that we get a plot generated by both functions after the 1st iteration in the 1st epoch and after the end of each epoch.

Feel free to place these plot functions according to the frequency of plots you need displayed.

def train(X_train, noised_train_data,
input_shape, smooth_real,
epochs, batch_size,
optimizer_g, optimizer_d):

# define two empty lists to store the discriminator
# and the generator losses
discriminator_losses = []
generator_losses = []

# Number of iteration possible with batches of size 128
iterations = X_train.shape[0] // batch_size

# Load the generator and the discriminator
generator = img_generator(input_shape)
discriminator = img_discriminator(input_shape)

# Compile the discriminator with binary_crossentropy loss
discriminator.compile(loss='binary_crossentropy',optimizer=optimizer_d)

# Feed the generator and the discriminator to the function dcgan
# to form the DCGAN architecture
gan = dcgan(discriminator, generator, input_shape)

# Compile the DCGAN with binary_crossentropy loss
gan.compile(loss='binary_crossentropy', optimizer=optimizer_g)

for i in range(epochs):
print ('Epoch %d' % (i+1))
# Use tqdm to get an estimate of time remaining
for j in tqdm(range(1, iterations+1)):

# batch of original images (batch = batchsize)
original = X_train[np.random.randint(0, X_train.shape[0], size=batch_size)]

# batch of noised images (batch = batchsize)
noise = noised_train_data[np.random.randint(0, noised_train_data.shape[0], size=batch_size)]

# Generate fake images
generated_images = generator.predict(noise)

# Labels for generated data
dis_lab = np.zeros(2*batch_size)

# data for discriminator
dis_train = np.concatenate([original, generated_images])

# label smoothing for original images
dis_lab[:batch_size] = smooth_real

# Train discriminator on original iamges
discriminator.trainable = True
discriminator_loss = discriminator.train_on_batch(dis_train, dis_lab)

# save the losses
discriminator_losses.append(discriminator_loss)

# Train generator
gen_lab = np.ones(batch_size)
discriminator.trainable = False
sample_indices = np.random.randint(0, X_train.shape[0], size=batch_size)
original = X_train[sample_indices]
noise = noised_train_data[sample_indices]

generator_loss = gan.train_on_batch(noise, gen_lab)

# save the losses
generator_losses.append(generator_loss)

if i == 0 and j == 1:
print('Iteration - %d', j)
generated_images_plot(original, noise, generator)
plot_generated_images_combined(original, noise, generator)

print("Discriminator Loss: ", discriminator_loss,
", Adversarial Loss: ", generator_loss)

# training plot 1
generated_images_plot(original, noise, generator)
# training plot 2
plot_generated_images_combined(original, noise, generator)


# plot the training losses
plt.figure()
plt.plot(range(len(discriminator_losses)), discriminator_losses,
color='red', label='Discriminator loss')
plt.plot(range(len(generator_losses)), generator_losses,
color='blue', label='Adversarial loss')
plt.title('Discriminator and Adversarial loss')
plt.xlabel('Iterations')
plt.ylabel('Loss (Adversarial/Discriminator)')
plt.legend()
plt.show()

return generator

generator = train(X_train, noised_train_data,
input_shape, smooth_real,
epochs, batch_size,
optimizer_g, optimizer_d)
       
 
                                                               
Figure 14.11.1: Generated images plotted with training plots at the end of the 1st iteration of epoch 1
                                    
                        
Figure 14.11.2: Generated images plotted with training plots at the end of epoch 2
                        
Figure 14.11.3: Generated images plotted with training plots at the end of epoch 5
Figure 14.11: Images generated during training
   
Figure 14.12: Plot of Discriminator and Adversarial Loss during training
Play around the learning rate for both the generator and the discriminator to find the optimal values for your use case. In general, when training GAN's, you train it for a large number of epochs and then use the above loss v/s iteration plot to identify the minimum spot you would like for the training to stop.
..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset
3.149.250.11