Exemplo n.º 1
0
 def backward_D(self, netD, real, fake):
     # Fake, stop backprop to the generator by detaching fake_B
     if self.use_ada:
         pred_fake = netD(self.augment(fake.detach()))
         # real
         pred_real = netD(self.augment(real))
     else:
         pred_fake = netD(fake.detach())
         # real
         pred_real = netD(real)
     loss_D_fake, _ = self.criterionGAN(pred_fake, False)
     loss_D_real, _ = self.criterionGAN(pred_real, True)
     if self.use_ada:
         sum_logits = 0
         for pred in pred_real:
             sum_logits += pred.mean()
         training_stats.report('Loss/signs/real', (sum_logits / len(pred_real)).sign())
     # Combined loss
     loss_D = loss_D_fake + loss_D_real
     loss_D.backward()
     return loss_D, [loss_D_fake, loss_D_real]
Exemplo n.º 2
0
    def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain):
    
        with torch.autograd.profiler.record_function('Emain_forward'):
            codes = self.run_E(real_img, real_c, sync=sync)
            np.savetxt('codes', codes.detach().cpu().numpy())
            gen_img, _gen_ws = self.run_G(codes, real_c, sync=sync)
            codes_gen = self.run_E(gen_img, real_c, sync=sync)
            l1 = self.lpips_loss(gen_img, real_img).mean()
            l2 = self.mse_loss(gen_img, real_img)
            loss = self.lambda1*l1 + self.lambda2*l2
            training_stats.report('Loss/E/loss1', l1)
            training_stats.report('Loss/E/loss2', l2)
            training_stats.report('Loss/E/loss', loss)
        with torch.autograd.profiler.record_function('Emain_backward'):
            loss.backward()

