예제 #1
0
def test_warning_on_difference_in_number_of_classes():
    """Test that warning is thrown if the detected number of classes are different from the the specified number of
    classes."""
    preds = torch.randint(3, (10, ))
    target = torch.randint(3, (10, ))
    with pytest.warns(RuntimeWarning, ):
        jaccard_index(preds, target, num_classes=4)
예제 #2
0
def test(loader):
    model.eval()

    ious, categories = [], []
    y_map = torch.empty(loader.dataset.num_classes, device=device).long()
    for data in loader:
        data = data.to(device)
        outs = model(data)

        sizes = (data.ptr[1:] - data.ptr[:-1]).tolist()
        for out, y, category in zip(outs.split(sizes), data.y.split(sizes),
                                    data.category.tolist()):
            category = list(ShapeNet.seg_classes.keys())[category]
            part = ShapeNet.seg_classes[category]
            part = torch.tensor(part, device=device)

            y_map[part] = torch.arange(part.size(0), device=device)

            iou = jaccard_index(out[:, part].argmax(dim=-1), y_map[y],
                                num_classes=part.size(0), absent_score=1.0)
            ious.append(iou)

        categories.append(data.category)

    iou = torch.tensor(ious, device=device)
    category = torch.cat(categories, dim=0)

    mean_iou = scatter(iou, category, reduce='mean')  # Per-category IoU.
    return float(mean_iou.mean())  # Global IoU.
예제 #3
0
def test_jaccard_ignore_index(pred, target, ignore_index, num_classes, reduction, expected):
    jaccard_val = jaccard_index(
        preds=tensor(pred),
        target=tensor(target),
        ignore_index=ignore_index,
        num_classes=num_classes,
        reduction=reduction,
    )
    assert torch.allclose(jaccard_val, tensor(expected).to(jaccard_val))
예제 #4
0
def test_jaccard_absent_score(pred, target, ignore_index, absent_score, num_classes, expected):
    jaccard_val = jaccard_index(
        preds=tensor(pred),
        target=tensor(target),
        ignore_index=ignore_index,
        absent_score=absent_score,
        num_classes=num_classes,
        reduction="none",
    )
    assert torch.allclose(jaccard_val, tensor(expected).to(jaccard_val))
예제 #5
0
def test_jaccard(half_ones, reduction, ignore_index, expected):
    preds = (torch.arange(120) % 3).view(-1, 1)
    target = (torch.arange(120) % 3).view(-1, 1)
    if half_ones:
        preds[:60] = 1
    jaccard_val = jaccard_index(
        preds=preds,
        target=target,
        num_classes=3,
        ignore_index=ignore_index,
        reduction=reduction,
    )
    assert torch.allclose(jaccard_val, expected, atol=1e-9)