def forward_pass(image,
                     label,
                     gen,
                     disc_f,
                     disc_h,
                     disc_j,
                     model_en,
                     batch_size,
                     cont_dim,
                     config,
                     update_g=True,
                     update_d=True):
        if not config.conditional:
            label = None
        fake_noise = tf.random.truncated_normal([batch_size, cont_dim])
        fake_img = gen(fake_noise, label, training=True)
        latent_code_real = model_en(image, training=True)
        real_f_to_j, real_f_score = disc_f(image, label, training=True)
        fake_f_to_j, fake_f_score = disc_f(fake_img, label, training=True)
        real_h_to_j, real_h_score = disc_h(latent_code_real, training=True)
        fake_h_to_j, fake_h_score = disc_h(fake_noise, training=True)
        real_j_score = disc_j(real_f_to_j, real_h_to_j, training=True)
        fake_j_score = disc_j(fake_f_to_j, fake_h_to_j, training=True)

        d_loss = disc_loss(real_f_score, real_h_score, real_j_score,
                           fake_f_score, fake_h_score, fake_j_score)
        g_e_loss = gen_en_loss(real_f_score, real_h_score, real_j_score,
                               fake_f_score, fake_h_score, fake_j_score)
        return g_e_loss, d_loss
def train_step(image, label, gen, disc_f, disc_h, disc_j, model_en,
               disc_optimizer, gen_en_optimizer, metric_loss_disc,
               metric_loss_gen_en, batch_size, cont_dim, config):
    print('Graph will be traced...')

    with tf.device('{}:*'.format(config.device)):
        for _ in range(config.D_G_ratio):
            fake_noise = tf.random.truncated_normal([batch_size, cont_dim])
            with tf.GradientTape(
                    persistent=True) as gen_en_tape, tf.GradientTape(
                    ) as en_tape:
                fake_img = gen(fake_noise, label, training=True)
                latent_code_real = model_en(image, training=True)
                with tf.GradientTape(persistent=True) as disc_tape:
                    real_f_to_j, real_f_score = disc_f(image,
                                                       label,
                                                       training=True)
                    fake_f_to_j, fake_f_score = disc_f(fake_img,
                                                       label,
                                                       training=True)
                    real_h_to_j, real_h_score = disc_h(latent_code_real,
                                                       training=True)
                    fake_h_to_j, fake_h_score = disc_h(fake_noise,
                                                       training=True)
                    real_j_score = disc_j(real_f_to_j,
                                          real_h_to_j,
                                          training=True)
                    fake_j_score = disc_j(fake_f_to_j,
                                          fake_h_to_j,
                                          training=True)

                    d_loss = disc_loss(real_f_score, real_h_score,
                                       real_j_score, fake_f_score,
                                       fake_h_score, fake_j_score)
                    g_e_loss = gen_en_loss(real_f_score, real_h_score,
                                           real_j_score, fake_f_score,
                                           fake_h_score, fake_j_score)

            grad_disc = disc_tape.gradient(
                d_loss, disc_f.trainable_variables +
                disc_h.trainable_variables + disc_j.trainable_variables)

            disc_optimizer.apply_gradients(
                zip(
                    grad_disc, disc_f.trainable_variables +
                    disc_h.trainable_variables + disc_j.trainable_variables))
            metric_loss_disc.update_state(
                d_loss)  # upgrade the value in metrics for single step.

        grad_gen_en = gen_en_tape.gradient(
            g_e_loss, gen.trainable_variables + model_en.trainable_variables)

        gen_en_optimizer.apply_gradients(
            zip(grad_gen_en,
                gen.trainable_variables + model_en.trainable_variables))
        metric_loss_gen_en.update_state(g_e_loss)

        del gen_en_tape, en_tape
        del disc_tape