Exemplo n.º 1
0
    def _test_acgan_helper(self, create_gan_model_fn):
        model = create_gan_model_fn()
        loss = train.gan_loss(model)
        loss_ac_gen = train.gan_loss(model, aux_cond_generator_weight=1.0)
        loss_ac_dis = train.gan_loss(model, aux_cond_discriminator_weight=1.0)
        self.assertTrue(isinstance(loss, namedtuples.GANLoss))
        self.assertTrue(isinstance(loss_ac_gen, namedtuples.GANLoss))
        self.assertTrue(isinstance(loss_ac_dis, namedtuples.GANLoss))

        # Check values.
        with self.test_session(use_gpu=True) as sess:
            variables.global_variables_initializer().run()
            loss_gen_np, loss_ac_gen_gen_np, loss_ac_dis_gen_np = sess.run([
                loss.generator_loss, loss_ac_gen.generator_loss,
                loss_ac_dis.generator_loss
            ])
            loss_dis_np, loss_ac_gen_dis_np, loss_ac_dis_dis_np = sess.run([
                loss.discriminator_loss, loss_ac_gen.discriminator_loss,
                loss_ac_dis.discriminator_loss
            ])

        self.assertTrue(loss_gen_np < loss_dis_np)
        self.assertTrue(np.isscalar(loss_ac_gen_gen_np))
        self.assertTrue(np.isscalar(loss_ac_dis_gen_np))
        self.assertTrue(np.isscalar(loss_ac_gen_dis_np))
        self.assertTrue(np.isscalar(loss_ac_dis_dis_np))
Exemplo n.º 2
0
  def _test_acgan_helper(self, create_gan_model_fn):
    model = create_gan_model_fn()
    loss = train.gan_loss(model)
    loss_ac_gen = train.gan_loss(model, aux_cond_generator_weight=1.0)
    loss_ac_dis = train.gan_loss(model, aux_cond_discriminator_weight=1.0)
    self.assertTrue(isinstance(loss, namedtuples.GANLoss))
    self.assertTrue(isinstance(loss_ac_gen, namedtuples.GANLoss))
    self.assertTrue(isinstance(loss_ac_dis, namedtuples.GANLoss))

    # Check values.
    with self.test_session(use_gpu=True) as sess:
      variables.global_variables_initializer().run()
      loss_gen_np, loss_ac_gen_gen_np, loss_ac_dis_gen_np = sess.run(
          [loss.generator_loss,
           loss_ac_gen.generator_loss,
           loss_ac_dis.generator_loss])
      loss_dis_np, loss_ac_gen_dis_np, loss_ac_dis_dis_np = sess.run(
          [loss.discriminator_loss,
           loss_ac_gen.discriminator_loss,
           loss_ac_dis.discriminator_loss])

    self.assertTrue(loss_gen_np < loss_dis_np)
    self.assertTrue(np.isscalar(loss_ac_gen_gen_np))
    self.assertTrue(np.isscalar(loss_ac_dis_gen_np))
    self.assertTrue(np.isscalar(loss_ac_gen_dis_np))
    self.assertTrue(np.isscalar(loss_ac_dis_dis_np))
Exemplo n.º 3
0
    def test_doesnt_crash_when_in_nested_scope(self):
        with variable_scope.variable_scope('outer_scope'):
            gan_model = train.gan_model(
                generator_model,
                discriminator_model,
                real_data=array_ops.zeros([1, 2]),
                generator_inputs=random_ops.random_normal([1, 2]))

            # This should work inside a scope.
            train.gan_loss(gan_model, gradient_penalty_weight=1.0)

        # This should also work outside a scope.
        train.gan_loss(gan_model, gradient_penalty_weight=1.0)
Exemplo n.º 4
0
  def test_doesnt_crash_when_in_nested_scope(self):
    with variable_scope.variable_scope('outer_scope'):
      gan_model = train.gan_model(
          generator_model,
          discriminator_model,
          real_data=array_ops.zeros([1, 2]),
          generator_inputs=random_ops.random_normal([1, 2]))

      # This should work inside a scope.
      train.gan_loss(gan_model, gradient_penalty_weight=1.0)

    # This should also work outside a scope.
    train.gan_loss(gan_model, gradient_penalty_weight=1.0)
