Exemple #1
0
def calc_loss_1(args, generator, discriminator, real_img, ada_aug_p):
    losses = {}
    noise = mixing_noise(args.batch, args.latent, args.mixing, args.device)
    fake_img, _ = generator(noise)

    if args.augment:
        real_img_aug, _ = augment(real_img, ada_aug_p)
        fake_img, _ = augment(fake_img, ada_aug_p)
    else:
        real_img_aug = real_img

    fake_pred = discriminator(fake_img)
    real_pred = discriminator(real_img_aug)
    d_loss = d_logistic_loss(real_pred, fake_pred)

    losses['d'] = d_loss
    losses['real_score'] = real_pred.mean()
    losses['fake_score'] = fake_pred.mean()

    return losses, real_pred
Exemple #2
0
def calc_loss_3(args, generator, discriminator, ada_aug_p):
    losses = {}

    noise = mixing_noise(args.batch, args.latent, args.mixing, args.device)
    fake_img, _ = generator(noise)

    if args.augment:
        fake_img, _ = augment(fake_img, ada_aug_p)

    fake_pred = discriminator(fake_img)
    g_loss = g_nonsaturating_loss(fake_pred)
    losses['g'] = g_loss

    return losses
Exemple #3
0
def train(args, loader, loader2, generator, encoder, discriminator, vggnet,
          g_optim, e_optim, d_optim, g_ema, e_ema, device):
    inception = real_mean = real_cov = mean_latent = None
    if args.eval_every > 0:
        inception = nn.DataParallel(load_patched_inception_v3()).to(device)
        inception.eval()
        with open(args.inception, "rb") as f:
            embeds = pickle.load(f)
            real_mean = embeds["mean"]
            real_cov = embeds["cov"]
    if get_rank() == 0:
        if args.eval_every > 0:
            with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f:
                f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n")
        if args.log_every > 0:
            with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n")

    loader = sample_data(loader)
    pbar = range(args.iter)
    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    dynamic_ncols=True,
                    smoothing=0.01)

    mean_path_length = 0
    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    d_loss_val = r1_val = real_score_val = recx_score_val = 0
    loss_dict = {
        "d": torch.tensor(0.0, device=device),
        "r1": torch.tensor(0.0, device=device)
    }
    avg_pix_loss = util.AverageMeter()
    avg_vgg_loss = util.AverageMeter()

    if args.distributed:
        g_module = generator.module
        e_module = encoder.module
        d_module = discriminator.module
    else:
        g_module = generator
        e_module = encoder
        d_module = discriminator

    d_weight = torch.tensor(1.0, device=device)
    last_layer = None
    if args.use_adaptive_weight:
        if args.distributed:
            last_layer = generator.module.get_last_layer()
        else:
            last_layer = generator.get_last_layer()

    # accum = 0.5 ** (32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0
    r_t_dict = {'real': 0, 'recx': 0}  # r_t stat
    g_scale = 1
    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length,
                                      args.ada_every, device)

    sample_z = torch.randn(args.n_sample, args.latent, device=device)
    sample_x = load_real_samples(args, loader)
    if sample_x.ndim > 4:
        sample_x = sample_x[:, 0, ...]

    n_step_max = max(args.n_step_d, args.n_step_e)

    requires_grad(g_ema, False)
    requires_grad(e_ema, False)

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")
            break

        if args.debug: util.seed_everything(i)
        real_imgs = [next(loader).to(device) for _ in range(n_step_max)]

        # Train Discriminator
        if args.lambda_adv > 0:
            requires_grad(generator, False)
            requires_grad(encoder, False)
            requires_grad(discriminator, True)
            for step_index in range(args.n_step_d):
                real_img = real_imgs[step_index]
                latent_real, _ = encoder(real_img)
                rec_img, _ = generator([latent_real], input_is_latent=True)
                if args.augment:
                    real_img_aug, _ = augment(real_img, ada_aug_p)
                    rec_img_aug, _ = augment(rec_img, ada_aug_p)
                else:
                    real_img_aug = real_img
                    rec_img_aug = rec_img
                real_pred = discriminator(real_img_aug)
                rec_pred = discriminator(rec_img_aug)
                d_loss_real = F.softplus(-real_pred).mean()
                d_loss_rec = F.softplus(rec_pred).mean()
                loss_dict["real_score"] = real_pred.mean()
                loss_dict["recx_score"] = rec_pred.mean()

                d_loss = d_loss_real + d_loss_rec * args.lambda_rec_d
                loss_dict["d"] = d_loss

                discriminator.zero_grad()
                d_loss.backward()
                d_optim.step()

            if args.augment and args.augment_p == 0:
                ada_aug_p = ada_augment.tune(real_pred)
                r_t_stat = ada_augment.r_t_stat
            # Compute batchwise r_t
            r_t_dict['real'] = torch.sign(real_pred).sum().item() / args.batch

            d_regularize = i % args.d_reg_every == 0
            if d_regularize:
                real_img.requires_grad = True
                if args.augment:
                    real_img_aug, _ = augment(real_img, ada_aug_p)
                else:
                    real_img_aug = real_img
                real_pred = discriminator(real_img_aug)
                r1_loss = d_r1_loss(real_pred, real_img)
                discriminator.zero_grad()
                (args.r1 / 2 * r1_loss * args.d_reg_every +
                 0 * real_pred[0]).backward()
                d_optim.step()
            loss_dict["r1"] = r1_loss

            r_t_dict['recx'] = torch.sign(rec_pred).sum().item() / args.batch

        # Train AutoEncoder
        requires_grad(encoder, True)
        requires_grad(generator, True)
        requires_grad(discriminator, False)
        if args.debug: util.seed_everything(i)
        pix_loss = vgg_loss = adv_loss = torch.tensor(0., device=device)
        for step_index in range(args.n_step_e):
            real_img = real_imgs[step_index]
            latent_real, _ = encoder(real_img)
            rec_img, _ = generator([latent_real], input_is_latent=True)
            if args.lambda_pix > 0:
                if args.pix_loss == 'l2':
                    pix_loss = torch.mean((rec_img - real_img)**2)
                elif args.pix_loss == 'l1':
                    pix_loss = F.l1_loss(rec_img, real_img)
            if args.lambda_vgg > 0:
                vgg_loss = torch.mean((vggnet(real_img) - vggnet(rec_img))**2)
            if args.lambda_adv > 0:
                if args.augment:
                    rec_img_aug, _ = augment(rec_img, ada_aug_p)
                else:
                    rec_img_aug = rec_img
                rec_pred = discriminator(rec_img_aug)
                adv_loss = g_nonsaturating_loss(rec_pred)

            if args.use_adaptive_weight and i >= args.disc_iter_start:
                nll_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg
                g_loss = adv_loss * args.lambda_adv
                d_weight = calculate_adaptive_weight(nll_loss,
                                                     g_loss,
                                                     last_layer=last_layer)

            ae_loss = (pix_loss * args.lambda_pix +
                       vgg_loss * args.lambda_vgg +
                       d_weight * adv_loss * args.lambda_adv)
            loss_dict["ae"] = ae_loss
            loss_dict["pix"] = pix_loss
            loss_dict["vgg"] = vgg_loss
            loss_dict["adv"] = adv_loss

            encoder.zero_grad()
            generator.zero_grad()
            ae_loss.backward()
            e_optim.step()
            if args.g_decay is not None:
                scale_grad(generator, g_scale)
                g_scale *= args.g_decay
            g_optim.step()

        g_regularize = args.g_reg_every > 0 and i % args.g_reg_every == 0
        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing,
                                 device)
            fake_img, latents = generator(noise, return_latents=True)
            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)
            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss
            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]
            weighted_path_loss.backward()
            g_optim.step()
            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())
        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        # Update EMA
        ema_nimg = args.ema_kimg * 1000
        if args.ema_rampup is not None:
            ema_nimg = min(ema_nimg, i * args.batch * args.ema_rampup)
        accum = 0.5**(args.batch / max(ema_nimg, 1e-8))
        accumulate(g_ema, g_module, 0 if args.no_ema_g else accum)
        accumulate(e_ema, e_module, 0 if args.no_ema_e else accum)

        loss_reduced = reduce_loss_dict(loss_dict)
        ae_loss_val = loss_reduced["ae"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()
        pix_loss_val = loss_reduced["pix"].mean().item()
        vgg_loss_val = loss_reduced["vgg"].mean().item()
        adv_loss_val = loss_reduced["adv"].mean().item()
        if args.lambda_adv > 0:
            d_loss_val = loss_reduced["d"].mean().item()
            r1_val = loss_reduced["r1"].mean().item()
            real_score_val = loss_reduced["real_score"].mean().item()
            recx_score_val = loss_reduced["recx_score"].mean().item()
        avg_pix_loss.update(pix_loss_val, real_img.shape[0])
        avg_vgg_loss.update(vgg_loss_val, real_img.shape[0])

        if get_rank() == 0:
            pbar.set_description((
                f"d: {d_loss_val:.4f}; ae: {ae_loss_val:.4f}; r1: {r1_val:.4f}; "
                f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                f"augment: {ada_aug_p:.4f}; "
                f"d_weight: {d_weight.item():.4f}; "
                f"pix: {pix_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; adv: {adv_loss_val:.4f}"
            ))

            if i % args.log_every == 0:
                with torch.no_grad():
                    g_ema.eval()
                    e_ema.eval()
                    nrow = int(args.n_sample**0.5)
                    nchw = list(sample_x.shape)[1:]
                    # Reconstruction of real images
                    latent_x, _ = e_ema(sample_x)
                    rec_real, _ = g_ema([latent_x], input_is_latent=True)
                    sample = torch.cat(
                        (sample_x.reshape(args.n_sample // nrow, nrow, *nchw),
                         rec_real.reshape(args.n_sample // nrow, nrow, *nchw)),
                        1)
                    utils.save_image(
                        sample.reshape(2 * args.n_sample, *nchw),
                        os.path.join(args.log_dir, 'sample',
                                     f"{str(i).zfill(6)}-recon.png"),
                        nrow=nrow,
                        normalize=True,
                        value_range=(-1, 1),
                    )
                    ref_pix_loss = torch.sum(torch.abs(sample_x - rec_real))
                    ref_vgg_loss = torch.mean(
                        (vggnet(sample_x) -
                         vggnet(rec_real))**2) if vggnet is not None else 0
                    # Fixed fake samples and reconstructions
                    sample, _ = g_ema([sample_z])
                    utils.save_image(
                        sample,
                        os.path.join(args.log_dir, 'sample',
                                     f"{str(i).zfill(6)}-sample.png"),
                        nrow=int(args.n_sample**0.5),
                        normalize=True,
                        value_range=(-1, 1),
                    )

                with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                    f.write((
                        f"{i:07d}; "
                        f"d: {d_loss_val:.4f}; r1: {r1_val:.4f}; "
                        f"path: {path_loss_val:.4f}; mean_path: {mean_path_length_avg:.4f}; "
                        f"augment: {ada_aug_p:.4f}; {'; '.join([f'{k}: {r_t_dict[k]:.4f}' for k in r_t_dict])}; "
                        f"real_score: {real_score_val:.4f}; recx_score: {recx_score_val:.4f}; "
                        f"pix: {avg_pix_loss.avg:.4f}; vgg: {avg_vgg_loss.avg:.4f}; "
                        f"ref_pix: {ref_pix_loss.item():.4f}; ref_vgg: {ref_vgg_loss.item():.4f}; "
                        f"d_weight: {d_weight.item():.4f}; "
                        f"\n"))

            if wandb and args.wandb:
                wandb.log({
                    "Discriminator": d_loss_val,
                    "Augment": ada_aug_p,
                    "Rt": r_t_stat,
                    "R1": r1_val,
                    "Path Length Regularization": path_loss_val,
                    "Mean Path Length": mean_path_length,
                    "Real Score": real_score_val,
                    "Path Length": path_length_val,
                })

            if args.eval_every > 0 and i % args.eval_every == 0:
                with torch.no_grad():
                    fid_sa = fid_re = fid_sr = 0
                    g_ema.eval()
                    e_ema.eval()
                    if args.truncation < 1:
                        mean_latent = g_ema.mean_latent(4096)
                    # Real reconstruction FID
                    if 'fid_recon' in args.which_metric:
                        features = extract_feature_from_reconstruction(
                            e_ema,
                            g_ema,
                            inception,
                            args.truncation,
                            mean_latent,
                            loader2,
                            args.device,
                            mode='recon',
                        ).numpy()
                        sample_mean = np.mean(features, 0)
                        sample_cov = np.cov(features, rowvar=False)
                        fid_re = calc_fid(sample_mean, sample_cov, real_mean,
                                          real_cov)
                with open(os.path.join(args.log_dir, 'log_fid.txt'),
                          'a+') as f:
                    f.write(f"{i:07d}; rec_real: {float(fid_re):.4f};\n")

            if i % args.save_every == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "e_ema": e_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight',
                                 f"{str(i).zfill(6)}.pt"),
                )

            if i % args.save_latest_every == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "e_ema": e_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight', f"latest.pt"),
                )
Exemple #4
0
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema,
          device, save_dir):
    loader = sample_data(loader)

    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    dynamic_ncols=True,
                    smoothing=0.01)

    mean_path_length = 0

    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}

    if args.distributed:
        g_module = generator.module
        d_module = discriminator.module

    else:
        g_module = generator
        d_module = discriminator

    accum = 0.5**(32 / (10 * 1000))
    ada_augment = torch.tensor([0.0, 0.0], device=device)
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    ada_aug_step = args.ada_target / args.ada_length
    r_t_stat = 0

    sample_z = torch.randn(args.n_sample, args.latent, device=device)

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")

            break

        real_img = next(loader)
        real_img = real_img.to(device)

        requires_grad(generator, False)
        requires_grad(discriminator, True)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)

        if args.augment:
            real_img_aug, _ = augment(real_img, ada_aug_p)
            fake_img, _ = augment(fake_img, ada_aug_p)

        else:
            real_img_aug = real_img

        fake_pred = discriminator(fake_img)
        real_pred = discriminator(real_img_aug)
        d_loss = d_logistic_loss(real_pred, fake_pred)

        loss_dict["d"] = d_loss
        loss_dict["real_score"] = real_pred.mean()
        loss_dict["fake_score"] = fake_pred.mean()

        discriminator.zero_grad()
        d_loss.backward()
        d_optim.step()

        if args.augment and args.augment_p == 0:
            ada_augment_data = torch.tensor(
                (torch.sign(real_pred).sum().item(), real_pred.shape[0]),
                device=device)
            ada_augment += reduce_sum(ada_augment_data)

            if ada_augment[1] > 255:
                pred_signs, n_pred = ada_augment.tolist()

                r_t_stat = pred_signs / n_pred

                if r_t_stat > args.ada_target:
                    sign = 1

                else:
                    sign = -1

                ada_aug_p += sign * ada_aug_step * n_pred
                ada_aug_p = min(1, max(0, ada_aug_p))
                ada_augment.mul_(0)

        d_regularize = i % args.d_reg_every == 0

        if d_regularize:
            real_img.requires_grad = True
            real_pred = discriminator(real_img)
            r1_loss = d_r1_loss(real_pred, real_img)

            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every +
             0 * real_pred[0]).backward()

            d_optim.step()

        loss_dict["r1"] = r1_loss

        requires_grad(generator, True)
        requires_grad(discriminator, False)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)

        if args.augment:
            fake_img, _ = augment(fake_img, ada_aug_p)

        fake_pred = discriminator(fake_img)
        g_loss = g_nonsaturating_loss(fake_pred)

        loss_dict["g"] = g_loss

        generator.zero_grad()
        g_loss.backward()
        g_optim.step()

        g_regularize = i % args.g_reg_every == 0

        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing,
                                 device)
            fake_img, latents = generator(noise, return_latents=True)

            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)

            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss

            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()

            g_optim.step()

            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())

        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        accumulate(g_ema, g_module, accum)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()

        if get_rank() == 0:
            pbar.set_description((
                f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
                f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                f"augment: {ada_aug_p:.4f}"))

            if wandb and args.wandb:
                wandb.log({
                    "Generator": g_loss_val,
                    "Discriminator": d_loss_val,
                    "Augment": ada_aug_p,
                    "Rt": r_t_stat,
                    "R1": r1_val,
                    "Path Length Regularization": path_loss_val,
                    "Mean Path Length": mean_path_length,
                    "Real Score": real_score_val,
                    "Fake Score": fake_score_val,
                    "Path Length": path_length_val,
                })

            if i % 1000 == 0:  # save some samples

                with torch.no_grad():
                    g_ema.eval()
                    sample, _ = g_ema([sample_z])

                    length = int(round(np.sqrt(args.n_sample), 0))
                    black_bar_width = 7
                    template = np.zeros((length * int(args.size) +
                                         ((length - 1) * black_bar_width),
                                         length * (int(args.size) * 2) +
                                         ((length - 1) * black_bar_width), 3))
                    position = [0, 0]
                    line_count = 1

                    for j in range(args.n_sample):

                        # save the first image from this pair
                        utils.save_image(torch.from_numpy(
                            sample.detach().cpu().numpy()[j, :3, :, :]),
                                         "temp.png",
                                         nrow=int(args.n_sample**0.5),
                                         normalize=True,
                                         range=(-1, 1))

                        first_image = np.asarray(Image.open("temp.png"))

                        # save the second image from this pair
                        utils.save_image(torch.from_numpy(
                            sample.detach().cpu().numpy()[j, 3:, :, :]),
                                         "temp.png",
                                         nrow=int(args.n_sample**0.5),
                                         normalize=True,
                                         range=(-1, 1))

                        second_image = np.asarray(Image.open("temp.png"))

                        # --------------------------------------------------------------------------
                        # unfold the pair and add black borders, if needed
                        # --------------------------------------------------------------------------
                        unfolded = np.column_stack((first_image, second_image))

                        #if needed, add a black column
                        if (((j + 1) % length) != 0):
                            unfolded = np.column_stack(
                                (unfolded,
                                 np.zeros((args.size, black_bar_width, 3))))

                        # if needed, add a black line
                        if ((line_count % length) != 0):
                            unfolded = np.vstack(
                                (unfolded,
                                 np.zeros(
                                     (black_bar_width, unfolded.shape[1], 3))))

                        # --------------------------------------------------------------------------------------------------------------------------------------------------
                        # add the pair to the template
                        # --------------------------------------------------------------------------------------------------------------------------------------------------
                        if ((line_count % length) !=
                                0):  # it's not the last line
                            if (((j + 1) % length) !=
                                    0):  # it's not the last column
                                template[position[0]:position[0] + args.size +
                                         black_bar_width,
                                         position[1]:position[1] +
                                         (args.size * 2) +
                                         black_bar_width, :] = unfolded

                            else:  # it's the last column
                                template[position[0]:position[0] + args.size +
                                         black_bar_width,
                                         position[1]:position[1] +
                                         (args.size * 2), :] = unfolded

                        else:  # it's the last line
                            if (((j + 1) % length) !=
                                    0):  # it's not the last column
                                template[position[0]:position[0] + args.size,
                                         position[1]:position[1] +
                                         (args.size * 2) +
                                         black_bar_width, :] = unfolded

                            else:  # it's the last column
                                template[position[0]:position[0] + args.size,
                                         position[1]:position[1] +
                                         (args.size * 2), :] = unfolded

                        # -------------------------------------------------------------
                        # prepare the next iteration
                        # -------------------------------------------------------------
                        if (((j + 1) % length) == 0):
                            position = [position[0] + unfolded.shape[0], 0]
                            line_count += 1

                        else:
                            position = [
                                position[0], position[1] + unfolded.shape[1]
                            ]

                    Image.fromarray(template.astype(np.uint8)).convert(
                        "RGBA").save(save_dir + "/samples/" +
                                     f"/{str(i).zfill(6)}.png")
                    os.remove("temp.png")

            if i % 2000 == 0:  #save the model
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                    },
                    save_dir + f"/checkpoints/{str(i).zfill(6)}.pt",
                )
    def training_step(self, batch, batch_idx, optimizer_idx):
        real_img = batch[0]
        batch_size = real_img.shape[0]
        if optimizer_idx == 0:
            requires_grad(self.generator, False)
            requires_grad(self.discriminator, True)

            noise = mixing_noise(batch_size, self.hparams.latent,
                                 self.hparams.mixing, self.device)
            fake_img, _ = self.generator(noise)

            if self.hparams.augment:
                real_img_aug, _ = augment(real_img, self.ada_aug_p)
                fake_img, _ = augment(fake_img, self.ada_aug_p)
            else:
                real_img_aug = real_img

            fake_pred = self.discriminator(fake_img)
            real_pred = self.discriminator(real_img_aug)
            d_loss = d_logistic_loss(real_pred, fake_pred)

            self.log('real_score', real_pred.mean())
            self.log('fake_score', fake_pred.mean())
            self.log('d', d_loss, prog_bar=True)

            if self.use_ada_augment:
                self.ada_aug_p = self.ada_augment.tune(real_pred)
            if self.hparams.augment:
                self.log('ada_aug_p', self.ada_aug_p, prog_bar=True)

            return {'loss': d_loss}
        if optimizer_idx == 1:
            if self.trainer.global_step % self.hparams.d_reg_every != 0:
                return

            requires_grad(self.generator, False)
            requires_grad(self.discriminator, True)
            real_img.requires_grad = True
            if self.hparams.augment:
                real_img_aug, _ = augment(real_img, self.ada_aug_p)
            else:
                real_img_aug = real_img
            real_pred = self.discriminator(real_img_aug)
            r1_loss = d_r1_loss(real_pred, real_img)

            self.log('r1', r1_loss, prog_bar=True)
            return {
                'loss':
                (self.hparams.r1 / 2 * r1_loss * self.hparams.d_reg_every +
                 0 * real_pred[0])
            }
        if optimizer_idx == 2:
            requires_grad(self.generator, True)
            requires_grad(self.discriminator, False)

            if self.hparams.top_k_batches > 0:
                with torch.no_grad():
                    noises, scores = [], []
                    rand = random.random()
                    for _ in range(self.hparams.top_k_batches):
                        noise = mixing_noise(self.hparams.top_k_batch_size,
                                             self.hparams.latent,
                                             self.hparams.mixing,
                                             self.device,
                                             rand=rand)
                        fake_img, _ = self.generator(noise)
                        if self.hparams.augment:
                            fake_img, _ = augment(fake_img, self.ada_aug_p)
                        score = self.discriminator(fake_img)
                        noises.append(noise)
                        scores.append(score)
                    scores = torch.cat(scores)
                    best_score_ids = torch.argsort(
                        scores, descending=True)[:batch_size].squeeze(1)
                    noise = [
                        torch.cat([n[idx]
                                   for n in noises], dim=0)[best_score_ids]
                        for idx in range(len(noises[0]))
                    ]
            else:
                noise = mixing_noise(batch_size, self.hparams.latent,
                                     self.hparams.mixing, self.device)

            fake_img, _ = self.generator(noise)

            if self.hparams.augment:
                fake_img, _ = augment(fake_img, self.ada_aug_p)

            fake_pred = self.discriminator(fake_img)
            g_loss = g_nonsaturating_loss(fake_pred)

            self.log('g', g_loss, prog_bar=True)
            return {'loss': g_loss}
        if optimizer_idx == 3:
            if self.trainer.global_step % self.hparams.g_reg_every != 0:
                return

            requires_grad(self.generator, True)
            requires_grad(self.discriminator, False)
            path_batch_size = max(1,
                                  batch_size // self.hparams.path_batch_shrink)
            noise = mixing_noise(path_batch_size, self.hparams.latent,
                                 self.hparams.mixing, self.device)
            fake_img, latents = self.generator(noise, return_latents=True)

            path_loss, self.mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, self.mean_path_length)

            weighted_path_loss = self.hparams.path_regularize * self.hparams.g_reg_every * path_loss

            if self.hparams.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            self.log('path', path_loss, prog_bar=True)
            self.log('mean_path', self.mean_path_length, prog_bar=True)
            self.log('path_length', path_lengths.mean())
            return {'loss': weighted_path_loss}
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device):
    loader = sample_data(loader, datatype="imagefolder")
    # inception related:
    if (get_rank() == 0):
        from calc_inception import load_patched_inception_v3
        inception = load_patched_inception_v3().to(device)
        inception.eval()
        if args.eval_every > 0:
            with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f:
                f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-' * 50}\n")
        if args.log_every > 0:
            with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-' * 50}\n")

    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01)

    mean_path_length = 0

    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}

    if args.distributed:
        g_module = generator.module
        d_module = discriminator.module

    else:
        g_module = generator
        d_module = discriminator

    accum = 0.5 ** (32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0

    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 8, device)

    sample_z = torch.randn(args.n_sample, args.latent, device=device)

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")

            break

        real_img = next(loader)
        real_img = real_img.to(device)

        requires_grad(generator, False)
        requires_grad(discriminator, True)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)

        if args.augment:
            real_img_aug, _ = augment(real_img, ada_aug_p)
            fake_img, _ = augment(fake_img, ada_aug_p)

        else:
            real_img_aug = real_img

        fake_pred = discriminator(fake_img)
        real_pred = discriminator(real_img_aug)
        d_loss = d_logistic_loss(real_pred, fake_pred)

        loss_dict["d"] = d_loss
        loss_dict["real_score"] = real_pred.mean()
        loss_dict["fake_score"] = fake_pred.mean()

        discriminator.zero_grad()
        d_loss.backward()
        d_optim.step()

        if args.augment and args.augment_p == 0:
            ada_aug_p = ada_augment.tune(real_pred)
            r_t_stat = ada_augment.r_t_stat

        d_regularize = i % args.d_reg_every == 0

        if d_regularize:
            real_img.requires_grad = True

            if args.augment:
                real_img_aug, _ = augment(real_img, ada_aug_p)

            else:
                real_img_aug = real_img

            real_pred = discriminator(real_img_aug)
            r1_loss = d_r1_loss(real_pred, real_img,args)

            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward()

            d_optim.step()

        loss_dict["r1"] = r1_loss

        requires_grad(generator, True)
        requires_grad(discriminator, False)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)

        if args.augment:
            fake_img, _ = augment(fake_img, ada_aug_p)

        fake_pred = discriminator(fake_img)
        g_loss = g_nonsaturating_loss(fake_pred)

        loss_dict["g"] = g_loss

        generator.zero_grad()
        g_loss.backward()
        g_optim.step()

        #g_reg starts
        g_regularize = False
        if args.useG_reg==True:
            # print("I entered g_reg")
            g_regularize = i % args.g_reg_every == 0
        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing, device)
            fake_img, latents = generator(noise, return_latents=True)

            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length
            )

            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss

            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()

            g_optim.step()

            mean_path_length_avg = (
                reduce_sum(mean_path_length).item() / get_world_size()
            )

        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        accumulate(g_ema, g_module, accum)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()

        if get_rank() == 0:
            pbar.set_description(
                (
                    f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
                    f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                    f"augment: {ada_aug_p:.4f}"
                )
            )

            # inception related:
            if args.eval_every > 0 and i % args.eval_every == 0:
                real_mean = real_cov = mean_latent = None
                with open(args.inception, "rb") as f:
                    embeds = pickle.load(f)
                    real_mean = embeds["mean"]
                    real_cov = embeds["cov"]
                    # print("yahooo!\n")
                with torch.no_grad():
                    g_ema.eval()
                    if args.truncation < 1:
                        mean_latent = g_ema.mean_latent(4096)
                    # print("I am fine sir!\n")
                    features = extract_feature_from_samples(
                        g_ema, inception, args.truncation, mean_latent, 64, args.n_sample_fid, device
                    ).numpy()
                    # print("I am normal sir!")
                    sample_mean = np.mean(features, 0)
                    sample_cov = np.cov(features, rowvar=False)
                    fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov)
                    with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f:
                        f.write(f"{i:07d}; fid: {float(fid):.4f};\n")
                    # print("alright hurray \n")

            if i % args.log_every == 0:
                with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                    f.write(
                        (
                            f"{i:07d}; "
                            f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
                            f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                            f"augment: {ada_aug_p:.4f};\n"
                        )
                    )

            if i % args.log_every == 0:
                with torch.no_grad():
                    g_ema.eval()
                    sample, _ = g_ema([sample_z])
                    utils.save_image(
                        sample,
                        os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}.png"),
                        nrow=int(args.n_sample ** 0.5),
                        normalize=True,
                        range=(-1, 1),
                    )

            if i % args.save_every == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"),
                )

            if i % args.save_latest_every == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight', f"latest.pt"),
                )
