We looked at individual pieces of how a GAN is trained. Let's summarize them as follows and look at the complete code that will be used to train the GAN network we created:
- Train the discriminator network with real images
- Train the discriminator network with fake images
- Optimize the discriminator
- Train the generator based on the discriminator feedback
- Optimize the generator network alone
We will use the following code to train the network:
for epoch in range(niter):
for i, data in enumerate(dataloader, 0):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
# train with real
netD.zero_grad()
real, _ = data
batch_size = real.size(0)
if torch.cuda.is_available():
real = real.cuda()
input.resize_as_(real).copy_(real)
label.resize_(batch_size).fill_(real_label)
inputv = Variable(input)
labelv = Variable(label)
output = netD(inputv)
errD_real = criterion(output, labelv)
errD_real.backward()
D_x = output.data.mean()
# train with fake
noise.resize_(batch_size, nz, 1, 1).normal_(0, 1)
noisev = Variable(noise)
fake = netG(noisev)
labelv = Variable(label.fill_(fake_label))
output = netD(fake.detach())
errD_fake = criterion(output, labelv)
errD_fake.backward()
D_G_z1 = output.data.mean()
errD = errD_real + errD_fake
optimizerD.step()
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
netG.zero_grad()
labelv = Variable(label.fill_(real_label)) # fake labels are real for generator cost
output = netD(fake)
errG = criterion(output, labelv)
errG.backward()
D_G_z2 = output.data.mean()
optimizerG.step()
print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
% (epoch, niter, i, len(dataloader),
errD.data[0], errG.data[0], D_x, D_G_z1, D_G_z2))
if i % 100 == 0:
vutils.save_image(real_cpu,
'%s/real_samples.png' % outf,
normalize=True)
fake = netG(fixed_noise)
vutils.save_image(fake.data,
'%s/fake_samples_epoch_%03d.png' % (outf, epoch),
normalize=True)
The vutils.save_image will take a tensor and save it as an image. If provided with a mini-batch of images, then it saves them as a grid of images.
In the following sections, we will take a look at what the generated images and the real images look like.