Exemplo n.º 1
0
def init_volume_boundary_pointcloud(
    batch_size: int,
    volume_size: Tuple[int, int, int],
    n_points: int,
    interp_mode: str,
    device: str,
    require_grad: bool = False,
):
    """
    Initialize a point cloud that closely follows a boundary of
    a volume with a given size. The volume buffer is initialized as well.
    """

    # generate a 3D point cloud sampled from sides of a [0,1] cube
    xyz, rgb = init_cube_point_cloud(batch_size,
                                     n_points=n_points,
                                     device=device,
                                     rotate_y=True)

    # make volume_size tensor
    volume_size_t = torch.tensor(volume_size,
                                 dtype=xyz.dtype,
                                 device=xyz.device)

    if interp_mode == "trilinear":
        # make the xyz locations fall on the boundary of the
        # first/last two voxels along each spatial dimension of the
        # volume - this properly checks the correctness of the
        # trilinear interpolation scheme
        xyz = (xyz - 0.5) * ((volume_size_t - 2) /
                             (volume_size_t - 1))[[2, 1, 0]] + 0.5

    # rescale the cube pointcloud to overlap with the volume sides
    # of the volume
    rel_scale = volume_size_t / volume_size[0]
    xyz = xyz * rel_scale[[2, 1, 0]][None, None]

    # enable grad accumulation for the differentiability check
    xyz.requires_grad = require_grad
    rgb.requires_grad = require_grad

    # create the pointclouds structure
    pointclouds = Pointclouds(xyz, features=rgb)

    # set the volume translation so that the point cloud is centered
    # around 0
    volume_translation = -0.5 * rel_scale[[2, 1, 0]]

    # set the voxel size to 1 / (volume_size-1)
    volume_voxel_size = 1 / (volume_size[0] - 1.0)

    # instantiate the volumes
    initial_volumes = Volumes(
        features=xyz.new_zeros(batch_size, 3, *volume_size),
        densities=xyz.new_zeros(batch_size, 1, *volume_size),
        volume_translation=volume_translation,
        voxel_size=volume_voxel_size,
    )

    return pointclouds, initial_volumes
Exemplo n.º 2
0
    def test_coord_transforms(self, num_volumes=3, num_channels=4, dtype=torch.float32):
        """
        Test the correctness of the conversion between the internal
        Transform3D Volumes._local_to_world_transform and the initialization
        from the translation and voxel_size.
        """

        device = torch.device("cuda:0")

        # try for 10 sets of different random sizes/centers/voxel_sizes
        for _ in range(10):

            size = torch.randint(high=10, size=(3,), low=3).tolist()

            densities = torch.randn(
                size=[num_volumes, num_channels, *size],
                device=device,
                dtype=torch.float32,
            )

            # init the transformation params
            volume_translation = torch.randn(num_volumes, 3)
            voxel_size = torch.rand(num_volumes, 3) * 3.0 + 0.5

            # get the corresponding Transform3d object
            local_offset = torch.tensor(list(size), dtype=torch.float32, device=device)[
                [2, 1, 0]
            ][None].repeat(num_volumes, 1)
            local_to_world_transform = (
                Scale(0.5 * local_offset - 0.5, device=device)
                .scale(voxel_size)
                .translate(-volume_translation)
            )

            # init the volume structures with the scale and translation,
            # then get the coord grid in world coords
            v_trans_vs = Volumes(
                densities=densities,
                voxel_size=voxel_size,
                volume_translation=volume_translation,
            )
            grid_rot_trans_vs = v_trans_vs.get_coord_grid(world_coordinates=True)

            # map the default local coords to the world coords
            # with local_to_world_transform
            v_default = Volumes(densities=densities)
            grid_default_local = v_default.get_coord_grid(world_coordinates=False)
            grid_default_world = local_to_world_transform.transform_points(
                grid_default_local.view(num_volumes, -1, 3)
            ).view(num_volumes, *size, 3)

            # check that both grids are the same
            self.assertClose(grid_rot_trans_vs, grid_default_world, atol=1e-5)

            # check that the transformations are the same
            self.assertClose(
                v_trans_vs.get_local_to_world_coords_transform().get_matrix(),
                local_to_world_transform.get_matrix(),
                atol=1e-5,
            )