Exemple #7
0
def train(args, loader, loader2, encoder, generator, discriminator, vggnet,
          pwcnet, e_optim, d_optim, e_ema, pca_state, device):
    inception = real_mean = real_cov = mean_latent = None
    if args.eval_every > 0:
        inception = nn.DataParallel(load_patched_inception_v3()).to(device)
        inception.eval()
        with open(args.inception, "rb") as f:
            embeds = pickle.load(f)
            real_mean = embeds["mean"]
            real_cov = embeds["cov"]
    if get_rank() == 0:
        if args.eval_every > 0:
            with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f:
                f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n")
        if args.log_every > 0:
            with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n")

    loader = sample_data(loader)
    pbar = range(args.iter)
    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    dynamic_ncols=True,
                    smoothing=0.01)

    d_loss_val = 0
    e_loss_val = 0
    rec_loss_val = 0
    vgg_loss_val = 0
    adv_loss_val = 0
    loss_dict = {
        "d": torch.tensor(0., device=device),
        "real_score": torch.tensor(0., device=device),
        "fake_score": torch.tensor(0., device=device),
        "r1_d": torch.tensor(0., device=device),
        "r1_e": torch.tensor(0., device=device),
        "rec": torch.tensor(0., device=device),
    }
    avg_pix_loss = util.AverageMeter()
    avg_vgg_loss = util.AverageMeter()

    if args.distributed:
        e_module = encoder.module
        d_module = discriminator.module
        g_module = generator.module
    else:
        e_module = encoder
        d_module = discriminator
        g_module = generator

    # accum = 0.5 ** (32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0

    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length,
                                      args.ada_every, device)

    # sample_x = accumulate_batches(loader, args.n_sample).to(device)
    sample_x = load_real_samples(args, loader)
    if sample_x.ndim > 4:
        sample_x = sample_x[:, 0, ...]

    input_is_latent = args.latent_space != 'z'  # Encode in z space?

    requires_grad(generator, False)  # always False
    generator.eval()  # Generator should be ema and in eval mode
    g_ema = generator

    # if args.no_ema or e_ema is None:
    #     e_ema = encoder

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")
            break

        real_img = next(loader)
        real_img = real_img.to(device)

        # Train Encoder
        if args.toggle_grads:
            requires_grad(encoder, True)
            requires_grad(discriminator, False)
        pix_loss = vgg_loss = adv_loss = rec_loss = torch.tensor(0.,
                                                                 device=device)
        latent_real, _ = encoder(real_img)
        fake_img, _ = generator([latent_real], input_is_latent=input_is_latent)

        if args.lambda_adv > 0:
            if args.augment:
                fake_img_aug, _ = augment(fake_img, ada_aug_p)
            else:
                fake_img_aug = fake_img
            fake_pred = discriminator(fake_img_aug)
            adv_loss = g_nonsaturating_loss(fake_pred)

        if args.lambda_pix > 0:
            if args.pix_loss == 'l2':
                pix_loss = torch.mean((fake_img - real_img)**2)
            else:
                pix_loss = F.l1_loss(fake_img, real_img)

        if args.lambda_vgg > 0:
            real_feat = vggnet(real_img)
            fake_feat = vggnet(fake_img)
            vgg_loss = torch.mean((real_feat - fake_feat)**2)

        e_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv

        loss_dict["e"] = e_loss
        loss_dict["pix"] = pix_loss
        loss_dict["vgg"] = vgg_loss
        loss_dict["adv"] = adv_loss

        encoder.zero_grad()
        e_loss.backward()
        e_optim.step()

        if args.train_on_fake:
            e_regularize = args.e_rec_every > 0 and i % args.e_rec_every == 0
            if e_regularize and args.lambda_rec > 0:
                noise = mixing_noise(args.batch, args.latent, args.mixing,
                                     device)
                fake_img, latent_fake = generator(
                    noise,
                    input_is_latent=input_is_latent,
                    return_latents=True)
                latent_pred, _ = encoder(fake_img)
                if latent_pred.ndim < 3:
                    latent_pred = latent_pred.unsqueeze(1).repeat(
                        1, latent_fake.size(1), 1)
                rec_loss = torch.mean((latent_fake - latent_pred)**2)
                encoder.zero_grad()
                (rec_loss * args.lambda_rec).backward()
                e_optim.step()
                loss_dict["rec"] = rec_loss

        # e_regularize = args.e_reg_every > 0 and i % args.e_reg_every == 0
        # if e_regularize:
        #     # why not regularize on augmented real?
        #     real_img.requires_grad = True
        #     real_pred, _ = encoder(real_img)
        #     r1_loss_e = d_r1_loss(real_pred, real_img)

        #     encoder.zero_grad()
        #     (args.r1 / 2 * r1_loss_e * args.e_reg_every + 0 * real_pred.view(-1)[0]).backward()
        #     e_optim.step()

        #     loss_dict["r1_e"] = r1_loss_e

        if not args.no_ema and e_ema is not None:
            ema_nimg = args.ema_kimg * 1000
            if args.ema_rampup is not None:
                ema_nimg = min(ema_nimg, i * args.batch * args.ema_rampup)
            accum = 0.5**(args.batch / max(ema_nimg, 1e-8))
            accumulate(e_ema, e_module, accum)

        # Train Discriminator
        if args.toggle_grads:
            requires_grad(encoder, False)
            requires_grad(discriminator, True)
        if not args.no_update_discriminator and args.lambda_adv > 0:
            latent_real, _ = encoder(real_img)
            fake_img, _ = generator([latent_real],
                                    input_is_latent=input_is_latent)

            if args.augment:
                real_img_aug, _ = augment(real_img, ada_aug_p)
                fake_img_aug, _ = augment(fake_img, ada_aug_p)
            else:
                real_img_aug = real_img
                fake_img_aug = fake_img

            fake_pred = discriminator(fake_img_aug)
            real_pred = discriminator(real_img_aug)
            d_loss = d_logistic_loss(real_pred, fake_pred)

            loss_dict["d"] = d_loss
            loss_dict["real_score"] = real_pred.mean()
            loss_dict["fake_score"] = fake_pred.mean()

            discriminator.zero_grad()
            d_loss.backward()
            d_optim.step()

            if args.augment and args.augment_p == 0:
                ada_aug_p = ada_augment.tune(real_pred)
                r_t_stat = ada_augment.r_t_stat

            d_regularize = args.d_reg_every > 0 and i % args.d_reg_every == 0
            if d_regularize:
                # why not regularize on augmented real?
                real_img.requires_grad = True
                real_pred = discriminator(real_img)
                r1_loss_d = d_r1_loss(real_pred, real_img)

                discriminator.zero_grad()
                (args.r1 / 2 * r1_loss_d * args.d_reg_every +
                 0 * real_pred.view(-1)[0]).backward()
                # Why 0* ? Answer is here https://github.com/rosinality/stylegan2-pytorch/issues/76
                d_optim.step()

                loss_dict["r1_d"] = r1_loss_d

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        e_loss_val = loss_reduced["e"].mean().item()
        r1_d_val = loss_reduced["r1_d"].mean().item()
        r1_e_val = loss_reduced["r1_e"].mean().item()
        pix_loss_val = loss_reduced["pix"].mean().item()
        vgg_loss_val = loss_reduced["vgg"].mean().item()
        adv_loss_val = loss_reduced["adv"].mean().item()
        rec_loss_val = loss_reduced["rec"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        avg_pix_loss.update(pix_loss_val, real_img.shape[0])
        avg_vgg_loss.update(vgg_loss_val, real_img.shape[0])

        if get_rank() == 0:
            pbar.set_description((
                f"d: {d_loss_val:.4f}; e: {e_loss_val:.4f}; r1_d: {r1_d_val:.4f}; r1_e: {r1_e_val:.4f}; "
                f"pix: {pix_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; adv: {adv_loss_val:.4f}; "
                f"rec: {rec_loss_val:.4f}; augment: {ada_aug_p:.4f}"))

            if i % args.log_every == 0:
                with torch.no_grad():
                    latent_x, _ = e_ema(sample_x)
                    fake_x, _ = g_ema([latent_x],
                                      input_is_latent=input_is_latent)
                    sample_pix_loss = torch.sum((sample_x - fake_x)**2)
                with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                    f.write(
                        f"{i:07d}; pix: {avg_pix_loss.avg}; vgg: {avg_vgg_loss.avg}; "
                        f"ref: {sample_pix_loss.item()};\n")

            if args.eval_every > 0 and i % args.eval_every == 0:
                with torch.no_grad():
                    g_ema.eval()
                    e_ema.eval()
                    # Recon
                    features = extract_feature_from_reconstruction(
                        e_ema,
                        g_ema,
                        inception,
                        args.truncation,
                        mean_latent,
                        loader2,
                        args.device,
                        input_is_latent=input_is_latent,
                        mode='recon',
                    ).numpy()
                    sample_mean = np.mean(features, 0)
                    sample_cov = np.cov(features, rowvar=False)
                    fid_re = calc_fid(sample_mean, sample_cov, real_mean,
                                      real_cov)
                # print("Recon FID:", fid_re)
                with open(os.path.join(args.log_dir, 'log_fid.txt'),
                          'a+') as f:
                    f.write(f"{i:07d}; recon fid: {float(fid_re):.4f};\n")

            if wandb and args.wandb:
                wandb.log({
                    "Encoder": e_loss_val,
                    "Discriminator": d_loss_val,
                    "Augment": ada_aug_p,
                    "Rt": r_t_stat,
                    "R1 D": r1_d_val,
                    "R1 E": r1_e_val,
                    "Pix Loss": pix_loss_val,
                    "VGG Loss": vgg_loss_val,
                    "Adv Loss": adv_loss_val,
                    "Rec Loss": rec_loss_val,
                    "Real Score": real_score_val,
                    "Fake Score": fake_score_val,
                })

            if i % args.log_every == 0:
                with torch.no_grad():
                    e_eval = encoder if args.no_ema else e_ema
                    e_eval.eval()
                    nrow = int(args.n_sample**0.5)
                    nchw = list(sample_x.shape)[1:]
                    latent_real, _ = e_eval(sample_x)
                    fake_img, _ = generator([latent_real],
                                            input_is_latent=input_is_latent)
                    sample = torch.cat(
                        (sample_x.reshape(args.n_sample // nrow, nrow, *nchw),
                         fake_img.reshape(args.n_sample // nrow, nrow, *nchw)),
                        1)
                    utils.save_image(
                        sample.reshape(2 * args.n_sample, *nchw),
                        os.path.join(args.log_dir, 'sample',
                                     f"{str(i).zfill(6)}.png"),
                        nrow=nrow,
                        normalize=True,
                        value_range=(-1, 1),
                    )
                    e_eval.train()

            if i % args.save_every == 0:
                e_eval = encoder if args.no_ema else e_ema
                torch.save(
                    {
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_module.state_dict(),
                        "e_ema": e_eval.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight',
                                 f"{str(i).zfill(6)}.pt"),
                )

            if i % args.save_latest_every == 0:
                torch.save(
                    {
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_module.state_dict(),
                        "e_ema": e_eval.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight', f"latest.pt"),
                )
Exemple #8
0
def train(args, loader, generator, netsD, g_optim, rf_opt, info_opt, g_ema,
          device):
    loader = sample_data(loader)

    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    dynamic_ncols=True,
                    smoothing=0.01)

    mean_path_length = 0

    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}

    if args.distributed:
        g_module = generator.module
        d_module1 = netsD[1].module
        d_module2 = netsD[2].module

    else:
        g_module = generator
        d_module1 = netsD[1]
        d_module2 = netsD[2]

    accum = 0.5**(32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0

    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256,
                                      device)

    sample_z = torch.randn(args.n_sample, args.latent, device=device)

    criterion_class = nn.CrossEntropyLoss()
    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")

            break

        real_img = next(loader)
        real_img = real_img.to(device)

        ############# train child discriminator #############
        requires_grad(generator, False)
        requires_grad(netsD[2], True)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        c_code = sample_c_code(args.batch, args.c_categories, device)
        # fake_img, _ = generator(noise, c_code)
        image_li, _, _ = generator(noise, c_code)
        fake_img = image_li[0]

        if args.augment:
            real_img_aug, _ = augment(real_img, ada_aug_p)
            fake_img, _ = augment(fake_img, ada_aug_p)
        else:
            real_img_aug = real_img

        fake_pred = netsD[2](fake_img)[0]
        real_pred = netsD[2](real_img)[0]

        d_loss = d_logistic_loss(real_pred, fake_pred)

        loss_dict["d"] = d_loss
        loss_dict["real_score"] = real_pred.mean()
        loss_dict["fake_score"] = fake_pred.mean()

        netsD[2].zero_grad()
        d_loss.backward()
        rf_opt[2].step()

        if args.augment and args.augment_p == 0:
            ada_aug_p = ada_augment.tune(real_pred)
            r_t_stat = ada_augment.r_t_stat

        d_regularize = i % args.d_reg_every == 0

        if d_regularize:
            real_img.requires_grad = True
            real_pred = netsD[2](real_img)[0]
            r1_loss = d_r1_loss(real_pred, real_img)

            netsD[2].zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every +
             0 * real_pred[0]).backward()

            rf_opt[2].step()

        loss_dict["r1"] = r1_loss

        ############# train generator and info discriminator #############
        requires_grad(generator, True)
        requires_grad(netsD[1], True)
        requires_grad(netsD[2], True)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        c_code = sample_c_code(args.batch, args.c_categories, device)
        image_li, code_li, _ = generator(noise, c_code)
        fake_img = image_li[0]
        mkd_images = image_li[2]
        masks = image_li[3]
        p_code = code_li[1]
        c_code = code_li[2]

        if args.augment:
            fake_img, _ = augment(fake_img, ada_aug_p)

        fake_pred = netsD[2](fake_img)[0]
        g_loss = g_nonsaturating_loss(fake_pred)

        loss_dict["g"] = g_loss

        pred_p = netsD[1](mkd_images[1])[1]
        p_info_loss = criterion_class(
            pred_p,
            torch.nonzero(p_code.long(), as_tuple=False)[:, 1])

        pred_c = netsD[2](mkd_images[2])[1]
        c_info_loss = criterion_class(
            pred_c,
            torch.nonzero(c_code.long(), as_tuple=False)[:, 1])

        loss_dict["p_info"] = p_info_loss
        loss_dict["c_info"] = c_info_loss

        binary_loss = binarization_loss(masks[1]) * 2e1
        # oob_loss = torch.sum(bg_mk * ch_mk, dim=(-1,-2)).mean() * 1e-2
        ms = masks[1].size()
        min_fg_cvg = 0.2 * ms[2] * ms[3]
        fg_cvg_loss = F.relu(min_fg_cvg -
                             torch.sum(masks[1], dim=(-1, -2))).mean() * 1e-2

        ms = masks[1].size()
        min_bg_cvg = 0.2 * ms[2] * ms[3]
        bg_cvg_loss = F.relu(min_bg_cvg - torch.sum(
            torch.ones_like(masks[1]) - masks[1], dim=(-1, -2))).mean() * 1e-2

        loss_dict["bin"] = binary_loss
        loss_dict["cvg"] = fg_cvg_loss + bg_cvg_loss

        generator_loss = g_loss + p_info_loss + c_info_loss + binary_loss + fg_cvg_loss + bg_cvg_loss

        generator.zero_grad()
        netsD[1].zero_grad()
        netsD[2].zero_grad()
        generator_loss.backward()
        g_optim.step()
        info_opt[1].step()
        info_opt[2].step()

        g_regularize = i % args.g_reg_every == 0

        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing,
                                 device)
            c_code = sample_c_code(path_batch_size, args.c_categories, device)
            image_li, _, latents = generator(noise,
                                             c_code,
                                             return_latents=True)
            fake_img = image_li[0]

            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)

            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss

            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()

            g_optim.step()

            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())

        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        accumulate(g_ema, g_module, accum)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        p_info_loss_val = loss_reduced["p_info"].mean().item()
        c_info_loss_val = loss_reduced["c_info"].mean().item()
        binary_loss_val = loss_reduced["bin"].mean().item()
        cvg_loss_val = loss_reduced["cvg"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()

        if get_rank() == 0:
            pbar.set_description((
                f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
                f"p_info: {p_info_loss_val:.4f}; c_info: {c_info_loss_val:.4f}; "
                f"bin: {binary_loss_val:.4f}; cvg: {cvg_loss_val:.4f}; "
                f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                f"augment: {ada_aug_p:.4f}"))

            if wandb and args.wandb:
                wandb.log({
                    "Generator": g_loss_val,
                    "Discriminator": d_loss_val,
                    "p_info": p_info_loss_val,
                    "c_info": c_info_loss_val,
                    "Augment": ada_aug_p,
                    "Rt": r_t_stat,
                    "R1": r1_val,
                    "Path Length Regularization": path_loss_val,
                    "Mean Path Length": mean_path_length,
                    "Real Score": real_score_val,
                    "Fake Score": fake_score_val,
                    "Path Length": path_length_val,
                })

            if i % 1000 == 0:
                with torch.no_grad():
                    g_ema.eval()
                    c_code = sample_c_code(args.n_sample, args.c_categories,
                                           device)
                    image_li, _, _ = g_ema([sample_z], c_code)

                    utils.save_image(
                        image_li[0],
                        f"sample/{str(i).zfill(6)}_0.png",
                        nrow=int(args.n_sample**0.5),
                        normalize=True,
                        range=(-1, 1),
                    )

                    for j in range(3):
                        utils.save_image(
                            image_li[1][j],
                            f"sample/{str(i).zfill(6)}_{str(1+j)}.png",
                            nrow=int(args.n_sample**0.5),
                            normalize=True,
                            range=(-1, 1),
                        )

                    for j in range(3):
                        utils.save_image(
                            image_li[2][j],
                            f"sample/{str(i).zfill(6)}_{str(4+j)}.png",
                            nrow=int(args.n_sample**0.5),
                            normalize=True,
                            range=(-1, 1),
                        )

                    for j in range(2):
                        utils.save_image(
                            image_li[3][j],
                            f"sample/{str(i).zfill(6)}_{str(7+j)}.png",
                            nrow=int(args.n_sample**0.5),
                            normalize=True,
                            range=(0, 1),
                        )

            if i % 10000 == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        # "d0": d_module0.state_dict(),
                        "d1": d_module1.state_dict(),
                        "d2": d_module2.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "rf_optim2": rf_opt[2].state_dict(),
                        "info_optim1": info_opt[1].state_dict(),
                        "info_optim2": info_opt[2].state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                    },
                    f"checkpoint/{str(i).zfill(6)}.pt",
                )
Exemple #9
0
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema,
          device):
    loader = sample_data(loader)

    pbar = range(args.iter)
    embedding = nn.Embedding(16, args.embedding_dim).to(device)
    aux_loss_fn = nn.CrossEntropyLoss().to(device)

    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    dynamic_ncols=True,
                    smoothing=0.01)

    mean_path_length = 0

    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}

    if args.distributed:
        g_module = generator.module
        d_module = discriminator.module

    else:
        g_module = generator
        d_module = discriminator

    accum = 0.5**(32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0

    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 8,
                                      device)

    sample_z = torch.randn(args.n_sample, args.latent, device=device)

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")

            break

        real_img, pos_label = next(loader)
        real_img = real_img.to(device)
        pos_label = pos_label.to(device)
        gen_label = torch.randint_like(pos_label, 0, args.patch_number)
        gen_label_emb = embedding(gen_label).detach()

        requires_grad(generator, False)
        requires_grad(discriminator, True)

        noise = mixing_noise(args.batch, args.latent - args.embedding_dim,
                             args.mixing, device)
        noise = [
            torch.cat([cur_noise, gen_label_emb], dim=1) for cur_noise in noise
        ]

        fake_img, _ = generator(noise)

        if args.augment:
            real_img_aug, _ = augment(real_img, ada_aug_p)
            fake_img, _ = augment(fake_img, ada_aug_p)

        else:
            real_img_aug = real_img

        fake_pred, fake_aux = discriminator(fake_img)
        real_pred, real_aux = discriminator(real_img_aug)
        d_loss = 0.5 * d_logistic_loss(real_pred, fake_pred) + \
                        0.25 * aux_loss_fn(fake_aux, gen_label) + \
                        0.25 * aux_loss_fn(real_aux, pos_label)

        loss_dict["d"] = d_loss
        loss_dict["real_score"] = real_pred.mean()
        loss_dict["fake_score"] = fake_pred.mean()

        discriminator.zero_grad()
        d_loss.backward()
        d_optim.step()

        if args.augment and args.augment_p == 0:
            ada_aug_p = ada_augment.tune(real_pred)
            r_t_stat = ada_augment.r_t_stat

        d_regularize = i % args.d_reg_every == 0

        if d_regularize:
            real_img.requires_grad = True

            if args.augment:
                real_img_aug, _ = augment(real_img, ada_aug_p)

            else:
                real_img_aug = real_img

            real_pred, real_aux = discriminator(real_img_aug)
            r1_loss = d_r1_loss(real_pred, real_img)

            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every +
             0 * real_pred[0]).backward()

            d_optim.step()

        loss_dict["r1"] = r1_loss

        requires_grad(generator, True)
        requires_grad(discriminator, False)

        noise = mixing_noise(args.batch, args.latent - args.embedding_dim,
                             args.mixing, device)
        noise = [
            torch.cat([cur_noise, gen_label_emb], dim=1) for cur_noise in noise
        ]
        fake_img, _ = generator(noise)

        if args.augment:
            fake_img, _ = augment(fake_img, ada_aug_p)

        fake_pred, fake_aux = discriminator(fake_img)
        g_loss = (g_nonsaturating_loss(fake_pred) +
                  aux_loss_fn(fake_aux, gen_label)) / 2.0

        loss_dict["g"] = g_loss

        generator.zero_grad()
        g_loss.backward()
        g_optim.step()

        g_regularize = i % args.g_reg_every == 0

        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing,
                                 device)
            fake_img, latents = generator(noise, return_latents=True)

            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)

            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss

            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()

            g_optim.step()

            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())

        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        accumulate(g_ema, g_module, accum)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()

        if get_rank() == 0:
            pbar.set_description((
                f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
                f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                f"augment: {ada_aug_p:.4f}"))

            if wandb and args.wandb:
                wandb.log({
                    "Generator": g_loss_val,
                    "Discriminator": d_loss_val,
                    "Augment": ada_aug_p,
                    "Rt": r_t_stat,
                    "R1": r1_val,
                    "Path Length Regularization": path_loss_val,
                    "Mean Path Length": mean_path_length,
                    "Real Score": real_score_val,
                    "Fake Score": fake_score_val,
                    "Path Length": path_length_val,
                })

            if i % 100 == 0:
                with torch.no_grad():
                    g_ema.eval()
                    sample, _ = g_ema([sample_z])
                    utils.save_image(
                        sample,
                        f"experiments_2/sample/{str(i).zfill(6)}.png",
                        nrow=int(args.n_sample**0.5),
                        normalize=True,
                        value_range=(-1, 1),
                    )

            if i % 10000 == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                    },
                    f"experiments_2/checkpoint/{str(i).zfill(6)}.pt",
                )
def train(args, loader, generator, encoder, discriminator, discriminator2,
          vggnet, g_optim, e_optim, d_optim, d2_optim, g_ema, e_ema, device):
    kwargs_d = {'detach_aux': args.detach_d_aux_head}
    if args.dataset == 'imagefolder':
        loader = sample_data2(loader)
    else:
        loader = sample_data(loader)

    if args.eval_every > 0:
        inception = nn.DataParallel(load_patched_inception_v3()).to(device)
        inception.eval()
        with open(args.inception, "rb") as f:
            embeds = pickle.load(f)
            real_mean = embeds["mean"]
            real_cov = embeds["cov"]
    else:
        inception = real_mean = real_cov = None
    mean_latent = None

    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    dynamic_ncols=True,
                    smoothing=0.01)

    mean_path_length = 0

    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}
    avg_pix_loss = util.AverageMeter()
    avg_vgg_loss = util.AverageMeter()

    if args.distributed:
        g_module = generator.module
        e_module = encoder.module
        d_module = discriminator.module
    else:
        g_module = generator
        e_module = encoder
        d_module = discriminator

    d2_module = None
    if discriminator2 is not None:
        if args.distributed:
            d2_module = discriminator2.module
        else:
            d2_module = discriminator2

    accum = 0.5**(32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0

    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256,
                                      device)

    sample_z = torch.randn(args.n_sample, args.latent, device=device)
    sample_x = load_real_samples(args, loader)
    if sample_x.ndim > 4:
        sample_x = sample_x[:, 0, ...]

    n_step_max = max(args.n_step_d, args.n_step_e)

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")
            break

        # real_img = next(loader).to(device)
        real_imgs = [next(loader).to(device) for _ in range(n_step_max)]

        # Train Discriminator
        requires_grad(generator, False)
        requires_grad(encoder, False)
        requires_grad(discriminator, True)
        if discriminator2 is not None:
            requires_grad(discriminator2, True)
        for step_index in range(args.n_step_d):
            # real_img = next(loader).to(device)
            real_img = real_imgs[step_index]
            noise = mixing_noise(args.batch, args.latent, args.mixing, device)
            fake_img, _ = generator(noise)
            if args.augment:
                real_img_aug, _ = augment(real_img, ada_aug_p)
                fake_img, _ = augment(fake_img, ada_aug_p)
            else:
                real_img_aug = real_img
            fake_pred = discriminator(fake_img, **kwargs_d)
            real_pred = discriminator(real_img_aug, **kwargs_d)
            d_loss_real1 = 0.
            if args.n_head_d > 1:
                fake_pred = fake_pred[0]
                real_pred, real_pred1 = real_pred[0], real_pred[1]
                d_loss_real1 = F.softplus(-real_pred1).mean()
            d_loss_real2 = 0.
            if args.decouple_d:
                real_pred2 = discriminator2(real_img_aug, **kwargs_d)
                d_loss_real2 = F.softplus(-real_pred2).mean()
            d_loss_real = F.softplus(-real_pred).mean()
            d_loss_fake = F.softplus(fake_pred).mean()
            loss_dict["real_score"] = real_pred.mean()
            loss_dict["fake_score"] = fake_pred.mean()

            latent_real, _ = encoder(real_img)
            rec_img, _ = generator([latent_real],
                                   input_is_latent=True,
                                   return_latents=False)
            if args.augment:
                rec_img, _ = augment(rec_img, ada_aug_p)
            if not args.decouple_d:
                rec_pred = discriminator(rec_img, **kwargs_d)
                if args.n_head_d > 1:
                    rec_pred = rec_pred[1]
            else:
                rec_pred = discriminator2(rec_img, **kwargs_d)
            d_loss_rec = F.softplus(rec_pred).mean()

            d_loss = (d_loss_real + d_loss_real1 + d_loss_real2 +
                      d_loss_fake * args.lambda_fake_d +
                      d_loss_rec * args.lambda_rec_d)
            loss_dict["rec_score"] = rec_pred.mean()
            loss_dict["d"] = d_loss

            if not args.decouple_d:
                discriminator.zero_grad()
                d_loss.backward()
                d_optim.step()
            else:
                discriminator.zero_grad()
                discriminator2.zero_grad()
                d_loss.backward()
                d_optim.step()
                d2_optim.step()

        if args.augment and args.augment_p == 0:
            ada_aug_p = ada_augment.tune(real_pred)
            r_t_stat = ada_augment.r_t_stat

        d_regularize = i % args.d_reg_every == 0
        if d_regularize:
            real_img.requires_grad = True
            real_pred = discriminator(real_img, **kwargs_d)
            if args.n_head_d > 1:
                real_pred = real_pred[0] + real_pred[1]
            r1_loss = d_r1_loss(real_pred, real_img)
            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every +
             0 * real_pred[0]).backward()
            d_optim.step()
        loss_dict["r1"] = r1_loss

        if d_regularize and args.decouple_d:
            real_img.requires_grad = True
            real_pred = discriminator2(real_img, **kwargs_d)
            if args.n_head_d > 1:
                real_pred = real_pred[0] + real_pred[1]
            r1_loss = d_r1_loss(real_pred, real_img)
            discriminator2.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every +
             0 * real_pred[0]).backward()
            d2_optim.step()

        # Train Generator
        requires_grad(generator, True)
        requires_grad(discriminator, False)
        if discriminator2 is not None:
            requires_grad(discriminator2, False)
        real_img = real_imgs[0]
        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)
        if args.augment:
            fake_img, _ = augment(fake_img, ada_aug_p)
        fake_pred = discriminator(fake_img, **kwargs_d)
        if args.n_head_d > 1:
            fake_pred = fake_pred[0]
        g_loss_fake = g_nonsaturating_loss(fake_pred)

        g_loss_rec = 0.
        if args.lambda_rec_g > 0:
            latent_real, _ = encoder(real_img)
            rec_img, _ = generator([latent_real],
                                   input_is_latent=True,
                                   return_latents=False)
            if args.augment:
                rec_img, _ = augment(rec_img, ada_aug_p)
            if not args.decouple_d:
                rec_pred = discriminator(rec_img, **kwargs_d)
                if args.n_head_d > 1:
                    rec_pred = rec_pred[1]
            else:
                rec_pred = discriminator2(rec_img, **kwargs_d)
            g_loss_rec = g_nonsaturating_loss(rec_pred)

        g_loss = g_loss_fake * args.lambda_fake_g + g_loss_rec * args.lambda_rec_g
        loss_dict["g"] = g_loss

        generator.zero_grad()
        g_loss.backward()
        g_optim.step()

        g_regularize = args.g_reg_every > 0 and i % args.g_reg_every == 0
        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing,
                                 device)
            fake_img, latents = generator(noise, return_latents=True)
            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)
            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss
            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]
            weighted_path_loss.backward()
            g_optim.step()
            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())
        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        # Train Encoder
        requires_grad(encoder, True)
        requires_grad(discriminator, False)
        requires_grad(generator, args.joint)
        if discriminator2 is not None:
            requires_grad(discriminator2, False)
        pix_loss = vgg_loss = adv_loss = torch.tensor(0., device=device)
        for step_index in range(args.n_step_e):
            # real_img = next(loader).to(device)
            real_img = real_imgs[step_index]
            latent_real, _ = encoder(real_img)
            rec_img, _ = generator([latent_real],
                                   input_is_latent=True,
                                   return_latents=False)
            if args.lambda_adv > 0:
                if args.augment:
                    rec_img_aug, _ = augment(rec_img, ada_aug_p)
                else:
                    rec_img_aug = rec_img
                if not args.decouple_d:
                    rec_pred = discriminator(rec_img_aug, **kwargs_d)
                    if args.n_head_d > 1:
                        rec_pred = rec_pred[1]
                else:
                    rec_pred = discriminator2(rec_img_aug, **kwargs_d)
                adv_loss = g_nonsaturating_loss(rec_pred)
            if args.lambda_pix > 0:
                if args.pix_loss == 'l2':
                    pix_loss = torch.mean((rec_img - real_img)**2)
                else:
                    pix_loss = F.l1_loss(rec_img, real_img)
            if args.lambda_vgg > 0:
                vgg_loss = torch.mean((vggnet(real_img) - vggnet(rec_img))**2)

            e_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv
            loss_dict["e"] = e_loss
            loss_dict["pix"] = pix_loss
            loss_dict["vgg"] = vgg_loss
            loss_dict["adv"] = adv_loss

            if args.joint:
                encoder.zero_grad()
                generator.zero_grad()
                e_loss.backward()
                # generator.style.zero_grad()  # not necessary
                e_optim.step()
                g_optim.step()
            else:
                encoder.zero_grad()
                e_loss.backward()
                e_optim.step()

        accumulate(e_ema, e_module, accum)
        accumulate(g_ema, g_module, accum)

        loss_reduced = reduce_loss_dict(loss_dict)
        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()
        pix_loss_val = loss_reduced["pix"].mean().item()
        vgg_loss_val = loss_reduced["vgg"].mean().item()
        adv_loss_val = loss_reduced["adv"].mean().item()
        avg_pix_loss.update(pix_loss_val, real_img.shape[0])
        avg_vgg_loss.update(vgg_loss_val, real_img.shape[0])

        if get_rank() == 0:
            pbar.set_description((
                f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
                f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                f"augment: {ada_aug_p:.4f}; "
                f"pix: {pix_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; adv: {adv_loss_val:.4f}"
            ))

            if i % args.log_every == 0:
                with torch.no_grad():
                    latent_x, _ = e_ema(sample_x)
                    fake_x, _ = generator([latent_x],
                                          input_is_latent=True,
                                          return_latents=False)
                    sample_pix_loss = torch.sum((sample_x - fake_x)**2)
                with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                    f.write(
                        f"{i:07d}; pix: {avg_pix_loss.avg}; vgg: {avg_vgg_loss.avg}; "
                        f"ref: {sample_pix_loss.item()};\n")

            if args.eval_every > 0 and i % args.eval_every == 0:
                with torch.no_grad():
                    g_ema.eval()
                    if args.truncation < 1:
                        mean_latent = g_ema.mean_latent(4096)
                    features = extract_feature_from_samples(
                        g_ema, inception, args.truncation, mean_latent, 64,
                        args.n_sample_fid, args.device).numpy()
                    sample_mean = np.mean(features, 0)
                    sample_cov = np.cov(features, rowvar=False)
                    fid = calc_fid(sample_mean, sample_cov, real_mean,
                                   real_cov)
                print("fid:", fid)
                with open(os.path.join(args.log_dir, 'log_fid.txt'),
                          'a+') as f:
                    f.write(f"{i:07d}; fid: {float(fid):.4f};\n")

            if wandb and args.wandb:
                wandb.log({
                    "Generator": g_loss_val,
                    "Discriminator": d_loss_val,
                    "Augment": ada_aug_p,
                    "Rt": r_t_stat,
                    "R1": r1_val,
                    "Path Length Regularization": path_loss_val,
                    "Mean Path Length": mean_path_length,
                    "Real Score": real_score_val,
                    "Fake Score": fake_score_val,
                    "Path Length": path_length_val,
                })

            if i % args.log_every == 0:
                with torch.no_grad():
                    # Fixed fake samples
                    g_ema.eval()
                    sample, _ = g_ema([sample_z])
                    utils.save_image(
                        sample,
                        os.path.join(args.log_dir, 'sample',
                                     f"{str(i).zfill(6)}-sample.png"),
                        nrow=int(args.n_sample**0.5),
                        normalize=True,
                        value_range=(-1, 1),
                    )
                    # Reconstruction samples
                    e_ema.eval()
                    nrow = int(args.n_sample**0.5)
                    nchw = list(sample_x.shape)[1:]
                    latent_real, _ = e_ema(sample_x)
                    fake_img, _ = g_ema([latent_real],
                                        input_is_latent=True,
                                        return_latents=False)
                    sample = torch.cat(
                        (sample_x.reshape(args.n_sample // nrow, nrow, *nchw),
                         fake_img.reshape(args.n_sample // nrow, nrow, *nchw)),
                        1)
                    utils.save_image(
                        sample.reshape(2 * args.n_sample, *nchw),
                        os.path.join(args.log_dir, 'sample',
                                     f"{str(i).zfill(6)}-recon.png"),
                        nrow=nrow,
                        normalize=True,
                        value_range=(-1, 1),
                    )

            if i % args.save_every == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "d2":
                        d2_module.state_dict() if args.decouple_d else None,
                        "g_ema": g_ema.state_dict(),
                        "e_ema": e_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "d2_optim":
                        d2_optim.state_dict() if args.decouple_d else None,
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight',
                                 f"{str(i).zfill(6)}.pt"),
                )

            if i % args.save_latest_every == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "d2":
                        d2_module.state_dict() if args.decouple_d else None,
                        "g_ema": g_ema.state_dict(),
                        "e_ema": e_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "d2_optim":
                        d2_optim.state_dict() if args.decouple_d else None,
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight', f"latest.pt"),
                )
