示例#1
0
    def trainstep(real_human, real_anime, big_anime):
        with tf.GradientTape(persistent=True) as tape:
            latent_anime = encode_share(encode_anime(real_anime))
            latent_human = encode_share(encode_human(real_human))

            recon_anime = decode_anime(decode_share(latent_anime))
            recon_human = decode_human(decode_share(latent_human))

            fake_anime = decode_anime(decode_share(latent_human))
            latent_human_cycled = encode_share(encode_anime(fake_anime))

            fake_human = decode_anime(decode_share(latent_anime))
            latent_anime_cycled = encode_share(encode_anime(fake_human))

            def kl_loss(mean, log_var):
                loss = 1 + log_var - tf.math.square(mean) + tf.math.exp(
                    log_var)
                loss = tf.reduce_sum(loss, axis=-1) * -0.5
                return loss

            disc_fake = D(fake_anime)
            disc_real = D(real_anime)

            c_dann_anime = c_dann(latent_anime)
            c_dann_human = c_dann(latent_human)

            loss_anime_encode = identity_loss(real_anime, recon_anime) * 3
            loss_human_encode = identity_loss(real_human, recon_human) * 3

            loss_domain_adversarial = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=tf.zeros_like(c_dann_anime),
                    logits=c_dann_anime)) + tf.reduce_mean(
                        tf.nn.sigmoid_cross_entropy_with_logits(
                            labels=tf.ones_like(c_dann_human),
                            logits=c_dann_human))
            loss_domain_adversarial = tf.math.minimum(loss_domain_adversarial,
                                                      100)
            loss_domain_adversarial = loss_domain_adversarial * 0.2
            tf.print(loss_domain_adversarial)

            loss_semantic_consistency = (
                identity_loss(latent_anime, latent_anime_cycled) * 3 +
                identity_loss(latent_human, latent_human_cycled) * 3)

            loss_gan = w_g_loss(disc_fake)

            anime_encode_total_loss = (loss_anime_encode +
                                       loss_domain_adversarial +
                                       loss_semantic_consistency + loss_gan)
            human_encode_total_loss = (loss_human_encode +
                                       loss_domain_adversarial +
                                       loss_semantic_consistency)
            share_encode_total_loss = (loss_anime_encode +
                                       loss_domain_adversarial +
                                       loss_semantic_consistency + loss_gan +
                                       loss_human_encode)

            share_decode_total_loss = loss_anime_encode + loss_human_encode + loss_gan
            anime_decode_total_loss = loss_anime_encode + loss_gan
            human_decode_total_loss = loss_human_encode

            # loss_disc = (
            #     mse_loss(tf.ones_like(disc_fake), disc_fake)
            #     + mse_loss(tf.zeros_like(disc_real), disc_real)
            # ) * 10
            loss_disc = w_d_loss(disc_real, disc_fake)
            loss_disc += gradient_penalty(partial(D, training=True),
                                          real_anime, fake_anime)

            losses = [
                anime_encode_total_loss, human_encode_total_loss,
                share_encode_total_loss, loss_domain_adversarial,
                share_decode_total_loss, anime_decode_total_loss,
                human_decode_total_loss, loss_disc
            ]

            scaled_losses = [
                optim.get_scaled_loss(loss)
                for optim, loss in zip(optims, losses)
            ]

        list_variables = [
            encode_anime.trainable_variables, encode_human.trainable_variables,
            encode_share.trainable_variables, c_dann.trainable_variables,
            decode_share.trainable_variables, decode_anime.trainable_variables,
            decode_human.trainable_variables, D.trainable_variables
        ]
        gan_grad = [
            tape.gradient(scaled_loss, train_variable) for scaled_loss,
            train_variable in zip(scaled_losses, list_variables)
        ]
        gan_grad = [
            optim.get_unscaled_gradients(x)
            for optim, x in zip(optims, gan_grad)
        ]
        for optim, grad, trainable in zip(optims, gan_grad, list_variables):
            optim.apply_gradients(zip(grad, trainable))
        # dis_grad = dis_optim.get_unscaled_gradients(
        #     tape.gradient(scaled_loss_disc, D.trainable_variables)
        # )
        # dis_optim.apply_gradients(zip(dis_grad, D.trainable_variables))

        return (real_human, real_anime, recon_anime, recon_human, fake_anime,
                fake_human, loss_anime_encode, loss_human_encode,
                loss_domain_adversarial, loss_semantic_consistency,
                loss_gan, loss_disc, tf.reduce_mean(disc_fake),
                tf.reduce_mean(disc_real))
    def trainstep(real_human, real_anime, big_anime):

        with tf.GradientTape(persistent=True) as tape:

            fake_anime = generator_to_anime(real_human, training=True)
            cycled_human = generator_to_human(fake_anime, training=True)

            print("generator_to_anime", generator_to_anime.count_params())

            fake_human = generator_to_human(real_anime, training=True)
            cycled_anime = generator_to_anime(fake_human, training=True)

            # same_human and same_anime are used for identity loss.
            same_human = generator_to_human(real_human, training=True)
            same_anime = generator_to_anime(real_anime, training=True)

            disc_real_human = discriminator_human(real_human, training=True)
            disc_real_anime = discriminator_anime(real_anime, training=True)
            print("discriminator_human", discriminator_human.count_params())

            disc_fake_human = discriminator_human(fake_human, training=True)
            disc_fake_anime = discriminator_anime(fake_anime, training=True)

            fake_anime_upscale = generator_anime_upscale(fake_anime,
                                                         training=True)
            real_anime_upscale = generator_anime_upscale(real_anime,
                                                         training=True)

            disc_fake_upscale = discriminator_anime_upscale(fake_anime_upscale,
                                                            training=True)

            disc_real_upscale = discriminator_anime_upscale(real_anime_upscale,
                                                            training=True)
            disc_real_big = discriminator_anime_upscale(big_anime,
                                                        training=True)
            # assert()
            # calculate the loss
            gen_anime_loss = w_g_loss(disc_fake_anime)
            gen_human_loss = w_g_loss(disc_fake_human)

            total_cycle_loss = cycle_loss(real_human,
                                          cycled_human) + cycle_loss(
                                              real_anime, cycled_anime)

            # Total generator loss = adversarial loss + cycle loss
            total_gen_anime_loss = (gen_anime_loss + total_cycle_loss +
                                    identity_loss(real_anime, same_anime))

            total_gen_human_loss = (gen_human_loss + total_cycle_loss +
                                    identity_loss(real_human, same_human))

            gen_upscale_loss = (
                w_g_loss(disc_fake_upscale) + w_g_loss(disc_real_upscale)
                # + mse_loss(big_anime, real_anime_upscale) * 0.1
                + identity_loss(big_anime, real_anime_upscale) * 0.3)

            discriminator_human_gradient_penalty = (gradient_penalty(
                functools.partial(discriminator_human, training=True),
                real_human,
                fake_human,
            ) * 10)
            discriminator_anime_gradient_penalty = (gradient_penalty(
                functools.partial(discriminator_anime, training=True),
                real_anime,
                fake_anime,
            ) * 10)
            discriminator_upscale_gradient_penalty = (gradient_penalty(
                functools.partial(discriminator_human, training=True),
                big_anime,
                fake_anime_upscale,
            ) * 5)
            discriminator_upscale_gradient_penalty += (gradient_penalty(
                functools.partial(discriminator_human, training=True),
                big_anime,
                real_anime_upscale,
            ) * 5)

            disc_human_loss = (w_d_loss(disc_real_human, disc_fake_human) +
                               discriminator_human_gradient_penalty)
            disc_anime_loss = (w_d_loss(disc_real_anime, disc_fake_anime) +
                               discriminator_anime_gradient_penalty)
            # # print("ggg",big_anime.shape)
            disc_upscale_loss = w_d_loss(disc_real_big, disc_fake_upscale)
            disc_upscale_loss += (w_d_loss(disc_real_big, disc_real_upscale) +
                                  discriminator_upscale_gradient_penalty)

        generator_to_anime_gradients = tape.gradient(
            total_gen_anime_loss, generator_to_anime.trainable_variables)
        generator_to_human_gradients = tape.gradient(
            total_gen_human_loss, generator_to_human.trainable_variables)
        generator_upscale_gradients = tape.gradient(
            gen_upscale_loss, generator_anime_upscale.trainable_variables)

        discriminator_human_gradients = tape.gradient(
            disc_human_loss, discriminator_human.trainable_variables)
        discriminator_anime_gradients = tape.gradient(
            disc_anime_loss, discriminator_anime.trainable_variables)

        discriminator_upscale_gradients = tape.gradient(
            disc_upscale_loss, discriminator_anime_upscale.trainable_variables)

        generator_to_anime_optimizer.apply_gradients(
            zip(generator_to_anime_gradients,
                generator_to_anime.trainable_variables))

        generator_to_human_optimizer.apply_gradients(
            zip(generator_to_human_gradients,
                generator_to_human.trainable_variables))

        generator_anime_upscale_optimizer.apply_gradients(
            zip(generator_upscale_gradients,
                generator_anime_upscale.trainable_variables))

        discriminator_human_optimizer.apply_gradients(
            zip(discriminator_human_gradients,
                discriminator_human.trainable_variables))

        discriminator_anime_optimizer.apply_gradients(
            zip(discriminator_anime_gradients,
                discriminator_anime.trainable_variables))

        discriminator_anime_upscale_optimizer.apply_gradients(
            zip(
                discriminator_upscale_gradients,
                discriminator_anime_upscale.trainable_variables,
            ))

        return [
            real_human,
            real_anime,
            fake_anime,
            cycled_human,
            fake_human,
            cycled_anime,
            same_human,
            same_anime,
            fake_anime_upscale,
            real_anime_upscale,
            gen_anime_loss,
            gen_human_loss,
            disc_human_loss,
            disc_anime_loss,
            total_gen_anime_loss,
            total_gen_human_loss,
            gen_upscale_loss,
            disc_upscale_loss,
        ]