Exemplo n.º 3
0
    def test_to(self,
                num_volumes=3,
                num_channels=4,
                size=(6, 8, 10),
                dtype=torch.float32):
        """
        Test the moving of the volumes from/to gpu and cpu
        """

        device = torch.device("cuda:0")
        device_cpu = torch.device("cpu")

        features = torch.randn(size=[num_volumes, num_channels, *size],
                               device=device,
                               dtype=torch.float32)
        densities = torch.rand(size=[num_volumes, 1, *size],
                               device=device,
                               dtype=dtype)

        for features_ in (features, None):
            v = Volumes(densities=densities, features=features_)

            v_cpu = v.cpu()
            v_cuda = v_cpu.cuda()
            v_cuda_2 = v_cuda.cuda()
            v_cpu_2 = v_cuda_2.cpu()

            for v1, v2 in itertools.combinations(
                (v, v_cpu, v_cpu_2, v_cuda, v_cuda_2), 2):
                if v1 is v_cuda and v2 is v_cuda_2:
                    # checks that we do not copy if the devices stay the same
                    assert_fun = self.assertIs
                else:
                    assert_fun = self.assertSeparate
                assert_fun(v1._densities, v2._densities)
                if features_ is not None:
                    assert_fun(v1._features, v2._features)
                for v_ in (v1, v2):
                    if v_ in (v_cpu, v_cpu_2):
                        self._check_vars_on_device(v_, device_cpu)
                    else:
                        self._check_vars_on_device(v_, device)
Exemplo n.º 4
0
    def test_clone(
        self, num_volumes=3, num_channels=4, size=(6, 8, 10), dtype=torch.float32
    ):
        """
        Test cloning of a `Volumes` object
        """

        device = torch.device("cuda:0")

        features = torch.randn(
            size=[num_volumes, num_channels, *size], device=device, dtype=torch.float32
        )
        densities = torch.rand(
            size=[num_volumes, 1, *size], device=device, dtype=torch.float32
        )

        for has_features in (True, False):
            v = Volumes(
                densities=densities, features=features if has_features else None
            )
            vnew = v.clone()
            vnew._densities.data[0, 0, 0, 0, 0] += 1.0
            self.assertNotAlmostEqual(
                float(
                    (vnew.densities()[0, 0, 0, 0, 0] - v.densities()[0, 0, 0, 0, 0])
                    .abs()
                    .max()
                ),
                0.0,
            )

            if has_features:
                vnew._features.data[0, 0, 0, 0, 0] += 1.0
                self.assertNotAlmostEqual(
                    float(
                        (vnew.features()[0, 0, 0, 0, 0] - v.features()[0, 0, 0, 0, 0])
                        .abs()
                        .max()
                    ),
                    0.0,
                )
Exemplo n.º 5
0
    def test_coord_grid_transforms(
        self, num_volumes=3, num_channels=4, dtype=torch.float32
    ):
        """
        Test whether conversion between local-world coordinates of the
        volume returns correct results.
        """

        device = torch.device("cuda:0")

        # try for 10 sets of different random sizes/centers/voxel_sizes
        for _ in range(10):

            size = torch.randint(high=10, size=(3,), low=3).tolist()

            center = torch.randn(num_volumes, 3, dtype=torch.float32, device=device)
            voxel_size = torch.rand(1, dtype=torch.float32, device=device) * 5.0 + 0.5

            for densities in (
                torch.randn(
                    size=[num_volumes, num_channels, *size],
                    device=device,
                    dtype=torch.float32,
                ),
                TestVolumes._random_volume_list(
                    num_volumes, 3, size, num_channels, device, rand_sizes=None
                )[0],
            ):

                # init the volume structure
                v = Volumes(
                    densities=densities,
                    voxel_size=voxel_size,
                    volume_translation=-center,
                )

                # get local coord grid
                grid_local = v.get_coord_grid(world_coordinates=False)

                # convert from world to local to world
                grid_world = v.get_coord_grid(world_coordinates=True)
                grid_local_2 = v.world_to_local_coords(grid_world)
                grid_world_2 = v.local_to_world_coords(grid_local_2)

                # assertions on shape and values of grid_world and grid_local
                self.assertClose(grid_world, grid_world_2, atol=1e-5)
                self.assertClose(grid_local, grid_local_2, atol=1e-5)

                # check that the individual slices of the location grid have
                # constant values along expected dimensions
                for plane_dim in (1, 2, 3):
                    for grid_plane in grid_world.split(1, dim=plane_dim):
                        grid_coord_dim = {1: 2, 2: 1, 3: 0}[plane_dim]
                        grid_coord_plane = grid_plane.squeeze()[..., grid_coord_dim]
                        # check that all elements of grid_coord_plane are
                        # the same for each batch element
                        self.assertClose(
                            grid_coord_plane.reshape(num_volumes, -1).max(dim=1).values,
                            grid_coord_plane.reshape(num_volumes, -1).min(dim=1).values,
                        )
