def _get_estimator_spec(
    mode, gan_model, generator_loss_fn, discriminator_loss_fn,
    get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer,
    get_hooks_fn=None):
  """Get the EstimatorSpec for the current mode."""
  if mode == model_fn_lib.ModeKeys.PREDICT:
    estimator_spec = model_fn_lib.EstimatorSpec(
        mode=mode, predictions=gan_model.generated_data)
  else:
    gan_loss = tfgan_tuples.GANLoss(
        generator_loss=generator_loss_fn(gan_model),
        discriminator_loss=discriminator_loss_fn(gan_model))
    if mode == model_fn_lib.ModeKeys.EVAL:
      estimator_spec = _get_eval_estimator_spec(
          gan_model, gan_loss, get_eval_metric_ops_fn)
    else:  # model_fn_lib.ModeKeys.TRAIN:
      gopt = (generator_optimizer() if callable(generator_optimizer) else
              generator_optimizer)
      dopt = (discriminator_optimizer() if callable(discriminator_optimizer)
              else discriminator_optimizer)
      get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks()
      estimator_spec = _get_train_estimator_spec(
          gan_model, gan_loss, gopt, dopt, get_hooks_fn)

  return estimator_spec
Exemplo n.º 2
0
def dummy_loss_fn(gan_model):
    loss = math_ops.reduce_sum(
        gan_model.discriminator_input_data_domain_predication -
        gan_model.discriminator_generated_data_domain_predication)
    loss += math_ops.reduce_sum(gan_model.input_data -
                                gan_model.generated_data)
    return tfgan_tuples.GANLoss(loss, loss)
Exemplo n.º 3
0
def _get_estimator_spec(
    mode, gan_model, generator_loss_fn, discriminator_loss_fn,
    get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer,
    get_hooks_fn=None, use_loss_summaries=True, is_chief=True):
  """Get the EstimatorSpec for the current mode."""
  if mode == model_fn_lib.ModeKeys.PREDICT:
    estimator_spec = model_fn_lib.EstimatorSpec(
        mode=mode, predictions=gan_model.generated_data)
  else:
    gan_loss = tfgan_tuples.GANLoss(
        generator_loss=generator_loss_fn(
            gan_model, add_summaries=use_loss_summaries),
        discriminator_loss=discriminator_loss_fn(
            gan_model, add_summaries=use_loss_summaries))
    if mode == model_fn_lib.ModeKeys.EVAL:
      estimator_spec = _get_eval_estimator_spec(
          gan_model, gan_loss, get_eval_metric_ops_fn)
    else:  # model_fn_lib.ModeKeys.TRAIN:
      if callable(generator_optimizer):
        generator_optimizer = generator_optimizer()
      if callable(discriminator_optimizer):
        discriminator_optimizer = discriminator_optimizer()
      get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks()
      estimator_spec = _get_train_estimator_spec(
          gan_model, gan_loss, generator_optimizer, discriminator_optimizer,
          get_hooks_fn, is_chief=is_chief)

  return estimator_spec
Exemplo n.º 4
0
def _get_estimator_spec(mode, gan_model, generator_loss_fn,
                        discriminator_loss_fn, get_eval_metric_ops_fn,
                        generator_optimizer, discriminator_optimizer,
                        joint_train, is_on_tpu, gan_train_steps):
    """Get the TPUEstimatorSpec for the current mode."""
    if mode == model_fn_lib.ModeKeys.PREDICT:
        estimator_spec = tpu_estimator.TPUEstimatorSpec(
            mode=mode,
            predictions={'generated_data': gan_model.generated_data})
    elif mode == model_fn_lib.ModeKeys.EVAL:
        gan_loss = tfgan_tuples.GANLoss(
            generator_loss=generator_loss_fn(gan_model,
                                             add_summaries=not is_on_tpu),
            discriminator_loss=discriminator_loss_fn(
                gan_model, add_summaries=not is_on_tpu))
        # Eval losses for metrics must preserve batch dimension.
        gan_loss_no_reduction = tfgan_tuples.GANLoss(
            generator_loss=generator_loss_fn(gan_model,
                                             add_summaries=False,
                                             reduction=losses.Reduction.NONE),
            discriminator_loss=discriminator_loss_fn(
                gan_model,
                add_summaries=False,
                reduction=losses.Reduction.NONE))
        estimator_spec = _get_eval_estimator_spec(gan_model, gan_loss,
                                                  gan_loss_no_reduction,
                                                  get_eval_metric_ops_fn)
    else:  # model_fn_lib.ModeKeys.TRAIN:
        gan_loss = tfgan_tuples.GANLoss(
            generator_loss=generator_loss_fn(gan_model,
                                             add_summaries=not is_on_tpu),
            discriminator_loss=discriminator_loss_fn(
                gan_model, add_summaries=not is_on_tpu))

        # Construct optimizers if arguments were callable. For TPUs, they must be
        # `CrossShardOptimizer`.
        g_callable = callable(generator_optimizer)
        gopt = generator_optimizer() if g_callable else generator_optimizer
        d_callable = callable(discriminator_optimizer)
        dopt = discriminator_optimizer(
        ) if d_callable else discriminator_optimizer

        estimator_spec = _get_train_estimator_spec(gan_model, gan_loss, gopt,
                                                   dopt, joint_train,
                                                   gan_train_steps)

    return estimator_spec
