Creating models

Let's create models using the methods in the previous Stage-I StackGAN section under A Keras Implementation of StackGAN. We will use four models: a generator model, a discriminator model, a compressor model that compresses the text embedding, and an adversarial model containing both the generator and discriminator:

  1. Start by defining the optimizers required for the training:
dis_optimizer = Adam(lr=stage1_discriminator_lr, beta_1=0.5, beta_2=0.999)
gen_optimizer = Adam(lr=stage1_generator_lr, beta_1=0.5, beta_2=0.999)
  1. Then build and compile different networks as follows:
ca_model = build_ca_model()
ca_model.compile(loss="binary_crossentropy", optimizer="adam")

stage1_dis = build_stage1_discriminator()
stage1_dis.compile(loss='binary_crossentropy', optimizer=dis_optimizer)

stage1_gen = build_stage1_generator()
stage1_gen.compile(loss="mse", optimizer=gen_optimizer)

embedding_compressor_model = build_embedding_compressor_model()
embedding_compressor_model.compile(loss="binary_crossentropy", optimizer="adam")

adversarial_model = build_adversarial_model(gen_model=stage1_gen, dis_model=stage1_dis)
adversarial_model.compile(loss=['binary_crossentropy', KL_loss], loss_weights=[1, 2.0],
optimizer=gen_optimizer, metrics=None)

Here, KL_loss is a custom loss function, which is defined as follows:

def KL_loss(y_true, y_pred):
mean = y_pred[:, :128]
logsigma = y_pred[:, :128]
loss = -logsigma + .5 * (-1 + K.exp(2. * logsigma) + K.square(mean))
loss = K.mean(loss)
return loss

We now have the dataset and the models ready, so we can start training the model.

Also, add TensorBoard to store losses for visualization, as follows:

tensorboard = TensorBoard(log_dir="logs/".format(time.time()))
tensorboard.set_model(stage1_gen)
tensorboard.set_model(stage1_dis)
tensorboard.set_model(ca_model)
tensorboard.set_model(embedding_compressor_model)
..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset
3.140.196.244