def test_nearest(dtype, device): x = tensor([ [-1, -1], [-1, +1], [+1, +1], [+1, -1], [-2, -2], [-2, +2], [+2, +2], [+2, -2], ], dtype, device) y = tensor([ [-1, 0], [+1, 0], [-2, 0], [+2, 0], ], dtype, device) batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device) batch_y = tensor([0, 0, 1, 1], torch.long, device) out = nearest(x, y, batch_x, batch_y) assert out.tolist() == [0, 0, 1, 1, 2, 2, 3, 3] out = nearest(x, y) assert out.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
def nearest(x, y, batch_x=None, batch_y=None): r"""Clusters points in :obj:`x` together which are nearest to a given query point in :obj:`y`. Args: x (Tensor): Node feature matrix :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`. y (Tensor): Node feature matrix :math:`\mathbf{Y} \in \mathbb{R}^{M \times F}`. batch_x (LongTensor, optional): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. (default: :obj:`None`) batch_y (LongTensor, optional): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each node to a specific example. (default: :obj:`None`) :rtype: :class:`LongTensor` .. code-block:: python import torch from torch_geometric.nn import nearest x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]]) batch_x = torch.tensor([0, 0, 0, 0]) y = torch.Tensor([[-1, 0], [1, 0]]) batch_y = torch.tensor([0, 0]) cluster = nearest(x, y, batch_x, batch_y) """ if torch_cluster is None: raise ImportError('`radius` requires `torch-cluster`.') return torch_cluster.nearest(x, y, batch_x, batch_y)
def forward(self, data): cluster = nn_geometric.voxel_grid( data.pos, data.batch, self.pool_rad, start=data.pos.min(dim=0)[0] - self.pool_rad * 0.5, end=data.pos.max(dim=0)[0] + self.pool_rad * 0.5) cluster, perm = consecutive_cluster(cluster) new_pos = scatter(data.pos, cluster, dim=0, reduce='mean') new_batch = data.batch[perm] cluster = nearest(data.pos, new_pos, data.batch, new_batch) cluster, perm = consecutive_cluster(cluster) data.x = scatter(data.x, cluster, dim=0, reduce=self._aggr) data.pos = scatter(data.pos, cluster, dim=0, reduce='mean') data.batch = data.batch[perm] data.edge_index = None data.edge_attr = None data = self.graph_reg(data) return data
def forward(self, data): # 下采样,得到新的数据 cluster = nn_geometric.voxel_grid( data.pos, data.batch, self.pool_rad, # self.pool_rad=0.1,找到目标池化后数据的范围 start=data.pos.min(dim=0)[0] - self.pool_rad * 0.5, end=data.pos.max(dim=0)[0] + self.pool_rad * 0.5) cluster, perm = consecutive_cluster(cluster) new_pos = scatter_('mean', data.pos, cluster) new_batch = data.batch[perm] cluster = nearest(data.pos, new_pos, data.batch, new_batch) data.x = scatter_( self._aggr, data.x, cluster, dim_size=new_pos.size(0)) # 根据self._aggr判断是什么池化,最大还是平均 data.pos = new_pos data.batch = new_batch data.edge_attr = None data = self.graph_reg(data) # 下采样完成后,再进行构图操作 return data
def forward(self, b1, b2): pos = torch.cat([b1.pos, b2.pos], 0) batch = torch.cat([b1.batch, b2.batch], 0) batch, sorted_indx = torch.sort(batch) inv_indx = torch.argsort(sorted_indx) pos = pos[sorted_indx, :] start = pos.min(dim=0)[0] - self.pool_rad * 0.5 end = pos.max(dim=0)[0] + self.pool_rad * 0.5 cluster = torch_geometric.nn.voxel_grid(pos, batch, self.pool_rad, start=start, end=end) cluster, perm = consecutive_cluster(cluster) superpoint = scatter(pos, cluster, dim=0, reduce='mean') new_batch = batch[perm] cluster = nearest(pos, superpoint, batch, new_batch) cluster, perm = consecutive_cluster(cluster) pos = scatter(pos, cluster, dim=0, reduce='mean') branch_mask = torch.zeros(batch.size(0)).bool() branch_mask[0:b1.batch.size(0)] = 1 cluster = cluster[inv_indx] nVoxels = len(cluster.unique()) x_b1 = torch.ones(nVoxels, b1.x.shape[1], device=b1.x.device) x_b2 = torch.ones(nVoxels, b2.x.shape[1], device=b2.x.device) x_b1 = scatter(b1.x, cluster[branch_mask], dim=0, out=x_b1, reduce='mean') x_b2 = scatter(b2.x, cluster[~branch_mask], dim=0, out=x_b2, reduce='mean') x = torch.cat([x_b1, x_b2], 1) batch = batch[perm] b1.x = x b1.pos = pos b1.batch = batch b1.edge_attr = None b1.edge_index = None return b1
def point_cloud_nearest_dist(points_source, points_target): # dummy batch IDs # source_ids = torch.zeros(points_source.shape[0], dtype=torch.int64, device=points_source.device) # target_ids = torch.zeros(points_target.shape[0], dtype=torch.int64, device=points_source.device) # get the nearest point in target # nearest_ind = torch_cluster.nearest(points_source, points_tarege, source_ids, target_ids) nearest_ind = torch_cluster.nearest(points_source, points_target) # compute the distances themselves nearest_pos = points_target[nearest_ind, :] dists = utils.norm(points_source - nearest_pos) return dists