def _test_cyclegan_helper(self, create_gan_model_fn): model = create_gan_model_fn() loss = train.cyclegan_loss(model) self.assertIsInstance(loss, namedtuples.CycleGANLoss) # Check values. with self.test_session(use_gpu=True) as sess: variables.global_variables_initializer().run() (loss_x2y_gen_np, loss_x2y_dis_np, loss_y2x_gen_np, loss_y2x_dis_np) = sess.run([ loss.loss_x2y.generator_loss, loss.loss_x2y.discriminator_loss, loss.loss_y2x.generator_loss, loss.loss_y2x.discriminator_loss ]) self.assertGreater(loss_x2y_gen_np, loss_x2y_dis_np) self.assertGreater(loss_y2x_gen_np, loss_y2x_dis_np) self.assertTrue(np.isscalar(loss_x2y_gen_np)) self.assertTrue(np.isscalar(loss_x2y_dis_np)) self.assertTrue(np.isscalar(loss_y2x_gen_np)) self.assertTrue(np.isscalar(loss_y2x_dis_np))
def _test_cyclegan_helper(self, create_gan_model_fn): model = create_gan_model_fn() loss = train.cyclegan_loss(model) self.assertIsInstance(loss, namedtuples.CycleGANLoss) # Check values. with self.test_session(use_gpu=True) as sess: variables.global_variables_initializer().run() (loss_x2y_gen_np, loss_x2y_dis_np, loss_y2x_gen_np, loss_y2x_dis_np) = sess.run([ loss.loss_x2y.generator_loss, loss.loss_x2y.discriminator_loss, loss.loss_y2x.generator_loss, loss.loss_y2x.discriminator_loss ]) self.assertGreater(loss_x2y_gen_np, loss_x2y_dis_np) self.assertGreater(loss_y2x_gen_np, loss_y2x_dis_np) self.assertTrue(np.isscalar(loss_x2y_gen_np)) self.assertTrue(np.isscalar(loss_x2y_dis_np)) self.assertTrue(np.isscalar(loss_y2x_gen_np)) self.assertTrue(np.isscalar(loss_y2x_dis_np))
def test_output_type_callable_cyclegan(self): loss = train.cyclegan_loss( create_callable_cyclegan_model(), add_summaries=True) self.assertIsInstance(loss, namedtuples.CycleGANLoss) self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0)
def test_cyclegan_output_type(self, get_gan_model_fn): loss = train.cyclegan_loss(get_gan_model_fn(), add_summaries=True) self.assertIsInstance(loss, namedtuples.CycleGANLoss) self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0)
def test_cyclegan_output_type(self, get_gan_model_fn): loss = train.cyclegan_loss(get_gan_model_fn(), add_summaries=True) self.assertIsInstance(loss, namedtuples.CycleGANLoss) self.assertNotEmpty(ops.get_collection(ops.GraphKeys.SUMMARIES))