Beispiel #1
0
def run_tensorflow():
    """
    [summary] This is needed for tensorflow to free up my gpu ram...
    """

    gpus = tf.config.experimental.list_physical_devices("GPU")
    if gpus:
        try:
            # Currently, memory growth needs to be the same across GPUs
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            logical_gpus = tf.config.experimental.list_logical_devices("GPU")
            print(len(gpus), "Physical GPUs,", len(logical_gpus),
                  "Logical GPUs")
        except RuntimeError as e:
            # Memory growth must be set before GPUs have been initialized
            print(e)

    mixed_precision = tf.keras.mixed_precision.experimental

    policy = mixed_precision.Policy("mixed_float16")
    mixed_precision.set_policy(policy)

    AnimeCleanData = getAnimeCleanData(BATCH_SIZE=5)
    CelebaData = getCelebaData(BATCH_SIZE=5)

    logdir = "./logs/train_data/" + datetime.now().strftime("%Y%m%d-%H%M%S")
    file_writer = tf.summary.create_file_writer(logdir)

    generator_to_anime_optimizer = mixed_precision.LossScaleOptimizer(
        tf.keras.optimizers.Adam(2e-4, beta_1=0.5), loss_scale="dynamic")
    generator_to_human_optimizer = mixed_precision.LossScaleOptimizer(
        tf.keras.optimizers.Adam(2e-4, beta_1=0.5), loss_scale="dynamic")

    generator_anime_upscale_optimizer = mixed_precision.LossScaleOptimizer(
        tf.keras.optimizers.Adam(2e-4, beta_1=0.5), loss_scale="dynamic")

    discriminator_human_optimizer = mixed_precision.LossScaleOptimizer(
        tf.keras.optimizers.Adam(2e-4, beta_1=0.5), loss_scale="dynamic")
    discriminator_anime_optimizer = mixed_precision.LossScaleOptimizer(
        tf.keras.optimizers.Adam(2e-4, beta_1=0.5), loss_scale="dynamic")

    discriminator_anime_upscale_optimizer = mixed_precision.LossScaleOptimizer(
        tf.keras.optimizers.Adam(2e-4, beta_1=0.5), loss_scale="dynamic")

    generator_to_anime = GeneratorV2()
    generator_to_human = GeneratorV2()

    generator_anime_upscale = UpsampleGenerator()

    # input: Batch, 256,256,3
    discriminator_human = Discriminator()
    discriminator_anime = Discriminator()

    discriminator_anime_upscale = Discriminator()

    checkpoint_path = "./checkpoints/train"

    ckpt = tf.train.Checkpoint(
        generator_to_anime=generator_to_anime,
        generator_to_human=generator_to_human,
        generator_anime_upscale=generator_anime_upscale,  # *
        discriminator_human=discriminator_human,
        discriminator_anime=discriminator_anime,
        discriminator_anime_upscale=discriminator_anime_upscale,  # *
        generator_to_anime_optimizer=generator_to_anime_optimizer,
        generator_to_human_optimizer=generator_to_human_optimizer,
        generator_anime_upscale_optimizer=generator_anime_upscale_optimizer,  # *
        discriminator_human_optimizer=discriminator_human_optimizer,
        discriminator_anime_optimizer=discriminator_anime_optimizer,
        discriminator_anime_upscale_optimizer=
        discriminator_anime_upscale_optimizer,  # *
    )

    ckpt_manager = tf.train.CheckpointManager(ckpt,
                                              checkpoint_path,
                                              max_to_keep=5)

    # if a checkpoint exists, restore the latest checkpoint.
    if ckpt_manager.latest_checkpoint:
        ckpt.restore(ckpt_manager.latest_checkpoint)
        print("Latest checkpoint restored!!")
    # out: Batch, 16, 16, 1
    # x is human, y is anime
    @tf.function
    def trainstep(real_human, real_anime, big_anime):
        with tf.GradientTape(persistent=True) as tape:

            fake_anime = generator_to_anime(real_human, training=True)
            cycled_human = generator_to_human(fake_anime, training=True)

            fake_human = generator_to_human(real_anime, training=True)
            cycled_anime = generator_to_anime(fake_human, training=True)

            # same_human and same_anime are used for identity loss.
            same_human = generator_to_human(real_human, training=True)
            same_anime = generator_to_anime(real_anime, training=True)

            disc_real_human = discriminator_human(real_human, training=True)
            disc_real_anime = discriminator_anime(real_anime, training=True)

            disc_fake_human = discriminator_human(fake_human, training=True)
            disc_fake_anime = discriminator_anime(fake_anime, training=True)

            # calculate the loss
            gen_anime_loss = generator_loss(disc_fake_anime)
            gen_human_loss = generator_loss(disc_fake_human)

            total_cycle_loss = cycle_loss(real_human,
                                          cycled_human) + cycle_loss(
                                              real_anime, cycled_anime)

            # Total generator loss = adversarial loss + cycle loss
            total_gen_anime_loss = (gen_anime_loss + total_cycle_loss +
                                    identity_loss(real_anime, same_anime))
            total_gen_anime_loss = generator_to_anime_optimizer.get_scaled_loss(
                total_gen_anime_loss)

            total_gen_human_loss = (gen_human_loss + total_cycle_loss +
                                    identity_loss(real_human, same_human))
            total_gen_human_loss = generator_to_human_optimizer.get_scaled_loss(
                total_gen_human_loss)

            disc_human_loss = discriminator_loss(disc_real_human,
                                                 disc_fake_human)
            disc_human_loss = discriminator_human_optimizer.get_scaled_loss(
                disc_human_loss)

            disc_anime_loss = discriminator_loss(disc_real_anime,
                                                 disc_fake_anime)
            disc_anime_loss = discriminator_anime_optimizer.get_scaled_loss(
                disc_anime_loss)

            # My part

            fake_anime_upscale = generator_anime_upscale(fake_anime)
            cycle_anime_upscale = generator_anime_upscale(real_anime)
            same_anime_upscale = generator_anime_upscale(same_anime)

            disc_fake_upscale = discriminator_anime_upscale(fake_anime_upscale)
            disc_cycle_upscale = discriminator_anime_upscale(
                cycle_anime_upscale)
            disc_same_upscale = discriminator_anime_upscale(same_anime_upscale)

            disc_real_big = discriminator_anime_upscale(big_anime)

            gen_upscale_loss = (
                generator_loss(disc_fake_upscale) * 3 +
                generator_loss(disc_cycle_upscale) +
                generator_loss(disc_same_upscale)
                # + mse_loss(big_anime, cycle_anime_upscale)
                + identity_loss(big_anime, cycle_anime_upscale) * 0.5 +
                identity_loss(big_anime, same_anime_upscale) * 0.5)

            gen_upscale_loss = generator_anime_upscale_optimizer.get_scaled_loss(
                gen_upscale_loss)

            disc_upscale_loss = discriminator_upscale_loss(
                disc_real_big, disc_fake_upscale, disc_cycle_upscale,
                disc_same_upscale)

            disc_upscale_loss = discriminator_anime_upscale_optimizer.get_scaled_loss(
                disc_upscale_loss)

        # Calculate the gradients for generator and discriminator
        generator_to_anime_gradients = tape.gradient(
            total_gen_anime_loss, generator_to_anime.trainable_variables)
        generator_to_human_gradients = tape.gradient(
            total_gen_human_loss, generator_to_human.trainable_variables)

        discriminator_human_gradients = tape.gradient(
            disc_human_loss, discriminator_human.trainable_variables)
        discriminator_anime_gradients = tape.gradient(
            disc_anime_loss, discriminator_anime.trainable_variables)

        generator_upscale_gradients = tape.gradient(
            gen_upscale_loss, generator_anime_upscale.trainable_variables)

        discriminator_upscale_gradients = tape.gradient(
            disc_upscale_loss, discriminator_anime_upscale.trainable_variables)

        # Apply the gradients to the optimizer
        generator_to_anime_gradients = generator_to_anime_optimizer.get_unscaled_gradients(
            generator_to_anime_gradients)
        generator_to_anime_optimizer.apply_gradients(
            zip(generator_to_anime_gradients,
                generator_to_anime.trainable_variables))
        generator_to_human_gradients = generator_to_human_optimizer.get_unscaled_gradients(
            generator_to_human_gradients)
        generator_to_human_optimizer.apply_gradients(
            zip(generator_to_human_gradients,
                generator_to_human.trainable_variables))

        discriminator_human_gradients = discriminator_human_optimizer.get_unscaled_gradients(
            discriminator_human_gradients)
        discriminator_human_optimizer.apply_gradients(
            zip(discriminator_human_gradients,
                discriminator_human.trainable_variables))

        discriminator_anime_gradients = discriminator_anime_optimizer.get_unscaled_gradients(
            discriminator_anime_gradients)
        discriminator_anime_optimizer.apply_gradients(
            zip(discriminator_anime_gradients,
                discriminator_anime.trainable_variables))

        generator_upscale_gradients = generator_anime_upscale_optimizer.get_unscaled_gradients(
            generator_upscale_gradients)
        generator_anime_upscale_optimizer.apply_gradients(
            zip(generator_upscale_gradients,
                generator_anime_upscale.trainable_variables))

        discriminator_upscale_gradients = discriminator_anime_upscale_optimizer.get_unscaled_gradients(
            discriminator_upscale_gradients)
        discriminator_anime_upscale_optimizer.apply_gradients(
            zip(
                discriminator_upscale_gradients,
                discriminator_anime_upscale.trainable_variables,
            ))

        return (
            real_human,
            real_anime,
            fake_anime,
            cycled_human,
            fake_human,
            cycled_anime,
            same_human,
            same_anime,
            fake_anime_upscale,
            same_anime_upscale,
            gen_anime_loss,
            gen_human_loss,
            disc_human_loss,
            disc_anime_loss,
            total_gen_anime_loss,
            total_gen_human_loss,
            gen_upscale_loss,
            disc_upscale_loss,
        )

    def process_data_for_display(input_image):
        return input_image * 0.5 + 0.5

    counter = 0
    i = -1
    while True:
        i = i + 1
        counter = counter + 1
        AnimeBatchImage, BigAnimeBatchImage = next(iter(AnimeCleanData))
        CelebaBatchImage = next(iter(CelebaData))
        print(counter)

        if not (i % 5):

            (
                AnimeTrainImage,
                CelebATrainImage,
                fake_anime,
                cycled_human,
                fake_human,
                cycled_anime,
                same_human,
                same_anime,
                fake_anime_upscale,
                same_anime_upscale,
                gen_anime_loss,
                gen_human_loss,
                disc_human_loss,
                disc_anime_loss,
                total_gen_anime_loss,
                total_gen_human_loss,
                gen_upscale_loss,
                disc_upscale_loss,
            ) = trainstep(CelebaBatchImage, AnimeBatchImage,
                          BigAnimeBatchImage)

            with file_writer.as_default():
                tf.summary.image(
                    "AnimeTrainImage",
                    process_data_for_display(AnimeTrainImage),
                    step=counter,
                )
                tf.summary.image(
                    "CelebATrainImage",
                    process_data_for_display(CelebATrainImage),
                    step=counter,
                )
                tf.summary.image("fake_anime",
                                 process_data_for_display(fake_anime),
                                 step=counter)
                tf.summary.image("cycled_human",
                                 process_data_for_display(cycled_human),
                                 step=counter)
                tf.summary.image("fake_human",
                                 process_data_for_display(fake_human),
                                 step=counter)
                tf.summary.image("cycled_anime",
                                 process_data_for_display(cycled_anime),
                                 step=counter)
                tf.summary.image("same_human",
                                 process_data_for_display(same_human),
                                 step=counter)
                tf.summary.image("same_anime",
                                 process_data_for_display(same_anime),
                                 step=counter)

                tf.summary.image("fake_anime_upscale",
                                 fake_anime_upscale,
                                 step=counter)
                tf.summary.image("same_anime_upscale",
                                 same_anime_upscale,
                                 step=counter)

                tf.summary.scalar("gen_anime_loss",
                                  gen_anime_loss,
                                  step=counter)
                tf.summary.scalar("gen_human_loss",
                                  gen_human_loss,
                                  step=counter)
                tf.summary.scalar("disc_human_loss",
                                  disc_human_loss,
                                  step=counter)
                tf.summary.scalar("disc_anime_loss",
                                  disc_anime_loss,
                                  step=counter)
                tf.summary.scalar("total_gen_anime_loss",
                                  total_gen_anime_loss,
                                  step=counter)
                tf.summary.scalar("total_gen_human_loss",
                                  total_gen_human_loss,
                                  step=counter)
                tf.summary.scalar("gen_upscale_loss",
                                  gen_upscale_loss,
                                  step=counter)
                tf.summary.scalar("disc_upscale_loss",
                                  disc_upscale_loss,
                                  step=counter)

            ckpt_manager.save()
        else:
            trainstep(CelebaBatchImage, AnimeBatchImage, BigAnimeBatchImage)
