Ejemplo n.º 1
0
    def __init__(self, device: torch.device, key_dtype: torch.dtype, 
                value_dtype: torch.dtype, 
                max_size: int = -1) -> None:
        is_cpu = device.type == "cpu"
        self.is_cpu = is_cpu
        self.key_dtype = key_dtype
        self.value_dtype = value_dtype
        key_data_tv = tv.Tensor()
        value_data_tv = tv.Tensor()
        if is_cpu:
            self.keys_data = None 
            self.values_data = None 
        else:
            assert max_size > 0, "you must provide max_size for fixed-size cuda hash table, usually *2 of num of keys"
            assert device is not None, "you must specify device for cuda hash table."
            self.keys_data = torch.empty([max_size], dtype=key_dtype, device=device)
            self.values_data = torch.empty([max_size], dtype=value_dtype, device=device)
            key_data_tv = torch_tensor_to_tv(self.keys_data)
            value_data_tv = torch_tensor_to_tv(self.values_data)
        stream = 0
        if not self.is_cpu:
            stream = get_current_stream()
        self.key_itemsize = _TORCH_DTYPE_TO_ITEMSIZE[self.key_dtype]
        self.value_itemsize = _TORCH_DTYPE_TO_ITEMSIZE[self.value_dtype]
        self._valid_value_dtype_for_arange = set([torch.int32, torch.int64])

        self._table = _HashTable(is_cpu, self.key_itemsize, self.value_itemsize, key_data_tv, value_data_tv, stream)
Ejemplo n.º 2
0
    def items(self, max_size: int = -1):
        count_tv = tv.Tensor()
        count = torch.Tensor()
        stream = 0
        if not self.is_cpu:
            stream = get_current_stream()
        if not self.is_cpu:
            assert self.values_data is not None
            if self.key_itemsize == 4:
                count = torch.zeros([1], dtype=torch.int32, device=self.values_data.device)
                count_tv = torch_tensor_to_tv(count, dtype=tv.uint32)
            elif self.key_itemsize == 8:
                count = torch.zeros([1], dtype=torch.int64, device=self.values_data.device)
                count_tv = torch_tensor_to_tv(count, dtype=tv.uint64)
            else:
                raise NotImplementedError
        if not self.is_cpu:
            assert self.values_data is not None
            if max_size == -1:
                max_size = self.values_data.shape[0]
            keys = torch.empty([max_size], dtype=self.key_dtype, device=self.values_data.device)
            values = torch.empty([max_size], dtype=self.value_dtype, device=self.values_data.device)

        else:
            max_size = self._table.size_cpu()
            count = torch.tensor([max_size], dtype=torch.int64)
            keys = torch.empty([max_size], dtype=self.key_dtype)
            values = torch.empty([max_size], dtype=self.value_dtype)
        keys_tv = torch_tensor_to_tv(keys)
        values_tv = torch_tensor_to_tv(values)
        self._table.items(keys_tv, values_tv, count_tv, stream)
        return keys, values, count
Ejemplo n.º 3
0
    def assign_arange_(self):
        """iterate table, assign values with "arange" value.
        equivalent to 1. get key by items(), 2. use key and arange(key.shape[0]) to insert
        """
        count_tv = tv.Tensor()
        count = torch.Tensor()
        stream = 0
        if not self.is_cpu:
            stream = get_current_stream()
        else:
            assert self.value_dtype in self._valid_value_dtype_for_arange
        if not self.is_cpu:
            assert self.values_data is not None
            if self.key_itemsize == 4:
                count = torch.zeros([1], dtype=torch.int32, device=self.values_data.device)
                count_tv = torch_tensor_to_tv(count, dtype=tv.uint32)
            elif self.key_itemsize == 8:
                count = torch.zeros([1], dtype=torch.int64, device=self.values_data.device)
                count_tv = torch_tensor_to_tv(count, dtype=tv.uint64)
            else:
                raise NotImplementedError
        else:
            max_size = self._table.size_cpu()
            count = torch.tensor([max_size], dtype=torch.int64)

        self._table.assign_arange_(count_tv, stream)
        return count
Ejemplo n.º 4
0
 def insert_exist_keys(self, keys: torch.Tensor, values: torch.Tensor):
     """insert kv that k exists in table. return a uint8 tensor that
     whether insert success.
     """
     keys_tv = torch_tensor_to_tv(keys)
     values_tv = torch_tensor_to_tv(values)
     stream = 0
     if not self.is_cpu:
         stream = get_current_stream()
     is_success = torch.empty([keys.shape[0]], dtype=torch.uint8, device=keys.device)
     is_success_tv = torch_tensor_to_tv(is_success)
     self._table.insert_exist_keys(keys_tv, values_tv, is_success_tv, stream)
     return is_success
