def _process(self, data): if self._mode == "last": data = shuffle_data(data) coords = ((data.pos) / self._grid_size).int() if "batch" not in data: cluster = grid_cluster(coords, torch.tensor([1, 1, 1])) else: cluster = voxel_grid(coords, data.batch, 1) cluster, unique_pos_indices = consecutive_cluster(cluster) skip_keys = [] if self._quantize_coords: skip_keys.append("pos") data = group_data(data, cluster, unique_pos_indices, mode=self._mode, skip_keys=skip_keys) if self._quantize_coords: data.pos = coords[unique_pos_indices] return data
def test_grid_cluster(test, dtype, device): pos = tensor(test['pos'], dtype, device) size = tensor(test['size'], dtype, device) start = tensor(test.get('start'), dtype, device) end = tensor(test.get('end'), dtype, device) cluster = grid_cluster(pos, size, start, end) assert cluster.tolist() == test['cluster']
def voxel_grid( pos: Tensor, size: Union[float, List[float], Tensor], batch: Optional[Tensor] = None, start: Optional[Union[float, List[float], Tensor]] = None, end: Optional[Union[float, List[float], Tensor]] = None, ) -> Tensor: r"""Voxel grid pooling from the, *e.g.*, `Dynamic Edge-Conditioned Filters in Convolutional Networks on Graphs <https://arxiv.org/abs/1704.02901>`_ paper, which overlays a regular grid of user-defined size over a point cloud and clusters all points within the same voxel. Args: pos (Tensor): Node position matrix :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times D}`. size (float or [float] or Tensor): Size of a voxel (in each dimension). batch (LongTensor, optional): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots,B-1\}}^N`, which assigns each node to a specific example. (default: :obj:`None`) start (float or [float] or Tensor, optional): Start coordinates of the grid (in each dimension). If set to :obj:`None`, will be set to the minimum coordinates found in :attr:`pos`. (default: :obj:`None`) end (float or [float] or Tensor, optional): End coordinates of the grid (in each dimension). If set to :obj:`None`, will be set to the maximum coordinates found in :attr:`pos`. (default: :obj:`None`) :rtype: :class:`LongTensor` """ if grid_cluster is None: raise ImportError('`voxel_grid` requires `torch-cluster`.') pos = pos.unsqueeze(-1) if pos.dim() == 1 else pos num_nodes, dim = pos.size() size = size.tolist() if torch.is_tensor(size) else size start = start.tolist() if torch.is_tensor(start) else start end = end.tolist() if torch.is_tensor(end) else end size, start, end = repeat(size, dim), repeat(start, dim), repeat(end, dim) if batch is None: batch = torch.zeros(pos.shape[0], dtype=torch.long) pos = torch.cat([pos, batch.unsqueeze(-1).type_as(pos)], dim=-1) size = size + [1] start = None if start is None else start + [0] end = None if end is None else end + [batch.max().item()] size = torch.tensor(size, dtype=pos.dtype, device=pos.device) if start is not None: start = torch.tensor(start, dtype=pos.dtype, device=pos.device) if end is not None: end = torch.tensor(end, dtype=pos.dtype, device=pos.device) return grid_cluster(pos, size, start, end)
def voxel_grid(pos, batch, size, start=None, end=None): pos = pos.unsqueeze(-1) if pos.dim() == 1 else pos num_nodes, dim = pos.size() size, start, end = repeat(size, dim), repeat(start, dim), repeat(end, dim) pos = torch.cat([pos, batch.unsqueeze(-1).type_as(pos)], dim=-1) size = size + [1] start = None if start is None else start + [0] end = None if end is None else end + [batch.max().item() + 1] size, start, end = pos.new(size), pos.new(start), pos.new(end) return grid_cluster(pos, size, start, end)
def test_grid_cluster_gpu(tensor, i): # pragma: no cover data = tests[i] pos = getattr(torch.cuda, tensor)(data['pos']) size = getattr(torch.cuda, tensor)(data['size']) start = data.get('start') start = start if start is None else getattr(torch.cuda, tensor)(start) end = data.get('end') end = end if end is None else getattr(torch.cuda, tensor)(end) cluster = grid_cluster(pos, size, start, end) assert cluster.tolist() == data['cluster']
def _process(self, data): if self._mode == "last": data = shuffle_data(data) coords = torch.round((data.pos) / self._grid_size) if "batch" not in data: cluster = grid_cluster(coords, torch.tensor([1, 1, 1])) else: cluster = voxel_grid(coords, data.batch, 1) cluster, unique_pos_indices = consecutive_cluster(cluster) data = group_data(data, cluster, unique_pos_indices, mode=self._mode) if self._quantize_coords: data.coords = coords[unique_pos_indices].int() return data
def voxel_grid(pos, batch, size, start=None, end=None): pos = pos.unsqueeze(-1) if pos.dim() == 1 else pos num_nodes, dim = pos.size() size, start, end = repeat(size, dim), repeat(start, dim), repeat(end, dim) pos = torch.cat([pos, batch.unsqueeze(-1).type_as(pos)], dim=-1) size = size + [1] start = None if start is None else start + [0] end = None if end is None else end + [batch.max().item()] size = torch.tensor(size, dtype=pos.dtype, device=pos.device) if start is not None: start = torch.tensor(start, dtype=pos.dtype, device=pos.device) if end is not None: end = torch.tensor(end, dtype=pos.dtype, device=pos.device) return grid_cluster(pos, size, start, end)
def voxel_grid(pos, size, start=None, end=None, batch=None, consecutive=True): pos = pos.unsqueeze(-1) if pos.dim() == 1 else pos n, d = pos.size() size = pos.new(repeat_to(size, d)) start = pos.new(repeat_to(start, d)) if start is not None else start end = pos.new(repeat_to(end, d)) if end is not None else end if batch is not None: pos = torch.cat([pos, batch.unsqueeze(-1).type_as(pos)], dim=-1) size = torch.cat([size, size.new(1).fill_(1)], dim=-1) if start is not None: start = torch.cat([start, start.new(1).fill_(0)], dim=-1) if end is not None: end = torch.cat([end, end.new(1).fill_(batch.max() + 1)], dim=-1) cluster = grid_cluster(pos, size, start, end) if consecutive: cluster, batch = consecutive_cluster(cluster, batch) return cluster, batch