def forward(self, x, d_train, ae): if ae: self.encoder.requires_grad_(True) z = torch.randn(x.shape[0], self.latent_size) _, rec = self.generate(z=z, noise=True) z_rec, d_result_real = self.encode(rec) lae = torch.mean(((z_rec - z.detach())**2)) return lae elif d_train: with torch.no_grad(): _, Xp = self.generate(count=x.shape[0], noise=True) self.encoder.requires_grad_(True) self.z_discriminator.requires_grad_(False) z1, d_result_real = self.encode(x) z2, d_result_fake = self.encode(Xp.detach()) loss_d = losses.discriminator_logistic_simple_gp(d_result_fake, d_result_real, x) zd_result_fake = self.z_discriminator(z1) loss_zg = losses.generator_logistic_non_saturating(zd_result_fake) return loss_d, loss_zg else: with torch.no_grad(): z = torch.randn(x.shape[0], self.latent_size) self.encoder.requires_grad_(False) self.z_discriminator.requires_grad_(True) _, rec = self.generate(count=x.shape[0], z=z.detach(), noise=True) z_fake, d_result_fake = self.encode(rec) zd_result_fake = self.z_discriminator(z_fake.detach()) z_real = torch.randn(x.shape[0], self.latent_size).requires_grad_(True) zd_result_real = self.z_discriminator(z_real) loss_g = losses.generator_logistic_non_saturating(d_result_fake) loss_zd = losses.discriminator_logistic_simple_gp(zd_result_fake, zd_result_real, z_real) return loss_g, loss_zd
def forward(self, x, lod, blend_factor, d_train, ae): if ae: self.encoder.requires_grad_(True) z = torch.randn(x.shape[0], self.latent_size) s, rec = self.generate(lod, blend_factor, z=z, mixing=False, noise=True, return_styles=True) Z, d_result_real = self.encode(rec, lod, blend_factor) assert Z.shape == s.shape if self.z_regression: Lae = torch.mean(((Z[:, 0] - z)**2)) else: Lae = torch.mean(((Z - s.detach())**2)) return Lae elif d_train: with torch.no_grad(): Xp = self.generate(lod, blend_factor, count=x.shape[0], noise=True) self.encoder.requires_grad_(True) _, d_result_real = self.encode(x, lod, blend_factor) _, d_result_fake = self.encode(Xp.detach(), lod, blend_factor) loss_d = losses.discriminator_logistic_simple_gp( d_result_fake, d_result_real, x) return loss_d else: with torch.no_grad(): z = torch.randn(x.shape[0], self.latent_size) self.encoder.requires_grad_(False) rec = self.generate(lod, blend_factor, count=x.shape[0], z=z.detach(), noise=True) _, d_result_fake = self.encode(rec, lod, blend_factor) loss_g = losses.generator_logistic_non_saturating(d_result_fake) return loss_g