Exemplo n.º 6
0
    def test_coord_grid_convention(
        self, num_volumes=3, num_channels=4, dtype=torch.float32
    ):
        """
        Check that for a trivial volume with spatial size DxHxW=5x7x5:
        1) xyz_world=(0, 0, 0) lands right in the middle of the volume
        with xyz_local=(0, 0, 0).
        2) xyz_world=(-2, 3, 2) results in xyz_local=(-1, 1, -1).
        3) The centeral voxel of the volume coordinate grid
        has coords x_world=(0, 0, 0) and x_local=(0, 0, 0)
        4) grid_sampler(world_coordinate_grid, local_coordinate_grid)
        is the same as world_coordinate_grid itself. I.e. the local coordinate
        grid matches the grid_sampler coordinate convention.
        """

        device = torch.device("cuda:0")

        densities = torch.randn(
            size=[num_volumes, num_channels, 5, 7, 5],
            device=device,
            dtype=torch.float32,
        )
        v_trivial = Volumes(densities=densities)

        # check the case with x_world=(0,0,0)
        pts_world = torch.zeros(num_volumes, 1, 3, device=device, dtype=torch.float32)
        pts_local = v_trivial.world_to_local_coords(pts_world)
        pts_local_expected = torch.zeros_like(pts_local)
        self.assertClose(pts_local, pts_local_expected)

        # check the case with x_world=(-2, 3, -2)
        pts_world = torch.tensor([-2, 3, -2], device=device, dtype=torch.float32)[
            None, None
        ].repeat(num_volumes, 1, 1)
        pts_local = v_trivial.world_to_local_coords(pts_world)
        pts_local_expected = torch.tensor(
            [-1, 1, -1], device=device, dtype=torch.float32
        )[None, None].repeat(num_volumes, 1, 1)
        self.assertClose(pts_local, pts_local_expected)

        # check that the central voxel has coords x_world=(0, 0, 0) and x_local(0, 0, 0)
        grid_world = v_trivial.get_coord_grid(world_coordinates=True)
        grid_local = v_trivial.get_coord_grid(world_coordinates=False)
        for grid in (grid_world, grid_local):
            x0 = grid[0, :, :, 2, 0]
            y0 = grid[0, :, 3, :, 1]
            z0 = grid[0, 2, :, :, 2]
            for coord_line in (x0, y0, z0):
                self.assertClose(coord_line, torch.zeros_like(coord_line), atol=1e-7)

        # resample grid_world using grid_sampler with local coords
        # -> make sure the resampled version is the same as original
        grid_world_resampled = torch.nn.functional.grid_sample(
            grid_world.permute(0, 4, 1, 2, 3), grid_local, align_corners=True
        ).permute(0, 2, 3, 4, 1)
        self.assertClose(grid_world_resampled, grid_world, atol=1e-7)
