Beispiel #1
0
  def __call__(self,
               y_true,
               y_pred,
               sample_weight=None,
               regularization_losses=None):
    """Computes the overall loss.

    Args:
      y_true: An arbitrary structure of Tensors representing the ground truth.
      y_pred: An arbitrary structure of Tensors representing a Model's outputs.
      sample_weight: An arbitrary structure of Tensors representing the
        per-sample loss weights. If one Tensor is passed, it is used for all
        losses. If multiple Tensors are passed, the structure should match
        `y_pred`.
      regularization_losses: Additional losses to be added to the total loss.

    Returns:
      Tuple of `(total_loss, per_output_loss_list)`
    """
    y_true = self._conform_to_outputs(y_pred, y_true)
    sample_weight = self._conform_to_outputs(y_pred, sample_weight)

    if not self._built:
      self.build(y_pred)

    y_pred = tf.nest.flatten(y_pred)
    y_true = tf.nest.flatten(y_true)
    sample_weight = tf.nest.flatten(sample_weight)

    loss_values = []  # Used for gradient calculation.
    loss_metric_values = []  # Used for loss metric calculation.
    batch_dim = None
    zip_args = (y_true, y_pred, sample_weight, self._losses, self._loss_weights,
                self._per_output_metrics)
    for y_t, y_p, sw, loss_obj, loss_weight, metric_obj in zip(*zip_args):
      if y_t is None or loss_obj is None:  # Ok to have no loss for an output.
        continue

      y_t, y_p, sw = match_dtype_and_rank(y_t, y_p, sw)
      sw = apply_mask(y_p, sw, get_mask(y_p))
      loss_value = loss_obj(y_t, y_p, sample_weight=sw)

      loss_metric_value = loss_value
      # Correct for the `Mean` loss metrics counting each replica as a batch.
      if loss_obj.reduction == losses_utils.ReductionV2.SUM:
        loss_metric_value *= tf.distribute.get_strategy().num_replicas_in_sync

      if batch_dim is None:
        batch_dim = tf.compat.v1.shape(y_t)[0]
      if metric_obj is not None:
        metric_obj.update_state(loss_metric_value, sample_weight=batch_dim)

      if loss_weight is not None:
        loss_value *= loss_weight
        loss_metric_value *= loss_weight

      if (loss_obj.reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE or
          loss_obj.reduction == losses_utils.ReductionV2.AUTO):
        loss_value = losses_utils.scale_loss_for_distribution(loss_value)

      loss_values.append(loss_value)
      loss_metric_values.append(loss_metric_value)

    if regularization_losses:
      regularization_losses = losses_utils.cast_losses_to_common_dtype(
          regularization_losses)
      reg_loss = tf.add_n(regularization_losses)
      loss_metric_values.append(reg_loss)
      loss_values.append(losses_utils.scale_loss_for_distribution(reg_loss))

    if loss_values:
      loss_metric_values = losses_utils.cast_losses_to_common_dtype(
          loss_metric_values)
      total_loss_metric_value = tf.add_n(loss_metric_values)
      self._loss_metric.update_state(
          total_loss_metric_value, sample_weight=batch_dim)

      loss_values = losses_utils.cast_losses_to_common_dtype(loss_values)
      total_loss = tf.add_n(loss_values)
      return total_loss
    else:
      # Ok for a model to have no compiled loss.
      return tf.zeros(shape=())
