Exemplo n.º 1
0
def valid(args, epoch, loader, model, device):
    if get_rank() == 0:
        pbar = tqdm(loader, dynamic_ncols=True)

    else:
        pbar = loader

    model.eval()

    recon_total = 0
    kl_total = 0
    n_imgs = 0

    for i, img in enumerate(pbar):
        img = img.to(device)

        out, mean, logvar = model(img, sample=False)
        recon = recon_loss(out, img)
        kl = kl_loss(mean, logvar)

        loss_dict = {'recon': recon, 'kl': kl}
        loss_reduced = reduce_loss_dict(loss_dict)

        if get_rank() == 0:
            batch = img.shape[0]
            recon_total += loss_reduced['recon'] * batch
            kl_total += loss_reduced['kl'] * batch
            n_imgs += batch
            recon = recon_total / n_imgs
            kl = kl_total / n_imgs

            pbar.set_description(
                f'valid; epoch: {epoch}; recon: {recon.item():.2f}; kl: {kl.item():.2f}'
            )

            if i == 0:
                utils.save_image(
                    torch.cat([img, out], 0),
                    f'sample_vae/{str(epoch).zfill(2)}.png',
                    nrow=8,
                    normalize=True,
                    range=(-1, 1),
                )

    if get_rank() == 0:
        if wandb and args.wandb:
            wandb.log(
                {
                    'Valid/Reconstruction': recon.item(),
                    'Valid/KL Divergence': kl.item(),
                },
                step=epoch,
            )
Exemplo n.º 2
0
def train(args, epoch, loader, model, optimizer, scheduler, device):
    if get_rank() == 0:
        pbar = tqdm(loader, dynamic_ncols=True)

    else:
        pbar = loader

    model.train()

    for img in pbar:
        img = img.to(device)

        out, mean, logvar = model(img)
        recon = recon_loss(out, img)
        kl = kl_loss(mean, logvar)
        loss = recon + args.beta * kl

        model.zero_grad()
        loss.backward()
        optimizer.step()

        if scheduler is not None:
            scheduler.step()

        loss_dict = {'recon': recon, 'kl': kl}
        loss_reduced = reduce_loss_dict(loss_dict)

        if get_rank() == 0:
            recon = loss_reduced['recon']
            kl = loss_reduced['kl']
            lr = optimizer.param_groups[0]['lr']

            pbar.set_description(
                f'train; epoch: {epoch}; recon: {recon.item():.2f}; kl: {kl.item():.2f}; lr: {lr:.5f}'
            )

            if wandb and args.wandb:
                wandb.log(
                    {
                        'Train/Reconstruction': recon.item(),
                        'Train/KL Divergence': kl.item(),
                    }
                )