def train(args, loader, generator, discriminator, optimizer, g_ema, device):
    collect_info = True
    ckpt_dir = 'checkpoints/stylegan-acgd'
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
    fig_dir = 'figs/stylegan-acgd'
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)
    loader = sample_data(loader)
    pbar = range(args.iter)
    pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01)

    accs = torch.tensor([1.0 for i in range(50)])
    loss_dict = {}
    if args.gpu_num > 1:
        g_module = generator.module
        d_module = discriminator.module
    else:
        g_module = generator
        d_module = discriminator
    accum = 0.5 ** (32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0

    sample_z = torch.randn(args.n_sample, args.latent, device=device)
    ada_ratio = 2
    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")
            break

        real_img = next(loader)
        real_img = real_img.to(device)
        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)

        if args.augment:
            real_img_aug, _ = augment(real_img, ada_aug_p)
            fake_img, _ = augment(fake_img, ada_aug_p)
        else:
            real_img_aug = real_img

        fake_pred = discriminator(fake_img)
        real_pred = discriminator(real_img_aug)
        d_loss = d_logistic_loss(real_pred, fake_pred)
        optimizer.step(d_loss)

        num_correct = torch.sum(real_pred > 0) + torch.sum(fake_pred < 0)
        acc = num_correct.item() / (fake_pred.shape[0] + real_pred.shape[0])

        loss_dict["loss"] = d_loss.item()
        loss_dict["real_score"] = real_pred.mean().item()
        loss_dict["fake_score"] = fake_pred.mean().item()


        # update ada_ratio
        accs[i % 50] = acc
        acc_indicator = sum(accs) / 50
        if i % 2 == 0:
            if acc_indicator > 0.85:
                ada_ratio += 1
            elif acc_indicator < 0.75:
                ada_ratio -= 1
            max_ratio = 2 ** min(4, ada_ratio)
            min_ratio = 2 ** min(0, 4 - ada_ratio)
            if args.ada_train:
                print('Adjust lrs')
                optimizer.set_lr(lr_max=max_ratio * args.lr_d, lr_min=min_ratio * args.lr_d)

        accumulate(g_ema, g_module, accum)

        d_loss_val = loss_dict["loss"]
        real_score_val = loss_dict["real_score"]
        fake_score_val = loss_dict["fake_score"]

        pbar.set_description(
            (
                f"d: {d_loss_val:.4f}; g: {d_loss_val:.4f}; Acc: {acc:.4f}; "
                f"augment: {ada_aug_p:.4f}"
            )
        )
        if wandb and args.wandb:
            if collect_info:
                cgd_info = optimizer.get_info()
                wandb.log(
                    {
                        'CG iter num': cgd_info['iter_num'],
                        'CG runtime': cgd_info['time'],
                        'D gradient': cgd_info['grad_y'],
                        'G gradient': cgd_info['grad_x'],
                        'D hvp': cgd_info['hvp_y'],
                        'G hvp': cgd_info['hvp_x'],
                        'D cg': cgd_info['cg_y'],
                        'G cg': cgd_info['cg_x']
                    },
                    step=i,
                )
            wandb.log(
                {
                    "Generator": d_loss_val,
                    "Discriminator": d_loss_val,
                    "Ada ratio": ada_ratio,
                    'Generator lr': max_ratio * args.lr_d,
                    'Discriminator lr': min_ratio * args.lr_d,
                    "Rt": r_t_stat,
                    "Accuracy": acc_indicator,
                    "Real Score": real_score_val,
                    "Fake Score": fake_score_val
                },
                step=i,
            )
        if i % 100 == 0:
            with torch.no_grad():
                g_ema.eval()
                sample, _ = g_ema([sample_z])
                utils.save_image(
                    sample,
                    f"figs/stylegan-acgd/{str(i).zfill(6)}.png",
                    nrow=int(args.n_sample ** 0.5),
                    normalize=True,
                    range=(-1, 1),
                )
        if i % 2000 == 0:
            torch.save(
                {
                    "g": g_module.state_dict(),
                    "d": d_module.state_dict(),
                    "g_ema": g_ema.state_dict(),
                    "d_optim": optimizer.state_dict(),
                    "args": args,
                    "ada_aug_p": ada_aug_p,
                },
                f"checkpoints/stylegan-acgd/fix{str(i).zfill(6)}.pt",
            )
def train(args, loader, loader2, T_list, generator, encoder, discriminator,
          discriminator2, vggnet, g_optim, e_optim, d_optim, d2_optim, g_ema,
          e_ema, device):
    # kwargs_d = {'detach_aux': args.detach_d_aux_head}
    inception = real_mean = real_cov = mean_latent = None
    if args.eval_every > 0:
        inception = nn.DataParallel(load_patched_inception_v3()).to(device)
        inception.eval()
        with open(args.inception, "rb") as f:
            embeds = pickle.load(f)
            real_mean = embeds["mean"]
            real_cov = embeds["cov"]
    if get_rank() == 0:
        if args.eval_every > 0:
            with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f:
                f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n")
        if args.log_every > 0:
            with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n")

    if args.dataset == 'imagefolder':
        loader = sample_data2(loader)
    else:
        loader = sample_data(loader)
    pbar = range(args.iter)
    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    dynamic_ncols=True,
                    smoothing=0.01)

    mean_path_length = 0
    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}
    avg_pix_loss = util.AverageMeter()
    avg_vgg_loss = util.AverageMeter()

    if args.distributed:
        g_module = generator.module
        e_module = encoder.module
        d_module = discriminator.module
    else:
        g_module = generator
        e_module = encoder
        d_module = discriminator

    d2_module = None
    if discriminator2 is not None:
        if args.distributed:
            d2_module = discriminator2.module
        else:
            d2_module = discriminator2

    accum = 0.5**(32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0

    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256,
                                      device)

    sample_z = torch.randn(args.n_sample, args.latent, device=device)
    sample_x, sample_idx = load_real_samples(args, loader)
    assert (sample_x.shape[1] >= args.nframe_num)
    sample_x1 = sample_x[:, 0, ...]
    sample_x2 = sample_x[:, -1, ...]
    # sample_idx = torch.randperm(args.n_sample)
    fid_batch_idx = sample_idx = None

    n_step_max = max(args.n_step_d, args.n_step_e)

    requires_grad(g_ema, False)
    requires_grad(e_ema, False)

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")
            break

        frames = [
            get_batch(loader, device, T_list[i], not args.no_rand_T)
            for _ in range(n_step_max)
        ]

        # Train Discriminator
        requires_grad(generator, False)
        requires_grad(encoder, False)
        requires_grad(discriminator, True)
        for step_index in range(args.n_step_d):
            frames1, frames2 = frames[step_index]
            real_img = frames1
            noise = mixing_noise(args.batch, args.latent, args.mixing, device)
            if args.use_ema:
                g_ema.eval()
                fake_img, _ = g_ema(noise)
            else:
                fake_img, _ = generator(noise)
            if args.augment:
                real_img_aug, _ = augment(real_img, ada_aug_p)
                fake_img, _ = augment(fake_img, ada_aug_p)
            else:
                real_img_aug = real_img
            fake_pred = discriminator(fake_img)
            real_pred = discriminator(real_img_aug)
            d_loss_fake = F.softplus(fake_pred).mean()
            d_loss_real = F.softplus(-real_pred).mean()
            if args.use_frames2_d:
                real_pred2 = discriminator(frames2)
                d_loss_real += F.softplus(-real_pred2).mean()
            loss_dict["real_score"] = real_pred.mean()
            loss_dict["fake_score"] = fake_pred.mean()

            d_loss_rec = 0.
            if not args.decouple_d and args.lambda_rec_d > 0:  # Do not train D on x_rec if decouple_d
                if args.use_ema:
                    e_ema.eval()
                    g_ema.eval()
                    latent_real, _ = e_ema(real_img)
                    rec_img, _ = g_ema([latent_real], input_is_latent=True)
                else:
                    latent_real, _ = encoder(real_img)
                    rec_img, _ = generator([latent_real], input_is_latent=True)
                rec_pred = discriminator(rec_img)
                d_loss_rec = F.softplus(rec_pred).mean()
                loss_dict["rec_score"] = rec_pred.mean()

            d_loss_cross = 0.
            if not args.decouple_d and args.lambda_cross_d > 0:
                if args.use_ema:
                    e_ema.eval()
                    w1, _ = e_ema(frames1)
                    w2, _ = e_ema(frames2)
                else:
                    w1, _ = encoder(frames1)
                    w2, _ = encoder(frames2)
                dw = w2 - w1
                dw_shuffle = dw[torch.randperm(args.batch), ...]
                if args.use_ema:
                    g_ema.eval()
                    cross_img, _ = g_ema([w1 + dw_shuffle],
                                         input_is_latent=True)
                else:
                    cross_img, _ = generator([w1 + dw_shuffle],
                                             input_is_latent=True)
                cross_pred = discriminator(cross_img)
                d_loss_cross = F.softplus(cross_pred).mean()

            d_loss_fake_cross = 0.
            if args.lambda_fake_cross_d > 0:
                if args.use_ema:
                    e_ema.eval()
                    w1, _ = e_ema(frames1)
                    w2, _ = e_ema(frames2)
                else:
                    w1, _ = encoder(frames1)
                    w2, _ = encoder(frames2)
                dw = w2 - w1
                noise = mixing_noise(args.batch, args.latent, args.mixing,
                                     device)
                if args.use_ema:
                    g_ema.eval()
                    style = g_ema.get_styles(noise).view(args.batch, -1)
                else:
                    style = generator.get_styles(noise).view(args.batch, -1)
                if dw.shape[1] < style.shape[1]:  # W space
                    dw = dw.repeat(1, args.n_latent)
                if args.use_ema:
                    cross_img, _ = g_ema([style + dw], input_is_latent=True)
                else:
                    cross_img, _ = generator([style + dw],
                                             input_is_latent=True)
                fake_cross_pred = discriminator(cross_img)
                d_loss_fake_cross = F.softplus(fake_cross_pred).mean()

            d_loss = (d_loss_real + d_loss_fake +
                      d_loss_fake_cross * args.lambda_fake_cross_d +
                      d_loss_rec * args.lambda_rec_d +
                      d_loss_cross * args.lambda_cross_d)
            loss_dict["d"] = d_loss

            discriminator.zero_grad()
            d_loss.backward()
            d_optim.step()

        if args.augment and args.augment_p == 0:
            ada_aug_p = ada_augment.tune(real_pred)
            r_t_stat = ada_augment.r_t_stat

        d_regularize = i % args.d_reg_every == 0
        if d_regularize:
            real_img.requires_grad = True
            real_pred = discriminator(real_img)
            r1_loss = d_r1_loss(real_pred, real_img)
            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every +
             0 * real_pred[0]).backward()
            d_optim.step()
        loss_dict["r1"] = r1_loss

        # Train Discriminator2
        if args.decouple_d and discriminator2 is not None:
            requires_grad(generator, False)
            requires_grad(encoder, False)
            requires_grad(discriminator2, True)
            for step_index in range(
                    args.n_step_e):  # n_step_d2 is same as n_step_e
                frames1, frames2 = frames[step_index]
                real_img = frames1
                if args.use_ema:
                    e_ema.eval()
                    g_ema.eval()
                    latent_real, _ = e_ema(real_img)
                    rec_img, _ = g_ema([latent_real], input_is_latent=True)
                else:
                    latent_real, _ = encoder(real_img)
                    rec_img, _ = generator([latent_real], input_is_latent=True)
                real_pred1 = discriminator2(frames1)
                rec_pred = discriminator2(rec_img)
                d2_loss_real = F.softplus(-real_pred1).mean()
                d2_loss_rec = F.softplus(rec_pred).mean()
                if args.use_frames2_d:
                    real_pred2 = discriminator2(frames2)
                    d2_loss_real += F.softplus(-real_pred2).mean()

                if args.use_ema:
                    e_ema.eval()
                    w1, _ = e_ema(frames1)
                    w2, _ = e_ema(frames2)
                else:
                    w1, _ = encoder(frames1)
                    w2, _ = encoder(frames2)
                dw = w2 - w1
                dw_shuffle = dw[torch.randperm(args.batch), ...]
                if args.use_ema:
                    g_ema.eval()
                    cross_img, _ = g_ema([w1 + dw_shuffle],
                                         input_is_latent=True)
                else:
                    cross_img, _ = generator([w1 + dw_shuffle],
                                             input_is_latent=True)
                cross_pred = discriminator2(cross_img)
                d2_loss_cross = F.softplus(cross_pred).mean()

                d2_loss = d2_loss_real + d2_loss_rec + d2_loss_cross
                loss_dict["d2"] = d2_loss
                loss_dict["rec_score"] = rec_pred.mean()
                loss_dict["cross_score"] = cross_pred.mean()

                discriminator2.zero_grad()
                d2_loss.backward()
                d2_optim.step()

            d_regularize = i % args.d_reg_every == 0
            if d_regularize:
                real_img.requires_grad = True
                real_pred = discriminator2(real_img)
                r1_loss = d_r1_loss(real_pred, real_img)
                discriminator2.zero_grad()
                (args.r1 / 2 * r1_loss * args.d_reg_every +
                 0 * real_pred[0]).backward()
                d2_optim.step()

        # Train Encoder and Generator
        requires_grad(encoder, True)
        requires_grad(generator, True)
        requires_grad(discriminator, False)
        requires_grad(discriminator2, False)
        pix_loss = vgg_loss = adv_loss = cross_loss = torch.tensor(
            0., device=device)
        for step_index in range(args.n_step_e):
            frames1, frames2 = frames[step_index]
            real_img = frames1
            latent_real, _ = encoder(real_img)
            rec_img, _ = generator([latent_real], input_is_latent=True)
            if args.lambda_pix > 0:
                if args.pix_loss == 'l2':
                    pix_loss = torch.mean((rec_img - real_img)**2)
                else:
                    pix_loss = F.l1_loss(rec_img, real_img)
            if args.lambda_vgg > 0:
                vgg_loss = torch.mean((vggnet(real_img) - vggnet(rec_img))**2)
            if args.lambda_adv > 0:
                if not args.decouple_d:
                    rec_pred = discriminator(rec_img)
                else:
                    rec_pred = discriminator2(rec_img)
                adv_loss = g_nonsaturating_loss(rec_pred)

            if args.lambda_cross > 0:
                w1, _ = encoder(frames1)
                w2, _ = encoder(frames2)
                dw = w2 - w1
                dw_shuffle = dw[torch.randperm(args.batch), ...]
                cross_img, _ = generator([w1 + dw_shuffle],
                                         input_is_latent=True)
                if not args.decouple_d:
                    cross_pred = discriminator(cross_img)
                else:
                    cross_pred = discriminator2(cross_img)
                cross_loss = g_nonsaturating_loss(cross_pred)

            e_loss = (pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg +
                      adv_loss * args.lambda_adv +
                      cross_loss * args.lambda_cross)
            loss_dict["e"] = e_loss
            loss_dict["pix"] = pix_loss
            loss_dict["vgg"] = vgg_loss
            loss_dict["adv"] = adv_loss

            encoder.zero_grad()
            generator.zero_grad()
            e_loss.backward()
            e_optim.step()
            g_optim.step()

        # Train Generator
        requires_grad(generator, True)
        requires_grad(encoder, False)
        requires_grad(discriminator, False)
        requires_grad(discriminator2, False)
        frames1, frames2 = frames[0]
        real_img = frames1
        g_loss_fake = 0.
        if args.lambda_fake_g > 0:
            noise = mixing_noise(args.batch, args.latent, args.mixing, device)
            fake_img, _ = generator(noise)
            if args.augment:
                fake_img, _ = augment(fake_img, ada_aug_p)
            fake_pred = discriminator(fake_img)
            g_loss_fake = g_nonsaturating_loss(fake_pred)
            if args.no_sim_g:
                generator.zero_grad()
                (g_loss_fake * args.lambda_fake_g).backward()
                g_optim.step()

        g_loss_fake_cross = 0.
        if args.lambda_fake_cross_g > 0:
            if args.use_ema:
                e_ema.eval()
                w1, _ = e_ema(frames1)
                w2, _ = e_ema(frames2)
            else:
                w1, _ = encoder(frames1)
                w2, _ = encoder(frames2)
            dw = w2 - w1
            noise = mixing_noise(args.batch, args.latent, args.mixing, device)
            style = generator.get_styles(noise).view(args.batch, -1)
            if dw.shape[1] < style.shape[1]:  # W space
                dw = dw.repeat(1, args.n_latent)
            cross_img, _ = generator([style + dw], input_is_latent=True)
            fake_cross_pred = discriminator(cross_img)
            g_loss_fake_cross = g_nonsaturating_loss(fake_cross_pred)
            if args.no_sim_g:
                generator.zero_grad()
                (g_loss_fake_cross * args.lambda_fake_cross_g).backward()
                g_optim.step()

        g_loss = g_loss_fake * args.lambda_fake_g + g_loss_fake_cross * args.lambda_fake_cross_g
        loss_dict["g"] = g_loss
        if not args.no_sim_g:
            generator.zero_grad()
            g_loss.backward()
            g_optim.step()

        g_regularize = args.g_reg_every > 0 and i % args.g_reg_every == 0
        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing,
                                 device)
            fake_img, latents = generator(noise, return_latents=True)
            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)
            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss
            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]
            weighted_path_loss.backward()
            g_optim.step()
            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())
        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        accumulate(e_ema, e_module, accum)
        accumulate(g_ema, g_module, accum)

        loss_reduced = reduce_loss_dict(loss_dict)
        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()
        pix_loss_val = loss_reduced["pix"].mean().item()
        vgg_loss_val = loss_reduced["vgg"].mean().item()
        adv_loss_val = loss_reduced["adv"].mean().item()
        avg_pix_loss.update(pix_loss_val, real_img.shape[0])
        avg_vgg_loss.update(vgg_loss_val, real_img.shape[0])

        if get_rank() == 0:
            pbar.set_description((
                f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
                f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                f"augment: {ada_aug_p:.4f}; "
                f"pix: {pix_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; adv: {adv_loss_val:.4f}"
            ))

            if i % args.log_every == 0:
                with torch.no_grad():
                    latent_x, _ = e_ema(sample_x1)
                    fake_x, _ = generator([latent_x],
                                          input_is_latent=True,
                                          return_latents=False)
                    sample_pix_loss = torch.sum((sample_x1 - fake_x)**2)
                with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                    f.write(
                        f"{i:07d}; pix: {avg_pix_loss.avg}; vgg: {avg_vgg_loss.avg}; "
                        f"ref: {sample_pix_loss.item()};\n")

            if args.eval_every > 0 and i % args.eval_every == 0:
                with torch.no_grad():
                    g_ema.eval()
                    e_ema.eval()
                    # Sample FID
                    if args.truncation < 1:
                        mean_latent = g_ema.mean_latent(4096)
                    features = extract_feature_from_samples(
                        g_ema, inception, args.truncation, mean_latent, 64,
                        args.n_sample_fid, args.device).numpy()
                    sample_mean = np.mean(features, 0)
                    sample_cov = np.cov(features, rowvar=False)
                    fid_sa = calc_fid(sample_mean, sample_cov, real_mean,
                                      real_cov)
                    # Recon FID
                    features = extract_feature_from_recon_hybrid(
                        e_ema,
                        g_ema,
                        inception,
                        args.truncation,
                        mean_latent,
                        loader2,
                        args.device,
                        mode='recon',
                    ).numpy()
                    sample_mean = np.mean(features, 0)
                    sample_cov = np.cov(features, rowvar=False)
                    fid_re = calc_fid(sample_mean, sample_cov, real_mean,
                                      real_cov)
                    # Hybrid FID
                    features = extract_feature_from_recon_hybrid(
                        e_ema,
                        g_ema,
                        inception,
                        args.truncation,
                        mean_latent,
                        loader2,
                        args.device,
                        mode='hybrid',
                        shuffle_idx=fid_batch_idx).numpy()
                    sample_mean = np.mean(features, 0)
                    sample_cov = np.cov(features, rowvar=False)
                    fid_hy = calc_fid(sample_mean, sample_cov, real_mean,
                                      real_cov)
                # print("Sample FID:", fid_sa, "Recon FID:", fid_re, "Hybrid FID:", fid_hy)
                with open(os.path.join(args.log_dir, 'log_fid.txt'),
                          'a+') as f:
                    f.write(
                        f"{i:07d}; sample fid: {float(fid_sa):.4f}; recon fid: {float(fid_re):.4f}; hybrid fid: {float(fid_hy):.4f};\n"
                    )

            if wandb and args.wandb:
                wandb.log({
                    "Generator": g_loss_val,
                    "Discriminator": d_loss_val,
                    "Augment": ada_aug_p,
                    "Rt": r_t_stat,
                    "R1": r1_val,
                    "Path Length Regularization": path_loss_val,
                    "Mean Path Length": mean_path_length,
                    "Real Score": real_score_val,
                    "Fake Score": fake_score_val,
                    "Path Length": path_length_val,
                })

            if i % args.log_every == 0:
                with torch.no_grad():
                    g_ema.eval()
                    e_ema.eval()
                    nrow = int(args.n_sample**0.5)
                    nchw = list(sample_x1.shape)[1:]
                    # Fixed fake samples
                    sample, _ = g_ema([sample_z])
                    utils.save_image(
                        sample,
                        os.path.join(args.log_dir, 'sample',
                                     f"{str(i).zfill(6)}-sample.png"),
                        nrow=nrow,
                        normalize=True,
                        value_range=(-1, 1),
                    )
                    # Fake hybrid samples
                    w1, _ = e_ema(sample_x1)
                    w2, _ = e_ema(sample_x2)
                    dw = w2 - w1
                    style_z = g_ema.get_styles([sample_z
                                                ]).view(args.n_sample, -1)
                    if dw.shape[1] < style_z.shape[1]:  # W space
                        dw = dw.repeat(1, args.n_latent)
                    fake_img, _ = g_ema([style_z + dw], input_is_latent=True)
                    drive = torch.cat((
                        sample_x1.reshape(args.n_sample, 1, *nchw),
                        sample_x2.reshape(args.n_sample, 1, *nchw),
                    ), 1)
                    source = torch.cat((
                        sample.reshape(args.n_sample, 1, *nchw),
                        fake_img.reshape(args.n_sample, 1, *nchw),
                    ), 1)
                    sample = torch.cat((
                        drive.reshape(args.n_sample // nrow, 2 * nrow, *nchw),
                        source.reshape(args.n_sample // nrow, 2 * nrow, *nchw),
                    ), 1)
                    utils.save_image(
                        sample.reshape(4 * args.n_sample, *nchw),
                        os.path.join(args.log_dir, 'sample',
                                     f"{str(i).zfill(6)}-sample_hybrid.png"),
                        nrow=2 * nrow,
                        normalize=True,
                        value_range=(-1, 1),
                    )
                    # Reconstruction samples
                    latent_real, _ = e_ema(sample_x1)
                    fake_img, _ = g_ema([latent_real],
                                        input_is_latent=True,
                                        return_latents=False)
                    sample = torch.cat(
                        (sample_x1.reshape(args.n_sample // nrow, nrow, *nchw),
                         fake_img.reshape(args.n_sample // nrow, nrow, *nchw)),
                        1)
                    utils.save_image(
                        sample.reshape(2 * args.n_sample, *nchw),
                        os.path.join(args.log_dir, 'sample',
                                     f"{str(i).zfill(6)}-recon.png"),
                        nrow=nrow,
                        normalize=True,
                        value_range=(-1, 1),
                    )
                    # Cross reconstruction: [real_y1, real_y2; real_x1, fake_x2]
                    w1, _ = e_ema(sample_x1)
                    w2, _ = e_ema(sample_x2)
                    dw = w2 - w1
                    dw = torch.cat(dw.chunk(2, 0)[::-1],
                                   0) if sample_idx is None else dw[sample_idx,
                                                                    ...]
                    fake_img, _ = g_ema([w1 + dw],
                                        input_is_latent=True,
                                        return_latents=False)
                    # sample = torch.cat((sample_x2.reshape(args.n_sample//nrow, nrow, *nchw),
                    #                     fake_img.reshape(args.n_sample//nrow, nrow, *nchw)), 1)
                    drive = torch.cat((
                        torch.cat(sample_x1.chunk(2, 0)[::-1], 0).reshape(
                            args.n_sample, 1, *nchw),
                        torch.cat(sample_x2.chunk(2, 0)[::-1], 0).reshape(
                            args.n_sample, 1, *nchw),
                    ), 1)  # [n_sample, 2, C, H, w]
                    source = torch.cat((
                        sample_x1.reshape(args.n_sample, 1, *nchw),
                        fake_img.reshape(args.n_sample, 1, *nchw),
                    ), 1)  # [n_sample, 2, C, H, w]
                    sample = torch.cat((
                        drive.reshape(args.n_sample // nrow, 2 * nrow, *nchw),
                        source.reshape(args.n_sample // nrow, 2 * nrow, *nchw),
                    ), 1)  # [n_sample//nrow, 4*nrow, C, H, W]
                    utils.save_image(
                        sample.reshape(4 * args.n_sample, *nchw),
                        os.path.join(args.log_dir, 'sample',
                                     f"{str(i).zfill(6)}-cross.png"),
                        nrow=2 * nrow,
                        normalize=True,
                        value_range=(-1, 1),
                    )

            if i % args.save_every == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "d2":
                        d2_module.state_dict() if args.decouple_d else None,
                        "g_ema": g_ema.state_dict(),
                        "e_ema": e_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "d2_optim":
                        d2_optim.state_dict() if args.decouple_d else None,
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight',
                                 f"{str(i).zfill(6)}.pt"),
                )

            if i % args.save_latest_every == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "d2":
                        d2_module.state_dict() if args.decouple_d else None,
                        "g_ema": g_ema.state_dict(),
                        "e_ema": e_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "d2_optim":
                        d2_optim.state_dict() if args.decouple_d else None,
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight', f"latest.pt"),
                )
Exemple #13
0
def train(args, loader, loader2, generator, encoder, discriminator,
          discriminator2, vggnet, g_optim, e_optim, d_optim, d2_optim, g_ema,
          e_ema, device):
    inception = real_mean = real_cov = mean_latent = None
    if args.eval_every > 0:
        inception = nn.DataParallel(load_patched_inception_v3()).to(device)
        inception.eval()
        with open(args.inception, "rb") as f:
            embeds = pickle.load(f)
            real_mean = embeds["mean"]
            real_cov = embeds["cov"]
    if get_rank() == 0:
        if args.eval_every > 0:
            with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f:
                f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n")
        if args.log_every > 0:
            with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n")

    loader = sample_data(loader)
    pbar = range(args.iter)
    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    dynamic_ncols=True,
                    smoothing=0.01)

    mean_path_length = 0
    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {
        'recx_score': torch.tensor(0.0, device=device),
        'ae_fake': torch.tensor(0.0, device=device),
        'ae_real': torch.tensor(0.0, device=device),
        'pix': torch.tensor(0.0, device=device),
        'vgg': torch.tensor(0.0, device=device),
    }
    avg_pix_loss = util.AverageMeter()
    avg_vgg_loss = util.AverageMeter()

    if args.distributed:
        g_module = generator.module
        e_module = encoder.module
        d_module = discriminator.module
    else:
        g_module = generator
        e_module = encoder
        d_module = discriminator

    d2_module = None
    if discriminator2 is not None:
        if args.distributed:
            d2_module = discriminator2.module
        else:
            d2_module = discriminator2

    # When joint training enabled, d_weight balances reconstruction loss and adversarial loss on
    # recontructed real images. This does not balance the overall AE loss and GAN loss.
    d_weight = torch.tensor(1.0, device=device)
    last_layer = None
    if args.use_adaptive_weight:
        if args.distributed:
            last_layer = generator.module.get_last_layer()
        else:
            last_layer = generator.get_last_layer()
    g_scale = 1

    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0
    r_t_dict = {'real': 0, 'fake': 0, 'recx': 0}  # r_t stat

    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length,
                                      args.ada_every, device)

    sample_z = torch.randn(args.n_sample, args.latent, device=device)
    sample_x = load_real_samples(args, loader)
    if sample_x.ndim > 4:
        sample_x = sample_x[:, 0, ...]

    input_is_latent = args.latent_space != 'z'  # Encode in z space?

    n_step_max = max(args.n_step_d, args.n_step_e)

    requires_grad(g_ema, False)
    requires_grad(e_ema, False)

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")
            break

        if args.debug: util.seed_everything(i)
        real_imgs = [next(loader).to(device) for _ in range(n_step_max)]

        # Train Discriminator and Encoder
        requires_grad(generator, False)
        requires_grad(encoder, True)
        requires_grad(discriminator, True)
        requires_grad(discriminator2, True)
        for step_index in range(args.n_step_d):
            real_img = real_imgs[step_index]
            noise = mixing_noise(args.batch, args.latent, args.mixing, device)
            fake_img, _ = generator(noise)
            if args.augment:
                real_img_aug, _ = augment(real_img, ada_aug_p)
                fake_img_aug, _ = augment(fake_img, ada_aug_p)
            else:
                real_img_aug = real_img
                fake_img_aug = fake_img
            real_pred = discriminator(encoder(real_img_aug)[0])
            fake_pred = discriminator(encoder(fake_img_aug)[0])
            d_loss_real = F.softplus(-real_pred).mean()
            d_loss_fake = F.softplus(fake_pred).mean()
            loss_dict["real_score"] = real_pred.mean()
            loss_dict["fake_score"] = fake_pred.mean()

            d_loss_rec = 0.
            if args.lambda_rec_d > 0:
                latent_real, _ = encoder(real_img)
                rec_img, _ = generator([latent_real],
                                       input_is_latent=input_is_latent)
                if args.augment:
                    rec_img, _ = augment(rec_img, ada_aug_p)
                rec_pred = discriminator(encoder(rec_img)[0])
                d_loss_rec = F.softplus(rec_pred).mean()
                loss_dict["recx_score"] = rec_pred.mean()
                r_t_dict['recx'] = torch.sign(
                    rec_pred).sum().item() / args.batch

            d_loss = d_loss_real + d_loss_fake * args.lambda_fake_d + d_loss_rec * args.lambda_rec_d
            loss_dict["d"] = d_loss

            discriminator.zero_grad()
            encoder.zero_grad()
            d_loss.backward()
            d_optim.step()
            e_optim.step()

        if args.augment and args.augment_p == 0:
            ada_aug_p = ada_augment.tune(real_pred)
            r_t_stat = ada_augment.r_t_stat
        r_t_dict['real'] = torch.sign(real_pred).sum().item() / args.batch
        r_t_dict['fake'] = torch.sign(fake_pred).sum().item() / args.batch

        d_regularize = args.d_reg_every > 0 and i % args.d_reg_every == 0
        if d_regularize:
            real_img.requires_grad = True
            if args.augment:
                real_img_aug, _ = augment(real_img, ada_aug_p)
            else:
                real_img_aug = real_img
            real_pred = discriminator(encoder(real_img_aug)[0])
            r1_loss = d_r1_loss(real_pred, real_img)
            discriminator.zero_grad()
            encoder.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every +
             0 * real_pred[0]).backward()
            d_optim.step()
            e_optim.step()
        loss_dict["r1"] = r1_loss

        # Train Generator
        requires_grad(generator, True)
        requires_grad(encoder, False)
        requires_grad(discriminator, False)
        requires_grad(discriminator2, False)
        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)
        if args.augment:
            fake_img_aug, _ = augment(fake_img, ada_aug_p)
        else:
            fake_img_aug = fake_img
        fake_pred = discriminator(encoder(fake_img_aug)[0])
        g_loss_fake = g_nonsaturating_loss(fake_pred)
        loss_dict["g"] = g_loss_fake
        generator.zero_grad()
        (g_loss_fake * args.lambda_fake_g).backward()
        g_optim.step()

        g_regularize = args.g_reg_every > 0 and i % args.g_reg_every == 0
        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing,
                                 device)
            fake_img, latents = generator(noise, return_latents=True)
            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)
            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss
            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]
            weighted_path_loss.backward()
            g_optim.step()
            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())
        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        # Train Encoder (and Generator)
        joint = (not args.no_joint) and (g_scale > 1e-6)

        # Train AE on fake samples (latent reconstruction)
        if args.lambda_rec_w + (args.lambda_pix_fake + args.lambda_vgg_fake +
                                args.lambda_adv_fake) > 0:
            requires_grad(encoder, True)
            requires_grad(generator, joint)
            requires_grad(discriminator, False)
            requires_grad(discriminator2, False)
            for step_index in range(args.n_step_e):
                # mixing_prob = 0 if args.which_latent == 'w_tied' else args.mixing
                # noise = mixing_noise(args.batch, args.latent, mixing_prob, device)
                # fake_img, latent_fake = generator(noise, return_latents=True, detach_style=not args.no_detach_style)
                # if args.which_latent == 'w_tied':
                #     latent_fake = latent_fake[:,0,:]
                # else:
                #     latent_fake = latent_fake.view(args.batch, -1)
                # latent_pred, _ = encoder(fake_img)
                # ae_loss_fake = torch.mean((latent_pred - latent_fake.detach()) ** 2)

                ae_loss_fake = 0
                mixing_prob = 0 if args.which_latent == 'w_tied' else args.mixing
                if args.lambda_rec_w > 0:
                    noise = mixing_noise(args.batch, args.latent, mixing_prob,
                                         device)
                    fake_img, latent_fake = generator(
                        noise,
                        return_latents=True,
                        detach_style=not args.no_detach_style)
                    if args.which_latent == 'w_tied':
                        latent_fake = latent_fake[:, 0, :]
                    else:
                        latent_fake = latent_fake.view(args.batch, -1)
                    latent_pred, _ = encoder(fake_img)
                    ae_loss_fake = torch.mean(
                        (latent_pred - latent_fake.detach())**2)

                if args.lambda_pix_fake + args.lambda_vgg_fake + args.lambda_adv_fake > 0:
                    pix_loss = vgg_loss = adv_loss = torch.tensor(
                        0., device=device)
                    noise = mixing_noise(args.batch, args.latent, mixing_prob,
                                         device)
                    fake_img, _ = generator(noise, detach_style=False)
                    fake_img = fake_img.detach()

                    latent_pred, _ = encoder(fake_img)
                    rec_img, _ = generator([latent_pred],
                                           input_is_latent=input_is_latent)
                    if args.lambda_pix_fake > 0:
                        if args.pix_loss == 'l2':
                            pix_loss = torch.mean((rec_img - fake_img)**2)
                        elif args.pix_loss == 'l1':
                            pix_loss = F.l1_loss(rec_img, fake_img)
                    if args.lambda_vgg_fake > 0:
                        vgg_loss = torch.mean(
                            (vggnet(fake_img) - vggnet(rec_img))**2)

                    ae_loss_fake = (ae_loss_fake +
                                    pix_loss * args.lambda_pix_fake +
                                    vgg_loss * args.lambda_vgg_fake)

                loss_dict["ae_fake"] = ae_loss_fake

                if joint:
                    encoder.zero_grad()
                    generator.zero_grad()
                    (ae_loss_fake * args.lambda_rec_w).backward()
                    e_optim.step()
                    if args.g_decay is not None:
                        scale_grad(generator, g_scale)
                    # Do NOT update F (or generator.style). Grad should be zero when style
                    # is detached in generator, but we explicitly zero it, just in case.
                    if not args.no_detach_style:
                        generator.style.zero_grad()
                    g_optim.step()
                else:
                    encoder.zero_grad()
                    (ae_loss_fake * args.lambda_rec_w).backward()
                    e_optim.step()

        # Train AE on real samples (image reconstruction)
        if args.lambda_pix + args.lambda_vgg + args.lambda_adv > 0:
            requires_grad(encoder, True)
            requires_grad(generator, joint)
            requires_grad(discriminator, False)
            requires_grad(discriminator2, False)
            pix_loss = vgg_loss = adv_loss = torch.tensor(0., device=device)
            for step_index in range(args.n_step_e):
                real_img = real_imgs[step_index]
                latent_real, _ = encoder(real_img)
                rec_img, _ = generator([latent_real],
                                       input_is_latent=input_is_latent)
                if args.lambda_pix > 0:
                    if args.pix_loss == 'l2':
                        pix_loss = torch.mean((rec_img - real_img)**2)
                    elif args.pix_loss == 'l1':
                        pix_loss = F.l1_loss(rec_img, real_img)
                if args.lambda_vgg > 0:
                    vgg_loss = torch.mean(
                        (vggnet(real_img) - vggnet(rec_img))**2)
                if args.lambda_adv > 0:
                    if args.augment:
                        rec_img_aug, _ = augment(rec_img, ada_aug_p)
                    else:
                        rec_img_aug = rec_img
                    rec_pred = discriminator(encoder(rec_img_aug)[0])
                    adv_loss = g_nonsaturating_loss(rec_pred)

                if args.use_adaptive_weight and i >= args.disc_iter_start:
                    nll_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg
                    g_loss = adv_loss * args.lambda_adv
                    d_weight = calculate_adaptive_weight(nll_loss,
                                                         g_loss,
                                                         last_layer=last_layer)

                ae_loss_real = (pix_loss * args.lambda_pix +
                                vgg_loss * args.lambda_vgg +
                                d_weight * adv_loss * args.lambda_adv)
                loss_dict["ae_real"] = ae_loss_real
                loss_dict["pix"] = pix_loss
                loss_dict["vgg"] = vgg_loss
                loss_dict["adv"] = adv_loss

                if joint:
                    encoder.zero_grad()
                    generator.zero_grad()
                    ae_loss_real.backward()
                    e_optim.step()
                    if args.g_decay is not None:
                        scale_grad(generator, g_scale)
                    g_optim.step()
                else:
                    encoder.zero_grad()
                    ae_loss_real.backward()
                    e_optim.step()

        if args.g_decay is not None:
            g_scale *= args.g_decay

        # Update EMA
        ema_nimg = args.ema_kimg * 1000
        if args.ema_rampup is not None:
            ema_nimg = min(ema_nimg, i * args.batch * args.ema_rampup)
        accum = 0.5**(args.batch / max(ema_nimg, 1e-8))
        accumulate(g_ema, g_module, 0 if args.no_ema_g else accum)
        accumulate(e_ema, e_module, 0 if args.no_ema_e else accum)

        loss_reduced = reduce_loss_dict(loss_dict)
        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        ae_real_val = loss_reduced["ae_real"].mean().item()
        ae_fake_val = loss_reduced["ae_fake"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        recx_score_val = loss_reduced["recx_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()
        pix_loss_val = loss_reduced["pix"].mean().item()
        vgg_loss_val = loss_reduced["vgg"].mean().item()
        avg_pix_loss.update(pix_loss_val, real_img.shape[0])
        avg_vgg_loss.update(vgg_loss_val, real_img.shape[0])

        if get_rank() == 0:
            pbar.set_description((
                f"d: {d_loss_val:.4f}; r1: {r1_val:.4f}; "
                f"ae_fake: {ae_fake_val:.4f}; ae_real: {ae_real_val:.4f}; "
                f"g: {g_loss_val:.4f}; path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                f"augment: {ada_aug_p:.4f}; "
                f"d_weight: {d_weight.item():.4f}; "))

            if i % args.log_every == 0:
                with torch.no_grad():
                    g_ema.eval()
                    e_ema.eval()
                    nrow = int(args.n_sample**0.5)
                    nchw = list(sample_x.shape)[1:]
                    # Reconstruction of real images
                    latent_x, _ = e_ema(sample_x)
                    rec_real, _ = g_ema([latent_x],
                                        input_is_latent=input_is_latent)
                    sample = torch.cat(
                        (sample_x.reshape(args.n_sample // nrow, nrow, *nchw),
                         rec_real.reshape(args.n_sample // nrow, nrow, *nchw)),
                        1)
                    utils.save_image(
                        sample.reshape(2 * args.n_sample, *nchw),
                        os.path.join(args.log_dir, 'sample',
                                     f"{str(i).zfill(6)}-recon.png"),
                        nrow=nrow,
                        normalize=True,
                        value_range=(-1, 1),
                    )
                    ref_pix_loss = torch.sum(torch.abs(sample_x - rec_real))
                    ref_vgg_loss = torch.mean(
                        (vggnet(sample_x) -
                         vggnet(rec_real))**2) if vggnet is not None else 0
                    # Fixed fake samples and reconstructions
                    sample_gz, _ = g_ema([sample_z])
                    latent_gz, _ = e_ema(sample_gz)
                    rec_fake, _ = g_ema([latent_gz],
                                        input_is_latent=input_is_latent)
                    sample = torch.cat(
                        (sample_gz.reshape(args.n_sample // nrow, nrow, *nchw),
                         rec_fake.reshape(args.n_sample // nrow, nrow, *nchw)),
                        1)
                    utils.save_image(
                        sample.reshape(2 * args.n_sample, *nchw),
                        os.path.join(args.log_dir, 'sample',
                                     f"{str(i).zfill(6)}-sample.png"),
                        nrow=nrow,
                        normalize=True,
                        value_range=(-1, 1),
                    )

                with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                    f.write((
                        f"{i:07d}; "
                        f"d: {d_loss_val:.4f}; r1: {r1_val:.4f}; "
                        f"ae_fake: {ae_fake_val:.4f}; ae_real: {ae_real_val:.4f}; "
                        f"g: {g_loss_val:.4f}; path: {path_loss_val:.4f}; mean_path: {mean_path_length_avg:.4f}; "
                        f"augment: {ada_aug_p:.4f}; {'; '.join([f'{k}: {r_t_dict[k]:.4f}' for k in r_t_dict])}; "
                        f"real_score: {real_score_val:.4f}; fake_score: {fake_score_val:.4f}; recx_score: {recx_score_val:.4f}; "
                        f"pix: {avg_pix_loss.avg:.4f}; vgg: {avg_vgg_loss.avg:.4f}; "
                        f"ref_pix: {ref_pix_loss.item():.4f}; ref_vgg: {ref_vgg_loss.item():.4f}; "
                        f"d_weight: {d_weight.item():.4f}; "
                        f"\n"))

            if wandb and args.wandb:
                wandb.log({
                    "Generator": g_loss_val,
                    "Discriminator": d_loss_val,
                    "Augment": ada_aug_p,
                    "Rt": r_t_stat,
                    "R1": r1_val,
                    "Path Length Regularization": path_loss_val,
                    "Mean Path Length": mean_path_length,
                    "Real Score": real_score_val,
                    "Fake Score": fake_score_val,
                    "Path Length": path_length_val,
                })

            if args.eval_every > 0 and i % args.eval_every == 0:
                with torch.no_grad():
                    fid_sa = fid_re = fid_sr = 0
                    g_ema.eval()
                    e_ema.eval()
                    if args.truncation < 1:
                        mean_latent = g_ema.mean_latent(4096)
                    # Sample FID
                    if 'fid_sample' in args.which_metric:
                        features = extract_feature_from_samples(
                            g_ema, inception, args.truncation, mean_latent, 64,
                            args.n_sample_fid, args.device).numpy()
                        sample_mean = np.mean(features, 0)
                        sample_cov = np.cov(features, rowvar=False)
                        fid_sa = calc_fid(sample_mean, sample_cov, real_mean,
                                          real_cov)
                    # Sample reconstruction FID
                    if 'fid_sample_recon' in args.which_metric:
                        features = extract_feature_from_samples(
                            g_ema,
                            inception,
                            args.truncation,
                            mean_latent,
                            64,
                            args.n_sample_fid,
                            args.device,
                            mode='recon',
                            encoder=e_ema,
                            input_is_latent=input_is_latent,
                        ).numpy()
                        sample_mean = np.mean(features, 0)
                        sample_cov = np.cov(features, rowvar=False)
                        fid_sr = calc_fid(sample_mean, sample_cov, real_mean,
                                          real_cov)
                    # Real reconstruction FID
                    if 'fid_recon' in args.which_metric:
                        features = extract_feature_from_reconstruction(
                            e_ema,
                            g_ema,
                            inception,
                            args.truncation,
                            mean_latent,
                            loader2,
                            args.device,
                            input_is_latent=input_is_latent,
                            mode='recon',
                        ).numpy()
                        sample_mean = np.mean(features, 0)
                        sample_cov = np.cov(features, rowvar=False)
                        fid_re = calc_fid(sample_mean, sample_cov, real_mean,
                                          real_cov)
                with open(os.path.join(args.log_dir, 'log_fid.txt'),
                          'a+') as f:
                    f.write(
                        f"{i:07d}; sample: {float(fid_sa):.4f}; rec_fake: {float(fid_sr):.4f}; rec_real: {float(fid_re):.4f};\n"
                    )

            if i % args.save_every == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "d2":
                        d2_module.state_dict() if args.decouple_d else None,
                        "g_ema": g_ema.state_dict(),
                        "e_ema": e_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "d2_optim":
                        d2_optim.state_dict() if args.decouple_d else None,
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight',
                                 f"{str(i).zfill(6)}.pt"),
                )

            if i % args.save_latest_every == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "d2":
                        d2_module.state_dict() if args.decouple_d else None,
                        "g_ema": g_ema.state_dict(),
                        "e_ema": e_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "d2_optim":
                        d2_optim.state_dict() if args.decouple_d else None,
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight', f"latest.pt"),
                )
Exemple #14
0
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema,
          device):
    loader = sample_data(loader)
    print("Save checkpoint very %d iter." % args.save_iter)
    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    dynamic_ncols=True,
                    smoothing=0.01)

    mean_path_length = 0
    with torch.cuda.amp.autocast():
        r1_loss = torch.tensor(0.0, device=device)
        path_loss = torch.tensor(0.0, device=device)
        path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}

    if args.distributed:
        g_module = generator.module
        d_module = discriminator.module

    else:
        g_module = generator
        d_module = discriminator

    accum = 0.5**(32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0

    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256,
                                      device)

    sample_z = torch.randn(args.n_sample, args.latent, device=device)

    for idx in pbar:
        i = idx + args.start_iter
        print("iter: %d" % i)
        if i > args.iter:
            print("Done!")

            break
        ##############################################################################
        # Casts operations to mixed precision
        # https://pytorch.org/docs/stable/notes/amp_examples.html
        with torch.cuda.amp.autocast():
            real_img = next(loader)
            real_img = real_img.to(device)

            requires_grad(generator, False)
            requires_grad(discriminator, True)

            noise = mixing_noise(args.batch, args.latent, args.mixing, device)
            fake_img, _ = generator(noise)

            if args.augment:
                real_img_aug, _ = augment(real_img, ada_aug_p)
                fake_img, _ = augment(fake_img, ada_aug_p)

            else:
                real_img_aug = real_img

            fake_pred = discriminator(fake_img)
            real_pred = discriminator(real_img_aug)
            d_loss = d_logistic_loss(real_pred, fake_pred)

            loss_dict["d"] = d_loss
            loss_dict["real_score"] = real_pred.mean()
            loss_dict["fake_score"] = fake_pred.mean()

        discriminator.zero_grad()

        # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
        # Backward passes under autocast are not recommended.
        # Backward ops run in the same dtype autocast chose for corresponding forward ops.
        scaler.scale(d_loss).backward()
        # scaler.step() first unscales the gradients of the optimizer's assigned params.
        # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
        # otherwise, optimizer.step() is skipped.
        scaler.step(d_optim)
        scaler.update()

        if args.augment and args.augment_p == 0:
            ada_aug_p = ada_augment.tune(real_pred)
            r_t_stat = ada_augment.r_t_stat

        d_regularize = i % args.d_reg_every == 0

        # Runs the forward pass with autocasting.
        if d_regularize:
            with torch.cuda.amp.autocast():
                real_img.requires_grad = True
                real_pred = discriminator(real_img)

            r1_loss = d_r1_loss(real_pred, real_img)
            discriminator.zero_grad()
            scaler.scale((args.r1 / 2 * r1_loss * args.d_reg_every +
                          0 * real_pred[0])).backward()
            scaler.step(d_optim)
            scaler.update()

        loss_dict["r1"] = r1_loss

        # Runs the forward pass with autocasting.
        with torch.cuda.amp.autocast():
            requires_grad(generator, True)
            requires_grad(discriminator, False)

            noise = mixing_noise(args.batch, args.latent, args.mixing, device)
            fake_img, _ = generator(noise)

            if args.augment:
                fake_img, _ = augment(fake_img, ada_aug_p)

            fake_pred = discriminator(fake_img)
            g_loss = g_nonsaturating_loss(fake_pred)

            loss_dict["g"] = g_loss

        generator.zero_grad()
        scaler.scale(g_loss).backward()
        #g_loss.backward()
        scaler.step(g_optim)
        scaler.update()
        #g_optim.step()

        g_regularize = i % args.g_reg_every == 0

        if g_regularize:
            with torch.cuda.amp.autocast():
                path_batch_size = max(1, args.batch // args.path_batch_shrink)
                noise = mixing_noise(path_batch_size, args.latent, args.mixing,
                                     device)
                fake_img, latents = generator(noise, return_latents=True)

            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)
            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss
            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            scaler.scale(weighted_path_loss).backward()
            #weighted_path_loss.backward()
            scaler.step(g_optim)
            #g_optim.step()
            # Updates the scale for next iteration
            scaler.update()
            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())

        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        accumulate(g_ema, g_module, accum)
        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()

        if get_rank() == 0:
            pbar.set_description((
                f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
                f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                f"augment: {ada_aug_p:.4f}"))

            if wandb and args.wandb:
                wandb.log({
                    "Generator": g_loss_val,
                    "Discriminator": d_loss_val,
                    "Augment": ada_aug_p,
                    "Rt": r_t_stat,
                    "R1": r1_val,
                    "Path Length Regularization": path_loss_val,
                    "Mean Path Length": mean_path_length,
                    "Real Score": real_score_val,
                    "Fake Score": fake_score_val,
                    "Path Length": path_length_val,
                })

            if i % 100 == 0:
                with torch.no_grad():
                    g_ema.eval()
                    sample, _ = g_ema([sample_z])
                    utils.save_image(
                        sample,
                        f"sample/{str(i).zfill(6)}.png",
                        nrow=int(args.n_sample**0.5),
                        normalize=True,
                        range=(-1, 1),
                    )

            if i % args.save_iter == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                    },
                    f"checkpoint/{str(i).zfill(6)}.pt",
                )
