def test_iou_ignore_index(pred, target, ignore_index, num_classes, reduction, expected): iou_val = iou( preds=tensor(pred), target=tensor(target), ignore_index=ignore_index, num_classes=num_classes, reduction=reduction, ) assert torch.allclose(iou_val, tensor(expected).to(iou_val))
def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes, expected): iou_val = iou( preds=tensor(pred), target=tensor(target), ignore_index=ignore_index, absent_score=absent_score, num_classes=num_classes, reduction='none', ) assert torch.allclose(iou_val, tensor(expected).to(iou_val))
def test_iou(half_ones, reduction, ignore_index, expected): pred = (torch.arange(120) % 3).view(-1, 1) target = (torch.arange(120) % 3).view(-1, 1) if half_ones: pred[:60] = 1 iou_val = iou( pred=pred, target=target, ignore_index=ignore_index, reduction=reduction, ) assert torch.allclose(iou_val, expected, atol=1e-9)