def _test_multiple_steps_helper(self, get_hooks_fn_fn): train_ops = self._gan_train_ops(generator_add=10, discriminator_add=100) train_steps = namedtuples.GANTrainSteps(generator_train_steps=3, discriminator_train_steps=4) final_step = train.gan_train( train_ops, get_hooks_fn=get_hooks_fn_fn(train_steps), logdir='', hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=1)]) self.assertTrue(np.isscalar(final_step)) self.assertEqual(1 + 3 * 10 + 4 * 100, final_step)
def _test_multiple_steps_helper(self, get_hooks_fn_fn): train_ops = self._gan_train_ops(generator_add=10, discriminator_add=100) train_steps = namedtuples.GANTrainSteps( generator_train_steps=3, discriminator_train_steps=4) final_step = train.gan_train( train_ops, get_hooks_fn=get_hooks_fn_fn(train_steps), logdir='', hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=1)]) self.assertTrue(np.isscalar(final_step)) self.assertEqual(1 + 3 * 10 + 4 * 100, final_step)
def _test_run_helper(self, create_gan_model_fn): random_seed.set_random_seed(1234) model = create_gan_model_fn() loss = train.gan_loss(model) g_opt = gradient_descent.GradientDescentOptimizer(1.0) d_opt = gradient_descent.GradientDescentOptimizer(1.0) train_ops = train.gan_train_ops(model, loss, g_opt, d_opt) final_step = train.gan_train( train_ops, logdir='', hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=2)]) self.assertTrue(np.isscalar(final_step)) self.assertEqual(2, final_step)
def _test_run_helper(self, create_gan_model_fn): random_seed.set_random_seed(1234) model = create_gan_model_fn() loss = train.gan_loss(model) g_opt = gradient_descent.GradientDescentOptimizer(1.0) d_opt = gradient_descent.GradientDescentOptimizer(1.0) train_ops = train.gan_train_ops(model, loss, g_opt, d_opt) final_step = train.gan_train( train_ops, logdir='', hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=2)]) self.assertTrue(np.isscalar(final_step)) self.assertEqual(2, final_step)
def test_patchgan(self, create_gan_model_fn): """Ensure that patch-based discriminators work end-to-end.""" random_seed.set_random_seed(1234) model = create_gan_model_fn() loss = train.gan_loss(model) g_opt = gradient_descent.GradientDescentOptimizer(1.0) d_opt = gradient_descent.GradientDescentOptimizer(1.0) train_ops = train.gan_train_ops(model, loss, g_opt, d_opt) final_step = train.gan_train( train_ops, logdir='', hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=2)]) self.assertTrue(np.isscalar(final_step)) self.assertEqual(2, final_step)
def test_patchgan(self, create_gan_model_fn): """Ensure that patch-based discriminators work end-to-end.""" random_seed.set_random_seed(1234) model = create_gan_model_fn() loss = train.gan_loss(model) g_opt = gradient_descent.GradientDescentOptimizer(1.0) d_opt = gradient_descent.GradientDescentOptimizer(1.0) train_ops = train.gan_train_ops(model, loss, g_opt, d_opt) final_step = train.gan_train( train_ops, logdir='', hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=2)]) self.assertTrue(np.isscalar(final_step)) self.assertEqual(2, final_step)