Esempio n. 1
0
    def test_acgan(self, create_gan_model_fn):
        """Test that ACGAN models work."""
        model = create_gan_model_fn()
        loss = train.gan_loss(model)
        loss_ac_gen = train.gan_loss(model, aux_cond_generator_weight=1.0)
        loss_ac_dis = train.gan_loss(model, aux_cond_discriminator_weight=1.0)
        self.assertIsInstance(loss, namedtuples.GANLoss)
        self.assertIsInstance(loss_ac_gen, namedtuples.GANLoss)
        self.assertIsInstance(loss_ac_dis, namedtuples.GANLoss)

        # Check values.
        with self.test_session(use_gpu=True) as sess:
            variables.global_variables_initializer().run()
            loss_gen_np, loss_ac_gen_gen_np, loss_ac_dis_gen_np = sess.run([
                loss.generator_loss, loss_ac_gen.generator_loss,
                loss_ac_dis.generator_loss
            ])
            loss_dis_np, loss_ac_gen_dis_np, loss_ac_dis_dis_np = sess.run([
                loss.discriminator_loss, loss_ac_gen.discriminator_loss,
                loss_ac_dis.discriminator_loss
            ])

        self.assertLess(loss_gen_np, loss_dis_np)
        self.assertTrue(np.isscalar(loss_ac_gen_gen_np))
        self.assertTrue(np.isscalar(loss_ac_dis_gen_np))
        self.assertTrue(np.isscalar(loss_ac_gen_dis_np))
        self.assertTrue(np.isscalar(loss_ac_dis_dis_np))
Esempio n. 2
0
    def test_doesnt_crash_when_in_nested_scope(self):
        with variable_scope.variable_scope('outer_scope'):
            gan_model = train.gan_model(
                generator_model,
                discriminator_model,
                real_data=array_ops.zeros([1, 2]),
                generator_inputs=random_ops.random_normal([1, 2]))

            # This should work inside a scope.
            train.gan_loss(gan_model, gradient_penalty_weight=1.0)

        # This should also work outside a scope.
        train.gan_loss(gan_model, gradient_penalty_weight=1.0)
Esempio n. 3
0
    def test_train_hooks_exist_in_get_hooks_fn(self, create_gan_model_fn):
        model = create_gan_model_fn()
        loss = train.gan_loss(model)

        g_opt = get_sync_optimizer()
        d_opt = get_sync_optimizer()
        train_ops = train.gan_train_ops(model,
                                        loss,
                                        g_opt,
                                        d_opt,
                                        summarize_gradients=True,
                                        colocate_gradients_with_ops=True)

        sequential_train_hooks = train.get_sequential_train_hooks()(train_ops)
        self.assertLen(sequential_train_hooks, 4)
        sync_opts = [
            hook._sync_optimizer
            for hook in sequential_train_hooks if isinstance(
                hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)
        ]
        self.assertLen(sync_opts, 2)
        self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))

        joint_train_hooks = train.get_joint_train_hooks()(train_ops)
        self.assertLen(joint_train_hooks, 5)
        sync_opts = [
            hook._sync_optimizer for hook in joint_train_hooks if isinstance(
                hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)
        ]
        self.assertLen(sync_opts, 2)
        self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))
Esempio n. 4
0
    def test_unused_update_ops(self, create_gan_model_fn, provide_update_ops):
        model = create_gan_model_fn()
        loss = train.gan_loss(model)

        # Add generator and discriminator update ops.
        with variable_scope.variable_scope(model.generator_scope):
            gen_update_count = variable_scope.get_variable('gen_count',
                                                           initializer=0)
            gen_update_op = gen_update_count.assign_add(1)
            ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, gen_update_op)
        with variable_scope.variable_scope(model.discriminator_scope):
            dis_update_count = variable_scope.get_variable('dis_count',
                                                           initializer=0)
            dis_update_op = dis_update_count.assign_add(1)
            ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, dis_update_op)

        # Add an update op outside the generator and discriminator scopes.
        if provide_update_ops:
            kwargs = {
                'update_ops':
                [constant_op.constant(1.0), gen_update_op, dis_update_op]
            }
        else:
            ops.add_to_collection(ops.GraphKeys.UPDATE_OPS,
                                  constant_op.constant(1.0))
            kwargs = {}

        g_opt = gradient_descent.GradientDescentOptimizer(1.0)
        d_opt = gradient_descent.GradientDescentOptimizer(1.0)

        with self.assertRaisesRegexp(ValueError,
                                     'There are unused update ops:'):
            train.gan_train_ops(model,
                                loss,
                                g_opt,
                                d_opt,
                                check_for_unused_update_ops=True,
                                **kwargs)
        train_ops = train.gan_train_ops(model,
                                        loss,
                                        g_opt,
                                        d_opt,
                                        check_for_unused_update_ops=False,
                                        **kwargs)

        with self.test_session(use_gpu=True) as sess:
            sess.run(variables.global_variables_initializer())
            self.assertEqual(0, gen_update_count.eval())
            self.assertEqual(0, dis_update_count.eval())

            train_ops.generator_train_op.eval()
            self.assertEqual(1, gen_update_count.eval())
            self.assertEqual(0, dis_update_count.eval())

            train_ops.discriminator_train_op.eval()
            self.assertEqual(1, gen_update_count.eval())
            self.assertEqual(1, dis_update_count.eval())
