def validation_epoch_end(self, outputs): batch = outputs[0]["batch"] gc.collect() th.cuda.empty_cache() val_fid = validation.fid( self.g_ema.to(batch.device), self.val_batch_size, self.fid_n_sample, self.fid_truncation, self.name, )["FID"] val_ppl = validation.ppl( self.g_ema.to(batch.device), self.val_batch_size, self.ppl_n_sample, self.ppl_space, self.ppl_crop, self.latent_size, ) with th.no_grad(): self.g_ema.eval() sample, _ = self.g_ema( [self.sample_z.to(next(self.g_ema.parameters()).device)]) grid = tv.utils.make_grid(sample, nrow=int( round(4.0 / 3 * self.n_sample**0.5)), normalize=True, range=(-1, 1)) self.logger.experiment.log({ "Generated Images EMA": [wandb.Image(grid, caption=f"Step {self.global_step}")] }) self.generator.eval() sample, _ = self.generator( [self.sample_z.to(next(self.generator.parameters()).device)]) grid = tv.utils.make_grid(sample, nrow=int( round(4.0 / 3 * self.n_sample**0.5)), normalize=True, range=(-1, 1)) self.logger.experiment.log({ "Generated Images": [wandb.Image(grid, caption=f"Step {self.global_step}")] }) self.generator.train() # val_fid = [score for score in outputs[0]["FID"] if score != -69][0] # val_ppl = [score for score in outputs[0]["PPL"] if score != -69][0] gc.collect() th.cuda.empty_cache() return { "val_loss": val_fid, "log": { "Validation/FID": val_fid, "Validation/PPL": val_ppl } }
def train(args, loader, generator, discriminator, contrast_learner, augment, g_optim, d_optim, scaler, g_ema, device): loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = th.zeros(size=(1,), device=device) g_loss_val = 0 path_loss = th.zeros(size=(1,), device=device) path_lengths = th.zeros(size=(1,), device=device) loss_dict = {} mse = th.nn.MSELoss() if args.distributed: g_module = generator.module d_module = discriminator.module if contrast_learner is not None: cl_module = contrast_learner.module else: g_module = generator d_module = discriminator cl_module = contrast_learner sample_z = th.randn(args.n_sample, args.latent_size, device=device) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break requires_grad(generator, False) requires_grad(discriminator, True) discriminator.zero_grad() loss_dict["d"], loss_dict["real_score"], loss_dict["fake_score"] = 0, 0, 0 loss_dict["cl_reg"], loss_dict["bc_reg"] = ( th.tensor(0, device=device).float(), th.tensor(0, device=device).float(), ) for _ in range(args.num_accumulate): # sample = [] # for _ in range(0, len(sample_z), args.batch_size): # subsample = next(loader) # sample.append(subsample) # sample = th.cat(sample) # utils.save_image(sample, "reals-no-augment.png", nrow=10, normalize=True) # utils.save_image(augment(sample), "reals-augment.png", nrow=10, normalize=True) real_img = next(loader) real_img = real_img.to(device) # with th.cuda.amp.autocast(): noise = make_noise(args.batch_size, args.latent_size, args.mixing_prob, device) fake_img, _ = generator(noise) if args.augment_D: fake_pred = discriminator(augment(fake_img)) real_pred = discriminator(augment(real_img)) else: fake_pred = discriminator(fake_img) real_pred = discriminator(real_img) # logistic loss real_loss = F.softplus(-real_pred) fake_loss = F.softplus(fake_pred) d_loss = real_loss.mean() + fake_loss.mean() loss_dict["d"] += d_loss.detach() loss_dict["real_score"] += real_pred.mean().detach() loss_dict["fake_score"] += fake_pred.mean().detach() if i > 10000 or i == 0: if args.contrastive > 0: contrast_learner(fake_img.clone().detach(), accumulate=True) contrast_learner(real_img, accumulate=True) contrast_loss = cl_module.calculate_loss() loss_dict["cl_reg"] += contrast_loss.detach() d_loss += args.contrastive * contrast_loss if args.balanced_consistency > 0: aug_fake_pred = discriminator(augment(fake_img.clone().detach())) aug_real_pred = discriminator(augment(real_img)) consistency_loss = mse(real_pred, aug_real_pred) + mse(fake_pred, aug_fake_pred) loss_dict["bc_reg"] += consistency_loss.detach() d_loss += args.balanced_consistency * consistency_loss d_loss /= args.num_accumulate # scaler.scale(d_loss).backward() d_loss.backward() # scaler.step(d_optim) d_optim.step() # R1 regularization if args.r1 > 0 and i % args.d_reg_every == 0: discriminator.zero_grad() loss_dict["r1"] = 0 for _ in range(args.num_accumulate): real_img = next(loader) real_img = real_img.to(device) real_img.requires_grad = True # with th.cuda.amp.autocast(): # if args.augment_D: # real_pred = discriminator( # augment(real_img) # ) # RuntimeError: derivative for grid_sampler_2d_backward is not implemented :( # else: real_pred = discriminator(real_img) real_pred_sum = real_pred.sum() (grad_real,) = th.autograd.grad(outputs=real_pred_sum, inputs=real_img, create_graph=True) # (grad_real,) = th.autograd.grad(outputs=scaler.scale(real_pred_sum), inputs=real_img, create_graph=True) # grad_real = grad_real * (1.0 / scaler.get_scale()) # with th.cuda.amp.autocast(): r1_loss = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() weighted_r1_loss = args.r1 / 2.0 * r1_loss * args.d_reg_every + 0 * real_pred[0] loss_dict["r1"] += r1_loss.detach() weighted_r1_loss /= args.num_accumulate # scaler.scale(weighted_r1_loss).backward() weighted_r1_loss.backward() # scaler.step(d_optim) d_optim.step() requires_grad(generator, True) requires_grad(discriminator, False) generator.zero_grad() loss_dict["g"] = 0 for _ in range(args.num_accumulate): # with th.cuda.amp.autocast(): noise = make_noise(args.batch_size, args.latent_size, args.mixing_prob, device) fake_img, _ = generator(noise) if args.augment_G: fake_img = augment(fake_img) fake_pred = discriminator(fake_img) # non-saturating loss g_loss = F.softplus(-fake_pred).mean() loss_dict["g"] += g_loss.detach() g_loss /= args.num_accumulate # scaler.scale(g_loss).backward() g_loss.backward() # scaler.step(g_optim) g_optim.step() # path length regularization if args.path_regularize > 0 and i % args.g_reg_every == 0: generator.zero_grad() loss_dict["path"], loss_dict["path_length"] = 0, 0 for _ in range(args.num_accumulate): path_batch_size = max(1, args.batch_size // args.path_batch_shrink) # with th.cuda.amp.autocast(): noise = make_noise(path_batch_size, args.latent_size, args.mixing_prob, device) fake_img, latents = generator(noise, return_latents=True) img_noise = th.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3]) noisy_img_sum = (fake_img * img_noise).sum() (grad,) = th.autograd.grad(outputs=noisy_img_sum, inputs=latents, create_graph=True) # (grad,) = th.autograd.grad(outputs=scaler.scale(noisy_img_sum), inputs=latents, create_graph=True) # grad = grad * (1.0 / scaler.get_scale()) # with th.cuda.amp.autocast(): path_lengths = th.sqrt(grad.pow(2).sum(2).mean(1)) path_mean = mean_path_length + 0.01 * (path_lengths.mean() - mean_path_length) path_loss = (path_lengths - path_mean).pow(2).mean() mean_path_length = path_mean.detach() loss_dict["path"] += path_loss.detach() loss_dict["path_length"] += path_lengths.mean().detach() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss /= args.num_accumulate # scaler.scale(weighted_path_loss).backward() weighted_path_loss.backward() # scaler.step(g_optim) g_optim.step() # scaler.update() accumulate(g_ema, g_module) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() / args.num_accumulate g_loss_val = loss_reduced["g"].mean().item() / args.num_accumulate cl_reg_val = loss_reduced["cl_reg"].mean().item() / args.num_accumulate bc_reg_val = loss_reduced["bc_reg"].mean().item() / args.num_accumulate r1_val = loss_reduced["r1"].mean().item() / args.num_accumulate path_loss_val = loss_reduced["path"].mean().item() / args.num_accumulate real_score_val = loss_reduced["real_score"].mean().item() / args.num_accumulate fake_score_val = loss_reduced["fake_score"].mean().item() / args.num_accumulate path_length_val = loss_reduced["path_length"].mean().item() / args.num_accumulate if get_rank() == 0: log_dict = { "Generator": g_loss_val, "Discriminator": d_loss_val, "Real Score": real_score_val, "Fake Score": fake_score_val, "Contrastive": cl_reg_val, "Consistency": bc_reg_val, } if args.log_spec_norm: G_norms = [] for name, spec_norm in g_module.named_buffers(): if "spectral_norm" in name: G_norms.append(spec_norm.cpu().numpy()) G_norms = np.array(G_norms) D_norms = [] for name, spec_norm in d_module.named_buffers(): if "spectral_norm" in name: D_norms.append(spec_norm.cpu().numpy()) D_norms = np.array(D_norms) log_dict[f"Spectral Norms/G min spectral norm"] = np.log(G_norms).min() log_dict[f"Spectral Norms/G mean spectral norm"] = np.log(G_norms).mean() log_dict[f"Spectral Norms/G max spectral norm"] = np.log(G_norms).max() log_dict[f"Spectral Norms/D min spectral norm"] = np.log(D_norms).min() log_dict[f"Spectral Norms/D mean spectral norm"] = np.log(D_norms).mean() log_dict[f"Spectral Norms/D max spectral norm"] = np.log(D_norms).max() if args.r1 > 0 and i % args.d_reg_every == 0: log_dict["R1"] = r1_val if args.path_regularize > 0 and i % args.g_reg_every == 0: log_dict["Path Length Regularization"] = path_loss_val log_dict["Mean Path Length"] = mean_path_length log_dict["Path Length"] = path_length_val if i % args.img_every == 0: gc.collect() th.cuda.empty_cache() with th.no_grad(): g_ema.eval() sample = [] for sub in range(0, len(sample_z), args.batch_size): subsample, _ = g_ema([sample_z[sub : sub + args.batch_size]]) sample.append(subsample.cpu()) sample = th.cat(sample) grid = utils.make_grid(sample, nrow=10, normalize=True, range=(-1, 1)) # utils.save_image(sample, "fakes-no-augment.png", nrow=10, normalize=True) # utils.save_image(augment(sample), "fakes-augment.png", nrow=10, normalize=True) # exit() log_dict["Generated Images EMA"] = [wandb.Image(grid, caption=f"Step {i}")] if i % args.eval_every == 0: start_time = time.time() pbar.set_description((f"Calculating FID...")) fid_dict = validation.fid(g_ema, args.val_batch_size, args.fid_n_sample, args.fid_truncation, args.name) fid = fid_dict["FID"] density = fid_dict["Density"] coverage = fid_dict["Coverage"] pbar.set_description((f"Calculating PPL...")) ppl = validation.ppl( g_ema, args.val_batch_size, args.ppl_n_sample, args.ppl_space, args.ppl_crop, args.latent_size, ) pbar.set_description( ( f"FID: {fid:.4f}; Density: {density:.4f}; Coverage: {coverage:.4f}; PPL: {ppl:.4f} in {time.time() - start_time:.1f}s" ) ) log_dict["Evaluation/FID"] = fid log_dict["Evaluation/Density"] = density log_dict["Evaluation/Coverage"] = coverage log_dict["Evaluation/PPL"] = ppl gc.collect() th.cuda.empty_cache() wandb.log(log_dict) if i % args.checkpoint_every == 0: th.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), # "cl": cl_module.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "d_optim": d_optim.state_dict(), }, f"/home/hans/modelzoo/maua-sg2/{args.name}-{wandb.run.dir.split('/')[-1].split('-')[-1]}-{int(fid)}-{int(ppl)}-{str(i).zfill(6)}.pt", )
def train(args, loader, generator, discriminator, contrast_learner, g_optim, d_optim, g_ema): if args.distributed: g_module = generator.module d_module = discriminator.module if contrast_learner is not None: cl_module = contrast_learner.module else: g_module = generator d_module = discriminator cl_module = contrast_learner loader = sample_data(loader) sample_z = th.randn(args.n_sample, args.latent_size, device=device) mse = th.nn.MSELoss() mean_path_length = 0 ada_augment = th.tensor([0.0, 0.0], device=device) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 ada_aug_step = args.ada_target / args.ada_length r_t_stat = 0 fids = [] pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break loss_dict = { "Generator": th.tensor(0, device=device).float(), "Discriminator": th.tensor(0, device=device).float(), "Real Score": th.tensor(0, device=device).float(), "Fake Score": th.tensor(0, device=device).float(), "Contrastive": th.tensor(0, device=device).float(), "Consistency": th.tensor(0, device=device).float(), "R1 Penalty": th.tensor(0, device=device).float(), "Path Length Regularization": th.tensor(0, device=device).float(), "Augment": th.tensor(0, device=device).float(), "Rt": th.tensor(0, device=device).float(), } requires_grad(generator, False) requires_grad(discriminator, True) discriminator.zero_grad() for _ in range(args.num_accumulate): real_img_og = next(loader).to(device) noise = make_noise(args.batch_size, args.latent_size, args.mixing_prob) fake_img_og, _ = generator(noise) if args.augment: fake_img, _ = augment(fake_img_og, ada_aug_p) real_img, _ = augment(real_img_og, ada_aug_p) else: fake_img = fake_img_og real_img = real_img_og fake_pred = discriminator(fake_img) real_pred = discriminator(real_img) logistic_loss = d_logistic_loss(real_pred, fake_pred) loss_dict["Discriminator"] += logistic_loss.detach() loss_dict["Real Score"] += real_pred.mean().detach() loss_dict["Fake Score"] += fake_pred.mean().detach() d_loss = logistic_loss if args.contrastive > 0: contrast_learner(fake_img_og, fake_img, accumulate=True) contrast_learner(real_img_og, real_img, accumulate=True) contrast_loss = cl_module.calculate_loss() loss_dict["Contrastive"] += contrast_loss.detach() d_loss += args.contrastive * contrast_loss if args.balanced_consistency > 0: consistency_loss = mse( real_pred, discriminator(real_img_og)) + mse( fake_pred, discriminator(fake_img_og)) loss_dict["Consistency"] += consistency_loss.detach() d_loss += args.balanced_consistency * consistency_loss d_loss /= args.num_accumulate d_loss.backward() d_optim.step() if args.r1 > 0 and i % args.d_reg_every == 0: discriminator.zero_grad() for _ in range(args.num_accumulate): real_img = next(loader).to(device) real_img.requires_grad = True real_pred = discriminator(real_img) r1_loss = d_r1_penalty(real_img, real_pred, args) loss_dict["R1 Penalty"] += r1_loss.detach().squeeze() r1_loss = args.r1 * args.d_reg_every * r1_loss / args.num_accumulate r1_loss.backward() d_optim.step() if args.augment and args.augment_p == 0: ada_augment += th.tensor( (th.sign(real_pred).sum().item(), real_pred.shape[0]), device=device) ada_augment = reduce_sum(ada_augment) if ada_augment[1] > 255: pred_signs, n_pred = ada_augment.tolist() r_t_stat = pred_signs / n_pred loss_dict["Rt"] = th.tensor(r_t_stat, device=device).float() if r_t_stat > args.ada_target: sign = 1 else: sign = -1 ada_aug_p += sign * ada_aug_step * n_pred ada_aug_p = min(1, max(0, ada_aug_p)) ada_augment.mul_(0) loss_dict["Augment"] = th.tensor(ada_aug_p, device=device).float() requires_grad(generator, True) requires_grad(discriminator, False) generator.zero_grad() for _ in range(args.num_accumulate): noise = make_noise(args.batch_size, args.latent_size, args.mixing_prob) fake_img, _ = generator(noise) if args.augment: fake_img, _ = augment(fake_img, ada_aug_p) fake_pred = discriminator(fake_img) g_loss = g_non_saturating_loss(fake_pred) loss_dict["Generator"] += g_loss.detach() g_loss /= args.num_accumulate g_loss.backward() g_optim.step() if args.path_regularize > 0 and i % args.g_reg_every == 0: generator.zero_grad() for _ in range(args.num_accumulate): path_loss, mean_path_length = g_path_length_regularization( generator, mean_path_length, args) loss_dict["Path Length Regularization"] += path_loss.detach() path_loss = args.path_regularize * args.g_reg_every * path_loss / args.num_accumulate path_loss.backward() g_optim.step() accumulate(g_ema, g_module) loss_reduced = reduce_loss_dict(loss_dict) log_dict = { k: v.mean().item() / args.num_accumulate for k, v in loss_reduced.items() if v != 0 } if get_rank() == 0: if args.log_spec_norm: G_norms = [] for name, spec_norm in g_module.named_buffers(): if "spectral_norm" in name: G_norms.append(spec_norm.cpu().numpy()) G_norms = np.array(G_norms) D_norms = [] for name, spec_norm in d_module.named_buffers(): if "spectral_norm" in name: D_norms.append(spec_norm.cpu().numpy()) D_norms = np.array(D_norms) log_dict[f"Spectral Norms/G min spectral norm"] = np.log( G_norms).min() log_dict[f"Spectral Norms/G mean spectral norm"] = np.log( G_norms).mean() log_dict[f"Spectral Norms/G max spectral norm"] = np.log( G_norms).max() log_dict[f"Spectral Norms/D min spectral norm"] = np.log( D_norms).min() log_dict[f"Spectral Norms/D mean spectral norm"] = np.log( D_norms).mean() log_dict[f"Spectral Norms/D max spectral norm"] = np.log( D_norms).max() if i % args.img_every == 0: gc.collect() th.cuda.empty_cache() with th.no_grad(): g_ema.eval() sample = [] for sub in range(0, len(sample_z), args.batch_size): subsample, _ = g_ema( [sample_z[sub:sub + args.batch_size]]) sample.append(subsample.cpu()) sample = th.cat(sample) grid = utils.make_grid(sample, nrow=10, normalize=True, range=(-1, 1)) log_dict["Generated Images EMA"] = [ wandb.Image(grid, caption=f"Step {i}") ] if i % args.eval_every == 0: fid_dict = validation.fid(g_ema, args.val_batch_size, args.fid_n_sample, args.fid_truncation, args.name) fid = fid_dict["FID"] fids.append(fid) density = fid_dict["Density"] coverage = fid_dict["Coverage"] ppl = validation.ppl( g_ema, args.val_batch_size, args.ppl_n_sample, args.ppl_space, args.ppl_crop, args.latent_size, ) log_dict["Evaluation/FID"] = fid log_dict["Sweep/FID_smooth"] = gaussian_filter( np.array(fids), [5])[-1] log_dict["Evaluation/Density"] = density log_dict["Evaluation/Coverage"] = coverage log_dict["Evaluation/PPL"] = ppl gc.collect() th.cuda.empty_cache() wandb.log(log_dict) description = ( f"FID: {fid:.4f} PPL: {ppl:.4f} Dens: {density:.4f} Cov: {coverage:.4f} " + f"G: {log_dict['Generator']:.4f} D: {log_dict['Discriminator']:.4f}" ) if "Augment" in log_dict: description += f" Aug: {log_dict['Augment']:.4f}" # Rt: {log_dict['Rt']:.4f}" if "R1 Penalty" in log_dict: description += f" R1: {log_dict['R1 Penalty']:.4f}" if "Path Length Regularization" in log_dict: description += f" Path: {log_dict['Path Length Regularization']:.4f}" pbar.set_description(description) if i % args.checkpoint_every == 0: check_name = "-".join([ args.name, args.runname, wandb.run.dir.split("/")[-1].split("-")[-1], int(fid), args.size, str(i).zfill(6), ]) th.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), # "cl": cl_module.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "d_optim": d_optim.state_dict(), }, f"/home/hans/modelzoo/maua-sg2/{check_name}.pt", )