#----------------------------------------------------------------------------
Exemplo n.º 3
0
    def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync,
                             gain):
        assert phase in ["Gmain", "Greg", "Gboth", "Dmain", "Dreg", "Dboth"]
        do_Gmain = phase in ["Gmain", "Gboth"]
        do_Dmain = phase in ["Dmain", "Dboth"]
        do_Gpl = (phase in ["Greg", "Gboth"]) and (self.pl_weight != 0)
        do_Dr1 = (phase in ["Dreg", "Dboth"]) and (self.r1_gamma != 0)

        # Gmain: Maximize logits for generated images.
        if do_Gmain:
            with torch.autograd.profiler.record_function("Gmain_forward"):
                gen_img, _gen_ws = self.run_G(
                    gen_z, gen_c,
                    sync=(sync and not do_Gpl))  # May get synced by Gpl.
                gen_logits = self.run_D(gen_img, gen_c, sync=False)
                training_stats.report("Loss/scores/fake", gen_logits)
                training_stats.report("Loss/signs/fake", gen_logits.sign())
                loss_Gmain = torch.nn.functional.softplus(
                    -gen_logits)  # -log(sigmoid(gen_logits))
                training_stats.report("Loss/G/loss", loss_Gmain)
            with torch.autograd.profiler.record_function("Gmain_backward"):
                loss_Gmain.mean().mul(gain).backward()

        # Gpl: Apply path length regularization.
        if do_Gpl:
            with torch.autograd.profiler.record_function("Gpl_forward"):
                batch_size = gen_z.shape[0] // self.pl_batch_shrink
                gen_img, gen_ws = self.run_G(gen_z[:batch_size],
                                             gen_c[:batch_size],
                                             sync=sync)
                pl_noise = torch.randn_like(gen_img) / np.sqrt(
                    gen_img.shape[2] * gen_img.shape[3])
                with torch.autograd.profiler.record_function(
                        "pl_grads"), conv2d_gradfix.no_weight_gradients():
                    pl_grads = torch.autograd.grad(outputs=[
                        (gen_img * pl_noise).sum()
                    ],
                                                   inputs=[gen_ws],
                                                   create_graph=True,
                                                   only_inputs=True)[0]
                pl_lengths = pl_grads.square().sum(2).mean(1).sqrt()
                pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay)
                self.pl_mean.copy_(pl_mean.detach())
                pl_penalty = (pl_lengths - pl_mean).square()
                training_stats.report("Loss/pl_penalty", pl_penalty)
                loss_Gpl = pl_penalty * self.pl_weight
                training_stats.report("Loss/G/reg", loss_Gpl)
            with torch.autograd.profiler.record_function("Gpl_backward"):
                (gen_img[:, 0, 0, 0] * 0 +
                 loss_Gpl).mean().mul(gain).backward()

        # Dmain: Minimize logits for generated images.
        loss_Dgen = 0
        if do_Dmain:
            with torch.autograd.profiler.record_function("Dgen_forward"):
                gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=False)
                gen_logits = self.run_D(
                    gen_img, gen_c, sync=False)  # Gets synced by loss_Dreal.
                training_stats.report("Loss/scores/fake", gen_logits)
                training_stats.report("Loss/signs/fake", gen_logits.sign())
                loss_Dgen = torch.nn.functional.softplus(
                    gen_logits)  # -log(1 - sigmoid(gen_logits))
            with torch.autograd.profiler.record_function("Dgen_backward"):
                loss_Dgen.mean().mul(gain).backward()

        # Dmain: Maximize logits for real images.
        # Dr1: Apply R1 regularization.
        if do_Dmain or do_Dr1:
            name = "Dreal_Dr1" if do_Dmain and do_Dr1 else "Dreal" if do_Dmain else "Dr1"
            with torch.autograd.profiler.record_function(name + "_forward"):
                real_img_tmp = real_img.detach().requires_grad_(do_Dr1)
                real_logits = self.run_D(real_img_tmp, real_c, sync=sync)
                training_stats.report("Loss/scores/real", real_logits)
                training_stats.report("Loss/signs/real", real_logits.sign())

                loss_Dreal = 0
                if do_Dmain:
                    loss_Dreal = torch.nn.functional.softplus(
                        -real_logits)  # -log(sigmoid(real_logits))
                    training_stats.report("Loss/D/loss",
                                          loss_Dgen + loss_Dreal)

                loss_Dr1 = 0
                if do_Dr1:
                    with torch.autograd.profiler.record_function(
                            "r1_grads"), conv2d_gradfix.no_weight_gradients():
                        r1_grads = torch.autograd.grad(
                            outputs=[real_logits.sum()],
                            inputs=[real_img_tmp],
                            create_graph=True,
                            only_inputs=True)[0]
                    r1_penalty = r1_grads.square().sum([1, 2, 3])
                    loss_Dr1 = r1_penalty * (self.r1_gamma / 2)
                    training_stats.report("Loss/r1_penalty", r1_penalty)
                    training_stats.report("Loss/D/reg", loss_Dr1)

            with torch.autograd.profiler.record_function(name + "_backward"):
                (real_logits * 0 + loss_Dreal +
                 loss_Dr1).mean().mul(gain).backward()
