예제 #1
0
def train_step_P(net, x, y, optimizerP, args):
    alpha = args['alpha']
    batch_size = x.shape[0]
    # zero the gradient
    net['P'].zero_grad()
    # raal data
    real_data = torch.cat([x, y], 1)
    real_loss = net['P'](real_data).mean()
    # generator fake data
    with torch.autograd.no_grad():
        fake_y = sample_generator(net['G'], x)
        fake_y_data = torch.cat([x, fake_y], 1)
    fake_y_loss = net['P'](fake_y_data.data).mean()
    grad_y_loss = gradient_penalty(real_data, fake_y_data, net['P'],
                                   args['lambda_gp'])
    loss_y = alpha * (fake_y_loss - real_loss)
    loss_yg = alpha * grad_y_loss
    # Denoiser fake data
    with torch.autograd.no_grad():
        fake_x = y - net['D'](y)
        fake_x_data = torch.cat([fake_x, y], 1)
    fake_x_loss = net['P'](fake_x_data.data).mean()
    grad_x_loss = gradient_penalty(real_data, fake_x_data, net['P'],
                                   args['lambda_gp'])
    loss_x = (1 - alpha) * (fake_x_loss - real_loss)
    loss_xg = (1 - alpha) * grad_x_loss
    loss = loss_x + loss_xg + loss_y + loss_yg
    # backward
    loss.backward()
    optimizerP.step()

    return loss, loss_x, loss_xg, loss_y, loss_yg
예제 #2
0
def train_D(A, B, A2B, B2A):
    with tf.GradientTape() as t:
        A_d_logits = D_A(A, training=True)
        B2A_d_logits = D_A(B2A, training=True)
        B_d_logits = D_B(B, training=True)
        A2B_d_logits = D_B(A2B, training=True)

        A_d_loss, B2A_d_loss = d_loss_fn(A_d_logits, B2A_d_logits)
        B_d_loss, A2B_d_loss = d_loss_fn(B_d_logits, A2B_d_logits)
        D_A_gp = loss.gradient_penalty(functools.partial(D_A, training=True), A, B2A)
        D_B_gp = loss.gradient_penalty(functools.partial(D_B, training=True), B, A2B)

        D_loss = (A_d_loss + B2A_d_loss) + (B_d_loss + A2B_d_loss) + (D_A_gp + D_B_gp) * args.gradient_penalty_weight

    D_grad = t.gradient(D_loss, D_A.trainable_variables + D_B.trainable_variables)
    D_optimizer.apply_gradients(zip(D_grad, D_A.trainable_variables + D_B.trainable_variables))

    return {'A_d_loss': A_d_loss + B2A_d_loss,
            'B_d_loss': B_d_loss + A2B_d_loss,
            'D_A_gp': D_A_gp,
            'D_B_gp': D_B_gp}
예제 #3
0
def train_step_P(net, x, y, optimizerP, args): # Discriminator
    ##################
    x_ = x[:, 1, :, :].unsqueeze(1)
    y_ = y[:, 1, :, :].unsqueeze(1)
    ##################
    alpha = args['alpha']
    batch_size =x.shape[0]
    # zero the gradient
    net['P'].zero_grad()
    # raal data
    real_data = torch.cat([x_,y_], 1) ### x<-1, y<-1
    real_loss = net['P'](real_data).mean()
    # generator fake data
    with torch.autograd.no_grad():
        fake_y = sample_generator(net['G'], x) ### x<-3
        fake_y = fake_y[:, 1, :, :].unsqueeze(1)
        fake_y_data = torch.cat([x_, fake_y], 1)
    fake_y_loss = net['P'](fake_y_data.data).mean() ### <<<<<<<<<<
    grad_y_loss = gradient_penalty(real_data, fake_y_data, net['P'], args['lambda_gp'])
    loss_y = alpha * (fake_y_loss - real_loss)
    loss_yg = alpha * grad_y_loss
    # Denoiser fake data
    with torch.autograd.no_grad():
        fake_x = y - net['D'](y) ### <<<<<<<<<<
        fake_x = fake_x[:, 1, :, :].unsqueeze(1)
        fake_x_data = torch.cat([fake_x, y_], 1)
    fake_x_loss = net['P'](fake_x_data.data).mean() ### <<<<<<<<<<
    grad_x_loss = gradient_penalty(real_data, fake_x_data, net['P'], args['lambda_gp'])
    loss_x = (1-alpha) * (fake_x_loss - real_loss)
    loss_xg = (1-alpha) * grad_x_loss
    loss = loss_x + loss_xg + loss_y + loss_yg
    # backward
    loss.backward()
    optimizerP.step()

    return loss, loss_x, loss_xg, loss_y, loss_yg
