Esempio n. 1
0
    def test_train_hooks_exist_in_get_hooks_fn(self, create_gan_model_fn):
        model = create_gan_model_fn()
        loss = train.gan_loss(model)

        g_opt = get_sync_optimizer()
        d_opt = get_sync_optimizer()
        train_ops = train.gan_train_ops(model,
                                        loss,
                                        g_opt,
                                        d_opt,
                                        summarize_gradients=True,
                                        colocate_gradients_with_ops=True)

        sequential_train_hooks = train.get_sequential_train_hooks()(train_ops)
        self.assertLen(sequential_train_hooks, 4)
        sync_opts = [
            hook._sync_optimizer
            for hook in sequential_train_hooks if isinstance(
                hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)
        ]
        self.assertLen(sync_opts, 2)
        self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))

        joint_train_hooks = train.get_joint_train_hooks()(train_ops)
        self.assertLen(joint_train_hooks, 5)
        sync_opts = [
            hook._sync_optimizer for hook in joint_train_hooks if isinstance(
                hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)
        ]
        self.assertLen(sync_opts, 2)
        self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))
def _get_estimator_spec(mode,
                        gan_model,
                        loss_fn,
                        get_eval_metric_ops_fn,
                        generator_optimizer,
                        discriminator_optimizer,
                        get_hooks_fn=None):
    """Get the EstimatorSpec for the current mode."""
    if mode == model_fn_lib.ModeKeys.PREDICT:
        estimator_spec = model_fn_lib.EstimatorSpec(
            mode=mode, predictions=gan_model.generated_data)
    else:
        gan_loss = loss_fn(gan_model)
        if mode == model_fn_lib.ModeKeys.EVAL:
            estimator_spec = _get_eval_estimator_spec(gan_model, gan_loss,
                                                      get_eval_metric_ops_fn)
        else:  # model_fn_lib.ModeKeys.TRAIN:
            gopt = (generator_optimizer()
                    if callable(generator_optimizer) else generator_optimizer)
            dopt = (discriminator_optimizer()
                    if callable(discriminator_optimizer) else
                    discriminator_optimizer)
            get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks(
            )
            estimator_spec = _get_train_estimator_spec(gan_model, gan_loss,
                                                       gopt, dopt,
                                                       get_hooks_fn)

    return estimator_spec
Esempio n. 3
0
    def __init__(self,
                 generator_loss_fn,
                 discriminator_loss_fn,
                 generator_optimizer,
                 discriminator_optimizer,
                 use_loss_summaries=True,
                 get_hooks_fn=None,
                 get_eval_metric_ops_fn=None,
                 name=None):
        """`Head` for GAN training.

    Args:
      generator_loss_fn: A TFGAN loss function for the generator. Takes a
        `GANModel` and returns a scalar.
      discriminator_loss_fn: Same as `generator_loss_fn`, but for the
      discriminator.
      generator_optimizer: The optimizer for generator updates.
      discriminator_optimizer: Same as `generator_optimizer`, but for the
        discriminator updates.
      use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
        If `None`, uses defaults.
      get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a
        list of hooks. Defaults to `train.get_sequential_train_hooks()`
      get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a
        dict of metric results keyed by name. The output of this function is
        passed into `tf.estimator.EstimatorSpec` during evaluation.
      name: name of the head. If provided, summary and metrics keys will be
        suffixed by `"/" + name`.
    """

        if not callable(generator_loss_fn):
            raise TypeError('generator_loss_fn must be callable.')
        if not callable(discriminator_loss_fn):
            raise TypeError('discriminator_loss_fn must be callable.')
        if use_loss_summaries not in [True, False, None]:
            raise ValueError('use_loss_summaries must be True, False or None.')
        if get_hooks_fn is not None and not callable(get_hooks_fn):
            raise TypeError('get_hooks_fn must be callable.')
        if name is not None and not isinstance(name, str):
            raise TypeError('name must be string.')

        if get_hooks_fn is None:
            get_hooks_fn = tfgan_train.get_sequential_train_hooks()

        if use_loss_summaries in [True, False]:
            generator_loss_fn = functools.partial(
                generator_loss_fn, add_summaries=use_loss_summaries)
            discriminator_loss_fn = functools.partial(
                discriminator_loss_fn, add_summaries=use_loss_summaries)
        self._generator_loss_fn = generator_loss_fn
        self._discriminator_loss_fn = discriminator_loss_fn
        self._generator_optimizer = generator_optimizer
        self._discriminator_optimizer = discriminator_optimizer
        self._get_hooks_fn = get_hooks_fn
        self._get_eval_metric_ops_fn = get_eval_metric_ops_fn
        self._name = name
