Example #1
0
 def _gan_train_ops(self, generator_add, discriminator_add):
   step = training_util.create_global_step()
   # Increment the global count every time a train op is run so we can count
   # the number of times they're run.
   # NOTE: `use_locking=True` is required to avoid race conditions with
   # joint training.
   train_ops = namedtuples.GANTrainOps(
       generator_train_op=step.assign_add(generator_add, use_locking=True),
       discriminator_train_op=step.assign_add(
           discriminator_add, use_locking=True),
       global_step_inc_op=step.assign_add(1))
   return train_ops
Example #2
0
    def test_supervisor_run_gan_model_train_ops_multiple_steps(self):
        step = training_util.create_global_step()
        train_ops = namedtuples.GANTrainOps(
            generator_train_op=constant_op.constant(3.0),
            discriminator_train_op=constant_op.constant(2.0),
            global_step_inc_op=step.assign_add(1))
        train_steps = namedtuples.GANTrainSteps(generator_train_steps=3,
                                                discriminator_train_steps=4)

        final_loss = slim_learning.train(
            train_op=train_ops,
            logdir='',
            global_step=step,
            number_of_steps=1,
            train_step_fn=train.get_sequential_train_steps(train_steps))
        self.assertTrue(np.isscalar(final_loss))
        self.assertEqual(17.0, final_loss)
Example #3
0
def gan_train_ops(
        model,  # GANModel
        loss,  # GANLoss
        generator_optimizer,
        discriminator_optimizer,
        # Optional check flags.
        check_for_unused_update_ops=True,
        # Optional args to pass directly to the `create_train_op`.
        **kwargs):
    """Returns GAN train ops.

  The highest-level call in TFGAN. It is composed of functions that can also
  be called, should a user require more control over some part of the GAN
  training process.

  Args:
    model: A GANModel.
    loss: A GANLoss.
    generator_optimizer: The optimizer for generator updates.
    discriminator_optimizer: The optimizer for the discriminator updates.
    check_for_unused_update_ops: If `True`, throws an exception if there are
      update ops outside of the generator or discriminator scopes.
    **kwargs: Keyword args to pass directly to
      `training.create_train_op` for both the generator and
      discriminator train op.

  Returns:
    A GANTrainOps tuple of (generator_train_op, discriminator_train_op) that can
    be used to train a generator/discriminator pair.
  """
    # Create global step increment op.
    global_step = training_util.get_or_create_global_step()
    global_step_inc = global_step.assign_add(1)

    # Get generator and discriminator update ops. We split them so that update
    # ops aren't accidentally run multiple times. For now, throw an error if
    # there are update ops that aren't associated with either the generator or
    # the discriminator. Might modify the `kwargs` dictionary.
    gen_update_ops, dis_update_ops = _get_update_ops(
        kwargs, model.generator_scope.name, model.discriminator_scope.name,
        check_for_unused_update_ops)

    generator_global_step = None
    if isinstance(generator_optimizer,
                  sync_replicas_optimizer.SyncReplicasOptimizer):
        # TODO(joelshor): Figure out a way to get this work without including the
        # dummy global step in the checkpoint.
        # WARNING: Making this variable a local variable causes sync replicas to
        # hang forever.
        generator_global_step = variable_scope.get_variable(
            'dummy_global_step_generator',
            shape=[],
            dtype=global_step.dtype.base_dtype,
            initializer=init_ops.zeros_initializer(),
            trainable=False,
            collections=[ops.GraphKeys.GLOBAL_VARIABLES])
        gen_update_ops += [generator_global_step.assign(global_step)]
    with ops.name_scope('generator_train'):
        gen_train_op = training.create_train_op(
            total_loss=loss.generator_loss,
            optimizer=generator_optimizer,
            variables_to_train=model.generator_variables,
            global_step=generator_global_step,
            update_ops=gen_update_ops,
            **kwargs)

    discriminator_global_step = None
    if isinstance(discriminator_optimizer,
                  sync_replicas_optimizer.SyncReplicasOptimizer):
        # See comment above `generator_global_step`.
        discriminator_global_step = variable_scope.get_variable(
            'dummy_global_step_discriminator',
            shape=[],
            dtype=global_step.dtype.base_dtype,
            initializer=init_ops.zeros_initializer(),
            trainable=False,
            collections=[ops.GraphKeys.GLOBAL_VARIABLES])
        dis_update_ops += [discriminator_global_step.assign(global_step)]
    with ops.name_scope('discriminator_train'):
        disc_train_op = training.create_train_op(
            total_loss=loss.discriminator_loss,
            optimizer=discriminator_optimizer,
            variables_to_train=model.discriminator_variables,
            global_step=discriminator_global_step,
            update_ops=dis_update_ops,
            **kwargs)

    return namedtuples.GANTrainOps(gen_train_op, disc_train_op,
                                   global_step_inc)
