Exemplo n.º 1
0
    def _test_multiple_steps_helper(self, get_hooks_fn_fn):
        train_ops = self._gan_train_ops(generator_add=10,
                                        discriminator_add=100)
        train_steps = namedtuples.GANTrainSteps(generator_train_steps=3,
                                                discriminator_train_steps=4)
        final_step = train.gan_train(
            train_ops,
            get_hooks_fn=get_hooks_fn_fn(train_steps),
            logdir='',
            hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=1)])

        self.assertTrue(np.isscalar(final_step))
        self.assertEqual(1 + 3 * 10 + 4 * 100, final_step)
Exemplo n.º 2
0
  def _test_multiple_steps_helper(self, get_hooks_fn_fn):
    train_ops = self._gan_train_ops(generator_add=10, discriminator_add=100)
    train_steps = namedtuples.GANTrainSteps(
        generator_train_steps=3,
        discriminator_train_steps=4)
    final_step = train.gan_train(
        train_ops,
        get_hooks_fn=get_hooks_fn_fn(train_steps),
        logdir='',
        hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=1)])

    self.assertTrue(np.isscalar(final_step))
    self.assertEqual(1 + 3 * 10 + 4 * 100, final_step)
Exemplo n.º 3
0
    def _test_run_helper(self, create_gan_model_fn):
        random_seed.set_random_seed(1234)
        model = create_gan_model_fn()
        loss = train.gan_loss(model)

        g_opt = gradient_descent.GradientDescentOptimizer(1.0)
        d_opt = gradient_descent.GradientDescentOptimizer(1.0)
        train_ops = train.gan_train_ops(model, loss, g_opt, d_opt)

        final_step = train.gan_train(
            train_ops,
            logdir='',
            hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=2)])
        self.assertTrue(np.isscalar(final_step))
        self.assertEqual(2, final_step)
Exemplo n.º 4
0
  def _test_run_helper(self, create_gan_model_fn):
    random_seed.set_random_seed(1234)
    model = create_gan_model_fn()
    loss = train.gan_loss(model)

    g_opt = gradient_descent.GradientDescentOptimizer(1.0)
    d_opt = gradient_descent.GradientDescentOptimizer(1.0)
    train_ops = train.gan_train_ops(model, loss, g_opt, d_opt)

    final_step = train.gan_train(
        train_ops,
        logdir='',
        hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=2)])
    self.assertTrue(np.isscalar(final_step))
    self.assertEqual(2, final_step)
Exemplo n.º 5
0
    def test_patchgan(self, create_gan_model_fn):
        """Ensure that patch-based discriminators work end-to-end."""
        random_seed.set_random_seed(1234)
        model = create_gan_model_fn()
        loss = train.gan_loss(model)

        g_opt = gradient_descent.GradientDescentOptimizer(1.0)
        d_opt = gradient_descent.GradientDescentOptimizer(1.0)
        train_ops = train.gan_train_ops(model, loss, g_opt, d_opt)

        final_step = train.gan_train(
            train_ops,
            logdir='',
            hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=2)])
        self.assertTrue(np.isscalar(final_step))
        self.assertEqual(2, final_step)
Exemplo n.º 6
0
  def test_patchgan(self, create_gan_model_fn):
    """Ensure that patch-based discriminators work end-to-end."""
    random_seed.set_random_seed(1234)
    model = create_gan_model_fn()
    loss = train.gan_loss(model)

    g_opt = gradient_descent.GradientDescentOptimizer(1.0)
    d_opt = gradient_descent.GradientDescentOptimizer(1.0)
    train_ops = train.gan_train_ops(model, loss, g_opt, d_opt)

    final_step = train.gan_train(
        train_ops,
        logdir='',
        hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=2)])
    self.assertTrue(np.isscalar(final_step))
    self.assertEqual(2, final_step)