Exemplo n.º 7
0
    def test_unscaled(self):
        D = 5
        P = 1000
        B, C, H, W = 2, 3, D, D
        densities = torch.zeros(B, 1, D, H, W)
        features = torch.zeros(B, C, D, H, W)
        volumes = Volumes(densities=densities, features=features)
        points = torch.rand(B, 1000, 3) * (D - 1) - ((D - 1) * 0.5)
        point_features = torch.rand(B, 1000, C)
        pointclouds = Pointclouds(points=points, features=point_features)

        volumes2 = add_pointclouds_to_volumes(pointclouds,
                                              volumes,
                                              rescale_features=False)
        self.assertConstant(volumes2.densities().sum([2, 3, 4]) / P,
                            1,
                            atol=1e-5)
        self.assertConstant(volumes2.features().sum([2, 3, 4]) / P,
                            0.5,
                            atol=0.03)
Exemplo n.º 8
0
    def test_constructor(self,
                         num_volumes=3,
                         num_channels=4,
                         size=(6, 8, 10),
                         dtype=torch.float32):
        """
        Test different ways of calling the `Volumes` constructor
        """

        device = torch.device("cuda:0")

        # all ways to define features
        features = [
            torch.randn(
                size=[num_volumes, num_channels, *size],
                device=device,
                dtype=torch.float32,
            ),  # padded tensor
            torch.randn(
                size=[num_volumes, num_channels, *size],
                device=device,
                dtype=torch.float32,
            ).unbind(0),  # list of features
            None,  # no features
        ]

        # bad ways to define features
        bad_features = [
            torch.randn(
                size=[num_volumes, num_channels, 2, *size],
                device=device,
                dtype=torch.float32,
            ),  # 6 dims
            torch.randn(size=[num_volumes, *size],
                        device=device,
                        dtype=torch.float32),  # 4 dims
            torch.randn(
                size=[num_volumes, *size], device=device,
                dtype=torch.float32).unbind(0),  # list of 4 dim tensors
        ]

        # all ways to define densities
        densities = [
            torch.randn(size=[num_volumes, 1, *size],
                        device=device,
                        dtype=torch.float32),  # padded tensor
            torch.randn(size=[num_volumes, 1, *size],
                        device=device,
                        dtype=torch.float32).unbind(0),  # list of densities
        ]

        # bad ways to define densities
        bad_densities = [
            None,  # omitted
            torch.randn(size=[num_volumes, 1, 1, *size],
                        device=device,
                        dtype=torch.float32),  # 6-dim tensor
            torch.randn(
                size=[num_volumes, 1, 1, *size],
                device=device,
                dtype=torch.float32).unbind(0),  # list of 5-dim densities
        ]

        # all possible ways to define the voxels sizes
        vox_sizes = [
            torch.Tensor([1.0, 1.0, 1.0]),
            [1.0, 1.0, 1.0],
            torch.Tensor([1.0, 1.0, 1.0])[None].repeat(num_volumes, 1),
            torch.Tensor([1.0])[None].repeat(num_volumes, 1),
            1.0,
            torch.Tensor([1.0]),
        ]

        # all possible ways to define the volume translations
        vol_translations = [
            torch.Tensor([1.0, 1.0, 1.0]),
            [1.0, 1.0, 1.0],
            torch.Tensor([1.0, 1.0, 1.0])[None].repeat(num_volumes, 1),
        ]

        # wrong ways to define voxel sizes
        bad_vox_sizes = [
            torch.Tensor([1.0, 1.0, 1.0, 1.0]),
            [1.0, 1.0, 1.0, 1.0],
            torch.Tensor([]),
            None,
        ]

        # wrong ways to define the volume translations
        bad_vol_translations = [
            torch.Tensor([1.0, 1.0]),
            [1.0, 1.0],
            1.0,
            torch.Tensor([1.0, 1.0, 1.0])[None].repeat(num_volumes + 1, 1),
        ]

        def zip_with_ok_indicator(good, bad):
            return zip([*good, *bad],
                       [*([True] * len(good)), *([False] * len(bad))])

        for features_, features_ok in zip_with_ok_indicator(
                features, bad_features):
            for densities_, densities_ok in zip_with_ok_indicator(
                    densities, bad_densities):
                for vox_size, size_ok in zip_with_ok_indicator(
                        vox_sizes, bad_vox_sizes):
                    for vol_translation, trans_ok in zip_with_ok_indicator(
                            vol_translations, bad_vol_translations):
                        if (
                                size_ok and trans_ok and features_ok
                                and densities_ok
                        ):  # if all entries are good we check that this doesnt throw
                            Volumes(
                                features=features_,
                                densities=densities_,
                                voxel_size=vox_size,
                                volume_translation=vol_translation,
                            )

                        else:  # otherwise we check for ValueError
                            self.assertRaises(
                                ValueError,
                                Volumes,
                                features=features_,
                                densities=densities_,
                                voxel_size=vox_size,
                                volume_translation=vol_translation,
                            )
