def _disc_max_logits_r1_loss(self, real_img: torch.Tensor, real_c: torch.Tensor, gain: int,
                         do_main: bool, do_reg: bool, log_values:dict) -> torch.Tensor:
        real_img_tmp = real_img.detach().requires_grad_(do_reg)
        real_logits = self._disc_run(real_img_tmp, real_c)
        self.logit_sign(real_logits)
        # training_stats.report('Loss/scores/real', real_logits)
        # training_stats.report('Loss/signs/real', real_logits.sign())
        loss_Dreal = 0
        if do_main:
            loss_Dreal = F.softplus(-real_logits)  # -log(sigmoid(real_logits))
            log_values['Loss/D/loss'] += loss_Dreal.mean()
            # training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal)

        loss_Dr1 = 0
        if do_reg:
            with conv2d_gradfix.no_weight_gradients():
                r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True,
                                               only_inputs=True)[0]
            r1_penalty = r1_grads.square().sum([1, 2, 3])
            loss_Dr1 = r1_penalty * (self.r1_gamma / 2)

            log_values['Loss/r1_penalty'] = r1_penalty
            log_values['Loss/D/reg'] = loss_Dr1

        return (real_logits * 0 + loss_Dreal + loss_Dr1).mean().mul(gain)
    def _gen_pl_loss(self, gen_z: torch.Tensor, gen_c: torch.Tensor, gain: int) -> torch.Tensor:
        batch_size = gen_z.shape[0] // self.pl_batch_shrink
        gen_img, gen_ws = self._gen_run(gen_z[:batch_size], gen_c[:batch_size])
        print(gen_img.requires_grad)
        pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3])
        with conv2d_gradfix.no_weight_gradients():
            pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True,
                                           only_inputs=True)[0]
        pl_lengths = pl_grads.square().sum(2).mean(1).sqrt()
        self.pl_mean = self.pl_mean.to(pl_lengths.device)
        pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay)
        self.pl_mean.copy_(pl_mean.detach())
        pl_penalty = (pl_lengths - pl_mean).square()
        # training_stats.report('Loss/pl_penalty', pl_penalty)
        loss_Gpl = pl_penalty * self.pl_weight
        # training_stats.report('Loss/G/reg', loss_Gpl)

        values = {'Loss/pl_penalty': pl_penalty,
                  'Loss/G/reg': loss_Gpl}
        self.log_dict(values)

        return (gen_img[:, 0, 0, 0] * 0 + loss_Gpl).mean().mul(gain)