예제 #4
0
파일: train.py 프로젝트: oldrive/AttGAN
def train_step_D(x_a, atts_a):
    '''

    :param x_a:  具有特征a的真实图片
    :param atts_a:  对应着论文中的特征a,真实图片中具有的特征,值为0/1
    :return:
    '''
    with tf.GradientTape() as tape:
        atts_b = tf.random.shuffle(atts_a)  # 对应论文中的特征b
        atts_b = atts_b * 2 - 1  # 将值从 0/1 ==> -1/1

        real_z = G_enc(x_a, training=True)
        x_b = G_dec(real_z + [atts_b], training=True)  # 具有特征b的生成图片

        # 判别器的损失函数
        xa_logit_D = D(x_a, training=True)
        xb_logit_D = D(x_b, training=True)
        wgan_d_loss = d_loss_fn(xa_logit_D, xb_logit_D)
        gp = loss.gradient_penalty(functools.partial(D, training=True), x_a,
                                   x_b)
        D_loss = wgan_d_loss + config.GP_WEIGHT * gp

        # 分类器的损失函数
        xa_logit_C = C(x_a, training=True)
        C_loss = tf.reduce_mean(
            tf.losses.binary_crossentropy(atts_a, xa_logit_C))  # 二分类损失函数

        # reg_loss = tf.reduce_sum(D.func.reg_losses)  # 源码中还加上了这个损失

        D_and_C_loss = D_loss + config.C_ATTRIBUTE_LOSS_WEIGHT * C_loss
    D_gradients = tape.gradient(
        D_and_C_loss, [*D.trainable_variables, *C.trainable_variables])
    D_and_C_optimizer.apply_gradients(
        zip(D_gradients, [*D.trainable_variables, *C.trainable_variables]))

    return D_and_C_loss
