Ejemplo n.º 1
0
def train(args, loader, generator, discriminator, optimizer, g_ema, device):
    loader = sample_data(loader)
    pbar = range(args.iter)
    pbar = tqdm(pbar,
                initial=args.start_iter,
                dynamic_ncols=True,
                smoothing=0.01)
    mean_path_length = 0

    r1_loss = torch.tensor(0.0, device=device)
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}
    if args.gpu_num > 1:
        g_module = generator.module
        d_module = discriminator.module
    else:
        g_module = generator
        d_module = discriminator
    accum = 0.5**(32 / (10 * 1000))
    ada_augment = torch.tensor([0.0, 0.0], device=device)
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    ada_aug_step = args.ada_target / args.ada_length
    r_t_stat = 0

    sample_z = torch.randn(args.n_sample, args.latent, device=device)

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")
            break

        real_img = next(loader)
        real_img = real_img.to(device)
        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)

        if args.augment:
            real_img_aug, _ = augment(real_img, ada_aug_p)
            fake_img, _ = augment(fake_img, ada_aug_p)
        else:
            real_img_aug = real_img

        fake_pred = discriminator(fake_img)
        real_pred = discriminator(real_img_aug)

        d_loss = d_logistic_loss(real_pred, fake_pred)
        loss_dict["loss"] = d_loss.item()
        loss_dict["real_score"] = real_pred.mean().item()
        loss_dict["fake_score"] = fake_pred.mean().item()

        d_regularize = i % args.d_reg_every == 0
        # d_regularize = False
        if d_regularize:
            real_img_cp = real_img.clone().detach()
            real_img_cp.requires_grad = True
            real_pred_cp = discriminator(real_img_cp)
            r1_loss = d_r1_loss(real_pred_cp, real_img_cp)
            d_loss += args.r1 / 2 * r1_loss * args.d_reg_every
        loss_dict["r1"] = r1_loss.item()

        # g_regularize = i % args.g_reg_every == 0
        g_regularize = False
        if g_regularize:  # TODO adapt code for nn.DataParallel
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing,
                                 device)
            fake_img, latents = generator(noise, return_latents=True)

            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)
            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss
            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            d_loss += weighted_path_loss
            mean_path_length_avg = mean_path_length.item()

        loss_dict["path"] = path_loss.mean().item()
        loss_dict["path_length"] = path_lengths.mean().item()

        optimizer.step(d_loss)
        # update ada_aug_p
        if args.augment and args.augment_p == 0:
            ada_augment_data = torch.tensor(
                (torch.sign(real_pred).sum().item(), real_pred.shape[0]),
                device=device)
            ada_augment += ada_augment_data
            if ada_augment[1] > 255:
                pred_signs, n_pred = ada_augment.tolist()
                r_t_stat = pred_signs / n_pred
                if r_t_stat > args.ada_target:
                    sign = 1
                else:
                    sign = -1
                ada_aug_p += sign * ada_aug_step * n_pred
                ada_aug_p = min(1, max(0, ada_aug_p))
                ada_augment.mul_(0)

        accumulate(g_ema, g_module, accum)

        d_loss_val = loss_dict["loss"]
        r1_val = loss_dict['r1']
        path_loss_val = loss_dict["path"]
        real_score_val = loss_dict["real_score"]
        fake_score_val = loss_dict["fake_score"]
        path_length_val = loss_dict["path_length"]

        pbar.set_description((
            f"d: {d_loss_val:.4f}; g: {d_loss_val:.4f}; r1: {r1_val:.4f}; "
            f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
            f"augment: {ada_aug_p:.4f}"))
        if wandb and args.wandb:
            wandb.log({
                "Generator": d_loss_val,
                "Discriminator": d_loss_val,
                "Augment": ada_aug_p,
                "Rt": r_t_stat,
                "R1": r1_val,
                "Path Length Regularization": path_loss_val,
                "Mean Path Length": mean_path_length,
                "Real Score": real_score_val,
                "Fake Score": fake_score_val,
                "Path Length": path_length_val,
            })
        if i % 100 == 0:
            with torch.no_grad():
                g_ema.eval()
                sample, _ = g_ema([sample_z])
                utils.save_image(
                    sample,
                    f"figs/stylegan-acgd/{str(i).zfill(6)}.png",
                    nrow=int(args.n_sample**0.5),
                    normalize=True,
                    range=(-1, 1),
                )
        if i % 100 == 0:
            torch.save(
                {
                    "g": g_module.state_dict(),
                    "d": d_module.state_dict(),
                    "g_ema": g_ema.state_dict(),
                    "d_optim": optimizer.state_dict(),
                    "args": args,
                    "ada_aug_p": ada_aug_p,
                },
                f"checkpoints/stylegan-acgd/{str(i).zfill(6)}.pt",
            )
