def forward(self, data): cluster = nn_geometric.voxel_grid( data.pos, data.batch, self.pool_rad, start=data.pos.min(dim=0)[0] - self.pool_rad * 0.5, end=data.pos.max(dim=0)[0] + self.pool_rad * 0.5) cluster, perm = consecutive_cluster(cluster) new_pos = scatter(data.pos, cluster, dim=0, reduce='mean') new_batch = data.batch[perm] cluster = nearest(data.pos, new_pos, data.batch, new_batch) cluster, perm = consecutive_cluster(cluster) data.x = scatter(data.x, cluster, dim=0, reduce=self._aggr) data.pos = scatter(data.pos, cluster, dim=0, reduce='mean') data.batch = data.batch[perm] data.edge_index = None data.edge_attr = None data = self.graph_reg(data) return data
def forward(self, b1, b2): pos = torch.cat([b1.pos, b2.pos], 0) batch = torch.cat([b1.batch, b2.batch], 0) batch, sorted_indx = torch.sort(batch) inv_indx = torch.argsort(sorted_indx) pos = pos[sorted_indx, :] start = pos.min(dim=0)[0] - self.pool_rad * 0.5 end = pos.max(dim=0)[0] + self.pool_rad * 0.5 cluster = torch_geometric.nn.voxel_grid(pos, batch, self.pool_rad, start=start, end=end) cluster, perm = consecutive_cluster(cluster) superpoint = scatter(pos, cluster, dim=0, reduce='mean') new_batch = batch[perm] cluster = nearest(pos, superpoint, batch, new_batch) cluster, perm = consecutive_cluster(cluster) pos = scatter(pos, cluster, dim=0, reduce='mean') branch_mask = torch.zeros(batch.size(0)).bool() branch_mask[0:b1.batch.size(0)] = 1 cluster = cluster[inv_indx] nVoxels = len(cluster.unique()) x_b1 = torch.ones(nVoxels, b1.x.shape[1], device=b1.x.device) x_b2 = torch.ones(nVoxels, b2.x.shape[1], device=b2.x.device) x_b1 = scatter(b1.x, cluster[branch_mask], dim=0, out=x_b1, reduce='mean') x_b2 = scatter(b2.x, cluster[~branch_mask], dim=0, out=x_b2, reduce='mean') x = torch.cat([x_b1, x_b2], 1) batch = batch[perm] b1.x = x b1.pos = pos b1.batch = batch b1.edge_attr = None b1.edge_index = None return b1
def _process(self, data): if self._mode == "last": data = shuffle_data(data) coords = ((data.pos) / self._grid_size).int() if "batch" not in data: cluster = grid_cluster(coords, torch.tensor([1, 1, 1])) else: cluster = voxel_grid(coords, data.batch, 1) cluster, unique_pos_indices = consecutive_cluster(cluster) skip_keys = [] if self._quantize_coords: skip_keys.append("pos") data = group_data(data, cluster, unique_pos_indices, mode=self._mode, skip_keys=skip_keys) if self._quantize_coords: data.pos = coords[unique_pos_indices] return data
def __call__(self, data): num_nodes = data.num_nodes if "batch" not in data: batch = data.pos.new_zeros(num_nodes, dtype=torch.long) else: batch = data.batch cluster = voxel_grid(data.pos, batch, self.size, self.start, self.end) cluster, perm = consecutive_cluster(cluster) for key, item in data: if bool(re.search("edge", key)): raise ValueError( "GridSampling does not support coarsening of edges") if torch.is_tensor(item) and item.size(0) == num_nodes: if key == "y": item = F.one_hot(item, num_classes=self.num_classes) item = scatter_add(item, cluster, dim=0) data[key] = item.argmax(dim=-1) elif key == "batch": data[key] = item[perm] else: data[key] = scatter_mean(item, cluster, dim=0) return data
def forward(self, data): # 下采样,得到新的数据 cluster = nn_geometric.voxel_grid( data.pos, data.batch, self.pool_rad, # self.pool_rad=0.1,找到目标池化后数据的范围 start=data.pos.min(dim=0)[0] - self.pool_rad * 0.5, end=data.pos.max(dim=0)[0] + self.pool_rad * 0.5) cluster, perm = consecutive_cluster(cluster) new_pos = scatter_('mean', data.pos, cluster) new_batch = data.batch[perm] cluster = nearest(data.pos, new_pos, data.batch, new_batch) data.x = scatter_( self._aggr, data.x, cluster, dim_size=new_pos.size(0)) # 根据self._aggr判断是什么池化,最大还是平均 data.pos = new_pos data.batch = new_batch data.edge_attr = None data = self.graph_reg(data) # 下采样完成后,再进行构图操作 return data
def max_pool_x(cluster, x, batch: Optional[torch.Tensor] = None, size: Optional[int] = None): r"""Max-Pools node features according to the clustering defined in :attr:`cluster`. Args: cluster (LongTensor): Cluster vector :math:`\mathbf{c} \in \{ 0, \ldots, N - 1 \}^N`, which assigns each node to a specific cluster. x (Tensor): Node feature matrix :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`. batch (LongTensor): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. size (int, optional): The maximum number of clusters in a single example. This property is useful to obtain a batch-wise dense representation, *e.g.* for applying FC layers, but should only be used if the size of the maximum number of clusters per example is known in advance. (default: :obj:`None`) :rtype: (:class:`Tensor`, :class:`LongTensor`) if :attr:`size` is :obj:`None`, else :class:`Tensor` """ # TK: Throws error when trying to compile, if batch is None batch.max() is not defined # if size is not None: # batch_size = int(batch.max().item()) + 1 # return _max_pool_x(cluster, x, batch_size * size), None cluster, perm = consecutive_cluster(cluster) x = _max_pool_x(cluster, x) # TK: Put the following under an if-statement # batch = pool_batch(perm, batch) if not batch is None: batch = batch[perm] return x, batch
def max_pool(cluster, x: torch.Tensor, edge_index: torch.Tensor, batch: Optional[torch.Tensor]=None, transform: bool=None): r"""Pools and coarsens a graph given by the :class:`torch_geometric.data.Data` object according to the clustering defined in :attr:`cluster`. All nodes within the same cluster will be represented as one node. Final node features are defined by the *maximum* features of all nodes within the same cluster, node positions are averaged and edge indices are defined to be the union of the edge indices of all nodes within the same cluster. Args: cluster (LongTensor): Cluster vector :math:`\mathbf{c} \in \{ 0, \ldots, N - 1 \}^N`, which assigns each node to a specific cluster. data (Data): Graph data object. transform (callable, optional): A function/transform that takes in the coarsened and pooled :obj:`torch_geometric.data.Data` object and returns a transformed version. (default: :obj:`None`) :rtype: :class:`torch_geometric.data.Data` """ cluster, perm = consecutive_cluster(cluster) x = _max_pool_x(cluster, x) # x = None if data.x is None else torch_geometric.nn.pool.max_pool._max_pool_x(cluster, data.x) edge_index, edge_attr = pool_edge(cluster, edge_index, edge_attr=None) if not(batch is None): # batch = pool_batch(perm, batch) batch = batch[perm] return x, edge_index, batch, edge_attr
def sample(self, pos=None, x=None, batch=None): if len(pos.shape) != 2: raise ValueError("This class is for sparse data and expects the pos tensor to be of dimension 2") pool = voxel_grid(pos, batch, self._subsampling_param) pool, perm = consecutive_cluster(pool) batch = pool_batch(perm, batch) if x is not None: return pool_pos(pool, x), pool_pos(pool, pos), batch else: return None, pool_pos(pool, pos), batch
def _prepare_data(self, data): coords = torch.round((data.pos) / self._grid_size).long() cluster = voxel_grid(coords, data.batch, 1) cluster, unique_pos_indices = consecutive_cluster(cluster) coords = coords[unique_pos_indices] new_batch = data.batch[unique_pos_indices] new_pos = data.pos[unique_pos_indices] x = self._aggregate(data.x, cluster, unique_pos_indices) sparse_data = Batch(x=x, pos=new_pos, coords=coords, batch=new_batch) return sparse_data, cluster
def _process(self, data): num_nodes = data.num_nodes if "batch" not in data: batch = data.pos.new_zeros(num_nodes, dtype=torch.long) else: batch = data.batch cluster = voxel_grid(data.pos, batch, self.size, self.start, self.end) cluster, perm = consecutive_cluster(cluster) return group_data(data, cluster, perm, mode="mean")
def _process(self, data): if self._mode == "last": data = shuffle_data(data) coords = torch.round((data.pos) / self._grid_size) if "batch" not in data: cluster = grid_cluster(coords, torch.tensor([1, 1, 1])) else: cluster = voxel_grid(coords, data.batch, 1) cluster, unique_pos_indices = consecutive_cluster(cluster) data = group_data(data, cluster, unique_pos_indices, mode=self._mode) if self._quantize_coords: data.coords = coords[unique_pos_indices].int() return data
def __call__(self, data): num_nodes = data.num_nodes if 'batch' not in data: batch = data.pos.new_zeros(num_nodes, dtype=torch.long) else: batch = data.batch cluster = voxel_grid(data.pos, batch, self.size, self.start, self.end) cluster, perm = consecutive_cluster(cluster) for key, item in data: if bool(re.search('edge', key)): raise ValueError( 'GridSampling does not support coarsening of edges') if torch.is_tensor(item) and item.size(0) == num_nodes: if key == 'batch': data[key] = item[perm] else: data[key] = scatter_mean(item, cluster, dim=0) return data
def __call__(self, data): try: batch = data.batch except AttributeError: batch = torch.zeros(data.x.shape[0]) # points = [data.pos] # First downsample points = [data.pos] list_pool = [] list_batch = [batch] for ind, voxel_size in enumerate(self.list_voxel_size): pool = voxel_grid(points[-1], list_batch[-1], voxel_size) pool, perm = consecutive_cluster(pool) list_batch.append(pool_batch(perm, list_batch[-1])) points.append(pool_pos(pool, points[-1])) try: res = MultiScaleBatch( batch=data.batch, list_pool=list_pool, points=points, list_batch=list_batch, x=data.x, y=data.y, pos=data.pos) except AttributeError: res = MultiScaleData( list_pool=list_pool, points=points, x=data.x, y=data.y, pos=data.pos) return res
def __call__(self, data): num_nodes = data.num_nodes pos = data.pos batch = data.batch pool = voxel_grid(pos, batch, self._subsampling_param) pool, _ = consecutive_cluster(pool) for key, item in data: if bool(re.search('edge', key)): continue if torch.is_tensor(item) and item.size(0) == num_nodes: if key == 'y': one_hot = torch.zeros((item.shape[0], self._num_classes))\ .scatter(1, item.unsqueeze(-1), 1) aggr_labels = scatter_add(one_hot, pool, dim=0) data[key] = torch.argmax(aggr_labels, -1) else: data[key] = pool_pos(pool, item).to(item.dtype) return data
def mgpool(x, pos, edge_index, batch, mask=None): adj_values = torch.ones(edge_index.shape[1]).cuda() cluster = graclus(edge_index) cluster, perm = consecutive_cluster(cluster) index = torch.stack([cluster, torch.arange(0, x.shape[0]).cuda()], dim=0) values = torch.ones(cluster.shape[0], dtype=torch.float).cuda() uniq, inv, counts = torch.unique(cluster, return_inverse=True, return_counts=True) newsize = uniq.shape[0] origsize = x.shape[0] new_batch = pool_batch(perm, batch) # Compute random walk graph laplacian: laplacian_index, laplacian_weights = get_laplacian(edge_index, normalization='rw') laplacian_index, laplacian_weights = torch_sparse.coalesce( laplacian_index, laplacian_weights, m=origsize, n=origsize) index, values = torch_sparse.coalesce(index, values, m=newsize, n=origsize) # P^T matrix new_feat = torch_sparse.spmm(index, values, m=newsize, n=origsize, matrix=x) # P^T X new_pos = torch_sparse.spmm(index, values, m=newsize, n=origsize, matrix=pos) # P^T POS new_adj, new_adj_val = torch_sparse.spspmm(index, values, edge_index, adj_values, m=newsize, k=origsize, n=origsize, coalesced=True) # P^T A index, values = torch_sparse.transpose(index, values, m=newsize, n=origsize, coalesced=True) # P new_adj, new_adj_val = torch_sparse.spspmm(new_adj, new_adj_val, index, values, m=newsize, k=origsize, n=newsize, coalesced=True) # (P^T A) P # Precompute QP : values = torch.ones(cluster.shape[0], dtype=torch.float).cuda() index, values = torch_sparse.spspmm(laplacian_index, laplacian_weights, index, values, m=origsize, k=origsize, n=newsize, coalesced=True) return new_adj, new_feat, new_pos, new_batch, index, values, origsize, newsize
def community_pooling(cluster, data): """Pools features and edges of all cluster members All cluster members are pooled into a single node that is assigned: - the max cluster value for each feature - the average cluster nodes position Args: cluster ([type]): clusters data ([type]): features tensor Returns: pooled features tensor """ # determine what the batches has as attributes has_internal_edges = hasattr(data, 'internal_edge_index') has_pos2D = hasattr(data, 'pos2D') has_pos = hasattr(data, 'pos') has_cluster = hasattr(data, 'cluster0') cluster, perm = consecutive_cluster(cluster) cluster = cluster.to(data.x.device) # pool the node infos x, _ = scatter_max(data.x, cluster, dim=0) # pool the edges edge_index, edge_attr = pool_edge(cluster, data.edge_index, data.edge_attr) # pool internal edges if necessary if has_internal_edges: internal_edge_index, internal_edge_attr = pool_edge( cluster, data.internal_edge_index, data.internal_edge_attr) # pool the pos if has_pos: pos = scatter_mean(data.pos, cluster, dim=0) if has_pos2D: pos2D = scatter_mean(data.pos2D, cluster, dim=0) if has_cluster: c0, c1 = data.cluster0, data.cluster1 # pool batch if hasattr(data, 'batch'): batch = None if data.batch is None else pool_batch(perm, data.batch) data = Batch(batch=batch, x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos) if has_internal_edges: data.internal_edge_index = internal_edge_index data.internal_edge_attr = internal_edge_attr if has_cluster: data.cluster0 = c0 data.cluster1 = c1 else: data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos) if has_internal_edges: data.internal_edge_index = internal_edge_index data.internal_edge_attr = internal_edge_attr if has_pos2D: data.pos2D = pos2D if has_cluster: data.cluster0 = c0 data.cluster1 = c1 return data
def community_pooling(cluster, data): # determine what the batches has as attributes has_internal_edges = hasattr(data, 'internal_edge_index') has_pos2D = hasattr(data, 'pos2D') has_pos = hasattr(data, 'pos') has_cluster = hasattr(data, 'cluster0') cluster, perm = consecutive_cluster(cluster) cluster = cluster.to(data.x.device) # pool the node infos x, _ = scatter_max(data.x, cluster, dim=0) # pool the edges edge_index, edge_attr = pool_edge(cluster, data.edge_index, data.edge_attr) # pool internal edges if necessary if has_internal_edges: internal_edge_index, internal_edge_attr = pool_edge( cluster, data.internal_edge_index, data.internal_edge_attr) # pool the pos if has_pos: pos = scatter_mean(data.pos, cluster, dim=0) if has_pos2D: pos2D = scatter_mean(data.pos2D, cluster, dim=0) if has_cluster: c0, c1 = data.cluster0, data.cluster1 # pool batch if hasattr(data, 'batch'): batch = None if data.batch is None else pool_batch(perm, data.batch) data = Batch(batch=batch, x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos) if has_internal_edges: data.internal_edge_index = internal_edge_index data.internal_edge_attr = internal_edge_attr if has_cluster: data.cluster0 = c0 data.cluster1 = c1 else: data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos) if has_internal_edges: data.internal_edge_index = internal_edge_index data.internal_edge_attr = internal_edge_attr if has_pos2D: data.pos2D = pos2D if has_cluster: data.cluster0 = c0 data.cluster1 = c1 return data
def community_pooling(cluster, data): """Pools features and edges of all cluster members All cluster members are pooled into a single node that is assigned: - the max cluster value for each feature - the average cluster nodes position Args: cluster ([type]): clusters data ([type]): features tensor Returns: pooled features tensor Example: >>> import torch >>> from torch_geometric.data import Data, Batch >>> edge_index = torch.tensor([[0, 1, 1, 2, 3, 4, 4, 5], >>> [1, 0, 2, 1, 4, 3, 5, 4]], dtype=torch.long) >>> x = torch.tensor([[0], [1], [2], [3], [4], [5]], >>> dtype=torch.float) >>> data = Data(x=x, edge_index=edge_index) >>> data.pos = torch.tensor(np.random.rand(data.num_nodes, 3)) >>> c = community_detection(data.edge_index, data.num_nodes) >>> batch = Batch().from_data_list([data, data]) >>> cluster = community_detection(batch.edge_index, batch.num_nodes) >>> new_batch = community_pooling(cluster, batch) """ # determine what the batches has as attributes has_internal_edges = hasattr(data, 'internal_edge_index') has_pos2D = hasattr(data, 'pos2D') has_pos = hasattr(data, 'pos') has_cluster = hasattr(data, 'cluster0') cluster, perm = consecutive_cluster(cluster) cluster = cluster.to(data.x.device) # pool the node infos x, _ = scatter_max(data.x, cluster, dim=0) # pool the edges edge_index, edge_attr = pool_edge(cluster, data.edge_index, data.edge_attr) # pool internal edges if necessary if has_internal_edges: internal_edge_index, internal_edge_attr = pool_edge( cluster, data.internal_edge_index, data.internal_edge_attr) # pool the pos if has_pos: pos = scatter_mean(data.pos, cluster, dim=0) if has_pos2D: pos2D = scatter_mean(data.pos2D, cluster, dim=0) if has_cluster: c0, c1 = data.cluster0, data.cluster1 # pool batch if hasattr(data, 'batch'): batch = None if data.batch is None else pool_batch(perm, data.batch) data = Batch(batch=batch, x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos) if has_internal_edges: data.internal_edge_index = internal_edge_index data.internal_edge_attr = internal_edge_attr if has_cluster: data.cluster0 = c0 data.cluster1 = c1 else: data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos) if has_internal_edges: data.internal_edge_index = internal_edge_index data.internal_edge_attr = internal_edge_attr if has_pos2D: data.pos2D = pos2D if has_cluster: data.cluster0 = c0 data.cluster1 = c1 return data
def test_consecutive_cluster(): src = torch.tensor([8, 2, 10, 15, 100, 1, 100]) out, perm = consecutive_cluster(src) assert out.tolist() == [2, 1, 3, 4, 5, 0, 5] assert perm.tolist() == [5, 1, 0, 2, 3, 6]