Exemplo n.º 4
0
    def accumulate_gradients(self, stage, real_img, real_c, gen_z, gen_c, sync, gain):
        assert stage in ["G_main", "G_reg", "G_both", "D_main", "D_reg", "D_both"]
        G_main = (stage in ["G_main", "G_both"])
        D_main = (stage in ["D_main", "D_both"])
        G_pl   = (stage in ["G_reg", "G_both"]) and (self.pl_weight != 0)
        D_r1   = (stage in ["D_reg", "D_both"]) and (self.r1_gamma != 0)

        # G_main: Maximize logits for generated images
        if G_main:
            with torch.autograd.profiler.record_function("G_main_forward"):
                gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync = (sync and not G_pl)) # May get synced by G_pl
                gen_logits = self.run_D(gen_img, gen_c, sync = False)
                training_stats.report("Loss/scores/fake", gen_logits)
                training_stats.report("Loss/signs/fake", gen_logits.sign())
    
                if self.g_loss == "logistic":
                    loss_G_main = -torch.nn.functional.softplus(gen_logits) # -log(sigmoid(gen_logits))
                elif self.g_loss == "logistic_ns":
                    loss_G_main = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits))
                elif self.g_loss == "hinge":
                    loss_G_main = -torch.clamp(1.0 + gen_logits, min = 0)
                elif self.g_loss == "wgan":
                    loss_G_main = -gen_logits

                training_stats.report("Loss/G/loss", loss_G_main)
            with torch.autograd.profiler.record_function("G_main_backward"):
                loss_G_main.mean().mul(gain).backward()

        # G_pl: Apply path length regularization
        if G_pl:
            with torch.autograd.profiler.record_function("G_pl_forward"):
                batch_size = gen_z.shape[0] // self.pl_batch_shrink
                gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size], sync = sync)
                pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3])
                with torch.autograd.profiler.record_function("pl_grads"), conv2d_gradfix.no_weight_gradients():
                    pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph = True, only_inputs = True)[0]
                pl_lengths = pl_grads.square().sum(2).mean(1).sqrt()
                pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay)
                self.pl_mean.copy_(pl_mean.detach())
                pl_penalty = (pl_lengths - pl_mean).square()
                training_stats.report("Loss/pl_penalty", pl_penalty)
                loss_G_pl = pl_penalty * self.pl_weight
                training_stats.report("Loss/G/reg", loss_G_pl)
            with torch.autograd.profiler.record_function("G_pl_backward"):
                (gen_img[:, 0, 0, 0] * 0 + loss_G_pl).mean().mul(gain).backward()

        # D_main: Minimize logits for generated images
        loss_D_gen = 0
        if D_main:
            with torch.autograd.profiler.record_function("D_gen_forward"):
                gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync = False)
                gen_logits = self.run_D(gen_img, gen_c, sync = False) # Gets synced by loss_D_real
                training_stats.report("Loss/scores/fake", gen_logits)
                training_stats.report("Loss/signs/fake", gen_logits.sign())
            
                if self.d_loss == "logistic":
                    loss_D_gen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits))
                elif self.d_loss == "hinge":
                    loss_D_gen = torch.clamp(1.0 + gen_logits, min = 0)
                elif self.d_loss == "wgan":
                    loss_D_gen = gen_logits

            with torch.autograd.profiler.record_function("D_gen_backward"):
                loss_D_gen.mean().mul(gain).backward()

        # D_main: Maximize logits for real images
        # D_r1: Apply R1 regularization
        if D_main or D_r1:
            name = "D_real_D_r1" if D_main and D_r1 else "D_real" if D_main else "D_r1"
            with torch.autograd.profiler.record_function(name + "_forward"):
                real_img_tmp = real_img.detach().requires_grad_(D_r1)
                real_logits = self.run_D(real_img_tmp, real_c, sync = sync)
                training_stats.report("Loss/scores/real", real_logits)
                training_stats.report("Loss/signs/real", real_logits.sign())

                loss_D_real = 0
                if D_main:
                    if self.d_loss == "logistic":
                        loss_D_real = torch.nn.functional.softplus(-real_logits) # -log(sigmoid(real_logits))
                    elif self.d_loss == "hinge":
                        loss_D_real = torch.clamp(1.0 - real_logits, min = 0)
                    elif self.d_loss == "wgan":
                        loss_D_real = -real_logits + tf.square(real_logits) * wgan_epsilon

                    training_stats.report("Loss/D/loss", loss_D_gen + loss_D_real)

                loss_D_r1 = 0
                if D_r1:
                    with torch.autograd.profiler.record_function("r1_grads"), conv2d_gradfix.no_weight_gradients():
                        r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph = True, only_inputs = True)[0]
                    r1_penalty = r1_grads.square().sum([1,2,3])
                    loss_D_r1 = r1_penalty * (self.r1_gamma / 2)
                    training_stats.report("Loss/r1_penalty", r1_penalty)
                    training_stats.report("Loss/D/reg", loss_D_r1)

            with torch.autograd.profiler.record_function(name + "_backward"):
                (real_logits * 0 + loss_D_real + loss_D_r1).mean().mul(gain).backward()