Exemple #15
0
def train(args, loader, encoder, generator, discriminator, discriminator_w,
          vggnet, pwcnet, e_optim, g_optim, g1_optim, d_optim, dw_optim,
          e_ema, g_ema, device):
    loader = sample_data(loader)
    args.toggle_grads = True
    args.augment = False

    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01)
    mean_path_length = 0
    d_loss_val = 0
    e_loss_val = 0
    rec_loss_val = 0
    vgg_loss_val = 0
    adv_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    loss_dict = {"d": torch.tensor(0., device=device),
                 "real_score": torch.tensor(0., device=device),
                 "fake_score": torch.tensor(0., device=device),
                 "hybrid_score": torch.tensor(0., device=device),
                 "r1_d": torch.tensor(0., device=device),
                 "rec": torch.tensor(0., device=device),}
    avg_pix_loss = util.AverageMeter()
    avg_vgg_loss = util.AverageMeter()

    if args.distributed:
        e_module = encoder.module
        d_module = discriminator.module
        g_module = generator.module
    else:
        e_module = encoder
        d_module = discriminator
        g_module = generator

    accum = 0.5 ** (32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0

    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256, device)

    # sample_x = accumulate_batches(loader, args.n_sample).to(device)
    sample_x = load_real_samples(args, loader)
    sample_x1 = sample_x[:,0,...]
    sample_x2 = sample_x[:,-1,...]
    sample_idx = torch.randperm(args.n_sample)
    sample_z = torch.randn(args.n_sample, args.latent, device=device)

    for idx in pbar:
        i = idx + args.start_iter

        if get_rank() == 0:
            if i % args.log_every == 0:
                with torch.no_grad():
                    e_eval = e_ema
                    e_eval.eval()
                    g_ema.eval()
                    nrow = int(args.n_sample ** 0.5)
                    nchw = list(sample_x1.shape)[1:]
                    # Recon
                    latent_real, _ = e_eval(sample_x1)
                    fake_img, _ = g_ema([latent_real], input_is_latent=True, return_latents=False)
                    sample = torch.cat((sample_x1.reshape(args.n_sample//nrow, nrow, *nchw), 
                                        fake_img.reshape(args.n_sample//nrow, nrow, *nchw)), 1)
                    utils.save_image(
                        sample.reshape(2*args.n_sample, *nchw),
                        os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-recon.png"),
                        nrow=nrow,
                        normalize=True,
                        value_range=(-1, 1),
                    )
                    # Cross
                    w1, _ = e_eval(sample_x1)
                    w2, _ = e_eval(sample_x2)
                    delta_w = w2 - w1
                    delta_w = delta_w[sample_idx,...]
                    fake_img, _ = g_ema([w1 + delta_w], input_is_latent=True, return_latents=False)
                    sample = torch.cat((sample_x2.reshape(args.n_sample//nrow, nrow, *nchw), 
                                        fake_img.reshape(args.n_sample//nrow, nrow, *nchw)), 1)
                    utils.save_image(
                        sample.reshape(2*args.n_sample, *nchw),
                        os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-cross.png"),
                        nrow=nrow,
                        normalize=True,
                        value_range=(-1, 1),
                    )
                    # Sample
                    sample, _ = g_ema([sample_z])
                    utils.save_image(
                        sample,
                        os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-sample.png"),
                        nrow=nrow,
                        normalize=True,
                        value_range=(-1, 1),
                    )
                e_eval.train()

        if i > args.iter:
            print("Done!")
            break

        frames = next(loader)  # [N, T, C, H, W]
        batch = frames.shape[0]
        frames1 = frames[:,0,...]
        selected_indices = torch.sort(torch.multinomial(torch.ones(batch, args.nframe_num-1), 2)+1, 1)[0]
        frames2 = frames[range(batch),selected_indices[:,0],...]
        frames3 = frames[range(batch),selected_indices[:,1],...]
        frames1 = frames1.to(device)
        frames2 = frames2.to(device)
        frames3 = frames3.to(device)

        # Train Discriminator
        if args.toggle_grads:
            requires_grad(encoder, False)
            requires_grad(generator, False)
            requires_grad(discriminator, True)
            requires_grad(discriminator_w, True)

        real_img, fake_img, _, _, _ = cross_reconstruction(encoder, generator, frames1, frames2, frames3, args.cond_disc)

        if args.augment:
            real_img_aug, _ = augment(real_img, ada_aug_p)
            fake_img_aug, _ = augment(fake_img, ada_aug_p)
        else:
            real_img_aug = real_img
            fake_img_aug = fake_img
        
        fake_pred = discriminator(fake_img_aug)
        real_pred = discriminator(real_img_aug)
        d_loss = d_logistic_loss(real_pred, fake_pred)

        if args.lambda_gan > 0:
            noise = mixing_noise(args.batch, args.latent, args.mixing, device)
            fake_img, _ = generator(noise)
            if args.augment:
                fake_img, _ = augment(fake_img, ada_aug_p)
            fake_pred = discriminator(fake_img)
            fake_loss = F.softplus(fake_pred)
            d_loss += fake_loss.mean() * args.lambda_gan

        loss_dict["d"] = d_loss
        loss_dict["real_score"] = real_pred.mean()
        # loss_dict["fake_score"] = fake_pred.mean()
        fake_pred1, fake_pred2 = fake_pred.chunk(2, dim=0)
        loss_dict["fake_score"] = fake_pred1.mean()
        loss_dict["hybrid_score"] = fake_pred2.mean()

        discriminator.zero_grad()
        d_loss.backward()
        d_optim.step()

        if args.augment and args.augment_p == 0:
            ada_aug_p = ada_augment.tune(real_pred)
            r_t_stat = ada_augment.r_t_stat
        
        d_regularize = args.d_reg_every > 0 and i % args.d_reg_every == 0
        if d_regularize:
            # why not regularize on augmented real?
            real_img.requires_grad = True
            real_pred = discriminator(real_img)
            r1_loss_d = d_r1_loss(real_pred, real_img)

            d_optim.zero_grad()
            (args.r1 / 2 * r1_loss_d * args.d_reg_every + 0 * real_pred.view(-1)[0]).backward()
            # Why 0* ? Answer is here https://github.com/rosinality/stylegan2-pytorch/issues/76
            d_optim.step()

            loss_dict["r1_d"] = r1_loss_d
        
        # Train Discriminator_W
        if args.learned_prior and args.lambda_gan_w > 0:
            noise = mixing_noise(args.batch, args.latent, 0, device)
            fake_w = generator.get_latent(noise[0])
            real_w, _ = encoder(frames1)
            fake_pred = discriminator_w(fake_w)
            real_pred = discriminator_w(real_w)
            d_loss_w = d_logistic_loss(real_pred, fake_pred)
            dw_optim.zero_grad()
            d_loss_w.backward()
            dw_optim.step()

        # Train Encoder and Generator
        if args.toggle_grads:
            requires_grad(encoder, True)
            requires_grad(generator, True)
            requires_grad(discriminator, False)
            requires_grad(discriminator_w, False)
        pix_loss = vgg_loss = adv_loss = rec_loss = torch.tensor(0., device=device)

        _, fake_img, x_real, x_recon, x_cross = cross_reconstruction(encoder, generator, frames1, frames2, frames3, args.cond_disc)

        if args.lambda_adv > 0:
            if args.augment:
                fake_img_aug, _ = augment(fake_img, ada_aug_p)
            else:
                fake_img_aug = fake_img
            fake_pred = discriminator(fake_img_aug)
            adv_loss = g_nonsaturating_loss(fake_pred)

        if args.lambda_pix > 0:
            if args.pix_loss == 'l2':
                pix_loss = torch.mean((x_recon - x_real) ** 2)
            else:
                pix_loss = F.l1_loss(x_recon, x_real)

        if args.lambda_vgg > 0:
            real_feat = vggnet(x_real)
            fake_feat = vggnet(x_recon) if not args.vgg_on_cross else vggnet(x_cross)
            vgg_loss = torch.mean((fake_feat - real_feat) ** 2)

        e_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv

        if args.lambda_gan > 0 and not args.no_sim_opt:
            noise = mixing_noise(args.batch, args.latent, args.mixing, device)
            fake_img, _ = generator(noise)
            if args.augment:
                fake_img, _ = augment(fake_img, ada_aug_p)
            fake_pred = discriminator(fake_img)
            g_loss = g_nonsaturating_loss(fake_pred)
            e_loss += g_loss * args.lambda_gan

        loss_dict["e"] = e_loss
        loss_dict["pix"] = pix_loss
        loss_dict["vgg"] = vgg_loss
        loss_dict["adv"] = adv_loss

        e_optim.zero_grad()
        g_optim.zero_grad()
        e_loss.backward()
        e_optim.step()
        g_optim.step()

        if args.learned_prior:
            g_loss_w = 0.
            if args.lambda_gan_w > 0:
                noise = mixing_noise(args.batch, args.latent, 0, device)
                fake_w = generator.get_latent(noise[0])
                fake_pred = discriminator_w(fake_w)
                g_loss_w += g_nonsaturating_loss(fake_pred) * args.lambda_gan_w
            if args.lambda_adv_w > 0:
                noise = mixing_noise(args.batch, args.latent, args.mixing, device)
                fake_img, _ = generator(noise)
                fake_pred = discriminator(fake_img)
                g_loss_w += g_nonsaturating_loss(fake_pred) * args.lambda_adv_w
            g1_optim.zero_grad()
            g_loss_w.backward()
            g1_optim.step()
        
        if args.lambda_gan > 0 and args.no_sim_opt:
            noise = mixing_noise(args.batch, args.latent, args.mixing, device)
            fake_img, _ = generator(noise)
            if args.augment:
                fake_img, _ = augment(fake_img, ada_aug_p)
            fake_pred = discriminator(fake_img)
            g_loss = g_nonsaturating_loss(fake_pred) * args.lambda_gan
            generator.zero_grad()
            g_loss.backward()
            g_optim.step()
        
        g_regularize = args.lambda_gan > 0 and args.g_reg_every > 0 and i % args.g_reg_every == 0
        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing, device)
            fake_img, latents = generator(noise, return_latents=True)
            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length
            )
            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss
            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]
            weighted_path_loss.backward()
            g_optim.step()
            # mean_path_length_avg = (
            #     reduce_sum(mean_path_length).item() / get_world_size()
            # )
        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        accumulate(e_ema, e_module, accum)
        accumulate(g_ema, g_module, accum)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        e_loss_val = loss_reduced["e"].mean().item()
        r1_d_val = loss_reduced["r1_d"].mean().item()
        pix_loss_val = loss_reduced["pix"].mean().item()
        vgg_loss_val = loss_reduced["vgg"].mean().item()
        adv_loss_val = loss_reduced["adv"].mean().item()
        rec_loss_val = loss_reduced["rec"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        hybrid_score_val = loss_reduced["hybrid_score"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        # path_length_val = loss_reduced["path_length"].mean().item()
        avg_pix_loss.update(pix_loss_val, real_img.shape[0])
        avg_vgg_loss.update(vgg_loss_val, real_img.shape[0])

        if get_rank() == 0:
            pbar.set_description(
                (
                    f"d: {d_loss_val:.4f}; e: {e_loss_val:.4f}; r1_d: {r1_d_val:.4f}; "
                    f"pix: {pix_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; adv: {adv_loss_val:.4f}; "
                    f"path: {path_loss_val:.4f}; augment: {ada_aug_p:.4f}"
                )
            )

            if i % args.log_every == 0:
                with torch.no_grad():
                    latent_x, _ = e_ema(sample_x1)
                    fake_x, _ = generator([latent_x], input_is_latent=True, return_latents=False)
                    sample_pix_loss = torch.sum((sample_x1 - fake_x) ** 2)
                with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                    f.write(f"{i:07d}; pix: {avg_pix_loss.avg}; vgg: {avg_vgg_loss.avg}; "
                            f"ref: {sample_pix_loss.item()};\n")

            if wandb and args.wandb:
                wandb.log(
                    {
                        "Encoder": e_loss_val,
                        "Discriminator": d_loss_val,
                        "Augment": ada_aug_p,
                        "Rt": r_t_stat,
                        "R1 D": r1_d_val,
                        "Pix Loss": pix_loss_val,
                        "VGG Loss": vgg_loss_val,
                        "Adv Loss": adv_loss_val,
                        "Rec Loss": rec_loss_val,
                        "Real Score": real_score_val,
                        "Fake Score": fake_score_val,
                        "Hybrid Score": hybrid_score_val,
                    }
                )

            if i % args.save_every == 0:
                e_eval = e_ema
                torch.save(
                    {
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g": g_module.state_dict(),
                        "g_ema": g_module.state_dict(),
                        "e_ema": e_eval.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"),
                )
            
            if i % args.save_latest_every == 0:
                torch.save(
                    {
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g": g_module.state_dict(),
                        "g_ema": g_module.state_dict(),
                        "e_ema": e_eval.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight', f"latest.pt"),
                )
def train(args, loader, generator, discriminator, extra, g_optim, d_optim,
          e_optim, g_ema, device, g_source, d_source):
    loader = sample_data(loader)

    imsave_path = os.path.join('samples', args.exp)
    model_path = os.path.join('checkpoints', args.exp)

    if not os.path.exists(imsave_path):
        os.makedirs(imsave_path)
    if not os.path.exists(model_path):
        os.makedirs(model_path)

    # this defines the anchor points, and when sampling noise close to these, we impose image-level adversarial loss (Eq. 4 in the paper)
    init_z = torch.randn(args.n_train, args.latent, device=device)
    pbar = range(args.iter)
    sfm = nn.Softmax(dim=1)
    kl_loss = nn.KLDivLoss()
    sim = nn.CosineSimilarity()
    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    dynamic_ncols=True,
                    smoothing=0.01)

    mean_path_length = 0

    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}

    g_module = generator
    d_module = discriminator
    g_ema_module = g_ema.module

    accum = 0.5**(32 / (10 * 1000))
    ada_augment = torch.tensor([0.0, 0.0], device=device)
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    ada_aug_step = args.ada_target / args.ada_length
    r_t_stat = 0

    # this defines which level feature of the discriminator is used to implement the patch-level adversarial loss: could be anything between [0, args.highp]
    lowp, highp = 0, args.highp

    # the following defines the constant noise used for generating images at different stages of training
    sample_z = torch.randn(args.n_sample, args.latent, device=device)

    requires_grad(g_source, False)
    requires_grad(d_source, False)
    sub_region_z = get_subspace(args, init_z.clone(), vis_flag=True)
    for idx in pbar:
        i = idx + args.start_iter
        which = i % args.subspace_freq  # defines whether we sample from anchor region in this iteration or other

        if i > args.iter:
            print("Done!")
            break

        real_img = next(loader)
        real_img = real_img.to(device)

        requires_grad(generator, False)
        requires_grad(discriminator, True)
        requires_grad(extra, True)

        if which > 0:
            # sample normally, apply patch-level adversarial loss
            noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        else:
            # sample from anchors, apply image-level adversarial loss
            noise = [get_subspace(args, init_z.clone())]

        fake_img, _ = generator(noise)

        if args.augment:
            real_img, _ = augment(real_img, ada_aug_p)
            fake_img, _ = augment(fake_img, ada_aug_p)

        fake_pred, _ = discriminator(fake_img,
                                     extra=extra,
                                     flag=which,
                                     p_ind=np.random.randint(lowp, highp))
        real_pred, _ = discriminator(real_img,
                                     extra=extra,
                                     flag=which,
                                     p_ind=np.random.randint(lowp, highp),
                                     real=True)

        d_loss = d_logistic_loss(real_pred, fake_pred)

        loss_dict["d"] = d_loss
        loss_dict["real_score"] = real_pred.mean()
        loss_dict["fake_score"] = fake_pred.mean()

        discriminator.zero_grad()
        extra.zero_grad()
        d_loss.backward()
        d_optim.step()
        e_optim.step()

        if args.augment and args.augment_p == 0:
            ada_augment += torch.tensor(
                (torch.sign(real_pred).sum().item(), real_pred.shape[0]),
                device=device)
            ada_augment = reduce_sum(ada_augment)

            if ada_augment[1] > 255:
                pred_signs, n_pred = ada_augment.tolist()

                r_t_stat = pred_signs / n_pred

                if r_t_stat > args.ada_target:
                    sign = 1

                else:
                    sign = -1

                ada_aug_p += sign * ada_aug_step * n_pred
                ada_aug_p = min(1, max(0, ada_aug_p))
                ada_augment.mul_(0)

        d_regularize = i % args.d_reg_every == 0

        if d_regularize:
            real_img.requires_grad = True
            real_pred, _ = discriminator(real_img,
                                         extra=extra,
                                         flag=which,
                                         p_ind=np.random.randint(lowp, highp))
            real_pred = real_pred.view(real_img.size(0), -1)
            real_pred = real_pred.mean(dim=1).unsqueeze(1)

            r1_loss = d_r1_loss(real_pred, real_img)

            discriminator.zero_grad()
            extra.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every +
             0 * real_pred[0]).backward()

            d_optim.step()
            e_optim.step()
        loss_dict["r1"] = r1_loss

        requires_grad(generator, True)
        requires_grad(discriminator, False)
        requires_grad(extra, False)
        if which > 0:
            noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        else:
            noise = [get_subspace(args, init_z.clone())]

        fake_img, _ = generator(noise)

        if args.augment:
            fake_img, _ = augment(fake_img, ada_aug_p)

        fake_pred, _ = discriminator(fake_img,
                                     extra=extra,
                                     flag=which,
                                     p_ind=np.random.randint(lowp, highp))
        g_loss = g_nonsaturating_loss(fake_pred)

        # distance consistency loss
        with torch.set_grad_enabled(False):
            z = torch.randn(args.feat_const_batch, args.latent, device=device)
            feat_ind = numpy.random.randint(1,
                                            g_source.module.n_latent - 1,
                                            size=args.feat_const_batch)

            # computing source distances
            source_sample, feat_source = g_source([z], return_feats=True)
            dist_source = torch.zeros(
                [args.feat_const_batch, args.feat_const_batch - 1]).cuda()

            # iterating over different elements in the batch
            for pair1 in range(args.feat_const_batch):
                tmpc = 0
                # comparing the possible pairs
                for pair2 in range(args.feat_const_batch):
                    if pair1 != pair2:
                        anchor_feat = torch.unsqueeze(
                            feat_source[feat_ind[pair1]][pair1].reshape(-1), 0)
                        compare_feat = torch.unsqueeze(
                            feat_source[feat_ind[pair1]][pair2].reshape(-1), 0)
                        dist_source[pair1, tmpc] = sim(anchor_feat,
                                                       compare_feat)
                        tmpc += 1
            dist_source = sfm(dist_source)

        # computing distances among target generations
        _, feat_target = generator([z], return_feats=True)
        dist_target = torch.zeros(
            [args.feat_const_batch, args.feat_const_batch - 1]).cuda()

        # iterating over different elements in the batch
        for pair1 in range(args.feat_const_batch):
            tmpc = 0
            for pair2 in range(
                    args.feat_const_batch):  # comparing the possible pairs
                if pair1 != pair2:
                    anchor_feat = torch.unsqueeze(
                        feat_target[feat_ind[pair1]][pair1].reshape(-1), 0)
                    compare_feat = torch.unsqueeze(
                        feat_target[feat_ind[pair1]][pair2].reshape(-1), 0)
                    dist_target[pair1, tmpc] = sim(anchor_feat, compare_feat)
                    tmpc += 1
        dist_target = sfm(dist_target)
        rel_loss = args.kl_wt * \
            kl_loss(torch.log(dist_target), dist_source) # distance consistency loss
        g_loss = g_loss + rel_loss

        loss_dict["g"] = g_loss

        generator.zero_grad()
        g_loss.backward()
        g_optim.step()

        g_regularize = i % args.g_reg_every == 0

        # to save up space
        del rel_loss, g_loss, d_loss, fake_img, fake_pred, real_img, real_pred, anchor_feat, compare_feat, dist_source, dist_target, feat_source, feat_target

        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing,
                                 device)
            fake_img, latents = generator(noise, return_latents=True)

            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)

            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss

            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()

            g_optim.step()

            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())

        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        accumulate(g_ema_module, g_module, accum)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()

        if get_rank() == 0:
            pbar.set_description((
                f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
                f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                f"augment: {ada_aug_p:.4f}"))

            if wandb and args.wandb:
                wandb.log({
                    "Generator": g_loss_val,
                    "Discriminator": d_loss_val,
                    "Augment": ada_aug_p,
                    "Rt": r_t_stat,
                    "R1": r1_val,
                    "Path Length Regularization": path_loss_val,
                    "Mean Path Length": mean_path_length,
                    "Real Score": real_score_val,
                    "Fake Score": fake_score_val,
                    "Path Length": path_length_val,
                })

            if i % args.img_freq == 0:
                with torch.set_grad_enabled(False):
                    g_ema.eval()
                    sample, _ = g_ema([sample_z.data])
                    sample_subz, _ = g_ema([sub_region_z.data])
                    utils.save_image(
                        sample,
                        f"%s/{str(i).zfill(6)}.png" % (imsave_path),
                        nrow=int(args.n_sample**0.5),
                        normalize=True,
                        range=(-1, 1),
                    )
                    del sample

            if (i % args.save_freq == 0) and (i > 0):
                torch.save(
                    {
                        "g_ema": g_ema.state_dict(),
                        # uncomment the following lines only if you wish to resume training after saving. Otherwise, saving just the generator is sufficient for evaluations

                        #"g": g_module.state_dict(),
                        #"g_s": g_source.state_dict(),
                        #"d": d_module.state_dict(),
                        #"g_optim": g_optim.state_dict(),
                        #"d_optim": d_optim.state_dict(),
                    },
                    f"%s/{str(i).zfill(6)}.pt" % (model_path),
                )
def train(
    args,
    loader,
    encoder,
    generator,
    discriminator,
    discriminator3d,  # video disctiminator
    posterior,
    prior,
    factor,  # a learnable matrix
    vggnet,
    e_optim,
    d_optim,
    dv_optim,
    q_optim,  # q for posterior
    p_optim,  # p for prior
    f_optim,  # f for factor
    e_ema,
    device
):
    loader = sample_data(loader)

    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01)

    d_loss_val = 0
    e_loss_val = 0
    rec_loss_val = 0
    vgg_loss_val = 0
    adv_loss_val = 0
    loss_dict = {"d": torch.tensor(0., device=device),
                 "real_score": torch.tensor(0., device=device),
                 "fake_score": torch.tensor(0., device=device),
                 "r1_d": torch.tensor(0., device=device),
                 "r1_e": torch.tensor(0., device=device),
                 "rec": torch.tensor(0., device=device),}

    if args.distributed:
        e_module = encoder.module
        d_module = discriminator.module
        g_module = generator.module
    else:
        e_module = encoder
        d_module = discriminator
        g_module = generator

    accum = 0.5 ** (32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0

    latent_full = args.latent_full
    factor_dim_full = args.factor_dim_full

    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256, device)

    sample_x = accumulate_batches(loader, args.n_sample).to(device)
    utils.save_image(
        sample_x.view(-1, *list(sample_x.shape)[2:]),
        os.path.join(args.log_dir, 'sample', f"real-img.png"),
        nrow=sample_x.shape[1],
        normalize=True,
        value_range=(-1, 1),
    )
    util.save_video(
        sample_x[0],
        os.path.join(args.log_dir, 'sample', f"real-vid.mp4")
    )

    requires_grad(generator, False)  # always False
    generator.eval()  # Generator should be ema and in eval mode
    if args.no_update_encoder:
        encoder = e_ema if e_ema is not None else encoder
        requires_grad(encoder, False)
        encoder.eval()
    from models.networks_3d import GANLoss
    criterionGAN = GANLoss()
    # criterionL1 = nn.L1Loss()

    # if args.no_ema or e_ema is None:
    #     e_ema = encoder
    
    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")
            break

        data = next(loader)
        real_seq = data['frames']
        real_seq = real_seq.to(device)  # [N, T, C, H, W]
        shape = list(real_seq.shape)
        N, T = shape[:2]

        # Train Encoder with frame-level objectives
        if args.toggle_grads:
            if not args.no_update_encoder:
                requires_grad(encoder, True)
            requires_grad(discriminator, False)
        pix_loss = vgg_loss = adv_loss = rec_loss = vid_loss = l1y_loss = torch.tensor(0., device=device)

        # TODO: real_seq -> encoder -> posterior -> generator -> fake_seq
        # f: [N, latent_full]; y: [N, T, D]
        fake_img, fake_seq, y_post = reconstruct_sequence(args, real_seq, encoder, generator, factor, posterior, i, ret_y=True)
        # if args.debug == 'no_lstm':
        #     real_lat = encoder(real_seq.view(-1, *shape[2:]))
        #     fake_img, _ = generator([real_lat], input_is_latent=True, return_latents=False)
        #     fake_seq = fake_img.view(N, T, *shape[2:])
        # elif args.debug == 'decomp':
        #     real_lat = encoder(real_seq.view(-1, *shape[2:]))  # [N*T, latent_full]
        #     f_post = real_lat[::T, ...]
        #     z_post = real_lat.view(N, T, -1) - f_post.unsqueeze(1)
        #     if args.use_multi_head:
        #         y_post = []
        #         for z, w in zip(torch.split(z_post, 512, 2), factor.weight):
        #             y_post.append(torch.mm(z.view(N*T, -1), w).view(N, T, -1))
        #         y_post = torch.cat(y_post, 2)
        #     else:
        #         y_post = torch.mm(z_post.view(N*T, -1), factor.weight[0]).view(N, T, -1)
        #     z_post_hat = factor(y_post)
        #     f_expand = f_post.unsqueeze(1).expand(-1, T, -1)
        #     w_post = f_expand + z_post_hat
        #     fake_img, _ = generator([w_post.view(N*T, latent_full)], input_is_latent=True, return_latents=False)
        #     fake_seq = fake_img.view(N, T, *shape[2:])
        # else:
        #     real_lat = encoder(real_seq.view(-1, *shape[2:]))
        #     # single head: f_post [N, latent_full]; y_post [N, T, D]
        #     # multi head: f_post [N, n_latent, latent]; y_post [N, T, n_latent, d]
        #     f_post, y_post = posterior(real_lat.view(N, T, latent_full))
        #     z_post = factor(y_post)
        #     f_expand = f_post.unsqueeze(1).expand(-1, T, -1)
        #     w_post = f_expand + z_post  # shape [N, T, latent_full]
        #     fake_img, _ = generator([w_post.view(N*T, latent_full)], input_is_latent=True, return_latents=False)
        #     fake_seq = fake_img.view(N, T, *shape[2:])

        # TODO: sample frames
        real_img = real_seq.view(N*T, *shape[2:])
        # fake_img = fake_seq.view(N*T, *shape[2:])

        if args.lambda_adv > 0:
            if args.augment:
                fake_img_aug, _ = augment(fake_img, ada_aug_p)
            else:
                fake_img_aug = fake_img
            fake_pred = discriminator(fake_img_aug)
            adv_loss = g_nonsaturating_loss(fake_pred)

        # TODO: do we always put pix and vgg loss for all frames?
        if args.lambda_pix > 0:
            pix_loss = torch.mean((real_img - fake_img) ** 2)

        if args.lambda_vgg > 0:
            real_feat = vggnet(real_img)
            fake_feat = vggnet(fake_img)
            vgg_loss = torch.mean((real_feat - fake_feat) ** 2)
        
        # Train Encoder with video-level objectives
        # TODO: video adversarial loss
        if args.lambda_vid > 0:
            fake_pred = discriminator3d(flip_video(fake_seq.transpose(1, 2)))
            vid_loss = criterionGAN(fake_pred, True)
        
        if args.lambda_l1y > 0:
            # l1y_loss = criterionL1(y_post)
            l1y_loss = torch.mean(torch.abs(y_post))

        e_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv
        e_loss = e_loss + args.lambda_vid * vid_loss + args.lambda_l1y * l1y_loss
        loss_dict["e"] = e_loss
        loss_dict["pix"] = pix_loss
        loss_dict["vgg"] = vgg_loss
        loss_dict["adv"] = adv_loss
        
        if not args.no_update_encoder:
            encoder.zero_grad()
        posterior.zero_grad()
        e_loss.backward()
        q_optim.step()
        if not args.no_update_encoder:
            e_optim.step()

        # if args.train_on_fake:
        #     e_regularize = args.e_rec_every > 0 and i % args.e_rec_every == 0
        #     if e_regularize and args.lambda_rec > 0:
        #         noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        #         fake_img, latent_fake = generator(noise, input_is_latent=False, return_latents=True)
        #         latent_pred = encoder(fake_img)
        #         if latent_pred.ndim < 3:
        #             latent_pred = latent_pred.unsqueeze(1).repeat(1, latent_fake.size(1), 1)
        #         rec_loss = torch.mean((latent_fake - latent_pred) ** 2)
        #         encoder.zero_grad()
        #         (rec_loss * args.lambda_rec).backward()
        #         e_optim.step()
        #         loss_dict["rec"] = rec_loss

        # e_regularize = args.e_reg_every > 0 and i % args.e_reg_every == 0
        # if e_regularize:
        #     # why not regularize on augmented real?
        #     real_img.requires_grad = True
        #     real_pred = encoder(real_img)
        #     r1_loss_e = d_r1_loss(real_pred, real_img)

        #     encoder.zero_grad()
        #     (args.r1 / 2 * r1_loss_e * args.e_reg_every + 0 * real_pred.view(-1)[0]).backward()
        #     e_optim.step()

        #     loss_dict["r1_e"] = r1_loss_e

        if not args.no_update_encoder:
            if not args.no_ema and e_ema is not None:
                accumulate(e_ema, e_module, accum)
        
        # Train Discriminator
        if args.toggle_grads:
            requires_grad(encoder, False)
            requires_grad(discriminator, True)
        fake_img, fake_seq = reconstruct_sequence(args, real_seq, encoder, generator, factor, posterior)
        # if args.debug == 'no_lstm':
        #     real_lat = encoder(real_seq.view(-1, *shape[2:]))
        #     fake_img, _ = generator([real_lat], input_is_latent=True, return_latents=False)
        #     fake_seq = fake_img.view(N, T, *shape[2:])
        # elif args.debug == 'decomp':
        #     real_lat = encoder(real_seq.view(-1, *shape[2:]))  # [N*T, latent_full]
        #     f_post = real_lat[::T, ...]
        #     z_post = real_lat.view(N, T, -1) - f_post.unsqueeze(1)
        #     if args.use_multi_head:
        #         y_post = []
        #         for z, w in zip(torch.split(z_post, 512, 2), factor.weight):
        #             y_post.append(torch.mm(z.view(N*T, -1), w).view(N, T, -1))
        #         y_post = torch.cat(y_post, 2)
        #     else:
        #         y_post = torch.mm(z_post.view(N*T, -1), factor.weight[0]).view(N, T, -1)
        #     z_post_hat = factor(y_post)
        #     f_expand = f_post.unsqueeze(1).expand(-1, T, -1)
        #     w_post = f_expand + z_post_hat
        #     fake_img, _ = generator([w_post.view(N*T, latent_full)], input_is_latent=True, return_latents=False)
        #     fake_seq = fake_img.view(N, T, *shape[2:])
        # elif args.debug == 'coef':
        #     real_lat = encoder(real_seq.view(-1, *shape[2:]))  # [N*T, latent_full]
        #     f_post = real_lat[::T, ...]
        #     z_post_hat = real_lat.view(N, T, -1) - f_post.unsqueeze(1)
        #     y_post = torch.mm(z_post_hat.view(N*T, -1), factor.weight).view(N, T, -1)
        #     z_post = factor(y_post)
        #     f_expand = f_post.unsqueeze(1).expand(-1, T, -1)
        #     w_post = f_expand + z_post
        #     fake_img, _ = generator([w_post.view(N*T, latent_full)], input_is_latent=True, return_latents=False)
        #     fake_seq = fake_img.view(N, T, *shape[2:])
        # else:
        #     real_lat = encoder(real_seq.view(-1, *shape[2:]))
        #     f_post, y_post = posterior(real_lat.view(N, T, latent_full))
        #     z_post = factor(y_post)
        #     f_expand = f_post.unsqueeze(1).expand(-1, T, -1)
        #     w_post = f_expand + z_post  # shape [N, T, latent_full]
        #     fake_img, _ = generator([w_post.view(N*T, latent_full)], input_is_latent=True, return_latents=False)
        #     fake_seq = fake_img.view(N, T, *shape[2:])
        # fake_img = fake_seq.view(N*T, *shape[2:])
        if not args.no_update_discriminator:
            if args.lambda_adv > 0:
                if args.augment:
                    real_img_aug, _ = augment(real_img, ada_aug_p)
                    fake_img_aug, _ = augment(fake_img, ada_aug_p)
                else:
                    real_img_aug = real_img
                    fake_img_aug = fake_img
                
                fake_pred = discriminator(fake_img_aug)
                real_pred = discriminator(real_img_aug)
                d_loss = d_logistic_loss(real_pred, fake_pred)

            # Train video discriminator
            if args.lambda_vid > 0:
                pred_real = discriminator3d(flip_video(real_seq.transpose(1, 2)))
                pred_fake = discriminator3d(flip_video(fake_seq.transpose(1, 2)))
                dv_loss_real = criterionGAN(pred_real, True)
                dv_loss_fake = criterionGAN(pred_fake, False)
                dv_loss = 0.5 * (dv_loss_real + dv_loss_fake)
                d_loss = d_loss + dv_loss

            loss_dict["d"] = d_loss
            loss_dict["real_score"] = real_pred.mean()
            loss_dict["fake_score"] = fake_pred.mean()

            if args.lambda_adv > 0:
                discriminator.zero_grad()
            if args.lambda_vid > 0:
                discriminator3d.zero_grad()
            d_loss.backward()
            if args.lambda_adv > 0:
                d_optim.step()
            if args.lambda_vid > 0:
                dv_optim.step()

            if args.augment and args.augment_p == 0:
                ada_aug_p = ada_augment.tune(real_pred)
                r_t_stat = ada_augment.r_t_stat
            
            d_regularize = args.d_reg_every > 0 and i % args.d_reg_every == 0
            if d_regularize:
                # why not regularize on augmented real?
                real_img.requires_grad = True
                real_pred = discriminator(real_img)
                r1_loss_d = d_r1_loss(real_pred, real_img)

                discriminator.zero_grad()
                (args.r1 / 2 * r1_loss_d * args.d_reg_every + 0 * real_pred.view(-1)[0]).backward()
                # Why 0* ? Answer is here https://github.com/rosinality/stylegan2-pytorch/issues/76
                d_optim.step()

                loss_dict["r1_d"] = r1_loss_d

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        e_loss_val = loss_reduced["e"].mean().item()
        r1_d_val = loss_reduced["r1_d"].mean().item()
        r1_e_val = loss_reduced["r1_e"].mean().item()
        pix_loss_val = loss_reduced["pix"].mean().item()
        vgg_loss_val = loss_reduced["vgg"].mean().item()
        adv_loss_val = loss_reduced["adv"].mean().item()
        rec_loss_val = loss_reduced["rec"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()

        if get_rank() == 0:
            pbar.set_description(
                (
                    f"d: {d_loss_val:.4f}; e: {e_loss_val:.4f}; r1_d: {r1_d_val:.4f}; r1_e: {r1_e_val:.4f}; "
                    f"pix: {pix_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; adv: {adv_loss_val:.4f}; "
                    f"rec: {rec_loss_val:.4f}; augment: {ada_aug_p:.4f}"
                )
            )

            if wandb and args.wandb:
                wandb.log(
                    {
                        "Encoder": e_loss_val,
                        "Discriminator": d_loss_val,
                        "Augment": ada_aug_p,
                        "Rt": r_t_stat,
                        "R1 D": r1_d_val,
                        "R1 E": r1_e_val,
                        "Pix Loss": pix_loss_val,
                        "VGG Loss": vgg_loss_val,
                        "Adv Loss": adv_loss_val,
                        "Rec Loss": rec_loss_val,
                        "Real Score": real_score_val,
                        "Fake Score": fake_score_val,
                    }
                )

            if i % args.log_every == 0:
                with torch.no_grad():
                    e_eval = encoder if args.no_ema else e_ema
                    e_eval.eval()
                    posterior.eval()
                    # N = sample_x.shape[0]
                    fake_img, fake_seq = reconstruct_sequence(args, sample_x, e_eval, generator, factor, posterior)
                    # if args.debug == 'no_lstm':
                    #     real_lat = encoder(sample_x.view(-1, *shape[2:]))
                    #     fake_img, _ = generator([real_lat], input_is_latent=True, return_latents=False)
                    #     fake_seq = fake_img.view(N, T, *shape[2:])
                    # elif args.debug == 'decomp':
                    #     real_lat = encoder(sample_x.view(-1, *shape[2:]))  # [N*T, latent_full]
                    #     f_post = real_lat[::T, ...]
                    #     z_post = real_lat.view(N, T, -1) - f_post.unsqueeze(1)
                    #     if args.use_multi_head:
                    #         y_post = []
                    #         for z, w in zip(torch.split(z_post, 512, 2), factor.weight):
                    #             y_post.append(torch.mm(z.view(N*T, -1), w).view(N, T, -1))
                    #         y_post = torch.cat(y_post, 2)
                    #     else:
                    #         y_post = torch.mm(z_post.view(N*T, -1), factor.weight[0]).view(N, T, -1)
                    #     z_post_hat = factor(y_post)
                    #     f_expand = f_post.unsqueeze(1).expand(-1, T, -1)
                    #     w_post = f_expand + z_post_hat
                    #     fake_img, _ = generator([w_post.view(N*T, latent_full)], input_is_latent=True, return_latents=False)
                    #     fake_seq = fake_img.view(N, T, *shape[2:])
                    # else:
                    #     x_lat = encoder(sample_x.view(-1, *shape[2:]))
                    #     f_post, y_post = posterior(x_lat.view(N, T, latent_full))
                    #     z_post = factor(y_post)
                    #     f_expand = f_post.unsqueeze(1).expand(-1, T, -1)
                    #     w_post = f_expand + z_post
                    #     fake_img, _ = generator([w_post.view(N*T, latent_full)], input_is_latent=True, return_latents=False)
                    #     fake_seq = fake_img.view(N, T, *shape[2:])
                    utils.save_image(
                        torch.cat((sample_x, fake_seq), 1).view(-1, *shape[2:]),
                        os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-img_recon.png"),
                        nrow=T,
                        normalize=True,
                        value_range=(-1, 1),
                    )
                    util.save_video(
                        fake_seq[random.randint(0, args.n_sample-1)],
                        os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-vid_recon.mp4")
                    )
                    fake_img, fake_seq = swap_sequence(args, sample_x, e_eval, generator, factor, posterior)
                    utils.save_image(
                        torch.cat((sample_x, fake_seq), 1).view(-1, *shape[2:]),
                        os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-img_swap.png"),
                        nrow=T,
                        normalize=True,
                        value_range=(-1, 1),
                    )
                    e_eval.train()
                    posterior.train()

            if i % args.save_every == 0:
                e_eval = encoder if args.no_ema else e_ema
                torch.save(
                    {
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_module.state_dict(),
                        "e_ema": e_eval.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"),
                )
            
            if not args.debug and i % args.save_latest_every == 0:
                torch.save(
                    {
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_module.state_dict(),
                        "e_ema": e_eval.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                    },
                    os.path.join(args.log_dir, 'weight', f"latest.pt"),
                )
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device):
    inception = real_mean = real_cov = mean_latent = None
    if args.eval_every > 0:
        inception = nn.DataParallel(load_patched_inception_v3()).to(device)
        inception.eval()
        with open(args.inception, "rb") as f:
            embeds = pickle.load(f)
            real_mean = embeds["mean"]
            real_cov = embeds["cov"]
    if get_rank() == 0:
        if args.eval_every > 0:
            with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f:
                f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n")
        if args.log_every > 0:
            with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n")

    loader = sample_data(loader)
    pbar = range(args.iter)
    if get_rank() == 0:
        pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01)

    mean_path_length = 0
    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}

    if args.distributed:
        g_module = generator.module
        d_module = discriminator.module

    else:
        g_module = generator
        d_module = discriminator

    # accum = 0.5 ** (32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0

    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, args.ada_every, device)

    args.n_sheets = int(np.ceil(args.n_classes / args.n_class_per_sheet))
    args.n_sample_per_sheet = args.n_sample_per_class * args.n_class_per_sheet
    args.n_sample = args.n_sample_per_sheet * args.n_sheets
    sample_z = torch.randn(args.n_sample, args.latent, device=device)
    sample_y = torch.arange(args.n_classes).repeat(args.n_sample_per_class, 1).t().reshape(-1).to(device)
    if args.n_sample > args.n_sample_per_class * args.n_classes:
        sample_y1 = make_fake_label(args.n_sample - args.n_sample_per_class * args.n_classes, args.n_classes, device)
        sample_y = torch.cat([sample_y, sample_y1], 0)

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")

            break

        # Train Discriminator
        requires_grad(generator, False)
        requires_grad(discriminator, True)

        for step_index in range(args.n_step_d):
            real_img, real_labels = next(loader)
            real_img, real_labels = real_img.to(device), real_labels.to(device)

            noise = mixing_noise(args.batch, args.latent, args.mixing, device)
            fake_labels = make_fake_label(args.batch, args.n_classes, device)
            fake_img, _ = generator(noise, fake_labels)

            if args.augment:
                real_img_aug, _ = augment(real_img, ada_aug_p)
                fake_img, _ = augment(fake_img, ada_aug_p)

            else:
                real_img_aug = real_img

            fake_pred = discriminator(fake_img, fake_labels)
            real_pred = discriminator(real_img_aug, real_labels)
            d_loss = d_logistic_loss(real_pred, fake_pred)

            loss_dict["d"] = d_loss
            loss_dict["real_score"] = real_pred.mean()
            loss_dict["fake_score"] = fake_pred.mean()

            discriminator.zero_grad()
            d_loss.backward()
            d_optim.step()

        if args.augment and args.augment_p == 0:
            ada_aug_p = ada_augment.tune(real_pred)
            r_t_stat = ada_augment.r_t_stat

        d_regularize = i % args.d_reg_every == 0

        if d_regularize:
            real_img.requires_grad = True
            if args.augment:
                real_img_aug, _ = augment(real_img, ada_aug_p)
            else:
                real_img_aug = real_img
            real_pred = discriminator(real_img_aug, real_labels)
            r1_loss = d_r1_loss(real_pred, real_img)

            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward()

            d_optim.step()

        loss_dict["r1"] = r1_loss

        # Train Generator
        requires_grad(generator, True)
        requires_grad(discriminator, False)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_labels = make_fake_label(args.batch, args.n_classes, device)
        fake_img, _ = generator(noise, fake_labels)

        if args.augment:
            fake_img, _ = augment(fake_img, ada_aug_p)

        fake_pred = discriminator(fake_img, fake_labels)
        g_loss = g_nonsaturating_loss(fake_pred)

        loss_dict["g"] = g_loss

        generator.zero_grad()
        g_loss.backward()
        g_optim.step()

        g_regularize = args.g_reg_every > 0 and i % args.g_reg_every == 0
        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing, device)
            fake_labels = make_fake_label(args.batch, args.n_classes, device)
            fake_img, latents = generator(noise, fake_labels, return_latents=True)

            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length
            )

            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss

            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()

            g_optim.step()

            mean_path_length_avg = (
                reduce_sum(mean_path_length).item() / get_world_size()
            )

        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        # Update G_ema
        # G_ema = G * (1-ema_beta) + G_ema * ema_beta
        ema_nimg = args.ema_kimg * 1000
        if args.ema_rampup is not None:
            ema_nimg = min(ema_nimg, i * args.batch * args.ema_rampup)
        accum = 0.5 ** (args.batch / max(ema_nimg, 1e-8))
        accumulate(g_ema, g_module, accum)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()

        if get_rank() == 0:
            pbar.set_description(
                (
                    f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
                    f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                    f"augment: {ada_aug_p:.4f}"
                )
            )

            if wandb and args.wandb:
                wandb.log(
                    {
                        "Generator": g_loss_val,
                        "Discriminator": d_loss_val,
                        "Augment": ada_aug_p,
                        "Rt": r_t_stat,
                        "R1": r1_val,
                        "Path Length Regularization": path_loss_val,
                        "Mean Path Length": mean_path_length,
                        "Real Score": real_score_val,
                        "Fake Score": fake_score_val,
                        "Path Length": path_length_val,
                    }
                )
            
            if i % args.log_every == 0:
                with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                    f.write(
                        (
                            f"{i:07d}; "
                            f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
                            f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                            f"augment: {ada_aug_p:.4f};\n"
                        )
                    )

            if i % args.log_every == 0:
                with torch.no_grad():
                    g_ema.eval()
                    for sheet_index in range(args.n_sheets):
                        sample_z_sheet = sample_z[sheet_index*args.n_sample_per_sheet:(sheet_index+1)*args.n_sample_per_sheet]
                        sample_y_sheet = sample_y[sheet_index*args.n_sample_per_sheet:(sheet_index+1)*args.n_sample_per_sheet]
                        sample, _ = g_ema([sample_z_sheet], sample_y_sheet)
                        utils.save_image(
                            sample,
                            os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}_{sheet_index}.png"),
                            nrow=args.n_sample_per_class,
                            normalize=True,
                            value_range=(-1, 1),
                        )
            
            if args.eval_every > 0 and i % args.eval_every == 0:
                with torch.no_grad():
                    g_ema.eval()
                    if args.truncation < 1:
                        mean_latent = g_ema.mean_latent(4096)
                    features = extract_feature_from_samples(
                        g_ema, inception, args.truncation, mean_latent, 64, args.n_sample_fid, args.device,
                        n_classes=args.n_classes,
                    ).numpy()
                    sample_mean = np.mean(features, 0)
                    sample_cov = np.cov(features, rowvar=False)
                    fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov)
                # print("fid:", fid)
                with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f:
                    f.write(f"{i:07d}; fid: {float(fid):.4f};\n")

            if i % args.save_every == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"),
                )
            
            if i % args.save_latest_every == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight', f"latest.pt"),
                )