예제 #5
0
def main(opts):
    # Create the data loader
    loader = sunnerData.DataLoader(sunnerData.ImageDataset(
        root=[[opts.path]],
        transform=transforms.Compose([
            sunnertransforms.Resize((128, 128)),
            sunnertransforms.ToTensor(),
            sunnertransforms.ToFloat(),
            sunnertransforms.Transpose(sunnertransforms.BHWC2BCHW),
            sunnertransforms.Normalize(),
        ])),
                                   batch_size=opts.batch_size,
                                   shuffle=True,
                                   num_workers=4)

    # Create the model
    if opts.type == 'style':
        G = StyleGenerator().to(opts.device)
    else:
        G = Generator().to(opts.device)
    D = Discriminator().to(opts.device)

    # Load the pre-trained weight
    if os.path.exists(opts.resume):
        INFO("Load the pre-trained weight!")
        state = torch.load(opts.resume)
        G.load_state_dict(state['G'])
        D.load_state_dict(state['D'])
    else:
        INFO(
            "Pre-trained weight cannot load successfully, train from scratch!")

    # Create the criterion, optimizer and scheduler
    optim_D = optim.Adam(D.parameters(), lr=0.0001, betas=(0.5, 0.999))
    optim_G = optim.Adam(G.parameters(), lr=0.0001, betas=(0.5, 0.999))
    scheduler_D = optim.lr_scheduler.ExponentialLR(optim_D, gamma=0.99)
    scheduler_G = optim.lr_scheduler.ExponentialLR(optim_G, gamma=0.99)

    # Train
    fix_z = torch.randn([opts.batch_size, 512]).to(opts.device)
    Loss_D_list = [0.0]
    Loss_G_list = [0.0]
    for ep in range(opts.epoch):
        bar = tqdm(loader)
        loss_D_list = []
        loss_G_list = []
        for i, (real_img, ) in enumerate(bar):
            # =======================================================================================================
            #   Update discriminator
            # =======================================================================================================
            # Compute adversarial loss toward discriminator
            real_img = real_img.to(opts.device)
            real_logit = D(real_img)
            fake_img = G(torch.randn([real_img.size(0), 512]).to(opts.device))
            fake_logit = D(fake_img.detach())
            d_loss = -(real_logit.mean() -
                       fake_logit.mean()) + gradient_penalty(
                           real_img.data, fake_img.data, D) * 10.0
            loss_D_list.append(d_loss.item())

            # Update discriminator
            optim_D.zero_grad()
            d_loss.backward()
            optim_D.step()

            # =======================================================================================================
            #   Update generator
            # =======================================================================================================
            if i % CRITIC_ITER == 0:
                # Compute adversarial loss toward generator
                fake_img = G(
                    torch.randn([opts.batch_size, 512]).to(opts.device))
                fake_logit = D(fake_img)
                g_loss = -fake_logit.mean()
                loss_G_list.append(g_loss.item())

                # Update generator
                D.zero_grad()
                optim_G.zero_grad()
                g_loss.backward()
                optim_G.step()
            bar.set_description(" {} [G]: {} [D]: {}".format(
                ep, loss_G_list[-1], loss_D_list[-1]))

        # Save the result
        Loss_G_list.append(np.mean(loss_G_list))
        Loss_D_list.append(np.mean(loss_D_list))
        fake_img = G(fix_z)
        save_image(fake_img,
                   os.path.join(opts.det, 'images',
                                str(ep) + '.png'),
                   nrow=4,
                   normalize=True)
        state = {
            'G': G.state_dict(),
            'D': D.state_dict(),
            'Loss_G': Loss_G_list,
            'Loss_D': Loss_D_list,
        }
        torch.save(state, os.path.join(opts.det, 'models', 'latest.pth'))

        scheduler_D.step()
        scheduler_G.step()

    # Plot the total loss curve
    Loss_D_list = Loss_D_list[1:]
    Loss_G_list = Loss_G_list[1:]
    plotLossCurve(opts, Loss_D_list, Loss_G_list)
