Exemplo n.º 1
0
def test_that_precision_metric_calculates_multiclass_correctly(y_ypred):
    predicted, true = y_ypred
    predicted = predicted.softmax(-1) + torch.randn_like(predicted) / 100
    # test adding a batch:
    precision = Precision(num_classes=predicted.size(1), average=True)
    precision.update(predicted, true)
    true_positive = ((torch.zeros_like(predicted).scatter(
        1,
        predicted.argmax(-1).unsqueeze(-1), 1) > 0)
                     & (torch.zeros_like(predicted).scatter(
                         1, true.unsqueeze(-1), 1) > 0)).sum(dim=0)
    all_positive = (torch.zeros_like(predicted).scatter(
        1, true.unsqueeze(-1), 1) > 0).sum(dim=0)
    manual_precision = true_positive.float() / all_positive.float()
    manual_precision[torch.isnan(manual_precision)] = 0
    # breakpoint()
    assert np.allclose(precision.value, manual_precision.mean())
    # test adding all elements in a batch separately
    new_precision = Precision(num_classes=predicted.size(1), average=True)
    for x, y in zip(predicted, true):
        x, y = x.unsqueeze(0), y.unsqueeze(0)
        new_precision.update(x, y)
    assert np.allclose(new_precision.value, precision.value)
Exemplo n.º 2
0
def test_precision():
    prec_calculator = Precision(num_classes=3, average=True)

    # start with multiclass classification
    pred = torch.FloatTensor([[0.4, 0.5, 0.1], [0.1, 0.8, 0.1]])
    target = torch.LongTensor([[2], [1]])

    for i in range(2):
        prec_calculator.update(output=pred[i, :].unsqueeze(0),
                               target=target[i, :].unsqueeze(0))

    assert round(prec_calculator.value, 2) == 0.33

    prec_calculator = Precision(num_classes=3, average=False)

    # start with multiclass classification
    pred = torch.FloatTensor([[0.4, 0.5, 0.1], [0.1, 0.8, 0.1]])
    target = torch.LongTensor([[2], [1]])

    for i in range(2):
        prec_calculator.update(output=pred[i, :].unsqueeze(0),
                               target=target[i, :].unsqueeze(0))

    assert np.allclose(np.array(prec_calculator.value), np.array([0, 1, 0]))
Exemplo n.º 3
0
def test_that_precision_string_repr_doesnt_throw_errors(y_ypred):
    predicted, true = y_ypred
    predicted = predicted.softmax(-1) + torch.randn_like(predicted) / 100
    precision = Precision(num_classes=predicted.size(1), average=True)
    precision.update(predicted, true)
    assert "±" in str(precision)