Exemple #19
0
def train(opt):
    lib.print_model_settings(locals().copy())

    """ dataset preparation """
    if not opt.data_filtering_off:
        print('Filtering the images containing characters which are not in opt.character')
        print('Filtering the images whose label is longer than opt.batch_max_length')
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    log = open(os.path.join(opt.exp_dir,opt.exp_name,'log_dataset.txt'), 'a')
    AlignCollate_valid = AlignPairCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)

    train_dataset, train_dataset_log = hierarchical_dataset(root=opt.train_data, opt=opt)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=opt.batch_size,
        sampler=data_sampler(train_dataset, shuffle=True, distributed=opt.distributed),
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True)
    log.write(train_dataset_log)
    print('-' * 80)

    valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=opt.batch_size,
        sampler=data_sampler(train_dataset, shuffle=False, distributed=opt.distributed),
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()

    if 'Attn' in opt.Prediction:
        converter = AttnLabelConverter(opt.character)
    else:
        converter = CTCLabelConverter(opt.character)
    
    opt.num_class = len(converter.character)

    
    # styleModel = StyleTensorEncoder(input_dim=opt.input_channel)
    # genModel = AdaIN_Tensor_WordGenerator(opt)
    # disModel = MsImageDisV2(opt)

    # styleModel = StyleLatentEncoder(input_dim=opt.input_channel, norm='none')
    # mixModel = Mixer(opt,nblk=3, dim=opt.latent)
    genModel = styleGANGen(opt.size, opt.latent, opt.n_mlp, opt.num_class, channel_multiplier=opt.channel_multiplier).to(device)
    disModel = styleGANDis(opt.size, channel_multiplier=opt.channel_multiplier, input_dim=opt.input_channel).to(device)
    g_ema = styleGANGen(opt.size, opt.latent, opt.n_mlp, opt.num_class, channel_multiplier=opt.channel_multiplier).to(device)
    ocrModel = ModelV1(opt).to(device)
    accumulate(g_ema, genModel, 0)

    # #  weight initialization
    # for currModel in [styleModel, mixModel]:
    #     for name, param in currModel.named_parameters():
    #         if 'localization_fc2' in name:
    #             print(f'Skip {name} as it is already initialized')
    #             continue
    #         try:
    #             if 'bias' in name:
    #                 init.constant_(param, 0.0)
    #             elif 'weight' in name:
    #                 init.kaiming_normal_(param)
    #         except Exception as e:  # for batchnorm.
    #             if 'weight' in name:
    #                 param.data.fill_(1)
    #             continue

    if opt.contentLoss == 'vis' or opt.contentLoss == 'seq':
        ocrCriterion = torch.nn.L1Loss()
    else:
        if 'CTC' in opt.Prediction:
            ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
        else:
            ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)  # ignore [GO] token = ignore index 0

    # vggRecCriterion = torch.nn.L1Loss()
    # vggModel = VGGPerceptualLossModel(models.vgg19(pretrained=True), vggRecCriterion)
    
    print('model input parameters', opt.imgH, opt.imgW, opt.input_channel, opt.output_channel,
          opt.hidden_size, opt.num_class, opt.batch_max_length)

    if opt.distributed:
        genModel = torch.nn.parallel.DistributedDataParallel(
            genModel,
            device_ids=[opt.local_rank],
            output_device=opt.local_rank,
            broadcast_buffers=False,
        )
        
        disModel = torch.nn.parallel.DistributedDataParallel(
            disModel,
            device_ids=[opt.local_rank],
            output_device=opt.local_rank,
            broadcast_buffers=False,
        )
        ocrModel = torch.nn.parallel.DistributedDataParallel(
            ocrModel,
            device_ids=[opt.local_rank],
            output_device=opt.local_rank,
            broadcast_buffers=False
        )
    
    # styleModel = torch.nn.DataParallel(styleModel).to(device)
    # styleModel.train()
    
    # mixModel = torch.nn.DataParallel(mixModel).to(device)
    # mixModel.train()
    
    # genModel = torch.nn.DataParallel(genModel).to(device)
    # g_ema = torch.nn.DataParallel(g_ema).to(device)
    genModel.train()
    g_ema.eval()

    # disModel = torch.nn.DataParallel(disModel).to(device)
    disModel.train()

    # vggModel = torch.nn.DataParallel(vggModel).to(device)
    # vggModel.eval()

    # ocrModel = torch.nn.DataParallel(ocrModel).to(device)
    # if opt.distributed:
    #     ocrModel.module.Transformation.eval()
    #     ocrModel.module.FeatureExtraction.eval()
    #     ocrModel.module.AdaptiveAvgPool.eval()
    #     # ocrModel.module.SequenceModeling.eval()
    #     ocrModel.module.Prediction.eval()
    # else:
    #     ocrModel.Transformation.eval()
    #     ocrModel.FeatureExtraction.eval()
    #     ocrModel.AdaptiveAvgPool.eval()
    #     # ocrModel.SequenceModeling.eval()
    #     ocrModel.Prediction.eval()
    ocrModel.eval()

    if opt.distributed:
        g_module = genModel.module
        d_module = disModel.module
    else:
        g_module = genModel
        d_module = disModel

    g_reg_ratio = opt.g_reg_every / (opt.g_reg_every + 1)
    d_reg_ratio = opt.d_reg_every / (opt.d_reg_every + 1)

    optimizer = optim.Adam(
        genModel.parameters(),
        lr=opt.lr * g_reg_ratio,
        betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),
    )
    dis_optimizer = optim.Adam(
        disModel.parameters(),
        lr=opt.lr * d_reg_ratio,
        betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio),
    )

    ## Loading pre-trained files
    if opt.modelFolderFlag:
        if len(glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth")))>0:
            opt.saved_synth_model = glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth"))[-1]

    if opt.saved_ocr_model !='' and opt.saved_ocr_model !='None':
        if not opt.distributed:
            ocrModel = torch.nn.DataParallel(ocrModel)
        print(f'loading pretrained ocr model from {opt.saved_ocr_model}')
        checkpoint = torch.load(opt.saved_ocr_model)
        ocrModel.load_state_dict(checkpoint)
        #temporary fix
        if not opt.distributed:
            ocrModel = ocrModel.module
    
    if opt.saved_gen_model !='' and opt.saved_gen_model !='None':
        print(f'loading pretrained gen model from {opt.saved_gen_model}')
        checkpoint = torch.load(opt.saved_gen_model, map_location=lambda storage, loc: storage)
        genModel.module.load_state_dict(checkpoint['g'])
        g_ema.module.load_state_dict(checkpoint['g_ema'])

    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        print(f'loading pretrained synth model from {opt.saved_synth_model}')
        checkpoint = torch.load(opt.saved_synth_model)
        
        # styleModel.load_state_dict(checkpoint['styleModel'])
        # mixModel.load_state_dict(checkpoint['mixModel'])
        genModel.load_state_dict(checkpoint['genModel'])
        g_ema.load_state_dict(checkpoint['g_ema'])
        disModel.load_state_dict(checkpoint['disModel'])
        
        optimizer.load_state_dict(checkpoint["optimizer"])
        dis_optimizer.load_state_dict(checkpoint["dis_optimizer"])

    # if opt.imgReconLoss == 'l1':
    #     recCriterion = torch.nn.L1Loss()
    # elif opt.imgReconLoss == 'ssim':
    #     recCriterion = ssim
    # elif opt.imgReconLoss == 'ms-ssim':
    #     recCriterion = msssim
    

    # loss averager
    loss_avg = Averager()
    loss_avg_dis = Averager()
    loss_avg_gen = Averager()
    loss_avg_imgRecon = Averager()
    loss_avg_vgg_per = Averager()
    loss_avg_vgg_sty = Averager()
    loss_avg_ocr = Averager()

    log_r1_val = Averager()
    log_avg_path_loss_val = Averager()
    log_avg_mean_path_length_avg = Averager()
    log_ada_aug_p = Averager()

    """ final options """
    with open(os.path.join(opt.exp_dir,opt.exp_name,'opt.txt'), 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)

    """ start training """
    start_iter = 0
    
    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        try:
            start_iter = int(opt.saved_synth_model.split('_')[-2].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    
    #get schedulers
    scheduler = get_scheduler(optimizer,opt)
    dis_scheduler = get_scheduler(dis_optimizer,opt)

    start_time = time.time()
    iteration = start_iter
    cntr=0
    
    mean_path_length = 0
    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}

    accum = 0.5 ** (32 / (10 * 1000))
    ada_augment = torch.tensor([0.0, 0.0], device=device)
    ada_aug_p = opt.augment_p if opt.augment_p > 0 else 0.0
    ada_aug_step = opt.ada_target / opt.ada_length
    r_t_stat = 0

    sample_z = torch.randn(opt.n_sample, opt.latent, device=device)

    while(True):
        # print(cntr)
        # train part
       
        if opt.lr_policy !="None":
            scheduler.step()
            dis_scheduler.step()
        
        image_input_tensors, image_gt_tensors, labels_1, labels_2 = iter(train_loader).next()

        image_input_tensors = image_input_tensors.to(device)
        image_gt_tensors = image_gt_tensors.to(device)
        batch_size = image_input_tensors.size(0)

        requires_grad(genModel, False)
        # requires_grad(styleModel, False)
        # requires_grad(mixModel, False)
        requires_grad(disModel, True)

        text_1, length_1 = converter.encode(labels_1, batch_max_length=opt.batch_max_length)
        text_2, length_2 = converter.encode(labels_2, batch_max_length=opt.batch_max_length)
        
        
        #forward pass from style and word generator
        # style = styleModel(image_input_tensors).squeeze(2).squeeze(2)
        style = mixing_noise(opt.batch_size, opt.latent, opt.mixing, device)
        # scInput = mixModel(style,text_2)
        if 'CTC' in opt.Prediction:
            images_recon_2,_ = genModel(style, text_2, input_is_latent=opt.input_latent)
        else:
            images_recon_2,_ = genModel(style, text_2[:,1:-1], input_is_latent=opt.input_latent)
        
        #Domain discriminator: Dis update
        if opt.augment:
            image_gt_tensors_aug, _ = augment(image_gt_tensors, ada_aug_p)
            images_recon_2, _ = augment(images_recon_2, ada_aug_p)

        else:
            image_gt_tensors_aug = image_gt_tensors

        fake_pred = disModel(images_recon_2)
        real_pred = disModel(image_gt_tensors_aug)
        disCost = d_logistic_loss(real_pred, fake_pred)

        loss_dict["d"] = disCost*opt.disWeight
        loss_dict["real_score"] = real_pred.mean()
        loss_dict["fake_score"] = fake_pred.mean()

        loss_avg_dis.add(disCost)

        disModel.zero_grad()
        disCost.backward()
        dis_optimizer.step()

        if opt.augment and opt.augment_p == 0:
            ada_augment += torch.tensor(
                (torch.sign(real_pred).sum().item(), real_pred.shape[0]), device=device
            )
            ada_augment = reduce_sum(ada_augment)

            if ada_augment[1] > 255:
                pred_signs, n_pred = ada_augment.tolist()

                r_t_stat = pred_signs / n_pred

                if r_t_stat > opt.ada_target:
                    sign = 1

                else:
                    sign = -1

                ada_aug_p += sign * ada_aug_step * n_pred
                ada_aug_p = min(1, max(0, ada_aug_p))
                ada_augment.mul_(0)

        d_regularize = cntr % opt.d_reg_every == 0

        if d_regularize:
            image_gt_tensors.requires_grad = True
            image_input_tensors.requires_grad = True
            cat_tensor = image_gt_tensors
            real_pred = disModel(cat_tensor)
            
            r1_loss = d_r1_loss(real_pred, cat_tensor)

            disModel.zero_grad()
            (opt.r1 / 2 * r1_loss * opt.d_reg_every + 0 * real_pred[0]).backward()

            dis_optimizer.step()

        loss_dict["r1"] = r1_loss

        
        # #[Style Encoder] + [Word Generator] update
        image_input_tensors, image_gt_tensors, labels_1, labels_2 = iter(train_loader).next()
        
        image_input_tensors = image_input_tensors.to(device)
        image_gt_tensors = image_gt_tensors.to(device)
        batch_size = image_input_tensors.size(0)

        requires_grad(genModel, True)
        # requires_grad(styleModel, True)
        # requires_grad(mixModel, True)
        requires_grad(disModel, False)

        text_1, length_1 = converter.encode(labels_1, batch_max_length=opt.batch_max_length)
        text_2, length_2 = converter.encode(labels_2, batch_max_length=opt.batch_max_length)

        # style = styleModel(image_input_tensors).squeeze(2).squeeze(2)
        # scInput = mixModel(style,text_2)

        # images_recon_2,_ = genModel([scInput], input_is_latent=opt.input_latent)
        style = mixing_noise(batch_size, opt.latent, opt.mixing, device)
        
        if 'CTC' in opt.Prediction:
            images_recon_2, _ = genModel(style, text_2)
        else:
            images_recon_2, _ = genModel(style, text_2[:,1:-1])

        if opt.augment:
            images_recon_2, _ = augment(images_recon_2, ada_aug_p)

        fake_pred = disModel(images_recon_2)
        disGenCost = g_nonsaturating_loss(fake_pred)

        loss_dict["g"] = disGenCost

        # # #Adversarial loss
        # # disGenCost = disModel.module.calc_gen_loss(torch.cat((images_recon_2,image_input_tensors),dim=1))

        # #Input reconstruction loss
        # recCost = recCriterion(images_recon_2,image_gt_tensors)

        # #vgg loss
        # vggPerCost, vggStyleCost = vggModel(image_gt_tensors, images_recon_2)
        #ocr loss
        text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)
        length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
        if opt.contentLoss == 'vis' or opt.contentLoss == 'seq':
            preds_recon = ocrModel(images_recon_2, text_for_pred, is_train=False, returnFeat=opt.contentLoss)
            preds_gt = ocrModel(image_gt_tensors, text_for_pred, is_train=False, returnFeat=opt.contentLoss)
            ocrCost = ocrCriterion(preds_recon, preds_gt)
        else:
            if 'CTC' in opt.Prediction:
                
                preds_recon = ocrModel(images_recon_2, text_for_pred, is_train=False)
                # preds_o = preds_recon[:, :text_1.shape[1], :]
                preds_size = torch.IntTensor([preds_recon.size(1)] * batch_size)
                preds_recon_softmax = preds_recon.log_softmax(2).permute(1, 0, 2)
                ocrCost = ocrCriterion(preds_recon_softmax, text_2, preds_size, length_2)
                
                #predict ocr recognition on generated images
                # preds_recon_size = torch.IntTensor([preds_recon.size(1)] * batch_size)
                _, preds_recon_index = preds_recon.max(2)
                labels_o_ocr = converter.decode(preds_recon_index.data, preds_size.data)

                #predict ocr recognition on gt style images
                preds_s = ocrModel(image_input_tensors, text_for_pred, is_train=False)
                # preds_s = preds_s[:, :text_1.shape[1] - 1, :]
                preds_s_size = torch.IntTensor([preds_s.size(1)] * batch_size)
                _, preds_s_index = preds_s.max(2)
                labels_s_ocr = converter.decode(preds_s_index.data, preds_s_size.data)

                #predict ocr recognition on gt stylecontent images
                preds_sc = ocrModel(image_gt_tensors, text_for_pred, is_train=False)
                # preds_sc = preds_sc[:, :text_2.shape[1] - 1, :]
                preds_sc_size = torch.IntTensor([preds_sc.size(1)] * batch_size)
                _, preds_sc_index = preds_sc.max(2)
                labels_sc_ocr = converter.decode(preds_sc_index.data, preds_sc_size.data)

            else:
                preds_recon = ocrModel(images_recon_2, text_for_pred[:, :-1], is_train=False)  # align with Attention.forward
                target_2 = text_2[:, 1:]  # without [GO] Symbol
                ocrCost = ocrCriterion(preds_recon.view(-1, preds_recon.shape[-1]), target_2.contiguous().view(-1))

                #predict ocr recognition on generated images
                _, preds_o_index = preds_recon.max(2)
                labels_o_ocr = converter.decode(preds_o_index, length_for_pred)
                for idx, pred in enumerate(labels_o_ocr):
                    pred_EOS = pred.find('[s]')
                    labels_o_ocr[idx] = pred[:pred_EOS]  # prune after "end of sentence" token ([s])

                #predict ocr recognition on gt style images
                preds_s = ocrModel(image_input_tensors, text_for_pred, is_train=False)
                _, preds_s_index = preds_s.max(2)
                labels_s_ocr = converter.decode(preds_s_index, length_for_pred)
                for idx, pred in enumerate(labels_s_ocr):
                    pred_EOS = pred.find('[s]')
                    labels_s_ocr[idx] = pred[:pred_EOS]  # prune after "end of sentence" token ([s])
                
                #predict ocr recognition on gt stylecontent images
                preds_sc = ocrModel(image_gt_tensors, text_for_pred, is_train=False)
                _, preds_sc_index = preds_sc.max(2)
                labels_sc_ocr = converter.decode(preds_sc_index, length_for_pred)
                for idx, pred in enumerate(labels_sc_ocr):
                    pred_EOS = pred.find('[s]')
                    labels_sc_ocr[idx] = pred[:pred_EOS]  # prune after "end of sentence" token ([s])

        # cost =  opt.reconWeight*recCost + opt.disWeight*disGenCost + opt.vggPerWeight*vggPerCost + opt.vggStyWeight*vggStyleCost + opt.ocrWeight*ocrCost
        cost =  opt.disWeight*disGenCost + opt.ocrWeight*ocrCost

        # styleModel.zero_grad()
        genModel.zero_grad()
        # mixModel.zero_grad()
        disModel.zero_grad()
        # vggModel.zero_grad()
        ocrModel.zero_grad()
        
        cost.backward()
        optimizer.step()
        loss_avg.add(cost)

        g_regularize = cntr % opt.g_reg_every == 0

        if g_regularize:
            image_input_tensors, image_gt_tensors, labels_1, labels_2 = iter(train_loader).next()
        
            image_input_tensors = image_input_tensors.to(device)
            image_gt_tensors = image_gt_tensors.to(device)
            batch_size = image_input_tensors.size(0)

            text_1, length_1 = converter.encode(labels_1, batch_max_length=opt.batch_max_length)
            text_2, length_2 = converter.encode(labels_2, batch_max_length=opt.batch_max_length)

            path_batch_size = max(1, batch_size // opt.path_batch_shrink)

            # style = styleModel(image_input_tensors).squeeze(2).squeeze(2)
            # scInput = mixModel(style,text_2)

            # images_recon_2, latents = genModel([scInput],input_is_latent=opt.input_latent, return_latents=True)

            style = mixing_noise(path_batch_size, opt.latent, opt.mixing, device)
            
            
            if 'CTC' in opt.Prediction:
                images_recon_2, latents = genModel(style, text_2[:path_batch_size], return_latents=True)
            else:
                images_recon_2, latents = genModel(style, text_2[:path_batch_size,1:-1], return_latents=True)
            
            
            path_loss, mean_path_length, path_lengths = g_path_regularize(
                images_recon_2, latents, mean_path_length
            )

            genModel.zero_grad()
            weighted_path_loss = opt.path_regularize * opt.g_reg_every * path_loss

            if opt.path_batch_shrink:
                weighted_path_loss += 0 * images_recon_2[0, 0, 0, 0]

            weighted_path_loss.backward()

            optimizer.step()

            mean_path_length_avg = (
                reduce_sum(mean_path_length).item() / get_world_size()
            )

        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        accumulate(g_ema, g_module, accum)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()


        #Individual losses
        loss_avg_gen.add(opt.disWeight*disGenCost)
        loss_avg_imgRecon.add(torch.tensor(0.0))
        loss_avg_vgg_per.add(torch.tensor(0.0))
        loss_avg_vgg_sty.add(torch.tensor(0.0))
        loss_avg_ocr.add(opt.ocrWeight*ocrCost)

        log_r1_val.add(loss_reduced["path"])
        log_avg_path_loss_val.add(loss_reduced["path"])
        log_avg_mean_path_length_avg.add(torch.tensor(mean_path_length_avg))
        log_ada_aug_p.add(torch.tensor(ada_aug_p))
        
        if get_rank() == 0:
            # pbar.set_description(
            #     (
            #         f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
            #         f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
            #         f"augment: {ada_aug_p:.4f}"
            #     )
            # )

            if wandb and opt.wandb:
                wandb.log(
                    {
                        "Generator": g_loss_val,
                        "Discriminator": d_loss_val,
                        "Augment": ada_aug_p,
                        "Rt": r_t_stat,
                        "R1": r1_val,
                        "Path Length Regularization": path_loss_val,
                        "Mean Path Length": mean_path_length,
                        "Real Score": real_score_val,
                        "Fake Score": fake_score_val,   
                        "Path Length": path_length_val,
                    }
                )
            # if cntr % 100 == 0:
            #     with torch.no_grad():
            #         g_ema.eval()
            #         sample, _ = g_ema([scInput[:,:opt.latent],scInput[:,opt.latent:]])
            #         utils.save_image(
            #             sample,
            #             os.path.join(opt.trainDir, f"sample_{str(cntr).zfill(6)}.png"),
            #             nrow=int(opt.n_sample ** 0.5),
            #             normalize=True,
            #             range=(-1, 1),
            #         )


        # validation part
        if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' 
            
            #Save training images
            curr_batch_size = style[0].shape[0]
            images_recon_2, _ = g_ema(style, text_2[:curr_batch_size], input_is_latent=opt.input_latent)
            
            os.makedirs(os.path.join(opt.trainDir,str(iteration)), exist_ok=True)
            for trImgCntr in range(batch_size):
                try:
                    if opt.contentLoss == 'vis' or opt.contentLoss == 'seq':
                        save_image(tensor2im(image_input_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_sInput_'+labels_1[trImgCntr]+'.png'))
                        save_image(tensor2im(image_gt_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csGT_'+labels_2[trImgCntr]+'.png'))
                        save_image(tensor2im(images_recon_2[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csRecon_'+labels_2[trImgCntr]+'.png'))
                    else:
                        save_image(tensor2im(image_input_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_sInput_'+labels_1[trImgCntr]+'_'+labels_s_ocr[trImgCntr]+'.png'))
                        save_image(tensor2im(image_gt_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csGT_'+labels_2[trImgCntr]+'_'+labels_sc_ocr[trImgCntr]+'.png'))
                        save_image(tensor2im(images_recon_2[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csRecon_'+labels_2[trImgCntr]+'_'+labels_o_ocr[trImgCntr]+'.png'))
                except:
                    print('Warning while saving training image')
            
            elapsed_time = time.time() - start_time
            # for log
            
            with open(os.path.join(opt.exp_dir,opt.exp_name,'log_train.txt'), 'a') as log:
                # styleModel.eval()
                genModel.eval()
                g_ema.eval()
                # mixModel.eval()
                disModel.eval()
                
                with torch.no_grad():                    
                    valid_loss, infer_time, length_of_data = validation_synth_v6(
                        iteration, g_ema, ocrModel, disModel, ocrCriterion, valid_loader, converter, opt)
                
                # styleModel.train()
                genModel.train()
                # mixModel.train()
                disModel.train()

                # training loss and validation loss
                loss_log = f'[{iteration+1}/{opt.num_iter}] Train Synth loss: {loss_avg.val():0.5f}, \
                    Train Dis loss: {loss_avg_dis.val():0.5f}, Train Gen loss: {loss_avg_gen.val():0.5f},\
                    Train OCR loss: {loss_avg_ocr.val():0.5f}, \
                    Train R1-val loss: {log_r1_val.val():0.5f}, Train avg-path-loss: {log_avg_path_loss_val.val():0.5f}, \
                    Train mean-path-length loss: {log_avg_mean_path_length_avg.val():0.5f}, Train ada-aug-p: {log_ada_aug_p.val():0.5f}, \
                    Valid Synth loss: {valid_loss[0]:0.5f}, \
                    Valid Dis loss: {valid_loss[1]:0.5f}, Valid Gen loss: {valid_loss[2]:0.5f}, \
                    Valid OCR loss: {valid_loss[6]:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                
                
                #plotting
                lib.plot.plot(os.path.join(opt.plotDir,'Train-Synth-Loss'), loss_avg.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-Dis-Loss'), loss_avg_dis.val().item())
                
                lib.plot.plot(os.path.join(opt.plotDir,'Train-Gen-Loss'), loss_avg_gen.val().item())
                # lib.plot.plot(os.path.join(opt.plotDir,'Train-ImgRecon1-Loss'), loss_avg_imgRecon.val().item())
                # lib.plot.plot(os.path.join(opt.plotDir,'Train-VGG-Per-Loss'), loss_avg_vgg_per.val().item())
                # lib.plot.plot(os.path.join(opt.plotDir,'Train-VGG-Sty-Loss'), loss_avg_vgg_sty.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-OCR-Loss'), loss_avg_ocr.val().item())

                lib.plot.plot(os.path.join(opt.plotDir,'Train-r1_val'), log_r1_val.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-path_loss_val'), log_avg_path_loss_val.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-mean_path_length_avg'), log_avg_mean_path_length_avg.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-ada_aug_p'), log_ada_aug_p.val().item())

                lib.plot.plot(os.path.join(opt.plotDir,'Valid-Synth-Loss'), valid_loss[0].item())
                lib.plot.plot(os.path.join(opt.plotDir,'Valid-Dis-Loss'), valid_loss[1].item())

                lib.plot.plot(os.path.join(opt.plotDir,'Valid-Gen-Loss'), valid_loss[2].item())
                # lib.plot.plot(os.path.join(opt.plotDir,'Valid-ImgRecon1-Loss'), valid_loss[3].item())
                # lib.plot.plot(os.path.join(opt.plotDir,'Valid-VGG-Per-Loss'), valid_loss[4].item())
                # lib.plot.plot(os.path.join(opt.plotDir,'Valid-VGG-Sty-Loss'), valid_loss[5].item())
                lib.plot.plot(os.path.join(opt.plotDir,'Valid-OCR-Loss'), valid_loss[6].item())
                
                print(loss_log)

                loss_avg.reset()
                loss_avg_dis.reset()

                loss_avg_gen.reset()
                loss_avg_imgRecon.reset()
                loss_avg_vgg_per.reset()
                loss_avg_vgg_sty.reset()
                loss_avg_ocr.reset()

                log_r1_val.reset()
                log_avg_path_loss_val.reset()
                log_avg_mean_path_length_avg.reset()
                log_ada_aug_p.reset()
                

            lib.plot.flush()

        lib.plot.tick()

        # save model per 1e+5 iter.
        if (iteration) % 1e+4 == 0:
            torch.save({
                # 'styleModel':styleModel.state_dict(),
                # 'mixModel':mixModel.state_dict(),
                'genModel':g_module.state_dict(),
                'g_ema':g_ema.state_dict(),
                'disModel':d_module.state_dict(),
                'optimizer':optimizer.state_dict(),
                'dis_optimizer':dis_optimizer.state_dict()}, 
                os.path.join(opt.exp_dir,opt.exp_name,'iter_'+str(iteration+1)+'_synth.pth'))
            

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
        iteration += 1
        cntr+=1
def train(args, loader, generator, discriminator, optimizer, g_ema, device):
    ckpt_dir = 'checkpoints/stylegan-acgd'
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
    fig_dir = 'figs/stylegan-acgd'
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)
    loader = sample_data(loader)
    pbar = range(args.iter)
    pbar = tqdm(pbar,
                initial=args.start_iter,
                dynamic_ncols=True,
                smoothing=0.01)
    mean_path_length = 0

    r1_loss = torch.tensor(0.0, device=device)
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}
    if args.gpu_num > 1:
        g_module = generator.module
        d_module = discriminator.module
    else:
        g_module = generator
        d_module = discriminator
    accum = 0.5**(32 / (10 * 1000))
    ada_augment = torch.tensor([0.0, 0.0], device=device)
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    ada_aug_step = args.ada_target / args.ada_length
    r_t_stat = 0

    sample_z = torch.randn(args.n_sample, args.latent, device=device)

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")
            break

        real_img = next(loader)
        real_img = real_img.to(device)
        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)

        if args.augment:
            real_img_aug, _ = augment(real_img, ada_aug_p)
            fake_img, _ = augment(fake_img, ada_aug_p)
        else:
            real_img_aug = real_img

        fake_pred = discriminator(fake_img)
        real_pred = discriminator(real_img_aug)

        d_loss = d_logistic_loss(real_pred, fake_pred)
        # d_loss = fake_pred.mean() - real_pred.mean()
        loss_dict["loss"] = d_loss.item()
        loss_dict["real_score"] = real_pred.mean().item()
        loss_dict["fake_score"] = fake_pred.mean().item()

        # d_regularize = i % args.d_reg_every == 0
        d_regularize = False
        if d_regularize:
            real_img_cp = real_img.clone().detach()
            real_img_cp.requires_grad = True
            real_pred_cp = discriminator(real_img_cp)
            r1_loss = d_r1_loss(real_pred_cp, real_img_cp)
            d_loss += args.r1 / 2 * r1_loss * args.d_reg_every
        loss_dict["r1"] = r1_loss.item()

        # g_regularize = i % args.g_reg_every == 0
        g_regularize = False
        if g_regularize:  # TODO adapt code for nn.DataParallel
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing,
                                 device)
            fake_img, latents = generator(noise, return_latents=True)

            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)
            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss
            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            d_loss += weighted_path_loss
            mean_path_length_avg = mean_path_length.item()

        loss_dict["path"] = path_loss.mean().item()
        loss_dict["path_length"] = path_lengths.mean().item()

        optimizer.step(d_loss)
        # update ada_aug_p
        if args.augment and args.augment_p == 0:
            ada_augment_data = torch.tensor(
                (torch.sign(real_pred).sum().item(), real_pred.shape[0]),
                device=device)
            ada_augment += ada_augment_data
            if ada_augment[1] > 255:
                pred_signs, n_pred = ada_augment.tolist()
                r_t_stat = pred_signs / n_pred
                if r_t_stat > args.ada_target:
                    sign = 1
                else:
                    sign = -1
                ada_aug_p += sign * ada_aug_step * n_pred
                ada_aug_p = min(1, max(0, ada_aug_p))
                ada_augment.mul_(0)

        accumulate(g_ema, g_module, accum)

        d_loss_val = loss_dict["loss"]
        r1_val = loss_dict['r1']
        path_loss_val = loss_dict["path"]
        real_score_val = loss_dict["real_score"]
        fake_score_val = loss_dict["fake_score"]
        path_length_val = loss_dict["path_length"]

        pbar.set_description((
            f"d: {d_loss_val:.4f}; g: {d_loss_val:.4f}; r1: {r1_val:.4f}; "
            f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
            f"augment: {ada_aug_p:.4f}"))
        if wandb and args.wandb:
            wandb.log({
                "Generator": d_loss_val,
                "Discriminator": d_loss_val,
                "Augment": ada_aug_p,
                "Rt": r_t_stat,
                "R1": r1_val,
                "Path Length Regularization": path_loss_val,
                "Mean Path Length": mean_path_length,
                "Real Score": real_score_val,
                "Fake Score": fake_score_val,
                "Path Length": path_length_val,
            })
        if i % 100 == 0:
            with torch.no_grad():
                g_ema.eval()
                sample, _ = g_ema([sample_z])
                utils.save_image(
                    sample,
                    f"figs/stylegan-acgd/{str(i).zfill(6)}.png",
                    nrow=int(args.n_sample**0.5),
                    normalize=True,
                    range=(-1, 1),
                )
        if i % 1000 == 0:
            torch.save(
                {
                    "g": g_module.state_dict(),
                    "d": d_module.state_dict(),
                    "g_ema": g_ema.state_dict(),
                    "d_optim": optimizer.state_dict(),
                    "args": args,
                    "ada_aug_p": ada_aug_p,
                },
                f"checkpoints/stylegan-acgd/{str(i).zfill(6)}.pt",
            )
Exemple #21
0
def train(args, loader, loader2, generator, encoder, discriminator,
          discriminator2, vggnet, g_optim, e_optim, d_optim, d2_optim, g_ema,
          e_ema, device):
    inception = real_mean = real_cov = mean_latent = None
    if args.eval_every > 0:
        inception = nn.DataParallel(load_patched_inception_v3()).to(device)
        inception.eval()
        with open(args.inception, "rb") as f:
            embeds = pickle.load(f)
            real_mean = embeds["mean"]
            real_cov = embeds["cov"]
    if get_rank() == 0:
        if args.eval_every > 0:
            with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f:
                f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n")
        if args.log_every > 0:
            with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n")

    loader = sample_data(loader)
    pbar = range(args.iter)
    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    dynamic_ncols=True,
                    smoothing=0.01)

    mean_path_length = 0
    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}
    avg_pix_loss = util.AverageMeter()
    avg_vgg_loss = util.AverageMeter()

    if args.distributed:
        g_module = generator.module
        e_module = encoder.module
        d_module = discriminator.module
    else:
        g_module = generator
        e_module = encoder
        d_module = discriminator

    d2_module = None
    if discriminator2 is not None:
        if args.distributed:
            d2_module = discriminator2.module
        else:
            d2_module = discriminator2

    # accum = 0.5 ** (32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0
    r_t_dict = {'real': 0, 'fake': 0, 'recx': 0}  # r_t stat
    real_diff = fake_diff = count = 0
    g_scale = 1
    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length,
                                      args.ada_every, device)
    if args.decouple_d and args.augment:
        ada_aug_p2 = args.augment_p if args.augment_p > 0 else 0.0
        # r_t_stat2 = 0
        if args.augment_p == 0:
            ada_augment2 = AdaptiveAugment(args.ada_target, args.ada_length,
                                           args.ada_every, device)

    sample_z = torch.randn(args.n_sample, args.latent, device=device)
    sample_x = load_real_samples(args, loader)
    sample_x1 = sample_x2 = sample_idx = fid_batch_idx = None
    if sample_x.ndim > 4:
        sample_x1 = sample_x[:, 0, ...]
        sample_x2 = sample_x[:, -1, ...]
        sample_x = sample_x[:, 0, ...]

    n_step_max = max(args.n_step_d, args.n_step_e)

    requires_grad(g_ema, False)
    requires_grad(e_ema, False)

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")
            break

        real_imgs = [next(loader).to(device) for _ in range(n_step_max)]

        # Train Discriminator
        requires_grad(generator, False)
        requires_grad(encoder, False)
        requires_grad(discriminator, True)
        for step_index in range(args.n_step_d):
            real_img = real_imgs[step_index]
            noise = mixing_noise(args.batch, args.latent, args.mixing, device)
            if args.use_ema:
                g_ema.eval()
                fake_img, _ = g_ema(noise)
            else:
                fake_img, _ = generator(noise)
            if args.augment:
                real_img_aug, _ = augment(real_img, ada_aug_p)
                fake_img, _ = augment(fake_img, ada_aug_p)
            else:
                real_img_aug = real_img
            fake_pred = discriminator(fake_img)
            real_pred = discriminator(real_img_aug)
            d_loss_fake = F.softplus(fake_pred).mean()
            d_loss_real = F.softplus(-real_pred).mean()
            loss_dict["real_score"] = real_pred.mean()
            loss_dict["fake_score"] = fake_pred.mean()

            d_loss_rec = 0.
            if args.lambda_rec_d > 0 and not args.decouple_d:  # Do not train D on x_rec if decouple_d
                if args.use_ema:
                    e_ema.eval()
                    g_ema.eval()
                    latent_real, _ = e_ema(real_img)
                    rec_img, _ = g_ema([latent_real], input_is_latent=True)
                else:
                    latent_real, _ = encoder(real_img)
                    rec_img, _ = generator([latent_real], input_is_latent=True)
                if args.augment:
                    rec_img, _ = augment(rec_img, ada_aug_p)
                rec_pred = discriminator(rec_img)
                d_loss_rec = F.softplus(rec_pred).mean()
                loss_dict["recx_score"] = rec_pred.mean()

            d_loss = d_loss_real + d_loss_fake * args.lambda_fake_d + d_loss_rec * args.lambda_rec_d
            loss_dict["d"] = d_loss

            discriminator.zero_grad()
            d_loss.backward()
            d_optim.step()

        if args.augment and args.augment_p == 0:
            ada_aug_p = ada_augment.tune(real_pred)
            r_t_stat = ada_augment.r_t_stat
        # Compute batchwise r_t
        r_t_dict['real'] = torch.sign(real_pred).sum().item() / args.batch
        r_t_dict['fake'] = torch.sign(fake_pred).sum().item() / args.batch

        with torch.no_grad():
            real_diff += torch.mean(real_pred - rec_pred).item()
            noise = mixing_noise(args.batch, args.latent, args.mixing, device)
            x_fake, _ = generator(noise)
            x_recf, _ = generator([encoder(x_fake)[0]], input_is_latent=True)
            recf_pred = discriminator(x_recf)
            fake_pred = discriminator(x_fake)
            fake_diff += torch.mean(fake_pred - recf_pred).item()
            count += 1

        d_regularize = i % args.d_reg_every == 0
        if d_regularize:
            real_img.requires_grad = True
            if args.augment:
                real_img_aug, _ = augment(real_img, ada_aug_p)
            else:
                real_img_aug = real_img
            real_pred = discriminator(real_img_aug)
            r1_loss = d_r1_loss(real_pred, real_img)
            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every +
             0 * real_pred[0]).backward()
            d_optim.step()
        loss_dict["r1"] = r1_loss

        # Train Discriminator2
        if args.decouple_d and discriminator2 is not None:
            requires_grad(generator, False)
            requires_grad(encoder, False)
            requires_grad(discriminator2, True)
            for step_index in range(
                    args.n_step_e):  # n_step_d2 is same as n_step_e
                real_img = real_imgs[step_index]
                if args.use_ema:
                    e_ema.eval()
                    g_ema.eval()
                    latent_real, _ = e_ema(real_img)
                    rec_img, _ = g_ema([latent_real], input_is_latent=True)
                else:
                    latent_real, _ = encoder(real_img)
                    rec_img, _ = generator([latent_real], input_is_latent=True)
                if args.augment:
                    real_img_aug, _ = augment(real_img, ada_aug_p2)
                    rec_img, _ = augment(rec_img, ada_aug_p2)
                else:
                    real_img_aug = real_img
                rec_pred = discriminator2(rec_img)
                real_pred = discriminator2(real_img_aug)
                d2_loss_rec = F.softplus(rec_pred).mean()
                d2_loss_real = F.softplus(-real_pred).mean()

                d2_loss = d2_loss_real + d2_loss_rec
                loss_dict["d2"] = d2_loss
                loss_dict["recx_score"] = rec_pred.mean()

                discriminator2.zero_grad()
                d2_loss.backward()
                d2_optim.step()

                real_diff += torch.mean(real_pred - rec_pred).item()

            d_regularize = args.d_reg_every > 0 and i % args.d_reg_every == 0
            if d_regularize:
                real_img.requires_grad = True
                real_pred = discriminator2(real_img)
                r1_loss = d_r1_loss(real_pred, real_img)
                discriminator2.zero_grad()
                (args.r1 / 2 * r1_loss * args.d_reg_every +
                 0 * real_pred[0]).backward()
                d2_optim.step()

            if args.augment and args.augment_p == 0:
                ada_aug_p2 = ada_augment2.tune(rec_pred)
                # r_t_stat2 = ada_augment2.r_t_stat

        r_t_dict['recx'] = torch.sign(rec_pred).sum().item() / args.batch

        # Train Encoder
        requires_grad(encoder, True)
        requires_grad(generator, args.train_ge)
        requires_grad(discriminator, False)
        requires_grad(discriminator2, False)
        pix_loss = vgg_loss = adv_loss = torch.tensor(0., device=device)
        for step_index in range(args.n_step_e):
            real_img = real_imgs[step_index]
            latent_real, _ = encoder(real_img)
            if args.use_ema:
                g_ema.eval()
                rec_img, _ = g_ema([latent_real], input_is_latent=True)
            else:
                rec_img, _ = generator([latent_real], input_is_latent=True)
            if args.lambda_pix > 0:
                if args.pix_loss == 'l2':
                    pix_loss = torch.mean((rec_img - real_img)**2)
                elif args.pix_loss == 'l1':
                    pix_loss = F.l1_loss(rec_img, real_img)
                else:
                    raise NotImplementedError
            if args.lambda_vgg > 0:
                vgg_loss = torch.mean((vggnet(real_img) - vggnet(rec_img))**2)
            if args.lambda_adv > 0:
                if not args.decouple_d:
                    if args.augment:
                        rec_img_aug, _ = augment(rec_img, ada_aug_p)
                    else:
                        rec_img_aug = rec_img
                    rec_pred = discriminator(rec_img_aug)
                else:
                    if args.augment:
                        rec_img_aug, _ = augment(rec_img, ada_aug_p2)
                    else:
                        rec_img_aug = rec_img
                    rec_pred = discriminator2(rec_img_aug)
                adv_loss = g_nonsaturating_loss(rec_pred)

            e_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv
            loss_dict["e"] = e_loss
            loss_dict["pix"] = pix_loss
            loss_dict["vgg"] = vgg_loss
            loss_dict["adv"] = adv_loss

            if args.train_ge:
                encoder.zero_grad()
                generator.zero_grad()
                e_loss.backward()
                e_optim.step()
                if args.g_decay < 1:
                    manually_scale_grad(generator, g_scale)
                    g_scale *= args.g_decay
                g_optim.step()
            else:
                encoder.zero_grad()
                e_loss.backward()
                e_optim.step()

        # Train Generator
        requires_grad(generator, True)
        requires_grad(encoder, False)
        requires_grad(discriminator, False)
        requires_grad(discriminator2, False)
        real_img = real_imgs[0]
        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)
        if args.augment:
            fake_img, _ = augment(fake_img, ada_aug_p)
        fake_pred = discriminator(fake_img)
        g_loss_fake = g_nonsaturating_loss(fake_pred)
        loss_dict["g"] = g_loss_fake
        generator.zero_grad()
        g_loss_fake.backward()
        g_optim.step()

        g_regularize = args.g_reg_every > 0 and i % args.g_reg_every == 0
        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing,
                                 device)
            fake_img, latents = generator(noise, return_latents=True)
            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)
            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss
            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]
            weighted_path_loss.backward()
            g_optim.step()
            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())
        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        # Update EMA
        ema_nimg = args.ema_kimg * 1000
        if args.ema_rampup is not None:
            ema_nimg = min(ema_nimg, i * args.batch * args.ema_rampup)
        accum = 0.5**(args.batch / max(ema_nimg, 1e-8))
        accumulate(e_ema, e_module, accum)
        accumulate(g_ema, g_module, accum)

        loss_reduced = reduce_loss_dict(loss_dict)
        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        recx_score_val = loss_reduced["recx_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()
        pix_loss_val = loss_reduced["pix"].mean().item()
        vgg_loss_val = loss_reduced["vgg"].mean().item()
        adv_loss_val = loss_reduced["adv"].mean().item()
        avg_pix_loss.update(pix_loss_val, real_img.shape[0])
        avg_vgg_loss.update(vgg_loss_val, real_img.shape[0])

        if get_rank() == 0:
            pbar.set_description((
                f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
                f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                f"augment: {ada_aug_p:.4f}; "
                f"pix: {pix_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; adv: {adv_loss_val:.4f}"
            ))

            if i % args.log_every == 0:
                with torch.no_grad():
                    latent_x, _ = e_ema(sample_x)
                    fake_x, _ = generator([latent_x],
                                          input_is_latent=True,
                                          return_latents=False)
                    sample_pix_loss = torch.sum((sample_x - fake_x)**2)
                with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                    f.write((
                        f"{i:07d}; pix: {avg_pix_loss.avg:.4f}; vgg: {avg_vgg_loss.avg:.4f}; ref: {sample_pix_loss.item():.4f}; "
                        f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
                        f"path: {path_loss_val:.4f}; mean_path: {mean_path_length_avg:.4f}; "
                        f"augment: {ada_aug_p:.4f}; {'; '.join([f'{k}: {r_t_dict[k]:.4f}' for k in r_t_dict])}; "
                        f"real_score: {real_score_val:.4f}; fake_score: {fake_score_val:.4f}; recx_score: {recx_score_val:.4f}; "
                        f"real_diff: {real_diff/count:.4f}; fake_diff: {fake_diff/count:.4f};\n"
                    ))
                real_diff = fake_diff = count = 0

            if args.eval_every > 0 and i % args.eval_every == 0:
                with torch.no_grad():
                    fid_sa = fid_re = fid_hy = 0
                    # Sample FID
                    g_ema.eval()
                    if args.truncation < 1:
                        mean_latent = g_ema.mean_latent(4096)
                    features = extract_feature_from_samples(
                        g_ema, inception, args.truncation, mean_latent, 64,
                        args.n_sample_fid, args.device).numpy()
                    sample_mean = np.mean(features, 0)
                    sample_cov = np.cov(features, rowvar=False)
                    fid_sa = calc_fid(sample_mean, sample_cov, real_mean,
                                      real_cov)
                    # Recon FID
                    features = extract_feature_from_reconstruction(
                        e_ema,
                        g_ema,
                        inception,
                        args.truncation,
                        mean_latent,
                        loader2,
                        args.device,
                        mode='recon',
                    ).numpy()
                    sample_mean = np.mean(features, 0)
                    sample_cov = np.cov(features, rowvar=False)
                    fid_re = calc_fid(sample_mean, sample_cov, real_mean,
                                      real_cov)
                    # Hybrid FID
                    if args.eval_hybrid:
                        features = extract_feature_from_reconstruction(
                            e_ema,
                            g_ema,
                            inception,
                            args.truncation,
                            mean_latent,
                            loader2,
                            args.device,
                            mode='hybrid',  # shuffle_idx=fid_batch_idx
                        ).numpy()
                        sample_mean = np.mean(features, 0)
                        sample_cov = np.cov(features, rowvar=False)
                        fid_hy = calc_fid(sample_mean, sample_cov, real_mean,
                                          real_cov)
                # print("Sample FID:", fid_sa, "Recon FID:", fid_re, "Hybrid FID:", fid_hy)
                with open(os.path.join(args.log_dir, 'log_fid.txt'),
                          'a+') as f:
                    f.write(
                        f"{i:07d}; sample fid: {float(fid_sa):.4f}; recon fid: {float(fid_re):.4f}; hybrid fid: {float(fid_hy):.4f};\n"
                    )

            if wandb and args.wandb:
                wandb.log({
                    "Generator": g_loss_val,
                    "Discriminator": d_loss_val,
                    "Augment": ada_aug_p,
                    "Rt": r_t_stat,
                    "R1": r1_val,
                    "Path Length Regularization": path_loss_val,
                    "Mean Path Length": mean_path_length,
                    "Real Score": real_score_val,
                    "Fake Score": fake_score_val,
                    "Path Length": path_length_val,
                })

            if i % args.log_every == 0:
                with torch.no_grad():
                    g_ema.eval()
                    e_ema.eval()
                    nrow = int(args.n_sample**0.5)
                    nchw = list(sample_x.shape)[1:]
                    # Fixed fake samples
                    sample, _ = g_ema([sample_z])
                    utils.save_image(
                        sample,
                        os.path.join(args.log_dir, 'sample',
                                     f"{str(i).zfill(6)}-sample.png"),
                        nrow=int(args.n_sample**0.5),
                        normalize=True,
                        value_range=(-1, 1),
                    )
                    # Reconstruction samples
                    latent_real, _ = e_ema(sample_x)
                    fake_img, _ = g_ema([latent_real],
                                        input_is_latent=True,
                                        return_latents=False)
                    sample = torch.cat(
                        (sample_x.reshape(args.n_sample // nrow, nrow, *nchw),
                         fake_img.reshape(args.n_sample // nrow, nrow, *nchw)),
                        1)
                    utils.save_image(
                        sample.reshape(2 * args.n_sample, *nchw),
                        os.path.join(args.log_dir, 'sample',
                                     f"{str(i).zfill(6)}-recon.png"),
                        nrow=nrow,
                        normalize=True,
                        value_range=(-1, 1),
                    )
                    # Hybrid samples: [real_y1, real_y2; real_x1, fake_x2]
                    if args.eval_hybrid:
                        w1, _ = e_ema(sample_x1)
                        w2, _ = e_ema(sample_x2)
                        dw = w2 - w1
                        dw = torch.cat(
                            dw.chunk(2, 0)[::-1],
                            0) if sample_idx is None else dw[sample_idx, ...]
                        fake_img, _ = g_ema([w1 + dw],
                                            input_is_latent=True,
                                            return_latents=False)
                        drive = torch.cat(
                            (torch.cat(sample_x1.chunk(2, 0)[::-1], 0).reshape(
                                args.n_sample, 1, *nchw),
                             torch.cat(sample_x2.chunk(2, 0)[::-1], 0).reshape(
                                 args.n_sample, 1, *nchw)), 1)
                        source = torch.cat(
                            (sample_x1.reshape(args.n_sample, 1, *nchw),
                             fake_img.reshape(args.n_sample, 1, *nchw)), 1)
                        sample = torch.cat(
                            (drive.reshape(args.n_sample // nrow, 2 * nrow, *
                                           nchw),
                             source.reshape(args.n_sample // nrow, 2 * nrow, *
                                            nchw)), 1)
                        utils.save_image(
                            sample.reshape(4 * args.n_sample, *nchw),
                            os.path.join(args.log_dir, 'sample',
                                         f"{str(i).zfill(6)}-cross.png"),
                            nrow=2 * nrow,
                            normalize=True,
                            value_range=(-1, 1),
                        )

            if i % args.save_every == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "d2":
                        d2_module.state_dict() if args.decouple_d else None,
                        "g_ema": g_ema.state_dict(),
                        "e_ema": e_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "d2_optim":
                        d2_optim.state_dict() if args.decouple_d else None,
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight',
                                 f"{str(i).zfill(6)}.pt"),
                )

            if i % args.save_latest_every == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "d2":
                        d2_module.state_dict() if args.decouple_d else None,
                        "g_ema": g_ema.state_dict(),
                        "e_ema": e_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "d2_optim":
                        d2_optim.state_dict() if args.decouple_d else None,
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight', f"latest.pt"),
                )
Exemple #22
0
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema,
          device, save_dir):
    loader = sample_data(loader)

    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    dynamic_ncols=True,
                    smoothing=0.01)

    mean_path_length = 0

    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}

    if args.distributed:
        g_module = generator.module
        d_module = discriminator.module

    else:
        g_module = generator
        d_module = discriminator

    accum = 0.5**(32 / (10 * 1000))
    ada_augment = torch.tensor([0.0, 0.0], device=device)
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    ada_aug_step = args.ada_target / args.ada_length
    r_t_stat = 0

    sample_z = torch.randn(args.n_sample, args.latent, device=device)

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")

            break

        real_img = next(loader)
        real_img = real_img.to(device)

        requires_grad(generator, False)
        requires_grad(discriminator, True)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)

        if args.augment:
            real_img_aug, _ = augment(real_img, ada_aug_p)
            fake_img, _ = augment(fake_img, ada_aug_p)

        else:
            real_img_aug = real_img

        fake_pred = discriminator(fake_img)
        real_pred = discriminator(real_img_aug)
        d_loss = d_logistic_loss(real_pred, fake_pred)

        loss_dict["d"] = d_loss
        loss_dict["real_score"] = real_pred.mean()
        loss_dict["fake_score"] = fake_pred.mean()

        discriminator.zero_grad()
        d_loss.backward()
        d_optim.step()

        if args.augment and args.augment_p == 0:
            ada_augment_data = torch.tensor(
                (torch.sign(real_pred).sum().item(), real_pred.shape[0]),
                device=device)
            ada_augment += reduce_sum(ada_augment_data)

            if ada_augment[1] > 255:
                pred_signs, n_pred = ada_augment.tolist()

                r_t_stat = pred_signs / n_pred

                if r_t_stat > args.ada_target:
                    sign = 1

                else:
                    sign = -1

                ada_aug_p += sign * ada_aug_step * n_pred
                ada_aug_p = min(1, max(0, ada_aug_p))
                ada_augment.mul_(0)

        d_regularize = i % args.d_reg_every == 0

        if d_regularize:
            real_img.requires_grad = True
            real_pred = discriminator(real_img)
            r1_loss = d_r1_loss(real_pred, real_img)

            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every +
             0 * real_pred[0]).backward()

            d_optim.step()

        loss_dict["r1"] = r1_loss

        requires_grad(generator, True)
        requires_grad(discriminator, False)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)

        if args.augment:
            fake_img, _ = augment(fake_img, ada_aug_p)

        fake_pred = discriminator(fake_img)
        g_loss = g_nonsaturating_loss(fake_pred)

        loss_dict["g"] = g_loss

        generator.zero_grad()
        g_loss.backward()
        g_optim.step()

        g_regularize = i % args.g_reg_every == 0

        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing,
                                 device)
            fake_img, latents = generator(noise, return_latents=True)

            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)

            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss

            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()

            g_optim.step()

            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())

        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        accumulate(g_ema, g_module, accum)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()

        if get_rank() == 0:
            pbar.set_description((
                f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
                f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                f"augment: {ada_aug_p:.4f}"))

            if wandb and args.wandb:
                wandb.log({
                    "Generator": g_loss_val,
                    "Discriminator": d_loss_val,
                    "Augment": ada_aug_p,
                    "Rt": r_t_stat,
                    "R1": r1_val,
                    "Path Length Regularization": path_loss_val,
                    "Mean Path Length": mean_path_length,
                    "Real Score": real_score_val,
                    "Fake Score": fake_score_val,
                    "Path Length": path_length_val,
                })

            if i % 1000 == 0:  # save some samples

                with torch.no_grad():
                    g_ema.eval()

                    sample, _ = g_ema([sample_z])

                    utils.save_image(
                        sample,
                        save_dir + f"/samples/{str(i).zfill(6)}.png",
                        nrow=int(args.n_sample**0.5),
                        normalize=True,
                        range=(-1, 1),
                    )

            if i % 2000 == 0:  #save the model
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                    },
                    save_dir + f"/checkpoints/{str(i).zfill(6)}.pt",
                )
