def init_volume_boundary_pointcloud( batch_size: int, volume_size: Tuple[int, int, int], n_points: int, interp_mode: str, device: str, require_grad: bool = False, ): """ Initialize a point cloud that closely follows a boundary of a volume with a given size. The volume buffer is initialized as well. """ # generate a 3D point cloud sampled from sides of a [0,1] cube xyz, rgb = init_cube_point_cloud(batch_size, n_points=n_points, device=device, rotate_y=True) # make volume_size tensor volume_size_t = torch.tensor(volume_size, dtype=xyz.dtype, device=xyz.device) if interp_mode == "trilinear": # make the xyz locations fall on the boundary of the # first/last two voxels along each spatial dimension of the # volume - this properly checks the correctness of the # trilinear interpolation scheme xyz = (xyz - 0.5) * ((volume_size_t - 2) / (volume_size_t - 1))[[2, 1, 0]] + 0.5 # rescale the cube pointcloud to overlap with the volume sides # of the volume rel_scale = volume_size_t / volume_size[0] xyz = xyz * rel_scale[[2, 1, 0]][None, None] # enable grad accumulation for the differentiability check xyz.requires_grad = require_grad rgb.requires_grad = require_grad # create the pointclouds structure pointclouds = Pointclouds(xyz, features=rgb) # set the volume translation so that the point cloud is centered # around 0 volume_translation = -0.5 * rel_scale[[2, 1, 0]] # set the voxel size to 1 / (volume_size-1) volume_voxel_size = 1 / (volume_size[0] - 1.0) # instantiate the volumes initial_volumes = Volumes( features=xyz.new_zeros(batch_size, 3, *volume_size), densities=xyz.new_zeros(batch_size, 1, *volume_size), volume_translation=volume_translation, voxel_size=volume_voxel_size, ) return pointclouds, initial_volumes
def test_coord_transforms(self, num_volumes=3, num_channels=4, dtype=torch.float32): """ Test the correctness of the conversion between the internal Transform3D Volumes._local_to_world_transform and the initialization from the translation and voxel_size. """ device = torch.device("cuda:0") # try for 10 sets of different random sizes/centers/voxel_sizes for _ in range(10): size = torch.randint(high=10, size=(3,), low=3).tolist() densities = torch.randn( size=[num_volumes, num_channels, *size], device=device, dtype=torch.float32, ) # init the transformation params volume_translation = torch.randn(num_volumes, 3) voxel_size = torch.rand(num_volumes, 3) * 3.0 + 0.5 # get the corresponding Transform3d object local_offset = torch.tensor(list(size), dtype=torch.float32, device=device)[ [2, 1, 0] ][None].repeat(num_volumes, 1) local_to_world_transform = ( Scale(0.5 * local_offset - 0.5, device=device) .scale(voxel_size) .translate(-volume_translation) ) # init the volume structures with the scale and translation, # then get the coord grid in world coords v_trans_vs = Volumes( densities=densities, voxel_size=voxel_size, volume_translation=volume_translation, ) grid_rot_trans_vs = v_trans_vs.get_coord_grid(world_coordinates=True) # map the default local coords to the world coords # with local_to_world_transform v_default = Volumes(densities=densities) grid_default_local = v_default.get_coord_grid(world_coordinates=False) grid_default_world = local_to_world_transform.transform_points( grid_default_local.view(num_volumes, -1, 3) ).view(num_volumes, *size, 3) # check that both grids are the same self.assertClose(grid_rot_trans_vs, grid_default_world, atol=1e-5) # check that the transformations are the same self.assertClose( v_trans_vs.get_local_to_world_coords_transform().get_matrix(), local_to_world_transform.get_matrix(), atol=1e-5, )
def test_to(self, num_volumes=3, num_channels=4, size=(6, 8, 10), dtype=torch.float32): """ Test the moving of the volumes from/to gpu and cpu """ device = torch.device("cuda:0") device_cpu = torch.device("cpu") features = torch.randn(size=[num_volumes, num_channels, *size], device=device, dtype=torch.float32) densities = torch.rand(size=[num_volumes, 1, *size], device=device, dtype=dtype) for features_ in (features, None): v = Volumes(densities=densities, features=features_) v_cpu = v.cpu() v_cuda = v_cpu.cuda() v_cuda_2 = v_cuda.cuda() v_cpu_2 = v_cuda_2.cpu() for v1, v2 in itertools.combinations( (v, v_cpu, v_cpu_2, v_cuda, v_cuda_2), 2): if v1 is v_cuda and v2 is v_cuda_2: # checks that we do not copy if the devices stay the same assert_fun = self.assertIs else: assert_fun = self.assertSeparate assert_fun(v1._densities, v2._densities) if features_ is not None: assert_fun(v1._features, v2._features) for v_ in (v1, v2): if v_ in (v_cpu, v_cpu_2): self._check_vars_on_device(v_, device_cpu) else: self._check_vars_on_device(v_, device)
def test_clone( self, num_volumes=3, num_channels=4, size=(6, 8, 10), dtype=torch.float32 ): """ Test cloning of a `Volumes` object """ device = torch.device("cuda:0") features = torch.randn( size=[num_volumes, num_channels, *size], device=device, dtype=torch.float32 ) densities = torch.rand( size=[num_volumes, 1, *size], device=device, dtype=torch.float32 ) for has_features in (True, False): v = Volumes( densities=densities, features=features if has_features else None ) vnew = v.clone() vnew._densities.data[0, 0, 0, 0, 0] += 1.0 self.assertNotAlmostEqual( float( (vnew.densities()[0, 0, 0, 0, 0] - v.densities()[0, 0, 0, 0, 0]) .abs() .max() ), 0.0, ) if has_features: vnew._features.data[0, 0, 0, 0, 0] += 1.0 self.assertNotAlmostEqual( float( (vnew.features()[0, 0, 0, 0, 0] - v.features()[0, 0, 0, 0, 0]) .abs() .max() ), 0.0, )
def test_coord_grid_transforms( self, num_volumes=3, num_channels=4, dtype=torch.float32 ): """ Test whether conversion between local-world coordinates of the volume returns correct results. """ device = torch.device("cuda:0") # try for 10 sets of different random sizes/centers/voxel_sizes for _ in range(10): size = torch.randint(high=10, size=(3,), low=3).tolist() center = torch.randn(num_volumes, 3, dtype=torch.float32, device=device) voxel_size = torch.rand(1, dtype=torch.float32, device=device) * 5.0 + 0.5 for densities in ( torch.randn( size=[num_volumes, num_channels, *size], device=device, dtype=torch.float32, ), TestVolumes._random_volume_list( num_volumes, 3, size, num_channels, device, rand_sizes=None )[0], ): # init the volume structure v = Volumes( densities=densities, voxel_size=voxel_size, volume_translation=-center, ) # get local coord grid grid_local = v.get_coord_grid(world_coordinates=False) # convert from world to local to world grid_world = v.get_coord_grid(world_coordinates=True) grid_local_2 = v.world_to_local_coords(grid_world) grid_world_2 = v.local_to_world_coords(grid_local_2) # assertions on shape and values of grid_world and grid_local self.assertClose(grid_world, grid_world_2, atol=1e-5) self.assertClose(grid_local, grid_local_2, atol=1e-5) # check that the individual slices of the location grid have # constant values along expected dimensions for plane_dim in (1, 2, 3): for grid_plane in grid_world.split(1, dim=plane_dim): grid_coord_dim = {1: 2, 2: 1, 3: 0}[plane_dim] grid_coord_plane = grid_plane.squeeze()[..., grid_coord_dim] # check that all elements of grid_coord_plane are # the same for each batch element self.assertClose( grid_coord_plane.reshape(num_volumes, -1).max(dim=1).values, grid_coord_plane.reshape(num_volumes, -1).min(dim=1).values, )
def test_coord_grid_convention( self, num_volumes=3, num_channels=4, dtype=torch.float32 ): """ Check that for a trivial volume with spatial size DxHxW=5x7x5: 1) xyz_world=(0, 0, 0) lands right in the middle of the volume with xyz_local=(0, 0, 0). 2) xyz_world=(-2, 3, 2) results in xyz_local=(-1, 1, -1). 3) The centeral voxel of the volume coordinate grid has coords x_world=(0, 0, 0) and x_local=(0, 0, 0) 4) grid_sampler(world_coordinate_grid, local_coordinate_grid) is the same as world_coordinate_grid itself. I.e. the local coordinate grid matches the grid_sampler coordinate convention. """ device = torch.device("cuda:0") densities = torch.randn( size=[num_volumes, num_channels, 5, 7, 5], device=device, dtype=torch.float32, ) v_trivial = Volumes(densities=densities) # check the case with x_world=(0,0,0) pts_world = torch.zeros(num_volumes, 1, 3, device=device, dtype=torch.float32) pts_local = v_trivial.world_to_local_coords(pts_world) pts_local_expected = torch.zeros_like(pts_local) self.assertClose(pts_local, pts_local_expected) # check the case with x_world=(-2, 3, -2) pts_world = torch.tensor([-2, 3, -2], device=device, dtype=torch.float32)[ None, None ].repeat(num_volumes, 1, 1) pts_local = v_trivial.world_to_local_coords(pts_world) pts_local_expected = torch.tensor( [-1, 1, -1], device=device, dtype=torch.float32 )[None, None].repeat(num_volumes, 1, 1) self.assertClose(pts_local, pts_local_expected) # check that the central voxel has coords x_world=(0, 0, 0) and x_local(0, 0, 0) grid_world = v_trivial.get_coord_grid(world_coordinates=True) grid_local = v_trivial.get_coord_grid(world_coordinates=False) for grid in (grid_world, grid_local): x0 = grid[0, :, :, 2, 0] y0 = grid[0, :, 3, :, 1] z0 = grid[0, 2, :, :, 2] for coord_line in (x0, y0, z0): self.assertClose(coord_line, torch.zeros_like(coord_line), atol=1e-7) # resample grid_world using grid_sampler with local coords # -> make sure the resampled version is the same as original grid_world_resampled = torch.nn.functional.grid_sample( grid_world.permute(0, 4, 1, 2, 3), grid_local, align_corners=True ).permute(0, 2, 3, 4, 1) self.assertClose(grid_world_resampled, grid_world, atol=1e-7)
def test_unscaled(self): D = 5 P = 1000 B, C, H, W = 2, 3, D, D densities = torch.zeros(B, 1, D, H, W) features = torch.zeros(B, C, D, H, W) volumes = Volumes(densities=densities, features=features) points = torch.rand(B, 1000, 3) * (D - 1) - ((D - 1) * 0.5) point_features = torch.rand(B, 1000, C) pointclouds = Pointclouds(points=points, features=point_features) volumes2 = add_pointclouds_to_volumes(pointclouds, volumes, rescale_features=False) self.assertConstant(volumes2.densities().sum([2, 3, 4]) / P, 1, atol=1e-5) self.assertConstant(volumes2.features().sum([2, 3, 4]) / P, 0.5, atol=0.03)
def test_constructor(self, num_volumes=3, num_channels=4, size=(6, 8, 10), dtype=torch.float32): """ Test different ways of calling the `Volumes` constructor """ device = torch.device("cuda:0") # all ways to define features features = [ torch.randn( size=[num_volumes, num_channels, *size], device=device, dtype=torch.float32, ), # padded tensor torch.randn( size=[num_volumes, num_channels, *size], device=device, dtype=torch.float32, ).unbind(0), # list of features None, # no features ] # bad ways to define features bad_features = [ torch.randn( size=[num_volumes, num_channels, 2, *size], device=device, dtype=torch.float32, ), # 6 dims torch.randn(size=[num_volumes, *size], device=device, dtype=torch.float32), # 4 dims torch.randn( size=[num_volumes, *size], device=device, dtype=torch.float32).unbind(0), # list of 4 dim tensors ] # all ways to define densities densities = [ torch.randn(size=[num_volumes, 1, *size], device=device, dtype=torch.float32), # padded tensor torch.randn(size=[num_volumes, 1, *size], device=device, dtype=torch.float32).unbind(0), # list of densities ] # bad ways to define densities bad_densities = [ None, # omitted torch.randn(size=[num_volumes, 1, 1, *size], device=device, dtype=torch.float32), # 6-dim tensor torch.randn( size=[num_volumes, 1, 1, *size], device=device, dtype=torch.float32).unbind(0), # list of 5-dim densities ] # all possible ways to define the voxels sizes vox_sizes = [ torch.Tensor([1.0, 1.0, 1.0]), [1.0, 1.0, 1.0], torch.Tensor([1.0, 1.0, 1.0])[None].repeat(num_volumes, 1), torch.Tensor([1.0])[None].repeat(num_volumes, 1), 1.0, torch.Tensor([1.0]), ] # all possible ways to define the volume translations vol_translations = [ torch.Tensor([1.0, 1.0, 1.0]), [1.0, 1.0, 1.0], torch.Tensor([1.0, 1.0, 1.0])[None].repeat(num_volumes, 1), ] # wrong ways to define voxel sizes bad_vox_sizes = [ torch.Tensor([1.0, 1.0, 1.0, 1.0]), [1.0, 1.0, 1.0, 1.0], torch.Tensor([]), None, ] # wrong ways to define the volume translations bad_vol_translations = [ torch.Tensor([1.0, 1.0]), [1.0, 1.0], 1.0, torch.Tensor([1.0, 1.0, 1.0])[None].repeat(num_volumes + 1, 1), ] def zip_with_ok_indicator(good, bad): return zip([*good, *bad], [*([True] * len(good)), *([False] * len(bad))]) for features_, features_ok in zip_with_ok_indicator( features, bad_features): for densities_, densities_ok in zip_with_ok_indicator( densities, bad_densities): for vox_size, size_ok in zip_with_ok_indicator( vox_sizes, bad_vox_sizes): for vol_translation, trans_ok in zip_with_ok_indicator( vol_translations, bad_vol_translations): if ( size_ok and trans_ok and features_ok and densities_ok ): # if all entries are good we check that this doesnt throw Volumes( features=features_, densities=densities_, voxel_size=vox_size, volume_translation=vol_translation, ) else: # otherwise we check for ValueError self.assertRaises( ValueError, Volumes, features=features_, densities=densities_, voxel_size=vox_size, volume_translation=vol_translation, )
def test_constructor_for_padded_lists(self): """ Tests constructor for padded/list representations. """ device = torch.device("cuda:0") diff_device = torch.device("cpu") num_volumes = 3 num_channels = 4 size = (6, 8, 10) diff_size = (6, 8, 11) # good ways to define densities ok_densities = [ torch.randn(size=[num_volumes, 1, *size], device=device, dtype=torch.float32).unbind(0), torch.randn(size=[num_volumes, 1, *size], device=device, dtype=torch.float32), ] # bad ways to define features bad_features = [ torch.randn( size=[num_volumes + 1, num_channels, *size], device=device, dtype=torch.float32, ).unbind(0), # list with diff batch size torch.randn( size=[num_volumes + 1, num_channels, *size], device=device, dtype=torch.float32, ), # diff batch size torch.randn( size=[num_volumes, num_channels, *diff_size], device=device, dtype=torch.float32, ).unbind(0), # list with different size torch.randn( size=[num_volumes, num_channels, *diff_size], device=device, dtype=torch.float32, ), # different size torch.randn( size=[num_volumes, num_channels, *size], device=diff_device, dtype=torch.float32, ), # different device torch.randn( size=[num_volumes, num_channels, *size], device=diff_device, dtype=torch.float32, ).unbind(0), # list with different device ] # good ways to define features ok_features = [ torch.randn( size=[num_volumes, num_channels, *size], device=device, dtype=torch.float32, ).unbind(0), # list of features of correct size torch.randn( size=[num_volumes, num_channels, *size], device=device, dtype=torch.float32, ), ] for densities in ok_densities: for features in bad_features: self.assertRaises(ValueError, Volumes, densities=densities, features=features) for features in ok_features: Volumes(densities=densities, features=features)
def test_get_item( self, num_volumes=5, num_channels=4, volume_size=(10, 13, 8), dtype=torch.float32, ): device = torch.device("cuda:0") # make sure we have at least 3 volumes to prevent indexing crash num_volumes = max(num_volumes, 3) features = torch.randn( size=[num_volumes, num_channels, *volume_size], device=device, dtype=torch.float32, ) densities = torch.randn(size=[num_volumes, 1, *volume_size], device=device, dtype=torch.float32) features_list, rand_sizes = TestVolumes._random_volume_list( num_volumes, 3, volume_size, num_channels, device) densities_list, _ = TestVolumes._random_volume_list( num_volumes, 3, volume_size, 1, device, rand_sizes=rand_sizes) volume_translation = -torch.randn(num_volumes, 3).type_as(features) voxel_size = torch.rand(num_volumes, 1).type_as(features) + 0.5 for features_, densities_ in zip( (None, features, features_list), (densities, densities, densities_list)): # init the volume structure v = Volumes( features=features_, densities=densities_, volume_translation=volume_translation, voxel_size=voxel_size, ) # int index index = 1 v_selected = v[index] self.assertEqual(len(v_selected), 1) self._check_indexed_volumes(v, v_selected, [(0, 1)]) # list index index = [1, 2] v_selected = v[index] self.assertEqual(len(v_selected), len(index)) self._check_indexed_volumes(v, v_selected, enumerate(index)) # slice index index = slice(0, 2, 1) v_selected = v[0:2] self.assertEqual(len(v_selected), 2) self._check_indexed_volumes(v, v_selected, [(0, 0), (1, 1)]) # bool tensor index = (torch.rand(num_volumes) > 0.5).to(device) index[:2] = True # make sure smth is selected v_selected = v[index] self.assertEqual(len(v_selected), index.sum()) self._check_indexed_volumes( v, v_selected, zip( torch.arange(index.sum()), torch.nonzero(index, as_tuple=False).squeeze(), ), ) # int tensor index = torch.tensor([1, 2], dtype=torch.int64, device=device) v_selected = v[index] self.assertEqual(len(v_selected), index.numel()) self._check_indexed_volumes(v, v_selected, enumerate(index.tolist())) # invalid index index = torch.tensor([1, 0, 1], dtype=torch.float32, device=device) with self.assertRaises(IndexError): v_selected = v[index] index = 1.2 # floating point index with self.assertRaises(IndexError): v_selected = v[index]
def test_feature_density_setters(self): """ Tests getters and setters for padded/list representations. """ device = torch.device("cuda:0") diff_device = torch.device("cpu") num_volumes = 30 num_channels = 4 K = 20 densities = [] features = [] grid_sizes = [] diff_grid_sizes = [] for _ in range(num_volumes): grid_size = torch.randint(K - 1, size=(3, )).long() + 1 densities.append( torch.rand((1, *grid_size), device=device, dtype=torch.float32)) features.append( torch.rand((num_channels, *grid_size), device=device, dtype=torch.float32)) grid_sizes.append(grid_size) diff_grid_size = (copy.deepcopy(grid_size) + torch.randint(2, size=(3, )).long() + 1) diff_grid_sizes.append(diff_grid_size) grid_sizes = torch.stack(grid_sizes).to(device) diff_grid_sizes = torch.stack(diff_grid_sizes).to(device) volumes = Volumes(densities=densities, features=features) self.assertClose(volumes.get_grid_sizes(), grid_sizes) # test the getters features_padded = volumes.features() densities_padded = volumes.densities() features_list = volumes.features_list() densities_list = volumes.densities_list() for x_pad, x_list in zip( (densities_padded, features_padded, densities_padded, features_padded), (densities_list, features_list, densities, features), ): self._check_padded(x_pad, x_list, grid_sizes) # test feature setters features_new = [ torch.rand((num_channels, *grid_size), device=device, dtype=torch.float32) for grid_size in grid_sizes ] volumes._set_features(features_new) features_new_list = volumes.features_list() features_new_padded = volumes.features() for x_pad, x_list in zip( (features_new_padded, features_new_padded), (features_new, features_new_list), ): self._check_padded(x_pad, x_list, grid_sizes) # wrong features to update bad_features_new = [ [ torch.rand((num_channels, *grid_size), device=diff_device, dtype=torch.float32) for grid_size in diff_grid_sizes ], torch.rand( (num_volumes, num_channels, K + 1, K, K), device=device, dtype=torch.float32, ), None, ] for bad_features_new_ in bad_features_new: with self.assertRaises(ValueError): volumes._set_densities(bad_features_new_) # test density setters densities_new = [ torch.rand((1, *grid_size), device=device, dtype=torch.float32) for grid_size in grid_sizes ] volumes._set_densities(densities_new) densities_new_list = volumes.densities_list() densities_new_padded = volumes.densities() for x_pad, x_list in zip( (densities_new_padded, densities_new_padded), (densities_new, densities_new_list), ): self._check_padded(x_pad, x_list, grid_sizes) # wrong densities to update bad_densities_new = [ [ torch.rand((1, *grid_size), device=diff_device, dtype=torch.float32) for grid_size in diff_grid_sizes ], torch.rand((num_volumes, 1, K + 1, K, K), device=device, dtype=torch.float32), None, ] for bad_densities_new_ in bad_densities_new: with self.assertRaises(ValueError): volumes._set_densities(bad_densities_new_) # test update_padded volumes = Volumes(densities=densities, features=features) volumes_updated = volumes.update_padded(densities_new, new_features=features_new) densities_new_list = volumes_updated.densities_list() densities_new_padded = volumes_updated.densities() features_new_list = volumes_updated.features_list() features_new_padded = volumes_updated.features() for x_pad, x_list in zip( ( densities_new_padded, densities_new_padded, features_new_padded, features_new_padded, ), (densities_new, densities_new_list, features_new, features_new_list), ): self._check_padded(x_pad, x_list, grid_sizes) self.assertIs(volumes.get_grid_sizes(), volumes_updated.get_grid_sizes()) self.assertIs( volumes.get_local_to_world_coords_transform(), volumes_updated.get_local_to_world_coords_transform(), ) self.assertIs(volumes.device, volumes_updated.device)
def test_to(self, num_volumes=3, num_channels=4, size=(6, 8, 10), dtype=torch.float32): """ Test the moving of the volumes from/to gpu and cpu """ features = torch.randn(size=[num_volumes, num_channels, *size], dtype=torch.float32) densities = torch.rand(size=[num_volumes, 1, *size], dtype=dtype) volumes = Volumes(densities=densities, features=features) # Test support for str and torch.device cpu_device = torch.device("cpu") converted_volumes = volumes.to("cpu") self.assertEqual(cpu_device, converted_volumes.device) self.assertEqual(cpu_device, volumes.device) self.assertIs(volumes, converted_volumes) converted_volumes = volumes.to(cpu_device) self.assertEqual(cpu_device, converted_volumes.device) self.assertEqual(cpu_device, volumes.device) self.assertIs(volumes, converted_volumes) cuda_device = torch.device("cuda:0") converted_volumes = volumes.to("cuda:0") self.assertEqual(cuda_device, converted_volumes.device) self.assertEqual(cpu_device, volumes.device) self.assertIsNot(volumes, converted_volumes) converted_volumes = volumes.to(cuda_device) self.assertEqual(cuda_device, converted_volumes.device) self.assertEqual(cpu_device, volumes.device) self.assertIsNot(volumes, converted_volumes) # Test device placement of internal tensors features = features.to(cuda_device) densities = features.to(cuda_device) for features_ in (features, None): volumes = Volumes(densities=densities, features=features_) cpu_volumes = volumes.cpu() cuda_volumes = cpu_volumes.cuda() cuda_volumes2 = cuda_volumes.cuda() cpu_volumes2 = cuda_volumes2.cpu() for volumes1, volumes2 in itertools.combinations( (volumes, cpu_volumes, cpu_volumes2, cuda_volumes, cuda_volumes2), 2): if volumes1 is cuda_volumes and volumes2 is cuda_volumes2: # checks that we do not copy if the devices stay the same assert_fun = self.assertIs else: assert_fun = self.assertSeparate assert_fun(volumes1._densities, volumes2._densities) if features_ is not None: assert_fun(volumes1._features, volumes2._features) for volumes_ in (volumes1, volumes2): if volumes_ in (cpu_volumes, cpu_volumes2): self._check_vars_on_device(volumes_, cpu_device) else: self._check_vars_on_device(volumes_, cuda_device)
def test_coord_grid_convention_heterogeneous(self, num_channels=4, dtype=torch.float32): """ Check that for a list of 2 trivial volumes with spatial sizes DxHxW=(5x7x5, 3x5x5): 1) xyz_world=(0, 0, 0) lands right in the middle of the volume with xyz_local=(0, 0, 0). 2) xyz_world=((-2, 3, -2), (-2, -2, 1)) results in xyz_local=((-1, 1, -1), (-1, -1, 1)). 3) The centeral voxel of the volume coordinate grid has coords x_world=(0, 0, 0) and x_local=(0, 0, 0) 4) grid_sampler(world_coordinate_grid, local_coordinate_grid) is the same as world_coordinate_grid itself. I.e. the local coordinate grid matches the grid_sampler coordinate convention. """ device = torch.device("cuda:0") sizes = [(5, 7, 5), (3, 5, 5)] densities_list = [ torch.randn(size=[num_channels, *size], device=device, dtype=torch.float32) for size in sizes ] # init the volume v_trivial = Volumes(densities=densities_list) # check the border point locations pts_world = torch.tensor([[-2.0, 3.0, -2.0], [-2.0, -2.0, 1.0]], device=device, dtype=torch.float32)[:, None] pts_local = v_trivial.world_to_local_coords(pts_world) pts_local_expected = torch.tensor( [[-1.0, 1.0, -1.0], [-1.0, -1.0, 1.0]], device=device, dtype=torch.float32)[:, None] self.assertClose(pts_local, pts_local_expected) # check that the central voxel has coords x_world=(0, 0, 0) and x_local(0, 0, 0) grid_world = v_trivial.get_coord_grid(world_coordinates=True) grid_local = v_trivial.get_coord_grid(world_coordinates=False) for grid in (grid_world, grid_local): x0 = grid[0, :, :, 2, 0] y0 = grid[0, :, 3, :, 1] z0 = grid[0, 2, :, :, 2] for coord_line in (x0, y0, z0): self.assertClose(coord_line, torch.zeros_like(coord_line), atol=1e-7) x0 = grid[1, :, :, 2, 0] y0 = grid[1, :, 2, :, 1] z0 = grid[1, 1, :, :, 2] for coord_line in (x0, y0, z0): self.assertClose(coord_line, torch.zeros_like(coord_line), atol=1e-7) # resample grid_world using grid_sampler with local coords # -> make sure the resampled version is the same as original for grid_world_, grid_local_, size in zip(grid_world, grid_local, sizes): grid_world_crop = grid_world_[:size[0], :size[1], :size[2], :][ None] grid_local_crop = grid_local_[:size[0], :size[1], :size[2], :][ None] grid_world_crop_resampled = torch.nn.functional.grid_sample( grid_world_crop.permute(0, 4, 1, 2, 3), grid_local_crop, align_corners=True, ).permute(0, 2, 3, 4, 1) self.assertClose(grid_world_crop_resampled, grid_world_crop, atol=1e-7)