def test_split_point_target(self): torch.random.manual_seed(42) heatmap_target = torch.rand(3, 2, 10, 10) regression_target = torch.rand(3, 4, 10, 10) target = torch.cat([heatmap_target, regression_target], dim=-3) mixin = CenterNetMixin() heatmap, regression = mixin.split_point_target(target) assert torch.allclose(heatmap, heatmap_target) assert torch.allclose(regression, regression_target)
def test_split_point_target_returns_views(self): torch.random.manual_seed(42) heatmap_target = torch.rand(3, 2, 10, 10) regression_target = torch.rand(3, 4, 10, 10) target = torch.cat([heatmap_target, regression_target], dim=-3) mixin = CenterNetMixin() heatmap, regression = mixin.split_point_target(target) target.mul_(10) assert torch.allclose(heatmap, target[..., :-4, :, :]) assert torch.allclose(regression, target[..., -4:, :, :])