def test_filter_heatmap_classes(self, return_inverse, keep_classes, with_regression): torch.random.manual_seed(42) mixin = CenterNetMixin() possible_classes = set([0, 1, 2]) if return_inverse: drop_classes = set(keep_classes) real_keep_classes = possible_classes - drop_classes else: drop_classes = possible_classes - set(keep_classes) real_keep_classes = keep_classes if with_regression: heatmap = torch.rand(2, len(possible_classes) + 4, 32, 32) else: heatmap = torch.rand(2, len(possible_classes), 32, 32) result = mixin.filter_heatmap_classes( heatmap, keep_classes=keep_classes, return_inverse=return_inverse, with_regression=with_regression ) assert result.shape[-3] == heatmap.shape[-3] - len(drop_classes) assert result.shape[-2:] == heatmap.shape[-2:] assert result.shape[0] == heatmap.shape[0] if with_regression: expected = heatmap[..., tuple(real_keep_classes) + (-4, -3, -2, -1), :, :] else: expected = heatmap[..., tuple(real_keep_classes), :, :] assert torch.allclose(result, expected)
def test_get_global_pred_target_pairs(self, batched): torch.random.manual_seed(42) num_classes = 3 pred_heatmap = torch.rand(num_classes + 4, 10, 10) # classes {0, 1} present, class 3 not present target = torch.tensor( [ [0, 0, 2.1, 2.1, 0], [0, 0, 2, 2, 0], [3, 3, 6, 6, 0], [5, 5, 10, 9, 1], [1, 1, 4.9, 4.9, 1], [-1, -1, -1, -1, -1], ] ).float() if batched: pred_heatmap = pred_heatmap.unsqueeze_(0).expand(2, -1, -1, -1) target = target.unsqueeze_(0).expand(2, -1, -1) mixin = CenterNetMixin() result = mixin.get_global_pred_target_pairs(pred_heatmap, target) # expected pred is the max over the heatmap expected_pred = pred_heatmap[..., :-4, :, :].max(dim=-1).values.max(dim=-1).values # expected target is 1, 1, 0 for classes 0, 1 present / 2 not present expected_target = torch.tensor([1.0, 1.0, 0.0]) if batched: expected_target = expected_target.unsqueeze_(0).expand(2, -1) assert torch.allclose(result[..., 0], expected_pred) assert torch.allclose(result[..., 1], expected_target)
def test_split_regression(self): torch.random.manual_seed(42) regression = torch.rand(3, 4, 10, 10) mixin = CenterNetMixin() offset, size = mixin.split_regression(regression) assert torch.allclose(offset, regression[..., :2, :, :]) assert torch.allclose(size, regression[..., 2:, :, :])
def test_combine_regression(self): torch.random.manual_seed(42) regression = torch.rand(3, 4, 10, 10) mixin = CenterNetMixin() offset = regression[..., :2, :, :] size = regression[..., 2:, :, :] result = mixin.combine_regression(offset, size) assert torch.allclose(result, regression)
def test_append_heatmap_label(self, label_size): torch.random.manual_seed(42) old_label = torch.rand(3, 6, 10, 10) new_label = torch.rand(3, label_size, 10, 10) mixin = CenterNetMixin() final_label = mixin.append_heatmap_label(old_label, new_label) assert torch.allclose(final_label[..., :2, :, :], old_label[..., :2, :, :]) assert torch.allclose(final_label[..., -4:, :, :], old_label[..., -4:, :, :]) assert torch.allclose(final_label[..., 2:-4, :, :], new_label)
def test_combine_point_target(self): torch.random.manual_seed(42) heatmap_target = torch.rand(3, 2, 10, 10) regression_target = torch.rand(3, 4, 10, 10) true_target = torch.cat([heatmap_target, regression_target], dim=-3) mixin = CenterNetMixin() target = mixin.combine_point_target(heatmap_target, regression_target) assert torch.allclose(target, true_target)
def get_global_pred_target_pairs_on_batch(type_heatmap: Tensor, tar_bbox: Tensor, tar_type: Optional[Tensor], pad_value: float = -1, **kwargs) -> Tensor: # compute pred target pairs for types target = CenterNetMixin.combine_box_target(tar_bbox, tar_type) type_pairs = CenterNetMixin.get_global_pred_target_pairs( type_heatmap, target, pad_value=pad_value, **kwargs) return type_pairs
def test_heatmap_max_score(self): torch.random.manual_seed(42) num_classes = 3 heatmap = torch.rand(3, num_classes + 4, 10, 10) expected = heatmap[..., :num_classes, :, :].max(dim=-1).values.max(dim=-1).values mixin = CenterNetMixin() actual = mixin.heatmap_max_score(heatmap) assert actual.ndim == 2 assert torch.allclose(actual, expected)
def test_split_point_target(self): torch.random.manual_seed(42) heatmap_target = torch.rand(3, 2, 10, 10) regression_target = torch.rand(3, 4, 10, 10) target = torch.cat([heatmap_target, regression_target], dim=-3) mixin = CenterNetMixin() heatmap, regression = mixin.split_point_target(target) assert torch.allclose(heatmap, heatmap_target) assert torch.allclose(regression, regression_target)
def test_split_point_target_returns_views(self): torch.random.manual_seed(42) heatmap_target = torch.rand(3, 2, 10, 10) regression_target = torch.rand(3, 4, 10, 10) target = torch.cat([heatmap_target, regression_target], dim=-3) mixin = CenterNetMixin() heatmap, regression = mixin.split_point_target(target) target.mul_(10) assert torch.allclose(heatmap, target[..., :-4, :, :]) assert torch.allclose(regression, target[..., -4:, :, :])
def test_visualize_heatmap_no_background(self): torch.random.manual_seed(42) heatmap = torch.rand(3, 2, 10, 10) mixin = CenterNetMixin() result = mixin.visualize_heatmap(heatmap) assert len(result) == heatmap.shape[-3] for x in result: assert x.min() >= 0 assert x.max() <= 255 assert x.dtype == torch.uint8
def get_local_pred_target_pairs_on_batch( type_heatmap: Tensor, tar_bbox: Tensor, tar_type: Tensor, pad_value: float = -1, **kwargs) -> Union[Tensor, List[Tensor]]: target = CenterNetMixin.combine_box_target(tar_bbox, tar_type) type_pairs = CenterNetMixin.get_pred_target_pairs(type_heatmap, target, upsample, pad_value=pad_value, **kwargs) return type_pairs