The training loop

We have previously had the luxury of calling .fit() on our model and letting Keras handle the painful process of breaking the data apart into minibatches and training for us.

Unfortunately, because we need to perform the separate updates for the discriminator and the stacked model together for a single batch we're going to have to do things the old fashioned way, with a few loops. This is how things used to be done all the time, so while it's perhaps a little more work, it does admittedly leave me feeling nostalgic. The following code illustrates the training technique:

num_examples = X_train.shape[0]
num_batches = int(num_examples / float(batch_size))
half_batch = int(batch_size / 2)

for epoch in range(epochs + 1):
for batch in range(num_batches):
# noise images for the batch
noise = np.random.normal(0, 1, (half_batch, 100))
fake_images = generator.predict(noise)
fake_labels = np.zeros((half_batch, 1))
# real images for batch
idx = np.random.randint(0, X_train.shape[0], half_batch)
real_images = X_train[idx]
real_labels = np.ones((half_batch, 1))
# Train the discriminator (real classified as ones and
generated as zeros)
d_loss_real = discriminator.train_on_batch(real_images,
d_loss_fake = discriminator.train_on_batch(fake_images,
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
noise = np.random.normal(0, 1, (batch_size, 100))
# Train the generator
g_loss = combined.train_on_batch(noise, np.ones((batch_size, 1)))
# Plot the progress
print("Epoch %d Batch %d/%d [D loss: %f, acc.: %.2f%%] [G loss:
(epoch,batch, num_batches, d_loss[0], 100 * d_loss[1], g_loss))
if batch % 50 == 0:
save_imgs(generator, epoch, batch)

There is a lot going on here, to be sure. As before, let's break it down block by block. First, let's see the code to generate noise vectors:

        noise = np.random.normal(0, 1, (half_batch, 100))
fake_images = generator.predict(noise)
fake_labels = np.zeros((half_batch, 1))

This code is generating a matrix of noise vectors (which we've previously called z) and sending it to the generator. It's getting a set of generated images back, which I'm calling fake images. We will use these to train the discriminator, so the labels we want to use are 0s, indicating that these are in fact generated images.

Note that the shape here is half_batch x 28 x 28 x 1. The half_batch is exactly what you think it is. We're creating half a batch of generated images because the other half of the batch will be real data, which we will assemble next. To get our real images, we will generate a random set of indices across X_train and use that slice of X_train as our real images, as shown in the following code:

idx = np.random.randint(0, X_train.shape[0], half_batch)
real_images = X_train[idx]
real_labels = np.ones((half_batch, 1))
Yes, we are sampling with replacement in this case. It does work out but it's probably not the best way to implement minibatch training. It is, however, probably the easiest and most common.

Since we are using these images to train the discriminator, and because they are real images, we will assign them 1s as labels, rather than 0s. Now that we have our discriminator training set assembled, we will update the discriminator. Also note that we aren't using the soft labels that we had discussed previously. That's because I wanted to keep things as easy as they can be to understand. Luckily the network doesn't require them in this case. We will use the following code to to train the discriminator:

# Train the discriminator (real classified as ones and generated as zeros)
d_loss_real = discriminator.train_on_batch(real_images, real_labels)
d_loss_fake = discriminator.train_on_batch(fake_images, fake_labels)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

Notice that here I'm using the discriminator's train_on_batch() method. This is the first time I've used this method in the book. The train_on_batch() method does exactly one round of forward and backwards propagation. Every time we call it, it updates the model once from the model's previous state.

Also notice that I'm making the update for the real images and fake images separately. This is advice that is given on the GAN hack Git I had previously referenced in the Generator architecture section. Especially in the early stages of training, when real images and fake images are from radically different distributions, batch normalization will cause problems with training if we were to put both sets of data in the same update.

Now that the discriminator has been updated, it's time to update the generator. This is done indirectly by updating the combined stack, as shown in the following code:

noise = np.random.normal(0, 1, (batch_size, 100))
g_loss = combined.train_on_batch(noise, np.ones((batch_size, 1)))

To update the combined model, we create a new noise matrix, and this time it will be as large as the entire batch. We will use that as an input to the stack, which will cause the generator to generate an image and the discriminator to evaluate that image. Finally, we will use the label of 1 because we want to back propagate the error between a real image and the generated image.

Lastly, the training loop reports the discriminator and generator loss at the epoch/batch and then, every 50 batches, of every epoch we will use save_imgs to generate example images and save them to disk, as shown in the following code:

print("Epoch %d Batch %d/%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" %
(epoch,batch, num_batches, d_loss[0], 100 * d_loss[1], g_loss))

if batch % 50 == 0:
save_imgs(generator, epoch, batch)

The save_imgs function uses the generator to create images as we go, so we can see the fruits of our labor. We will use the following code to define save_imgs:

def save_imgs(generator, epoch, batch):
r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, 100))
gen_imgs = generator.predict(noise)
gen_imgs = 0.5 * gen_imgs + 0.5

fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
axs[i, j].axis('off')
cnt += 1
fig.savefig("images/mnist_%d_%d.png" % (epoch, batch))

It uses only the generator by creating a noise matrix and retrieving an image matrix in return. Then, using matplotlib.pyplot, it saves those images to disk in a 5 x 5 grid.