Esempio n. 4
0
def gan_head(generator_loss_fn,
             discriminator_loss_fn,
             generator_optimizer,
             discriminator_optimizer,
             use_loss_summaries=True,
             get_hooks_fn=tfgan_train.get_sequential_train_hooks(),
             get_eval_metric_ops_fn=None,
             name=None):
    """Creates a `GANHead`.

  Args:
    generator_loss_fn: A TFGAN loss function for the generator. Takes a
      `GANModel` and returns a scalar.
    discriminator_loss_fn: Same as `generator_loss_fn`, but for the
      discriminator.
    generator_optimizer: The optimizer for generator updates.
    discriminator_optimizer: Same as `generator_optimizer`, but for the
      discriminator updates.
    use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
      If `None`, uses defaults.
    get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a
      list of hooks.
    get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a
      dict of metric results keyed by name. The output of this function is
      passed into `tf.estimator.EstimatorSpec` during evaluation.
    name: name of the head. If provided, summary and metrics keys will be
      suffixed by `"/" + name`.

  Returns:
    An instance of `GANHead`.
  """
    return GANHead(generator_loss_fn=generator_loss_fn,
                   discriminator_loss_fn=discriminator_loss_fn,
                   generator_optimizer=generator_optimizer,
                   discriminator_optimizer=discriminator_optimizer,
                   use_loss_summaries=use_loss_summaries,
                   get_hooks_fn=get_hooks_fn,
                   get_eval_metric_ops_fn=get_eval_metric_ops_fn,
                   name=name)
Esempio n. 5
0
def _get_estimator_spec(mode,
                        gan_model,
                        generator_loss_fn,
                        discriminator_loss_fn,
                        get_eval_metric_ops_fn,
                        generator_optimizer,
                        discriminator_optimizer,
                        get_hooks_fn=None,
                        use_loss_summaries=True,
                        is_chief=True):
    """Get the EstimatorSpec for the current mode."""
    if mode == model_fn_lib.ModeKeys.PREDICT:
        estimator_spec = model_fn_lib.EstimatorSpec(
            mode=mode, predictions=gan_model.generated_data)
    else:
        gan_loss = tfgan_tuples.GANLoss(
            generator_loss=generator_loss_fn(gan_model,
                                             add_summaries=use_loss_summaries),
            discriminator_loss=discriminator_loss_fn(
                gan_model, add_summaries=use_loss_summaries))
        if mode == model_fn_lib.ModeKeys.EVAL:
            estimator_spec = _get_eval_estimator_spec(gan_model, gan_loss,
                                                      get_eval_metric_ops_fn)
        else:  # model_fn_lib.ModeKeys.TRAIN:
            if callable(generator_optimizer):
                generator_optimizer = generator_optimizer()
            if callable(discriminator_optimizer):
                discriminator_optimizer = discriminator_optimizer()
            get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks(
            )
            estimator_spec = _get_train_estimator_spec(gan_model,
                                                       gan_loss,
                                                       generator_optimizer,
                                                       discriminator_optimizer,
                                                       get_hooks_fn,
                                                       is_chief=is_chief)

    return estimator_spec