Exemplo n.º 1
0
    def test_loss_wrapper(self):
        loss_fn = losses.get('mse')
        mse_obj = losses.LossFunctionWrapper(loss_fn, name=loss_fn.__name__)

        assert mse_obj.name == 'mean_squared_error'
        assert (
            mse_obj.reduction == losses_utils.Reduction.SUM_OVER_BATCH_SIZE)

        y_true = K.constant([[1., 9.], [2., 5.]])
        y_pred = K.constant([[4., 8.], [12., 3.]])
        sample_weight = K.constant([1.2, 0.5])
        loss = mse_obj(y_true, y_pred, sample_weight=sample_weight)

        # mse = [((4 - 1)^2 + (8 - 9)^2) / 2, ((12 - 2)^2 + (3 - 5)^2) / 2]
        # mse = [5, 52]
        # weighted_mse = [5 * 1.2, 52 * 0.5] = [6, 26]
        # reduced_weighted_mse = (6 + 26) / 2 =
        np.allclose(K.eval(loss), 16, atol=1e-2)
Exemplo n.º 2
0
  def _get_loss_object(self, loss):
    """Returns a `Loss` object.

    Converts the user-supplied loss to a `Loss` object. Also allows
    `SUM_OVER_BATCH_SIZE` reduction to be used for this loss.

    Args:
      loss: A string, function, or `Loss` object.

    Returns:
      A `Loss` object.
    """
    if loss is None:
      return None  # Ok to have no loss for an output.

    loss = losses_mod.get(loss)
    if not isinstance(loss, losses_mod.Loss):
      loss_name = get_custom_object_name(loss)
      if loss_name is None:
        raise ValueError('Loss should be a callable, found: {}'.format(loss))
      loss = losses_mod.LossFunctionWrapper(loss, name=loss_name)
    loss._allow_sum_over_batch_size = True  # pylint: disable=protected-access
    return loss