def train(args, loader, generator, encoder, discriminator, vggnet, g_optim,
          e_optim, d_optim, g_ema, e_ema, device):
    kwargs_d = {'detach_aux': False}
    if args.dataset == 'imagefolder':
        loader = sample_data2(loader)
    else:
        loader = sample_data(loader)

    if args.eval_every > 0:
        inception = nn.DataParallel(load_patched_inception_v3()).to(device)
        inception.eval()
        with open(args.inception, "rb") as f:
            embeds = pickle.load(f)
            real_mean = embeds["mean"]
            real_cov = embeds["cov"]
    else:
        inception = real_mean = real_cov = None
    mean_latent = None

    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    dynamic_ncols=True,
                    smoothing=0.01)

    mean_path_length = 0

    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}
    avg_pix_loss = util.AverageMeter()
    avg_vgg_loss = util.AverageMeter()

    if args.distributed:
        g_module = generator.module
        e_module = encoder.module
        d_module = discriminator.module
    else:
        g_module = generator
        e_module = encoder
        d_module = discriminator

    accum = 0.5**(32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0

    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256,
                                      device)

    sample_z = torch.randn(args.n_sample, args.latent, device=device)
    sample_x = load_real_samples(args, loader)
    if sample_x.ndim > 4:
        sample_x = sample_x[:, 0, ...]

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")
            break

        real_img = next(loader)
        real_img = real_img.to(device)

        # Train Discriminator
        requires_grad(generator, False)
        requires_grad(encoder, False)
        requires_grad(discriminator, True)
        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)
        latent_real, _ = encoder(real_img)
        rec_img, _ = generator([latent_real], input_is_latent=True)
        real_pred = discriminator(real_img)
        fake_pred = discriminator(fake_img)
        rec_pred = discriminator(rec_img)
        d_loss_real = F.softplus(-real_pred).mean()
        d_loss_fake = F.softplus(fake_pred).mean()
        d_loss_rec = F.softplus(rec_pred).mean()
        loss_dict["real_score"] = real_pred.mean()
        loss_dict["fake_score"] = fake_pred.mean()
        loss_dict["rec_score"] = rec_pred.mean()

        d_loss = d_loss_real + d_loss_fake + d_loss_rec
        loss_dict["d"] = d_loss

        discriminator.zero_grad()
        d_loss.backward()
        d_optim.step()

        d_regularize = args.d_reg_every > 0 and i % args.d_reg_every == 0
        if d_regularize:
            real_img.requires_grad = True
            real_pred = discriminator(real_img)
            r1_loss = d_r1_loss(real_pred, real_img)
            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every +
             0 * real_pred[0]).backward()
            d_optim.step()
        loss_dict["r1"] = r1_loss

        # # Train Encoder and Generator
        # requires_grad(generator, True)
        # requires_grad(encoder, True)
        # requires_grad(discriminator, False)
        # pix_loss = vgg_loss = adv_loss = torch.tensor(0., device=device)
        # noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        # fake_img, _ = generator(noise)
        # latent_real, _ = encoder(real_img)
        # rec_img, _ = generator([latent_real], input_is_latent=True)
        # fake_pred = discriminator(fake_img)
        # rec_pred = discriminator(rec_img)
        # g_loss_fake = g_nonsaturating_loss(fake_pred)
        # g_loss_rec = g_nonsaturating_loss(rec_pred)
        # adv_loss = g_loss_fake + g_loss_rec
        # if args.lambda_pix > 0:
        #     if args.pix_loss == 'l2':
        #         pix_loss = torch.mean((rec_img - real_img) ** 2)
        #     else:
        #         pix_loss = F.l1_loss(rec_img, real_img)
        # if args.lambda_vgg > 0:
        #     vgg_loss = torch.mean((vggnet(real_img) - vggnet(rec_img)) ** 2)
        # e_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv
        # loss_dict["e"] = e_loss
        # encoder.zero_grad()
        # generator.zero_grad()
        # e_loss.backward()
        # e_optim.step()
        # g_optim.step()

        # Train Encoder
        requires_grad(generator, False)
        requires_grad(encoder, True)
        requires_grad(discriminator, False)
        pix_loss = vgg_loss = adv_loss = torch.tensor(0., device=device)
        latent_real, _ = encoder(real_img)
        rec_img, _ = generator([latent_real], input_is_latent=True)
        rec_pred = discriminator(rec_img)
        g_loss_rec = g_nonsaturating_loss(rec_pred)
        adv_loss = g_loss_rec
        if args.lambda_pix > 0:
            if args.pix_loss == 'l2':
                pix_loss = torch.mean((rec_img - real_img)**2)
            else:
                pix_loss = F.l1_loss(rec_img, real_img)
        if args.lambda_vgg > 0:
            vgg_loss = torch.mean((vggnet(real_img) - vggnet(rec_img))**2)

        e_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv

        loss_dict["e"] = e_loss
        encoder.zero_grad()
        e_loss.backward()
        e_optim.step()

        # Train Generator
        requires_grad(generator, True)
        requires_grad(encoder, False)
        requires_grad(discriminator, False)
        pix_loss = vgg_loss = adv_loss = torch.tensor(0., device=device)
        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)
        latent_real, _ = encoder(real_img)
        rec_img, _ = generator([latent_real], input_is_latent=True)
        fake_pred = discriminator(fake_img)
        rec_pred = discriminator(rec_img)
        g_loss_fake = g_nonsaturating_loss(fake_pred)
        g_loss_rec = g_nonsaturating_loss(rec_pred)
        adv_loss = g_loss_fake + g_loss_rec
        if args.lambda_pix > 0:
            if args.pix_loss == 'l2':
                pix_loss = torch.mean((rec_img - real_img)**2)
            else:
                pix_loss = F.l1_loss(rec_img, real_img)
        if args.lambda_vgg > 0:
            vgg_loss = torch.mean((vggnet(real_img) - vggnet(rec_img))**2)

        g_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv

        loss_dict["g"] = g_loss
        generator.zero_grad()
        g_loss.backward()
        g_optim.step()

        g_regularize = args.g_reg_every > 0 and i % args.g_reg_every == 0
        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing,
                                 device)
            fake_img, latents = generator(noise, return_latents=True)
            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)
            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss
            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]
            weighted_path_loss.backward()
            g_optim.step()
            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())
        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        with torch.no_grad():
            latent_real, _ = encoder(real_img)
            rec_img, _ = generator([latent_real], input_is_latent=True)
            if args.pix_loss == 'l2':
                pix_loss = torch.mean((rec_img - real_img)**2)
            else:
                pix_loss = F.l1_loss(rec_img, real_img)
            vgg_loss = torch.mean((vggnet(real_img) - vggnet(rec_img))**2)
            pix_loss_val = pix_loss.mean().item()
            vgg_loss_val = vgg_loss.mean().item()

        accumulate(e_ema, e_module, accum)
        accumulate(g_ema, g_module, accum)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        e_loss_val = loss_reduced["e"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()
        avg_pix_loss.update(pix_loss_val, real_img.shape[0])
        avg_vgg_loss.update(vgg_loss_val, real_img.shape[0])

        if get_rank() == 0:
            pbar.set_description((
                f"d: {d_loss_val:.4f}; e: {e_loss_val:.4f}; r1: {r1_val:.4f}; "
                f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                f"augment: {ada_aug_p:.4f}"))

            if i % args.log_every == 0:
                with torch.no_grad():
                    latent_x, _ = e_ema(sample_x)
                    fake_x, _ = generator([latent_x],
                                          input_is_latent=True,
                                          return_latents=False)
                    sample_pix_loss = torch.sum((sample_x - fake_x)**2)
                with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                    f.write(
                        f"{i:07d}; pix: {avg_pix_loss.avg}; vgg: {avg_vgg_loss.avg}; "
                        f"ref: {sample_pix_loss.item()};\n")

            if args.eval_every > 0 and i % args.eval_every == 0:
                with torch.no_grad():
                    g_ema.eval()
                    if args.truncation < 1:
                        mean_latent = g_ema.mean_latent(4096)
                    features = extract_feature_from_samples(
                        g_ema, inception, args.truncation, mean_latent, 64,
                        args.n_sample_fid, args.device).numpy()
                    sample_mean = np.mean(features, 0)
                    sample_cov = np.cov(features, rowvar=False)
                    fid = calc_fid(sample_mean, sample_cov, real_mean,
                                   real_cov)
                print("fid:", fid)
                with open(os.path.join(args.log_dir, 'log_fid.txt'),
                          'a+') as f:
                    f.write(f"{i:07d}: fid: {float(fid):.4f}\n")

            if wandb and args.wandb:
                wandb.log({
                    "Generator": g_loss_val,
                    "Discriminator": d_loss_val,
                    "Augment": ada_aug_p,
                    "Rt": r_t_stat,
                    "R1": r1_val,
                    "Path Length Regularization": path_loss_val,
                    "Mean Path Length": mean_path_length,
                    "Real Score": real_score_val,
                    "Fake Score": fake_score_val,
                    "Path Length": path_length_val,
                })

            if i % args.log_every == 0:
                with torch.no_grad():
                    # Fixed fake samples
                    g_ema.eval()
                    sample, _ = g_ema([sample_z])
                    utils.save_image(
                        sample,
                        os.path.join(args.log_dir, 'sample',
                                     f"{str(i).zfill(6)}-sample.png"),
                        nrow=int(args.n_sample**0.5),
                        normalize=True,
                        value_range=(-1, 1),
                    )
                    # Reconstruction samples
                    e_ema.eval()
                    nrow = int(args.n_sample**0.5)
                    nchw = list(sample_x.shape)[1:]
                    latent_real, _ = e_ema(sample_x)
                    fake_img, _ = g_ema([latent_real],
                                        input_is_latent=True,
                                        return_latents=False)
                    sample = torch.cat(
                        (sample_x.reshape(args.n_sample // nrow, nrow, *nchw),
                         fake_img.reshape(args.n_sample // nrow, nrow, *nchw)),
                        1)
                    utils.save_image(
                        sample.reshape(2 * args.n_sample, *nchw),
                        os.path.join(args.log_dir, 'sample',
                                     f"{str(i).zfill(6)}-recon.png"),
                        nrow=nrow,
                        normalize=True,
                        value_range=(-1, 1),
                    )

            if i % args.save_every == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "e_ema": e_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight',
                                 f"{str(i).zfill(6)}.pt"),
                )

            if i % args.save_latest_every == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "e_ema": e_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight', f"latest.pt"),
                )
Beispiel #2
0
    ckpt = torch.load(args.ckpt)

    g = Generator(args.size, 512, 8).to(device)
    g.load_state_dict(ckpt["g_ema"])
    g = nn.DataParallel(g)
    g.eval()

    if args.truncation < 1:
        with torch.no_grad():
            mean_latent = g.mean_latent(args.truncation_mean)

    else:
        mean_latent = None

    inception = nn.DataParallel(load_patched_inception_v3()).to(device)
    inception.eval()

    features = extract_feature_from_samples(g, inception, args.truncation,
                                            mean_latent, args.batch,
                                            args.n_sample, device).numpy()
    print(f"extracted {features.shape[0]} features")

    sample_mean = np.mean(features, 0)
    sample_cov = np.cov(features, rowvar=False)

    with open(args.inception, "rb") as f:
        embeds = pickle.load(f)
        real_mean = embeds["mean"]
        real_cov = embeds["cov"]
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device):
    inception = real_mean = real_cov = mean_latent = None
    if args.eval_every > 0:
        inception = nn.DataParallel(load_patched_inception_v3()).to(device)
        inception.eval()
        with open(args.inception, "rb") as f:
            embeds = pickle.load(f)
            real_mean = embeds["mean"]
            real_cov = embeds["cov"]
    if get_rank() == 0:
        if args.eval_every > 0:
            with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f:
                f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n")
        if args.log_every > 0:
            with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n")

    loader = sample_data(loader)
    pbar = range(args.iter)
    if get_rank() == 0:
        pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01)

    mean_path_length = 0
    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}

    if args.distributed:
        g_module = generator.module
        d_module = discriminator.module

    else:
        g_module = generator
        d_module = discriminator

    # accum = 0.5 ** (32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0

    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, args.ada_every, device)

    args.n_sheets = int(np.ceil(args.n_classes / args.n_class_per_sheet))
    args.n_sample_per_sheet = args.n_sample_per_class * args.n_class_per_sheet
    args.n_sample = args.n_sample_per_sheet * args.n_sheets
    sample_z = torch.randn(args.n_sample, args.latent, device=device)
    sample_y = torch.arange(args.n_classes).repeat(args.n_sample_per_class, 1).t().reshape(-1).to(device)
    if args.n_sample > args.n_sample_per_class * args.n_classes:
        sample_y1 = make_fake_label(args.n_sample - args.n_sample_per_class * args.n_classes, args.n_classes, device)
        sample_y = torch.cat([sample_y, sample_y1], 0)

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")

            break

        # Train Discriminator
        requires_grad(generator, False)
        requires_grad(discriminator, True)

        for step_index in range(args.n_step_d):
            real_img, real_labels = next(loader)
            real_img, real_labels = real_img.to(device), real_labels.to(device)

            noise = mixing_noise(args.batch, args.latent, args.mixing, device)
            fake_labels = make_fake_label(args.batch, args.n_classes, device)
            fake_img, _ = generator(noise, fake_labels)

            if args.augment:
                real_img_aug, _ = augment(real_img, ada_aug_p)
                fake_img, _ = augment(fake_img, ada_aug_p)

            else:
                real_img_aug = real_img

            fake_pred = discriminator(fake_img, fake_labels)
            real_pred = discriminator(real_img_aug, real_labels)
            d_loss = d_logistic_loss(real_pred, fake_pred)

            loss_dict["d"] = d_loss
            loss_dict["real_score"] = real_pred.mean()
            loss_dict["fake_score"] = fake_pred.mean()

            discriminator.zero_grad()
            d_loss.backward()
            d_optim.step()

        if args.augment and args.augment_p == 0:
            ada_aug_p = ada_augment.tune(real_pred)
            r_t_stat = ada_augment.r_t_stat

        d_regularize = i % args.d_reg_every == 0

        if d_regularize:
            real_img.requires_grad = True
            if args.augment:
                real_img_aug, _ = augment(real_img, ada_aug_p)
            else:
                real_img_aug = real_img
            real_pred = discriminator(real_img_aug, real_labels)
            r1_loss = d_r1_loss(real_pred, real_img)

            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward()

            d_optim.step()

        loss_dict["r1"] = r1_loss

        # Train Generator
        requires_grad(generator, True)
        requires_grad(discriminator, False)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_labels = make_fake_label(args.batch, args.n_classes, device)
        fake_img, _ = generator(noise, fake_labels)

        if args.augment:
            fake_img, _ = augment(fake_img, ada_aug_p)

        fake_pred = discriminator(fake_img, fake_labels)
        g_loss = g_nonsaturating_loss(fake_pred)

        loss_dict["g"] = g_loss

        generator.zero_grad()
        g_loss.backward()
        g_optim.step()

        g_regularize = args.g_reg_every > 0 and i % args.g_reg_every == 0
        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing, device)
            fake_labels = make_fake_label(args.batch, args.n_classes, device)
            fake_img, latents = generator(noise, fake_labels, return_latents=True)

            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length
            )

            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss

            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()

            g_optim.step()

            mean_path_length_avg = (
                reduce_sum(mean_path_length).item() / get_world_size()
            )

        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        # Update G_ema
        # G_ema = G * (1-ema_beta) + G_ema * ema_beta
        ema_nimg = args.ema_kimg * 1000
        if args.ema_rampup is not None:
            ema_nimg = min(ema_nimg, i * args.batch * args.ema_rampup)
        accum = 0.5 ** (args.batch / max(ema_nimg, 1e-8))
        accumulate(g_ema, g_module, accum)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()

        if get_rank() == 0:
            pbar.set_description(
                (
                    f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
                    f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                    f"augment: {ada_aug_p:.4f}"
                )
            )

            if wandb and args.wandb:
                wandb.log(
                    {
                        "Generator": g_loss_val,
                        "Discriminator": d_loss_val,
                        "Augment": ada_aug_p,
                        "Rt": r_t_stat,
                        "R1": r1_val,
                        "Path Length Regularization": path_loss_val,
                        "Mean Path Length": mean_path_length,
                        "Real Score": real_score_val,
                        "Fake Score": fake_score_val,
                        "Path Length": path_length_val,
                    }
                )
            
            if i % args.log_every == 0:
                with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                    f.write(
                        (
                            f"{i:07d}; "
                            f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
                            f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                            f"augment: {ada_aug_p:.4f};\n"
                        )
                    )

            if i % args.log_every == 0:
                with torch.no_grad():
                    g_ema.eval()
                    for sheet_index in range(args.n_sheets):
                        sample_z_sheet = sample_z[sheet_index*args.n_sample_per_sheet:(sheet_index+1)*args.n_sample_per_sheet]
                        sample_y_sheet = sample_y[sheet_index*args.n_sample_per_sheet:(sheet_index+1)*args.n_sample_per_sheet]
                        sample, _ = g_ema([sample_z_sheet], sample_y_sheet)
                        utils.save_image(
                            sample,
                            os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}_{sheet_index}.png"),
                            nrow=args.n_sample_per_class,
                            normalize=True,
                            value_range=(-1, 1),
                        )
            
            if args.eval_every > 0 and i % args.eval_every == 0:
                with torch.no_grad():
                    g_ema.eval()
                    if args.truncation < 1:
                        mean_latent = g_ema.mean_latent(4096)
                    features = extract_feature_from_samples(
                        g_ema, inception, args.truncation, mean_latent, 64, args.n_sample_fid, args.device,
                        n_classes=args.n_classes,
                    ).numpy()
                    sample_mean = np.mean(features, 0)
                    sample_cov = np.cov(features, rowvar=False)
                    fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov)
                # print("fid:", fid)
                with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f:
                    f.write(f"{i:07d}; fid: {float(fid):.4f};\n")

            if i % args.save_every == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"),
                )
            
            if i % args.save_latest_every == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight', f"latest.pt"),
                )