Exemplo n.º 5
0
    def _test_tensor_pool_helper(self, create_gan_model_fn):
        model = create_gan_model_fn()
        if isinstance(model, namedtuples.InfoGANModel):

            def tensor_pool_fn_impl(input_values):
                generated_data, generator_inputs = input_values
                output_values = random_tensor_pool.tensor_pool(
                    [generated_data] + generator_inputs, pool_size=5)
                return output_values[0], output_values[1:]

            tensor_pool_fn = tensor_pool_fn_impl
        else:

            def tensor_pool_fn_impl(input_values):
                return random_tensor_pool.tensor_pool(input_values,
                                                      pool_size=5)

            tensor_pool_fn = tensor_pool_fn_impl
        loss = train.gan_loss(model, tensor_pool_fn=tensor_pool_fn)
        self.assertTrue(isinstance(loss, namedtuples.GANLoss))

        # Check values.
        with self.test_session(use_gpu=True) as sess:
            variables.global_variables_initializer().run()
            for _ in range(10):
                sess.run([loss.generator_loss, loss.discriminator_loss])
Exemplo n.º 6
0
    def test_train_hooks_exist_in_get_hooks_fn(self, create_gan_model_fn):
        model = create_gan_model_fn()
        loss = train.gan_loss(model)

        g_opt = get_sync_optimizer()
        d_opt = get_sync_optimizer()
        train_ops = train.gan_train_ops(model,
                                        loss,
                                        g_opt,
                                        d_opt,
                                        summarize_gradients=True,
                                        colocate_gradients_with_ops=True)

        sequential_train_hooks = train.get_sequential_train_hooks()(train_ops)
        self.assertLen(sequential_train_hooks, 4)
        sync_opts = [
            hook._sync_optimizer
            for hook in sequential_train_hooks if isinstance(
                hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)
        ]
        self.assertLen(sync_opts, 2)
        self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))

        joint_train_hooks = train.get_joint_train_hooks()(train_ops)
        self.assertLen(joint_train_hooks, 5)
        sync_opts = [
            hook._sync_optimizer for hook in joint_train_hooks if isinstance(
                hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)
        ]
        self.assertLen(sync_opts, 2)
        self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))
Exemplo n.º 7
0
  def _test_tensor_pool_helper(self, create_gan_model_fn):
    model = create_gan_model_fn()
    if isinstance(model, namedtuples.InfoGANModel):

      def tensor_pool_fn_impl(input_values):
        generated_data, generator_inputs = input_values
        output_values = random_tensor_pool.tensor_pool(
            [generated_data] + generator_inputs, pool_size=5)
        return output_values[0], output_values[1:]

      tensor_pool_fn = tensor_pool_fn_impl
    else:

      def tensor_pool_fn_impl(input_values):
        return random_tensor_pool.tensor_pool(input_values, pool_size=5)

      tensor_pool_fn = tensor_pool_fn_impl
    loss = train.gan_loss(model, tensor_pool_fn=tensor_pool_fn)
    self.assertTrue(isinstance(loss, namedtuples.GANLoss))

    # Check values.
    with self.test_session(use_gpu=True) as sess:
      variables.global_variables_initializer().run()
      for _ in range(10):
        sess.run([loss.generator_loss, loss.discriminator_loss])
Exemplo n.º 8
0
  def test_train_hooks_exist_in_get_hooks_fn(self, create_gan_model_fn):
    model = create_gan_model_fn()
    loss = train.gan_loss(model)

    g_opt = get_sync_optimizer()
    d_opt = get_sync_optimizer()
    train_ops = train.gan_train_ops(
        model,
        loss,
        g_opt,
        d_opt,
        summarize_gradients=True,
        colocate_gradients_with_ops=True)

    sequential_train_hooks = train.get_sequential_train_hooks()(train_ops)
    self.assertLen(sequential_train_hooks, 4)
    sync_opts = [
        hook._sync_optimizer for hook in sequential_train_hooks if
        isinstance(hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)]
    self.assertLen(sync_opts, 2)
    self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))

    joint_train_hooks = train.get_joint_train_hooks()(train_ops)
    self.assertLen(joint_train_hooks, 5)
    sync_opts = [
        hook._sync_optimizer for hook in joint_train_hooks if
        isinstance(hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)]
    self.assertLen(sync_opts, 2)
    self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))
Exemplo n.º 9
0
    def _test_grad_penalty_helper(self, create_gan_model_fn):
        model = create_gan_model_fn()
        loss = train.gan_loss(model)
        loss_gp = train.gan_loss(model, gradient_penalty_weight=1.0)
        self.assertTrue(isinstance(loss_gp, namedtuples.GANLoss))

        # Check values.
        with self.test_session(use_gpu=True) as sess:
            variables.global_variables_initializer().run()
            loss_gen_np, loss_gen_gp_np = sess.run(
                [loss.generator_loss, loss_gp.generator_loss])
            loss_dis_np, loss_dis_gp_np = sess.run(
                [loss.discriminator_loss, loss_gp.discriminator_loss])

        self.assertEqual(loss_gen_np, loss_gen_gp_np)
        self.assertTrue(loss_dis_np < loss_dis_gp_np)
