예제 #1
0
    def test_get_perspective_transform3d(self, batch_size, device, dtype):
        # generate input data
        d_max, h_max, w_max = 16, 64, 32  # height, width
        d = torch.ceil(d_max * torch.rand(batch_size, device=device, dtype=dtype))
        h = torch.ceil(h_max * torch.rand(batch_size, device=device, dtype=dtype))
        w = torch.ceil(w_max * torch.rand(batch_size, device=device, dtype=dtype))

        norm = torch.rand(batch_size, 8, 3, device=device, dtype=dtype)
        points_src = torch.rand_like(norm, device=device, dtype=dtype)
        points_dst = points_src + norm

        # compute transform from source to target
        dst_homo_src = kornia.get_perspective_transform3d(points_src, points_dst)

        # TODO: get_perspective_transform3d seems to be correct since it would result in the
        # expected output for cropping volumes. Not sure what is going on here.
        assert_allclose(
            kornia.transform_points(dst_homo_src, points_src), points_dst, rtol=1e-4, atol=1e-4)

        # compute gradient check
        points_src = utils.tensor_to_gradcheck_var(points_src)  # to var
        points_dst = utils.tensor_to_gradcheck_var(points_dst)  # to var
        assert gradcheck(
            kornia.get_perspective_transform3d, (
                points_src,
                points_dst,
            ),
            raise_exception=True)
예제 #2
0
    def test_get_perspective_transform3d_2(self, batch_size, device, dtype):
        torch.manual_seed(0)
        src = kornia.bbox_generator3d(
            torch.randint_like(torch.ones(batch_size), 0, 50, dtype=dtype),
            torch.randint_like(torch.ones(batch_size), 0, 50, dtype=dtype),
            torch.randint_like(torch.ones(batch_size), 0, 50, dtype=dtype),
            torch.randint(0, 50, (1, ), dtype=dtype).repeat(batch_size),
            torch.randint(0, 50, (1, ), dtype=dtype).repeat(batch_size),
            torch.randint(0, 50, (1, ), dtype=dtype).repeat(batch_size),
        ).to(device=device, dtype=dtype)
        dst = kornia.bbox_generator3d(
            torch.randint_like(torch.ones(batch_size), 0, 50, dtype=dtype),
            torch.randint_like(torch.ones(batch_size), 0, 50, dtype=dtype),
            torch.randint_like(torch.ones(batch_size), 0, 50, dtype=dtype),
            torch.randint(0, 50, (1, ), dtype=dtype).repeat(batch_size),
            torch.randint(0, 50, (1, ), dtype=dtype).repeat(batch_size),
            torch.randint(0, 50, (1, ), dtype=dtype).repeat(batch_size),
        ).to(device=device, dtype=dtype)
        out = kornia.get_perspective_transform3d(src, dst)
        if batch_size == 1:
            expected = torch.tensor(
                [[
                    [3.3000, 0.0000, 0.0000, -118.2000],
                    [0.0000, 0.0769, 0.0000, 0.0000],
                    [0.0000, 0.0000, 0.5517, 28.7930],
                    [0.0000, 0.0000, 0.0000, 1.0000],
                ]],
                device=device,
                dtype=dtype,
            )
        if batch_size == 2:
            expected = torch.tensor(
                [
                    [
                        [0.9630, 0.0000, 0.0000, -9.3702],
                        [0.0000, 2.0000, 0.0000, -49.9999],
                        [0.0000, 0.0000, 0.3830, 44.0213],
                        [0.0000, 0.0000, 0.0000, 1.0000],
                    ],
                    [
                        [0.9630, 0.0000, 0.0000, -36.5555],
                        [0.0000, 2.0000, 0.0000, -14.0000],
                        [0.0000, 0.0000, 0.3830, 16.8940],
                        [0.0000, 0.0000, 0.0000, 1.0000],
                    ],
                ],
                device=device,
                dtype=dtype,
            )

        assert_close(out, expected, rtol=1e-4, atol=1e-4)

        # compute gradient check
        points_src = utils.tensor_to_gradcheck_var(src)  # to var
        points_dst = utils.tensor_to_gradcheck_var(dst)  # to var
        assert gradcheck(kornia.get_perspective_transform3d,
                         (points_src, points_dst),
                         raise_exception=True)