Exemplo n.º 5
0
    def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync,
                             gain):
        assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth']
        do_Gmain = (phase in ['Gmain', 'Gboth'])
        do_Dmain = (phase in ['Dmain', 'Dboth'])
        do_Gpl = (phase in ['Greg', 'Gboth']) and (self.pl_weight != 0)
        do_Dr1 = (phase in ['Dreg', 'Dboth']) and (self.r1_gamma != 0)

        # Gmain: Maximize logits for generated images.
        if do_Gmain:
            with torch.autograd.profiler.record_function('Gmain_forward'):
                minibatch_size = gen_z.shape[0]
                gen_img, _gen_ws = self.run_G(
                    gen_z, gen_c,
                    sync=(sync and not do_Gpl))  # May get synced by Gpl.
                gen_logits = self.run_D(gen_img, gen_c, sync=False)
                training_stats.report('Loss/scores/fake', gen_logits)
                training_stats.report('Loss/signs/fake', gen_logits.sign())

                # top-k function based on: https://github.com/dvschultz/stylegan2-ada/blob/main/training/loss.py#L102
                if self.G_top_k:
                    D_fake_scores = gen_logits
                    k_frac = np.maximum(self.G_top_k_gamma**self.G.epochs,
                                        self.G_top_k_frac)
                    k = int(np.ceil(minibatch_size * k_frac))
                    lowest_k_scores, _ = torch.topk(
                        -torch.squeeze(D_fake_scores),
                        k=k)  # want smallest probabilities not largest
                    gen_logits = torch.unsqueeze(-lowest_k_scores, axis=1)

                loss_Gmain = torch.nn.functional.softplus(
                    -gen_logits)  # -log(sigmoid(gen_logits))
                training_stats.report('Loss/G/loss', loss_Gmain)
            with torch.autograd.profiler.record_function('Gmain_backward'):
                loss_Gmain.mean().mul(gain).backward()

        # Gpl: Apply path length regularization.
        if do_Gpl:
            with torch.autograd.profiler.record_function('Gpl_forward'):
                batch_size = gen_z.shape[0] // self.pl_batch_shrink
                gen_img, gen_ws = self.run_G(gen_z[:batch_size],
                                             gen_c[:batch_size],
                                             sync=sync)
                pl_noise = torch.randn_like(gen_img) / np.sqrt(
                    gen_img.shape[2] * gen_img.shape[3])
                with torch.autograd.profiler.record_function(
                        'pl_grads'), conv2d_gradfix.no_weight_gradients():
                    pl_grads = torch.autograd.grad(outputs=[
                        (gen_img * pl_noise).sum()
                    ],
                                                   inputs=[gen_ws],
                                                   create_graph=True,
                                                   only_inputs=True)[0]
                pl_lengths = pl_grads.square().sum(2).mean(1).sqrt()
                pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay)
                self.pl_mean.copy_(pl_mean.detach())
                pl_penalty = (pl_lengths - pl_mean).square()
                training_stats.report('Loss/pl_penalty', pl_penalty)
                loss_Gpl = pl_penalty * self.pl_weight
                training_stats.report('Loss/G/reg', loss_Gpl)
            with torch.autograd.profiler.record_function('Gpl_backward'):
                (gen_img[:, 0, 0, 0] * 0 +
                 loss_Gpl).mean().mul(gain).backward()

        # Dmain: Minimize logits for generated images.
        loss_Dgen = 0
        if do_Dmain:
            with torch.autograd.profiler.record_function('Dgen_forward'):
                gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=False)
                gen_logits = self.run_D(
                    gen_img, gen_c, sync=False)  # Gets synced by loss_Dreal.
                training_stats.report('Loss/scores/fake', gen_logits)
                training_stats.report('Loss/signs/fake', gen_logits.sign())
                loss_Dgen = torch.nn.functional.softplus(
                    gen_logits)  # -log(1 - sigmoid(gen_logits))
            with torch.autograd.profiler.record_function('Dgen_backward'):
                loss_Dgen.mean().mul(gain).backward()

        # Dmain: Maximize logits for real images.
        # Dr1: Apply R1 regularization.
        if do_Dmain or do_Dr1:
            name = 'Dreal_Dr1' if do_Dmain and do_Dr1 else 'Dreal' if do_Dmain else 'Dr1'
            with torch.autograd.profiler.record_function(name + '_forward'):
                real_img_tmp = real_img.detach().requires_grad_(do_Dr1)
                real_logits = self.run_D(real_img_tmp, real_c, sync=sync)
                training_stats.report('Loss/scores/real', real_logits)
                training_stats.report('Loss/signs/real', real_logits.sign())

                loss_Dreal = 0
                if do_Dmain:
                    loss_Dreal = torch.nn.functional.softplus(
                        -real_logits)  # -log(sigmoid(real_logits))
                    training_stats.report('Loss/D/loss',
                                          loss_Dgen + loss_Dreal)

                loss_Dr1 = 0
                if do_Dr1:
                    with torch.autograd.profiler.record_function(
                            'r1_grads'), conv2d_gradfix.no_weight_gradients():
                        r1_grads = torch.autograd.grad(
                            outputs=[real_logits.sum()],
                            inputs=[real_img_tmp],
                            create_graph=True,
                            only_inputs=True)[0]
                    r1_penalty = r1_grads.square().sum([1, 2, 3])
                    loss_Dr1 = r1_penalty * (self.r1_gamma / 2)
                    training_stats.report('Loss/r1_penalty', r1_penalty)
                    training_stats.report('Loss/D/reg', loss_Dr1)

            with torch.autograd.profiler.record_function(name + '_backward'):
                (real_logits * 0 + loss_Dreal +
                 loss_Dr1).mean().mul(gain).backward()