예제 #6
0
    def trainstep(real_human, real_anime, big_anime):
        with tf.GradientTape(persistent=True) as tape:
            latent_anime = encode_share(encode_anime(real_anime))
            latent_human = encode_share(encode_human(real_human))

            recon_anime = decode_anime(decode_share(latent_anime))
            recon_human = decode_human(decode_share(latent_human))

            fake_anime = decode_anime(decode_share(latent_human))
            latent_human_cycled = encode_share(encode_anime(fake_anime))

            fake_human = decode_anime(decode_share(latent_anime))
            latent_anime_cycled = encode_share(encode_anime(fake_human))

            def kl_loss(mean, log_var):
                loss = 1 + log_var - tf.math.square(mean) + tf.math.exp(
                    log_var)
                loss = tf.reduce_sum(loss, axis=-1) * -0.5
                return loss

            disc_fake = D(fake_anime)
            disc_real = D(real_anime)

            c_dann_anime = c_dann(latent_anime)
            c_dann_human = c_dann(latent_human)

            loss_anime_encode = identity_loss(real_anime, recon_anime) * 3
            loss_human_encode = identity_loss(real_human, recon_human) * 3

            loss_domain_adversarial = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=tf.zeros_like(c_dann_anime),
                    logits=c_dann_anime)) + tf.reduce_mean(
                        tf.nn.sigmoid_cross_entropy_with_logits(
                            labels=tf.ones_like(c_dann_human),
                            logits=c_dann_human))
            loss_domain_adversarial = tf.math.minimum(loss_domain_adversarial,
                                                      100)
            loss_domain_adversarial = loss_domain_adversarial * 0.2
            tf.print(loss_domain_adversarial)

            loss_semantic_consistency = (
                identity_loss(latent_anime, latent_anime_cycled) * 3 +
                identity_loss(latent_human, latent_human_cycled) * 3)

            loss_gan = w_g_loss(disc_fake)

            anime_encode_total_loss = (loss_anime_encode +
                                       loss_domain_adversarial +
                                       loss_semantic_consistency + loss_gan)
            human_encode_total_loss = (loss_human_encode +
                                       loss_domain_adversarial +
                                       loss_semantic_consistency)
            share_encode_total_loss = (loss_anime_encode +
                                       loss_domain_adversarial +
                                       loss_semantic_consistency + loss_gan +
                                       loss_human_encode)

            share_decode_total_loss = loss_anime_encode + loss_human_encode + loss_gan
            anime_decode_total_loss = loss_anime_encode + loss_gan
            human_decode_total_loss = loss_human_encode

            # loss_disc = (
            #     mse_loss(tf.ones_like(disc_fake), disc_fake)
            #     + mse_loss(tf.zeros_like(disc_real), disc_real)
            # ) * 10
            loss_disc = w_d_loss(disc_real, disc_fake)
            loss_disc += gradient_penalty(partial(D, training=True),
                                          real_anime, fake_anime)

            losses = [
                anime_encode_total_loss, human_encode_total_loss,
                share_encode_total_loss, loss_domain_adversarial,
                share_decode_total_loss, anime_decode_total_loss,
                human_decode_total_loss, loss_disc
            ]

            scaled_losses = [
                optim.get_scaled_loss(loss)
                for optim, loss in zip(optims, losses)
            ]

        list_variables = [
            encode_anime.trainable_variables, encode_human.trainable_variables,
            encode_share.trainable_variables, c_dann.trainable_variables,
            decode_share.trainable_variables, decode_anime.trainable_variables,
            decode_human.trainable_variables, D.trainable_variables
        ]
        gan_grad = [
            tape.gradient(scaled_loss, train_variable) for scaled_loss,
            train_variable in zip(scaled_losses, list_variables)
        ]
        gan_grad = [
            optim.get_unscaled_gradients(x)
            for optim, x in zip(optims, gan_grad)
        ]
        for optim, grad, trainable in zip(optims, gan_grad, list_variables):
            optim.apply_gradients(zip(grad, trainable))
        # dis_grad = dis_optim.get_unscaled_gradients(
        #     tape.gradient(scaled_loss_disc, D.trainable_variables)
        # )
        # dis_optim.apply_gradients(zip(dis_grad, D.trainable_variables))

        return (real_human, real_anime, recon_anime, recon_human, fake_anime,
                fake_human, loss_anime_encode, loss_human_encode,
                loss_domain_adversarial, loss_semantic_consistency,
                loss_gan, loss_disc, tf.reduce_mean(disc_fake),
                tf.reduce_mean(disc_real))