Beispiel #2
0
def run_tensorflow():
    """
    [summary] This is needed for tensorflow to free up my gpu ram...
    """

    gpus = tf.config.experimental.list_physical_devices("GPU")
    if gpus:
        try:
            # Currently, memory growth needs to be the same across GPUs
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            logical_gpus = tf.config.experimental.list_logical_devices("GPU")
            print(len(gpus), "Physical GPUs,", len(logical_gpus),
                  "Logical GPUs")
        except RuntimeError as e:
            # Memory growth must be set before GPUs have been initialized
            print(e)

    mixed_precision = tf.keras.mixed_precision.experimental

    policy = mixed_precision.Policy("mixed_float16")
    mixed_precision.set_policy(policy)

    AnimeCleanData = getAnimeCleanData(BATCH_SIZE=batch_size)
    CelebaData = getCelebaData(BATCH_SIZE=batch_size)

    logdir = "./logs/XWGan/" + datetime.now().strftime("%Y%m%d-%H%M%S")
    file_writer = tf.summary.create_file_writer(logdir)

    checkpoint_path = "./checkpoints/XWGan"

    encode_anime = encoder_seperate_layers()
    encode_human = encoder_seperate_layers()
    encode_share = encoder_shared_layers()

    decode_share = decoder_shared_layers()
    decode_human = decoder_seperate_layers()
    decode_anime = decoder_seperate_layers()
    c_dann = C_dann()
    D = W_Discriminator()

    optims = [
        mixed_precision.LossScaleOptimizer(
            tf.keras.optimizers.Adam(1e-4, beta_1=0.9, beta_2=0.99),
            loss_scale="dynamic") for _ in range(8)
    ]
    # gan_optim = mixed_precision.LossScaleOptimizer(
    #     tf.keras.optimizers.Adam(1e-4, beta_1=0.5), loss_scale="dynamic"
    # )
    # dis_optim = mixed_precision.LossScaleOptimizer(
    #     tf.keras.optimizers.Adam(1e-4, beta_1=0.5), loss_scale="dynamic"
    # )

    ckpt = tf.train.Checkpoint(
        encode_anime=encode_anime,
        encode_human=encode_human,
        encode_share=encode_share,
        decode_share=decode_share,
        decode_human=decode_human,
        decode_anime=decode_anime,
        c_dann=c_dann,
    )

    ckpt_manager = tf.train.CheckpointManager(ckpt,
                                              checkpoint_path,
                                              max_to_keep=5)

    # if a checkpoint exists, restore the latest checkpoint.
    if ckpt_manager.latest_checkpoint:
        ckpt.restore(ckpt_manager.latest_checkpoint)
        print("Latest checkpoint restored!!")

    @tf.function
    def trainstep(real_human, real_anime, big_anime):
        with tf.GradientTape(persistent=True) as tape:
            latent_anime = encode_share(encode_anime(real_anime))
            latent_human = encode_share(encode_human(real_human))

            recon_anime = decode_anime(decode_share(latent_anime))
            recon_human = decode_human(decode_share(latent_human))

            fake_anime = decode_anime(decode_share(latent_human))
            latent_human_cycled = encode_share(encode_anime(fake_anime))

            fake_human = decode_anime(decode_share(latent_anime))
            latent_anime_cycled = encode_share(encode_anime(fake_human))

            def kl_loss(mean, log_var):
                loss = 1 + log_var - tf.math.square(mean) + tf.math.exp(
                    log_var)
                loss = tf.reduce_sum(loss, axis=-1) * -0.5
                return loss

            disc_fake = D(fake_anime)
            disc_real = D(real_anime)

            c_dann_anime = c_dann(latent_anime)
            c_dann_human = c_dann(latent_human)

            loss_anime_encode = identity_loss(real_anime, recon_anime) * 3
            loss_human_encode = identity_loss(real_human, recon_human) * 3

            loss_domain_adversarial = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=tf.zeros_like(c_dann_anime),
                    logits=c_dann_anime)) + tf.reduce_mean(
                        tf.nn.sigmoid_cross_entropy_with_logits(
                            labels=tf.ones_like(c_dann_human),
                            logits=c_dann_human))
            loss_domain_adversarial = tf.math.minimum(loss_domain_adversarial,
                                                      100)
            loss_domain_adversarial = loss_domain_adversarial * 0.2
            tf.print(loss_domain_adversarial)

            loss_semantic_consistency = (
                identity_loss(latent_anime, latent_anime_cycled) * 3 +
                identity_loss(latent_human, latent_human_cycled) * 3)

            loss_gan = w_g_loss(disc_fake)

            anime_encode_total_loss = (loss_anime_encode +
                                       loss_domain_adversarial +
                                       loss_semantic_consistency + loss_gan)
            human_encode_total_loss = (loss_human_encode +
                                       loss_domain_adversarial +
                                       loss_semantic_consistency)
            share_encode_total_loss = (loss_anime_encode +
                                       loss_domain_adversarial +
                                       loss_semantic_consistency + loss_gan +
                                       loss_human_encode)

            share_decode_total_loss = loss_anime_encode + loss_human_encode + loss_gan
            anime_decode_total_loss = loss_anime_encode + loss_gan
            human_decode_total_loss = loss_human_encode

            # loss_disc = (
            #     mse_loss(tf.ones_like(disc_fake), disc_fake)
            #     + mse_loss(tf.zeros_like(disc_real), disc_real)
            # ) * 10
            loss_disc = w_d_loss(disc_real, disc_fake)
            loss_disc += gradient_penalty(partial(D, training=True),
                                          real_anime, fake_anime)

            losses = [
                anime_encode_total_loss, human_encode_total_loss,
                share_encode_total_loss, loss_domain_adversarial,
                share_decode_total_loss, anime_decode_total_loss,
                human_decode_total_loss, loss_disc
            ]

            scaled_losses = [
                optim.get_scaled_loss(loss)
                for optim, loss in zip(optims, losses)
            ]

        list_variables = [
            encode_anime.trainable_variables, encode_human.trainable_variables,
            encode_share.trainable_variables, c_dann.trainable_variables,
            decode_share.trainable_variables, decode_anime.trainable_variables,
            decode_human.trainable_variables, D.trainable_variables
        ]
        gan_grad = [
            tape.gradient(scaled_loss, train_variable) for scaled_loss,
            train_variable in zip(scaled_losses, list_variables)
        ]
        gan_grad = [
            optim.get_unscaled_gradients(x)
            for optim, x in zip(optims, gan_grad)
        ]
        for optim, grad, trainable in zip(optims, gan_grad, list_variables):
            optim.apply_gradients(zip(grad, trainable))
        # dis_grad = dis_optim.get_unscaled_gradients(
        #     tape.gradient(scaled_loss_disc, D.trainable_variables)
        # )
        # dis_optim.apply_gradients(zip(dis_grad, D.trainable_variables))

        return (real_human, real_anime, recon_anime, recon_human, fake_anime,
                fake_human, loss_anime_encode, loss_human_encode,
                loss_domain_adversarial, loss_semantic_consistency,
                loss_gan, loss_disc, tf.reduce_mean(disc_fake),
                tf.reduce_mean(disc_real))

    def process_data_for_display(input_image):
        return input_image * 0.5 + 0.5

    print_string = [
        "real_human", "real_anime", "recon_anime", "recon_human", "fake_anime",
        "fake_human", "loss_anime_encode", "loss_human_encode",
        "loss_domain_adversarial", "loss_semantic_consistency", "loss_gan",
        "loss_disc", "disc_fake", "disc_real"
    ]

    counter = 0
    i = -1
    while True:
        i = i + 1
        counter = counter + 1
        AnimeBatchImage, BigAnimeBatchImage = next(iter(AnimeCleanData))
        CelebaBatchImage = next(iter(CelebaData))
        print(counter)

        if not (i % 5):
            result = trainstep(CelebaBatchImage, AnimeBatchImage,
                               BigAnimeBatchImage)

            with file_writer.as_default():
                for j in range(len(result)):

                    if j < 6:
                        tf.summary.image(
                            print_string[j],
                            process_data_for_display(result[j]),
                            step=counter,
                        )
                    else:
                        print(print_string[j], result[j])
                        tf.summary.scalar(
                            print_string[j],
                            result[j],
                            step=counter,
                        )

            ckpt_manager.save()
        else:
            trainstep(CelebaBatchImage, AnimeBatchImage, BigAnimeBatchImage)
