def update_g(real_x, real_x_p, real_z, real_z_p, v, v_p): gen_inputs = tf.concat([real_z, v], axis=1) gen_inputs_p = tf.concat([real_z_p, v_p], axis=1) # concatenate real inputs for WGAN discriminator (x, z) d_real = tf.concat([real_x, real_z], axis=1) d_real_p = tf.concat([real_x_p, real_z_p], axis=1) with tf.GradientTape() as gen_tape: fake_x = generator.call(gen_inputs) fake_x_p = generator.call(gen_inputs_p) d_fake = tf.concat([fake_x, real_z], axis=1) d_fake_p = tf.concat([fake_x_p, real_z_p], axis=1) f_real = discriminator.call(d_real) f_fake = discriminator.call(d_fake) f_real_p = discriminator.call(d_real_p) f_fake_p = discriminator.call(d_fake_p) # call compute loss using @tf.function + autograph gen_loss = gan_utils.benchmark_loss(f_real, f_fake, scaling_coef, sinkhorn_eps, sinkhorn_l, f_real_p, f_fake_p) # update generator parameters generator_grads = gen_tape.gradient(gen_loss, generator.trainable_variables) gen_optimiser.apply_gradients( zip(generator_grads, generator.trainable_variables)) return gen_loss
def x_update_d(real_x, real_x_p, real_z, real_z_p, v, v_p): gen_inputs = tf.concat([real_z, v], axis=1) gen_inputs_p = tf.concat([real_z_p, v_p], axis=1) # concatenate real inputs for WGAN discriminator (x, z) d_real = tf.concat([real_x, real_z], axis=1) d_real_p = tf.concat([real_x_p, real_z_p], axis=1) fake_x = generator_x.call(gen_inputs) fake_x_p = generator_x.call(gen_inputs_p) d_fake = tf.concat([fake_x, real_z], axis=1) d_fake_p = tf.concat([fake_x_p, real_z_p], axis=1) with tf.GradientTape() as disc_tape: f_real = discriminator_x.call(d_real) f_fake = discriminator_x.call(d_fake) f_real_p = discriminator_x.call(d_real_p) f_fake_p = discriminator_x.call(d_fake_p) # call compute loss using @tf.function + autograph loss1 = gan_utils.benchmark_loss(f_real, f_fake, scaling_coef, sinkhorn_eps, sinkhorn_l, f_real_p, f_fake_p) # disc_loss = - tf.math.minimum(loss1, 1) disc_loss = -loss1 # update discriminator parameters d_grads = disc_tape.gradient(disc_loss, discriminator_x.trainable_variables) dx_optimiser.apply_gradients( zip(d_grads, discriminator_x.trainable_variables))