예제 #7
0
    def trainstep_D(
        real_human,
        real_anime,
        big_anime,
        fake_anime,
        cycled_anime,
        same_anime,
        fake_human,
        cycled_human,
        same_human,
        fake_anime_upscale,
        cycled_anime_upscale,
        same_anime_upscale,
    ):
        with tf.GradientTape(persistent=True) as tape:
            disc_real_human = discriminator_human(real_human, training=True)
            disc_real_anime = discriminator_anime(real_anime, training=True)

            disc_fake_human = discriminator_human(fake_human, training=True)
            disc_fake_anime = discriminator_anime(fake_anime, training=True)

            disc_real_big = discriminator_anime_upscale(big_anime,
                                                        training=True)
            disc_fake_upscale = discriminator_anime_upscale(fake_anime_upscale,
                                                            training=True)
            # disc_same_upscale = discriminator_anime_upscale(
            #     same_anime_upscale, training=True
            # )

            discriminator_human_gradient_penalty = gradient_penalty(
                functools.partial(discriminator_human, training=True),
                real_human,
                fake_human,
            )
            discriminator_anime_gradient_penalty = gradient_penalty(
                functools.partial(discriminator_anime, training=True),
                real_anime,
                fake_anime,
            )
            discriminator_upscale_gradient_penalty = gradient_penalty(
                functools.partial(discriminator_human, training=True),
                big_anime,
                fake_anime_upscale,
            )

            disc_human_loss = (w_d_loss(disc_real_human, disc_fake_human) +
                               discriminator_human_gradient_penalty)
            disc_anime_loss = (w_d_loss(disc_real_anime, disc_fake_anime) +
                               discriminator_anime_gradient_penalty)
            disc_upscale_loss = (w_d_loss(disc_real_big, disc_fake_upscale) +
                                 discriminator_upscale_gradient_penalty)
            tf.print("disc_real_big", disc_real_big)
            tf.print("disc_fake_upscale", disc_fake_upscale)
            tf.print("disc_upscale_loss", disc_upscale_loss)

        discriminator_human_gradients = tape.gradient(
            disc_human_loss, discriminator_human.trainable_variables)
        discriminator_anime_gradients = tape.gradient(
            disc_anime_loss, discriminator_anime.trainable_variables)
        discriminator_upscale_gradients = tape.gradient(
            disc_upscale_loss, discriminator_anime_upscale.trainable_variables)
        discriminator_human_optimizer.apply_gradients(
            zip(discriminator_human_gradients,
                discriminator_human.trainable_variables))
        discriminator_anime_optimizer.apply_gradients(
            zip(discriminator_anime_gradients,
                discriminator_anime.trainable_variables))
        discriminator_anime_upscale_optimizer.apply_gradients(
            zip(
                discriminator_upscale_gradients,
                discriminator_anime_upscale.trainable_variables,
            ))
    def trainstep(real_human, real_anime, big_anime):

        with tf.GradientTape(persistent=True) as tape:

            fake_anime = generator_to_anime(real_human, training=True)
            cycled_human = generator_to_human(fake_anime, training=True)

            print("generator_to_anime", generator_to_anime.count_params())

            fake_human = generator_to_human(real_anime, training=True)
            cycled_anime = generator_to_anime(fake_human, training=True)

            # same_human and same_anime are used for identity loss.
            same_human = generator_to_human(real_human, training=True)
            same_anime = generator_to_anime(real_anime, training=True)

            disc_real_human = discriminator_human(real_human, training=True)
            disc_real_anime = discriminator_anime(real_anime, training=True)
            print("discriminator_human", discriminator_human.count_params())

            disc_fake_human = discriminator_human(fake_human, training=True)
            disc_fake_anime = discriminator_anime(fake_anime, training=True)

            fake_anime_upscale = generator_anime_upscale(fake_anime,
                                                         training=True)
            real_anime_upscale = generator_anime_upscale(real_anime,
                                                         training=True)

            disc_fake_upscale = discriminator_anime_upscale(fake_anime_upscale,
                                                            training=True)

            disc_real_upscale = discriminator_anime_upscale(real_anime_upscale,
                                                            training=True)
            disc_real_big = discriminator_anime_upscale(big_anime,
                                                        training=True)
            # assert()
            # calculate the loss
            gen_anime_loss = w_g_loss(disc_fake_anime)
            gen_human_loss = w_g_loss(disc_fake_human)

            total_cycle_loss = cycle_loss(real_human,
                                          cycled_human) + cycle_loss(
                                              real_anime, cycled_anime)

            # Total generator loss = adversarial loss + cycle loss
            total_gen_anime_loss = (gen_anime_loss + total_cycle_loss +
                                    identity_loss(real_anime, same_anime))

            total_gen_human_loss = (gen_human_loss + total_cycle_loss +
                                    identity_loss(real_human, same_human))

            gen_upscale_loss = (
                w_g_loss(disc_fake_upscale) + w_g_loss(disc_real_upscale)
                # + mse_loss(big_anime, real_anime_upscale) * 0.1
                + identity_loss(big_anime, real_anime_upscale) * 0.3)

            discriminator_human_gradient_penalty = (gradient_penalty(
                functools.partial(discriminator_human, training=True),
                real_human,
                fake_human,
            ) * 10)
            discriminator_anime_gradient_penalty = (gradient_penalty(
                functools.partial(discriminator_anime, training=True),
                real_anime,
                fake_anime,
            ) * 10)
            discriminator_upscale_gradient_penalty = (gradient_penalty(
                functools.partial(discriminator_human, training=True),
                big_anime,
                fake_anime_upscale,
            ) * 5)
            discriminator_upscale_gradient_penalty += (gradient_penalty(
                functools.partial(discriminator_human, training=True),
                big_anime,
                real_anime_upscale,
            ) * 5)

            disc_human_loss = (w_d_loss(disc_real_human, disc_fake_human) +
                               discriminator_human_gradient_penalty)
            disc_anime_loss = (w_d_loss(disc_real_anime, disc_fake_anime) +
                               discriminator_anime_gradient_penalty)
            # # print("ggg",big_anime.shape)
            disc_upscale_loss = w_d_loss(disc_real_big, disc_fake_upscale)
            disc_upscale_loss += (w_d_loss(disc_real_big, disc_real_upscale) +
                                  discriminator_upscale_gradient_penalty)

        generator_to_anime_gradients = tape.gradient(
            total_gen_anime_loss, generator_to_anime.trainable_variables)
        generator_to_human_gradients = tape.gradient(
            total_gen_human_loss, generator_to_human.trainable_variables)
        generator_upscale_gradients = tape.gradient(
            gen_upscale_loss, generator_anime_upscale.trainable_variables)

        discriminator_human_gradients = tape.gradient(
            disc_human_loss, discriminator_human.trainable_variables)
        discriminator_anime_gradients = tape.gradient(
            disc_anime_loss, discriminator_anime.trainable_variables)

        discriminator_upscale_gradients = tape.gradient(
            disc_upscale_loss, discriminator_anime_upscale.trainable_variables)

        generator_to_anime_optimizer.apply_gradients(
            zip(generator_to_anime_gradients,
                generator_to_anime.trainable_variables))

        generator_to_human_optimizer.apply_gradients(
            zip(generator_to_human_gradients,
                generator_to_human.trainable_variables))

        generator_anime_upscale_optimizer.apply_gradients(
            zip(generator_upscale_gradients,
                generator_anime_upscale.trainable_variables))

        discriminator_human_optimizer.apply_gradients(
            zip(discriminator_human_gradients,
                discriminator_human.trainable_variables))

        discriminator_anime_optimizer.apply_gradients(
            zip(discriminator_anime_gradients,
                discriminator_anime.trainable_variables))

        discriminator_anime_upscale_optimizer.apply_gradients(
            zip(
                discriminator_upscale_gradients,
                discriminator_anime_upscale.trainable_variables,
            ))

        return [
            real_human,
            real_anime,
            fake_anime,
            cycled_human,
            fake_human,
            cycled_anime,
            same_human,
            same_anime,
            fake_anime_upscale,
            real_anime_upscale,
            gen_anime_loss,
            gen_human_loss,
            disc_human_loss,
            disc_anime_loss,
            total_gen_anime_loss,
            total_gen_human_loss,
            gen_upscale_loss,
            disc_upscale_loss,
        ]
