def test_nearest(dtype, device):
    x = tensor([
        [-1, -1],
        [-1, +1],
        [+1, +1],
        [+1, -1],
        [-2, -2],
        [-2, +2],
        [+2, +2],
        [+2, -2],
    ], dtype, device)
    y = tensor([
        [-1, 0],
        [+1, 0],
        [-2, 0],
        [+2, 0],
    ], dtype, device)

    batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
    batch_y = tensor([0, 0, 1, 1], torch.long, device)

    out = nearest(x, y, batch_x, batch_y)
    assert out.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]

    out = nearest(x, y)
    assert out.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
Exemple #2
0
def nearest(x, y, batch_x=None, batch_y=None):
    r"""Clusters points in :obj:`x` together which are nearest to a given query
    point in :obj:`y`.

    Args:
        x (Tensor): Node feature matrix
            :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
        y (Tensor): Node feature matrix
            :math:`\mathbf{Y} \in \mathbb{R}^{M \times F}`.
        batch_x (LongTensor, optional): Batch vector
            :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
            node to a specific example. (default: :obj:`None`)
        batch_y (LongTensor, optional): Batch vector
            :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
            node to a specific example. (default: :obj:`None`)

    :rtype: :class:`LongTensor`

    .. code-block:: python

        import torch
        from torch_geometric.nn import nearest

        x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
        batch_x = torch.tensor([0, 0, 0, 0])
        y = torch.Tensor([[-1, 0], [1, 0]])
        batch_y = torch.tensor([0, 0])
        cluster = nearest(x, y, batch_x, batch_y)
    """
    if torch_cluster is None:
        raise ImportError('`radius` requires `torch-cluster`.')

    return torch_cluster.nearest(x, y, batch_x, batch_y)
Exemple #3
0
    def forward(self, data):

        cluster = nn_geometric.voxel_grid(
            data.pos,
            data.batch,
            self.pool_rad,
            start=data.pos.min(dim=0)[0] - self.pool_rad * 0.5,
            end=data.pos.max(dim=0)[0] + self.pool_rad * 0.5)

        cluster, perm = consecutive_cluster(cluster)

        new_pos = scatter(data.pos, cluster, dim=0, reduce='mean')
        new_batch = data.batch[perm]

        cluster = nearest(data.pos, new_pos, data.batch, new_batch)

        cluster, perm = consecutive_cluster(cluster)

        data.x = scatter(data.x, cluster, dim=0, reduce=self._aggr)
        data.pos = scatter(data.pos, cluster, dim=0, reduce='mean')

        data.batch = data.batch[perm]

        data.edge_index = None
        data.edge_attr = None

        data = self.graph_reg(data)

        return data
Exemple #4
0
    def forward(self, data):
        # 下采样,得到新的数据
        cluster = nn_geometric.voxel_grid(
            data.pos,
            data.batch,
            self.pool_rad,  # self.pool_rad=0.1,找到目标池化后数据的范围
            start=data.pos.min(dim=0)[0] - self.pool_rad * 0.5,
            end=data.pos.max(dim=0)[0] + self.pool_rad * 0.5)

        cluster, perm = consecutive_cluster(cluster)

        new_pos = scatter_('mean', data.pos, cluster)
        new_batch = data.batch[perm]

        cluster = nearest(data.pos, new_pos, data.batch, new_batch)
        data.x = scatter_(
            self._aggr, data.x, cluster,
            dim_size=new_pos.size(0))  # 根据self._aggr判断是什么池化,最大还是平均

        data.pos = new_pos
        data.batch = new_batch
        data.edge_attr = None

        data = self.graph_reg(data)  # 下采样完成后,再进行构图操作
        return data
Exemple #5
0
    def forward(self, b1, b2):
        pos = torch.cat([b1.pos, b2.pos], 0)
        batch = torch.cat([b1.batch, b2.batch], 0)

        batch, sorted_indx = torch.sort(batch)
        inv_indx = torch.argsort(sorted_indx)
        pos = pos[sorted_indx, :]

        start = pos.min(dim=0)[0] - self.pool_rad * 0.5
        end = pos.max(dim=0)[0] + self.pool_rad * 0.5

        cluster = torch_geometric.nn.voxel_grid(pos,
                                                batch,
                                                self.pool_rad,
                                                start=start,
                                                end=end)
        cluster, perm = consecutive_cluster(cluster)

        superpoint = scatter(pos, cluster, dim=0, reduce='mean')
        new_batch = batch[perm]

        cluster = nearest(pos, superpoint, batch, new_batch)

        cluster, perm = consecutive_cluster(cluster)

        pos = scatter(pos, cluster, dim=0, reduce='mean')
        branch_mask = torch.zeros(batch.size(0)).bool()
        branch_mask[0:b1.batch.size(0)] = 1

        cluster = cluster[inv_indx]

        nVoxels = len(cluster.unique())

        x_b1 = torch.ones(nVoxels, b1.x.shape[1], device=b1.x.device)
        x_b2 = torch.ones(nVoxels, b2.x.shape[1], device=b2.x.device)

        x_b1 = scatter(b1.x,
                       cluster[branch_mask],
                       dim=0,
                       out=x_b1,
                       reduce='mean')
        x_b2 = scatter(b2.x,
                       cluster[~branch_mask],
                       dim=0,
                       out=x_b2,
                       reduce='mean')

        x = torch.cat([x_b1, x_b2], 1)

        batch = batch[perm]

        b1.x = x
        b1.pos = pos
        b1.batch = batch
        b1.edge_attr = None
        b1.edge_index = None

        return b1
Exemple #6
0
def point_cloud_nearest_dist(points_source, points_target):

    # dummy batch IDs
    # source_ids = torch.zeros(points_source.shape[0], dtype=torch.int64, device=points_source.device)
    # target_ids = torch.zeros(points_target.shape[0], dtype=torch.int64, device=points_source.device)

    # get the nearest point in target
    # nearest_ind = torch_cluster.nearest(points_source, points_tarege, source_ids, target_ids)
    nearest_ind = torch_cluster.nearest(points_source, points_target)

    # compute the distances themselves
    nearest_pos = points_target[nearest_ind, :] 
    dists = utils.norm(points_source - nearest_pos)
    return dists