Ejemplo n.º 2
0
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema,
          device):
    ckpt_dir = 'checkpoints/stylegan'
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
    fig_dir = 'figs/stylegan'
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)

    loader = sample_data(loader)

    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    dynamic_ncols=True,
                    smoothing=0.01)

    mean_path_length = 0

    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}

    if args.distributed:
        g_module = generator.module
        d_module = discriminator.module

    else:
        g_module = generator
        d_module = discriminator

    accum = 0.5**(32 / (10 * 1000))
    ada_augment = torch.tensor([0.0, 0.0], device=device)
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    ada_aug_step = args.ada_target / args.ada_length
    r_t_stat = 0

    sample_z = torch.randn(args.n_sample, args.latent, device=device)

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")

            break

        real_img = next(loader)
        real_img = real_img.to(device)

        requires_grad(generator, False)
        requires_grad(discriminator, True)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)

        if args.augment:
            real_img_aug, _ = augment(real_img, ada_aug_p)
            fake_img, _ = augment(fake_img, ada_aug_p)

        else:
            real_img_aug = real_img

        fake_pred = discriminator(fake_img)
        real_pred = discriminator(real_img_aug)
        d_loss = d_logistic_loss(real_pred, fake_pred)

        loss_dict["d"] = d_loss
        loss_dict["real_score"] = real_pred.mean()
        loss_dict["fake_score"] = fake_pred.mean()

        discriminator.zero_grad()
        d_loss.backward()
        d_optim.step()

        if args.augment and args.augment_p == 0:
            ada_augment_data = torch.tensor(
                (torch.sign(real_pred).sum().item(), real_pred.shape[0]),
                device=device)
            ada_augment += reduce_sum(ada_augment_data)

            if ada_augment[1] > 255:
                pred_signs, n_pred = ada_augment.tolist()

                r_t_stat = pred_signs / n_pred

                if r_t_stat > args.ada_target:
                    sign = 1

                else:
                    sign = -1

                ada_aug_p += sign * ada_aug_step * n_pred
                ada_aug_p = min(1, max(0, ada_aug_p))
                ada_augment.mul_(0)

        d_regularize = i % args.d_reg_every == 0

        if d_regularize:
            real_img.requires_grad = True
            real_pred = discriminator(real_img)
            r1_loss = d_r1_loss(real_pred, real_img)

            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every +
             0 * real_pred[0]).backward()

            d_optim.step()

        loss_dict["r1"] = r1_loss

        requires_grad(generator, True)
        requires_grad(discriminator, False)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)

        if args.augment:
            fake_img, _ = augment(fake_img, ada_aug_p)

        fake_pred = discriminator(fake_img)
        g_loss = g_nonsaturating_loss(fake_pred)

        loss_dict["g"] = g_loss

        generator.zero_grad()
        g_loss.backward()
        g_optim.step()

        g_regularize = i % args.g_reg_every == 0

        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing,
                                 device)
            fake_img, latents = generator(noise, return_latents=True)

            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)

            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss

            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()

            g_optim.step()

            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())

        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        accumulate(g_ema, g_module, accum)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()

        if get_rank() == 0:
            pbar.set_description((
                f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
                f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                f"augment: {ada_aug_p:.4f}"))

            if wandb and args.wandb:
                wandb.log({
                    "Generator": g_loss_val,
                    "Discriminator": d_loss_val,
                    "Augment": ada_aug_p,
                    "Rt": r_t_stat,
                    "R1": r1_val,
                    "Path Length Regularization": path_loss_val,
                    "Mean Path Length": mean_path_length,
                    "Real Score": real_score_val,
                    "Fake Score": fake_score_val,
                    "Path Length": path_length_val,
                })

            if i % 100 == 0:
                with torch.no_grad():
                    g_ema.eval()
                    sample, _ = g_ema([sample_z])
                    utils.save_image(
                        sample,
                        f"figs/stylegan/{str(i).zfill(6)}.png",
                        nrow=int(args.n_sample**0.5),
                        normalize=True,
                        range=(-1, 1),
                    )

            if i % 10000 == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                    },
                    f"checkpoints/stylegan/{str(i).zfill(6)}.pt",
                )
