def _get_estimator_spec( mode, gan_model, generator_loss_fn, discriminator_loss_fn, get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, get_hooks_fn=None): """Get the EstimatorSpec for the current mode.""" if mode == model_fn_lib.ModeKeys.PREDICT: estimator_spec = model_fn_lib.EstimatorSpec( mode=mode, predictions=gan_model.generated_data) else: gan_loss = tfgan_tuples.GANLoss( generator_loss=generator_loss_fn(gan_model), discriminator_loss=discriminator_loss_fn(gan_model)) if mode == model_fn_lib.ModeKeys.EVAL: estimator_spec = _get_eval_estimator_spec( gan_model, gan_loss, get_eval_metric_ops_fn) else: # model_fn_lib.ModeKeys.TRAIN: gopt = (generator_optimizer() if callable(generator_optimizer) else generator_optimizer) dopt = (discriminator_optimizer() if callable(discriminator_optimizer) else discriminator_optimizer) get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks() estimator_spec = _get_train_estimator_spec( gan_model, gan_loss, gopt, dopt, get_hooks_fn) return estimator_spec
def dummy_loss_fn(gan_model): loss = math_ops.reduce_sum( gan_model.discriminator_input_data_domain_predication - gan_model.discriminator_generated_data_domain_predication) loss += math_ops.reduce_sum(gan_model.input_data - gan_model.generated_data) return tfgan_tuples.GANLoss(loss, loss)
def _get_estimator_spec( mode, gan_model, generator_loss_fn, discriminator_loss_fn, get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, get_hooks_fn=None, use_loss_summaries=True, is_chief=True): """Get the EstimatorSpec for the current mode.""" if mode == model_fn_lib.ModeKeys.PREDICT: estimator_spec = model_fn_lib.EstimatorSpec( mode=mode, predictions=gan_model.generated_data) else: gan_loss = tfgan_tuples.GANLoss( generator_loss=generator_loss_fn( gan_model, add_summaries=use_loss_summaries), discriminator_loss=discriminator_loss_fn( gan_model, add_summaries=use_loss_summaries)) if mode == model_fn_lib.ModeKeys.EVAL: estimator_spec = _get_eval_estimator_spec( gan_model, gan_loss, get_eval_metric_ops_fn) else: # model_fn_lib.ModeKeys.TRAIN: if callable(generator_optimizer): generator_optimizer = generator_optimizer() if callable(discriminator_optimizer): discriminator_optimizer = discriminator_optimizer() get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks() estimator_spec = _get_train_estimator_spec( gan_model, gan_loss, generator_optimizer, discriminator_optimizer, get_hooks_fn, is_chief=is_chief) return estimator_spec
def _get_estimator_spec(mode, gan_model, generator_loss_fn, discriminator_loss_fn, get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, joint_train, is_on_tpu, gan_train_steps): """Get the TPUEstimatorSpec for the current mode.""" if mode == model_fn_lib.ModeKeys.PREDICT: estimator_spec = tpu_estimator.TPUEstimatorSpec( mode=mode, predictions={'generated_data': gan_model.generated_data}) elif mode == model_fn_lib.ModeKeys.EVAL: gan_loss = tfgan_tuples.GANLoss( generator_loss=generator_loss_fn(gan_model, add_summaries=not is_on_tpu), discriminator_loss=discriminator_loss_fn( gan_model, add_summaries=not is_on_tpu)) # Eval losses for metrics must preserve batch dimension. gan_loss_no_reduction = tfgan_tuples.GANLoss( generator_loss=generator_loss_fn(gan_model, add_summaries=False, reduction=losses.Reduction.NONE), discriminator_loss=discriminator_loss_fn( gan_model, add_summaries=False, reduction=losses.Reduction.NONE)) estimator_spec = _get_eval_estimator_spec(gan_model, gan_loss, gan_loss_no_reduction, get_eval_metric_ops_fn) else: # model_fn_lib.ModeKeys.TRAIN: gan_loss = tfgan_tuples.GANLoss( generator_loss=generator_loss_fn(gan_model, add_summaries=not is_on_tpu), discriminator_loss=discriminator_loss_fn( gan_model, add_summaries=not is_on_tpu)) # Construct optimizers if arguments were callable. For TPUs, they must be # `CrossShardOptimizer`. g_callable = callable(generator_optimizer) gopt = generator_optimizer() if g_callable else generator_optimizer d_callable = callable(discriminator_optimizer) dopt = discriminator_optimizer( ) if d_callable else discriminator_optimizer estimator_spec = _get_train_estimator_spec(gan_model, gan_loss, gopt, dopt, joint_train, gan_train_steps) return estimator_spec
def create_loss(self, features, mode, logits, labels): """Returns a GANLoss tuple from the provided GANModel. See `Head` for more details. Args: features: Input `dict` of `Tensor` objects. Unused. mode: Estimator's `ModeKeys`. logits: A GANModel tuple. labels: Must be `None`. Returns: A GANLoss tuple. """ _validate_logits_and_labels(logits, labels) del mode, labels, features # unused for this head. gan_model = logits # rename variable for clarity return tfgan_tuples.GANLoss( generator_loss=self._generator_loss_fn(gan_model), discriminator_loss=self._discriminator_loss_fn(gan_model))
def gan_loss( # GANModel. model, # Loss functions. generator_loss_fn=tfgan_losses.wasserstein_generator_loss, discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss, # Auxiliary losses. gradient_penalty_weight=None, gradient_penalty_epsilon=1e-10, mutual_information_penalty_weight=None, aux_cond_generator_weight=None, aux_cond_discriminator_weight=None, # Options. add_summaries=True): """Returns losses necessary to train generator and discriminator. Args: model: A GANModel tuple. generator_loss_fn: The loss function on the generator. Takes a GANModel tuple. discriminator_loss_fn: The loss function on the discriminator. Takes a GANModel tuple. gradient_penalty_weight: If not `None`, must be a non-negative Python number or Tensor indicating how much to weight the gradient penalty. See https://arxiv.org/pdf/1704.00028.pdf for more details. gradient_penalty_epsilon: If `gradient_penalty_weight` is not None, the small positive value used by the gradient penalty function for numerical stability. Note some applications will need to increase this value to avoid NaNs. mutual_information_penalty_weight: If not `None`, must be a non-negative Python number or Tensor indicating how much to weight the mutual information penalty. See https://arxiv.org/abs/1606.03657 for more details. aux_cond_generator_weight: If not None: add a classification loss as in https://arxiv.org/abs/1610.09585 aux_cond_discriminator_weight: If not None: add a classification loss as in https://arxiv.org/abs/1610.09585 add_summaries: Whether or not to add summaries for the losses. Returns: A GANLoss 2-tuple of (generator_loss, discriminator_loss). Includes regularization losses. Raises: ValueError: If any of the auxiliary loss weights is provided and negative. ValueError: If `mutual_information_penalty_weight` is provided, but the `model` isn't an `InfoGANModel`. """ # Validate arguments. gradient_penalty_weight = _validate_aux_loss_weight( gradient_penalty_weight, 'gradient_penalty_weight') mutual_information_penalty_weight = _validate_aux_loss_weight( mutual_information_penalty_weight, 'infogan_weight') aux_cond_generator_weight = _validate_aux_loss_weight( aux_cond_generator_weight, 'aux_cond_generator_weight') aux_cond_discriminator_weight = _validate_aux_loss_weight( aux_cond_discriminator_weight, 'aux_cond_discriminator_weight') # Verify configuration for mutual information penalty if (_use_aux_loss(mutual_information_penalty_weight) and not isinstance(model, namedtuples.InfoGANModel)): raise ValueError( 'When `mutual_information_penalty_weight` is provided, `model` must be ' 'an `InfoGANModel`. Instead, was %s.' % type(model)) # Verify configuration for mutual auxiliary condition loss (ACGAN). if ((_use_aux_loss(aux_cond_generator_weight) or _use_aux_loss(aux_cond_discriminator_weight)) and not isinstance(model, namedtuples.ACGANModel)): raise ValueError( 'When `aux_cond_generator_weight` or `aux_cond_discriminator_weight` ' 'is provided, `model` must be an `ACGANModel`. Instead, was %s.' % type(model)) # Create standard losses. gen_loss = generator_loss_fn(model, add_summaries=add_summaries) dis_loss = discriminator_loss_fn(model, add_summaries=add_summaries) # Add optional extra losses. if _use_aux_loss(gradient_penalty_weight): gp_loss = tfgan_losses.wasserstein_gradient_penalty( model, epsilon=gradient_penalty_epsilon, add_summaries=add_summaries) dis_loss += gradient_penalty_weight * gp_loss if _use_aux_loss(mutual_information_penalty_weight): info_loss = tfgan_losses.mutual_information_penalty( model, add_summaries=add_summaries) dis_loss += mutual_information_penalty_weight * info_loss gen_loss += mutual_information_penalty_weight * info_loss if _use_aux_loss(aux_cond_generator_weight): ac_gen_loss = tfgan_losses.acgan_generator_loss( model, add_summaries=add_summaries) gen_loss += aux_cond_generator_weight * ac_gen_loss if _use_aux_loss(aux_cond_discriminator_weight): ac_disc_loss = tfgan_losses.acgan_discriminator_loss( model, add_summaries=add_summaries) dis_loss += aux_cond_discriminator_weight * ac_disc_loss # Gathers auxilliary losses. if model.generator_scope: gen_reg_loss = losses.get_regularization_loss( model.generator_scope.name) else: gen_reg_loss = 0 if model.discriminator_scope: dis_reg_loss = losses.get_regularization_loss( model.discriminator_scope.name) else: dis_reg_loss = 0 return namedtuples.GANLoss(gen_loss + gen_reg_loss, dis_loss + dis_reg_loss)
def gan_loss( # GANModel. model, # Loss functions. generator_loss_fn=tfgan_losses.wasserstein_generator_loss, discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss, # Auxiliary losses. gradient_penalty_weight=None, gradient_penalty_epsilon=1e-10, gradient_penalty_target=1.0, gradient_penalty_one_sided=False, mutual_information_penalty_weight=None, aux_cond_generator_weight=None, aux_cond_discriminator_weight=None, tensor_pool_fn=None, # Options. add_summaries=True): """Returns losses necessary to train generator and discriminator. Args: model: A GANModel tuple. generator_loss_fn: The loss function on the generator. Takes a GANModel tuple. discriminator_loss_fn: The loss function on the discriminator. Takes a GANModel tuple. gradient_penalty_weight: If not `None`, must be a non-negative Python number or Tensor indicating how much to weight the gradient penalty. See https://arxiv.org/pdf/1704.00028.pdf for more details. gradient_penalty_epsilon: If `gradient_penalty_weight` is not None, the small positive value used by the gradient penalty function for numerical stability. Note some applications will need to increase this value to avoid NaNs. gradient_penalty_target: If `gradient_penalty_weight` is not None, a Python number or `Tensor` indicating the target value of gradient norm. See the CIFAR10 section of https://arxiv.org/abs/1710.10196. Defaults to 1.0. gradient_penalty_one_sided: If `True`, penalty proposed in https://arxiv.org/abs/1709.08894 is used. Defaults to `False`. mutual_information_penalty_weight: If not `None`, must be a non-negative Python number or Tensor indicating how much to weight the mutual information penalty. See https://arxiv.org/abs/1606.03657 for more details. aux_cond_generator_weight: If not None: add a classification loss as in https://arxiv.org/abs/1610.09585 aux_cond_discriminator_weight: If not None: add a classification loss as in https://arxiv.org/abs/1610.09585 tensor_pool_fn: A function that takes (generated_data, generator_inputs), stores them in an internal pool and returns previous stored (generated_data, generator_inputs). For example `tf.gan.features.tensor_pool`. Defaults to None (not using tensor pool). add_summaries: Whether or not to add summaries for the losses. Returns: A GANLoss 2-tuple of (generator_loss, discriminator_loss). Includes regularization losses. Raises: ValueError: If any of the auxiliary loss weights is provided and negative. ValueError: If `mutual_information_penalty_weight` is provided, but the `model` isn't an `InfoGANModel`. """ # Validate arguments. gradient_penalty_weight = _validate_aux_loss_weight( gradient_penalty_weight, 'gradient_penalty_weight') mutual_information_penalty_weight = _validate_aux_loss_weight( mutual_information_penalty_weight, 'infogan_weight') aux_cond_generator_weight = _validate_aux_loss_weight( aux_cond_generator_weight, 'aux_cond_generator_weight') aux_cond_discriminator_weight = _validate_aux_loss_weight( aux_cond_discriminator_weight, 'aux_cond_discriminator_weight') # Verify configuration for mutual information penalty if (_use_aux_loss(mutual_information_penalty_weight) and not isinstance(model, namedtuples.InfoGANModel)): raise ValueError( 'When `mutual_information_penalty_weight` is provided, `model` must be ' 'an `InfoGANModel`. Instead, was %s.' % type(model)) # Verify configuration for mutual auxiliary condition loss (ACGAN). if ((_use_aux_loss(aux_cond_generator_weight) or _use_aux_loss(aux_cond_discriminator_weight)) and not isinstance(model, namedtuples.ACGANModel)): raise ValueError( 'When `aux_cond_generator_weight` or `aux_cond_discriminator_weight` ' 'is provided, `model` must be an `ACGANModel`. Instead, was %s.' % type(model)) # Optionally create pooled model. pooled_model = (_tensor_pool_adjusted_model(model, tensor_pool_fn) if tensor_pool_fn else model) # Create standard losses. gen_loss = generator_loss_fn(model, add_summaries=add_summaries) dis_loss = discriminator_loss_fn(pooled_model, add_summaries=add_summaries) # Add optional extra losses. if _use_aux_loss(gradient_penalty_weight): gp_loss = tfgan_losses.wasserstein_gradient_penalty( pooled_model, epsilon=gradient_penalty_epsilon, target=gradient_penalty_target, one_sided=gradient_penalty_one_sided, add_summaries=add_summaries) dis_loss += gradient_penalty_weight * gp_loss if _use_aux_loss(mutual_information_penalty_weight): gen_info_loss = tfgan_losses.mutual_information_penalty( model, add_summaries=add_summaries) dis_info_loss = (gen_info_loss if tensor_pool_fn is None else tfgan_losses.mutual_information_penalty( pooled_model, add_summaries=add_summaries)) gen_loss += mutual_information_penalty_weight * gen_info_loss dis_loss += mutual_information_penalty_weight * dis_info_loss if _use_aux_loss(aux_cond_generator_weight): ac_gen_loss = tfgan_losses.acgan_generator_loss( model, add_summaries=add_summaries) gen_loss += aux_cond_generator_weight * ac_gen_loss if _use_aux_loss(aux_cond_discriminator_weight): ac_disc_loss = tfgan_losses.acgan_discriminator_loss( pooled_model, add_summaries=add_summaries) dis_loss += aux_cond_discriminator_weight * ac_disc_loss # Gathers auxiliary losses. if model.generator_scope: gen_reg_loss = losses.get_regularization_loss(model.generator_scope.name) else: gen_reg_loss = 0 if model.discriminator_scope: dis_reg_loss = losses.get_regularization_loss( model.discriminator_scope.name) else: dis_reg_loss = 0 return namedtuples.GANLoss(gen_loss + gen_reg_loss, dis_loss + dis_reg_loss)