示例#1
0
    def test_split_point_target(self):
        torch.random.manual_seed(42)
        heatmap_target = torch.rand(3, 2, 10, 10)
        regression_target = torch.rand(3, 4, 10, 10)
        target = torch.cat([heatmap_target, regression_target], dim=-3)

        mixin = CenterNetMixin()
        heatmap, regression = mixin.split_point_target(target)

        assert torch.allclose(heatmap, heatmap_target)
        assert torch.allclose(regression, regression_target)
示例#2
0
    def test_split_point_target_returns_views(self):
        torch.random.manual_seed(42)
        heatmap_target = torch.rand(3, 2, 10, 10)
        regression_target = torch.rand(3, 4, 10, 10)
        target = torch.cat([heatmap_target, regression_target], dim=-3)

        mixin = CenterNetMixin()
        heatmap, regression = mixin.split_point_target(target)

        target.mul_(10)
        assert torch.allclose(heatmap, target[..., :-4, :, :])
        assert torch.allclose(regression, target[..., -4:, :, :])