def _gan_train_ops(self, generator_add, discriminator_add): step = training_util.create_global_step() # Increment the global count every time a train op is run so we can count # the number of times they're run. # NOTE: `use_locking=True` is required to avoid race conditions with # joint training. train_ops = namedtuples.GANTrainOps( generator_train_op=step.assign_add(generator_add, use_locking=True), discriminator_train_op=step.assign_add( discriminator_add, use_locking=True), global_step_inc_op=step.assign_add(1)) return train_ops
def test_supervisor_run_gan_model_train_ops_multiple_steps(self): step = training_util.create_global_step() train_ops = namedtuples.GANTrainOps( generator_train_op=constant_op.constant(3.0), discriminator_train_op=constant_op.constant(2.0), global_step_inc_op=step.assign_add(1)) train_steps = namedtuples.GANTrainSteps(generator_train_steps=3, discriminator_train_steps=4) final_loss = slim_learning.train( train_op=train_ops, logdir='', global_step=step, number_of_steps=1, train_step_fn=train.get_sequential_train_steps(train_steps)) self.assertTrue(np.isscalar(final_loss)) self.assertEqual(17.0, final_loss)
def gan_train_ops( model, # GANModel loss, # GANLoss generator_optimizer, discriminator_optimizer, # Optional check flags. check_for_unused_update_ops=True, # Optional args to pass directly to the `create_train_op`. **kwargs): """Returns GAN train ops. The highest-level call in TFGAN. It is composed of functions that can also be called, should a user require more control over some part of the GAN training process. Args: model: A GANModel. loss: A GANLoss. generator_optimizer: The optimizer for generator updates. discriminator_optimizer: The optimizer for the discriminator updates. check_for_unused_update_ops: If `True`, throws an exception if there are update ops outside of the generator or discriminator scopes. **kwargs: Keyword args to pass directly to `training.create_train_op` for both the generator and discriminator train op. Returns: A GANTrainOps tuple of (generator_train_op, discriminator_train_op) that can be used to train a generator/discriminator pair. """ # Create global step increment op. global_step = training_util.get_or_create_global_step() global_step_inc = global_step.assign_add(1) # Get generator and discriminator update ops. We split them so that update # ops aren't accidentally run multiple times. For now, throw an error if # there are update ops that aren't associated with either the generator or # the discriminator. Might modify the `kwargs` dictionary. gen_update_ops, dis_update_ops = _get_update_ops( kwargs, model.generator_scope.name, model.discriminator_scope.name, check_for_unused_update_ops) generator_global_step = None if isinstance(generator_optimizer, sync_replicas_optimizer.SyncReplicasOptimizer): # TODO(joelshor): Figure out a way to get this work without including the # dummy global step in the checkpoint. # WARNING: Making this variable a local variable causes sync replicas to # hang forever. generator_global_step = variable_scope.get_variable( 'dummy_global_step_generator', shape=[], dtype=global_step.dtype.base_dtype, initializer=init_ops.zeros_initializer(), trainable=False, collections=[ops.GraphKeys.GLOBAL_VARIABLES]) gen_update_ops += [generator_global_step.assign(global_step)] with ops.name_scope('generator_train'): gen_train_op = training.create_train_op( total_loss=loss.generator_loss, optimizer=generator_optimizer, variables_to_train=model.generator_variables, global_step=generator_global_step, update_ops=gen_update_ops, **kwargs) discriminator_global_step = None if isinstance(discriminator_optimizer, sync_replicas_optimizer.SyncReplicasOptimizer): # See comment above `generator_global_step`. discriminator_global_step = variable_scope.get_variable( 'dummy_global_step_discriminator', shape=[], dtype=global_step.dtype.base_dtype, initializer=init_ops.zeros_initializer(), trainable=False, collections=[ops.GraphKeys.GLOBAL_VARIABLES]) dis_update_ops += [discriminator_global_step.assign(global_step)] with ops.name_scope('discriminator_train'): disc_train_op = training.create_train_op( total_loss=loss.discriminator_loss, optimizer=discriminator_optimizer, variables_to_train=model.discriminator_variables, global_step=discriminator_global_step, update_ops=dis_update_ops, **kwargs) return namedtuples.GANTrainOps(gen_train_op, disc_train_op, global_step_inc)
def gan_train_ops( model, loss, generator_optimizer, discriminator_optimizer, check_for_unused_update_ops=True, # Optional args to pass directly to the `create_train_op`. **kwargs): """Returns GAN train ops. The highest-level call in TFGAN. It is composed of functions that can also be called, should a user require more control over some part of the GAN training process. Args: model: A GANModel. loss: A GANLoss. generator_optimizer: The optimizer for generator updates. discriminator_optimizer: The optimizer for the discriminator updates. check_for_unused_update_ops: If `True`, throws an exception if there are update ops outside of the generator or discriminator scopes. **kwargs: Keyword args to pass directly to `training.create_train_op` for both the generator and discriminator train op. Returns: A GANTrainOps tuple of (generator_train_op, discriminator_train_op) that can be used to train a generator/discriminator pair. """ if isinstance(model, namedtuples.CycleGANModel): # Get and store all arguments other than model and loss from locals. # Contents of locals should not be modified, may not affect values. So make # a copy. https://docs.python.org/2/library/functions.html#locals. saved_params = dict(locals()) saved_params.pop('model', None) saved_params.pop('loss', None) kwargs = saved_params.pop('kwargs', {}) saved_params.update(kwargs) with ops.name_scope('cyclegan_x2y_train'): train_ops_x2y = gan_train_ops(model.model_x2y, loss.loss_x2y, **saved_params) with ops.name_scope('cyclegan_y2x_train'): train_ops_y2x = gan_train_ops(model.model_y2x, loss.loss_y2x, **saved_params) return namedtuples.GANTrainOps( (train_ops_x2y.generator_train_op, train_ops_y2x.generator_train_op), (train_ops_x2y.discriminator_train_op, train_ops_y2x.discriminator_train_op), training_util.get_or_create_global_step().assign_add(1)) # Create global step increment op. global_step = training_util.get_or_create_global_step() global_step_inc = global_step.assign_add(1) # Get generator and discriminator update ops. We split them so that update # ops aren't accidentally run multiple times. For now, throw an error if # there are update ops that aren't associated with either the generator or # the discriminator. Might modify the `kwargs` dictionary. gen_update_ops, dis_update_ops = _get_update_ops( kwargs, model.generator_scope.name, model.discriminator_scope.name, check_for_unused_update_ops) generator_global_step = None if isinstance(generator_optimizer, sync_replicas_optimizer.SyncReplicasOptimizer): # TODO (joelshor): Figure out a way to get this work without including the id:885 # https://github.com/imdone/tensorflow/issues/886 # dummy global step in the checkpoint. # WARNING: Making this variable a local variable causes sync replicas to # hang forever. generator_global_step = variable_scope.get_variable( 'dummy_global_step_generator', shape=[], dtype=global_step.dtype.base_dtype, initializer=init_ops.zeros_initializer(), trainable=False, collections=[ops.GraphKeys.GLOBAL_VARIABLES]) gen_update_ops += [generator_global_step.assign(global_step)] with ops.name_scope('generator_train'): gen_train_op = training.create_train_op( total_loss=loss.generator_loss, optimizer=generator_optimizer, variables_to_train=model.generator_variables, global_step=generator_global_step, update_ops=gen_update_ops, **kwargs) discriminator_global_step = None if isinstance(discriminator_optimizer, sync_replicas_optimizer.SyncReplicasOptimizer): # See comment above `generator_global_step`. discriminator_global_step = variable_scope.get_variable( 'dummy_global_step_discriminator', shape=[], dtype=global_step.dtype.base_dtype, initializer=init_ops.zeros_initializer(), trainable=False, collections=[ops.GraphKeys.GLOBAL_VARIABLES]) dis_update_ops += [discriminator_global_step.assign(global_step)] with ops.name_scope('discriminator_train'): disc_train_op = training.create_train_op( total_loss=loss.discriminator_loss, optimizer=discriminator_optimizer, variables_to_train=model.discriminator_variables, global_step=discriminator_global_step, update_ops=dis_update_ops, **kwargs) return namedtuples.GANTrainOps(gen_train_op, disc_train_op, global_step_inc)