def test_unused_update_ops(self, create_gan_model_fn, provide_update_ops): if tf.executing_eagerly(): # None of the usual utilities work in eager. return model = create_gan_model_fn() loss = tfgan.gan_loss(model) # Add generator and discriminator update tf. with tf.compat.v1.variable_scope(model.generator_scope): gen_update_count = tf.compat.v1.get_variable('gen_count', initializer=0) gen_update_op = gen_update_count.assign_add(1) tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.UPDATE_OPS, gen_update_op) with tf.compat.v1.variable_scope(model.discriminator_scope): dis_update_count = tf.compat.v1.get_variable('dis_count', initializer=0) dis_update_op = dis_update_count.assign_add(1) tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.UPDATE_OPS, dis_update_op) # Add an update op outside the generator and discriminator scopes. if provide_update_ops: kwargs = {'update_ops': [tf.constant(1.0), gen_update_op, dis_update_op]} else: tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.UPDATE_OPS, tf.constant(1.0)) kwargs = {} g_opt = tf.compat.v1.train.GradientDescentOptimizer(1.0) d_opt = tf.compat.v1.train.GradientDescentOptimizer(1.0) with self.assertRaisesRegexp(ValueError, 'There are unused update ops:'): tfgan.gan_train_ops( model, loss, g_opt, d_opt, check_for_unused_update_ops=True, **kwargs) train_ops = tfgan.gan_train_ops( model, loss, g_opt, d_opt, check_for_unused_update_ops=False, **kwargs) with self.cached_session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) self.assertEqual(0, sess.run(gen_update_count)) self.assertEqual(0, sess.run(dis_update_count)) sess.run(train_ops.generator_train_op) self.assertEqual(1, sess.run(gen_update_count)) self.assertEqual(0, sess.run(dis_update_count)) sess.run(train_ops.discriminator_train_op) self.assertEqual(1, sess.run(gen_update_count)) self.assertEqual(1, sess.run(dis_update_count))
def define_train_ops(gan_model, gan_loss, **kwargs): """Defines progressive GAN train ops. Args: gan_model: A `GANModel` namedtuple. gan_loss: A `GANLoss` namedtuple. **kwargs: A dictionary of 'adam_beta1': A float of Adam optimizer beta1. 'adam_beta2': A float of Adam optimizer beta2. 'generator_learning_rate': A float of generator learning rate. 'discriminator_learning_rate': A float of discriminator learning rate. Returns: A tuple of `GANTrainOps` namedtuple and a list variables tracking the state of optimizers. """ with tf.compat.v1.variable_scope('progressive_gan_train_ops') as var_scope: beta1, beta2 = kwargs['adam_beta1'], kwargs['adam_beta2'] gen_opt = tf.compat.v1.train.AdamOptimizer( kwargs['generator_learning_rate'], beta1, beta2) dis_opt = tf.compat.v1.train.AdamOptimizer( kwargs['discriminator_learning_rate'], beta1, beta2) gan_train_ops = tfgan.gan_train_ops(gan_model, gan_loss, gen_opt, dis_opt) return gan_train_ops, tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, scope=var_scope.name)
def _define_train_ops(cyclegan_model, cyclegan_loss, hparams): """Defines train ops that trains `cyclegan_model` with `cyclegan_loss`. Args: cyclegan_model: A `CycleGANModel` namedtuple. cyclegan_loss: A `CycleGANLoss` namedtuple containing all losses for `cyclegan_model`. hparams: An HParams instance containing the hyperparameters for training. Returns: A `GANTrainOps` namedtuple. """ gen_lr = _get_lr(hparams.generator_lr, hparams.max_number_of_steps) dis_lr = _get_lr(hparams.discriminator_lr, hparams.max_number_of_steps) gen_opt, dis_opt = _get_optimizer(gen_lr, dis_lr) train_ops = tfgan.gan_train_ops( cyclegan_model, cyclegan_loss, generator_optimizer=gen_opt, discriminator_optimizer=dis_opt, summarize_gradients=True, colocate_gradients_with_ops=True, aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) tf.summary.scalar('generator_lr', gen_lr) tf.summary.scalar('discriminator_lr', dis_lr) return train_ops
def _define_train_ops(cyclegan_model, cyclegan_loss): """Defines train ops that trains `cyclegan_model` with `cyclegan_loss`. Args: cyclegan_model: A `CycleGANModel` namedtuple. cyclegan_loss: A `CycleGANLoss` namedtuple containing all losses for `cyclegan_model`. Returns: A `GANTrainOps` namedtuple. """ gen_lr = _get_lr(FLAGS.generator_lr) dis_lr = _get_lr(FLAGS.discriminator_lr) gen_opt, dis_opt = _get_optimizer(gen_lr, dis_lr) train_ops = tfgan.gan_train_ops( cyclegan_model, cyclegan_loss, generator_optimizer=gen_opt, discriminator_optimizer=dis_opt, summarize_gradients=True, colocate_gradients_with_ops=True, check_for_unused_update_ops=False, aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) tf.summary.scalar('generator_lr', gen_lr) tf.summary.scalar('discriminator_lr', dis_lr) return train_ops
def test_is_chief_in_train_hooks(self, is_chief): """Make sure is_chief is propagated correctly to sync hooks.""" if tf.executing_eagerly(): # None of the usual utilities work in eager. return model = create_gan_model() loss = tfgan.gan_loss(model) g_opt = get_sync_optimizer() d_opt = get_sync_optimizer() train_ops = tfgan.gan_train_ops( model, loss, g_opt, d_opt, is_chief=is_chief, summarize_gradients=True, colocate_gradients_with_ops=True) self.assertLen(train_ops.train_hooks, 2) for hook in train_ops.train_hooks: self.assertIsInstance(hook, get_sync_optimizer_hook_type()) is_chief_list = [hook._is_chief for hook in train_ops.train_hooks] self.assertListEqual(is_chief_list, [is_chief, is_chief])
def test_run_helper(self, create_gan_model_fn): if tf.executing_eagerly(): # None of the usual utilities work in eager. return tf.compat.v1.random.set_random_seed(1234) model = create_gan_model_fn() loss = tfgan.gan_loss(model) g_opt = tf.compat.v1.train.GradientDescentOptimizer(1.0) d_opt = tf.compat.v1.train.GradientDescentOptimizer(1.0) train_ops = tfgan.gan_train_ops(model, loss, g_opt, d_opt) final_step = tfgan.gan_train( train_ops, logdir='', hooks=[tf.estimator.StopAtStepHook(num_steps=2)]) self.assertTrue(np.isscalar(final_step)) self.assertEqual(2, final_step)
def test_output_type(self, create_gan_model_fn): if tf.executing_eagerly(): # None of the usual utilities work in eager. return model = create_gan_model_fn() loss = tfgan.gan_loss(model) g_opt = tf.compat.v1.train.GradientDescentOptimizer(1.0) d_opt = tf.compat.v1.train.GradientDescentOptimizer(1.0) train_ops = tfgan.gan_train_ops( model, loss, g_opt, d_opt, summarize_gradients=True, colocate_gradients_with_ops=True) self.assertIsInstance(train_ops, tfgan.GANTrainOps) # Make sure there are no training hooks populated accidentally. self.assertEmpty(train_ops.train_hooks)
def _define_train_ops(model, loss, gen_lr, dis_lr, beta1, beta2, max_number_of_steps): """Defines train ops that trains `stargan_model` with `stargan_loss`. Args: model: A `StarGANModel` namedtuple. loss: A `StarGANLoss` namedtuple containing all losses for `stargan_model`. gen_lr: A scalar float `Tensor` or a Python number. The Generator base learning rate. dis_lr: A scalar float `Tensor` or a Python number. The Discriminator base learning rate. beta1: A scalar float `Tensor` or a Python number. The beta1 parameter to the `AdamOptimizer`. beta2: A scalar float `Tensor` or a Python number. The beta2 parameter to the `AdamOptimizer`. max_number_of_steps: A Python number. The total number of steps to train. Returns: A `GANTrainOps` namedtuple. """ gen_lr = _get_lr(gen_lr, max_number_of_steps) dis_lr = _get_lr(dis_lr, max_number_of_steps) gen_opt, dis_opt = _get_optimizer(gen_lr, dis_lr, beta1, beta2) train_ops = tfgan.gan_train_ops( model, loss, generator_optimizer=gen_opt, discriminator_optimizer=dis_opt, summarize_gradients=True, colocate_gradients_with_ops=True, aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) tf.summary.scalar('generator_lr', gen_lr) tf.summary.scalar('discriminator_lr', dis_lr) return train_ops
def test_train_hooks_exist_in_get_hooks_fn(self, create_gan_model_fn): if tf.executing_eagerly(): # None of the usual utilities work in eager. return model = create_gan_model_fn() loss = tfgan.gan_loss(model) g_opt = get_sync_optimizer() d_opt = get_sync_optimizer() train_ops = tfgan.gan_train_ops( model, loss, g_opt, d_opt, summarize_gradients=True, colocate_gradients_with_ops=True) sequential_train_hooks = tfgan.get_sequential_train_hooks()(train_ops) self.assertLen(sequential_train_hooks, 4) sync_opts = [ hook._sync_optimizer for hook in sequential_train_hooks if isinstance(hook, get_sync_optimizer_hook_type()) ] self.assertLen(sync_opts, 2) self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt))) joint_train_hooks = tfgan.get_joint_train_hooks()(train_ops) self.assertLen(joint_train_hooks, 5) sync_opts = [ hook._sync_optimizer for hook in joint_train_hooks if isinstance(hook, get_sync_optimizer_hook_type()) ] self.assertLen(sync_opts, 2) self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))
def train(hparams): """Trains an MNIST GAN. Args: hparams: An HParams instance containing the hyperparameters for training. """ if not tf.io.gfile.exists(hparams.train_log_dir): tf.io.gfile.makedirs(hparams.train_log_dir) # Force all input processing onto CPU in order to reserve the GPU for # the forward inference and back-propagation. with tf.name_scope('inputs'), tf.device('/cpu:0'): images, one_hot_labels = data_provider.provide_data( 'train', hparams.batch_size, num_parallel_calls=4) # Define the GANModel tuple. Optionally, condition the GAN on the label or # use an InfoGAN to learn a latent representation. if hparams.gan_type == 'unconditional': gan_model = tfgan.gan_model( generator_fn=networks.unconditional_generator, discriminator_fn=networks.unconditional_discriminator, real_data=images, generator_inputs=tf.random.normal( [hparams.batch_size, hparams.noise_dims])) elif hparams.gan_type == 'conditional': noise = tf.random.normal([hparams.batch_size, hparams.noise_dims]) gan_model = tfgan.gan_model( generator_fn=networks.conditional_generator, discriminator_fn=networks.conditional_discriminator, real_data=images, generator_inputs=(noise, one_hot_labels)) elif hparams.gan_type == 'infogan': cat_dim, cont_dim = 10, 2 generator_fn = functools.partial(networks.infogan_generator, categorical_dim=cat_dim) discriminator_fn = functools.partial(networks.infogan_discriminator, categorical_dim=cat_dim, continuous_dim=cont_dim) unstructured_inputs, structured_inputs = util.get_infogan_noise( hparams.batch_size, cat_dim, cont_dim, hparams.noise_dims) gan_model = tfgan.infogan_model( generator_fn=generator_fn, discriminator_fn=discriminator_fn, real_data=images, unstructured_generator_inputs=unstructured_inputs, structured_generator_inputs=structured_inputs) tfgan.eval.add_gan_model_image_summaries(gan_model, hparams.grid_size) # Get the GANLoss tuple. You can pass a custom function, use one of the # already-implemented losses from the losses library, or use the defaults. with tf.name_scope('loss'): if hparams.gan_type == 'infogan': gan_loss = tfgan.gan_loss( gan_model, generator_loss_fn=tfgan.losses.modified_generator_loss, discriminator_loss_fn=tfgan.losses.modified_discriminator_loss, mutual_information_penalty_weight=1.0, add_summaries=True) else: gan_loss = tfgan.gan_loss(gan_model, add_summaries=True) tfgan.eval.add_regularization_loss_summaries(gan_model) # Get the GANTrain ops using custom optimizers. with tf.name_scope('train'): gen_lr, dis_lr = _learning_rate(hparams.gan_type) train_ops = tfgan.gan_train_ops( gan_model, gan_loss, generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5), discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5), summarize_gradients=True, aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) # Run the alternating training loop. Skip it if no steps should be taken # (used for graph construction tests). status_message = tf.strings.join([ 'Starting train step: ', tf.as_string(tf.train.get_or_create_global_step()) ], name='status_message') if hparams.max_number_of_steps == 0: return tfgan.gan_train( train_ops, hooks=[ tf.estimator.StopAtStepHook(num_steps=hparams.max_number_of_steps), tf.estimator.LoggingTensorHook([status_message], every_n_iter=10) ], logdir=hparams.train_log_dir, get_hooks_fn=tfgan.get_joint_train_hooks(), save_checkpoint_secs=60)
def test_sync_replicas(self, create_gan_model_fn, create_global_step): if tf.executing_eagerly(): # None of the usual utilities work in eager. return model = create_gan_model_fn() loss = tfgan.gan_loss(model) num_trainable_vars = len(get_trainable_variables()) if create_global_step: gstep = tf.compat.v1.get_variable( 'custom_gstep', dtype=tf.int32, initializer=0, trainable=False) tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.GLOBAL_STEP, gstep) g_opt = get_sync_optimizer() d_opt = get_sync_optimizer() train_ops = tfgan.gan_train_ops( model, loss, generator_optimizer=g_opt, discriminator_optimizer=d_opt) self.assertIsInstance(train_ops, tfgan.GANTrainOps) # No new trainable variables should have been added. self.assertLen(get_trainable_variables(), num_trainable_vars) # Sync hooks should be populated in the GANTraintf. self.assertLen(train_ops.train_hooks, 2) for hook in train_ops.train_hooks: self.assertIsInstance(hook, get_sync_optimizer_hook_type()) sync_opts = [hook._sync_optimizer for hook in train_ops.train_hooks] self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt))) g_sync_init_op = g_opt.get_init_tokens_op(num_tokens=1) d_sync_init_op = d_opt.get_init_tokens_op(num_tokens=1) # Check that update op is run properly. global_step = tf.compat.v1.train.get_or_create_global_step() with self.cached_session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) sess.run(tf.compat.v1.local_variables_initializer()) sess.run(g_opt.chief_init_op) sess.run(d_opt.chief_init_op) gstep_before = sess.run(global_step) # Start required queue runner for SyncReplicasOptimizer. coord = tf.train.Coordinator() g_threads = g_opt.get_chief_queue_runner().create_threads(sess, coord) d_threads = d_opt.get_chief_queue_runner().create_threads(sess, coord) sess.run(g_sync_init_op) sess.run(d_sync_init_op) sess.run(train_ops.generator_train_op) # Check that global step wasn't incremented. self.assertEqual(gstep_before, sess.run(global_step)) sess.run(train_ops.discriminator_train_op) # Check that global step wasn't incremented. self.assertEqual(gstep_before, sess.run(global_step)) coord.request_stop() coord.join(g_threads + d_threads)
def train(hparams): """Trains a CIFAR10 GAN. Args: hparams: An HParams instance containing the hyperparameters for training. """ if not tf.io.gfile.exists(hparams.train_log_dir): tf.io.gfile.makedirs(hparams.train_log_dir) with tf.device( tf.compat.v1.train.replica_device_setter(hparams.ps_replicas)): # Force all input processing onto CPU in order to reserve the GPU for # the forward inference and back-propagation. with tf.compat.v1.name_scope('inputs'): with tf.device('/cpu:0'): images, _ = data_provider.provide_data('train', hparams.batch_size, num_parallel_calls=4) # Define the GANModel tuple. generator_fn = networks.generator discriminator_fn = networks.discriminator generator_inputs = tf.random.normal([hparams.batch_size, 64]) gan_model = tfgan.gan_model(generator_fn, discriminator_fn, real_data=images, generator_inputs=generator_inputs) tfgan.eval.add_gan_model_image_summaries(gan_model) # Get the GANLoss tuple. Use the selected GAN loss functions. with tf.compat.v1.name_scope('loss'): gan_loss = tfgan.gan_loss(gan_model, gradient_penalty_weight=1.0, add_summaries=True) # Get the GANTrain ops using the custom optimizers and optional # discriminator weight clipping. with tf.compat.v1.name_scope('train'): gen_opt, dis_opt = _get_optimizers(hparams) train_ops = tfgan.gan_train_ops(gan_model, gan_loss, generator_optimizer=gen_opt, discriminator_optimizer=dis_opt, summarize_gradients=True) # Run the alternating training loop. Skip it if no steps should be taken # (used for graph construction tests). status_message = tf.strings.join([ 'Starting train step: ', tf.as_string(tf.compat.v1.train.get_or_create_global_step()) ], name='status_message') if hparams.max_number_of_steps == 0: return tfgan.gan_train(train_ops, hooks=([ tf.estimator.StopAtStepHook( num_steps=hparams.max_number_of_steps), tf.estimator.LoggingTensorHook([status_message], every_n_iter=10) ]), logdir=hparams.train_log_dir, master=hparams.master, is_chief=hparams.task == 0)