Exemplo n.º 9
0
    def test_constructor_for_padded_lists(self):
        """
        Tests constructor for padded/list representations.
        """

        device = torch.device("cuda:0")
        diff_device = torch.device("cpu")

        num_volumes = 3
        num_channels = 4
        size = (6, 8, 10)
        diff_size = (6, 8, 11)

        # good ways to define densities
        ok_densities = [
            torch.randn(size=[num_volumes, 1, *size],
                        device=device,
                        dtype=torch.float32).unbind(0),
            torch.randn(size=[num_volumes, 1, *size],
                        device=device,
                        dtype=torch.float32),
        ]

        # bad ways to define features
        bad_features = [
            torch.randn(
                size=[num_volumes + 1, num_channels, *size],
                device=device,
                dtype=torch.float32,
            ).unbind(0),  # list with diff batch size
            torch.randn(
                size=[num_volumes + 1, num_channels, *size],
                device=device,
                dtype=torch.float32,
            ),  # diff batch size
            torch.randn(
                size=[num_volumes, num_channels, *diff_size],
                device=device,
                dtype=torch.float32,
            ).unbind(0),  # list with different size
            torch.randn(
                size=[num_volumes, num_channels, *diff_size],
                device=device,
                dtype=torch.float32,
            ),  # different size
            torch.randn(
                size=[num_volumes, num_channels, *size],
                device=diff_device,
                dtype=torch.float32,
            ),  # different device
            torch.randn(
                size=[num_volumes, num_channels, *size],
                device=diff_device,
                dtype=torch.float32,
            ).unbind(0),  # list with different device
        ]

        # good ways to define features
        ok_features = [
            torch.randn(
                size=[num_volumes, num_channels, *size],
                device=device,
                dtype=torch.float32,
            ).unbind(0),  # list of features of correct size
            torch.randn(
                size=[num_volumes, num_channels, *size],
                device=device,
                dtype=torch.float32,
            ),
        ]

        for densities in ok_densities:
            for features in bad_features:
                self.assertRaises(ValueError,
                                  Volumes,
                                  densities=densities,
                                  features=features)
            for features in ok_features:
                Volumes(densities=densities, features=features)