示例#3
0
    def trainstep_D(
        real_human,
        real_anime,
        big_anime,
        fake_anime,
        cycled_anime,
        same_anime,
        fake_human,
        cycled_human,
        same_human,
        fake_anime_upscale,
        cycled_anime_upscale,
        same_anime_upscale,
    ):
        with tf.GradientTape(persistent=True) as tape:
            disc_real_human = discriminator_human(real_human, training=True)
            disc_real_anime = discriminator_anime(real_anime, training=True)

            disc_fake_human = discriminator_human(fake_human, training=True)
            disc_fake_anime = discriminator_anime(fake_anime, training=True)

            disc_real_big = discriminator_anime_upscale(big_anime,
                                                        training=True)
            disc_fake_upscale = discriminator_anime_upscale(fake_anime_upscale,
                                                            training=True)
            # disc_same_upscale = discriminator_anime_upscale(
            #     same_anime_upscale, training=True
            # )

            discriminator_human_gradient_penalty = gradient_penalty(
                functools.partial(discriminator_human, training=True),
                real_human,
                fake_human,
            )
            discriminator_anime_gradient_penalty = gradient_penalty(
                functools.partial(discriminator_anime, training=True),
                real_anime,
                fake_anime,
            )
            discriminator_upscale_gradient_penalty = gradient_penalty(
                functools.partial(discriminator_human, training=True),
                big_anime,
                fake_anime_upscale,
            )

            disc_human_loss = (w_d_loss(disc_real_human, disc_fake_human) +
                               discriminator_human_gradient_penalty)
            disc_anime_loss = (w_d_loss(disc_real_anime, disc_fake_anime) +
                               discriminator_anime_gradient_penalty)
            disc_upscale_loss = (w_d_loss(disc_real_big, disc_fake_upscale) +
                                 discriminator_upscale_gradient_penalty)
            tf.print("disc_real_big", disc_real_big)
            tf.print("disc_fake_upscale", disc_fake_upscale)
            tf.print("disc_upscale_loss", disc_upscale_loss)

        discriminator_human_gradients = tape.gradient(
            disc_human_loss, discriminator_human.trainable_variables)
        discriminator_anime_gradients = tape.gradient(
            disc_anime_loss, discriminator_anime.trainable_variables)
        discriminator_upscale_gradients = tape.gradient(
            disc_upscale_loss, discriminator_anime_upscale.trainable_variables)
        discriminator_human_optimizer.apply_gradients(
            zip(discriminator_human_gradients,
                discriminator_human.trainable_variables))
        discriminator_anime_optimizer.apply_gradients(
            zip(discriminator_anime_gradients,
                discriminator_anime.trainable_variables))
        discriminator_anime_upscale_optimizer.apply_gradients(
            zip(
                discriminator_upscale_gradients,
                discriminator_anime_upscale.trainable_variables,
            ))