Beispiel #4
0

if __name__ == '__main__':
    device = 'cuda'

    parser = argparse.ArgumentParser()

    parser.add_argument('--batch', type=int, default=16)
    parser.add_argument('--size', type=int, default=256)
    parser.add_argument('--epoch', type=int, default=30000)
    parser.add_argument('--path_a', type=str, default='./data/collage/train')
    parser.add_argument('--name', type=str, default='new512')

    args = parser.parse_args()

    inception = load_patched_inception_v3().eval().to(device)
    torch.cuda.empty_cache()

    transform = transforms.Compose([
        transforms.Resize((args.size, args.size)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

    path_b = './fake/' + args.name + '/eval_%d' % (args.epoch)
    dset_b = ImageFolder(root=path_b, transform=transform)
    loader_b = iter(
        DataLoader(dset_b,
                   batch_size=args.batch,
                   num_workers=4,
                   sampler=InfiniteSamplerWrapper(dset_b)))
Beispiel #5
0
def train(args, loader, loader2, generator, encoder, discriminator, vggnet,
          g_optim, e_optim, d_optim, g_ema, e_ema, device):
    inception = real_mean = real_cov = mean_latent = None
    if args.eval_every > 0:
        inception = nn.DataParallel(load_patched_inception_v3()).to(device)
        inception.eval()
        with open(args.inception, "rb") as f:
            embeds = pickle.load(f)
            real_mean = embeds["mean"]
            real_cov = embeds["cov"]
    if get_rank() == 0:
        if args.eval_every > 0:
            with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f:
                f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n")
        if args.log_every > 0:
            with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n")

    loader = sample_data(loader)
    pbar = range(args.iter)
    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    dynamic_ncols=True,
                    smoothing=0.01)

    mean_path_length = 0
    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    d_loss_val = r1_val = real_score_val = recx_score_val = 0
    loss_dict = {
        "d": torch.tensor(0.0, device=device),
        "r1": torch.tensor(0.0, device=device)
    }
    avg_pix_loss = util.AverageMeter()
    avg_vgg_loss = util.AverageMeter()

    if args.distributed:
        g_module = generator.module
        e_module = encoder.module
        d_module = discriminator.module
    else:
        g_module = generator
        e_module = encoder
        d_module = discriminator

    d_weight = torch.tensor(1.0, device=device)
    last_layer = None
    if args.use_adaptive_weight:
        if args.distributed:
            last_layer = generator.module.get_last_layer()
        else:
            last_layer = generator.get_last_layer()

    # accum = 0.5 ** (32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0
    r_t_dict = {'real': 0, 'recx': 0}  # r_t stat
    g_scale = 1
    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length,
                                      args.ada_every, device)

    sample_z = torch.randn(args.n_sample, args.latent, device=device)
    sample_x = load_real_samples(args, loader)
    if sample_x.ndim > 4:
        sample_x = sample_x[:, 0, ...]

    n_step_max = max(args.n_step_d, args.n_step_e)

    requires_grad(g_ema, False)
    requires_grad(e_ema, False)

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")
            break

        if args.debug: util.seed_everything(i)
        real_imgs = [next(loader).to(device) for _ in range(n_step_max)]

        # Train Discriminator
        if args.lambda_adv > 0:
            requires_grad(generator, False)
            requires_grad(encoder, False)
            requires_grad(discriminator, True)
            for step_index in range(args.n_step_d):
                real_img = real_imgs[step_index]
                latent_real, _ = encoder(real_img)
                rec_img, _ = generator([latent_real], input_is_latent=True)
                if args.augment:
                    real_img_aug, _ = augment(real_img, ada_aug_p)
                    rec_img_aug, _ = augment(rec_img, ada_aug_p)
                else:
                    real_img_aug = real_img
                    rec_img_aug = rec_img
                real_pred = discriminator(real_img_aug)
                rec_pred = discriminator(rec_img_aug)
                d_loss_real = F.softplus(-real_pred).mean()
                d_loss_rec = F.softplus(rec_pred).mean()
                loss_dict["real_score"] = real_pred.mean()
                loss_dict["recx_score"] = rec_pred.mean()

                d_loss = d_loss_real + d_loss_rec * args.lambda_rec_d
                loss_dict["d"] = d_loss

                discriminator.zero_grad()
                d_loss.backward()
                d_optim.step()

            if args.augment and args.augment_p == 0:
                ada_aug_p = ada_augment.tune(real_pred)
                r_t_stat = ada_augment.r_t_stat
            # Compute batchwise r_t
            r_t_dict['real'] = torch.sign(real_pred).sum().item() / args.batch

            d_regularize = i % args.d_reg_every == 0
            if d_regularize:
                real_img.requires_grad = True
                if args.augment:
                    real_img_aug, _ = augment(real_img, ada_aug_p)
                else:
                    real_img_aug = real_img
                real_pred = discriminator(real_img_aug)
                r1_loss = d_r1_loss(real_pred, real_img)
                discriminator.zero_grad()
                (args.r1 / 2 * r1_loss * args.d_reg_every +
                 0 * real_pred[0]).backward()
                d_optim.step()
            loss_dict["r1"] = r1_loss

            r_t_dict['recx'] = torch.sign(rec_pred).sum().item() / args.batch

        # Train AutoEncoder
        requires_grad(encoder, True)
        requires_grad(generator, True)
        requires_grad(discriminator, False)
        if args.debug: util.seed_everything(i)
        pix_loss = vgg_loss = adv_loss = torch.tensor(0., device=device)
        for step_index in range(args.n_step_e):
            real_img = real_imgs[step_index]
            latent_real, _ = encoder(real_img)
            rec_img, _ = generator([latent_real], input_is_latent=True)
            if args.lambda_pix > 0:
                if args.pix_loss == 'l2':
                    pix_loss = torch.mean((rec_img - real_img)**2)
                elif args.pix_loss == 'l1':
                    pix_loss = F.l1_loss(rec_img, real_img)
            if args.lambda_vgg > 0:
                vgg_loss = torch.mean((vggnet(real_img) - vggnet(rec_img))**2)
            if args.lambda_adv > 0:
                if args.augment:
                    rec_img_aug, _ = augment(rec_img, ada_aug_p)
                else:
                    rec_img_aug = rec_img
                rec_pred = discriminator(rec_img_aug)
                adv_loss = g_nonsaturating_loss(rec_pred)

            if args.use_adaptive_weight and i >= args.disc_iter_start:
                nll_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg
                g_loss = adv_loss * args.lambda_adv
                d_weight = calculate_adaptive_weight(nll_loss,
                                                     g_loss,
                                                     last_layer=last_layer)

            ae_loss = (pix_loss * args.lambda_pix +
                       vgg_loss * args.lambda_vgg +
                       d_weight * adv_loss * args.lambda_adv)
            loss_dict["ae"] = ae_loss
            loss_dict["pix"] = pix_loss
            loss_dict["vgg"] = vgg_loss
            loss_dict["adv"] = adv_loss

            encoder.zero_grad()
            generator.zero_grad()
            ae_loss.backward()
            e_optim.step()
            if args.g_decay is not None:
                scale_grad(generator, g_scale)
                g_scale *= args.g_decay
            g_optim.step()

        g_regularize = args.g_reg_every > 0 and i % args.g_reg_every == 0
        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing,
                                 device)
            fake_img, latents = generator(noise, return_latents=True)
            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)
            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss
            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]
            weighted_path_loss.backward()
            g_optim.step()
            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())
        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        # Update EMA
        ema_nimg = args.ema_kimg * 1000
        if args.ema_rampup is not None:
            ema_nimg = min(ema_nimg, i * args.batch * args.ema_rampup)
        accum = 0.5**(args.batch / max(ema_nimg, 1e-8))
        accumulate(g_ema, g_module, 0 if args.no_ema_g else accum)
        accumulate(e_ema, e_module, 0 if args.no_ema_e else accum)

        loss_reduced = reduce_loss_dict(loss_dict)
        ae_loss_val = loss_reduced["ae"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()
        pix_loss_val = loss_reduced["pix"].mean().item()
        vgg_loss_val = loss_reduced["vgg"].mean().item()
        adv_loss_val = loss_reduced["adv"].mean().item()
        if args.lambda_adv > 0:
            d_loss_val = loss_reduced["d"].mean().item()
            r1_val = loss_reduced["r1"].mean().item()
            real_score_val = loss_reduced["real_score"].mean().item()
            recx_score_val = loss_reduced["recx_score"].mean().item()
        avg_pix_loss.update(pix_loss_val, real_img.shape[0])
        avg_vgg_loss.update(vgg_loss_val, real_img.shape[0])

        if get_rank() == 0:
            pbar.set_description((
                f"d: {d_loss_val:.4f}; ae: {ae_loss_val:.4f}; r1: {r1_val:.4f}; "
                f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                f"augment: {ada_aug_p:.4f}; "
                f"d_weight: {d_weight.item():.4f}; "
                f"pix: {pix_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; adv: {adv_loss_val:.4f}"
            ))

            if i % args.log_every == 0:
                with torch.no_grad():
                    g_ema.eval()
                    e_ema.eval()
                    nrow = int(args.n_sample**0.5)
                    nchw = list(sample_x.shape)[1:]
                    # Reconstruction of real images
                    latent_x, _ = e_ema(sample_x)
                    rec_real, _ = g_ema([latent_x], input_is_latent=True)
                    sample = torch.cat(
                        (sample_x.reshape(args.n_sample // nrow, nrow, *nchw),
                         rec_real.reshape(args.n_sample // nrow, nrow, *nchw)),
                        1)
                    utils.save_image(
                        sample.reshape(2 * args.n_sample, *nchw),
                        os.path.join(args.log_dir, 'sample',
                                     f"{str(i).zfill(6)}-recon.png"),
                        nrow=nrow,
                        normalize=True,
                        value_range=(-1, 1),
                    )
                    ref_pix_loss = torch.sum(torch.abs(sample_x - rec_real))
                    ref_vgg_loss = torch.mean(
                        (vggnet(sample_x) -
                         vggnet(rec_real))**2) if vggnet is not None else 0
                    # Fixed fake samples and reconstructions
                    sample, _ = g_ema([sample_z])
                    utils.save_image(
                        sample,
                        os.path.join(args.log_dir, 'sample',
                                     f"{str(i).zfill(6)}-sample.png"),
                        nrow=int(args.n_sample**0.5),
                        normalize=True,
                        value_range=(-1, 1),
                    )

                with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                    f.write((
                        f"{i:07d}; "
                        f"d: {d_loss_val:.4f}; r1: {r1_val:.4f}; "
                        f"path: {path_loss_val:.4f}; mean_path: {mean_path_length_avg:.4f}; "
                        f"augment: {ada_aug_p:.4f}; {'; '.join([f'{k}: {r_t_dict[k]:.4f}' for k in r_t_dict])}; "
                        f"real_score: {real_score_val:.4f}; recx_score: {recx_score_val:.4f}; "
                        f"pix: {avg_pix_loss.avg:.4f}; vgg: {avg_vgg_loss.avg:.4f}; "
                        f"ref_pix: {ref_pix_loss.item():.4f}; ref_vgg: {ref_vgg_loss.item():.4f}; "
                        f"d_weight: {d_weight.item():.4f}; "
                        f"\n"))

            if wandb and args.wandb:
                wandb.log({
                    "Discriminator": d_loss_val,
                    "Augment": ada_aug_p,
                    "Rt": r_t_stat,
                    "R1": r1_val,
                    "Path Length Regularization": path_loss_val,
                    "Mean Path Length": mean_path_length,
                    "Real Score": real_score_val,
                    "Path Length": path_length_val,
                })

            if args.eval_every > 0 and i % args.eval_every == 0:
                with torch.no_grad():
                    fid_sa = fid_re = fid_sr = 0
                    g_ema.eval()
                    e_ema.eval()
                    if args.truncation < 1:
                        mean_latent = g_ema.mean_latent(4096)
                    # Real reconstruction FID
                    if 'fid_recon' in args.which_metric:
                        features = extract_feature_from_reconstruction(
                            e_ema,
                            g_ema,
                            inception,
                            args.truncation,
                            mean_latent,
                            loader2,
                            args.device,
                            mode='recon',
                        ).numpy()
                        sample_mean = np.mean(features, 0)
                        sample_cov = np.cov(features, rowvar=False)
                        fid_re = calc_fid(sample_mean, sample_cov, real_mean,
                                          real_cov)
                with open(os.path.join(args.log_dir, 'log_fid.txt'),
                          'a+') as f:
                    f.write(f"{i:07d}; rec_real: {float(fid_re):.4f};\n")

            if i % args.save_every == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "e_ema": e_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight',
                                 f"{str(i).zfill(6)}.pt"),
                )

            if i % args.save_latest_every == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "e_ema": e_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight', f"latest.pt"),
                )
Beispiel #6
0
def train(args, loader, loader2, generator, encoder, discriminator,
          discriminator2, vggnet, g_optim, e_optim, d_optim, d2_optim, g_ema,
          e_ema, device):
    # kwargs_d = {'detach_aux': args.detach_d_aux_head}
    inception = real_mean = real_cov = mean_latent = None
    if args.eval_every > 0:
        inception = nn.DataParallel(load_patched_inception_v3()).to(device)
        inception.eval()
        with open(args.inception, "rb") as f:
            embeds = pickle.load(f)
            real_mean = embeds["mean"]
            real_cov = embeds["cov"]
    if get_rank() == 0:
        if args.eval_every > 0:
            with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f:
                f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n")
        if args.log_every > 0:
            with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n")

    if args.dataset == 'imagefolder':
        loader = sample_data2(loader)
    else:
        loader = sample_data(loader)
    pbar = range(args.iter)
    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    dynamic_ncols=True,
                    smoothing=0.01)

    mean_path_length = 0
    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}
    avg_pix_loss = util.AverageMeter()
    avg_vgg_loss = util.AverageMeter()

    if args.distributed:
        g_module = generator.module
        e_module = encoder.module
        d_module = discriminator.module
    else:
        g_module = generator
        e_module = encoder
        d_module = discriminator

    d2_module = None
    # if discriminator2 is not None:
    #     if args.distributed:
    #         d2_module = discriminator2.module
    #     else:
    #         d2_module = discriminator2

    accum = 0.5**(32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0
    r_t_dict = {'real': 0, 'fake': 0, 'recx': 0}  # r_t stat
    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment2(args.ada_margin, args.ada_length,
                                       args.ada_every, device)

    sample_z = torch.randn(args.n_sample, args.latent, device=device)
    sample_x = load_real_samples(args, loader)
    sample_x1 = sample_x2 = sample_idx = fid_batch_idx = None
    if sample_x.ndim > 4:
        sample_x1 = sample_x[:, 0, ...]
        sample_x2 = sample_x[:, -1, ...]
        sample_x = sample_x[:, 0, ...]

    n_step_max = max(args.n_step_d, args.n_step_e)

    requires_grad(g_ema, False)
    requires_grad(e_ema, False)

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")
            break

        real_imgs = [next(loader).to(device) for _ in range(n_step_max)]

        # Train Discriminator
        requires_grad(generator, False)
        requires_grad(encoder, False)
        requires_grad(discriminator, True)
        for step_index in range(args.n_step_d):
            real_img = real_imgs[step_index]
            noise = mixing_noise(args.batch, args.latent, args.mixing, device)
            if args.use_ema:
                g_ema.eval()
                fake_img, _ = g_ema(noise)
            else:
                fake_img, _ = generator(noise)

            fake_pred = discriminator(fake_img)
            real_pred = discriminator(real_img)
            d_loss_fake = F.softplus(fake_pred).mean()
            d_loss_real = F.softplus(-real_pred).mean()
            loss_dict["real_score"] = real_pred.mean()
            loss_dict["fake_score"] = fake_pred.mean()

            d_loss_rec = rec_pred = 0.
            if args.lambda_rec_d > 0 and not args.decouple_d:
                if args.use_ema:
                    e_ema.eval()
                    g_ema.eval()
                    latent_real, _ = e_ema(real_img)
                    rec_img, _ = g_ema([latent_real], input_is_latent=True)
                else:
                    latent_real, _ = encoder(real_img)
                    rec_img, _ = generator([latent_real], input_is_latent=True)

                rec_pred = discriminator(rec_img)
                d_loss_rec = F.softplus(rec_pred).mean()
                loss_dict["recx_score"] = rec_pred.mean()

            d_loss = d_loss_real + d_loss_fake * args.lambda_fake_d + d_loss_rec * args.lambda_rec_d
            loss_dict["d"] = d_loss

            discriminator.zero_grad()
            d_loss.backward()
            d_optim.step()

        if args.augment and args.augment_p == 0:
            ada_aug_p = ada_augment.tune(fake_pred, rec_pred)
            r_t_stat = ada_augment.r_t_stat
        r_t_dict['real'] = torch.sign(real_pred).sum().item() / args.batch
        r_t_dict['fake'] = torch.sign(fake_pred).sum().item() / args.batch
        r_t_dict['recx'] = torch.sign(rec_pred).sum().item() / args.batch

        d_regularize = i % args.d_reg_every == 0
        if d_regularize:
            real_img.requires_grad = True
            real_pred = discriminator(real_img)
            r1_loss = d_r1_loss(real_pred, real_img)
            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every +
             0 * real_pred[0]).backward()
            d_optim.step()
        loss_dict["r1"] = r1_loss

        # Train Encoder
        requires_grad(encoder, True)
        requires_grad(generator, True)
        requires_grad(discriminator, False)
        # requires_grad(discriminator2, False)
        pix_loss = vgg_loss = adv_loss = torch.tensor(0., device=device)
        for step_index in range(args.n_step_e):
            real_img = real_imgs[step_index]
            latent_real, _ = encoder(real_img)
            if args.use_ema:
                g_ema.eval()
                rec_img, _ = g_ema([latent_real], input_is_latent=True)
            else:
                rec_img, _ = generator([latent_real], input_is_latent=True)
            if args.lambda_pix > 0:
                if args.pix_loss == 'l2':
                    pix_loss = torch.mean((rec_img - real_img)**2)
                elif args.pix_loss == 'l1':
                    pix_loss = F.l1_loss(rec_img, real_img)
                else:
                    raise NotImplementedError
            if args.lambda_vgg > 0:
                vgg_loss = torch.mean((vggnet(real_img) - vggnet(rec_img))**2)
            if args.lambda_adv > 0:
                rec_pred = discriminator(rec_img)
                adv_loss = g_nonsaturating_loss(rec_pred)

            e_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv
            loss_dict["e"] = e_loss
            loss_dict["pix"] = pix_loss
            loss_dict["vgg"] = vgg_loss
            loss_dict["adv"] = adv_loss

            encoder.zero_grad()
            generator.zero_grad()
            e_loss.backward()
            manually_scale_grad(generator, 1 - ada_aug_p)
            e_optim.step()
            g_optim.step()

        # Train Generator
        requires_grad(generator, True)
        requires_grad(encoder, False)
        requires_grad(discriminator, False)
        # requires_grad(discriminator2, False)
        real_img = real_imgs[0]
        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)
        fake_pred = discriminator(fake_img)
        g_loss_fake = g_nonsaturating_loss(fake_pred)
        loss_dict["g"] = g_loss_fake
        generator.zero_grad()
        g_loss_fake.backward()
        g_optim.step()

        g_regularize = args.g_reg_every > 0 and i % args.g_reg_every == 0
        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing,
                                 device)
            fake_img, latents = generator(noise, return_latents=True)
            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)
            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss
            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]
            weighted_path_loss.backward()
            g_optim.step()
            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())
        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        accumulate(e_ema, e_module, accum)
        accumulate(g_ema, g_module, accum)

        loss_reduced = reduce_loss_dict(loss_dict)
        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        recx_score_val = loss_reduced["recx_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()
        pix_loss_val = loss_reduced["pix"].mean().item()
        vgg_loss_val = loss_reduced["vgg"].mean().item()
        adv_loss_val = loss_reduced["adv"].mean().item()
        avg_pix_loss.update(pix_loss_val, real_img.shape[0])
        avg_vgg_loss.update(vgg_loss_val, real_img.shape[0])

        if get_rank() == 0:
            pbar.set_description((
                f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
                f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                f"augment: {ada_aug_p:.4f}; "
                f"pix: {pix_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; adv: {adv_loss_val:.4f}"
            ))

            if i % args.log_every == 0:
                with torch.no_grad():
                    latent_x, _ = e_ema(sample_x)
                    fake_x, _ = generator([latent_x],
                                          input_is_latent=True,
                                          return_latents=False)
                    sample_pix_loss = torch.sum((sample_x - fake_x)**2)
                with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                    f.write((
                        f"{i:07d}; pix: {avg_pix_loss.avg:.4f}; vgg: {avg_vgg_loss.avg:.4f}; ref: {sample_pix_loss.item():.4f}; "
                        f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
                        f"path: {path_loss_val:.4f}; mean_path: {mean_path_length_avg:.4f}; "
                        f"augment: {ada_aug_p:.4f}; r_stat: {r_t_stat:.4f}; {'; '.join([f'{k}: {r_t_dict[k]:.4f}' for k in r_t_dict])}; "
                        f"real_score: {real_score_val:.4f}; fake_score: {fake_score_val:.4f}; recx_score: {recx_score_val:.4f};\n"
                    ))

            if args.eval_every > 0 and i % args.eval_every == 0:
                with torch.no_grad():
                    fid_sa = fid_re = fid_hy = 0
                    # Sample FID
                    g_ema.eval()
                    if args.truncation < 1:
                        mean_latent = g_ema.mean_latent(4096)
                    features = extract_feature_from_samples(
                        g_ema, inception, args.truncation, mean_latent, 64,
                        args.n_sample_fid, args.device).numpy()
                    sample_mean = np.mean(features, 0)
                    sample_cov = np.cov(features, rowvar=False)
                    fid_sa = calc_fid(sample_mean, sample_cov, real_mean,
                                      real_cov)
                    # Recon FID
                    features = extract_feature_from_recon_hybrid(
                        e_ema,
                        g_ema,
                        inception,
                        args.truncation,
                        mean_latent,
                        loader2,
                        args.device,
                        mode='recon',
                    ).numpy()
                    sample_mean = np.mean(features, 0)
                    sample_cov = np.cov(features, rowvar=False)
                    fid_re = calc_fid(sample_mean, sample_cov, real_mean,
                                      real_cov)
                # print("Sample FID:", fid_sa, "Recon FID:", fid_re, "Hybrid FID:", fid_hy)
                with open(os.path.join(args.log_dir, 'log_fid.txt'),
                          'a+') as f:
                    f.write(
                        f"{i:07d}; sample fid: {float(fid_sa):.4f}; recon fid: {float(fid_re):.4f}; \n"
                    )

            if wandb and args.wandb:
                wandb.log({
                    "Generator": g_loss_val,
                    "Discriminator": d_loss_val,
                    "Augment": ada_aug_p,
                    "Rt": r_t_stat,
                    "R1": r1_val,
                    "Path Length Regularization": path_loss_val,
                    "Mean Path Length": mean_path_length,
                    "Real Score": real_score_val,
                    "Fake Score": fake_score_val,
                    "Path Length": path_length_val,
                })

            if i % args.log_every == 0:
                with torch.no_grad():
                    g_ema.eval()
                    e_ema.eval()
                    nrow = int(args.n_sample**0.5)
                    nchw = list(sample_x.shape)[1:]
                    # Fixed fake samples
                    sample, _ = g_ema([sample_z])
                    utils.save_image(
                        sample,
                        os.path.join(args.log_dir, 'sample',
                                     f"{str(i).zfill(6)}-sample.png"),
                        nrow=int(args.n_sample**0.5),
                        normalize=True,
                        value_range=(-1, 1),
                    )
                    # Reconstruction samples
                    latent_real, _ = e_ema(sample_x)
                    fake_img, _ = g_ema([latent_real],
                                        input_is_latent=True,
                                        return_latents=False)
                    sample = torch.cat(
                        (sample_x.reshape(args.n_sample // nrow, nrow, *nchw),
                         fake_img.reshape(args.n_sample // nrow, nrow, *nchw)),
                        1)
                    utils.save_image(
                        sample.reshape(2 * args.n_sample, *nchw),
                        os.path.join(args.log_dir, 'sample',
                                     f"{str(i).zfill(6)}-recon.png"),
                        nrow=nrow,
                        normalize=True,
                        value_range=(-1, 1),
                    )

            if i % args.save_every == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "d2":
                        d2_module.state_dict() if args.decouple_d else None,
                        "g_ema": g_ema.state_dict(),
                        "e_ema": e_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "d2_optim":
                        d2_optim.state_dict() if args.decouple_d else None,
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight',
                                 f"{str(i).zfill(6)}.pt"),
                )

            if i % args.save_latest_every == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "d2":
                        d2_module.state_dict() if args.decouple_d else None,
                        "g_ema": g_ema.state_dict(),
                        "e_ema": e_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "d2_optim":
                        d2_optim.state_dict() if args.decouple_d else None,
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight', f"latest.pt"),
                )
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device):
    loader = sample_data(loader, datatype="imagefolder")
    # inception related:
    if (get_rank() == 0):
        from calc_inception import load_patched_inception_v3
        inception = load_patched_inception_v3().to(device)
        inception.eval()
        if args.eval_every > 0:
            with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f:
                f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-' * 50}\n")
        if args.log_every > 0:
            with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-' * 50}\n")

    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01)

    mean_path_length = 0

    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}

    if args.distributed:
        g_module = generator.module
        d_module = discriminator.module

    else:
        g_module = generator
        d_module = discriminator

    accum = 0.5 ** (32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0

    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 8, device)

    sample_z = torch.randn(args.n_sample, args.latent, device=device)

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")

            break

        real_img = next(loader)
        real_img = real_img.to(device)

        requires_grad(generator, False)
        requires_grad(discriminator, True)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)

        if args.augment:
            real_img_aug, _ = augment(real_img, ada_aug_p)
            fake_img, _ = augment(fake_img, ada_aug_p)

        else:
            real_img_aug = real_img

        fake_pred = discriminator(fake_img)
        real_pred = discriminator(real_img_aug)
        d_loss = d_logistic_loss(real_pred, fake_pred)

        loss_dict["d"] = d_loss
        loss_dict["real_score"] = real_pred.mean()
        loss_dict["fake_score"] = fake_pred.mean()

        discriminator.zero_grad()
        d_loss.backward()
        d_optim.step()

        if args.augment and args.augment_p == 0:
            ada_aug_p = ada_augment.tune(real_pred)
            r_t_stat = ada_augment.r_t_stat

        d_regularize = i % args.d_reg_every == 0

        if d_regularize:
            real_img.requires_grad = True

            if args.augment:
                real_img_aug, _ = augment(real_img, ada_aug_p)

            else:
                real_img_aug = real_img

            real_pred = discriminator(real_img_aug)
            r1_loss = d_r1_loss(real_pred, real_img,args)

            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward()

            d_optim.step()

        loss_dict["r1"] = r1_loss

        requires_grad(generator, True)
        requires_grad(discriminator, False)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)

        if args.augment:
            fake_img, _ = augment(fake_img, ada_aug_p)

        fake_pred = discriminator(fake_img)
        g_loss = g_nonsaturating_loss(fake_pred)

        loss_dict["g"] = g_loss

        generator.zero_grad()
        g_loss.backward()
        g_optim.step()

        #g_reg starts
        g_regularize = False
        if args.useG_reg==True:
            # print("I entered g_reg")
            g_regularize = i % args.g_reg_every == 0
        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing, device)
            fake_img, latents = generator(noise, return_latents=True)

            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length
            )

            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss

            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()

            g_optim.step()

            mean_path_length_avg = (
                reduce_sum(mean_path_length).item() / get_world_size()
            )

        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        accumulate(g_ema, g_module, accum)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()

        if get_rank() == 0:
            pbar.set_description(
                (
                    f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
                    f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                    f"augment: {ada_aug_p:.4f}"
                )
            )

            # inception related:
            if args.eval_every > 0 and i % args.eval_every == 0:
                real_mean = real_cov = mean_latent = None
                with open(args.inception, "rb") as f:
                    embeds = pickle.load(f)
                    real_mean = embeds["mean"]
                    real_cov = embeds["cov"]
                    # print("yahooo!\n")
                with torch.no_grad():
                    g_ema.eval()
                    if args.truncation < 1:
                        mean_latent = g_ema.mean_latent(4096)
                    # print("I am fine sir!\n")
                    features = extract_feature_from_samples(
                        g_ema, inception, args.truncation, mean_latent, 64, args.n_sample_fid, device
                    ).numpy()
                    # print("I am normal sir!")
                    sample_mean = np.mean(features, 0)
                    sample_cov = np.cov(features, rowvar=False)
                    fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov)
                    with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f:
                        f.write(f"{i:07d}; fid: {float(fid):.4f};\n")
                    # print("alright hurray \n")

            if i % args.log_every == 0:
                with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                    f.write(
                        (
                            f"{i:07d}; "
                            f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
                            f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                            f"augment: {ada_aug_p:.4f};\n"
                        )
                    )

            if i % args.log_every == 0:
                with torch.no_grad():
                    g_ema.eval()
                    sample, _ = g_ema([sample_z])
                    utils.save_image(
                        sample,
                        os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}.png"),
                        nrow=int(args.n_sample ** 0.5),
                        normalize=True,
                        range=(-1, 1),
                    )

            if i % args.save_every == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"),
                )

            if i % args.save_latest_every == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight', f"latest.pt"),
                )