示例#3
0
    def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync,
                             gain):
        assert phase in ["Gmain", "Greg", "Gboth", "Dmain", "Dreg", "Dboth"]
        do_Gmain = phase in ["Gmain", "Gboth"]
        do_Dmain = phase in ["Dmain", "Dboth"]
        do_Gpl = (phase in ["Greg", "Gboth"]) and (self.pl_weight != 0)
        do_Dr1 = (phase in ["Dreg", "Dboth"]) and (self.r1_gamma != 0)

        # Gmain: Maximize logits for generated images.
        if do_Gmain:
            with torch.autograd.profiler.record_function("Gmain_forward"):
                gen_img, _gen_ws = self.run_G(
                    gen_z, gen_c,
                    sync=(sync and not do_Gpl))  # May get synced by Gpl.
                gen_logits = self.run_D(gen_img, gen_c, sync=False)
                training_stats.report("Loss/scores/fake", gen_logits)
                training_stats.report("Loss/signs/fake", gen_logits.sign())
                loss_Gmain = torch.nn.functional.softplus(
                    -gen_logits)  # -log(sigmoid(gen_logits))
                training_stats.report("Loss/G/loss", loss_Gmain)
            with torch.autograd.profiler.record_function("Gmain_backward"):
                loss_Gmain.mean().mul(gain).backward()

        # Gpl: Apply path length regularization.
        if do_Gpl:
            with torch.autograd.profiler.record_function("Gpl_forward"):
                batch_size = gen_z.shape[0] // self.pl_batch_shrink
                gen_img, gen_ws = self.run_G(gen_z[:batch_size],
                                             gen_c[:batch_size],
                                             sync=sync)
                pl_noise = torch.randn_like(gen_img) / np.sqrt(
                    gen_img.shape[2] * gen_img.shape[3])
                with torch.autograd.profiler.record_function(
                        "pl_grads"), conv2d_gradfix.no_weight_gradients():
                    pl_grads = torch.autograd.grad(outputs=[
                        (gen_img * pl_noise).sum()
                    ],
                                                   inputs=[gen_ws],
                                                   create_graph=True,
                                                   only_inputs=True)[0]
                pl_lengths = pl_grads.square().sum(2).mean(1).sqrt()
                pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay)
                self.pl_mean.copy_(pl_mean.detach())
                pl_penalty = (pl_lengths - pl_mean).square()
                training_stats.report("Loss/pl_penalty", pl_penalty)
                loss_Gpl = pl_penalty * self.pl_weight
                training_stats.report("Loss/G/reg", loss_Gpl)
            with torch.autograd.profiler.record_function("Gpl_backward"):
                (gen_img[:, 0, 0, 0] * 0 +
                 loss_Gpl).mean().mul(gain).backward()

        # Dmain: Minimize logits for generated images.
        loss_Dgen = 0
        if do_Dmain:
            with torch.autograd.profiler.record_function("Dgen_forward"):
                gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=False)
                gen_logits = self.run_D(
                    gen_img, gen_c, sync=False)  # Gets synced by loss_Dreal.
                training_stats.report("Loss/scores/fake", gen_logits)
                training_stats.report("Loss/signs/fake", gen_logits.sign())
                loss_Dgen = torch.nn.functional.softplus(
                    gen_logits)  # -log(1 - sigmoid(gen_logits))
            with torch.autograd.profiler.record_function("Dgen_backward"):
                loss_Dgen.mean().mul(gain).backward()

        # Dmain: Maximize logits for real images.
        # Dr1: Apply R1 regularization.
        if do_Dmain or do_Dr1:
            name = "Dreal_Dr1" if do_Dmain and do_Dr1 else "Dreal" if do_Dmain else "Dr1"
            with torch.autograd.profiler.record_function(name + "_forward"):
                real_img_tmp = real_img.detach().requires_grad_(do_Dr1)
                real_logits = self.run_D(real_img_tmp, real_c, sync=sync)
                training_stats.report("Loss/scores/real", real_logits)
                training_stats.report("Loss/signs/real", real_logits.sign())

                loss_Dreal = 0
                if do_Dmain:
                    loss_Dreal = torch.nn.functional.softplus(
                        -real_logits)  # -log(sigmoid(real_logits))
                    training_stats.report("Loss/D/loss",
                                          loss_Dgen + loss_Dreal)

                loss_Dr1 = 0
                if do_Dr1:
                    with torch.autograd.profiler.record_function(
                            "r1_grads"), conv2d_gradfix.no_weight_gradients():
                        r1_grads = torch.autograd.grad(
                            outputs=[real_logits.sum()],
                            inputs=[real_img_tmp],
                            create_graph=True,
                            only_inputs=True)[0]
                    r1_penalty = r1_grads.square().sum([1, 2, 3])
                    loss_Dr1 = r1_penalty * (self.r1_gamma / 2)
                    training_stats.report("Loss/r1_penalty", r1_penalty)
                    training_stats.report("Loss/D/reg", loss_Dr1)

            with torch.autograd.profiler.record_function(name + "_backward"):
                (real_logits * 0 + loss_Dreal +
                 loss_Dr1).mean().mul(gain).backward()
