def test_loss_wrapper(self): loss_fn = losses.get('mse') mse_obj = losses.LossFunctionWrapper(loss_fn, name=loss_fn.__name__) assert mse_obj.name == 'mean_squared_error' assert ( mse_obj.reduction == losses_utils.Reduction.SUM_OVER_BATCH_SIZE) y_true = K.constant([[1., 9.], [2., 5.]]) y_pred = K.constant([[4., 8.], [12., 3.]]) sample_weight = K.constant([1.2, 0.5]) loss = mse_obj(y_true, y_pred, sample_weight=sample_weight) # mse = [((4 - 1)^2 + (8 - 9)^2) / 2, ((12 - 2)^2 + (3 - 5)^2) / 2] # mse = [5, 52] # weighted_mse = [5 * 1.2, 52 * 0.5] = [6, 26] # reduced_weighted_mse = (6 + 26) / 2 = np.allclose(K.eval(loss), 16, atol=1e-2)
def _get_loss_object(self, loss): """Returns a `Loss` object. Converts the user-supplied loss to a `Loss` object. Also allows `SUM_OVER_BATCH_SIZE` reduction to be used for this loss. Args: loss: A string, function, or `Loss` object. Returns: A `Loss` object. """ if loss is None: return None # Ok to have no loss for an output. loss = losses_mod.get(loss) if not isinstance(loss, losses_mod.Loss): loss_name = get_custom_object_name(loss) if loss_name is None: raise ValueError('Loss should be a callable, found: {}'.format(loss)) loss = losses_mod.LossFunctionWrapper(loss, name=loss_name) loss._allow_sum_over_batch_size = True # pylint: disable=protected-access return loss