def train(data_loader, G, z_train, g_optimizer, z_sample, train_dict, args, myargs, **kwargs): for i, (imgs, _) in enumerate(tqdm.tqdm(data_loader)): train_dict['batches_done'] += 1 summary = {} G.train() imgs = imgs.cuda() # train G z_train.sample_() f_imgs = G(z_train) sinkhorn_d = sinkhorn_autodiff.sinkhorn_loss(x=imgs.view(imgs.size(0), -1), y=f_imgs.view(f_imgs.size(0), -1), epsilon=args.sinkhorn_eps, niter=args.sinkhorn_niter, cuda=True, pi_detach=args.sinkhorn_pi_detach) summary['sinkhorn_d'] = sinkhorn_d.item() g_loss = sinkhorn_d summary['g_loss'] = g_loss.item() G.zero_grad() g_loss.backward() g_optimizer.step() if i % args.sample_every == 0: # sample images G.eval() f_imgs_sample = G(z_sample) merged_img = torchvision.utils.make_grid(f_imgs_sample, normalize=True, pad_value=1, nrow=16) myargs.writer.add_images('G_z', merged_img.view(1, *merged_img.shape), train_dict['batches_done']) merged_img = torchvision.utils.make_grid(imgs, normalize=True, pad_value=1, nrow=16) myargs.writer.add_images('x', merged_img.view(1, *merged_img.shape), train_dict['batches_done']) # checkpoint myargs.checkpoint.save_checkpoint(checkpoint_dict=myargs.checkpoint_dict, filename='ckpt.tar') # summary for key in summary: myargs.writer.add_scalar('train_vs_batch/%s'%key, summary[key], train_dict['batches_done']) G.train()
def train(**kwargs): for i, (imgs, _) in enumerate(tqdm.tqdm(data_loader)): train_dict['batches_done'] += 1 step = train_dict['batches_done'] summary = {} summary_d_logits_mean = {} summary_wd = {} G.train() G_ema.train() imgs = imgs.cuda() bs = imgs.size(0) z_train.sample_() f_imgs = G(z_train[:bs]) # train D with torch.no_grad(): sinkhorn_d = sinkhorn_autodiff.sinkhorn_loss( x=imgs.view(bs, -1), y=f_imgs.view(bs, -1).detach(), epsilon=args.sinkhorn_eps, niter=args.sinkhorn_niter, cuda=True, pi_detach=args.sinkhorn_pi_detach) summary_wd['D_sinkhorn_d'] = sinkhorn_d.item() r_logit = D(imgs) r_logit_mean = r_logit.mean() f_logit = D(f_imgs.detach()) f_logit_mean = f_logit.mean() summary_d_logits_mean['D_r_logit_mean'] = r_logit_mean.item() summary_d_logits_mean['D_f_logit_mean'] = f_logit_mean.item() # Wasserstein-1 Distance wd = r_logit_mean - f_logit_mean gp = gan_losses.wgan_gp_gradient_penalty(imgs.data, f_imgs.data, D) d_loss = -wd + gp * 10.0 + torch.relu(wd - sinkhorn_d.item()) summary_wd['wd'] = wd.item() summary['gp'] = gp.item() summary['d_loss'] = d_loss.item() D.zero_grad() d_loss.backward() d_optimizer.step() if step % args.n_critic == 0: # train G z_train.sample_() f_imgs = G(z_train) sinkhorn_d = sinkhorn_autodiff.sinkhorn_loss( x=imgs.view(imgs.size(0), -1), y=f_imgs.view(f_imgs.size(0), -1), epsilon=args.sinkhorn_eps, niter=args.sinkhorn_niter, cuda=True, pi_detach=args.sinkhorn_pi_detach) summary_wd['G_sinkhorn_d'] = sinkhorn_d.item() f_logit = D(f_imgs) f_logit_mean = f_logit.mean() g_loss = -f_logit_mean + args.lambda_sinkhorn * sinkhorn_d summary_d_logits_mean['G_f_logit_mean'] = f_logit_mean.item() summary['g_loss'] = g_loss.item() D.zero_grad() G.zero_grad() g_loss.backward() g_optimizer.step() # end iter ema.update(train_dict['batches_done']) if i % args.sample_every == 0: # sample images G.eval() G_ema.eval() G_z = G(z_sample) merged_img = torchvision.utils.make_grid(G_z, normalize=True, pad_value=1, nrow=16) myargs.writer.add_images('G_z', merged_img.view(1, *merged_img.shape), train_dict['batches_done']) # G_ema G_ema_z = G_ema(z_sample) merged_img = torchvision.utils.make_grid(G_ema_z, normalize=True, pad_value=1, nrow=16) myargs.writer.add_images('G_ema_z', merged_img.view(1, *merged_img.shape), train_dict['batches_done']) # x merged_img = torchvision.utils.make_grid(imgs, normalize=True, pad_value=1, nrow=16) myargs.writer.add_images('x', merged_img.view(1, *merged_img.shape), train_dict['batches_done']) # checkpoint myargs.checkpoint.save_checkpoint( checkpoint_dict=myargs.checkpoint_dict, filename='ckpt.tar') # summary for key in summary: myargs.writer.add_scalar('train_vs_batch/%s' % key, summary[key], train_dict['batches_done']) myargs.writer.add_scalars('train_vs_batch', summary_d_logits_mean, train_dict['batches_done']) myargs.writer.add_scalars('wd', summary_wd, train_dict['batches_done']) G.train() elif train_dict['batches_done'] <= 20000: for key in summary: myargs.writer.add_scalar('train_vs_batch/%s' % key, summary[key], train_dict['batches_done']) myargs.writer.add_scalars('train_vs_batch', summary_d_logits_mean, train_dict['batches_done']) myargs.writer.add_scalars('wd', summary_wd, train_dict['batches_done'])
def train_one_epoch(self, ): config = self.config.train_one_epoch if config.dummy_train: return myargs = self.myargs train_dict = self.train_dict pbar = tqdm.tqdm(self.data_loader, file=myargs.stdout) self._summary_create() for i, (imgs, _) in enumerate(pbar): self.G.train() self.D.train() train_dict['batches_done'] += 1 imgs = imgs.cuda() bs = imgs.size(0) if bs != self.config.noise.batch_size_train: return self.z_train.sample_() with torch.no_grad(): f_imgs = self.G(self.z_train[:bs]) imgs.requires_grad_() f_imgs.requires_grad_() # train D D_r_logit = self.D(imgs) D_r_logit_mean = D_r_logit.mean() D_f_logit = self.D(f_imgs) D_f_logit_mean = D_f_logit.mean() self.summary_logit_mean['D_r_logit_mean'] = D_r_logit_mean.item() self.summary_logit_mean['D_f_logit_mean'] = D_f_logit_mean.item() self.d_optimizer.zero_grad() wd = D_f_logit_mean - D_r_logit_mean # Backward gp loss in this func gp = gan_losses.wgan_div_gradient_penalty( real_imgs=imgs, fake_imgs=f_imgs, real_validity=D_r_logit, fake_validity=D_f_logit, backward=True, retain_graph=True ) if config.bound_type == 'constant': raise NotImplemented D_loss = -wd + torch.relu(wd - float(config.bound)) # D_loss = -wd + gp * config.gp_lambda + \ # torch.relu(wd - float(config.bound)) self.summary_wd['bound'] = config.bound elif config.bound_type == 'sinkhorn': with torch.no_grad(): sinkhorn_d = sinkhorn_autodiff.sinkhorn_loss( x=imgs.view(bs, -1), y=f_imgs.view(bs, -1).detach(), epsilon=config.sinkhorn_eps, niter=config.sinkhorn_niter, cuda=True, pi_detach=config.sinkhorn_pi_detach) self.summary_wd['D_sinkhorn_d'] = sinkhorn_d.item() D_loss = -wd + torch.relu(wd - sinkhorn_d.item()) else: D_loss = -wd self.summary_wd['wd'] = wd.item() self.summary['gp'] = gp.item() self.summary['D_loss'] = D_loss.item() D_loss.backward() self.d_optimizer.step() if i % config.n_critic == 0: # train G self.z_train.sample_() f_imgs = self.G(self.z_train) D_f_logit = self.D(f_imgs) D_f_logit_mean = D_f_logit.mean() g_loss_only = D_f_logit_mean if config.bound_type == 'sinkhorn': sinkhorn_d = sinkhorn_autodiff.sinkhorn_loss( x=imgs.view(imgs.size(0), -1), y=f_imgs.view(f_imgs.size(0), -1), epsilon=config.sinkhorn_eps, niter=config.sinkhorn_niter, cuda=True, pi_detach=config.sinkhorn_pi_detach) self.summary_wd['G_sinkhorn_d'] = sinkhorn_d.item() G_loss = g_loss_only + config.sinkhorn_lambda * sinkhorn_d else: G_loss = g_loss_only self.summary_logit_mean['G_f_logit_mean'] = D_f_logit_mean.item() self.summary['g_loss_only'] = g_loss_only.item() self.summary['G_loss'] = G_loss.item() self.g_optimizer.zero_grad() G_loss.backward() self.g_optimizer.step() # end iter self.ema.update(train_dict['batches_done']) if i % config.sample_every == 0: # checkpoint myargs.checkpoint.save_checkpoint( checkpoint_dict=myargs.checkpoint_dict, filename='ckpt.tar') # summary self._summary_scalars() elif train_dict['batches_done'] <= config.sample_start_iter: self._summary_scalars() if i % (len(self.data_loader)//4) == 0: # save images self._summary_figures() self._summary_images(imgs=imgs, itr=self.train_dict['batches_done'])
def train_one_epoch(self, ): config = self.config.train_one_epoch if config.dummy_train: return myargs = self.myargs train_dict = myargs.checkpoint_dict['train_dict'] batch_size = self.config.noise.batch_size pbar = tqdm.tqdm(self.loaders[0], file=self.myargs.stdout) for i, (images, labels) in enumerate(pbar): if len(images) % batch_size != 0: return summary_dict = collections.defaultdict(dict) # Make sure G and D are in training mode, just in case they got set to eval # For D, which typically doesn't have BN, this shouldn't matter much. self.G.train() self.D.train() self.GD.train() self.G_ema.train() images, labels = images.cuda(), labels.cuda() # Optionally toggle D and G's "require_grad" if config.toggle_grads: gan_utils.toggle_grad(self.D, True) gan_utils.toggle_grad(self.G, False) # How many chunks to split x and y into? x = torch.split(images, batch_size) y = torch.split(labels, batch_size) counter = 0 for step_index in range(config.num_D_steps): self.G.optim.zero_grad() self.D.optim.zero_grad() if getattr(config, 'weigh_loss', False): self.alpha_optim.zero_grad() # Increment the iteration counter train_dict['batches_done'] += 1 # If accumulating gradients, loop multiple times before an optimizer step for accumulation_index in range(config.num_D_accumulations): self.z_.sample_() dy = y[counter] gy = dy # self.y_.sample_() D_fake, D_real, G_z = self.GD(z=self.z_[:batch_size], gy=gy, x=x[counter], dy=dy, train_G=False, split_D=config.split_D, return_G_z=True) # Compute components of D's loss, average them, and divide by # the number of gradient accumulations D_real_mean, D_fake_mean, wd, _ = gan_losses.wgan_discriminator_loss( r_logit=D_real, f_logit=D_fake) # entropy loss if getattr(config, 'weigh_loss', False) and config.use_entropy: weight_logit = 1. / torch.stack( self.alpha_params).exp() weight = F.softmax(weight_logit) entropy = (-weight * F.log_softmax(weight_logit)).sum() (-config.entropy_lambda * entropy).backward() # gp if getattr(config, 'weigh_loss', False) and \ hasattr(self, 'alpha_gp'): if config.use_entropy: weight_logit = 1. / torch.stack( self.alpha_params).exp() alpha_gp_logit = 1. / self.alpha_gp.exp() weight = alpha_gp_logit.exp() / weight_logit.exp( ).sum() else: weight = 1. / self.alpha_gp.exp() weight_gp = weight sigma_gp = torch.exp(self.alpha_gp / 2.) gp_loss, gp = gan_losses.wgan_gp_gradient_penalty_cond( x=x[counter], G_z=G_z, gy=gy, f=self.D, backward=True, gp_lambda=weight_gp, return_gp=True) (self.alpha_gp / 2.).backward() summary_dict['sigma']['sigma_gp'] = sigma_gp.item() summary_dict['weight']['weight_gp'] = weight_gp.item() summary_dict['gp']['gp_loss'] = gp_loss.item() summary_dict['gp']['gp_raw'] = gp.item() else: gp_loss = gan_losses.wgan_gp_gradient_penalty_cond( x=x[counter], G_z=G_z, gy=gy, f=self.D, backward=True, gp_lambda=config.gp_lambda) summary_dict['gp']['gp_loss'] = gp_loss.item() # wd if getattr(config, 'weigh_loss', False) and \ hasattr(self, 'alpha_wd'): if config.use_entropy: weight_logit = 1. / torch.stack( self.alpha_params).exp() alpha_wd_logit = 1. / self.alpha_wd.exp() weight = alpha_wd_logit.exp() / weight_logit.exp( ).sum() else: weight = 1. / self.alpha_wd.exp() weight_wd = weight sigma_wd = torch.exp(self.alpha_wd / 2.) D_loss = weight_wd.item() * (-wd) + \ weight_wd * wd.abs().item() + self.alpha_wd / 2. summary_dict['sigma']['sigma_wd'] = sigma_wd.item() summary_dict['weight']['weight_wd'] = weight_wd.item() summary_dict['wd']['wd_loss'] = \ weight_wd.item() * wd.abs().item() summary_dict['wd']['wd_raw'] = wd.item() elif getattr(config, 'd_sinkhorn', False): sinkhorn_c = config.sinkhorn_c with torch.no_grad(): sinkhorn_d = sinkhorn_autodiff.sinkhorn_loss( x=x[counter].view(batch_size, -1), y=G_z.view(batch_size, -1).detach(), epsilon=sinkhorn_c.sinkhorn_eps, niter=sinkhorn_c.sinkhorn_niter, cuda=True, pi_detach=sinkhorn_c.sinkhorn_pi_detach) summary_dict['wd'][ 'D_sinkhorn_d'] = sinkhorn_d.item() D_loss = -wd + torch.relu(wd - sinkhorn_d.item()) summary_dict['wd']['wd_raw'] = wd.item() else: D_loss = -wd summary_dict['wd']['wd_raw'] = wd.item() D_losses = D_loss / float(config.num_D_accumulations) # Accumulate gradients D_losses.backward() counter += 1 summary_dict['D_logit_mean'][ 'D_real_mean'] = D_real_mean.item() summary_dict['D_logit_mean'][ 'D_fake_mean'] = D_fake_mean.item() summary_dict['scalars']['D_loss'] = D_loss.item() # End accumulation # Optionally apply ortho reg in D if config.D_ortho > 0.0: # Debug print to indicate we're using ortho reg in D. print('using modified ortho reg in D') weight_regularity.ortho(self.D, config.D_ortho) # for name, value in self.D.named_parameters(): # self.myargs.writer.add_histogram( # name, value.grad, train_dict['batches_done']) self.D.optim.step() if getattr(config, 'weigh_loss', False): self.alpha_optim.step() # Optionally toggle "requires_grad" if config.toggle_grads: gan_utils.toggle_grad(self.D, False) gan_utils.toggle_grad(self.G, True) # Zero G's gradients by default before training G, for safety self.G.optim.zero_grad() if getattr(config, 'weigh_loss', False): self.alpha_optim.zero_grad() # If accumulating gradients, loop multiple times for accumulation_index in range(config.num_G_accumulations): self.z_.sample_() gy = dy # self.y_.sample_() D_fake, G_z = self.GD(z=self.z_, gy=gy, train_G=True, return_G_z=True, split_D=config.split_D) G_fake_mean, _ = gan_losses.wgan_generator_loss(f_logit=D_fake) # wd for generator if getattr(config, 'weigh_loss', False) and \ hasattr(self, 'alpha_g'): if config.use_entropy: weight_logit = 1. / torch.stack( self.alpha_params).exp() alpha_g_logit = 1. / self.alpha_g.exp() weight = alpha_g_logit.exp() / weight_logit.exp().sum() else: weight = 1. / self.alpha_g.exp() weight_g = weight sigma_g = torch.exp(self.alpha_g / 2.) wd_g = D_real_mean.item() - G_fake_mean G_loss = weight_g.item() * (wd_g) + \ weight_g * wd_g.abs().item() + self.alpha_g / 2. summary_dict['sigma']['sigma_g'] = sigma_g.item() summary_dict['weight']['weight_g'] = weight_g.item() summary_dict['wd']['wd_g_raw'] = wd_g.item() summary_dict['wd']['wd_g_loss'] = weight_g.item( ) * wd_g.abs().item() elif getattr(config, 'g_sinkhorn', False): sinkhorn_c = config.sinkhorn_c sinkhorn_d = sinkhorn_autodiff.sinkhorn_loss( x=x[-1].view(x[-1].size(0), -1), y=G_z.view(G_z.size(0), -1), epsilon=sinkhorn_c.sinkhorn_eps, niter=sinkhorn_c.sinkhorn_niter, cuda=True, pi_detach=sinkhorn_c.sinkhorn_pi_detach) summary_dict['wd']['G_sinkhorn_d'] = sinkhorn_d.item() g_loss_only = -G_fake_mean summary_dict['scalars']['g_loss_only'] = g_loss_only.item() G_loss = g_loss_only + sinkhorn_c.sinkhorn_lambda * sinkhorn_d else: G_loss = -G_fake_mean # Accumulate gradients G_loss = G_loss / float(config['num_G_accumulations']) G_loss.backward() summary_dict['D_logit_mean']['G_fake_mean'] = G_fake_mean.item( ) summary_dict['scalars']['G_loss'] = G_loss.item() # Optionally apply modified ortho reg in G if config.G_ortho > 0.0: # Debug print to indicate we're using ortho reg in G print('using modified ortho reg in G') # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this weight_regularity.ortho( self.G, config.G_ortho, blacklist=[param for param in self.G.shared.parameters()]) self.G.optim.step() if getattr(config, 'weigh_loss', False): self.alpha_optim.step() # If we have an ema, update it, regardless of if we test with it or not self.ema.update(train_dict['batches_done']) pbar.set_description('wd=%f' % wd.item()) if i % config.sample_every == 0: # weights self.save_checkpoint(filename='ckpt.tar') self.summary_dicts(summary_dicts=summary_dict, prefix='train_one_epoch', step=train_dict['batches_done']) # singular values # G_sv_dict = utils_BigGAN.get_SVs(self.G, 'G') # D_sv_dict = utils_BigGAN.get_SVs(self.D, 'D') # myargs.writer.add_scalars('sv/G_sv_dict', G_sv_dict, # train_dict['batches_done']) # myargs.writer.add_scalars('sv/D_sv_dict', D_sv_dict, # train_dict['batches_done']) elif train_dict['batches_done'] <= config.sample_start_iter: # scalars self.summary_dicts(summary_dicts=summary_dict, prefix='train_one_epoch', step=train_dict['batches_done']) if (i + 1) % (len(self.loaders[0]) // 2) == 0: # save images # self.summary_figures(summary_dicts=summary_dict, # prefix='train_one_epoch') # samples self._summary_images(imgs=x[0])