def trainstep(real_human, real_anime, big_anime): with tf.GradientTape(persistent=True) as tape: latent_anime = encode_share(encode_anime(real_anime)) latent_human = encode_share(encode_human(real_human)) recon_anime = decode_anime(decode_share(latent_anime)) recon_human = decode_human(decode_share(latent_human)) fake_anime = decode_anime(decode_share(latent_human)) latent_human_cycled = encode_share(encode_anime(fake_anime)) fake_human = decode_anime(decode_share(latent_anime)) latent_anime_cycled = encode_share(encode_anime(fake_human)) def kl_loss(mean, log_var): loss = 1 + log_var - tf.math.square(mean) + tf.math.exp( log_var) loss = tf.reduce_sum(loss, axis=-1) * -0.5 return loss disc_fake = D(fake_anime) disc_real = D(real_anime) c_dann_anime = c_dann(latent_anime) c_dann_human = c_dann(latent_human) loss_anime_encode = identity_loss(real_anime, recon_anime) * 3 loss_human_encode = identity_loss(real_human, recon_human) * 3 loss_domain_adversarial = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.zeros_like(c_dann_anime), logits=c_dann_anime)) + tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.ones_like(c_dann_human), logits=c_dann_human)) loss_domain_adversarial = tf.math.minimum(loss_domain_adversarial, 100) loss_domain_adversarial = loss_domain_adversarial * 0.2 tf.print(loss_domain_adversarial) loss_semantic_consistency = ( identity_loss(latent_anime, latent_anime_cycled) * 3 + identity_loss(latent_human, latent_human_cycled) * 3) loss_gan = w_g_loss(disc_fake) anime_encode_total_loss = (loss_anime_encode + loss_domain_adversarial + loss_semantic_consistency + loss_gan) human_encode_total_loss = (loss_human_encode + loss_domain_adversarial + loss_semantic_consistency) share_encode_total_loss = (loss_anime_encode + loss_domain_adversarial + loss_semantic_consistency + loss_gan + loss_human_encode) share_decode_total_loss = loss_anime_encode + loss_human_encode + loss_gan anime_decode_total_loss = loss_anime_encode + loss_gan human_decode_total_loss = loss_human_encode # loss_disc = ( # mse_loss(tf.ones_like(disc_fake), disc_fake) # + mse_loss(tf.zeros_like(disc_real), disc_real) # ) * 10 loss_disc = w_d_loss(disc_real, disc_fake) loss_disc += gradient_penalty(partial(D, training=True), real_anime, fake_anime) losses = [ anime_encode_total_loss, human_encode_total_loss, share_encode_total_loss, loss_domain_adversarial, share_decode_total_loss, anime_decode_total_loss, human_decode_total_loss, loss_disc ] scaled_losses = [ optim.get_scaled_loss(loss) for optim, loss in zip(optims, losses) ] list_variables = [ encode_anime.trainable_variables, encode_human.trainable_variables, encode_share.trainable_variables, c_dann.trainable_variables, decode_share.trainable_variables, decode_anime.trainable_variables, decode_human.trainable_variables, D.trainable_variables ] gan_grad = [ tape.gradient(scaled_loss, train_variable) for scaled_loss, train_variable in zip(scaled_losses, list_variables) ] gan_grad = [ optim.get_unscaled_gradients(x) for optim, x in zip(optims, gan_grad) ] for optim, grad, trainable in zip(optims, gan_grad, list_variables): optim.apply_gradients(zip(grad, trainable)) # dis_grad = dis_optim.get_unscaled_gradients( # tape.gradient(scaled_loss_disc, D.trainable_variables) # ) # dis_optim.apply_gradients(zip(dis_grad, D.trainable_variables)) return (real_human, real_anime, recon_anime, recon_human, fake_anime, fake_human, loss_anime_encode, loss_human_encode, loss_domain_adversarial, loss_semantic_consistency, loss_gan, loss_disc, tf.reduce_mean(disc_fake), tf.reduce_mean(disc_real))
def trainstep(real_human, real_anime, big_anime): with tf.GradientTape(persistent=True) as tape: fake_anime = generator_to_anime(real_human, training=True) cycled_human = generator_to_human(fake_anime, training=True) print("generator_to_anime", generator_to_anime.count_params()) fake_human = generator_to_human(real_anime, training=True) cycled_anime = generator_to_anime(fake_human, training=True) # same_human and same_anime are used for identity loss. same_human = generator_to_human(real_human, training=True) same_anime = generator_to_anime(real_anime, training=True) disc_real_human = discriminator_human(real_human, training=True) disc_real_anime = discriminator_anime(real_anime, training=True) print("discriminator_human", discriminator_human.count_params()) disc_fake_human = discriminator_human(fake_human, training=True) disc_fake_anime = discriminator_anime(fake_anime, training=True) fake_anime_upscale = generator_anime_upscale(fake_anime, training=True) real_anime_upscale = generator_anime_upscale(real_anime, training=True) disc_fake_upscale = discriminator_anime_upscale(fake_anime_upscale, training=True) disc_real_upscale = discriminator_anime_upscale(real_anime_upscale, training=True) disc_real_big = discriminator_anime_upscale(big_anime, training=True) # assert() # calculate the loss gen_anime_loss = w_g_loss(disc_fake_anime) gen_human_loss = w_g_loss(disc_fake_human) total_cycle_loss = cycle_loss(real_human, cycled_human) + cycle_loss( real_anime, cycled_anime) # Total generator loss = adversarial loss + cycle loss total_gen_anime_loss = (gen_anime_loss + total_cycle_loss + identity_loss(real_anime, same_anime)) total_gen_human_loss = (gen_human_loss + total_cycle_loss + identity_loss(real_human, same_human)) gen_upscale_loss = ( w_g_loss(disc_fake_upscale) + w_g_loss(disc_real_upscale) # + mse_loss(big_anime, real_anime_upscale) * 0.1 + identity_loss(big_anime, real_anime_upscale) * 0.3) discriminator_human_gradient_penalty = (gradient_penalty( functools.partial(discriminator_human, training=True), real_human, fake_human, ) * 10) discriminator_anime_gradient_penalty = (gradient_penalty( functools.partial(discriminator_anime, training=True), real_anime, fake_anime, ) * 10) discriminator_upscale_gradient_penalty = (gradient_penalty( functools.partial(discriminator_human, training=True), big_anime, fake_anime_upscale, ) * 5) discriminator_upscale_gradient_penalty += (gradient_penalty( functools.partial(discriminator_human, training=True), big_anime, real_anime_upscale, ) * 5) disc_human_loss = (w_d_loss(disc_real_human, disc_fake_human) + discriminator_human_gradient_penalty) disc_anime_loss = (w_d_loss(disc_real_anime, disc_fake_anime) + discriminator_anime_gradient_penalty) # # print("ggg",big_anime.shape) disc_upscale_loss = w_d_loss(disc_real_big, disc_fake_upscale) disc_upscale_loss += (w_d_loss(disc_real_big, disc_real_upscale) + discriminator_upscale_gradient_penalty) generator_to_anime_gradients = tape.gradient( total_gen_anime_loss, generator_to_anime.trainable_variables) generator_to_human_gradients = tape.gradient( total_gen_human_loss, generator_to_human.trainable_variables) generator_upscale_gradients = tape.gradient( gen_upscale_loss, generator_anime_upscale.trainable_variables) discriminator_human_gradients = tape.gradient( disc_human_loss, discriminator_human.trainable_variables) discriminator_anime_gradients = tape.gradient( disc_anime_loss, discriminator_anime.trainable_variables) discriminator_upscale_gradients = tape.gradient( disc_upscale_loss, discriminator_anime_upscale.trainable_variables) generator_to_anime_optimizer.apply_gradients( zip(generator_to_anime_gradients, generator_to_anime.trainable_variables)) generator_to_human_optimizer.apply_gradients( zip(generator_to_human_gradients, generator_to_human.trainable_variables)) generator_anime_upscale_optimizer.apply_gradients( zip(generator_upscale_gradients, generator_anime_upscale.trainable_variables)) discriminator_human_optimizer.apply_gradients( zip(discriminator_human_gradients, discriminator_human.trainable_variables)) discriminator_anime_optimizer.apply_gradients( zip(discriminator_anime_gradients, discriminator_anime.trainable_variables)) discriminator_anime_upscale_optimizer.apply_gradients( zip( discriminator_upscale_gradients, discriminator_anime_upscale.trainable_variables, )) return [ real_human, real_anime, fake_anime, cycled_human, fake_human, cycled_anime, same_human, same_anime, fake_anime_upscale, real_anime_upscale, gen_anime_loss, gen_human_loss, disc_human_loss, disc_anime_loss, total_gen_anime_loss, total_gen_human_loss, gen_upscale_loss, disc_upscale_loss, ]
def trainstep_D( real_human, real_anime, big_anime, fake_anime, cycled_anime, same_anime, fake_human, cycled_human, same_human, fake_anime_upscale, cycled_anime_upscale, same_anime_upscale, ): with tf.GradientTape(persistent=True) as tape: disc_real_human = discriminator_human(real_human, training=True) disc_real_anime = discriminator_anime(real_anime, training=True) disc_fake_human = discriminator_human(fake_human, training=True) disc_fake_anime = discriminator_anime(fake_anime, training=True) disc_real_big = discriminator_anime_upscale(big_anime, training=True) disc_fake_upscale = discriminator_anime_upscale(fake_anime_upscale, training=True) # disc_same_upscale = discriminator_anime_upscale( # same_anime_upscale, training=True # ) discriminator_human_gradient_penalty = gradient_penalty( functools.partial(discriminator_human, training=True), real_human, fake_human, ) discriminator_anime_gradient_penalty = gradient_penalty( functools.partial(discriminator_anime, training=True), real_anime, fake_anime, ) discriminator_upscale_gradient_penalty = gradient_penalty( functools.partial(discriminator_human, training=True), big_anime, fake_anime_upscale, ) disc_human_loss = (w_d_loss(disc_real_human, disc_fake_human) + discriminator_human_gradient_penalty) disc_anime_loss = (w_d_loss(disc_real_anime, disc_fake_anime) + discriminator_anime_gradient_penalty) disc_upscale_loss = (w_d_loss(disc_real_big, disc_fake_upscale) + discriminator_upscale_gradient_penalty) tf.print("disc_real_big", disc_real_big) tf.print("disc_fake_upscale", disc_fake_upscale) tf.print("disc_upscale_loss", disc_upscale_loss) discriminator_human_gradients = tape.gradient( disc_human_loss, discriminator_human.trainable_variables) discriminator_anime_gradients = tape.gradient( disc_anime_loss, discriminator_anime.trainable_variables) discriminator_upscale_gradients = tape.gradient( disc_upscale_loss, discriminator_anime_upscale.trainable_variables) discriminator_human_optimizer.apply_gradients( zip(discriminator_human_gradients, discriminator_human.trainable_variables)) discriminator_anime_optimizer.apply_gradients( zip(discriminator_anime_gradients, discriminator_anime.trainable_variables)) discriminator_anime_upscale_optimizer.apply_gradients( zip( discriminator_upscale_gradients, discriminator_anime_upscale.trainable_variables, ))
def trainstep_D( real_human, real_anime, big_anime, fake_anime, cycled_anime, same_anime, fake_human, cycled_human, same_human, fake_anime_upscale, cycled_anime_upscale, same_anime_upscale, ): with tf.GradientTape(persistent=True) as tape: disc_real_human = discriminator_human(real_human, training=True) disc_real_anime = discriminator_anime(real_anime, training=True) disc_fake_human = discriminator_human(fake_human, training=True) disc_fake_anime = discriminator_anime(fake_anime, training=True) disc_real_big = discriminator_anime_upscale(big_anime, training=True) disc_fake_upscale = discriminator_anime_upscale(fake_anime_upscale, training=True) disc_cycled_upscale = discriminator_anime_upscale( cycled_anime_upscale, training=True) disc_same_upscale = discriminator_anime_upscale(same_anime_upscale, training=True) discriminator_human_gradient_penalty = gradient_penalty( functools.partial(discriminator_human, training=True), real_human, fake_human, ) discriminator_anime_gradient_penalty = gradient_penalty( functools.partial(discriminator_anime, training=True), real_anime, fake_anime, ) discriminator_upscale_gradient_penalty = gradient_penalty( functools.partial(discriminator_human, training=True), big_anime, fake_anime_upscale, ) discriminator_upscale_gradient_penalty = gradient_penalty( functools.partial(discriminator_human, training=True), big_anime, cycled_anime_upscale, ) discriminator_upscale_gradient_penalty = gradient_penalty( functools.partial(discriminator_human, training=True), big_anime, same_anime_upscale, ) disc_human_loss = (w_d_loss(disc_real_human, disc_fake_human) + discriminator_human_gradient_penalty) disc_anime_loss = (w_d_loss(disc_real_anime, disc_fake_anime) + discriminator_anime_gradient_penalty) # # print("ggg",big_anime.shape) disc_upscale_loss = (w_d_loss(disc_real_big, disc_fake_upscale) + w_d_loss(disc_real_big, disc_cycled_upscale) + w_d_loss(disc_real_big, disc_same_upscale) + discriminator_upscale_gradient_penalty) / 3.0 scaled_disc_human_loss = discriminator_human_optimizer.get_scaled_loss( disc_human_loss) scaled_disc_anime_loss = discriminator_anime_optimizer.get_scaled_loss( disc_anime_loss) scaled_disc_upscale_loss = discriminator_anime_upscale_optimizer.get_scaled_loss( disc_upscale_loss) discriminator_human_gradients = discriminator_human_optimizer.get_unscaled_gradients( tape.gradient(scaled_disc_human_loss, discriminator_human.trainable_variables)) discriminator_anime_gradients = discriminator_anime_optimizer.get_unscaled_gradients( tape.gradient(scaled_disc_anime_loss, discriminator_anime.trainable_variables)) discriminator_upscale_gradients = discriminator_anime_upscale_optimizer.get_unscaled_gradients( tape.gradient( scaled_disc_upscale_loss, discriminator_anime_upscale.trainable_variables, )) discriminator_human_optimizer.apply_gradients( zip(discriminator_human_gradients, discriminator_human.trainable_variables)) discriminator_anime_optimizer.apply_gradients( zip(discriminator_anime_gradients, discriminator_anime.trainable_variables)) discriminator_anime_upscale_optimizer.apply_gradients( zip( discriminator_upscale_gradients, discriminator_anime_upscale.trainable_variables, ))