Example #4
0
def gan_train_ops(
        model,
        loss,
        generator_optimizer,
        discriminator_optimizer,
        check_for_unused_update_ops=True,
        # Optional args to pass directly to the `create_train_op`.
        **kwargs):
    """Returns GAN train ops.

  The highest-level call in TFGAN. It is composed of functions that can also
  be called, should a user require more control over some part of the GAN
  training process.

  Args:
    model: A GANModel.
    loss: A GANLoss.
    generator_optimizer: The optimizer for generator updates.
    discriminator_optimizer: The optimizer for the discriminator updates.
    check_for_unused_update_ops: If `True`, throws an exception if there are
      update ops outside of the generator or discriminator scopes.
    **kwargs: Keyword args to pass directly to
      `training.create_train_op` for both the generator and
      discriminator train op.

  Returns:
    A GANTrainOps tuple of (generator_train_op, discriminator_train_op) that can
    be used to train a generator/discriminator pair.
  """
    if isinstance(model, namedtuples.CycleGANModel):
        # Get and store all arguments other than model and loss from locals.
        # Contents of locals should not be modified, may not affect values. So make
        # a copy. https://docs.python.org/2/library/functions.html#locals.
        saved_params = dict(locals())
        saved_params.pop('model', None)
        saved_params.pop('loss', None)
        kwargs = saved_params.pop('kwargs', {})
        saved_params.update(kwargs)
        with ops.name_scope('cyclegan_x2y_train'):
            train_ops_x2y = gan_train_ops(model.model_x2y, loss.loss_x2y,
                                          **saved_params)
        with ops.name_scope('cyclegan_y2x_train'):
            train_ops_y2x = gan_train_ops(model.model_y2x, loss.loss_y2x,
                                          **saved_params)
        return namedtuples.GANTrainOps(
            (train_ops_x2y.generator_train_op,
             train_ops_y2x.generator_train_op),
            (train_ops_x2y.discriminator_train_op,
             train_ops_y2x.discriminator_train_op),
            training_util.get_or_create_global_step().assign_add(1))

    # Create global step increment op.
    global_step = training_util.get_or_create_global_step()
    global_step_inc = global_step.assign_add(1)

    # Get generator and discriminator update ops. We split them so that update
    # ops aren't accidentally run multiple times. For now, throw an error if
    # there are update ops that aren't associated with either the generator or
    # the discriminator. Might modify the `kwargs` dictionary.
    gen_update_ops, dis_update_ops = _get_update_ops(
        kwargs, model.generator_scope.name, model.discriminator_scope.name,
        check_for_unused_update_ops)

    generator_global_step = None
    if isinstance(generator_optimizer,
                  sync_replicas_optimizer.SyncReplicasOptimizer):
        # TODO (joelshor): Figure out a way to get this work without including the id:885
        # https://github.com/imdone/tensorflow/issues/886
        # dummy global step in the checkpoint.
        # WARNING: Making this variable a local variable causes sync replicas to
        # hang forever.
        generator_global_step = variable_scope.get_variable(
            'dummy_global_step_generator',
            shape=[],
            dtype=global_step.dtype.base_dtype,
            initializer=init_ops.zeros_initializer(),
            trainable=False,
            collections=[ops.GraphKeys.GLOBAL_VARIABLES])
        gen_update_ops += [generator_global_step.assign(global_step)]
    with ops.name_scope('generator_train'):
        gen_train_op = training.create_train_op(
            total_loss=loss.generator_loss,
            optimizer=generator_optimizer,
            variables_to_train=model.generator_variables,
            global_step=generator_global_step,
            update_ops=gen_update_ops,
            **kwargs)

    discriminator_global_step = None
    if isinstance(discriminator_optimizer,
                  sync_replicas_optimizer.SyncReplicasOptimizer):
        # See comment above `generator_global_step`.
        discriminator_global_step = variable_scope.get_variable(
            'dummy_global_step_discriminator',
            shape=[],
            dtype=global_step.dtype.base_dtype,
            initializer=init_ops.zeros_initializer(),
            trainable=False,
            collections=[ops.GraphKeys.GLOBAL_VARIABLES])
        dis_update_ops += [discriminator_global_step.assign(global_step)]
    with ops.name_scope('discriminator_train'):
        disc_train_op = training.create_train_op(
            total_loss=loss.discriminator_loss,
            optimizer=discriminator_optimizer,
            variables_to_train=model.discriminator_variables,
            global_step=discriminator_global_step,
            update_ops=dis_update_ops,
            **kwargs)

    return namedtuples.GANTrainOps(gen_train_op, disc_train_op,
                                   global_step_inc)