def train(args, loader, encoder, generator, discriminator, discriminator_z, g1, vggnet, pwcnet, e_optim, d_optim, dz_optim, g1_optim, e_ema, e_tf, g1_ema, device): mmd_eval = functools.partial(mix_rbf_mmd2, sigma_list=[2.0, 5.0, 10.0, 20.0, 40.0, 80.0]) loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) d_loss_val = 0 e_loss_val = 0 rec_loss_val = 0 vgg_loss_val = 0 adv_loss_val = 0 loss_dict = { "d": torch.tensor(0., device=device), "real_score": torch.tensor(0., device=device), "fake_score": torch.tensor(0., device=device), "r1_d": torch.tensor(0., device=device), "r1_e": torch.tensor(0., device=device), "rec": torch.tensor(0., device=device), } avg_pix_loss = util.AverageMeter() avg_vgg_loss = util.AverageMeter() if args.distributed: e_module = encoder.module d_module = discriminator.module g_module = generator.module g1_module = g1.module if args.train_latent_mlp else None else: e_module = encoder d_module = discriminator g_module = generator g1_module = g1 if args.train_latent_mlp else None accum = 0.5**(32 / (10 * 1000)) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256, device) # sample_x = accumulate_batches(loader, args.n_sample).to(device) sample_x = load_real_samples(args, loader) requires_grad(generator, False) # always False generator.eval() # Generator should be ema and in eval mode # if args.no_ema or e_ema is None: # e_ema = encoder for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break real_img = next(loader) real_img = real_img.to(device) batch = real_img.shape[0] # Train Encoder if args.toggle_grads: requires_grad(encoder, True) requires_grad(discriminator, False) pix_loss = vgg_loss = adv_loss = rec_loss = torch.tensor(0., device=device) kld_z = torch.tensor(0., device=device) mmd_z = torch.tensor(0., device=device) gan_z = torch.tensor(0., device=device) etf_z = torch.tensor(0., device=device) latent_real, logvar = encoder(real_img) if args.reparameterization: latent_real = reparameterize(latent_real, logvar) if args.train_latent_mlp: fake_img, _ = generator([g1(latent_real)], input_is_latent=True, return_latents=False) else: fake_img, _ = generator([latent_real], input_is_latent=False, return_latents=False) if args.lambda_adv > 0: if args.augment: fake_img_aug, _ = augment(fake_img, ada_aug_p) else: fake_img_aug = fake_img fake_pred = discriminator(fake_img_aug) adv_loss = g_nonsaturating_loss(fake_pred) if args.lambda_pix > 0: pix_loss = torch.mean((real_img - fake_img)**2) if args.lambda_vgg > 0: real_feat = vggnet(real_img) fake_feat = vggnet(fake_img) vgg_loss = torch.mean((real_feat - fake_feat)**2) if args.lambda_kld_z > 0: z_mean = latent_real.view(batch, -1) kld_z = -0.5 * torch.sum(1. + logvar - z_mean.pow(2) - logvar.exp()) / batch # print(kld_z) if args.lambda_mmd_z > 0: z_real = torch.randn(batch, args.latent_full, device=device) mmd_z = mmd_eval(latent_real, z_real) # print(mmd_z) if args.lambda_gan_z > 0: fake_pred = discriminator_z(latent_real) gan_z = g_nonsaturating_loss(fake_pred) # print(gan_z) if args.use_latent_teacher_forcing and args.lambda_etf > 0: w_tf, _ = e_tf(real_img) if args.train_latent_mlp: w_pred = g1(latent_real) else: w_pred = generator.get_latent(latent_real) etf_z = torch.mean((w_tf - w_pred)**2) # print(etf_z) if args.train_on_fake and args.lambda_rec > 0: z_real = torch.randn(args.batch, args.latent_full, device=device) if args.train_latent_mlp: fake_img, _ = generator([g1(z_real)], input_is_latent=True, return_latents=False) else: fake_img, _ = generator([z_real], input_is_latent=False, return_latents=False) # fake_img, _ = generator([z_real], input_is_latent=False, return_latents=True) z_fake, z_logvar = encoder(fake_img) if args.reparameterization: z_fake = reparameterize(z_fake, z_logvar) rec_loss = torch.mean((z_real - z_fake)**2) loss_dict["rec"] = rec_loss # print(rec_loss) e_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv e_loss = e_loss + args.lambda_kld_z * kld_z + args.lambda_mmd_z * mmd_z + args.lambda_gan_z * gan_z + args.lambda_etf * etf_z + rec_loss * args.lambda_rec loss_dict["e"] = e_loss loss_dict["pix"] = pix_loss loss_dict["vgg"] = vgg_loss loss_dict["adv"] = adv_loss if args.train_latent_mlp and g1 is not None: g1.zero_grad() encoder.zero_grad() e_loss.backward() e_optim.step() if args.train_latent_mlp and g1_optim is not None: g1_optim.step() # if args.train_on_fake: # e_regularize = args.e_rec_every > 0 and i % args.e_rec_every == 0 # if e_regularize and args.lambda_rec > 0: # # noise = mixing_noise(args.batch, args.latent, args.mixing, device) # # fake_img, latent_fake = generator(noise, input_is_latent=False, return_latents=True) # z_real = torch.randn(args.batch, args.latent_full, device=device) # fake_img, w_real = generator([z_real], input_is_latent=False, return_latents=True) # z_fake, logvar = encoder(fake_img) # if args.reparameterization: # z_fake = reparameterize(z_fake, logvar) # rec_loss = torch.mean((z_real - z_fake) ** 2) # encoder.zero_grad() # (rec_loss * args.lambda_rec).backward() # e_optim.step() # loss_dict["rec"] = rec_loss e_regularize = args.e_reg_every > 0 and i % args.e_reg_every == 0 if e_regularize: # why not regularize on augmented real? real_img.requires_grad = True real_pred, logvar = encoder(real_img) if args.reparameterization: real_pred = reparameterize(real_pred, logvar) r1_loss_e = d_r1_loss(real_pred, real_img) encoder.zero_grad() (args.r1 / 2 * r1_loss_e * args.e_reg_every + 0 * real_pred.view(-1)[0]).backward() e_optim.step() loss_dict["r1_e"] = r1_loss_e if not args.no_ema and e_ema is not None: accumulate(e_ema, e_module, accum) if args.train_latent_mlp: accumulate(g1_ema, g1_module, accum) # Train Discriminator if args.toggle_grads: requires_grad(encoder, False) requires_grad(discriminator, True) if not args.no_update_discriminator and args.lambda_adv > 0: latent_real, logvar = encoder(real_img) if args.reparameterization: latent_real = reparameterize(latent_real, logvar) if args.train_latent_mlp: fake_img, _ = generator([g1(latent_real)], input_is_latent=True, return_latents=False) else: fake_img, _ = generator([latent_real], input_is_latent=False, return_latents=False) if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) fake_img_aug, _ = augment(fake_img, ada_aug_p) else: real_img_aug = real_img fake_img_aug = fake_img fake_pred = discriminator(fake_img_aug) real_pred = discriminator(real_img_aug) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict["d"] = d_loss loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() z_real = torch.randn(batch, args.latent_full, device=device) fake_pred = discriminator_z(latent_real.detach()) real_pred = discriminator_z(z_real) d_loss_z = d_logistic_loss(real_pred, fake_pred) discriminator_z.zero_grad() d_loss_z.backward() dz_optim.step() if args.augment and args.augment_p == 0: ada_aug_p = ada_augment.tune(real_pred) r_t_stat = ada_augment.r_t_stat d_regularize = args.d_reg_every > 0 and i % args.d_reg_every == 0 if d_regularize: # why not regularize on augmented real? real_img.requires_grad = True real_pred = discriminator(real_img) r1_loss_d = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss_d * args.d_reg_every + 0 * real_pred.view(-1)[0]).backward() # Why 0* ? Answer is here https://github.com/rosinality/stylegan2-pytorch/issues/76 d_optim.step() loss_dict["r1_d"] = r1_loss_d loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() e_loss_val = loss_reduced["e"].mean().item() r1_d_val = loss_reduced["r1_d"].mean().item() r1_e_val = loss_reduced["r1_e"].mean().item() pix_loss_val = loss_reduced["pix"].mean().item() vgg_loss_val = loss_reduced["vgg"].mean().item() adv_loss_val = loss_reduced["adv"].mean().item() rec_loss_val = loss_reduced["rec"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() avg_pix_loss.update(pix_loss_val, real_img.shape[0]) avg_vgg_loss.update(vgg_loss_val, real_img.shape[0]) if get_rank() == 0: pbar.set_description(( f"d: {d_loss_val:.4f}; e: {e_loss_val:.4f}; r1_d: {r1_d_val:.4f}; r1_e: {r1_e_val:.4f}; " f"pix: {pix_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; adv: {adv_loss_val:.4f}; " f"rec: {rec_loss_val:.4f}; augment: {ada_aug_p:.4f}")) if i % args.log_every == 0: with torch.no_grad(): latent_x, _ = e_ema(sample_x) if args.train_latent_mlp: g1_ema.eval() fake_x, _ = generator([g1_ema(latent_x)], input_is_latent=True, return_latents=False) else: fake_x, _ = generator([latent_x], input_is_latent=False, return_latents=False) sample_pix_loss = torch.sum((sample_x - fake_x)**2) with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write( f"{i:07d}; pix: {avg_pix_loss.avg}; vgg: {avg_vgg_loss.avg}; " f"ref: {sample_pix_loss.item()};\n") if wandb and args.wandb: wandb.log({ "Encoder": e_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1 D": r1_d_val, "R1 E": r1_e_val, "Pix Loss": pix_loss_val, "VGG Loss": vgg_loss_val, "Adv Loss": adv_loss_val, "Rec Loss": rec_loss_val, "Real Score": real_score_val, "Fake Score": fake_score_val, }) if i % args.log_every == 0: with torch.no_grad(): e_eval = encoder if args.no_ema else e_ema e_eval.eval() nrow = int(args.n_sample**0.5) nchw = list(sample_x.shape)[1:] latent_real, _ = e_eval(sample_x) if args.train_latent_mlp: g1_ema.eval() fake_img, _ = generator([g1_ema(latent_real)], input_is_latent=True, return_latents=False) else: fake_img, _ = generator([latent_real], input_is_latent=False, return_latents=False) sample = torch.cat( (sample_x.reshape(args.n_sample // nrow, nrow, *nchw), fake_img.reshape(args.n_sample // nrow, nrow, *nchw)), 1) utils.save_image( sample.reshape(2 * args.n_sample, *nchw), os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}.png"), nrow=nrow, normalize=True, value_range=(-1, 1), ) e_eval.train() if i % args.save_every == 0: e_eval = encoder if args.no_ema else e_ema torch.save( { "e": e_module.state_dict(), "d": d_module.state_dict(), "g1": g1_module.state_dict() if args.train_latent_mlp else None, "g1_ema": g1_ema.state_dict() if args.train_latent_mlp else None, "g_ema": g_module.state_dict(), "e_ema": e_eval.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"), ) if i % args.save_latest_every == 0: torch.save( { "e": e_module.state_dict(), "d": d_module.state_dict(), "g1": g1_module.state_dict() if args.train_latent_mlp else None, "g1_ema": g1_ema.state_dict() if args.train_latent_mlp else None, "g_ema": g_module.state_dict(), "e_ema": e_eval.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"latest.pt"), )
def train( args, loader, encoder, generator, discriminator, discriminator3d, # video disctiminator posterior, prior, factor, # a learnable matrix vggnet, e_optim, d_optim, dv_optim, q_optim, # q for posterior p_optim, # p for prior f_optim, # f for factor e_ema, device ): loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) d_loss_val = 0 e_loss_val = 0 rec_loss_val = 0 vgg_loss_val = 0 adv_loss_val = 0 loss_dict = {"d": torch.tensor(0., device=device), "real_score": torch.tensor(0., device=device), "fake_score": torch.tensor(0., device=device), "r1_d": torch.tensor(0., device=device), "r1_e": torch.tensor(0., device=device), "rec": torch.tensor(0., device=device),} if args.distributed: e_module = encoder.module d_module = discriminator.module g_module = generator.module else: e_module = encoder d_module = discriminator g_module = generator accum = 0.5 ** (32 / (10 * 1000)) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 latent_full = args.latent_full factor_dim_full = args.factor_dim_full if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256, device) sample_x = accumulate_batches(loader, args.n_sample).to(device) utils.save_image( sample_x.view(-1, *list(sample_x.shape)[2:]), os.path.join(args.log_dir, 'sample', f"real-img.png"), nrow=sample_x.shape[1], normalize=True, value_range=(-1, 1), ) util.save_video( sample_x[0], os.path.join(args.log_dir, 'sample', f"real-vid.mp4") ) requires_grad(generator, False) # always False generator.eval() # Generator should be ema and in eval mode if args.no_update_encoder: encoder = e_ema if e_ema is not None else encoder requires_grad(encoder, False) encoder.eval() from models.networks_3d import GANLoss criterionGAN = GANLoss() # criterionL1 = nn.L1Loss() # if args.no_ema or e_ema is None: # e_ema = encoder for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break data = next(loader) real_seq = data['frames'] real_seq = real_seq.to(device) # [N, T, C, H, W] shape = list(real_seq.shape) N, T = shape[:2] # Train Encoder with frame-level objectives if args.toggle_grads: if not args.no_update_encoder: requires_grad(encoder, True) requires_grad(discriminator, False) pix_loss = vgg_loss = adv_loss = rec_loss = vid_loss = l1y_loss = torch.tensor(0., device=device) # TODO: real_seq -> encoder -> posterior -> generator -> fake_seq # f: [N, latent_full]; y: [N, T, D] fake_img, fake_seq, y_post = reconstruct_sequence(args, real_seq, encoder, generator, factor, posterior, i, ret_y=True) # if args.debug == 'no_lstm': # real_lat = encoder(real_seq.view(-1, *shape[2:])) # fake_img, _ = generator([real_lat], input_is_latent=True, return_latents=False) # fake_seq = fake_img.view(N, T, *shape[2:]) # elif args.debug == 'decomp': # real_lat = encoder(real_seq.view(-1, *shape[2:])) # [N*T, latent_full] # f_post = real_lat[::T, ...] # z_post = real_lat.view(N, T, -1) - f_post.unsqueeze(1) # if args.use_multi_head: # y_post = [] # for z, w in zip(torch.split(z_post, 512, 2), factor.weight): # y_post.append(torch.mm(z.view(N*T, -1), w).view(N, T, -1)) # y_post = torch.cat(y_post, 2) # else: # y_post = torch.mm(z_post.view(N*T, -1), factor.weight[0]).view(N, T, -1) # z_post_hat = factor(y_post) # f_expand = f_post.unsqueeze(1).expand(-1, T, -1) # w_post = f_expand + z_post_hat # fake_img, _ = generator([w_post.view(N*T, latent_full)], input_is_latent=True, return_latents=False) # fake_seq = fake_img.view(N, T, *shape[2:]) # else: # real_lat = encoder(real_seq.view(-1, *shape[2:])) # # single head: f_post [N, latent_full]; y_post [N, T, D] # # multi head: f_post [N, n_latent, latent]; y_post [N, T, n_latent, d] # f_post, y_post = posterior(real_lat.view(N, T, latent_full)) # z_post = factor(y_post) # f_expand = f_post.unsqueeze(1).expand(-1, T, -1) # w_post = f_expand + z_post # shape [N, T, latent_full] # fake_img, _ = generator([w_post.view(N*T, latent_full)], input_is_latent=True, return_latents=False) # fake_seq = fake_img.view(N, T, *shape[2:]) # TODO: sample frames real_img = real_seq.view(N*T, *shape[2:]) # fake_img = fake_seq.view(N*T, *shape[2:]) if args.lambda_adv > 0: if args.augment: fake_img_aug, _ = augment(fake_img, ada_aug_p) else: fake_img_aug = fake_img fake_pred = discriminator(fake_img_aug) adv_loss = g_nonsaturating_loss(fake_pred) # TODO: do we always put pix and vgg loss for all frames? if args.lambda_pix > 0: pix_loss = torch.mean((real_img - fake_img) ** 2) if args.lambda_vgg > 0: real_feat = vggnet(real_img) fake_feat = vggnet(fake_img) vgg_loss = torch.mean((real_feat - fake_feat) ** 2) # Train Encoder with video-level objectives # TODO: video adversarial loss if args.lambda_vid > 0: fake_pred = discriminator3d(flip_video(fake_seq.transpose(1, 2))) vid_loss = criterionGAN(fake_pred, True) if args.lambda_l1y > 0: # l1y_loss = criterionL1(y_post) l1y_loss = torch.mean(torch.abs(y_post)) e_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv e_loss = e_loss + args.lambda_vid * vid_loss + args.lambda_l1y * l1y_loss loss_dict["e"] = e_loss loss_dict["pix"] = pix_loss loss_dict["vgg"] = vgg_loss loss_dict["adv"] = adv_loss if not args.no_update_encoder: encoder.zero_grad() posterior.zero_grad() e_loss.backward() q_optim.step() if not args.no_update_encoder: e_optim.step() # if args.train_on_fake: # e_regularize = args.e_rec_every > 0 and i % args.e_rec_every == 0 # if e_regularize and args.lambda_rec > 0: # noise = mixing_noise(args.batch, args.latent, args.mixing, device) # fake_img, latent_fake = generator(noise, input_is_latent=False, return_latents=True) # latent_pred = encoder(fake_img) # if latent_pred.ndim < 3: # latent_pred = latent_pred.unsqueeze(1).repeat(1, latent_fake.size(1), 1) # rec_loss = torch.mean((latent_fake - latent_pred) ** 2) # encoder.zero_grad() # (rec_loss * args.lambda_rec).backward() # e_optim.step() # loss_dict["rec"] = rec_loss # e_regularize = args.e_reg_every > 0 and i % args.e_reg_every == 0 # if e_regularize: # # why not regularize on augmented real? # real_img.requires_grad = True # real_pred = encoder(real_img) # r1_loss_e = d_r1_loss(real_pred, real_img) # encoder.zero_grad() # (args.r1 / 2 * r1_loss_e * args.e_reg_every + 0 * real_pred.view(-1)[0]).backward() # e_optim.step() # loss_dict["r1_e"] = r1_loss_e if not args.no_update_encoder: if not args.no_ema and e_ema is not None: accumulate(e_ema, e_module, accum) # Train Discriminator if args.toggle_grads: requires_grad(encoder, False) requires_grad(discriminator, True) fake_img, fake_seq = reconstruct_sequence(args, real_seq, encoder, generator, factor, posterior) # if args.debug == 'no_lstm': # real_lat = encoder(real_seq.view(-1, *shape[2:])) # fake_img, _ = generator([real_lat], input_is_latent=True, return_latents=False) # fake_seq = fake_img.view(N, T, *shape[2:]) # elif args.debug == 'decomp': # real_lat = encoder(real_seq.view(-1, *shape[2:])) # [N*T, latent_full] # f_post = real_lat[::T, ...] # z_post = real_lat.view(N, T, -1) - f_post.unsqueeze(1) # if args.use_multi_head: # y_post = [] # for z, w in zip(torch.split(z_post, 512, 2), factor.weight): # y_post.append(torch.mm(z.view(N*T, -1), w).view(N, T, -1)) # y_post = torch.cat(y_post, 2) # else: # y_post = torch.mm(z_post.view(N*T, -1), factor.weight[0]).view(N, T, -1) # z_post_hat = factor(y_post) # f_expand = f_post.unsqueeze(1).expand(-1, T, -1) # w_post = f_expand + z_post_hat # fake_img, _ = generator([w_post.view(N*T, latent_full)], input_is_latent=True, return_latents=False) # fake_seq = fake_img.view(N, T, *shape[2:]) # elif args.debug == 'coef': # real_lat = encoder(real_seq.view(-1, *shape[2:])) # [N*T, latent_full] # f_post = real_lat[::T, ...] # z_post_hat = real_lat.view(N, T, -1) - f_post.unsqueeze(1) # y_post = torch.mm(z_post_hat.view(N*T, -1), factor.weight).view(N, T, -1) # z_post = factor(y_post) # f_expand = f_post.unsqueeze(1).expand(-1, T, -1) # w_post = f_expand + z_post # fake_img, _ = generator([w_post.view(N*T, latent_full)], input_is_latent=True, return_latents=False) # fake_seq = fake_img.view(N, T, *shape[2:]) # else: # real_lat = encoder(real_seq.view(-1, *shape[2:])) # f_post, y_post = posterior(real_lat.view(N, T, latent_full)) # z_post = factor(y_post) # f_expand = f_post.unsqueeze(1).expand(-1, T, -1) # w_post = f_expand + z_post # shape [N, T, latent_full] # fake_img, _ = generator([w_post.view(N*T, latent_full)], input_is_latent=True, return_latents=False) # fake_seq = fake_img.view(N, T, *shape[2:]) # fake_img = fake_seq.view(N*T, *shape[2:]) if not args.no_update_discriminator: if args.lambda_adv > 0: if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) fake_img_aug, _ = augment(fake_img, ada_aug_p) else: real_img_aug = real_img fake_img_aug = fake_img fake_pred = discriminator(fake_img_aug) real_pred = discriminator(real_img_aug) d_loss = d_logistic_loss(real_pred, fake_pred) # Train video discriminator if args.lambda_vid > 0: pred_real = discriminator3d(flip_video(real_seq.transpose(1, 2))) pred_fake = discriminator3d(flip_video(fake_seq.transpose(1, 2))) dv_loss_real = criterionGAN(pred_real, True) dv_loss_fake = criterionGAN(pred_fake, False) dv_loss = 0.5 * (dv_loss_real + dv_loss_fake) d_loss = d_loss + dv_loss loss_dict["d"] = d_loss loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() if args.lambda_adv > 0: discriminator.zero_grad() if args.lambda_vid > 0: discriminator3d.zero_grad() d_loss.backward() if args.lambda_adv > 0: d_optim.step() if args.lambda_vid > 0: dv_optim.step() if args.augment and args.augment_p == 0: ada_aug_p = ada_augment.tune(real_pred) r_t_stat = ada_augment.r_t_stat d_regularize = args.d_reg_every > 0 and i % args.d_reg_every == 0 if d_regularize: # why not regularize on augmented real? real_img.requires_grad = True real_pred = discriminator(real_img) r1_loss_d = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss_d * args.d_reg_every + 0 * real_pred.view(-1)[0]).backward() # Why 0* ? Answer is here https://github.com/rosinality/stylegan2-pytorch/issues/76 d_optim.step() loss_dict["r1_d"] = r1_loss_d loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() e_loss_val = loss_reduced["e"].mean().item() r1_d_val = loss_reduced["r1_d"].mean().item() r1_e_val = loss_reduced["r1_e"].mean().item() pix_loss_val = loss_reduced["pix"].mean().item() vgg_loss_val = loss_reduced["vgg"].mean().item() adv_loss_val = loss_reduced["adv"].mean().item() rec_loss_val = loss_reduced["rec"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() if get_rank() == 0: pbar.set_description( ( f"d: {d_loss_val:.4f}; e: {e_loss_val:.4f}; r1_d: {r1_d_val:.4f}; r1_e: {r1_e_val:.4f}; " f"pix: {pix_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; adv: {adv_loss_val:.4f}; " f"rec: {rec_loss_val:.4f}; augment: {ada_aug_p:.4f}" ) ) if wandb and args.wandb: wandb.log( { "Encoder": e_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1 D": r1_d_val, "R1 E": r1_e_val, "Pix Loss": pix_loss_val, "VGG Loss": vgg_loss_val, "Adv Loss": adv_loss_val, "Rec Loss": rec_loss_val, "Real Score": real_score_val, "Fake Score": fake_score_val, } ) if i % args.log_every == 0: with torch.no_grad(): e_eval = encoder if args.no_ema else e_ema e_eval.eval() posterior.eval() # N = sample_x.shape[0] fake_img, fake_seq = reconstruct_sequence(args, sample_x, e_eval, generator, factor, posterior) # if args.debug == 'no_lstm': # real_lat = encoder(sample_x.view(-1, *shape[2:])) # fake_img, _ = generator([real_lat], input_is_latent=True, return_latents=False) # fake_seq = fake_img.view(N, T, *shape[2:]) # elif args.debug == 'decomp': # real_lat = encoder(sample_x.view(-1, *shape[2:])) # [N*T, latent_full] # f_post = real_lat[::T, ...] # z_post = real_lat.view(N, T, -1) - f_post.unsqueeze(1) # if args.use_multi_head: # y_post = [] # for z, w in zip(torch.split(z_post, 512, 2), factor.weight): # y_post.append(torch.mm(z.view(N*T, -1), w).view(N, T, -1)) # y_post = torch.cat(y_post, 2) # else: # y_post = torch.mm(z_post.view(N*T, -1), factor.weight[0]).view(N, T, -1) # z_post_hat = factor(y_post) # f_expand = f_post.unsqueeze(1).expand(-1, T, -1) # w_post = f_expand + z_post_hat # fake_img, _ = generator([w_post.view(N*T, latent_full)], input_is_latent=True, return_latents=False) # fake_seq = fake_img.view(N, T, *shape[2:]) # else: # x_lat = encoder(sample_x.view(-1, *shape[2:])) # f_post, y_post = posterior(x_lat.view(N, T, latent_full)) # z_post = factor(y_post) # f_expand = f_post.unsqueeze(1).expand(-1, T, -1) # w_post = f_expand + z_post # fake_img, _ = generator([w_post.view(N*T, latent_full)], input_is_latent=True, return_latents=False) # fake_seq = fake_img.view(N, T, *shape[2:]) utils.save_image( torch.cat((sample_x, fake_seq), 1).view(-1, *shape[2:]), os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-img_recon.png"), nrow=T, normalize=True, value_range=(-1, 1), ) util.save_video( fake_seq[random.randint(0, args.n_sample-1)], os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-vid_recon.mp4") ) fake_img, fake_seq = swap_sequence(args, sample_x, e_eval, generator, factor, posterior) utils.save_image( torch.cat((sample_x, fake_seq), 1).view(-1, *shape[2:]), os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-img_swap.png"), nrow=T, normalize=True, value_range=(-1, 1), ) e_eval.train() posterior.train() if i % args.save_every == 0: e_eval = encoder if args.no_ema else e_ema torch.save( { "e": e_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_module.state_dict(), "e_ema": e_eval.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"), ) if not args.debug and i % args.save_latest_every == 0: torch.save( { "e": e_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_module.state_dict(), "e_ema": e_eval.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, }, os.path.join(args.log_dir, 'weight', f"latest.pt"), )
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device): loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} if args.distributed: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator accum = 0.5 ** (32 / (10 * 1000)) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 8, device) sample_z = torch.randn(args.n_sample, args.latent, device=device) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break real_img = next(loader) real_img = real_img.to(device) requires_grad(generator, False) requires_grad(discriminator, True) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) fake_img, _ = augment(fake_img, ada_aug_p) else: real_img_aug = real_img fake_pred = discriminator(fake_img) real_pred = discriminator(real_img_aug) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict["d"] = d_loss loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() if args.augment and args.augment_p == 0: ada_aug_p = ada_augment.tune(real_pred) r_t_stat = ada_augment.r_t_stat d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) else: real_img_aug = real_img real_pred = discriminator(real_img_aug) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict["r1"] = r1_loss requires_grad(generator, True) requires_grad(discriminator, False) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) if args.augment: fake_img, _ = augment(fake_img, ada_aug_p) fake_pred = discriminator(fake_img) g_loss = g_nonsaturating_loss(fake_pred) loss_dict["g"] = g_loss generator.zero_grad() g_loss.backward() g_optim.step() g_regularize = i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) fake_img, latents = generator(noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length ) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = ( reduce_sum(mean_path_length).item() / get_world_size() ) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() if get_rank() == 0: pbar.set_description( ( f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}" ) ) if wandb and args.wandb: wandb.log( { "Generator": g_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, } ) if i % 100 == 0: with torch.no_grad(): g_ema.eval() sample, _ = g_ema([sample_z]) utils.save_image( sample, f"sample/{str(i).zfill(6)}.png", nrow=int(args.n_sample ** 0.5), normalize=True, range=(-1, 1), ) if i % 10000 == 0: torch.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, }, f"checkpoint/{str(i).zfill(6)}.pt", )
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device): inception = real_mean = real_cov = mean_latent = None if args.eval_every > 0: inception = nn.DataParallel(load_patched_inception_v3()).to(device) inception.eval() with open(args.inception, "rb") as f: embeds = pickle.load(f) real_mean = embeds["mean"] real_cov = embeds["cov"] if get_rank() == 0: if args.eval_every > 0: with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f: f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n") if args.log_every > 0: with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n") loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} if args.distributed: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator # accum = 0.5 ** (32 / (10 * 1000)) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, args.ada_every, device) args.n_sheets = int(np.ceil(args.n_classes / args.n_class_per_sheet)) args.n_sample_per_sheet = args.n_sample_per_class * args.n_class_per_sheet args.n_sample = args.n_sample_per_sheet * args.n_sheets sample_z = torch.randn(args.n_sample, args.latent, device=device) sample_y = torch.arange(args.n_classes).repeat(args.n_sample_per_class, 1).t().reshape(-1).to(device) if args.n_sample > args.n_sample_per_class * args.n_classes: sample_y1 = make_fake_label(args.n_sample - args.n_sample_per_class * args.n_classes, args.n_classes, device) sample_y = torch.cat([sample_y, sample_y1], 0) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break # Train Discriminator requires_grad(generator, False) requires_grad(discriminator, True) for step_index in range(args.n_step_d): real_img, real_labels = next(loader) real_img, real_labels = real_img.to(device), real_labels.to(device) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_labels = make_fake_label(args.batch, args.n_classes, device) fake_img, _ = generator(noise, fake_labels) if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) fake_img, _ = augment(fake_img, ada_aug_p) else: real_img_aug = real_img fake_pred = discriminator(fake_img, fake_labels) real_pred = discriminator(real_img_aug, real_labels) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict["d"] = d_loss loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() if args.augment and args.augment_p == 0: ada_aug_p = ada_augment.tune(real_pred) r_t_stat = ada_augment.r_t_stat d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) else: real_img_aug = real_img real_pred = discriminator(real_img_aug, real_labels) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict["r1"] = r1_loss # Train Generator requires_grad(generator, True) requires_grad(discriminator, False) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_labels = make_fake_label(args.batch, args.n_classes, device) fake_img, _ = generator(noise, fake_labels) if args.augment: fake_img, _ = augment(fake_img, ada_aug_p) fake_pred = discriminator(fake_img, fake_labels) g_loss = g_nonsaturating_loss(fake_pred) loss_dict["g"] = g_loss generator.zero_grad() g_loss.backward() g_optim.step() g_regularize = args.g_reg_every > 0 and i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) fake_labels = make_fake_label(args.batch, args.n_classes, device) fake_img, latents = generator(noise, fake_labels, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length ) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = ( reduce_sum(mean_path_length).item() / get_world_size() ) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() # Update G_ema # G_ema = G * (1-ema_beta) + G_ema * ema_beta ema_nimg = args.ema_kimg * 1000 if args.ema_rampup is not None: ema_nimg = min(ema_nimg, i * args.batch * args.ema_rampup) accum = 0.5 ** (args.batch / max(ema_nimg, 1e-8)) accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() if get_rank() == 0: pbar.set_description( ( f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}" ) ) if wandb and args.wandb: wandb.log( { "Generator": g_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, } ) if i % args.log_every == 0: with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write( ( f"{i:07d}; " f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f};\n" ) ) if i % args.log_every == 0: with torch.no_grad(): g_ema.eval() for sheet_index in range(args.n_sheets): sample_z_sheet = sample_z[sheet_index*args.n_sample_per_sheet:(sheet_index+1)*args.n_sample_per_sheet] sample_y_sheet = sample_y[sheet_index*args.n_sample_per_sheet:(sheet_index+1)*args.n_sample_per_sheet] sample, _ = g_ema([sample_z_sheet], sample_y_sheet) utils.save_image( sample, os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}_{sheet_index}.png"), nrow=args.n_sample_per_class, normalize=True, value_range=(-1, 1), ) if args.eval_every > 0 and i % args.eval_every == 0: with torch.no_grad(): g_ema.eval() if args.truncation < 1: mean_latent = g_ema.mean_latent(4096) features = extract_feature_from_samples( g_ema, inception, args.truncation, mean_latent, 64, args.n_sample_fid, args.device, n_classes=args.n_classes, ).numpy() sample_mean = np.mean(features, 0) sample_cov = np.cov(features, rowvar=False) fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov) # print("fid:", fid) with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f: f.write(f"{i:07d}; fid: {float(fid):.4f};\n") if i % args.save_every == 0: torch.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"), ) if i % args.save_latest_every == 0: torch.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"latest.pt"), )
def train(args, loader, generator, encoder, discriminator, discriminator2, vggnet, g_optim, e_optim, d_optim, d2_optim, g_ema, e_ema, device): kwargs_d = {'detach_aux': args.detach_d_aux_head} if args.dataset == 'imagefolder': loader = sample_data2(loader) else: loader = sample_data(loader) if args.eval_every > 0: inception = nn.DataParallel(load_patched_inception_v3()).to(device) inception.eval() with open(args.inception, "rb") as f: embeds = pickle.load(f) real_mean = embeds["mean"] real_cov = embeds["cov"] else: inception = real_mean = real_cov = None mean_latent = None pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} avg_pix_loss = util.AverageMeter() avg_vgg_loss = util.AverageMeter() if args.distributed: g_module = generator.module e_module = encoder.module d_module = discriminator.module else: g_module = generator e_module = encoder d_module = discriminator d2_module = None if discriminator2 is not None: if args.distributed: d2_module = discriminator2.module else: d2_module = discriminator2 accum = 0.5**(32 / (10 * 1000)) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256, device) sample_z = torch.randn(args.n_sample, args.latent, device=device) sample_x = load_real_samples(args, loader) if sample_x.ndim > 4: sample_x = sample_x[:, 0, ...] n_step_max = max(args.n_step_d, args.n_step_e) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break # real_img = next(loader).to(device) real_imgs = [next(loader).to(device) for _ in range(n_step_max)] # Train Discriminator requires_grad(generator, False) requires_grad(encoder, False) requires_grad(discriminator, True) if discriminator2 is not None: requires_grad(discriminator2, True) for step_index in range(args.n_step_d): # real_img = next(loader).to(device) real_img = real_imgs[step_index] noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) fake_img, _ = augment(fake_img, ada_aug_p) else: real_img_aug = real_img fake_pred = discriminator(fake_img, **kwargs_d) real_pred = discriminator(real_img_aug, **kwargs_d) d_loss_real1 = 0. if args.n_head_d > 1: fake_pred = fake_pred[0] real_pred, real_pred1 = real_pred[0], real_pred[1] d_loss_real1 = F.softplus(-real_pred1).mean() d_loss_real2 = 0. if args.decouple_d: real_pred2 = discriminator2(real_img_aug, **kwargs_d) d_loss_real2 = F.softplus(-real_pred2).mean() d_loss_real = F.softplus(-real_pred).mean() d_loss_fake = F.softplus(fake_pred).mean() loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() latent_real, _ = encoder(real_img) rec_img, _ = generator([latent_real], input_is_latent=True, return_latents=False) if args.augment: rec_img, _ = augment(rec_img, ada_aug_p) if not args.decouple_d: rec_pred = discriminator(rec_img, **kwargs_d) if args.n_head_d > 1: rec_pred = rec_pred[1] else: rec_pred = discriminator2(rec_img, **kwargs_d) d_loss_rec = F.softplus(rec_pred).mean() d_loss = (d_loss_real + d_loss_real1 + d_loss_real2 + d_loss_fake * args.lambda_fake_d + d_loss_rec * args.lambda_rec_d) loss_dict["rec_score"] = rec_pred.mean() loss_dict["d"] = d_loss if not args.decouple_d: discriminator.zero_grad() d_loss.backward() d_optim.step() else: discriminator.zero_grad() discriminator2.zero_grad() d_loss.backward() d_optim.step() d2_optim.step() if args.augment and args.augment_p == 0: ada_aug_p = ada_augment.tune(real_pred) r_t_stat = ada_augment.r_t_stat d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True real_pred = discriminator(real_img, **kwargs_d) if args.n_head_d > 1: real_pred = real_pred[0] + real_pred[1] r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict["r1"] = r1_loss if d_regularize and args.decouple_d: real_img.requires_grad = True real_pred = discriminator2(real_img, **kwargs_d) if args.n_head_d > 1: real_pred = real_pred[0] + real_pred[1] r1_loss = d_r1_loss(real_pred, real_img) discriminator2.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d2_optim.step() # Train Generator requires_grad(generator, True) requires_grad(discriminator, False) if discriminator2 is not None: requires_grad(discriminator2, False) real_img = real_imgs[0] noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) if args.augment: fake_img, _ = augment(fake_img, ada_aug_p) fake_pred = discriminator(fake_img, **kwargs_d) if args.n_head_d > 1: fake_pred = fake_pred[0] g_loss_fake = g_nonsaturating_loss(fake_pred) g_loss_rec = 0. if args.lambda_rec_g > 0: latent_real, _ = encoder(real_img) rec_img, _ = generator([latent_real], input_is_latent=True, return_latents=False) if args.augment: rec_img, _ = augment(rec_img, ada_aug_p) if not args.decouple_d: rec_pred = discriminator(rec_img, **kwargs_d) if args.n_head_d > 1: rec_pred = rec_pred[1] else: rec_pred = discriminator2(rec_img, **kwargs_d) g_loss_rec = g_nonsaturating_loss(rec_pred) g_loss = g_loss_fake * args.lambda_fake_g + g_loss_rec * args.lambda_rec_g loss_dict["g"] = g_loss generator.zero_grad() g_loss.backward() g_optim.step() g_regularize = args.g_reg_every > 0 and i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) fake_img, latents = generator(noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() # Train Encoder requires_grad(encoder, True) requires_grad(discriminator, False) requires_grad(generator, args.joint) if discriminator2 is not None: requires_grad(discriminator2, False) pix_loss = vgg_loss = adv_loss = torch.tensor(0., device=device) for step_index in range(args.n_step_e): # real_img = next(loader).to(device) real_img = real_imgs[step_index] latent_real, _ = encoder(real_img) rec_img, _ = generator([latent_real], input_is_latent=True, return_latents=False) if args.lambda_adv > 0: if args.augment: rec_img_aug, _ = augment(rec_img, ada_aug_p) else: rec_img_aug = rec_img if not args.decouple_d: rec_pred = discriminator(rec_img_aug, **kwargs_d) if args.n_head_d > 1: rec_pred = rec_pred[1] else: rec_pred = discriminator2(rec_img_aug, **kwargs_d) adv_loss = g_nonsaturating_loss(rec_pred) if args.lambda_pix > 0: if args.pix_loss == 'l2': pix_loss = torch.mean((rec_img - real_img)**2) else: pix_loss = F.l1_loss(rec_img, real_img) if args.lambda_vgg > 0: vgg_loss = torch.mean((vggnet(real_img) - vggnet(rec_img))**2) e_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv loss_dict["e"] = e_loss loss_dict["pix"] = pix_loss loss_dict["vgg"] = vgg_loss loss_dict["adv"] = adv_loss if args.joint: encoder.zero_grad() generator.zero_grad() e_loss.backward() # generator.style.zero_grad() # not necessary e_optim.step() g_optim.step() else: encoder.zero_grad() e_loss.backward() e_optim.step() accumulate(e_ema, e_module, accum) accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() pix_loss_val = loss_reduced["pix"].mean().item() vgg_loss_val = loss_reduced["vgg"].mean().item() adv_loss_val = loss_reduced["adv"].mean().item() avg_pix_loss.update(pix_loss_val, real_img.shape[0]) avg_vgg_loss.update(vgg_loss_val, real_img.shape[0]) if get_rank() == 0: pbar.set_description(( f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}; " f"pix: {pix_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; adv: {adv_loss_val:.4f}" )) if i % args.log_every == 0: with torch.no_grad(): latent_x, _ = e_ema(sample_x) fake_x, _ = generator([latent_x], input_is_latent=True, return_latents=False) sample_pix_loss = torch.sum((sample_x - fake_x)**2) with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write( f"{i:07d}; pix: {avg_pix_loss.avg}; vgg: {avg_vgg_loss.avg}; " f"ref: {sample_pix_loss.item()};\n") if args.eval_every > 0 and i % args.eval_every == 0: with torch.no_grad(): g_ema.eval() if args.truncation < 1: mean_latent = g_ema.mean_latent(4096) features = extract_feature_from_samples( g_ema, inception, args.truncation, mean_latent, 64, args.n_sample_fid, args.device).numpy() sample_mean = np.mean(features, 0) sample_cov = np.cov(features, rowvar=False) fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov) print("fid:", fid) with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f: f.write(f"{i:07d}; fid: {float(fid):.4f};\n") if wandb and args.wandb: wandb.log({ "Generator": g_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, }) if i % args.log_every == 0: with torch.no_grad(): # Fixed fake samples g_ema.eval() sample, _ = g_ema([sample_z]) utils.save_image( sample, os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-sample.png"), nrow=int(args.n_sample**0.5), normalize=True, value_range=(-1, 1), ) # Reconstruction samples e_ema.eval() nrow = int(args.n_sample**0.5) nchw = list(sample_x.shape)[1:] latent_real, _ = e_ema(sample_x) fake_img, _ = g_ema([latent_real], input_is_latent=True, return_latents=False) sample = torch.cat( (sample_x.reshape(args.n_sample // nrow, nrow, *nchw), fake_img.reshape(args.n_sample // nrow, nrow, *nchw)), 1) utils.save_image( sample.reshape(2 * args.n_sample, *nchw), os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-recon.png"), nrow=nrow, normalize=True, value_range=(-1, 1), ) if i % args.save_every == 0: torch.save( { "g": g_module.state_dict(), "e": e_module.state_dict(), "d": d_module.state_dict(), "d2": d2_module.state_dict() if args.decouple_d else None, "g_ema": g_ema.state_dict(), "e_ema": e_ema.state_dict(), "g_optim": g_optim.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "d2_optim": d2_optim.state_dict() if args.decouple_d else None, "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"), ) if i % args.save_latest_every == 0: torch.save( { "g": g_module.state_dict(), "e": e_module.state_dict(), "d": d_module.state_dict(), "d2": d2_module.state_dict() if args.decouple_d else None, "g_ema": g_ema.state_dict(), "e_ema": e_ema.state_dict(), "g_optim": g_optim.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "d2_optim": d2_optim.state_dict() if args.decouple_d else None, "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"latest.pt"), )
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device): loader = sample_data(loader, datatype="imagefolder") # inception related: if (get_rank() == 0): from calc_inception import load_patched_inception_v3 inception = load_patched_inception_v3().to(device) inception.eval() if args.eval_every > 0: with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f: f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-' * 50}\n") if args.log_every > 0: with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-' * 50}\n") pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} if args.distributed: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator accum = 0.5 ** (32 / (10 * 1000)) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 8, device) sample_z = torch.randn(args.n_sample, args.latent, device=device) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break real_img = next(loader) real_img = real_img.to(device) requires_grad(generator, False) requires_grad(discriminator, True) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) fake_img, _ = augment(fake_img, ada_aug_p) else: real_img_aug = real_img fake_pred = discriminator(fake_img) real_pred = discriminator(real_img_aug) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict["d"] = d_loss loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() if args.augment and args.augment_p == 0: ada_aug_p = ada_augment.tune(real_pred) r_t_stat = ada_augment.r_t_stat d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) else: real_img_aug = real_img real_pred = discriminator(real_img_aug) r1_loss = d_r1_loss(real_pred, real_img,args) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict["r1"] = r1_loss requires_grad(generator, True) requires_grad(discriminator, False) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) if args.augment: fake_img, _ = augment(fake_img, ada_aug_p) fake_pred = discriminator(fake_img) g_loss = g_nonsaturating_loss(fake_pred) loss_dict["g"] = g_loss generator.zero_grad() g_loss.backward() g_optim.step() #g_reg starts g_regularize = False if args.useG_reg==True: # print("I entered g_reg") g_regularize = i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) fake_img, latents = generator(noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length ) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = ( reduce_sum(mean_path_length).item() / get_world_size() ) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() if get_rank() == 0: pbar.set_description( ( f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}" ) ) # inception related: if args.eval_every > 0 and i % args.eval_every == 0: real_mean = real_cov = mean_latent = None with open(args.inception, "rb") as f: embeds = pickle.load(f) real_mean = embeds["mean"] real_cov = embeds["cov"] # print("yahooo!\n") with torch.no_grad(): g_ema.eval() if args.truncation < 1: mean_latent = g_ema.mean_latent(4096) # print("I am fine sir!\n") features = extract_feature_from_samples( g_ema, inception, args.truncation, mean_latent, 64, args.n_sample_fid, device ).numpy() # print("I am normal sir!") sample_mean = np.mean(features, 0) sample_cov = np.cov(features, rowvar=False) fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov) with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f: f.write(f"{i:07d}; fid: {float(fid):.4f};\n") # print("alright hurray \n") if i % args.log_every == 0: with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write( ( f"{i:07d}; " f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f};\n" ) ) if i % args.log_every == 0: with torch.no_grad(): g_ema.eval() sample, _ = g_ema([sample_z]) utils.save_image( sample, os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}.png"), nrow=int(args.n_sample ** 0.5), normalize=True, range=(-1, 1), ) if i % args.save_every == 0: torch.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"), ) if i % args.save_latest_every == 0: torch.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"latest.pt"), )
def train(args, loader, loader2, generator, encoder, discriminator, vggnet, g_optim, e_optim, d_optim, g_ema, e_ema, device): inception = real_mean = real_cov = mean_latent = None if args.eval_every > 0: inception = nn.DataParallel(load_patched_inception_v3()).to(device) inception.eval() with open(args.inception, "rb") as f: embeds = pickle.load(f) real_mean = embeds["mean"] real_cov = embeds["cov"] if get_rank() == 0: if args.eval_every > 0: with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f: f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n") if args.log_every > 0: with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n") loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 d_loss_val = r1_val = real_score_val = recx_score_val = 0 loss_dict = { "d": torch.tensor(0.0, device=device), "r1": torch.tensor(0.0, device=device) } avg_pix_loss = util.AverageMeter() avg_vgg_loss = util.AverageMeter() if args.distributed: g_module = generator.module e_module = encoder.module d_module = discriminator.module else: g_module = generator e_module = encoder d_module = discriminator d_weight = torch.tensor(1.0, device=device) last_layer = None if args.use_adaptive_weight: if args.distributed: last_layer = generator.module.get_last_layer() else: last_layer = generator.get_last_layer() # accum = 0.5 ** (32 / (10 * 1000)) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 r_t_dict = {'real': 0, 'recx': 0} # r_t stat g_scale = 1 if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, args.ada_every, device) sample_z = torch.randn(args.n_sample, args.latent, device=device) sample_x = load_real_samples(args, loader) if sample_x.ndim > 4: sample_x = sample_x[:, 0, ...] n_step_max = max(args.n_step_d, args.n_step_e) requires_grad(g_ema, False) requires_grad(e_ema, False) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break if args.debug: util.seed_everything(i) real_imgs = [next(loader).to(device) for _ in range(n_step_max)] # Train Discriminator if args.lambda_adv > 0: requires_grad(generator, False) requires_grad(encoder, False) requires_grad(discriminator, True) for step_index in range(args.n_step_d): real_img = real_imgs[step_index] latent_real, _ = encoder(real_img) rec_img, _ = generator([latent_real], input_is_latent=True) if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) rec_img_aug, _ = augment(rec_img, ada_aug_p) else: real_img_aug = real_img rec_img_aug = rec_img real_pred = discriminator(real_img_aug) rec_pred = discriminator(rec_img_aug) d_loss_real = F.softplus(-real_pred).mean() d_loss_rec = F.softplus(rec_pred).mean() loss_dict["real_score"] = real_pred.mean() loss_dict["recx_score"] = rec_pred.mean() d_loss = d_loss_real + d_loss_rec * args.lambda_rec_d loss_dict["d"] = d_loss discriminator.zero_grad() d_loss.backward() d_optim.step() if args.augment and args.augment_p == 0: ada_aug_p = ada_augment.tune(real_pred) r_t_stat = ada_augment.r_t_stat # Compute batchwise r_t r_t_dict['real'] = torch.sign(real_pred).sum().item() / args.batch d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) else: real_img_aug = real_img real_pred = discriminator(real_img_aug) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict["r1"] = r1_loss r_t_dict['recx'] = torch.sign(rec_pred).sum().item() / args.batch # Train AutoEncoder requires_grad(encoder, True) requires_grad(generator, True) requires_grad(discriminator, False) if args.debug: util.seed_everything(i) pix_loss = vgg_loss = adv_loss = torch.tensor(0., device=device) for step_index in range(args.n_step_e): real_img = real_imgs[step_index] latent_real, _ = encoder(real_img) rec_img, _ = generator([latent_real], input_is_latent=True) if args.lambda_pix > 0: if args.pix_loss == 'l2': pix_loss = torch.mean((rec_img - real_img)**2) elif args.pix_loss == 'l1': pix_loss = F.l1_loss(rec_img, real_img) if args.lambda_vgg > 0: vgg_loss = torch.mean((vggnet(real_img) - vggnet(rec_img))**2) if args.lambda_adv > 0: if args.augment: rec_img_aug, _ = augment(rec_img, ada_aug_p) else: rec_img_aug = rec_img rec_pred = discriminator(rec_img_aug) adv_loss = g_nonsaturating_loss(rec_pred) if args.use_adaptive_weight and i >= args.disc_iter_start: nll_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg g_loss = adv_loss * args.lambda_adv d_weight = calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) ae_loss = (pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + d_weight * adv_loss * args.lambda_adv) loss_dict["ae"] = ae_loss loss_dict["pix"] = pix_loss loss_dict["vgg"] = vgg_loss loss_dict["adv"] = adv_loss encoder.zero_grad() generator.zero_grad() ae_loss.backward() e_optim.step() if args.g_decay is not None: scale_grad(generator, g_scale) g_scale *= args.g_decay g_optim.step() g_regularize = args.g_reg_every > 0 and i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) fake_img, latents = generator(noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() # Update EMA ema_nimg = args.ema_kimg * 1000 if args.ema_rampup is not None: ema_nimg = min(ema_nimg, i * args.batch * args.ema_rampup) accum = 0.5**(args.batch / max(ema_nimg, 1e-8)) accumulate(g_ema, g_module, 0 if args.no_ema_g else accum) accumulate(e_ema, e_module, 0 if args.no_ema_e else accum) loss_reduced = reduce_loss_dict(loss_dict) ae_loss_val = loss_reduced["ae"].mean().item() path_loss_val = loss_reduced["path"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() pix_loss_val = loss_reduced["pix"].mean().item() vgg_loss_val = loss_reduced["vgg"].mean().item() adv_loss_val = loss_reduced["adv"].mean().item() if args.lambda_adv > 0: d_loss_val = loss_reduced["d"].mean().item() r1_val = loss_reduced["r1"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() recx_score_val = loss_reduced["recx_score"].mean().item() avg_pix_loss.update(pix_loss_val, real_img.shape[0]) avg_vgg_loss.update(vgg_loss_val, real_img.shape[0]) if get_rank() == 0: pbar.set_description(( f"d: {d_loss_val:.4f}; ae: {ae_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}; " f"d_weight: {d_weight.item():.4f}; " f"pix: {pix_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; adv: {adv_loss_val:.4f}" )) if i % args.log_every == 0: with torch.no_grad(): g_ema.eval() e_ema.eval() nrow = int(args.n_sample**0.5) nchw = list(sample_x.shape)[1:] # Reconstruction of real images latent_x, _ = e_ema(sample_x) rec_real, _ = g_ema([latent_x], input_is_latent=True) sample = torch.cat( (sample_x.reshape(args.n_sample // nrow, nrow, *nchw), rec_real.reshape(args.n_sample // nrow, nrow, *nchw)), 1) utils.save_image( sample.reshape(2 * args.n_sample, *nchw), os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-recon.png"), nrow=nrow, normalize=True, value_range=(-1, 1), ) ref_pix_loss = torch.sum(torch.abs(sample_x - rec_real)) ref_vgg_loss = torch.mean( (vggnet(sample_x) - vggnet(rec_real))**2) if vggnet is not None else 0 # Fixed fake samples and reconstructions sample, _ = g_ema([sample_z]) utils.save_image( sample, os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-sample.png"), nrow=int(args.n_sample**0.5), normalize=True, value_range=(-1, 1), ) with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write(( f"{i:07d}; " f"d: {d_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean_path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}; {'; '.join([f'{k}: {r_t_dict[k]:.4f}' for k in r_t_dict])}; " f"real_score: {real_score_val:.4f}; recx_score: {recx_score_val:.4f}; " f"pix: {avg_pix_loss.avg:.4f}; vgg: {avg_vgg_loss.avg:.4f}; " f"ref_pix: {ref_pix_loss.item():.4f}; ref_vgg: {ref_vgg_loss.item():.4f}; " f"d_weight: {d_weight.item():.4f}; " f"\n")) if wandb and args.wandb: wandb.log({ "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Path Length": path_length_val, }) if args.eval_every > 0 and i % args.eval_every == 0: with torch.no_grad(): fid_sa = fid_re = fid_sr = 0 g_ema.eval() e_ema.eval() if args.truncation < 1: mean_latent = g_ema.mean_latent(4096) # Real reconstruction FID if 'fid_recon' in args.which_metric: features = extract_feature_from_reconstruction( e_ema, g_ema, inception, args.truncation, mean_latent, loader2, args.device, mode='recon', ).numpy() sample_mean = np.mean(features, 0) sample_cov = np.cov(features, rowvar=False) fid_re = calc_fid(sample_mean, sample_cov, real_mean, real_cov) with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f: f.write(f"{i:07d}; rec_real: {float(fid_re):.4f};\n") if i % args.save_every == 0: torch.save( { "g": g_module.state_dict(), "e": e_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "e_ema": e_ema.state_dict(), "g_optim": g_optim.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"), ) if i % args.save_latest_every == 0: torch.save( { "g": g_module.state_dict(), "e": e_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "e_ema": e_ema.state_dict(), "g_optim": g_optim.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"latest.pt"), )
def train(args, loader, loader2, T_list, generator, encoder, discriminator, discriminator2, vggnet, g_optim, e_optim, d_optim, d2_optim, g_ema, e_ema, device): # kwargs_d = {'detach_aux': args.detach_d_aux_head} inception = real_mean = real_cov = mean_latent = None if args.eval_every > 0: inception = nn.DataParallel(load_patched_inception_v3()).to(device) inception.eval() with open(args.inception, "rb") as f: embeds = pickle.load(f) real_mean = embeds["mean"] real_cov = embeds["cov"] if get_rank() == 0: if args.eval_every > 0: with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f: f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n") if args.log_every > 0: with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n") if args.dataset == 'imagefolder': loader = sample_data2(loader) else: loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} avg_pix_loss = util.AverageMeter() avg_vgg_loss = util.AverageMeter() if args.distributed: g_module = generator.module e_module = encoder.module d_module = discriminator.module else: g_module = generator e_module = encoder d_module = discriminator d2_module = None if discriminator2 is not None: if args.distributed: d2_module = discriminator2.module else: d2_module = discriminator2 accum = 0.5**(32 / (10 * 1000)) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256, device) sample_z = torch.randn(args.n_sample, args.latent, device=device) sample_x, sample_idx = load_real_samples(args, loader) assert (sample_x.shape[1] >= args.nframe_num) sample_x1 = sample_x[:, 0, ...] sample_x2 = sample_x[:, -1, ...] # sample_idx = torch.randperm(args.n_sample) fid_batch_idx = sample_idx = None n_step_max = max(args.n_step_d, args.n_step_e) requires_grad(g_ema, False) requires_grad(e_ema, False) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break frames = [ get_batch(loader, device, T_list[i], not args.no_rand_T) for _ in range(n_step_max) ] # Train Discriminator requires_grad(generator, False) requires_grad(encoder, False) requires_grad(discriminator, True) for step_index in range(args.n_step_d): frames1, frames2 = frames[step_index] real_img = frames1 noise = mixing_noise(args.batch, args.latent, args.mixing, device) if args.use_ema: g_ema.eval() fake_img, _ = g_ema(noise) else: fake_img, _ = generator(noise) if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) fake_img, _ = augment(fake_img, ada_aug_p) else: real_img_aug = real_img fake_pred = discriminator(fake_img) real_pred = discriminator(real_img_aug) d_loss_fake = F.softplus(fake_pred).mean() d_loss_real = F.softplus(-real_pred).mean() if args.use_frames2_d: real_pred2 = discriminator(frames2) d_loss_real += F.softplus(-real_pred2).mean() loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() d_loss_rec = 0. if not args.decouple_d and args.lambda_rec_d > 0: # Do not train D on x_rec if decouple_d if args.use_ema: e_ema.eval() g_ema.eval() latent_real, _ = e_ema(real_img) rec_img, _ = g_ema([latent_real], input_is_latent=True) else: latent_real, _ = encoder(real_img) rec_img, _ = generator([latent_real], input_is_latent=True) rec_pred = discriminator(rec_img) d_loss_rec = F.softplus(rec_pred).mean() loss_dict["rec_score"] = rec_pred.mean() d_loss_cross = 0. if not args.decouple_d and args.lambda_cross_d > 0: if args.use_ema: e_ema.eval() w1, _ = e_ema(frames1) w2, _ = e_ema(frames2) else: w1, _ = encoder(frames1) w2, _ = encoder(frames2) dw = w2 - w1 dw_shuffle = dw[torch.randperm(args.batch), ...] if args.use_ema: g_ema.eval() cross_img, _ = g_ema([w1 + dw_shuffle], input_is_latent=True) else: cross_img, _ = generator([w1 + dw_shuffle], input_is_latent=True) cross_pred = discriminator(cross_img) d_loss_cross = F.softplus(cross_pred).mean() d_loss_fake_cross = 0. if args.lambda_fake_cross_d > 0: if args.use_ema: e_ema.eval() w1, _ = e_ema(frames1) w2, _ = e_ema(frames2) else: w1, _ = encoder(frames1) w2, _ = encoder(frames2) dw = w2 - w1 noise = mixing_noise(args.batch, args.latent, args.mixing, device) if args.use_ema: g_ema.eval() style = g_ema.get_styles(noise).view(args.batch, -1) else: style = generator.get_styles(noise).view(args.batch, -1) if dw.shape[1] < style.shape[1]: # W space dw = dw.repeat(1, args.n_latent) if args.use_ema: cross_img, _ = g_ema([style + dw], input_is_latent=True) else: cross_img, _ = generator([style + dw], input_is_latent=True) fake_cross_pred = discriminator(cross_img) d_loss_fake_cross = F.softplus(fake_cross_pred).mean() d_loss = (d_loss_real + d_loss_fake + d_loss_fake_cross * args.lambda_fake_cross_d + d_loss_rec * args.lambda_rec_d + d_loss_cross * args.lambda_cross_d) loss_dict["d"] = d_loss discriminator.zero_grad() d_loss.backward() d_optim.step() if args.augment and args.augment_p == 0: ada_aug_p = ada_augment.tune(real_pred) r_t_stat = ada_augment.r_t_stat d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True real_pred = discriminator(real_img) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict["r1"] = r1_loss # Train Discriminator2 if args.decouple_d and discriminator2 is not None: requires_grad(generator, False) requires_grad(encoder, False) requires_grad(discriminator2, True) for step_index in range( args.n_step_e): # n_step_d2 is same as n_step_e frames1, frames2 = frames[step_index] real_img = frames1 if args.use_ema: e_ema.eval() g_ema.eval() latent_real, _ = e_ema(real_img) rec_img, _ = g_ema([latent_real], input_is_latent=True) else: latent_real, _ = encoder(real_img) rec_img, _ = generator([latent_real], input_is_latent=True) real_pred1 = discriminator2(frames1) rec_pred = discriminator2(rec_img) d2_loss_real = F.softplus(-real_pred1).mean() d2_loss_rec = F.softplus(rec_pred).mean() if args.use_frames2_d: real_pred2 = discriminator2(frames2) d2_loss_real += F.softplus(-real_pred2).mean() if args.use_ema: e_ema.eval() w1, _ = e_ema(frames1) w2, _ = e_ema(frames2) else: w1, _ = encoder(frames1) w2, _ = encoder(frames2) dw = w2 - w1 dw_shuffle = dw[torch.randperm(args.batch), ...] if args.use_ema: g_ema.eval() cross_img, _ = g_ema([w1 + dw_shuffle], input_is_latent=True) else: cross_img, _ = generator([w1 + dw_shuffle], input_is_latent=True) cross_pred = discriminator2(cross_img) d2_loss_cross = F.softplus(cross_pred).mean() d2_loss = d2_loss_real + d2_loss_rec + d2_loss_cross loss_dict["d2"] = d2_loss loss_dict["rec_score"] = rec_pred.mean() loss_dict["cross_score"] = cross_pred.mean() discriminator2.zero_grad() d2_loss.backward() d2_optim.step() d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True real_pred = discriminator2(real_img) r1_loss = d_r1_loss(real_pred, real_img) discriminator2.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d2_optim.step() # Train Encoder and Generator requires_grad(encoder, True) requires_grad(generator, True) requires_grad(discriminator, False) requires_grad(discriminator2, False) pix_loss = vgg_loss = adv_loss = cross_loss = torch.tensor( 0., device=device) for step_index in range(args.n_step_e): frames1, frames2 = frames[step_index] real_img = frames1 latent_real, _ = encoder(real_img) rec_img, _ = generator([latent_real], input_is_latent=True) if args.lambda_pix > 0: if args.pix_loss == 'l2': pix_loss = torch.mean((rec_img - real_img)**2) else: pix_loss = F.l1_loss(rec_img, real_img) if args.lambda_vgg > 0: vgg_loss = torch.mean((vggnet(real_img) - vggnet(rec_img))**2) if args.lambda_adv > 0: if not args.decouple_d: rec_pred = discriminator(rec_img) else: rec_pred = discriminator2(rec_img) adv_loss = g_nonsaturating_loss(rec_pred) if args.lambda_cross > 0: w1, _ = encoder(frames1) w2, _ = encoder(frames2) dw = w2 - w1 dw_shuffle = dw[torch.randperm(args.batch), ...] cross_img, _ = generator([w1 + dw_shuffle], input_is_latent=True) if not args.decouple_d: cross_pred = discriminator(cross_img) else: cross_pred = discriminator2(cross_img) cross_loss = g_nonsaturating_loss(cross_pred) e_loss = (pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv + cross_loss * args.lambda_cross) loss_dict["e"] = e_loss loss_dict["pix"] = pix_loss loss_dict["vgg"] = vgg_loss loss_dict["adv"] = adv_loss encoder.zero_grad() generator.zero_grad() e_loss.backward() e_optim.step() g_optim.step() # Train Generator requires_grad(generator, True) requires_grad(encoder, False) requires_grad(discriminator, False) requires_grad(discriminator2, False) frames1, frames2 = frames[0] real_img = frames1 g_loss_fake = 0. if args.lambda_fake_g > 0: noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) if args.augment: fake_img, _ = augment(fake_img, ada_aug_p) fake_pred = discriminator(fake_img) g_loss_fake = g_nonsaturating_loss(fake_pred) if args.no_sim_g: generator.zero_grad() (g_loss_fake * args.lambda_fake_g).backward() g_optim.step() g_loss_fake_cross = 0. if args.lambda_fake_cross_g > 0: if args.use_ema: e_ema.eval() w1, _ = e_ema(frames1) w2, _ = e_ema(frames2) else: w1, _ = encoder(frames1) w2, _ = encoder(frames2) dw = w2 - w1 noise = mixing_noise(args.batch, args.latent, args.mixing, device) style = generator.get_styles(noise).view(args.batch, -1) if dw.shape[1] < style.shape[1]: # W space dw = dw.repeat(1, args.n_latent) cross_img, _ = generator([style + dw], input_is_latent=True) fake_cross_pred = discriminator(cross_img) g_loss_fake_cross = g_nonsaturating_loss(fake_cross_pred) if args.no_sim_g: generator.zero_grad() (g_loss_fake_cross * args.lambda_fake_cross_g).backward() g_optim.step() g_loss = g_loss_fake * args.lambda_fake_g + g_loss_fake_cross * args.lambda_fake_cross_g loss_dict["g"] = g_loss if not args.no_sim_g: generator.zero_grad() g_loss.backward() g_optim.step() g_regularize = args.g_reg_every > 0 and i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) fake_img, latents = generator(noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() accumulate(e_ema, e_module, accum) accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() pix_loss_val = loss_reduced["pix"].mean().item() vgg_loss_val = loss_reduced["vgg"].mean().item() adv_loss_val = loss_reduced["adv"].mean().item() avg_pix_loss.update(pix_loss_val, real_img.shape[0]) avg_vgg_loss.update(vgg_loss_val, real_img.shape[0]) if get_rank() == 0: pbar.set_description(( f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}; " f"pix: {pix_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; adv: {adv_loss_val:.4f}" )) if i % args.log_every == 0: with torch.no_grad(): latent_x, _ = e_ema(sample_x1) fake_x, _ = generator([latent_x], input_is_latent=True, return_latents=False) sample_pix_loss = torch.sum((sample_x1 - fake_x)**2) with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write( f"{i:07d}; pix: {avg_pix_loss.avg}; vgg: {avg_vgg_loss.avg}; " f"ref: {sample_pix_loss.item()};\n") if args.eval_every > 0 and i % args.eval_every == 0: with torch.no_grad(): g_ema.eval() e_ema.eval() # Sample FID if args.truncation < 1: mean_latent = g_ema.mean_latent(4096) features = extract_feature_from_samples( g_ema, inception, args.truncation, mean_latent, 64, args.n_sample_fid, args.device).numpy() sample_mean = np.mean(features, 0) sample_cov = np.cov(features, rowvar=False) fid_sa = calc_fid(sample_mean, sample_cov, real_mean, real_cov) # Recon FID features = extract_feature_from_recon_hybrid( e_ema, g_ema, inception, args.truncation, mean_latent, loader2, args.device, mode='recon', ).numpy() sample_mean = np.mean(features, 0) sample_cov = np.cov(features, rowvar=False) fid_re = calc_fid(sample_mean, sample_cov, real_mean, real_cov) # Hybrid FID features = extract_feature_from_recon_hybrid( e_ema, g_ema, inception, args.truncation, mean_latent, loader2, args.device, mode='hybrid', shuffle_idx=fid_batch_idx).numpy() sample_mean = np.mean(features, 0) sample_cov = np.cov(features, rowvar=False) fid_hy = calc_fid(sample_mean, sample_cov, real_mean, real_cov) # print("Sample FID:", fid_sa, "Recon FID:", fid_re, "Hybrid FID:", fid_hy) with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f: f.write( f"{i:07d}; sample fid: {float(fid_sa):.4f}; recon fid: {float(fid_re):.4f}; hybrid fid: {float(fid_hy):.4f};\n" ) if wandb and args.wandb: wandb.log({ "Generator": g_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, }) if i % args.log_every == 0: with torch.no_grad(): g_ema.eval() e_ema.eval() nrow = int(args.n_sample**0.5) nchw = list(sample_x1.shape)[1:] # Fixed fake samples sample, _ = g_ema([sample_z]) utils.save_image( sample, os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-sample.png"), nrow=nrow, normalize=True, value_range=(-1, 1), ) # Fake hybrid samples w1, _ = e_ema(sample_x1) w2, _ = e_ema(sample_x2) dw = w2 - w1 style_z = g_ema.get_styles([sample_z ]).view(args.n_sample, -1) if dw.shape[1] < style_z.shape[1]: # W space dw = dw.repeat(1, args.n_latent) fake_img, _ = g_ema([style_z + dw], input_is_latent=True) drive = torch.cat(( sample_x1.reshape(args.n_sample, 1, *nchw), sample_x2.reshape(args.n_sample, 1, *nchw), ), 1) source = torch.cat(( sample.reshape(args.n_sample, 1, *nchw), fake_img.reshape(args.n_sample, 1, *nchw), ), 1) sample = torch.cat(( drive.reshape(args.n_sample // nrow, 2 * nrow, *nchw), source.reshape(args.n_sample // nrow, 2 * nrow, *nchw), ), 1) utils.save_image( sample.reshape(4 * args.n_sample, *nchw), os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-sample_hybrid.png"), nrow=2 * nrow, normalize=True, value_range=(-1, 1), ) # Reconstruction samples latent_real, _ = e_ema(sample_x1) fake_img, _ = g_ema([latent_real], input_is_latent=True, return_latents=False) sample = torch.cat( (sample_x1.reshape(args.n_sample // nrow, nrow, *nchw), fake_img.reshape(args.n_sample // nrow, nrow, *nchw)), 1) utils.save_image( sample.reshape(2 * args.n_sample, *nchw), os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-recon.png"), nrow=nrow, normalize=True, value_range=(-1, 1), ) # Cross reconstruction: [real_y1, real_y2; real_x1, fake_x2] w1, _ = e_ema(sample_x1) w2, _ = e_ema(sample_x2) dw = w2 - w1 dw = torch.cat(dw.chunk(2, 0)[::-1], 0) if sample_idx is None else dw[sample_idx, ...] fake_img, _ = g_ema([w1 + dw], input_is_latent=True, return_latents=False) # sample = torch.cat((sample_x2.reshape(args.n_sample//nrow, nrow, *nchw), # fake_img.reshape(args.n_sample//nrow, nrow, *nchw)), 1) drive = torch.cat(( torch.cat(sample_x1.chunk(2, 0)[::-1], 0).reshape( args.n_sample, 1, *nchw), torch.cat(sample_x2.chunk(2, 0)[::-1], 0).reshape( args.n_sample, 1, *nchw), ), 1) # [n_sample, 2, C, H, w] source = torch.cat(( sample_x1.reshape(args.n_sample, 1, *nchw), fake_img.reshape(args.n_sample, 1, *nchw), ), 1) # [n_sample, 2, C, H, w] sample = torch.cat(( drive.reshape(args.n_sample // nrow, 2 * nrow, *nchw), source.reshape(args.n_sample // nrow, 2 * nrow, *nchw), ), 1) # [n_sample//nrow, 4*nrow, C, H, W] utils.save_image( sample.reshape(4 * args.n_sample, *nchw), os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-cross.png"), nrow=2 * nrow, normalize=True, value_range=(-1, 1), ) if i % args.save_every == 0: torch.save( { "g": g_module.state_dict(), "e": e_module.state_dict(), "d": d_module.state_dict(), "d2": d2_module.state_dict() if args.decouple_d else None, "g_ema": g_ema.state_dict(), "e_ema": e_ema.state_dict(), "g_optim": g_optim.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "d2_optim": d2_optim.state_dict() if args.decouple_d else None, "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"), ) if i % args.save_latest_every == 0: torch.save( { "g": g_module.state_dict(), "e": e_module.state_dict(), "d": d_module.state_dict(), "d2": d2_module.state_dict() if args.decouple_d else None, "g_ema": g_ema.state_dict(), "e_ema": e_ema.state_dict(), "g_optim": g_optim.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "d2_optim": d2_optim.state_dict() if args.decouple_d else None, "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"latest.pt"), )
class StyleGAN2(pl.LightningModule): def __init__(self, hparams) -> None: super().__init__() self.hparams = hparams self.generator = Generator( self.hparams.size_h, self.hparams.size_w, self.hparams.log_size, self.hparams.latent, self.hparams.n_mlp, channel_multiplier=self.hparams.channel_multiplier) self.discriminator = Discriminator( self.hparams.size_h, self.hparams.size_w, self.hparams.log_size, channel_multiplier=self.hparams.channel_multiplier) self.g_ema = Generator( self.hparams.size_h, self.hparams.size_w, self.hparams.log_size, self.hparams.latent, self.hparams.n_mlp, channel_multiplier=self.hparams.channel_multiplier) self.img_dim = (3, self.hparams.size_h, self.hparams.size_w) self.accum = 0.5**(32 / (10 * 1000)) self.mean_path_length = 0 self.ada_aug_p = max(self.hparams.augment_p, 0) self.use_ada_augment = self.hparams.augment and self.hparams.augment_p == 0 if self.use_ada_augment: self.ada_augment = AdaptiveAugment(self.hparams.ada_target, self.hparams.ada_length, self.hparams.ada_every) self.g_ema.eval() accumulate(self.g_ema, self.generator, 0) self.example_input_array = [torch.zeros(1, self.hparams.latent)] def configure_optimizers(self): cfg = self.hparams g_reg_ratio = cfg.g_reg_every / (cfg.g_reg_every + 1) d_reg_ratio = cfg.d_reg_every / (cfg.d_reg_every + 1) g_optim = optim.Adam( self.generator.parameters(), lr=cfg.lr * g_reg_ratio, betas=(0**g_reg_ratio, 0.99**g_reg_ratio), ) d_optim = optim.Adam( self.discriminator.parameters(), lr=cfg.lr * d_reg_ratio, betas=(0**d_reg_ratio, 0.99**d_reg_ratio), ) return ( { 'optimizer': d_optim }, { 'optimizer': d_optim }, { 'optimizer': g_optim }, { 'optimizer': g_optim }, ) def training_step(self, batch, batch_idx, optimizer_idx): real_img = batch[0] batch_size = real_img.shape[0] if optimizer_idx == 0: requires_grad(self.generator, False) requires_grad(self.discriminator, True) noise = mixing_noise(batch_size, self.hparams.latent, self.hparams.mixing, self.device) fake_img, _ = self.generator(noise) if self.hparams.augment: real_img_aug, _ = augment(real_img, self.ada_aug_p) fake_img, _ = augment(fake_img, self.ada_aug_p) else: real_img_aug = real_img fake_pred = self.discriminator(fake_img) real_pred = self.discriminator(real_img_aug) d_loss = d_logistic_loss(real_pred, fake_pred) self.log('real_score', real_pred.mean()) self.log('fake_score', fake_pred.mean()) self.log('d', d_loss, prog_bar=True) if self.use_ada_augment: self.ada_aug_p = self.ada_augment.tune(real_pred) if self.hparams.augment: self.log('ada_aug_p', self.ada_aug_p, prog_bar=True) return {'loss': d_loss} if optimizer_idx == 1: if self.trainer.global_step % self.hparams.d_reg_every != 0: return requires_grad(self.generator, False) requires_grad(self.discriminator, True) real_img.requires_grad = True if self.hparams.augment: real_img_aug, _ = augment(real_img, self.ada_aug_p) else: real_img_aug = real_img real_pred = self.discriminator(real_img_aug) r1_loss = d_r1_loss(real_pred, real_img) self.log('r1', r1_loss, prog_bar=True) return { 'loss': (self.hparams.r1 / 2 * r1_loss * self.hparams.d_reg_every + 0 * real_pred[0]) } if optimizer_idx == 2: requires_grad(self.generator, True) requires_grad(self.discriminator, False) if self.hparams.top_k_batches > 0: with torch.no_grad(): noises, scores = [], [] rand = random.random() for _ in range(self.hparams.top_k_batches): noise = mixing_noise(self.hparams.top_k_batch_size, self.hparams.latent, self.hparams.mixing, self.device, rand=rand) fake_img, _ = self.generator(noise) if self.hparams.augment: fake_img, _ = augment(fake_img, self.ada_aug_p) score = self.discriminator(fake_img) noises.append(noise) scores.append(score) scores = torch.cat(scores) best_score_ids = torch.argsort( scores, descending=True)[:batch_size].squeeze(1) noise = [ torch.cat([n[idx] for n in noises], dim=0)[best_score_ids] for idx in range(len(noises[0])) ] else: noise = mixing_noise(batch_size, self.hparams.latent, self.hparams.mixing, self.device) fake_img, _ = self.generator(noise) if self.hparams.augment: fake_img, _ = augment(fake_img, self.ada_aug_p) fake_pred = self.discriminator(fake_img) g_loss = g_nonsaturating_loss(fake_pred) self.log('g', g_loss, prog_bar=True) return {'loss': g_loss} if optimizer_idx == 3: if self.trainer.global_step % self.hparams.g_reg_every != 0: return requires_grad(self.generator, True) requires_grad(self.discriminator, False) path_batch_size = max(1, batch_size // self.hparams.path_batch_shrink) noise = mixing_noise(path_batch_size, self.hparams.latent, self.hparams.mixing, self.device) fake_img, latents = self.generator(noise, return_latents=True) path_loss, self.mean_path_length, path_lengths = g_path_regularize( fake_img, latents, self.mean_path_length) weighted_path_loss = self.hparams.path_regularize * self.hparams.g_reg_every * path_loss if self.hparams.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] self.log('path', path_loss, prog_bar=True) self.log('mean_path', self.mean_path_length, prog_bar=True) self.log('path_length', path_lengths.mean()) return {'loss': weighted_path_loss} def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): accumulate(self.g_ema, self.generator, self.accum) def forward(self, batch): return self.g_ema([batch])
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device): loader = sample_data(loader) print("Save checkpoint very %d iter." % args.save_iter) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 with torch.cuda.amp.autocast(): r1_loss = torch.tensor(0.0, device=device) path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} if args.distributed: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator accum = 0.5**(32 / (10 * 1000)) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256, device) sample_z = torch.randn(args.n_sample, args.latent, device=device) for idx in pbar: i = idx + args.start_iter print("iter: %d" % i) if i > args.iter: print("Done!") break ############################################################################## # Casts operations to mixed precision # https://pytorch.org/docs/stable/notes/amp_examples.html with torch.cuda.amp.autocast(): real_img = next(loader) real_img = real_img.to(device) requires_grad(generator, False) requires_grad(discriminator, True) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) fake_img, _ = augment(fake_img, ada_aug_p) else: real_img_aug = real_img fake_pred = discriminator(fake_img) real_pred = discriminator(real_img_aug) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict["d"] = d_loss loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() discriminator.zero_grad() # Scales loss. Calls backward() on scaled loss to create scaled gradients. # Backward passes under autocast are not recommended. # Backward ops run in the same dtype autocast chose for corresponding forward ops. scaler.scale(d_loss).backward() # scaler.step() first unscales the gradients of the optimizer's assigned params. # If these gradients do not contain infs or NaNs, optimizer.step() is then called, # otherwise, optimizer.step() is skipped. scaler.step(d_optim) scaler.update() if args.augment and args.augment_p == 0: ada_aug_p = ada_augment.tune(real_pred) r_t_stat = ada_augment.r_t_stat d_regularize = i % args.d_reg_every == 0 # Runs the forward pass with autocasting. if d_regularize: with torch.cuda.amp.autocast(): real_img.requires_grad = True real_pred = discriminator(real_img) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() scaler.scale((args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0])).backward() scaler.step(d_optim) scaler.update() loss_dict["r1"] = r1_loss # Runs the forward pass with autocasting. with torch.cuda.amp.autocast(): requires_grad(generator, True) requires_grad(discriminator, False) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) if args.augment: fake_img, _ = augment(fake_img, ada_aug_p) fake_pred = discriminator(fake_img) g_loss = g_nonsaturating_loss(fake_pred) loss_dict["g"] = g_loss generator.zero_grad() scaler.scale(g_loss).backward() #g_loss.backward() scaler.step(g_optim) scaler.update() #g_optim.step() g_regularize = i % args.g_reg_every == 0 if g_regularize: with torch.cuda.amp.autocast(): path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) fake_img, latents = generator(noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] scaler.scale(weighted_path_loss).backward() #weighted_path_loss.backward() scaler.step(g_optim) #g_optim.step() # Updates the scale for next iteration scaler.update() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() if get_rank() == 0: pbar.set_description(( f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}")) if wandb and args.wandb: wandb.log({ "Generator": g_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, }) if i % 100 == 0: with torch.no_grad(): g_ema.eval() sample, _ = g_ema([sample_z]) utils.save_image( sample, f"sample/{str(i).zfill(6)}.png", nrow=int(args.n_sample**0.5), normalize=True, range=(-1, 1), ) if i % args.save_iter == 0: torch.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, }, f"checkpoint/{str(i).zfill(6)}.pt", )
def train(args, loader, loader2, generator, encoder, discriminator, discriminator2, vggnet, g_optim, e_optim, d_optim, d2_optim, g_ema, e_ema, device): inception = real_mean = real_cov = mean_latent = None if args.eval_every > 0: inception = nn.DataParallel(load_patched_inception_v3()).to(device) inception.eval() with open(args.inception, "rb") as f: embeds = pickle.load(f) real_mean = embeds["mean"] real_cov = embeds["cov"] if get_rank() == 0: if args.eval_every > 0: with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f: f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n") if args.log_every > 0: with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n") loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = { 'recx_score': torch.tensor(0.0, device=device), 'ae_fake': torch.tensor(0.0, device=device), 'ae_real': torch.tensor(0.0, device=device), 'pix': torch.tensor(0.0, device=device), 'vgg': torch.tensor(0.0, device=device), } avg_pix_loss = util.AverageMeter() avg_vgg_loss = util.AverageMeter() if args.distributed: g_module = generator.module e_module = encoder.module d_module = discriminator.module else: g_module = generator e_module = encoder d_module = discriminator d2_module = None if discriminator2 is not None: if args.distributed: d2_module = discriminator2.module else: d2_module = discriminator2 # When joint training enabled, d_weight balances reconstruction loss and adversarial loss on # recontructed real images. This does not balance the overall AE loss and GAN loss. d_weight = torch.tensor(1.0, device=device) last_layer = None if args.use_adaptive_weight: if args.distributed: last_layer = generator.module.get_last_layer() else: last_layer = generator.get_last_layer() g_scale = 1 ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 r_t_dict = {'real': 0, 'fake': 0, 'recx': 0} # r_t stat if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, args.ada_every, device) sample_z = torch.randn(args.n_sample, args.latent, device=device) sample_x = load_real_samples(args, loader) if sample_x.ndim > 4: sample_x = sample_x[:, 0, ...] input_is_latent = args.latent_space != 'z' # Encode in z space? n_step_max = max(args.n_step_d, args.n_step_e) requires_grad(g_ema, False) requires_grad(e_ema, False) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break if args.debug: util.seed_everything(i) real_imgs = [next(loader).to(device) for _ in range(n_step_max)] # Train Discriminator and Encoder requires_grad(generator, False) requires_grad(encoder, True) requires_grad(discriminator, True) requires_grad(discriminator2, True) for step_index in range(args.n_step_d): real_img = real_imgs[step_index] noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) fake_img_aug, _ = augment(fake_img, ada_aug_p) else: real_img_aug = real_img fake_img_aug = fake_img real_pred = discriminator(encoder(real_img_aug)[0]) fake_pred = discriminator(encoder(fake_img_aug)[0]) d_loss_real = F.softplus(-real_pred).mean() d_loss_fake = F.softplus(fake_pred).mean() loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() d_loss_rec = 0. if args.lambda_rec_d > 0: latent_real, _ = encoder(real_img) rec_img, _ = generator([latent_real], input_is_latent=input_is_latent) if args.augment: rec_img, _ = augment(rec_img, ada_aug_p) rec_pred = discriminator(encoder(rec_img)[0]) d_loss_rec = F.softplus(rec_pred).mean() loss_dict["recx_score"] = rec_pred.mean() r_t_dict['recx'] = torch.sign( rec_pred).sum().item() / args.batch d_loss = d_loss_real + d_loss_fake * args.lambda_fake_d + d_loss_rec * args.lambda_rec_d loss_dict["d"] = d_loss discriminator.zero_grad() encoder.zero_grad() d_loss.backward() d_optim.step() e_optim.step() if args.augment and args.augment_p == 0: ada_aug_p = ada_augment.tune(real_pred) r_t_stat = ada_augment.r_t_stat r_t_dict['real'] = torch.sign(real_pred).sum().item() / args.batch r_t_dict['fake'] = torch.sign(fake_pred).sum().item() / args.batch d_regularize = args.d_reg_every > 0 and i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) else: real_img_aug = real_img real_pred = discriminator(encoder(real_img_aug)[0]) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() encoder.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() e_optim.step() loss_dict["r1"] = r1_loss # Train Generator requires_grad(generator, True) requires_grad(encoder, False) requires_grad(discriminator, False) requires_grad(discriminator2, False) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) if args.augment: fake_img_aug, _ = augment(fake_img, ada_aug_p) else: fake_img_aug = fake_img fake_pred = discriminator(encoder(fake_img_aug)[0]) g_loss_fake = g_nonsaturating_loss(fake_pred) loss_dict["g"] = g_loss_fake generator.zero_grad() (g_loss_fake * args.lambda_fake_g).backward() g_optim.step() g_regularize = args.g_reg_every > 0 and i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) fake_img, latents = generator(noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() # Train Encoder (and Generator) joint = (not args.no_joint) and (g_scale > 1e-6) # Train AE on fake samples (latent reconstruction) if args.lambda_rec_w + (args.lambda_pix_fake + args.lambda_vgg_fake + args.lambda_adv_fake) > 0: requires_grad(encoder, True) requires_grad(generator, joint) requires_grad(discriminator, False) requires_grad(discriminator2, False) for step_index in range(args.n_step_e): # mixing_prob = 0 if args.which_latent == 'w_tied' else args.mixing # noise = mixing_noise(args.batch, args.latent, mixing_prob, device) # fake_img, latent_fake = generator(noise, return_latents=True, detach_style=not args.no_detach_style) # if args.which_latent == 'w_tied': # latent_fake = latent_fake[:,0,:] # else: # latent_fake = latent_fake.view(args.batch, -1) # latent_pred, _ = encoder(fake_img) # ae_loss_fake = torch.mean((latent_pred - latent_fake.detach()) ** 2) ae_loss_fake = 0 mixing_prob = 0 if args.which_latent == 'w_tied' else args.mixing if args.lambda_rec_w > 0: noise = mixing_noise(args.batch, args.latent, mixing_prob, device) fake_img, latent_fake = generator( noise, return_latents=True, detach_style=not args.no_detach_style) if args.which_latent == 'w_tied': latent_fake = latent_fake[:, 0, :] else: latent_fake = latent_fake.view(args.batch, -1) latent_pred, _ = encoder(fake_img) ae_loss_fake = torch.mean( (latent_pred - latent_fake.detach())**2) if args.lambda_pix_fake + args.lambda_vgg_fake + args.lambda_adv_fake > 0: pix_loss = vgg_loss = adv_loss = torch.tensor( 0., device=device) noise = mixing_noise(args.batch, args.latent, mixing_prob, device) fake_img, _ = generator(noise, detach_style=False) fake_img = fake_img.detach() latent_pred, _ = encoder(fake_img) rec_img, _ = generator([latent_pred], input_is_latent=input_is_latent) if args.lambda_pix_fake > 0: if args.pix_loss == 'l2': pix_loss = torch.mean((rec_img - fake_img)**2) elif args.pix_loss == 'l1': pix_loss = F.l1_loss(rec_img, fake_img) if args.lambda_vgg_fake > 0: vgg_loss = torch.mean( (vggnet(fake_img) - vggnet(rec_img))**2) ae_loss_fake = (ae_loss_fake + pix_loss * args.lambda_pix_fake + vgg_loss * args.lambda_vgg_fake) loss_dict["ae_fake"] = ae_loss_fake if joint: encoder.zero_grad() generator.zero_grad() (ae_loss_fake * args.lambda_rec_w).backward() e_optim.step() if args.g_decay is not None: scale_grad(generator, g_scale) # Do NOT update F (or generator.style). Grad should be zero when style # is detached in generator, but we explicitly zero it, just in case. if not args.no_detach_style: generator.style.zero_grad() g_optim.step() else: encoder.zero_grad() (ae_loss_fake * args.lambda_rec_w).backward() e_optim.step() # Train AE on real samples (image reconstruction) if args.lambda_pix + args.lambda_vgg + args.lambda_adv > 0: requires_grad(encoder, True) requires_grad(generator, joint) requires_grad(discriminator, False) requires_grad(discriminator2, False) pix_loss = vgg_loss = adv_loss = torch.tensor(0., device=device) for step_index in range(args.n_step_e): real_img = real_imgs[step_index] latent_real, _ = encoder(real_img) rec_img, _ = generator([latent_real], input_is_latent=input_is_latent) if args.lambda_pix > 0: if args.pix_loss == 'l2': pix_loss = torch.mean((rec_img - real_img)**2) elif args.pix_loss == 'l1': pix_loss = F.l1_loss(rec_img, real_img) if args.lambda_vgg > 0: vgg_loss = torch.mean( (vggnet(real_img) - vggnet(rec_img))**2) if args.lambda_adv > 0: if args.augment: rec_img_aug, _ = augment(rec_img, ada_aug_p) else: rec_img_aug = rec_img rec_pred = discriminator(encoder(rec_img_aug)[0]) adv_loss = g_nonsaturating_loss(rec_pred) if args.use_adaptive_weight and i >= args.disc_iter_start: nll_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg g_loss = adv_loss * args.lambda_adv d_weight = calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) ae_loss_real = (pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + d_weight * adv_loss * args.lambda_adv) loss_dict["ae_real"] = ae_loss_real loss_dict["pix"] = pix_loss loss_dict["vgg"] = vgg_loss loss_dict["adv"] = adv_loss if joint: encoder.zero_grad() generator.zero_grad() ae_loss_real.backward() e_optim.step() if args.g_decay is not None: scale_grad(generator, g_scale) g_optim.step() else: encoder.zero_grad() ae_loss_real.backward() e_optim.step() if args.g_decay is not None: g_scale *= args.g_decay # Update EMA ema_nimg = args.ema_kimg * 1000 if args.ema_rampup is not None: ema_nimg = min(ema_nimg, i * args.batch * args.ema_rampup) accum = 0.5**(args.batch / max(ema_nimg, 1e-8)) accumulate(g_ema, g_module, 0 if args.no_ema_g else accum) accumulate(e_ema, e_module, 0 if args.no_ema_e else accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() ae_real_val = loss_reduced["ae_real"].mean().item() ae_fake_val = loss_reduced["ae_fake"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() recx_score_val = loss_reduced["recx_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() pix_loss_val = loss_reduced["pix"].mean().item() vgg_loss_val = loss_reduced["vgg"].mean().item() avg_pix_loss.update(pix_loss_val, real_img.shape[0]) avg_vgg_loss.update(vgg_loss_val, real_img.shape[0]) if get_rank() == 0: pbar.set_description(( f"d: {d_loss_val:.4f}; r1: {r1_val:.4f}; " f"ae_fake: {ae_fake_val:.4f}; ae_real: {ae_real_val:.4f}; " f"g: {g_loss_val:.4f}; path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}; " f"d_weight: {d_weight.item():.4f}; ")) if i % args.log_every == 0: with torch.no_grad(): g_ema.eval() e_ema.eval() nrow = int(args.n_sample**0.5) nchw = list(sample_x.shape)[1:] # Reconstruction of real images latent_x, _ = e_ema(sample_x) rec_real, _ = g_ema([latent_x], input_is_latent=input_is_latent) sample = torch.cat( (sample_x.reshape(args.n_sample // nrow, nrow, *nchw), rec_real.reshape(args.n_sample // nrow, nrow, *nchw)), 1) utils.save_image( sample.reshape(2 * args.n_sample, *nchw), os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-recon.png"), nrow=nrow, normalize=True, value_range=(-1, 1), ) ref_pix_loss = torch.sum(torch.abs(sample_x - rec_real)) ref_vgg_loss = torch.mean( (vggnet(sample_x) - vggnet(rec_real))**2) if vggnet is not None else 0 # Fixed fake samples and reconstructions sample_gz, _ = g_ema([sample_z]) latent_gz, _ = e_ema(sample_gz) rec_fake, _ = g_ema([latent_gz], input_is_latent=input_is_latent) sample = torch.cat( (sample_gz.reshape(args.n_sample // nrow, nrow, *nchw), rec_fake.reshape(args.n_sample // nrow, nrow, *nchw)), 1) utils.save_image( sample.reshape(2 * args.n_sample, *nchw), os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-sample.png"), nrow=nrow, normalize=True, value_range=(-1, 1), ) with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write(( f"{i:07d}; " f"d: {d_loss_val:.4f}; r1: {r1_val:.4f}; " f"ae_fake: {ae_fake_val:.4f}; ae_real: {ae_real_val:.4f}; " f"g: {g_loss_val:.4f}; path: {path_loss_val:.4f}; mean_path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}; {'; '.join([f'{k}: {r_t_dict[k]:.4f}' for k in r_t_dict])}; " f"real_score: {real_score_val:.4f}; fake_score: {fake_score_val:.4f}; recx_score: {recx_score_val:.4f}; " f"pix: {avg_pix_loss.avg:.4f}; vgg: {avg_vgg_loss.avg:.4f}; " f"ref_pix: {ref_pix_loss.item():.4f}; ref_vgg: {ref_vgg_loss.item():.4f}; " f"d_weight: {d_weight.item():.4f}; " f"\n")) if wandb and args.wandb: wandb.log({ "Generator": g_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, }) if args.eval_every > 0 and i % args.eval_every == 0: with torch.no_grad(): fid_sa = fid_re = fid_sr = 0 g_ema.eval() e_ema.eval() if args.truncation < 1: mean_latent = g_ema.mean_latent(4096) # Sample FID if 'fid_sample' in args.which_metric: features = extract_feature_from_samples( g_ema, inception, args.truncation, mean_latent, 64, args.n_sample_fid, args.device).numpy() sample_mean = np.mean(features, 0) sample_cov = np.cov(features, rowvar=False) fid_sa = calc_fid(sample_mean, sample_cov, real_mean, real_cov) # Sample reconstruction FID if 'fid_sample_recon' in args.which_metric: features = extract_feature_from_samples( g_ema, inception, args.truncation, mean_latent, 64, args.n_sample_fid, args.device, mode='recon', encoder=e_ema, input_is_latent=input_is_latent, ).numpy() sample_mean = np.mean(features, 0) sample_cov = np.cov(features, rowvar=False) fid_sr = calc_fid(sample_mean, sample_cov, real_mean, real_cov) # Real reconstruction FID if 'fid_recon' in args.which_metric: features = extract_feature_from_reconstruction( e_ema, g_ema, inception, args.truncation, mean_latent, loader2, args.device, input_is_latent=input_is_latent, mode='recon', ).numpy() sample_mean = np.mean(features, 0) sample_cov = np.cov(features, rowvar=False) fid_re = calc_fid(sample_mean, sample_cov, real_mean, real_cov) with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f: f.write( f"{i:07d}; sample: {float(fid_sa):.4f}; rec_fake: {float(fid_sr):.4f}; rec_real: {float(fid_re):.4f};\n" ) if i % args.save_every == 0: torch.save( { "g": g_module.state_dict(), "e": e_module.state_dict(), "d": d_module.state_dict(), "d2": d2_module.state_dict() if args.decouple_d else None, "g_ema": g_ema.state_dict(), "e_ema": e_ema.state_dict(), "g_optim": g_optim.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "d2_optim": d2_optim.state_dict() if args.decouple_d else None, "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"), ) if i % args.save_latest_every == 0: torch.save( { "g": g_module.state_dict(), "e": e_module.state_dict(), "d": d_module.state_dict(), "d2": d2_module.state_dict() if args.decouple_d else None, "g_ema": g_ema.state_dict(), "e_ema": e_ema.state_dict(), "g_optim": g_optim.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "d2_optim": d2_optim.state_dict() if args.decouple_d else None, "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"latest.pt"), )
def train(args, loader, encoder, generator, discriminator, discriminator_w, vggnet, pwcnet, e_optim, g_optim, g1_optim, d_optim, dw_optim, e_ema, g_ema, device): loader = sample_data(loader) args.toggle_grads = True args.augment = False pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 e_loss_val = 0 rec_loss_val = 0 vgg_loss_val = 0 adv_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) loss_dict = {"d": torch.tensor(0., device=device), "real_score": torch.tensor(0., device=device), "fake_score": torch.tensor(0., device=device), "hybrid_score": torch.tensor(0., device=device), "r1_d": torch.tensor(0., device=device), "rec": torch.tensor(0., device=device),} avg_pix_loss = util.AverageMeter() avg_vgg_loss = util.AverageMeter() if args.distributed: e_module = encoder.module d_module = discriminator.module g_module = generator.module else: e_module = encoder d_module = discriminator g_module = generator accum = 0.5 ** (32 / (10 * 1000)) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256, device) # sample_x = accumulate_batches(loader, args.n_sample).to(device) sample_x = load_real_samples(args, loader) sample_x1 = sample_x[:,0,...] sample_x2 = sample_x[:,-1,...] sample_idx = torch.randperm(args.n_sample) sample_z = torch.randn(args.n_sample, args.latent, device=device) for idx in pbar: i = idx + args.start_iter if get_rank() == 0: if i % args.log_every == 0: with torch.no_grad(): e_eval = e_ema e_eval.eval() g_ema.eval() nrow = int(args.n_sample ** 0.5) nchw = list(sample_x1.shape)[1:] # Recon latent_real, _ = e_eval(sample_x1) fake_img, _ = g_ema([latent_real], input_is_latent=True, return_latents=False) sample = torch.cat((sample_x1.reshape(args.n_sample//nrow, nrow, *nchw), fake_img.reshape(args.n_sample//nrow, nrow, *nchw)), 1) utils.save_image( sample.reshape(2*args.n_sample, *nchw), os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-recon.png"), nrow=nrow, normalize=True, value_range=(-1, 1), ) # Cross w1, _ = e_eval(sample_x1) w2, _ = e_eval(sample_x2) delta_w = w2 - w1 delta_w = delta_w[sample_idx,...] fake_img, _ = g_ema([w1 + delta_w], input_is_latent=True, return_latents=False) sample = torch.cat((sample_x2.reshape(args.n_sample//nrow, nrow, *nchw), fake_img.reshape(args.n_sample//nrow, nrow, *nchw)), 1) utils.save_image( sample.reshape(2*args.n_sample, *nchw), os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-cross.png"), nrow=nrow, normalize=True, value_range=(-1, 1), ) # Sample sample, _ = g_ema([sample_z]) utils.save_image( sample, os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-sample.png"), nrow=nrow, normalize=True, value_range=(-1, 1), ) e_eval.train() if i > args.iter: print("Done!") break frames = next(loader) # [N, T, C, H, W] batch = frames.shape[0] frames1 = frames[:,0,...] selected_indices = torch.sort(torch.multinomial(torch.ones(batch, args.nframe_num-1), 2)+1, 1)[0] frames2 = frames[range(batch),selected_indices[:,0],...] frames3 = frames[range(batch),selected_indices[:,1],...] frames1 = frames1.to(device) frames2 = frames2.to(device) frames3 = frames3.to(device) # Train Discriminator if args.toggle_grads: requires_grad(encoder, False) requires_grad(generator, False) requires_grad(discriminator, True) requires_grad(discriminator_w, True) real_img, fake_img, _, _, _ = cross_reconstruction(encoder, generator, frames1, frames2, frames3, args.cond_disc) if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) fake_img_aug, _ = augment(fake_img, ada_aug_p) else: real_img_aug = real_img fake_img_aug = fake_img fake_pred = discriminator(fake_img_aug) real_pred = discriminator(real_img_aug) d_loss = d_logistic_loss(real_pred, fake_pred) if args.lambda_gan > 0: noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) if args.augment: fake_img, _ = augment(fake_img, ada_aug_p) fake_pred = discriminator(fake_img) fake_loss = F.softplus(fake_pred) d_loss += fake_loss.mean() * args.lambda_gan loss_dict["d"] = d_loss loss_dict["real_score"] = real_pred.mean() # loss_dict["fake_score"] = fake_pred.mean() fake_pred1, fake_pred2 = fake_pred.chunk(2, dim=0) loss_dict["fake_score"] = fake_pred1.mean() loss_dict["hybrid_score"] = fake_pred2.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() if args.augment and args.augment_p == 0: ada_aug_p = ada_augment.tune(real_pred) r_t_stat = ada_augment.r_t_stat d_regularize = args.d_reg_every > 0 and i % args.d_reg_every == 0 if d_regularize: # why not regularize on augmented real? real_img.requires_grad = True real_pred = discriminator(real_img) r1_loss_d = d_r1_loss(real_pred, real_img) d_optim.zero_grad() (args.r1 / 2 * r1_loss_d * args.d_reg_every + 0 * real_pred.view(-1)[0]).backward() # Why 0* ? Answer is here https://github.com/rosinality/stylegan2-pytorch/issues/76 d_optim.step() loss_dict["r1_d"] = r1_loss_d # Train Discriminator_W if args.learned_prior and args.lambda_gan_w > 0: noise = mixing_noise(args.batch, args.latent, 0, device) fake_w = generator.get_latent(noise[0]) real_w, _ = encoder(frames1) fake_pred = discriminator_w(fake_w) real_pred = discriminator_w(real_w) d_loss_w = d_logistic_loss(real_pred, fake_pred) dw_optim.zero_grad() d_loss_w.backward() dw_optim.step() # Train Encoder and Generator if args.toggle_grads: requires_grad(encoder, True) requires_grad(generator, True) requires_grad(discriminator, False) requires_grad(discriminator_w, False) pix_loss = vgg_loss = adv_loss = rec_loss = torch.tensor(0., device=device) _, fake_img, x_real, x_recon, x_cross = cross_reconstruction(encoder, generator, frames1, frames2, frames3, args.cond_disc) if args.lambda_adv > 0: if args.augment: fake_img_aug, _ = augment(fake_img, ada_aug_p) else: fake_img_aug = fake_img fake_pred = discriminator(fake_img_aug) adv_loss = g_nonsaturating_loss(fake_pred) if args.lambda_pix > 0: if args.pix_loss == 'l2': pix_loss = torch.mean((x_recon - x_real) ** 2) else: pix_loss = F.l1_loss(x_recon, x_real) if args.lambda_vgg > 0: real_feat = vggnet(x_real) fake_feat = vggnet(x_recon) if not args.vgg_on_cross else vggnet(x_cross) vgg_loss = torch.mean((fake_feat - real_feat) ** 2) e_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv if args.lambda_gan > 0 and not args.no_sim_opt: noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) if args.augment: fake_img, _ = augment(fake_img, ada_aug_p) fake_pred = discriminator(fake_img) g_loss = g_nonsaturating_loss(fake_pred) e_loss += g_loss * args.lambda_gan loss_dict["e"] = e_loss loss_dict["pix"] = pix_loss loss_dict["vgg"] = vgg_loss loss_dict["adv"] = adv_loss e_optim.zero_grad() g_optim.zero_grad() e_loss.backward() e_optim.step() g_optim.step() if args.learned_prior: g_loss_w = 0. if args.lambda_gan_w > 0: noise = mixing_noise(args.batch, args.latent, 0, device) fake_w = generator.get_latent(noise[0]) fake_pred = discriminator_w(fake_w) g_loss_w += g_nonsaturating_loss(fake_pred) * args.lambda_gan_w if args.lambda_adv_w > 0: noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) fake_pred = discriminator(fake_img) g_loss_w += g_nonsaturating_loss(fake_pred) * args.lambda_adv_w g1_optim.zero_grad() g_loss_w.backward() g1_optim.step() if args.lambda_gan > 0 and args.no_sim_opt: noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) if args.augment: fake_img, _ = augment(fake_img, ada_aug_p) fake_pred = discriminator(fake_img) g_loss = g_nonsaturating_loss(fake_pred) * args.lambda_gan generator.zero_grad() g_loss.backward() g_optim.step() g_regularize = args.lambda_gan > 0 and args.g_reg_every > 0 and i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) fake_img, latents = generator(noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length ) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() # mean_path_length_avg = ( # reduce_sum(mean_path_length).item() / get_world_size() # ) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() accumulate(e_ema, e_module, accum) accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() e_loss_val = loss_reduced["e"].mean().item() r1_d_val = loss_reduced["r1_d"].mean().item() pix_loss_val = loss_reduced["pix"].mean().item() vgg_loss_val = loss_reduced["vgg"].mean().item() adv_loss_val = loss_reduced["adv"].mean().item() rec_loss_val = loss_reduced["rec"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() hybrid_score_val = loss_reduced["hybrid_score"].mean().item() path_loss_val = loss_reduced["path"].mean().item() # path_length_val = loss_reduced["path_length"].mean().item() avg_pix_loss.update(pix_loss_val, real_img.shape[0]) avg_vgg_loss.update(vgg_loss_val, real_img.shape[0]) if get_rank() == 0: pbar.set_description( ( f"d: {d_loss_val:.4f}; e: {e_loss_val:.4f}; r1_d: {r1_d_val:.4f}; " f"pix: {pix_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; adv: {adv_loss_val:.4f}; " f"path: {path_loss_val:.4f}; augment: {ada_aug_p:.4f}" ) ) if i % args.log_every == 0: with torch.no_grad(): latent_x, _ = e_ema(sample_x1) fake_x, _ = generator([latent_x], input_is_latent=True, return_latents=False) sample_pix_loss = torch.sum((sample_x1 - fake_x) ** 2) with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write(f"{i:07d}; pix: {avg_pix_loss.avg}; vgg: {avg_vgg_loss.avg}; " f"ref: {sample_pix_loss.item()};\n") if wandb and args.wandb: wandb.log( { "Encoder": e_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1 D": r1_d_val, "Pix Loss": pix_loss_val, "VGG Loss": vgg_loss_val, "Adv Loss": adv_loss_val, "Rec Loss": rec_loss_val, "Real Score": real_score_val, "Fake Score": fake_score_val, "Hybrid Score": hybrid_score_val, } ) if i % args.save_every == 0: e_eval = e_ema torch.save( { "e": e_module.state_dict(), "d": d_module.state_dict(), "g": g_module.state_dict(), "g_ema": g_module.state_dict(), "e_ema": e_eval.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "g_optim": g_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"), ) if i % args.save_latest_every == 0: torch.save( { "e": e_module.state_dict(), "d": d_module.state_dict(), "g": g_module.state_dict(), "g_ema": g_module.state_dict(), "e_ema": e_eval.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "g_optim": g_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"latest.pt"), )
def train(args, loader, loader2, encoder, generator, discriminator, vggnet, pwcnet, e_optim, d_optim, e_ema, pca_state, device): inception = real_mean = real_cov = mean_latent = None if args.eval_every > 0: inception = nn.DataParallel(load_patched_inception_v3()).to(device) inception.eval() with open(args.inception, "rb") as f: embeds = pickle.load(f) real_mean = embeds["mean"] real_cov = embeds["cov"] if get_rank() == 0: if args.eval_every > 0: with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f: f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n") if args.log_every > 0: with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n") loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) d_loss_val = 0 e_loss_val = 0 rec_loss_val = 0 vgg_loss_val = 0 adv_loss_val = 0 loss_dict = { "d": torch.tensor(0., device=device), "real_score": torch.tensor(0., device=device), "fake_score": torch.tensor(0., device=device), "r1_d": torch.tensor(0., device=device), "r1_e": torch.tensor(0., device=device), "rec": torch.tensor(0., device=device), } avg_pix_loss = util.AverageMeter() avg_vgg_loss = util.AverageMeter() if args.distributed: e_module = encoder.module d_module = discriminator.module g_module = generator.module else: e_module = encoder d_module = discriminator g_module = generator # accum = 0.5 ** (32 / (10 * 1000)) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, args.ada_every, device) # sample_x = accumulate_batches(loader, args.n_sample).to(device) sample_x = load_real_samples(args, loader) if sample_x.ndim > 4: sample_x = sample_x[:, 0, ...] input_is_latent = args.latent_space != 'z' # Encode in z space? requires_grad(generator, False) # always False generator.eval() # Generator should be ema and in eval mode g_ema = generator # if args.no_ema or e_ema is None: # e_ema = encoder for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break real_img = next(loader) real_img = real_img.to(device) # Train Encoder if args.toggle_grads: requires_grad(encoder, True) requires_grad(discriminator, False) pix_loss = vgg_loss = adv_loss = rec_loss = torch.tensor(0., device=device) latent_real, _ = encoder(real_img) fake_img, _ = generator([latent_real], input_is_latent=input_is_latent) if args.lambda_adv > 0: if args.augment: fake_img_aug, _ = augment(fake_img, ada_aug_p) else: fake_img_aug = fake_img fake_pred = discriminator(fake_img_aug) adv_loss = g_nonsaturating_loss(fake_pred) if args.lambda_pix > 0: if args.pix_loss == 'l2': pix_loss = torch.mean((fake_img - real_img)**2) else: pix_loss = F.l1_loss(fake_img, real_img) if args.lambda_vgg > 0: real_feat = vggnet(real_img) fake_feat = vggnet(fake_img) vgg_loss = torch.mean((real_feat - fake_feat)**2) e_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv loss_dict["e"] = e_loss loss_dict["pix"] = pix_loss loss_dict["vgg"] = vgg_loss loss_dict["adv"] = adv_loss encoder.zero_grad() e_loss.backward() e_optim.step() if args.train_on_fake: e_regularize = args.e_rec_every > 0 and i % args.e_rec_every == 0 if e_regularize and args.lambda_rec > 0: noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, latent_fake = generator( noise, input_is_latent=input_is_latent, return_latents=True) latent_pred, _ = encoder(fake_img) if latent_pred.ndim < 3: latent_pred = latent_pred.unsqueeze(1).repeat( 1, latent_fake.size(1), 1) rec_loss = torch.mean((latent_fake - latent_pred)**2) encoder.zero_grad() (rec_loss * args.lambda_rec).backward() e_optim.step() loss_dict["rec"] = rec_loss # e_regularize = args.e_reg_every > 0 and i % args.e_reg_every == 0 # if e_regularize: # # why not regularize on augmented real? # real_img.requires_grad = True # real_pred, _ = encoder(real_img) # r1_loss_e = d_r1_loss(real_pred, real_img) # encoder.zero_grad() # (args.r1 / 2 * r1_loss_e * args.e_reg_every + 0 * real_pred.view(-1)[0]).backward() # e_optim.step() # loss_dict["r1_e"] = r1_loss_e if not args.no_ema and e_ema is not None: ema_nimg = args.ema_kimg * 1000 if args.ema_rampup is not None: ema_nimg = min(ema_nimg, i * args.batch * args.ema_rampup) accum = 0.5**(args.batch / max(ema_nimg, 1e-8)) accumulate(e_ema, e_module, accum) # Train Discriminator if args.toggle_grads: requires_grad(encoder, False) requires_grad(discriminator, True) if not args.no_update_discriminator and args.lambda_adv > 0: latent_real, _ = encoder(real_img) fake_img, _ = generator([latent_real], input_is_latent=input_is_latent) if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) fake_img_aug, _ = augment(fake_img, ada_aug_p) else: real_img_aug = real_img fake_img_aug = fake_img fake_pred = discriminator(fake_img_aug) real_pred = discriminator(real_img_aug) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict["d"] = d_loss loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() if args.augment and args.augment_p == 0: ada_aug_p = ada_augment.tune(real_pred) r_t_stat = ada_augment.r_t_stat d_regularize = args.d_reg_every > 0 and i % args.d_reg_every == 0 if d_regularize: # why not regularize on augmented real? real_img.requires_grad = True real_pred = discriminator(real_img) r1_loss_d = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss_d * args.d_reg_every + 0 * real_pred.view(-1)[0]).backward() # Why 0* ? Answer is here https://github.com/rosinality/stylegan2-pytorch/issues/76 d_optim.step() loss_dict["r1_d"] = r1_loss_d loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() e_loss_val = loss_reduced["e"].mean().item() r1_d_val = loss_reduced["r1_d"].mean().item() r1_e_val = loss_reduced["r1_e"].mean().item() pix_loss_val = loss_reduced["pix"].mean().item() vgg_loss_val = loss_reduced["vgg"].mean().item() adv_loss_val = loss_reduced["adv"].mean().item() rec_loss_val = loss_reduced["rec"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() avg_pix_loss.update(pix_loss_val, real_img.shape[0]) avg_vgg_loss.update(vgg_loss_val, real_img.shape[0]) if get_rank() == 0: pbar.set_description(( f"d: {d_loss_val:.4f}; e: {e_loss_val:.4f}; r1_d: {r1_d_val:.4f}; r1_e: {r1_e_val:.4f}; " f"pix: {pix_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; adv: {adv_loss_val:.4f}; " f"rec: {rec_loss_val:.4f}; augment: {ada_aug_p:.4f}")) if i % args.log_every == 0: with torch.no_grad(): latent_x, _ = e_ema(sample_x) fake_x, _ = g_ema([latent_x], input_is_latent=input_is_latent) sample_pix_loss = torch.sum((sample_x - fake_x)**2) with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write( f"{i:07d}; pix: {avg_pix_loss.avg}; vgg: {avg_vgg_loss.avg}; " f"ref: {sample_pix_loss.item()};\n") if args.eval_every > 0 and i % args.eval_every == 0: with torch.no_grad(): g_ema.eval() e_ema.eval() # Recon features = extract_feature_from_reconstruction( e_ema, g_ema, inception, args.truncation, mean_latent, loader2, args.device, input_is_latent=input_is_latent, mode='recon', ).numpy() sample_mean = np.mean(features, 0) sample_cov = np.cov(features, rowvar=False) fid_re = calc_fid(sample_mean, sample_cov, real_mean, real_cov) # print("Recon FID:", fid_re) with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f: f.write(f"{i:07d}; recon fid: {float(fid_re):.4f};\n") if wandb and args.wandb: wandb.log({ "Encoder": e_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1 D": r1_d_val, "R1 E": r1_e_val, "Pix Loss": pix_loss_val, "VGG Loss": vgg_loss_val, "Adv Loss": adv_loss_val, "Rec Loss": rec_loss_val, "Real Score": real_score_val, "Fake Score": fake_score_val, }) if i % args.log_every == 0: with torch.no_grad(): e_eval = encoder if args.no_ema else e_ema e_eval.eval() nrow = int(args.n_sample**0.5) nchw = list(sample_x.shape)[1:] latent_real, _ = e_eval(sample_x) fake_img, _ = generator([latent_real], input_is_latent=input_is_latent) sample = torch.cat( (sample_x.reshape(args.n_sample // nrow, nrow, *nchw), fake_img.reshape(args.n_sample // nrow, nrow, *nchw)), 1) utils.save_image( sample.reshape(2 * args.n_sample, *nchw), os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}.png"), nrow=nrow, normalize=True, value_range=(-1, 1), ) e_eval.train() if i % args.save_every == 0: e_eval = encoder if args.no_ema else e_ema torch.save( { "e": e_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_module.state_dict(), "e_ema": e_eval.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"), ) if i % args.save_latest_every == 0: torch.save( { "e": e_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_module.state_dict(), "e_ema": e_eval.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"latest.pt"), )
def train(args, loader, loader2, generator, encoder, discriminator, discriminator2, vggnet, g_optim, e_optim, d_optim, d2_optim, g_ema, e_ema, device): inception = real_mean = real_cov = mean_latent = None if args.eval_every > 0: inception = nn.DataParallel(load_patched_inception_v3()).to(device) inception.eval() with open(args.inception, "rb") as f: embeds = pickle.load(f) real_mean = embeds["mean"] real_cov = embeds["cov"] if get_rank() == 0: if args.eval_every > 0: with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f: f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n") if args.log_every > 0: with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n") loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} avg_pix_loss = util.AverageMeter() avg_vgg_loss = util.AverageMeter() if args.distributed: g_module = generator.module e_module = encoder.module d_module = discriminator.module else: g_module = generator e_module = encoder d_module = discriminator d2_module = None if discriminator2 is not None: if args.distributed: d2_module = discriminator2.module else: d2_module = discriminator2 # accum = 0.5 ** (32 / (10 * 1000)) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 r_t_dict = {'real': 0, 'fake': 0, 'recx': 0} # r_t stat real_diff = fake_diff = count = 0 g_scale = 1 if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, args.ada_every, device) if args.decouple_d and args.augment: ada_aug_p2 = args.augment_p if args.augment_p > 0 else 0.0 # r_t_stat2 = 0 if args.augment_p == 0: ada_augment2 = AdaptiveAugment(args.ada_target, args.ada_length, args.ada_every, device) sample_z = torch.randn(args.n_sample, args.latent, device=device) sample_x = load_real_samples(args, loader) sample_x1 = sample_x2 = sample_idx = fid_batch_idx = None if sample_x.ndim > 4: sample_x1 = sample_x[:, 0, ...] sample_x2 = sample_x[:, -1, ...] sample_x = sample_x[:, 0, ...] n_step_max = max(args.n_step_d, args.n_step_e) requires_grad(g_ema, False) requires_grad(e_ema, False) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break real_imgs = [next(loader).to(device) for _ in range(n_step_max)] # Train Discriminator requires_grad(generator, False) requires_grad(encoder, False) requires_grad(discriminator, True) for step_index in range(args.n_step_d): real_img = real_imgs[step_index] noise = mixing_noise(args.batch, args.latent, args.mixing, device) if args.use_ema: g_ema.eval() fake_img, _ = g_ema(noise) else: fake_img, _ = generator(noise) if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) fake_img, _ = augment(fake_img, ada_aug_p) else: real_img_aug = real_img fake_pred = discriminator(fake_img) real_pred = discriminator(real_img_aug) d_loss_fake = F.softplus(fake_pred).mean() d_loss_real = F.softplus(-real_pred).mean() loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() d_loss_rec = 0. if args.lambda_rec_d > 0 and not args.decouple_d: # Do not train D on x_rec if decouple_d if args.use_ema: e_ema.eval() g_ema.eval() latent_real, _ = e_ema(real_img) rec_img, _ = g_ema([latent_real], input_is_latent=True) else: latent_real, _ = encoder(real_img) rec_img, _ = generator([latent_real], input_is_latent=True) if args.augment: rec_img, _ = augment(rec_img, ada_aug_p) rec_pred = discriminator(rec_img) d_loss_rec = F.softplus(rec_pred).mean() loss_dict["recx_score"] = rec_pred.mean() d_loss = d_loss_real + d_loss_fake * args.lambda_fake_d + d_loss_rec * args.lambda_rec_d loss_dict["d"] = d_loss discriminator.zero_grad() d_loss.backward() d_optim.step() if args.augment and args.augment_p == 0: ada_aug_p = ada_augment.tune(real_pred) r_t_stat = ada_augment.r_t_stat # Compute batchwise r_t r_t_dict['real'] = torch.sign(real_pred).sum().item() / args.batch r_t_dict['fake'] = torch.sign(fake_pred).sum().item() / args.batch with torch.no_grad(): real_diff += torch.mean(real_pred - rec_pred).item() noise = mixing_noise(args.batch, args.latent, args.mixing, device) x_fake, _ = generator(noise) x_recf, _ = generator([encoder(x_fake)[0]], input_is_latent=True) recf_pred = discriminator(x_recf) fake_pred = discriminator(x_fake) fake_diff += torch.mean(fake_pred - recf_pred).item() count += 1 d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) else: real_img_aug = real_img real_pred = discriminator(real_img_aug) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict["r1"] = r1_loss # Train Discriminator2 if args.decouple_d and discriminator2 is not None: requires_grad(generator, False) requires_grad(encoder, False) requires_grad(discriminator2, True) for step_index in range( args.n_step_e): # n_step_d2 is same as n_step_e real_img = real_imgs[step_index] if args.use_ema: e_ema.eval() g_ema.eval() latent_real, _ = e_ema(real_img) rec_img, _ = g_ema([latent_real], input_is_latent=True) else: latent_real, _ = encoder(real_img) rec_img, _ = generator([latent_real], input_is_latent=True) if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p2) rec_img, _ = augment(rec_img, ada_aug_p2) else: real_img_aug = real_img rec_pred = discriminator2(rec_img) real_pred = discriminator2(real_img_aug) d2_loss_rec = F.softplus(rec_pred).mean() d2_loss_real = F.softplus(-real_pred).mean() d2_loss = d2_loss_real + d2_loss_rec loss_dict["d2"] = d2_loss loss_dict["recx_score"] = rec_pred.mean() discriminator2.zero_grad() d2_loss.backward() d2_optim.step() real_diff += torch.mean(real_pred - rec_pred).item() d_regularize = args.d_reg_every > 0 and i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True real_pred = discriminator2(real_img) r1_loss = d_r1_loss(real_pred, real_img) discriminator2.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d2_optim.step() if args.augment and args.augment_p == 0: ada_aug_p2 = ada_augment2.tune(rec_pred) # r_t_stat2 = ada_augment2.r_t_stat r_t_dict['recx'] = torch.sign(rec_pred).sum().item() / args.batch # Train Encoder requires_grad(encoder, True) requires_grad(generator, args.train_ge) requires_grad(discriminator, False) requires_grad(discriminator2, False) pix_loss = vgg_loss = adv_loss = torch.tensor(0., device=device) for step_index in range(args.n_step_e): real_img = real_imgs[step_index] latent_real, _ = encoder(real_img) if args.use_ema: g_ema.eval() rec_img, _ = g_ema([latent_real], input_is_latent=True) else: rec_img, _ = generator([latent_real], input_is_latent=True) if args.lambda_pix > 0: if args.pix_loss == 'l2': pix_loss = torch.mean((rec_img - real_img)**2) elif args.pix_loss == 'l1': pix_loss = F.l1_loss(rec_img, real_img) else: raise NotImplementedError if args.lambda_vgg > 0: vgg_loss = torch.mean((vggnet(real_img) - vggnet(rec_img))**2) if args.lambda_adv > 0: if not args.decouple_d: if args.augment: rec_img_aug, _ = augment(rec_img, ada_aug_p) else: rec_img_aug = rec_img rec_pred = discriminator(rec_img_aug) else: if args.augment: rec_img_aug, _ = augment(rec_img, ada_aug_p2) else: rec_img_aug = rec_img rec_pred = discriminator2(rec_img_aug) adv_loss = g_nonsaturating_loss(rec_pred) e_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv loss_dict["e"] = e_loss loss_dict["pix"] = pix_loss loss_dict["vgg"] = vgg_loss loss_dict["adv"] = adv_loss if args.train_ge: encoder.zero_grad() generator.zero_grad() e_loss.backward() e_optim.step() if args.g_decay < 1: manually_scale_grad(generator, g_scale) g_scale *= args.g_decay g_optim.step() else: encoder.zero_grad() e_loss.backward() e_optim.step() # Train Generator requires_grad(generator, True) requires_grad(encoder, False) requires_grad(discriminator, False) requires_grad(discriminator2, False) real_img = real_imgs[0] noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) if args.augment: fake_img, _ = augment(fake_img, ada_aug_p) fake_pred = discriminator(fake_img) g_loss_fake = g_nonsaturating_loss(fake_pred) loss_dict["g"] = g_loss_fake generator.zero_grad() g_loss_fake.backward() g_optim.step() g_regularize = args.g_reg_every > 0 and i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) fake_img, latents = generator(noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() # Update EMA ema_nimg = args.ema_kimg * 1000 if args.ema_rampup is not None: ema_nimg = min(ema_nimg, i * args.batch * args.ema_rampup) accum = 0.5**(args.batch / max(ema_nimg, 1e-8)) accumulate(e_ema, e_module, accum) accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() recx_score_val = loss_reduced["recx_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() pix_loss_val = loss_reduced["pix"].mean().item() vgg_loss_val = loss_reduced["vgg"].mean().item() adv_loss_val = loss_reduced["adv"].mean().item() avg_pix_loss.update(pix_loss_val, real_img.shape[0]) avg_vgg_loss.update(vgg_loss_val, real_img.shape[0]) if get_rank() == 0: pbar.set_description(( f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}; " f"pix: {pix_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; adv: {adv_loss_val:.4f}" )) if i % args.log_every == 0: with torch.no_grad(): latent_x, _ = e_ema(sample_x) fake_x, _ = generator([latent_x], input_is_latent=True, return_latents=False) sample_pix_loss = torch.sum((sample_x - fake_x)**2) with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write(( f"{i:07d}; pix: {avg_pix_loss.avg:.4f}; vgg: {avg_vgg_loss.avg:.4f}; ref: {sample_pix_loss.item():.4f}; " f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean_path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}; {'; '.join([f'{k}: {r_t_dict[k]:.4f}' for k in r_t_dict])}; " f"real_score: {real_score_val:.4f}; fake_score: {fake_score_val:.4f}; recx_score: {recx_score_val:.4f}; " f"real_diff: {real_diff/count:.4f}; fake_diff: {fake_diff/count:.4f};\n" )) real_diff = fake_diff = count = 0 if args.eval_every > 0 and i % args.eval_every == 0: with torch.no_grad(): fid_sa = fid_re = fid_hy = 0 # Sample FID g_ema.eval() if args.truncation < 1: mean_latent = g_ema.mean_latent(4096) features = extract_feature_from_samples( g_ema, inception, args.truncation, mean_latent, 64, args.n_sample_fid, args.device).numpy() sample_mean = np.mean(features, 0) sample_cov = np.cov(features, rowvar=False) fid_sa = calc_fid(sample_mean, sample_cov, real_mean, real_cov) # Recon FID features = extract_feature_from_reconstruction( e_ema, g_ema, inception, args.truncation, mean_latent, loader2, args.device, mode='recon', ).numpy() sample_mean = np.mean(features, 0) sample_cov = np.cov(features, rowvar=False) fid_re = calc_fid(sample_mean, sample_cov, real_mean, real_cov) # Hybrid FID if args.eval_hybrid: features = extract_feature_from_reconstruction( e_ema, g_ema, inception, args.truncation, mean_latent, loader2, args.device, mode='hybrid', # shuffle_idx=fid_batch_idx ).numpy() sample_mean = np.mean(features, 0) sample_cov = np.cov(features, rowvar=False) fid_hy = calc_fid(sample_mean, sample_cov, real_mean, real_cov) # print("Sample FID:", fid_sa, "Recon FID:", fid_re, "Hybrid FID:", fid_hy) with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f: f.write( f"{i:07d}; sample fid: {float(fid_sa):.4f}; recon fid: {float(fid_re):.4f}; hybrid fid: {float(fid_hy):.4f};\n" ) if wandb and args.wandb: wandb.log({ "Generator": g_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, }) if i % args.log_every == 0: with torch.no_grad(): g_ema.eval() e_ema.eval() nrow = int(args.n_sample**0.5) nchw = list(sample_x.shape)[1:] # Fixed fake samples sample, _ = g_ema([sample_z]) utils.save_image( sample, os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-sample.png"), nrow=int(args.n_sample**0.5), normalize=True, value_range=(-1, 1), ) # Reconstruction samples latent_real, _ = e_ema(sample_x) fake_img, _ = g_ema([latent_real], input_is_latent=True, return_latents=False) sample = torch.cat( (sample_x.reshape(args.n_sample // nrow, nrow, *nchw), fake_img.reshape(args.n_sample // nrow, nrow, *nchw)), 1) utils.save_image( sample.reshape(2 * args.n_sample, *nchw), os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-recon.png"), nrow=nrow, normalize=True, value_range=(-1, 1), ) # Hybrid samples: [real_y1, real_y2; real_x1, fake_x2] if args.eval_hybrid: w1, _ = e_ema(sample_x1) w2, _ = e_ema(sample_x2) dw = w2 - w1 dw = torch.cat( dw.chunk(2, 0)[::-1], 0) if sample_idx is None else dw[sample_idx, ...] fake_img, _ = g_ema([w1 + dw], input_is_latent=True, return_latents=False) drive = torch.cat( (torch.cat(sample_x1.chunk(2, 0)[::-1], 0).reshape( args.n_sample, 1, *nchw), torch.cat(sample_x2.chunk(2, 0)[::-1], 0).reshape( args.n_sample, 1, *nchw)), 1) source = torch.cat( (sample_x1.reshape(args.n_sample, 1, *nchw), fake_img.reshape(args.n_sample, 1, *nchw)), 1) sample = torch.cat( (drive.reshape(args.n_sample // nrow, 2 * nrow, * nchw), source.reshape(args.n_sample // nrow, 2 * nrow, * nchw)), 1) utils.save_image( sample.reshape(4 * args.n_sample, *nchw), os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-cross.png"), nrow=2 * nrow, normalize=True, value_range=(-1, 1), ) if i % args.save_every == 0: torch.save( { "g": g_module.state_dict(), "e": e_module.state_dict(), "d": d_module.state_dict(), "d2": d2_module.state_dict() if args.decouple_d else None, "g_ema": g_ema.state_dict(), "e_ema": e_ema.state_dict(), "g_optim": g_optim.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "d2_optim": d2_optim.state_dict() if args.decouple_d else None, "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"), ) if i % args.save_latest_every == 0: torch.save( { "g": g_module.state_dict(), "e": e_module.state_dict(), "d": d_module.state_dict(), "d2": d2_module.state_dict() if args.decouple_d else None, "g_ema": g_ema.state_dict(), "e_ema": e_ema.state_dict(), "g_optim": g_optim.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "d2_optim": d2_optim.state_dict() if args.decouple_d else None, "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"latest.pt"), )
def train(args, loader, generator, netsD, g_optim, rf_opt, info_opt, g_ema, device): loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} if args.distributed: g_module = generator.module d_module1 = netsD[1].module d_module2 = netsD[2].module else: g_module = generator d_module1 = netsD[1] d_module2 = netsD[2] accum = 0.5**(32 / (10 * 1000)) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256, device) sample_z = torch.randn(args.n_sample, args.latent, device=device) criterion_class = nn.CrossEntropyLoss() for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break real_img = next(loader) real_img = real_img.to(device) ############# train child discriminator ############# requires_grad(generator, False) requires_grad(netsD[2], True) noise = mixing_noise(args.batch, args.latent, args.mixing, device) c_code = sample_c_code(args.batch, args.c_categories, device) # fake_img, _ = generator(noise, c_code) image_li, _, _ = generator(noise, c_code) fake_img = image_li[0] if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) fake_img, _ = augment(fake_img, ada_aug_p) else: real_img_aug = real_img fake_pred = netsD[2](fake_img)[0] real_pred = netsD[2](real_img)[0] d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict["d"] = d_loss loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() netsD[2].zero_grad() d_loss.backward() rf_opt[2].step() if args.augment and args.augment_p == 0: ada_aug_p = ada_augment.tune(real_pred) r_t_stat = ada_augment.r_t_stat d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True real_pred = netsD[2](real_img)[0] r1_loss = d_r1_loss(real_pred, real_img) netsD[2].zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() rf_opt[2].step() loss_dict["r1"] = r1_loss ############# train generator and info discriminator ############# requires_grad(generator, True) requires_grad(netsD[1], True) requires_grad(netsD[2], True) noise = mixing_noise(args.batch, args.latent, args.mixing, device) c_code = sample_c_code(args.batch, args.c_categories, device) image_li, code_li, _ = generator(noise, c_code) fake_img = image_li[0] mkd_images = image_li[2] masks = image_li[3] p_code = code_li[1] c_code = code_li[2] if args.augment: fake_img, _ = augment(fake_img, ada_aug_p) fake_pred = netsD[2](fake_img)[0] g_loss = g_nonsaturating_loss(fake_pred) loss_dict["g"] = g_loss pred_p = netsD[1](mkd_images[1])[1] p_info_loss = criterion_class( pred_p, torch.nonzero(p_code.long(), as_tuple=False)[:, 1]) pred_c = netsD[2](mkd_images[2])[1] c_info_loss = criterion_class( pred_c, torch.nonzero(c_code.long(), as_tuple=False)[:, 1]) loss_dict["p_info"] = p_info_loss loss_dict["c_info"] = c_info_loss binary_loss = binarization_loss(masks[1]) * 2e1 # oob_loss = torch.sum(bg_mk * ch_mk, dim=(-1,-2)).mean() * 1e-2 ms = masks[1].size() min_fg_cvg = 0.2 * ms[2] * ms[3] fg_cvg_loss = F.relu(min_fg_cvg - torch.sum(masks[1], dim=(-1, -2))).mean() * 1e-2 ms = masks[1].size() min_bg_cvg = 0.2 * ms[2] * ms[3] bg_cvg_loss = F.relu(min_bg_cvg - torch.sum( torch.ones_like(masks[1]) - masks[1], dim=(-1, -2))).mean() * 1e-2 loss_dict["bin"] = binary_loss loss_dict["cvg"] = fg_cvg_loss + bg_cvg_loss generator_loss = g_loss + p_info_loss + c_info_loss + binary_loss + fg_cvg_loss + bg_cvg_loss generator.zero_grad() netsD[1].zero_grad() netsD[2].zero_grad() generator_loss.backward() g_optim.step() info_opt[1].step() info_opt[2].step() g_regularize = i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) c_code = sample_c_code(path_batch_size, args.c_categories, device) image_li, _, latents = generator(noise, c_code, return_latents=True) fake_img = image_li[0] path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() p_info_loss_val = loss_reduced["p_info"].mean().item() c_info_loss_val = loss_reduced["c_info"].mean().item() binary_loss_val = loss_reduced["bin"].mean().item() cvg_loss_val = loss_reduced["cvg"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() if get_rank() == 0: pbar.set_description(( f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " f"p_info: {p_info_loss_val:.4f}; c_info: {c_info_loss_val:.4f}; " f"bin: {binary_loss_val:.4f}; cvg: {cvg_loss_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}")) if wandb and args.wandb: wandb.log({ "Generator": g_loss_val, "Discriminator": d_loss_val, "p_info": p_info_loss_val, "c_info": c_info_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, }) if i % 1000 == 0: with torch.no_grad(): g_ema.eval() c_code = sample_c_code(args.n_sample, args.c_categories, device) image_li, _, _ = g_ema([sample_z], c_code) utils.save_image( image_li[0], f"sample/{str(i).zfill(6)}_0.png", nrow=int(args.n_sample**0.5), normalize=True, range=(-1, 1), ) for j in range(3): utils.save_image( image_li[1][j], f"sample/{str(i).zfill(6)}_{str(1+j)}.png", nrow=int(args.n_sample**0.5), normalize=True, range=(-1, 1), ) for j in range(3): utils.save_image( image_li[2][j], f"sample/{str(i).zfill(6)}_{str(4+j)}.png", nrow=int(args.n_sample**0.5), normalize=True, range=(-1, 1), ) for j in range(2): utils.save_image( image_li[3][j], f"sample/{str(i).zfill(6)}_{str(7+j)}.png", nrow=int(args.n_sample**0.5), normalize=True, range=(0, 1), ) if i % 10000 == 0: torch.save( { "g": g_module.state_dict(), # "d0": d_module0.state_dict(), "d1": d_module1.state_dict(), "d2": d_module2.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "rf_optim2": rf_opt[2].state_dict(), "info_optim1": info_opt[1].state_dict(), "info_optim2": info_opt[2].state_dict(), "args": args, "ada_aug_p": ada_aug_p, }, f"checkpoint/{str(i).zfill(6)}.pt", )
def train(args, loader, encoder, generator, discriminator, vggnet, pwcnet, e_optim, d_optim, e_ema, device): loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) d_loss_val = 0 e_loss_val = 0 rec_loss_val = 0 vgg_loss_val = 0 adv_loss_val = 0 loss_dict = { "d": torch.tensor(0., device=device), "real_score": torch.tensor(0., device=device), "fake_score": torch.tensor(0., device=device), "r1_d": torch.tensor(0., device=device), "r1_e": torch.tensor(0., device=device), "rec": torch.tensor(0., device=device), } avg_pix_loss = util.AverageMeter() avg_vgg_loss = util.AverageMeter() if args.distributed: e_module = encoder.module d_module = discriminator.module g_module = generator.module else: e_module = encoder d_module = discriminator g_module = generator accum = 0.5**(32 / (10 * 1000)) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256, device) sample_x = accumulate_batches(loader, args.n_sample).to(device) requires_grad(generator, False) # always False generator.eval() # Generator should be ema and in eval mode # if args.no_ema or e_ema is None: # e_ema = encoder for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break real_img1, real_img2 = next(loader) real_img1 = real_img1.to(device) real_img2 = real_img2.to(device) # Train Encoder if args.toggle_grads: requires_grad(encoder, True) requires_grad(discriminator, False) pix_loss = vgg_loss = adv_loss = rec_loss = torch.tensor(0., device=device) latent_real1 = encoder(real_img1) fake_img1, _ = generator([latent_real1], input_is_latent=True, return_latents=False) latent_real2 = encoder(real_img2) fake_img2, _ = generator([latent_real2], input_is_latent=True, return_latents=False) if args.lambda_adv > 0: if args.use_residual: fake_img_pair = fake_img1 - real_img2 else: fake_img_pair = torch.cat((fake_img1, real_img2), 1) if args.augment: fake_img_aug, _ = augment(fake_img_pair, ada_aug_p) else: fake_img_aug = fake_img_pair fake_pred = discriminator(fake_img_aug) adv_loss = g_nonsaturating_loss(fake_pred) if args.lambda_pix > 0: pix_loss = torch.mean((real_img1 - fake_img1)**2) if args.lambda_vgg > 0: real_feat = vggnet(real_img1) fake_feat = vggnet(fake_img1) vgg_loss = torch.mean((real_feat - fake_feat)**2) e_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv loss_dict["e"] = e_loss loss_dict["pix"] = pix_loss loss_dict["vgg"] = vgg_loss loss_dict["adv"] = adv_loss encoder.zero_grad() e_loss.backward() e_optim.step() # if args.train_on_fake: # e_regularize = args.e_rec_every > 0 and i % args.e_rec_every == 0 # if e_regularize and args.lambda_rec > 0: # noise = mixing_noise(args.batch, args.latent, args.mixing, device) # fake_img, latent_fake = generator(noise, input_is_latent=False, return_latents=True) # latent_pred = encoder(fake_img) # if latent_pred.ndim < 3: # latent_pred = latent_pred.unsqueeze(1).repeat(1, latent_fake.size(1), 1) # rec_loss = torch.mean((latent_fake - latent_pred) ** 2) # encoder.zero_grad() # (rec_loss * args.lambda_rec).backward() # e_optim.step() # loss_dict["rec"] = rec_loss e_regularize = args.e_reg_every > 0 and i % args.e_reg_every == 0 if e_regularize: # why not regularize on augmented real? real_img_pair = torch.cat((real_img1, real_img2), 1) real_img_pair.requires_grad = True real_pred = encoder(real_img_pair) r1_loss_e = d_r1_loss(real_pred, real_img_pair) encoder.zero_grad() (args.r1 / 2 * r1_loss_e * args.e_reg_every + 0 * real_pred.view(-1)[0]).backward() e_optim.step() loss_dict["r1_e"] = r1_loss_e if not args.no_ema and e_ema is not None: accumulate(e_ema, e_module, accum) # Train Discriminator if args.toggle_grads: requires_grad(encoder, False) requires_grad(discriminator, True) if not args.no_update_discriminator and args.lambda_adv > 0: latent_real1 = encoder(real_img1) fake_img1, _ = generator([latent_real1], input_is_latent=True, return_latents=False) if args.use_residual: fake_img_pair = fake_img1 - real_img2 real_img_pair = real_img1 - real_img2 else: fake_img_pair = torch.cat((fake_img1, real_img2), 1) real_img_pair = torch.cat((real_img1, real_img2), 1) if args.augment: real_img_aug, _ = augment(real_img_pair, ada_aug_p) fake_img_aug, _ = augment(fake_img_pair, ada_aug_p) else: real_img_aug = real_img_pair fake_img_aug = fake_img_pair fake_pred = discriminator(fake_img_aug) real_pred = discriminator(real_img_aug) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict["d"] = d_loss loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() if args.augment and args.augment_p == 0: ada_aug_p = ada_augment.tune(real_pred) r_t_stat = ada_augment.r_t_stat d_regularize = args.d_reg_every > 0 and i % args.d_reg_every == 0 if d_regularize: # why not regularize on augmented real? if args.use_residual: real_img_pair = real_img1 - real_img2 else: real_img_pair = torch.cat((real_img1, real_img2), 1) real_img_pair.requires_grad = True real_pred = discriminator(real_img_pair) r1_loss_d = d_r1_loss(real_pred, real_img_pair) discriminator.zero_grad() (args.r1 / 2 * r1_loss_d * args.d_reg_every + 0 * real_pred.view(-1)[0]).backward() # Why 0* ? Answer is here https://github.com/rosinality/stylegan2-pytorch/issues/76 d_optim.step() loss_dict["r1_d"] = r1_loss_d loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() e_loss_val = loss_reduced["e"].mean().item() r1_d_val = loss_reduced["r1_d"].mean().item() r1_e_val = loss_reduced["r1_e"].mean().item() pix_loss_val = loss_reduced["pix"].mean().item() vgg_loss_val = loss_reduced["vgg"].mean().item() adv_loss_val = loss_reduced["adv"].mean().item() rec_loss_val = loss_reduced["rec"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() avg_pix_loss.update(pix_loss_val, real_img1.shape[0]) avg_vgg_loss.update(vgg_loss_val, real_img1.shape[0]) if get_rank() == 0: pbar.set_description(( f"d: {d_loss_val:.4f}; e: {e_loss_val:.4f}; r1_d: {r1_d_val:.4f}; r1_e: {r1_e_val:.4f}; " f"pix: {pix_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; adv: {adv_loss_val:.4f}; " f"rec: {rec_loss_val:.4f}; augment: {ada_aug_p:.4f}")) if i % args.log_every == 0: with torch.no_grad(): latent_x = e_ema(sample_x) fake_x, _ = generator([latent_x], input_is_latent=True, return_latents=False) sample_pix_loss = torch.sum((sample_x - fake_x)**2) with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write( f"{i:07d}; pix: {avg_pix_loss.avg}; vgg: {avg_vgg_loss.avg}; " f"ref: {sample_pix_loss.item()};\n") if wandb and args.wandb: wandb.log({ "Encoder": e_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1 D": r1_d_val, "R1 E": r1_e_val, "Pix Loss": pix_loss_val, "VGG Loss": vgg_loss_val, "Adv Loss": adv_loss_val, "Rec Loss": rec_loss_val, "Real Score": real_score_val, "Fake Score": fake_score_val, }) if i % args.log_every == 0: with torch.no_grad(): e_eval = encoder if args.no_ema else e_ema e_eval.eval() nrow = int(args.n_sample**0.5) nchw = list(sample_x.shape)[1:] latent_real = e_eval(sample_x) fake_img, _ = generator([latent_real], input_is_latent=True, return_latents=False) sample = torch.cat( (sample_x.reshape(args.n_sample // nrow, nrow, *nchw), fake_img.reshape(args.n_sample // nrow, nrow, *nchw)), 1) utils.save_image( sample.reshape(2 * args.n_sample, *nchw), os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}.png"), nrow=nrow, normalize=True, value_range=(-1, 1), ) e_eval.train() if i % args.save_every == 0: e_eval = encoder if args.no_ema else e_ema torch.save( { "e": e_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_module.state_dict(), "e_ema": e_eval.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"), ) if i % args.save_latest_every == 0: torch.save( { "e": e_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_module.state_dict(), "e_ema": e_eval.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"latest.pt"), )