Exemple #1
0
    def forward(self, x, d_train, ae):
        if ae:
            self.encoder.requires_grad_(True)

            z = torch.randn(x.shape[0], self.latent_size)
            _, rec = self.generate(z=z, noise=True)

            z_rec, d_result_real = self.encode(rec)

            lae = torch.mean(((z_rec - z.detach())**2))

            return lae

        elif d_train:
            with torch.no_grad():
                _, Xp = self.generate(count=x.shape[0], noise=True)

            self.encoder.requires_grad_(True)
            self.z_discriminator.requires_grad_(False)

            z1, d_result_real = self.encode(x)

            z2, d_result_fake = self.encode(Xp.detach())

            loss_d = losses.discriminator_logistic_simple_gp(d_result_fake, d_result_real, x)

            zd_result_fake = self.z_discriminator(z1)

            loss_zg = losses.generator_logistic_non_saturating(zd_result_fake)
            return loss_d, loss_zg
        else:
            with torch.no_grad():
                z = torch.randn(x.shape[0], self.latent_size)

            self.encoder.requires_grad_(False)
            self.z_discriminator.requires_grad_(True)

            _, rec = self.generate(count=x.shape[0], z=z.detach(), noise=True)

            z_fake, d_result_fake = self.encode(rec)

            zd_result_fake = self.z_discriminator(z_fake.detach())
            z_real = torch.randn(x.shape[0], self.latent_size).requires_grad_(True)
            zd_result_real = self.z_discriminator(z_real)

            loss_g = losses.generator_logistic_non_saturating(d_result_fake)
            loss_zd = losses.discriminator_logistic_simple_gp(zd_result_fake, zd_result_real, z_real)

            return loss_g, loss_zd
Exemple #2
0
    def forward(self, x, lod, blend_factor, d_train, ae):
        if ae:
            self.encoder.requires_grad_(True)

            z = torch.randn(x.shape[0], self.latent_size)
            s, rec = self.generate(lod,
                                   blend_factor,
                                   z=z,
                                   mixing=False,
                                   noise=True,
                                   return_styles=True)

            Z, d_result_real = self.encode(rec, lod, blend_factor)

            assert Z.shape == s.shape

            if self.z_regression:
                Lae = torch.mean(((Z[:, 0] - z)**2))
            else:
                Lae = torch.mean(((Z - s.detach())**2))

            return Lae

        elif d_train:
            with torch.no_grad():
                Xp = self.generate(lod,
                                   blend_factor,
                                   count=x.shape[0],
                                   noise=True)

            self.encoder.requires_grad_(True)

            _, d_result_real = self.encode(x, lod, blend_factor)

            _, d_result_fake = self.encode(Xp.detach(), lod, blend_factor)

            loss_d = losses.discriminator_logistic_simple_gp(
                d_result_fake, d_result_real, x)
            return loss_d
        else:
            with torch.no_grad():
                z = torch.randn(x.shape[0], self.latent_size)

            self.encoder.requires_grad_(False)

            rec = self.generate(lod,
                                blend_factor,
                                count=x.shape[0],
                                z=z.detach(),
                                noise=True)

            _, d_result_fake = self.encode(rec, lod, blend_factor)

            loss_g = losses.generator_logistic_non_saturating(d_result_fake)

            return loss_g