示例#1
0
def test_accuracy():
    acc_calculator = Accuracy(average=False, topk=(2, ))

    # start with multiclass classification
    pred = torch.FloatTensor([[0.4, 0.5, 0.1], [0.1, 0.8, 0.1]])
    target = torch.LongTensor([[2], [1]])

    assert pred.size() == torch.Size([2, 3])

    # check for length of the data
    for i in range(2):
        acc_calculator.update(output=pred[i, :].unsqueeze(0),
                              target=target[i, :])

    assert acc_calculator.value.shape == (2, 1)
    assert np.allclose(acc_calculator.value, np.array([[0], [1]]))

    acc_calculator = Accuracy(average=True, topk=(2, ))

    # check for length of the data
    for i in range(2):
        acc_calculator.update(output=pred[i, :].unsqueeze(0),
                              target=target[i, :])

    assert acc_calculator.value == 0.5
示例#2
0
def test_that_accuracy_metric_calculates_multiclass_correctly(y_ypred):
    predicted, true = y_ypred
    predicted = predicted.softmax(-1) + torch.randn_like(predicted) / 100
    # test adding a batch:
    accuracy_calculator = Accuracy(average=True, topk=(1, ))
    accuracy_calculator.update(predicted, true)
    assert np.allclose(accuracy_calculator.value,
                       (true == predicted.argmax(-1)).float().mean().item())
    # test adding all elements in a batch separately
    new_accuracy_calculator = Accuracy(average=True, topk=(1, ))
    for x, y in zip(predicted, true):
        x, y = x.unsqueeze(0), y.unsqueeze(0)
        new_accuracy_calculator.update(x, y)
    assert np.allclose(new_accuracy_calculator.value,
                       accuracy_calculator.value)
示例#3
0
def test_that_accuracy_metric_calculates_top_k_correctly(y_ypred):
    predicted, true = y_ypred
    predicted = predicted.softmax(-1) + torch.randn_like(predicted) / 100
    last_accuracy = 0
    for k in range(1, predicted.size(1) + 1):
        accuracy_calculator = Accuracy(average=True, topk=(k, ))
        accuracy_calculator.update(predicted, true)
        assert 1 >= accuracy_calculator.value >= last_accuracy
        last_accuracy = accuracy_calculator.value

    accuracy_calculator = Accuracy(average=True,
                                   topk=tuple(range(1,
                                                    predicted.size(1) + 1)))
    accuracy_calculator.update(predicted, true)
    assert all(val_1 >= val_0 for val_0, val_1 in zip(
        accuracy_calculator.value[::2], accuracy_calculator.value[1::2]))
示例#4
0
def test_that_accuracy_raises_errors_when_shapes_dont_match():
    predicted = torch.randn(5, 3)
    true = torch.tensor([0, 0, 0])
    accuracy_calculator = Accuracy(average=True, topk=(1, ))
    with pytest.raises(ValueError):
        accuracy_calculator.update(predicted, true)
示例#5
0
def test_that_accuracy_string_repr_doesnt_throw_errors(y_ypred):
    predicted, true = y_ypred
    predicted = predicted.softmax(-1) + torch.randn_like(predicted) / 100
    accuracy = Accuracy(average=True, topk=(1, 2))
    accuracy.update(predicted, true)
    assert "±" in str(accuracy)