def _test(y_pred, y, batch_size, metric_device): metric_device = torch.device(metric_device) prc = PrecisionRecallCurve(device=metric_device) torch.manual_seed(10 + rank) prc.reset() if batch_size > 1: n_iters = y.shape[0] // batch_size + 1 for i in range(n_iters): idx = i * batch_size prc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) else: prc.update((y_pred, y)) # gather y_pred, y y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) np_y = y.cpu().numpy() np_y_pred = y_pred.cpu().numpy() res = prc.compute() assert isinstance(res, Tuple) assert precision_recall_curve(np_y, np_y_pred)[0] == pytest.approx(res[0].cpu().numpy()) assert precision_recall_curve(np_y, np_y_pred)[1] == pytest.approx(res[1].cpu().numpy()) assert precision_recall_curve(np_y, np_y_pred)[2] == pytest.approx(res[2].cpu().numpy())
def test_no_sklearn(mock_no_sklearn): with pytest.raises( RuntimeError, match=r"This contrib module requires sklearn to be installed."): y = torch.tensor([1, 1]) pr_curve = PrecisionRecallCurve() pr_curve.update((y, y)) pr_curve.compute()
def test_check_compute_fn(): y_pred = torch.zeros((8, 13)) y_pred[:, 1] = 1 y_true = torch.zeros_like(y_pred) output = (y_pred, y_true) em = PrecisionRecallCurve(check_compute_fn=True) em.reset() with pytest.warns(EpochMetricWarning, match=r"Probably, there can be a problem with `compute_fn`"): em.update(output) em = PrecisionRecallCurve(check_compute_fn=False) em.update(output)
def test_precision_recall_curve(): size = 100 np_y_pred = np.random.rand(size, 1) np_y = np.zeros((size,), dtype=np.long) np_y[size // 2 :] = 1 sk_precision, sk_recall, sk_thresholds = precision_recall_curve(np_y, np_y_pred) precision_recall_curve_metric = PrecisionRecallCurve() y_pred = torch.from_numpy(np_y_pred) y = torch.from_numpy(np_y) precision_recall_curve_metric.update((y_pred, y)) precision, recall, thresholds = precision_recall_curve_metric.compute() assert np.array_equal(precision, sk_precision) assert np.array_equal(recall, sk_recall) # assert thresholds almost equal, due to numpy->torch->numpy conversion np.testing.assert_array_almost_equal(thresholds, sk_thresholds)
def test_precision_recall_curve(): size = 100 np_y_pred = np.random.rand(size, 1) np_y = np.zeros((size,)) np_y[size // 2 :] = 1 sk_precision, sk_recall, sk_thresholds = precision_recall_curve(np_y, np_y_pred) precision_recall_curve_metric = PrecisionRecallCurve() y_pred = torch.from_numpy(np_y_pred) y = torch.from_numpy(np_y) precision_recall_curve_metric.update((y_pred, y)) precision, recall, thresholds = precision_recall_curve_metric.compute() precision = precision.numpy() recall = recall.numpy() thresholds = thresholds.numpy() assert pytest.approx(precision) == sk_precision assert pytest.approx(recall) == sk_recall # assert thresholds almost equal, due to numpy->torch->numpy conversion np.testing.assert_array_almost_equal(thresholds, sk_thresholds)