Exemplo n.º 6
0
    def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain):
        assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth']
        do_Gmain = (phase in ['Gmain', 'Gboth'])
        do_Dmain = (phase in ['Dmain', 'Dboth'])
        do_Gpl   = (phase in ['Greg', 'Gboth']) and (self.pl_weight != 0)
        do_Dr1   = (phase in ['Dreg', 'Dboth']) and (self.r1_gamma != 0)
        do_Gae   = False

        # Gmain: Maximize logits for generated images.
        if do_Gmain:
            with torch.autograd.profiler.record_function('Gmain_forward'):
                gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=(sync and not do_Gpl and not do_Gae)) # May get synced later.
                gen_logits = self.run_D(gen_img, gen_c, sync=False)
                training_stats.report('Loss/scores/fake', gen_logits)
                training_stats.report('Loss/signs/fake', gen_logits.sign())
                loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits))
                training_stats.report('Loss/G/loss', loss_Gmain)
            with torch.autograd.profiler.record_function('Gmain_backward'):
                loss_Gmain.mean().mul(gain).backward()

        # Gpl: Apply path length regularization.
        if do_Gpl:
            with torch.autograd.profiler.record_function('Gpl_forward'):
                batch_size = gen_z.shape[0] // self.pl_batch_shrink
                gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size], sync=sync)
                pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3])
                with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients():
                    pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0]
                pl_lengths = pl_grads.square().sum(2).mean(1).sqrt()
                pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay)
                self.pl_mean.copy_(pl_mean.detach())
                pl_penalty = (pl_lengths - pl_mean).square()
                training_stats.report('Loss/pl_penalty', pl_penalty)
                loss_Gpl = pl_penalty * self.pl_weight
                training_stats.report('Loss/G/reg', loss_Gpl)
            with torch.autograd.profiler.record_function('Gpl_backward'):
                (gen_img[:, 0, 0, 0] * 0 + loss_Gpl).mean().mul(gain).backward()

        if do_Gae:
            with torch.autograd.profiler.record_function('Gae_forward'):
                structure, style = self.run_encoder(real_img, real_c, sync)
                # batch_size = gen_z.shape[0] // self.pl_batch_shrink
                # gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size], sync=sync)
                # pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3])
                # with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients():
                #     pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0]
                # pl_lengths = pl_grads.square().sum(2).mean(1).sqrt()
                # pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay)
                # self.pl_mean.copy_(pl_mean.detach())
                # pl_penalty = (pl_lengths - pl_mean).square()
                rec_img = self.run_decoder(structure, style, sync)
                training_stats.report('Loss/G/ae_loss', pl_penalty)
                loss_Gae = ae_loss * self.G_ae_weight
                training_stats.report('Loss/G/ae_loss_weighted', loss_Gpl)
            with torch.autograd.profiler.record_function('Gae_backward'):
                (gen_img[:, 0, 0, 0] * 0 + loss_Gpl).mean().mul(gain).backward()

        # Dmain: Minimize logits for generated images.
        loss_Dgen = 0
        if do_Dmain:
            with torch.autograd.profiler.record_function('Dgen_forward'):
                gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=False)
                gen_logits = self.run_D(gen_img, gen_c, sync=False) # Gets synced by loss_Dreal.
                training_stats.report('Loss/scores/fake', gen_logits)
                training_stats.report('Loss/signs/fake', gen_logits.sign())
                loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits))
            with torch.autograd.profiler.record_function('Dgen_backward'):
                loss_Dgen.mean().mul(gain).backward()

        # Dmain: Maximize logits for real images.
        # Dr1: Apply R1 regularization.
        if do_Dmain or do_Dr1:
            name = 'Dreal_Dr1' if do_Dmain and do_Dr1 else 'Dreal' if do_Dmain else 'Dr1'
            with torch.autograd.profiler.record_function(name + '_forward'):
                real_img_tmp = real_img.detach().requires_grad_(do_Dr1)
                real_logits = self.run_D(real_img_tmp, real_c, sync=sync)
                training_stats.report('Loss/scores/real', real_logits)
                training_stats.report('Loss/signs/real', real_logits.sign())

                loss_Dreal = 0
                if do_Dmain:
                    loss_Dreal = torch.nn.functional.softplus(-real_logits) # -log(sigmoid(real_logits))
                    training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal)

                loss_Dr1 = 0
                if do_Dr1:
                    with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients():
                        r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0]
                    r1_penalty = r1_grads.square().sum([1,2,3])
                    loss_Dr1 = r1_penalty * (self.r1_gamma / 2)
                    training_stats.report('Loss/r1_penalty', r1_penalty)
                    training_stats.report('Loss/D/reg', loss_Dr1)

            with torch.autograd.profiler.record_function(name + '_backward'):
                (real_logits * 0 + loss_Dreal + loss_Dr1).mean().mul(gain).backward()
