예제 #1
0
    def test_filter_heatmap_classes(self, return_inverse, keep_classes, with_regression):
        torch.random.manual_seed(42)
        mixin = CenterNetMixin()

        possible_classes = set([0, 1, 2])
        if return_inverse:
            drop_classes = set(keep_classes)
            real_keep_classes = possible_classes - drop_classes
        else:
            drop_classes = possible_classes - set(keep_classes)
            real_keep_classes = keep_classes

        if with_regression:
            heatmap = torch.rand(2, len(possible_classes) + 4, 32, 32)
        else:
            heatmap = torch.rand(2, len(possible_classes), 32, 32)

        result = mixin.filter_heatmap_classes(
            heatmap, keep_classes=keep_classes, return_inverse=return_inverse, with_regression=with_regression
        )

        assert result.shape[-3] == heatmap.shape[-3] - len(drop_classes)
        assert result.shape[-2:] == heatmap.shape[-2:]
        assert result.shape[0] == heatmap.shape[0]

        if with_regression:
            expected = heatmap[..., tuple(real_keep_classes) + (-4, -3, -2, -1), :, :]
        else:
            expected = heatmap[..., tuple(real_keep_classes), :, :]

        assert torch.allclose(result, expected)
예제 #2
0
    def test_get_global_pred_target_pairs(self, batched):
        torch.random.manual_seed(42)
        num_classes = 3

        pred_heatmap = torch.rand(num_classes + 4, 10, 10)

        # classes {0, 1} present, class 3 not present
        target = torch.tensor(
            [
                [0, 0, 2.1, 2.1, 0],
                [0, 0, 2, 2, 0],
                [3, 3, 6, 6, 0],
                [5, 5, 10, 9, 1],
                [1, 1, 4.9, 4.9, 1],
                [-1, -1, -1, -1, -1],
            ]
        ).float()

        if batched:
            pred_heatmap = pred_heatmap.unsqueeze_(0).expand(2, -1, -1, -1)
            target = target.unsqueeze_(0).expand(2, -1, -1)

        mixin = CenterNetMixin()
        result = mixin.get_global_pred_target_pairs(pred_heatmap, target)

        # expected pred is the max over the heatmap
        expected_pred = pred_heatmap[..., :-4, :, :].max(dim=-1).values.max(dim=-1).values

        # expected target is 1, 1, 0 for classes 0, 1 present / 2 not present
        expected_target = torch.tensor([1.0, 1.0, 0.0])
        if batched:
            expected_target = expected_target.unsqueeze_(0).expand(2, -1)

        assert torch.allclose(result[..., 0], expected_pred)
        assert torch.allclose(result[..., 1], expected_target)
예제 #3
0
    def test_split_regression(self):
        torch.random.manual_seed(42)
        regression = torch.rand(3, 4, 10, 10)

        mixin = CenterNetMixin()
        offset, size = mixin.split_regression(regression)

        assert torch.allclose(offset, regression[..., :2, :, :])
        assert torch.allclose(size, regression[..., 2:, :, :])
예제 #4
0
    def test_combine_regression(self):
        torch.random.manual_seed(42)
        regression = torch.rand(3, 4, 10, 10)

        mixin = CenterNetMixin()
        offset = regression[..., :2, :, :]
        size = regression[..., 2:, :, :]

        result = mixin.combine_regression(offset, size)
        assert torch.allclose(result, regression)
예제 #5
0
    def test_append_heatmap_label(self, label_size):
        torch.random.manual_seed(42)
        old_label = torch.rand(3, 6, 10, 10)
        new_label = torch.rand(3, label_size, 10, 10)

        mixin = CenterNetMixin()
        final_label = mixin.append_heatmap_label(old_label, new_label)
        assert torch.allclose(final_label[..., :2, :, :], old_label[..., :2, :, :])
        assert torch.allclose(final_label[..., -4:, :, :], old_label[..., -4:, :, :])
        assert torch.allclose(final_label[..., 2:-4, :, :], new_label)
예제 #6
0
    def test_combine_point_target(self):
        torch.random.manual_seed(42)
        heatmap_target = torch.rand(3, 2, 10, 10)
        regression_target = torch.rand(3, 4, 10, 10)
        true_target = torch.cat([heatmap_target, regression_target], dim=-3)

        mixin = CenterNetMixin()
        target = mixin.combine_point_target(heatmap_target, regression_target)

        assert torch.allclose(target, true_target)
예제 #7
0
 def get_global_pred_target_pairs_on_batch(type_heatmap: Tensor,
                                           tar_bbox: Tensor,
                                           tar_type: Optional[Tensor],
                                           pad_value: float = -1,
                                           **kwargs) -> Tensor:
     # compute pred target pairs for types
     target = CenterNetMixin.combine_box_target(tar_bbox, tar_type)
     type_pairs = CenterNetMixin.get_global_pred_target_pairs(
         type_heatmap, target, pad_value=pad_value, **kwargs)
     return type_pairs
예제 #8
0
    def test_heatmap_max_score(self):
        torch.random.manual_seed(42)
        num_classes = 3
        heatmap = torch.rand(3, num_classes + 4, 10, 10)
        expected = heatmap[..., :num_classes, :, :].max(dim=-1).values.max(dim=-1).values

        mixin = CenterNetMixin()
        actual = mixin.heatmap_max_score(heatmap)

        assert actual.ndim == 2
        assert torch.allclose(actual, expected)
예제 #9
0
    def test_split_point_target(self):
        torch.random.manual_seed(42)
        heatmap_target = torch.rand(3, 2, 10, 10)
        regression_target = torch.rand(3, 4, 10, 10)
        target = torch.cat([heatmap_target, regression_target], dim=-3)

        mixin = CenterNetMixin()
        heatmap, regression = mixin.split_point_target(target)

        assert torch.allclose(heatmap, heatmap_target)
        assert torch.allclose(regression, regression_target)
예제 #10
0
    def test_split_point_target_returns_views(self):
        torch.random.manual_seed(42)
        heatmap_target = torch.rand(3, 2, 10, 10)
        regression_target = torch.rand(3, 4, 10, 10)
        target = torch.cat([heatmap_target, regression_target], dim=-3)

        mixin = CenterNetMixin()
        heatmap, regression = mixin.split_point_target(target)

        target.mul_(10)
        assert torch.allclose(heatmap, target[..., :-4, :, :])
        assert torch.allclose(regression, target[..., -4:, :, :])
예제 #11
0
    def test_visualize_heatmap_no_background(self):
        torch.random.manual_seed(42)
        heatmap = torch.rand(3, 2, 10, 10)

        mixin = CenterNetMixin()
        result = mixin.visualize_heatmap(heatmap)

        assert len(result) == heatmap.shape[-3]

        for x in result:
            assert x.min() >= 0
            assert x.max() <= 255
            assert x.dtype == torch.uint8
예제 #12
0
 def get_local_pred_target_pairs_on_batch(
         type_heatmap: Tensor,
         tar_bbox: Tensor,
         tar_type: Tensor,
         pad_value: float = -1,
         **kwargs) -> Union[Tensor, List[Tensor]]:
     target = CenterNetMixin.combine_box_target(tar_bbox, tar_type)
     type_pairs = CenterNetMixin.get_pred_target_pairs(type_heatmap,
                                                       target,
                                                       upsample,
                                                       pad_value=pad_value,
                                                       **kwargs)
     return type_pairs