Initialization

Initialization in the CycleGAN implementation is more complex than the previous implementations—we have to build a few more generators and discriminators to pass to the GAN class. 

Let’s dive right in:

  1. Imports should be fairly obvious by now—these are the key pieces and classes we are importing in this training class definition:
#!/usr/bin/env python3
from gan import GAN
from generator import Generator
from discriminator import Discriminator
from keras.layers import Input
from keras.datasets import mnist
from random import randint
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
import os
from PIL import Image
import numpy as np
  1. Instantiate the Trainer class with all of the input variables:
class Trainer:
def __init__(self, height = 64, width = 64, epochs = 50000,
batch = 32, checkpoint = 50, train_data_path_A =
'',train_data_path_B =
'',test_data_path_A='',test_data_path_B=''):
self.EPOCHS = epochs
self.BATCH = batch
self.RESIZE_HEIGHT = height
self.RESIZE_WIDTH = width
self.CHECKPOINT = checkpoint

We now have two separate folders for both a training and test setup. One of the interesting parts of this recipe is that we will train a model on one dataset and demonstrate the results of the generator on the test dataset—this makes sure that we aren't overfitting to the relationship we are learning between the train_A and train_B datasets.

  1. Load all of the data into its respective class variables—we've got some new helper functions that make this happen:
self.X_train_A, self.H_A, self.W_A, self.C_A = 
self.load_data(train_data_path_A)
self.X_train_B, self.H_B, self.W_B, self.C_B =
self.load_data(train_data_path_B)
self.X_test_A, self.H_A_test, self.W_A_test, self.C_A_test =
self.load_data(test_data_path_A)
self.X_test_B, self.H_B_test, self.W_B_test, self.C_B_test =
self.load_data(test_data_path_B)

We'll cover the design of the load_data method later in this recipe. For now, just understand that the load_data function expects a string that represents the path to the folder and it'll read every image with a certain file ending within that folder.

  1. We need the generators that go from A to B and from B to A. The instantiation of these models is direct:
self.generator_A_to_B = Generator(height=self.H_A, width=self.W_A, 
channels=self.C_A)
self.generator_B_to_A = Generator(height=self.H_B, width=self.W_B,
channels=self.C_B)
  1. Here's where we start to get serious—add the following lines to your instantiation in the class definition for training:
self.orig_A = Input(shape=(self.W_A, self.H_A, self.C_A))
self.orig_B = Input(shape=(self.W_B, self.H_B, self.C_B))

self.fake_B = self.generator_A_to_B.Generator(self.orig_A)
self.fake_A = self.generator_B_to_A.Generator(self.orig_B)
self.reconstructed_A = self.generator_B_to_A.Generator(self.fake_B)
self.reconstructed_B = self.generator_A_to_B.Generator(self.fake_A)
self.id_A = self.generator_B_to_A.Generator(self.orig_A)
self.id_B = self.generator_A_to_B.Generator(self.orig_B)

There are three distinct ideas contained in this block and other pre-step:

  • First, we need to make sure we have the original A and B images stored as the Input class from Keras. Variables orig_A and orig_B are the input values shared among the next three components.
  • fake_A and fake_B are the generators that take us from one style to the other and produce an image with the translated style. Hence, this is why we say they are fake.
  • reconstructed_A and reconstructed_B take the fake A and B images and retranslate them into the original image style.
  • id_A and id_B are identity functions because they take in the original image and translate back into the same style. Ideally, these functions would not apply any style changes to these images

You now have the key generator pieces for us to construct the GAN.

There's more though! We need our discriminators that evaluate both A and B images—we also need a validity discriminator that checks the fake_A and fake_B generators:

self.discriminator_A = Discriminator(height=self.H_A, width=self.W_A, 
channels=self.C_A)
self.discriminator_B = Discriminator(height=self.H_B, width=self.W_B,
channels=self.C_B)
self.discriminator_A.trainable = False
self.discriminator_B.trainable = False
self.valid_A = self.discriminator_A.Discriminator(self.fake_A)
self.valid_B = self.discriminator_B.Discriminator(self.fake_B)

Here's where a bit of the magic happens; since we have set up our classes in a structured way, we are able to simply pass all of the models to the GAN class and it will construct our adversarial model:

model_inputs  = [self.orig_A,self.orig_B]
model_outputs = [self.valid_A,
self.valid_B,self.reconstructed_A,self.reconstructed_B,self.id_A,
self.id_B]
self.gan =
GAN(model_inputs=model_inputs,model_outputs=model_outputs,
lambda_cycle=10.0,lambda_id=1.0)

..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset
18.188.85.135