Creating and compiling an adversarial network

The adversarial network is a combined network. It uses all four networks in a single Keras model. The main purpose of creating an adversarial network is to train the generator networks. When we train the adversarial network, it only trains the generator networks, but freezes the training of the discriminator networks. Let's create an adversarial model with the desired functionality.

  1. Start by creating two input layers to the network, as follows:
inputA = Input(shape=(128, 128, 3))
inputB = Input(shape=(128, 128, 3))

Both inputs will take images of a dimension of (128, 128, 3). These are symbolic input variables and don't hold actual values. They are used to create a Keras model (TensorFlow graph).

  1. Next, use the generator networks to generate fake images, as follows:
generatedB = generatorAToB(inputA)
generatedA = generatorBToA(inputB)

Use the symbolic input layers to generate images.

  1. Now, reconstruct the original images using the generator networks again, as follows:
reconstructedA = generatorBToA(generatedB)
reconstructedB = generatorAToB(generatedA)
  1. Use the generator networks to generate fake images, as follows:
generatedAId = generatorBToA(inputA)
generatedBId = generatorAToB(inputB)

The generator network A (generatorAToB) will translate an image from domain A to domain B. Similarly, the generator network B (generatorBToA) will translate an image from domain B to domain A.

  1. Next, make both of the discriminator networks non-trainable, as follows:
discriminatorA.trainable = False
discriminatorB.trainable = False

We don't want to train the discriminator networks in our adversarial network.

  1. Use the discriminator networks to predict whether each generated image is real or fake, as follows:
probsA = discriminatorA(generatedA)
probsB = discriminatorB(generatedB)
  1. Create a Keras model and specify the inputs and outputs for the network, as follows:
adversarial_model = Model(inputs=[inputA, inputB],outputs=[probsA, probsB, reconstructedA, reconstructedB, generatedAId, generatedBId])

Our adversarial network will take two input values, which are Tensors, and return six output values, which are also Tensors.

  1. Next, compile the adversarial network, as follows:
adversarial_model.compile(loss=['mse', 'mse', 'mae', 'mae', 'mae', 'mae'],
loss_weights=[1, 1, 10.0, 10.0, 1.0, 1.0],
optimizer=common_optimizer)

The adversarial network returns six values and we need to specify the loss function for each output value. For the first two values, we are using mean squared error loss, as this is part of the adversarial loss. For the next four values, we are using mean absolute error loss, which is part of the cycle-consistency loss. The weight values for six different losses are 1, 1, 10.0, 10.0, 1.0, 1.0. We are using common_optimizer to train the network.

We have now successfully created a Keras model for the adversarial network. If you have difficulty in understanding how a Keras model works, have a look at the documentation of the TensorFlow graph and its functionality.

Before embarking on the training, perform the following two essential steps. TensorBoard will be used in the later sections:

Add TensorBoard to store the losses and the graphs for visualization purposes, as follows:

tensorboard = TensorBoard(log_dir="logs/{}".format(time.time()), write_images=True, write_grads=True,
write_graph=True)
tensorboard.set_model(generatorAToB)
tensorboard.set_model(generatorBToA)
tensorboard.set_model(discriminatorA)
tensorboard.set_model(discriminatorB)

Create a four-dimensional array containing all values equal to one, which represents the real label. Similarly, create another four-dimensional array with all values equal to zero, which represents the fake label, as follows:

real_labels = np.ones((batch_size, 7, 7, 1))
fake_labels = np.zeros((batch_size, 7, 7, 1))

Use numpy's ones() and zeros() functions to create the desired ndarrays. Now that we have the essential components ready, let's start the training.

..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset
18.191.233.205