示例#1
0
    def _process(self, data):
        if self._mode == "last":
            data = shuffle_data(data)

        coords = ((data.pos) / self._grid_size).int()
        if "batch" not in data:
            cluster = grid_cluster(coords, torch.tensor([1, 1, 1]))
        else:
            cluster = voxel_grid(coords, data.batch, 1)
        cluster, unique_pos_indices = consecutive_cluster(cluster)

        skip_keys = []
        if self._quantize_coords:
            skip_keys.append("pos")

        data = group_data(data,
                          cluster,
                          unique_pos_indices,
                          mode=self._mode,
                          skip_keys=skip_keys)

        if self._quantize_coords:
            data.pos = coords[unique_pos_indices]

        return data
示例#2
0
def test_grid_cluster(test, dtype, device):
    pos = tensor(test['pos'], dtype, device)
    size = tensor(test['size'], dtype, device)
    start = tensor(test.get('start'), dtype, device)
    end = tensor(test.get('end'), dtype, device)

    cluster = grid_cluster(pos, size, start, end)
    assert cluster.tolist() == test['cluster']
示例#3
0
def voxel_grid(
    pos: Tensor,
    size: Union[float, List[float], Tensor],
    batch: Optional[Tensor] = None,
    start: Optional[Union[float, List[float], Tensor]] = None,
    end: Optional[Union[float, List[float], Tensor]] = None,
) -> Tensor:
    r"""Voxel grid pooling from the, *e.g.*, `Dynamic Edge-Conditioned Filters
    in Convolutional Networks on Graphs <https://arxiv.org/abs/1704.02901>`_
    paper, which overlays a regular grid of user-defined size over a point
    cloud and clusters all points within the same voxel.

    Args:
        pos (Tensor): Node position matrix
            :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times D}`.
        size (float or [float] or Tensor): Size of a voxel (in each dimension).
        batch (LongTensor, optional): Batch vector :math:`\mathbf{b} \in {\{ 0,
            \ldots,B-1\}}^N`, which assigns each node to a specific example.
            (default: :obj:`None`)
        start (float or [float] or Tensor, optional): Start coordinates of the
            grid (in each dimension). If set to :obj:`None`, will be set to the
            minimum coordinates found in :attr:`pos`. (default: :obj:`None`)
        end (float or [float] or Tensor, optional): End coordinates of the grid
            (in each dimension). If set to :obj:`None`, will be set to the
            maximum coordinates found in :attr:`pos`. (default: :obj:`None`)

    :rtype: :class:`LongTensor`
    """

    if grid_cluster is None:
        raise ImportError('`voxel_grid` requires `torch-cluster`.')

    pos = pos.unsqueeze(-1) if pos.dim() == 1 else pos
    num_nodes, dim = pos.size()

    size = size.tolist() if torch.is_tensor(size) else size
    start = start.tolist() if torch.is_tensor(start) else start
    end = end.tolist() if torch.is_tensor(end) else end

    size, start, end = repeat(size, dim), repeat(start, dim), repeat(end, dim)

    if batch is None:
        batch = torch.zeros(pos.shape[0], dtype=torch.long)

    pos = torch.cat([pos, batch.unsqueeze(-1).type_as(pos)], dim=-1)
    size = size + [1]
    start = None if start is None else start + [0]
    end = None if end is None else end + [batch.max().item()]

    size = torch.tensor(size, dtype=pos.dtype, device=pos.device)
    if start is not None:
        start = torch.tensor(start, dtype=pos.dtype, device=pos.device)
    if end is not None:
        end = torch.tensor(end, dtype=pos.dtype, device=pos.device)

    return grid_cluster(pos, size, start, end)
示例#4
0
def voxel_grid(pos, batch, size, start=None, end=None):
    pos = pos.unsqueeze(-1) if pos.dim() == 1 else pos
    num_nodes, dim = pos.size()

    size, start, end = repeat(size, dim), repeat(start, dim), repeat(end, dim)

    pos = torch.cat([pos, batch.unsqueeze(-1).type_as(pos)], dim=-1)
    size = size + [1]
    start = None if start is None else start + [0]
    end = None if end is None else end + [batch.max().item() + 1]

    size, start, end = pos.new(size), pos.new(start), pos.new(end)

    return grid_cluster(pos, size, start, end)
示例#5
0
def test_grid_cluster_gpu(tensor, i):  # pragma: no cover
    data = tests[i]

    pos = getattr(torch.cuda, tensor)(data['pos'])
    size = getattr(torch.cuda, tensor)(data['size'])

    start = data.get('start')
    start = start if start is None else getattr(torch.cuda, tensor)(start)

    end = data.get('end')
    end = end if end is None else getattr(torch.cuda, tensor)(end)

    cluster = grid_cluster(pos, size, start, end)
    assert cluster.tolist() == data['cluster']
示例#6
0
    def _process(self, data):
        if self._mode == "last":
            data = shuffle_data(data)

        coords = torch.round((data.pos) / self._grid_size)
        if "batch" not in data:
            cluster = grid_cluster(coords, torch.tensor([1, 1, 1]))
        else:
            cluster = voxel_grid(coords, data.batch, 1)
        cluster, unique_pos_indices = consecutive_cluster(cluster)

        data = group_data(data, cluster, unique_pos_indices, mode=self._mode)
        if self._quantize_coords:
            data.coords = coords[unique_pos_indices].int()

        return data
示例#7
0
def voxel_grid(pos, batch, size, start=None, end=None):
    pos = pos.unsqueeze(-1) if pos.dim() == 1 else pos
    num_nodes, dim = pos.size()

    size, start, end = repeat(size, dim), repeat(start, dim), repeat(end, dim)

    pos = torch.cat([pos, batch.unsqueeze(-1).type_as(pos)], dim=-1)
    size = size + [1]
    start = None if start is None else start + [0]
    end = None if end is None else end + [batch.max().item()]

    size = torch.tensor(size, dtype=pos.dtype, device=pos.device)
    if start is not None:
        start = torch.tensor(start, dtype=pos.dtype, device=pos.device)
    if end is not None:
        end = torch.tensor(end, dtype=pos.dtype, device=pos.device)

    return grid_cluster(pos, size, start, end)
def voxel_grid(pos, size, start=None, end=None, batch=None, consecutive=True):
    pos = pos.unsqueeze(-1) if pos.dim() == 1 else pos

    n, d = pos.size()

    size = pos.new(repeat_to(size, d))
    start = pos.new(repeat_to(start, d)) if start is not None else start
    end = pos.new(repeat_to(end, d)) if end is not None else end

    if batch is not None:
        pos = torch.cat([pos, batch.unsqueeze(-1).type_as(pos)], dim=-1)
        size = torch.cat([size, size.new(1).fill_(1)], dim=-1)

        if start is not None:
            start = torch.cat([start, start.new(1).fill_(0)], dim=-1)
        if end is not None:
            end = torch.cat([end, end.new(1).fill_(batch.max() + 1)], dim=-1)

    cluster = grid_cluster(pos, size, start, end)

    if consecutive:
        cluster, batch = consecutive_cluster(cluster, batch)

    return cluster, batch