def _test(y_pred, y, batch_size, metric_device): metric_device = torch.device(metric_device) prc = PrecisionRecallCurve(device=metric_device) torch.manual_seed(10 + rank) prc.reset() if batch_size > 1: n_iters = y.shape[0] // batch_size + 1 for i in range(n_iters): idx = i * batch_size prc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) else: prc.update((y_pred, y)) # gather y_pred, y y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) np_y = y.cpu().numpy() np_y_pred = y_pred.cpu().numpy() res = prc.compute() assert isinstance(res, Tuple) assert precision_recall_curve(np_y, np_y_pred)[0] == pytest.approx(res[0].cpu().numpy()) assert precision_recall_curve(np_y, np_y_pred)[1] == pytest.approx(res[1].cpu().numpy()) assert precision_recall_curve(np_y, np_y_pred)[2] == pytest.approx(res[2].cpu().numpy())
def test_check_compute_fn(): y_pred = torch.zeros((8, 13)) y_pred[:, 1] = 1 y_true = torch.zeros_like(y_pred) output = (y_pred, y_true) em = PrecisionRecallCurve(check_compute_fn=True) em.reset() with pytest.warns(EpochMetricWarning, match=r"Probably, there can be a problem with `compute_fn`"): em.update(output) em = PrecisionRecallCurve(check_compute_fn=False) em.update(output)