Exemplo n.º 5
0
    def create_loss(self, features, mode, logits, labels):
        """Returns a GANLoss tuple from the provided GANModel.

    See `Head` for more details.

    Args:
      features: Input `dict` of `Tensor` objects. Unused.
      mode: Estimator's `ModeKeys`.
      logits: A GANModel tuple.
      labels: Must be `None`.

    Returns:
      A GANLoss tuple.

    """
        _validate_logits_and_labels(logits, labels)
        del mode, labels, features  # unused for this head.
        gan_model = logits  # rename variable for clarity
        return tfgan_tuples.GANLoss(
            generator_loss=self._generator_loss_fn(gan_model),
            discriminator_loss=self._discriminator_loss_fn(gan_model))
Exemplo n.º 6
0
def gan_loss(
        # GANModel.
        model,
        # Loss functions.
        generator_loss_fn=tfgan_losses.wasserstein_generator_loss,
        discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss,
        # Auxiliary losses.
        gradient_penalty_weight=None,
        gradient_penalty_epsilon=1e-10,
        mutual_information_penalty_weight=None,
        aux_cond_generator_weight=None,
        aux_cond_discriminator_weight=None,
        # Options.
        add_summaries=True):
    """Returns losses necessary to train generator and discriminator.

  Args:
    model: A GANModel tuple.
    generator_loss_fn: The loss function on the generator. Takes a GANModel
      tuple.
    discriminator_loss_fn: The loss function on the discriminator. Takes a
      GANModel tuple.
    gradient_penalty_weight: If not `None`, must be a non-negative Python number
      or Tensor indicating how much to weight the gradient penalty. See
      https://arxiv.org/pdf/1704.00028.pdf for more details.
    gradient_penalty_epsilon: If `gradient_penalty_weight` is not None, the
      small positive value used by the gradient penalty function for numerical
      stability. Note some applications will need to increase this value to
      avoid NaNs.
    mutual_information_penalty_weight: If not `None`, must be a non-negative
      Python number or Tensor indicating how much to weight the mutual
      information penalty. See https://arxiv.org/abs/1606.03657 for more
      details.
    aux_cond_generator_weight: If not None: add a classification loss as in
      https://arxiv.org/abs/1610.09585
    aux_cond_discriminator_weight: If not None: add a classification loss as in
      https://arxiv.org/abs/1610.09585
    add_summaries: Whether or not to add summaries for the losses.

  Returns:
    A GANLoss 2-tuple of (generator_loss, discriminator_loss). Includes
    regularization losses.

  Raises:
    ValueError: If any of the auxiliary loss weights is provided and negative.
    ValueError: If `mutual_information_penalty_weight` is provided, but the
      `model` isn't an `InfoGANModel`.
  """
    # Validate arguments.
    gradient_penalty_weight = _validate_aux_loss_weight(
        gradient_penalty_weight, 'gradient_penalty_weight')
    mutual_information_penalty_weight = _validate_aux_loss_weight(
        mutual_information_penalty_weight, 'infogan_weight')
    aux_cond_generator_weight = _validate_aux_loss_weight(
        aux_cond_generator_weight, 'aux_cond_generator_weight')
    aux_cond_discriminator_weight = _validate_aux_loss_weight(
        aux_cond_discriminator_weight, 'aux_cond_discriminator_weight')

    # Verify configuration for mutual information penalty
    if (_use_aux_loss(mutual_information_penalty_weight)
            and not isinstance(model, namedtuples.InfoGANModel)):
        raise ValueError(
            'When `mutual_information_penalty_weight` is provided, `model` must be '
            'an `InfoGANModel`. Instead, was %s.' % type(model))

    # Verify configuration for mutual auxiliary condition loss (ACGAN).
    if ((_use_aux_loss(aux_cond_generator_weight)
         or _use_aux_loss(aux_cond_discriminator_weight))
            and not isinstance(model, namedtuples.ACGANModel)):
        raise ValueError(
            'When `aux_cond_generator_weight` or `aux_cond_discriminator_weight` '
            'is provided, `model` must be an `ACGANModel`. Instead, was %s.' %
            type(model))

    # Create standard losses.
    gen_loss = generator_loss_fn(model, add_summaries=add_summaries)
    dis_loss = discriminator_loss_fn(model, add_summaries=add_summaries)

    # Add optional extra losses.
    if _use_aux_loss(gradient_penalty_weight):
        gp_loss = tfgan_losses.wasserstein_gradient_penalty(
            model,
            epsilon=gradient_penalty_epsilon,
            add_summaries=add_summaries)
        dis_loss += gradient_penalty_weight * gp_loss
    if _use_aux_loss(mutual_information_penalty_weight):
        info_loss = tfgan_losses.mutual_information_penalty(
            model, add_summaries=add_summaries)
        dis_loss += mutual_information_penalty_weight * info_loss
        gen_loss += mutual_information_penalty_weight * info_loss
    if _use_aux_loss(aux_cond_generator_weight):
        ac_gen_loss = tfgan_losses.acgan_generator_loss(
            model, add_summaries=add_summaries)
        gen_loss += aux_cond_generator_weight * ac_gen_loss
    if _use_aux_loss(aux_cond_discriminator_weight):
        ac_disc_loss = tfgan_losses.acgan_discriminator_loss(
            model, add_summaries=add_summaries)
        dis_loss += aux_cond_discriminator_weight * ac_disc_loss
    # Gathers auxilliary losses.
    if model.generator_scope:
        gen_reg_loss = losses.get_regularization_loss(
            model.generator_scope.name)
    else:
        gen_reg_loss = 0
    if model.discriminator_scope:
        dis_reg_loss = losses.get_regularization_loss(
            model.discriminator_scope.name)
    else:
        dis_reg_loss = 0

    return namedtuples.GANLoss(gen_loss + gen_reg_loss,
                               dis_loss + dis_reg_loss)