Exemplo n.º 10
0
  def _test_grad_penalty_helper(self, create_gan_model_fn):
    model = create_gan_model_fn()
    loss = train.gan_loss(model)
    loss_gp = train.gan_loss(model, gradient_penalty_weight=1.0)
    self.assertTrue(isinstance(loss_gp, namedtuples.GANLoss))

    # Check values.
    with self.test_session(use_gpu=True) as sess:
      variables.global_variables_initializer().run()
      loss_gen_np, loss_gen_gp_np = sess.run(
          [loss.generator_loss, loss_gp.generator_loss])
      loss_dis_np, loss_dis_gp_np = sess.run(
          [loss.discriminator_loss, loss_gp.discriminator_loss])

    self.assertEqual(loss_gen_np, loss_gen_gp_np)
    self.assertTrue(loss_dis_np < loss_dis_gp_np)
Exemplo n.º 11
0
    def _test_sync_replicas_helper(self,
                                   create_gan_model_fn,
                                   create_global_step=False):
        model = create_gan_model_fn()
        loss = train.gan_loss(model)
        num_trainable_vars = len(variables_lib.get_trainable_variables())

        if create_global_step:
            gstep = variable_scope.get_variable('custom_gstep',
                                                dtype=dtypes.int32,
                                                initializer=0,
                                                trainable=False)
            ops.add_to_collection(ops.GraphKeys.GLOBAL_STEP, gstep)

        g_opt = get_sync_optimizer()
        d_opt = get_sync_optimizer()
        train_ops = train.gan_train_ops(model,
                                        loss,
                                        generator_optimizer=g_opt,
                                        discriminator_optimizer=d_opt)
        self.assertTrue(isinstance(train_ops, namedtuples.GANTrainOps))
        # No new trainable variables should have been added.
        self.assertEqual(num_trainable_vars,
                         len(variables_lib.get_trainable_variables()))

        g_sync_init_op = g_opt.get_init_tokens_op(num_tokens=1)
        d_sync_init_op = d_opt.get_init_tokens_op(num_tokens=1)

        # Check that update op is run properly.
        global_step = training_util.get_or_create_global_step()
        with self.test_session(use_gpu=True) as sess:
            variables.global_variables_initializer().run()
            variables.local_variables_initializer().run()

            g_opt.chief_init_op.run()
            d_opt.chief_init_op.run()

            gstep_before = global_step.eval()

            # Start required queue runner for SyncReplicasOptimizer.
            coord = coordinator.Coordinator()
            g_threads = g_opt.get_chief_queue_runner().create_threads(
                sess, coord)
            d_threads = d_opt.get_chief_queue_runner().create_threads(
                sess, coord)

            g_sync_init_op.run()
            d_sync_init_op.run()

            train_ops.generator_train_op.eval()
            # Check that global step wasn't incremented.
            self.assertEqual(gstep_before, global_step.eval())

            train_ops.discriminator_train_op.eval()
            # Check that global step wasn't incremented.
            self.assertEqual(gstep_before, global_step.eval())

            coord.request_stop()
            coord.join(g_threads + d_threads)
