def forward(self, content, style, alpha=1.0):

        style_feats = self.encode_with_intermediate(style)
        cont_feats = self.encode_with_intermediate(content)

        hidden_cont_feats = self.feature_pyramid(cont_feats[-3:])
        hidden_style_feats = self.feature_pyramid(style_feats[-3:])

        cs, cs_feats = self.pair_inference(cont_feats, style_feats,
                                           hidden_cont_feats,
                                           hidden_style_feats)
        if not self.training:
            return cs

        # perceptual
        loss_c = loss.perceptual_loss(cs_feats[-3:], cont_feats[-3:])

        # Style Loss
        loss_s = loss.adain_style_loss(cs_feats, style_feats)

        result = (cs, loss_c, loss_s)

        if self.use_iden:
            cc, cc_feats = self.pair_inference(cont_feats, cont_feats,
                                               hidden_cont_feats,
                                               hidden_cont_feats, True)
            ss, ss_feats = self.pair_inference(style_feats, style_feats,
                                               hidden_style_feats,
                                               hidden_style_feats, True)
            loss_i = loss.identity_loss(cc, cc_feats, content, cont_feats, ss,
                                        ss_feats, style, style_feats, 50)
            result += (loss_i, )
        else:
            result += (0, )
        if self.use_cx:
            loss_cx = loss.contextual_loss(cs_feats, style_feats)
            result += (loss_cx, )
        else:
            result += (0, )
        result += (loss.total_variation(cs), )
        return result