Esempio n. 5
0
    def test_grad_penalty(self, create_gan_model_fn, one_sided):
        """Test gradient penalty option."""
        model = create_gan_model_fn()
        loss = train.gan_loss(model)
        loss_gp = train.gan_loss(model,
                                 gradient_penalty_weight=1.0,
                                 gradient_penalty_one_sided=one_sided)
        self.assertIsInstance(loss_gp, namedtuples.GANLoss)

        # Check values.
        with self.test_session(use_gpu=True) as sess:
            variables.global_variables_initializer().run()
            loss_gen_np, loss_gen_gp_np = sess.run(
                [loss.generator_loss, loss_gp.generator_loss])
            loss_dis_np, loss_dis_gp_np = sess.run(
                [loss.discriminator_loss, loss_gp.discriminator_loss])

        self.assertEqual(loss_gen_np, loss_gen_gp_np)
        self.assertLess(loss_dis_np, loss_dis_gp_np)
Esempio n. 6
0
    def test_discriminator_only_sees_pool(self):
        """Checks that discriminator only sees pooled values."""
        def checker_gen_fn(_):
            return constant_op.constant(0.0)

        model = train.gan_model(checker_gen_fn,
                                discriminator_model,
                                real_data=array_ops.zeros([]),
                                generator_inputs=random_ops.random_normal([]))

        def tensor_pool_fn(_):
            return (random_ops.random_uniform([]),
                    random_ops.random_uniform([]))

        def checker_dis_fn(inputs, _):
            """Discriminator that checks that it only sees pooled Tensors."""
            self.assertFalse(constant_op.is_constant(inputs))
            return inputs

        model = model._replace(discriminator_fn=checker_dis_fn)
        train.gan_loss(model, tensor_pool_fn=tensor_pool_fn)
Esempio n. 7
0
    def test_tensor_pool(self, create_gan_model_fn):
        """Test tensor pool option."""
        model = create_gan_model_fn()
        tensor_pool_fn = lambda x: random_tensor_pool.tensor_pool(x,
                                                                  pool_size=5)
        loss = train.gan_loss(model, tensor_pool_fn=tensor_pool_fn)
        self.assertIsInstance(loss, namedtuples.GANLoss)

        # Check values.
        with self.test_session(use_gpu=True) as sess:
            variables.global_variables_initializer().run()
            for _ in range(10):
                sess.run([loss.generator_loss, loss.discriminator_loss])
Esempio n. 8
0
    def test_regularization_helper(self, get_gan_model_fn):
        """Test regularization loss."""
        # Evaluate losses without regularization.
        no_reg_loss = train.gan_loss(get_gan_model_fn())
        with self.test_session(use_gpu=True):
            no_reg_loss_gen_np = no_reg_loss.generator_loss.eval()
            no_reg_loss_dis_np = no_reg_loss.discriminator_loss.eval()

        with ops.name_scope(get_gan_model_fn().generator_scope.name):
            ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES,
                                  constant_op.constant(3.0))
        with ops.name_scope(get_gan_model_fn().discriminator_scope.name):
            ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES,
                                  constant_op.constant(2.0))

        # Check that losses now include the correct regularization values.
        reg_loss = train.gan_loss(get_gan_model_fn())
        with self.test_session(use_gpu=True):
            reg_loss_gen_np = reg_loss.generator_loss.eval()
            reg_loss_dis_np = reg_loss.discriminator_loss.eval()

        self.assertEqual(3.0, reg_loss_gen_np - no_reg_loss_gen_np)
        self.assertEqual(2.0, reg_loss_dis_np - no_reg_loss_dis_np)
Esempio n. 9
0
    def test_run_helper(self, create_gan_model_fn):
        random_seed.set_random_seed(1234)
        model = create_gan_model_fn()
        loss = train.gan_loss(model)

        g_opt = gradient_descent.GradientDescentOptimizer(1.0)
        d_opt = gradient_descent.GradientDescentOptimizer(1.0)
        train_ops = train.gan_train_ops(model, loss, g_opt, d_opt)

        final_step = train.gan_train(
            train_ops,
            logdir='',
            hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=2)])
        self.assertTrue(np.isscalar(final_step))
        self.assertEqual(2, final_step)