Exemplo n.º 10
0
    def test_get_item(
            self,
            num_volumes=5,
            num_channels=4,
            volume_size=(10, 13, 8),
            dtype=torch.float32,
    ):

        device = torch.device("cuda:0")

        # make sure we have at least 3 volumes to prevent indexing crash
        num_volumes = max(num_volumes, 3)

        features = torch.randn(
            size=[num_volumes, num_channels, *volume_size],
            device=device,
            dtype=torch.float32,
        )
        densities = torch.randn(size=[num_volumes, 1, *volume_size],
                                device=device,
                                dtype=torch.float32)

        features_list, rand_sizes = TestVolumes._random_volume_list(
            num_volumes, 3, volume_size, num_channels, device)
        densities_list, _ = TestVolumes._random_volume_list(
            num_volumes, 3, volume_size, 1, device, rand_sizes=rand_sizes)

        volume_translation = -torch.randn(num_volumes, 3).type_as(features)
        voxel_size = torch.rand(num_volumes, 1).type_as(features) + 0.5

        for features_, densities_ in zip(
            (None, features, features_list),
            (densities, densities, densities_list)):

            # init the volume structure
            v = Volumes(
                features=features_,
                densities=densities_,
                volume_translation=volume_translation,
                voxel_size=voxel_size,
            )

            # int index
            index = 1
            v_selected = v[index]
            self.assertEqual(len(v_selected), 1)
            self._check_indexed_volumes(v, v_selected, [(0, 1)])

            # list index
            index = [1, 2]
            v_selected = v[index]
            self.assertEqual(len(v_selected), len(index))
            self._check_indexed_volumes(v, v_selected, enumerate(index))

            # slice index
            index = slice(0, 2, 1)
            v_selected = v[0:2]
            self.assertEqual(len(v_selected), 2)
            self._check_indexed_volumes(v, v_selected, [(0, 0), (1, 1)])

            # bool tensor
            index = (torch.rand(num_volumes) > 0.5).to(device)
            index[:2] = True  # make sure smth is selected
            v_selected = v[index]
            self.assertEqual(len(v_selected), index.sum())
            self._check_indexed_volumes(
                v,
                v_selected,
                zip(
                    torch.arange(index.sum()),
                    torch.nonzero(index, as_tuple=False).squeeze(),
                ),
            )

            # int tensor
            index = torch.tensor([1, 2], dtype=torch.int64, device=device)
            v_selected = v[index]
            self.assertEqual(len(v_selected), index.numel())
            self._check_indexed_volumes(v, v_selected,
                                        enumerate(index.tolist()))

            # invalid index
            index = torch.tensor([1, 0, 1], dtype=torch.float32, device=device)
            with self.assertRaises(IndexError):
                v_selected = v[index]
            index = 1.2  # floating point index
            with self.assertRaises(IndexError):
                v_selected = v[index]
Exemplo n.º 11
0
    def test_feature_density_setters(self):
        """
        Tests getters and setters for padded/list representations.
        """

        device = torch.device("cuda:0")
        diff_device = torch.device("cpu")

        num_volumes = 30
        num_channels = 4
        K = 20

        densities = []
        features = []
        grid_sizes = []
        diff_grid_sizes = []

        for _ in range(num_volumes):
            grid_size = torch.randint(K - 1, size=(3, )).long() + 1
            densities.append(
                torch.rand((1, *grid_size), device=device,
                           dtype=torch.float32))
            features.append(
                torch.rand((num_channels, *grid_size),
                           device=device,
                           dtype=torch.float32))
            grid_sizes.append(grid_size)

            diff_grid_size = (copy.deepcopy(grid_size) +
                              torch.randint(2, size=(3, )).long() + 1)
            diff_grid_sizes.append(diff_grid_size)
        grid_sizes = torch.stack(grid_sizes).to(device)
        diff_grid_sizes = torch.stack(diff_grid_sizes).to(device)

        volumes = Volumes(densities=densities, features=features)
        self.assertClose(volumes.get_grid_sizes(), grid_sizes)

        # test the getters
        features_padded = volumes.features()
        densities_padded = volumes.densities()
        features_list = volumes.features_list()
        densities_list = volumes.densities_list()
        for x_pad, x_list in zip(
            (densities_padded, features_padded, densities_padded,
             features_padded),
            (densities_list, features_list, densities, features),
        ):
            self._check_padded(x_pad, x_list, grid_sizes)

        # test feature setters
        features_new = [
            torch.rand((num_channels, *grid_size),
                       device=device,
                       dtype=torch.float32) for grid_size in grid_sizes
        ]
        volumes._set_features(features_new)
        features_new_list = volumes.features_list()
        features_new_padded = volumes.features()
        for x_pad, x_list in zip(
            (features_new_padded, features_new_padded),
            (features_new, features_new_list),
        ):
            self._check_padded(x_pad, x_list, grid_sizes)

        # wrong features to update
        bad_features_new = [
            [
                torch.rand((num_channels, *grid_size),
                           device=diff_device,
                           dtype=torch.float32)
                for grid_size in diff_grid_sizes
            ],
            torch.rand(
                (num_volumes, num_channels, K + 1, K, K),
                device=device,
                dtype=torch.float32,
            ),
            None,
        ]
        for bad_features_new_ in bad_features_new:
            with self.assertRaises(ValueError):
                volumes._set_densities(bad_features_new_)

        # test density setters
        densities_new = [
            torch.rand((1, *grid_size), device=device, dtype=torch.float32)
            for grid_size in grid_sizes
        ]
        volumes._set_densities(densities_new)
        densities_new_list = volumes.densities_list()
        densities_new_padded = volumes.densities()
        for x_pad, x_list in zip(
            (densities_new_padded, densities_new_padded),
            (densities_new, densities_new_list),
        ):
            self._check_padded(x_pad, x_list, grid_sizes)

        # wrong densities to update
        bad_densities_new = [
            [
                torch.rand((1, *grid_size),
                           device=diff_device,
                           dtype=torch.float32)
                for grid_size in diff_grid_sizes
            ],
            torch.rand((num_volumes, 1, K + 1, K, K),
                       device=device,
                       dtype=torch.float32),
            None,
        ]
        for bad_densities_new_ in bad_densities_new:
            with self.assertRaises(ValueError):
                volumes._set_densities(bad_densities_new_)

        # test update_padded
        volumes = Volumes(densities=densities, features=features)
        volumes_updated = volumes.update_padded(densities_new,
                                                new_features=features_new)
        densities_new_list = volumes_updated.densities_list()
        densities_new_padded = volumes_updated.densities()
        features_new_list = volumes_updated.features_list()
        features_new_padded = volumes_updated.features()
        for x_pad, x_list in zip(
            (
                densities_new_padded,
                densities_new_padded,
                features_new_padded,
                features_new_padded,
            ),
            (densities_new, densities_new_list, features_new,
             features_new_list),
        ):
            self._check_padded(x_pad, x_list, grid_sizes)
        self.assertIs(volumes.get_grid_sizes(),
                      volumes_updated.get_grid_sizes())
        self.assertIs(
            volumes.get_local_to_world_coords_transform(),
            volumes_updated.get_local_to_world_coords_transform(),
        )
        self.assertIs(volumes.device, volumes_updated.device)
