def sample(self, pos=None, x=None, batch=None): if len(pos.shape) != 2: raise ValueError("This class is for sparse data and expects the pos tensor to be of dimension 2") pool = voxel_grid(pos, batch, self._subsampling_param) pool, perm = consecutive_cluster(pool) batch = pool_batch(perm, batch) if x is not None: return pool_pos(pool, x), pool_pos(pool, pos), batch else: return None, pool_pos(pool, pos), batch
def __call__(self, data): try: batch = data.batch except AttributeError: batch = torch.zeros(data.x.shape[0]) # points = [data.pos] # First downsample points = [data.pos] list_pool = [] list_batch = [batch] for ind, voxel_size in enumerate(self.list_voxel_size): pool = voxel_grid(points[-1], list_batch[-1], voxel_size) pool, perm = consecutive_cluster(pool) list_batch.append(pool_batch(perm, list_batch[-1])) points.append(pool_pos(pool, points[-1])) try: res = MultiScaleBatch( batch=data.batch, list_pool=list_pool, points=points, list_batch=list_batch, x=data.x, y=data.y, pos=data.pos) except AttributeError: res = MultiScaleData( list_pool=list_pool, points=points, x=data.x, y=data.y, pos=data.pos) return res
def community_pooling(cluster, data): """Pools features and edges of all cluster members All cluster members are pooled into a single node that is assigned: - the max cluster value for each feature - the average cluster nodes position Args: cluster ([type]): clusters data ([type]): features tensor Returns: pooled features tensor """ # determine what the batches has as attributes has_internal_edges = hasattr(data, 'internal_edge_index') has_pos2D = hasattr(data, 'pos2D') has_pos = hasattr(data, 'pos') has_cluster = hasattr(data, 'cluster0') cluster, perm = consecutive_cluster(cluster) cluster = cluster.to(data.x.device) # pool the node infos x, _ = scatter_max(data.x, cluster, dim=0) # pool the edges edge_index, edge_attr = pool_edge(cluster, data.edge_index, data.edge_attr) # pool internal edges if necessary if has_internal_edges: internal_edge_index, internal_edge_attr = pool_edge( cluster, data.internal_edge_index, data.internal_edge_attr) # pool the pos if has_pos: pos = scatter_mean(data.pos, cluster, dim=0) if has_pos2D: pos2D = scatter_mean(data.pos2D, cluster, dim=0) if has_cluster: c0, c1 = data.cluster0, data.cluster1 # pool batch if hasattr(data, 'batch'): batch = None if data.batch is None else pool_batch(perm, data.batch) data = Batch(batch=batch, x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos) if has_internal_edges: data.internal_edge_index = internal_edge_index data.internal_edge_attr = internal_edge_attr if has_cluster: data.cluster0 = c0 data.cluster1 = c1 else: data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos) if has_internal_edges: data.internal_edge_index = internal_edge_index data.internal_edge_attr = internal_edge_attr if has_pos2D: data.pos2D = pos2D if has_cluster: data.cluster0 = c0 data.cluster1 = c1 return data
def community_pooling(cluster, data): # determine what the batches has as attributes has_internal_edges = hasattr(data, 'internal_edge_index') has_pos2D = hasattr(data, 'pos2D') has_pos = hasattr(data, 'pos') has_cluster = hasattr(data, 'cluster0') cluster, perm = consecutive_cluster(cluster) cluster = cluster.to(data.x.device) # pool the node infos x, _ = scatter_max(data.x, cluster, dim=0) # pool the edges edge_index, edge_attr = pool_edge(cluster, data.edge_index, data.edge_attr) # pool internal edges if necessary if has_internal_edges: internal_edge_index, internal_edge_attr = pool_edge( cluster, data.internal_edge_index, data.internal_edge_attr) # pool the pos if has_pos: pos = scatter_mean(data.pos, cluster, dim=0) if has_pos2D: pos2D = scatter_mean(data.pos2D, cluster, dim=0) if has_cluster: c0, c1 = data.cluster0, data.cluster1 # pool batch if hasattr(data, 'batch'): batch = None if data.batch is None else pool_batch(perm, data.batch) data = Batch(batch=batch, x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos) if has_internal_edges: data.internal_edge_index = internal_edge_index data.internal_edge_attr = internal_edge_attr if has_cluster: data.cluster0 = c0 data.cluster1 = c1 else: data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos) if has_internal_edges: data.internal_edge_index = internal_edge_index data.internal_edge_attr = internal_edge_attr if has_pos2D: data.pos2D = pos2D if has_cluster: data.cluster0 = c0 data.cluster1 = c1 return data
def community_pooling(cluster, data): """Pools features and edges of all cluster members All cluster members are pooled into a single node that is assigned: - the max cluster value for each feature - the average cluster nodes position Args: cluster ([type]): clusters data ([type]): features tensor Returns: pooled features tensor Example: >>> import torch >>> from torch_geometric.data import Data, Batch >>> edge_index = torch.tensor([[0, 1, 1, 2, 3, 4, 4, 5], >>> [1, 0, 2, 1, 4, 3, 5, 4]], dtype=torch.long) >>> x = torch.tensor([[0], [1], [2], [3], [4], [5]], >>> dtype=torch.float) >>> data = Data(x=x, edge_index=edge_index) >>> data.pos = torch.tensor(np.random.rand(data.num_nodes, 3)) >>> c = community_detection(data.edge_index, data.num_nodes) >>> batch = Batch().from_data_list([data, data]) >>> cluster = community_detection(batch.edge_index, batch.num_nodes) >>> new_batch = community_pooling(cluster, batch) """ # determine what the batches has as attributes has_internal_edges = hasattr(data, 'internal_edge_index') has_pos2D = hasattr(data, 'pos2D') has_pos = hasattr(data, 'pos') has_cluster = hasattr(data, 'cluster0') cluster, perm = consecutive_cluster(cluster) cluster = cluster.to(data.x.device) # pool the node infos x, _ = scatter_max(data.x, cluster, dim=0) # pool the edges edge_index, edge_attr = pool_edge(cluster, data.edge_index, data.edge_attr) # pool internal edges if necessary if has_internal_edges: internal_edge_index, internal_edge_attr = pool_edge( cluster, data.internal_edge_index, data.internal_edge_attr) # pool the pos if has_pos: pos = scatter_mean(data.pos, cluster, dim=0) if has_pos2D: pos2D = scatter_mean(data.pos2D, cluster, dim=0) if has_cluster: c0, c1 = data.cluster0, data.cluster1 # pool batch if hasattr(data, 'batch'): batch = None if data.batch is None else pool_batch(perm, data.batch) data = Batch(batch=batch, x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos) if has_internal_edges: data.internal_edge_index = internal_edge_index data.internal_edge_attr = internal_edge_attr if has_cluster: data.cluster0 = c0 data.cluster1 = c1 else: data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos) if has_internal_edges: data.internal_edge_index = internal_edge_index data.internal_edge_attr = internal_edge_attr if has_pos2D: data.pos2D = pos2D if has_cluster: data.cluster0 = c0 data.cluster1 = c1 return data
def mgpool(x, pos, edge_index, batch, mask=None): adj_values = torch.ones(edge_index.shape[1]).cuda() cluster = graclus(edge_index) cluster, perm = consecutive_cluster(cluster) index = torch.stack([cluster, torch.arange(0, x.shape[0]).cuda()], dim=0) values = torch.ones(cluster.shape[0], dtype=torch.float).cuda() uniq, inv, counts = torch.unique(cluster, return_inverse=True, return_counts=True) newsize = uniq.shape[0] origsize = x.shape[0] new_batch = pool_batch(perm, batch) # Compute random walk graph laplacian: laplacian_index, laplacian_weights = get_laplacian(edge_index, normalization='rw') laplacian_index, laplacian_weights = torch_sparse.coalesce( laplacian_index, laplacian_weights, m=origsize, n=origsize) index, values = torch_sparse.coalesce(index, values, m=newsize, n=origsize) # P^T matrix new_feat = torch_sparse.spmm(index, values, m=newsize, n=origsize, matrix=x) # P^T X new_pos = torch_sparse.spmm(index, values, m=newsize, n=origsize, matrix=pos) # P^T POS new_adj, new_adj_val = torch_sparse.spspmm(index, values, edge_index, adj_values, m=newsize, k=origsize, n=origsize, coalesced=True) # P^T A index, values = torch_sparse.transpose(index, values, m=newsize, n=origsize, coalesced=True) # P new_adj, new_adj_val = torch_sparse.spspmm(new_adj, new_adj_val, index, values, m=newsize, k=origsize, n=newsize, coalesced=True) # (P^T A) P # Precompute QP : values = torch.ones(cluster.shape[0], dtype=torch.float).cuda() index, values = torch_sparse.spspmm(laplacian_index, laplacian_weights, index, values, m=origsize, k=origsize, n=newsize, coalesced=True) return new_adj, new_feat, new_pos, new_batch, index, values, origsize, newsize