def _disc_max_logits_r1_loss(self, real_img: torch.Tensor, real_c: torch.Tensor, gain: int, do_main: bool, do_reg: bool, log_values:dict) -> torch.Tensor: real_img_tmp = real_img.detach().requires_grad_(do_reg) real_logits = self._disc_run(real_img_tmp, real_c) self.logit_sign(real_logits) # training_stats.report('Loss/scores/real', real_logits) # training_stats.report('Loss/signs/real', real_logits.sign()) loss_Dreal = 0 if do_main: loss_Dreal = F.softplus(-real_logits) # -log(sigmoid(real_logits)) log_values['Loss/D/loss'] += loss_Dreal.mean() # training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal) loss_Dr1 = 0 if do_reg: with conv2d_gradfix.no_weight_gradients(): r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0] r1_penalty = r1_grads.square().sum([1, 2, 3]) loss_Dr1 = r1_penalty * (self.r1_gamma / 2) log_values['Loss/r1_penalty'] = r1_penalty log_values['Loss/D/reg'] = loss_Dr1 return (real_logits * 0 + loss_Dreal + loss_Dr1).mean().mul(gain)
def _gen_pl_loss(self, gen_z: torch.Tensor, gen_c: torch.Tensor, gain: int) -> torch.Tensor: batch_size = gen_z.shape[0] // self.pl_batch_shrink gen_img, gen_ws = self._gen_run(gen_z[:batch_size], gen_c[:batch_size]) print(gen_img.requires_grad) pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3]) with conv2d_gradfix.no_weight_gradients(): pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0] pl_lengths = pl_grads.square().sum(2).mean(1).sqrt() self.pl_mean = self.pl_mean.to(pl_lengths.device) pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay) self.pl_mean.copy_(pl_mean.detach()) pl_penalty = (pl_lengths - pl_mean).square() # training_stats.report('Loss/pl_penalty', pl_penalty) loss_Gpl = pl_penalty * self.pl_weight # training_stats.report('Loss/G/reg', loss_Gpl) values = {'Loss/pl_penalty': pl_penalty, 'Loss/G/reg': loss_Gpl} self.log_dict(values) return (gen_img[:, 0, 0, 0] * 0 + loss_Gpl).mean().mul(gain)
def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain): assert phase in ["Gmain", "Greg", "Gboth", "Dmain", "Dreg", "Dboth"] do_Gmain = phase in ["Gmain", "Gboth"] do_Dmain = phase in ["Dmain", "Dboth"] do_Gpl = (phase in ["Greg", "Gboth"]) and (self.pl_weight != 0) do_Dr1 = (phase in ["Dreg", "Dboth"]) and (self.r1_gamma != 0) # Gmain: Maximize logits for generated images. if do_Gmain: with torch.autograd.profiler.record_function("Gmain_forward"): gen_img, _gen_ws = self.run_G( gen_z, gen_c, sync=(sync and not do_Gpl)) # May get synced by Gpl. gen_logits = self.run_D(gen_img, gen_c, sync=False) training_stats.report("Loss/scores/fake", gen_logits) training_stats.report("Loss/signs/fake", gen_logits.sign()) loss_Gmain = torch.nn.functional.softplus( -gen_logits) # -log(sigmoid(gen_logits)) training_stats.report("Loss/G/loss", loss_Gmain) with torch.autograd.profiler.record_function("Gmain_backward"): loss_Gmain.mean().mul(gain).backward() # Gpl: Apply path length regularization. if do_Gpl: with torch.autograd.profiler.record_function("Gpl_forward"): batch_size = gen_z.shape[0] // self.pl_batch_shrink gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size], sync=sync) pl_noise = torch.randn_like(gen_img) / np.sqrt( gen_img.shape[2] * gen_img.shape[3]) with torch.autograd.profiler.record_function( "pl_grads"), conv2d_gradfix.no_weight_gradients(): pl_grads = torch.autograd.grad(outputs=[ (gen_img * pl_noise).sum() ], inputs=[gen_ws], create_graph=True, only_inputs=True)[0] pl_lengths = pl_grads.square().sum(2).mean(1).sqrt() pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay) self.pl_mean.copy_(pl_mean.detach()) pl_penalty = (pl_lengths - pl_mean).square() training_stats.report("Loss/pl_penalty", pl_penalty) loss_Gpl = pl_penalty * self.pl_weight training_stats.report("Loss/G/reg", loss_Gpl) with torch.autograd.profiler.record_function("Gpl_backward"): (gen_img[:, 0, 0, 0] * 0 + loss_Gpl).mean().mul(gain).backward() # Dmain: Minimize logits for generated images. loss_Dgen = 0 if do_Dmain: with torch.autograd.profiler.record_function("Dgen_forward"): gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=False) gen_logits = self.run_D( gen_img, gen_c, sync=False) # Gets synced by loss_Dreal. training_stats.report("Loss/scores/fake", gen_logits) training_stats.report("Loss/signs/fake", gen_logits.sign()) loss_Dgen = torch.nn.functional.softplus( gen_logits) # -log(1 - sigmoid(gen_logits)) with torch.autograd.profiler.record_function("Dgen_backward"): loss_Dgen.mean().mul(gain).backward() # Dmain: Maximize logits for real images. # Dr1: Apply R1 regularization. if do_Dmain or do_Dr1: name = "Dreal_Dr1" if do_Dmain and do_Dr1 else "Dreal" if do_Dmain else "Dr1" with torch.autograd.profiler.record_function(name + "_forward"): real_img_tmp = real_img.detach().requires_grad_(do_Dr1) real_logits = self.run_D(real_img_tmp, real_c, sync=sync) training_stats.report("Loss/scores/real", real_logits) training_stats.report("Loss/signs/real", real_logits.sign()) loss_Dreal = 0 if do_Dmain: loss_Dreal = torch.nn.functional.softplus( -real_logits) # -log(sigmoid(real_logits)) training_stats.report("Loss/D/loss", loss_Dgen + loss_Dreal) loss_Dr1 = 0 if do_Dr1: with torch.autograd.profiler.record_function( "r1_grads"), conv2d_gradfix.no_weight_gradients(): r1_grads = torch.autograd.grad( outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0] r1_penalty = r1_grads.square().sum([1, 2, 3]) loss_Dr1 = r1_penalty * (self.r1_gamma / 2) training_stats.report("Loss/r1_penalty", r1_penalty) training_stats.report("Loss/D/reg", loss_Dr1) with torch.autograd.profiler.record_function(name + "_backward"): (real_logits * 0 + loss_Dreal + loss_Dr1).mean().mul(gain).backward()
def accumulate_gradients(self, stage, real_img, real_c, gen_z, gen_c, sync, gain): assert stage in ["G_main", "G_reg", "G_both", "D_main", "D_reg", "D_both"] G_main = (stage in ["G_main", "G_both"]) D_main = (stage in ["D_main", "D_both"]) G_pl = (stage in ["G_reg", "G_both"]) and (self.pl_weight != 0) D_r1 = (stage in ["D_reg", "D_both"]) and (self.r1_gamma != 0) # G_main: Maximize logits for generated images if G_main: with torch.autograd.profiler.record_function("G_main_forward"): gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync = (sync and not G_pl)) # May get synced by G_pl gen_logits = self.run_D(gen_img, gen_c, sync = False) training_stats.report("Loss/scores/fake", gen_logits) training_stats.report("Loss/signs/fake", gen_logits.sign()) if self.g_loss == "logistic": loss_G_main = -torch.nn.functional.softplus(gen_logits) # -log(sigmoid(gen_logits)) elif self.g_loss == "logistic_ns": loss_G_main = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits)) elif self.g_loss == "hinge": loss_G_main = -torch.clamp(1.0 + gen_logits, min = 0) elif self.g_loss == "wgan": loss_G_main = -gen_logits training_stats.report("Loss/G/loss", loss_G_main) with torch.autograd.profiler.record_function("G_main_backward"): loss_G_main.mean().mul(gain).backward() # G_pl: Apply path length regularization if G_pl: with torch.autograd.profiler.record_function("G_pl_forward"): batch_size = gen_z.shape[0] // self.pl_batch_shrink gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size], sync = sync) pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3]) with torch.autograd.profiler.record_function("pl_grads"), conv2d_gradfix.no_weight_gradients(): pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph = True, only_inputs = True)[0] pl_lengths = pl_grads.square().sum(2).mean(1).sqrt() pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay) self.pl_mean.copy_(pl_mean.detach()) pl_penalty = (pl_lengths - pl_mean).square() training_stats.report("Loss/pl_penalty", pl_penalty) loss_G_pl = pl_penalty * self.pl_weight training_stats.report("Loss/G/reg", loss_G_pl) with torch.autograd.profiler.record_function("G_pl_backward"): (gen_img[:, 0, 0, 0] * 0 + loss_G_pl).mean().mul(gain).backward() # D_main: Minimize logits for generated images loss_D_gen = 0 if D_main: with torch.autograd.profiler.record_function("D_gen_forward"): gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync = False) gen_logits = self.run_D(gen_img, gen_c, sync = False) # Gets synced by loss_D_real training_stats.report("Loss/scores/fake", gen_logits) training_stats.report("Loss/signs/fake", gen_logits.sign()) if self.d_loss == "logistic": loss_D_gen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits)) elif self.d_loss == "hinge": loss_D_gen = torch.clamp(1.0 + gen_logits, min = 0) elif self.d_loss == "wgan": loss_D_gen = gen_logits with torch.autograd.profiler.record_function("D_gen_backward"): loss_D_gen.mean().mul(gain).backward() # D_main: Maximize logits for real images # D_r1: Apply R1 regularization if D_main or D_r1: name = "D_real_D_r1" if D_main and D_r1 else "D_real" if D_main else "D_r1" with torch.autograd.profiler.record_function(name + "_forward"): real_img_tmp = real_img.detach().requires_grad_(D_r1) real_logits = self.run_D(real_img_tmp, real_c, sync = sync) training_stats.report("Loss/scores/real", real_logits) training_stats.report("Loss/signs/real", real_logits.sign()) loss_D_real = 0 if D_main: if self.d_loss == "logistic": loss_D_real = torch.nn.functional.softplus(-real_logits) # -log(sigmoid(real_logits)) elif self.d_loss == "hinge": loss_D_real = torch.clamp(1.0 - real_logits, min = 0) elif self.d_loss == "wgan": loss_D_real = -real_logits + tf.square(real_logits) * wgan_epsilon training_stats.report("Loss/D/loss", loss_D_gen + loss_D_real) loss_D_r1 = 0 if D_r1: with torch.autograd.profiler.record_function("r1_grads"), conv2d_gradfix.no_weight_gradients(): r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph = True, only_inputs = True)[0] r1_penalty = r1_grads.square().sum([1,2,3]) loss_D_r1 = r1_penalty * (self.r1_gamma / 2) training_stats.report("Loss/r1_penalty", r1_penalty) training_stats.report("Loss/D/reg", loss_D_r1) with torch.autograd.profiler.record_function(name + "_backward"): (real_logits * 0 + loss_D_real + loss_D_r1).mean().mul(gain).backward()
def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain): assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth'] do_Gmain = (phase in ['Gmain', 'Gboth']) do_Dmain = (phase in ['Dmain', 'Dboth']) do_Gpl = (phase in ['Greg', 'Gboth']) and (self.pl_weight != 0) do_Dr1 = (phase in ['Dreg', 'Dboth']) and (self.r1_gamma != 0) # Gmain: Maximize logits for generated images. if do_Gmain: with torch.autograd.profiler.record_function('Gmain_forward'): minibatch_size = gen_z.shape[0] gen_img, _gen_ws = self.run_G( gen_z, gen_c, sync=(sync and not do_Gpl)) # May get synced by Gpl. gen_logits = self.run_D(gen_img, gen_c, sync=False) training_stats.report('Loss/scores/fake', gen_logits) training_stats.report('Loss/signs/fake', gen_logits.sign()) # top-k function based on: https://github.com/dvschultz/stylegan2-ada/blob/main/training/loss.py#L102 if self.G_top_k: D_fake_scores = gen_logits k_frac = np.maximum(self.G_top_k_gamma**self.G.epochs, self.G_top_k_frac) k = int(np.ceil(minibatch_size * k_frac)) lowest_k_scores, _ = torch.topk( -torch.squeeze(D_fake_scores), k=k) # want smallest probabilities not largest gen_logits = torch.unsqueeze(-lowest_k_scores, axis=1) loss_Gmain = torch.nn.functional.softplus( -gen_logits) # -log(sigmoid(gen_logits)) training_stats.report('Loss/G/loss', loss_Gmain) with torch.autograd.profiler.record_function('Gmain_backward'): loss_Gmain.mean().mul(gain).backward() # Gpl: Apply path length regularization. if do_Gpl: with torch.autograd.profiler.record_function('Gpl_forward'): batch_size = gen_z.shape[0] // self.pl_batch_shrink gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size], sync=sync) pl_noise = torch.randn_like(gen_img) / np.sqrt( gen_img.shape[2] * gen_img.shape[3]) with torch.autograd.profiler.record_function( 'pl_grads'), conv2d_gradfix.no_weight_gradients(): pl_grads = torch.autograd.grad(outputs=[ (gen_img * pl_noise).sum() ], inputs=[gen_ws], create_graph=True, only_inputs=True)[0] pl_lengths = pl_grads.square().sum(2).mean(1).sqrt() pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay) self.pl_mean.copy_(pl_mean.detach()) pl_penalty = (pl_lengths - pl_mean).square() training_stats.report('Loss/pl_penalty', pl_penalty) loss_Gpl = pl_penalty * self.pl_weight training_stats.report('Loss/G/reg', loss_Gpl) with torch.autograd.profiler.record_function('Gpl_backward'): (gen_img[:, 0, 0, 0] * 0 + loss_Gpl).mean().mul(gain).backward() # Dmain: Minimize logits for generated images. loss_Dgen = 0 if do_Dmain: with torch.autograd.profiler.record_function('Dgen_forward'): gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=False) gen_logits = self.run_D( gen_img, gen_c, sync=False) # Gets synced by loss_Dreal. training_stats.report('Loss/scores/fake', gen_logits) training_stats.report('Loss/signs/fake', gen_logits.sign()) loss_Dgen = torch.nn.functional.softplus( gen_logits) # -log(1 - sigmoid(gen_logits)) with torch.autograd.profiler.record_function('Dgen_backward'): loss_Dgen.mean().mul(gain).backward() # Dmain: Maximize logits for real images. # Dr1: Apply R1 regularization. if do_Dmain or do_Dr1: name = 'Dreal_Dr1' if do_Dmain and do_Dr1 else 'Dreal' if do_Dmain else 'Dr1' with torch.autograd.profiler.record_function(name + '_forward'): real_img_tmp = real_img.detach().requires_grad_(do_Dr1) real_logits = self.run_D(real_img_tmp, real_c, sync=sync) training_stats.report('Loss/scores/real', real_logits) training_stats.report('Loss/signs/real', real_logits.sign()) loss_Dreal = 0 if do_Dmain: loss_Dreal = torch.nn.functional.softplus( -real_logits) # -log(sigmoid(real_logits)) training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal) loss_Dr1 = 0 if do_Dr1: with torch.autograd.profiler.record_function( 'r1_grads'), conv2d_gradfix.no_weight_gradients(): r1_grads = torch.autograd.grad( outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0] r1_penalty = r1_grads.square().sum([1, 2, 3]) loss_Dr1 = r1_penalty * (self.r1_gamma / 2) training_stats.report('Loss/r1_penalty', r1_penalty) training_stats.report('Loss/D/reg', loss_Dr1) with torch.autograd.profiler.record_function(name + '_backward'): (real_logits * 0 + loss_Dreal + loss_Dr1).mean().mul(gain).backward()
def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain): assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth'] do_Gmain = (phase in ['Gmain', 'Gboth']) do_Dmain = (phase in ['Dmain', 'Dboth']) do_Gpl = (phase in ['Greg', 'Gboth']) and (self.pl_weight != 0) do_Dr1 = (phase in ['Dreg', 'Dboth']) and (self.r1_gamma != 0) do_Gae = False # Gmain: Maximize logits for generated images. if do_Gmain: with torch.autograd.profiler.record_function('Gmain_forward'): gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=(sync and not do_Gpl and not do_Gae)) # May get synced later. gen_logits = self.run_D(gen_img, gen_c, sync=False) training_stats.report('Loss/scores/fake', gen_logits) training_stats.report('Loss/signs/fake', gen_logits.sign()) loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits)) training_stats.report('Loss/G/loss', loss_Gmain) with torch.autograd.profiler.record_function('Gmain_backward'): loss_Gmain.mean().mul(gain).backward() # Gpl: Apply path length regularization. if do_Gpl: with torch.autograd.profiler.record_function('Gpl_forward'): batch_size = gen_z.shape[0] // self.pl_batch_shrink gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size], sync=sync) pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3]) with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients(): pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0] pl_lengths = pl_grads.square().sum(2).mean(1).sqrt() pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay) self.pl_mean.copy_(pl_mean.detach()) pl_penalty = (pl_lengths - pl_mean).square() training_stats.report('Loss/pl_penalty', pl_penalty) loss_Gpl = pl_penalty * self.pl_weight training_stats.report('Loss/G/reg', loss_Gpl) with torch.autograd.profiler.record_function('Gpl_backward'): (gen_img[:, 0, 0, 0] * 0 + loss_Gpl).mean().mul(gain).backward() if do_Gae: with torch.autograd.profiler.record_function('Gae_forward'): structure, style = self.run_encoder(real_img, real_c, sync) # batch_size = gen_z.shape[0] // self.pl_batch_shrink # gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size], sync=sync) # pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3]) # with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients(): # pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0] # pl_lengths = pl_grads.square().sum(2).mean(1).sqrt() # pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay) # self.pl_mean.copy_(pl_mean.detach()) # pl_penalty = (pl_lengths - pl_mean).square() rec_img = self.run_decoder(structure, style, sync) training_stats.report('Loss/G/ae_loss', pl_penalty) loss_Gae = ae_loss * self.G_ae_weight training_stats.report('Loss/G/ae_loss_weighted', loss_Gpl) with torch.autograd.profiler.record_function('Gae_backward'): (gen_img[:, 0, 0, 0] * 0 + loss_Gpl).mean().mul(gain).backward() # Dmain: Minimize logits for generated images. loss_Dgen = 0 if do_Dmain: with torch.autograd.profiler.record_function('Dgen_forward'): gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=False) gen_logits = self.run_D(gen_img, gen_c, sync=False) # Gets synced by loss_Dreal. training_stats.report('Loss/scores/fake', gen_logits) training_stats.report('Loss/signs/fake', gen_logits.sign()) loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits)) with torch.autograd.profiler.record_function('Dgen_backward'): loss_Dgen.mean().mul(gain).backward() # Dmain: Maximize logits for real images. # Dr1: Apply R1 regularization. if do_Dmain or do_Dr1: name = 'Dreal_Dr1' if do_Dmain and do_Dr1 else 'Dreal' if do_Dmain else 'Dr1' with torch.autograd.profiler.record_function(name + '_forward'): real_img_tmp = real_img.detach().requires_grad_(do_Dr1) real_logits = self.run_D(real_img_tmp, real_c, sync=sync) training_stats.report('Loss/scores/real', real_logits) training_stats.report('Loss/signs/real', real_logits.sign()) loss_Dreal = 0 if do_Dmain: loss_Dreal = torch.nn.functional.softplus(-real_logits) # -log(sigmoid(real_logits)) training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal) loss_Dr1 = 0 if do_Dr1: with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients(): r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0] r1_penalty = r1_grads.square().sum([1,2,3]) loss_Dr1 = r1_penalty * (self.r1_gamma / 2) training_stats.report('Loss/r1_penalty', r1_penalty) training_stats.report('Loss/D/reg', loss_Dr1) with torch.autograd.profiler.record_function(name + '_backward'): (real_logits * 0 + loss_Dreal + loss_Dr1).mean().mul(gain).backward()