def _test_acgan_helper(self, create_gan_model_fn): model = create_gan_model_fn() loss = train.gan_loss(model) loss_ac_gen = train.gan_loss(model, aux_cond_generator_weight=1.0) loss_ac_dis = train.gan_loss(model, aux_cond_discriminator_weight=1.0) self.assertTrue(isinstance(loss, namedtuples.GANLoss)) self.assertTrue(isinstance(loss_ac_gen, namedtuples.GANLoss)) self.assertTrue(isinstance(loss_ac_dis, namedtuples.GANLoss)) # Check values. with self.test_session(use_gpu=True) as sess: variables.global_variables_initializer().run() loss_gen_np, loss_ac_gen_gen_np, loss_ac_dis_gen_np = sess.run([ loss.generator_loss, loss_ac_gen.generator_loss, loss_ac_dis.generator_loss ]) loss_dis_np, loss_ac_gen_dis_np, loss_ac_dis_dis_np = sess.run([ loss.discriminator_loss, loss_ac_gen.discriminator_loss, loss_ac_dis.discriminator_loss ]) self.assertTrue(loss_gen_np < loss_dis_np) self.assertTrue(np.isscalar(loss_ac_gen_gen_np)) self.assertTrue(np.isscalar(loss_ac_dis_gen_np)) self.assertTrue(np.isscalar(loss_ac_gen_dis_np)) self.assertTrue(np.isscalar(loss_ac_dis_dis_np))
def _test_acgan_helper(self, create_gan_model_fn): model = create_gan_model_fn() loss = train.gan_loss(model) loss_ac_gen = train.gan_loss(model, aux_cond_generator_weight=1.0) loss_ac_dis = train.gan_loss(model, aux_cond_discriminator_weight=1.0) self.assertTrue(isinstance(loss, namedtuples.GANLoss)) self.assertTrue(isinstance(loss_ac_gen, namedtuples.GANLoss)) self.assertTrue(isinstance(loss_ac_dis, namedtuples.GANLoss)) # Check values. with self.test_session(use_gpu=True) as sess: variables.global_variables_initializer().run() loss_gen_np, loss_ac_gen_gen_np, loss_ac_dis_gen_np = sess.run( [loss.generator_loss, loss_ac_gen.generator_loss, loss_ac_dis.generator_loss]) loss_dis_np, loss_ac_gen_dis_np, loss_ac_dis_dis_np = sess.run( [loss.discriminator_loss, loss_ac_gen.discriminator_loss, loss_ac_dis.discriminator_loss]) self.assertTrue(loss_gen_np < loss_dis_np) self.assertTrue(np.isscalar(loss_ac_gen_gen_np)) self.assertTrue(np.isscalar(loss_ac_dis_gen_np)) self.assertTrue(np.isscalar(loss_ac_gen_dis_np)) self.assertTrue(np.isscalar(loss_ac_dis_dis_np))
def test_doesnt_crash_when_in_nested_scope(self): with variable_scope.variable_scope('outer_scope'): gan_model = train.gan_model( generator_model, discriminator_model, real_data=array_ops.zeros([1, 2]), generator_inputs=random_ops.random_normal([1, 2])) # This should work inside a scope. train.gan_loss(gan_model, gradient_penalty_weight=1.0) # This should also work outside a scope. train.gan_loss(gan_model, gradient_penalty_weight=1.0)
def _test_tensor_pool_helper(self, create_gan_model_fn): model = create_gan_model_fn() if isinstance(model, namedtuples.InfoGANModel): def tensor_pool_fn_impl(input_values): generated_data, generator_inputs = input_values output_values = random_tensor_pool.tensor_pool( [generated_data] + generator_inputs, pool_size=5) return output_values[0], output_values[1:] tensor_pool_fn = tensor_pool_fn_impl else: def tensor_pool_fn_impl(input_values): return random_tensor_pool.tensor_pool(input_values, pool_size=5) tensor_pool_fn = tensor_pool_fn_impl loss = train.gan_loss(model, tensor_pool_fn=tensor_pool_fn) self.assertTrue(isinstance(loss, namedtuples.GANLoss)) # Check values. with self.test_session(use_gpu=True) as sess: variables.global_variables_initializer().run() for _ in range(10): sess.run([loss.generator_loss, loss.discriminator_loss])
def test_train_hooks_exist_in_get_hooks_fn(self, create_gan_model_fn): model = create_gan_model_fn() loss = train.gan_loss(model) g_opt = get_sync_optimizer() d_opt = get_sync_optimizer() train_ops = train.gan_train_ops(model, loss, g_opt, d_opt, summarize_gradients=True, colocate_gradients_with_ops=True) sequential_train_hooks = train.get_sequential_train_hooks()(train_ops) self.assertLen(sequential_train_hooks, 4) sync_opts = [ hook._sync_optimizer for hook in sequential_train_hooks if isinstance( hook, sync_replicas_optimizer._SyncReplicasOptimizerHook) ] self.assertLen(sync_opts, 2) self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt))) joint_train_hooks = train.get_joint_train_hooks()(train_ops) self.assertLen(joint_train_hooks, 5) sync_opts = [ hook._sync_optimizer for hook in joint_train_hooks if isinstance( hook, sync_replicas_optimizer._SyncReplicasOptimizerHook) ] self.assertLen(sync_opts, 2) self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))
def test_train_hooks_exist_in_get_hooks_fn(self, create_gan_model_fn): model = create_gan_model_fn() loss = train.gan_loss(model) g_opt = get_sync_optimizer() d_opt = get_sync_optimizer() train_ops = train.gan_train_ops( model, loss, g_opt, d_opt, summarize_gradients=True, colocate_gradients_with_ops=True) sequential_train_hooks = train.get_sequential_train_hooks()(train_ops) self.assertLen(sequential_train_hooks, 4) sync_opts = [ hook._sync_optimizer for hook in sequential_train_hooks if isinstance(hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)] self.assertLen(sync_opts, 2) self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt))) joint_train_hooks = train.get_joint_train_hooks()(train_ops) self.assertLen(joint_train_hooks, 5) sync_opts = [ hook._sync_optimizer for hook in joint_train_hooks if isinstance(hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)] self.assertLen(sync_opts, 2) self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))
def _test_grad_penalty_helper(self, create_gan_model_fn): model = create_gan_model_fn() loss = train.gan_loss(model) loss_gp = train.gan_loss(model, gradient_penalty_weight=1.0) self.assertTrue(isinstance(loss_gp, namedtuples.GANLoss)) # Check values. with self.test_session(use_gpu=True) as sess: variables.global_variables_initializer().run() loss_gen_np, loss_gen_gp_np = sess.run( [loss.generator_loss, loss_gp.generator_loss]) loss_dis_np, loss_dis_gp_np = sess.run( [loss.discriminator_loss, loss_gp.discriminator_loss]) self.assertEqual(loss_gen_np, loss_gen_gp_np) self.assertTrue(loss_dis_np < loss_dis_gp_np)
def _test_sync_replicas_helper(self, create_gan_model_fn, create_global_step=False): model = create_gan_model_fn() loss = train.gan_loss(model) num_trainable_vars = len(variables_lib.get_trainable_variables()) if create_global_step: gstep = variable_scope.get_variable('custom_gstep', dtype=dtypes.int32, initializer=0, trainable=False) ops.add_to_collection(ops.GraphKeys.GLOBAL_STEP, gstep) g_opt = get_sync_optimizer() d_opt = get_sync_optimizer() train_ops = train.gan_train_ops(model, loss, generator_optimizer=g_opt, discriminator_optimizer=d_opt) self.assertTrue(isinstance(train_ops, namedtuples.GANTrainOps)) # No new trainable variables should have been added. self.assertEqual(num_trainable_vars, len(variables_lib.get_trainable_variables())) g_sync_init_op = g_opt.get_init_tokens_op(num_tokens=1) d_sync_init_op = d_opt.get_init_tokens_op(num_tokens=1) # Check that update op is run properly. global_step = training_util.get_or_create_global_step() with self.test_session(use_gpu=True) as sess: variables.global_variables_initializer().run() variables.local_variables_initializer().run() g_opt.chief_init_op.run() d_opt.chief_init_op.run() gstep_before = global_step.eval() # Start required queue runner for SyncReplicasOptimizer. coord = coordinator.Coordinator() g_threads = g_opt.get_chief_queue_runner().create_threads( sess, coord) d_threads = d_opt.get_chief_queue_runner().create_threads( sess, coord) g_sync_init_op.run() d_sync_init_op.run() train_ops.generator_train_op.eval() # Check that global step wasn't incremented. self.assertEqual(gstep_before, global_step.eval()) train_ops.discriminator_train_op.eval() # Check that global step wasn't incremented. self.assertEqual(gstep_before, global_step.eval()) coord.request_stop() coord.join(g_threads + d_threads)
def test_sync_replicas(self, create_gan_model_fn, create_global_step): model = create_gan_model_fn() loss = train.gan_loss(model) num_trainable_vars = len(variables_lib.get_trainable_variables()) if create_global_step: gstep = variable_scope.get_variable( 'custom_gstep', dtype=dtypes.int32, initializer=0, trainable=False) ops.add_to_collection(ops.GraphKeys.GLOBAL_STEP, gstep) g_opt = get_sync_optimizer() d_opt = get_sync_optimizer() train_ops = train.gan_train_ops( model, loss, generator_optimizer=g_opt, discriminator_optimizer=d_opt) self.assertIsInstance(train_ops, namedtuples.GANTrainOps) # No new trainable variables should have been added. self.assertLen(variables_lib.get_trainable_variables(), num_trainable_vars) # Sync hooks should be populated in the GANTrainOps. self.assertLen(train_ops.train_hooks, 2) for hook in train_ops.train_hooks: self.assertIsInstance( hook, sync_replicas_optimizer._SyncReplicasOptimizerHook) sync_opts = [hook._sync_optimizer for hook in train_ops.train_hooks] self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt))) g_sync_init_op = g_opt.get_init_tokens_op(num_tokens=1) d_sync_init_op = d_opt.get_init_tokens_op(num_tokens=1) # Check that update op is run properly. global_step = training_util.get_or_create_global_step() with self.test_session(use_gpu=True) as sess: variables.global_variables_initializer().run() variables.local_variables_initializer().run() g_opt.chief_init_op.run() d_opt.chief_init_op.run() gstep_before = global_step.eval() # Start required queue runner for SyncReplicasOptimizer. coord = coordinator.Coordinator() g_threads = g_opt.get_chief_queue_runner().create_threads(sess, coord) d_threads = d_opt.get_chief_queue_runner().create_threads(sess, coord) g_sync_init_op.run() d_sync_init_op.run() train_ops.generator_train_op.eval() # Check that global step wasn't incremented. self.assertEqual(gstep_before, global_step.eval()) train_ops.discriminator_train_op.eval() # Check that global step wasn't incremented. self.assertEqual(gstep_before, global_step.eval()) coord.request_stop() coord.join(g_threads + d_threads)
def test_discriminator_only_sees_pool(self): """Checks that discriminator only sees pooled values.""" def checker_gen_fn(_): return constant_op.constant(0.0) model = train.gan_model( checker_gen_fn, discriminator_model, real_data=array_ops.zeros([]), generator_inputs=random_ops.random_normal([])) def tensor_pool_fn(_): return (random_ops.random_uniform([]), random_ops.random_uniform([])) def checker_dis_fn(inputs, _): """Discriminator that checks that it only sees pooled Tensors.""" self.assertFalse(constant_op.is_constant(inputs)) return inputs model = model._replace( discriminator_fn=checker_dis_fn) train.gan_loss(model, tensor_pool_fn=tensor_pool_fn)
def _test_unused_update_ops(self, create_gan_model_fn, provide_update_ops): model = create_gan_model_fn() loss = train.gan_loss(model) # Add generator and discriminator update ops. with variable_scope.variable_scope(model.generator_scope): gen_update_count = variable_scope.get_variable('gen_count', initializer=0) gen_update_op = gen_update_count.assign_add(1) ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, gen_update_op) with variable_scope.variable_scope(model.discriminator_scope): dis_update_count = variable_scope.get_variable('dis_count', initializer=0) dis_update_op = dis_update_count.assign_add(1) ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, dis_update_op) # Add an update op outside the generator and discriminator scopes. if provide_update_ops: kwargs = { 'update_ops': [constant_op.constant(1.0), gen_update_op, dis_update_op] } else: ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, constant_op.constant(1.0)) kwargs = {} g_opt = gradient_descent.GradientDescentOptimizer(1.0) d_opt = gradient_descent.GradientDescentOptimizer(1.0) with self.assertRaisesRegexp(ValueError, 'There are unused update ops:'): train.gan_train_ops(model, loss, g_opt, d_opt, check_for_unused_update_ops=True, **kwargs) train_ops = train.gan_train_ops(model, loss, g_opt, d_opt, check_for_unused_update_ops=False, **kwargs) with self.test_session(use_gpu=True) as sess: sess.run(variables.global_variables_initializer()) self.assertEqual(0, gen_update_count.eval()) self.assertEqual(0, dis_update_count.eval()) train_ops.generator_train_op.eval() self.assertEqual(1, gen_update_count.eval()) self.assertEqual(0, dis_update_count.eval()) train_ops.discriminator_train_op.eval() self.assertEqual(1, gen_update_count.eval()) self.assertEqual(1, dis_update_count.eval())
def test_grad_penalty(self, create_gan_model_fn, one_sided): """Test gradient penalty option.""" model = create_gan_model_fn() loss = train.gan_loss(model) loss_gp = train.gan_loss(model, gradient_penalty_weight=1.0, gradient_penalty_one_sided=one_sided) self.assertIsInstance(loss_gp, namedtuples.GANLoss) # Check values. with self.test_session(use_gpu=True) as sess: variables.global_variables_initializer().run() loss_gen_np, loss_gen_gp_np = sess.run( [loss.generator_loss, loss_gp.generator_loss]) loss_dis_np, loss_dis_gp_np = sess.run( [loss.discriminator_loss, loss_gp.discriminator_loss]) self.assertEqual(loss_gen_np, loss_gen_gp_np) self.assertLess(loss_dis_np, loss_dis_gp_np)
def test_grad_penalty(self, create_gan_model_fn, one_sided): """Test gradient penalty option.""" model = create_gan_model_fn() loss = train.gan_loss(model) loss_gp = train.gan_loss( model, gradient_penalty_weight=1.0, gradient_penalty_one_sided=one_sided) self.assertIsInstance(loss_gp, namedtuples.GANLoss) # Check values. with self.test_session(use_gpu=True) as sess: variables.global_variables_initializer().run() loss_gen_np, loss_gen_gp_np = sess.run( [loss.generator_loss, loss_gp.generator_loss]) loss_dis_np, loss_dis_gp_np = sess.run( [loss.discriminator_loss, loss_gp.discriminator_loss]) self.assertEqual(loss_gen_np, loss_gen_gp_np) self.assertLess(loss_dis_np, loss_dis_gp_np)
def test_tensor_pool(self, create_gan_model_fn): """Test tensor pool option.""" model = create_gan_model_fn() tensor_pool_fn = lambda x: random_tensor_pool.tensor_pool(x, pool_size=5) loss = train.gan_loss(model, tensor_pool_fn=tensor_pool_fn) self.assertIsInstance(loss, namedtuples.GANLoss) # Check values. with self.test_session(use_gpu=True) as sess: variables.global_variables_initializer().run() for _ in range(10): sess.run([loss.generator_loss, loss.discriminator_loss])
def _test_regularization_helper(self, get_gan_model_fn): # Evaluate losses without regularization. no_reg_loss = train.gan_loss(get_gan_model_fn()) with self.test_session(use_gpu=True): no_reg_loss_gen_np = no_reg_loss.generator_loss.eval() no_reg_loss_dis_np = no_reg_loss.discriminator_loss.eval() with ops.name_scope(get_gan_model_fn().generator_scope.name): ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(3.0)) with ops.name_scope(get_gan_model_fn().discriminator_scope.name): ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(2.0)) # Check that losses now include the correct regularization values. reg_loss = train.gan_loss(get_gan_model_fn()) with self.test_session(use_gpu=True): reg_loss_gen_np = reg_loss.generator_loss.eval() reg_loss_dis_np = reg_loss.discriminator_loss.eval() self.assertTrue(3.0, reg_loss_gen_np - no_reg_loss_gen_np) self.assertTrue(3.0, reg_loss_dis_np - no_reg_loss_dis_np)
def _test_regularization_helper(self, get_gan_model_fn): # Evaluate losses without regularization. no_reg_loss = train.gan_loss(get_gan_model_fn()) with self.test_session(use_gpu=True): no_reg_loss_gen_np = no_reg_loss.generator_loss.eval() no_reg_loss_dis_np = no_reg_loss.discriminator_loss.eval() with ops.name_scope(get_gan_model_fn().generator_scope.name): ops.add_to_collection( ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(3.0)) with ops.name_scope(get_gan_model_fn().discriminator_scope.name): ops.add_to_collection( ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(2.0)) # Check that losses now include the correct regularization values. reg_loss = train.gan_loss(get_gan_model_fn()) with self.test_session(use_gpu=True): reg_loss_gen_np = reg_loss.generator_loss.eval() reg_loss_dis_np = reg_loss.discriminator_loss.eval() self.assertTrue(3.0, reg_loss_gen_np - no_reg_loss_gen_np) self.assertTrue(3.0, reg_loss_dis_np - no_reg_loss_dis_np)
def _test_tensor_pool_helper(self, create_gan_model_fn): model = create_gan_model_fn() if isinstance(model, namedtuples.InfoGANModel): tensor_pool_fn = get_tensor_pool_fn_for_infogan(pool_size=5) else: tensor_pool_fn = get_tensor_pool_fn(pool_size=5) loss = train.gan_loss(model, tensor_pool_fn=tensor_pool_fn) self.assertTrue(isinstance(loss, namedtuples.GANLoss)) # Check values. with self.test_session(use_gpu=True) as sess: variables.global_variables_initializer().run() for _ in range(10): sess.run([loss.generator_loss, loss.discriminator_loss])
def _test_output_type_helper(self, create_gan_model_fn): model = create_gan_model_fn() loss = train.gan_loss(model) g_opt = gradient_descent.GradientDescentOptimizer(1.0) d_opt = gradient_descent.GradientDescentOptimizer(1.0) train_ops = train.gan_train_ops(model, loss, g_opt, d_opt, summarize_gradients=True, colocate_gradients_with_ops=True) self.assertTrue(isinstance(train_ops, namedtuples.GANTrainOps))
def _test_run_helper(self, create_gan_model_fn): random_seed.set_random_seed(1234) model = create_gan_model_fn() loss = train.gan_loss(model) g_opt = gradient_descent.GradientDescentOptimizer(1.0) d_opt = gradient_descent.GradientDescentOptimizer(1.0) train_ops = train.gan_train_ops(model, loss, g_opt, d_opt) final_step = train.gan_train( train_ops, logdir='', hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=2)]) self.assertTrue(np.isscalar(final_step)) self.assertEqual(2, final_step)
def _test_output_type_helper(self, create_gan_model_fn): model = create_gan_model_fn() loss = train.gan_loss(model) g_opt = gradient_descent.GradientDescentOptimizer(1.0) d_opt = gradient_descent.GradientDescentOptimizer(1.0) train_ops = train.gan_train_ops( model, loss, g_opt, d_opt, summarize_gradients=True, colocate_gradients_with_ops=True) self.assertTrue(isinstance(train_ops, namedtuples.GANTrainOps))
def test_patchgan(self, create_gan_model_fn): """Ensure that patch-based discriminators work end-to-end.""" random_seed.set_random_seed(1234) model = create_gan_model_fn() loss = train.gan_loss(model) g_opt = gradient_descent.GradientDescentOptimizer(1.0) d_opt = gradient_descent.GradientDescentOptimizer(1.0) train_ops = train.gan_train_ops(model, loss, g_opt, d_opt) final_step = train.gan_train( train_ops, logdir='', hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=2)]) self.assertTrue(np.isscalar(final_step)) self.assertEqual(2, final_step)
def test_output_type(self, create_gan_model_fn): model = create_gan_model_fn() loss = train.gan_loss(model) g_opt = gradient_descent.GradientDescentOptimizer(1.0) d_opt = gradient_descent.GradientDescentOptimizer(1.0) train_ops = train.gan_train_ops(model, loss, g_opt, d_opt, summarize_gradients=True, colocate_gradients_with_ops=True) self.assertIsInstance(train_ops, namedtuples.GANTrainOps) # Make sure there are no training hooks populated accidentally. self.assertEmpty(train_ops.train_hooks)
def test_unused_update_ops(self, create_gan_model_fn, provide_update_ops): model = create_gan_model_fn() loss = train.gan_loss(model) # Add generator and discriminator update ops. with variable_scope.variable_scope(model.generator_scope): gen_update_count = variable_scope.get_variable('gen_count', initializer=0) gen_update_op = gen_update_count.assign_add(1) ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, gen_update_op) with variable_scope.variable_scope(model.discriminator_scope): dis_update_count = variable_scope.get_variable('dis_count', initializer=0) dis_update_op = dis_update_count.assign_add(1) ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, dis_update_op) # Add an update op outside the generator and discriminator scopes. if provide_update_ops: kwargs = { 'update_ops': [ constant_op.constant(1.0), gen_update_op, dis_update_op ] } else: ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, constant_op.constant(1.0)) kwargs = {} g_opt = gradient_descent.GradientDescentOptimizer(1.0) d_opt = gradient_descent.GradientDescentOptimizer(1.0) with self.assertRaisesRegexp(ValueError, 'There are unused update ops:'): train.gan_train_ops( model, loss, g_opt, d_opt, check_for_unused_update_ops=True, **kwargs) train_ops = train.gan_train_ops( model, loss, g_opt, d_opt, check_for_unused_update_ops=False, **kwargs) with self.test_session(use_gpu=True) as sess: sess.run(variables.global_variables_initializer()) self.assertEqual(0, gen_update_count.eval()) self.assertEqual(0, dis_update_count.eval()) train_ops.generator_train_op.eval() self.assertEqual(1, gen_update_count.eval()) self.assertEqual(0, dis_update_count.eval()) train_ops.discriminator_train_op.eval() self.assertEqual(1, gen_update_count.eval()) self.assertEqual(1, dis_update_count.eval())
def test_output_type(self, create_gan_model_fn): model = create_gan_model_fn() loss = train.gan_loss(model) g_opt = gradient_descent.GradientDescentOptimizer(1.0) d_opt = gradient_descent.GradientDescentOptimizer(1.0) train_ops = train.gan_train_ops( model, loss, g_opt, d_opt, summarize_gradients=True, colocate_gradients_with_ops=True) self.assertIsInstance(train_ops, namedtuples.GANTrainOps) # Make sure there are no training hooks populated accidentally. self.assertEmpty(train_ops.train_hooks)
def test_is_chief_in_train_hooks(self, is_chief): """Make sure is_chief is propagated correctly to sync hooks.""" model = create_gan_model() loss = train.gan_loss(model) g_opt = get_sync_optimizer() d_opt = get_sync_optimizer() train_ops = train.gan_train_ops(model, loss, g_opt, d_opt, is_chief=is_chief, summarize_gradients=True, colocate_gradients_with_ops=True) self.assertLen(train_ops.train_hooks, 2) for hook in train_ops.train_hooks: self.assertIsInstance( hook, sync_replicas_optimizer._SyncReplicasOptimizerHook) is_chief_list = [hook._is_chief for hook in train_ops.train_hooks] self.assertListEqual(is_chief_list, [is_chief, is_chief])
def test_is_chief_in_train_hooks(self, is_chief): """Make sure is_chief is propagated correctly to sync hooks.""" model = create_gan_model() loss = train.gan_loss(model) g_opt = get_sync_optimizer() d_opt = get_sync_optimizer() train_ops = train.gan_train_ops( model, loss, g_opt, d_opt, is_chief=is_chief, summarize_gradients=True, colocate_gradients_with_ops=True) self.assertLen(train_ops.train_hooks, 2) for hook in train_ops.train_hooks: self.assertIsInstance( hook, sync_replicas_optimizer._SyncReplicasOptimizerHook) is_chief_list = [hook._is_chief for hook in train_ops.train_hooks] self.assertListEqual(is_chief_list, [is_chief, is_chief])
def test_output_type(self, get_gan_model_fn): """Test output type.""" loss = train.gan_loss(get_gan_model_fn(), add_summaries=True) self.assertIsInstance(loss, namedtuples.GANLoss) self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0)
def test_mutual_info_penalty(self, create_gan_model_fn): """Test mutual information penalty option.""" train.gan_loss( create_gan_model_fn(), mutual_information_penalty_weight=constant_op.constant(1.0))
def _test_mutual_info_penalty_helper(self, create_gan_model_fn): train.gan_loss( create_gan_model_fn(), mutual_information_penalty_weight=constant_op.constant(1.0))
def _test_output_type_helper(self, get_gan_model_fn): loss = train.gan_loss(get_gan_model_fn(), add_summaries=True) self.assertTrue(isinstance(loss, namedtuples.GANLoss)) self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0)
def _test_mutual_info_penalty_helper(self, create_gan_model_fn): train.gan_loss(create_gan_model_fn(), mutual_information_penalty_weight=constant_op.constant(1.0))