def test_no_update(): ap = AveragePrecision() with pytest.raises( NotComputableError, match= r"EpochMetric must have at least one example before it can be computed" ): ap.compute()
def test_ap_score_2(): np.random.seed(1) size = 100 np_y_pred = np.random.rand(size, 1) np_y = np.zeros((size, ), dtype=np.long) np_y[size // 2:] = 1 np.random.shuffle(np_y) np_ap = average_precision_score(np_y, np_y_pred) ap_metric = AveragePrecision() y_pred = torch.from_numpy(np_y_pred) y = torch.from_numpy(np_y) ap_metric.reset() n_iters = 10 batch_size = size // n_iters for i in range(n_iters): idx = i * batch_size ap_metric.update( (y_pred[idx:idx + batch_size], y[idx:idx + batch_size])) ap = ap_metric.compute() assert ap == np_ap
def _test(y_pred, y, n_iters, metric_device): metric_device = torch.device(metric_device) ap = AveragePrecision(device=metric_device) torch.manual_seed(10 + rank) ap.reset() ap.update((y_pred, y)) if n_iters > 1: batch_size = y.shape[0] // n_iters + 1 for i in range(n_iters): idx = i * batch_size ap.update( (y_pred[idx:idx + batch_size], y[idx:idx + batch_size])) # gather y_pred, y y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) np_y = y.cpu().numpy() np_y_pred = y_pred.cpu().numpy() res = ap.compute() assert isinstance(res, float) assert average_precision_score(np_y, np_y_pred) == pytest.approx(res)
def test_ap_score(): size = 100 np_y_pred = np.random.rand(size, 5) np_y = np.random.randint(0, 2, size=(size, 5), dtype=np.long) np_ap = average_precision_score(np_y, np_y_pred) ap_metric = AveragePrecision() y_pred = torch.from_numpy(np_y_pred) y = torch.from_numpy(np_y) ap_metric.reset() ap_metric.update((y_pred, y)) ap = ap_metric.compute() assert ap == np_ap
def test_ap_score_with_activation(): size = 100 np_y_pred = np.random.rand(size, 5) np_y_pred_softmax = torch.softmax(torch.from_numpy(np_y_pred), dim=1).numpy() np_y = np.random.randint(0, 2, size=(size, 5), dtype=np.long) np_ap = average_precision_score(np_y, np_y_pred_softmax) ap_metric = AveragePrecision(activation=torch.nn.Softmax(dim=1)) y_pred = torch.from_numpy(np_y_pred) y = torch.from_numpy(np_y) ap_metric.reset() ap_metric.update((y_pred, y)) ap = ap_metric.compute() assert ap == np_ap