Exemple #23
0
def train(args, loader, drs_loader, generator, discriminator,
          drs_discriminator, g_optim, d_optim, drs_d_optim, g_ema, device,
          output_path):
    # loader = sample_data(loader)

    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    dynamic_ncols=True,
                    smoothing=0.01)

    mean_path_length = 0

    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}

    if args.distributed:
        g_module = generator.module
        d_module = discriminator.module
        drs_d_module = drs_discriminator.module

    else:
        g_module = generator
        d_module = discriminator
        drs_d_module = drs_discriminator

    logit_results = defaultdict(dict)

    accum = 0.5**(32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0

    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256,
                                      device)

    sample_z = torch.randn(args.n_sample, args.latent, device=device)

    iter_dataloader = iter(loader)
    iter_drs_dataloader = iter(drs_loader)
    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")

            break

        try:
            real_img, _ = next(iter_dataloader)
        except StopIteration:
            iter_dataloader = iter(loader)
            real_img, _ = next(iter_dataloader)

        try:
            drs_real_img, _ = next(iter_drs_dataloader)
        except StopIteration:
            iter_drs_dataloader = iter(drs_loader)
            drs_real_img, _ = next(iter_drs_dataloader)

        real_img = real_img.to(device)
        drs_real_img = drs_real_img.to(device)

        requires_grad(generator, False)
        requires_grad(discriminator, True)
        requires_grad(drs_discriminator, True)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)

        if args.augment:
            real_img_aug, _ = augment(real_img, ada_aug_p)
            drs_real_img_aug, _ = augment(drs_real_img, ada_aug_p)
            fake_img, _ = augment(fake_img, ada_aug_p)

        else:
            real_img_aug = real_img
            drs_real_img_aug = drs_real_img

        fake_pred = discriminator(fake_img)
        real_pred = discriminator(real_img_aug)

        drs_fake_pred = drs_discriminator(fake_img)
        drs_real_pred = drs_discriminator(drs_real_img_aug)

        d_loss = d_logistic_loss(real_pred, fake_pred)
        drs_d_loss = d_logistic_loss(drs_real_pred, drs_fake_pred)

        loss_dict["d"] = d_loss
        loss_dict["drs_d"] = drs_d_loss
        loss_dict["real_score"] = real_pred.mean()
        loss_dict["fake_score"] = fake_pred.mean()

        discriminator.zero_grad()
        d_loss.backward()
        d_optim.step()

        drs_discriminator.zero_grad()
        drs_d_loss.backward()
        drs_d_optim.step()

        if args.augment and args.augment_p == 0:
            ada_aug_p = ada_augment.tune(real_pred)
            r_t_stat = ada_augment.r_t_stat

        d_regularize = i % args.d_reg_every == 0

        if d_regularize:
            real_img.requires_grad = True
            real_pred = discriminator(real_img)
            r1_loss = d_r1_loss(real_pred, real_img)

            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every +
             0 * real_pred[0]).backward()

            d_optim.step()

            drs_real_img.requires_grad = True
            drs_real_pred = drs_discriminator(drs_real_img)
            drs_r1_loss = d_r1_loss(drs_real_pred, drs_real_img)

            drs_discriminator.zero_grad()
            (args.r1 / 2 * drs_r1_loss * args.d_reg_every +
             0 * drs_real_pred[0]).backward()

            drs_d_optim.step()

        loss_dict["r1"] = r1_loss

        requires_grad(generator, True)
        requires_grad(discriminator, False)
        requires_grad(drs_discriminator, False)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)

        if args.augment:
            fake_img, _ = augment(fake_img, ada_aug_p)

        fake_pred = discriminator(fake_img)
        g_loss = g_nonsaturating_loss(fake_pred)

        loss_dict["g"] = g_loss

        generator.zero_grad()
        g_loss.backward()
        g_optim.step()

        g_regularize = i % args.g_reg_every == 0

        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing,
                                 device)
            fake_img, latents = generator(noise, return_latents=True)

            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)

            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss

            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()

            g_optim.step()

            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())

        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        accumulate(g_ema, g_module, accum)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        drs_d_loss_val = loss_reduced["drs_d"].mean().item()

        g_loss_val = loss_reduced["g"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()

        if get_rank() == 0 and i > 0:
            pbar.set_description((
                f"d: {d_loss_val:.4f}; drs_d: {drs_d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
                f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                f"augment: {ada_aug_p:.4f}"))

            if wandb and args.wandb:
                wandb.log({
                    "Generator": g_loss_val,
                    "Discriminator": d_loss_val,
                    "Augment": ada_aug_p,
                    "Rt": r_t_stat,
                    "R1": r1_val,
                    "Path Length Regularization": path_loss_val,
                    "Mean Path Length": mean_path_length,
                    "Real Score": real_score_val,
                    "Fake Score": fake_score_val,
                    "Path Length": path_length_val,
                })

            if i % 100 == 0:
                with torch.no_grad():
                    g_ema.eval()
                    sample, _ = g_ema([sample_z])
                    save_path = output_path / 'fixed_sample'
                    save_path.mkdir(parents=True, exist_ok=True)
                    utils.save_image(
                        sample,
                        save_path / f"{str(i).zfill(6)}.png",
                        nrow=int(args.n_sample**0.5),
                        normalize=True,
                        range=(-1, 1),
                    )

                    random_sample_z = torch.randn(args.n_sample,
                                                  args.latent,
                                                  device=device)
                    sample, _ = g_ema([random_sample_z])
                    save_path = output_path / 'random_sample'
                    save_path.mkdir(parents=True, exist_ok=True)
                    utils.save_image(
                        sample,
                        save_path / f"{str(i).zfill(6)}.png",
                        nrow=int(args.n_sample**0.5),
                        normalize=True,
                        range=(-1, 1),
                    )

            if i % 5000 == 0:
                save_path = output_path / 'checkpoint'
                save_path.mkdir(parents=True, exist_ok=True)
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "d": d_module.state_dict(),
                        "drs_d": drs_d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "drs_d_optim": drs_d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                    },
                    save_path / f"{str(i).zfill(6)}.pt",
                )
