def test_coords_manager(self): key = CoordsKey(D=1) key.setTensorStride(1) cm = CoordsManager(D=1) coords = torch.IntTensor([[0, 1], [0, 1], [0, 2], [0, 2], [1, 0], [1, 0], [1, 1]]) unique_coords = torch.unique(coords, dim=0) # Initialize map mapping, inverse_mapping = cm.initialize(coords, key, force_remap=True, allow_duplicate_coords=False) self.assertTrue(len(unique_coords) == len(mapping)) print(mapping, len(mapping)) cm.print_diagnostics(key) print(cm) self.assertTrue(cm.get_batch_size() == 2) self.assertTrue(cm.get_batch_indices() == {0, 1}) # Create a strided map stride_key = cm.stride(key, [4]) strided_coords = cm.get_coords(stride_key) self.assertTrue(len(strided_coords) == 2) cm.print_diagnostics(key) print(cm) # Create a transposed stride map transposed_key = cm.transposed_stride(stride_key, [2], [3], [1]) print('Transposed Stride: ', cm.get_coords(transposed_key)) print(cm) # Create a transposed stride map transposed_key = cm.transposed_stride(stride_key, [2], [3], [1], force_creation=True) print('Forced Transposed Stride: ', cm.get_coords(transposed_key)) print(cm) # Create a reduction map key = cm.reduce() print('Reduction: ', cm.get_coords(key)) print(cm) print('Reduction mapping: ', cm.get_row_indices_per_batch(stride_key)) print(cm)
def test_coords_manager(self): key = CoordsKey(D=1) key.setTensorStride(1) cm = CoordsManager(D=1) coords = (torch.rand(20, 2) * 10).int() print(coords) unique_coords = torch.unique(coords, dim=0) print('Num unique: ', unique_coords.shape) # Initialize map mapping = cm.initialize(coords, key, force_remap=True, allow_duplicate_coords=False) print(mapping, len(mapping)) cm.print_diagnostics(key) print(cm) print(cm.get_batch_size()) print(cm.get_batch_indices()) # Create a strided map stride_key = cm.stride(key, [4]) print('Stride: ', cm.get_coords(stride_key)) cm.print_diagnostics(key) print(cm) # Create a transposed stride map transposed_key = cm.transposed_stride(stride_key, [2], [3], [1]) print('Transposed Stride: ', cm.get_coords(transposed_key)) print(cm) # Create a transposed stride map transposed_key = cm.transposed_stride(stride_key, [2], [3], [1], force_creation=True) print('Forced Transposed Stride: ', cm.get_coords(transposed_key)) print(cm) # Create a reduction map key = cm.reduce() print('Reduction: ', cm.get_coords(key)) print(cm) print('Reduction mapping: ', cm.get_row_indices_per_batch(stride_key)) print(cm)