Beispiel #3
0
def run_tensorflow():
    """
    [summary] This is needed for tensorflow to free up my gpu ram...
    """

    gpus = tf.config.experimental.list_physical_devices("GPU")
    if gpus:
        try:
            # Currently, memory growth needs to be the same across GPUs
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            logical_gpus = tf.config.experimental.list_logical_devices("GPU")
            print(len(gpus), "Physical GPUs,", len(logical_gpus),
                  "Logical GPUs")
        except RuntimeError as e:
            # Memory growth must be set before GPUs have been initialized
            print(e)

    AnimeCleanData = getAnimeCleanData(BATCH_SIZE=1)
    CelebAData = getCelebaData(BATCH_SIZE=1)

    AnimeCleanData_iter = iter(AnimeCleanData)
    CelebAData_iter = iter(CelebAData)

    logdir = "./logs/train_data/" + datetime.now().strftime("%Y%m%d-%H%M%S")
    file_writer = tf.summary.create_file_writer(logdir)

    generator_to_anime_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    generator_to_human_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    generator_anime_upscale_optimizer = tf.keras.optimizers.Adam(2e-4,
                                                                 beta_1=0.5)
    discriminator_human_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    discriminator_anime_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    discriminator_anime_upscale_optimizer = tf.keras.optimizers.Adam(
        2e-4, beta_1=0.5)

    generator_to_anime = GeneratorV2()
    generator_to_human = GeneratorV2()

    generator_anime_upscale = UpsampleGenerator()

    # input: Batch, 256,256,3
    discriminator_human = W_Discriminator()
    discriminator_anime = W_Discriminator()

    discriminator_anime_upscale = W_Discriminator()

    checkpoint_path = "./checkpoints/train"

    ckpt = tf.train.Checkpoint(
        generator_to_anime=generator_to_anime,
        generator_to_human=generator_to_human,
        generator_anime_upscale=generator_anime_upscale,  # *
        discriminator_human=discriminator_human,
        discriminator_anime=discriminator_anime,
        discriminator_anime_upscale=discriminator_anime_upscale,  # *
        generator_to_anime_optimizer=generator_to_anime_optimizer,
        generator_to_human_optimizer=generator_to_human_optimizer,
        generator_anime_upscale_optimizer=generator_anime_upscale_optimizer,  # *
        discriminator_human_optimizer=discriminator_human_optimizer,
        discriminator_anime_optimizer=discriminator_anime_optimizer,
        discriminator_anime_upscale_optimizer=
        discriminator_anime_upscale_optimizer,  # *
    )

    ckpt_manager = tf.train.CheckpointManager(ckpt,
                                              checkpoint_path,
                                              max_to_keep=5)

    # if a checkpoint exists, restore the latest checkpoint.
    if ckpt_manager.latest_checkpoint:
        ckpt.restore(ckpt_manager.latest_checkpoint)
        print("Latest checkpoint restored!!")

    @tf.function
    def init_pool(AnimeBatchImage, BigAnimeBatchImage, CelebaBatchImage):
        # store 30 images per type for training tool
        fake_anime_pool = generator_to_anime(CelebaBatchImage)
        cycled_anime_pool = generator_to_anime(
            generator_to_anime(AnimeBatchImage), )
        same_anime_pool = generator_to_anime(CelebaBatchImage)

        fake_human_pool = generator_to_human(AnimeBatchImage)
        cycled_human_pool = generator_to_human(
            generator_to_anime(CelebaBatchImage), )
        same_human_pool = generator_to_human(AnimeBatchImage)

        fake_anime_upscale_pool = generator_anime_upscale(fake_anime_pool)
        cycled_anime_upscale_pool = generator_anime_upscale(cycled_anime_pool)
        same_anime_upscale_pool = generator_anime_upscale(same_anime_pool)

        data_pools = [
            fake_anime_pool,
            cycled_anime_pool,
            same_anime_pool,
            fake_human_pool,
            cycled_human_pool,
            same_human_pool,
            fake_anime_upscale_pool,
            cycled_anime_upscale_pool,
            same_anime_upscale_pool,
        ]
        return data_pools

    AnimeBatchImage, BigAnimeBatchImage = next(AnimeCleanData_iter)
    CelebaBatchImage = next(CelebAData_iter)
    data_pools = init_pool(AnimeBatchImage, BigAnimeBatchImage,
                           CelebaBatchImage)

    def add_data_to_pool(pool, new_data, pool_size=50):
        pool = new_data
        # tf.random.shuffle(pool)
        # pool = pool[:pool_size, :, :, :]
        # tf.concat([pool, new_data], 0)

    def get_data_from_pool(pool, batch_size=8):
        return pool[:batch_size, :, :, :]

    # out: Batch, 16, 16, 1
    # x is human, y is anime
    @tf.function
    def trainstep_G(real_human, real_anime, big_anime):
        with tf.GradientTape(persistent=True) as tape:
            fake_anime = generator_to_anime(real_human, training=True)
            cycled_human = generator_to_human(fake_anime, training=True)

            fake_human = generator_to_human(real_anime, training=True)
            cycled_anime = generator_to_anime(fake_human, training=True)

            same_human = generator_to_human(real_human, training=True)
            same_anime = generator_to_anime(real_anime, training=True)

            disc_fake_human = discriminator_human(fake_human, training=True)
            disc_fake_anime = discriminator_anime(fake_anime, training=True)

            fake_anime_upscale = generator_anime_upscale(fake_anime,
                                                         training=True)
            real_anime_upscale = generator_anime_upscale(real_anime,
                                                         training=True)

            cycled_anime_upscale = generator_anime_upscale(cycled_anime,
                                                           training=True)
            same_anime_upscale = generator_anime_upscale(same_anime,
                                                         training=True)

            disc_fake_upscale = discriminator_anime_upscale(fake_anime_upscale,
                                                            training=True)
            disc_real_upscale = discriminator_anime_upscale(real_anime_upscale,
                                                            training=True)
            # calculate the loss
            gen_anime_loss = w_g_loss(disc_fake_anime)
            gen_human_loss = w_g_loss(disc_fake_human)

            total_cycle_loss = cycle_loss(real_human,
                                          cycled_human) + cycle_loss(
                                              real_anime, cycled_anime)

            # Total generator loss = adversarial loss + cycle loss
            total_gen_anime_loss = (gen_anime_loss * 1e3 + total_cycle_loss +
                                    identity_loss(real_anime, same_anime))

            tf.print("gen_anime_loss*1e3", gen_anime_loss * 1e3)
            tf.print("total_cycle_loss", total_cycle_loss)
            tf.print("identity_loss", identity_loss(real_anime, same_anime))
            tf.print("--------------------------")
            total_gen_human_loss = (gen_human_loss * 1e3 + total_cycle_loss +
                                    identity_loss(real_human, same_human))

            gen_upscale_loss = (
                w_g_loss(disc_fake_upscale) * 1e3
                # + w_g_loss(disc_cycle_upscale)
                # + w_g_loss(disc_same_upscale)
                + identity_loss(big_anime, real_anime_upscale) * 1e-6
                # + identity_loss(big_anime, same_anime_upscale)
            )

            tf.print("w_g_loss(disc_fake_upscale)",
                     w_g_loss(disc_fake_upscale))
            tf.print("identity_loss(big_anime, disc_real_upscale)",
                     identity_loss(big_anime, disc_real_upscale))

            # tf.print("w_g_loss(disc_cycle_upscale)", w_g_loss(disc_cycle_upscale))
            # tf.print("w_g_loss(disc_same_upscale)", w_g_loss(disc_same_upscale))
            # tf.print(
            #     "identity_loss(big_anime, cycled_anime_upscale)",
            #     identity_loss(big_anime, cycled_anime_upscale),
            # )
            # tf.print(
            #     "identity_loss(big_anime, same_anime_upscale)",
            #     identity_loss(big_anime, same_anime_upscale),
            # )

        generator_to_anime_gradients = tape.gradient(
            total_gen_anime_loss, generator_to_anime.trainable_variables)
        generator_to_human_gradients = tape.gradient(
            total_gen_human_loss, generator_to_human.trainable_variables)
        generator_upscale_gradients = tape.gradient(
            gen_upscale_loss, generator_anime_upscale.trainable_variables)
        generator_to_anime_optimizer.apply_gradients(
            zip(generator_to_anime_gradients,
                generator_to_anime.trainable_variables))
        generator_to_human_optimizer.apply_gradients(
            zip(generator_to_human_gradients,
                generator_to_human.trainable_variables))
        generator_anime_upscale_optimizer.apply_gradients(
            zip(generator_upscale_gradients,
                generator_anime_upscale.trainable_variables))

        return [
            real_human,
            real_anime,
            fake_anime,
            cycled_anime,
            same_anime,
            fake_human,
            cycled_human,
            same_human,
            fake_anime_upscale,
            cycled_anime_upscale,
            same_anime_upscale,
            gen_anime_loss,
            gen_human_loss,
            total_gen_anime_loss,
            total_gen_human_loss,
            gen_upscale_loss,
        ]

    @tf.function
    def trainstep_D(
        real_human,
        real_anime,
        big_anime,
        fake_anime,
        cycled_anime,
        same_anime,
        fake_human,
        cycled_human,
        same_human,
        fake_anime_upscale,
        cycled_anime_upscale,
        same_anime_upscale,
    ):
        with tf.GradientTape(persistent=True) as tape:
            disc_real_human = discriminator_human(real_human, training=True)
            disc_real_anime = discriminator_anime(real_anime, training=True)

            disc_fake_human = discriminator_human(fake_human, training=True)
            disc_fake_anime = discriminator_anime(fake_anime, training=True)

            disc_real_big = discriminator_anime_upscale(big_anime,
                                                        training=True)
            disc_fake_upscale = discriminator_anime_upscale(fake_anime_upscale,
                                                            training=True)
            # disc_same_upscale = discriminator_anime_upscale(
            #     same_anime_upscale, training=True
            # )

            discriminator_human_gradient_penalty = gradient_penalty(
                functools.partial(discriminator_human, training=True),
                real_human,
                fake_human,
            )
            discriminator_anime_gradient_penalty = gradient_penalty(
                functools.partial(discriminator_anime, training=True),
                real_anime,
                fake_anime,
            )
            discriminator_upscale_gradient_penalty = gradient_penalty(
                functools.partial(discriminator_human, training=True),
                big_anime,
                fake_anime_upscale,
            )

            disc_human_loss = (w_d_loss(disc_real_human, disc_fake_human) +
                               discriminator_human_gradient_penalty)
            disc_anime_loss = (w_d_loss(disc_real_anime, disc_fake_anime) +
                               discriminator_anime_gradient_penalty)
            disc_upscale_loss = (w_d_loss(disc_real_big, disc_fake_upscale) +
                                 discriminator_upscale_gradient_penalty)
            tf.print("disc_real_big", disc_real_big)
            tf.print("disc_fake_upscale", disc_fake_upscale)
            tf.print("disc_upscale_loss", disc_upscale_loss)

        discriminator_human_gradients = tape.gradient(
            disc_human_loss, discriminator_human.trainable_variables)
        discriminator_anime_gradients = tape.gradient(
            disc_anime_loss, discriminator_anime.trainable_variables)
        discriminator_upscale_gradients = tape.gradient(
            disc_upscale_loss, discriminator_anime_upscale.trainable_variables)
        discriminator_human_optimizer.apply_gradients(
            zip(discriminator_human_gradients,
                discriminator_human.trainable_variables))
        discriminator_anime_optimizer.apply_gradients(
            zip(discriminator_anime_gradients,
                discriminator_anime.trainable_variables))
        discriminator_anime_upscale_optimizer.apply_gradients(
            zip(
                discriminator_upscale_gradients,
                discriminator_anime_upscale.trainable_variables,
            ))

    def process_data_for_display(input_image):
        return input_image * 0.5 + 0.5

    counter = 0
    i = -1

    print_string = [
        "real_human",
        "real_anime",
        "fake_anime",
        "cycled_anime",
        "same_anime",
        "fake_human",
        "cycled_human",
        "same_human",
        "fake_anime_upscale",
        "cycled_anime_upscale",
        "same_anime_upscale",
        "gen_anime_loss",
        "gen_human_loss",
        "total_gen_anime_loss",
        "total_gen_human_loss",
        "gen_upscale_loss",
    ]

    while True:
        i = i + 1
        counter = counter + 1
        AnimeBatchImage, BigAnimeBatchImage = next(AnimeCleanData_iter)
        CelebaBatchImage = next(CelebAData_iter)
        print(counter)

        # for j in range(3):
        result = trainstep_G(CelebaBatchImage, AnimeBatchImage,
                             BigAnimeBatchImage)
        for j in range(9):
            add_data_to_pool(data_pools[j], result[2 + j])
        trainstep_D(CelebaBatchImage, AnimeBatchImage, BigAnimeBatchImage,
                    *[get_data_from_pool(x) for x in data_pools])
        # print("generator_to_anime.count_params()",generator_to_anime.count_params() )
        # print("generator_to_human.count_params()",generator_to_human.count_params() )
        # print("generator_anime_upscale.count_params()",generator_anime_upscale.count_params() )
        # print("discriminator_human.count_params()",discriminator_human.count_params() )
        # print("discriminator_anime.count_params()",discriminator_anime.count_params() )
        # print("discriminator_anime_upscale.count_params()",discriminator_anime_upscale.count_params() )

        if not (i % 5):
            with file_writer.as_default():
                for j in range(11):
                    tf.summary.image(print_string[j],
                                     process_data_for_display(result[j]),
                                     step=i)
                for j in range(11, len(print_string)):
                    tf.summary.scalar(print_string[j], result[j], step=i)
            ckpt_manager.save()
