def valid(args, epoch, loader, model, device): if get_rank() == 0: pbar = tqdm(loader, dynamic_ncols=True) else: pbar = loader model.eval() recon_total = 0 kl_total = 0 n_imgs = 0 for i, img in enumerate(pbar): img = img.to(device) out, mean, logvar = model(img, sample=False) recon = recon_loss(out, img) kl = kl_loss(mean, logvar) loss_dict = {'recon': recon, 'kl': kl} loss_reduced = reduce_loss_dict(loss_dict) if get_rank() == 0: batch = img.shape[0] recon_total += loss_reduced['recon'] * batch kl_total += loss_reduced['kl'] * batch n_imgs += batch recon = recon_total / n_imgs kl = kl_total / n_imgs pbar.set_description( f'valid; epoch: {epoch}; recon: {recon.item():.2f}; kl: {kl.item():.2f}' ) if i == 0: utils.save_image( torch.cat([img, out], 0), f'sample_vae/{str(epoch).zfill(2)}.png', nrow=8, normalize=True, range=(-1, 1), ) if get_rank() == 0: if wandb and args.wandb: wandb.log( { 'Valid/Reconstruction': recon.item(), 'Valid/KL Divergence': kl.item(), }, step=epoch, )
def train(args, epoch, loader, model, optimizer, scheduler, device): if get_rank() == 0: pbar = tqdm(loader, dynamic_ncols=True) else: pbar = loader model.train() for img in pbar: img = img.to(device) out, mean, logvar = model(img) recon = recon_loss(out, img) kl = kl_loss(mean, logvar) loss = recon + args.beta * kl model.zero_grad() loss.backward() optimizer.step() if scheduler is not None: scheduler.step() loss_dict = {'recon': recon, 'kl': kl} loss_reduced = reduce_loss_dict(loss_dict) if get_rank() == 0: recon = loss_reduced['recon'] kl = loss_reduced['kl'] lr = optimizer.param_groups[0]['lr'] pbar.set_description( f'train; epoch: {epoch}; recon: {recon.item():.2f}; kl: {kl.item():.2f}; lr: {lr:.5f}' ) if wandb and args.wandb: wandb.log( { 'Train/Reconstruction': recon.item(), 'Train/KL Divergence': kl.item(), } )
def train( args, loader, encoder, generator, discriminator, cooccur, g_optim, d_optim, e_ema, g_ema, device, ): loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 loss_dict = {} if args.distributed: e_module = encoder.module g_module = generator.module d_module = discriminator.module c_module = cooccur.module else: e_module = encoder g_module = generator d_module = discriminator c_module = cooccur accum = 0.5**(32 / (10 * 1000)) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break real_img = next(loader) real_img = real_img.to(device) requires_grad(encoder, False) requires_grad(generator, False) requires_grad(discriminator, True) requires_grad(cooccur, True) real_img1, real_img2 = real_img.chunk(2, dim=0) structure1, texture1 = encoder(real_img1) _, texture2 = encoder(real_img2) fake_img1 = generator(structure1, texture1) fake_img2 = generator(structure1, texture2) fake_pred = discriminator(torch.cat((fake_img1, fake_img2), 0)) real_pred = discriminator(real_img) d_loss = d_logistic_loss(real_pred, fake_pred) fake_patch = patchify_image(fake_img2, args.n_crop) real_patch = patchify_image(real_img2, args.n_crop) ref_patch = patchify_image(real_img2, args.ref_crop * args.n_crop) fake_patch_pred, ref_input = cooccur(fake_patch, ref_patch, ref_batch=args.ref_crop) real_patch_pred, _ = cooccur(real_patch, ref_input=ref_input) cooccur_loss = d_logistic_loss(real_patch_pred, fake_patch_pred) loss_dict["d"] = d_loss loss_dict["cooccur"] = cooccur_loss loss_dict["real_score"] = real_pred.mean() fake_pred1, fake_pred2 = fake_pred.chunk(2, dim=0) loss_dict["fake_score"] = fake_pred1.mean() loss_dict["hybrid_score"] = fake_pred2.mean() d_optim.zero_grad() (d_loss + cooccur_loss).backward() d_optim.step() d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True real_pred = discriminator(real_img) r1_loss = d_r1_loss(real_pred, real_img) real_patch.requires_grad = True real_patch_pred, _ = cooccur(real_patch, ref_patch, ref_batch=args.ref_crop) cooccur_r1_loss = d_r1_loss(real_patch_pred, real_patch) d_optim.zero_grad() r1_loss_sum = args.r1 / 2 * r1_loss * args.d_reg_every r1_loss_sum += args.cooccur_r1 / 2 * cooccur_r1_loss * args.d_reg_every r1_loss_sum += 0 * real_pred[0, 0] + 0 * real_patch_pred[0, 0] r1_loss_sum.backward() d_optim.step() loss_dict["r1"] = r1_loss loss_dict["cooccur_r1"] = cooccur_r1_loss requires_grad(encoder, True) requires_grad(generator, True) requires_grad(discriminator, False) requires_grad(cooccur, False) structure1, texture1 = encoder(real_img1) _, texture2 = encoder(real_img2) fake_img1 = generator(structure1, texture1) fake_img2 = generator(structure1, texture2) recon_loss = F.l1_loss(fake_img1, real_img1) fake_pred = discriminator(torch.cat((fake_img1, fake_img2), 0)) g_loss = g_nonsaturating_loss(fake_pred) fake_patch = patchify_image(fake_img2, args.n_crop) ref_patch = patchify_image(real_img2, args.ref_crop * args.n_crop) fake_patch_pred, _ = cooccur(fake_patch, ref_patch, ref_batch=args.ref_crop) g_cooccur_loss = g_nonsaturating_loss(fake_patch_pred) loss_dict["recon"] = recon_loss loss_dict["g"] = g_loss loss_dict["g_cooccur"] = g_cooccur_loss g_optim.zero_grad() (recon_loss + g_loss + g_cooccur_loss).backward() g_optim.step() accumulate(e_ema, e_module, accum) accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() cooccur_val = loss_reduced["cooccur"].mean().item() recon_val = loss_reduced["recon"].mean().item() g_loss_val = loss_reduced["g"].mean().item() g_cooccur_val = loss_reduced["g_cooccur"].mean().item() r1_val = loss_reduced["r1"].mean().item() cooccur_r1_val = loss_reduced["cooccur_r1"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() hybrid_score_val = loss_reduced["hybrid_score"].mean().item() if get_rank() == 0: pbar.set_description(( f"d: {d_loss_val:.4f}; c: {cooccur_val:.4f} g: {g_loss_val:.4f}; " f"g_cooccur: {g_cooccur_val:.4f}; recon: {recon_val:.4f}; r1: {r1_val:.4f}; " f"r1_cooccur: {cooccur_r1_val:.4f}")) if wandb and args.wandb and i % 10 == 0: wandb.log( { "Generator": g_loss_val, "Discriminator": d_loss_val, "Cooccur": cooccur_val, "Recon": recon_val, "Generator Cooccur": g_cooccur_val, "R1": r1_val, "Cooccur R1": cooccur_r1_val, "Real Score": real_score_val, "Fake Score": fake_score_val, "Hybrid Score": hybrid_score_val, }, step=i, ) if i % 100 == 0: with torch.no_grad(): e_ema.eval() g_ema.eval() structure1, texture1 = e_ema(real_img1) _, texture2 = e_ema(real_img2) fake_img1 = g_ema(structure1, texture1) fake_img2 = g_ema(structure1, texture2) sample = torch.cat((fake_img1, fake_img2), 0) utils.save_image( sample, f"sample/{str(i).zfill(6)}.png", nrow=int(sample.shape[0]**0.5), normalize=True, range=(-1, 1), ) if i % 10000 == 0: torch.save( { "e": e_module.state_dict(), "g": g_module.state_dict(), "d": d_module.state_dict(), "cooccur": c_module.state_dict(), "e_ema": e_ema.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, }, f"checkpoint/{str(i).zfill(6)}.pt", )
def train(args, loader, generator, discriminator, vae, g_optim, d_optim, g_ema, device): loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 loss_dict = {} if args.distributed: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator sample_z = torch.randn(8 * 8, args.latent, device=device) for idx in pbar: i = idx + args.start_iter if i > args.iter: print('Done!') break real_img = next(loader) real_img = real_img.to(device) requires_grad(generator, False) requires_grad(discriminator, True) with torch.no_grad(): vae_latent, _, _ = vae(real_img) noise = make_noise(args.batch, args.latent, 1, device) fake_img, _ = generator([torch.cat([noise, vae_latent], 1)]) fake_pred = discriminator(fake_img) real_pred = discriminator(real_img) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict['d'] = d_loss discriminator.zero_grad() d_loss.backward() d_optim.step() d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True real_pred = discriminator(real_img) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict['r1'] = r1_loss requires_grad(generator, True) requires_grad(discriminator, False) noise = make_noise(args.batch, args.latent, 1, device) fake_img, _ = generator([torch.cat([noise, vae_latent], 1)]) fake_pred = discriminator(fake_img) _, mean, logvar = vae(fake_img) vae_loss = gaussian_nll_loss(vae_latent.detach(), mean, logvar) vae_weight = min(1, ((1 - 1e-5) / args.vae_warmup) * i + 1e-5) vae_loss = vae_weight * vae_loss g_loss = g_nonsaturating_loss(fake_pred) loss_dict['g'] = g_loss loss_dict['vae'] = vae_loss generator.zero_grad() (g_loss + args.vae_regularize * vae_loss).backward() g_optim.step() accumulate(g_ema, g_module) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced['d'].mean().item() g_loss_val = loss_reduced['g'].mean().item() r1_val = loss_reduced['r1'].mean().item() vae_loss_val = loss_reduced['vae'].mean().item() if get_rank() == 0: pbar.set_description(( f'd: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; vae: {vae_loss_val:.4f}' )) if wandb and args.wandb: wandb.log({ 'Generator': g_loss_val, 'Discriminator': d_loss_val, 'R1': r1_val, 'VAE': vae_loss_val, }) if i % 100 == 0: n_repeat = sample_z.shape[0] // vae_latent.shape[0] + 1 vae_latent = vae_latent.repeat(n_repeat, 1) vae_latent = vae_latent[:sample_z.shape[0]] with torch.no_grad(): g_ema.eval() sample, _ = g_ema([torch.cat([sample_z, vae_latent], 1)]) utils.save_image( sample, f'sample/{str(i).zfill(6)}.png', nrow=8, normalize=True, range=(-1, 1), ) if i % 10000 == 0: torch.save( { 'g': g_module.state_dict(), 'd': d_module.state_dict(), 'g_ema': g_ema.state_dict(), 'g_optim': g_optim.state_dict(), 'd_optim': d_optim.state_dict(), }, f'checkpoint/{str(i).zfill(6)}.pt', )