Exemplo n.º 7
0
def gan_loss(
    # GANModel.
    model,
    # Loss functions.
    generator_loss_fn=tfgan_losses.wasserstein_generator_loss,
    discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss,
    # Auxiliary losses.
    gradient_penalty_weight=None,
    gradient_penalty_epsilon=1e-10,
    gradient_penalty_target=1.0,
    gradient_penalty_one_sided=False,
    mutual_information_penalty_weight=None,
    aux_cond_generator_weight=None,
    aux_cond_discriminator_weight=None,
    tensor_pool_fn=None,
    # Options.
    add_summaries=True):
  """Returns losses necessary to train generator and discriminator.

  Args:
    model: A GANModel tuple.
    generator_loss_fn: The loss function on the generator. Takes a GANModel
      tuple.
    discriminator_loss_fn: The loss function on the discriminator. Takes a
      GANModel tuple.
    gradient_penalty_weight: If not `None`, must be a non-negative Python number
      or Tensor indicating how much to weight the gradient penalty. See
      https://arxiv.org/pdf/1704.00028.pdf for more details.
    gradient_penalty_epsilon: If `gradient_penalty_weight` is not None, the
      small positive value used by the gradient penalty function for numerical
      stability. Note some applications will need to increase this value to
      avoid NaNs.
    gradient_penalty_target: If `gradient_penalty_weight` is not None, a Python
      number or `Tensor` indicating the target value of gradient norm. See the
      CIFAR10 section of https://arxiv.org/abs/1710.10196. Defaults to 1.0.
    gradient_penalty_one_sided: If `True`, penalty proposed in
      https://arxiv.org/abs/1709.08894 is used. Defaults to `False`.
    mutual_information_penalty_weight: If not `None`, must be a non-negative
      Python number or Tensor indicating how much to weight the mutual
      information penalty. See https://arxiv.org/abs/1606.03657 for more
      details.
    aux_cond_generator_weight: If not None: add a classification loss as in
      https://arxiv.org/abs/1610.09585
    aux_cond_discriminator_weight: If not None: add a classification loss as in
      https://arxiv.org/abs/1610.09585
    tensor_pool_fn: A function that takes (generated_data, generator_inputs),
      stores them in an internal pool and returns previous stored
      (generated_data, generator_inputs). For example
      `tf.gan.features.tensor_pool`. Defaults to None (not using tensor pool).
    add_summaries: Whether or not to add summaries for the losses.

  Returns:
    A GANLoss 2-tuple of (generator_loss, discriminator_loss). Includes
    regularization losses.

  Raises:
    ValueError: If any of the auxiliary loss weights is provided and negative.
    ValueError: If `mutual_information_penalty_weight` is provided, but the
      `model` isn't an `InfoGANModel`.
  """
  # Validate arguments.
  gradient_penalty_weight = _validate_aux_loss_weight(
      gradient_penalty_weight, 'gradient_penalty_weight')
  mutual_information_penalty_weight = _validate_aux_loss_weight(
      mutual_information_penalty_weight, 'infogan_weight')
  aux_cond_generator_weight = _validate_aux_loss_weight(
      aux_cond_generator_weight, 'aux_cond_generator_weight')
  aux_cond_discriminator_weight = _validate_aux_loss_weight(
      aux_cond_discriminator_weight, 'aux_cond_discriminator_weight')

  # Verify configuration for mutual information penalty
  if (_use_aux_loss(mutual_information_penalty_weight) and
      not isinstance(model, namedtuples.InfoGANModel)):
    raise ValueError(
        'When `mutual_information_penalty_weight` is provided, `model` must be '
        'an `InfoGANModel`. Instead, was %s.' % type(model))

  # Verify configuration for mutual auxiliary condition loss (ACGAN).
  if ((_use_aux_loss(aux_cond_generator_weight) or
       _use_aux_loss(aux_cond_discriminator_weight)) and
      not isinstance(model, namedtuples.ACGANModel)):
    raise ValueError(
        'When `aux_cond_generator_weight` or `aux_cond_discriminator_weight` '
        'is provided, `model` must be an `ACGANModel`. Instead, was %s.' %
        type(model))

  # Optionally create pooled model.
  pooled_model = (_tensor_pool_adjusted_model(model, tensor_pool_fn) if
                  tensor_pool_fn else model)

  # Create standard losses.
  gen_loss = generator_loss_fn(model, add_summaries=add_summaries)
  dis_loss = discriminator_loss_fn(pooled_model, add_summaries=add_summaries)

  # Add optional extra losses.
  if _use_aux_loss(gradient_penalty_weight):
    gp_loss = tfgan_losses.wasserstein_gradient_penalty(
        pooled_model,
        epsilon=gradient_penalty_epsilon,
        target=gradient_penalty_target,
        one_sided=gradient_penalty_one_sided,
        add_summaries=add_summaries)
    dis_loss += gradient_penalty_weight * gp_loss
  if _use_aux_loss(mutual_information_penalty_weight):
    gen_info_loss = tfgan_losses.mutual_information_penalty(
        model, add_summaries=add_summaries)
    dis_info_loss = (gen_info_loss if tensor_pool_fn is None else
                     tfgan_losses.mutual_information_penalty(
                         pooled_model, add_summaries=add_summaries))
    gen_loss += mutual_information_penalty_weight * gen_info_loss
    dis_loss += mutual_information_penalty_weight * dis_info_loss
  if _use_aux_loss(aux_cond_generator_weight):
    ac_gen_loss = tfgan_losses.acgan_generator_loss(
        model, add_summaries=add_summaries)
    gen_loss += aux_cond_generator_weight * ac_gen_loss
  if _use_aux_loss(aux_cond_discriminator_weight):
    ac_disc_loss = tfgan_losses.acgan_discriminator_loss(
        pooled_model, add_summaries=add_summaries)
    dis_loss += aux_cond_discriminator_weight * ac_disc_loss
  # Gathers auxiliary losses.
  if model.generator_scope:
    gen_reg_loss = losses.get_regularization_loss(model.generator_scope.name)
  else:
    gen_reg_loss = 0
  if model.discriminator_scope:
    dis_reg_loss = losses.get_regularization_loss(
        model.discriminator_scope.name)
  else:
    dis_reg_loss = 0

  return namedtuples.GANLoss(gen_loss + gen_reg_loss, dis_loss + dis_reg_loss)