def _test(y_pred, y, n_iters, metric_device): metric_device = torch.device(metric_device) ck = CohenKappa(device=metric_device) torch.manual_seed(10 + rank) ck.reset() ck.update((y_pred, y)) if n_iters > 1: batch_size = y.shape[0] // n_iters + 1 for i in range(n_iters): idx = i * batch_size ck.update( (y_pred[idx:idx + batch_size], y[idx:idx + batch_size])) # gather y_pred, y y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) np_y = y.cpu().numpy() np_y_pred = y_pred.cpu().numpy() res = ck.compute() assert isinstance(res, float) assert cohen_kappa_score(np_y, np_y_pred) == pytest.approx(res)
def test_input_types(): ck = CohenKappa() ck.reset() output1 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long)) ck.update(output1) with pytest.raises( ValueError, match= r"Incoherent types between input y_pred and stored predictions"): ck.update((torch.randint(0, 5, size=(4, 3)), torch.randint(0, 2, size=(4, 3)))) with pytest.raises( ValueError, match=r"Incoherent types between input y and stored targets"): ck.update((torch.rand(4, 3), torch.randint(0, 2, size=(4, 3)).to(torch.int32))) with pytest.raises( ValueError, match= r"Incoherent types between input y_pred and stored predictions"): ck.update((torch.randint(0, 2, size=(10, )).long(), torch.randint(0, 2, size=(10, 5)).long()))
def test_cohen_kappa_all_weights(weights): size = 100 np_y_pred = np.random.randint(0, 2, size=(size, 1), dtype=np.long) np_y = np.random.randint(0, 2, size=(size, 1), dtype=np.long) np_ck = cohen_kappa_score(np_y, np_y_pred) ck_metric = CohenKappa(weights=weights) y_pred = torch.from_numpy(np_y_pred) y = torch.from_numpy(np_y) ck_metric.reset() ck_metric.update((y_pred, y)) ck = ck_metric.compute() assert ck == pytest.approx(np_ck)
def test_multilabel_inputs(): ck = CohenKappa() with pytest.raises(ValueError, match=r"multilabel-indicator is not supported"): ck.reset() ck.update((torch.randint(0, 2, size=(10, 4)).long(), torch.randint(0, 2, size=(10, 4)).long())) ck.compute() with pytest.raises(ValueError, match=r"multilabel-indicator is not supported"): ck.reset() ck.update((torch.randint(0, 2, size=(10, 6)).long(), torch.randint(0, 2, size=(10, 6)).long())) ck.compute() with pytest.raises(ValueError, match=r"multilabel-indicator is not supported"): ck.reset() ck.update((torch.randint(0, 2, size=(10, 8)).long(), torch.randint(0, 2, size=(10, 8)).long())) ck.compute()