def test_no_update(): roc_auc = ROC_AUC() with pytest.raises( NotComputableError, match=r"EpochMetric must have at least one example before it can be computed" ): roc_auc.compute()
def test_roc_auc_score_2(): np.random.seed(1) size = 100 np_y_pred = np.random.rand(size, 1) np_y = np.zeros((size, ), dtype=np.long) np_y[size // 2:] = 1 np.random.shuffle(np_y) np_roc_auc = roc_auc_score(np_y, np_y_pred) roc_auc_metric = ROC_AUC() y_pred = torch.from_numpy(np_y_pred) y = torch.from_numpy(np_y) roc_auc_metric.reset() n_iters = 10 batch_size = size // n_iters for i in range(n_iters): idx = i * batch_size roc_auc_metric.update( (y_pred[idx:idx + batch_size], y[idx:idx + batch_size])) roc_auc = roc_auc_metric.compute() assert roc_auc == np_roc_auc
def _test(y_pred, y, batch_size, metric_device): metric_device = torch.device(metric_device) roc_auc = ROC_AUC(device=metric_device) torch.manual_seed(10 + rank) roc_auc.reset() if batch_size > 1: n_iters = y.shape[0] // batch_size + 1 for i in range(n_iters): idx = i * batch_size roc_auc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) else: roc_auc.update((y_pred, y)) # gather y_pred, y y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) np_y = y.cpu().numpy() np_y_pred = y_pred.cpu().numpy() res = roc_auc.compute() assert isinstance(res, float) assert roc_auc_score(np_y, np_y_pred) == pytest.approx(res)
def test_roc_auc_score(): size = 100 np_y_pred = np.random.rand(size, 1) np_y = np.zeros((size, ), dtype=np.long) np_y[size // 2:] = 1 np_roc_auc = roc_auc_score(np_y, np_y_pred) roc_auc_metric = ROC_AUC() y_pred = torch.from_numpy(np_y_pred) y = torch.from_numpy(np_y) roc_auc_metric.reset() roc_auc_metric.update((y_pred, y)) roc_auc = roc_auc_metric.compute() assert roc_auc == np_roc_auc