Exemplo n.º 12
0
  def test_sync_replicas(self, create_gan_model_fn, create_global_step):
    model = create_gan_model_fn()
    loss = train.gan_loss(model)
    num_trainable_vars = len(variables_lib.get_trainable_variables())

    if create_global_step:
      gstep = variable_scope.get_variable(
          'custom_gstep', dtype=dtypes.int32, initializer=0, trainable=False)
      ops.add_to_collection(ops.GraphKeys.GLOBAL_STEP, gstep)

    g_opt = get_sync_optimizer()
    d_opt = get_sync_optimizer()
    train_ops = train.gan_train_ops(
        model, loss, generator_optimizer=g_opt, discriminator_optimizer=d_opt)
    self.assertIsInstance(train_ops, namedtuples.GANTrainOps)
    # No new trainable variables should have been added.
    self.assertLen(variables_lib.get_trainable_variables(), num_trainable_vars)

    # Sync hooks should be populated in the GANTrainOps.
    self.assertLen(train_ops.train_hooks, 2)
    for hook in train_ops.train_hooks:
      self.assertIsInstance(
          hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)
    sync_opts = [hook._sync_optimizer for hook in train_ops.train_hooks]
    self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))

    g_sync_init_op = g_opt.get_init_tokens_op(num_tokens=1)
    d_sync_init_op = d_opt.get_init_tokens_op(num_tokens=1)

    # Check that update op is run properly.
    global_step = training_util.get_or_create_global_step()
    with self.test_session(use_gpu=True) as sess:
      variables.global_variables_initializer().run()
      variables.local_variables_initializer().run()

      g_opt.chief_init_op.run()
      d_opt.chief_init_op.run()

      gstep_before = global_step.eval()

      # Start required queue runner for SyncReplicasOptimizer.
      coord = coordinator.Coordinator()
      g_threads = g_opt.get_chief_queue_runner().create_threads(sess, coord)
      d_threads = d_opt.get_chief_queue_runner().create_threads(sess, coord)

      g_sync_init_op.run()
      d_sync_init_op.run()

      train_ops.generator_train_op.eval()
      # Check that global step wasn't incremented.
      self.assertEqual(gstep_before, global_step.eval())

      train_ops.discriminator_train_op.eval()
      # Check that global step wasn't incremented.
      self.assertEqual(gstep_before, global_step.eval())

      coord.request_stop()
      coord.join(g_threads + d_threads)
Exemplo n.º 13
0
 def test_discriminator_only_sees_pool(self):
   """Checks that discriminator only sees pooled values."""
   def checker_gen_fn(_):
     return constant_op.constant(0.0)
   model = train.gan_model(
       checker_gen_fn,
       discriminator_model,
       real_data=array_ops.zeros([]),
       generator_inputs=random_ops.random_normal([]))
   def tensor_pool_fn(_):
     return (random_ops.random_uniform([]), random_ops.random_uniform([]))
   def checker_dis_fn(inputs, _):
     """Discriminator that checks that it only sees pooled Tensors."""
     self.assertFalse(constant_op.is_constant(inputs))
     return inputs
   model = model._replace(
       discriminator_fn=checker_dis_fn)
   train.gan_loss(model, tensor_pool_fn=tensor_pool_fn)
Exemplo n.º 14
0
 def test_discriminator_only_sees_pool(self):
   """Checks that discriminator only sees pooled values."""
   def checker_gen_fn(_):
     return constant_op.constant(0.0)
   model = train.gan_model(
       checker_gen_fn,
       discriminator_model,
       real_data=array_ops.zeros([]),
       generator_inputs=random_ops.random_normal([]))
   def tensor_pool_fn(_):
     return (random_ops.random_uniform([]), random_ops.random_uniform([]))
   def checker_dis_fn(inputs, _):
     """Discriminator that checks that it only sees pooled Tensors."""
     self.assertFalse(constant_op.is_constant(inputs))
     return inputs
   model = model._replace(
       discriminator_fn=checker_dis_fn)
   train.gan_loss(model, tensor_pool_fn=tensor_pool_fn)
Exemplo n.º 15
0
    def _test_unused_update_ops(self, create_gan_model_fn, provide_update_ops):
        model = create_gan_model_fn()
        loss = train.gan_loss(model)

        # Add generator and discriminator update ops.
        with variable_scope.variable_scope(model.generator_scope):
            gen_update_count = variable_scope.get_variable('gen_count',
                                                           initializer=0)
            gen_update_op = gen_update_count.assign_add(1)
            ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, gen_update_op)
        with variable_scope.variable_scope(model.discriminator_scope):
            dis_update_count = variable_scope.get_variable('dis_count',
                                                           initializer=0)
            dis_update_op = dis_update_count.assign_add(1)
            ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, dis_update_op)

        # Add an update op outside the generator and discriminator scopes.
        if provide_update_ops:
            kwargs = {
                'update_ops':
                [constant_op.constant(1.0), gen_update_op, dis_update_op]
            }
        else:
            ops.add_to_collection(ops.GraphKeys.UPDATE_OPS,
                                  constant_op.constant(1.0))
            kwargs = {}

        g_opt = gradient_descent.GradientDescentOptimizer(1.0)
        d_opt = gradient_descent.GradientDescentOptimizer(1.0)

        with self.assertRaisesRegexp(ValueError,
                                     'There are unused update ops:'):
            train.gan_train_ops(model,
                                loss,
                                g_opt,
                                d_opt,
                                check_for_unused_update_ops=True,
                                **kwargs)
        train_ops = train.gan_train_ops(model,
                                        loss,
                                        g_opt,
                                        d_opt,
                                        check_for_unused_update_ops=False,
                                        **kwargs)

        with self.test_session(use_gpu=True) as sess:
            sess.run(variables.global_variables_initializer())
            self.assertEqual(0, gen_update_count.eval())
            self.assertEqual(0, dis_update_count.eval())

            train_ops.generator_train_op.eval()
            self.assertEqual(1, gen_update_count.eval())
            self.assertEqual(0, dis_update_count.eval())

            train_ops.discriminator_train_op.eval()
            self.assertEqual(1, gen_update_count.eval())
            self.assertEqual(1, dis_update_count.eval())
