def train(args, epoch, loader, model, optimizer, scheduler): torch.backends.cudnn.benchmark = True model.train() if get_rank() == 0: pbar = tqdm(loader, dynamic_ncols=True) else: pbar = loader for i, (img, annot) in enumerate(pbar): img = img.to('cuda') annot = annot.to('cuda') loss, _ = model(img, annot) loss_sum = loss['loss'] + args.aux_weight * loss['aux'] model.zero_grad() loss_sum.backward() optimizer.step() scheduler.step() loss_dict = reduce_loss_dict(loss) loss = loss_dict['loss'].mean().item() aux_loss = loss_dict['aux'].mean().item() if get_rank() == 0: lr = optimizer.param_groups[0]['lr'] pbar.set_description( f'epoch: {epoch + 1}; loss: {loss:.5f}; aux loss: {aux_loss:.5f}; lr: {lr:.5f}' )
def train(args, epoch, loader, model, optimizer, device, logger=None): model.train() if get_rank() == 0: pbar = tqdm(enumerate(loader), total=len(loader), dynamic_ncols=True) else: pbar = enumerate(loader) for idx, (images, targets, _) in pbar: model.zero_grad() images = images.to(device) targets = [target.to(device) for target in targets] _, loss_dict = model(images, targets=targets) loss_cls = loss_dict['loss_cls'].mean() loss_box = loss_dict['loss_reg'].mean() loss_center = loss_dict['loss_centerness'].mean() loss = loss_cls + loss_box + loss_center loss.backward() nn.utils.clip_grad_norm_(model.parameters(), 10) optimizer.step() loss_reduced = reduce_loss_dict(loss_dict) loss_cls = loss_reduced['loss_cls'].mean().item() loss_box = loss_reduced['loss_reg'].mean().item() loss_center = loss_reduced['loss_centerness'].mean().item() if get_rank() == 0: pbar.set_description( (f'epoch: {epoch + 1}; cls: {loss_cls:.4f}; ' f'box: {loss_box:.4f}; center: {loss_center:.4f}')) # writing log to tensorboard if logger and idx % 10 == 0: totalStep = (epoch * len(loader) + idx) * args.batch * args.n_gpu logger.add_scalar('training/loss_cls', loss_cls, totalStep) logger.add_scalar('training/loss_box', loss_box, totalStep) logger.add_scalar('training/loss_center', loss_center, totalStep) logger.add_scalar('training/loss_all', (loss_cls + loss_box + loss_center), totalStep)
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device): loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} if args.distributed: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator accum = 0.5**(32 / (10 * 1000)) sample_z = torch.randn(args.n_sample, args.latent, device=device) for idx in pbar: i = idx + args.start_iter if i > args.iter: print('Done!') break real_img = next(loader) real_img = real_img.to(device) requires_grad(generator, False) requires_grad(discriminator, True) noise = mixing_noise(args.batch, args.latent, args.mixing, device) # print(noise[0].shape) fake_img, _ = generator(noise) # print(fake_img.shape, "======================") fake_pred = discriminator(fake_img) real_pred = discriminator(real_img) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict['d'] = d_loss loss_dict['real_score'] = real_pred.mean() loss_dict['fake_score'] = fake_pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True real_pred = discriminator(real_img) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict['r1'] = r1_loss requires_grad(generator, True) requires_grad(discriminator, False) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) fake_pred = discriminator(fake_img) g_loss = g_nonsaturating_loss(fake_pred) loss_dict['g'] = g_loss generator.zero_grad() g_loss.backward() g_optim.step() g_regularize = i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) fake_img, latents = generator(noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict['path'] = path_loss loss_dict['path_length'] = path_lengths.mean() accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced['d'].mean().item() g_loss_val = loss_reduced['g'].mean().item() r1_val = loss_reduced['r1'].mean().item() path_loss_val = loss_reduced['path'].mean().item() real_score_val = loss_reduced['real_score'].mean().item() fake_score_val = loss_reduced['fake_score'].mean().item() path_length_val = loss_reduced['path_length'].mean().item() print("========") if get_rank() == 0: pbar.set_description(( f'd: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; ' f'path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}' )) if wandb and args.wandb: wandb.log( { 'Generator': g_loss_val, 'Discriminator': d_loss_val, 'R1': r1_val, 'Path Length Regularization': path_loss_val, 'Mean Path Length': mean_path_length, 'Real Score': real_score_val, 'Fake Score': fake_score_val, 'Path Length': path_length_val, }, step=i) print((len(dataset) // args.batch), i) if i % (len(dataset) // args.batch) == 0: if not os.path.exists("/scratch/gobi2/lechang/" + args.run_name): os.mkdir("/scratch/gobi2/lechang/" + args.run_name) with torch.no_grad(): g_ema.eval() sample, _ = g_ema([sample_z]) if not os.path.exists("/scratch/gobi2/lechang/" + args.run_name + "/sample/"): os.mkdir("/scratch/gobi2/lechang/" + args.run_name + "/sample/") utils.save_image( sample, "/scratch/gobi2/lechang/" + args.run_name + "/sample/" + f'{str(i).zfill(6)}.png', nrow=int(args.n_sample**0.5), normalize=True, range=(-1, 1), ) wandb.log( { "G(z)": [ wandb.Image(sample[i][0], mode="F") for i in range(sample.shape[0]) ] }, step=i) if i % (200 * len(dataset) // args.batch) == 0: if not os.path.exists("/scratch/gobi2/lechang/" + args.run_name + "/ckpt/"): os.mkdir("/scratch/gobi2/lechang/" + args.run_name + "/ckpt/") torch.save( { 'g': g_module.state_dict(), 'd': d_module.state_dict(), 'g_ema': g_ema.state_dict(), 'g_optim': g_optim.state_dict(), 'd_optim': d_optim.state_dict(), }, "/scratch/gobi2/lechang/" + args.run_name + "/ckpt/" + f'{str(i).zfill(6)}.pt', )
def train(args, loader, generator, discriminator, contrast_learner, augment, g_optim, d_optim, scaler, g_ema, device): loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = th.zeros(size=(1,), device=device) g_loss_val = 0 path_loss = th.zeros(size=(1,), device=device) path_lengths = th.zeros(size=(1,), device=device) loss_dict = {} mse = th.nn.MSELoss() if args.distributed: g_module = generator.module d_module = discriminator.module if contrast_learner is not None: cl_module = contrast_learner.module else: g_module = generator d_module = discriminator cl_module = contrast_learner sample_z = th.randn(args.n_sample, args.latent_size, device=device) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break requires_grad(generator, False) requires_grad(discriminator, True) discriminator.zero_grad() loss_dict["d"], loss_dict["real_score"], loss_dict["fake_score"] = 0, 0, 0 loss_dict["cl_reg"], loss_dict["bc_reg"] = ( th.tensor(0, device=device).float(), th.tensor(0, device=device).float(), ) for _ in range(args.num_accumulate): # sample = [] # for _ in range(0, len(sample_z), args.batch_size): # subsample = next(loader) # sample.append(subsample) # sample = th.cat(sample) # utils.save_image(sample, "reals-no-augment.png", nrow=10, normalize=True) # utils.save_image(augment(sample), "reals-augment.png", nrow=10, normalize=True) real_img = next(loader) real_img = real_img.to(device) # with th.cuda.amp.autocast(): noise = make_noise(args.batch_size, args.latent_size, args.mixing_prob, device) fake_img, _ = generator(noise) if args.augment_D: fake_pred = discriminator(augment(fake_img)) real_pred = discriminator(augment(real_img)) else: fake_pred = discriminator(fake_img) real_pred = discriminator(real_img) # logistic loss real_loss = F.softplus(-real_pred) fake_loss = F.softplus(fake_pred) d_loss = real_loss.mean() + fake_loss.mean() loss_dict["d"] += d_loss.detach() loss_dict["real_score"] += real_pred.mean().detach() loss_dict["fake_score"] += fake_pred.mean().detach() if i > 10000 or i == 0: if args.contrastive > 0: contrast_learner(fake_img.clone().detach(), accumulate=True) contrast_learner(real_img, accumulate=True) contrast_loss = cl_module.calculate_loss() loss_dict["cl_reg"] += contrast_loss.detach() d_loss += args.contrastive * contrast_loss if args.balanced_consistency > 0: aug_fake_pred = discriminator(augment(fake_img.clone().detach())) aug_real_pred = discriminator(augment(real_img)) consistency_loss = mse(real_pred, aug_real_pred) + mse(fake_pred, aug_fake_pred) loss_dict["bc_reg"] += consistency_loss.detach() d_loss += args.balanced_consistency * consistency_loss d_loss /= args.num_accumulate # scaler.scale(d_loss).backward() d_loss.backward() # scaler.step(d_optim) d_optim.step() # R1 regularization if args.r1 > 0 and i % args.d_reg_every == 0: discriminator.zero_grad() loss_dict["r1"] = 0 for _ in range(args.num_accumulate): real_img = next(loader) real_img = real_img.to(device) real_img.requires_grad = True # with th.cuda.amp.autocast(): # if args.augment_D: # real_pred = discriminator( # augment(real_img) # ) # RuntimeError: derivative for grid_sampler_2d_backward is not implemented :( # else: real_pred = discriminator(real_img) real_pred_sum = real_pred.sum() (grad_real,) = th.autograd.grad(outputs=real_pred_sum, inputs=real_img, create_graph=True) # (grad_real,) = th.autograd.grad(outputs=scaler.scale(real_pred_sum), inputs=real_img, create_graph=True) # grad_real = grad_real * (1.0 / scaler.get_scale()) # with th.cuda.amp.autocast(): r1_loss = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() weighted_r1_loss = args.r1 / 2.0 * r1_loss * args.d_reg_every + 0 * real_pred[0] loss_dict["r1"] += r1_loss.detach() weighted_r1_loss /= args.num_accumulate # scaler.scale(weighted_r1_loss).backward() weighted_r1_loss.backward() # scaler.step(d_optim) d_optim.step() requires_grad(generator, True) requires_grad(discriminator, False) generator.zero_grad() loss_dict["g"] = 0 for _ in range(args.num_accumulate): # with th.cuda.amp.autocast(): noise = make_noise(args.batch_size, args.latent_size, args.mixing_prob, device) fake_img, _ = generator(noise) if args.augment_G: fake_img = augment(fake_img) fake_pred = discriminator(fake_img) # non-saturating loss g_loss = F.softplus(-fake_pred).mean() loss_dict["g"] += g_loss.detach() g_loss /= args.num_accumulate # scaler.scale(g_loss).backward() g_loss.backward() # scaler.step(g_optim) g_optim.step() # path length regularization if args.path_regularize > 0 and i % args.g_reg_every == 0: generator.zero_grad() loss_dict["path"], loss_dict["path_length"] = 0, 0 for _ in range(args.num_accumulate): path_batch_size = max(1, args.batch_size // args.path_batch_shrink) # with th.cuda.amp.autocast(): noise = make_noise(path_batch_size, args.latent_size, args.mixing_prob, device) fake_img, latents = generator(noise, return_latents=True) img_noise = th.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3]) noisy_img_sum = (fake_img * img_noise).sum() (grad,) = th.autograd.grad(outputs=noisy_img_sum, inputs=latents, create_graph=True) # (grad,) = th.autograd.grad(outputs=scaler.scale(noisy_img_sum), inputs=latents, create_graph=True) # grad = grad * (1.0 / scaler.get_scale()) # with th.cuda.amp.autocast(): path_lengths = th.sqrt(grad.pow(2).sum(2).mean(1)) path_mean = mean_path_length + 0.01 * (path_lengths.mean() - mean_path_length) path_loss = (path_lengths - path_mean).pow(2).mean() mean_path_length = path_mean.detach() loss_dict["path"] += path_loss.detach() loss_dict["path_length"] += path_lengths.mean().detach() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss /= args.num_accumulate # scaler.scale(weighted_path_loss).backward() weighted_path_loss.backward() # scaler.step(g_optim) g_optim.step() # scaler.update() accumulate(g_ema, g_module) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() / args.num_accumulate g_loss_val = loss_reduced["g"].mean().item() / args.num_accumulate cl_reg_val = loss_reduced["cl_reg"].mean().item() / args.num_accumulate bc_reg_val = loss_reduced["bc_reg"].mean().item() / args.num_accumulate r1_val = loss_reduced["r1"].mean().item() / args.num_accumulate path_loss_val = loss_reduced["path"].mean().item() / args.num_accumulate real_score_val = loss_reduced["real_score"].mean().item() / args.num_accumulate fake_score_val = loss_reduced["fake_score"].mean().item() / args.num_accumulate path_length_val = loss_reduced["path_length"].mean().item() / args.num_accumulate if get_rank() == 0: log_dict = { "Generator": g_loss_val, "Discriminator": d_loss_val, "Real Score": real_score_val, "Fake Score": fake_score_val, "Contrastive": cl_reg_val, "Consistency": bc_reg_val, } if args.log_spec_norm: G_norms = [] for name, spec_norm in g_module.named_buffers(): if "spectral_norm" in name: G_norms.append(spec_norm.cpu().numpy()) G_norms = np.array(G_norms) D_norms = [] for name, spec_norm in d_module.named_buffers(): if "spectral_norm" in name: D_norms.append(spec_norm.cpu().numpy()) D_norms = np.array(D_norms) log_dict[f"Spectral Norms/G min spectral norm"] = np.log(G_norms).min() log_dict[f"Spectral Norms/G mean spectral norm"] = np.log(G_norms).mean() log_dict[f"Spectral Norms/G max spectral norm"] = np.log(G_norms).max() log_dict[f"Spectral Norms/D min spectral norm"] = np.log(D_norms).min() log_dict[f"Spectral Norms/D mean spectral norm"] = np.log(D_norms).mean() log_dict[f"Spectral Norms/D max spectral norm"] = np.log(D_norms).max() if args.r1 > 0 and i % args.d_reg_every == 0: log_dict["R1"] = r1_val if args.path_regularize > 0 and i % args.g_reg_every == 0: log_dict["Path Length Regularization"] = path_loss_val log_dict["Mean Path Length"] = mean_path_length log_dict["Path Length"] = path_length_val if i % args.img_every == 0: gc.collect() th.cuda.empty_cache() with th.no_grad(): g_ema.eval() sample = [] for sub in range(0, len(sample_z), args.batch_size): subsample, _ = g_ema([sample_z[sub : sub + args.batch_size]]) sample.append(subsample.cpu()) sample = th.cat(sample) grid = utils.make_grid(sample, nrow=10, normalize=True, range=(-1, 1)) # utils.save_image(sample, "fakes-no-augment.png", nrow=10, normalize=True) # utils.save_image(augment(sample), "fakes-augment.png", nrow=10, normalize=True) # exit() log_dict["Generated Images EMA"] = [wandb.Image(grid, caption=f"Step {i}")] if i % args.eval_every == 0: start_time = time.time() pbar.set_description((f"Calculating FID...")) fid_dict = validation.fid(g_ema, args.val_batch_size, args.fid_n_sample, args.fid_truncation, args.name) fid = fid_dict["FID"] density = fid_dict["Density"] coverage = fid_dict["Coverage"] pbar.set_description((f"Calculating PPL...")) ppl = validation.ppl( g_ema, args.val_batch_size, args.ppl_n_sample, args.ppl_space, args.ppl_crop, args.latent_size, ) pbar.set_description( ( f"FID: {fid:.4f}; Density: {density:.4f}; Coverage: {coverage:.4f}; PPL: {ppl:.4f} in {time.time() - start_time:.1f}s" ) ) log_dict["Evaluation/FID"] = fid log_dict["Evaluation/Density"] = density log_dict["Evaluation/Coverage"] = coverage log_dict["Evaluation/PPL"] = ppl gc.collect() th.cuda.empty_cache() wandb.log(log_dict) if i % args.checkpoint_every == 0: th.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), # "cl": cl_module.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "d_optim": d_optim.state_dict(), }, f"/home/hans/modelzoo/maua-sg2/{args.name}-{wandb.run.dir.split('/')[-1].split('-')[-1]}-{int(fid)}-{int(ppl)}-{str(i).zfill(6)}.pt", )
def train(args, loader, generator, discriminator, contrast_learner, g_optim, d_optim, g_ema): if args.distributed: g_module = generator.module d_module = discriminator.module if contrast_learner is not None: cl_module = contrast_learner.module else: g_module = generator d_module = discriminator cl_module = contrast_learner loader = sample_data(loader) sample_z = th.randn(args.n_sample, args.latent_size, device=device) mse = th.nn.MSELoss() mean_path_length = 0 ada_augment = th.tensor([0.0, 0.0], device=device) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 ada_aug_step = args.ada_target / args.ada_length r_t_stat = 0 fids = [] pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break loss_dict = { "Generator": th.tensor(0, device=device).float(), "Discriminator": th.tensor(0, device=device).float(), "Real Score": th.tensor(0, device=device).float(), "Fake Score": th.tensor(0, device=device).float(), "Contrastive": th.tensor(0, device=device).float(), "Consistency": th.tensor(0, device=device).float(), "R1 Penalty": th.tensor(0, device=device).float(), "Path Length Regularization": th.tensor(0, device=device).float(), "Augment": th.tensor(0, device=device).float(), "Rt": th.tensor(0, device=device).float(), } requires_grad(generator, False) requires_grad(discriminator, True) discriminator.zero_grad() for _ in range(args.num_accumulate): real_img_og = next(loader).to(device) noise = make_noise(args.batch_size, args.latent_size, args.mixing_prob) fake_img_og, _ = generator(noise) if args.augment: fake_img, _ = augment(fake_img_og, ada_aug_p) real_img, _ = augment(real_img_og, ada_aug_p) else: fake_img = fake_img_og real_img = real_img_og fake_pred = discriminator(fake_img) real_pred = discriminator(real_img) logistic_loss = d_logistic_loss(real_pred, fake_pred) loss_dict["Discriminator"] += logistic_loss.detach() loss_dict["Real Score"] += real_pred.mean().detach() loss_dict["Fake Score"] += fake_pred.mean().detach() d_loss = logistic_loss if args.contrastive > 0: contrast_learner(fake_img_og, fake_img, accumulate=True) contrast_learner(real_img_og, real_img, accumulate=True) contrast_loss = cl_module.calculate_loss() loss_dict["Contrastive"] += contrast_loss.detach() d_loss += args.contrastive * contrast_loss if args.balanced_consistency > 0: consistency_loss = mse( real_pred, discriminator(real_img_og)) + mse( fake_pred, discriminator(fake_img_og)) loss_dict["Consistency"] += consistency_loss.detach() d_loss += args.balanced_consistency * consistency_loss d_loss /= args.num_accumulate d_loss.backward() d_optim.step() if args.r1 > 0 and i % args.d_reg_every == 0: discriminator.zero_grad() for _ in range(args.num_accumulate): real_img = next(loader).to(device) real_img.requires_grad = True real_pred = discriminator(real_img) r1_loss = d_r1_penalty(real_img, real_pred, args) loss_dict["R1 Penalty"] += r1_loss.detach().squeeze() r1_loss = args.r1 * args.d_reg_every * r1_loss / args.num_accumulate r1_loss.backward() d_optim.step() if args.augment and args.augment_p == 0: ada_augment += th.tensor( (th.sign(real_pred).sum().item(), real_pred.shape[0]), device=device) ada_augment = reduce_sum(ada_augment) if ada_augment[1] > 255: pred_signs, n_pred = ada_augment.tolist() r_t_stat = pred_signs / n_pred loss_dict["Rt"] = th.tensor(r_t_stat, device=device).float() if r_t_stat > args.ada_target: sign = 1 else: sign = -1 ada_aug_p += sign * ada_aug_step * n_pred ada_aug_p = min(1, max(0, ada_aug_p)) ada_augment.mul_(0) loss_dict["Augment"] = th.tensor(ada_aug_p, device=device).float() requires_grad(generator, True) requires_grad(discriminator, False) generator.zero_grad() for _ in range(args.num_accumulate): noise = make_noise(args.batch_size, args.latent_size, args.mixing_prob) fake_img, _ = generator(noise) if args.augment: fake_img, _ = augment(fake_img, ada_aug_p) fake_pred = discriminator(fake_img) g_loss = g_non_saturating_loss(fake_pred) loss_dict["Generator"] += g_loss.detach() g_loss /= args.num_accumulate g_loss.backward() g_optim.step() if args.path_regularize > 0 and i % args.g_reg_every == 0: generator.zero_grad() for _ in range(args.num_accumulate): path_loss, mean_path_length = g_path_length_regularization( generator, mean_path_length, args) loss_dict["Path Length Regularization"] += path_loss.detach() path_loss = args.path_regularize * args.g_reg_every * path_loss / args.num_accumulate path_loss.backward() g_optim.step() accumulate(g_ema, g_module) loss_reduced = reduce_loss_dict(loss_dict) log_dict = { k: v.mean().item() / args.num_accumulate for k, v in loss_reduced.items() if v != 0 } if get_rank() == 0: if args.log_spec_norm: G_norms = [] for name, spec_norm in g_module.named_buffers(): if "spectral_norm" in name: G_norms.append(spec_norm.cpu().numpy()) G_norms = np.array(G_norms) D_norms = [] for name, spec_norm in d_module.named_buffers(): if "spectral_norm" in name: D_norms.append(spec_norm.cpu().numpy()) D_norms = np.array(D_norms) log_dict[f"Spectral Norms/G min spectral norm"] = np.log( G_norms).min() log_dict[f"Spectral Norms/G mean spectral norm"] = np.log( G_norms).mean() log_dict[f"Spectral Norms/G max spectral norm"] = np.log( G_norms).max() log_dict[f"Spectral Norms/D min spectral norm"] = np.log( D_norms).min() log_dict[f"Spectral Norms/D mean spectral norm"] = np.log( D_norms).mean() log_dict[f"Spectral Norms/D max spectral norm"] = np.log( D_norms).max() if i % args.img_every == 0: gc.collect() th.cuda.empty_cache() with th.no_grad(): g_ema.eval() sample = [] for sub in range(0, len(sample_z), args.batch_size): subsample, _ = g_ema( [sample_z[sub:sub + args.batch_size]]) sample.append(subsample.cpu()) sample = th.cat(sample) grid = utils.make_grid(sample, nrow=10, normalize=True, range=(-1, 1)) log_dict["Generated Images EMA"] = [ wandb.Image(grid, caption=f"Step {i}") ] if i % args.eval_every == 0: fid_dict = validation.fid(g_ema, args.val_batch_size, args.fid_n_sample, args.fid_truncation, args.name) fid = fid_dict["FID"] fids.append(fid) density = fid_dict["Density"] coverage = fid_dict["Coverage"] ppl = validation.ppl( g_ema, args.val_batch_size, args.ppl_n_sample, args.ppl_space, args.ppl_crop, args.latent_size, ) log_dict["Evaluation/FID"] = fid log_dict["Sweep/FID_smooth"] = gaussian_filter( np.array(fids), [5])[-1] log_dict["Evaluation/Density"] = density log_dict["Evaluation/Coverage"] = coverage log_dict["Evaluation/PPL"] = ppl gc.collect() th.cuda.empty_cache() wandb.log(log_dict) description = ( f"FID: {fid:.4f} PPL: {ppl:.4f} Dens: {density:.4f} Cov: {coverage:.4f} " + f"G: {log_dict['Generator']:.4f} D: {log_dict['Discriminator']:.4f}" ) if "Augment" in log_dict: description += f" Aug: {log_dict['Augment']:.4f}" # Rt: {log_dict['Rt']:.4f}" if "R1 Penalty" in log_dict: description += f" R1: {log_dict['R1 Penalty']:.4f}" if "Path Length Regularization" in log_dict: description += f" Path: {log_dict['Path Length Regularization']:.4f}" pbar.set_description(description) if i % args.checkpoint_every == 0: check_name = "-".join([ args.name, args.runname, wandb.run.dir.split("/")[-1].split("-")[-1], int(fid), args.size, str(i).zfill(6), ]) th.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), # "cl": cl_module.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "d_optim": d_optim.state_dict(), }, f"/home/hans/modelzoo/maua-sg2/{check_name}.pt", )
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device): loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} if args.distributed: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator accum = 0.5**(32 / (10 * 1000)) sample_z = torch.randn(args.n_sample, args.latent, device=device) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break real_img = next(loader) real_img = real_img.to(device) requires_grad(generator, False) requires_grad(discriminator, True) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) fake_pred = discriminator(fake_img) real_pred = discriminator(real_img) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict["d"] = d_loss loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True real_pred = discriminator(real_img) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict["r1"] = r1_loss requires_grad(generator, True) requires_grad(discriminator, False) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) fake_pred = discriminator(fake_img) g_loss = g_nonsaturating_loss(fake_pred) loss_dict["g"] = g_loss generator.zero_grad() g_loss.backward() g_optim.step() g_regularize = i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) fake_img, latents = generator(noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() if get_rank() == 0: pbar.set_description(( f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}" )) if wandb and args.wandb: wandb.log({ "Generator": g_loss_val, "Discriminator": d_loss_val, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, }) if i % 100 == 0: with torch.no_grad(): g_ema.eval() sample, _ = g_ema([sample_z]) utils.save_image( sample, f"sample/{str(i).zfill(6)}.png", nrow=int(args.n_sample**0.5), normalize=True, range=(-1, 1), ) if i % 10000 == 0: torch.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "d_optim": d_optim.state_dict(), }, f"checkpoint/{str(i).zfill(6)}.pt", )
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device, save_dir): loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} if args.distributed: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator accum = 0.5**(32 / (10 * 1000)) ada_augment = torch.tensor([0.0, 0.0], device=device) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 ada_aug_step = args.ada_target / args.ada_length r_t_stat = 0 sample_z = torch.randn(args.n_sample, args.latent, device=device) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break real_img = next(loader) real_img = real_img.to(device) requires_grad(generator, False) requires_grad(discriminator, True) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) fake_img, _ = augment(fake_img, ada_aug_p) else: real_img_aug = real_img fake_pred = discriminator(fake_img) real_pred = discriminator(real_img_aug) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict["d"] = d_loss loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() if args.augment and args.augment_p == 0: ada_augment_data = torch.tensor( (torch.sign(real_pred).sum().item(), real_pred.shape[0]), device=device) ada_augment += reduce_sum(ada_augment_data) if ada_augment[1] > 255: pred_signs, n_pred = ada_augment.tolist() r_t_stat = pred_signs / n_pred if r_t_stat > args.ada_target: sign = 1 else: sign = -1 ada_aug_p += sign * ada_aug_step * n_pred ada_aug_p = min(1, max(0, ada_aug_p)) ada_augment.mul_(0) d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True real_pred = discriminator(real_img) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict["r1"] = r1_loss requires_grad(generator, True) requires_grad(discriminator, False) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) if args.augment: fake_img, _ = augment(fake_img, ada_aug_p) fake_pred = discriminator(fake_img) g_loss = g_nonsaturating_loss(fake_pred) loss_dict["g"] = g_loss generator.zero_grad() g_loss.backward() g_optim.step() g_regularize = i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) fake_img, latents = generator(noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() if get_rank() == 0: pbar.set_description(( f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}")) if wandb and args.wandb: wandb.log({ "Generator": g_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, }) if i % 1000 == 0: # save some samples with torch.no_grad(): g_ema.eval() sample, _ = g_ema([sample_z]) utils.save_image( sample, save_dir + f"/samples/{str(i).zfill(6)}.png", nrow=int(args.n_sample**0.5), normalize=True, range=(-1, 1), ) if i % 2000 == 0: #save the model torch.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, }, save_dir + f"/checkpoints/{str(i).zfill(6)}.pt", )
def train(args, loader, generator, encoder, discriminator, vggnet, g_optim, e_optim, d_optim, g_ema, e_ema, device): kwargs_d = {'detach_aux': False} if args.dataset == 'imagefolder': loader = sample_data2(loader) else: loader = sample_data(loader) if args.eval_every > 0: inception = nn.DataParallel(load_patched_inception_v3()).to(device) inception.eval() with open(args.inception, "rb") as f: embeds = pickle.load(f) real_mean = embeds["mean"] real_cov = embeds["cov"] else: inception = real_mean = real_cov = None mean_latent = None pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} avg_pix_loss = util.AverageMeter() avg_vgg_loss = util.AverageMeter() if args.distributed: g_module = generator.module e_module = encoder.module d_module = discriminator.module else: g_module = generator e_module = encoder d_module = discriminator accum = 0.5**(32 / (10 * 1000)) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256, device) sample_z = torch.randn(args.n_sample, args.latent, device=device) sample_x = load_real_samples(args, loader) if sample_x.ndim > 4: sample_x = sample_x[:, 0, ...] for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break real_img = next(loader) real_img = real_img.to(device) # Train Discriminator requires_grad(generator, False) requires_grad(encoder, False) requires_grad(discriminator, True) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) latent_real, _ = encoder(real_img) rec_img, _ = generator([latent_real], input_is_latent=True) real_pred = discriminator(real_img) fake_pred = discriminator(fake_img) rec_pred = discriminator(rec_img) d_loss_real = F.softplus(-real_pred).mean() d_loss_fake = F.softplus(fake_pred).mean() d_loss_rec = F.softplus(rec_pred).mean() loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() loss_dict["rec_score"] = rec_pred.mean() d_loss = d_loss_real + d_loss_fake + d_loss_rec loss_dict["d"] = d_loss discriminator.zero_grad() d_loss.backward() d_optim.step() d_regularize = args.d_reg_every > 0 and i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True real_pred = discriminator(real_img) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict["r1"] = r1_loss # # Train Encoder and Generator # requires_grad(generator, True) # requires_grad(encoder, True) # requires_grad(discriminator, False) # pix_loss = vgg_loss = adv_loss = torch.tensor(0., device=device) # noise = mixing_noise(args.batch, args.latent, args.mixing, device) # fake_img, _ = generator(noise) # latent_real, _ = encoder(real_img) # rec_img, _ = generator([latent_real], input_is_latent=True) # fake_pred = discriminator(fake_img) # rec_pred = discriminator(rec_img) # g_loss_fake = g_nonsaturating_loss(fake_pred) # g_loss_rec = g_nonsaturating_loss(rec_pred) # adv_loss = g_loss_fake + g_loss_rec # if args.lambda_pix > 0: # if args.pix_loss == 'l2': # pix_loss = torch.mean((rec_img - real_img) ** 2) # else: # pix_loss = F.l1_loss(rec_img, real_img) # if args.lambda_vgg > 0: # vgg_loss = torch.mean((vggnet(real_img) - vggnet(rec_img)) ** 2) # e_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv # loss_dict["e"] = e_loss # encoder.zero_grad() # generator.zero_grad() # e_loss.backward() # e_optim.step() # g_optim.step() # Train Encoder requires_grad(generator, False) requires_grad(encoder, True) requires_grad(discriminator, False) pix_loss = vgg_loss = adv_loss = torch.tensor(0., device=device) latent_real, _ = encoder(real_img) rec_img, _ = generator([latent_real], input_is_latent=True) rec_pred = discriminator(rec_img) g_loss_rec = g_nonsaturating_loss(rec_pred) adv_loss = g_loss_rec if args.lambda_pix > 0: if args.pix_loss == 'l2': pix_loss = torch.mean((rec_img - real_img)**2) else: pix_loss = F.l1_loss(rec_img, real_img) if args.lambda_vgg > 0: vgg_loss = torch.mean((vggnet(real_img) - vggnet(rec_img))**2) e_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv loss_dict["e"] = e_loss encoder.zero_grad() e_loss.backward() e_optim.step() # Train Generator requires_grad(generator, True) requires_grad(encoder, False) requires_grad(discriminator, False) pix_loss = vgg_loss = adv_loss = torch.tensor(0., device=device) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) latent_real, _ = encoder(real_img) rec_img, _ = generator([latent_real], input_is_latent=True) fake_pred = discriminator(fake_img) rec_pred = discriminator(rec_img) g_loss_fake = g_nonsaturating_loss(fake_pred) g_loss_rec = g_nonsaturating_loss(rec_pred) adv_loss = g_loss_fake + g_loss_rec if args.lambda_pix > 0: if args.pix_loss == 'l2': pix_loss = torch.mean((rec_img - real_img)**2) else: pix_loss = F.l1_loss(rec_img, real_img) if args.lambda_vgg > 0: vgg_loss = torch.mean((vggnet(real_img) - vggnet(rec_img))**2) g_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv loss_dict["g"] = g_loss generator.zero_grad() g_loss.backward() g_optim.step() g_regularize = args.g_reg_every > 0 and i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) fake_img, latents = generator(noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() with torch.no_grad(): latent_real, _ = encoder(real_img) rec_img, _ = generator([latent_real], input_is_latent=True) if args.pix_loss == 'l2': pix_loss = torch.mean((rec_img - real_img)**2) else: pix_loss = F.l1_loss(rec_img, real_img) vgg_loss = torch.mean((vggnet(real_img) - vggnet(rec_img))**2) pix_loss_val = pix_loss.mean().item() vgg_loss_val = vgg_loss.mean().item() accumulate(e_ema, e_module, accum) accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() e_loss_val = loss_reduced["e"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() avg_pix_loss.update(pix_loss_val, real_img.shape[0]) avg_vgg_loss.update(vgg_loss_val, real_img.shape[0]) if get_rank() == 0: pbar.set_description(( f"d: {d_loss_val:.4f}; e: {e_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}")) if i % args.log_every == 0: with torch.no_grad(): latent_x, _ = e_ema(sample_x) fake_x, _ = generator([latent_x], input_is_latent=True, return_latents=False) sample_pix_loss = torch.sum((sample_x - fake_x)**2) with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write( f"{i:07d}; pix: {avg_pix_loss.avg}; vgg: {avg_vgg_loss.avg}; " f"ref: {sample_pix_loss.item()};\n") if args.eval_every > 0 and i % args.eval_every == 0: with torch.no_grad(): g_ema.eval() if args.truncation < 1: mean_latent = g_ema.mean_latent(4096) features = extract_feature_from_samples( g_ema, inception, args.truncation, mean_latent, 64, args.n_sample_fid, args.device).numpy() sample_mean = np.mean(features, 0) sample_cov = np.cov(features, rowvar=False) fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov) print("fid:", fid) with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f: f.write(f"{i:07d}: fid: {float(fid):.4f}\n") if wandb and args.wandb: wandb.log({ "Generator": g_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, }) if i % args.log_every == 0: with torch.no_grad(): # Fixed fake samples g_ema.eval() sample, _ = g_ema([sample_z]) utils.save_image( sample, os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-sample.png"), nrow=int(args.n_sample**0.5), normalize=True, value_range=(-1, 1), ) # Reconstruction samples e_ema.eval() nrow = int(args.n_sample**0.5) nchw = list(sample_x.shape)[1:] latent_real, _ = e_ema(sample_x) fake_img, _ = g_ema([latent_real], input_is_latent=True, return_latents=False) sample = torch.cat( (sample_x.reshape(args.n_sample // nrow, nrow, *nchw), fake_img.reshape(args.n_sample // nrow, nrow, *nchw)), 1) utils.save_image( sample.reshape(2 * args.n_sample, *nchw), os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-recon.png"), nrow=nrow, normalize=True, value_range=(-1, 1), ) if i % args.save_every == 0: torch.save( { "g": g_module.state_dict(), "e": e_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "e_ema": e_ema.state_dict(), "g_optim": g_optim.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"), ) if i % args.save_latest_every == 0: torch.save( { "g": g_module.state_dict(), "e": e_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "e_ema": e_ema.state_dict(), "g_optim": g_optim.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"latest.pt"), )
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device): loader = sample_data(loader) start_iter = args.start_iter // get_world_size() // args.batch pbar = range(args.iter // get_world_size() // args.batch) if get_rank() == 0: pbar = tqdm(pbar, initial=start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 seg_loss = torch.tensor(0.0, device=device) r1_loss = torch.tensor(0.0, device=device) path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg, seg_loss_val, shift_loss_val = 0, 0, 0 loss_dict = {} if args.distributed: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator accum = 0.5**(32 / (10 * 1000)) sample_condition_img, sample_conditions, condition_img_color = random_condition_img( args.n_sample) if get_rank() == 0: os.makedirs(f'sample', exist_ok=True) os.makedirs(f'sample/{args.name}', exist_ok=True) os.makedirs(f'ckpts/{args.name}', exist_ok=True) if args.with_tensorboard: os.makedirs(f'tensorboard/{args.name}', exist_ok=True) writer = SummaryWriter(f'tensorboard/{args.name}') for idx in pbar: i = idx + start_iter if i > args.iter: print('Done!') break real_img, condition_img = next(loader) real_img = real_img.to(device) if args.condition_path is not None: condition_img = condition_img.to(device) else: condition_img = None requires_grad(generator, False) requires_grad(discriminator, True) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _, _, _ = generator(noise, condition_img=condition_img) if args.with_rgbs: condition_img_encoder = F.interpolate(condition_img, size=args.resolution, mode='nearest') real_img = torch.cat((real_img, condition_img_encoder), dim=1) fake_pred, _ = discriminator(fake_img) real_pred, real_pred_feat = discriminator(real_img) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict['d'] = d_loss loss_dict['real_score'] = real_pred.mean() loss_dict['fake_score'] = fake_pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True real_pred, _ = discriminator(real_img) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict['r1'] = r1_loss requires_grad(generator, True) requires_grad(discriminator, False) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _, _, parsing_feature = generator( noise, condition_img=condition_img) fake_pred, fake_pred_feat = discriminator(fake_img) g_loss = g_nonsaturating_loss(fake_pred) loss_dict['g'] = g_loss loss_dict['seg'] = seg_loss loss_dict['shift_loss'] = seg_loss loss = g_loss generator.zero_grad() loss.backward() g_optim.step() requires_grad(generator, True) requires_grad(discriminator, False) g_regularize = i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) if args.condition_path is not None: condition_img = condition_img[range(path_batch_size)] condition_img.requires_grad = True fake_img, latents, _, _ = generator(noise, return_latents=True, condition_img=condition_img) path_loss, mean_path_length, path_lengths, isNaN = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.g_reg_every * args.path_regularize * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() if not isNaN: g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict['path'] = path_loss loss_dict['path_length'] = path_lengths.mean() accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced['d'].mean().item() g_loss_val = loss_reduced['g'].mean().item() r1_val = loss_reduced['r1'].mean().item() path_length_val = loss_reduced['path_length'].mean().item() if args.condition_path is not None and (0 == i % args.g_reg_every): seg_loss_val = loss_reduced['seg'].mean().item() shift_loss_val = loss_reduced['shift_loss'].mean().item() if get_rank() == 0: pbar.set_description((f'mean path: {mean_path_length_avg:.4f}')) if args.with_tensorboard: writer.add_scalar('Loss/Generator', g_loss_val, i) writer.add_scalar('Loss/Discriminator', d_loss_val, i) writer.add_scalar('Loss/R1', r1_val, i) writer.add_scalar('Loss/Path Length', path_length_val, i) writer.add_scalar('Loss/mean path', mean_path_length_avg, i) if args.condition_path is not None: writer.add_scalar('Loss/seg_img', seg_loss_val, i) writer.add_scalar('Loss/shift_loss', shift_loss_val, i) steps = get_world_size() * args.batch * (1 + i) if steps % 100000 < get_world_size() * args.batch or ( steps < 1000 and steps % 500 == get_world_size() * args.batch): with torch.no_grad(): g_ema.eval() samples, featuresMaps, parsing_features = [], [], [] small_batch = args.n_sample // args.batch if 0 != args.n_sample % args.batch: small_batch += 1 # only condition change rows = int(args.n_sample**0.5) if args.condition_path is not None: sample_z = mixing_noise(rows, args.latent, args.mixing, device) sample_z = sample_z.unsqueeze(1).repeat( 1, rows, 1, 1).view(args.n_sample, sample_z.shape[1], sample_z.shape[2]) else: sample_z = mixing_noise(args.n_sample, args.latent, args.mixing, device) for k in range(small_batch): start, end = k * args.batch, (k + 1) * args.batch if k == small_batch - 1: end = sample_z.shape[0] if args.condition_path is not None: sample_condition_img_sub = sample_condition_img[ start:end] sample_condition_img_sub = random_affine( sample_condition_img_sub.clone(), Scale=0.0).to(device) else: sample_condition_img_sub = None sample, _, _, _ = g_ema( sample_z[start:end], condition_img=sample_condition_img_sub) samples.append(sample.cpu().detach()) samples = torch.cat(samples, dim=0) nrow = int(args.n_sample**0.5) c, h, w = samples.shape[-3:] samples = samples.reshape(nrow, nrow, c, h, w).transpose( 1, 0).reshape(-1, c, h, w) utils.save_image( samples, f'sample/{args.name}/{str(steps).zfill(6)}.png', nrow=int(args.n_sample**0.5), normalize=True, range=(-1, 1), ) if 0 == i: c, h, w = condition_img_color.shape[-3:] condition_img_color = condition_img_color.reshape( nrow, nrow, c, h, w).transpose(1, 0).reshape(-1, c, h, w) utils.save_image( condition_img_color, f'sample/{args.name}/seg_vis.png', nrow=nrow, normalize=True, range=(-1, 1), ) if (steps + get_world_size() * args.batch) % 100000 < get_world_size( ) * args.batch and steps != args.start_iter: torch.save( { 'g': g_module.state_dict(), 'd': d_module.state_dict(), 'g_ema': g_ema.state_dict(), # 'g_optim': g_optim.state_dict(), # 'd_optim': d_optim.state_dict(), }, f'ckpts/{args.name}/{str(steps).zfill(6)}.pt', )
def train(args, loader, encoder, generator, discriminator, vggnet, pwcnet, e_optim, d_optim, e_ema, device): loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) d_loss_val = 0 e_loss_val = 0 rec_loss_val = 0 vgg_loss_val = 0 adv_loss_val = 0 loss_dict = { "d": torch.tensor(0., device=device), "real_score": torch.tensor(0., device=device), "fake_score": torch.tensor(0., device=device), "r1_d": torch.tensor(0., device=device), "r1_e": torch.tensor(0., device=device), "rec": torch.tensor(0., device=device), } avg_pix_loss = util.AverageMeter() avg_vgg_loss = util.AverageMeter() if args.distributed: e_module = encoder.module d_module = discriminator.module g_module = generator.module else: e_module = encoder d_module = discriminator g_module = generator accum = 0.5**(32 / (10 * 1000)) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256, device) sample_x = accumulate_batches(loader, args.n_sample).to(device) requires_grad(generator, False) # always False generator.eval() # Generator should be ema and in eval mode # if args.no_ema or e_ema is None: # e_ema = encoder for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break real_img1, real_img2 = next(loader) real_img1 = real_img1.to(device) real_img2 = real_img2.to(device) # Train Encoder if args.toggle_grads: requires_grad(encoder, True) requires_grad(discriminator, False) pix_loss = vgg_loss = adv_loss = rec_loss = torch.tensor(0., device=device) latent_real1 = encoder(real_img1) fake_img1, _ = generator([latent_real1], input_is_latent=True, return_latents=False) latent_real2 = encoder(real_img2) fake_img2, _ = generator([latent_real2], input_is_latent=True, return_latents=False) if args.lambda_adv > 0: # if args.augment: # fake_img_aug1, _ = augment(fake_img1, ada_aug_p) # else: # fake_img_aug1 = fake_img1 fake_img_pair = torch.cat((fake_img1, fake_img2), 1) fake_pred = discriminator(fake_img_pair) adv_loss = g_nonsaturating_loss(fake_pred) if args.lambda_pix > 0: pix_loss = torch.mean((real_img1 - fake_img1)**2) if args.reconstruct_pair: pix_loss += torch.mean((real_img2 - fake_img2)**2) if args.lambda_vgg > 0: real_feat = vggnet(real_img1) fake_feat = vggnet(fake_img1) vgg_loss = torch.mean((real_feat - fake_feat)**2) if args.reconstruct_pair: real_feat = vggnet(real_img2) fake_feat = vggnet(fake_img2) vgg_loss += torch.mean((real_feat - fake_feat)**2) e_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv loss_dict["e"] = e_loss loss_dict["pix"] = pix_loss loss_dict["vgg"] = vgg_loss loss_dict["adv"] = adv_loss encoder.zero_grad() e_loss.backward() e_optim.step() # if args.train_on_fake: # e_regularize = args.e_rec_every > 0 and i % args.e_rec_every == 0 # if e_regularize and args.lambda_rec > 0: # noise = mixing_noise(args.batch, args.latent, args.mixing, device) # fake_img, latent_fake = generator(noise, input_is_latent=False, return_latents=True) # latent_pred = encoder(fake_img) # if latent_pred.ndim < 3: # latent_pred = latent_pred.unsqueeze(1).repeat(1, latent_fake.size(1), 1) # rec_loss = torch.mean((latent_fake - latent_pred) ** 2) # encoder.zero_grad() # (rec_loss * args.lambda_rec).backward() # e_optim.step() # loss_dict["rec"] = rec_loss e_regularize = args.e_reg_every > 0 and i % args.e_reg_every == 0 if e_regularize: # why not regularize on augmented real? real_img_pair = torch.cat((real_img1, real_img2), 1) real_img_pair.requires_grad = True real_pred = encoder(real_img_pair) r1_loss_e = d_r1_loss(real_pred, real_img_pair) encoder.zero_grad() (args.r1 / 2 * r1_loss_e * args.e_reg_every + 0 * real_pred.view(-1)[0]).backward() e_optim.step() loss_dict["r1_e"] = r1_loss_e if not args.no_ema and e_ema is not None: accumulate(e_ema, e_module, accum) # Train Discriminator if args.toggle_grads: requires_grad(encoder, False) requires_grad(discriminator, True) if not args.no_update_discriminator and args.lambda_adv > 0: latent_real1 = encoder(real_img1) fake_img1, _ = generator([latent_real1], input_is_latent=True, return_latents=False) latent_real2 = encoder(real_img2) fake_img2, _ = generator([latent_real2], input_is_latent=True, return_latents=False) # if args.augment: # real_img_aug, _ = augment(real_img, ada_aug_p) # fake_img_aug, _ = augment(fake_img, ada_aug_p) # else: # real_img_aug = real_img # fake_img_aug = fake_img fake_img_pair = torch.cat((fake_img1, fake_img2), 1) real_img_pair = torch.cat((real_img1, real_img2), 1) fake_pred = discriminator(fake_img_pair) real_pred = discriminator(real_img_pair) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict["d"] = d_loss loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() # if args.augment and args.augment_p == 0: # ada_aug_p = ada_augment.tune(real_pred) # r_t_stat = ada_augment.r_t_stat d_regularize = args.d_reg_every > 0 and i % args.d_reg_every == 0 if d_regularize: # why not regularize on augmented real? real_img_pair = torch.cat((real_img1, real_img2), 1) real_img_pair.requires_grad = True real_pred = discriminator(real_img_pair) r1_loss_d = d_r1_loss(real_pred, real_img_pair) discriminator.zero_grad() (args.r1 / 2 * r1_loss_d * args.d_reg_every + 0 * real_pred.view(-1)[0]).backward() # Why 0* ? Answer is here https://github.com/rosinality/stylegan2-pytorch/issues/76 d_optim.step() loss_dict["r1_d"] = r1_loss_d loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() e_loss_val = loss_reduced["e"].mean().item() r1_d_val = loss_reduced["r1_d"].mean().item() r1_e_val = loss_reduced["r1_e"].mean().item() pix_loss_val = loss_reduced["pix"].mean().item() vgg_loss_val = loss_reduced["vgg"].mean().item() adv_loss_val = loss_reduced["adv"].mean().item() rec_loss_val = loss_reduced["rec"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() avg_pix_loss.update(pix_loss_val, real_img1.shape[0]) avg_vgg_loss.update(vgg_loss_val, real_img1.shape[0]) if get_rank() == 0: pbar.set_description(( f"d: {d_loss_val:.4f}; e: {e_loss_val:.4f}; r1_d: {r1_d_val:.4f}; r1_e: {r1_e_val:.4f}; " f"pix: {pix_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; adv: {adv_loss_val:.4f}; " f"rec: {rec_loss_val:.4f}; augment: {ada_aug_p:.4f}")) if i % args.log_every == 0: with torch.no_grad(): latent_x = e_ema(sample_x) fake_x, _ = generator([latent_x], input_is_latent=True, return_latents=False) sample_pix_loss = torch.sum((sample_x - fake_x)**2) with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write( f"{i:07d}; pix: {avg_pix_loss.avg}; vgg: {avg_vgg_loss.avg}; " f"ref: {sample_pix_loss.item()};\n") if wandb and args.wandb: wandb.log({ "Encoder": e_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1 D": r1_d_val, "R1 E": r1_e_val, "Pix Loss": pix_loss_val, "VGG Loss": vgg_loss_val, "Adv Loss": adv_loss_val, "Rec Loss": rec_loss_val, "Real Score": real_score_val, "Fake Score": fake_score_val, }) if i % args.log_every == 0: with torch.no_grad(): e_eval = encoder if args.no_ema else e_ema e_eval.eval() nrow = int(args.n_sample**0.5) nchw = list(sample_x.shape)[1:] latent_real = e_eval(sample_x) fake_img, _ = generator([latent_real], input_is_latent=True, return_latents=False) sample = torch.cat( (sample_x.reshape(args.n_sample // nrow, nrow, *nchw), fake_img.reshape(args.n_sample // nrow, nrow, *nchw)), 1) utils.save_image( sample.reshape(2 * args.n_sample, *nchw), os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}.png"), nrow=nrow, normalize=True, value_range=(-1, 1), ) e_eval.train() if i % args.save_every == 0: e_eval = encoder if args.no_ema else e_ema torch.save( { "e": e_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_module.state_dict(), "e_ema": e_eval.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"), ) if i % args.save_latest_every == 0: torch.save( { "e": e_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_module.state_dict(), "e_ema": e_eval.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"latest.pt"), )
def train(args, generator, discriminator_photo, discriminator_cari, discriminator_feat_p, discriminator_feat_c, g_optim, d_optim_p, d_optim_c, d_optim_fp, d_optim_fc, g_ema, p_cls, c_cls, id_net, device): pbar = range(args.iter) if get_rank() == 0: if not os.path.exists(f'checkpoint/{args.name}'): os.makedirs(f'checkpoint/{args.name}') if not os.path.exists(f'sample/{args.name}'): os.makedirs(f'sample/{args.name}') pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) d_loss_val = 0 d_feat_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) gan_loss_val = 0 gan_feat_loss_val = 0 idt_loss_val = 0 attr_loss_val = 0 loss_dict = {} if args.distributed: g_module = generator.module d_module_p = discriminator_photo.module d_module_c = discriminator_cari.module d_module_feat_p = discriminator_feat_p.module d_module_feat_c = discriminator_feat_c.module else: g_module = generator d_module_p = discriminator_photo d_module_c = discriminator_cari d_module_feat_p = discriminator_feat_p d_module_feat_c = discriminator_feat_c accum = 0.5**(32 / (10 * 1000)) ada_augment = torch.tensor([0.0, 0.0], device=device) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 ada_aug_step = args.ada_target / args.ada_length r_t_stat = 0 sample_z = torch.randn(args.n_sample, args.latent, device=device) criterion_BCE = nn.BCELoss() for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break ''' Discriminator for feat cari ''' requires_grad(generator, False) requires_grad(discriminator_feat_c, True) noise = mixing_noise(args.batch, args.latent, args.mixing, device) noise_fine = mixing_noise(args.batch, args.latent, args.mixing, device) ret = generator(noise, noise_fine, truncation_latent=mean_latent, mode='p2c') fake_feat = ret['co'] real_feat = ret['gt_co'].detach() fake_pred = discriminator_feat_c(fake_feat) real_pred = discriminator_feat_c(real_feat) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict["d_feat_c"] = d_loss loss_dict["real_score_feat_c"] = real_pred.mean() loss_dict["fake_score_feat_c"] = fake_pred.mean() discriminator_feat_c.zero_grad() d_loss.backward() d_optim_fc.step() d_regularize = i % args.d_reg_every == 0 if d_regularize: real_feat.requires_grad = True real_pred = discriminator_feat_c(real_feat) r1_loss = d_r1_loss(real_pred, real_feat) discriminator_feat_c.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim_fc.step() ''' Discriminator for feat photo ''' requires_grad(generator, False) requires_grad(discriminator_feat_p, True) noise = mixing_noise(args.batch, args.latent, args.mixing, device) noise_fine = mixing_noise(args.batch, args.latent, args.mixing, device) ret = generator(noise, noise_fine, truncation_latent=mean_latent, mode='c2p') fake_feat = ret['po'] real_feat = ret['gt_po'].detach() fake_pred = discriminator_feat_p(fake_feat) real_pred = discriminator_feat_p(real_feat) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict["d_feat_p"] = d_loss loss_dict["real_score_feat_p"] = real_pred.mean() loss_dict["fake_score_feat_p"] = fake_pred.mean() discriminator_feat_p.zero_grad() d_loss.backward() d_optim_fp.step() d_regularize = i % args.d_reg_every == 0 if d_regularize: real_feat.requires_grad = True real_pred = discriminator_feat_p(real_feat) r1_loss = d_r1_loss(real_pred, real_feat) discriminator_feat_p.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim_fp.step() ''' Discriminator for cari ''' requires_grad(generator, False) requires_grad(discriminator_cari, True) noise = mixing_noise(args.batch, args.latent, args.mixing, device) noise_fine = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img = generator(noise, noise_fine, truncation_latent=mean_latent, mode='p2c')['result'] real_img = generator(noise, noise_fine, truncation_latent=mean_latent, mode='c_gt') real_img = real_img.detach() fake_pred = discriminator_cari(fake_img) real_pred = discriminator_cari(real_img) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict["d_c"] = d_loss loss_dict["real_score_c"] = real_pred.mean() loss_dict["fake_score_c"] = fake_pred.mean() discriminator_cari.zero_grad() d_loss.backward() d_optim_c.step() d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True real_pred = discriminator_cari(real_img) r1_loss = d_r1_loss(real_pred, real_img) discriminator_cari.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim_c.step() ''' Discriminator for photo ''' requires_grad(generator, False) requires_grad(discriminator_photo, True) noise = mixing_noise(args.batch, args.latent, args.mixing, device) noise_fine = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img = generator(noise, noise_fine, truncation_latent=mean_latent, mode='c2p')['result'] real_img = generator(noise, noise_fine, truncation_latent=mean_latent, mode='p_gt') real_img = real_img.detach() fake_pred = discriminator_photo(fake_img) real_pred = discriminator_photo(real_img) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict["d_p"] = d_loss loss_dict["real_score_p"] = real_pred.mean() loss_dict["fake_score_p"] = fake_pred.mean() discriminator_photo.zero_grad() d_loss.backward() d_optim_p.step() d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True real_pred = discriminator_photo(real_img) r1_loss = d_r1_loss(real_pred, real_img) discriminator_photo.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim_p.step() loss_dict["r1"] = r1_loss requires_grad(generator.module.deformation_blocks_CP, True) requires_grad(generator.module.deformation_blocks_PC, True) requires_grad(discriminator_photo, False) requires_grad(discriminator_cari, False) requires_grad(discriminator_feat_p, False) requires_grad(discriminator_feat_c, False) noise = mixing_noise(args.batch, args.latent, args.mixing, device) ret_p2c = generator(noise, truncation_latent=mean_latent, mode='p2c') ret_p2c_recon = generator(noise, truncation_latent=mean_latent, mode='p2c_recon') ret_c2p = generator(noise, truncation_latent=mean_latent, mode='c2p') cyc_loss_p2c = 0 cyc_loss_c2p = 0 for lv in range(len(ret_p2c['po'])): cyc_loss_p2c += F.mse_loss(ret_p2c['po'][lv], ret_p2c['ro'][lv]) cyc_loss_c2p += F.mse_loss(ret_c2p['co'][lv], ret_c2p['ro'][lv]) cyc_loss = (cyc_loss_p2c + cyc_loss_c2p) / 2 attr_p_p2c = p_cls(ret_p2c['org']).detach() attr_c_p2c = c_cls(ret_p2c['result']) attr_loss_p2c = criterion_BCE(attr_c_p2c, attr_p_p2c) attr_c_c2p = c_cls(ret_c2p['org']).detach() attr_p_c2p = p_cls(ret_c2p['result']) attr_loss_c2p = criterion_BCE(attr_p_c2p, attr_c_c2p) attr_loss = (attr_loss_p2c + attr_loss_c2p) / 2 fake_pred_photo = discriminator_photo(ret_c2p['result']) fake_pred_cari = discriminator_cari(ret_p2c['result']) gan_loss_p2c = g_nonsaturating_loss(fake_pred_cari) gan_loss_c2p = g_nonsaturating_loss(fake_pred_photo) gan_loss = (gan_loss_p2c + gan_loss_c2p) / 2 fake_feat_photo = discriminator_feat_p(ret_c2p['po']) fake_feat_cari = discriminator_feat_c(ret_p2c['co']) gan_feat_loss_p2c = g_nonsaturating_loss(fake_feat_cari) gan_feat_loss_c2p = g_nonsaturating_loss(fake_feat_photo) gan_feat_loss = (gan_feat_loss_p2c + gan_feat_loss_c2p) / 2 cyc_id_loss = F.mse_loss(id_net(ret_p2c_recon['result']), id_net(ret_p2c_recon['org']).detach()) g_loss = 10 * gan_loss + 10 * cyc_loss + gan_feat_loss + gan_feat_loss + 10 * attr_loss + 10000 * cyc_id_loss loss_dict["gan"] = gan_loss loss_dict["cyc"] = cyc_loss loss_dict["attr"] = attr_loss loss_dict["feat"] = gan_feat_loss loss_dict["idt"] = cyc_id_loss generator.zero_grad() g_loss.backward() g_optim.step() accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_p_val = loss_reduced["d_p"].mean().item() d_loss_c_val = loss_reduced["d_c"].mean().item() gan_loss_val = loss_reduced["gan"].mean().item() cyc_loss_val = loss_reduced["cyc"].mean().item() feat_loss_val = loss_reduced["feat"].mean().item() attr_loss_val = loss_reduced["attr"].mean().item() idt_loss_val = loss_reduced["idt"].mean().item() r1_val = loss_reduced["r1"].mean().item() real_score_p_val = loss_reduced["real_score_p"].mean().item() fake_score_p_val = loss_reduced["fake_score_p"].mean().item() real_score_c_val = loss_reduced["real_score_c"].mean().item() fake_score_c_val = loss_reduced["fake_score_c"].mean().item() if get_rank() == 0: pbar.set_description(( f"d_p: {d_loss_p_val:.4f}; d_c: {d_loss_c_val:.4f}; g: {gan_loss_val:.4f}, {cyc_loss_val:.4f}, {feat_loss_val:.4f}, {attr_loss_val:.4f}, {idt_loss_val:.4f}; r1: {r1_val:.4f}; " f"augment: {ada_aug_p:.4f}")) if wandb and args.wandb: wandb.log({ "Generator_gan": gan_loss_val, "Generator_idt": idt_loss_val, "Generator_attr": attr_loss_val, "Discriminator_p": d_loss_p_val, "Discriminator_c": d_loss_c_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Real Score_p": real_score_p_val, "Fake Score_p": fake_score_p_val, "Real Score_c": real_score_c_val, "Fake Score_c": fake_score_c_val, }) if i % 100 == 0: with torch.no_grad(): g_ema.eval() ret = g_ema([sample_z], truncation_latent=mean_latent, mode='p2c') utils.save_image( ret['result'], f"sample/{args.name}/p2c_exg_{str(i).zfill(6)}.png", nrow=int(args.n_sample**0.5), normalize=True, range=(-1, 1), ) utils.save_image( ret['org'], f"sample/{args.name}/p2c_gt_{str(i).zfill(6)}.png", nrow=int(args.n_sample**0.5), normalize=True, range=(-1, 1), ) ret = g_ema([sample_z], truncation_latent=mean_latent, mode='c2p') utils.save_image( ret['result'], f"sample/{args.name}/c2p_exg_{str(i).zfill(6)}.png", nrow=int(args.n_sample**0.5), normalize=True, range=(-1, 1), ) utils.save_image( ret['org'], f"sample/{args.name}/c2p_gt_{str(i).zfill(6)}.png", nrow=int(args.n_sample**0.5), normalize=True, range=(-1, 1), ) if i % 1000 == 0: torch.save( { "g": g_module.state_dict(), "d_p": d_module_p.state_dict(), "d_c": d_module_c.state_dict(), "d_feat_p": d_module_feat_p.state_dict(), "d_feat_c": d_module_feat_c.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "d_optim_p": d_optim_p.state_dict(), "d_optim_c": d_optim_c.state_dict(), "d_optim_fp": d_optim_fp.state_dict(), "d_optim_fc": d_optim_fc.state_dict(), "args": args, "ada_aug_p": ada_aug_p, }, f"checkpoint/{args.name}/{str(i).zfill(6)}.pt", )
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device): loader = sample_data(loader, datatype="imagefolder") # inception related: if (get_rank() == 0): from calc_inception import load_patched_inception_v3 inception = load_patched_inception_v3().to(device) inception.eval() if args.eval_every > 0: with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f: f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-' * 50}\n") if args.log_every > 0: with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-' * 50}\n") pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} if args.distributed: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator accum = 0.5 ** (32 / (10 * 1000)) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 8, device) sample_z = torch.randn(args.n_sample, args.latent, device=device) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break real_img = next(loader) real_img = real_img.to(device) requires_grad(generator, False) requires_grad(discriminator, True) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) fake_img, _ = augment(fake_img, ada_aug_p) else: real_img_aug = real_img fake_pred = discriminator(fake_img) real_pred = discriminator(real_img_aug) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict["d"] = d_loss loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() if args.augment and args.augment_p == 0: ada_aug_p = ada_augment.tune(real_pred) r_t_stat = ada_augment.r_t_stat d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) else: real_img_aug = real_img real_pred = discriminator(real_img_aug) r1_loss = d_r1_loss(real_pred, real_img,args) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict["r1"] = r1_loss requires_grad(generator, True) requires_grad(discriminator, False) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) if args.augment: fake_img, _ = augment(fake_img, ada_aug_p) fake_pred = discriminator(fake_img) g_loss = g_nonsaturating_loss(fake_pred) loss_dict["g"] = g_loss generator.zero_grad() g_loss.backward() g_optim.step() #g_reg starts g_regularize = False if args.useG_reg==True: # print("I entered g_reg") g_regularize = i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) fake_img, latents = generator(noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length ) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = ( reduce_sum(mean_path_length).item() / get_world_size() ) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() if get_rank() == 0: pbar.set_description( ( f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}" ) ) # inception related: if args.eval_every > 0 and i % args.eval_every == 0: real_mean = real_cov = mean_latent = None with open(args.inception, "rb") as f: embeds = pickle.load(f) real_mean = embeds["mean"] real_cov = embeds["cov"] # print("yahooo!\n") with torch.no_grad(): g_ema.eval() if args.truncation < 1: mean_latent = g_ema.mean_latent(4096) # print("I am fine sir!\n") features = extract_feature_from_samples( g_ema, inception, args.truncation, mean_latent, 64, args.n_sample_fid, device ).numpy() # print("I am normal sir!") sample_mean = np.mean(features, 0) sample_cov = np.cov(features, rowvar=False) fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov) with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f: f.write(f"{i:07d}; fid: {float(fid):.4f};\n") # print("alright hurray \n") if i % args.log_every == 0: with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write( ( f"{i:07d}; " f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f};\n" ) ) if i % args.log_every == 0: with torch.no_grad(): g_ema.eval() sample, _ = g_ema([sample_z]) utils.save_image( sample, os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}.png"), nrow=int(args.n_sample ** 0.5), normalize=True, range=(-1, 1), ) if i % args.save_every == 0: torch.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"), ) if i % args.save_latest_every == 0: torch.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"latest.pt"), )
def train(args, loader, generator, encoder, discriminator, discriminator2, vggnet, g_optim, e_optim, d_optim, d2_optim, g_ema, e_ema, device): # kwargs_d = {'detach_aux': args.detach_d_aux_head} if args.dataset == 'imagefolder': loader = sample_data2(loader) else: loader = sample_data(loader) if args.eval_every > 0: inception = nn.DataParallel(load_patched_inception_v3()).to(device) inception.eval() with open(args.inception, "rb") as f: embeds = pickle.load(f) real_mean = embeds["mean"] real_cov = embeds["cov"] else: inception = real_mean = real_cov = None mean_latent = None pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} avg_pix_loss = util.AverageMeter() avg_vgg_loss = util.AverageMeter() if args.distributed: g_module = generator.module e_module = encoder.module d_module = discriminator.module else: g_module = generator e_module = encoder d_module = discriminator d2_module = None if discriminator2 is not None: if args.distributed: d2_module = discriminator2.module else: d2_module = discriminator2 accum = 0.5**(32 / (10 * 1000)) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256, device) sample_z = torch.randn(args.n_sample, args.latent, device=device) sample_x = load_real_samples(args, loader) sample_x1 = sample_x[:, 0, ...] sample_x2 = sample_x[:, -1, ...] sample_idx = torch.randperm(args.n_sample) n_step_max = max(args.n_step_d, args.n_step_e) requires_grad(g_ema, False) requires_grad(e_ema, False) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break frames = [get_batch(loader, device) for _ in range(n_step_max)] # Train Discriminator requires_grad(generator, False) requires_grad(encoder, False) requires_grad(discriminator, True) for step_index in range(args.n_step_d): frames1, frames2 = frames[step_index] real_img = frames1 noise = mixing_noise(args.batch, args.latent, args.mixing, device) if args.use_ema: g_ema.eval() fake_img, _ = g_ema(noise) else: fake_img, _ = generator(noise) if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) fake_img, _ = augment(fake_img, ada_aug_p) else: real_img_aug = real_img fake_pred = discriminator(fake_img) real_pred = discriminator(real_img_aug) d_loss_fake = F.softplus(fake_pred).mean() d_loss_real = F.softplus(-real_pred).mean() loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() d_loss_rec = 0. if args.lambda_rec_d > 0 and not args.decouple_d: # Do not train D on x_rec if decouple_d if args.use_ema: e_ema.eval() g_ema.eval() latent_real, _ = e_ema(real_img) rec_img, _ = g_ema([latent_real], input_is_latent=True) else: latent_real, _ = encoder(real_img) rec_img, _ = generator([latent_real], input_is_latent=True) rec_pred = discriminator(rec_img) d_loss_rec = F.softplus(rec_pred).mean() loss_dict["rec_score"] = rec_pred.mean() d_loss_cross = 0. if args.lambda_cross_d > 0 and not args.decouple_d: if args.use_ema: e_ema.eval() w1, _ = e_ema(frames1) w2, _ = e_ema(frames2) else: w1, _ = encoder(frames1) w2, _ = encoder(frames2) dw = w2 - w1 dw_shuffle = dw[torch.randperm(args.batch), ...] if args.use_ema: g_ema.eval() cross_img, _ = g_ema([w1 + dw_shuffle], input_is_latent=True) else: cross_img, _ = generator([w1 + dw_shuffle], input_is_latent=True) cross_pred = discriminator(cross_img) d_loss_cross = F.softplus(cross_pred).mean() d_loss_fake_cross = 0. if args.lambda_fake_cross_d > 0: if args.use_ema: e_ema.eval() w1, _ = e_ema(frames1) w2, _ = e_ema(frames2) else: w1, _ = encoder(frames1) w2, _ = encoder(frames2) dw = w2 - w1 noise = mixing_noise(args.batch, args.latent, args.mixing, device) if args.use_ema: g_ema.eval() style = g_ema.get_styles(noise).view(args.batch, -1) else: style = generator.get_styles(noise).view(args.batch, -1) if dw.shape[1] < style.shape[1]: # W space dw = dw.repeat(1, args.n_latent) if args.use_ema: cross_img, _ = g_ema([style + dw], input_is_latent=True) else: cross_img, _ = generator([style + dw], input_is_latent=True) fake_cross_pred = discriminator(cross_img) d_loss_fake_cross = F.softplus(fake_cross_pred).mean() d_loss = (d_loss_real + d_loss_fake + d_loss_fake_cross * args.lambda_fake_cross_d + d_loss_rec * args.lambda_rec_d + d_loss_cross * args.lambda_cross_d) loss_dict["d"] = d_loss discriminator.zero_grad() d_loss.backward() d_optim.step() if args.augment and args.augment_p == 0: ada_aug_p = ada_augment.tune(real_pred) r_t_stat = ada_augment.r_t_stat d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True real_pred = discriminator(real_img) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict["r1"] = r1_loss # Train Discriminator2 if args.decouple_d and discriminator2 is not None: requires_grad(generator, False) requires_grad(encoder, False) requires_grad(discriminator2, True) for step_index in range( args.n_step_e): # n_step_d2 is same as n_step_e frames1, frames2 = frames[step_index] real_img = frames1 if args.use_ema: e_ema.eval() g_ema.eval() latent_real, _ = e_ema(real_img) rec_img, _ = g_ema([latent_real], input_is_latent=True) else: latent_real, _ = encoder(real_img) rec_img, _ = generator([latent_real], input_is_latent=True) rec_pred = discriminator2(rec_img) d2_loss_rec = F.softplus(rec_pred).mean() real_pred1 = discriminator2(frames1) d2_loss_real = F.softplus(-real_pred1).mean() if args.use_frames2_d: real_pred2 = discriminator2(frames2) d2_loss_real += F.softplus(-real_pred2).mean() if args.use_ema: e_ema.eval() w1, _ = e_ema(frames1) w2, _ = e_ema(frames2) else: w1, _ = encoder(frames1) w2, _ = encoder(frames2) dw = w2 - w1 dw_shuffle = dw[torch.randperm(args.batch), ...] cross_img, _ = generator([w1 + dw_shuffle], input_is_latent=True) cross_pred = discriminator2(cross_img) d2_loss_cross = F.softplus(cross_pred).mean() d2_loss = d2_loss_real + d2_loss_rec + d2_loss_cross loss_dict["d2"] = d2_loss loss_dict["rec_score"] = rec_pred.mean() loss_dict["cross_score"] = cross_pred.mean() discriminator2.zero_grad() d2_loss.backward() d2_optim.step() d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True real_pred = discriminator2(real_img) r1_loss = d_r1_loss(real_pred, real_img) discriminator2.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d2_optim.step() # Train Encoder requires_grad(encoder, True) requires_grad(generator, args.train_ge) requires_grad(discriminator, False) if discriminator2 is not None: requires_grad(discriminator2, False) pix_loss = vgg_loss = adv_loss = torch.tensor(0., device=device) for step_index in range(args.n_step_e): frames1, frames2 = frames[step_index] real_img = frames1 latent_real, _ = encoder(real_img) if args.use_ema: g_ema.eval() rec_img, _ = g_ema([latent_real], input_is_latent=True) else: rec_img, _ = generator([latent_real], input_is_latent=True) if args.lambda_adv > 0: if not args.decouple_d: rec_pred = discriminator(rec_img) else: rec_pred = discriminator2(rec_img) adv_loss = g_nonsaturating_loss(rec_pred) if args.lambda_pix > 0: if args.pix_loss == 'l2': pix_loss = torch.mean((rec_img - real_img)**2) else: pix_loss = F.l1_loss(rec_img, real_img) if args.lambda_vgg > 0: vgg_loss = torch.mean((vggnet(real_img) - vggnet(rec_img))**2) e_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv loss_dict["e"] = e_loss loss_dict["pix"] = pix_loss loss_dict["vgg"] = vgg_loss loss_dict["adv"] = adv_loss if args.train_ge: encoder.zero_grad() generator.zero_grad() e_loss.backward() e_optim.step() g_optim.step() else: encoder.zero_grad() e_loss.backward() e_optim.step() # Train Generator requires_grad(generator, True) requires_grad(discriminator, False) if discriminator2 is not None: requires_grad(discriminator2, False) frames1, frames2 = frames[0] real_img = frames1 g_loss_fake = 0. if args.lambda_fake_g > 0: noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) if args.augment: fake_img, _ = augment(fake_img, ada_aug_p) fake_pred = discriminator(fake_img) g_loss_fake = g_nonsaturating_loss(fake_pred) g_loss_rec = 0. if args.lambda_rec_g > 0: if args.use_ema: e_ema.eval() latent_real, _ = e_ema(real_img) else: latent_real, _ = encoder(real_img) rec_img, _ = generator([latent_real], input_is_latent=True) if not args.decouple_d: rec_pred = discriminator(rec_img) else: rec_pred = discriminator2(rec_img) g_loss_rec = g_nonsaturating_loss(rec_pred) g_loss_cross = 0. if args.lambda_cross_g > 0: if args.use_ema: e_ema.eval() w1, _ = e_ema(frames1) w2, _ = e_ema(frames2) else: w1, _ = encoder(frames1) w2, _ = encoder(frames2) dw = w2 - w1 dw_shuffle = dw[torch.randperm(args.batch), ...] cross_img, _ = generator([w1 + dw_shuffle], input_is_latent=True) if not args.decouple_d: cross_pred = discriminator(cross_img) else: cross_pred = discriminator2(cross_img) g_loss_cross = g_nonsaturating_loss(cross_pred) g_loss_fake_cross = 0. if args.lambda_fake_cross_g > 0: if args.use_ema: e_ema.eval() w1, _ = e_ema(frames1) w2, _ = e_ema(frames2) else: w1, _ = encoder(frames1) w2, _ = encoder(frames2) dw = w2 - w1 noise = mixing_noise(args.batch, args.latent, args.mixing, device) style = generator.get_styles(noise).view(args.batch, -1) if dw.shape[1] < style.shape[1]: # W space dw = dw.repeat(1, args.n_latent) cross_img, _ = generator([style + dw], input_is_latent=True) fake_cross_pred = discriminator(cross_img) g_loss_fake_cross = g_nonsaturating_loss(fake_cross_pred) g_loss = (g_loss_fake * args.lambda_fake_g + g_loss_rec * args.lambda_rec_g + g_loss_cross * args.lambda_cross_g + g_loss_fake_cross * args.lambda_fake_cross_g) loss_dict["g"] = g_loss generator.zero_grad() g_loss.backward() g_optim.step() g_regularize = args.g_reg_every > 0 and i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) fake_img, latents = generator(noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() accumulate(e_ema, e_module, accum) accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() pix_loss_val = loss_reduced["pix"].mean().item() vgg_loss_val = loss_reduced["vgg"].mean().item() adv_loss_val = loss_reduced["adv"].mean().item() avg_pix_loss.update(pix_loss_val, real_img.shape[0]) avg_vgg_loss.update(vgg_loss_val, real_img.shape[0]) if get_rank() == 0: pbar.set_description(( f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}; " f"pix: {pix_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; adv: {adv_loss_val:.4f}" )) if i % args.log_every == 0: with torch.no_grad(): latent_x, _ = e_ema(sample_x) fake_x, _ = generator([latent_x], input_is_latent=True, return_latents=False) sample_pix_loss = torch.sum((sample_x - fake_x)**2) with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write( f"{i:07d}; pix: {avg_pix_loss.avg}; vgg: {avg_vgg_loss.avg}; " f"ref: {sample_pix_loss.item()};\n") if args.eval_every > 0 and i % args.eval_every == 0: with torch.no_grad(): g_ema.eval() if args.truncation < 1: mean_latent = g_ema.mean_latent(4096) features = extract_feature_from_samples( g_ema, inception, args.truncation, mean_latent, 64, args.n_sample_fid, args.device).numpy() sample_mean = np.mean(features, 0) sample_cov = np.cov(features, rowvar=False) fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov) print("fid:", fid) with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f: f.write(f"{i:07d}; fid: {float(fid):.4f};\n") if wandb and args.wandb: wandb.log({ "Generator": g_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, }) if i % args.log_every == 0: with torch.no_grad(): # Fixed fake samples g_ema.eval() sample, _ = g_ema([sample_z]) utils.save_image( sample, os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-sample.png"), nrow=int(args.n_sample**0.5), normalize=True, value_range=(-1, 1), ) # Reconstruction samples e_ema.eval() nrow = int(args.n_sample**0.5) nchw = list(sample_x.shape)[1:] latent_real, _ = e_ema(sample_x) fake_img, _ = g_ema([latent_real], input_is_latent=True, return_latents=False) sample = torch.cat( (sample_x.reshape(args.n_sample // nrow, nrow, *nchw), fake_img.reshape(args.n_sample // nrow, nrow, *nchw)), 1) utils.save_image( sample.reshape(2 * args.n_sample, *nchw), os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-recon.png"), nrow=nrow, normalize=True, value_range=(-1, 1), ) if i % args.save_every == 0: torch.save( { "g": g_module.state_dict(), "e": e_module.state_dict(), "d": d_module.state_dict(), "d2": d2_module.state_dict() if args.decouple_d else None, "g_ema": g_ema.state_dict(), "e_ema": e_ema.state_dict(), "g_optim": g_optim.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "d2_optim": d2_optim.state_dict() if args.decouple_d else None, "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"), ) if i % args.save_latest_every == 0: torch.save( { "g": g_module.state_dict(), "e": e_module.state_dict(), "d": d_module.state_dict(), "d2": d2_module.state_dict() if args.decouple_d else None, "g_ema": g_ema.state_dict(), "e_ema": e_ema.state_dict(), "g_optim": g_optim.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "d2_optim": d2_optim.state_dict() if args.decouple_d else None, "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"latest.pt"), )
def train(args, loader, drs_loader, generator, discriminator, drs_discriminator, g_optim, d_optim, drs_d_optim, g_ema, device, output_path): # loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} if args.distributed: g_module = generator.module d_module = discriminator.module drs_d_module = drs_discriminator.module else: g_module = generator d_module = discriminator drs_d_module = drs_discriminator logit_results = defaultdict(dict) accum = 0.5**(32 / (10 * 1000)) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256, device) sample_z = torch.randn(args.n_sample, args.latent, device=device) iter_dataloader = iter(loader) iter_drs_dataloader = iter(drs_loader) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break try: real_img, _ = next(iter_dataloader) except StopIteration: iter_dataloader = iter(loader) real_img, _ = next(iter_dataloader) try: drs_real_img, _ = next(iter_drs_dataloader) except StopIteration: iter_drs_dataloader = iter(drs_loader) drs_real_img, _ = next(iter_drs_dataloader) real_img = real_img.to(device) drs_real_img = drs_real_img.to(device) requires_grad(generator, False) requires_grad(discriminator, True) requires_grad(drs_discriminator, True) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) drs_real_img_aug, _ = augment(drs_real_img, ada_aug_p) fake_img, _ = augment(fake_img, ada_aug_p) else: real_img_aug = real_img drs_real_img_aug = drs_real_img fake_pred = discriminator(fake_img) real_pred = discriminator(real_img_aug) drs_fake_pred = drs_discriminator(fake_img) drs_real_pred = drs_discriminator(drs_real_img_aug) d_loss = d_logistic_loss(real_pred, fake_pred) drs_d_loss = d_logistic_loss(drs_real_pred, drs_fake_pred) loss_dict["d"] = d_loss loss_dict["drs_d"] = drs_d_loss loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() drs_discriminator.zero_grad() drs_d_loss.backward() drs_d_optim.step() if args.augment and args.augment_p == 0: ada_aug_p = ada_augment.tune(real_pred) r_t_stat = ada_augment.r_t_stat d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True real_pred = discriminator(real_img) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() drs_real_img.requires_grad = True drs_real_pred = drs_discriminator(drs_real_img) drs_r1_loss = d_r1_loss(drs_real_pred, drs_real_img) drs_discriminator.zero_grad() (args.r1 / 2 * drs_r1_loss * args.d_reg_every + 0 * drs_real_pred[0]).backward() drs_d_optim.step() loss_dict["r1"] = r1_loss requires_grad(generator, True) requires_grad(discriminator, False) requires_grad(drs_discriminator, False) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) if args.augment: fake_img, _ = augment(fake_img, ada_aug_p) fake_pred = discriminator(fake_img) g_loss = g_nonsaturating_loss(fake_pred) loss_dict["g"] = g_loss generator.zero_grad() g_loss.backward() g_optim.step() g_regularize = i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) fake_img, latents = generator(noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() drs_d_loss_val = loss_reduced["drs_d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() if get_rank() == 0 and i > 0: pbar.set_description(( f"d: {d_loss_val:.4f}; drs_d: {drs_d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}")) if wandb and args.wandb: wandb.log({ "Generator": g_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, }) if i % 100 == 0: with torch.no_grad(): g_ema.eval() sample, _ = g_ema([sample_z]) save_path = output_path / 'fixed_sample' save_path.mkdir(parents=True, exist_ok=True) utils.save_image( sample, save_path / f"{str(i).zfill(6)}.png", nrow=int(args.n_sample**0.5), normalize=True, range=(-1, 1), ) random_sample_z = torch.randn(args.n_sample, args.latent, device=device) sample, _ = g_ema([random_sample_z]) save_path = output_path / 'random_sample' save_path.mkdir(parents=True, exist_ok=True) utils.save_image( sample, save_path / f"{str(i).zfill(6)}.png", nrow=int(args.n_sample**0.5), normalize=True, range=(-1, 1), ) if i % 5000 == 0: save_path = output_path / 'checkpoint' save_path.mkdir(parents=True, exist_ok=True) torch.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), "drs_d": drs_d_module.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "d_optim": d_optim.state_dict(), "drs_d_optim": drs_d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, }, save_path / f"{str(i).zfill(6)}.pt", )
def train2(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device, logger): loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) # init vars mean_path_length = 0 mean_path_length_avg = 0 losses = {} if args.distributed: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator accum = 0.5**(32 / (10 * 1000)) ada_augment = torch.tensor([0.0, 0.0], device=device) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 ada_aug_step = args.ada_target / args.ada_length r_t_stat = 0 sample_z = torch.randn(args.n_sample, args.latent, device=device) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break real_img = next(loader) real_img = real_img.to(device) requires_grad(generator, False) requires_grad(discriminator, True) loss, real_pred = calc_loss_1(args, generator, discriminator, real_img, ada_aug_p) d_loss = loss['d'] discriminator.zero_grad() # d_loss.backward() d_optim.step() losses.update(loss) # if args.augment and args.augment_p == 0: ada_augment_data = torch.tensor( (torch.sign(real_pred).sum().item(), real_pred.shape[0]), device=device) ada_augment += reduce_sum(ada_augment_data) if ada_augment[1] > 255: pred_signs, n_pred = ada_augment.tolist() r_t_stat = pred_signs / n_pred if r_t_stat > args.ada_target: sign = 1 else: sign = -1 ada_aug_p += sign * ada_aug_step * n_pred ada_aug_p = min(1, max(0, ada_aug_p)) ada_augment.mul_(0) # loss2 d_regularize = i % args.d_reg_every == 0 if d_regularize: loss, real_pred = calc_loss_2(discriminator, real_img) r1_loss = loss['r1'] discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() losses.update(loss) # loss 3 requires_grad(generator, True) requires_grad(discriminator, False) loss = calc_loss_3(args, generator, discriminator, ada_aug_p) g_loss = loss['g'] losses.update(loss) generator.zero_grad() g_loss.backward() g_optim.step() g_regularize = i % args.g_reg_every == 0 if g_regularize: loss, mean_path_length = calc_loss_4(args, generator, device, mean_path_length) loss['weighted_path'].backward() g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) losses.update(loss) accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(losses) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() if get_rank() == 0: pbar.set_description(( f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}")) # write log write_log(loss_reduced, 'train', i, logger) if wandb and args.wandb: wandb.log({ "Generator": g_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, }) if i % 100 == 0: with torch.no_grad(): g_ema.eval() sample, _ = g_ema([sample_z]) utils.save_image( sample, f"sample/{str(i).zfill(6)}.png", nrow=int(args.n_sample**0.5), normalize=True, range=(-1, 1), ) if i % 10000 == 0: torch.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, }, f"checkpoint/{str(i).zfill(6)}.pt", )
if PATH_BATCH_SHRINK: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() gc.collect() torch.cuda.empty_cache() if get_rank() == 0: pbar.set_description(( f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
def train(args, loader_src, loader_norm, generator, discriminator, ExpertModel, g_optim, d_optim, g_ema, device): # Save Path date = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') ImgSavePath = 'sample/{}'.format(date) CheckpointSavePath = 'checkpoint/{}'.format(date) if not os.path.exists(ImgSavePath): os.makedirs(ImgSavePath) if not os.path.exists(CheckpointSavePath): os.makedirs(CheckpointSavePath) shutil.copy('./train.py', './{}/train.py'.format(CheckpointSavePath)) shutil.copy('./model.py', './{}/model.py'.format(CheckpointSavePath)) loader_src = sample_data(loader_src) loader_norm = sample_data(loader_norm) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 r1_loss = torch.tensor(0.0, device=device) path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} if args.distributed: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator accum = 0.5**(32 / (10 * 1000)) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break real_img = next(loader_src) # source set tgt_img = next(loader_norm) # normal set real_img = real_img.to(device) tgt_img = tgt_img.to(device) #################################### Train discrimiantor #################################### requires_grad(generator, False) requires_grad(discriminator, True) Profile_Fea, Profile_Map = ExpertModel( TrainingSize_Select(real_img, device, args), args) Profile_Syn_Img, _ = generator(Profile_Fea, Profile_Map) Front_Fea, Front_Map = ExpertModel( TrainingSize_Select(tgt_img, device, args), args) Front_Syn_Img, _ = generator(Front_Fea, Front_Map) Profile_Syn_Pred = discriminator(Profile_Syn_Img) Front_Syn_Pred = discriminator(Front_Syn_Img) Real_Pred = discriminator(tgt_img) d_loss = (d_logistic_loss(Real_Pred, Profile_Syn_Pred) + d_logistic_loss(Real_Pred, Front_Syn_Pred)) / 2 loss_dict["d"] = d_loss loss_dict["real_score"] = Real_Pred.mean() loss_dict["profile_fake_score"] = Profile_Syn_Pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True real_pred = discriminator(real_img) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict["r1"] = r1_loss #################################### Train generator #################################### requires_grad(generator, True) requires_grad(discriminator, False) Front_Fea, Front_Map = ExpertModel( TrainingSize_Select(tgt_img, device, args), args) Front_Syn_Img, _ = generator(Front_Fea, Front_Map) Front_Syn_Pred = discriminator(Front_Syn_Img) Front_Syn_Fea, _ = ExpertModel( TrainingSize_Select(Front_Syn_Img, device, args), args) Profile_Fea, Profile_Map = ExpertModel( TrainingSize_Select(real_img, device, args), args) Profile_Syn_Img, _ = generator(Profile_Fea, Profile_Map) Profile_Syn_Pred = discriminator(Profile_Syn_Img) Profile_Syn_Fea, _ = ExpertModel( TrainingSize_Select(Profile_Syn_Img, device, args), args) adv_g_loss = (g_nonsaturating_loss(Profile_Syn_Pred) + g_nonsaturating_loss(Front_Syn_Pred)) / 2 fea_loss = (feature_loss(Profile_Syn_Fea[0], Profile_Fea[0]) + feature_loss(Front_Syn_Fea[0], Front_Fea[0])) / 2 sym_loss = (SymLoss(Front_Syn_Img) + SymLoss(Profile_Syn_Img)) / 2 L1_loss = L1Loss(Front_Syn_Img, tgt_img) g_loss = args.lambda_adv * adv_g_loss + args.lambda_fea * fea_loss + args.lambda_sym * sym_loss + args.lambda_l1 * L1_loss loss_dict["g"] = g_loss loss_dict["adv_g_loss"] = args.lambda_adv * adv_g_loss loss_dict["fea_loss"] = args.lambda_fea * fea_loss loss_dict["symmetry_loss"] = args.lambda_sym * sym_loss generator.zero_grad() g_loss.backward() g_optim.step() g_regularize = i % args.g_reg_every == 0 if g_regularize: noise, noise_map = ExpertModel( TrainingSize_Select(real_img, device, args), args) fake_img, latents = generator(noise, noise_map, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() fea_loss_val = loss_reduced["fea_loss"].mean().item() sym_loss_val = loss_reduced["symmetry_loss"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() profile_fake_score_val = loss_reduced["profile_fake_score"].mean( ).item() path_length_val = loss_reduced["path_length"].mean().item() if get_rank() == 0: pbar.set_description(( f"d: {d_loss_val:.4f}; g_total: {g_loss_val:.4f}; fea: {fea_loss_val:.4f}; sym: {sym_loss_val:.4f}; r1: {r1_val:.4f};" f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}" )) if wandb and args.wandb: wandb.log({ "Generator": g_loss_val, "Discriminator": d_loss_val, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Profile Score": profile_fake_score_val, "Path Length": path_length_val, }) if i % 100 == 0: with torch.no_grad(): g_ema.eval() pro_fea, pro_map = ExpertModel( TrainingSize_Select(real_img, device, args), args) pro_syn, _ = g_ema(pro_fea, pro_map) tgt_fea, tgt_map = ExpertModel( TrainingSize_Select(tgt_img, device, args), args) tgt_syn, _ = g_ema(tgt_fea, tgt_map) result = torch.cat([real_img, pro_syn, tgt_img, tgt_syn], 2) utils.save_image( result, f"{ImgSavePath}/{str(i).zfill(6)}.png", nrow=int(args.n_sample**0.5), normalize=True, range=(-1, 1), ) if i % 100 == 0: torch.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "d_optim": d_optim.state_dict(), }, f"{CheckpointSavePath}/{str(i).zfill(6)}.pt", )
def train(args, loader, loader2, generator, encoder, discriminator, vggnet, g_optim, e_optim, d_optim, g_ema, e_ema, device): inception = real_mean = real_cov = mean_latent = None if args.eval_every > 0: inception = nn.DataParallel(load_patched_inception_v3()).to(device) inception.eval() with open(args.inception, "rb") as f: embeds = pickle.load(f) real_mean = embeds["mean"] real_cov = embeds["cov"] if get_rank() == 0: if args.eval_every > 0: with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f: f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n") if args.log_every > 0: with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n") loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 d_loss_val = r1_val = real_score_val = recx_score_val = 0 loss_dict = { "d": torch.tensor(0.0, device=device), "r1": torch.tensor(0.0, device=device) } avg_pix_loss = util.AverageMeter() avg_vgg_loss = util.AverageMeter() if args.distributed: g_module = generator.module e_module = encoder.module d_module = discriminator.module else: g_module = generator e_module = encoder d_module = discriminator d_weight = torch.tensor(1.0, device=device) last_layer = None if args.use_adaptive_weight: if args.distributed: last_layer = generator.module.get_last_layer() else: last_layer = generator.get_last_layer() # accum = 0.5 ** (32 / (10 * 1000)) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 r_t_dict = {'real': 0, 'recx': 0} # r_t stat g_scale = 1 if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, args.ada_every, device) sample_z = torch.randn(args.n_sample, args.latent, device=device) sample_x = load_real_samples(args, loader) if sample_x.ndim > 4: sample_x = sample_x[:, 0, ...] n_step_max = max(args.n_step_d, args.n_step_e) requires_grad(g_ema, False) requires_grad(e_ema, False) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break if args.debug: util.seed_everything(i) real_imgs = [next(loader).to(device) for _ in range(n_step_max)] # Train Discriminator if args.lambda_adv > 0: requires_grad(generator, False) requires_grad(encoder, False) requires_grad(discriminator, True) for step_index in range(args.n_step_d): real_img = real_imgs[step_index] latent_real, _ = encoder(real_img) rec_img, _ = generator([latent_real], input_is_latent=True) if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) rec_img_aug, _ = augment(rec_img, ada_aug_p) else: real_img_aug = real_img rec_img_aug = rec_img real_pred = discriminator(real_img_aug) rec_pred = discriminator(rec_img_aug) d_loss_real = F.softplus(-real_pred).mean() d_loss_rec = F.softplus(rec_pred).mean() loss_dict["real_score"] = real_pred.mean() loss_dict["recx_score"] = rec_pred.mean() d_loss = d_loss_real + d_loss_rec * args.lambda_rec_d loss_dict["d"] = d_loss discriminator.zero_grad() d_loss.backward() d_optim.step() if args.augment and args.augment_p == 0: ada_aug_p = ada_augment.tune(real_pred) r_t_stat = ada_augment.r_t_stat # Compute batchwise r_t r_t_dict['real'] = torch.sign(real_pred).sum().item() / args.batch d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) else: real_img_aug = real_img real_pred = discriminator(real_img_aug) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict["r1"] = r1_loss r_t_dict['recx'] = torch.sign(rec_pred).sum().item() / args.batch # Train AutoEncoder requires_grad(encoder, True) requires_grad(generator, True) requires_grad(discriminator, False) if args.debug: util.seed_everything(i) pix_loss = vgg_loss = adv_loss = torch.tensor(0., device=device) for step_index in range(args.n_step_e): real_img = real_imgs[step_index] latent_real, _ = encoder(real_img) rec_img, _ = generator([latent_real], input_is_latent=True) if args.lambda_pix > 0: if args.pix_loss == 'l2': pix_loss = torch.mean((rec_img - real_img)**2) elif args.pix_loss == 'l1': pix_loss = F.l1_loss(rec_img, real_img) if args.lambda_vgg > 0: vgg_loss = torch.mean((vggnet(real_img) - vggnet(rec_img))**2) if args.lambda_adv > 0: if args.augment: rec_img_aug, _ = augment(rec_img, ada_aug_p) else: rec_img_aug = rec_img rec_pred = discriminator(rec_img_aug) adv_loss = g_nonsaturating_loss(rec_pred) if args.use_adaptive_weight and i >= args.disc_iter_start: nll_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg g_loss = adv_loss * args.lambda_adv d_weight = calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) ae_loss = (pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + d_weight * adv_loss * args.lambda_adv) loss_dict["ae"] = ae_loss loss_dict["pix"] = pix_loss loss_dict["vgg"] = vgg_loss loss_dict["adv"] = adv_loss encoder.zero_grad() generator.zero_grad() ae_loss.backward() e_optim.step() if args.g_decay is not None: scale_grad(generator, g_scale) g_scale *= args.g_decay g_optim.step() g_regularize = args.g_reg_every > 0 and i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) fake_img, latents = generator(noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() # Update EMA ema_nimg = args.ema_kimg * 1000 if args.ema_rampup is not None: ema_nimg = min(ema_nimg, i * args.batch * args.ema_rampup) accum = 0.5**(args.batch / max(ema_nimg, 1e-8)) accumulate(g_ema, g_module, 0 if args.no_ema_g else accum) accumulate(e_ema, e_module, 0 if args.no_ema_e else accum) loss_reduced = reduce_loss_dict(loss_dict) ae_loss_val = loss_reduced["ae"].mean().item() path_loss_val = loss_reduced["path"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() pix_loss_val = loss_reduced["pix"].mean().item() vgg_loss_val = loss_reduced["vgg"].mean().item() adv_loss_val = loss_reduced["adv"].mean().item() if args.lambda_adv > 0: d_loss_val = loss_reduced["d"].mean().item() r1_val = loss_reduced["r1"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() recx_score_val = loss_reduced["recx_score"].mean().item() avg_pix_loss.update(pix_loss_val, real_img.shape[0]) avg_vgg_loss.update(vgg_loss_val, real_img.shape[0]) if get_rank() == 0: pbar.set_description(( f"d: {d_loss_val:.4f}; ae: {ae_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}; " f"d_weight: {d_weight.item():.4f}; " f"pix: {pix_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; adv: {adv_loss_val:.4f}" )) if i % args.log_every == 0: with torch.no_grad(): g_ema.eval() e_ema.eval() nrow = int(args.n_sample**0.5) nchw = list(sample_x.shape)[1:] # Reconstruction of real images latent_x, _ = e_ema(sample_x) rec_real, _ = g_ema([latent_x], input_is_latent=True) sample = torch.cat( (sample_x.reshape(args.n_sample // nrow, nrow, *nchw), rec_real.reshape(args.n_sample // nrow, nrow, *nchw)), 1) utils.save_image( sample.reshape(2 * args.n_sample, *nchw), os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-recon.png"), nrow=nrow, normalize=True, value_range=(-1, 1), ) ref_pix_loss = torch.sum(torch.abs(sample_x - rec_real)) ref_vgg_loss = torch.mean( (vggnet(sample_x) - vggnet(rec_real))**2) if vggnet is not None else 0 # Fixed fake samples and reconstructions sample, _ = g_ema([sample_z]) utils.save_image( sample, os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-sample.png"), nrow=int(args.n_sample**0.5), normalize=True, value_range=(-1, 1), ) with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write(( f"{i:07d}; " f"d: {d_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean_path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}; {'; '.join([f'{k}: {r_t_dict[k]:.4f}' for k in r_t_dict])}; " f"real_score: {real_score_val:.4f}; recx_score: {recx_score_val:.4f}; " f"pix: {avg_pix_loss.avg:.4f}; vgg: {avg_vgg_loss.avg:.4f}; " f"ref_pix: {ref_pix_loss.item():.4f}; ref_vgg: {ref_vgg_loss.item():.4f}; " f"d_weight: {d_weight.item():.4f}; " f"\n")) if wandb and args.wandb: wandb.log({ "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Path Length": path_length_val, }) if args.eval_every > 0 and i % args.eval_every == 0: with torch.no_grad(): fid_sa = fid_re = fid_sr = 0 g_ema.eval() e_ema.eval() if args.truncation < 1: mean_latent = g_ema.mean_latent(4096) # Real reconstruction FID if 'fid_recon' in args.which_metric: features = extract_feature_from_reconstruction( e_ema, g_ema, inception, args.truncation, mean_latent, loader2, args.device, mode='recon', ).numpy() sample_mean = np.mean(features, 0) sample_cov = np.cov(features, rowvar=False) fid_re = calc_fid(sample_mean, sample_cov, real_mean, real_cov) with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f: f.write(f"{i:07d}; rec_real: {float(fid_re):.4f};\n") if i % args.save_every == 0: torch.save( { "g": g_module.state_dict(), "e": e_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "e_ema": e_ema.state_dict(), "g_optim": g_optim.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"), ) if i % args.save_latest_every == 0: torch.save( { "g": g_module.state_dict(), "e": e_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "e_ema": e_ema.state_dict(), "g_optim": g_optim.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"latest.pt"), )
def train( args, loader, encoder, generator, discriminator, discriminator3d, # video disctiminator posterior, prior, factor, # a learnable matrix vggnet, e_optim, d_optim, dv_optim, q_optim, # q for posterior p_optim, # p for prior f_optim, # f for factor e_ema, device ): loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) d_loss_val = 0 e_loss_val = 0 rec_loss_val = 0 vgg_loss_val = 0 adv_loss_val = 0 loss_dict = {"d": torch.tensor(0., device=device), "real_score": torch.tensor(0., device=device), "fake_score": torch.tensor(0., device=device), "r1_d": torch.tensor(0., device=device), "r1_e": torch.tensor(0., device=device), "rec": torch.tensor(0., device=device),} if args.distributed: e_module = encoder.module d_module = discriminator.module g_module = generator.module else: e_module = encoder d_module = discriminator g_module = generator accum = 0.5 ** (32 / (10 * 1000)) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 latent_full = args.latent_full factor_dim_full = args.factor_dim_full if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256, device) sample_x = accumulate_batches(loader, args.n_sample).to(device) utils.save_image( sample_x.view(-1, *list(sample_x.shape)[2:]), os.path.join(args.log_dir, 'sample', f"real-img.png"), nrow=sample_x.shape[1], normalize=True, value_range=(-1, 1), ) util.save_video( sample_x[0], os.path.join(args.log_dir, 'sample', f"real-vid.mp4") ) requires_grad(generator, False) # always False generator.eval() # Generator should be ema and in eval mode if args.no_update_encoder: encoder = e_ema if e_ema is not None else encoder requires_grad(encoder, False) encoder.eval() from models.networks_3d import GANLoss criterionGAN = GANLoss() # criterionL1 = nn.L1Loss() # if args.no_ema or e_ema is None: # e_ema = encoder for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break data = next(loader) real_seq = data['frames'] real_seq = real_seq.to(device) # [N, T, C, H, W] shape = list(real_seq.shape) N, T = shape[:2] # Train Encoder with frame-level objectives if args.toggle_grads: if not args.no_update_encoder: requires_grad(encoder, True) requires_grad(discriminator, False) pix_loss = vgg_loss = adv_loss = rec_loss = vid_loss = l1y_loss = torch.tensor(0., device=device) # TODO: real_seq -> encoder -> posterior -> generator -> fake_seq # f: [N, latent_full]; y: [N, T, D] fake_img, fake_seq, y_post = reconstruct_sequence(args, real_seq, encoder, generator, factor, posterior, i, ret_y=True) # if args.debug == 'no_lstm': # real_lat = encoder(real_seq.view(-1, *shape[2:])) # fake_img, _ = generator([real_lat], input_is_latent=True, return_latents=False) # fake_seq = fake_img.view(N, T, *shape[2:]) # elif args.debug == 'decomp': # real_lat = encoder(real_seq.view(-1, *shape[2:])) # [N*T, latent_full] # f_post = real_lat[::T, ...] # z_post = real_lat.view(N, T, -1) - f_post.unsqueeze(1) # if args.use_multi_head: # y_post = [] # for z, w in zip(torch.split(z_post, 512, 2), factor.weight): # y_post.append(torch.mm(z.view(N*T, -1), w).view(N, T, -1)) # y_post = torch.cat(y_post, 2) # else: # y_post = torch.mm(z_post.view(N*T, -1), factor.weight[0]).view(N, T, -1) # z_post_hat = factor(y_post) # f_expand = f_post.unsqueeze(1).expand(-1, T, -1) # w_post = f_expand + z_post_hat # fake_img, _ = generator([w_post.view(N*T, latent_full)], input_is_latent=True, return_latents=False) # fake_seq = fake_img.view(N, T, *shape[2:]) # else: # real_lat = encoder(real_seq.view(-1, *shape[2:])) # # single head: f_post [N, latent_full]; y_post [N, T, D] # # multi head: f_post [N, n_latent, latent]; y_post [N, T, n_latent, d] # f_post, y_post = posterior(real_lat.view(N, T, latent_full)) # z_post = factor(y_post) # f_expand = f_post.unsqueeze(1).expand(-1, T, -1) # w_post = f_expand + z_post # shape [N, T, latent_full] # fake_img, _ = generator([w_post.view(N*T, latent_full)], input_is_latent=True, return_latents=False) # fake_seq = fake_img.view(N, T, *shape[2:]) # TODO: sample frames real_img = real_seq.view(N*T, *shape[2:]) # fake_img = fake_seq.view(N*T, *shape[2:]) if args.lambda_adv > 0: if args.augment: fake_img_aug, _ = augment(fake_img, ada_aug_p) else: fake_img_aug = fake_img fake_pred = discriminator(fake_img_aug) adv_loss = g_nonsaturating_loss(fake_pred) # TODO: do we always put pix and vgg loss for all frames? if args.lambda_pix > 0: pix_loss = torch.mean((real_img - fake_img) ** 2) if args.lambda_vgg > 0: real_feat = vggnet(real_img) fake_feat = vggnet(fake_img) vgg_loss = torch.mean((real_feat - fake_feat) ** 2) # Train Encoder with video-level objectives # TODO: video adversarial loss if args.lambda_vid > 0: fake_pred = discriminator3d(flip_video(fake_seq.transpose(1, 2))) vid_loss = criterionGAN(fake_pred, True) if args.lambda_l1y > 0: # l1y_loss = criterionL1(y_post) l1y_loss = torch.mean(torch.abs(y_post)) e_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv e_loss = e_loss + args.lambda_vid * vid_loss + args.lambda_l1y * l1y_loss loss_dict["e"] = e_loss loss_dict["pix"] = pix_loss loss_dict["vgg"] = vgg_loss loss_dict["adv"] = adv_loss if not args.no_update_encoder: encoder.zero_grad() posterior.zero_grad() e_loss.backward() q_optim.step() if not args.no_update_encoder: e_optim.step() # if args.train_on_fake: # e_regularize = args.e_rec_every > 0 and i % args.e_rec_every == 0 # if e_regularize and args.lambda_rec > 0: # noise = mixing_noise(args.batch, args.latent, args.mixing, device) # fake_img, latent_fake = generator(noise, input_is_latent=False, return_latents=True) # latent_pred = encoder(fake_img) # if latent_pred.ndim < 3: # latent_pred = latent_pred.unsqueeze(1).repeat(1, latent_fake.size(1), 1) # rec_loss = torch.mean((latent_fake - latent_pred) ** 2) # encoder.zero_grad() # (rec_loss * args.lambda_rec).backward() # e_optim.step() # loss_dict["rec"] = rec_loss # e_regularize = args.e_reg_every > 0 and i % args.e_reg_every == 0 # if e_regularize: # # why not regularize on augmented real? # real_img.requires_grad = True # real_pred = encoder(real_img) # r1_loss_e = d_r1_loss(real_pred, real_img) # encoder.zero_grad() # (args.r1 / 2 * r1_loss_e * args.e_reg_every + 0 * real_pred.view(-1)[0]).backward() # e_optim.step() # loss_dict["r1_e"] = r1_loss_e if not args.no_update_encoder: if not args.no_ema and e_ema is not None: accumulate(e_ema, e_module, accum) # Train Discriminator if args.toggle_grads: requires_grad(encoder, False) requires_grad(discriminator, True) fake_img, fake_seq = reconstruct_sequence(args, real_seq, encoder, generator, factor, posterior) # if args.debug == 'no_lstm': # real_lat = encoder(real_seq.view(-1, *shape[2:])) # fake_img, _ = generator([real_lat], input_is_latent=True, return_latents=False) # fake_seq = fake_img.view(N, T, *shape[2:]) # elif args.debug == 'decomp': # real_lat = encoder(real_seq.view(-1, *shape[2:])) # [N*T, latent_full] # f_post = real_lat[::T, ...] # z_post = real_lat.view(N, T, -1) - f_post.unsqueeze(1) # if args.use_multi_head: # y_post = [] # for z, w in zip(torch.split(z_post, 512, 2), factor.weight): # y_post.append(torch.mm(z.view(N*T, -1), w).view(N, T, -1)) # y_post = torch.cat(y_post, 2) # else: # y_post = torch.mm(z_post.view(N*T, -1), factor.weight[0]).view(N, T, -1) # z_post_hat = factor(y_post) # f_expand = f_post.unsqueeze(1).expand(-1, T, -1) # w_post = f_expand + z_post_hat # fake_img, _ = generator([w_post.view(N*T, latent_full)], input_is_latent=True, return_latents=False) # fake_seq = fake_img.view(N, T, *shape[2:]) # elif args.debug == 'coef': # real_lat = encoder(real_seq.view(-1, *shape[2:])) # [N*T, latent_full] # f_post = real_lat[::T, ...] # z_post_hat = real_lat.view(N, T, -1) - f_post.unsqueeze(1) # y_post = torch.mm(z_post_hat.view(N*T, -1), factor.weight).view(N, T, -1) # z_post = factor(y_post) # f_expand = f_post.unsqueeze(1).expand(-1, T, -1) # w_post = f_expand + z_post # fake_img, _ = generator([w_post.view(N*T, latent_full)], input_is_latent=True, return_latents=False) # fake_seq = fake_img.view(N, T, *shape[2:]) # else: # real_lat = encoder(real_seq.view(-1, *shape[2:])) # f_post, y_post = posterior(real_lat.view(N, T, latent_full)) # z_post = factor(y_post) # f_expand = f_post.unsqueeze(1).expand(-1, T, -1) # w_post = f_expand + z_post # shape [N, T, latent_full] # fake_img, _ = generator([w_post.view(N*T, latent_full)], input_is_latent=True, return_latents=False) # fake_seq = fake_img.view(N, T, *shape[2:]) # fake_img = fake_seq.view(N*T, *shape[2:]) if not args.no_update_discriminator: if args.lambda_adv > 0: if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) fake_img_aug, _ = augment(fake_img, ada_aug_p) else: real_img_aug = real_img fake_img_aug = fake_img fake_pred = discriminator(fake_img_aug) real_pred = discriminator(real_img_aug) d_loss = d_logistic_loss(real_pred, fake_pred) # Train video discriminator if args.lambda_vid > 0: pred_real = discriminator3d(flip_video(real_seq.transpose(1, 2))) pred_fake = discriminator3d(flip_video(fake_seq.transpose(1, 2))) dv_loss_real = criterionGAN(pred_real, True) dv_loss_fake = criterionGAN(pred_fake, False) dv_loss = 0.5 * (dv_loss_real + dv_loss_fake) d_loss = d_loss + dv_loss loss_dict["d"] = d_loss loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() if args.lambda_adv > 0: discriminator.zero_grad() if args.lambda_vid > 0: discriminator3d.zero_grad() d_loss.backward() if args.lambda_adv > 0: d_optim.step() if args.lambda_vid > 0: dv_optim.step() if args.augment and args.augment_p == 0: ada_aug_p = ada_augment.tune(real_pred) r_t_stat = ada_augment.r_t_stat d_regularize = args.d_reg_every > 0 and i % args.d_reg_every == 0 if d_regularize: # why not regularize on augmented real? real_img.requires_grad = True real_pred = discriminator(real_img) r1_loss_d = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss_d * args.d_reg_every + 0 * real_pred.view(-1)[0]).backward() # Why 0* ? Answer is here https://github.com/rosinality/stylegan2-pytorch/issues/76 d_optim.step() loss_dict["r1_d"] = r1_loss_d loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() e_loss_val = loss_reduced["e"].mean().item() r1_d_val = loss_reduced["r1_d"].mean().item() r1_e_val = loss_reduced["r1_e"].mean().item() pix_loss_val = loss_reduced["pix"].mean().item() vgg_loss_val = loss_reduced["vgg"].mean().item() adv_loss_val = loss_reduced["adv"].mean().item() rec_loss_val = loss_reduced["rec"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() if get_rank() == 0: pbar.set_description( ( f"d: {d_loss_val:.4f}; e: {e_loss_val:.4f}; r1_d: {r1_d_val:.4f}; r1_e: {r1_e_val:.4f}; " f"pix: {pix_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; adv: {adv_loss_val:.4f}; " f"rec: {rec_loss_val:.4f}; augment: {ada_aug_p:.4f}" ) ) if wandb and args.wandb: wandb.log( { "Encoder": e_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1 D": r1_d_val, "R1 E": r1_e_val, "Pix Loss": pix_loss_val, "VGG Loss": vgg_loss_val, "Adv Loss": adv_loss_val, "Rec Loss": rec_loss_val, "Real Score": real_score_val, "Fake Score": fake_score_val, } ) if i % args.log_every == 0: with torch.no_grad(): e_eval = encoder if args.no_ema else e_ema e_eval.eval() posterior.eval() # N = sample_x.shape[0] fake_img, fake_seq = reconstruct_sequence(args, sample_x, e_eval, generator, factor, posterior) # if args.debug == 'no_lstm': # real_lat = encoder(sample_x.view(-1, *shape[2:])) # fake_img, _ = generator([real_lat], input_is_latent=True, return_latents=False) # fake_seq = fake_img.view(N, T, *shape[2:]) # elif args.debug == 'decomp': # real_lat = encoder(sample_x.view(-1, *shape[2:])) # [N*T, latent_full] # f_post = real_lat[::T, ...] # z_post = real_lat.view(N, T, -1) - f_post.unsqueeze(1) # if args.use_multi_head: # y_post = [] # for z, w in zip(torch.split(z_post, 512, 2), factor.weight): # y_post.append(torch.mm(z.view(N*T, -1), w).view(N, T, -1)) # y_post = torch.cat(y_post, 2) # else: # y_post = torch.mm(z_post.view(N*T, -1), factor.weight[0]).view(N, T, -1) # z_post_hat = factor(y_post) # f_expand = f_post.unsqueeze(1).expand(-1, T, -1) # w_post = f_expand + z_post_hat # fake_img, _ = generator([w_post.view(N*T, latent_full)], input_is_latent=True, return_latents=False) # fake_seq = fake_img.view(N, T, *shape[2:]) # else: # x_lat = encoder(sample_x.view(-1, *shape[2:])) # f_post, y_post = posterior(x_lat.view(N, T, latent_full)) # z_post = factor(y_post) # f_expand = f_post.unsqueeze(1).expand(-1, T, -1) # w_post = f_expand + z_post # fake_img, _ = generator([w_post.view(N*T, latent_full)], input_is_latent=True, return_latents=False) # fake_seq = fake_img.view(N, T, *shape[2:]) utils.save_image( torch.cat((sample_x, fake_seq), 1).view(-1, *shape[2:]), os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-img_recon.png"), nrow=T, normalize=True, value_range=(-1, 1), ) util.save_video( fake_seq[random.randint(0, args.n_sample-1)], os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-vid_recon.mp4") ) fake_img, fake_seq = swap_sequence(args, sample_x, e_eval, generator, factor, posterior) utils.save_image( torch.cat((sample_x, fake_seq), 1).view(-1, *shape[2:]), os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-img_swap.png"), nrow=T, normalize=True, value_range=(-1, 1), ) e_eval.train() posterior.train() if i % args.save_every == 0: e_eval = encoder if args.no_ema else e_ema torch.save( { "e": e_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_module.state_dict(), "e_ema": e_eval.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"), ) if not args.debug and i % args.save_latest_every == 0: torch.save( { "e": e_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_module.state_dict(), "e_ema": e_eval.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, }, os.path.join(args.log_dir, 'weight', f"latest.pt"), )
def train(args, epoch, loader, target_loader, model, g_optimizer, l_optimizer, d_optimizer, device): model.train() global global_iter memory = Counter() pbar = tqdm(range(min(len(loader), len(target_loader)))) iterator = iter(loader) target_iterator = iter(target_loader) for i in pbar: try: images, targets, _ = next(iterator) target_images, target_targets, _ = next(target_iterator) except: iterator = iter(loader) target_iterator = iter(target_loader) images, targets, _ = next(iterator) target_images, target_targets, _ = next(target_iterator) global_iter += 1 model.zero_grad() images = images.to(device) targets = [target.to(device) for target in targets] target_images = target_images.to(device) target_targets = [target.to(device) for target in target_targets] _, loss_dict = model(images.tensors, targets=targets) loss_cls = loss_dict['loss_cls'].mean() loss_box = loss_dict['loss_box'].mean() loss_center = loss_dict['loss_center'].mean() loss = loss_cls + loss_box + loss_center del loss_cls, loss_box, loss_center loss.backward() nn.utils.clip_grad_norm_(model.parameters(), 10) g_optimizer.step() d_optimizer.step() l_optimizer.step() del loss, _ loss_reduced = reduce_loss_dict(loss_dict) loss_cls_item = loss_reduced['loss_cls'].mean().item() loss_box_item = loss_reduced['loss_box'].mean().item() loss_center_item = loss_reduced['loss_center'].mean().item() loss_discrep_item = loss_reduced['loss_discrep'].mean().item() del loss_dict, loss_reduced del images, targets, target_images, target_targets writer.add_scalar('loss/cls', loss_cls_item, global_iter) writer.add_scalar('loss/box', loss_box_item, global_iter) writer.add_scalar('loss/center', loss_center_item, global_iter) writer.add_scalar('loss/discrep', loss_discrep_item, global_iter) pbar.set_description(( f'epoch: {epoch + 1}; cls: {loss_cls_item:.4f}; ' f'box: {loss_box_item:.4f}; center: {loss_center_item:.4f}; discrep: {loss_discrep_item:.4f}' )) '''
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device): loader = sample_data(loader) current_ckpt = args.current_ckpt # generate one fake image to check data correct test_imgs = next(loader) real_grid = utils.make_grid(test_imgs, nrow=2, normalize=True, range=(-1, 1)) wandb.log({"reals": [wandb.Image(real_grid, caption='Real Data')]}) pbar = tqdm(dynamic_ncols=True, smoothing=0.01, initial=current_ckpt + 1, total=args.iter) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} if args.distributed: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator none_g_grads = set() test_in = torch.randn(1, args.latent, device=device) fake, latent = g_module([test_in], return_latents=True) path = g_path_regularize(fake, latent, 0) path[0].backward() for n, p in generator.named_parameters(): if p.grad is None: none_g_grads.add(n) test_in = torch.randn(1, 3, args.size, args.size, requires_grad=True, device=device) pred = d_module(test_in) r1_loss = d_r1_loss(pred, test_in) r1_loss.backward() none_d_grads = set() for n, p in discriminator.named_parameters(): if p.grad is None: none_d_grads.add(n) seed = torch.initial_seed() % 10000000 torch.manual_seed(20) torch.cuda.manual_seed_all(20) sample_z = torch.randn(4 * 4, args.latent, device=device) sample_z_chunks = torch.split(sample_z, args.batch) # reset seed torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) i = current_ckpt + 1 while i < args.iter: real_img = next(loader) real_img = real_img.to(device) requires_grad(generator, False) requires_grad(discriminator, True) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) fake_pred = discriminator(fake_img) real_pred = discriminator(real_img) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict['d'] = d_loss loss_dict['real_score'] = real_pred.mean() loss_dict['fake_score'] = fake_pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True real_pred = discriminator(real_img) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() set_grad_none(discriminator, none_d_grads) d_optim.step() loss_dict['r1'] = r1_loss requires_grad(generator, True) requires_grad(discriminator, False) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) fake_pred = discriminator(fake_img) g_loss = g_nonsaturating_loss(fake_pred) loss_dict['g'] = g_loss generator.zero_grad() g_loss.backward() g_optim.step() g_regularize = i % args.g_reg_every == 0 if g_regularize: noise = mixing_noise(args.batch // args.path_batch_shrink, args.latent, args.mixing, device) fake_img, latents = generator(noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() set_grad_none(g_module, none_g_grads) g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict['path'] = path_loss loss_dict['path_length'] = path_lengths.mean() accumulate(g_ema, g_module) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced['d'].mean().item() g_loss_val = loss_reduced['g'].mean().item() r1_val = loss_reduced['r1'].mean().item() path_loss_val = loss_reduced['path'].mean().item() real_score_val = loss_reduced['real_score'].mean().item() fake_score_val = loss_reduced['fake_score'].mean().item() path_length_val = loss_reduced['path_length'].mean().item() if get_rank() == 0: pbar.set_postfix(d_loss=f'{d_loss_val:.4f}', g_loss=f'{g_loss_val:.4f}', r1_loss=f'{r1_val:.4f}', path=f'{path_loss_val:.4f}', mean=f'{mean_path_length_avg:.4f}') if wandb and args.wandb: wandb.log({ 'Generator': g_loss_val, 'Discriminator': d_loss_val, 'R1': r1_val, 'Path Length Regularization': path_loss_val, 'Mean Path Length': mean_path_length, 'Real Score': real_score_val, 'Fake Score': fake_score_val, 'Path Length': path_length_val, 'current_ckpt': current_ckpt, 'iteration': i, }) if i % 500 == 0: with torch.no_grad(): g_ema.eval() sample = generate_fake_images(g_ema, sample_z_chunks) if wandb and args.wandb: label = f'{str(i).zfill(8)}.png' image = utils.make_grid(sample, nrow=4, normalize=True, range=(-1, 1)) wandb.log( {"samples": [wandb.Image(image, caption=label)]}) else: utils.save_image( sample, f'sample/{str(i).zfill(8)}.png', nrow=8, normalize=True, range=(-1, 1), ) if i % 2000 == 0: ckpt_name = f'checkpoint/{str(i).zfill(8)}.pt' # remove the previous checkpoint shutil.rmtree('checkpoint') os.mkdir('checkpoint') torch.save( { 'g': g_module.state_dict(), 'd': d_module.state_dict(), 'g_ema': g_ema.state_dict(), 'g_optim': g_optim.state_dict(), 'd_optim': d_optim.state_dict(), }, ckpt_name, ) current_ckpt = i if wandb and args.wandb: wandb.save(ckpt_name) i = i + 1 pbar.update()
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device, save_dir): loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} if args.distributed: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator accum = 0.5**(32 / (10 * 1000)) ada_augment = torch.tensor([0.0, 0.0], device=device) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 ada_aug_step = args.ada_target / args.ada_length r_t_stat = 0 sample_z = torch.randn(args.n_sample, args.latent, device=device) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break real_img = next(loader) real_img = real_img.to(device) requires_grad(generator, False) requires_grad(discriminator, True) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) fake_img, _ = augment(fake_img, ada_aug_p) else: real_img_aug = real_img fake_pred = discriminator(fake_img) real_pred = discriminator(real_img_aug) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict["d"] = d_loss loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() if args.augment and args.augment_p == 0: ada_augment_data = torch.tensor( (torch.sign(real_pred).sum().item(), real_pred.shape[0]), device=device) ada_augment += reduce_sum(ada_augment_data) if ada_augment[1] > 255: pred_signs, n_pred = ada_augment.tolist() r_t_stat = pred_signs / n_pred if r_t_stat > args.ada_target: sign = 1 else: sign = -1 ada_aug_p += sign * ada_aug_step * n_pred ada_aug_p = min(1, max(0, ada_aug_p)) ada_augment.mul_(0) d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True real_pred = discriminator(real_img) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict["r1"] = r1_loss requires_grad(generator, True) requires_grad(discriminator, False) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) if args.augment: fake_img, _ = augment(fake_img, ada_aug_p) fake_pred = discriminator(fake_img) g_loss = g_nonsaturating_loss(fake_pred) loss_dict["g"] = g_loss generator.zero_grad() g_loss.backward() g_optim.step() g_regularize = i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) fake_img, latents = generator(noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() if get_rank() == 0: pbar.set_description(( f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}")) if wandb and args.wandb: wandb.log({ "Generator": g_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, }) if i % 1000 == 0: # save some samples with torch.no_grad(): g_ema.eval() sample, _ = g_ema([sample_z]) length = int(round(np.sqrt(args.n_sample), 0)) black_bar_width = 7 template = np.zeros((length * int(args.size) + ((length - 1) * black_bar_width), length * (int(args.size) * 2) + ((length - 1) * black_bar_width), 3)) position = [0, 0] line_count = 1 for j in range(args.n_sample): # save the first image from this pair utils.save_image(torch.from_numpy( sample.detach().cpu().numpy()[j, :3, :, :]), "temp.png", nrow=int(args.n_sample**0.5), normalize=True, range=(-1, 1)) first_image = np.asarray(Image.open("temp.png")) # save the second image from this pair utils.save_image(torch.from_numpy( sample.detach().cpu().numpy()[j, 3:, :, :]), "temp.png", nrow=int(args.n_sample**0.5), normalize=True, range=(-1, 1)) second_image = np.asarray(Image.open("temp.png")) # -------------------------------------------------------------------------- # unfold the pair and add black borders, if needed # -------------------------------------------------------------------------- unfolded = np.column_stack((first_image, second_image)) #if needed, add a black column if (((j + 1) % length) != 0): unfolded = np.column_stack( (unfolded, np.zeros((args.size, black_bar_width, 3)))) # if needed, add a black line if ((line_count % length) != 0): unfolded = np.vstack( (unfolded, np.zeros( (black_bar_width, unfolded.shape[1], 3)))) # -------------------------------------------------------------------------------------------------------------------------------------------------- # add the pair to the template # -------------------------------------------------------------------------------------------------------------------------------------------------- if ((line_count % length) != 0): # it's not the last line if (((j + 1) % length) != 0): # it's not the last column template[position[0]:position[0] + args.size + black_bar_width, position[1]:position[1] + (args.size * 2) + black_bar_width, :] = unfolded else: # it's the last column template[position[0]:position[0] + args.size + black_bar_width, position[1]:position[1] + (args.size * 2), :] = unfolded else: # it's the last line if (((j + 1) % length) != 0): # it's not the last column template[position[0]:position[0] + args.size, position[1]:position[1] + (args.size * 2) + black_bar_width, :] = unfolded else: # it's the last column template[position[0]:position[0] + args.size, position[1]:position[1] + (args.size * 2), :] = unfolded # ------------------------------------------------------------- # prepare the next iteration # ------------------------------------------------------------- if (((j + 1) % length) == 0): position = [position[0] + unfolded.shape[0], 0] line_count += 1 else: position = [ position[0], position[1] + unfolded.shape[1] ] Image.fromarray(template.astype(np.uint8)).convert( "RGBA").save(save_dir + "/samples/" + f"/{str(i).zfill(6)}.png") os.remove("temp.png") if i % 2000 == 0: #save the model torch.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, }, save_dir + f"/checkpoints/{str(i).zfill(6)}.pt", )
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device): loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} if args.distributed: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator accum = 0.5**(32 / (10 * 1000)) sample_z = torch.randn(args.n_sample, args.latent, device=device) sample_labels = [] while len(sample_labels) < args.n_sample: real_img, real_label = next(loader) sample_labels.append(real_label.to(device)) sample_labels = torch.cat(sample_labels, 0)[:args.n_sample] for idx in pbar: i = idx + args.start_iter if i > args.iter: print('Done!') break real_img, real_label = next(loader) real_img = real_img.to(device) real_label = real_label.to(device) requires_grad(generator, False) requires_grad(discriminator, True) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(real_label, noise) fake_pred = discriminator(real_label, fake_img) real_pred = discriminator(real_label, real_img) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict['d'] = d_loss loss_dict['real_score'] = real_pred.mean() loss_dict['fake_score'] = fake_pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True real_pred = discriminator(real_label, real_img) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict['r1'] = r1_loss requires_grad(generator, True) requires_grad(discriminator, False) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(real_label, noise) fake_pred = discriminator(real_label, fake_img) g_loss = g_nonsaturating_loss(fake_pred) loss_dict['g'] = g_loss generator.zero_grad() g_loss.backward() g_optim.step() g_regularize = i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) fake_img, latents = generator(real_label[:path_batch_size], noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict['path'] = path_loss loss_dict['path_length'] = path_lengths.mean() accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced['d'].mean().item() g_loss_val = loss_reduced['g'].mean().item() r1_val = loss_reduced['r1'].mean().item() path_loss_val = loss_reduced['path'].mean().item() real_score_val = loss_reduced['real_score'].mean().item() fake_score_val = loss_reduced['fake_score'].mean().item() path_length_val = loss_reduced['path_length'].mean().item() if get_rank() == 0: pbar.set_description(( f'd: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; ' f'path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}' )) if wandb and args.wandb: wandb.log({ 'Generator': g_loss_val, 'Discriminator': d_loss_val, 'R1': r1_val, 'Path Length Regularization': path_loss_val, 'Mean Path Length': mean_path_length, 'Real Score': real_score_val, 'Fake Score': fake_score_val, 'Path Length': path_length_val, }) if i % 200 == 0: with torch.no_grad(): g_ema.eval() sample, _ = g_ema(sample_labels, [sample_z]) utils.save_image( sample, f'sample/{str(i).zfill(6)}.png', nrow=int(args.n_sample**0.5), normalize=True, range=(-1, 1), ) if i % 10000 == 0: torch.save( { 'g': g_module.state_dict(), 'd': d_module.state_dict(), 'g_ema': g_ema.state_dict(), 'g_optim': g_optim.state_dict(), 'd_optim': d_optim.state_dict(), }, f'checkpoint/{str(i).zfill(6)}.pt', )
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device): loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 r1_loss = torch.tensor(0.0, device=device) path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} if args.distributed: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator accum = 0.5 ** (32 / (10 * 1000)) sample_z = torch.randn(args.n_sample, args.latent, device=device) for idx in pbar: i = idx + args.start_iter if i > args.iter: print('Done!') break data = next(loader) key = np.random.randint(n_scales) real_stack = data[key].to(device) real_img, converted = real_stack[:, :3], real_stack[:, 3:] requires_grad(generator, False) requires_grad(discriminator, True) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(converted, noise) fake = fake_img if args.img2dis else torch.cat([fake_img, converted], 1) fake_pred = discriminator(fake, key) real = real_img if args.img2dis else real_stack real_pred = discriminator(real, key) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict['d'] = d_loss loss_dict['real_score'] = real_pred.mean() loss_dict['fake_score'] = fake_pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() d_regularize = i % args.d_reg_every == 0 if d_regularize: real.requires_grad = True real_pred = discriminator(real, key) r1_loss = d_r1_loss(real_pred, real) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict['r1'] = r1_loss requires_grad(generator, True) requires_grad(discriminator, False) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(converted, noise) fake = fake_img if args.img2dis else torch.cat([fake_img, converted], 1) fake_pred = discriminator(fake, key) g_loss = g_nonsaturating_loss(fake_pred) loss_dict['g'] = g_loss generator.zero_grad() g_loss.backward() g_optim.step() loss_dict['path'] = path_loss loss_dict['path_length'] = path_lengths.mean() accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced['d'].mean().item() g_loss_val = loss_reduced['g'].mean().item() r1_val = loss_reduced['r1'].mean().item() path_loss_val = loss_reduced['path'].mean().item() real_score_val = loss_reduced['real_score'].mean().item() fake_score_val = loss_reduced['fake_score'].mean().item() path_length_val = loss_reduced['path_length'].mean().item() if get_rank() == 0: pbar.set_description( ( f'd: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; ' f'path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}' ) ) if i % 100 == 0: writer.add_scalar("Generator", g_loss_val, i) writer.add_scalar("Discriminator", d_loss_val, i) writer.add_scalar("R1", r1_val, i) writer.add_scalar("Path Length Regularization", path_loss_val, i) writer.add_scalar("Mean Path Length", mean_path_length, i) writer.add_scalar("Real Score", real_score_val, i) writer.add_scalar("Fake Score", fake_score_val, i) writer.add_scalar("Path Length", path_length_val, i) if i % 500 == 0: with torch.no_grad(): g_ema.eval() converted_full = convert_to_coord_format(sample_z.size(0), args.size, args.size, device, integer_values=args.coords_integer_values) if args.generate_by_one: converted_full = convert_to_coord_format(1, args.size, args.size, device, integer_values=args.coords_integer_values) samples = [] for sz in sample_z: sample, _ = g_ema(converted_full, [sz.unsqueeze(0)]) samples.append(sample) sample = torch.cat(samples, 0) else: sample, _ = g_ema(converted_full, [sample_z]) utils.save_image( sample, os.path.join(path, 'outputs', args.output_dir, 'images', f'{str(i).zfill(6)}.png'), nrow=int(args.n_sample ** 0.5), normalize=True, range=(-1, 1), ) if i == 0: utils.save_image( fake_img, os.path.join( path, f'outputs/{args.output_dir}/images/fake_patch_{str(key)}_{str(i).zfill(6)}.png'), nrow=int(fake_img.size(0) ** 0.5), normalize=True, range=(-1, 1), ) utils.save_image( real_img, os.path.join( path, f'outputs/{args.output_dir}/images/real_patch_{str(key)}_{str(i).zfill(6)}.png'), nrow=int(real_img.size(0) ** 0.5), normalize=True, range=(-1, 1), ) if i % args.save_checkpoint_frequency == 0: torch.save( { 'g': g_module.state_dict(), 'd': d_module.state_dict(), 'g_ema': g_ema.state_dict(), 'g_optim': g_optim.state_dict(), 'd_optim': d_optim.state_dict(), }, os.path.join( path, f'outputs/{args.output_dir}/checkpoints/{str(i).zfill(6)}.pt'), ) if i > 0: cur_metrics = calculate_fid(g_ema, fid_dataset=fid_dataset, bs=args.fid_batch, size=args.coords_size, num_batches=args.fid_samples//args.fid_batch, latent_size=args.latent, save_dir=args.path_fid, integer_values=args.coords_integer_values) writer.add_scalar("fid", cur_metrics['frechet_inception_distance'], i) print(i, "fid", cur_metrics['frechet_inception_distance'])
def train(opt): lib.print_model_settings(locals().copy()) """ dataset preparation """ if not opt.data_filtering_off: print('Filtering the images containing characters which are not in opt.character') print('Filtering the images whose label is longer than opt.batch_max_length') # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 log = open(os.path.join(opt.exp_dir,opt.exp_name,'log_dataset.txt'), 'a') AlignCollate_valid = AlignPairCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) train_dataset, train_dataset_log = hierarchical_dataset(root=opt.train_data, opt=opt) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=opt.batch_size, sampler=data_sampler(train_dataset, shuffle=True, distributed=opt.distributed), num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True) log.write(train_dataset_log) print('-' * 80) valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, sampler=data_sampler(train_dataset, shuffle=False, distributed=opt.distributed), num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() if 'Attn' in opt.Prediction: converter = AttnLabelConverter(opt.character) else: converter = CTCLabelConverter(opt.character) opt.num_class = len(converter.character) # styleModel = StyleTensorEncoder(input_dim=opt.input_channel) # genModel = AdaIN_Tensor_WordGenerator(opt) # disModel = MsImageDisV2(opt) # styleModel = StyleLatentEncoder(input_dim=opt.input_channel, norm='none') # mixModel = Mixer(opt,nblk=3, dim=opt.latent) genModel = styleGANGen(opt.size, opt.latent, opt.n_mlp, opt.num_class, channel_multiplier=opt.channel_multiplier).to(device) disModel = styleGANDis(opt.size, channel_multiplier=opt.channel_multiplier, input_dim=opt.input_channel).to(device) g_ema = styleGANGen(opt.size, opt.latent, opt.n_mlp, opt.num_class, channel_multiplier=opt.channel_multiplier).to(device) ocrModel = ModelV1(opt).to(device) accumulate(g_ema, genModel, 0) # # weight initialization # for currModel in [styleModel, mixModel]: # for name, param in currModel.named_parameters(): # if 'localization_fc2' in name: # print(f'Skip {name} as it is already initialized') # continue # try: # if 'bias' in name: # init.constant_(param, 0.0) # elif 'weight' in name: # init.kaiming_normal_(param) # except Exception as e: # for batchnorm. # if 'weight' in name: # param.data.fill_(1) # continue if opt.contentLoss == 'vis' or opt.contentLoss == 'seq': ocrCriterion = torch.nn.L1Loss() else: if 'CTC' in opt.Prediction: ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 # vggRecCriterion = torch.nn.L1Loss() # vggModel = VGGPerceptualLossModel(models.vgg19(pretrained=True), vggRecCriterion) print('model input parameters', opt.imgH, opt.imgW, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length) if opt.distributed: genModel = torch.nn.parallel.DistributedDataParallel( genModel, device_ids=[opt.local_rank], output_device=opt.local_rank, broadcast_buffers=False, ) disModel = torch.nn.parallel.DistributedDataParallel( disModel, device_ids=[opt.local_rank], output_device=opt.local_rank, broadcast_buffers=False, ) ocrModel = torch.nn.parallel.DistributedDataParallel( ocrModel, device_ids=[opt.local_rank], output_device=opt.local_rank, broadcast_buffers=False ) # styleModel = torch.nn.DataParallel(styleModel).to(device) # styleModel.train() # mixModel = torch.nn.DataParallel(mixModel).to(device) # mixModel.train() # genModel = torch.nn.DataParallel(genModel).to(device) # g_ema = torch.nn.DataParallel(g_ema).to(device) genModel.train() g_ema.eval() # disModel = torch.nn.DataParallel(disModel).to(device) disModel.train() # vggModel = torch.nn.DataParallel(vggModel).to(device) # vggModel.eval() # ocrModel = torch.nn.DataParallel(ocrModel).to(device) # if opt.distributed: # ocrModel.module.Transformation.eval() # ocrModel.module.FeatureExtraction.eval() # ocrModel.module.AdaptiveAvgPool.eval() # # ocrModel.module.SequenceModeling.eval() # ocrModel.module.Prediction.eval() # else: # ocrModel.Transformation.eval() # ocrModel.FeatureExtraction.eval() # ocrModel.AdaptiveAvgPool.eval() # # ocrModel.SequenceModeling.eval() # ocrModel.Prediction.eval() ocrModel.eval() if opt.distributed: g_module = genModel.module d_module = disModel.module else: g_module = genModel d_module = disModel g_reg_ratio = opt.g_reg_every / (opt.g_reg_every + 1) d_reg_ratio = opt.d_reg_every / (opt.d_reg_every + 1) optimizer = optim.Adam( genModel.parameters(), lr=opt.lr * g_reg_ratio, betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio), ) dis_optimizer = optim.Adam( disModel.parameters(), lr=opt.lr * d_reg_ratio, betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio), ) ## Loading pre-trained files if opt.modelFolderFlag: if len(glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth")))>0: opt.saved_synth_model = glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth"))[-1] if opt.saved_ocr_model !='' and opt.saved_ocr_model !='None': if not opt.distributed: ocrModel = torch.nn.DataParallel(ocrModel) print(f'loading pretrained ocr model from {opt.saved_ocr_model}') checkpoint = torch.load(opt.saved_ocr_model) ocrModel.load_state_dict(checkpoint) #temporary fix if not opt.distributed: ocrModel = ocrModel.module if opt.saved_gen_model !='' and opt.saved_gen_model !='None': print(f'loading pretrained gen model from {opt.saved_gen_model}') checkpoint = torch.load(opt.saved_gen_model, map_location=lambda storage, loc: storage) genModel.module.load_state_dict(checkpoint['g']) g_ema.module.load_state_dict(checkpoint['g_ema']) if opt.saved_synth_model != '' and opt.saved_synth_model != 'None': print(f'loading pretrained synth model from {opt.saved_synth_model}') checkpoint = torch.load(opt.saved_synth_model) # styleModel.load_state_dict(checkpoint['styleModel']) # mixModel.load_state_dict(checkpoint['mixModel']) genModel.load_state_dict(checkpoint['genModel']) g_ema.load_state_dict(checkpoint['g_ema']) disModel.load_state_dict(checkpoint['disModel']) optimizer.load_state_dict(checkpoint["optimizer"]) dis_optimizer.load_state_dict(checkpoint["dis_optimizer"]) # if opt.imgReconLoss == 'l1': # recCriterion = torch.nn.L1Loss() # elif opt.imgReconLoss == 'ssim': # recCriterion = ssim # elif opt.imgReconLoss == 'ms-ssim': # recCriterion = msssim # loss averager loss_avg = Averager() loss_avg_dis = Averager() loss_avg_gen = Averager() loss_avg_imgRecon = Averager() loss_avg_vgg_per = Averager() loss_avg_vgg_sty = Averager() loss_avg_ocr = Averager() log_r1_val = Averager() log_avg_path_loss_val = Averager() log_avg_mean_path_length_avg = Averager() log_ada_aug_p = Averager() """ final options """ with open(os.path.join(opt.exp_dir,opt.exp_name,'opt.txt'), 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_synth_model != '' and opt.saved_synth_model != 'None': try: start_iter = int(opt.saved_synth_model.split('_')[-2].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass #get schedulers scheduler = get_scheduler(optimizer,opt) dis_scheduler = get_scheduler(dis_optimizer,opt) start_time = time.time() iteration = start_iter cntr=0 mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} accum = 0.5 ** (32 / (10 * 1000)) ada_augment = torch.tensor([0.0, 0.0], device=device) ada_aug_p = opt.augment_p if opt.augment_p > 0 else 0.0 ada_aug_step = opt.ada_target / opt.ada_length r_t_stat = 0 sample_z = torch.randn(opt.n_sample, opt.latent, device=device) while(True): # print(cntr) # train part if opt.lr_policy !="None": scheduler.step() dis_scheduler.step() image_input_tensors, image_gt_tensors, labels_1, labels_2 = iter(train_loader).next() image_input_tensors = image_input_tensors.to(device) image_gt_tensors = image_gt_tensors.to(device) batch_size = image_input_tensors.size(0) requires_grad(genModel, False) # requires_grad(styleModel, False) # requires_grad(mixModel, False) requires_grad(disModel, True) text_1, length_1 = converter.encode(labels_1, batch_max_length=opt.batch_max_length) text_2, length_2 = converter.encode(labels_2, batch_max_length=opt.batch_max_length) #forward pass from style and word generator # style = styleModel(image_input_tensors).squeeze(2).squeeze(2) style = mixing_noise(opt.batch_size, opt.latent, opt.mixing, device) # scInput = mixModel(style,text_2) if 'CTC' in opt.Prediction: images_recon_2,_ = genModel(style, text_2, input_is_latent=opt.input_latent) else: images_recon_2,_ = genModel(style, text_2[:,1:-1], input_is_latent=opt.input_latent) #Domain discriminator: Dis update if opt.augment: image_gt_tensors_aug, _ = augment(image_gt_tensors, ada_aug_p) images_recon_2, _ = augment(images_recon_2, ada_aug_p) else: image_gt_tensors_aug = image_gt_tensors fake_pred = disModel(images_recon_2) real_pred = disModel(image_gt_tensors_aug) disCost = d_logistic_loss(real_pred, fake_pred) loss_dict["d"] = disCost*opt.disWeight loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() loss_avg_dis.add(disCost) disModel.zero_grad() disCost.backward() dis_optimizer.step() if opt.augment and opt.augment_p == 0: ada_augment += torch.tensor( (torch.sign(real_pred).sum().item(), real_pred.shape[0]), device=device ) ada_augment = reduce_sum(ada_augment) if ada_augment[1] > 255: pred_signs, n_pred = ada_augment.tolist() r_t_stat = pred_signs / n_pred if r_t_stat > opt.ada_target: sign = 1 else: sign = -1 ada_aug_p += sign * ada_aug_step * n_pred ada_aug_p = min(1, max(0, ada_aug_p)) ada_augment.mul_(0) d_regularize = cntr % opt.d_reg_every == 0 if d_regularize: image_gt_tensors.requires_grad = True image_input_tensors.requires_grad = True cat_tensor = image_gt_tensors real_pred = disModel(cat_tensor) r1_loss = d_r1_loss(real_pred, cat_tensor) disModel.zero_grad() (opt.r1 / 2 * r1_loss * opt.d_reg_every + 0 * real_pred[0]).backward() dis_optimizer.step() loss_dict["r1"] = r1_loss # #[Style Encoder] + [Word Generator] update image_input_tensors, image_gt_tensors, labels_1, labels_2 = iter(train_loader).next() image_input_tensors = image_input_tensors.to(device) image_gt_tensors = image_gt_tensors.to(device) batch_size = image_input_tensors.size(0) requires_grad(genModel, True) # requires_grad(styleModel, True) # requires_grad(mixModel, True) requires_grad(disModel, False) text_1, length_1 = converter.encode(labels_1, batch_max_length=opt.batch_max_length) text_2, length_2 = converter.encode(labels_2, batch_max_length=opt.batch_max_length) # style = styleModel(image_input_tensors).squeeze(2).squeeze(2) # scInput = mixModel(style,text_2) # images_recon_2,_ = genModel([scInput], input_is_latent=opt.input_latent) style = mixing_noise(batch_size, opt.latent, opt.mixing, device) if 'CTC' in opt.Prediction: images_recon_2, _ = genModel(style, text_2) else: images_recon_2, _ = genModel(style, text_2[:,1:-1]) if opt.augment: images_recon_2, _ = augment(images_recon_2, ada_aug_p) fake_pred = disModel(images_recon_2) disGenCost = g_nonsaturating_loss(fake_pred) loss_dict["g"] = disGenCost # # #Adversarial loss # # disGenCost = disModel.module.calc_gen_loss(torch.cat((images_recon_2,image_input_tensors),dim=1)) # #Input reconstruction loss # recCost = recCriterion(images_recon_2,image_gt_tensors) # #vgg loss # vggPerCost, vggStyleCost = vggModel(image_gt_tensors, images_recon_2) #ocr loss text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) if opt.contentLoss == 'vis' or opt.contentLoss == 'seq': preds_recon = ocrModel(images_recon_2, text_for_pred, is_train=False, returnFeat=opt.contentLoss) preds_gt = ocrModel(image_gt_tensors, text_for_pred, is_train=False, returnFeat=opt.contentLoss) ocrCost = ocrCriterion(preds_recon, preds_gt) else: if 'CTC' in opt.Prediction: preds_recon = ocrModel(images_recon_2, text_for_pred, is_train=False) # preds_o = preds_recon[:, :text_1.shape[1], :] preds_size = torch.IntTensor([preds_recon.size(1)] * batch_size) preds_recon_softmax = preds_recon.log_softmax(2).permute(1, 0, 2) ocrCost = ocrCriterion(preds_recon_softmax, text_2, preds_size, length_2) #predict ocr recognition on generated images # preds_recon_size = torch.IntTensor([preds_recon.size(1)] * batch_size) _, preds_recon_index = preds_recon.max(2) labels_o_ocr = converter.decode(preds_recon_index.data, preds_size.data) #predict ocr recognition on gt style images preds_s = ocrModel(image_input_tensors, text_for_pred, is_train=False) # preds_s = preds_s[:, :text_1.shape[1] - 1, :] preds_s_size = torch.IntTensor([preds_s.size(1)] * batch_size) _, preds_s_index = preds_s.max(2) labels_s_ocr = converter.decode(preds_s_index.data, preds_s_size.data) #predict ocr recognition on gt stylecontent images preds_sc = ocrModel(image_gt_tensors, text_for_pred, is_train=False) # preds_sc = preds_sc[:, :text_2.shape[1] - 1, :] preds_sc_size = torch.IntTensor([preds_sc.size(1)] * batch_size) _, preds_sc_index = preds_sc.max(2) labels_sc_ocr = converter.decode(preds_sc_index.data, preds_sc_size.data) else: preds_recon = ocrModel(images_recon_2, text_for_pred[:, :-1], is_train=False) # align with Attention.forward target_2 = text_2[:, 1:] # without [GO] Symbol ocrCost = ocrCriterion(preds_recon.view(-1, preds_recon.shape[-1]), target_2.contiguous().view(-1)) #predict ocr recognition on generated images _, preds_o_index = preds_recon.max(2) labels_o_ocr = converter.decode(preds_o_index, length_for_pred) for idx, pred in enumerate(labels_o_ocr): pred_EOS = pred.find('[s]') labels_o_ocr[idx] = pred[:pred_EOS] # prune after "end of sentence" token ([s]) #predict ocr recognition on gt style images preds_s = ocrModel(image_input_tensors, text_for_pred, is_train=False) _, preds_s_index = preds_s.max(2) labels_s_ocr = converter.decode(preds_s_index, length_for_pred) for idx, pred in enumerate(labels_s_ocr): pred_EOS = pred.find('[s]') labels_s_ocr[idx] = pred[:pred_EOS] # prune after "end of sentence" token ([s]) #predict ocr recognition on gt stylecontent images preds_sc = ocrModel(image_gt_tensors, text_for_pred, is_train=False) _, preds_sc_index = preds_sc.max(2) labels_sc_ocr = converter.decode(preds_sc_index, length_for_pred) for idx, pred in enumerate(labels_sc_ocr): pred_EOS = pred.find('[s]') labels_sc_ocr[idx] = pred[:pred_EOS] # prune after "end of sentence" token ([s]) # cost = opt.reconWeight*recCost + opt.disWeight*disGenCost + opt.vggPerWeight*vggPerCost + opt.vggStyWeight*vggStyleCost + opt.ocrWeight*ocrCost cost = opt.disWeight*disGenCost + opt.ocrWeight*ocrCost # styleModel.zero_grad() genModel.zero_grad() # mixModel.zero_grad() disModel.zero_grad() # vggModel.zero_grad() ocrModel.zero_grad() cost.backward() optimizer.step() loss_avg.add(cost) g_regularize = cntr % opt.g_reg_every == 0 if g_regularize: image_input_tensors, image_gt_tensors, labels_1, labels_2 = iter(train_loader).next() image_input_tensors = image_input_tensors.to(device) image_gt_tensors = image_gt_tensors.to(device) batch_size = image_input_tensors.size(0) text_1, length_1 = converter.encode(labels_1, batch_max_length=opt.batch_max_length) text_2, length_2 = converter.encode(labels_2, batch_max_length=opt.batch_max_length) path_batch_size = max(1, batch_size // opt.path_batch_shrink) # style = styleModel(image_input_tensors).squeeze(2).squeeze(2) # scInput = mixModel(style,text_2) # images_recon_2, latents = genModel([scInput],input_is_latent=opt.input_latent, return_latents=True) style = mixing_noise(path_batch_size, opt.latent, opt.mixing, device) if 'CTC' in opt.Prediction: images_recon_2, latents = genModel(style, text_2[:path_batch_size], return_latents=True) else: images_recon_2, latents = genModel(style, text_2[:path_batch_size,1:-1], return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( images_recon_2, latents, mean_path_length ) genModel.zero_grad() weighted_path_loss = opt.path_regularize * opt.g_reg_every * path_loss if opt.path_batch_shrink: weighted_path_loss += 0 * images_recon_2[0, 0, 0, 0] weighted_path_loss.backward() optimizer.step() mean_path_length_avg = ( reduce_sum(mean_path_length).item() / get_world_size() ) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() #Individual losses loss_avg_gen.add(opt.disWeight*disGenCost) loss_avg_imgRecon.add(torch.tensor(0.0)) loss_avg_vgg_per.add(torch.tensor(0.0)) loss_avg_vgg_sty.add(torch.tensor(0.0)) loss_avg_ocr.add(opt.ocrWeight*ocrCost) log_r1_val.add(loss_reduced["path"]) log_avg_path_loss_val.add(loss_reduced["path"]) log_avg_mean_path_length_avg.add(torch.tensor(mean_path_length_avg)) log_ada_aug_p.add(torch.tensor(ada_aug_p)) if get_rank() == 0: # pbar.set_description( # ( # f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " # f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " # f"augment: {ada_aug_p:.4f}" # ) # ) if wandb and opt.wandb: wandb.log( { "Generator": g_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, } ) # if cntr % 100 == 0: # with torch.no_grad(): # g_ema.eval() # sample, _ = g_ema([scInput[:,:opt.latent],scInput[:,opt.latent:]]) # utils.save_image( # sample, # os.path.join(opt.trainDir, f"sample_{str(cntr).zfill(6)}.png"), # nrow=int(opt.n_sample ** 0.5), # normalize=True, # range=(-1, 1), # ) # validation part if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' #Save training images curr_batch_size = style[0].shape[0] images_recon_2, _ = g_ema(style, text_2[:curr_batch_size], input_is_latent=opt.input_latent) os.makedirs(os.path.join(opt.trainDir,str(iteration)), exist_ok=True) for trImgCntr in range(batch_size): try: if opt.contentLoss == 'vis' or opt.contentLoss == 'seq': save_image(tensor2im(image_input_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_sInput_'+labels_1[trImgCntr]+'.png')) save_image(tensor2im(image_gt_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csGT_'+labels_2[trImgCntr]+'.png')) save_image(tensor2im(images_recon_2[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csRecon_'+labels_2[trImgCntr]+'.png')) else: save_image(tensor2im(image_input_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_sInput_'+labels_1[trImgCntr]+'_'+labels_s_ocr[trImgCntr]+'.png')) save_image(tensor2im(image_gt_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csGT_'+labels_2[trImgCntr]+'_'+labels_sc_ocr[trImgCntr]+'.png')) save_image(tensor2im(images_recon_2[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csRecon_'+labels_2[trImgCntr]+'_'+labels_o_ocr[trImgCntr]+'.png')) except: print('Warning while saving training image') elapsed_time = time.time() - start_time # for log with open(os.path.join(opt.exp_dir,opt.exp_name,'log_train.txt'), 'a') as log: # styleModel.eval() genModel.eval() g_ema.eval() # mixModel.eval() disModel.eval() with torch.no_grad(): valid_loss, infer_time, length_of_data = validation_synth_v6( iteration, g_ema, ocrModel, disModel, ocrCriterion, valid_loader, converter, opt) # styleModel.train() genModel.train() # mixModel.train() disModel.train() # training loss and validation loss loss_log = f'[{iteration+1}/{opt.num_iter}] Train Synth loss: {loss_avg.val():0.5f}, \ Train Dis loss: {loss_avg_dis.val():0.5f}, Train Gen loss: {loss_avg_gen.val():0.5f},\ Train OCR loss: {loss_avg_ocr.val():0.5f}, \ Train R1-val loss: {log_r1_val.val():0.5f}, Train avg-path-loss: {log_avg_path_loss_val.val():0.5f}, \ Train mean-path-length loss: {log_avg_mean_path_length_avg.val():0.5f}, Train ada-aug-p: {log_ada_aug_p.val():0.5f}, \ Valid Synth loss: {valid_loss[0]:0.5f}, \ Valid Dis loss: {valid_loss[1]:0.5f}, Valid Gen loss: {valid_loss[2]:0.5f}, \ Valid OCR loss: {valid_loss[6]:0.5f}, Elapsed_time: {elapsed_time:0.5f}' #plotting lib.plot.plot(os.path.join(opt.plotDir,'Train-Synth-Loss'), loss_avg.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-Dis-Loss'), loss_avg_dis.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-Gen-Loss'), loss_avg_gen.val().item()) # lib.plot.plot(os.path.join(opt.plotDir,'Train-ImgRecon1-Loss'), loss_avg_imgRecon.val().item()) # lib.plot.plot(os.path.join(opt.plotDir,'Train-VGG-Per-Loss'), loss_avg_vgg_per.val().item()) # lib.plot.plot(os.path.join(opt.plotDir,'Train-VGG-Sty-Loss'), loss_avg_vgg_sty.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-OCR-Loss'), loss_avg_ocr.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-r1_val'), log_r1_val.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-path_loss_val'), log_avg_path_loss_val.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-mean_path_length_avg'), log_avg_mean_path_length_avg.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Train-ada_aug_p'), log_ada_aug_p.val().item()) lib.plot.plot(os.path.join(opt.plotDir,'Valid-Synth-Loss'), valid_loss[0].item()) lib.plot.plot(os.path.join(opt.plotDir,'Valid-Dis-Loss'), valid_loss[1].item()) lib.plot.plot(os.path.join(opt.plotDir,'Valid-Gen-Loss'), valid_loss[2].item()) # lib.plot.plot(os.path.join(opt.plotDir,'Valid-ImgRecon1-Loss'), valid_loss[3].item()) # lib.plot.plot(os.path.join(opt.plotDir,'Valid-VGG-Per-Loss'), valid_loss[4].item()) # lib.plot.plot(os.path.join(opt.plotDir,'Valid-VGG-Sty-Loss'), valid_loss[5].item()) lib.plot.plot(os.path.join(opt.plotDir,'Valid-OCR-Loss'), valid_loss[6].item()) print(loss_log) loss_avg.reset() loss_avg_dis.reset() loss_avg_gen.reset() loss_avg_imgRecon.reset() loss_avg_vgg_per.reset() loss_avg_vgg_sty.reset() loss_avg_ocr.reset() log_r1_val.reset() log_avg_path_loss_val.reset() log_avg_mean_path_length_avg.reset() log_ada_aug_p.reset() lib.plot.flush() lib.plot.tick() # save model per 1e+5 iter. if (iteration) % 1e+4 == 0: torch.save({ # 'styleModel':styleModel.state_dict(), # 'mixModel':mixModel.state_dict(), 'genModel':g_module.state_dict(), 'g_ema':g_ema.state_dict(), 'disModel':d_module.state_dict(), 'optimizer':optimizer.state_dict(), 'dis_optimizer':dis_optimizer.state_dict()}, os.path.join(opt.exp_dir,opt.exp_name,'iter_'+str(iteration+1)+'_synth.pth')) if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() iteration += 1 cntr+=1
def train(args, loader, loader2, generator, encoder, discriminator, discriminator2, vggnet, g_optim, e_optim, d_optim, d2_optim, g_ema, e_ema, device): # kwargs_d = {'detach_aux': args.detach_d_aux_head} inception = real_mean = real_cov = mean_latent = None if args.eval_every > 0: inception = nn.DataParallel(load_patched_inception_v3()).to(device) inception.eval() with open(args.inception, "rb") as f: embeds = pickle.load(f) real_mean = embeds["mean"] real_cov = embeds["cov"] if get_rank() == 0: if args.eval_every > 0: with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f: f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n") if args.log_every > 0: with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n") if args.dataset == 'imagefolder': loader = sample_data2(loader) else: loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} avg_pix_loss = util.AverageMeter() avg_vgg_loss = util.AverageMeter() if args.distributed: g_module = generator.module e_module = encoder.module d_module = discriminator.module else: g_module = generator e_module = encoder d_module = discriminator d2_module = None # if discriminator2 is not None: # if args.distributed: # d2_module = discriminator2.module # else: # d2_module = discriminator2 accum = 0.5**(32 / (10 * 1000)) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 r_t_dict = {'real': 0, 'fake': 0, 'recx': 0} # r_t stat if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment2(args.ada_margin, args.ada_length, args.ada_every, device) sample_z = torch.randn(args.n_sample, args.latent, device=device) sample_x = load_real_samples(args, loader) sample_x1 = sample_x2 = sample_idx = fid_batch_idx = None if sample_x.ndim > 4: sample_x1 = sample_x[:, 0, ...] sample_x2 = sample_x[:, -1, ...] sample_x = sample_x[:, 0, ...] n_step_max = max(args.n_step_d, args.n_step_e) requires_grad(g_ema, False) requires_grad(e_ema, False) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break real_imgs = [next(loader).to(device) for _ in range(n_step_max)] # Train Discriminator requires_grad(generator, False) requires_grad(encoder, False) requires_grad(discriminator, True) for step_index in range(args.n_step_d): real_img = real_imgs[step_index] noise = mixing_noise(args.batch, args.latent, args.mixing, device) if args.use_ema: g_ema.eval() fake_img, _ = g_ema(noise) else: fake_img, _ = generator(noise) fake_pred = discriminator(fake_img) real_pred = discriminator(real_img) d_loss_fake = F.softplus(fake_pred).mean() d_loss_real = F.softplus(-real_pred).mean() loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() d_loss_rec = rec_pred = 0. if args.lambda_rec_d > 0 and not args.decouple_d: if args.use_ema: e_ema.eval() g_ema.eval() latent_real, _ = e_ema(real_img) rec_img, _ = g_ema([latent_real], input_is_latent=True) else: latent_real, _ = encoder(real_img) rec_img, _ = generator([latent_real], input_is_latent=True) rec_pred = discriminator(rec_img) d_loss_rec = F.softplus(rec_pred).mean() loss_dict["recx_score"] = rec_pred.mean() d_loss = d_loss_real + d_loss_fake * args.lambda_fake_d + d_loss_rec * args.lambda_rec_d loss_dict["d"] = d_loss discriminator.zero_grad() d_loss.backward() d_optim.step() if args.augment and args.augment_p == 0: ada_aug_p = ada_augment.tune(fake_pred, rec_pred) r_t_stat = ada_augment.r_t_stat r_t_dict['real'] = torch.sign(real_pred).sum().item() / args.batch r_t_dict['fake'] = torch.sign(fake_pred).sum().item() / args.batch r_t_dict['recx'] = torch.sign(rec_pred).sum().item() / args.batch d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True real_pred = discriminator(real_img) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict["r1"] = r1_loss # Train Encoder requires_grad(encoder, True) requires_grad(generator, True) requires_grad(discriminator, False) # requires_grad(discriminator2, False) pix_loss = vgg_loss = adv_loss = torch.tensor(0., device=device) for step_index in range(args.n_step_e): real_img = real_imgs[step_index] latent_real, _ = encoder(real_img) if args.use_ema: g_ema.eval() rec_img, _ = g_ema([latent_real], input_is_latent=True) else: rec_img, _ = generator([latent_real], input_is_latent=True) if args.lambda_pix > 0: if args.pix_loss == 'l2': pix_loss = torch.mean((rec_img - real_img)**2) elif args.pix_loss == 'l1': pix_loss = F.l1_loss(rec_img, real_img) else: raise NotImplementedError if args.lambda_vgg > 0: vgg_loss = torch.mean((vggnet(real_img) - vggnet(rec_img))**2) if args.lambda_adv > 0: rec_pred = discriminator(rec_img) adv_loss = g_nonsaturating_loss(rec_pred) e_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv loss_dict["e"] = e_loss loss_dict["pix"] = pix_loss loss_dict["vgg"] = vgg_loss loss_dict["adv"] = adv_loss encoder.zero_grad() generator.zero_grad() e_loss.backward() manually_scale_grad(generator, 1 - ada_aug_p) e_optim.step() g_optim.step() # Train Generator requires_grad(generator, True) requires_grad(encoder, False) requires_grad(discriminator, False) # requires_grad(discriminator2, False) real_img = real_imgs[0] noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) fake_pred = discriminator(fake_img) g_loss_fake = g_nonsaturating_loss(fake_pred) loss_dict["g"] = g_loss_fake generator.zero_grad() g_loss_fake.backward() g_optim.step() g_regularize = args.g_reg_every > 0 and i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) fake_img, latents = generator(noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() accumulate(e_ema, e_module, accum) accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() recx_score_val = loss_reduced["recx_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() pix_loss_val = loss_reduced["pix"].mean().item() vgg_loss_val = loss_reduced["vgg"].mean().item() adv_loss_val = loss_reduced["adv"].mean().item() avg_pix_loss.update(pix_loss_val, real_img.shape[0]) avg_vgg_loss.update(vgg_loss_val, real_img.shape[0]) if get_rank() == 0: pbar.set_description(( f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}; " f"pix: {pix_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; adv: {adv_loss_val:.4f}" )) if i % args.log_every == 0: with torch.no_grad(): latent_x, _ = e_ema(sample_x) fake_x, _ = generator([latent_x], input_is_latent=True, return_latents=False) sample_pix_loss = torch.sum((sample_x - fake_x)**2) with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write(( f"{i:07d}; pix: {avg_pix_loss.avg:.4f}; vgg: {avg_vgg_loss.avg:.4f}; ref: {sample_pix_loss.item():.4f}; " f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean_path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}; r_stat: {r_t_stat:.4f}; {'; '.join([f'{k}: {r_t_dict[k]:.4f}' for k in r_t_dict])}; " f"real_score: {real_score_val:.4f}; fake_score: {fake_score_val:.4f}; recx_score: {recx_score_val:.4f};\n" )) if args.eval_every > 0 and i % args.eval_every == 0: with torch.no_grad(): fid_sa = fid_re = fid_hy = 0 # Sample FID g_ema.eval() if args.truncation < 1: mean_latent = g_ema.mean_latent(4096) features = extract_feature_from_samples( g_ema, inception, args.truncation, mean_latent, 64, args.n_sample_fid, args.device).numpy() sample_mean = np.mean(features, 0) sample_cov = np.cov(features, rowvar=False) fid_sa = calc_fid(sample_mean, sample_cov, real_mean, real_cov) # Recon FID features = extract_feature_from_recon_hybrid( e_ema, g_ema, inception, args.truncation, mean_latent, loader2, args.device, mode='recon', ).numpy() sample_mean = np.mean(features, 0) sample_cov = np.cov(features, rowvar=False) fid_re = calc_fid(sample_mean, sample_cov, real_mean, real_cov) # print("Sample FID:", fid_sa, "Recon FID:", fid_re, "Hybrid FID:", fid_hy) with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f: f.write( f"{i:07d}; sample fid: {float(fid_sa):.4f}; recon fid: {float(fid_re):.4f}; \n" ) if wandb and args.wandb: wandb.log({ "Generator": g_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, }) if i % args.log_every == 0: with torch.no_grad(): g_ema.eval() e_ema.eval() nrow = int(args.n_sample**0.5) nchw = list(sample_x.shape)[1:] # Fixed fake samples sample, _ = g_ema([sample_z]) utils.save_image( sample, os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-sample.png"), nrow=int(args.n_sample**0.5), normalize=True, value_range=(-1, 1), ) # Reconstruction samples latent_real, _ = e_ema(sample_x) fake_img, _ = g_ema([latent_real], input_is_latent=True, return_latents=False) sample = torch.cat( (sample_x.reshape(args.n_sample // nrow, nrow, *nchw), fake_img.reshape(args.n_sample // nrow, nrow, *nchw)), 1) utils.save_image( sample.reshape(2 * args.n_sample, *nchw), os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-recon.png"), nrow=nrow, normalize=True, value_range=(-1, 1), ) if i % args.save_every == 0: torch.save( { "g": g_module.state_dict(), "e": e_module.state_dict(), "d": d_module.state_dict(), "d2": d2_module.state_dict() if args.decouple_d else None, "g_ema": g_ema.state_dict(), "e_ema": e_ema.state_dict(), "g_optim": g_optim.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "d2_optim": d2_optim.state_dict() if args.decouple_d else None, "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"), ) if i % args.save_latest_every == 0: torch.save( { "g": g_module.state_dict(), "e": e_module.state_dict(), "d": d_module.state_dict(), "d2": d2_module.state_dict() if args.decouple_d else None, "g_ema": g_ema.state_dict(), "e_ema": e_ema.state_dict(), "g_optim": g_optim.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "d2_optim": d2_optim.state_dict() if args.decouple_d else None, "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"latest.pt"), )
def train(args, loader, encoder, generator, discriminator, discriminator_z, g1, vggnet, pwcnet, e_optim, d_optim, dz_optim, g1_optim, e_ema, e_tf, g1_ema, device): mmd_eval = functools.partial(mix_rbf_mmd2, sigma_list=[2.0, 5.0, 10.0, 20.0, 40.0, 80.0]) loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) d_loss_val = 0 e_loss_val = 0 rec_loss_val = 0 vgg_loss_val = 0 adv_loss_val = 0 loss_dict = { "d": torch.tensor(0., device=device), "real_score": torch.tensor(0., device=device), "fake_score": torch.tensor(0., device=device), "r1_d": torch.tensor(0., device=device), "r1_e": torch.tensor(0., device=device), "rec": torch.tensor(0., device=device), } avg_pix_loss = util.AverageMeter() avg_vgg_loss = util.AverageMeter() if args.distributed: e_module = encoder.module d_module = discriminator.module g_module = generator.module g1_module = g1.module if args.train_latent_mlp else None else: e_module = encoder d_module = discriminator g_module = generator g1_module = g1 if args.train_latent_mlp else None accum = 0.5**(32 / (10 * 1000)) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256, device) # sample_x = accumulate_batches(loader, args.n_sample).to(device) sample_x = load_real_samples(args, loader) requires_grad(generator, False) # always False generator.eval() # Generator should be ema and in eval mode # if args.no_ema or e_ema is None: # e_ema = encoder for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break real_img = next(loader) real_img = real_img.to(device) batch = real_img.shape[0] # Train Encoder if args.toggle_grads: requires_grad(encoder, True) requires_grad(discriminator, False) pix_loss = vgg_loss = adv_loss = rec_loss = torch.tensor(0., device=device) kld_z = torch.tensor(0., device=device) mmd_z = torch.tensor(0., device=device) gan_z = torch.tensor(0., device=device) etf_z = torch.tensor(0., device=device) latent_real, logvar = encoder(real_img) if args.reparameterization: latent_real = reparameterize(latent_real, logvar) if args.train_latent_mlp: fake_img, _ = generator([g1(latent_real)], input_is_latent=True, return_latents=False) else: fake_img, _ = generator([latent_real], input_is_latent=False, return_latents=False) if args.lambda_adv > 0: if args.augment: fake_img_aug, _ = augment(fake_img, ada_aug_p) else: fake_img_aug = fake_img fake_pred = discriminator(fake_img_aug) adv_loss = g_nonsaturating_loss(fake_pred) if args.lambda_pix > 0: pix_loss = torch.mean((real_img - fake_img)**2) if args.lambda_vgg > 0: real_feat = vggnet(real_img) fake_feat = vggnet(fake_img) vgg_loss = torch.mean((real_feat - fake_feat)**2) if args.lambda_kld_z > 0: z_mean = latent_real.view(batch, -1) kld_z = -0.5 * torch.sum(1. + logvar - z_mean.pow(2) - logvar.exp()) / batch # print(kld_z) if args.lambda_mmd_z > 0: z_real = torch.randn(batch, args.latent_full, device=device) mmd_z = mmd_eval(latent_real, z_real) # print(mmd_z) if args.lambda_gan_z > 0: fake_pred = discriminator_z(latent_real) gan_z = g_nonsaturating_loss(fake_pred) # print(gan_z) if args.use_latent_teacher_forcing and args.lambda_etf > 0: w_tf, _ = e_tf(real_img) if args.train_latent_mlp: w_pred = g1(latent_real) else: w_pred = generator.get_latent(latent_real) etf_z = torch.mean((w_tf - w_pred)**2) # print(etf_z) if args.train_on_fake and args.lambda_rec > 0: z_real = torch.randn(args.batch, args.latent_full, device=device) if args.train_latent_mlp: fake_img, _ = generator([g1(z_real)], input_is_latent=True, return_latents=False) else: fake_img, _ = generator([z_real], input_is_latent=False, return_latents=False) # fake_img, _ = generator([z_real], input_is_latent=False, return_latents=True) z_fake, z_logvar = encoder(fake_img) if args.reparameterization: z_fake = reparameterize(z_fake, z_logvar) rec_loss = torch.mean((z_real - z_fake)**2) loss_dict["rec"] = rec_loss # print(rec_loss) e_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv e_loss = e_loss + args.lambda_kld_z * kld_z + args.lambda_mmd_z * mmd_z + args.lambda_gan_z * gan_z + args.lambda_etf * etf_z + rec_loss * args.lambda_rec loss_dict["e"] = e_loss loss_dict["pix"] = pix_loss loss_dict["vgg"] = vgg_loss loss_dict["adv"] = adv_loss if args.train_latent_mlp and g1 is not None: g1.zero_grad() encoder.zero_grad() e_loss.backward() e_optim.step() if args.train_latent_mlp and g1_optim is not None: g1_optim.step() # if args.train_on_fake: # e_regularize = args.e_rec_every > 0 and i % args.e_rec_every == 0 # if e_regularize and args.lambda_rec > 0: # # noise = mixing_noise(args.batch, args.latent, args.mixing, device) # # fake_img, latent_fake = generator(noise, input_is_latent=False, return_latents=True) # z_real = torch.randn(args.batch, args.latent_full, device=device) # fake_img, w_real = generator([z_real], input_is_latent=False, return_latents=True) # z_fake, logvar = encoder(fake_img) # if args.reparameterization: # z_fake = reparameterize(z_fake, logvar) # rec_loss = torch.mean((z_real - z_fake) ** 2) # encoder.zero_grad() # (rec_loss * args.lambda_rec).backward() # e_optim.step() # loss_dict["rec"] = rec_loss e_regularize = args.e_reg_every > 0 and i % args.e_reg_every == 0 if e_regularize: # why not regularize on augmented real? real_img.requires_grad = True real_pred, logvar = encoder(real_img) if args.reparameterization: real_pred = reparameterize(real_pred, logvar) r1_loss_e = d_r1_loss(real_pred, real_img) encoder.zero_grad() (args.r1 / 2 * r1_loss_e * args.e_reg_every + 0 * real_pred.view(-1)[0]).backward() e_optim.step() loss_dict["r1_e"] = r1_loss_e if not args.no_ema and e_ema is not None: accumulate(e_ema, e_module, accum) if args.train_latent_mlp: accumulate(g1_ema, g1_module, accum) # Train Discriminator if args.toggle_grads: requires_grad(encoder, False) requires_grad(discriminator, True) if not args.no_update_discriminator and args.lambda_adv > 0: latent_real, logvar = encoder(real_img) if args.reparameterization: latent_real = reparameterize(latent_real, logvar) if args.train_latent_mlp: fake_img, _ = generator([g1(latent_real)], input_is_latent=True, return_latents=False) else: fake_img, _ = generator([latent_real], input_is_latent=False, return_latents=False) if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) fake_img_aug, _ = augment(fake_img, ada_aug_p) else: real_img_aug = real_img fake_img_aug = fake_img fake_pred = discriminator(fake_img_aug) real_pred = discriminator(real_img_aug) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict["d"] = d_loss loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() z_real = torch.randn(batch, args.latent_full, device=device) fake_pred = discriminator_z(latent_real.detach()) real_pred = discriminator_z(z_real) d_loss_z = d_logistic_loss(real_pred, fake_pred) discriminator_z.zero_grad() d_loss_z.backward() dz_optim.step() if args.augment and args.augment_p == 0: ada_aug_p = ada_augment.tune(real_pred) r_t_stat = ada_augment.r_t_stat d_regularize = args.d_reg_every > 0 and i % args.d_reg_every == 0 if d_regularize: # why not regularize on augmented real? real_img.requires_grad = True real_pred = discriminator(real_img) r1_loss_d = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss_d * args.d_reg_every + 0 * real_pred.view(-1)[0]).backward() # Why 0* ? Answer is here https://github.com/rosinality/stylegan2-pytorch/issues/76 d_optim.step() loss_dict["r1_d"] = r1_loss_d loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() e_loss_val = loss_reduced["e"].mean().item() r1_d_val = loss_reduced["r1_d"].mean().item() r1_e_val = loss_reduced["r1_e"].mean().item() pix_loss_val = loss_reduced["pix"].mean().item() vgg_loss_val = loss_reduced["vgg"].mean().item() adv_loss_val = loss_reduced["adv"].mean().item() rec_loss_val = loss_reduced["rec"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() avg_pix_loss.update(pix_loss_val, real_img.shape[0]) avg_vgg_loss.update(vgg_loss_val, real_img.shape[0]) if get_rank() == 0: pbar.set_description(( f"d: {d_loss_val:.4f}; e: {e_loss_val:.4f}; r1_d: {r1_d_val:.4f}; r1_e: {r1_e_val:.4f}; " f"pix: {pix_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; adv: {adv_loss_val:.4f}; " f"rec: {rec_loss_val:.4f}; augment: {ada_aug_p:.4f}")) if i % args.log_every == 0: with torch.no_grad(): latent_x, _ = e_ema(sample_x) if args.train_latent_mlp: g1_ema.eval() fake_x, _ = generator([g1_ema(latent_x)], input_is_latent=True, return_latents=False) else: fake_x, _ = generator([latent_x], input_is_latent=False, return_latents=False) sample_pix_loss = torch.sum((sample_x - fake_x)**2) with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write( f"{i:07d}; pix: {avg_pix_loss.avg}; vgg: {avg_vgg_loss.avg}; " f"ref: {sample_pix_loss.item()};\n") if wandb and args.wandb: wandb.log({ "Encoder": e_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1 D": r1_d_val, "R1 E": r1_e_val, "Pix Loss": pix_loss_val, "VGG Loss": vgg_loss_val, "Adv Loss": adv_loss_val, "Rec Loss": rec_loss_val, "Real Score": real_score_val, "Fake Score": fake_score_val, }) if i % args.log_every == 0: with torch.no_grad(): e_eval = encoder if args.no_ema else e_ema e_eval.eval() nrow = int(args.n_sample**0.5) nchw = list(sample_x.shape)[1:] latent_real, _ = e_eval(sample_x) if args.train_latent_mlp: g1_ema.eval() fake_img, _ = generator([g1_ema(latent_real)], input_is_latent=True, return_latents=False) else: fake_img, _ = generator([latent_real], input_is_latent=False, return_latents=False) sample = torch.cat( (sample_x.reshape(args.n_sample // nrow, nrow, *nchw), fake_img.reshape(args.n_sample // nrow, nrow, *nchw)), 1) utils.save_image( sample.reshape(2 * args.n_sample, *nchw), os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}.png"), nrow=nrow, normalize=True, value_range=(-1, 1), ) e_eval.train() if i % args.save_every == 0: e_eval = encoder if args.no_ema else e_ema torch.save( { "e": e_module.state_dict(), "d": d_module.state_dict(), "g1": g1_module.state_dict() if args.train_latent_mlp else None, "g1_ema": g1_ema.state_dict() if args.train_latent_mlp else None, "g_ema": g_module.state_dict(), "e_ema": e_eval.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"), ) if i % args.save_latest_every == 0: torch.save( { "e": e_module.state_dict(), "d": d_module.state_dict(), "g1": g1_module.state_dict() if args.train_latent_mlp else None, "g1_ema": g1_ema.state_dict() if args.train_latent_mlp else None, "g_ema": g_module.state_dict(), "e_ema": e_eval.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"latest.pt"), )
def train(args, loader, generator, discriminator, extra, g_optim, d_optim, e_optim, g_ema, device, g_source, d_source): loader = sample_data(loader) imsave_path = os.path.join('samples', args.exp) model_path = os.path.join('checkpoints', args.exp) if not os.path.exists(imsave_path): os.makedirs(imsave_path) if not os.path.exists(model_path): os.makedirs(model_path) # this defines the anchor points, and when sampling noise close to these, we impose image-level adversarial loss (Eq. 4 in the paper) init_z = torch.randn(args.n_train, args.latent, device=device) pbar = range(args.iter) sfm = nn.Softmax(dim=1) kl_loss = nn.KLDivLoss() sim = nn.CosineSimilarity() if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} g_module = generator d_module = discriminator g_ema_module = g_ema.module accum = 0.5**(32 / (10 * 1000)) ada_augment = torch.tensor([0.0, 0.0], device=device) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 ada_aug_step = args.ada_target / args.ada_length r_t_stat = 0 # this defines which level feature of the discriminator is used to implement the patch-level adversarial loss: could be anything between [0, args.highp] lowp, highp = 0, args.highp # the following defines the constant noise used for generating images at different stages of training sample_z = torch.randn(args.n_sample, args.latent, device=device) requires_grad(g_source, False) requires_grad(d_source, False) sub_region_z = get_subspace(args, init_z.clone(), vis_flag=True) for idx in pbar: i = idx + args.start_iter which = i % args.subspace_freq # defines whether we sample from anchor region in this iteration or other if i > args.iter: print("Done!") break real_img = next(loader) real_img = real_img.to(device) requires_grad(generator, False) requires_grad(discriminator, True) requires_grad(extra, True) if which > 0: # sample normally, apply patch-level adversarial loss noise = mixing_noise(args.batch, args.latent, args.mixing, device) else: # sample from anchors, apply image-level adversarial loss noise = [get_subspace(args, init_z.clone())] fake_img, _ = generator(noise) if args.augment: real_img, _ = augment(real_img, ada_aug_p) fake_img, _ = augment(fake_img, ada_aug_p) fake_pred, _ = discriminator(fake_img, extra=extra, flag=which, p_ind=np.random.randint(lowp, highp)) real_pred, _ = discriminator(real_img, extra=extra, flag=which, p_ind=np.random.randint(lowp, highp), real=True) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict["d"] = d_loss loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() discriminator.zero_grad() extra.zero_grad() d_loss.backward() d_optim.step() e_optim.step() if args.augment and args.augment_p == 0: ada_augment += torch.tensor( (torch.sign(real_pred).sum().item(), real_pred.shape[0]), device=device) ada_augment = reduce_sum(ada_augment) if ada_augment[1] > 255: pred_signs, n_pred = ada_augment.tolist() r_t_stat = pred_signs / n_pred if r_t_stat > args.ada_target: sign = 1 else: sign = -1 ada_aug_p += sign * ada_aug_step * n_pred ada_aug_p = min(1, max(0, ada_aug_p)) ada_augment.mul_(0) d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True real_pred, _ = discriminator(real_img, extra=extra, flag=which, p_ind=np.random.randint(lowp, highp)) real_pred = real_pred.view(real_img.size(0), -1) real_pred = real_pred.mean(dim=1).unsqueeze(1) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() extra.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() e_optim.step() loss_dict["r1"] = r1_loss requires_grad(generator, True) requires_grad(discriminator, False) requires_grad(extra, False) if which > 0: noise = mixing_noise(args.batch, args.latent, args.mixing, device) else: noise = [get_subspace(args, init_z.clone())] fake_img, _ = generator(noise) if args.augment: fake_img, _ = augment(fake_img, ada_aug_p) fake_pred, _ = discriminator(fake_img, extra=extra, flag=which, p_ind=np.random.randint(lowp, highp)) g_loss = g_nonsaturating_loss(fake_pred) # distance consistency loss with torch.set_grad_enabled(False): z = torch.randn(args.feat_const_batch, args.latent, device=device) feat_ind = numpy.random.randint(1, g_source.module.n_latent - 1, size=args.feat_const_batch) # computing source distances source_sample, feat_source = g_source([z], return_feats=True) dist_source = torch.zeros( [args.feat_const_batch, args.feat_const_batch - 1]).cuda() # iterating over different elements in the batch for pair1 in range(args.feat_const_batch): tmpc = 0 # comparing the possible pairs for pair2 in range(args.feat_const_batch): if pair1 != pair2: anchor_feat = torch.unsqueeze( feat_source[feat_ind[pair1]][pair1].reshape(-1), 0) compare_feat = torch.unsqueeze( feat_source[feat_ind[pair1]][pair2].reshape(-1), 0) dist_source[pair1, tmpc] = sim(anchor_feat, compare_feat) tmpc += 1 dist_source = sfm(dist_source) # computing distances among target generations _, feat_target = generator([z], return_feats=True) dist_target = torch.zeros( [args.feat_const_batch, args.feat_const_batch - 1]).cuda() # iterating over different elements in the batch for pair1 in range(args.feat_const_batch): tmpc = 0 for pair2 in range( args.feat_const_batch): # comparing the possible pairs if pair1 != pair2: anchor_feat = torch.unsqueeze( feat_target[feat_ind[pair1]][pair1].reshape(-1), 0) compare_feat = torch.unsqueeze( feat_target[feat_ind[pair1]][pair2].reshape(-1), 0) dist_target[pair1, tmpc] = sim(anchor_feat, compare_feat) tmpc += 1 dist_target = sfm(dist_target) rel_loss = args.kl_wt * \ kl_loss(torch.log(dist_target), dist_source) # distance consistency loss g_loss = g_loss + rel_loss loss_dict["g"] = g_loss generator.zero_grad() g_loss.backward() g_optim.step() g_regularize = i % args.g_reg_every == 0 # to save up space del rel_loss, g_loss, d_loss, fake_img, fake_pred, real_img, real_pred, anchor_feat, compare_feat, dist_source, dist_target, feat_source, feat_target if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) fake_img, latents = generator(noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() accumulate(g_ema_module, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() if get_rank() == 0: pbar.set_description(( f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}")) if wandb and args.wandb: wandb.log({ "Generator": g_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, }) if i % args.img_freq == 0: with torch.set_grad_enabled(False): g_ema.eval() sample, _ = g_ema([sample_z.data]) sample_subz, _ = g_ema([sub_region_z.data]) utils.save_image( sample, f"%s/{str(i).zfill(6)}.png" % (imsave_path), nrow=int(args.n_sample**0.5), normalize=True, range=(-1, 1), ) del sample if (i % args.save_freq == 0) and (i > 0): torch.save( { "g_ema": g_ema.state_dict(), # uncomment the following lines only if you wish to resume training after saving. Otherwise, saving just the generator is sufficient for evaluations #"g": g_module.state_dict(), #"g_s": g_source.state_dict(), #"d": d_module.state_dict(), #"g_optim": g_optim.state_dict(), #"d_optim": d_optim.state_dict(), }, f"%s/{str(i).zfill(6)}.pt" % (model_path), )
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device): loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} if args.distributed: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator none_g_grads = set() test_in = torch.randn(1, args.latent, device=device) fake, latent = g_module([test_in], return_latents=True) path = g_path_regularize(fake, latent, 0) path[0].backward() for n, p in generator.named_parameters(): if p.grad is None: none_g_grads.add(n) test_in = torch.randn(1, 3, args.size, args.size, requires_grad=True, device=device) pred = d_module(test_in) r1_loss = d_r1_loss(pred, test_in) r1_loss.backward() none_d_grads = set() for n, p in discriminator.named_parameters(): if p.grad is None: none_d_grads.add(n) sample_z = torch.randn(2 * 2, args.latent, device=device) for i in pbar: real_img = next(loader) real_img = real_img.to(device) requires_grad(generator, False) requires_grad(discriminator, True) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) fake_pred = discriminator(fake_img) real_pred = discriminator(real_img) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict['d'] = d_loss loss_dict['real_score'] = real_pred.mean() loss_dict['fake_score'] = fake_pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True real_pred = discriminator(real_img) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() set_grad_none(discriminator, none_d_grads) d_optim.step() loss_dict['r1'] = r1_loss requires_grad(generator.proj, True) requires_grad(discriminator, False) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) noise_proj_loss = sum([(generator.proj(noise_i) - noise_i).abs().sum() for noise_i in noise]) fake_pred = discriminator(fake_img) g_loss = g_nonsaturating_loss(fake_pred) print(noise_proj_loss.item()) loss_dict['g'] = g_loss generator.zero_grad() (g_loss + noise_proj_loss).backward() g_optim.step() g_regularize = i % args.g_reg_every == 0 if g_regularize: noise = mixing_noise( args.batch // args.path_batch_shrink, args.latent, args.mixing, device ) fake_img, latents = generator(noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length ) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() set_grad_none(g_module, none_g_grads) g_optim.step() mean_path_length_avg = ( reduce_sum(mean_path_length).item() / get_world_size() ) loss_dict['path'] = path_loss loss_dict['path_length'] = path_lengths.mean() accumulate(g_ema, g_module) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced['d'].mean().item() g_loss_val = loss_reduced['g'].mean().item() r1_val = loss_reduced['r1'].mean().item() path_loss_val = loss_reduced['path'].mean().item() real_score_val = loss_reduced['real_score'].mean().item() fake_score_val = loss_reduced['fake_score'].mean().item() path_length_val = loss_reduced['path_length'].mean().item() if get_rank() == 0: pbar.set_description( ( f'd: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; ' f'path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}' ) ) if wandb and args.wandb: wandb.log( { 'Generator': g_loss_val, 'Discriminator': d_loss_val, 'R1': r1_val, 'Path Length Regularization': path_loss_val, 'Mean Path Length': mean_path_length, 'Real Score': real_score_val, 'Fake Score': fake_score_val, 'Path Length': path_length_val, } ) if i % 10000 == 0: torch.save( { 'g': g_module.state_dict(), 'd': d_module.state_dict(), 'g_ema': g_ema.state_dict(), 'g_optim': g_optim.state_dict(), 'd_optim': d_optim.state_dict(), }, f'checkpoint/{str(i).zfill(6)}.pt', ) if i % 100 == 0: with torch.no_grad(): g_ema.eval() sample, _ = g_ema([sample_z]) utils.save_image( sample, f'sample/{str(i).zfill(6)}.png', nrow=2, normalize=True, range=(-1, 1), )
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device): inception = real_mean = real_cov = mean_latent = None if args.eval_every > 0: inception = nn.DataParallel(load_patched_inception_v3()).to(device) inception.eval() with open(args.inception, "rb") as f: embeds = pickle.load(f) real_mean = embeds["mean"] real_cov = embeds["cov"] if get_rank() == 0: if args.eval_every > 0: with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f: f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n") if args.log_every > 0: with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n") loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} if args.distributed: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator # accum = 0.5 ** (32 / (10 * 1000)) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, args.ada_every, device) args.n_sheets = int(np.ceil(args.n_classes / args.n_class_per_sheet)) args.n_sample_per_sheet = args.n_sample_per_class * args.n_class_per_sheet args.n_sample = args.n_sample_per_sheet * args.n_sheets sample_z = torch.randn(args.n_sample, args.latent, device=device) sample_y = torch.arange(args.n_classes).repeat(args.n_sample_per_class, 1).t().reshape(-1).to(device) if args.n_sample > args.n_sample_per_class * args.n_classes: sample_y1 = make_fake_label(args.n_sample - args.n_sample_per_class * args.n_classes, args.n_classes, device) sample_y = torch.cat([sample_y, sample_y1], 0) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break # Train Discriminator requires_grad(generator, False) requires_grad(discriminator, True) for step_index in range(args.n_step_d): real_img, real_labels = next(loader) real_img, real_labels = real_img.to(device), real_labels.to(device) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_labels = make_fake_label(args.batch, args.n_classes, device) fake_img, _ = generator(noise, fake_labels) if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) fake_img, _ = augment(fake_img, ada_aug_p) else: real_img_aug = real_img fake_pred = discriminator(fake_img, fake_labels) real_pred = discriminator(real_img_aug, real_labels) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict["d"] = d_loss loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() if args.augment and args.augment_p == 0: ada_aug_p = ada_augment.tune(real_pred) r_t_stat = ada_augment.r_t_stat d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) else: real_img_aug = real_img real_pred = discriminator(real_img_aug, real_labels) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict["r1"] = r1_loss # Train Generator requires_grad(generator, True) requires_grad(discriminator, False) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_labels = make_fake_label(args.batch, args.n_classes, device) fake_img, _ = generator(noise, fake_labels) if args.augment: fake_img, _ = augment(fake_img, ada_aug_p) fake_pred = discriminator(fake_img, fake_labels) g_loss = g_nonsaturating_loss(fake_pred) loss_dict["g"] = g_loss generator.zero_grad() g_loss.backward() g_optim.step() g_regularize = args.g_reg_every > 0 and i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) fake_labels = make_fake_label(args.batch, args.n_classes, device) fake_img, latents = generator(noise, fake_labels, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length ) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = ( reduce_sum(mean_path_length).item() / get_world_size() ) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() # Update G_ema # G_ema = G * (1-ema_beta) + G_ema * ema_beta ema_nimg = args.ema_kimg * 1000 if args.ema_rampup is not None: ema_nimg = min(ema_nimg, i * args.batch * args.ema_rampup) accum = 0.5 ** (args.batch / max(ema_nimg, 1e-8)) accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() if get_rank() == 0: pbar.set_description( ( f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}" ) ) if wandb and args.wandb: wandb.log( { "Generator": g_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, } ) if i % args.log_every == 0: with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write( ( f"{i:07d}; " f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f};\n" ) ) if i % args.log_every == 0: with torch.no_grad(): g_ema.eval() for sheet_index in range(args.n_sheets): sample_z_sheet = sample_z[sheet_index*args.n_sample_per_sheet:(sheet_index+1)*args.n_sample_per_sheet] sample_y_sheet = sample_y[sheet_index*args.n_sample_per_sheet:(sheet_index+1)*args.n_sample_per_sheet] sample, _ = g_ema([sample_z_sheet], sample_y_sheet) utils.save_image( sample, os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}_{sheet_index}.png"), nrow=args.n_sample_per_class, normalize=True, value_range=(-1, 1), ) if args.eval_every > 0 and i % args.eval_every == 0: with torch.no_grad(): g_ema.eval() if args.truncation < 1: mean_latent = g_ema.mean_latent(4096) features = extract_feature_from_samples( g_ema, inception, args.truncation, mean_latent, 64, args.n_sample_fid, args.device, n_classes=args.n_classes, ).numpy() sample_mean = np.mean(features, 0) sample_cov = np.cov(features, rowvar=False) fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov) # print("fid:", fid) with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f: f.write(f"{i:07d}; fid: {float(fid):.4f};\n") if i % args.save_every == 0: torch.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"), ) if i % args.save_latest_every == 0: torch.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"latest.pt"), )