Exemplo n.º 12
0
    def test_to(self,
                num_volumes=3,
                num_channels=4,
                size=(6, 8, 10),
                dtype=torch.float32):
        """
        Test the moving of the volumes from/to gpu and cpu
        """

        features = torch.randn(size=[num_volumes, num_channels, *size],
                               dtype=torch.float32)
        densities = torch.rand(size=[num_volumes, 1, *size], dtype=dtype)
        volumes = Volumes(densities=densities, features=features)

        # Test support for str and torch.device
        cpu_device = torch.device("cpu")

        converted_volumes = volumes.to("cpu")
        self.assertEqual(cpu_device, converted_volumes.device)
        self.assertEqual(cpu_device, volumes.device)
        self.assertIs(volumes, converted_volumes)

        converted_volumes = volumes.to(cpu_device)
        self.assertEqual(cpu_device, converted_volumes.device)
        self.assertEqual(cpu_device, volumes.device)
        self.assertIs(volumes, converted_volumes)

        cuda_device = torch.device("cuda:0")

        converted_volumes = volumes.to("cuda:0")
        self.assertEqual(cuda_device, converted_volumes.device)
        self.assertEqual(cpu_device, volumes.device)
        self.assertIsNot(volumes, converted_volumes)

        converted_volumes = volumes.to(cuda_device)
        self.assertEqual(cuda_device, converted_volumes.device)
        self.assertEqual(cpu_device, volumes.device)
        self.assertIsNot(volumes, converted_volumes)

        # Test device placement of internal tensors
        features = features.to(cuda_device)
        densities = features.to(cuda_device)

        for features_ in (features, None):
            volumes = Volumes(densities=densities, features=features_)

            cpu_volumes = volumes.cpu()
            cuda_volumes = cpu_volumes.cuda()
            cuda_volumes2 = cuda_volumes.cuda()
            cpu_volumes2 = cuda_volumes2.cpu()

            for volumes1, volumes2 in itertools.combinations(
                (volumes, cpu_volumes, cpu_volumes2, cuda_volumes,
                 cuda_volumes2), 2):
                if volumes1 is cuda_volumes and volumes2 is cuda_volumes2:
                    # checks that we do not copy if the devices stay the same
                    assert_fun = self.assertIs
                else:
                    assert_fun = self.assertSeparate
                assert_fun(volumes1._densities, volumes2._densities)
                if features_ is not None:
                    assert_fun(volumes1._features, volumes2._features)
                for volumes_ in (volumes1, volumes2):
                    if volumes_ in (cpu_volumes, cpu_volumes2):
                        self._check_vars_on_device(volumes_, cpu_device)
                    else:
                        self._check_vars_on_device(volumes_, cuda_device)