Exemplo n.º 16
0
    def test_grad_penalty(self, create_gan_model_fn, one_sided):
        """Test gradient penalty option."""
        model = create_gan_model_fn()
        loss = train.gan_loss(model)
        loss_gp = train.gan_loss(model,
                                 gradient_penalty_weight=1.0,
                                 gradient_penalty_one_sided=one_sided)
        self.assertIsInstance(loss_gp, namedtuples.GANLoss)

        # Check values.
        with self.test_session(use_gpu=True) as sess:
            variables.global_variables_initializer().run()
            loss_gen_np, loss_gen_gp_np = sess.run(
                [loss.generator_loss, loss_gp.generator_loss])
            loss_dis_np, loss_dis_gp_np = sess.run(
                [loss.discriminator_loss, loss_gp.discriminator_loss])

        self.assertEqual(loss_gen_np, loss_gen_gp_np)
        self.assertLess(loss_dis_np, loss_dis_gp_np)
Exemplo n.º 17
0
  def test_grad_penalty(self, create_gan_model_fn, one_sided):
    """Test gradient penalty option."""
    model = create_gan_model_fn()
    loss = train.gan_loss(model)
    loss_gp = train.gan_loss(
        model,
        gradient_penalty_weight=1.0,
        gradient_penalty_one_sided=one_sided)
    self.assertIsInstance(loss_gp, namedtuples.GANLoss)

    # Check values.
    with self.test_session(use_gpu=True) as sess:
      variables.global_variables_initializer().run()
      loss_gen_np, loss_gen_gp_np = sess.run(
          [loss.generator_loss, loss_gp.generator_loss])
      loss_dis_np, loss_dis_gp_np = sess.run(
          [loss.discriminator_loss, loss_gp.discriminator_loss])

    self.assertEqual(loss_gen_np, loss_gen_gp_np)
    self.assertLess(loss_dis_np, loss_dis_gp_np)
Exemplo n.º 18
0
  def test_tensor_pool(self, create_gan_model_fn):
    """Test tensor pool option."""
    model = create_gan_model_fn()
    tensor_pool_fn = lambda x: random_tensor_pool.tensor_pool(x, pool_size=5)
    loss = train.gan_loss(model, tensor_pool_fn=tensor_pool_fn)
    self.assertIsInstance(loss, namedtuples.GANLoss)

    # Check values.
    with self.test_session(use_gpu=True) as sess:
      variables.global_variables_initializer().run()
      for _ in range(10):
        sess.run([loss.generator_loss, loss.discriminator_loss])
Exemplo n.º 19
0
  def test_tensor_pool(self, create_gan_model_fn):
    """Test tensor pool option."""
    model = create_gan_model_fn()
    tensor_pool_fn = lambda x: random_tensor_pool.tensor_pool(x, pool_size=5)
    loss = train.gan_loss(model, tensor_pool_fn=tensor_pool_fn)
    self.assertIsInstance(loss, namedtuples.GANLoss)

    # Check values.
    with self.test_session(use_gpu=True) as sess:
      variables.global_variables_initializer().run()
      for _ in range(10):
        sess.run([loss.generator_loss, loss.discriminator_loss])
Exemplo n.º 20
0
    def _test_regularization_helper(self, get_gan_model_fn):
        # Evaluate losses without regularization.
        no_reg_loss = train.gan_loss(get_gan_model_fn())
        with self.test_session(use_gpu=True):
            no_reg_loss_gen_np = no_reg_loss.generator_loss.eval()
            no_reg_loss_dis_np = no_reg_loss.discriminator_loss.eval()

        with ops.name_scope(get_gan_model_fn().generator_scope.name):
            ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES,
                                  constant_op.constant(3.0))
        with ops.name_scope(get_gan_model_fn().discriminator_scope.name):
            ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES,
                                  constant_op.constant(2.0))

        # Check that losses now include the correct regularization values.
        reg_loss = train.gan_loss(get_gan_model_fn())
        with self.test_session(use_gpu=True):
            reg_loss_gen_np = reg_loss.generator_loss.eval()
            reg_loss_dis_np = reg_loss.discriminator_loss.eval()

        self.assertTrue(3.0, reg_loss_gen_np - no_reg_loss_gen_np)
        self.assertTrue(3.0, reg_loss_dis_np - no_reg_loss_dis_np)