예제 #9
0
    def train_epoch(self):
        self.models.train()
        self.epoch += 1

        # record training statistics
        avg_meters = {
            'loss_rec': AverageMeter('Loss Rec', ':.4e'),
            'loss_adv': AverageMeter('Loss Adv', ':.4e'),
            'loss_disc': AverageMeter('Loss Disc', ':.4e'),
            'time': AverageMeter('Time', ':6.3f')
        }
        progress_meter = ProgressMeter(len(self.train_loaders[0]),
                                       avg_meters.values(),
                                       prefix="Epoch: [{}]".format(self.epoch))

        # begin training from minibatches
        for ix, (data_0, data_1) in enumerate(zip(*self.train_loaders)):
            start_time = time.time()

            # load text and labels
            src_0, src_len_0, labels_0 = data_0
            src_0, labels_0 = src_0.to(args.device), labels_0.to(args.device)
            src_1, src_len_1, labels_1 = data_1
            src_1, labels_1 = src_1.to(args.device), labels_1.to(args.device)

            # encode
            encoder = self.models['encoder']
            z_0 = encoder(labels_0, src_0, src_len_0)  # (batch_size, dim_z)
            z_1 = encoder(labels_1, src_1, src_len_1)

            # recon & transfer
            generator = self.models['generator']
            inputs_0 = (z_0, labels_0, src_0)
            h_ori_seq_0, pred_ori_0 = generator(*inputs_0, src_len_0, False)
            h_trans_seq_0_to_1, _ = generator(*inputs_0, src_len_1, True)

            inputs_1 = (z_1, labels_1, src_1)
            h_ori_seq_1, pred_ori_1 = generator(*inputs_1, src_len_1, False)
            h_trans_seq_1_to_0, _ = generator(*inputs_1, src_len_0, True)

            # discriminate real and transfer
            disc_0, disc_1 = self.models['disc_0'], self.models['disc_1']
            d_0_real = disc_0(h_ori_seq_0.detach())  # detached
            d_0_fake = disc_0(h_trans_seq_1_to_0.detach())
            d_1_real = disc_1(h_ori_seq_1.detach())
            d_1_fake = disc_1(h_trans_seq_0_to_1.detach())

            # discriminator loss
            loss_disc = (loss_fn(args.gan_type)(d_0_real, self.ones) +
                         loss_fn(args.gan_type)(d_0_fake, self.zeros) +
                         loss_fn(args.gan_type)(d_1_real, self.ones) +
                         loss_fn(args.gan_type)(d_1_fake, self.zeros))
            # gradient penalty
            if args.gan_type == 'wgan-gp':
                loss_disc += args.gp_weight * gradient_penalty(
                    h_ori_seq_0,  # real data for 0
                    h_trans_seq_1_to_0,  # fake data for 0
                    disc_0)
                loss_disc += args.gp_weight * gradient_penalty(
                    h_ori_seq_1,  # real data for 1
                    h_trans_seq_0_to_1,  # fake data for 1
                    disc_1)
            avg_meters['loss_disc'].update(loss_disc.item(), src_0.size(0))

            self.disc_optimizer.zero_grad()
            loss_disc.backward()
            self.disc_optimizer.step()

            # reconstruction loss
            loss_rec = (
                F.cross_entropy(  # Recon 0 -> 0
                    pred_ori_0.view(-1, pred_ori_0.size(-1)),
                    src_0[1:].view(-1),
                    ignore_index=bert_tokenizer.pad_token_id,
                    reduction='sum') + F.cross_entropy(  # Recon 1 -> 1
                        pred_ori_1.view(-1, pred_ori_1.size(-1)),
                        src_1[1:].view(-1),
                        ignore_index=bert_tokenizer.pad_token_id,
                        reduction='sum')) / (
                            2.0 * args.batch_size
                        )  # match scale with the orginal paper
            avg_meters['loss_rec'].update(loss_rec.item(), src_0.size(0))

            # generator loss
            d_0_fake = disc_0(h_trans_seq_1_to_0)  # not detached
            d_1_fake = disc_1(h_trans_seq_0_to_1)
            loss_adv = (loss_fn(args.gan_type, disc=False)
                        (d_0_fake, self.ones) +
                        loss_fn(args.gan_type, disc=False)(d_1_fake, self.ones)
                        ) / 2.0  # match scale with the original paper
            avg_meters['loss_adv'].update(loss_adv.item(), src_0.size(0))

            # XXX: threshold for training stability
            if (not args.two_stage):
                if (args.threshold is not None and loss_disc < args.threshold):
                    loss = loss_rec + args.rho * loss_adv
                else:
                    loss = loss_rec
            else:  # two_stage training
                if (args.second_stage_num > args.epochs - self.epoch):
                    # last second_stage; flow loss_adv gradients
                    loss = loss_rec + args.rho * loss_adv
                else:
                    loss = loss_rec
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            avg_meters['time'].update(time.time() - start_time)

            # log progress
            if (ix + 1) % args.log_interval == 0:
                progress_meter.display(ix + 1)

        progress_meter.display(len(self.train_loaders[0]))
