def forward_pass(image, label, gen, disc_f, disc_h, disc_j, model_en, batch_size, cont_dim, config, update_g=True, update_d=True): if not config.conditional: label = None fake_noise = tf.random.truncated_normal([batch_size, cont_dim]) fake_img = gen(fake_noise, label, training=True) latent_code_real = model_en(image, training=True) real_f_to_j, real_f_score = disc_f(image, label, training=True) fake_f_to_j, fake_f_score = disc_f(fake_img, label, training=True) real_h_to_j, real_h_score = disc_h(latent_code_real, training=True) fake_h_to_j, fake_h_score = disc_h(fake_noise, training=True) real_j_score = disc_j(real_f_to_j, real_h_to_j, training=True) fake_j_score = disc_j(fake_f_to_j, fake_h_to_j, training=True) d_loss = disc_loss(real_f_score, real_h_score, real_j_score, fake_f_score, fake_h_score, fake_j_score) g_e_loss = gen_en_loss(real_f_score, real_h_score, real_j_score, fake_f_score, fake_h_score, fake_j_score) return g_e_loss, d_loss
def train_step(image, label, gen, disc_f, disc_h, disc_j, model_en, disc_optimizer, gen_en_optimizer, metric_loss_disc, metric_loss_gen_en, batch_size, cont_dim, config): print('Graph will be traced...') with tf.device('{}:*'.format(config.device)): for _ in range(config.D_G_ratio): fake_noise = tf.random.truncated_normal([batch_size, cont_dim]) with tf.GradientTape( persistent=True) as gen_en_tape, tf.GradientTape( ) as en_tape: fake_img = gen(fake_noise, label, training=True) latent_code_real = model_en(image, training=True) with tf.GradientTape(persistent=True) as disc_tape: real_f_to_j, real_f_score = disc_f(image, label, training=True) fake_f_to_j, fake_f_score = disc_f(fake_img, label, training=True) real_h_to_j, real_h_score = disc_h(latent_code_real, training=True) fake_h_to_j, fake_h_score = disc_h(fake_noise, training=True) real_j_score = disc_j(real_f_to_j, real_h_to_j, training=True) fake_j_score = disc_j(fake_f_to_j, fake_h_to_j, training=True) d_loss = disc_loss(real_f_score, real_h_score, real_j_score, fake_f_score, fake_h_score, fake_j_score) g_e_loss = gen_en_loss(real_f_score, real_h_score, real_j_score, fake_f_score, fake_h_score, fake_j_score) grad_disc = disc_tape.gradient( d_loss, disc_f.trainable_variables + disc_h.trainable_variables + disc_j.trainable_variables) disc_optimizer.apply_gradients( zip( grad_disc, disc_f.trainable_variables + disc_h.trainable_variables + disc_j.trainable_variables)) metric_loss_disc.update_state( d_loss) # upgrade the value in metrics for single step. grad_gen_en = gen_en_tape.gradient( g_e_loss, gen.trainable_variables + model_en.trainable_variables) gen_en_optimizer.apply_gradients( zip(grad_gen_en, gen.trainable_variables + model_en.trainable_variables)) metric_loss_gen_en.update_state(g_e_loss) del gen_en_tape, en_tape del disc_tape