def test_topk(): x = torch.Tensor([2, 4, 5, 6, 2, 9]) batch = torch.tensor([0, 0, 1, 1, 1, 1]) perm = topk(x, 0.5, batch) assert perm.tolist() == [1, 5, 3] assert x[perm].tolist() == [4, 9, 6] assert batch[perm].tolist() == [0, 1, 1] perm = topk(x, 3, batch) assert perm.tolist() == [1, 0, 5, 3, 2] assert x[perm].tolist() == [4, 2, 9, 6, 5] assert batch[perm].tolist() == [0, 0, 1, 1, 1]
def forward(self, x, edge_index, edge_attr=None, batch=None): if batch is None: batch = edge_index.new_zeros(x.size(0)) x = x.unsqueeze(-1) if x.dim() == 1 else x #SBTL score_s = self.sbtl_layer(x, edge_index).squeeze() #FBTL score_f = self.fbtl_layer(x).squeeze() #hyperparametr alpha score = score_s * self.alpha + score_f * (1 - self.alpha) score = score.unsqueeze(-1) if score.dim() == 0 else score if self.min_score is None: score = self.non_linearity(score) else: score = softmax(score, batch) perm = topk(score, self.ratio, batch) #fusion if (self.fusion_flag == 1): x = self.fusion(x, edge_index) x = x[perm] * score[perm].view(-1, 1) x = self.multiplier * x if self.multiplier != 1 else x batch = batch[perm] edge_index, edge_attr = filter_adj(edge_index, edge_attr, perm, num_nodes=score.size(0)) return x, edge_index, edge_attr, batch, perm
def forward(self, x, edge_index, edge_weight=None, batch=None): """""" N = x.size(0) edge_index, edge_weight = add_remaining_self_loops(edge_index, edge_weight, fill_value=1, num_nodes=N) if batch is None: batch = edge_index.new_zeros(x.size(0)) x = x.unsqueeze(-1) if x.dim() == 1 else x x_pool = x if self.GNN is not None: x_pool = self.gnn_intra_cluster(x=x, edge_index=edge_index, edge_weight=edge_weight) x_pool_j = x_pool[edge_index[0]] x_q = scatter(x_pool_j, edge_index[1], dim=0, reduce='max') x_q = self.lin(x_q)[edge_index[1]] score = self.att(torch.cat([x_q, x_pool_j], dim=-1)).view(-1) score = F.leaky_relu(score, self.negative_slope) score = softmax(score, edge_index[1], num_nodes=N) # Sample attention coefficients stochastically. score = F.dropout(score, p=self.dropout, training=self.training) v_j = x[edge_index[0]] * score.view(-1, 1) x = scatter(v_j, edge_index[1], dim=0, reduce='add') # Cluster selection. fitness = self.gnn_score(x, edge_index).sigmoid().view(-1) perm = topk(fitness, self.ratio, batch) x = x[perm] * fitness[perm].view(-1, 1) batch = batch[perm] # Graph coarsening. row, col = edge_index A = SparseTensor(row=row, col=col, value=edge_weight, sparse_sizes=(N, N)) S = SparseTensor(row=row, col=col, value=score, sparse_sizes=(N, N)) S = S[:, perm] A = S.t() @ A @ S if self.add_self_loops: A = A.fill_diag(1.) else: A = A.remove_diag() row, col, edge_weight = A.coo() edge_index = torch.stack([row, col], dim=0) return x, edge_index, edge_weight, batch, perm
def forward(self, x, edge_index, edge_attr=None, batch=None, attn=None): """""" if batch is None: batch = edge_index.new_zeros(x.size(0)) attn = x if attn is None else attn attn = attn.unsqueeze(-1) if attn.dim() == 1 else attn score = self.gnn(attn, edge_index).view(-1) ##### zero mean for each instance #########3 score = score.view(batch.max() + 1, -1) score = score - score.mean(1, keepdim=True) # score = score.view(-1) if self.min_score is None: score = self.nonlinearity(score) else: score = softmax(score, batch) perm = topk(score, self.ratio, batch, self.min_score) x = x[perm] * score[perm].view(-1, 1) x = self.multiplier * x if self.multiplier != 1 else x batch = batch[perm] edge_index, edge_attr = filter_adj(edge_index, edge_attr, perm, num_nodes=score.size(0)) # we changed the last returm term --- score, which are the scores for all the nodes return x, edge_index, edge_attr, batch, perm, score.view( batch.max() + 1, -1)
def forward(self, x, edge_index, edge_attr=None, batch=None, attn=None): """""" if batch is None: batch = edge_index.new_zeros(x.size(0)) attn = x if attn is None else attn attn = attn.unsqueeze(-1) if attn.dim() == 1 else attn score = self.gnn(attn, edge_index).view(-1) if self.min_score is None: score = self.nonlinearity(score) else: score = softmax(score, batch) perm = topk(score, self.ratio, batch, self.min_score) x = x[perm] * score[perm].view(-1, 1) x = self.multiplier * x if self.multiplier != 1 else x batch = batch[perm] edge_index, edge_attr = filter_adj(edge_index, edge_attr, perm, num_nodes=score.size(0)) return x, edge_index, edge_attr, batch, perm, score[perm]
def forward(self, x, edge_index, edge_weight=None, batch=None): if batch is None: batch = edge_index.new_zeros(x.size(0)) # NxF x = x.unsqueeze(-1) if x.dim() == 1 else x # Add Self Loops fill_value = 1 num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0) edge_index, edge_weight = add_remaining_self_loops(edge_index=edge_index, edge_weight=edge_weight, fill_value=fill_value, num_nodes=num_nodes.sum()) N = x.size(0) # total num of nodes in batch # ExF x_pool = self.gnn_intra_cluster(x=x, edge_index=edge_index, edge_weight=edge_weight) x_pool_j = x_pool[edge_index[1]] x_j = x[edge_index[1]] #---Master query formation--- # NxF X_q, _ = scatter_max(x_pool_j, edge_index[0], dim=0) # NxF M_q = self.lin_q(X_q) # ExF M_q = M_q[edge_index[0].tolist()] score = self.gat_att(torch.cat((M_q, x_pool_j), dim=-1)) score = F.leaky_relu(score, self.negative_slope) score = softmax(score, edge_index[0], num_nodes=num_nodes.sum()) # Sample attention coefficients stochastically. score = F.dropout(score, p=self.dropout_att, training=self.training) # ExF v_j = x_j * score.view(-1, 1) #---Aggregation--- # NxF out = scatter_add(v_j, edge_index[0], dim=0) #---Cluster Selection # Nx1 fitness = torch.sigmoid(self.gnn_score(x=out, edge_index=edge_index)).view(-1) perm = topk(x=fitness, ratio=self.ratio, batch=batch) x = out[perm] * fitness[perm].view(-1, 1) #---Maintaining Graph Connectivity batch = batch[perm] edge_index, edge_weight = graph_connectivity( device = x.device, perm=perm, edge_index=edge_index, edge_weight=edge_weight, score=score, ratio=self.ratio, batch=batch, N=N) return x, edge_index, edge_weight, batch, perm
def test_topk(): x = torch.tensor([2, 4, 5, 6, 2, 9], dtype=torch.float) batch = torch.tensor([0, 0, 1, 1, 1, 1]) perm = topk(x, 0.5, batch) assert perm.tolist() == [1, 5, 3] assert x[perm].tolist() == [4, 9, 6]
def forward(self, graph, x, batch=None): if batch is None: batch = graph.edge_index.new_zeros(x.size(0)) score = self.score_layer(graph, x).squeeze() perm = topk(score, self.ratio, batch) x = x[perm] * self.non_linearity(score[perm]).view(-1, 1) batch = batch[perm] edge_index, edge_attr = filter_adj(graph.edge_index, graph.edge_weight, perm, num_nodes=score.size(0)) return x, edge_index, edge_attr, batch, perm
def forward(self, input, edge_index, edge_attr=None, batch=None): if batch is None: batch = edge_index.new_zeros(input.size(0)) score = self.score_layer(input, edge_index).squeeze() perm = topk(score, self.ratio, batch) input = input[perm] * self.non_linearity(score[perm]).view(-1, 1) batch = batch[perm] edge_index, edge_attr = filter_adj(edge_index, edge_attr, perm, num_nodes=score.size(0)) return input, edge_index, edge_attr, batch, perm
def forward(self, x, edge_index, edge_attr=None, batch=None): """""" if batch is None: batch = edge_index.new_zeros(x.size(0)) x = x.unsqueeze(-1) if x.dim() == 1 else x score = torch.tanh(self.gnn(x, edge_index).view(-1)) perm = topk(score, self.ratio, batch) x = x[perm] * score[perm].view(-1, 1) batch = batch[perm] edge_index, edge_attr = filter_adj( edge_index, edge_attr, perm, num_nodes=score.size(0)) return x, edge_index, edge_attr, batch, perm
def forward(self, x, edge_index=None, edge_attr=None, batch=None): if batch is None: batch = self.A.new_zeros(x.size(0)) #x = x.unsqueeze(-1) if x.dim() == 1 else x score = self.score_layer(x, self.A).squeeze() perm = topk(score, self.ratio, batch) x = x[perm] * self.non_linearity(score[perm]).view(-1, 1) batch = batch[perm] edge_index, edge_attr = filter_adj(edge_index, edge_attr, perm, num_nodes=score.size(0)) return x, edge_index, edge_attr, batch, perm
def forward(self, x, edge_index, edge_attr, batch): score = self.score_layer(x, edge_index).squeeze() perm = topk(score, self.ratio, batch) x = x[perm] * self.non_linearity(score[perm]).view(-1, 1) batch = batch[perm] edge_index, edge_attr = filter_adj( edge_index, edge_attr, perm, num_nodes=score.size(0)) a = gmp(x, batch) m = gap(x, batch) return torch.cat([m, a], dim=1)
def forward(self, x, edge_index, attention, batch=None, direction=1): e_batch = edge_index[0] degree = torch.bincount(e_batch) node_scores = direction * g_pooling(attention, e_batch).view(-1) node_scores = node_scores.mul(degree) perm = topk(node_scores, self.rate, batch) edge_index, _ = self.augment_adj(edge_index, None, x.size(0)) edge_index, _ = filter_adj(edge_index, None, perm, num_nodes=node_scores.size(0)) x = x[perm] batch = batch[perm] return x, edge_index, batch, perm.view((1, -1))
def forward(self, x, edge_index, edge_attr=None, batch=None): if batch is None: batch = edge_index.new_zeros(x.size(0)) num_node = x.size(0) k = F.relu(self.lin_2(x)) A = SparseTensor.from_edge_index(edge_index=edge_index, edge_attr=edge_attr, sparse_sizes=(num_node, num_node)) I = SparseTensor.eye(num_node, device=self.args.device) A_wave = fill_diag(A, 1) s = A_wave @ k score = s.squeeze() perm = topk(score, self.ratio, batch) A = self.norm(A) K_neighbor = A * k.T x_neighbor = K_neighbor @ x # ----modified deg = sum(A, dim=1) deg_inv = deg.pow_(-1) deg_inv.masked_fill_(deg_inv == float('inf'), 0.) x_neighbor = x_neighbor * deg_inv.view(1, -1).T # ---- x_self = x * k x = x_neighbor * ( 1 - self.args.combine_ratio) + x_self * self.args.combine_ratio x = x[perm] batch = batch[perm] edge_index, edge_attr = filter_adj(edge_index, edge_attr, perm, num_nodes=s.size(0)) return x, edge_index, edge_attr, batch, perm
def forward(self, x, edge_index, edge_attr=None, batch=None): if batch is None: batch = edge_index.new_zeros(x.size(0)) #iterative fusion for i in range(3): score_ = self.score_layer(x, edge_index).squeeze() if i > 0: score = score * score_ + score else: score = score_ perm = topk(score, self.ratio, batch) x = x[perm] * self.non_linearity(score[perm]).view(-1, 1) batch = batch[perm] edge_index, edge_attr = filter_adj(edge_index, edge_attr, perm, num_nodes=score.size(0)) return x, edge_index, edge_attr, batch, perm
def forward(self, x, edge_index, closeness, degree, edge_attr=None, batch=None): if batch is None: batch = edge_index.new_zeros(x.size(0)) x = x.unsqueeze(-1) if x.dim() == 1 else x if self.layer == 1: score = torch.relu(self.gnn(x, edge_index).view(-1)) elif self.layer == 2: score = torch.relu(self.gnn1(x, edge_index)) score = torch.relu(self.gnn2(score, edge_index).view(-1)) elif self.layer == 3: score = torch.relu(self.gnn1(x, edge_index)) score = torch.relu(self.gnn2(score, edge_index)) score = torch.relu(self.gnn3(score, edge_index).view(-1)) '''centrality adjust''' closeness = closeness * self.weight_closeness degree = degree * self.weight_degree centrality = closeness + degree if self.bias is not None: centrality += self.bias score = score * self.weight_score score = score + centrality score = F.relu(score) perm = topk(score, self.ratio, batch) tmp1 = x[perm] tmp2 = score[perm] x = tmp1 * tmp2.view(-1, 1) batch = batch[perm] return x, perm, batch
def forward(self, x, x_score, edge_index, edge_attr, batch=None): n=x.shape[0] if batch is None: batch = edge_index.new_zeros(x.size(0)) # Graph Pooling perm = topk(x_score.view(-1), self.ratio, batch) induced_edge_index, induced_edge_attr = filter_adj(edge_index, edge_attr, perm, num_nodes=x.shape[0]) # isolate_mask=(perm.view(-1,1)==induced_edge_index.view(-1).unique()).sum(dim=1)>0 # perm = perm[isolate_mask] # induced_edge_index, induced_edge_attr = filter_adj(edge_index, edge_attr, perm, num_nodes=x.shape[0]) if edge_index.shape[1]>0: row,col=edge_index S=torch.exp(torch.norm(x[row]-x[col],dim=1)) th=torch.sort(S,descending=True).values[int(self.edge_ratio*(len(S)-1))] select=(S>th) edge_index=edge_index[:,select] x = x[perm] batch = batch[perm] return x, induced_edge_index, induced_edge_attr, batch # ############ add structure learning # class LookHopsPool(torch.nn.Module): # def __init__(self, k, out_channels,ratio=0.8,edge_ratio=0.8): # super(LookHopsPool, self).__init__() # self.k=k # self.ratio = ratio # self.edge_ratio=edge_ratio # self.idx=1 # self.node_att = nn.Linear(k*out_channels,1) # self.edge_att = nn.Linear(k*out_channels*2,1) # self.alpha = nn.Parameter(torch.tensor(1.0)) # def forward(self, x, neighbor_info, edge_index, edge_dis, batch): # n=x.shape[0] # if batch is None: # batch = edge_index.new_zeros(x.size(0)) # x_score = self.node_att(neighbor_info) # x=x*x_score # # Graph Pooling # perm = topk(x_score.view(-1), self.ratio, batch) # induced_edge_index, induced_edge_dis = filter_adj(edge_index, edge_dis, perm, num_nodes=x.shape[0]) # x = x[perm] # batch = batch[perm] # # neighbor_info = neighbor_info[perm] # # row,col=induced_edge_index # induced_edge_weight = torch.exp(-self.alpha*induced_edge_dis) # if torch.isnan(induced_edge_weight).any(): # print('NO') # # if edge_index.shape[1]>0: # # row,col=edge_index # # S=torch.exp(torch.norm(x[row]-x[col],dim=1)) # # th=torch.sort(S,descending=True).values[int(self.edge_ratio*(len(S)-1))] # # select=(S>th) # # edge_index=edge_index[:,select] # return x, induced_edge_index, induced_edge_weight,induced_edge_dis, batch
def forward(self, x, edge_index, edge_attr=None, batch=None): if batch is None: batch = edge_index.new_zeros(x.size(0)) x_information_score = self.calc_information_score(x, edge_index) score = torch.sum(torch.abs(x_information_score), dim=1) # Graph Pooling original_x = x perm = topk(score, self.ratio, batch) x = x[perm] batch = batch[perm] induced_edge_index, induced_edge_attr = filter_adj( edge_index, edge_attr, perm, num_nodes=score.size(0)) # Discard structure learning layer, directly return if self.sl is False: return x, induced_edge_index, induced_edge_attr, batch, perm # Structure Learning if self.sample: # A fast mode for large graphs. # In large graphs, learning the possible edge weights between each pair of nodes is time consuming. # To accelerate this process, we sample it's K-Hop neighbors for each node and then learn the # edge weights between them. k_hop = 3 if edge_attr is None: edge_attr = torch.ones((edge_index.size(1), ), dtype=torch.float, device=edge_index.device) hop_data = Data(x=original_x, edge_index=edge_index, edge_attr=edge_attr) for _ in range(k_hop - 1): hop_data = self.neighbor_augment(hop_data) hop_edge_index = hop_data.edge_index hop_edge_attr = hop_data.edge_attr new_edge_index, new_edge_attr = filter_adj(hop_edge_index, hop_edge_attr, perm, num_nodes=score.size(0)) row, col = new_edge_index if self.att.fast is not None: weights = (torch.cat([x[row], x[col]], dim=1) * self.att.fast).sum(dim=-1) else: tmps = torch.cat([x[row], x[col]], dim=1) # assert (tmps.shape[0]==self.att.) weights = (tmps * self.att).sum(dim=-1) weights = F.leaky_relu( weights, self.negative_slop) + new_edge_attr * self.lamb adj = torch.zeros((x.size(0), x.size(0)), dtype=torch.float, device=x.device) adj[row, col] = weights new_edge_index, weights = dense_to_sparse(adj) row, col = new_edge_index if self.sparse: new_edge_attr = self.sparse_attention(weights, row) else: new_edge_attr = softmax(weights, row, x.size(0)) # filter out zero weight edges adj[row, col] = new_edge_attr new_edge_index, new_edge_attr = dense_to_sparse(adj) # release gpu memory del adj torch.cuda.empty_cache() else: # Learning the possible edge weights between each pair of nodes in the pooled subgraph, relative slower. if edge_attr is None: induced_edge_attr = torch.ones( (induced_edge_index.size(1), ), dtype=x.dtype, device=induced_edge_index.device) num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0) shift_cum_num_nodes = torch.cat( [num_nodes.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]], dim=0) cum_num_nodes = num_nodes.cumsum(dim=0) adj = torch.zeros((x.size(0), x.size(0)), dtype=torch.float, device=x.device) # Construct batch fully connected graph in block diagonal matirx format for idx_i, idx_j in zip(shift_cum_num_nodes, cum_num_nodes): adj[idx_i:idx_j, idx_i:idx_j] = 1.0 new_edge_index, _ = dense_to_sparse(adj) row, col = new_edge_index if self.att.fast is not None: weights = (torch.cat([x[row], x[col]], dim=1) * self.att.fast).sum(dim=-1) else: weights = (torch.cat([x[row], x[col]], dim=1) * self.att).sum(dim=-1) weights = F.leaky_relu(weights, self.negative_slop) adj[row, col] = weights induced_row, induced_col = induced_edge_index adj[induced_row, induced_col] += induced_edge_attr * self.lamb weights = adj[row, col] if self.sparse: new_edge_attr = self.sparse_attention(weights, row) else: new_edge_attr = softmax(weights, row, x.size(0)) # filter out zero weight edges adj[row, col] = new_edge_attr new_edge_index, new_edge_attr = dense_to_sparse(adj) # release gpu memory del adj torch.cuda.empty_cache() return x, new_edge_index, new_edge_attr, batch, perm
def forward(self, x, edge_index, edge_attr, batch, h, neg_num, samp_bias1, samp_bias2): """ :param x: node feature after convolution :param edge_index: :param edge_attr: :param batch: :param h: node feature before convolution :param neg_num: :param samp_bias1: :param samp_bias2: :return: """ # I(h_i; x_i) res_mi_pos, res_mi_neg = self.disc1(x, h, process.negative_sampling_tg(batch, neg_num), samp_bias1, samp_bias2) mi_jsd_score = process.sp_func(res_mi_pos) + process.sp_func(torch.mean(res_mi_neg, dim=1)) # Graph Pooling original_x = x perm = topk(mi_jsd_score, self.ratio, batch) x = x[perm] batch = batch[perm] induced_edge_index, induced_edge_attr = filter_adj(edge_index, edge_attr, perm, num_nodes=mi_jsd_score.size(0)) # Discard structure learning layer, directly return if self.sl is False: return x, induced_edge_index, induced_edge_attr, batch # Structure Learning if self.sample: # A fast mode for large graphs. # In large graphs, learning the possible edge weights between each pair of nodes is time consuming. # To accelerate this process, we sample it's K-Hop neighbors for each node and then learn the # edge weights between them. k_hop = 3 if edge_attr is None: edge_attr = torch.ones((edge_index.size(1),), dtype=torch.float, device=edge_index.device) hop_data = Data(x=original_x, edge_index=edge_index, edge_attr=edge_attr) for _ in range(k_hop - 1): hop_data = self.neighbor_augment(hop_data) hop_edge_index = hop_data.edge_index hop_edge_attr = hop_data.edge_attr new_edge_index, new_edge_attr = filter_adj(hop_edge_index, hop_edge_attr, perm, num_nodes=mi_jsd_score.size(0)) new_edge_index, new_edge_attr = add_remaining_self_loops(new_edge_index, new_edge_attr, 0, x.size(0)) row, col = new_edge_index weights = (torch.cat([x[row], x[col]], dim=1) * self.att).sum(dim=-1) weights = F.leaky_relu(weights, self.negative_slop) + new_edge_attr * self.lamb adj = torch.zeros((x.size(0), x.size(0)), dtype=torch.float, device=x.device) adj[row, col] = weights new_edge_index, weights = dense_to_sparse(adj) row, col = new_edge_index if self.sparse: new_edge_attr = self.sparse_attention(weights, row) else: new_edge_attr = softmax(weights, row, x.size(0)) # filter out zero weight edges adj[row, col] = new_edge_attr new_edge_index, new_edge_attr = dense_to_sparse(adj) # release gpu memory del adj torch.cuda.empty_cache() else: # Learning the possible edge weights between each pair of nodes in the pooled subgraph, relative slower. if edge_attr is None: induced_edge_attr = torch.ones((induced_edge_index.size(1),), dtype=x.dtype, device=induced_edge_index.device) num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0) shift_cum_num_nodes = torch.cat([num_nodes.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]], dim=0) cum_num_nodes = num_nodes.cumsum(dim=0) adj = torch.zeros((x.size(0), x.size(0)), dtype=torch.float, device=x.device) # Construct batch fully connected graph in block diagonal matirx format for idx_i, idx_j in zip(shift_cum_num_nodes, cum_num_nodes): adj[idx_i:idx_j, idx_i:idx_j] = 1.0 new_edge_index, _ = dense_to_sparse(adj) row, col = new_edge_index weights = (torch.cat([x[row], x[col]], dim=1) * self.att).sum(dim=-1) weights = F.leaky_relu(weights, self.negative_slop) adj[row, col] = weights induced_row, induced_col = induced_edge_index adj[induced_row, induced_col] += induced_edge_attr * self.lamb weights = adj[row, col] if self.sparse: new_edge_attr = self.sparse_attention(weights, row) else: new_edge_attr = softmax(weights, row, x.size(0)) # filter out zero weight edges adj[row, col] = new_edge_attr new_edge_index, new_edge_attr = dense_to_sparse(adj) # release gpu memory del adj torch.cuda.empty_cache() return x, new_edge_index, new_edge_attr, batch