The training method

 The training method is the key method in this class—it'll define how we train each of the models and use the data. The following steps will walk you through how we train this architecture:

  1. We're going to define the train method and start out by defining the number of real images and generated images. From there, we'll run through the number of epochs and run through each label in every epoch:
def train(self):

count_generated_images = int(self.BATCH/2)
count_real_images = int(self.BATCH/2)
for e in range(self.EPOCHS):
for label in self.LABELS:
  1. We need to grab our 3D samples that match the label we defined in the instantiation of the class:
# Grab the Real 3D Samples
all_3D_samples = self.X_train_3D[np.where(self.Y_train_3D==label)]
  1. After we have all of the 3D samples that match our label, let's find a random index that we can collect data from that's smaller than our batch size:
starting_index = randint(0, (len(all_3D_samples)-count_real_images))
  1. Once we have a random index, then we grab all of the samples from the starting index to number in our batch size:
real_3D_samples = all_3D_samples[ starting_index : int((starting_index +  
count_real_images)) ]
  1. Create the y_real_labels by creating an array of ones ( this means true to the discriminator):
y_real_labels =  np.ones([count_generated_images,1])
  1. Repeat the same selection process for the encoded samples:
# Grab Generated Images for this training batch
all_encoded_samples =
self.X_train_2D_encoded[np.where(self.Y_train_2D==label)]
starting_index = randint(0, (len(all_encoded_samples)-count_generated_images))
batch_encoded_samples = all_encoded_samples[ starting_index :
int((starting_index + count_generated_images)) ]
  1. Reshape the encoded samples to match the shape of the model input:
batch_encoded_samples = batch_encoded_samples.reshape( count_generated_images,       
1, 1, 1,self.LATENT_SPACE_SIZE)
  1. Generate samples using the generator and generate labels for those samples to train the discriminator:
x_generated_3D_samples = 
self.generator.Generator.predict(batch_encoded_samples)
y_generated_labels = np.zeros([count_generated_images,1])
  1. Combine all of those datasets (real and fake) for training the discriminator:
# Combine to train on the discriminator
x_batch = np.concatenate( [real_3D_samples, x_generated_3D_samples] )
y_batch = np.concatenate( [y_real_labels, y_generated_labels] )
  1.  Using the newly concatenated data, train the discriminator:
# Now, train the discriminator with this batch
self.discriminator.Discriminator.trainable = False
discriminator_loss =
self.discriminator.Discriminator.train_on_batch(x_batch,y_batch)[0]
self.discriminator.Discriminator.trainable = True

  1. Just as we did previously, use that selection method to grab random indexes in the encoded samples and create the GAN training data:
# Generate Noise
starting_index = randint(0, (len(all_encoded_samples)-self.BATCH))
x_batch_encoded_samples = all_encoded_samples[ starting_index :
int((starting_index + self.BATCH)) ]
x_batch_encoded_samples = x_batch_encoded_samples.reshape( int(self.BATCH), 1,
1, 1,self.LATENT_SPACE_SIZE)
y_generated_labels = np.ones([self.BATCH,1])
  1. The generator is trained on encoded samples and the generated labels:
generator_loss = 
self.gan.gan_model.train_on_batch(x_batch_encoded_samples,y_generated_labels)
  1. In this step, we're printing the loss for each of our models, along with the epoch number and label that we're predicting:
print ('Epoch: '+str(int(e))+' Label: '+str(int(label))+', [Discriminator :: Loss: '+str(discriminator_loss)+'], [ Generator :: Loss: 
'+str(generator_loss)+']')
  1. The final step is to make sure that we can checkpoint the model at our internal checkpoint:
if e % self.CHECKPOINT == 0 and e != 0 :
self.plot_checkpoint(e,label)
return

Let's understand how to build the internal plot_checkpoint method in this next set of steps!

..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset
18.116.49.247