Helper method

We have a few new helper methods—instead of having independent data loading methods, we simply have a new method inside of training. Also, when we've developed CycleGAN, we need to be able to check the style transfer and the reconstruction from that transferred style—the plotting function will do this for us.

Here are the steps:

  1. Loading the data is fairly easy—just a rehashing of our data loading from Chapter 4, Dreaming of New Outdoor Structures Using DCGAN:
def load_data(self,data_path,amount_of_data = 1.0):
listOFFiles = self.grabListOfFiles(data_path,extension="jpg")
X_train = np.array(self.grabArrayOfImages(listOFFiles))
height, width, channels = np.shape(X_train[0])
X_train = X_train[:int(amount_of_data*float(len(X_train)))]
X_train = (X_train.astype(np.float32) - 127.5)/127.5
X_train = np.expand_dims(X_train, axis=3)
return X_train, height, width, channels
  1. Grabbing the list of files is a straightforward method using OS:
    def grabListOfFiles(self,startingDirectory,extension=".webp"):
listOfFiles = []
for file in os.listdir(startingDirectory):
if file.endswith(extension):
listOfFiles.append(os.path.join(startingDirectory,
file))
return listOfFiles
  1. Coin flipping is carried over from Chapter 4, Dreaming of New Outdoor Structures Using DCGAN:
    def flipCoin(self,chance=0.5):
return np.random.binomial(1, chance)
  1. Importing of images is merged in from the separate Python script introduced in Chapter 4, Dreaming of New Outdoor Structures Using DCGAN:
    def grabArrayOfImages(self,listOfFiles,gray=False):
imageArr = []
for f in listOfFiles:
if gray:
im = Image.open(f).convert("L")
else:
im = Image.open(f).convert("RGB")
im = im.resize((self.RESIZE_WIDTH,self.RESIZE_HEIGHT))
imData = np.asarray(im)
imageArr.append(imData)
return imageArr
  1. At each checkpoint, take an example from the test set and transfer style from A to B, then back to A:
    def plot_checkpoint(self,b):
orig_filename = "/data/batch_check_"+str(b)+"_original.png"

image_A = self.X_test_A[5]
image_A = np.reshape(image_A,
[self.W_A_test,self.H_A_test,self.C_A_test])
fake_B =
self.generator_A_to_B.Generator.predict(image_A.reshape(1,
self.W_A, self.H_A, self.C_A ))
fake_B = np.reshape(fake_B,
[self.W_A_test,self.H_A_test,self.C_A_test])
reconstructed_A =
self.generator_B_to_A.Generator.predict(fake_B.reshape(1,
self.W_A, self.H_A, self.C_A ))
reconstructed_A = np.reshape(reconstructed_A,
[self.W_A_test,self.H_A_test,self.C_A_test])
checkpoint_images = np.array([image_A, fake_B,
reconstructed_A])

  1. Use Matplotlib's plotting function to plot all three of the images:
# Rescale images 0 - 1
checkpoint_images = 0.5 * checkpoint_images + 0.5

titles = ['Original', 'Translated', 'Reconstructed']
fig, axes = plt.subplots(1, 3)
for i in range(3):
image = checkpoint_images[i]
image = np.reshape(image,
[self.H_A_test,self.W_A_test,self.C_A_test])
axes[i].imshow(image)
axes[i].set_title(titles[i])
axes[i].axis('off')
fig.savefig("/data/batch_check_"+str(b)+".png")
plt.close('all')
return

At each batch or epoch check, you should see an output image similar to this:

..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset
3.21.104.183