def __init__(self, device: torch.device, key_dtype: torch.dtype, value_dtype: torch.dtype, max_size: int = -1) -> None: is_cpu = device.type == "cpu" self.is_cpu = is_cpu self.key_dtype = key_dtype self.value_dtype = value_dtype key_data_tv = tv.Tensor() value_data_tv = tv.Tensor() if is_cpu: self.keys_data = None self.values_data = None else: assert max_size > 0, "you must provide max_size for fixed-size cuda hash table, usually *2 of num of keys" assert device is not None, "you must specify device for cuda hash table." self.keys_data = torch.empty([max_size], dtype=key_dtype, device=device) self.values_data = torch.empty([max_size], dtype=value_dtype, device=device) key_data_tv = torch_tensor_to_tv(self.keys_data) value_data_tv = torch_tensor_to_tv(self.values_data) stream = 0 if not self.is_cpu: stream = get_current_stream() self.key_itemsize = _TORCH_DTYPE_TO_ITEMSIZE[self.key_dtype] self.value_itemsize = _TORCH_DTYPE_TO_ITEMSIZE[self.value_dtype] self._valid_value_dtype_for_arange = set([torch.int32, torch.int64]) self._table = _HashTable(is_cpu, self.key_itemsize, self.value_itemsize, key_data_tv, value_data_tv, stream)
def items(self, max_size: int = -1): count_tv = tv.Tensor() count = torch.Tensor() stream = 0 if not self.is_cpu: stream = get_current_stream() if not self.is_cpu: assert self.values_data is not None if self.key_itemsize == 4: count = torch.zeros([1], dtype=torch.int32, device=self.values_data.device) count_tv = torch_tensor_to_tv(count, dtype=tv.uint32) elif self.key_itemsize == 8: count = torch.zeros([1], dtype=torch.int64, device=self.values_data.device) count_tv = torch_tensor_to_tv(count, dtype=tv.uint64) else: raise NotImplementedError if not self.is_cpu: assert self.values_data is not None if max_size == -1: max_size = self.values_data.shape[0] keys = torch.empty([max_size], dtype=self.key_dtype, device=self.values_data.device) values = torch.empty([max_size], dtype=self.value_dtype, device=self.values_data.device) else: max_size = self._table.size_cpu() count = torch.tensor([max_size], dtype=torch.int64) keys = torch.empty([max_size], dtype=self.key_dtype) values = torch.empty([max_size], dtype=self.value_dtype) keys_tv = torch_tensor_to_tv(keys) values_tv = torch_tensor_to_tv(values) self._table.items(keys_tv, values_tv, count_tv, stream) return keys, values, count
def assign_arange_(self): """iterate table, assign values with "arange" value. equivalent to 1. get key by items(), 2. use key and arange(key.shape[0]) to insert """ count_tv = tv.Tensor() count = torch.Tensor() stream = 0 if not self.is_cpu: stream = get_current_stream() else: assert self.value_dtype in self._valid_value_dtype_for_arange if not self.is_cpu: assert self.values_data is not None if self.key_itemsize == 4: count = torch.zeros([1], dtype=torch.int32, device=self.values_data.device) count_tv = torch_tensor_to_tv(count, dtype=tv.uint32) elif self.key_itemsize == 8: count = torch.zeros([1], dtype=torch.int64, device=self.values_data.device) count_tv = torch_tensor_to_tv(count, dtype=tv.uint64) else: raise NotImplementedError else: max_size = self._table.size_cpu() count = torch.tensor([max_size], dtype=torch.int64) self._table.assign_arange_(count_tv, stream) return count
def insert_exist_keys(self, keys: torch.Tensor, values: torch.Tensor): """insert kv that k exists in table. return a uint8 tensor that whether insert success. """ keys_tv = torch_tensor_to_tv(keys) values_tv = torch_tensor_to_tv(values) stream = 0 if not self.is_cpu: stream = get_current_stream() is_success = torch.empty([keys.shape[0]], dtype=torch.uint8, device=keys.device) is_success_tv = torch_tensor_to_tv(is_success) self._table.insert_exist_keys(keys_tv, values_tv, is_success_tv, stream) return is_success
def insert(self, keys: torch.Tensor, values: Optional[torch.Tensor] = None): """insert hash table by keys and values if values is None, only key is inserted, the value is undefined. """ keys_tv = torch_tensor_to_tv(keys) values_tv = tv.Tensor() if values is not None: values_tv = torch_tensor_to_tv(values) stream = 0 if not self.is_cpu: stream = get_current_stream() return self._table.insert(keys_tv, values_tv, stream)
def query(self, keys: torch.Tensor, values: Optional[torch.Tensor] = None): """query value by keys, if values is not None, create a new one. return values and a uint8 tensor that whether query success. """ keys_tv = torch_tensor_to_tv(keys) if values is None: values = torch.empty([keys.shape[0]], dtype=self.value_dtype, device=keys.device) values_tv = torch_tensor_to_tv(values) stream = 0 if not self.is_cpu: stream = get_current_stream() is_empty = torch.empty([keys.shape[0]], dtype=torch.uint8, device=keys.device) is_empty_tv = torch_tensor_to_tv(is_empty) self._table.query(keys_tv, values_tv, is_empty_tv, stream) return values, is_empty
def generate_voxel_with_id(self, pc: torch.Tensor, clear_voxels: bool = True, empty_mean: bool = False): """generate voxels/indices/num_point_per_voxel/pc_voxel_ids from point cloud. Args: pc: [N, 3+] point cloud. clear_voxels: if True, call zero on voxels empty_mean: if True, full empty location of voxels with mean. Returns: voxels: voxels indices: quantized coords num_per_voxel: number of points in a voxel pc_voxel_id: voxel id for every point. if not exists, -1. """ assert pc.device.type == self.device.type, "your pc device is wrong" expected_hash_data_num = pc.shape[0] * 2 with torch.no_grad(): pc_voxel_id = torch.empty([pc.shape[0]], dtype=torch.int64, device=self.device) pc_voxel_id_tv = torch_tensor_to_tv(pc_voxel_id) if self.device.type != "cpu": hashdata = torch.empty([expected_hash_data_num, 2], dtype=torch.int64, device=pc.device) point_indice_data = torch.empty([pc.shape[0]], dtype=torch.int64, device=pc.device) pc_tv = torch_tensor_to_tv(pc) stream = get_current_stream() voxels_tv = torch_tensor_to_tv(self.voxels) indices_tv = torch_tensor_to_tv(self.indices) num_per_voxel_tv = torch_tensor_to_tv(self.num_per_voxel) hashdata_tv = torch_tensor_to_tv(hashdata, dtype=tv.custom128, shape=[hashdata.shape[0]]) point_indice_data_tv = torch_tensor_to_tv(point_indice_data) with torch.cuda.device(pc.device): res = SpconvOps.point2voxel_cuda( pc_tv, voxels_tv, indices_tv, num_per_voxel_tv, hashdata_tv, point_indice_data_tv, pc_voxel_id_tv, self.vsize, self.grid_size, self.grid_stride, self.coors_range, empty_mean, clear_voxels, stream) num_voxels = res[0].shape[0] else: pc_tv = torch_tensor_to_tv(pc) stream = get_current_stream() voxels_tv = torch_tensor_to_tv(self.voxels) indices_tv = torch_tensor_to_tv(self.indices) num_per_voxel_tv = torch_tensor_to_tv(self.num_per_voxel) hashdata_tv = torch_tensor_to_tv(self.hashdata, dtype=tv.int32) res = SpconvOps.point2voxel_cpu( pc_tv, voxels_tv, indices_tv, num_per_voxel_tv, hashdata_tv, pc_voxel_id_tv, self.vsize, self.grid_size, self.grid_stride, self.coors_range, empty_mean, clear_voxels) num_voxels = res[0].shape[0] return (self.voxels[:num_voxels], self.indices[:num_voxels], self.num_per_voxel[:num_voxels], pc_voxel_id)