def train(args, loader, generator, discriminator, optimizer, g_ema, device):
    collect_info = True
    ckpt_dir = 'checkpoints/stylegan-acgd'
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
    fig_dir = 'figs/stylegan-acgd'
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)
    loader = sample_data(loader)
    pbar = range(args.iter)
    pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01)

    accs = torch.tensor([1.0 for i in range(50)])
    loss_dict = {}
    if args.gpu_num > 1:
        g_module = generator.module
        d_module = discriminator.module
    else:
        g_module = generator
        d_module = discriminator
    accum = 0.5 ** (32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0

    sample_z = torch.randn(args.n_sample, args.latent, device=device)
    ada_ratio = 2
    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")
            break

        real_img = next(loader)
        real_img = real_img.to(device)
        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)

        if args.augment:
            real_img_aug, _ = augment(real_img, ada_aug_p)
            fake_img, _ = augment(fake_img, ada_aug_p)
        else:
            real_img_aug = real_img

        fake_pred = discriminator(fake_img)
        real_pred = discriminator(real_img_aug)
        d_loss = d_logistic_loss(real_pred, fake_pred)
        optimizer.step(d_loss)

        num_correct = torch.sum(real_pred > 0) + torch.sum(fake_pred < 0)
        acc = num_correct.item() / (fake_pred.shape[0] + real_pred.shape[0])

        loss_dict["loss"] = d_loss.item()
        loss_dict["real_score"] = real_pred.mean().item()
        loss_dict["fake_score"] = fake_pred.mean().item()


        # update ada_ratio
        accs[i % 50] = acc
        acc_indicator = sum(accs) / 50
        if i % 2 == 0:
            if acc_indicator > 0.85:
                ada_ratio += 1
            elif acc_indicator < 0.75:
                ada_ratio -= 1
            max_ratio = 2 ** min(4, ada_ratio)
            min_ratio = 2 ** min(0, 4 - ada_ratio)
            if args.ada_train:
                print('Adjust lrs')
                optimizer.set_lr(lr_max=max_ratio * args.lr_d, lr_min=min_ratio * args.lr_d)

        accumulate(g_ema, g_module, accum)

        d_loss_val = loss_dict["loss"]
        real_score_val = loss_dict["real_score"]
        fake_score_val = loss_dict["fake_score"]

        pbar.set_description(
            (
                f"d: {d_loss_val:.4f}; g: {d_loss_val:.4f}; Acc: {acc:.4f}; "
                f"augment: {ada_aug_p:.4f}"
            )
        )
        if wandb and args.wandb:
            if collect_info:
                cgd_info = optimizer.get_info()
                wandb.log(
                    {
                        'CG iter num': cgd_info['iter_num'],
                        'CG runtime': cgd_info['time'],
                        'D gradient': cgd_info['grad_y'],
                        'G gradient': cgd_info['grad_x'],
                        'D hvp': cgd_info['hvp_y'],
                        'G hvp': cgd_info['hvp_x'],
                        'D cg': cgd_info['cg_y'],
                        'G cg': cgd_info['cg_x']
                    },
                    step=i,
                )
            wandb.log(
                {
                    "Generator": d_loss_val,
                    "Discriminator": d_loss_val,
                    "Ada ratio": ada_ratio,
                    'Generator lr': max_ratio * args.lr_d,
                    'Discriminator lr': min_ratio * args.lr_d,
                    "Rt": r_t_stat,
                    "Accuracy": acc_indicator,
                    "Real Score": real_score_val,
                    "Fake Score": fake_score_val
                },
                step=i,
            )
        if i % 100 == 0:
            with torch.no_grad():
                g_ema.eval()
                sample, _ = g_ema([sample_z])
                utils.save_image(
                    sample,
                    f"figs/stylegan-acgd/{str(i).zfill(6)}.png",
                    nrow=int(args.n_sample ** 0.5),
                    normalize=True,
                    range=(-1, 1),
                )
        if i % 2000 == 0:
            torch.save(
                {
                    "g": g_module.state_dict(),
                    "d": d_module.state_dict(),
                    "g_ema": g_ema.state_dict(),
                    "d_optim": optimizer.state_dict(),
                    "args": args,
                    "ada_aug_p": ada_aug_p,
                },
                f"checkpoints/stylegan-acgd/fix{str(i).zfill(6)}.pt",
            )