def test_accuracy(): acc_calculator = Accuracy(average=False, topk=(2, )) # start with multiclass classification pred = torch.FloatTensor([[0.4, 0.5, 0.1], [0.1, 0.8, 0.1]]) target = torch.LongTensor([[2], [1]]) assert pred.size() == torch.Size([2, 3]) # check for length of the data for i in range(2): acc_calculator.update(output=pred[i, :].unsqueeze(0), target=target[i, :]) assert acc_calculator.value.shape == (2, 1) assert np.allclose(acc_calculator.value, np.array([[0], [1]])) acc_calculator = Accuracy(average=True, topk=(2, )) # check for length of the data for i in range(2): acc_calculator.update(output=pred[i, :].unsqueeze(0), target=target[i, :]) assert acc_calculator.value == 0.5
def test_that_accuracy_metric_calculates_multiclass_correctly(y_ypred): predicted, true = y_ypred predicted = predicted.softmax(-1) + torch.randn_like(predicted) / 100 # test adding a batch: accuracy_calculator = Accuracy(average=True, topk=(1, )) accuracy_calculator.update(predicted, true) assert np.allclose(accuracy_calculator.value, (true == predicted.argmax(-1)).float().mean().item()) # test adding all elements in a batch separately new_accuracy_calculator = Accuracy(average=True, topk=(1, )) for x, y in zip(predicted, true): x, y = x.unsqueeze(0), y.unsqueeze(0) new_accuracy_calculator.update(x, y) assert np.allclose(new_accuracy_calculator.value, accuracy_calculator.value)
def test_that_accuracy_metric_calculates_top_k_correctly(y_ypred): predicted, true = y_ypred predicted = predicted.softmax(-1) + torch.randn_like(predicted) / 100 last_accuracy = 0 for k in range(1, predicted.size(1) + 1): accuracy_calculator = Accuracy(average=True, topk=(k, )) accuracy_calculator.update(predicted, true) assert 1 >= accuracy_calculator.value >= last_accuracy last_accuracy = accuracy_calculator.value accuracy_calculator = Accuracy(average=True, topk=tuple(range(1, predicted.size(1) + 1))) accuracy_calculator.update(predicted, true) assert all(val_1 >= val_0 for val_0, val_1 in zip( accuracy_calculator.value[::2], accuracy_calculator.value[1::2]))
def test_that_accuracy_raises_errors_when_shapes_dont_match(): predicted = torch.randn(5, 3) true = torch.tensor([0, 0, 0]) accuracy_calculator = Accuracy(average=True, topk=(1, )) with pytest.raises(ValueError): accuracy_calculator.update(predicted, true)
def test_that_accuracy_string_repr_doesnt_throw_errors(y_ypred): predicted, true = y_ypred predicted = predicted.softmax(-1) + torch.randn_like(predicted) / 100 accuracy = Accuracy(average=True, topk=(1, 2)) accuracy.update(predicted, true) assert "±" in str(accuracy)