def setUp(self): super(GANHeadTest, self).setUp() self.gan_head = head.gan_head( generator_loss_fn=dummy_loss, discriminator_loss_fn=dummy_loss, generator_optimizer=training.GradientDescentOptimizer(1.0), discriminator_optimizer=training.GradientDescentOptimizer(1.0)) self.assertTrue(isinstance(self.gan_head, head.GANHead))
def setUp(self): super(GANHeadTest, self).setUp() self.gan_head = head.gan_head( generator_loss_fn=dummy_loss, discriminator_loss_fn=dummy_loss, generator_optimizer=training.GradientDescentOptimizer(1.0), discriminator_optimizer=training.GradientDescentOptimizer(1.0), get_eval_metric_ops_fn=self.get_metrics) self.assertIsInstance(self.gan_head, head.GANHead)
def _model_fn(features, labels, mode): gopt = (generator_optimizer() if callable(generator_optimizer) else generator_optimizer) dopt = (discriminator_optimizer() if callable(discriminator_optimizer) else discriminator_optimizer) gan_head = head_lib.gan_head( generator_loss_fn, discriminator_loss_fn, gopt, dopt, use_loss_summaries, get_hooks_fn=get_hooks_fn) return _gan_model_fn( features, labels, mode, generator_fn, discriminator_fn, gan_head, add_summaries)