Esempio n. 1
0
    def test_crop(self, device, dtype):
        inp = torch.arange(0., 64., device=device,
                           dtype=dtype).view(1, 1, 4, 4, 4)

        depth, height, width = 2, 2, 2
        expected = torch.tensor([[[[[25.1667, 27.1667], [30.5000, 32.5000]],
                                   [[46.5000, 48.5000], [51.8333, 53.8333]]]]],
                                device=device,
                                dtype=dtype)

        boxes = torch.tensor([[
            [0, 0, 1],
            [3, 0, 1],
            [3, 2, 1],
            [0, 2, 1],
            [0, 0, 3],
            [3, 0, 3],
            [3, 2, 3],
            [0, 2, 3],
        ]],
                             device=device,
                             dtype=dtype)  # 1x8x3

        patches = kornia.crop_and_resize3d(inp, boxes, (depth, height, width))
        assert_allclose(patches, expected)
Esempio n. 2
0
    def test_crop_batch(self, device, dtype):
        inp = torch.cat(
            [
                torch.arange(0.0, 64.0, device=device, dtype=dtype).view(1, 1, 4, 4, 4),
                torch.arange(0.0, 128.0, step=2, device=device, dtype=dtype).view(1, 1, 4, 4, 4),
            ],
            dim=0,
        )

        depth, height, width = 2, 2, 2
        expected = torch.tensor(
            [
                [[[[16.0000, 19.0000], [24.0000, 27.0000]], [[48.0000, 51.0000], [56.0000, 59.0000]]]],
                [[[[0.0000, 6.0000], [16.0000, 22.0000]], [[64.0000, 70.0000], [80.0000, 86.0000]]]],
            ],
            device=device,
            dtype=dtype,
        )

        boxes = torch.tensor(
            [
                [[0, 0, 1], [3, 0, 1], [3, 2, 1], [0, 2, 1], [0, 0, 3], [3, 0, 3], [3, 2, 3], [0, 2, 3]],
                [[0, 0, 0], [3, 0, 0], [3, 2, 0], [0, 2, 0], [0, 0, 2], [3, 0, 2], [3, 2, 2], [0, 2, 2]],
            ],
            device=device,
            dtype=dtype,
        )  # 2x8x3

        patches = kornia.crop_and_resize3d(inp, boxes, (depth, height, width), align_corners=True)
        assert_allclose(patches, expected)