Initialization

Take the following steps to initialize the training class and the basic functionality needed to train the models:

  1. Create a train.py file and place the following imports at the top of the file:
#!/usr/bin/env python3
from gan import GAN
from generator import Generator
from discriminator import Discriminator
from keras.datasets import mnist
from keras.layers import Input
from random import randint
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os
import numpy as np
from copy import deepcopy
  1. Create the top-level trainer class with the initialization step, as follows:
class Trainer:
def __init__(self, height=55,width=35, channels=1,epochs =100, batch=16, checkpoint=50,sim_path='',real_path='',data_limit=0.001,generator_steps=2,discriminator_steps=1):

  1. Initialize all of the internal variables for the training script, as follows:
        self.W = width
self.H = height
self.C = channels
self.EPOCHS = epochs
self.BATCH = batch
self.CHECKPOINT = checkpoint
self.DATA_LIMIT=data_limit
self.GEN_STEPS = generator_steps
self.DISC_STEPS = discriminator_steps
  1. Load the data into the model as follows:
        self.X_real = self.load_h5py(real_path)
self.X_sim = self.load_h5py(sim_path)
  1. There are two critical networks that we need to build for simGAN; they're the refiner (generator) and the discriminator, as follows:
        self.refiner = Generator(height=self.H, width=self.W, channels=self.C)
self.discriminator = Discriminator(height=self.H, width=self.W, channels=self.C)
self.discriminator.trainable = False
  1. Create the following inputs for the models:
 self.synthetic_image = Input(shape=(self.H, self.W, self.C))
self.real_or_fake = Input(shape=(self.H, self.W, self.C))
  1. Hook each of the models with the different inputs, as follows:
        self.refined_image = self.refiner.Generator(self.synthetic_image)
self.discriminator_output = self.discriminator.Discriminator(self.real_or_fake)
self.combined = self.discriminator.Discriminator(self.refined_image)
  1. Create the adversarial model with the inputs and outputs you just created, as follows:
        model_inputs  = [self.synthetic_image]
model_outputs = [self.refined_image, self.combined]
self.gan = GAN(model_inputs=model_inputs,model_outputs=model_outputs)
..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset
3.143.239.103