Beispiel #1
0
 def update_g(real_x, real_x_p, real_z, real_z_p, v, v_p):
     gen_inputs = tf.concat([real_z, v], axis=1)
     gen_inputs_p = tf.concat([real_z_p, v_p], axis=1)
     # concatenate real inputs for WGAN discriminator (x, z)
     d_real = tf.concat([real_x, real_z], axis=1)
     d_real_p = tf.concat([real_x_p, real_z_p], axis=1)
     with tf.GradientTape() as gen_tape:
         fake_x = generator.call(gen_inputs)
         fake_x_p = generator.call(gen_inputs_p)
         d_fake = tf.concat([fake_x, real_z], axis=1)
         d_fake_p = tf.concat([fake_x_p, real_z_p], axis=1)
         f_real = discriminator.call(d_real)
         f_fake = discriminator.call(d_fake)
         f_real_p = discriminator.call(d_real_p)
         f_fake_p = discriminator.call(d_fake_p)
         # call compute loss using @tf.function + autograph
         gen_loss = gan_utils.benchmark_loss(f_real, f_fake, scaling_coef,
                                             sinkhorn_eps, sinkhorn_l,
                                             f_real_p, f_fake_p)
     # update generator parameters
     generator_grads = gen_tape.gradient(gen_loss,
                                         generator.trainable_variables)
     gen_optimiser.apply_gradients(
         zip(generator_grads, generator.trainable_variables))
     return gen_loss
Beispiel #2
0
    def x_update_d(real_x, real_x_p, real_z, real_z_p, v, v_p):
        gen_inputs = tf.concat([real_z, v], axis=1)
        gen_inputs_p = tf.concat([real_z_p, v_p], axis=1)
        # concatenate real inputs for WGAN discriminator (x, z)
        d_real = tf.concat([real_x, real_z], axis=1)
        d_real_p = tf.concat([real_x_p, real_z_p], axis=1)
        fake_x = generator_x.call(gen_inputs)
        fake_x_p = generator_x.call(gen_inputs_p)
        d_fake = tf.concat([fake_x, real_z], axis=1)
        d_fake_p = tf.concat([fake_x_p, real_z_p], axis=1)

        with tf.GradientTape() as disc_tape:
            f_real = discriminator_x.call(d_real)
            f_fake = discriminator_x.call(d_fake)
            f_real_p = discriminator_x.call(d_real_p)
            f_fake_p = discriminator_x.call(d_fake_p)
            # call compute loss using @tf.function + autograph

            loss1 = gan_utils.benchmark_loss(f_real, f_fake, scaling_coef,
                                             sinkhorn_eps, sinkhorn_l,
                                             f_real_p, f_fake_p)
            # disc_loss = - tf.math.minimum(loss1, 1)
            disc_loss = -loss1
        # update discriminator parameters
        d_grads = disc_tape.gradient(disc_loss,
                                     discriminator_x.trainable_variables)
        dx_optimiser.apply_gradients(
            zip(d_grads, discriminator_x.trainable_variables))