Exemplo n.º 21
0
  def _test_regularization_helper(self, get_gan_model_fn):
    # Evaluate losses without regularization.
    no_reg_loss = train.gan_loss(get_gan_model_fn())
    with self.test_session(use_gpu=True):
      no_reg_loss_gen_np = no_reg_loss.generator_loss.eval()
      no_reg_loss_dis_np = no_reg_loss.discriminator_loss.eval()

    with ops.name_scope(get_gan_model_fn().generator_scope.name):
      ops.add_to_collection(
          ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(3.0))
    with ops.name_scope(get_gan_model_fn().discriminator_scope.name):
      ops.add_to_collection(
          ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(2.0))

    # Check that losses now include the correct regularization values.
    reg_loss = train.gan_loss(get_gan_model_fn())
    with self.test_session(use_gpu=True):
      reg_loss_gen_np = reg_loss.generator_loss.eval()
      reg_loss_dis_np = reg_loss.discriminator_loss.eval()

    self.assertTrue(3.0, reg_loss_gen_np - no_reg_loss_gen_np)
    self.assertTrue(3.0, reg_loss_dis_np - no_reg_loss_dis_np)
Exemplo n.º 22
0
  def _test_tensor_pool_helper(self, create_gan_model_fn):
    model = create_gan_model_fn()
    if isinstance(model, namedtuples.InfoGANModel):
      tensor_pool_fn = get_tensor_pool_fn_for_infogan(pool_size=5)
    else:
      tensor_pool_fn = get_tensor_pool_fn(pool_size=5)
    loss = train.gan_loss(model, tensor_pool_fn=tensor_pool_fn)
    self.assertTrue(isinstance(loss, namedtuples.GANLoss))

    # Check values.
    with self.test_session(use_gpu=True) as sess:
      variables.global_variables_initializer().run()
      for _ in range(10):
        sess.run([loss.generator_loss, loss.discriminator_loss])
Exemplo n.º 23
0
    def _test_tensor_pool_helper(self, create_gan_model_fn):
        model = create_gan_model_fn()
        if isinstance(model, namedtuples.InfoGANModel):
            tensor_pool_fn = get_tensor_pool_fn_for_infogan(pool_size=5)
        else:
            tensor_pool_fn = get_tensor_pool_fn(pool_size=5)
        loss = train.gan_loss(model, tensor_pool_fn=tensor_pool_fn)
        self.assertTrue(isinstance(loss, namedtuples.GANLoss))

        # Check values.
        with self.test_session(use_gpu=True) as sess:
            variables.global_variables_initializer().run()
            for _ in range(10):
                sess.run([loss.generator_loss, loss.discriminator_loss])
Exemplo n.º 24
0
    def _test_output_type_helper(self, create_gan_model_fn):
        model = create_gan_model_fn()
        loss = train.gan_loss(model)

        g_opt = gradient_descent.GradientDescentOptimizer(1.0)
        d_opt = gradient_descent.GradientDescentOptimizer(1.0)
        train_ops = train.gan_train_ops(model,
                                        loss,
                                        g_opt,
                                        d_opt,
                                        summarize_gradients=True,
                                        colocate_gradients_with_ops=True)

        self.assertTrue(isinstance(train_ops, namedtuples.GANTrainOps))
Exemplo n.º 25
0
    def _test_run_helper(self, create_gan_model_fn):
        random_seed.set_random_seed(1234)
        model = create_gan_model_fn()
        loss = train.gan_loss(model)

        g_opt = gradient_descent.GradientDescentOptimizer(1.0)
        d_opt = gradient_descent.GradientDescentOptimizer(1.0)
        train_ops = train.gan_train_ops(model, loss, g_opt, d_opt)

        final_step = train.gan_train(
            train_ops,
            logdir='',
            hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=2)])
        self.assertTrue(np.isscalar(final_step))
        self.assertEqual(2, final_step)
Exemplo n.º 26
0
  def _test_output_type_helper(self, create_gan_model_fn):
    model = create_gan_model_fn()
    loss = train.gan_loss(model)

    g_opt = gradient_descent.GradientDescentOptimizer(1.0)
    d_opt = gradient_descent.GradientDescentOptimizer(1.0)
    train_ops = train.gan_train_ops(
        model,
        loss,
        g_opt,
        d_opt,
        summarize_gradients=True,
        colocate_gradients_with_ops=True)

    self.assertTrue(isinstance(train_ops, namedtuples.GANTrainOps))
