def test_warning_on_difference_in_number_of_classes(): """Test that warning is thrown if the detected number of classes are different from the the specified number of classes.""" preds = torch.randint(3, (10, )) target = torch.randint(3, (10, )) with pytest.warns(RuntimeWarning, ): jaccard_index(preds, target, num_classes=4)
def test(loader): model.eval() ious, categories = [], [] y_map = torch.empty(loader.dataset.num_classes, device=device).long() for data in loader: data = data.to(device) outs = model(data) sizes = (data.ptr[1:] - data.ptr[:-1]).tolist() for out, y, category in zip(outs.split(sizes), data.y.split(sizes), data.category.tolist()): category = list(ShapeNet.seg_classes.keys())[category] part = ShapeNet.seg_classes[category] part = torch.tensor(part, device=device) y_map[part] = torch.arange(part.size(0), device=device) iou = jaccard_index(out[:, part].argmax(dim=-1), y_map[y], num_classes=part.size(0), absent_score=1.0) ious.append(iou) categories.append(data.category) iou = torch.tensor(ious, device=device) category = torch.cat(categories, dim=0) mean_iou = scatter(iou, category, reduce='mean') # Per-category IoU. return float(mean_iou.mean()) # Global IoU.
def test_jaccard_ignore_index(pred, target, ignore_index, num_classes, reduction, expected): jaccard_val = jaccard_index( preds=tensor(pred), target=tensor(target), ignore_index=ignore_index, num_classes=num_classes, reduction=reduction, ) assert torch.allclose(jaccard_val, tensor(expected).to(jaccard_val))
def test_jaccard_absent_score(pred, target, ignore_index, absent_score, num_classes, expected): jaccard_val = jaccard_index( preds=tensor(pred), target=tensor(target), ignore_index=ignore_index, absent_score=absent_score, num_classes=num_classes, reduction="none", ) assert torch.allclose(jaccard_val, tensor(expected).to(jaccard_val))
def test_jaccard(half_ones, reduction, ignore_index, expected): preds = (torch.arange(120) % 3).view(-1, 1) target = (torch.arange(120) % 3).view(-1, 1) if half_ones: preds[:60] = 1 jaccard_val = jaccard_index( preds=preds, target=target, num_classes=3, ignore_index=ignore_index, reduction=reduction, ) assert torch.allclose(jaccard_val, expected, atol=1e-9)