def train(
    args,
    loader,
    encoder,
    generator,
    discriminator,
    cooccur,
    g_optim,
    d_optim,
    e_ema,
    g_ema,
    device,
):
    loader = sample_data(loader)

    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    dynamic_ncols=True,
                    smoothing=0.01)

    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    loss_dict = {}

    if args.distributed:
        e_module = encoder.module
        g_module = generator.module
        d_module = discriminator.module
        c_module = cooccur.module

    else:
        e_module = encoder
        g_module = generator
        d_module = discriminator
        c_module = cooccur

    accum = 0.5**(32 / (10 * 1000))

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")

            break

        real_img = next(loader)
        real_img = real_img.to(device)

        requires_grad(encoder, False)
        requires_grad(generator, False)
        requires_grad(discriminator, True)
        requires_grad(cooccur, True)

        real_img1, real_img2 = real_img.chunk(2, dim=0)

        structure1, texture1 = encoder(real_img1)
        _, texture2 = encoder(real_img2)

        fake_img1 = generator(structure1, texture1)
        fake_img2 = generator(structure1, texture2)

        fake_pred = discriminator(torch.cat((fake_img1, fake_img2), 0))
        real_pred = discriminator(real_img)
        d_loss = d_logistic_loss(real_pred, fake_pred)

        fake_patch = patchify_image(fake_img2, args.n_crop)
        real_patch = patchify_image(real_img2, args.n_crop)
        ref_patch = patchify_image(real_img2, args.ref_crop * args.n_crop)
        fake_patch_pred, ref_input = cooccur(fake_patch,
                                             ref_patch,
                                             ref_batch=args.ref_crop)
        real_patch_pred, _ = cooccur(real_patch, ref_input=ref_input)
        cooccur_loss = d_logistic_loss(real_patch_pred, fake_patch_pred)

        loss_dict["d"] = d_loss
        loss_dict["cooccur"] = cooccur_loss
        loss_dict["real_score"] = real_pred.mean()
        fake_pred1, fake_pred2 = fake_pred.chunk(2, dim=0)
        loss_dict["fake_score"] = fake_pred1.mean()
        loss_dict["hybrid_score"] = fake_pred2.mean()

        d_optim.zero_grad()
        (d_loss + cooccur_loss).backward()
        d_optim.step()

        d_regularize = i % args.d_reg_every == 0

        if d_regularize:
            real_img.requires_grad = True
            real_pred = discriminator(real_img)
            r1_loss = d_r1_loss(real_pred, real_img)

            real_patch.requires_grad = True
            real_patch_pred, _ = cooccur(real_patch,
                                         ref_patch,
                                         ref_batch=args.ref_crop)
            cooccur_r1_loss = d_r1_loss(real_patch_pred, real_patch)

            d_optim.zero_grad()

            r1_loss_sum = args.r1 / 2 * r1_loss * args.d_reg_every
            r1_loss_sum += args.cooccur_r1 / 2 * cooccur_r1_loss * args.d_reg_every
            r1_loss_sum += 0 * real_pred[0, 0] + 0 * real_patch_pred[0, 0]
            r1_loss_sum.backward()

            d_optim.step()

        loss_dict["r1"] = r1_loss
        loss_dict["cooccur_r1"] = cooccur_r1_loss

        requires_grad(encoder, True)
        requires_grad(generator, True)
        requires_grad(discriminator, False)
        requires_grad(cooccur, False)

        structure1, texture1 = encoder(real_img1)
        _, texture2 = encoder(real_img2)

        fake_img1 = generator(structure1, texture1)
        fake_img2 = generator(structure1, texture2)

        recon_loss = F.l1_loss(fake_img1, real_img1)

        fake_pred = discriminator(torch.cat((fake_img1, fake_img2), 0))
        g_loss = g_nonsaturating_loss(fake_pred)

        fake_patch = patchify_image(fake_img2, args.n_crop)
        ref_patch = patchify_image(real_img2, args.ref_crop * args.n_crop)
        fake_patch_pred, _ = cooccur(fake_patch,
                                     ref_patch,
                                     ref_batch=args.ref_crop)
        g_cooccur_loss = g_nonsaturating_loss(fake_patch_pred)

        loss_dict["recon"] = recon_loss
        loss_dict["g"] = g_loss
        loss_dict["g_cooccur"] = g_cooccur_loss

        g_optim.zero_grad()
        (recon_loss + g_loss + g_cooccur_loss).backward()
        g_optim.step()

        accumulate(e_ema, e_module, accum)
        accumulate(g_ema, g_module, accum)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        cooccur_val = loss_reduced["cooccur"].mean().item()
        recon_val = loss_reduced["recon"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        g_cooccur_val = loss_reduced["g_cooccur"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        cooccur_r1_val = loss_reduced["cooccur_r1"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        hybrid_score_val = loss_reduced["hybrid_score"].mean().item()

        if get_rank() == 0:
            pbar.set_description((
                f"d: {d_loss_val:.4f}; c: {cooccur_val:.4f} g: {g_loss_val:.4f}; "
                f"g_cooccur: {g_cooccur_val:.4f}; recon: {recon_val:.4f}; r1: {r1_val:.4f}; "
                f"r1_cooccur: {cooccur_r1_val:.4f}"))

            if wandb and args.wandb and i % 10 == 0:
                wandb.log(
                    {
                        "Generator": g_loss_val,
                        "Discriminator": d_loss_val,
                        "Cooccur": cooccur_val,
                        "Recon": recon_val,
                        "Generator Cooccur": g_cooccur_val,
                        "R1": r1_val,
                        "Cooccur R1": cooccur_r1_val,
                        "Real Score": real_score_val,
                        "Fake Score": fake_score_val,
                        "Hybrid Score": hybrid_score_val,
                    },
                    step=i,
                )

            if i % 100 == 0:
                with torch.no_grad():
                    e_ema.eval()
                    g_ema.eval()

                    structure1, texture1 = e_ema(real_img1)
                    _, texture2 = e_ema(real_img2)

                    fake_img1 = g_ema(structure1, texture1)
                    fake_img2 = g_ema(structure1, texture2)

                    sample = torch.cat((fake_img1, fake_img2), 0)

                    utils.save_image(
                        sample,
                        f"sample/{str(i).zfill(6)}.png",
                        nrow=int(sample.shape[0]**0.5),
                        normalize=True,
                        range=(-1, 1),
                    )

            if i % 10000 == 0:
                torch.save(
                    {
                        "e": e_module.state_dict(),
                        "g": g_module.state_dict(),
                        "d": d_module.state_dict(),
                        "cooccur": c_module.state_dict(),
                        "e_ema": e_ema.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                    },
                    f"checkpoint/{str(i).zfill(6)}.pt",
                )
Exemplo n.º 4
0
def train(args, loader, generator, discriminator, vae, g_optim, d_optim, g_ema,
          device):
    loader = sample_data(loader)

    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    dynamic_ncols=True,
                    smoothing=0.01)

    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    loss_dict = {}

    if args.distributed:
        g_module = generator.module
        d_module = discriminator.module

    else:
        g_module = generator
        d_module = discriminator

    sample_z = torch.randn(8 * 8, args.latent, device=device)

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print('Done!')

            break

        real_img = next(loader)
        real_img = real_img.to(device)

        requires_grad(generator, False)
        requires_grad(discriminator, True)

        with torch.no_grad():
            vae_latent, _, _ = vae(real_img)

        noise = make_noise(args.batch, args.latent, 1, device)
        fake_img, _ = generator([torch.cat([noise, vae_latent], 1)])
        fake_pred = discriminator(fake_img)

        real_pred = discriminator(real_img)
        d_loss = d_logistic_loss(real_pred, fake_pred)

        loss_dict['d'] = d_loss

        discriminator.zero_grad()
        d_loss.backward()
        d_optim.step()

        d_regularize = i % args.d_reg_every == 0

        if d_regularize:
            real_img.requires_grad = True
            real_pred = discriminator(real_img)
            r1_loss = d_r1_loss(real_pred, real_img)

            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every +
             0 * real_pred[0]).backward()

            d_optim.step()

        loss_dict['r1'] = r1_loss

        requires_grad(generator, True)
        requires_grad(discriminator, False)

        noise = make_noise(args.batch, args.latent, 1, device)
        fake_img, _ = generator([torch.cat([noise, vae_latent], 1)])
        fake_pred = discriminator(fake_img)
        _, mean, logvar = vae(fake_img)
        vae_loss = gaussian_nll_loss(vae_latent.detach(), mean, logvar)

        vae_weight = min(1, ((1 - 1e-5) / args.vae_warmup) * i + 1e-5)
        vae_loss = vae_weight * vae_loss

        g_loss = g_nonsaturating_loss(fake_pred)

        loss_dict['g'] = g_loss
        loss_dict['vae'] = vae_loss

        generator.zero_grad()
        (g_loss + args.vae_regularize * vae_loss).backward()
        g_optim.step()

        accumulate(g_ema, g_module)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced['d'].mean().item()
        g_loss_val = loss_reduced['g'].mean().item()
        r1_val = loss_reduced['r1'].mean().item()
        vae_loss_val = loss_reduced['vae'].mean().item()

        if get_rank() == 0:
            pbar.set_description((
                f'd: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; vae: {vae_loss_val:.4f}'
            ))

            if wandb and args.wandb:
                wandb.log({
                    'Generator': g_loss_val,
                    'Discriminator': d_loss_val,
                    'R1': r1_val,
                    'VAE': vae_loss_val,
                })

            if i % 100 == 0:
                n_repeat = sample_z.shape[0] // vae_latent.shape[0] + 1
                vae_latent = vae_latent.repeat(n_repeat, 1)
                vae_latent = vae_latent[:sample_z.shape[0]]
                with torch.no_grad():
                    g_ema.eval()
                    sample, _ = g_ema([torch.cat([sample_z, vae_latent], 1)])
                    utils.save_image(
                        sample,
                        f'sample/{str(i).zfill(6)}.png',
                        nrow=8,
                        normalize=True,
                        range=(-1, 1),
                    )

            if i % 10000 == 0:
                torch.save(
                    {
                        'g': g_module.state_dict(),
                        'd': d_module.state_dict(),
                        'g_ema': g_ema.state_dict(),
                        'g_optim': g_optim.state_dict(),
                        'd_optim': d_optim.state_dict(),
                    },
                    f'checkpoint/{str(i).zfill(6)}.pt',
                )