There are a few helper functions I want to highlight in this section in the following steps:
- The load data function helps define the downloaded input data—the images are stitched together as 256 x 512, so this function reads them in and splits them into two arrays:
def load_data(self,data_path):
listOFFiles = self.grabListOfFiles(data_path,extension="jpg")
imgs_temp = np.array(self.grabArrayOfImages(listOFFiles))
imgs_A = []
imgs_B = []
for img in imgs_temp:
imgs_A.append(img[:,:self.H])
imgs_B.append(img[:,self.H:])
imgs_A_out = self.norm_and_expand(np.array(imgs_A))
imgs_B_out = self.norm_and_expand(np.array(imgs_B))
return imgs_A_out, imgs_B_out
- This is a convenience function that puts the array in the correct shape for the network to use it:
def norm_and_expand(self,arr):
arr = (arr.astype(np.float32) - 127.5)/127.5
normed = np.expand_dims(arr, axis=3)
return normed
- This function lets us grab a list of files from a starting directory:
def grabListOfFiles(self,startingDirectory,extension=".webp"):
listOfFiles = []
for file in os.listdir(startingDirectory):
if file.endswith(extension):
listOfFiles.append(os.path.join(startingDirectory, file))
return listOfFiles
- Given a list of files, read in the images into an array and return them:
def grabArrayOfImages(self,listOfFiles,gray=False):
imageArr = []
for f in listOfFiles:
if gray:
im = Image.open(f).convert("L")
else:
im = Image.open(f).convert("RGB")
imData = np.asarray(im)
imageArr.append(imData)
return imageArr
Now, we'll move on to how we train this model now that we have the class setup and working.