Esempio n. 10
0
    def test_patchgan(self, create_gan_model_fn):
        """Ensure that patch-based discriminators work end-to-end."""
        random_seed.set_random_seed(1234)
        model = create_gan_model_fn()
        loss = train.gan_loss(model)

        g_opt = gradient_descent.GradientDescentOptimizer(1.0)
        d_opt = gradient_descent.GradientDescentOptimizer(1.0)
        train_ops = train.gan_train_ops(model, loss, g_opt, d_opt)

        final_step = train.gan_train(
            train_ops,
            logdir='',
            hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=2)])
        self.assertTrue(np.isscalar(final_step))
        self.assertEqual(2, final_step)
Esempio n. 11
0
    def test_output_type(self, create_gan_model_fn):
        model = create_gan_model_fn()
        loss = train.gan_loss(model)

        g_opt = gradient_descent.GradientDescentOptimizer(1.0)
        d_opt = gradient_descent.GradientDescentOptimizer(1.0)
        train_ops = train.gan_train_ops(model,
                                        loss,
                                        g_opt,
                                        d_opt,
                                        summarize_gradients=True,
                                        colocate_gradients_with_ops=True)

        self.assertIsInstance(train_ops, namedtuples.GANTrainOps)

        # Make sure there are no training hooks populated accidentally.
        self.assertEmpty(train_ops.train_hooks)
Esempio n. 12
0
    def test_is_chief_in_train_hooks(self, is_chief):
        """Make sure is_chief is propagated correctly to sync hooks."""
        model = create_gan_model()
        loss = train.gan_loss(model)
        g_opt = get_sync_optimizer()
        d_opt = get_sync_optimizer()
        train_ops = train.gan_train_ops(model,
                                        loss,
                                        g_opt,
                                        d_opt,
                                        is_chief=is_chief,
                                        summarize_gradients=True,
                                        colocate_gradients_with_ops=True)

        self.assertLen(train_ops.train_hooks, 2)
        for hook in train_ops.train_hooks:
            self.assertIsInstance(
                hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)
        is_chief_list = [hook._is_chief for hook in train_ops.train_hooks]
        self.assertListEqual(is_chief_list, [is_chief, is_chief])
Esempio n. 13
0
    def test_sync_replicas(self, create_gan_model_fn, create_global_step):
        model = create_gan_model_fn()
        loss = train.gan_loss(model)
        num_trainable_vars = len(variables_lib.get_trainable_variables())

        if create_global_step:
            gstep = variable_scope.get_variable('custom_gstep',
                                                dtype=dtypes.int32,
                                                initializer=0,
                                                trainable=False)
            ops.add_to_collection(ops.GraphKeys.GLOBAL_STEP, gstep)

        g_opt = get_sync_optimizer()
        d_opt = get_sync_optimizer()
        train_ops = train.gan_train_ops(model,
                                        loss,
                                        generator_optimizer=g_opt,
                                        discriminator_optimizer=d_opt)
        self.assertIsInstance(train_ops, namedtuples.GANTrainOps)
        # No new trainable variables should have been added.
        self.assertLen(variables_lib.get_trainable_variables(),
                       num_trainable_vars)

        # Sync hooks should be populated in the GANTrainOps.
        self.assertLen(train_ops.train_hooks, 2)
        for hook in train_ops.train_hooks:
            self.assertIsInstance(
                hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)
        sync_opts = [hook._sync_optimizer for hook in train_ops.train_hooks]
        self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))

        g_sync_init_op = g_opt.get_init_tokens_op(num_tokens=1)
        d_sync_init_op = d_opt.get_init_tokens_op(num_tokens=1)

        # Check that update op is run properly.
        global_step = training_util.get_or_create_global_step()
        with self.test_session(use_gpu=True) as sess:
            variables.global_variables_initializer().run()
            variables.local_variables_initializer().run()

            g_opt.chief_init_op.run()
            d_opt.chief_init_op.run()

            gstep_before = global_step.eval()

            # Start required queue runner for SyncReplicasOptimizer.
            coord = coordinator.Coordinator()
            g_threads = g_opt.get_chief_queue_runner().create_threads(
                sess, coord)
            d_threads = d_opt.get_chief_queue_runner().create_threads(
                sess, coord)

            g_sync_init_op.run()
            d_sync_init_op.run()

            train_ops.generator_train_op.eval()
            # Check that global step wasn't incremented.
            self.assertEqual(gstep_before, global_step.eval())

            train_ops.discriminator_train_op.eval()
            # Check that global step wasn't incremented.
            self.assertEqual(gstep_before, global_step.eval())

            coord.request_stop()
            coord.join(g_threads + d_threads)
Esempio n. 14
0
 def test_mutual_info_penalty(self, create_gan_model_fn):
     """Test mutual information penalty option."""
     train.gan_loss(
         create_gan_model_fn(),
         mutual_information_penalty_weight=constant_op.constant(1.0))
Esempio n. 15
0
 def test_output_type(self, get_gan_model_fn):
     """Test output type."""
     loss = train.gan_loss(get_gan_model_fn(), add_summaries=True)
     self.assertIsInstance(loss, namedtuples.GANLoss)
     self.assertNotEmpty(ops.get_collection(ops.GraphKeys.SUMMARIES))