Beispiel #4
0
def run_tensorflow():
    """
    [summary] This is needed for tensorflow to free up my gpu ram...
    """

    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            # Currently, memory growth needs to be the same across GPUs
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            logical_gpus = tf.config.experimental.list_logical_devices('GPU')
            print(len(gpus), "Physical GPUs,", len(logical_gpus),
                  "Logical GPUs")
        except RuntimeError as e:
            # Memory growth must be set before GPUs have been initialized
            print(e)

    # policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
    # tf.keras.mixed_precision.experimental.set_policy(policy)
    # print('Compute dtype: %s' % policy.compute_dtype)
    # print('Variable dtype: %s' % policy.variable_dtype)

    AnimeCleanData = getAnimeCleanData(BATCH_SIZE=32)
    CelebaData = getCelebaData()

    logdir = "../logs/train_data/" + datetime.now().strftime("%Y%m%d-%H%M%S")
    file_writer = tf.summary.create_file_writer(logdir)

    # AnimeBatchImage = next(iter(AnimeCleanData))
    # CelebaBatchImage = next(iter(CelebaData))
    # print(image.dtype)

    # # checkpoint_path = "./checkpoints/train"

    # # ckpt = tf.train.Checkpoint(generator_to_anime=generator_to_anime,
    # #                            generator_to_human=generator_to_human,
    # #                            discriminator_x=discriminator_x,
    # #                            discriminator_y=discriminator_y,
    # #                            generator_to_anime_optimizer=generator_to_anime_optimizer,
    # #                            generator_to_human_optimizer=generator_to_human_optimizer,
    # #                            discriminator_x_optimizer=discriminator_x_optimizer,
    # #                            discriminator_y_optimizer=discriminator_y_optimizer)

    # # ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

    # # # if a checkpoint exists, restore the latest checkpoint.
    # # if ckpt_manager.latest_checkpoint:
    # #   ckpt.restore(ckpt_manager.latest_checkpoint)
    # #   print ('Latest checkpoint restored!!')

    generator_to_anime_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    generator_to_human_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

    discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

    # input: Batch, 256,256,3
    discriminator_x = Discriminator()
    discriminator_y = Discriminator()
    # out: Batch, 16, 16, 1

    generator_to_anime = Generator()
    generator_to_human = Generator()

    # x is human, y is anime
    @tf.function
    def trainstep(real_human, real_anime):
        with tf.GradientTape(persistent=True) as tape:

            fake_anime = generator_to_anime(real_human, training=True)
            cycled_human = generator_to_human(fake_anime, training=True)

            fake_human = generator_to_human(real_anime, training=True)
            cycled_anime = generator_to_anime(fake_human, training=True)

            # same_human and same_anime are used for identity loss.
            same_human = generator_to_human(real_human, training=True)
            same_anime = generator_to_anime(real_anime, training=True)

            disc_real_human = discriminator_x(real_human, training=True)
            disc_real_anime = discriminator_y(real_anime, training=True)

            disc_fake_human = discriminator_x(fake_human, training=True)
            disc_fake_anime = discriminator_y(fake_anime, training=True)

            # calculate the loss
            gen_anime_loss = generator_loss(disc_fake_anime)
            gen_human_loss = generator_loss(disc_fake_human)

            total_cycle_loss = cycle_loss(real_human,
                                          cycled_human) + cycle_loss(
                                              real_anime, cycled_anime)

            # Total generator loss = adversarial loss + cycle loss
            total_gen_anime_loss = gen_anime_loss + total_cycle_loss + identity_loss(
                real_anime, same_anime)
            total_gen_human_loss = gen_human_loss + total_cycle_loss + identity_loss(
                real_human, same_human)

            disc_x_loss = discriminator_loss(disc_real_human, disc_fake_human)
            disc_y_loss = discriminator_loss(disc_real_anime, disc_fake_anime)

        # Calculate the gradients for generator and discriminator
        generator_to_anime_gradients = tape.gradient(
            total_gen_anime_loss, generator_to_anime.trainable_variables)
        generator_to_human_gradients = tape.gradient(
            total_gen_human_loss, generator_to_human.trainable_variables)

        discriminator_x_gradients = tape.gradient(
            disc_x_loss, discriminator_x.trainable_variables)
        discriminator_y_gradients = tape.gradient(
            disc_y_loss, discriminator_y.trainable_variables)

        # Apply the gradients to the optimizer
        generator_to_anime_optimizer.apply_gradients(
            zip(generator_to_anime_gradients,
                generator_to_anime.trainable_variables))

        generator_to_human_optimizer.apply_gradients(
            zip(generator_to_human_gradients,
                generator_to_human.trainable_variables))

        discriminator_x_optimizer.apply_gradients(
            zip(discriminator_x_gradients,
                discriminator_x.trainable_variables))

        discriminator_y_optimizer.apply_gradients(
            zip(discriminator_y_gradients,
                discriminator_y.trainable_variables))

        return fake_anime, cycled_human, fake_human, cycled_anime , same_human , same_anime, \
            gen_anime_loss, gen_human_loss, disc_x_loss, disc_y_loss, total_gen_anime_loss, total_gen_human_loss

    counter = 0
    i = -1
    while True:
        i = i + 1
        counter = counter + 1
        AnimeBatchImage = next(iter(AnimeCleanData))
        CelebaBatchImage = next(iter(CelebaData))

        if not (i % 5):
            fake_anime, cycled_human, fake_human, cycled_anime , same_human , same_anime, \
                gen_anime_loss, gen_human_loss, disc_x_loss, disc_y_loss, total_gen_anime_loss, total_gen_human_loss = trainstep(CelebaBatchImage, AnimeBatchImage)

            with file_writer.as_default():

                tf.summary.image("fake_anime", fake_anime, step=counter)
                tf.summary.image("cycled_human", cycled_human, step=counter)
                tf.summary.image("fake_human", fake_human, step=counter)
                tf.summary.image("cycled_anime", cycled_anime, step=counter)
                tf.summary.image("same_human", same_human, step=counter)
                tf.summary.image("same_anime", same_anime, step=counter)
                tf.summary.scalar("gen_anime_loss",
                                  gen_anime_loss,
                                  step=counter)
                tf.summary.scalar("gen_human_loss",
                                  gen_human_loss,
                                  step=counter)
                tf.summary.scalar("disc_x_loss", disc_x_loss, step=counter)
                tf.summary.scalar("disc_y_loss", disc_y_loss, step=counter)
                tf.summary.scalar("total_gen_anime_loss",
                                  total_gen_anime_loss,
                                  step=counter)
                tf.summary.scalar("total_gen_human_loss",
                                  total_gen_human_loss,
                                  step=counter)

                # tf.summary.image("CelebaBatchImage", CelebaBatchImage, step=counter)
        else:
            trainstep(CelebaBatchImage, AnimeBatchImage)
