コード例 #1
0
 def test_metric(self):
     with self.test_session():
         gan = mock_gan()
         loss = SupervisedLoss(gan, loss_config)
         d_loss, g_loss = loss.create()
         metrics = loss.metrics
         self.assertTrue(metrics['d_class_loss'] != None)
コード例 #2
0
 def test_create(self):
     with self.test_session():
         gan = mock_gan()
         loss = SupervisedLoss(gan, loss_config)
         d_loss, g_loss = loss.create()
         d_shape = gan.ops.shape(d_loss)
         self.assertEqual(d_shape, [1])
         self.assertEqual(g_loss, None)
コード例 #3
0
ファイル: cli.py プロジェクト: halflife2/HyperGAN
 def add_supervised_loss(self):
     if self.args.classloss:
         print("[discriminator] Class loss is on.  Semi-supervised learning mode activated.")
         supervised_loss = SupervisedLoss(self.gan, self.gan.config.loss)
         self.gan.loss = MultiComponent(components=[supervised_loss, self.gan.loss], combine='add')
         supervised_loss.create()
         #EWW
     else:
         print("[discriminator] Class loss is off.  Unsupervised learning mode activated.")
コード例 #4
0
 def test_config(self):
     with self.test_session():
         loss = SupervisedLoss(mock_gan(), loss_config)
         self.assertTrue(loss.config.test)