def train(args, loader, encoder, generator, discriminator, discriminator_z, g1,
          vggnet, pwcnet, e_optim, d_optim, dz_optim, g1_optim, e_ema, e_tf,
          g1_ema, device):
    mmd_eval = functools.partial(mix_rbf_mmd2,
                                 sigma_list=[2.0, 5.0, 10.0, 20.0, 40.0, 80.0])

    loader = sample_data(loader)
    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    dynamic_ncols=True,
                    smoothing=0.01)

    d_loss_val = 0
    e_loss_val = 0
    rec_loss_val = 0
    vgg_loss_val = 0
    adv_loss_val = 0
    loss_dict = {
        "d": torch.tensor(0., device=device),
        "real_score": torch.tensor(0., device=device),
        "fake_score": torch.tensor(0., device=device),
        "r1_d": torch.tensor(0., device=device),
        "r1_e": torch.tensor(0., device=device),
        "rec": torch.tensor(0., device=device),
    }
    avg_pix_loss = util.AverageMeter()
    avg_vgg_loss = util.AverageMeter()

    if args.distributed:
        e_module = encoder.module
        d_module = discriminator.module
        g_module = generator.module
        g1_module = g1.module if args.train_latent_mlp else None
    else:
        e_module = encoder
        d_module = discriminator
        g_module = generator
        g1_module = g1 if args.train_latent_mlp else None

    accum = 0.5**(32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0

    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256,
                                      device)

    # sample_x = accumulate_batches(loader, args.n_sample).to(device)
    sample_x = load_real_samples(args, loader)

    requires_grad(generator, False)  # always False
    generator.eval()  # Generator should be ema and in eval mode

    # if args.no_ema or e_ema is None:
    #     e_ema = encoder

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")
            break

        real_img = next(loader)
        real_img = real_img.to(device)

        batch = real_img.shape[0]

        # Train Encoder
        if args.toggle_grads:
            requires_grad(encoder, True)
            requires_grad(discriminator, False)
        pix_loss = vgg_loss = adv_loss = rec_loss = torch.tensor(0.,
                                                                 device=device)
        kld_z = torch.tensor(0., device=device)
        mmd_z = torch.tensor(0., device=device)
        gan_z = torch.tensor(0., device=device)
        etf_z = torch.tensor(0., device=device)
        latent_real, logvar = encoder(real_img)
        if args.reparameterization:
            latent_real = reparameterize(latent_real, logvar)
        if args.train_latent_mlp:
            fake_img, _ = generator([g1(latent_real)],
                                    input_is_latent=True,
                                    return_latents=False)
        else:
            fake_img, _ = generator([latent_real],
                                    input_is_latent=False,
                                    return_latents=False)

        if args.lambda_adv > 0:
            if args.augment:
                fake_img_aug, _ = augment(fake_img, ada_aug_p)
            else:
                fake_img_aug = fake_img
            fake_pred = discriminator(fake_img_aug)
            adv_loss = g_nonsaturating_loss(fake_pred)

        if args.lambda_pix > 0:
            pix_loss = torch.mean((real_img - fake_img)**2)

        if args.lambda_vgg > 0:
            real_feat = vggnet(real_img)
            fake_feat = vggnet(fake_img)
            vgg_loss = torch.mean((real_feat - fake_feat)**2)

        if args.lambda_kld_z > 0:
            z_mean = latent_real.view(batch, -1)
            kld_z = -0.5 * torch.sum(1. + logvar - z_mean.pow(2) -
                                     logvar.exp()) / batch
            # print(kld_z)

        if args.lambda_mmd_z > 0:
            z_real = torch.randn(batch, args.latent_full, device=device)
            mmd_z = mmd_eval(latent_real, z_real)
            # print(mmd_z)

        if args.lambda_gan_z > 0:
            fake_pred = discriminator_z(latent_real)
            gan_z = g_nonsaturating_loss(fake_pred)
            # print(gan_z)

        if args.use_latent_teacher_forcing and args.lambda_etf > 0:
            w_tf, _ = e_tf(real_img)
            if args.train_latent_mlp:
                w_pred = g1(latent_real)
            else:
                w_pred = generator.get_latent(latent_real)
            etf_z = torch.mean((w_tf - w_pred)**2)
            # print(etf_z)

        if args.train_on_fake and args.lambda_rec > 0:
            z_real = torch.randn(args.batch, args.latent_full, device=device)
            if args.train_latent_mlp:
                fake_img, _ = generator([g1(z_real)],
                                        input_is_latent=True,
                                        return_latents=False)
            else:
                fake_img, _ = generator([z_real],
                                        input_is_latent=False,
                                        return_latents=False)
            # fake_img, _ = generator([z_real], input_is_latent=False, return_latents=True)
            z_fake, z_logvar = encoder(fake_img)
            if args.reparameterization:
                z_fake = reparameterize(z_fake, z_logvar)
            rec_loss = torch.mean((z_real - z_fake)**2)
            loss_dict["rec"] = rec_loss
            # print(rec_loss)

        e_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv
        e_loss = e_loss + args.lambda_kld_z * kld_z + args.lambda_mmd_z * mmd_z + args.lambda_gan_z * gan_z + args.lambda_etf * etf_z + rec_loss * args.lambda_rec

        loss_dict["e"] = e_loss
        loss_dict["pix"] = pix_loss
        loss_dict["vgg"] = vgg_loss
        loss_dict["adv"] = adv_loss

        if args.train_latent_mlp and g1 is not None:
            g1.zero_grad()
        encoder.zero_grad()
        e_loss.backward()
        e_optim.step()
        if args.train_latent_mlp and g1_optim is not None:
            g1_optim.step()

        # if args.train_on_fake:
        #     e_regularize = args.e_rec_every > 0 and i % args.e_rec_every == 0
        #     if e_regularize and args.lambda_rec > 0:
        #         # noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        #         # fake_img, latent_fake = generator(noise, input_is_latent=False, return_latents=True)
        #         z_real = torch.randn(args.batch, args.latent_full, device=device)
        #         fake_img, w_real = generator([z_real], input_is_latent=False, return_latents=True)
        #         z_fake, logvar = encoder(fake_img)
        #         if args.reparameterization:
        #             z_fake = reparameterize(z_fake, logvar)
        #         rec_loss = torch.mean((z_real - z_fake) ** 2)
        #         encoder.zero_grad()
        #         (rec_loss * args.lambda_rec).backward()
        #         e_optim.step()
        #         loss_dict["rec"] = rec_loss

        e_regularize = args.e_reg_every > 0 and i % args.e_reg_every == 0
        if e_regularize:
            # why not regularize on augmented real?
            real_img.requires_grad = True
            real_pred, logvar = encoder(real_img)
            if args.reparameterization:
                real_pred = reparameterize(real_pred, logvar)
            r1_loss_e = d_r1_loss(real_pred, real_img)

            encoder.zero_grad()
            (args.r1 / 2 * r1_loss_e * args.e_reg_every +
             0 * real_pred.view(-1)[0]).backward()
            e_optim.step()

            loss_dict["r1_e"] = r1_loss_e

        if not args.no_ema and e_ema is not None:
            accumulate(e_ema, e_module, accum)
        if args.train_latent_mlp:
            accumulate(g1_ema, g1_module, accum)

        # Train Discriminator
        if args.toggle_grads:
            requires_grad(encoder, False)
            requires_grad(discriminator, True)
        if not args.no_update_discriminator and args.lambda_adv > 0:
            latent_real, logvar = encoder(real_img)
            if args.reparameterization:
                latent_real = reparameterize(latent_real, logvar)
            if args.train_latent_mlp:
                fake_img, _ = generator([g1(latent_real)],
                                        input_is_latent=True,
                                        return_latents=False)
            else:
                fake_img, _ = generator([latent_real],
                                        input_is_latent=False,
                                        return_latents=False)

            if args.augment:
                real_img_aug, _ = augment(real_img, ada_aug_p)
                fake_img_aug, _ = augment(fake_img, ada_aug_p)
            else:
                real_img_aug = real_img
                fake_img_aug = fake_img

            fake_pred = discriminator(fake_img_aug)
            real_pred = discriminator(real_img_aug)
            d_loss = d_logistic_loss(real_pred, fake_pred)

            loss_dict["d"] = d_loss
            loss_dict["real_score"] = real_pred.mean()
            loss_dict["fake_score"] = fake_pred.mean()

            discriminator.zero_grad()
            d_loss.backward()
            d_optim.step()

            z_real = torch.randn(batch, args.latent_full, device=device)
            fake_pred = discriminator_z(latent_real.detach())
            real_pred = discriminator_z(z_real)
            d_loss_z = d_logistic_loss(real_pred, fake_pred)
            discriminator_z.zero_grad()
            d_loss_z.backward()
            dz_optim.step()

            if args.augment and args.augment_p == 0:
                ada_aug_p = ada_augment.tune(real_pred)
                r_t_stat = ada_augment.r_t_stat

            d_regularize = args.d_reg_every > 0 and i % args.d_reg_every == 0
            if d_regularize:
                # why not regularize on augmented real?
                real_img.requires_grad = True
                real_pred = discriminator(real_img)
                r1_loss_d = d_r1_loss(real_pred, real_img)

                discriminator.zero_grad()
                (args.r1 / 2 * r1_loss_d * args.d_reg_every +
                 0 * real_pred.view(-1)[0]).backward()
                # Why 0* ? Answer is here https://github.com/rosinality/stylegan2-pytorch/issues/76
                d_optim.step()

                loss_dict["r1_d"] = r1_loss_d

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        e_loss_val = loss_reduced["e"].mean().item()
        r1_d_val = loss_reduced["r1_d"].mean().item()
        r1_e_val = loss_reduced["r1_e"].mean().item()
        pix_loss_val = loss_reduced["pix"].mean().item()
        vgg_loss_val = loss_reduced["vgg"].mean().item()
        adv_loss_val = loss_reduced["adv"].mean().item()
        rec_loss_val = loss_reduced["rec"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        avg_pix_loss.update(pix_loss_val, real_img.shape[0])
        avg_vgg_loss.update(vgg_loss_val, real_img.shape[0])

        if get_rank() == 0:
            pbar.set_description((
                f"d: {d_loss_val:.4f}; e: {e_loss_val:.4f}; r1_d: {r1_d_val:.4f}; r1_e: {r1_e_val:.4f}; "
                f"pix: {pix_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; adv: {adv_loss_val:.4f}; "
                f"rec: {rec_loss_val:.4f}; augment: {ada_aug_p:.4f}"))

            if i % args.log_every == 0:
                with torch.no_grad():
                    latent_x, _ = e_ema(sample_x)
                    if args.train_latent_mlp:
                        g1_ema.eval()
                        fake_x, _ = generator([g1_ema(latent_x)],
                                              input_is_latent=True,
                                              return_latents=False)
                    else:
                        fake_x, _ = generator([latent_x],
                                              input_is_latent=False,
                                              return_latents=False)
                    sample_pix_loss = torch.sum((sample_x - fake_x)**2)
                with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                    f.write(
                        f"{i:07d}; pix: {avg_pix_loss.avg}; vgg: {avg_vgg_loss.avg}; "
                        f"ref: {sample_pix_loss.item()};\n")

            if wandb and args.wandb:
                wandb.log({
                    "Encoder": e_loss_val,
                    "Discriminator": d_loss_val,
                    "Augment": ada_aug_p,
                    "Rt": r_t_stat,
                    "R1 D": r1_d_val,
                    "R1 E": r1_e_val,
                    "Pix Loss": pix_loss_val,
                    "VGG Loss": vgg_loss_val,
                    "Adv Loss": adv_loss_val,
                    "Rec Loss": rec_loss_val,
                    "Real Score": real_score_val,
                    "Fake Score": fake_score_val,
                })

            if i % args.log_every == 0:
                with torch.no_grad():
                    e_eval = encoder if args.no_ema else e_ema
                    e_eval.eval()
                    nrow = int(args.n_sample**0.5)
                    nchw = list(sample_x.shape)[1:]
                    latent_real, _ = e_eval(sample_x)
                    if args.train_latent_mlp:
                        g1_ema.eval()
                        fake_img, _ = generator([g1_ema(latent_real)],
                                                input_is_latent=True,
                                                return_latents=False)
                    else:
                        fake_img, _ = generator([latent_real],
                                                input_is_latent=False,
                                                return_latents=False)
                    sample = torch.cat(
                        (sample_x.reshape(args.n_sample // nrow, nrow, *nchw),
                         fake_img.reshape(args.n_sample // nrow, nrow, *nchw)),
                        1)
                    utils.save_image(
                        sample.reshape(2 * args.n_sample, *nchw),
                        os.path.join(args.log_dir, 'sample',
                                     f"{str(i).zfill(6)}.png"),
                        nrow=nrow,
                        normalize=True,
                        value_range=(-1, 1),
                    )
                    e_eval.train()

            if i % args.save_every == 0:
                e_eval = encoder if args.no_ema else e_ema
                torch.save(
                    {
                        "e":
                        e_module.state_dict(),
                        "d":
                        d_module.state_dict(),
                        "g1":
                        g1_module.state_dict()
                        if args.train_latent_mlp else None,
                        "g1_ema":
                        g1_ema.state_dict() if args.train_latent_mlp else None,
                        "g_ema":
                        g_module.state_dict(),
                        "e_ema":
                        e_eval.state_dict(),
                        "e_optim":
                        e_optim.state_dict(),
                        "d_optim":
                        d_optim.state_dict(),
                        "args":
                        args,
                        "ada_aug_p":
                        ada_aug_p,
                        "iter":
                        i,
                    },
                    os.path.join(args.log_dir, 'weight',
                                 f"{str(i).zfill(6)}.pt"),
                )

            if i % args.save_latest_every == 0:
                torch.save(
                    {
                        "e":
                        e_module.state_dict(),
                        "d":
                        d_module.state_dict(),
                        "g1":
                        g1_module.state_dict()
                        if args.train_latent_mlp else None,
                        "g1_ema":
                        g1_ema.state_dict() if args.train_latent_mlp else None,
                        "g_ema":
                        g_module.state_dict(),
                        "e_ema":
                        e_eval.state_dict(),
                        "e_optim":
                        e_optim.state_dict(),
                        "d_optim":
                        d_optim.state_dict(),
                        "args":
                        args,
                        "ada_aug_p":
                        ada_aug_p,
                        "iter":
                        i,
                    },
                    os.path.join(args.log_dir, 'weight', f"latest.pt"),
                )
Exemple #25
0
def train(args, loader, encoder, generator, discriminator, vggnet, pwcnet,
          e_optim, d_optim, e_ema, device):
    loader = sample_data(loader)

    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    dynamic_ncols=True,
                    smoothing=0.01)

    d_loss_val = 0
    e_loss_val = 0
    rec_loss_val = 0
    vgg_loss_val = 0
    adv_loss_val = 0
    loss_dict = {
        "d": torch.tensor(0., device=device),
        "real_score": torch.tensor(0., device=device),
        "fake_score": torch.tensor(0., device=device),
        "r1_d": torch.tensor(0., device=device),
        "r1_e": torch.tensor(0., device=device),
        "rec": torch.tensor(0., device=device),
    }
    avg_pix_loss = util.AverageMeter()
    avg_vgg_loss = util.AverageMeter()

    if args.distributed:
        e_module = encoder.module
        d_module = discriminator.module
        g_module = generator.module
    else:
        e_module = encoder
        d_module = discriminator
        g_module = generator

    accum = 0.5**(32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0

    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256,
                                      device)

    sample_x = accumulate_batches(loader, args.n_sample).to(device)

    requires_grad(generator, False)  # always False
    generator.eval()  # Generator should be ema and in eval mode

    # if args.no_ema or e_ema is None:
    #     e_ema = encoder

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")
            break

        real_img1, real_img2 = next(loader)
        real_img1 = real_img1.to(device)
        real_img2 = real_img2.to(device)

        # Train Encoder
        if args.toggle_grads:
            requires_grad(encoder, True)
            requires_grad(discriminator, False)
        pix_loss = vgg_loss = adv_loss = rec_loss = torch.tensor(0.,
                                                                 device=device)
        latent_real1 = encoder(real_img1)
        fake_img1, _ = generator([latent_real1],
                                 input_is_latent=True,
                                 return_latents=False)
        latent_real2 = encoder(real_img2)
        fake_img2, _ = generator([latent_real2],
                                 input_is_latent=True,
                                 return_latents=False)

        if args.lambda_adv > 0:
            if args.use_residual:
                fake_img_pair = fake_img1 - real_img2
            else:
                fake_img_pair = torch.cat((fake_img1, real_img2), 1)
            if args.augment:
                fake_img_aug, _ = augment(fake_img_pair, ada_aug_p)
            else:
                fake_img_aug = fake_img_pair
            fake_pred = discriminator(fake_img_aug)
            adv_loss = g_nonsaturating_loss(fake_pred)

        if args.lambda_pix > 0:
            pix_loss = torch.mean((real_img1 - fake_img1)**2)

        if args.lambda_vgg > 0:
            real_feat = vggnet(real_img1)
            fake_feat = vggnet(fake_img1)
            vgg_loss = torch.mean((real_feat - fake_feat)**2)

        e_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv

        loss_dict["e"] = e_loss
        loss_dict["pix"] = pix_loss
        loss_dict["vgg"] = vgg_loss
        loss_dict["adv"] = adv_loss

        encoder.zero_grad()
        e_loss.backward()
        e_optim.step()

        # if args.train_on_fake:
        #     e_regularize = args.e_rec_every > 0 and i % args.e_rec_every == 0
        #     if e_regularize and args.lambda_rec > 0:
        #         noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        #         fake_img, latent_fake = generator(noise, input_is_latent=False, return_latents=True)
        #         latent_pred = encoder(fake_img)
        #         if latent_pred.ndim < 3:
        #             latent_pred = latent_pred.unsqueeze(1).repeat(1, latent_fake.size(1), 1)
        #         rec_loss = torch.mean((latent_fake - latent_pred) ** 2)
        #         encoder.zero_grad()
        #         (rec_loss * args.lambda_rec).backward()
        #         e_optim.step()
        #         loss_dict["rec"] = rec_loss

        e_regularize = args.e_reg_every > 0 and i % args.e_reg_every == 0
        if e_regularize:
            # why not regularize on augmented real?
            real_img_pair = torch.cat((real_img1, real_img2), 1)
            real_img_pair.requires_grad = True
            real_pred = encoder(real_img_pair)
            r1_loss_e = d_r1_loss(real_pred, real_img_pair)

            encoder.zero_grad()
            (args.r1 / 2 * r1_loss_e * args.e_reg_every +
             0 * real_pred.view(-1)[0]).backward()
            e_optim.step()

            loss_dict["r1_e"] = r1_loss_e

        if not args.no_ema and e_ema is not None:
            accumulate(e_ema, e_module, accum)

        # Train Discriminator
        if args.toggle_grads:
            requires_grad(encoder, False)
            requires_grad(discriminator, True)
        if not args.no_update_discriminator and args.lambda_adv > 0:
            latent_real1 = encoder(real_img1)
            fake_img1, _ = generator([latent_real1],
                                     input_is_latent=True,
                                     return_latents=False)

            if args.use_residual:
                fake_img_pair = fake_img1 - real_img2
                real_img_pair = real_img1 - real_img2
            else:
                fake_img_pair = torch.cat((fake_img1, real_img2), 1)
                real_img_pair = torch.cat((real_img1, real_img2), 1)

            if args.augment:
                real_img_aug, _ = augment(real_img_pair, ada_aug_p)
                fake_img_aug, _ = augment(fake_img_pair, ada_aug_p)
            else:
                real_img_aug = real_img_pair
                fake_img_aug = fake_img_pair

            fake_pred = discriminator(fake_img_aug)
            real_pred = discriminator(real_img_aug)
            d_loss = d_logistic_loss(real_pred, fake_pred)

            loss_dict["d"] = d_loss
            loss_dict["real_score"] = real_pred.mean()
            loss_dict["fake_score"] = fake_pred.mean()

            discriminator.zero_grad()
            d_loss.backward()
            d_optim.step()

            if args.augment and args.augment_p == 0:
                ada_aug_p = ada_augment.tune(real_pred)
                r_t_stat = ada_augment.r_t_stat

            d_regularize = args.d_reg_every > 0 and i % args.d_reg_every == 0
            if d_regularize:
                # why not regularize on augmented real?
                if args.use_residual:
                    real_img_pair = real_img1 - real_img2
                else:
                    real_img_pair = torch.cat((real_img1, real_img2), 1)
                real_img_pair.requires_grad = True
                real_pred = discriminator(real_img_pair)
                r1_loss_d = d_r1_loss(real_pred, real_img_pair)

                discriminator.zero_grad()
                (args.r1 / 2 * r1_loss_d * args.d_reg_every +
                 0 * real_pred.view(-1)[0]).backward()
                # Why 0* ? Answer is here https://github.com/rosinality/stylegan2-pytorch/issues/76
                d_optim.step()

                loss_dict["r1_d"] = r1_loss_d

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        e_loss_val = loss_reduced["e"].mean().item()
        r1_d_val = loss_reduced["r1_d"].mean().item()
        r1_e_val = loss_reduced["r1_e"].mean().item()
        pix_loss_val = loss_reduced["pix"].mean().item()
        vgg_loss_val = loss_reduced["vgg"].mean().item()
        adv_loss_val = loss_reduced["adv"].mean().item()
        rec_loss_val = loss_reduced["rec"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        avg_pix_loss.update(pix_loss_val, real_img1.shape[0])
        avg_vgg_loss.update(vgg_loss_val, real_img1.shape[0])

        if get_rank() == 0:
            pbar.set_description((
                f"d: {d_loss_val:.4f}; e: {e_loss_val:.4f}; r1_d: {r1_d_val:.4f}; r1_e: {r1_e_val:.4f}; "
                f"pix: {pix_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; adv: {adv_loss_val:.4f}; "
                f"rec: {rec_loss_val:.4f}; augment: {ada_aug_p:.4f}"))

            if i % args.log_every == 0:
                with torch.no_grad():
                    latent_x = e_ema(sample_x)
                    fake_x, _ = generator([latent_x],
                                          input_is_latent=True,
                                          return_latents=False)
                    sample_pix_loss = torch.sum((sample_x - fake_x)**2)
                with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                    f.write(
                        f"{i:07d}; pix: {avg_pix_loss.avg}; vgg: {avg_vgg_loss.avg}; "
                        f"ref: {sample_pix_loss.item()};\n")

            if wandb and args.wandb:
                wandb.log({
                    "Encoder": e_loss_val,
                    "Discriminator": d_loss_val,
                    "Augment": ada_aug_p,
                    "Rt": r_t_stat,
                    "R1 D": r1_d_val,
                    "R1 E": r1_e_val,
                    "Pix Loss": pix_loss_val,
                    "VGG Loss": vgg_loss_val,
                    "Adv Loss": adv_loss_val,
                    "Rec Loss": rec_loss_val,
                    "Real Score": real_score_val,
                    "Fake Score": fake_score_val,
                })

            if i % args.log_every == 0:
                with torch.no_grad():
                    e_eval = encoder if args.no_ema else e_ema
                    e_eval.eval()
                    nrow = int(args.n_sample**0.5)
                    nchw = list(sample_x.shape)[1:]
                    latent_real = e_eval(sample_x)
                    fake_img, _ = generator([latent_real],
                                            input_is_latent=True,
                                            return_latents=False)
                    sample = torch.cat(
                        (sample_x.reshape(args.n_sample // nrow, nrow, *nchw),
                         fake_img.reshape(args.n_sample // nrow, nrow, *nchw)),
                        1)
                    utils.save_image(
                        sample.reshape(2 * args.n_sample, *nchw),
                        os.path.join(args.log_dir, 'sample',
                                     f"{str(i).zfill(6)}.png"),
                        nrow=nrow,
                        normalize=True,
                        value_range=(-1, 1),
                    )
                    e_eval.train()

            if i % args.save_every == 0:
                e_eval = encoder if args.no_ema else e_ema
                torch.save(
                    {
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_module.state_dict(),
                        "e_ema": e_eval.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight',
                                 f"{str(i).zfill(6)}.pt"),
                )

            if i % args.save_latest_every == 0:
                torch.save(
                    {
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_module.state_dict(),
                        "e_ema": e_eval.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight', f"latest.pt"),
                )