Beispiel #2
0
def _model_loss(model,
                inputs,
                targets,
                output_loss_metrics=None,
                sample_weights=None,
                training=False):
    """Calculates the loss for a given model.

  Arguments:
      model: The model on which metrics are being calculated.
      inputs: Either a dictionary of inputs to the model or a list of input
        arrays.
      targets: List of target arrays.
      output_loss_metrics: List of metrics that are used to aggregated output
        loss values.
      sample_weights: Optional list of sample weight arrays.
      training: Whether the model should be run in inference or training mode.

  Returns:
     Returns the model output, total loss, loss value calculated using the
     specified loss function and masks for each output. The total loss includes
     regularization losses and applies masking and sample weighting
     to the loss value.
  """
    # TODO(psv): Dedup code here with graph mode prepare_total_loss() fn.
    # Used to keep track of the total loss value (stateless).
    # eg., total_loss = loss_weight_1 * output_1_loss_fn(...) +
    #                   loss_weight_2 * output_2_loss_fn(...) +
    #                   layer losses.
    total_loss = 0
    kwargs = {}
    if model._expects_training_arg:
        kwargs['training'] = training
    if len(inputs) == 1 and not isinstance(inputs, dict):
        inputs = inputs[0]

    # Allow mixed `NumPy` and `EagerTensor` input here.
    if any(
            isinstance(input_t, (np.ndarray, float, int))
            for input_t in tf.nest.flatten(inputs)):
        inputs = tf.nest.map_structure(tf.convert_to_tensor, inputs)

    outs = model(inputs, **kwargs)
    outs = tf.nest.flatten(outs)

    if targets:
        targets = training_utils_v1.cast_if_floating_dtype_and_mismatch(
            targets, outs)
    # TODO(sallymatson/psv): check if we should do same mismatch fix for weights
    if sample_weights:
        sample_weights = [
            training_utils_v1.cast_if_floating_dtype(tf.convert_to_tensor(val))
            if val is not None else None for val in sample_weights
        ]

    masks = [getattr(t, '_keras_mask', None) for t in outs]
    targets = tf.nest.flatten(targets)

    # Used to keep track of individual output losses.
    output_losses = []

    with backend.name_scope('loss'):
        loss_fns = [
            loss_fn for loss_fn in model.loss_functions if loss_fn is not None
        ]
        custom_losses = model.losses  # Regularization losses

        if not loss_fns and not custom_losses:
            if training:
                raise ValueError('The model cannot be trained '
                                 'because it has no loss to optimize.')
            else:
                raise ValueError('The model cannot be evaluated '
                                 'because it has no loss to compute.')

        for i, loss_fn in enumerate(loss_fns):
            weights = sample_weights[i] if sample_weights else None
            mask = masks[i]
            with backend.name_scope(model.output_names[i] + '_loss'):
                if mask is not None:
                    mask = tf.cast(mask, outs[i].dtype)
                    # Update weights with mask.
                    if weights is None:
                        weights = mask
                    else:
                        # Update dimensions of weights to match with mask if possible.
                        weights = tf.cast(weights, outs[i].dtype)
                        mask, _, weights = (
                            losses_utils.squeeze_or_expand_dimensions(
                                mask, sample_weight=weights))
                        weights *= mask

                if hasattr(loss_fn, 'reduction'):
                    per_sample_losses = loss_fn.call(targets[i], outs[i])
                    weighted_losses = losses_utils.compute_weighted_loss(
                        per_sample_losses,
                        sample_weight=weights,
                        reduction=losses_utils.ReductionV2.NONE)
                    loss_reduction = loss_fn.reduction

                    # `AUTO` loss reduction defaults to `SUM_OVER_BATCH_SIZE` for all
                    # compile use cases.
                    if loss_reduction == losses_utils.ReductionV2.AUTO:
                        loss_reduction = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE

                    # Compute the stateless loss value.
                    output_loss = losses_utils.reduce_weighted_loss(
                        weighted_losses, reduction=loss_reduction)
                else:
                    # Compute the stateless loss value for a custom loss class.
                    # Here we assume that the class takes care of loss reduction
                    # because if this class returns a vector value we cannot
                    # differentiate between use case where a custom optimizer
                    # expects a vector loss value vs unreduced per-sample loss value.
                    output_loss = loss_fn(targets[i],
                                          outs[i],
                                          sample_weight=weights)
                    loss_reduction = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE

            # If the number of outputs is 1 then we don't append the loss metric
            # associated with each model output. When there are multiple outputs
            # associated with a model, each output's loss is calculated and returned
            # as part of the loss_metrics.
            if len(model.outputs) > 1:
                # Keep track of the stateful output loss result.
                output_losses.append(output_loss_metrics[i](output_loss))

            # Scale output loss for distribution. For custom losses we assume
            # reduction was mean.
            if loss_reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE:
                output_loss = losses_utils.scale_loss_for_distribution(
                    output_loss)
            total_loss += model._loss_weights_list[i] * output_loss

        # Add regularization losses
        if custom_losses:
            total_loss += losses_utils.scale_loss_for_distribution(
                tf.add_n(custom_losses))
    return outs, total_loss, output_losses, masks