def train(args, loader, generator, encoder, discriminator, discriminator2,
          vggnet, g_optim, e_optim, d_optim, d2_optim, g_ema, e_ema, device):
    # kwargs_d = {'detach_aux': args.detach_d_aux_head}
    if args.dataset == 'imagefolder':
        loader = sample_data2(loader)
    else:
        loader = sample_data(loader)

    if args.eval_every > 0:
        inception = nn.DataParallel(load_patched_inception_v3()).to(device)
        inception.eval()
        with open(args.inception, "rb") as f:
            embeds = pickle.load(f)
            real_mean = embeds["mean"]
            real_cov = embeds["cov"]
    else:
        inception = real_mean = real_cov = None
    mean_latent = None

    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    dynamic_ncols=True,
                    smoothing=0.01)

    mean_path_length = 0

    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}
    avg_pix_loss = util.AverageMeter()
    avg_vgg_loss = util.AverageMeter()

    if args.distributed:
        g_module = generator.module
        e_module = encoder.module
        d_module = discriminator.module
    else:
        g_module = generator
        e_module = encoder
        d_module = discriminator

    d2_module = None
    if discriminator2 is not None:
        if args.distributed:
            d2_module = discriminator2.module
        else:
            d2_module = discriminator2

    accum = 0.5**(32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0

    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256,
                                      device)

    sample_z = torch.randn(args.n_sample, args.latent, device=device)
    sample_x = load_real_samples(args, loader)
    sample_x1 = sample_x[:, 0, ...]
    sample_x2 = sample_x[:, -1, ...]
    sample_idx = torch.randperm(args.n_sample)

    n_step_max = max(args.n_step_d, args.n_step_e)

    requires_grad(g_ema, False)
    requires_grad(e_ema, False)

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")
            break

        frames = [get_batch(loader, device) for _ in range(n_step_max)]

        # Train Discriminator
        requires_grad(generator, False)
        requires_grad(encoder, False)
        requires_grad(discriminator, True)
        for step_index in range(args.n_step_d):
            frames1, frames2 = frames[step_index]
            real_img = frames1
            noise = mixing_noise(args.batch, args.latent, args.mixing, device)
            if args.use_ema:
                g_ema.eval()
                fake_img, _ = g_ema(noise)
            else:
                fake_img, _ = generator(noise)
            if args.augment:
                real_img_aug, _ = augment(real_img, ada_aug_p)
                fake_img, _ = augment(fake_img, ada_aug_p)
            else:
                real_img_aug = real_img
            fake_pred = discriminator(fake_img)
            real_pred = discriminator(real_img_aug)
            d_loss_fake = F.softplus(fake_pred).mean()
            d_loss_real = F.softplus(-real_pred).mean()
            loss_dict["real_score"] = real_pred.mean()
            loss_dict["fake_score"] = fake_pred.mean()

            d_loss_rec = 0.
            if args.lambda_rec_d > 0 and not args.decouple_d:  # Do not train D on x_rec if decouple_d
                if args.use_ema:
                    e_ema.eval()
                    g_ema.eval()
                    latent_real, _ = e_ema(real_img)
                    rec_img, _ = g_ema([latent_real], input_is_latent=True)
                else:
                    latent_real, _ = encoder(real_img)
                    rec_img, _ = generator([latent_real], input_is_latent=True)
                rec_pred = discriminator(rec_img)
                d_loss_rec = F.softplus(rec_pred).mean()
                loss_dict["rec_score"] = rec_pred.mean()

            d_loss_cross = 0.
            if args.lambda_cross_d > 0 and not args.decouple_d:
                if args.use_ema:
                    e_ema.eval()
                    w1, _ = e_ema(frames1)
                    w2, _ = e_ema(frames2)
                else:
                    w1, _ = encoder(frames1)
                    w2, _ = encoder(frames2)
                dw = w2 - w1
                dw_shuffle = dw[torch.randperm(args.batch), ...]
                if args.use_ema:
                    g_ema.eval()
                    cross_img, _ = g_ema([w1 + dw_shuffle],
                                         input_is_latent=True)
                else:
                    cross_img, _ = generator([w1 + dw_shuffle],
                                             input_is_latent=True)
                cross_pred = discriminator(cross_img)
                d_loss_cross = F.softplus(cross_pred).mean()

            d_loss_fake_cross = 0.
            if args.lambda_fake_cross_d > 0:
                if args.use_ema:
                    e_ema.eval()
                    w1, _ = e_ema(frames1)
                    w2, _ = e_ema(frames2)
                else:
                    w1, _ = encoder(frames1)
                    w2, _ = encoder(frames2)
                dw = w2 - w1
                noise = mixing_noise(args.batch, args.latent, args.mixing,
                                     device)
                if args.use_ema:
                    g_ema.eval()
                    style = g_ema.get_styles(noise).view(args.batch, -1)
                else:
                    style = generator.get_styles(noise).view(args.batch, -1)
                if dw.shape[1] < style.shape[1]:  # W space
                    dw = dw.repeat(1, args.n_latent)
                if args.use_ema:
                    cross_img, _ = g_ema([style + dw], input_is_latent=True)
                else:
                    cross_img, _ = generator([style + dw],
                                             input_is_latent=True)
                fake_cross_pred = discriminator(cross_img)
                d_loss_fake_cross = F.softplus(fake_cross_pred).mean()

            d_loss = (d_loss_real + d_loss_fake +
                      d_loss_fake_cross * args.lambda_fake_cross_d +
                      d_loss_rec * args.lambda_rec_d +
                      d_loss_cross * args.lambda_cross_d)
            loss_dict["d"] = d_loss

            discriminator.zero_grad()
            d_loss.backward()
            d_optim.step()

        if args.augment and args.augment_p == 0:
            ada_aug_p = ada_augment.tune(real_pred)
            r_t_stat = ada_augment.r_t_stat

        d_regularize = i % args.d_reg_every == 0
        if d_regularize:
            real_img.requires_grad = True
            real_pred = discriminator(real_img)
            r1_loss = d_r1_loss(real_pred, real_img)
            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every +
             0 * real_pred[0]).backward()
            d_optim.step()
        loss_dict["r1"] = r1_loss

        # Train Discriminator2
        if args.decouple_d and discriminator2 is not None:
            requires_grad(generator, False)
            requires_grad(encoder, False)
            requires_grad(discriminator2, True)
            for step_index in range(
                    args.n_step_e):  # n_step_d2 is same as n_step_e
                frames1, frames2 = frames[step_index]
                real_img = frames1
                if args.use_ema:
                    e_ema.eval()
                    g_ema.eval()
                    latent_real, _ = e_ema(real_img)
                    rec_img, _ = g_ema([latent_real], input_is_latent=True)
                else:
                    latent_real, _ = encoder(real_img)
                    rec_img, _ = generator([latent_real], input_is_latent=True)
                rec_pred = discriminator2(rec_img)
                d2_loss_rec = F.softplus(rec_pred).mean()
                real_pred1 = discriminator2(frames1)
                d2_loss_real = F.softplus(-real_pred1).mean()
                if args.use_frames2_d:
                    real_pred2 = discriminator2(frames2)
                    d2_loss_real += F.softplus(-real_pred2).mean()

                if args.use_ema:
                    e_ema.eval()
                    w1, _ = e_ema(frames1)
                    w2, _ = e_ema(frames2)
                else:
                    w1, _ = encoder(frames1)
                    w2, _ = encoder(frames2)
                dw = w2 - w1
                dw_shuffle = dw[torch.randperm(args.batch), ...]
                cross_img, _ = generator([w1 + dw_shuffle],
                                         input_is_latent=True)
                cross_pred = discriminator2(cross_img)
                d2_loss_cross = F.softplus(cross_pred).mean()

                d2_loss = d2_loss_real + d2_loss_rec + d2_loss_cross
                loss_dict["d2"] = d2_loss
                loss_dict["rec_score"] = rec_pred.mean()
                loss_dict["cross_score"] = cross_pred.mean()

                discriminator2.zero_grad()
                d2_loss.backward()
                d2_optim.step()

            d_regularize = i % args.d_reg_every == 0
            if d_regularize:
                real_img.requires_grad = True
                real_pred = discriminator2(real_img)
                r1_loss = d_r1_loss(real_pred, real_img)
                discriminator2.zero_grad()
                (args.r1 / 2 * r1_loss * args.d_reg_every +
                 0 * real_pred[0]).backward()
                d2_optim.step()

        # Train Encoder
        requires_grad(encoder, True)
        requires_grad(generator, args.train_ge)
        requires_grad(discriminator, False)
        if discriminator2 is not None:
            requires_grad(discriminator2, False)
        pix_loss = vgg_loss = adv_loss = torch.tensor(0., device=device)
        for step_index in range(args.n_step_e):
            frames1, frames2 = frames[step_index]
            real_img = frames1
            latent_real, _ = encoder(real_img)
            if args.use_ema:
                g_ema.eval()
                rec_img, _ = g_ema([latent_real], input_is_latent=True)
            else:
                rec_img, _ = generator([latent_real], input_is_latent=True)
            if args.lambda_adv > 0:
                if not args.decouple_d:
                    rec_pred = discriminator(rec_img)
                else:
                    rec_pred = discriminator2(rec_img)
                adv_loss = g_nonsaturating_loss(rec_pred)
            if args.lambda_pix > 0:
                if args.pix_loss == 'l2':
                    pix_loss = torch.mean((rec_img - real_img)**2)
                else:
                    pix_loss = F.l1_loss(rec_img, real_img)
            if args.lambda_vgg > 0:
                vgg_loss = torch.mean((vggnet(real_img) - vggnet(rec_img))**2)

            e_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv
            loss_dict["e"] = e_loss
            loss_dict["pix"] = pix_loss
            loss_dict["vgg"] = vgg_loss
            loss_dict["adv"] = adv_loss

            if args.train_ge:
                encoder.zero_grad()
                generator.zero_grad()
                e_loss.backward()
                e_optim.step()
                g_optim.step()
            else:
                encoder.zero_grad()
                e_loss.backward()
                e_optim.step()

        # Train Generator
        requires_grad(generator, True)
        requires_grad(discriminator, False)
        if discriminator2 is not None:
            requires_grad(discriminator2, False)
        frames1, frames2 = frames[0]
        real_img = frames1
        g_loss_fake = 0.
        if args.lambda_fake_g > 0:
            noise = mixing_noise(args.batch, args.latent, args.mixing, device)
            fake_img, _ = generator(noise)
            if args.augment:
                fake_img, _ = augment(fake_img, ada_aug_p)
            fake_pred = discriminator(fake_img)
            g_loss_fake = g_nonsaturating_loss(fake_pred)

        g_loss_rec = 0.
        if args.lambda_rec_g > 0:
            if args.use_ema:
                e_ema.eval()
                latent_real, _ = e_ema(real_img)
            else:
                latent_real, _ = encoder(real_img)
            rec_img, _ = generator([latent_real], input_is_latent=True)
            if not args.decouple_d:
                rec_pred = discriminator(rec_img)
            else:
                rec_pred = discriminator2(rec_img)
            g_loss_rec = g_nonsaturating_loss(rec_pred)

        g_loss_cross = 0.
        if args.lambda_cross_g > 0:
            if args.use_ema:
                e_ema.eval()
                w1, _ = e_ema(frames1)
                w2, _ = e_ema(frames2)
            else:
                w1, _ = encoder(frames1)
                w2, _ = encoder(frames2)
            dw = w2 - w1
            dw_shuffle = dw[torch.randperm(args.batch), ...]
            cross_img, _ = generator([w1 + dw_shuffle], input_is_latent=True)
            if not args.decouple_d:
                cross_pred = discriminator(cross_img)
            else:
                cross_pred = discriminator2(cross_img)
            g_loss_cross = g_nonsaturating_loss(cross_pred)

        g_loss_fake_cross = 0.
        if args.lambda_fake_cross_g > 0:
            if args.use_ema:
                e_ema.eval()
                w1, _ = e_ema(frames1)
                w2, _ = e_ema(frames2)
            else:
                w1, _ = encoder(frames1)
                w2, _ = encoder(frames2)
            dw = w2 - w1
            noise = mixing_noise(args.batch, args.latent, args.mixing, device)
            style = generator.get_styles(noise).view(args.batch, -1)
            if dw.shape[1] < style.shape[1]:  # W space
                dw = dw.repeat(1, args.n_latent)
            cross_img, _ = generator([style + dw], input_is_latent=True)
            fake_cross_pred = discriminator(cross_img)
            g_loss_fake_cross = g_nonsaturating_loss(fake_cross_pred)

        g_loss = (g_loss_fake * args.lambda_fake_g +
                  g_loss_rec * args.lambda_rec_g +
                  g_loss_cross * args.lambda_cross_g +
                  g_loss_fake_cross * args.lambda_fake_cross_g)
        loss_dict["g"] = g_loss

        generator.zero_grad()
        g_loss.backward()
        g_optim.step()

        g_regularize = args.g_reg_every > 0 and i % args.g_reg_every == 0
        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing,
                                 device)
            fake_img, latents = generator(noise, return_latents=True)
            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)
            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss
            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]
            weighted_path_loss.backward()
            g_optim.step()
            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())
        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        accumulate(e_ema, e_module, accum)
        accumulate(g_ema, g_module, accum)

        loss_reduced = reduce_loss_dict(loss_dict)
        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()
        pix_loss_val = loss_reduced["pix"].mean().item()
        vgg_loss_val = loss_reduced["vgg"].mean().item()
        adv_loss_val = loss_reduced["adv"].mean().item()
        avg_pix_loss.update(pix_loss_val, real_img.shape[0])
        avg_vgg_loss.update(vgg_loss_val, real_img.shape[0])

        if get_rank() == 0:
            pbar.set_description((
                f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
                f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                f"augment: {ada_aug_p:.4f}; "
                f"pix: {pix_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; adv: {adv_loss_val:.4f}"
            ))

            if i % args.log_every == 0:
                with torch.no_grad():
                    latent_x, _ = e_ema(sample_x)
                    fake_x, _ = generator([latent_x],
                                          input_is_latent=True,
                                          return_latents=False)
                    sample_pix_loss = torch.sum((sample_x - fake_x)**2)
                with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                    f.write(
                        f"{i:07d}; pix: {avg_pix_loss.avg}; vgg: {avg_vgg_loss.avg}; "
                        f"ref: {sample_pix_loss.item()};\n")

            if args.eval_every > 0 and i % args.eval_every == 0:
                with torch.no_grad():
                    g_ema.eval()
                    if args.truncation < 1:
                        mean_latent = g_ema.mean_latent(4096)
                    features = extract_feature_from_samples(
                        g_ema, inception, args.truncation, mean_latent, 64,
                        args.n_sample_fid, args.device).numpy()
                    sample_mean = np.mean(features, 0)
                    sample_cov = np.cov(features, rowvar=False)
                    fid = calc_fid(sample_mean, sample_cov, real_mean,
                                   real_cov)
                print("fid:", fid)
                with open(os.path.join(args.log_dir, 'log_fid.txt'),
                          'a+') as f:
                    f.write(f"{i:07d}; fid: {float(fid):.4f};\n")

            if wandb and args.wandb:
                wandb.log({
                    "Generator": g_loss_val,
                    "Discriminator": d_loss_val,
                    "Augment": ada_aug_p,
                    "Rt": r_t_stat,
                    "R1": r1_val,
                    "Path Length Regularization": path_loss_val,
                    "Mean Path Length": mean_path_length,
                    "Real Score": real_score_val,
                    "Fake Score": fake_score_val,
                    "Path Length": path_length_val,
                })

            if i % args.log_every == 0:
                with torch.no_grad():
                    # Fixed fake samples
                    g_ema.eval()
                    sample, _ = g_ema([sample_z])
                    utils.save_image(
                        sample,
                        os.path.join(args.log_dir, 'sample',
                                     f"{str(i).zfill(6)}-sample.png"),
                        nrow=int(args.n_sample**0.5),
                        normalize=True,
                        value_range=(-1, 1),
                    )
                    # Reconstruction samples
                    e_ema.eval()
                    nrow = int(args.n_sample**0.5)
                    nchw = list(sample_x.shape)[1:]
                    latent_real, _ = e_ema(sample_x)
                    fake_img, _ = g_ema([latent_real],
                                        input_is_latent=True,
                                        return_latents=False)
                    sample = torch.cat(
                        (sample_x.reshape(args.n_sample // nrow, nrow, *nchw),
                         fake_img.reshape(args.n_sample // nrow, nrow, *nchw)),
                        1)
                    utils.save_image(
                        sample.reshape(2 * args.n_sample, *nchw),
                        os.path.join(args.log_dir, 'sample',
                                     f"{str(i).zfill(6)}-recon.png"),
                        nrow=nrow,
                        normalize=True,
                        value_range=(-1, 1),
                    )

            if i % args.save_every == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "d2":
                        d2_module.state_dict() if args.decouple_d else None,
                        "g_ema": g_ema.state_dict(),
                        "e_ema": e_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "d2_optim":
                        d2_optim.state_dict() if args.decouple_d else None,
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight',
                                 f"{str(i).zfill(6)}.pt"),
                )

            if i % args.save_latest_every == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "d2":
                        d2_module.state_dict() if args.decouple_d else None,
                        "g_ema": g_ema.state_dict(),
                        "e_ema": e_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "d2_optim":
                        d2_optim.state_dict() if args.decouple_d else None,
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight', f"latest.pt"),
                )
