def test_user_warning_recall(): y_pred = np.array([0.04, 0.04, 0.10]) y_true = np.array([1., -1., 1.]) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") recall(y_true, y_pred) assert issubclass(w[-1].category, UserWarning) assert "sorting method used" in str(w[-1].message)
def test_non_trivial_recall_trapz(): y_pred = np.array([0.25, 0.45, 0.60, 0.90]) y_true = np.array([1., 1., 0., 1.]) rec = recall(y_true, y_pred) reference = np.array([1. / 3., 1. / 3., 2. / 3., 1.]) assert (np.abs(rec - reference)).sum() < 1e-6