def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device): ckpt_dir = 'checkpoints/stylegan' if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) fig_dir = 'figs/stylegan' if not os.path.exists(fig_dir): os.makedirs(fig_dir) loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} if args.distributed: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator accum = 0.5**(32 / (10 * 1000)) ada_augment = torch.tensor([0.0, 0.0], device=device) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 ada_aug_step = args.ada_target / args.ada_length r_t_stat = 0 sample_z = torch.randn(args.n_sample, args.latent, device=device) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break real_img = next(loader) real_img = real_img.to(device) requires_grad(generator, False) requires_grad(discriminator, True) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) fake_img, _ = augment(fake_img, ada_aug_p) else: real_img_aug = real_img fake_pred = discriminator(fake_img) real_pred = discriminator(real_img_aug) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict["d"] = d_loss loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() if args.augment and args.augment_p == 0: ada_augment_data = torch.tensor( (torch.sign(real_pred).sum().item(), real_pred.shape[0]), device=device) ada_augment += reduce_sum(ada_augment_data) if ada_augment[1] > 255: pred_signs, n_pred = ada_augment.tolist() r_t_stat = pred_signs / n_pred if r_t_stat > args.ada_target: sign = 1 else: sign = -1 ada_aug_p += sign * ada_aug_step * n_pred ada_aug_p = min(1, max(0, ada_aug_p)) ada_augment.mul_(0) d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True real_pred = discriminator(real_img) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict["r1"] = r1_loss requires_grad(generator, True) requires_grad(discriminator, False) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) if args.augment: fake_img, _ = augment(fake_img, ada_aug_p) fake_pred = discriminator(fake_img) g_loss = g_nonsaturating_loss(fake_pred) loss_dict["g"] = g_loss generator.zero_grad() g_loss.backward() g_optim.step() g_regularize = i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) fake_img, latents = generator(noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() if get_rank() == 0: pbar.set_description(( f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}")) if wandb and args.wandb: wandb.log({ "Generator": g_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, }) if i % 100 == 0: with torch.no_grad(): g_ema.eval() sample, _ = g_ema([sample_z]) utils.save_image( sample, f"figs/stylegan/{str(i).zfill(6)}.png", nrow=int(args.n_sample**0.5), normalize=True, range=(-1, 1), ) if i % 10000 == 0: torch.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, }, f"checkpoints/stylegan/{str(i).zfill(6)}.pt", )
def train(args, loader, generator, discriminator, optimizer, g_ema, device): loader = sample_data(loader) pbar = range(args.iter) pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 r1_loss = torch.tensor(0.0, device=device) path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} if args.gpu_num > 1: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator accum = 0.5**(32 / (10 * 1000)) ada_augment = torch.tensor([0.0, 0.0], device=device) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 ada_aug_step = args.ada_target / args.ada_length r_t_stat = 0 sample_z = torch.randn(args.n_sample, args.latent, device=device) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break real_img = next(loader) real_img = real_img.to(device) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) fake_img, _ = augment(fake_img, ada_aug_p) else: real_img_aug = real_img fake_pred = discriminator(fake_img) real_pred = discriminator(real_img_aug) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict["loss"] = d_loss.item() loss_dict["real_score"] = real_pred.mean().item() loss_dict["fake_score"] = fake_pred.mean().item() d_regularize = i % args.d_reg_every == 0 # d_regularize = False if d_regularize: real_img_cp = real_img.clone().detach() real_img_cp.requires_grad = True real_pred_cp = discriminator(real_img_cp) r1_loss = d_r1_loss(real_pred_cp, real_img_cp) d_loss += args.r1 / 2 * r1_loss * args.d_reg_every loss_dict["r1"] = r1_loss.item() # g_regularize = i % args.g_reg_every == 0 g_regularize = False if g_regularize: # TODO adapt code for nn.DataParallel path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) fake_img, latents = generator(noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] d_loss += weighted_path_loss mean_path_length_avg = mean_path_length.item() loss_dict["path"] = path_loss.mean().item() loss_dict["path_length"] = path_lengths.mean().item() optimizer.step(d_loss) # update ada_aug_p if args.augment and args.augment_p == 0: ada_augment_data = torch.tensor( (torch.sign(real_pred).sum().item(), real_pred.shape[0]), device=device) ada_augment += ada_augment_data if ada_augment[1] > 255: pred_signs, n_pred = ada_augment.tolist() r_t_stat = pred_signs / n_pred if r_t_stat > args.ada_target: sign = 1 else: sign = -1 ada_aug_p += sign * ada_aug_step * n_pred ada_aug_p = min(1, max(0, ada_aug_p)) ada_augment.mul_(0) accumulate(g_ema, g_module, accum) d_loss_val = loss_dict["loss"] r1_val = loss_dict['r1'] path_loss_val = loss_dict["path"] real_score_val = loss_dict["real_score"] fake_score_val = loss_dict["fake_score"] path_length_val = loss_dict["path_length"] pbar.set_description(( f"d: {d_loss_val:.4f}; g: {d_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}")) if wandb and args.wandb: wandb.log({ "Generator": d_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, }) if i % 100 == 0: with torch.no_grad(): g_ema.eval() sample, _ = g_ema([sample_z]) utils.save_image( sample, f"figs/stylegan-acgd/{str(i).zfill(6)}.png", nrow=int(args.n_sample**0.5), normalize=True, range=(-1, 1), ) if i % 100 == 0: torch.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "d_optim": optimizer.state_dict(), "args": args, "ada_aug_p": ada_aug_p, }, f"checkpoints/stylegan-acgd/{str(i).zfill(6)}.pt", )
args.start_iter = 0 generator = Generator( args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier).to(device) discriminator = Discriminator( args.size, channel_multiplier=args.channel_multiplier).to(device) g_ema = Generator(args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier).to(device) g_ema.eval() accumulate(g_ema, generator, 0) g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1) d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1) g_optim = optim.Adam( generator.parameters(), lr=args.lr * g_reg_ratio, betas=(0**g_reg_ratio, 0.99**g_reg_ratio), ) d_optim = optim.Adam( discriminator.parameters(), lr=args.lr * d_reg_ratio, betas=(0**d_reg_ratio, 0.99**d_reg_ratio), )
def train(args, loader, generator, discriminator, optimizer, g_ema, device): collect_info = True ckpt_dir = 'checkpoints/stylegan-acgd' if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) fig_dir = 'figs/stylegan-acgd' if not os.path.exists(fig_dir): os.makedirs(fig_dir) loader = sample_data(loader) pbar = range(args.iter) pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) accs = torch.tensor([1.0 for i in range(50)]) loss_dict = {} if args.gpu_num > 1: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator accum = 0.5 ** (32 / (10 * 1000)) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 sample_z = torch.randn(args.n_sample, args.latent, device=device) ada_ratio = 2 for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break real_img = next(loader) real_img = real_img.to(device) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) fake_img, _ = augment(fake_img, ada_aug_p) else: real_img_aug = real_img fake_pred = discriminator(fake_img) real_pred = discriminator(real_img_aug) d_loss = d_logistic_loss(real_pred, fake_pred) optimizer.step(d_loss) num_correct = torch.sum(real_pred > 0) + torch.sum(fake_pred < 0) acc = num_correct.item() / (fake_pred.shape[0] + real_pred.shape[0]) loss_dict["loss"] = d_loss.item() loss_dict["real_score"] = real_pred.mean().item() loss_dict["fake_score"] = fake_pred.mean().item() # update ada_ratio accs[i % 50] = acc acc_indicator = sum(accs) / 50 if i % 2 == 0: if acc_indicator > 0.85: ada_ratio += 1 elif acc_indicator < 0.75: ada_ratio -= 1 max_ratio = 2 ** min(4, ada_ratio) min_ratio = 2 ** min(0, 4 - ada_ratio) if args.ada_train: print('Adjust lrs') optimizer.set_lr(lr_max=max_ratio * args.lr_d, lr_min=min_ratio * args.lr_d) accumulate(g_ema, g_module, accum) d_loss_val = loss_dict["loss"] real_score_val = loss_dict["real_score"] fake_score_val = loss_dict["fake_score"] pbar.set_description( ( f"d: {d_loss_val:.4f}; g: {d_loss_val:.4f}; Acc: {acc:.4f}; " f"augment: {ada_aug_p:.4f}" ) ) if wandb and args.wandb: if collect_info: cgd_info = optimizer.get_info() wandb.log( { 'CG iter num': cgd_info['iter_num'], 'CG runtime': cgd_info['time'], 'D gradient': cgd_info['grad_y'], 'G gradient': cgd_info['grad_x'], 'D hvp': cgd_info['hvp_y'], 'G hvp': cgd_info['hvp_x'], 'D cg': cgd_info['cg_y'], 'G cg': cgd_info['cg_x'] }, step=i, ) wandb.log( { "Generator": d_loss_val, "Discriminator": d_loss_val, "Ada ratio": ada_ratio, 'Generator lr': max_ratio * args.lr_d, 'Discriminator lr': min_ratio * args.lr_d, "Rt": r_t_stat, "Accuracy": acc_indicator, "Real Score": real_score_val, "Fake Score": fake_score_val }, step=i, ) if i % 100 == 0: with torch.no_grad(): g_ema.eval() sample, _ = g_ema([sample_z]) utils.save_image( sample, f"figs/stylegan-acgd/{str(i).zfill(6)}.png", nrow=int(args.n_sample ** 0.5), normalize=True, range=(-1, 1), ) if i % 2000 == 0: torch.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "d_optim": optimizer.state_dict(), "args": args, "ada_aug_p": ada_aug_p, }, f"checkpoints/stylegan-acgd/fix{str(i).zfill(6)}.pt", )