예제 #10
0
    def trainstep_D(
        real_human,
        real_anime,
        big_anime,
        fake_anime,
        cycled_anime,
        same_anime,
        fake_human,
        cycled_human,
        same_human,
        fake_anime_upscale,
        cycled_anime_upscale,
        same_anime_upscale,
    ):
        with tf.GradientTape(persistent=True) as tape:
            disc_real_human = discriminator_human(real_human, training=True)
            disc_real_anime = discriminator_anime(real_anime, training=True)

            disc_fake_human = discriminator_human(fake_human, training=True)
            disc_fake_anime = discriminator_anime(fake_anime, training=True)

            disc_real_big = discriminator_anime_upscale(big_anime,
                                                        training=True)
            disc_fake_upscale = discriminator_anime_upscale(fake_anime_upscale,
                                                            training=True)
            disc_cycled_upscale = discriminator_anime_upscale(
                cycled_anime_upscale, training=True)
            disc_same_upscale = discriminator_anime_upscale(same_anime_upscale,
                                                            training=True)

            discriminator_human_gradient_penalty = gradient_penalty(
                functools.partial(discriminator_human, training=True),
                real_human,
                fake_human,
            )
            discriminator_anime_gradient_penalty = gradient_penalty(
                functools.partial(discriminator_anime, training=True),
                real_anime,
                fake_anime,
            )
            discriminator_upscale_gradient_penalty = gradient_penalty(
                functools.partial(discriminator_human, training=True),
                big_anime,
                fake_anime_upscale,
            )
            discriminator_upscale_gradient_penalty = gradient_penalty(
                functools.partial(discriminator_human, training=True),
                big_anime,
                cycled_anime_upscale,
            )
            discriminator_upscale_gradient_penalty = gradient_penalty(
                functools.partial(discriminator_human, training=True),
                big_anime,
                same_anime_upscale,
            )

            disc_human_loss = (w_d_loss(disc_real_human, disc_fake_human) +
                               discriminator_human_gradient_penalty)
            disc_anime_loss = (w_d_loss(disc_real_anime, disc_fake_anime) +
                               discriminator_anime_gradient_penalty)
            # # print("ggg",big_anime.shape)
            disc_upscale_loss = (w_d_loss(disc_real_big, disc_fake_upscale) +
                                 w_d_loss(disc_real_big, disc_cycled_upscale) +
                                 w_d_loss(disc_real_big, disc_same_upscale) +
                                 discriminator_upscale_gradient_penalty) / 3.0
            scaled_disc_human_loss = discriminator_human_optimizer.get_scaled_loss(
                disc_human_loss)
            scaled_disc_anime_loss = discriminator_anime_optimizer.get_scaled_loss(
                disc_anime_loss)
            scaled_disc_upscale_loss = discriminator_anime_upscale_optimizer.get_scaled_loss(
                disc_upscale_loss)

        discriminator_human_gradients = discriminator_human_optimizer.get_unscaled_gradients(
            tape.gradient(scaled_disc_human_loss,
                          discriminator_human.trainable_variables))
        discriminator_anime_gradients = discriminator_anime_optimizer.get_unscaled_gradients(
            tape.gradient(scaled_disc_anime_loss,
                          discriminator_anime.trainable_variables))
        discriminator_upscale_gradients = discriminator_anime_upscale_optimizer.get_unscaled_gradients(
            tape.gradient(
                scaled_disc_upscale_loss,
                discriminator_anime_upscale.trainable_variables,
            ))
        discriminator_human_optimizer.apply_gradients(
            zip(discriminator_human_gradients,
                discriminator_human.trainable_variables))
        discriminator_anime_optimizer.apply_gradients(
            zip(discriminator_anime_gradients,
                discriminator_anime.trainable_variables))
        discriminator_anime_upscale_optimizer.apply_gradients(
            zip(
                discriminator_upscale_gradients,
                discriminator_anime_upscale.trainable_variables,
            ))