예제 #1
0
    def test_get_global_pred_target_pairs(self, batched):
        torch.random.manual_seed(42)
        num_classes = 3

        pred_heatmap = torch.rand(num_classes + 4, 10, 10)

        # classes {0, 1} present, class 3 not present
        target = torch.tensor(
            [
                [0, 0, 2.1, 2.1, 0],
                [0, 0, 2, 2, 0],
                [3, 3, 6, 6, 0],
                [5, 5, 10, 9, 1],
                [1, 1, 4.9, 4.9, 1],
                [-1, -1, -1, -1, -1],
            ]
        ).float()

        if batched:
            pred_heatmap = pred_heatmap.unsqueeze_(0).expand(2, -1, -1, -1)
            target = target.unsqueeze_(0).expand(2, -1, -1)

        mixin = CenterNetMixin()
        result = mixin.get_global_pred_target_pairs(pred_heatmap, target)

        # expected pred is the max over the heatmap
        expected_pred = pred_heatmap[..., :-4, :, :].max(dim=-1).values.max(dim=-1).values

        # expected target is 1, 1, 0 for classes 0, 1 present / 2 not present
        expected_target = torch.tensor([1.0, 1.0, 0.0])
        if batched:
            expected_target = expected_target.unsqueeze_(0).expand(2, -1)

        assert torch.allclose(result[..., 0], expected_pred)
        assert torch.allclose(result[..., 1], expected_target)
예제 #2
0
 def get_global_pred_target_pairs_on_batch(type_heatmap: Tensor,
                                           tar_bbox: Tensor,
                                           tar_type: Optional[Tensor],
                                           pad_value: float = -1,
                                           **kwargs) -> Tensor:
     # compute pred target pairs for types
     target = CenterNetMixin.combine_box_target(tar_bbox, tar_type)
     type_pairs = CenterNetMixin.get_global_pred_target_pairs(
         type_heatmap, target, pad_value=pad_value, **kwargs)
     return type_pairs