示例#4
0
    def accumulate_gradients(self, stage, real_img, real_c, gen_z, gen_c, sync, gain):
        assert stage in ["G_main", "G_reg", "G_both", "D_main", "D_reg", "D_both"]
        G_main = (stage in ["G_main", "G_both"])
        D_main = (stage in ["D_main", "D_both"])
        G_pl   = (stage in ["G_reg", "G_both"]) and (self.pl_weight != 0)
        D_r1   = (stage in ["D_reg", "D_both"]) and (self.r1_gamma != 0)

        # G_main: Maximize logits for generated images
        if G_main:
            with torch.autograd.profiler.record_function("G_main_forward"):
                gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync = (sync and not G_pl)) # May get synced by G_pl
                gen_logits = self.run_D(gen_img, gen_c, sync = False)
                training_stats.report("Loss/scores/fake", gen_logits)
                training_stats.report("Loss/signs/fake", gen_logits.sign())
    
                if self.g_loss == "logistic":
                    loss_G_main = -torch.nn.functional.softplus(gen_logits) # -log(sigmoid(gen_logits))
                elif self.g_loss == "logistic_ns":
                    loss_G_main = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits))
                elif self.g_loss == "hinge":
                    loss_G_main = -torch.clamp(1.0 + gen_logits, min = 0)
                elif self.g_loss == "wgan":
                    loss_G_main = -gen_logits

                training_stats.report("Loss/G/loss", loss_G_main)
            with torch.autograd.profiler.record_function("G_main_backward"):
                loss_G_main.mean().mul(gain).backward()

        # G_pl: Apply path length regularization
        if G_pl:
            with torch.autograd.profiler.record_function("G_pl_forward"):
                batch_size = gen_z.shape[0] // self.pl_batch_shrink
                gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size], sync = sync)
                pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3])
                with torch.autograd.profiler.record_function("pl_grads"), conv2d_gradfix.no_weight_gradients():
                    pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph = True, only_inputs = True)[0]
                pl_lengths = pl_grads.square().sum(2).mean(1).sqrt()
                pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay)
                self.pl_mean.copy_(pl_mean.detach())
                pl_penalty = (pl_lengths - pl_mean).square()
                training_stats.report("Loss/pl_penalty", pl_penalty)
                loss_G_pl = pl_penalty * self.pl_weight
                training_stats.report("Loss/G/reg", loss_G_pl)
            with torch.autograd.profiler.record_function("G_pl_backward"):
                (gen_img[:, 0, 0, 0] * 0 + loss_G_pl).mean().mul(gain).backward()

        # D_main: Minimize logits for generated images
        loss_D_gen = 0
        if D_main:
            with torch.autograd.profiler.record_function("D_gen_forward"):
                gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync = False)
                gen_logits = self.run_D(gen_img, gen_c, sync = False) # Gets synced by loss_D_real
                training_stats.report("Loss/scores/fake", gen_logits)
                training_stats.report("Loss/signs/fake", gen_logits.sign())
            
                if self.d_loss == "logistic":
                    loss_D_gen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits))
                elif self.d_loss == "hinge":
                    loss_D_gen = torch.clamp(1.0 + gen_logits, min = 0)
                elif self.d_loss == "wgan":
                    loss_D_gen = gen_logits

            with torch.autograd.profiler.record_function("D_gen_backward"):
                loss_D_gen.mean().mul(gain).backward()

        # D_main: Maximize logits for real images
        # D_r1: Apply R1 regularization
        if D_main or D_r1:
            name = "D_real_D_r1" if D_main and D_r1 else "D_real" if D_main else "D_r1"
            with torch.autograd.profiler.record_function(name + "_forward"):
                real_img_tmp = real_img.detach().requires_grad_(D_r1)
                real_logits = self.run_D(real_img_tmp, real_c, sync = sync)
                training_stats.report("Loss/scores/real", real_logits)
                training_stats.report("Loss/signs/real", real_logits.sign())

                loss_D_real = 0
                if D_main:
                    if self.d_loss == "logistic":
                        loss_D_real = torch.nn.functional.softplus(-real_logits) # -log(sigmoid(real_logits))
                    elif self.d_loss == "hinge":
                        loss_D_real = torch.clamp(1.0 - real_logits, min = 0)
                    elif self.d_loss == "wgan":
                        loss_D_real = -real_logits + tf.square(real_logits) * wgan_epsilon

                    training_stats.report("Loss/D/loss", loss_D_gen + loss_D_real)

                loss_D_r1 = 0
                if D_r1:
                    with torch.autograd.profiler.record_function("r1_grads"), conv2d_gradfix.no_weight_gradients():
                        r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph = True, only_inputs = True)[0]
                    r1_penalty = r1_grads.square().sum([1,2,3])
                    loss_D_r1 = r1_penalty * (self.r1_gamma / 2)
                    training_stats.report("Loss/r1_penalty", r1_penalty)
                    training_stats.report("Loss/D/reg", loss_D_r1)

            with torch.autograd.profiler.record_function(name + "_backward"):
                (real_logits * 0 + loss_D_real + loss_D_r1).mean().mul(gain).backward()
