def test_accuracy(): # test for empty pred pred = torch.empty(0, 4) label = torch.empty(0) accuracy = Accuracy(topk=1) acc = accuracy(pred, label) assert acc.item() == 0 pred = torch.Tensor([[0.2, 0.3, 0.6, 0.5], [0.1, 0.1, 0.2, 0.6], [0.9, 0.0, 0.0, 0.1], [0.4, 0.7, 0.1, 0.1], [0.0, 0.0, 0.99, 0]]) # test for top1 true_label = torch.Tensor([2, 3, 0, 1, 2]).long() accuracy = Accuracy(topk=1) acc = accuracy(pred, true_label) assert acc.item() == 100 # test for top1 with score thresh=0.8 true_label = torch.Tensor([2, 3, 0, 1, 2]).long() accuracy = Accuracy(topk=1, thresh=0.8) acc = accuracy(pred, true_label) assert acc.item() == 40 # test for top2 accuracy = Accuracy(topk=2) label = torch.Tensor([3, 2, 0, 0, 2]).long() acc = accuracy(pred, label) assert acc.item() == 100 # test for both top1 and top2 accuracy = Accuracy(topk=(1, 2)) true_label = torch.Tensor([2, 3, 0, 1, 2]).long() acc = accuracy(pred, true_label) for a in acc: assert a.item() == 100 # topk is larger than pred class number with pytest.raises(AssertionError): accuracy = Accuracy(topk=5) accuracy(pred, true_label) # wrong topk type with pytest.raises(AssertionError): accuracy = Accuracy(topk='wrong type') accuracy(pred, true_label) # label size is larger than required with pytest.raises(AssertionError): label = torch.Tensor([2, 3, 0, 1, 2, 0]).long() # size mismatch accuracy = Accuracy() accuracy(pred, label) # wrong pred dimension with pytest.raises(AssertionError): accuracy = Accuracy() accuracy(pred[:, :, None], true_label)
def test_accuracy(): # test for empty pred pred = torch.empty(0, 4) label = torch.empty(0) accuracy = Accuracy(topk=1) acc = accuracy(pred, label) assert acc.item() == 0 pred = torch.Tensor([[0.2, 0.3, 0.6, 0.5], [0.1, 0.1, 0.2, 0.6], [0.9, 0.0, 0.0, 0.1], [0.4, 0.7, 0.1, 0.1], [0.0, 0.0, 0.99, 0]]) # test for ignore_index true_label = torch.Tensor([2, 3, 0, 1, 2]).long() accuracy = Accuracy(topk=1, ignore_index=None) acc = accuracy(pred, true_label) assert torch.allclose(acc, torch.tensor(100.0)) # test for ignore_index with a wrong prediction of that index true_label = torch.Tensor([2, 3, 1, 1, 2]).long() accuracy = Accuracy(topk=1, ignore_index=1) acc = accuracy(pred, true_label) assert torch.allclose(acc, torch.tensor(100.0)) # test for ignore_index 1 with a wrong prediction of other index true_label = torch.Tensor([2, 0, 0, 1, 2]).long() accuracy = Accuracy(topk=1, ignore_index=1) acc = accuracy(pred, true_label) assert torch.allclose(acc, torch.tensor(75.0)) # test for ignore_index 4 with a wrong prediction of other index true_label = torch.Tensor([2, 0, 0, 1, 2]).long() accuracy = Accuracy(topk=1, ignore_index=4) acc = accuracy(pred, true_label) assert torch.allclose(acc, torch.tensor(80.0)) # test for ignoring all the pixels true_label = torch.Tensor([2, 2, 2, 2, 2]).long() accuracy = Accuracy(topk=1, ignore_index=2) acc = accuracy(pred, true_label) assert torch.allclose(acc, torch.tensor(100.0)) # test for top1 true_label = torch.Tensor([2, 3, 0, 1, 2]).long() accuracy = Accuracy(topk=1) acc = accuracy(pred, true_label) assert torch.allclose(acc, torch.tensor(100.0)) # test for top1 with score thresh=0.8 true_label = torch.Tensor([2, 3, 0, 1, 2]).long() accuracy = Accuracy(topk=1, thresh=0.8) acc = accuracy(pred, true_label) assert torch.allclose(acc, torch.tensor(40.0)) # test for top2 accuracy = Accuracy(topk=2) label = torch.Tensor([3, 2, 0, 0, 2]).long() acc = accuracy(pred, label) assert torch.allclose(acc, torch.tensor(100.0)) # test for both top1 and top2 accuracy = Accuracy(topk=(1, 2)) true_label = torch.Tensor([2, 3, 0, 1, 2]).long() acc = accuracy(pred, true_label) for a in acc: assert torch.allclose(a, torch.tensor(100.0)) # topk is larger than pred class number with pytest.raises(AssertionError): accuracy = Accuracy(topk=5) accuracy(pred, true_label) # wrong topk type with pytest.raises(AssertionError): accuracy = Accuracy(topk='wrong type') accuracy(pred, true_label) # label size is larger than required with pytest.raises(AssertionError): label = torch.Tensor([2, 3, 0, 1, 2, 0]).long() # size mismatch accuracy = Accuracy() accuracy(pred, label) # wrong pred dimension with pytest.raises(AssertionError): accuracy = Accuracy() accuracy(pred[:, :, None], true_label)