def test_get_global_pred_target_pairs(self, batched): torch.random.manual_seed(42) num_classes = 3 pred_heatmap = torch.rand(num_classes + 4, 10, 10) # classes {0, 1} present, class 3 not present target = torch.tensor( [ [0, 0, 2.1, 2.1, 0], [0, 0, 2, 2, 0], [3, 3, 6, 6, 0], [5, 5, 10, 9, 1], [1, 1, 4.9, 4.9, 1], [-1, -1, -1, -1, -1], ] ).float() if batched: pred_heatmap = pred_heatmap.unsqueeze_(0).expand(2, -1, -1, -1) target = target.unsqueeze_(0).expand(2, -1, -1) mixin = CenterNetMixin() result = mixin.get_global_pred_target_pairs(pred_heatmap, target) # expected pred is the max over the heatmap expected_pred = pred_heatmap[..., :-4, :, :].max(dim=-1).values.max(dim=-1).values # expected target is 1, 1, 0 for classes 0, 1 present / 2 not present expected_target = torch.tensor([1.0, 1.0, 0.0]) if batched: expected_target = expected_target.unsqueeze_(0).expand(2, -1) assert torch.allclose(result[..., 0], expected_pred) assert torch.allclose(result[..., 1], expected_target)
def get_global_pred_target_pairs_on_batch(type_heatmap: Tensor, tar_bbox: Tensor, tar_type: Optional[Tensor], pad_value: float = -1, **kwargs) -> Tensor: # compute pred target pairs for types target = CenterNetMixin.combine_box_target(tar_bbox, tar_type) type_pairs = CenterNetMixin.get_global_pred_target_pairs( type_heatmap, target, pad_value=pad_value, **kwargs) return type_pairs