def init_adj(self, edge_index): """ cache normalized adjacency and normalized strict two-hop adjacency, neither has self loops """ n = self.num_nodes if isinstance(edge_index, SparseTensor): dev = adj_t.device adj_t = edge_index adj_t = scipy.sparse.csr_matrix(adj_t.to_scipy()) adj_t[adj_t > 0] = 1 adj_t[adj_t < 0] = 0 adj_t = SparseTensor.from_scipy(adj_t).to(dev) elif isinstance(edge_index, torch.Tensor): row, col = edge_index adj_t = SparseTensor(row=col, col=row, value=None, sparse_sizes=(n, n)) adj_t.remove_diag(0) adj_t2 = matmul(adj_t, adj_t) adj_t2.remove_diag(0) adj_t = scipy.sparse.csr_matrix(adj_t.to_scipy()) adj_t2 = scipy.sparse.csr_matrix(adj_t2.to_scipy()) adj_t2 = adj_t2 - adj_t adj_t2[adj_t2 > 0] = 1 adj_t2[adj_t2 < 0] = 0 adj_t = SparseTensor.from_scipy(adj_t) adj_t2 = SparseTensor.from_scipy(adj_t2) adj_t = gcn_norm(adj_t, None, n, add_self_loops=False) adj_t2 = gcn_norm(adj_t2, None, n, add_self_loops=False) self.adj_t = adj_t.to(edge_index.device) self.adj_t2 = adj_t2.to(edge_index.device)
def forward(self, x, edge_index, edge_weight=None, batch=None): """""" N = x.size(0) edge_index, edge_weight = add_remaining_self_loops(edge_index, edge_weight, fill_value=1, num_nodes=N) if batch is None: batch = edge_index.new_zeros(x.size(0)) x = x.unsqueeze(-1) if x.dim() == 1 else x x_pool = x if self.GNN is not None: x_pool = self.gnn_intra_cluster(x=x, edge_index=edge_index, edge_weight=edge_weight) x_pool_j = x_pool[edge_index[0]] x_q = scatter(x_pool_j, edge_index[1], dim=0, reduce='max') x_q = self.lin(x_q)[edge_index[1]] score = self.att(torch.cat([x_q, x_pool_j], dim=-1)).view(-1) score = F.leaky_relu(score, self.negative_slope) score = softmax(score, edge_index[1], num_nodes=N) # Sample attention coefficients stochastically. score = F.dropout(score, p=self.dropout, training=self.training) v_j = x[edge_index[0]] * score.view(-1, 1) x = scatter(v_j, edge_index[1], dim=0, reduce='add') # Cluster selection. fitness = self.gnn_score(x, edge_index).sigmoid().view(-1) perm = topk(fitness, self.ratio, batch) x = x[perm] * fitness[perm].view(-1, 1) batch = batch[perm] # Graph coarsening. row, col = edge_index A = SparseTensor(row=row, col=col, value=edge_weight, sparse_sizes=(N, N)) S = SparseTensor(row=row, col=col, value=score, sparse_sizes=(N, N)) S = S[:, perm] A = S.t() @ A @ S if self.add_self_loops: A = A.fill_diag(1.) else: A = A.remove_diag() row, col, edge_weight = A.coo() edge_index = torch.stack([row, col], dim=0) return x, edge_index, edge_weight, batch, perm
def get_adj(row, col, N, asymm_norm=False, set_diag=True, remove_diag=False): adj = SparseTensor(row=row, col=col, sparse_sizes=(N, N)) if set_diag: print('... setting diagonal entries') adj = adj.set_diag() elif remove_diag: print('... removing diagonal entries') adj = adj.remove_diag() else: print('... keeping diag elements as they are') if not asymm_norm: print('... performing symmetric normalization') deg = adj.sum(dim=1).to(torch.float) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 adj = deg_inv_sqrt.view(-1, 1) * adj * deg_inv_sqrt.view(1, -1) else: print('... performing asymmetric normalization') deg = adj.sum(dim=1).to(torch.float) deg_inv = deg.pow(-1.0) deg_inv[deg_inv == float('inf')] = 0 adj = deg_inv.view(-1, 1) * adj adj = adj.to_scipy(layout='csr') return adj
def get_adj(row, col, N, asymm_norm=False, set_diag=True, remove_diag=False): adj = SparseTensor(row=row, col=col, sparse_sizes=(N, N)) if set_diag: adj = adj.set_diag() elif remove_diag: adj = adj.remove_diag() if not asymm_norm: deg = adj.sum(dim=1).to(torch.float) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 adj = deg_inv_sqrt.view(-1, 1) * adj * deg_inv_sqrt.view(1, -1) else: deg = adj.sum(dim=1).to(torch.float) deg_inv = deg.pow(-1.0) deg_inv[deg_inv == float('inf')] = 0 adj = deg_inv.view(-1, 1) * adj return adj