Exemplo n.º 7
0
    def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain,
                             cur_nimg):
        assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth']
        do_Gmain = (phase in ['Gmain', 'Gboth'])
        do_Dmain = (phase in ['Dmain', 'Dboth'])
        if phase in ['Dreg', 'Greg']: return  # no regularization needed for PG

        # blurring schedule
        blur_sigma = max(
            1 - cur_nimg / (self.blur_fade_kimg * 1e3),
            0) * self.blur_init_sigma if self.blur_fade_kimg > 1 else 0

        if do_Gmain:

            # Gmain: Maximize logits for generated images.
            with torch.autograd.profiler.record_function('Gmain_forward'):
                gen_img = self.run_G(gen_z, gen_c)
                gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma)
                loss_Gmain = (-gen_logits).mean()

                # Logging
                training_stats.report('Loss/scores/fake', gen_logits)
                training_stats.report('Loss/signs/fake', gen_logits.sign())
                training_stats.report('Loss/G/loss', loss_Gmain)

            with torch.autograd.profiler.record_function('Gmain_backward'):
                loss_Gmain.backward()

        if do_Dmain:

            # Dmain: Minimize logits for generated images.
            with torch.autograd.profiler.record_function('Dgen_forward'):
                gen_img = self.run_G(gen_z, gen_c, update_emas=True)
                gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma)
                loss_Dgen = (F.relu(torch.ones_like(gen_logits) +
                                    gen_logits)).mean()

                # Logging
                training_stats.report('Loss/scores/fake', gen_logits)
                training_stats.report('Loss/signs/fake', gen_logits.sign())

            with torch.autograd.profiler.record_function('Dgen_backward'):
                loss_Dgen.backward()

            # Dmain: Maximize logits for real images.
            with torch.autograd.profiler.record_function('Dreal_forward'):
                real_img_tmp = real_img.detach().requires_grad_(False)
                real_logits = self.run_D(real_img_tmp,
                                         real_c,
                                         blur_sigma=blur_sigma)
                loss_Dreal = (
                    F.relu(torch.ones_like(real_logits) - real_logits)).mean()

                # Logging
                training_stats.report('Loss/scores/real', real_logits)
                training_stats.report('Loss/signs/real', real_logits.sign())
                training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal)

            with torch.autograd.profiler.record_function('Dreal_backward'):
                loss_Dreal.backward()