Beispiel #9
0
def train(args, loader, loader2, generator, encoder, discriminator,
          discriminator2, vggnet, g_optim, e_optim, d_optim, d2_optim, g_ema,
          e_ema, device):
    inception = real_mean = real_cov = mean_latent = None
    if args.eval_every > 0:
        inception = nn.DataParallel(load_patched_inception_v3()).to(device)
        inception.eval()
        with open(args.inception, "rb") as f:
            embeds = pickle.load(f)
            real_mean = embeds["mean"]
            real_cov = embeds["cov"]
    if get_rank() == 0:
        if args.eval_every > 0:
            with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f:
                f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n")
        if args.log_every > 0:
            with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n")

    loader = sample_data(loader)
    pbar = range(args.iter)
    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    dynamic_ncols=True,
                    smoothing=0.01)

    mean_path_length = 0
    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {
        'recx_score': torch.tensor(0.0, device=device),
        'ae_fake': torch.tensor(0.0, device=device),
        'ae_real': torch.tensor(0.0, device=device),
        'pix': torch.tensor(0.0, device=device),
        'vgg': torch.tensor(0.0, device=device),
    }
    avg_pix_loss = util.AverageMeter()
    avg_vgg_loss = util.AverageMeter()

    if args.distributed:
        g_module = generator.module
        e_module = encoder.module
        d_module = discriminator.module
    else:
        g_module = generator
        e_module = encoder
        d_module = discriminator

    d2_module = None
    if discriminator2 is not None:
        if args.distributed:
            d2_module = discriminator2.module
        else:
            d2_module = discriminator2

    # When joint training enabled, d_weight balances reconstruction loss and adversarial loss on
    # recontructed real images. This does not balance the overall AE loss and GAN loss.
    d_weight = torch.tensor(1.0, device=device)
    last_layer = None
    if args.use_adaptive_weight:
        if args.distributed:
            last_layer = generator.module.get_last_layer()
        else:
            last_layer = generator.get_last_layer()
    g_scale = 1

    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0
    r_t_dict = {'real': 0, 'fake': 0, 'recx': 0}  # r_t stat

    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length,
                                      args.ada_every, device)

    sample_z = torch.randn(args.n_sample, args.latent, device=device)
    sample_x = load_real_samples(args, loader)
    if sample_x.ndim > 4:
        sample_x = sample_x[:, 0, ...]

    input_is_latent = args.latent_space != 'z'  # Encode in z space?

    n_step_max = max(args.n_step_d, args.n_step_e)

    requires_grad(g_ema, False)
    requires_grad(e_ema, False)

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")
            break

        if args.debug: util.seed_everything(i)
        real_imgs = [next(loader).to(device) for _ in range(n_step_max)]

        # Train Discriminator and Encoder
        requires_grad(generator, False)
        requires_grad(encoder, True)
        requires_grad(discriminator, True)
        requires_grad(discriminator2, True)
        for step_index in range(args.n_step_d):
            real_img = real_imgs[step_index]
            noise = mixing_noise(args.batch, args.latent, args.mixing, device)
            fake_img, _ = generator(noise)
            if args.augment:
                real_img_aug, _ = augment(real_img, ada_aug_p)
                fake_img_aug, _ = augment(fake_img, ada_aug_p)
            else:
                real_img_aug = real_img
                fake_img_aug = fake_img
            real_pred = discriminator(encoder(real_img_aug)[0])
            fake_pred = discriminator(encoder(fake_img_aug)[0])
            d_loss_real = F.softplus(-real_pred).mean()
            d_loss_fake = F.softplus(fake_pred).mean()
            loss_dict["real_score"] = real_pred.mean()
            loss_dict["fake_score"] = fake_pred.mean()

            d_loss_rec = 0.
            if args.lambda_rec_d > 0:
                latent_real, _ = encoder(real_img)
                rec_img, _ = generator([latent_real],
                                       input_is_latent=input_is_latent)
                if args.augment:
                    rec_img, _ = augment(rec_img, ada_aug_p)
                rec_pred = discriminator(encoder(rec_img)[0])
                d_loss_rec = F.softplus(rec_pred).mean()
                loss_dict["recx_score"] = rec_pred.mean()
                r_t_dict['recx'] = torch.sign(
                    rec_pred).sum().item() / args.batch

            d_loss = d_loss_real + d_loss_fake * args.lambda_fake_d + d_loss_rec * args.lambda_rec_d
            loss_dict["d"] = d_loss

            discriminator.zero_grad()
            encoder.zero_grad()
            d_loss.backward()
            d_optim.step()
            e_optim.step()

        if args.augment and args.augment_p == 0:
            ada_aug_p = ada_augment.tune(real_pred)
            r_t_stat = ada_augment.r_t_stat
        r_t_dict['real'] = torch.sign(real_pred).sum().item() / args.batch
        r_t_dict['fake'] = torch.sign(fake_pred).sum().item() / args.batch

        d_regularize = args.d_reg_every > 0 and i % args.d_reg_every == 0
        if d_regularize:
            real_img.requires_grad = True
            if args.augment:
                real_img_aug, _ = augment(real_img, ada_aug_p)
            else:
                real_img_aug = real_img
            real_pred = discriminator(encoder(real_img_aug)[0])
            r1_loss = d_r1_loss(real_pred, real_img)
            discriminator.zero_grad()
            encoder.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every +
             0 * real_pred[0]).backward()
            d_optim.step()
            e_optim.step()
        loss_dict["r1"] = r1_loss

        # Train Generator
        requires_grad(generator, True)
        requires_grad(encoder, False)
        requires_grad(discriminator, False)
        requires_grad(discriminator2, False)
        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)
        if args.augment:
            fake_img_aug, _ = augment(fake_img, ada_aug_p)
        else:
            fake_img_aug = fake_img
        fake_pred = discriminator(encoder(fake_img_aug)[0])
        g_loss_fake = g_nonsaturating_loss(fake_pred)
        loss_dict["g"] = g_loss_fake
        generator.zero_grad()
        (g_loss_fake * args.lambda_fake_g).backward()
        g_optim.step()

        g_regularize = args.g_reg_every > 0 and i % args.g_reg_every == 0
        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing,
                                 device)
            fake_img, latents = generator(noise, return_latents=True)
            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)
            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss
            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]
            weighted_path_loss.backward()
            g_optim.step()
            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())
        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        # Train Encoder (and Generator)
        joint = (not args.no_joint) and (g_scale > 1e-6)

        # Train AE on fake samples (latent reconstruction)
        if args.lambda_rec_w + (args.lambda_pix_fake + args.lambda_vgg_fake +
                                args.lambda_adv_fake) > 0:
            requires_grad(encoder, True)
            requires_grad(generator, joint)
            requires_grad(discriminator, False)
            requires_grad(discriminator2, False)
            for step_index in range(args.n_step_e):
                # mixing_prob = 0 if args.which_latent == 'w_tied' else args.mixing
                # noise = mixing_noise(args.batch, args.latent, mixing_prob, device)
                # fake_img, latent_fake = generator(noise, return_latents=True, detach_style=not args.no_detach_style)
                # if args.which_latent == 'w_tied':
                #     latent_fake = latent_fake[:,0,:]
                # else:
                #     latent_fake = latent_fake.view(args.batch, -1)
                # latent_pred, _ = encoder(fake_img)
                # ae_loss_fake = torch.mean((latent_pred - latent_fake.detach()) ** 2)

                ae_loss_fake = 0
                mixing_prob = 0 if args.which_latent == 'w_tied' else args.mixing
                if args.lambda_rec_w > 0:
                    noise = mixing_noise(args.batch, args.latent, mixing_prob,
                                         device)
                    fake_img, latent_fake = generator(
                        noise,
                        return_latents=True,
                        detach_style=not args.no_detach_style)
                    if args.which_latent == 'w_tied':
                        latent_fake = latent_fake[:, 0, :]
                    else:
                        latent_fake = latent_fake.view(args.batch, -1)
                    latent_pred, _ = encoder(fake_img)
                    ae_loss_fake = torch.mean(
                        (latent_pred - latent_fake.detach())**2)

                if args.lambda_pix_fake + args.lambda_vgg_fake + args.lambda_adv_fake > 0:
                    pix_loss = vgg_loss = adv_loss = torch.tensor(
                        0., device=device)
                    noise = mixing_noise(args.batch, args.latent, mixing_prob,
                                         device)
                    fake_img, _ = generator(noise, detach_style=False)
                    fake_img = fake_img.detach()

                    latent_pred, _ = encoder(fake_img)
                    rec_img, _ = generator([latent_pred],
                                           input_is_latent=input_is_latent)
                    if args.lambda_pix_fake > 0:
                        if args.pix_loss == 'l2':
                            pix_loss = torch.mean((rec_img - fake_img)**2)
                        elif args.pix_loss == 'l1':
                            pix_loss = F.l1_loss(rec_img, fake_img)
                    if args.lambda_vgg_fake > 0:
                        vgg_loss = torch.mean(
                            (vggnet(fake_img) - vggnet(rec_img))**2)

                    ae_loss_fake = (ae_loss_fake +
                                    pix_loss * args.lambda_pix_fake +
                                    vgg_loss * args.lambda_vgg_fake)

                loss_dict["ae_fake"] = ae_loss_fake

                if joint:
                    encoder.zero_grad()
                    generator.zero_grad()
                    (ae_loss_fake * args.lambda_rec_w).backward()
                    e_optim.step()
                    if args.g_decay is not None:
                        scale_grad(generator, g_scale)
                    # Do NOT update F (or generator.style). Grad should be zero when style
                    # is detached in generator, but we explicitly zero it, just in case.
                    if not args.no_detach_style:
                        generator.style.zero_grad()
                    g_optim.step()
                else:
                    encoder.zero_grad()
                    (ae_loss_fake * args.lambda_rec_w).backward()
                    e_optim.step()

        # Train AE on real samples (image reconstruction)
        if args.lambda_pix + args.lambda_vgg + args.lambda_adv > 0:
            requires_grad(encoder, True)
            requires_grad(generator, joint)
            requires_grad(discriminator, False)
            requires_grad(discriminator2, False)
            pix_loss = vgg_loss = adv_loss = torch.tensor(0., device=device)
            for step_index in range(args.n_step_e):
                real_img = real_imgs[step_index]
                latent_real, _ = encoder(real_img)
                rec_img, _ = generator([latent_real],
                                       input_is_latent=input_is_latent)
                if args.lambda_pix > 0:
                    if args.pix_loss == 'l2':
                        pix_loss = torch.mean((rec_img - real_img)**2)
                    elif args.pix_loss == 'l1':
                        pix_loss = F.l1_loss(rec_img, real_img)
                if args.lambda_vgg > 0:
                    vgg_loss = torch.mean(
                        (vggnet(real_img) - vggnet(rec_img))**2)
                if args.lambda_adv > 0:
                    if args.augment:
                        rec_img_aug, _ = augment(rec_img, ada_aug_p)
                    else:
                        rec_img_aug = rec_img
                    rec_pred = discriminator(encoder(rec_img_aug)[0])
                    adv_loss = g_nonsaturating_loss(rec_pred)

                if args.use_adaptive_weight and i >= args.disc_iter_start:
                    nll_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg
                    g_loss = adv_loss * args.lambda_adv
                    d_weight = calculate_adaptive_weight(nll_loss,
                                                         g_loss,
                                                         last_layer=last_layer)

                ae_loss_real = (pix_loss * args.lambda_pix +
                                vgg_loss * args.lambda_vgg +
                                d_weight * adv_loss * args.lambda_adv)
                loss_dict["ae_real"] = ae_loss_real
                loss_dict["pix"] = pix_loss
                loss_dict["vgg"] = vgg_loss
                loss_dict["adv"] = adv_loss

                if joint:
                    encoder.zero_grad()
                    generator.zero_grad()
                    ae_loss_real.backward()
                    e_optim.step()
                    if args.g_decay is not None:
                        scale_grad(generator, g_scale)
                    g_optim.step()
                else:
                    encoder.zero_grad()
                    ae_loss_real.backward()
                    e_optim.step()

        if args.g_decay is not None:
            g_scale *= args.g_decay

        # Update EMA
        ema_nimg = args.ema_kimg * 1000
        if args.ema_rampup is not None:
            ema_nimg = min(ema_nimg, i * args.batch * args.ema_rampup)
        accum = 0.5**(args.batch / max(ema_nimg, 1e-8))
        accumulate(g_ema, g_module, 0 if args.no_ema_g else accum)
        accumulate(e_ema, e_module, 0 if args.no_ema_e else accum)

        loss_reduced = reduce_loss_dict(loss_dict)
        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        ae_real_val = loss_reduced["ae_real"].mean().item()
        ae_fake_val = loss_reduced["ae_fake"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        recx_score_val = loss_reduced["recx_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()
        pix_loss_val = loss_reduced["pix"].mean().item()
        vgg_loss_val = loss_reduced["vgg"].mean().item()
        avg_pix_loss.update(pix_loss_val, real_img.shape[0])
        avg_vgg_loss.update(vgg_loss_val, real_img.shape[0])

        if get_rank() == 0:
            pbar.set_description((
                f"d: {d_loss_val:.4f}; r1: {r1_val:.4f}; "
                f"ae_fake: {ae_fake_val:.4f}; ae_real: {ae_real_val:.4f}; "
                f"g: {g_loss_val:.4f}; path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                f"augment: {ada_aug_p:.4f}; "
                f"d_weight: {d_weight.item():.4f}; "))

            if i % args.log_every == 0:
                with torch.no_grad():
                    g_ema.eval()
                    e_ema.eval()
                    nrow = int(args.n_sample**0.5)
                    nchw = list(sample_x.shape)[1:]
                    # Reconstruction of real images
                    latent_x, _ = e_ema(sample_x)
                    rec_real, _ = g_ema([latent_x],
                                        input_is_latent=input_is_latent)
                    sample = torch.cat(
                        (sample_x.reshape(args.n_sample // nrow, nrow, *nchw),
                         rec_real.reshape(args.n_sample // nrow, nrow, *nchw)),
                        1)
                    utils.save_image(
                        sample.reshape(2 * args.n_sample, *nchw),
                        os.path.join(args.log_dir, 'sample',
                                     f"{str(i).zfill(6)}-recon.png"),
                        nrow=nrow,
                        normalize=True,
                        value_range=(-1, 1),
                    )
                    ref_pix_loss = torch.sum(torch.abs(sample_x - rec_real))
                    ref_vgg_loss = torch.mean(
                        (vggnet(sample_x) -
                         vggnet(rec_real))**2) if vggnet is not None else 0
                    # Fixed fake samples and reconstructions
                    sample_gz, _ = g_ema([sample_z])
                    latent_gz, _ = e_ema(sample_gz)
                    rec_fake, _ = g_ema([latent_gz],
                                        input_is_latent=input_is_latent)
                    sample = torch.cat(
                        (sample_gz.reshape(args.n_sample // nrow, nrow, *nchw),
                         rec_fake.reshape(args.n_sample // nrow, nrow, *nchw)),
                        1)
                    utils.save_image(
                        sample.reshape(2 * args.n_sample, *nchw),
                        os.path.join(args.log_dir, 'sample',
                                     f"{str(i).zfill(6)}-sample.png"),
                        nrow=nrow,
                        normalize=True,
                        value_range=(-1, 1),
                    )

                with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                    f.write((
                        f"{i:07d}; "
                        f"d: {d_loss_val:.4f}; r1: {r1_val:.4f}; "
                        f"ae_fake: {ae_fake_val:.4f}; ae_real: {ae_real_val:.4f}; "
                        f"g: {g_loss_val:.4f}; path: {path_loss_val:.4f}; mean_path: {mean_path_length_avg:.4f}; "
                        f"augment: {ada_aug_p:.4f}; {'; '.join([f'{k}: {r_t_dict[k]:.4f}' for k in r_t_dict])}; "
                        f"real_score: {real_score_val:.4f}; fake_score: {fake_score_val:.4f}; recx_score: {recx_score_val:.4f}; "
                        f"pix: {avg_pix_loss.avg:.4f}; vgg: {avg_vgg_loss.avg:.4f}; "
                        f"ref_pix: {ref_pix_loss.item():.4f}; ref_vgg: {ref_vgg_loss.item():.4f}; "
                        f"d_weight: {d_weight.item():.4f}; "
                        f"\n"))

            if wandb and args.wandb:
                wandb.log({
                    "Generator": g_loss_val,
                    "Discriminator": d_loss_val,
                    "Augment": ada_aug_p,
                    "Rt": r_t_stat,
                    "R1": r1_val,
                    "Path Length Regularization": path_loss_val,
                    "Mean Path Length": mean_path_length,
                    "Real Score": real_score_val,
                    "Fake Score": fake_score_val,
                    "Path Length": path_length_val,
                })

            if args.eval_every > 0 and i % args.eval_every == 0:
                with torch.no_grad():
                    fid_sa = fid_re = fid_sr = 0
                    g_ema.eval()
                    e_ema.eval()
                    if args.truncation < 1:
                        mean_latent = g_ema.mean_latent(4096)
                    # Sample FID
                    if 'fid_sample' in args.which_metric:
                        features = extract_feature_from_samples(
                            g_ema, inception, args.truncation, mean_latent, 64,
                            args.n_sample_fid, args.device).numpy()
                        sample_mean = np.mean(features, 0)
                        sample_cov = np.cov(features, rowvar=False)
                        fid_sa = calc_fid(sample_mean, sample_cov, real_mean,
                                          real_cov)
                    # Sample reconstruction FID
                    if 'fid_sample_recon' in args.which_metric:
                        features = extract_feature_from_samples(
                            g_ema,
                            inception,
                            args.truncation,
                            mean_latent,
                            64,
                            args.n_sample_fid,
                            args.device,
                            mode='recon',
                            encoder=e_ema,
                            input_is_latent=input_is_latent,
                        ).numpy()
                        sample_mean = np.mean(features, 0)
                        sample_cov = np.cov(features, rowvar=False)
                        fid_sr = calc_fid(sample_mean, sample_cov, real_mean,
                                          real_cov)
                    # Real reconstruction FID
                    if 'fid_recon' in args.which_metric:
                        features = extract_feature_from_reconstruction(
                            e_ema,
                            g_ema,
                            inception,
                            args.truncation,
                            mean_latent,
                            loader2,
                            args.device,
                            input_is_latent=input_is_latent,
                            mode='recon',
                        ).numpy()
                        sample_mean = np.mean(features, 0)
                        sample_cov = np.cov(features, rowvar=False)
                        fid_re = calc_fid(sample_mean, sample_cov, real_mean,
                                          real_cov)
                with open(os.path.join(args.log_dir, 'log_fid.txt'),
                          'a+') as f:
                    f.write(
                        f"{i:07d}; sample: {float(fid_sa):.4f}; rec_fake: {float(fid_sr):.4f}; rec_real: {float(fid_re):.4f};\n"
                    )

            if i % args.save_every == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "d2":
                        d2_module.state_dict() if args.decouple_d else None,
                        "g_ema": g_ema.state_dict(),
                        "e_ema": e_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "d2_optim":
                        d2_optim.state_dict() if args.decouple_d else None,
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight',
                                 f"{str(i).zfill(6)}.pt"),
                )

            if i % args.save_latest_every == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "d2":
                        d2_module.state_dict() if args.decouple_d else None,
                        "g_ema": g_ema.state_dict(),
                        "e_ema": e_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "d2_optim":
                        d2_optim.state_dict() if args.decouple_d else None,
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight', f"latest.pt"),
                )
Beispiel #10
0
    filename = os.path.join(args.ckpt_dir, f'fid_{args.n_sample}.csv')
    # load values that are already computed
    computed = []
    if os.path.exists(filename):
        with open(filename, 'r') as f:
            reader = csv.reader(f, delimiter=',')
            for row in reader:
                computed += [row[0]]
    
    # prepare to write
    f = open(filename, mode='a')
    writer = csv.writer(f, delimiter=',')
        
    # load inception model
    inception = load_patched_inception_v3()
    inception = inception.eval().to(device)
    
    ckpt_paths = glob(os.path.join(args.ckpt_dir, '*.ckpt')) + glob(os.path.join(args.ckpt_dir, '*.pt'))+glob(os.path.join(args.ckpt_dir, '*.pth'))
    ckpt_paths = sorted(ckpt_paths)
    print('records:', ckpt_paths)
    print('computed:', computed)
    for ckpt_path in ckpt_paths:
        print()
        print(f'working on {ckpt_path}')
        iteration = os.path.basename(ckpt_path).split('.')[0]
        if iteration in computed:
            print('already computed')
            continue
        
        args.ckpt_path = ckpt_path
Beispiel #11
0
def train(args, loader, loader2, encoder, generator, discriminator, vggnet,
          pwcnet, e_optim, d_optim, e_ema, pca_state, device):
    inception = real_mean = real_cov = mean_latent = None
    if args.eval_every > 0:
        inception = nn.DataParallel(load_patched_inception_v3()).to(device)
        inception.eval()
        with open(args.inception, "rb") as f:
            embeds = pickle.load(f)
            real_mean = embeds["mean"]
            real_cov = embeds["cov"]
    if get_rank() == 0:
        if args.eval_every > 0:
            with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f:
                f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n")
        if args.log_every > 0:
            with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n")

    loader = sample_data(loader)
    pbar = range(args.iter)
    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    dynamic_ncols=True,
                    smoothing=0.01)

    d_loss_val = 0
    e_loss_val = 0
    rec_loss_val = 0
    vgg_loss_val = 0
    adv_loss_val = 0
    loss_dict = {
        "d": torch.tensor(0., device=device),
        "real_score": torch.tensor(0., device=device),
        "fake_score": torch.tensor(0., device=device),
        "r1_d": torch.tensor(0., device=device),
        "r1_e": torch.tensor(0., device=device),
        "rec": torch.tensor(0., device=device),
    }
    avg_pix_loss = util.AverageMeter()
    avg_vgg_loss = util.AverageMeter()

    if args.distributed:
        e_module = encoder.module
        d_module = discriminator.module
        g_module = generator.module
    else:
        e_module = encoder
        d_module = discriminator
        g_module = generator

    # accum = 0.5 ** (32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0

    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length,
                                      args.ada_every, device)

    # sample_x = accumulate_batches(loader, args.n_sample).to(device)
    sample_x = load_real_samples(args, loader)
    if sample_x.ndim > 4:
        sample_x = sample_x[:, 0, ...]

    input_is_latent = args.latent_space != 'z'  # Encode in z space?

    requires_grad(generator, False)  # always False
    generator.eval()  # Generator should be ema and in eval mode
    g_ema = generator

    # if args.no_ema or e_ema is None:
    #     e_ema = encoder

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")
            break

        real_img = next(loader)
        real_img = real_img.to(device)

        # Train Encoder
        if args.toggle_grads:
            requires_grad(encoder, True)
            requires_grad(discriminator, False)
        pix_loss = vgg_loss = adv_loss = rec_loss = torch.tensor(0.,
                                                                 device=device)
        latent_real, _ = encoder(real_img)
        fake_img, _ = generator([latent_real], input_is_latent=input_is_latent)

        if args.lambda_adv > 0:
            if args.augment:
                fake_img_aug, _ = augment(fake_img, ada_aug_p)
            else:
                fake_img_aug = fake_img
            fake_pred = discriminator(fake_img_aug)
            adv_loss = g_nonsaturating_loss(fake_pred)

        if args.lambda_pix > 0:
            if args.pix_loss == 'l2':
                pix_loss = torch.mean((fake_img - real_img)**2)
            else:
                pix_loss = F.l1_loss(fake_img, real_img)

        if args.lambda_vgg > 0:
            real_feat = vggnet(real_img)
            fake_feat = vggnet(fake_img)
            vgg_loss = torch.mean((real_feat - fake_feat)**2)

        e_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv

        loss_dict["e"] = e_loss
        loss_dict["pix"] = pix_loss
        loss_dict["vgg"] = vgg_loss
        loss_dict["adv"] = adv_loss

        encoder.zero_grad()
        e_loss.backward()
        e_optim.step()

        if args.train_on_fake:
            e_regularize = args.e_rec_every > 0 and i % args.e_rec_every == 0
            if e_regularize and args.lambda_rec > 0:
                noise = mixing_noise(args.batch, args.latent, args.mixing,
                                     device)
                fake_img, latent_fake = generator(
                    noise,
                    input_is_latent=input_is_latent,
                    return_latents=True)
                latent_pred, _ = encoder(fake_img)
                if latent_pred.ndim < 3:
                    latent_pred = latent_pred.unsqueeze(1).repeat(
                        1, latent_fake.size(1), 1)
                rec_loss = torch.mean((latent_fake - latent_pred)**2)
                encoder.zero_grad()
                (rec_loss * args.lambda_rec).backward()
                e_optim.step()
                loss_dict["rec"] = rec_loss

        # e_regularize = args.e_reg_every > 0 and i % args.e_reg_every == 0
        # if e_regularize:
        #     # why not regularize on augmented real?
        #     real_img.requires_grad = True
        #     real_pred, _ = encoder(real_img)
        #     r1_loss_e = d_r1_loss(real_pred, real_img)

        #     encoder.zero_grad()
        #     (args.r1 / 2 * r1_loss_e * args.e_reg_every + 0 * real_pred.view(-1)[0]).backward()
        #     e_optim.step()

        #     loss_dict["r1_e"] = r1_loss_e

        if not args.no_ema and e_ema is not None:
            ema_nimg = args.ema_kimg * 1000
            if args.ema_rampup is not None:
                ema_nimg = min(ema_nimg, i * args.batch * args.ema_rampup)
            accum = 0.5**(args.batch / max(ema_nimg, 1e-8))
            accumulate(e_ema, e_module, accum)

        # Train Discriminator
        if args.toggle_grads:
            requires_grad(encoder, False)
            requires_grad(discriminator, True)
        if not args.no_update_discriminator and args.lambda_adv > 0:
            latent_real, _ = encoder(real_img)
            fake_img, _ = generator([latent_real],
                                    input_is_latent=input_is_latent)

            if args.augment:
                real_img_aug, _ = augment(real_img, ada_aug_p)
                fake_img_aug, _ = augment(fake_img, ada_aug_p)
            else:
                real_img_aug = real_img
                fake_img_aug = fake_img

            fake_pred = discriminator(fake_img_aug)
            real_pred = discriminator(real_img_aug)
            d_loss = d_logistic_loss(real_pred, fake_pred)

            loss_dict["d"] = d_loss
            loss_dict["real_score"] = real_pred.mean()
            loss_dict["fake_score"] = fake_pred.mean()

            discriminator.zero_grad()
            d_loss.backward()
            d_optim.step()

            if args.augment and args.augment_p == 0:
                ada_aug_p = ada_augment.tune(real_pred)
                r_t_stat = ada_augment.r_t_stat

            d_regularize = args.d_reg_every > 0 and i % args.d_reg_every == 0
            if d_regularize:
                # why not regularize on augmented real?
                real_img.requires_grad = True
                real_pred = discriminator(real_img)
                r1_loss_d = d_r1_loss(real_pred, real_img)

                discriminator.zero_grad()
                (args.r1 / 2 * r1_loss_d * args.d_reg_every +
                 0 * real_pred.view(-1)[0]).backward()
                # Why 0* ? Answer is here https://github.com/rosinality/stylegan2-pytorch/issues/76
                d_optim.step()

                loss_dict["r1_d"] = r1_loss_d

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        e_loss_val = loss_reduced["e"].mean().item()
        r1_d_val = loss_reduced["r1_d"].mean().item()
        r1_e_val = loss_reduced["r1_e"].mean().item()
        pix_loss_val = loss_reduced["pix"].mean().item()
        vgg_loss_val = loss_reduced["vgg"].mean().item()
        adv_loss_val = loss_reduced["adv"].mean().item()
        rec_loss_val = loss_reduced["rec"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        avg_pix_loss.update(pix_loss_val, real_img.shape[0])
        avg_vgg_loss.update(vgg_loss_val, real_img.shape[0])

        if get_rank() == 0:
            pbar.set_description((
                f"d: {d_loss_val:.4f}; e: {e_loss_val:.4f}; r1_d: {r1_d_val:.4f}; r1_e: {r1_e_val:.4f}; "
                f"pix: {pix_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; adv: {adv_loss_val:.4f}; "
                f"rec: {rec_loss_val:.4f}; augment: {ada_aug_p:.4f}"))

            if i % args.log_every == 0:
                with torch.no_grad():
                    latent_x, _ = e_ema(sample_x)
                    fake_x, _ = g_ema([latent_x],
                                      input_is_latent=input_is_latent)
                    sample_pix_loss = torch.sum((sample_x - fake_x)**2)
                with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                    f.write(
                        f"{i:07d}; pix: {avg_pix_loss.avg}; vgg: {avg_vgg_loss.avg}; "
                        f"ref: {sample_pix_loss.item()};\n")

            if args.eval_every > 0 and i % args.eval_every == 0:
                with torch.no_grad():
                    g_ema.eval()
                    e_ema.eval()
                    # Recon
                    features = extract_feature_from_reconstruction(
                        e_ema,
                        g_ema,
                        inception,
                        args.truncation,
                        mean_latent,
                        loader2,
                        args.device,
                        input_is_latent=input_is_latent,
                        mode='recon',
                    ).numpy()
                    sample_mean = np.mean(features, 0)
                    sample_cov = np.cov(features, rowvar=False)
                    fid_re = calc_fid(sample_mean, sample_cov, real_mean,
                                      real_cov)
                # print("Recon FID:", fid_re)
                with open(os.path.join(args.log_dir, 'log_fid.txt'),
                          'a+') as f:
                    f.write(f"{i:07d}; recon fid: {float(fid_re):.4f};\n")

            if wandb and args.wandb:
                wandb.log({
                    "Encoder": e_loss_val,
                    "Discriminator": d_loss_val,
                    "Augment": ada_aug_p,
                    "Rt": r_t_stat,
                    "R1 D": r1_d_val,
                    "R1 E": r1_e_val,
                    "Pix Loss": pix_loss_val,
                    "VGG Loss": vgg_loss_val,
                    "Adv Loss": adv_loss_val,
                    "Rec Loss": rec_loss_val,
                    "Real Score": real_score_val,
                    "Fake Score": fake_score_val,
                })

            if i % args.log_every == 0:
                with torch.no_grad():
                    e_eval = encoder if args.no_ema else e_ema
                    e_eval.eval()
                    nrow = int(args.n_sample**0.5)
                    nchw = list(sample_x.shape)[1:]
                    latent_real, _ = e_eval(sample_x)
                    fake_img, _ = generator([latent_real],
                                            input_is_latent=input_is_latent)
                    sample = torch.cat(
                        (sample_x.reshape(args.n_sample // nrow, nrow, *nchw),
                         fake_img.reshape(args.n_sample // nrow, nrow, *nchw)),
                        1)
                    utils.save_image(
                        sample.reshape(2 * args.n_sample, *nchw),
                        os.path.join(args.log_dir, 'sample',
                                     f"{str(i).zfill(6)}.png"),
                        nrow=nrow,
                        normalize=True,
                        value_range=(-1, 1),
                    )
                    e_eval.train()

            if i % args.save_every == 0:
                e_eval = encoder if args.no_ema else e_ema
                torch.save(
                    {
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_module.state_dict(),
                        "e_ema": e_eval.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight',
                                 f"{str(i).zfill(6)}.pt"),
                )

            if i % args.save_latest_every == 0:
                torch.save(
                    {
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_module.state_dict(),
                        "e_ema": e_eval.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight', f"latest.pt"),
                )