Exemplo n.º 1
0
 def foward_d(self, source, output, target):
     fake = losses.conditional_input(source, output, self.conditional)
     real = losses.conditional_input(source, target, self.conditional)
     d_fake = self.adv_loss(self.dnet(fake), False)
     d_real = self.adv_loss(self.dnet(real), True)
     d_terms = {'d/real': d_real, 'd/fake': d_fake, 'd/loss': d_real + d_fake}
     d_loss = d_terms['d/loss']
     return d_loss, d_terms
Exemplo n.º 2
0
 def foward_d(self, source, output, target):
     fake = losses.conditional_input(source, output, self.conditional)
     real = losses.conditional_input(source, target, self.conditional)
     d_fake = self.adv_loss(self.dnet(fake), False)
     d_real = self.adv_loss(self.dnet(real), True)
     d_gp = losses.gradient_panelty(self.dnet, real, fake)
     d_terms = {'d/real': d_real, 'd/fake': d_fake, 'd/gp': d_gp, 'd/loss': d_real + d_fake + self.gp_weight * d_gp}
     d_loss = d_terms['d/loss']
     return d_loss, d_terms
Exemplo n.º 3
0
def test_loss_terms():
    source = torch.randn(10, 3, 128, 128, device=device())
    target = torch.randn(10, 3, 128, 128, device=device())
    output = g(source)

    real = L.conditional_input(source, target, conditional)
    fake = L.conditional_input(source, output, conditional)

    L.adversarial_ce_loss(F.sigmoid(d(fake)), 1)
    L.adversarial_ls_loss(d(fake), 1)
    L.adversarial_w_loss(d(fake), True)
    L.gradient_penalty(d, real, fake)
Exemplo n.º 4
0
 def forward_g(self, source, output, target):
     recon = self.recon_loss(output, target)
     adv = self.adv_loss(
         self.dnet(
             losses.conditional_input(source, output, self.conditional)),
         True)
     g_terms = {
         'g/recon': recon,
         'g/adv': adv,
         'g/loss': recon * self.recon_weight + adv
     }
     g_loss = g_terms['g/loss']
     return g_loss, g_terms