Exemple #2
0
def main(args):
    os.makedirs(args.log_dir, exist_ok=True)

    # create models
    G_1 = Generator_lr(in_channels=3)
    G_2 = Generator_lr(in_channels=3)
    D_1 = Discriminator_lr(in_channels=3, in_h=16, in_w=16)
    SR = EDSR(n_colors=3)
    G_3 = Generator_sr(in_channels=3)
    D_2 = Discriminator_sr(in_channels=3, in_h=64, in_w=64)

    for model in [G_1, G_2, D_1, SR, G_3, D_2]:
        model.cuda()
        model.train()

    # tensorboard
    writer = SummaryWriter(log_dir=args.log_dir)

    # create optimizors
    optim = {
        'G_1':
        torch.optim.Adam(params=filter(lambda p: p.requires_grad,
                                       G_1.parameters()),
                         lr=args.lr * 5),
        'G_2':
        torch.optim.Adam(params=filter(lambda p: p.requires_grad,
                                       G_2.parameters()),
                         lr=args.lr * 5),
        'D_1':
        torch.optim.Adam(params=filter(lambda p: p.requires_grad,
                                       D_1.parameters()),
                         lr=args.lr),
        'SR':
        torch.optim.Adam(params=filter(lambda p: p.requires_grad,
                                       SR.parameters()),
                         lr=args.lr * 5),
        'G_3':
        torch.optim.Adam(params=filter(lambda p: p.requires_grad,
                                       G_3.parameters()),
                         lr=args.lr),
        'D_2':
        torch.optim.Adam(params=filter(lambda p: p.requires_grad,
                                       D_2.parameters()),
                         lr=args.lr)
    }
    for key in optim.keys():
        optim[key].zero_grad()

    # get dataloader
    train_dataset = DIV2KDataset(root=args.data_path)
    trainloader = DataLoader(train_dataset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=3)

    print('-' * 20)
    print('Start training')
    print('-' * 20)
    iter_index = 0
    for epoch in range(args.epochs):
        G_1.train()
        SR.train()
        start = timeit.default_timer()
        for _, batch in enumerate(trainloader):
            iter_index += 1
            image, label_hr, label_lr = batch
            image = image.cuda()
            label_hr = label_hr.cuda()
            label_lr = label_lr.cuda()
            '''loss for lr GAN'''
            '''update G_1 and G_2'''
            for key in optim.keys():
                optim[key].zero_grad()
            # D loss for D_1
            image_clean = G_1(image)
            loss_D1 = discriminator_loss(discriminator=D_1,
                                         fake=image_clean,
                                         real=label_lr)
            loss_D1.backward()
            optim['D_1'].step()

            # GD loss for G_1
            loss_G1 = generator_discriminator_loss(generator=G_1,
                                                   discriminator=D_1,
                                                   input=image)
            loss_G1.backward()

            # cycle loss for G_1 and G_2
            loss_cycle = 10 * cycle_loss(G_1, G_2, image)
            loss_cycle.backward()

            # idt loss for G_1
            loss_idt = 5 * identity_loss(clean_image=label_lr, generator=G_1)
            loss_idt.backward()

            # tvloss for G_1
            loss_tv = 0.5 * tvloss(input=image, generator=G_1)
            loss_tv.backward()

            # optimize G_1 and G_2
            optim['G_1'].step()
            optim['G_2'].step()

            if iter_index % 100 == 0:
                print(
                    'iter {}: LR: loss_D1={}, loss_GD={}, loss_cycle={}, loss_idt={}, loss_tv={}'
                    .format(iter_index, loss_D1.item(), loss_G1.item(),
                            loss_cycle.item(), loss_idt.item(),
                            loss_tv.item()))
                writer.add_scalar('LR/loss_D1', loss_D1.item(),
                                  iter_index // 100)
                writer.add_scalar('LR/loss_GD', loss_G1.item(),
                                  iter_index // 100)
                writer.add_scalar('LR/loss_cycle', loss_cycle.item(),
                                  iter_index // 100)
                writer.add_scalar('LR/loss_idt', loss_idt.item(),
                                  iter_index // 100)
                writer.add_scalar('LR/loss_tv', loss_tv.item(),
                                  iter_index // 100)
                writer.add_image('LR/origin', image[0], iter_index // 100)
                writer.add_image('LR/denoise',
                                 G_1(image)[0], iter_index // 100)
            '''loss for sr GAN'''
            '''update G_1, SR and G_3'''
            for key in optim.keys():
                optim[key].zero_grad()
            image_clean = G_1(image).detach()
            # D loss for D_2
            image_sr = SR(image_clean)
            loss_D2 = discriminator_loss(discriminator=D_2,
                                         fake=image_sr,
                                         real=label_hr)
            loss_D2.backward()
            optim['D_2'].step()

            # GD loss for SR
            loss_SR = generator_discriminator_loss(generator=SR,
                                                   discriminator=D_2,
                                                   input=image_clean)
            loss_SR.backward()

            # cycle loss for SR and G_3
            loss_cycle = 10 * cycle_loss(SR, G_3, image_clean)
            loss_cycle.backward()

            # idt loss for SR
            loss_idt = 5 * identity_loss_sr(
                clean_image_lr=label_lr, clean_image_hr=label_hr, generator=SR)
            loss_idt.backward()

            # tvloss for SR
            loss_tv = 0.5 * tvloss(input=image_clean, generator=SR)
            loss_tv.backward()

            # optimize G_1, SR and G_3
            optim['G_1'].step()
            optim['SR'].step()
            optim['G_3'].step()

            if iter_index % 100 == 0:
                print(
                    '         SR: loss_D2={}, loss_SR={}, loss_cycle={}, loss_idt={}, loss_tv={}'
                    .format(loss_D2.item(), loss_SR.item(), loss_cycle.item(),
                            loss_idt.item(), loss_tv.item()))
                writer.add_scalar('SR/loss_D2', loss_D2.item(),
                                  iter_index // 100)
                writer.add_scalar('SR/loss_SR', loss_SR.item(),
                                  iter_index // 100)
                writer.add_scalar('SR/loss_cycle', loss_cycle.item(),
                                  iter_index // 100)
                writer.add_scalar('SR/loss_idt', loss_idt.item(),
                                  iter_index // 100)
                writer.add_scalar('SR/loss_tv', loss_tv.item(),
                                  iter_index // 100)
                writer.add_image('SR/origin', image[0], iter_index // 100)
                writer.add_image('SR/clean_image',
                                 G_1(image)[0], iter_index // 100)
                writer.add_image('SR/SR', SR(G_1(image))[0], iter_index // 100)
                writer.flush()

        end = timeit.default_timer()
        print('epoch {}, using {} seconds'.format(epoch, end - start))

        G_1.eval()
        SR.eval()
        image = Image.open('/data/data/DIV2K/unsupervised/lr/0001x4d.png')
        sr_image = resolv_sr(G_1, SR, image)
        # image_tensor = torchvision.transforms.functional.to_tensor(image).unsqueeze(0).cuda()
        # sr_image_tensor = SR(G_1(image_tensor).detach())
        # sr_image = torchvision.transforms.functional.to_pil_image(sr_image_tensor[0].cpu())
        sr_image.save(
            os.path.join(args.log_dir, '0001x4d_sr_{}.png'.format(str(epoch))))

        torch.save(G_1.state_dict(),
                   os.path.join(args.log_dir, 'ep-' + str(epoch) + '_G_1.pkl'))
        torch.save(G_2.state_dict(),
                   os.path.join(args.log_dir, 'ep-' + str(epoch) + '_G_2.pkl'))
        torch.save(D_1.state_dict(),
                   os.path.join(args.log_dir, 'ep-' + str(epoch) + '_D_1.pkl'))
        torch.save(SR.state_dict(),
                   os.path.join(args.log_dir, 'ep-' + str(epoch) + '_SR.pkl'))
        torch.save(G_3.state_dict(),
                   os.path.join(args.log_dir, 'ep-' + str(epoch) + '_G_3.pkl'))
        torch.save(D_2.state_dict(),
                   os.path.join(args.log_dir, 'ep-' + str(epoch) + '_D_2.pkl'))

    writer.close()
    print('Training done.')
    torch.save(G_1.state_dict(),
               os.path.join(args.log_dir, 'final_weights_G_1.pkl'))
    torch.save(G_2.state_dict(),
               os.path.join(args.log_dir, 'final_weights_G_2.pkl'))
    torch.save(D_1.state_dict(),
               os.path.join(args.log_dir, 'final_weights_D_1.pkl'))
    torch.save(SR.state_dict(),
               os.path.join(args.log_dir, 'final_weights_SR.pkl'))
    torch.save(G_3.state_dict(),
               os.path.join(args.log_dir, 'final_weights_G_3.pkl'))
    torch.save(D_2.state_dict(),
               os.path.join(args.log_dir, 'final_weights_D_2.pkl'))

    image = Image.open('/data/data/DIV2K/unsupervised/lr/0001x4d.png')
    image.save(os.path.join(args.log_dir, '0001x4d.png'))
    sr_image = resolv_sr(G_1, SR, image)
    # image_tensor = torchvision.transforms.functional.to_tensor(image).unsqueeze(0).cuda()
    # sr_image_tensor = SR(G_1(image_tensor))
    # sr_image = torchvision.transforms.functional.to_pil_image(sr_image_tensor[0].cpu())
    sr_image.save(os.path.join(args.log_dir, '0001x4d_sr.png'))
Exemple #3
0
    def trainstep(real_human, real_anime, big_anime):
        with tf.GradientTape(persistent=True) as tape:

            fake_anime = generator_to_anime(real_human, training=True)
            cycled_human = generator_to_human(fake_anime, training=True)

            fake_human = generator_to_human(real_anime, training=True)
            cycled_anime = generator_to_anime(fake_human, training=True)

            # same_human and same_anime are used for identity loss.
            same_human = generator_to_human(real_human, training=True)
            same_anime = generator_to_anime(real_anime, training=True)

            disc_real_human = discriminator_human(real_human, training=True)
            disc_real_anime = discriminator_anime(real_anime, training=True)

            disc_fake_human = discriminator_human(fake_human, training=True)
            disc_fake_anime = discriminator_anime(fake_anime, training=True)

            # calculate the loss
            gen_anime_loss = generator_loss(disc_fake_anime)
            gen_human_loss = generator_loss(disc_fake_human)

            total_cycle_loss = cycle_loss(real_human,
                                          cycled_human) + cycle_loss(
                                              real_anime, cycled_anime)

            # Total generator loss = adversarial loss + cycle loss
            total_gen_anime_loss = (gen_anime_loss + total_cycle_loss +
                                    identity_loss(real_anime, same_anime))
            total_gen_anime_loss = generator_to_anime_optimizer.get_scaled_loss(
                total_gen_anime_loss)

            total_gen_human_loss = (gen_human_loss + total_cycle_loss +
                                    identity_loss(real_human, same_human))
            total_gen_human_loss = generator_to_human_optimizer.get_scaled_loss(
                total_gen_human_loss)

            disc_human_loss = discriminator_loss(disc_real_human,
                                                 disc_fake_human)
            disc_human_loss = discriminator_human_optimizer.get_scaled_loss(
                disc_human_loss)

            disc_anime_loss = discriminator_loss(disc_real_anime,
                                                 disc_fake_anime)
            disc_anime_loss = discriminator_anime_optimizer.get_scaled_loss(
                disc_anime_loss)

            # My part

            fake_anime_upscale = generator_anime_upscale(fake_anime)
            cycle_anime_upscale = generator_anime_upscale(real_anime)
            same_anime_upscale = generator_anime_upscale(same_anime)

            disc_fake_upscale = discriminator_anime_upscale(fake_anime_upscale)
            disc_cycle_upscale = discriminator_anime_upscale(
                cycle_anime_upscale)
            disc_same_upscale = discriminator_anime_upscale(same_anime_upscale)

            disc_real_big = discriminator_anime_upscale(big_anime)

            gen_upscale_loss = (
                generator_loss(disc_fake_upscale) * 3 +
                generator_loss(disc_cycle_upscale) +
                generator_loss(disc_same_upscale)
                # + mse_loss(big_anime, cycle_anime_upscale)
                + identity_loss(big_anime, cycle_anime_upscale) * 0.5 +
                identity_loss(big_anime, same_anime_upscale) * 0.5)

            gen_upscale_loss = generator_anime_upscale_optimizer.get_scaled_loss(
                gen_upscale_loss)

            disc_upscale_loss = discriminator_upscale_loss(
                disc_real_big, disc_fake_upscale, disc_cycle_upscale,
                disc_same_upscale)

            disc_upscale_loss = discriminator_anime_upscale_optimizer.get_scaled_loss(
                disc_upscale_loss)

        # Calculate the gradients for generator and discriminator
        generator_to_anime_gradients = tape.gradient(
            total_gen_anime_loss, generator_to_anime.trainable_variables)
        generator_to_human_gradients = tape.gradient(
            total_gen_human_loss, generator_to_human.trainable_variables)

        discriminator_human_gradients = tape.gradient(
            disc_human_loss, discriminator_human.trainable_variables)
        discriminator_anime_gradients = tape.gradient(
            disc_anime_loss, discriminator_anime.trainable_variables)

        generator_upscale_gradients = tape.gradient(
            gen_upscale_loss, generator_anime_upscale.trainable_variables)

        discriminator_upscale_gradients = tape.gradient(
            disc_upscale_loss, discriminator_anime_upscale.trainable_variables)

        # Apply the gradients to the optimizer
        generator_to_anime_gradients = generator_to_anime_optimizer.get_unscaled_gradients(
            generator_to_anime_gradients)
        generator_to_anime_optimizer.apply_gradients(
            zip(generator_to_anime_gradients,
                generator_to_anime.trainable_variables))
        generator_to_human_gradients = generator_to_human_optimizer.get_unscaled_gradients(
            generator_to_human_gradients)
        generator_to_human_optimizer.apply_gradients(
            zip(generator_to_human_gradients,
                generator_to_human.trainable_variables))

        discriminator_human_gradients = discriminator_human_optimizer.get_unscaled_gradients(
            discriminator_human_gradients)
        discriminator_human_optimizer.apply_gradients(
            zip(discriminator_human_gradients,
                discriminator_human.trainable_variables))

        discriminator_anime_gradients = discriminator_anime_optimizer.get_unscaled_gradients(
            discriminator_anime_gradients)
        discriminator_anime_optimizer.apply_gradients(
            zip(discriminator_anime_gradients,
                discriminator_anime.trainable_variables))

        generator_upscale_gradients = generator_anime_upscale_optimizer.get_unscaled_gradients(
            generator_upscale_gradients)
        generator_anime_upscale_optimizer.apply_gradients(
            zip(generator_upscale_gradients,
                generator_anime_upscale.trainable_variables))

        discriminator_upscale_gradients = discriminator_anime_upscale_optimizer.get_unscaled_gradients(
            discriminator_upscale_gradients)
        discriminator_anime_upscale_optimizer.apply_gradients(
            zip(
                discriminator_upscale_gradients,
                discriminator_anime_upscale.trainable_variables,
            ))

        return (
            real_human,
            real_anime,
            fake_anime,
            cycled_human,
            fake_human,
            cycled_anime,
            same_human,
            same_anime,
            fake_anime_upscale,
            same_anime_upscale,
            gen_anime_loss,
            gen_human_loss,
            disc_human_loss,
            disc_anime_loss,
            total_gen_anime_loss,
            total_gen_human_loss,
            gen_upscale_loss,
            disc_upscale_loss,
        )
Exemple #4
0
    def trainstep(real_human, real_anime, big_anime):
        with tf.GradientTape(persistent=True) as tape:
            latent_anime = encode_share(encode_anime(real_anime))
            latent_human = encode_share(encode_human(real_human))

            recon_anime = decode_anime(decode_share(latent_anime))
            recon_human = decode_human(decode_share(latent_human))

            fake_anime = decode_anime(decode_share(latent_human))
            latent_human_cycled = encode_share(encode_anime(fake_anime))

            fake_human = decode_anime(decode_share(latent_anime))
            latent_anime_cycled = encode_share(encode_anime(fake_human))

            def kl_loss(mean, log_var):
                loss = 1 + log_var - tf.math.square(mean) + tf.math.exp(
                    log_var)
                loss = tf.reduce_sum(loss, axis=-1) * -0.5
                return loss

            disc_fake = D(fake_anime)
            disc_real = D(real_anime)

            c_dann_anime = c_dann(latent_anime)
            c_dann_human = c_dann(latent_human)

            loss_anime_encode = identity_loss(real_anime, recon_anime) * 3
            loss_human_encode = identity_loss(real_human, recon_human) * 3

            loss_domain_adversarial = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=tf.zeros_like(c_dann_anime),
                    logits=c_dann_anime)) + tf.reduce_mean(
                        tf.nn.sigmoid_cross_entropy_with_logits(
                            labels=tf.ones_like(c_dann_human),
                            logits=c_dann_human))
            loss_domain_adversarial = tf.math.minimum(loss_domain_adversarial,
                                                      100)
            loss_domain_adversarial = loss_domain_adversarial * 0.2
            tf.print(loss_domain_adversarial)

            loss_semantic_consistency = (
                identity_loss(latent_anime, latent_anime_cycled) * 3 +
                identity_loss(latent_human, latent_human_cycled) * 3)

            loss_gan = w_g_loss(disc_fake)

            anime_encode_total_loss = (loss_anime_encode +
                                       loss_domain_adversarial +
                                       loss_semantic_consistency + loss_gan)
            human_encode_total_loss = (loss_human_encode +
                                       loss_domain_adversarial +
                                       loss_semantic_consistency)
            share_encode_total_loss = (loss_anime_encode +
                                       loss_domain_adversarial +
                                       loss_semantic_consistency + loss_gan +
                                       loss_human_encode)

            share_decode_total_loss = loss_anime_encode + loss_human_encode + loss_gan
            anime_decode_total_loss = loss_anime_encode + loss_gan
            human_decode_total_loss = loss_human_encode

            # loss_disc = (
            #     mse_loss(tf.ones_like(disc_fake), disc_fake)
            #     + mse_loss(tf.zeros_like(disc_real), disc_real)
            # ) * 10
            loss_disc = w_d_loss(disc_real, disc_fake)
            loss_disc += gradient_penalty(partial(D, training=True),
                                          real_anime, fake_anime)

            losses = [
                anime_encode_total_loss, human_encode_total_loss,
                share_encode_total_loss, loss_domain_adversarial,
                share_decode_total_loss, anime_decode_total_loss,
                human_decode_total_loss, loss_disc
            ]

            scaled_losses = [
                optim.get_scaled_loss(loss)
                for optim, loss in zip(optims, losses)
            ]

        list_variables = [
            encode_anime.trainable_variables, encode_human.trainable_variables,
            encode_share.trainable_variables, c_dann.trainable_variables,
            decode_share.trainable_variables, decode_anime.trainable_variables,
            decode_human.trainable_variables, D.trainable_variables
        ]
        gan_grad = [
            tape.gradient(scaled_loss, train_variable) for scaled_loss,
            train_variable in zip(scaled_losses, list_variables)
        ]
        gan_grad = [
            optim.get_unscaled_gradients(x)
            for optim, x in zip(optims, gan_grad)
        ]
        for optim, grad, trainable in zip(optims, gan_grad, list_variables):
            optim.apply_gradients(zip(grad, trainable))
        # dis_grad = dis_optim.get_unscaled_gradients(
        #     tape.gradient(scaled_loss_disc, D.trainable_variables)
        # )
        # dis_optim.apply_gradients(zip(dis_grad, D.trainable_variables))

        return (real_human, real_anime, recon_anime, recon_human, fake_anime,
                fake_human, loss_anime_encode, loss_human_encode,
                loss_domain_adversarial, loss_semantic_consistency,
                loss_gan, loss_disc, tf.reduce_mean(disc_fake),
                tf.reduce_mean(disc_real))
Exemple #5
0
    def trainstep_G(real_human, real_anime, big_anime):
        with tf.GradientTape(persistent=True) as tape:
            fake_anime = generator_to_anime(real_human, training=True)
            cycled_human = generator_to_human(fake_anime, training=True)

            fake_human = generator_to_human(real_anime, training=True)
            cycled_anime = generator_to_anime(fake_human, training=True)

            same_human = generator_to_human(real_human, training=True)
            same_anime = generator_to_anime(real_anime, training=True)

            disc_fake_human = discriminator_human(fake_human, training=True)
            disc_fake_anime = discriminator_anime(fake_anime, training=True)

            fake_anime_upscale = generator_anime_upscale(fake_anime,
                                                         training=True)
            real_anime_upscale = generator_anime_upscale(real_anime,
                                                         training=True)

            cycled_anime_upscale = generator_anime_upscale(cycled_anime,
                                                           training=True)
            same_anime_upscale = generator_anime_upscale(same_anime,
                                                         training=True)

            disc_fake_upscale = discriminator_anime_upscale(fake_anime_upscale,
                                                            training=True)
            disc_real_upscale = discriminator_anime_upscale(real_anime_upscale,
                                                            training=True)
            # calculate the loss
            gen_anime_loss = w_g_loss(disc_fake_anime)
            gen_human_loss = w_g_loss(disc_fake_human)

            total_cycle_loss = cycle_loss(real_human,
                                          cycled_human) + cycle_loss(
                                              real_anime, cycled_anime)

            # Total generator loss = adversarial loss + cycle loss
            total_gen_anime_loss = (gen_anime_loss * 1e3 + total_cycle_loss +
                                    identity_loss(real_anime, same_anime))

            tf.print("gen_anime_loss*1e3", gen_anime_loss * 1e3)
            tf.print("total_cycle_loss", total_cycle_loss)
            tf.print("identity_loss", identity_loss(real_anime, same_anime))
            tf.print("--------------------------")
            total_gen_human_loss = (gen_human_loss * 1e3 + total_cycle_loss +
                                    identity_loss(real_human, same_human))

            gen_upscale_loss = (
                w_g_loss(disc_fake_upscale) * 1e3
                # + w_g_loss(disc_cycle_upscale)
                # + w_g_loss(disc_same_upscale)
                + identity_loss(big_anime, real_anime_upscale) * 1e-6
                # + identity_loss(big_anime, same_anime_upscale)
            )

            tf.print("w_g_loss(disc_fake_upscale)",
                     w_g_loss(disc_fake_upscale))
            tf.print("identity_loss(big_anime, disc_real_upscale)",
                     identity_loss(big_anime, disc_real_upscale))

            # tf.print("w_g_loss(disc_cycle_upscale)", w_g_loss(disc_cycle_upscale))
            # tf.print("w_g_loss(disc_same_upscale)", w_g_loss(disc_same_upscale))
            # tf.print(
            #     "identity_loss(big_anime, cycled_anime_upscale)",
            #     identity_loss(big_anime, cycled_anime_upscale),
            # )
            # tf.print(
            #     "identity_loss(big_anime, same_anime_upscale)",
            #     identity_loss(big_anime, same_anime_upscale),
            # )

        generator_to_anime_gradients = tape.gradient(
            total_gen_anime_loss, generator_to_anime.trainable_variables)
        generator_to_human_gradients = tape.gradient(
            total_gen_human_loss, generator_to_human.trainable_variables)
        generator_upscale_gradients = tape.gradient(
            gen_upscale_loss, generator_anime_upscale.trainable_variables)
        generator_to_anime_optimizer.apply_gradients(
            zip(generator_to_anime_gradients,
                generator_to_anime.trainable_variables))
        generator_to_human_optimizer.apply_gradients(
            zip(generator_to_human_gradients,
                generator_to_human.trainable_variables))
        generator_anime_upscale_optimizer.apply_gradients(
            zip(generator_upscale_gradients,
                generator_anime_upscale.trainable_variables))

        return [
            real_human,
            real_anime,
            fake_anime,
            cycled_anime,
            same_anime,
            fake_human,
            cycled_human,
            same_human,
            fake_anime_upscale,
            cycled_anime_upscale,
            same_anime_upscale,
            gen_anime_loss,
            gen_human_loss,
            total_gen_anime_loss,
            total_gen_human_loss,
            gen_upscale_loss,
        ]
Exemple #6
0
    def trainstep(real_human, real_anime):
        with tf.GradientTape(persistent=True) as tape:

            fake_anime = generator_to_anime(real_human, training=True)
            cycled_human = generator_to_human(fake_anime, training=True)

            fake_human = generator_to_human(real_anime, training=True)
            cycled_anime = generator_to_anime(fake_human, training=True)

            # same_human and same_anime are used for identity loss.
            same_human = generator_to_human(real_human, training=True)
            same_anime = generator_to_anime(real_anime, training=True)

            disc_real_human = discriminator_x(real_human, training=True)
            disc_real_anime = discriminator_y(real_anime, training=True)

            disc_fake_human = discriminator_x(fake_human, training=True)
            disc_fake_anime = discriminator_y(fake_anime, training=True)

            # calculate the loss
            gen_anime_loss = generator_loss(disc_fake_anime)
            gen_human_loss = generator_loss(disc_fake_human)

            total_cycle_loss = cycle_loss(real_human,
                                          cycled_human) + cycle_loss(
                                              real_anime, cycled_anime)

            # Total generator loss = adversarial loss + cycle loss
            total_gen_anime_loss = gen_anime_loss + total_cycle_loss + identity_loss(
                real_anime, same_anime)
            total_gen_human_loss = gen_human_loss + total_cycle_loss + identity_loss(
                real_human, same_human)

            disc_x_loss = discriminator_loss(disc_real_human, disc_fake_human)
            disc_y_loss = discriminator_loss(disc_real_anime, disc_fake_anime)

        # Calculate the gradients for generator and discriminator
        generator_to_anime_gradients = tape.gradient(
            total_gen_anime_loss, generator_to_anime.trainable_variables)
        generator_to_human_gradients = tape.gradient(
            total_gen_human_loss, generator_to_human.trainable_variables)

        discriminator_x_gradients = tape.gradient(
            disc_x_loss, discriminator_x.trainable_variables)
        discriminator_y_gradients = tape.gradient(
            disc_y_loss, discriminator_y.trainable_variables)

        # Apply the gradients to the optimizer
        generator_to_anime_optimizer.apply_gradients(
            zip(generator_to_anime_gradients,
                generator_to_anime.trainable_variables))

        generator_to_human_optimizer.apply_gradients(
            zip(generator_to_human_gradients,
                generator_to_human.trainable_variables))

        discriminator_x_optimizer.apply_gradients(
            zip(discriminator_x_gradients,
                discriminator_x.trainable_variables))

        discriminator_y_optimizer.apply_gradients(
            zip(discriminator_y_gradients,
                discriminator_y.trainable_variables))

        return fake_anime, cycled_human, fake_human, cycled_anime , same_human , same_anime, \
            gen_anime_loss, gen_human_loss, disc_x_loss, disc_y_loss, total_gen_anime_loss, total_gen_human_loss
Exemple #7
0
    def trainstep(real_human, real_anime, big_anime):
        with tf.GradientTape(persistent=True) as tape:
            ones = tf.ones_like(real_human)
            neg_ones = tf.ones_like(real_human) * -1

            def get_domain_anime(img):
                return tf.concat([img, ones], 3)

            def get_domain_human(img):
                return tf.concat([img, neg_ones], 3)

            fake_anime = generator(get_domain_anime(real_human), training=True)
            cycled_human = generator(get_domain_human(fake_anime), training=True)

            fake_human = generator(get_domain_human(real_anime), training=True)
            cycled_anime = generator(get_domain_anime(fake_human), training=True)

            # same_human and same_anime are used for identity loss.
            same_human = generator(get_domain_human(real_human), training=True)
            same_anime = generator(get_domain_anime(real_anime), training=True)

            disc_real_human, label_real_human = discriminator(real_human, training=True)
            disc_real_anime, label_real_anime = discriminator(real_anime, training=True)

            disc_fake_human, label_fake_human = discriminator(fake_human, training=True)
            disc_fake_anime, label_fake_anime = discriminator(fake_anime, training=True)

            _, label_cycled_human = discriminator(cycled_human, training=True)
            _, label_cycled_anime = discriminator(cycled_anime, training=True)

            _, label_same_human = discriminator(same_human, training=True)
            _, label_same_anime = discriminator(same_anime, training=True)

            # calculate the loss
            gen_anime_loss = generator_loss(disc_fake_anime)
            gen_human_loss = generator_loss(disc_fake_human)

            total_cycle_loss = cycle_loss(real_human, cycled_human) + cycle_loss(
                real_anime, cycled_anime
            )

            gen_class_loss = (
                discriminator_loss(label_fake_human, label_fake_anime)
                + discriminator_loss(label_cycled_human, label_cycled_anime)
                + discriminator_loss(label_same_human, label_same_anime)
            )

            # Total generator loss = adversarial loss + cycle loss
            total_gen_loss = (
                gen_anime_loss
                + gen_human_loss 
                + gen_class_loss
                + total_cycle_loss * 0.1
                + identity_loss(real_anime, same_anime)
                + identity_loss(real_human, same_human)
            )

            tf.print("gen_anime_loss",gen_anime_loss)
            tf.print("gen_human_loss",gen_human_loss)
            tf.print("gen_class_loss",gen_class_loss)
            tf.print("total_cycle_loss",total_cycle_loss)
            tf.print("identity_loss(real_anime, same_anime)",identity_loss(real_anime, same_anime))
            tf.print("identity_loss(real_human, same_human)",identity_loss(real_human, same_human))

            scaled_total_gen_anime_loss = generator_optimizer.get_scaled_loss(
                total_gen_loss
            )

            disc_human_loss = discriminator_loss(disc_real_human, disc_fake_human)
            disc_anime_loss = discriminator_loss(disc_real_anime, disc_fake_anime)

            # disc_gp_anime = gradient_penalty_star(partial(discriminator, training=True), real_anime,fake_anime )
            # disc_gp_human = gradient_penalty_star(partial(discriminator, training=True), real_human,fake_human )

            disc_loss = disc_human_loss + disc_anime_loss + discriminator_loss(label_real_human,label_real_anime)
            # +disc_gp_anime+disc_gp_human

            scaled_disc_loss = discriminator_optimizer.get_scaled_loss(
                disc_loss
            )

        # Calculate the gradients for generator and discriminator
        generator_gradients =generator_optimizer.get_unscaled_gradients( tape.gradient(
            scaled_total_gen_anime_loss, generator.trainable_variables
        ))
        discriminator_gradients = discriminator_optimizer.get_unscaled_gradients( tape.gradient(
            scaled_disc_loss, discriminator.trainable_variables
        ))

        generator_optimizer.apply_gradients(
            zip(generator_gradients, generator.trainable_variables)
        )

        discriminator_optimizer.apply_gradients(
            zip(discriminator_gradients, discriminator.trainable_variables)
        )

        with tf.GradientTape(persistent=True) as tape:
            real_anime_up = up_G(real_anime)
            fake_anime_up = up_G(fake_anime)

            dis_fake_anime_up = up_D(fake_anime_up)
            dis_real_anime_up = up_D(real_anime_up)
            dis_ori_anime = up_D(big_anime)
            gen_up_loss =  generator_loss(fake_anime_up) + generator_loss(dis_real_anime_up)*0.1
            dis_up_loss = discriminator_loss(dis_ori_anime,dis_fake_anime_up)+discriminator_loss(dis_ori_anime,dis_real_anime_up)*0.1
            scaled_gen_up_loss = up_G_optim.get_scaled_loss(gen_up_loss)
            scaled_disc_loss = up_D_optim.get_scaled_loss(dis_up_loss)

        up_G_gradients =up_G_optim.get_unscaled_gradients( tape.gradient(
            scaled_gen_up_loss, up_G.trainable_variables
        ))
        up_D_gradients = up_D_optim.get_unscaled_gradients( tape.gradient(
            scaled_disc_loss, up_D.trainable_variables
        ))

        up_G_optim.apply_gradients(
            zip(up_G_gradients, up_G.trainable_variables)
        )

        up_D_optim.apply_gradients(
            zip(up_D_gradients, up_D.trainable_variables)
        )
            

        return (
            real_human,
            real_anime,
            fake_anime,
            cycled_human,
            fake_human,
            cycled_anime,
            same_human,
            same_anime,
            fake_anime_up,
            real_anime_up,
            gen_anime_loss,
            gen_human_loss,
            disc_human_loss,
            disc_anime_loss,
            gen_up_loss,
            dis_up_loss
        )
    def trainstep(real_human, real_anime, big_anime):

        with tf.GradientTape(persistent=True) as tape:

            fake_anime = generator_to_anime(real_human, training=True)
            cycled_human = generator_to_human(fake_anime, training=True)

            print("generator_to_anime", generator_to_anime.count_params())

            fake_human = generator_to_human(real_anime, training=True)
            cycled_anime = generator_to_anime(fake_human, training=True)

            # same_human and same_anime are used for identity loss.
            same_human = generator_to_human(real_human, training=True)
            same_anime = generator_to_anime(real_anime, training=True)

            disc_real_human = discriminator_human(real_human, training=True)
            disc_real_anime = discriminator_anime(real_anime, training=True)
            print("discriminator_human", discriminator_human.count_params())

            disc_fake_human = discriminator_human(fake_human, training=True)
            disc_fake_anime = discriminator_anime(fake_anime, training=True)

            fake_anime_upscale = generator_anime_upscale(fake_anime,
                                                         training=True)
            real_anime_upscale = generator_anime_upscale(real_anime,
                                                         training=True)

            disc_fake_upscale = discriminator_anime_upscale(fake_anime_upscale,
                                                            training=True)

            disc_real_upscale = discriminator_anime_upscale(real_anime_upscale,
                                                            training=True)
            disc_real_big = discriminator_anime_upscale(big_anime,
                                                        training=True)
            # assert()
            # calculate the loss
            gen_anime_loss = w_g_loss(disc_fake_anime)
            gen_human_loss = w_g_loss(disc_fake_human)

            total_cycle_loss = cycle_loss(real_human,
                                          cycled_human) + cycle_loss(
                                              real_anime, cycled_anime)

            # Total generator loss = adversarial loss + cycle loss
            total_gen_anime_loss = (gen_anime_loss + total_cycle_loss +
                                    identity_loss(real_anime, same_anime))

            total_gen_human_loss = (gen_human_loss + total_cycle_loss +
                                    identity_loss(real_human, same_human))

            gen_upscale_loss = (
                w_g_loss(disc_fake_upscale) + w_g_loss(disc_real_upscale)
                # + mse_loss(big_anime, real_anime_upscale) * 0.1
                + identity_loss(big_anime, real_anime_upscale) * 0.3)

            discriminator_human_gradient_penalty = (gradient_penalty(
                functools.partial(discriminator_human, training=True),
                real_human,
                fake_human,
            ) * 10)
            discriminator_anime_gradient_penalty = (gradient_penalty(
                functools.partial(discriminator_anime, training=True),
                real_anime,
                fake_anime,
            ) * 10)
            discriminator_upscale_gradient_penalty = (gradient_penalty(
                functools.partial(discriminator_human, training=True),
                big_anime,
                fake_anime_upscale,
            ) * 5)
            discriminator_upscale_gradient_penalty += (gradient_penalty(
                functools.partial(discriminator_human, training=True),
                big_anime,
                real_anime_upscale,
            ) * 5)

            disc_human_loss = (w_d_loss(disc_real_human, disc_fake_human) +
                               discriminator_human_gradient_penalty)
            disc_anime_loss = (w_d_loss(disc_real_anime, disc_fake_anime) +
                               discriminator_anime_gradient_penalty)
            # # print("ggg",big_anime.shape)
            disc_upscale_loss = w_d_loss(disc_real_big, disc_fake_upscale)
            disc_upscale_loss += (w_d_loss(disc_real_big, disc_real_upscale) +
                                  discriminator_upscale_gradient_penalty)

        generator_to_anime_gradients = tape.gradient(
            total_gen_anime_loss, generator_to_anime.trainable_variables)
        generator_to_human_gradients = tape.gradient(
            total_gen_human_loss, generator_to_human.trainable_variables)
        generator_upscale_gradients = tape.gradient(
            gen_upscale_loss, generator_anime_upscale.trainable_variables)

        discriminator_human_gradients = tape.gradient(
            disc_human_loss, discriminator_human.trainable_variables)
        discriminator_anime_gradients = tape.gradient(
            disc_anime_loss, discriminator_anime.trainable_variables)

        discriminator_upscale_gradients = tape.gradient(
            disc_upscale_loss, discriminator_anime_upscale.trainable_variables)

        generator_to_anime_optimizer.apply_gradients(
            zip(generator_to_anime_gradients,
                generator_to_anime.trainable_variables))

        generator_to_human_optimizer.apply_gradients(
            zip(generator_to_human_gradients,
                generator_to_human.trainable_variables))

        generator_anime_upscale_optimizer.apply_gradients(
            zip(generator_upscale_gradients,
                generator_anime_upscale.trainable_variables))

        discriminator_human_optimizer.apply_gradients(
            zip(discriminator_human_gradients,
                discriminator_human.trainable_variables))

        discriminator_anime_optimizer.apply_gradients(
            zip(discriminator_anime_gradients,
                discriminator_anime.trainable_variables))

        discriminator_anime_upscale_optimizer.apply_gradients(
            zip(
                discriminator_upscale_gradients,
                discriminator_anime_upscale.trainable_variables,
            ))

        return [
            real_human,
            real_anime,
            fake_anime,
            cycled_human,
            fake_human,
            cycled_anime,
            same_human,
            same_anime,
            fake_anime_upscale,
            real_anime_upscale,
            gen_anime_loss,
            gen_human_loss,
            disc_human_loss,
            disc_anime_loss,
            total_gen_anime_loss,
            total_gen_human_loss,
            gen_upscale_loss,
            disc_upscale_loss,
        ]
Exemple #9
0
    def trainstep_G(real_human, real_anime, big_anime):
        with tf.GradientTape(persistent=True) as tape:
            fake_anime = generator_to_anime(real_human, training=True)
            cycled_human = generator_to_human(fake_anime, training=True)

            fake_human = generator_to_human(real_anime, training=True)
            cycled_anime = generator_to_anime(fake_human, training=True)

            same_human = generator_to_human(real_human, training=True)
            same_anime = generator_to_anime(real_anime, training=True)

            disc_fake_human = discriminator_human(fake_human, training=True)
            disc_fake_anime = discriminator_anime(fake_anime, training=True)

            fake_anime_upscale = generator_anime_upscale(fake_anime,
                                                         training=True)
            cycled_anime_upscale = generator_anime_upscale(cycled_anime,
                                                           training=True)
            same_anime_upscale = generator_anime_upscale(same_anime,
                                                         training=True)

            disc_fake_upscale = discriminator_anime_upscale(fake_anime_upscale,
                                                            training=True)
            disc_cycle_upscale = discriminator_anime_upscale(
                cycled_anime_upscale, training=True)
            disc_same_upscale = discriminator_anime_upscale(same_anime_upscale,
                                                            training=True)
            # calculate the loss
            gen_anime_loss = w_g_loss(disc_fake_anime)
            gen_human_loss = w_g_loss(disc_fake_human)

            total_cycle_loss = cycle_loss(real_human,
                                          cycled_human) + cycle_loss(
                                              real_anime, cycled_anime)

            # Total generator loss = adversarial loss + cycle loss
            total_gen_anime_loss = (w_g_loss(disc_fake_anime) * 2 +
                                    total_cycle_loss +
                                    identity_loss(real_anime, same_anime))

            total_gen_human_loss = (gen_human_loss * 2 + total_cycle_loss +
                                    identity_loss(real_human, same_human))

            gen_upscale_loss = (
                w_g_loss(disc_fake_upscale) + w_g_loss(disc_cycle_upscale) +
                w_g_loss(disc_same_upscale) +
                identity_loss(big_anime, cycled_anime_upscale) +
                identity_loss(big_anime, same_anime_upscale))

            scaled_total_gen_anime_loss = generator_to_anime_optimizer.get_scaled_loss(
                total_gen_anime_loss)
            scaled_total_gen_human_loss = generator_to_human_optimizer.get_scaled_loss(
                total_gen_human_loss)
            scaled_gen_upscale_loss = generator_anime_upscale_optimizer.get_scaled_loss(
                gen_upscale_loss)

        generator_to_anime_gradients = generator_to_anime_optimizer.get_unscaled_gradients(
            tape.gradient(scaled_total_gen_anime_loss,
                          generator_to_anime.trainable_variables))
        generator_to_human_gradients = generator_to_human_optimizer.get_unscaled_gradients(
            tape.gradient(scaled_total_gen_human_loss,
                          generator_to_human.trainable_variables))
        generator_upscale_gradients = generator_anime_upscale_optimizer.get_unscaled_gradients(
            tape.gradient(scaled_gen_upscale_loss,
                          generator_anime_upscale.trainable_variables))
        generator_to_anime_optimizer.apply_gradients(
            zip(generator_to_anime_gradients,
                generator_to_anime.trainable_variables))
        generator_to_human_optimizer.apply_gradients(
            zip(generator_to_human_gradients,
                generator_to_human.trainable_variables))
        generator_anime_upscale_optimizer.apply_gradients(
            zip(generator_upscale_gradients,
                generator_anime_upscale.trainable_variables))

        return [
            real_human,
            real_anime,
            fake_anime,
            cycled_anime,
            same_anime,
            fake_human,
            cycled_human,
            same_human,
            fake_anime_upscale,
            cycled_anime_upscale,
            same_anime_upscale,
            gen_anime_loss,
            gen_human_loss,
            total_gen_anime_loss,
            total_gen_human_loss,
            gen_upscale_loss,
        ]
Exemple #10
0
    def trainstep(real_human, real_anime, big_anime):

        with tf.GradientTape(persistent=True) as tape:

            fake_anime = generator_to_anime(real_human, training=True)
            cycled_human = generator_to_human(fake_anime, training=True)

            fake_human = generator_to_human(real_anime, training=True)
            cycled_anime = generator_to_anime(fake_human, training=True)

            # same_human and same_anime are used for identity loss.
            same_human = generator_to_human(real_human, training=True)
            same_anime = generator_to_anime(real_anime, training=True)

            disc_real_human = discriminator_human(real_human, training=True)
            disc_real_anime = discriminator_anime(real_anime, training=True)

            disc_fake_human = discriminator_human(fake_human, training=True)
            disc_fake_anime = discriminator_anime(fake_anime, training=True)

            # assert()
            # calculate the loss
            gen_anime_loss = mse_loss(disc_fake_anime,
                                      tf.zeros_like(disc_fake_anime))
            gen_human_loss = mse_loss(disc_fake_human,
                                      tf.zeros_like(disc_fake_human))

            total_cycle_loss = cycle_loss(real_human,
                                          cycled_human) + cycle_loss(
                                              real_anime, cycled_anime)
            total_gen_anime_loss = (gen_anime_loss + total_cycle_loss +
                                    identity_loss(real_anime, same_anime) +
                                    mse_loss(real_anime, fake_anime) * 0.1)

            total_gen_human_loss = (gen_human_loss + total_cycle_loss +
                                    identity_loss(real_human, same_human) +
                                    mse_loss(real_anime, fake_anime))
            disc_human_loss = mse_loss(
                disc_real_human, tf.ones_like(disc_real_human)) + mse_loss(
                    disc_fake_human, -1 * tf.ones_like(disc_fake_human))
            disc_anime_loss = mse_loss(
                disc_real_anime, tf.ones_like(disc_real_human)) + mse_loss(
                    disc_fake_anime, -1 * tf.ones_like(disc_fake_anime))

            fake_anime_upscale = generator_anime_upscale(fake_anime,
                                                         training=True)
            same_anime_upscale = generator_anime_upscale(same_anime,
                                                         training=True)
            disc_fake_upscale = discriminator_anime_upscale(fake_anime_upscale,
                                                            training=True)
            disc_same_upscale = discriminator_anime_upscale(same_anime_upscale,
                                                            training=True)
            disc_real_big = discriminator_anime_upscale(big_anime,
                                                        training=True)

            gen_upscale_loss = (
                mse_loss(disc_fake_upscale, tf.zeros_like(disc_fake_upscale)) +
                mse_loss(disc_same_upscale, tf.zeros_like(disc_same_upscale)) *
                0.1)
            # tf.print("gen_upscale_loss", gen_upscale_loss)

            print("generator_to_anime.count_params()",
                  generator_to_anime.count_params())
            print("discriminator_anime.count_params()",
                  discriminator_human.count_params())
            print("generator_anime_upscale.count_params()",
                  generator_anime_upscale.count_params())
            print(
                "discriminator_anime_upscale.count_params()",
                discriminator_anime_upscale.count_params(),
            )

            disc_upscale_loss = mse_loss(
                disc_fake_upscale,
                -1 * tf.ones_like(disc_fake_upscale)) + mse_loss(
                    disc_real_big, tf.ones_like(disc_real_big))

            scaled_total_gen_anime_loss = generator_to_anime_optimizer.get_scaled_loss(
                total_gen_anime_loss)
            scaled_total_gen_human_loss = generator_to_human_optimizer.get_scaled_loss(
                total_gen_human_loss)
            scaled_gen_upscale_loss = generator_anime_upscale_optimizer.get_scaled_loss(
                gen_upscale_loss)
            scaled_disc_human_loss = discriminator_human_optimizer.get_scaled_loss(
                disc_human_loss)
            scaled_disc_anime_loss = discriminator_anime_optimizer.get_scaled_loss(
                disc_anime_loss)
            scaled_disc_upscale_loss = discriminator_anime_upscale_optimizer.get_scaled_loss(
                disc_upscale_loss)

        generator_to_anime_gradients = generator_to_anime_optimizer.get_unscaled_gradients(
            tape.gradient(scaled_total_gen_anime_loss,
                          generator_to_anime.trainable_variables))

        generator_to_human_gradients = generator_to_human_optimizer.get_unscaled_gradients(
            tape.gradient(scaled_total_gen_human_loss,
                          generator_to_human.trainable_variables))

        generator_upscale_gradients = generator_anime_upscale_optimizer.get_unscaled_gradients(
            tape.gradient(scaled_gen_upscale_loss,
                          generator_anime_upscale.trainable_variables))

        discriminator_human_gradients = discriminator_human_optimizer.get_unscaled_gradients(
            tape.gradient(scaled_disc_human_loss,
                          discriminator_human.trainable_variables))
        discriminator_anime_gradients = discriminator_anime_optimizer.get_unscaled_gradients(
            tape.gradient(scaled_disc_anime_loss,
                          discriminator_anime.trainable_variables))

        discriminator_upscale_gradients = discriminator_anime_upscale_optimizer.get_unscaled_gradients(
            tape.gradient(
                scaled_disc_upscale_loss,
                discriminator_anime_upscale.trainable_variables,
            ))

        generator_to_anime_optimizer.apply_gradients(
            zip(generator_to_anime_gradients,
                generator_to_anime.trainable_variables))

        generator_to_human_optimizer.apply_gradients(
            zip(generator_to_human_gradients,
                generator_to_human.trainable_variables))

        generator_anime_upscale_optimizer.apply_gradients(
            zip(generator_upscale_gradients,
                generator_anime_upscale.trainable_variables))

        discriminator_human_optimizer.apply_gradients(
            zip(discriminator_human_gradients,
                discriminator_human.trainable_variables))

        discriminator_anime_optimizer.apply_gradients(
            zip(discriminator_anime_gradients,
                discriminator_anime.trainable_variables))

        discriminator_anime_upscale_optimizer.apply_gradients(
            zip(
                discriminator_upscale_gradients,
                discriminator_anime_upscale.trainable_variables,
            ))

        return [
            real_human,
            real_anime,
            fake_anime,
            cycled_human,
            fake_human,
            cycled_anime,
            same_human,
            same_anime,
            fake_anime_upscale,
            same_anime_upscale,
            # real_anime_upscale,
            gen_anime_loss,
            gen_human_loss,
            disc_human_loss,
            disc_anime_loss,
            total_gen_anime_loss,
            total_gen_human_loss,
            gen_upscale_loss,
            disc_upscale_loss,
        ]