def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device): loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, ncols=140, dynamic_ncols=False, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} if args.distributed: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator accum = 0.5**(32 / (10 * 1000)) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 8, device) sample_z = torch.randn(args.n_sample, args.latent, device=device) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break real_img = next(loader) real_img = real_img.to(device) requires_grad(generator, False) requires_grad(discriminator, True) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) fake_img, _ = augment(fake_img, ada_aug_p) else: real_img_aug = real_img fake_pred = discriminator(fake_img) real_pred = discriminator(real_img_aug) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict["d"] = d_loss loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() if args.augment and args.augment_p == 0: ada_aug_p = ada_augment.tune(real_pred) r_t_stat = ada_augment.r_t_stat d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) else: real_img_aug = real_img real_pred = discriminator(real_img_aug) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict["r1"] = r1_loss requires_grad(generator, True) requires_grad(discriminator, False) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) if args.augment: fake_img, _ = augment(fake_img, ada_aug_p) fake_pred = discriminator(fake_img) g_loss = g_nonsaturating_loss(fake_pred) loss_dict["g"] = g_loss generator.zero_grad() g_loss.backward() g_optim.step() g_regularize = i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) fake_img, latents = generator(noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() if get_rank() == 0: pbar.set_description(( f"iter: {i:05d}; d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}")) if wandb and args.wandb: wandb.log({ "Generator": g_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, }) if i % 100 == 0 or (i + 1) == args.iter: with torch.no_grad(): g_ema.eval() sample, _ = g_ema([sample_z]) sample = F.interpolate(sample, 256) utils.save_image( sample, f"log/%s/finetune-%06d.jpg" % (args.style, i), nrow=int(args.n_sample**0.5), normalize=True, range=(-1, 1), ) if (i + 1) % args.save_every == 0 or (i + 1) == args.iter: torch.save( { #"g": g_module.state_dict(), #"d": d_module.state_dict(), "g_ema": g_ema.state_dict(), #"g_optim": g_optim.state_dict(), #"d_optim": d_optim.state_dict(), #"args": args, #"ada_aug_p": ada_aug_p, }, f"%s/%s/fintune-%06d.pt" % (args.model_path, args.style, i + 1), )
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, instyles, Simgs, exstyles, vggloss, id_loss, device): loader = sample_data(loader) vgg_weights = [0.0, 0.5, 1.0, 0.0, 0.0] pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, smoothing=0.01, ncols=180, dynamic_ncols=False) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} if args.distributed: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator accum = 0.5**(32 / (10 * 1000)) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 8, device) sample_instyle = torch.randn(args.n_sample, args.latent, device=device) sample_exstyle, _, _ = get_paired_data(instyles, Simgs, exstyles, batch_size=args.n_sample, random_ind=8) sample_exstyle = sample_exstyle.to(device) for idx in pbar: i = idx + args.start_iter which = i % args.subspace_freq # defines whether we use paired data if i > args.iter: print("Done!") break # sample S real_img = next(loader) real_img = real_img.to(device) requires_grad(generator, False) requires_grad(discriminator, True) if which == 0: # sample z^+_e, z for Lsty, Lcon and Ladv exstyle, _, _ = get_paired_data(instyles, Simgs, exstyles, batch_size=args.batch, random_ind=8) exstyle = exstyle.to(device) instyle = mixing_noise(args.batch, args.latent, args.mixing, device) z_plus_latent = False else: # sample z^+_e, z^+_i and S for Eq. (4) exstyle, instyle, real_img = get_paired_data(instyles, Simgs, exstyles, batch_size=args.batch, random_ind=8) exstyle = exstyle.to(device) instyle = [instyle.to(device)] real_img = real_img.to(device) z_plus_latent = True fake_img, _ = generator(instyle, exstyle, use_res=True, z_plus_latent=z_plus_latent) if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) fake_img, _ = augment(fake_img, ada_aug_p) else: real_img_aug = real_img fake_pred = discriminator(fake_img) real_pred = discriminator(real_img_aug) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict["d"] = d_loss # Ladv loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() if args.augment and args.augment_p == 0: ada_aug_p = ada_augment.tune(real_pred) r_t_stat = ada_augment.r_t_stat d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) else: real_img_aug = real_img real_pred = discriminator(real_img_aug) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict["r1"] = r1_loss requires_grad(generator, True) requires_grad(discriminator, False) if which == 0: # sample z^+_e, z for Lsty, Lcon and Ladv exstyle, _, real_img = get_paired_data(instyles, Simgs, exstyles, batch_size=args.batch, random_ind=8) real_img = real_img.to(device) exstyle = exstyle.to(device) instyle = mixing_noise(args.batch, args.latent, args.mixing, device) z_plus_latent = False else: # sample z^+_e, z^+_i and S for Eq. (4) exstyle, instyle, real_img = get_paired_data(instyles, Simgs, exstyles, batch_size=args.batch, random_ind=8) exstyle = exstyle.to(device) instyle = [instyle.to(device)] real_img = real_img.to(device) z_plus_latent = True fake_img, _ = generator(instyle, exstyle, use_res=True, z_plus_latent=z_plus_latent) with torch.no_grad(): real_img_256 = F.adaptive_avg_pool2d(real_img, 256).detach() real_feats = vggloss(real_img_256) real_styles = [ F.adaptive_avg_pool2d(real_feat, output_size=1).detach() for real_feat in real_feats ] real_content, _ = generator(instyle, None, use_res=False, z_plus_latent=z_plus_latent) real_content_256 = F.adaptive_avg_pool2d(real_content, 256).detach() fake_img_256 = F.adaptive_avg_pool2d(fake_img, 256) fake_feats = vggloss(fake_img_256) fake_styles = [ F.adaptive_avg_pool2d(fake_feat, output_size=1) for fake_feat in fake_feats ] sty_loss = (torch.tensor(0.0).to(device) if args.CX_loss == 0 else FCX.contextual_loss(fake_feats[2], real_feats[2].detach(), band_width=0.2, loss_type='cosine') * args.CX_loss) if args.style_loss > 0: sty_loss += ((F.mse_loss(fake_styles[1], real_styles[1]) + F.mse_loss(fake_styles[2], real_styles[2])) * args.style_loss) ID_loss = (torch.tensor(0.0).to(device) if args.id_loss == 0 else id_loss(fake_img_256, real_content_256) * args.id_loss) gr_loss = torch.tensor(0.0).to(device) if which > 0: for ii, weight in enumerate(vgg_weights): if weight * args.perc_loss > 0: gr_loss += F.l1_loss( fake_feats[ii], real_feats[ii].detach()) * weight * args.perc_loss if args.augment: fake_img, _ = augment(fake_img, ada_aug_p) fake_pred = discriminator(fake_img) g_loss = g_nonsaturating_loss(fake_pred) l2_reg_loss = sum( torch.norm(p) for p in g_module.res.parameters()) * args.L2_reg_loss loss_dict["g"] = g_loss # Ladv loss_dict["gr"] = gr_loss # Lperc loss_dict["l2"] = l2_reg_loss # Lreg in Lcon loss_dict["id"] = ID_loss # LID in Lcon loss_dict["sty"] = sty_loss # Lsty g_loss = g_loss + gr_loss + sty_loss + l2_reg_loss + ID_loss generator.zero_grad() g_loss.backward() g_optim.step() g_regularize = i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) instyle = mixing_noise(path_batch_size, args.latent, args.mixing, device) exstyle, _, _ = get_paired_data(instyles, Simgs, exstyles, batch_size=path_batch_size, random_ind=8) exstyle = exstyle.to(device) fake_img, latents = generator(instyle, exstyle, return_latents=True, use_res=True, z_plus_latent=False) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() accumulate(g_ema.res, g_module.res, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() gr_loss_val = loss_reduced["gr"].mean().item() sty_loss_val = loss_reduced["sty"].mean().item() l2_loss_val = loss_reduced["l2"].mean().item() r1_val = loss_reduced["r1"].mean().item() id_loss_val = loss_reduced["id"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() if get_rank() == 0: pbar.set_description(( f"iter: {i:d}; d: {d_loss_val:.3f}; g: {g_loss_val:.3f}; gr: {gr_loss_val:.3f}; sty: {sty_loss_val:.3f}; l2: {l2_loss_val:.3f}; id: {id_loss_val:.3f}; " f"r1: {r1_val:.3f}; path: {path_loss_val:.3f}; mean path: {mean_path_length_avg:.3f}; " f"augment: {ada_aug_p:.4f};")) if i % 100 == 0 or (i + 1) == args.iter: with torch.no_grad(): g_ema.eval() sample, _ = g_ema([sample_instyle], sample_exstyle, use_res=True) sample = F.interpolate(sample, 256) utils.save_image( sample, f"log/%s/dualstylegan-%06d.jpg" % (args.style, i), nrow=int(args.n_sample**0.5), normalize=True, range=(-1, 1), ) if ((i + 1) >= args.save_begin and (i + 1) % args.save_every == 0) or (i + 1) == args.iter: torch.save( { #"g": g_module.state_dict(), #"d": d_module.state_dict(), "g_ema": g_ema.state_dict(), #"g_optim": g_optim.state_dict(), #"d_optim": d_optim.state_dict(), #"args": args, #"ada_aug_p": ada_aug_p, }, f"%s/%s/%s-%06d.pt" % (args.model_path, args.style, args.model_name, i + 1), )
) discriminator = nn.parallel.DistributedDataParallel( discriminator, device_ids=[args.local_rank], output_device=args.local_rank, broadcast_buffers=False, ) transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), ]) dataset = MultiResolutionDataset(args.path, transform, args.size) loader = data.DataLoader( dataset, batch_size=args.batch, sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed), drop_last=True, ) if get_rank() == 0 and wandb is not None and args.wandb: wandb.init(project="stylegan 2") train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device)
def pretrain(args, loader, generator, discriminator, g_optim, d_optim, g_ema, encoder, vggloss, device, inject_index=5, savemodel=True): loader = sample_data(loader) vgg_weights = [0.5, 0.5, 0.5, 0.0, 0.0] pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, ncols=140, dynamic_ncols=False, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} if args.distributed: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator accum = 0.5**(32 / (10 * 1000)) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 8, device) sample_zs = mixing_noise(args.n_sample, args.latent, 1.0, device) with torch.no_grad(): source_img, _ = generator([sample_zs[0]], None, input_is_latent=False, z_plus_latent=False, use_res=False) source_img = source_img.detach() target_img, _ = generator(sample_zs, None, input_is_latent=False, z_plus_latent=False, inject_index=inject_index, use_res=False) target_img = target_img.detach() style_img, _ = generator([sample_zs[1]], None, input_is_latent=False, z_plus_latent=False, use_res=False) _, sample_style = encoder(F.adaptive_avg_pool2d(style_img, 256), randomize_noise=False, return_latents=True, z_plus_latent=True, return_z_plus_latent=False) sample_style = sample_style.detach() if get_rank() == 0: utils.save_image(F.adaptive_avg_pool2d(source_img, 256), f"log/%s-instyle.jpg" % (args.model_name), nrow=int(args.n_sample**0.5), normalize=True, range=(-1, 1)) utils.save_image(F.adaptive_avg_pool2d(target_img, 256), f"log/%s-target.jpg" % (args.model_name), nrow=int(args.n_sample**0.5), normalize=True, range=(-1, 1)) utils.save_image(F.adaptive_avg_pool2d(style_img, 256), f"log/%s-exstyle.jpg" % (args.model_name), nrow=int(args.n_sample**0.5), normalize=True, range=(-1, 1)) for idx in pbar: i = idx + args.start_iter which = i % args.subspace_freq if i > args.iter: print("Done!") break real_img = next(loader) real_img = real_img.to(device) # real_zs contains z1 and z2 real_zs = mixing_noise(args.batch, args.latent, 1.0, device) with torch.no_grad(): # g(z^+_l) with l=inject_index target_img, _ = generator(real_zs, None, input_is_latent=False, z_plus_latent=False, inject_index=inject_index, use_res=False) target_img = target_img.detach() # g(z2) style_img, _ = generator([real_zs[1]], None, input_is_latent=False, z_plus_latent=False, use_res=False) style_img = style_img.detach() # E(g(z2)) _, pspstyle = encoder(F.adaptive_avg_pool2d(style_img, 256), randomize_noise=False, return_latents=True, z_plus_latent=True, return_z_plus_latent=False) pspstyle = pspstyle.detach() requires_grad(generator, False) requires_grad(discriminator, True) if which > 0: # set z~_2 = z2 noise = [real_zs[0]] externalstyle = g_module.get_latent(real_zs[1]).detach() z_plus_latent = False else: # set z~_2 = E(g(z2)) noise = [real_zs[0].unsqueeze(1).repeat(1, g_module.n_latent, 1)] externalstyle = pspstyle z_plus_latent = True fake_img, _ = generator(noise, externalstyle, use_res=True, z_plus_latent=z_plus_latent) if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) fake_img, _ = augment(fake_img, ada_aug_p) else: real_img_aug = real_img fake_pred = discriminator(fake_img) real_pred = discriminator(real_img_aug) d_loss = d_logistic_loss(real_pred, fake_pred) * 0.1 loss_dict["d"] = d_loss # Ladv loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() if args.augment and args.augment_p == 0: ada_aug_p = ada_augment.tune(real_pred) r_t_stat = ada_augment.r_t_stat d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) else: real_img_aug = real_img real_pred = discriminator(real_img_aug) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict["r1"] = r1_loss requires_grad(generator, True) requires_grad(discriminator, False) if which > 0: # set z~_2 = z2 noise = [real_zs[0]] externalstyle = g_module.get_latent(real_zs[1]).detach() z_plus_latent = False else: # set z~_2 = E(g(z2)) noise = [real_zs[0].unsqueeze(1).repeat(1, g_module.n_latent, 1)] externalstyle = pspstyle z_plus_latent = True fake_img, _ = generator(noise, externalstyle, use_res=True, z_plus_latent=z_plus_latent) real_feats = vggloss(F.adaptive_avg_pool2d(target_img, 256).detach()) fake_feats = vggloss(F.adaptive_avg_pool2d(fake_img, 256)) gr_loss = torch.tensor(0.0).to(device) for ii, weight in enumerate(vgg_weights): if weight > 0: gr_loss += F.l1_loss(fake_feats[ii], real_feats[ii].detach()) * weight if args.augment: fake_img, _ = augment(fake_img, ada_aug_p) fake_pred = discriminator(fake_img) g_loss = g_nonsaturating_loss(fake_pred) * 0.1 loss_dict["g"] = g_loss # Ladv loss_dict["gr"] = gr_loss # L_perc g_loss += gr_loss generator.zero_grad() g_loss.backward() g_optim.step() g_regularize = i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) externalstyle = torch.randn(path_batch_size, 512, device=device) externalstyle = g_module.get_latent(externalstyle).detach() fake_img, latents = generator(noise, externalstyle, return_latents=True, use_res=True, z_plus_latent=False) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() accumulate(g_ema.res, g_module.res, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() gr_loss_val = loss_reduced["gr"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() if get_rank() == 0: pbar.set_description(( f"iter: {i:d}; d: {d_loss_val:.3f}; g: {g_loss_val:.3f}; gr: {gr_loss_val:.3f}; r1: {r1_val:.3f}; " f"path: {path_loss_val:.3f}; mean path: {mean_path_length_avg:.3f}; " f"augment: {ada_aug_p:.1f}")) if i % 300 == 0 or (i + 1) == args.iter: with torch.no_grad(): g_ema.eval() sample, _ = g_ema([ sample_zs[0].unsqueeze(1).repeat( 1, g_module.n_latent, 1) ], sample_style, use_res=True, z_plus_latent=True) sample = F.interpolate(sample, 256) utils.save_image( sample, f"log/%s-%06d.jpg" % (args.model_name, i), nrow=int(args.n_sample**0.5), normalize=True, range=(-1, 1), ) if savemodel and ((i + 1) % args.save_every == 0 or (i + 1) == args.iter): torch.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, }, f"%s/%s-%06d.pt" % (args.model_path, args.model_name, i + 1), )