Пример #1
0
    def weighted(y_true, y_pred, weights, mask=None):
        """Wrapper function.

    Arguments:
        y_true: `y_true` argument of `fn`.
        y_pred: `y_pred` argument of `fn`.
        weights: Weights tensor.
        mask: Mask tensor.

    Returns:
        Scalar tensor.
    """
        # score_array has ndim >= 2
        score_array = fn(y_true, y_pred)
        if mask is not None:
            mask = math_ops.cast(mask, y_pred.dtype)
            # Update weights with mask.
            if weights is None:
                weights = mask
            else:
                # Update shape of weights if possible before adding mask.
                # Update dimensions of weights to match with mask if possible.
                mask, _, weights = metrics_module.squeeze_or_expand_dimensions(
                    mask, None, weights)
                try:
                    # Broadcast weights if possible.
                    weights = weights_broadcast_ops.broadcast_weights(
                        weights, mask)
                    weights *= mask
                except ValueError:
                    score_array *= mask
                    score_array /= K.mean(mask)
                    # TODO(psv): Handle case when mask and weight shapes are not
                    # compatible.

        # Apply sample weighting.
        if weights is not None:

            # Update dimensions of weights to match with values if possible.
            score_array, _, weights = metrics_module.squeeze_or_expand_dimensions(
                score_array, None, weights)
            try:
                # Broadcast weights if possible.
                weights = weights_broadcast_ops.broadcast_weights(
                    weights, score_array)
            except ValueError:
                # Reduce values to same ndim as weight array.
                ndim = K.ndim(score_array)
                weight_ndim = K.ndim(weights)
                score_array = K.mean(score_array,
                                     axis=list(range(weight_ndim, ndim)))

            score_array = math_ops.multiply(score_array, weights)
            score_array = math_ops.reduce_sum(score_array)
            weights = math_ops.reduce_sum(weights)
            score_array = metrics_module.safe_div(score_array, weights)
        return K.mean(score_array)
Пример #2
0
  def weighted(y_true, y_pred, weights, mask=None):
    """Wrapper function.

    Arguments:
        y_true: `y_true` argument of `fn`.
        y_pred: `y_pred` argument of `fn`.
        weights: Weights tensor.
        mask: Mask tensor.

    Returns:
        Scalar tensor.
    """
    # score_array has ndim >= 2
    score_array = fn(y_true, y_pred)
    if mask is not None:
      mask = math_ops.cast(mask, y_pred.dtype)
      # Update weights with mask.
      if weights is None:
        weights = mask
      else:
        # Update shape of weights if possible before adding mask.
        # Update dimensions of weights to match with mask if possible.
        mask, _, weights = metrics_module.squeeze_or_expand_dimensions(
            mask, None, weights)
        try:
          # Broadcast weights if possible.
          weights = weights_broadcast_ops.broadcast_weights(weights, mask)
          weights *= mask
        except ValueError:
          score_array *= mask
          score_array /= K.mean(mask)
          # TODO(psv): Handle case when mask and weight shapes are not
          # compatible.

    # Apply sample weighting.
    if weights is not None:

      # Update dimensions of weights to match with values if possible.
      score_array, _, weights = metrics_module.squeeze_or_expand_dimensions(
          score_array, None, weights)
      try:
        # Broadcast weights if possible.
        weights = weights_broadcast_ops.broadcast_weights(weights, score_array)
      except ValueError:
        # Reduce values to same ndim as weight array.
        ndim = K.ndim(score_array)
        weight_ndim = K.ndim(weights)
        score_array = K.mean(score_array, axis=list(range(weight_ndim, ndim)))

      score_array = math_ops.multiply(score_array, weights)
      score_array = math_ops.reduce_sum(score_array)
      weights = math_ops.reduce_sum(weights)
      score_array = metrics_module.safe_div(score_array, weights)
    return K.mean(score_array)
Пример #3
0
    def weighted(y_true, y_pred, weights, mask=None):
        """Wrapper function.

    Arguments:
        y_true: `y_true` argument of `fn`.
        y_pred: `y_pred` argument of `fn`.
        weights: Weights tensor.
        mask: Mask tensor.

    Returns:
        Scalar tensor.
    """
        # score_array has ndim >= 2
        score_array = fn(y_true, y_pred)
        if mask is not None:
            # Cast the mask to floatX to avoid float64 upcasting in theano
            mask = math_ops.cast(mask, K.floatx())
            # mask should have the same shape as score_array
            score_array *= mask
            #  the loss per batch should be proportional
            #  to the number of unmasked samples.
            score_array /= K.mean(mask)

        # Apply sample weighting.
        if weights is not None:

            # Update dimensions of weights to match with values if possible.
            score_array, _, weights = metrics_module.squeeze_or_expand_dimensions(
                score_array, None, weights)
            try:
                # Broadcast weights if possible.
                weights = weights_broadcast_ops.broadcast_weights(
                    weights, score_array)
            except ValueError:
                # Reduce values to same ndim as weight array.
                ndim = K.ndim(score_array)
                weight_ndim = K.ndim(weights)
                score_array = K.mean(score_array,
                                     axis=list(range(weight_ndim, ndim)))

            score_array = math_ops.multiply(score_array, weights)
            score_array = math_ops.reduce_sum(score_array)
            weights = math_ops.reduce_sum(weights)
            score_array = metrics_module.safe_div(score_array, weights)
        return K.mean(score_array)