Exemplo n.º 27
0
  def _test_run_helper(self, create_gan_model_fn):
    random_seed.set_random_seed(1234)
    model = create_gan_model_fn()
    loss = train.gan_loss(model)

    g_opt = gradient_descent.GradientDescentOptimizer(1.0)
    d_opt = gradient_descent.GradientDescentOptimizer(1.0)
    train_ops = train.gan_train_ops(model, loss, g_opt, d_opt)

    final_step = train.gan_train(
        train_ops,
        logdir='',
        hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=2)])
    self.assertTrue(np.isscalar(final_step))
    self.assertEqual(2, final_step)
Exemplo n.º 28
0
  def test_patchgan(self, create_gan_model_fn):
    """Ensure that patch-based discriminators work end-to-end."""
    random_seed.set_random_seed(1234)
    model = create_gan_model_fn()
    loss = train.gan_loss(model)

    g_opt = gradient_descent.GradientDescentOptimizer(1.0)
    d_opt = gradient_descent.GradientDescentOptimizer(1.0)
    train_ops = train.gan_train_ops(model, loss, g_opt, d_opt)

    final_step = train.gan_train(
        train_ops,
        logdir='',
        hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=2)])
    self.assertTrue(np.isscalar(final_step))
    self.assertEqual(2, final_step)
Exemplo n.º 29
0
    def test_patchgan(self, create_gan_model_fn):
        """Ensure that patch-based discriminators work end-to-end."""
        random_seed.set_random_seed(1234)
        model = create_gan_model_fn()
        loss = train.gan_loss(model)

        g_opt = gradient_descent.GradientDescentOptimizer(1.0)
        d_opt = gradient_descent.GradientDescentOptimizer(1.0)
        train_ops = train.gan_train_ops(model, loss, g_opt, d_opt)

        final_step = train.gan_train(
            train_ops,
            logdir='',
            hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=2)])
        self.assertTrue(np.isscalar(final_step))
        self.assertEqual(2, final_step)
Exemplo n.º 30
0
    def test_output_type(self, create_gan_model_fn):
        model = create_gan_model_fn()
        loss = train.gan_loss(model)

        g_opt = gradient_descent.GradientDescentOptimizer(1.0)
        d_opt = gradient_descent.GradientDescentOptimizer(1.0)
        train_ops = train.gan_train_ops(model,
                                        loss,
                                        g_opt,
                                        d_opt,
                                        summarize_gradients=True,
                                        colocate_gradients_with_ops=True)

        self.assertIsInstance(train_ops, namedtuples.GANTrainOps)

        # Make sure there are no training hooks populated accidentally.
        self.assertEmpty(train_ops.train_hooks)
Exemplo n.º 31
0
  def test_unused_update_ops(self, create_gan_model_fn, provide_update_ops):
    model = create_gan_model_fn()
    loss = train.gan_loss(model)

    # Add generator and discriminator update ops.
    with variable_scope.variable_scope(model.generator_scope):
      gen_update_count = variable_scope.get_variable('gen_count', initializer=0)
      gen_update_op = gen_update_count.assign_add(1)
      ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, gen_update_op)
    with variable_scope.variable_scope(model.discriminator_scope):
      dis_update_count = variable_scope.get_variable('dis_count', initializer=0)
      dis_update_op = dis_update_count.assign_add(1)
      ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, dis_update_op)

    # Add an update op outside the generator and discriminator scopes.
    if provide_update_ops:
      kwargs = {
          'update_ops': [
              constant_op.constant(1.0), gen_update_op, dis_update_op
          ]
      }
    else:
      ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, constant_op.constant(1.0))
      kwargs = {}

    g_opt = gradient_descent.GradientDescentOptimizer(1.0)
    d_opt = gradient_descent.GradientDescentOptimizer(1.0)

    with self.assertRaisesRegexp(ValueError, 'There are unused update ops:'):
      train.gan_train_ops(
          model, loss, g_opt, d_opt, check_for_unused_update_ops=True, **kwargs)
    train_ops = train.gan_train_ops(
        model, loss, g_opt, d_opt, check_for_unused_update_ops=False, **kwargs)

    with self.test_session(use_gpu=True) as sess:
      sess.run(variables.global_variables_initializer())
      self.assertEqual(0, gen_update_count.eval())
      self.assertEqual(0, dis_update_count.eval())

      train_ops.generator_train_op.eval()
      self.assertEqual(1, gen_update_count.eval())
      self.assertEqual(0, dis_update_count.eval())

      train_ops.discriminator_train_op.eval()
      self.assertEqual(1, gen_update_count.eval())
      self.assertEqual(1, dis_update_count.eval())