Exemplo n.º 13
0
    def test_coord_grid_convention_heterogeneous(self,
                                                 num_channels=4,
                                                 dtype=torch.float32):
        """
        Check that for a list of 2 trivial volumes with
        spatial sizes DxHxW=(5x7x5, 3x5x5):
        1) xyz_world=(0, 0, 0) lands right in the middle of the volume
        with xyz_local=(0, 0, 0).
        2) xyz_world=((-2, 3, -2), (-2, -2,  1)) results
        in xyz_local=((-1, 1, -1), (-1, -1, 1)).
        3) The centeral voxel of the volume coordinate grid
        has coords x_world=(0, 0, 0) and x_local=(0, 0, 0)
        4) grid_sampler(world_coordinate_grid, local_coordinate_grid)
        is the same as world_coordinate_grid itself. I.e. the local coordinate
        grid matches the grid_sampler coordinate convention.
        """

        device = torch.device("cuda:0")

        sizes = [(5, 7, 5), (3, 5, 5)]

        densities_list = [
            torch.randn(size=[num_channels, *size],
                        device=device,
                        dtype=torch.float32) for size in sizes
        ]

        # init the volume
        v_trivial = Volumes(densities=densities_list)

        # check the border point locations
        pts_world = torch.tensor([[-2.0, 3.0, -2.0], [-2.0, -2.0, 1.0]],
                                 device=device,
                                 dtype=torch.float32)[:, None]
        pts_local = v_trivial.world_to_local_coords(pts_world)
        pts_local_expected = torch.tensor(
            [[-1.0, 1.0, -1.0], [-1.0, -1.0, 1.0]],
            device=device,
            dtype=torch.float32)[:, None]
        self.assertClose(pts_local, pts_local_expected)

        # check that the central voxel has coords x_world=(0, 0, 0) and x_local(0, 0, 0)
        grid_world = v_trivial.get_coord_grid(world_coordinates=True)
        grid_local = v_trivial.get_coord_grid(world_coordinates=False)
        for grid in (grid_world, grid_local):
            x0 = grid[0, :, :, 2, 0]
            y0 = grid[0, :, 3, :, 1]
            z0 = grid[0, 2, :, :, 2]
            for coord_line in (x0, y0, z0):
                self.assertClose(coord_line,
                                 torch.zeros_like(coord_line),
                                 atol=1e-7)
            x0 = grid[1, :, :, 2, 0]
            y0 = grid[1, :, 2, :, 1]
            z0 = grid[1, 1, :, :, 2]
            for coord_line in (x0, y0, z0):
                self.assertClose(coord_line,
                                 torch.zeros_like(coord_line),
                                 atol=1e-7)

        # resample grid_world using grid_sampler with local coords
        # -> make sure the resampled version is the same as original
        for grid_world_, grid_local_, size in zip(grid_world, grid_local,
                                                  sizes):
            grid_world_crop = grid_world_[:size[0], :size[1], :size[2], :][
                None]
            grid_local_crop = grid_local_[:size[0], :size[1], :size[2], :][
                None]
            grid_world_crop_resampled = torch.nn.functional.grid_sample(
                grid_world_crop.permute(0, 4, 1, 2, 3),
                grid_local_crop,
                align_corners=True,
            ).permute(0, 2, 3, 4, 1)
            self.assertClose(grid_world_crop_resampled,
                             grid_world_crop,
                             atol=1e-7)