示例#5
0
    def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync,
                             gain):
        assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth']
        do_Gmain = (phase in ['Gmain', 'Gboth'])
        do_Dmain = (phase in ['Dmain', 'Dboth'])
        do_Gpl = (phase in ['Greg', 'Gboth']) and (self.pl_weight != 0)
        do_Dr1 = (phase in ['Dreg', 'Dboth']) and (self.r1_gamma != 0)

        # Gmain: Maximize logits for generated images.
        if do_Gmain:
            with torch.autograd.profiler.record_function('Gmain_forward'):
                minibatch_size = gen_z.shape[0]
                gen_img, _gen_ws = self.run_G(
                    gen_z, gen_c,
                    sync=(sync and not do_Gpl))  # May get synced by Gpl.
                gen_logits = self.run_D(gen_img, gen_c, sync=False)
                training_stats.report('Loss/scores/fake', gen_logits)
                training_stats.report('Loss/signs/fake', gen_logits.sign())

                # top-k function based on: https://github.com/dvschultz/stylegan2-ada/blob/main/training/loss.py#L102
                if self.G_top_k:
                    D_fake_scores = gen_logits
                    k_frac = np.maximum(self.G_top_k_gamma**self.G.epochs,
                                        self.G_top_k_frac)
                    k = int(np.ceil(minibatch_size * k_frac))
                    lowest_k_scores, _ = torch.topk(
                        -torch.squeeze(D_fake_scores),
                        k=k)  # want smallest probabilities not largest
                    gen_logits = torch.unsqueeze(-lowest_k_scores, axis=1)

                loss_Gmain = torch.nn.functional.softplus(
                    -gen_logits)  # -log(sigmoid(gen_logits))
                training_stats.report('Loss/G/loss', loss_Gmain)
            with torch.autograd.profiler.record_function('Gmain_backward'):
                loss_Gmain.mean().mul(gain).backward()

        # Gpl: Apply path length regularization.
        if do_Gpl:
            with torch.autograd.profiler.record_function('Gpl_forward'):
                batch_size = gen_z.shape[0] // self.pl_batch_shrink
                gen_img, gen_ws = self.run_G(gen_z[:batch_size],
                                             gen_c[:batch_size],
                                             sync=sync)
                pl_noise = torch.randn_like(gen_img) / np.sqrt(
                    gen_img.shape[2] * gen_img.shape[3])
                with torch.autograd.profiler.record_function(
                        'pl_grads'), conv2d_gradfix.no_weight_gradients():
                    pl_grads = torch.autograd.grad(outputs=[
                        (gen_img * pl_noise).sum()
                    ],
                                                   inputs=[gen_ws],
                                                   create_graph=True,
                                                   only_inputs=True)[0]
                pl_lengths = pl_grads.square().sum(2).mean(1).sqrt()
                pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay)
                self.pl_mean.copy_(pl_mean.detach())
                pl_penalty = (pl_lengths - pl_mean).square()
                training_stats.report('Loss/pl_penalty', pl_penalty)
                loss_Gpl = pl_penalty * self.pl_weight
                training_stats.report('Loss/G/reg', loss_Gpl)
            with torch.autograd.profiler.record_function('Gpl_backward'):
                (gen_img[:, 0, 0, 0] * 0 +
                 loss_Gpl).mean().mul(gain).backward()

        # Dmain: Minimize logits for generated images.
        loss_Dgen = 0
        if do_Dmain:
            with torch.autograd.profiler.record_function('Dgen_forward'):
                gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=False)
                gen_logits = self.run_D(
                    gen_img, gen_c, sync=False)  # Gets synced by loss_Dreal.
                training_stats.report('Loss/scores/fake', gen_logits)
                training_stats.report('Loss/signs/fake', gen_logits.sign())
                loss_Dgen = torch.nn.functional.softplus(
                    gen_logits)  # -log(1 - sigmoid(gen_logits))
            with torch.autograd.profiler.record_function('Dgen_backward'):
                loss_Dgen.mean().mul(gain).backward()

        # Dmain: Maximize logits for real images.
        # Dr1: Apply R1 regularization.
        if do_Dmain or do_Dr1:
            name = 'Dreal_Dr1' if do_Dmain and do_Dr1 else 'Dreal' if do_Dmain else 'Dr1'
            with torch.autograd.profiler.record_function(name + '_forward'):
                real_img_tmp = real_img.detach().requires_grad_(do_Dr1)
                real_logits = self.run_D(real_img_tmp, real_c, sync=sync)
                training_stats.report('Loss/scores/real', real_logits)
                training_stats.report('Loss/signs/real', real_logits.sign())

                loss_Dreal = 0
                if do_Dmain:
                    loss_Dreal = torch.nn.functional.softplus(
                        -real_logits)  # -log(sigmoid(real_logits))
                    training_stats.report('Loss/D/loss',
                                          loss_Dgen + loss_Dreal)

                loss_Dr1 = 0
                if do_Dr1:
                    with torch.autograd.profiler.record_function(
                            'r1_grads'), conv2d_gradfix.no_weight_gradients():
                        r1_grads = torch.autograd.grad(
                            outputs=[real_logits.sum()],
                            inputs=[real_img_tmp],
                            create_graph=True,
                            only_inputs=True)[0]
                    r1_penalty = r1_grads.square().sum([1, 2, 3])
                    loss_Dr1 = r1_penalty * (self.r1_gamma / 2)
                    training_stats.report('Loss/r1_penalty', r1_penalty)
                    training_stats.report('Loss/D/reg', loss_Dr1)

            with torch.autograd.profiler.record_function(name + '_backward'):
                (real_logits * 0 + loss_Dreal +
                 loss_Dr1).mean().mul(gain).backward()
