Training the complete network

We looked at individual pieces of how a GAN is trained. Let's summarize them as follows and look at the complete code that will be used to train the GAN network we created:

  • Train the discriminator network with real images
  • Train the discriminator network with fake images
  • Optimize the discriminator
  • Train the generator based on the discriminator feedback
  • Optimize the generator network alone

We will use the following code to train the network:

for epoch in range(niter):
for i, data in enumerate(dataloader, 0):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
# train with real
netD.zero_grad()
real, _ = data
batch_size = real.size(0)
if torch.cuda.is_available():
real = real.cuda()
input.resize_as_(real).copy_(real)
label.resize_(batch_size).fill_(real_label)
inputv = Variable(input)
labelv = Variable(label)

output = netD(inputv)
errD_real = criterion(output, labelv)
errD_real.backward()
D_x = output.data.mean()

# train with fake
noise.resize_(batch_size, nz, 1, 1).normal_(0, 1)
noisev = Variable(noise)
fake = netG(noisev)
labelv = Variable(label.fill_(fake_label))
output = netD(fake.detach())
errD_fake = criterion(output, labelv)
errD_fake.backward()
D_G_z1 = output.data.mean()
errD = errD_real + errD_fake
optimizerD.step()

############################
# (2) Update G network: maximize log(D(G(z)))
###########################
netG.zero_grad()
labelv = Variable(label.fill_(real_label)) # fake labels are real for generator cost
output = netD(fake)
errG = criterion(output, labelv)
errG.backward()
D_G_z2 = output.data.mean()
optimizerG.step()

print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
% (epoch, niter, i, len(dataloader),
errD.data[0], errG.data[0], D_x, D_G_z1, D_G_z2))
if i % 100 == 0:
vutils.save_image(real_cpu,
'%s/real_samples.png' % outf,
normalize=True)
fake = netG(fixed_noise)
vutils.save_image(fake.data,
'%s/fake_samples_epoch_%03d.png' % (outf, epoch),
normalize=True)

The vutils.save_image will take a tensor and save it as an image. If provided with a mini-batch of images, then it saves them as a grid of images.

In the following sections, we will take a look at what the generated images and the real images look like. 

..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset
18.188.10.1