Beispiel #5
0
def run_tensorflow():
    """
    [summary] This is needed for tensorflow to free up my gpu ram...
    """

    gpus = tf.config.experimental.list_physical_devices("GPU")
    if gpus:
        try:
            # Currently, memory growth needs to be the same across GPUs
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            logical_gpus = tf.config.experimental.list_logical_devices("GPU")
            print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
        except RuntimeError as e:
            # Memory growth must be set before GPUs have been initialized
            print(e)

    mixed_precision = tf.keras.mixed_precision.experimental

    policy = mixed_precision.Policy("mixed_float16")
    mixed_precision.set_policy(policy)

    AnimeCleanData = getAnimeCleanData(BATCH_SIZE=batch_size)
    CelebaData = getCelebaData(BATCH_SIZE=batch_size)

    logdir = "./logs/Startrain_data/" + datetime.now().strftime("%Y%m%d-%H%M%S")
    file_writer = tf.summary.create_file_writer(logdir)

    generator_optimizer = mixed_precision.LossScaleOptimizer(
        tf.keras.optimizers.Adam(2e-4, beta_1=0.5), loss_scale="dynamic"
    )

    discriminator_optimizer = mixed_precision.LossScaleOptimizer(
        tf.keras.optimizers.Adam(2e-4, beta_1=0.5), loss_scale="dynamic"
    )

    up_G_optim = mixed_precision.LossScaleOptimizer(
        tf.keras.optimizers.Adam(2e-4, beta_1=0.5), loss_scale="dynamic"
    )
    up_D_optim = mixed_precision.LossScaleOptimizer(
        tf.keras.optimizers.Adam(2e-4, beta_1=0.5), loss_scale="dynamic"
    )
    up_G = UpsampleGenerator()
    up_D = Discriminator()

    generator = GeneratorV2()
    # input: Batch, 256,256,3
    discriminator = StarDiscriminator()

    checkpoint_path = "./checkpoints/StarTrain"

    ckpt = tf.train.Checkpoint(
        generator = generator,
        discriminator = discriminator,
        generator_optimizer = generator_optimizer,
        discriminator_optimizer = discriminator_optimizer,

    )

    ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

    # if a checkpoint exists, restore the latest checkpoint.
    if ckpt_manager.latest_checkpoint:
        ckpt.restore(ckpt_manager.latest_checkpoint)
        print("Latest checkpoint restored!!")

    # out: Batch, 16, 16, 1
    # x is human, y is anime
    @tf.function
    def trainstep(real_human, real_anime, big_anime):
        with tf.GradientTape(persistent=True) as tape:
            ones = tf.ones_like(real_human)
            neg_ones = tf.ones_like(real_human) * -1

            def get_domain_anime(img):
                return tf.concat([img, ones], 3)

            def get_domain_human(img):
                return tf.concat([img, neg_ones], 3)

            fake_anime = generator(get_domain_anime(real_human), training=True)
            cycled_human = generator(get_domain_human(fake_anime), training=True)

            fake_human = generator(get_domain_human(real_anime), training=True)
            cycled_anime = generator(get_domain_anime(fake_human), training=True)

            # same_human and same_anime are used for identity loss.
            same_human = generator(get_domain_human(real_human), training=True)
            same_anime = generator(get_domain_anime(real_anime), training=True)

            disc_real_human, label_real_human = discriminator(real_human, training=True)
            disc_real_anime, label_real_anime = discriminator(real_anime, training=True)

            disc_fake_human, label_fake_human = discriminator(fake_human, training=True)
            disc_fake_anime, label_fake_anime = discriminator(fake_anime, training=True)

            _, label_cycled_human = discriminator(cycled_human, training=True)
            _, label_cycled_anime = discriminator(cycled_anime, training=True)

            _, label_same_human = discriminator(same_human, training=True)
            _, label_same_anime = discriminator(same_anime, training=True)

            # calculate the loss
            gen_anime_loss = generator_loss(disc_fake_anime)
            gen_human_loss = generator_loss(disc_fake_human)

            total_cycle_loss = cycle_loss(real_human, cycled_human) + cycle_loss(
                real_anime, cycled_anime
            )

            gen_class_loss = (
                discriminator_loss(label_fake_human, label_fake_anime)
                + discriminator_loss(label_cycled_human, label_cycled_anime)
                + discriminator_loss(label_same_human, label_same_anime)
            )

            # Total generator loss = adversarial loss + cycle loss
            total_gen_loss = (
                gen_anime_loss
                + gen_human_loss 
                + gen_class_loss
                + total_cycle_loss * 0.1
                + identity_loss(real_anime, same_anime)
                + identity_loss(real_human, same_human)
            )

            tf.print("gen_anime_loss",gen_anime_loss)
            tf.print("gen_human_loss",gen_human_loss)
            tf.print("gen_class_loss",gen_class_loss)
            tf.print("total_cycle_loss",total_cycle_loss)
            tf.print("identity_loss(real_anime, same_anime)",identity_loss(real_anime, same_anime))
            tf.print("identity_loss(real_human, same_human)",identity_loss(real_human, same_human))

            scaled_total_gen_anime_loss = generator_optimizer.get_scaled_loss(
                total_gen_loss
            )

            disc_human_loss = discriminator_loss(disc_real_human, disc_fake_human)
            disc_anime_loss = discriminator_loss(disc_real_anime, disc_fake_anime)

            # disc_gp_anime = gradient_penalty_star(partial(discriminator, training=True), real_anime,fake_anime )
            # disc_gp_human = gradient_penalty_star(partial(discriminator, training=True), real_human,fake_human )

            disc_loss = disc_human_loss + disc_anime_loss + discriminator_loss(label_real_human,label_real_anime)
            # +disc_gp_anime+disc_gp_human

            scaled_disc_loss = discriminator_optimizer.get_scaled_loss(
                disc_loss
            )

        # Calculate the gradients for generator and discriminator
        generator_gradients =generator_optimizer.get_unscaled_gradients( tape.gradient(
            scaled_total_gen_anime_loss, generator.trainable_variables
        ))
        discriminator_gradients = discriminator_optimizer.get_unscaled_gradients( tape.gradient(
            scaled_disc_loss, discriminator.trainable_variables
        ))

        generator_optimizer.apply_gradients(
            zip(generator_gradients, generator.trainable_variables)
        )

        discriminator_optimizer.apply_gradients(
            zip(discriminator_gradients, discriminator.trainable_variables)
        )

        with tf.GradientTape(persistent=True) as tape:
            real_anime_up = up_G(real_anime)
            fake_anime_up = up_G(fake_anime)

            dis_fake_anime_up = up_D(fake_anime_up)
            dis_real_anime_up = up_D(real_anime_up)
            dis_ori_anime = up_D(big_anime)
            gen_up_loss =  generator_loss(fake_anime_up) + generator_loss(dis_real_anime_up)*0.1
            dis_up_loss = discriminator_loss(dis_ori_anime,dis_fake_anime_up)+discriminator_loss(dis_ori_anime,dis_real_anime_up)*0.1
            scaled_gen_up_loss = up_G_optim.get_scaled_loss(gen_up_loss)
            scaled_disc_loss = up_D_optim.get_scaled_loss(dis_up_loss)

        up_G_gradients =up_G_optim.get_unscaled_gradients( tape.gradient(
            scaled_gen_up_loss, up_G.trainable_variables
        ))
        up_D_gradients = up_D_optim.get_unscaled_gradients( tape.gradient(
            scaled_disc_loss, up_D.trainable_variables
        ))

        up_G_optim.apply_gradients(
            zip(up_G_gradients, up_G.trainable_variables)
        )

        up_D_optim.apply_gradients(
            zip(up_D_gradients, up_D.trainable_variables)
        )
            

        return (
            real_human,
            real_anime,
            fake_anime,
            cycled_human,
            fake_human,
            cycled_anime,
            same_human,
            same_anime,
            fake_anime_up,
            real_anime_up,
            gen_anime_loss,
            gen_human_loss,
            disc_human_loss,
            disc_anime_loss,
            gen_up_loss,
            dis_up_loss
        )

    def process_data_for_display(input_image):
        return input_image * 0.5 + 0.5


    print_string = [
            "real_human",
            "real_anime",
            "fake_anime",
            "cycled_human",
            "fake_human",
            "cycled_anime",
            "same_human",
            "same_anime",
            "fake_anime_up",
            "real_anime_up",
            "gen_anime_loss",
            "gen_human_loss",
            "disc_human_loss",
            "disc_anime_loss",
            "gen_up_loss",
            "dis_up_loss"
    ]

    counter = 0
    i = -1
    while True:
        i = i + 1
        counter = counter + 1
        AnimeBatchImage, BigAnimeBatchImage = next(iter(AnimeCleanData))
        CelebaBatchImage = next(iter(CelebaData))
        print(counter)

        if not (i % 5):
            result = trainstep(CelebaBatchImage, AnimeBatchImage,BigAnimeBatchImage)

            with file_writer.as_default():
                for j in range(len(result)):
                    if j<10:
                        tf.summary.image(
                        print_string[j],
                        process_data_for_display(result[j]),
                        step=counter,
                        )
                    else:
                        tf.summary.scalar(
                        print_string[j],
                        result[j],
                        step=counter,
                        )
                
            ckpt_manager.save()
        else:
            trainstep(CelebaBatchImage, AnimeBatchImage,BigAnimeBatchImage)