示例#6
0
    def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain):
        assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth']
        do_Gmain = (phase in ['Gmain', 'Gboth'])
        do_Dmain = (phase in ['Dmain', 'Dboth'])
        do_Gpl   = (phase in ['Greg', 'Gboth']) and (self.pl_weight != 0)
        do_Dr1   = (phase in ['Dreg', 'Dboth']) and (self.r1_gamma != 0)
        do_Gae   = False

        # Gmain: Maximize logits for generated images.
        if do_Gmain:
            with torch.autograd.profiler.record_function('Gmain_forward'):
                gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=(sync and not do_Gpl and not do_Gae)) # May get synced later.
                gen_logits = self.run_D(gen_img, gen_c, sync=False)
                training_stats.report('Loss/scores/fake', gen_logits)
                training_stats.report('Loss/signs/fake', gen_logits.sign())
                loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits))
                training_stats.report('Loss/G/loss', loss_Gmain)
            with torch.autograd.profiler.record_function('Gmain_backward'):
                loss_Gmain.mean().mul(gain).backward()

        # Gpl: Apply path length regularization.
        if do_Gpl:
            with torch.autograd.profiler.record_function('Gpl_forward'):
                batch_size = gen_z.shape[0] // self.pl_batch_shrink
                gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size], sync=sync)
                pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3])
                with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients():
                    pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0]
                pl_lengths = pl_grads.square().sum(2).mean(1).sqrt()
                pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay)
                self.pl_mean.copy_(pl_mean.detach())
                pl_penalty = (pl_lengths - pl_mean).square()
                training_stats.report('Loss/pl_penalty', pl_penalty)
                loss_Gpl = pl_penalty * self.pl_weight
                training_stats.report('Loss/G/reg', loss_Gpl)
            with torch.autograd.profiler.record_function('Gpl_backward'):
                (gen_img[:, 0, 0, 0] * 0 + loss_Gpl).mean().mul(gain).backward()

        if do_Gae:
            with torch.autograd.profiler.record_function('Gae_forward'):
                structure, style = self.run_encoder(real_img, real_c, sync)
                # batch_size = gen_z.shape[0] // self.pl_batch_shrink
                # gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size], sync=sync)
                # pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3])
                # with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients():
                #     pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0]
                # pl_lengths = pl_grads.square().sum(2).mean(1).sqrt()
                # pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay)
                # self.pl_mean.copy_(pl_mean.detach())
                # pl_penalty = (pl_lengths - pl_mean).square()
                rec_img = self.run_decoder(structure, style, sync)
                training_stats.report('Loss/G/ae_loss', pl_penalty)
                loss_Gae = ae_loss * self.G_ae_weight
                training_stats.report('Loss/G/ae_loss_weighted', loss_Gpl)
            with torch.autograd.profiler.record_function('Gae_backward'):
                (gen_img[:, 0, 0, 0] * 0 + loss_Gpl).mean().mul(gain).backward()

        # Dmain: Minimize logits for generated images.
        loss_Dgen = 0
        if do_Dmain:
            with torch.autograd.profiler.record_function('Dgen_forward'):
                gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=False)
                gen_logits = self.run_D(gen_img, gen_c, sync=False) # Gets synced by loss_Dreal.
                training_stats.report('Loss/scores/fake', gen_logits)
                training_stats.report('Loss/signs/fake', gen_logits.sign())
                loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits))
            with torch.autograd.profiler.record_function('Dgen_backward'):
                loss_Dgen.mean().mul(gain).backward()

        # Dmain: Maximize logits for real images.
        # Dr1: Apply R1 regularization.
        if do_Dmain or do_Dr1:
            name = 'Dreal_Dr1' if do_Dmain and do_Dr1 else 'Dreal' if do_Dmain else 'Dr1'
            with torch.autograd.profiler.record_function(name + '_forward'):
                real_img_tmp = real_img.detach().requires_grad_(do_Dr1)
                real_logits = self.run_D(real_img_tmp, real_c, sync=sync)
                training_stats.report('Loss/scores/real', real_logits)
                training_stats.report('Loss/signs/real', real_logits.sign())

                loss_Dreal = 0
                if do_Dmain:
                    loss_Dreal = torch.nn.functional.softplus(-real_logits) # -log(sigmoid(real_logits))
                    training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal)

                loss_Dr1 = 0
                if do_Dr1:
                    with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients():
                        r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0]
                    r1_penalty = r1_grads.square().sum([1,2,3])
                    loss_Dr1 = r1_penalty * (self.r1_gamma / 2)
                    training_stats.report('Loss/r1_penalty', r1_penalty)
                    training_stats.report('Loss/D/reg', loss_Dr1)

            with torch.autograd.profiler.record_function(name + '_backward'):
                (real_logits * 0 + loss_Dreal + loss_Dr1).mean().mul(gain).backward()