def test_patchgan(self, create_gan_model_fn): """Ensure that patch-based discriminators work end-to-end.""" if tf.executing_eagerly(): # None of the usual utilities work in eager. return tf.compat.v1.random.set_random_seed(1234) model = create_gan_model_fn() loss = tfgan.gan_loss(model) g_opt = tf.compat.v1.train.GradientDescentOptimizer(1.0) d_opt = tf.compat.v1.train.GradientDescentOptimizer(1.0) train_ops = tfgan.gan_train_ops(model, loss, g_opt, d_opt) final_step = tfgan.gan_train( train_ops, logdir='', hooks=[tf.estimator.StopAtStepHook(num_steps=2)]) self.assertTrue(np.isscalar(final_step)) self.assertEqual(2, final_step)
def test_get_sync_estimator_spec(self): """Make sure spec is loaded with sync hooks for sync opts.""" with tf.Graph().as_default(): gan_model = get_dummy_gan_model() gan_loss = tfgan.gan_loss(gan_model, dummy_loss_fn, dummy_loss_fn) g_opt = get_sync_optimizer() d_opt = get_sync_optimizer() spec = get_train_estimator_spec(gan_model, gan_loss, Optimizers(g_opt, d_opt), get_hooks_fn=None) # use default. self.assertLen(spec.training_hooks, 4) sync_opts = [ hook._sync_optimizer for hook in spec.training_hooks if isinstance(hook, get_sync_optimizer_hook_type()) ] self.assertLen(sync_opts, 2) self.assertSetEqual(frozenset(sync_opts), frozenset( (g_opt, d_opt)))
def test_output_type(self, create_gan_model_fn): if tf.executing_eagerly(): # None of the usual utilities work in eager. return model = create_gan_model_fn() loss = tfgan.gan_loss(model) g_opt = tf.compat.v1.train.GradientDescentOptimizer(1.0) d_opt = tf.compat.v1.train.GradientDescentOptimizer(1.0) train_ops = tfgan.gan_train_ops( model, loss, g_opt, d_opt, summarize_gradients=True, colocate_gradients_with_ops=True) self.assertIsInstance(train_ops, tfgan.GANTrainOps) # Make sure there are no training hooks populated accidentally. self.assertEmpty(train_ops.train_hooks)
def define_loss(gan_model, **kwargs): """Defines progressive GAN losses. The generator and discriminator both use wasserstein loss. In addition, a small penalty term is added to the discriminator loss to prevent it getting too large. Args: gan_model: A `GANModel` namedtuple. **kwargs: A dictionary of 'gradient_penalty_weight': A float of gradient norm target for wasserstein loss. 'gradient_penalty_target': A float of gradient penalty weight for wasserstein loss. 'real_score_penalty_weight': A float of Additional penalty to keep the scores from drifting too far from zero. Returns: A `GANLoss` namedtuple. """ gan_loss = tfgan.gan_loss( gan_model, generator_loss_fn=tfgan.losses.wasserstein_generator_loss, discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss, gradient_penalty_weight=kwargs['gradient_penalty_weight'], gradient_penalty_target=kwargs['gradient_penalty_target'], gradient_penalty_epsilon=0.0) real_score_penalty = tf.reduce_mean( input_tensor=tf.square(gan_model.discriminator_real_outputs)) tf.compat.v1.summary.scalar('real_score_penalty', real_score_penalty) return gan_loss._replace( discriminator_loss=( gan_loss.discriminator_loss + kwargs['real_score_penalty_weight'] * real_score_penalty))
def test_train_hooks_exist_in_get_hooks_fn(self, create_gan_model_fn): if tf.executing_eagerly(): # None of the usual utilities work in eager. return model = create_gan_model_fn() loss = tfgan.gan_loss(model) g_opt = get_sync_optimizer() d_opt = get_sync_optimizer() train_ops = tfgan.gan_train_ops( model, loss, g_opt, d_opt, summarize_gradients=True, colocate_gradients_with_ops=True) sequential_train_hooks = tfgan.get_sequential_train_hooks()(train_ops) self.assertLen(sequential_train_hooks, 4) sync_opts = [ hook._sync_optimizer for hook in sequential_train_hooks if isinstance(hook, get_sync_optimizer_hook_type()) ] self.assertLen(sync_opts, 2) self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt))) joint_train_hooks = tfgan.get_joint_train_hooks()(train_ops) self.assertLen(joint_train_hooks, 5) sync_opts = [ hook._sync_optimizer for hook in joint_train_hooks if isinstance(hook, get_sync_optimizer_hook_type()) ] self.assertLen(sync_opts, 2) self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))
def train(hparams): """Trains an MNIST GAN. Args: hparams: An HParams instance containing the hyperparameters for training. """ if not tf.io.gfile.exists(hparams.train_log_dir): tf.io.gfile.makedirs(hparams.train_log_dir) # Force all input processing onto CPU in order to reserve the GPU for # the forward inference and back-propagation. with tf.name_scope('inputs'), tf.device('/cpu:0'): images, one_hot_labels = data_provider.provide_data( 'train', hparams.batch_size, num_parallel_calls=4) # Define the GANModel tuple. Optionally, condition the GAN on the label or # use an InfoGAN to learn a latent representation. if hparams.gan_type == 'unconditional': gan_model = tfgan.gan_model( generator_fn=networks.unconditional_generator, discriminator_fn=networks.unconditional_discriminator, real_data=images, generator_inputs=tf.random.normal( [hparams.batch_size, hparams.noise_dims])) elif hparams.gan_type == 'conditional': noise = tf.random.normal([hparams.batch_size, hparams.noise_dims]) gan_model = tfgan.gan_model( generator_fn=networks.conditional_generator, discriminator_fn=networks.conditional_discriminator, real_data=images, generator_inputs=(noise, one_hot_labels)) elif hparams.gan_type == 'infogan': cat_dim, cont_dim = 10, 2 generator_fn = functools.partial(networks.infogan_generator, categorical_dim=cat_dim) discriminator_fn = functools.partial(networks.infogan_discriminator, categorical_dim=cat_dim, continuous_dim=cont_dim) unstructured_inputs, structured_inputs = util.get_infogan_noise( hparams.batch_size, cat_dim, cont_dim, hparams.noise_dims) gan_model = tfgan.infogan_model( generator_fn=generator_fn, discriminator_fn=discriminator_fn, real_data=images, unstructured_generator_inputs=unstructured_inputs, structured_generator_inputs=structured_inputs) tfgan.eval.add_gan_model_image_summaries(gan_model, hparams.grid_size) # Get the GANLoss tuple. You can pass a custom function, use one of the # already-implemented losses from the losses library, or use the defaults. with tf.name_scope('loss'): if hparams.gan_type == 'infogan': gan_loss = tfgan.gan_loss( gan_model, generator_loss_fn=tfgan.losses.modified_generator_loss, discriminator_loss_fn=tfgan.losses.modified_discriminator_loss, mutual_information_penalty_weight=1.0, add_summaries=True) else: gan_loss = tfgan.gan_loss(gan_model, add_summaries=True) tfgan.eval.add_regularization_loss_summaries(gan_model) # Get the GANTrain ops using custom optimizers. with tf.name_scope('train'): gen_lr, dis_lr = _learning_rate(hparams.gan_type) train_ops = tfgan.gan_train_ops( gan_model, gan_loss, generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5), discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5), summarize_gradients=True, aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) # Run the alternating training loop. Skip it if no steps should be taken # (used for graph construction tests). status_message = tf.strings.join([ 'Starting train step: ', tf.as_string(tf.train.get_or_create_global_step()) ], name='status_message') if hparams.max_number_of_steps == 0: return tfgan.gan_train( train_ops, hooks=[ tf.estimator.StopAtStepHook(num_steps=hparams.max_number_of_steps), tf.estimator.LoggingTensorHook([status_message], every_n_iter=10) ], logdir=hparams.train_log_dir, get_hooks_fn=tfgan.get_joint_train_hooks(), save_checkpoint_secs=60)
def test_mutual_info_penalty(self, create_gan_model_fn): """Test mutual information penalty option.""" tfgan.gan_loss( create_gan_model_fn(), mutual_information_penalty_weight=tf.constant(1.0))
def test_args_passed_in_correctly(self): def loss_fn(gan_model, add_summaries): del gan_model self.assertFalse(add_summaries) return 0 tfgan.gan_loss(get_gan_model(), loss_fn, loss_fn, add_summaries=False)
def test_no_reduction_or_add_summaries_loss(self): def loss_fn(_): return 0 tfgan.gan_loss(get_gan_model(), loss_fn, loss_fn)
def test_sync_replicas(self, create_gan_model_fn, create_global_step): if tf.executing_eagerly(): # None of the usual utilities work in eager. return model = create_gan_model_fn() loss = tfgan.gan_loss(model) num_trainable_vars = len(get_trainable_variables()) if create_global_step: gstep = tf.compat.v1.get_variable( 'custom_gstep', dtype=tf.int32, initializer=0, trainable=False) tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.GLOBAL_STEP, gstep) g_opt = get_sync_optimizer() d_opt = get_sync_optimizer() train_ops = tfgan.gan_train_ops( model, loss, generator_optimizer=g_opt, discriminator_optimizer=d_opt) self.assertIsInstance(train_ops, tfgan.GANTrainOps) # No new trainable variables should have been added. self.assertLen(get_trainable_variables(), num_trainable_vars) # Sync hooks should be populated in the GANTraintf. self.assertLen(train_ops.train_hooks, 2) for hook in train_ops.train_hooks: self.assertIsInstance(hook, get_sync_optimizer_hook_type()) sync_opts = [hook._sync_optimizer for hook in train_ops.train_hooks] self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt))) g_sync_init_op = g_opt.get_init_tokens_op(num_tokens=1) d_sync_init_op = d_opt.get_init_tokens_op(num_tokens=1) # Check that update op is run properly. global_step = tf.compat.v1.train.get_or_create_global_step() with self.cached_session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) sess.run(tf.compat.v1.local_variables_initializer()) sess.run(g_opt.chief_init_op) sess.run(d_opt.chief_init_op) gstep_before = sess.run(global_step) # Start required queue runner for SyncReplicasOptimizer. coord = tf.train.Coordinator() g_threads = g_opt.get_chief_queue_runner().create_threads(sess, coord) d_threads = d_opt.get_chief_queue_runner().create_threads(sess, coord) sess.run(g_sync_init_op) sess.run(d_sync_init_op) sess.run(train_ops.generator_train_op) # Check that global step wasn't incremented. self.assertEqual(gstep_before, sess.run(global_step)) sess.run(train_ops.discriminator_train_op) # Check that global step wasn't incremented. self.assertEqual(gstep_before, sess.run(global_step)) coord.request_stop() coord.join(g_threads + d_threads)
def train(hparams): """Trains a CIFAR10 GAN. Args: hparams: An HParams instance containing the hyperparameters for training. """ if not tf.io.gfile.exists(hparams.train_log_dir): tf.io.gfile.makedirs(hparams.train_log_dir) with tf.device( tf.compat.v1.train.replica_device_setter(hparams.ps_replicas)): # Force all input processing onto CPU in order to reserve the GPU for # the forward inference and back-propagation. with tf.compat.v1.name_scope('inputs'): with tf.device('/cpu:0'): images, _ = data_provider.provide_data('train', hparams.batch_size, num_parallel_calls=4) # Define the GANModel tuple. generator_fn = networks.generator discriminator_fn = networks.discriminator generator_inputs = tf.random.normal([hparams.batch_size, 64]) gan_model = tfgan.gan_model(generator_fn, discriminator_fn, real_data=images, generator_inputs=generator_inputs) tfgan.eval.add_gan_model_image_summaries(gan_model) # Get the GANLoss tuple. Use the selected GAN loss functions. with tf.compat.v1.name_scope('loss'): gan_loss = tfgan.gan_loss(gan_model, gradient_penalty_weight=1.0, add_summaries=True) # Get the GANTrain ops using the custom optimizers and optional # discriminator weight clipping. with tf.compat.v1.name_scope('train'): gen_opt, dis_opt = _get_optimizers(hparams) train_ops = tfgan.gan_train_ops(gan_model, gan_loss, generator_optimizer=gen_opt, discriminator_optimizer=dis_opt, summarize_gradients=True) # Run the alternating training loop. Skip it if no steps should be taken # (used for graph construction tests). status_message = tf.strings.join([ 'Starting train step: ', tf.as_string(tf.compat.v1.train.get_or_create_global_step()) ], name='status_message') if hparams.max_number_of_steps == 0: return tfgan.gan_train(train_ops, hooks=([ tf.estimator.StopAtStepHook( num_steps=hparams.max_number_of_steps), tf.estimator.LoggingTensorHook([status_message], every_n_iter=10) ]), logdir=hparams.train_log_dir, master=hparams.master, is_chief=hparams.task == 0)