Beispiel #6
0
def run_tensorflow():
    """
    [summary] This is needed for tensorflow to free up my gpu ram...
    """

    gpus = tf.config.experimental.list_physical_devices("GPU")
    if gpus:
        try:
            # Currently, memory growth needs to be the same across GPUs
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            logical_gpus = tf.config.experimental.list_logical_devices("GPU")
            print(len(gpus), "Physical GPUs,", len(logical_gpus),
                  "Logical GPUs")
        except RuntimeError as e:
            # Memory growth must be set before GPUs have been initialized
            print(e)

    mixed_precision = tf.keras.mixed_precision.experimental

    policy = mixed_precision.Policy("mixed_float16")
    mixed_precision.set_policy(policy)

    AnimeCleanData = getAnimeCleanData(BATCH_SIZE=batch_size)
    CelebaData = getCelebaData(BATCH_SIZE=batch_size)

    generator_to_anime_optimizer = mixed_precision.LossScaleOptimizer(
        tf.keras.optimizers.Adam(1e-5, beta_1=0.5), loss_scale="dynamic")
    generator_to_human_optimizer = mixed_precision.LossScaleOptimizer(
        tf.keras.optimizers.Adam(1e-4, beta_1=0.5), loss_scale="dynamic")

    generator_anime_upscale_optimizer = mixed_precision.LossScaleOptimizer(
        tf.keras.optimizers.Adam(1e-4, beta_1=0.5), loss_scale="dynamic")

    discriminator_human_optimizer = mixed_precision.LossScaleOptimizer(
        tf.keras.optimizers.Adam(1e-4, beta_1=0.5), loss_scale="dynamic")
    discriminator_anime_optimizer = mixed_precision.LossScaleOptimizer(
        tf.keras.optimizers.Adam(1e-4, beta_1=0.5), loss_scale="dynamic")

    discriminator_anime_upscale_optimizer = mixed_precision.LossScaleOptimizer(
        tf.keras.optimizers.Adam(1e-4, beta_1=0.5), loss_scale="dynamic")

    generator_to_anime = GeneratorV2()
    generator_to_human = GeneratorV2()

    generator_anime_upscale = UpsampleGenerator()

    # input: Batch, 256,256,3
    discriminator_human = LS_Discriminator()
    discriminator_anime = LS_Discriminator()

    discriminator_anime_upscale = LS_Discriminator()

    checkpoint_path = "./checkpoints/LSgan"

    ckpt = tf.train.Checkpoint(
        generator_to_anime=generator_to_anime,
        generator_to_human=generator_to_human,
        generator_anime_upscale=generator_anime_upscale,  # *
        discriminator_human=discriminator_human,
        discriminator_anime=discriminator_anime,
        discriminator_anime_upscale=discriminator_anime_upscale,  # *
        generator_to_anime_optimizer=generator_to_anime_optimizer,
        generator_to_human_optimizer=generator_to_human_optimizer,
        generator_anime_upscale_optimizer=generator_anime_upscale_optimizer,  # *
        discriminator_human_optimizer=discriminator_human_optimizer,
        discriminator_anime_optimizer=discriminator_anime_optimizer,
        discriminator_anime_upscale_optimizer=
        discriminator_anime_upscale_optimizer,  # *
    )

    ckpt_manager = tf.train.CheckpointManager(ckpt,
                                              checkpoint_path,
                                              max_to_keep=5)
    logdir = "./logs/LSgan/" + datetime.now().strftime("%Y%m%d-%H%M%S")
    file_writer = tf.summary.create_file_writer(logdir)
    # if a checkpoint exists, restore the latest checkpoint.
    if ckpt_manager.latest_checkpoint:
        ckpt.restore(ckpt_manager.latest_checkpoint)
        print("Latest checkpoint restored!!")

    # out: Batch, 16, 16, 1
    # x is human, y is anime
    @tf.function
    def trainstep(real_human, real_anime, big_anime):

        with tf.GradientTape(persistent=True) as tape:

            fake_anime = generator_to_anime(real_human, training=True)
            cycled_human = generator_to_human(fake_anime, training=True)

            fake_human = generator_to_human(real_anime, training=True)
            cycled_anime = generator_to_anime(fake_human, training=True)

            # same_human and same_anime are used for identity loss.
            same_human = generator_to_human(real_human, training=True)
            same_anime = generator_to_anime(real_anime, training=True)

            disc_real_human = discriminator_human(real_human, training=True)
            disc_real_anime = discriminator_anime(real_anime, training=True)

            disc_fake_human = discriminator_human(fake_human, training=True)
            disc_fake_anime = discriminator_anime(fake_anime, training=True)

            # assert()
            # calculate the loss
            gen_anime_loss = mse_loss(disc_fake_anime,
                                      tf.zeros_like(disc_fake_anime))
            gen_human_loss = mse_loss(disc_fake_human,
                                      tf.zeros_like(disc_fake_human))

            total_cycle_loss = cycle_loss(real_human,
                                          cycled_human) + cycle_loss(
                                              real_anime, cycled_anime)
            total_gen_anime_loss = (gen_anime_loss + total_cycle_loss +
                                    identity_loss(real_anime, same_anime) +
                                    mse_loss(real_anime, fake_anime) * 0.1)

            total_gen_human_loss = (gen_human_loss + total_cycle_loss +
                                    identity_loss(real_human, same_human) +
                                    mse_loss(real_anime, fake_anime))
            disc_human_loss = mse_loss(
                disc_real_human, tf.ones_like(disc_real_human)) + mse_loss(
                    disc_fake_human, -1 * tf.ones_like(disc_fake_human))
            disc_anime_loss = mse_loss(
                disc_real_anime, tf.ones_like(disc_real_human)) + mse_loss(
                    disc_fake_anime, -1 * tf.ones_like(disc_fake_anime))

            fake_anime_upscale = generator_anime_upscale(fake_anime,
                                                         training=True)
            same_anime_upscale = generator_anime_upscale(same_anime,
                                                         training=True)
            disc_fake_upscale = discriminator_anime_upscale(fake_anime_upscale,
                                                            training=True)
            disc_same_upscale = discriminator_anime_upscale(same_anime_upscale,
                                                            training=True)
            disc_real_big = discriminator_anime_upscale(big_anime,
                                                        training=True)

            gen_upscale_loss = (
                mse_loss(disc_fake_upscale, tf.zeros_like(disc_fake_upscale)) +
                mse_loss(disc_same_upscale, tf.zeros_like(disc_same_upscale)) *
                0.1)
            # tf.print("gen_upscale_loss", gen_upscale_loss)

            print("generator_to_anime.count_params()",
                  generator_to_anime.count_params())
            print("discriminator_anime.count_params()",
                  discriminator_human.count_params())
            print("generator_anime_upscale.count_params()",
                  generator_anime_upscale.count_params())
            print(
                "discriminator_anime_upscale.count_params()",
                discriminator_anime_upscale.count_params(),
            )

            disc_upscale_loss = mse_loss(
                disc_fake_upscale,
                -1 * tf.ones_like(disc_fake_upscale)) + mse_loss(
                    disc_real_big, tf.ones_like(disc_real_big))

            scaled_total_gen_anime_loss = generator_to_anime_optimizer.get_scaled_loss(
                total_gen_anime_loss)
            scaled_total_gen_human_loss = generator_to_human_optimizer.get_scaled_loss(
                total_gen_human_loss)
            scaled_gen_upscale_loss = generator_anime_upscale_optimizer.get_scaled_loss(
                gen_upscale_loss)
            scaled_disc_human_loss = discriminator_human_optimizer.get_scaled_loss(
                disc_human_loss)
            scaled_disc_anime_loss = discriminator_anime_optimizer.get_scaled_loss(
                disc_anime_loss)
            scaled_disc_upscale_loss = discriminator_anime_upscale_optimizer.get_scaled_loss(
                disc_upscale_loss)

        generator_to_anime_gradients = generator_to_anime_optimizer.get_unscaled_gradients(
            tape.gradient(scaled_total_gen_anime_loss,
                          generator_to_anime.trainable_variables))

        generator_to_human_gradients = generator_to_human_optimizer.get_unscaled_gradients(
            tape.gradient(scaled_total_gen_human_loss,
                          generator_to_human.trainable_variables))

        generator_upscale_gradients = generator_anime_upscale_optimizer.get_unscaled_gradients(
            tape.gradient(scaled_gen_upscale_loss,
                          generator_anime_upscale.trainable_variables))

        discriminator_human_gradients = discriminator_human_optimizer.get_unscaled_gradients(
            tape.gradient(scaled_disc_human_loss,
                          discriminator_human.trainable_variables))
        discriminator_anime_gradients = discriminator_anime_optimizer.get_unscaled_gradients(
            tape.gradient(scaled_disc_anime_loss,
                          discriminator_anime.trainable_variables))

        discriminator_upscale_gradients = discriminator_anime_upscale_optimizer.get_unscaled_gradients(
            tape.gradient(
                scaled_disc_upscale_loss,
                discriminator_anime_upscale.trainable_variables,
            ))

        generator_to_anime_optimizer.apply_gradients(
            zip(generator_to_anime_gradients,
                generator_to_anime.trainable_variables))

        generator_to_human_optimizer.apply_gradients(
            zip(generator_to_human_gradients,
                generator_to_human.trainable_variables))

        generator_anime_upscale_optimizer.apply_gradients(
            zip(generator_upscale_gradients,
                generator_anime_upscale.trainable_variables))

        discriminator_human_optimizer.apply_gradients(
            zip(discriminator_human_gradients,
                discriminator_human.trainable_variables))

        discriminator_anime_optimizer.apply_gradients(
            zip(discriminator_anime_gradients,
                discriminator_anime.trainable_variables))

        discriminator_anime_upscale_optimizer.apply_gradients(
            zip(
                discriminator_upscale_gradients,
                discriminator_anime_upscale.trainable_variables,
            ))

        return [
            real_human,
            real_anime,
            fake_anime,
            cycled_human,
            fake_human,
            cycled_anime,
            same_human,
            same_anime,
            fake_anime_upscale,
            same_anime_upscale,
            # real_anime_upscale,
            gen_anime_loss,
            gen_human_loss,
            disc_human_loss,
            disc_anime_loss,
            total_gen_anime_loss,
            total_gen_human_loss,
            gen_upscale_loss,
            disc_upscale_loss,
        ]

    def process_data_for_display(input_image):
        return input_image * 0.5 + 0.5

    counter = 0
    i = -1
    last_time = time.time()
    print_string = [
        "real_human",
        "real_anime",
        "fake_anime",
        "cycled_human",
        "fake_human",
        "cycled_anime",
        "same_human",
        "same_anime",
        "fake_anime_upscale",
        "same_anime_upscale",
        "gen_anime_loss",
        "gen_human_loss",
        "disc_human_loss",
        "disc_anime_loss",
        "total_gen_anime_loss",
        "total_gen_human_loss",
        "gen_upscale_loss",
        "disc_upscale_loss",
    ]

    while True:
        i = i + 1
        counter = counter + 1
        AnimeBatchImage, BigAnimeBatchImage = next(iter(AnimeCleanData))
        CelebaBatchImage = next(iter(CelebaData))
        print(counter)
        print(time.time() - last_time)
        last_time = time.time()
        if not (i % 5):

            result = trainstep(CelebaBatchImage, AnimeBatchImage,
                               BigAnimeBatchImage)
            # print(type(AnimeTrainImage))
            # print(AnimeTrainImage.shape)

            with file_writer.as_default():
                for j in range(len(result)):

                    if j < 10:
                        tf.summary.image(
                            print_string[j],
                            process_data_for_display(result[j]),
                            step=counter,
                        )
                    else:
                        print(print_string[j], result[j])
                        tf.summary.scalar(
                            print_string[j],
                            result[j],
                            step=counter,
                        )

            ckpt_manager.save()
        else:
            # trainstep(CelebaBatchImage, AnimeBatchImage, BigAnimeBatchImage)
            trainstep(CelebaBatchImage, AnimeBatchImage, BigAnimeBatchImage)