def test_supervisor_run_gan_model_train_ops_multiple_steps(self): step = training_util.create_global_step() train_ops = namedtuples.GANTrainOps( generator_train_op=constant_op.constant(3.0), discriminator_train_op=constant_op.constant(2.0), global_step_inc_op=step.assign_add(1)) train_steps = namedtuples.GANTrainSteps(generator_train_steps=3, discriminator_train_steps=4) final_loss = slim_learning.train( train_op=train_ops, logdir='', global_step=step, number_of_steps=1, train_step_fn=train.get_sequential_train_steps(train_steps)) self.assertTrue(np.isscalar(final_loss)) self.assertEqual(17.0, final_loss)
def test_supervisor_run_gan_model_train_ops_multiple_steps(self): step = training_util.create_global_step() train_ops = namedtuples.GANTrainOps( generator_train_op=constant_op.constant(3.0), discriminator_train_op=constant_op.constant(2.0), global_step_inc_op=step.assign_add(1)) train_steps = namedtuples.GANTrainSteps( generator_train_steps=3, discriminator_train_steps=4) final_loss = slim_learning.train( train_op=train_ops, logdir='', global_step=step, number_of_steps=1, train_step_fn=train.get_sequential_train_steps(train_steps)) self.assertTrue(np.isscalar(final_loss)) self.assertEqual(17.0, final_loss)