Exemplo n.º 32
0
  def test_output_type(self, create_gan_model_fn):
    model = create_gan_model_fn()
    loss = train.gan_loss(model)

    g_opt = gradient_descent.GradientDescentOptimizer(1.0)
    d_opt = gradient_descent.GradientDescentOptimizer(1.0)
    train_ops = train.gan_train_ops(
        model,
        loss,
        g_opt,
        d_opt,
        summarize_gradients=True,
        colocate_gradients_with_ops=True)

    self.assertIsInstance(train_ops, namedtuples.GANTrainOps)

    # Make sure there are no training hooks populated accidentally.
    self.assertEmpty(train_ops.train_hooks)
Exemplo n.º 33
0
    def test_is_chief_in_train_hooks(self, is_chief):
        """Make sure is_chief is propagated correctly to sync hooks."""
        model = create_gan_model()
        loss = train.gan_loss(model)
        g_opt = get_sync_optimizer()
        d_opt = get_sync_optimizer()
        train_ops = train.gan_train_ops(model,
                                        loss,
                                        g_opt,
                                        d_opt,
                                        is_chief=is_chief,
                                        summarize_gradients=True,
                                        colocate_gradients_with_ops=True)

        self.assertLen(train_ops.train_hooks, 2)
        for hook in train_ops.train_hooks:
            self.assertIsInstance(
                hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)
        is_chief_list = [hook._is_chief for hook in train_ops.train_hooks]
        self.assertListEqual(is_chief_list, [is_chief, is_chief])
Exemplo n.º 34
0
  def test_is_chief_in_train_hooks(self, is_chief):
    """Make sure is_chief is propagated correctly to sync hooks."""
    model = create_gan_model()
    loss = train.gan_loss(model)
    g_opt = get_sync_optimizer()
    d_opt = get_sync_optimizer()
    train_ops = train.gan_train_ops(
        model,
        loss,
        g_opt,
        d_opt,
        is_chief=is_chief,
        summarize_gradients=True,
        colocate_gradients_with_ops=True)

    self.assertLen(train_ops.train_hooks, 2)
    for hook in train_ops.train_hooks:
      self.assertIsInstance(
          hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)
    is_chief_list = [hook._is_chief for hook in train_ops.train_hooks]
    self.assertListEqual(is_chief_list, [is_chief, is_chief])
Exemplo n.º 35
0
 def test_output_type(self, get_gan_model_fn):
   """Test output type."""
   loss = train.gan_loss(get_gan_model_fn(), add_summaries=True)
   self.assertIsInstance(loss, namedtuples.GANLoss)
   self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0)
Exemplo n.º 36
0
 def test_mutual_info_penalty(self, create_gan_model_fn):
   """Test mutual information penalty option."""
   train.gan_loss(
       create_gan_model_fn(),
       mutual_information_penalty_weight=constant_op.constant(1.0))
Exemplo n.º 37
0
 def _test_mutual_info_penalty_helper(self, create_gan_model_fn):
     train.gan_loss(
         create_gan_model_fn(),
         mutual_information_penalty_weight=constant_op.constant(1.0))
Exemplo n.º 38
0
 def _test_output_type_helper(self, get_gan_model_fn):
   loss = train.gan_loss(get_gan_model_fn(), add_summaries=True)
   self.assertTrue(isinstance(loss, namedtuples.GANLoss))
   self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0)
Exemplo n.º 39
0
 def _test_output_type_helper(self, get_gan_model_fn):
     loss = train.gan_loss(get_gan_model_fn(), add_summaries=True)
     self.assertTrue(isinstance(loss, namedtuples.GANLoss))
     self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0)
Exemplo n.º 40
0
 def test_output_type(self, get_gan_model_fn):
     """Test output type."""
     loss = train.gan_loss(get_gan_model_fn(), add_summaries=True)
     self.assertIsInstance(loss, namedtuples.GANLoss)
     self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0)
Exemplo n.º 41
0
 def test_mutual_info_penalty(self, create_gan_model_fn):
     """Test mutual information penalty option."""
     train.gan_loss(
         create_gan_model_fn(),
         mutual_information_penalty_weight=constant_op.constant(1.0))
Exemplo n.º 42
0
 def _test_mutual_info_penalty_helper(self, create_gan_model_fn):
   train.gan_loss(create_gan_model_fn(),
                  mutual_information_penalty_weight=constant_op.constant(1.0))