示例#4
0
    def trainstep_D(
        real_human,
        real_anime,
        big_anime,
        fake_anime,
        cycled_anime,
        same_anime,
        fake_human,
        cycled_human,
        same_human,
        fake_anime_upscale,
        cycled_anime_upscale,
        same_anime_upscale,
    ):
        with tf.GradientTape(persistent=True) as tape:
            disc_real_human = discriminator_human(real_human, training=True)
            disc_real_anime = discriminator_anime(real_anime, training=True)

            disc_fake_human = discriminator_human(fake_human, training=True)
            disc_fake_anime = discriminator_anime(fake_anime, training=True)

            disc_real_big = discriminator_anime_upscale(big_anime,
                                                        training=True)
            disc_fake_upscale = discriminator_anime_upscale(fake_anime_upscale,
                                                            training=True)
            disc_cycled_upscale = discriminator_anime_upscale(
                cycled_anime_upscale, training=True)
            disc_same_upscale = discriminator_anime_upscale(same_anime_upscale,
                                                            training=True)

            discriminator_human_gradient_penalty = gradient_penalty(
                functools.partial(discriminator_human, training=True),
                real_human,
                fake_human,
            )
            discriminator_anime_gradient_penalty = gradient_penalty(
                functools.partial(discriminator_anime, training=True),
                real_anime,
                fake_anime,
            )
            discriminator_upscale_gradient_penalty = gradient_penalty(
                functools.partial(discriminator_human, training=True),
                big_anime,
                fake_anime_upscale,
            )
            discriminator_upscale_gradient_penalty = gradient_penalty(
                functools.partial(discriminator_human, training=True),
                big_anime,
                cycled_anime_upscale,
            )
            discriminator_upscale_gradient_penalty = gradient_penalty(
                functools.partial(discriminator_human, training=True),
                big_anime,
                same_anime_upscale,
            )

            disc_human_loss = (w_d_loss(disc_real_human, disc_fake_human) +
                               discriminator_human_gradient_penalty)
            disc_anime_loss = (w_d_loss(disc_real_anime, disc_fake_anime) +
                               discriminator_anime_gradient_penalty)
            # # print("ggg",big_anime.shape)
            disc_upscale_loss = (w_d_loss(disc_real_big, disc_fake_upscale) +
                                 w_d_loss(disc_real_big, disc_cycled_upscale) +
                                 w_d_loss(disc_real_big, disc_same_upscale) +
                                 discriminator_upscale_gradient_penalty) / 3.0
            scaled_disc_human_loss = discriminator_human_optimizer.get_scaled_loss(
                disc_human_loss)
            scaled_disc_anime_loss = discriminator_anime_optimizer.get_scaled_loss(
                disc_anime_loss)
            scaled_disc_upscale_loss = discriminator_anime_upscale_optimizer.get_scaled_loss(
                disc_upscale_loss)

        discriminator_human_gradients = discriminator_human_optimizer.get_unscaled_gradients(
            tape.gradient(scaled_disc_human_loss,
                          discriminator_human.trainable_variables))
        discriminator_anime_gradients = discriminator_anime_optimizer.get_unscaled_gradients(
            tape.gradient(scaled_disc_anime_loss,
                          discriminator_anime.trainable_variables))
        discriminator_upscale_gradients = discriminator_anime_upscale_optimizer.get_unscaled_gradients(
            tape.gradient(
                scaled_disc_upscale_loss,
                discriminator_anime_upscale.trainable_variables,
            ))
        discriminator_human_optimizer.apply_gradients(
            zip(discriminator_human_gradients,
                discriminator_human.trainable_variables))
        discriminator_anime_optimizer.apply_gradients(
            zip(discriminator_anime_gradients,
                discriminator_anime.trainable_variables))
        discriminator_anime_upscale_optimizer.apply_gradients(
            zip(
                discriminator_upscale_gradients,
                discriminator_anime_upscale.trainable_variables,
            ))