Example #1
0
    def test_create(self):
        with self.test_session():
            gan = mock_gan()

            loss = WassersteinLoss(gan, loss_config)
            d_loss, g_loss = loss.create()
            d_shape = gan.ops.shape(d_loss)
            g_shape = gan.ops.shape(g_loss)
            self.assertEqual(sum(d_shape), 0)
            self.assertEqual(sum(g_shape), 0)
Example #2
0
    def _create(self, d_real, d_fake):
        config = self.config

        alpha = config.alpha
        beta = config.beta
        wgan_loss_d, wgan_loss_g = WassersteinLoss._create(self, d_real, d_fake)
        lsgan_loss_d, lsgan_loss_g = LeastSquaresLoss._create(self, d_real, d_fake)
        standard_loss_d, standard_loss_g = StandardLoss._create(self, d_real, d_fake)

        total = min(alpha + beta,1)

        d_loss = wgan_loss_d*alpha + lsgan_loss_d*beta + (1-total)*standard_loss_d
        g_loss = wgan_loss_g*alpha + lsgan_loss_g*beta + (1-total)*standard_loss_g

        return [d_loss, g_loss]
Example #3
0
 def test_config(self):
     with self.test_session():
         loss = WassersteinLoss(hg.GAN(), loss_config)
         self.assertTrue(loss.config.test)