def test_intersectionAndUnion_3classes(): """ (0,0) are matched once. (1,1) are matched once. (2,2) are matched once, giving us intersection [1,1,1] for those three classes. No way to compute union of two sets, without understanding where they intersect. Union of sets {0} union {0} -> {0} {0} union {1} -> {0,1} {2} union {2} -> {2} {1} union {1} -> {1} yields class counts [2,2,1] """ pred = np.array([[2, 0], [1, 0]]) target = np.array([[2, 0], [1, 1]]) num_classes = 3 # contain the number of samples in each bin. area_intersection, area_union, area_target = intersectionAndUnion( pred, target, K=num_classes, ignore_index=255) assert area_intersection.shape == (3, ) assert area_union.shape == (3, ) assert area_target.shape == (3, ) assert np.allclose(area_intersection, np.array([1, 1, 1])) assert np.allclose(area_target, np.array([1, 2, 1])) assert np.allclose(area_union, np.array([2, 2, 1]))
def test_intersectionAndUnion_2classes(): """ No way to compute union of two sets, without understanding where they intersect. """ pred = np.array([[0, 0], [1, 0]]) target = np.array([[0, 0], [1, 1]]) num_classes = 2 # contain the number of samples in each bin. area_intersection, area_union, area_target = intersectionAndUnion( pred, target, K=num_classes, ignore_index=255) assert np.allclose(area_intersection, np.array([2, 1])) assert np.allclose(area_target, np.array([2, 2])) assert np.allclose(area_union, np.array([3, 2]))
def test_test_intersectionAndUnion_ignore_label(): """ Handle the ignore case. Since 255 lies outside of the histogram bins, it will be ignored. """ pred = np.array([[1, 0], [1, 0]]) target = np.array([[255, 0], [255, 1]]) num_classes = 2 # contain the number of samples in each bin. area_intersection, area_union, area_target = intersectionAndUnion( pred, target, K=num_classes, ignore_index=255) assert np.allclose(area_intersection, np.array([1, 0])) assert np.allclose(area_target, np.array([1, 1])) assert np.allclose(area_union, np.array([2, 1]))
def update_metrics_cpu(self, pred, target, num_classes) -> None: """ Args: - pred - target - classes Returns: - None """ intersection, union, target = intersectionAndUnion(pred, target, num_classes) self.intersection_meter.update(intersection) self.union_meter.update(union) self.target_meter.update(target) self.accuracy = sum(self.intersection_meter.val) / (sum(self.target_meter.val) + 1e-10) self.intersection = 0.