def test_user_warning_trapz(): y_pred = np.array([0.04, 0.04, 0.10]) y_true = np.array([1., -1., 1.]) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") average_precision(y_true, y_pred, integration='trapz') assert issubclass(w[-1].category, UserWarning) assert "sorting method used" in str(w[-1].message)
def test_non_trivial_perfect_ap_voc2010(): y_pred = np.array([0.82, 0.75, 0.60, 0.90]) y_true = np.array([1., 1., 0., 1.]) ap = average_precision(y_true, y_pred, integration='voc2010') reference = 1. assert abs(ap - reference) < 1e-6
def test_non_trivial_ap_trapz(): y_pred = np.array([0.25, 0.45, 0.60, 0.90]) y_true = np.array([1., 1., 0., 1.]) ap = average_precision(y_true, y_pred, integration='trapz') reference = 0.7638888888888888 assert abs(ap - reference) < 1e-6
def test_perfect_pos_predictions_voc2010(): y_pred = np.array([0.92, 0.99, 0.97]) y_true = np.array([1., 1., 1.]) ap = average_precision(y_true, y_pred, integration='voc2010') reference = 1. assert abs(ap - reference) < 1e-6