Ejemplo n.º 5
0
    def insert(self, keys: torch.Tensor, values: Optional[torch.Tensor] = None):
        """insert hash table by keys and values
        if values is None, only key is inserted, the value is undefined.
        """
        keys_tv = torch_tensor_to_tv(keys)
        values_tv = tv.Tensor()
        if values is not None:
            values_tv = torch_tensor_to_tv(values)
        stream = 0
        if not self.is_cpu:
            stream = get_current_stream()

        return self._table.insert(keys_tv, values_tv, stream)
Ejemplo n.º 6
0
 def query(self, keys: torch.Tensor, values: Optional[torch.Tensor] = None):
     """query value by keys, if values is not None, create a new one.
     return values and a uint8 tensor that whether query success.
     """
     keys_tv = torch_tensor_to_tv(keys)
     if values is None:
         values = torch.empty([keys.shape[0]], dtype=self.value_dtype, device=keys.device)
     values_tv = torch_tensor_to_tv(values)
     stream = 0
     if not self.is_cpu:
         stream = get_current_stream()
     is_empty = torch.empty([keys.shape[0]], dtype=torch.uint8, device=keys.device)
     is_empty_tv = torch_tensor_to_tv(is_empty)
     self._table.query(keys_tv, values_tv, is_empty_tv, stream)
     return values, is_empty
Ejemplo n.º 7
0
    def generate_voxel_with_id(self,
                               pc: torch.Tensor,
                               clear_voxels: bool = True,
                               empty_mean: bool = False):
        """generate voxels/indices/num_point_per_voxel/pc_voxel_ids from 
        point cloud.
        Args:
            pc: [N, 3+] point cloud.
            clear_voxels: if True, call zero on voxels
            empty_mean: if True, full empty location of voxels with mean.
        Returns:
            voxels: voxels
            indices: quantized coords
            num_per_voxel: number of points in a voxel
            pc_voxel_id: voxel id for every point. if not exists, -1.
        """
        assert pc.device.type == self.device.type, "your pc device is wrong"
        expected_hash_data_num = pc.shape[0] * 2
        with torch.no_grad():
            pc_voxel_id = torch.empty([pc.shape[0]],
                                      dtype=torch.int64,
                                      device=self.device)
            pc_voxel_id_tv = torch_tensor_to_tv(pc_voxel_id)

            if self.device.type != "cpu":
                hashdata = torch.empty([expected_hash_data_num, 2],
                                       dtype=torch.int64,
                                       device=pc.device)

                point_indice_data = torch.empty([pc.shape[0]],
                                                dtype=torch.int64,
                                                device=pc.device)

                pc_tv = torch_tensor_to_tv(pc)
                stream = get_current_stream()
                voxels_tv = torch_tensor_to_tv(self.voxels)
                indices_tv = torch_tensor_to_tv(self.indices)
                num_per_voxel_tv = torch_tensor_to_tv(self.num_per_voxel)
                hashdata_tv = torch_tensor_to_tv(hashdata,
                                                 dtype=tv.custom128,
                                                 shape=[hashdata.shape[0]])
                point_indice_data_tv = torch_tensor_to_tv(point_indice_data)
                with torch.cuda.device(pc.device):
                    res = SpconvOps.point2voxel_cuda(
                        pc_tv, voxels_tv, indices_tv, num_per_voxel_tv,
                        hashdata_tv, point_indice_data_tv, pc_voxel_id_tv,
                        self.vsize, self.grid_size, self.grid_stride,
                        self.coors_range, empty_mean, clear_voxels, stream)
                num_voxels = res[0].shape[0]
            else:
                pc_tv = torch_tensor_to_tv(pc)
                stream = get_current_stream()
                voxels_tv = torch_tensor_to_tv(self.voxels)
                indices_tv = torch_tensor_to_tv(self.indices)
                num_per_voxel_tv = torch_tensor_to_tv(self.num_per_voxel)
                hashdata_tv = torch_tensor_to_tv(self.hashdata, dtype=tv.int32)
                res = SpconvOps.point2voxel_cpu(
                    pc_tv, voxels_tv, indices_tv, num_per_voxel_tv,
                    hashdata_tv, pc_voxel_id_tv, self.vsize, self.grid_size,
                    self.grid_stride, self.coors_range, empty_mean,
                    clear_voxels)
                num_voxels = res[0].shape[0]

            return (self.voxels[:num_voxels], self.indices[:num_voxels],
                    self.num_per_voxel[:num_voxels], pc_voxel_id)