def test_topk():
    x = torch.Tensor([2, 4, 5, 6, 2, 9])
    batch = torch.tensor([0, 0, 1, 1, 1, 1])

    perm = topk(x, 0.5, batch)

    assert perm.tolist() == [1, 5, 3]
    assert x[perm].tolist() == [4, 9, 6]
    assert batch[perm].tolist() == [0, 1, 1]

    perm = topk(x, 3, batch)

    assert perm.tolist() == [1, 0, 5, 3, 2]
    assert x[perm].tolist() == [4, 2, 9, 6, 5]
    assert batch[perm].tolist() == [0, 0, 1, 1, 1]
Beispiel #2
0
    def forward(self, x, edge_index, edge_attr=None, batch=None):
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))
        x = x.unsqueeze(-1) if x.dim() == 1 else x

        #SBTL
        score_s = self.sbtl_layer(x, edge_index).squeeze()
        #FBTL
        score_f = self.fbtl_layer(x).squeeze()
        #hyperparametr alpha
        score = score_s * self.alpha + score_f * (1 - self.alpha)

        score = score.unsqueeze(-1) if score.dim() == 0 else score

        if self.min_score is None:
            score = self.non_linearity(score)
        else:
            score = softmax(score, batch)
        perm = topk(score, self.ratio, batch)

        #fusion
        if (self.fusion_flag == 1):
            x = self.fusion(x, edge_index)

        x = x[perm] * score[perm].view(-1, 1)
        x = self.multiplier * x if self.multiplier != 1 else x

        batch = batch[perm]
        edge_index, edge_attr = filter_adj(edge_index,
                                           edge_attr,
                                           perm,
                                           num_nodes=score.size(0))

        return x, edge_index, edge_attr, batch, perm
Beispiel #3
0
    def forward(self, x, edge_index, edge_weight=None, batch=None):
        """"""
        N = x.size(0)

        edge_index, edge_weight = add_remaining_self_loops(edge_index,
                                                           edge_weight,
                                                           fill_value=1,
                                                           num_nodes=N)

        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        x = x.unsqueeze(-1) if x.dim() == 1 else x

        x_pool = x
        if self.GNN is not None:
            x_pool = self.gnn_intra_cluster(x=x,
                                            edge_index=edge_index,
                                            edge_weight=edge_weight)

        x_pool_j = x_pool[edge_index[0]]
        x_q = scatter(x_pool_j, edge_index[1], dim=0, reduce='max')
        x_q = self.lin(x_q)[edge_index[1]]

        score = self.att(torch.cat([x_q, x_pool_j], dim=-1)).view(-1)
        score = F.leaky_relu(score, self.negative_slope)
        score = softmax(score, edge_index[1], num_nodes=N)

        # Sample attention coefficients stochastically.
        score = F.dropout(score, p=self.dropout, training=self.training)

        v_j = x[edge_index[0]] * score.view(-1, 1)
        x = scatter(v_j, edge_index[1], dim=0, reduce='add')

        # Cluster selection.
        fitness = self.gnn_score(x, edge_index).sigmoid().view(-1)
        perm = topk(fitness, self.ratio, batch)
        x = x[perm] * fitness[perm].view(-1, 1)
        batch = batch[perm]

        # Graph coarsening.
        row, col = edge_index
        A = SparseTensor(row=row,
                         col=col,
                         value=edge_weight,
                         sparse_sizes=(N, N))
        S = SparseTensor(row=row, col=col, value=score, sparse_sizes=(N, N))
        S = S[:, perm]

        A = S.t() @ A @ S

        if self.add_self_loops:
            A = A.fill_diag(1.)
        else:
            A = A.remove_diag()

        row, col, edge_weight = A.coo()
        edge_index = torch.stack([row, col], dim=0)

        return x, edge_index, edge_weight, batch, perm
Beispiel #4
0
    def forward(self, x, edge_index, edge_attr=None, batch=None, attn=None):
        """"""
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        attn = x if attn is None else attn
        attn = attn.unsqueeze(-1) if attn.dim() == 1 else attn
        score = self.gnn(attn, edge_index).view(-1)

        #####  zero mean for each instance #########3
        score = score.view(batch.max() + 1, -1)
        score = score - score.mean(1, keepdim=True)  #
        score = score.view(-1)

        if self.min_score is None:
            score = self.nonlinearity(score)
        else:
            score = softmax(score, batch)

        perm = topk(score, self.ratio, batch, self.min_score)
        x = x[perm] * score[perm].view(-1, 1)
        x = self.multiplier * x if self.multiplier != 1 else x

        batch = batch[perm]
        edge_index, edge_attr = filter_adj(edge_index,
                                           edge_attr,
                                           perm,
                                           num_nodes=score.size(0))

        # we changed the last returm term --- score, which are the scores for all the nodes
        return x, edge_index, edge_attr, batch, perm, score.view(
            batch.max() + 1, -1)
Beispiel #5
0
    def forward(self, x, edge_index, edge_attr=None, batch=None, attn=None):
        """"""
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        attn = x if attn is None else attn
        attn = attn.unsqueeze(-1) if attn.dim() == 1 else attn
        score = self.gnn(attn, edge_index).view(-1)

        if self.min_score is None:
            score = self.nonlinearity(score)
        else:
            score = softmax(score, batch)

        perm = topk(score, self.ratio, batch, self.min_score)
        x = x[perm] * score[perm].view(-1, 1)
        x = self.multiplier * x if self.multiplier != 1 else x

        batch = batch[perm]
        edge_index, edge_attr = filter_adj(edge_index,
                                           edge_attr,
                                           perm,
                                           num_nodes=score.size(0))

        return x, edge_index, edge_attr, batch, perm, score[perm]
Beispiel #6
0
    def forward(self, x, edge_index, edge_weight=None, batch=None):
        
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        # NxF
        x = x.unsqueeze(-1) if x.dim() == 1 else x
        # Add Self Loops
        fill_value = 1
        num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
        edge_index, edge_weight = add_remaining_self_loops(edge_index=edge_index, edge_weight=edge_weight, 
            fill_value=fill_value, num_nodes=num_nodes.sum())

        N = x.size(0) # total num of nodes in batch

        # ExF
        x_pool = self.gnn_intra_cluster(x=x, edge_index=edge_index, edge_weight=edge_weight)
        x_pool_j = x_pool[edge_index[1]]
        x_j = x[edge_index[1]]
        
        #---Master query formation---
        # NxF
        X_q, _ = scatter_max(x_pool_j, edge_index[0], dim=0)
        # NxF
        M_q = self.lin_q(X_q)    
        # ExF
        M_q = M_q[edge_index[0].tolist()]

        score = self.gat_att(torch.cat((M_q, x_pool_j), dim=-1))
        score = F.leaky_relu(score, self.negative_slope)
        score = softmax(score, edge_index[0], num_nodes=num_nodes.sum())

        # Sample attention coefficients stochastically.
        score = F.dropout(score, p=self.dropout_att, training=self.training)
        # ExF
        v_j = x_j * score.view(-1, 1)
        #---Aggregation---
        # NxF
        out = scatter_add(v_j, edge_index[0], dim=0)
        
        #---Cluster Selection
        # Nx1
        fitness = torch.sigmoid(self.gnn_score(x=out, edge_index=edge_index)).view(-1)
        perm = topk(x=fitness, ratio=self.ratio, batch=batch)
        x = out[perm] * fitness[perm].view(-1, 1)
        
        #---Maintaining Graph Connectivity
        batch = batch[perm]
        edge_index, edge_weight = graph_connectivity(
            device = x.device,
            perm=perm,
            edge_index=edge_index,
            edge_weight=edge_weight,
            score=score,
            ratio=self.ratio,
            batch=batch,
            N=N)
 
        
        return x, edge_index, edge_weight, batch, perm
def test_topk():
    x = torch.tensor([2, 4, 5, 6, 2, 9], dtype=torch.float)
    batch = torch.tensor([0, 0, 1, 1, 1, 1])

    perm = topk(x, 0.5, batch)

    assert perm.tolist() == [1, 5, 3]
    assert x[perm].tolist() == [4, 9, 6]
Beispiel #8
0
 def forward(self, graph, x, batch=None):
     if batch is None:
         batch = graph.edge_index.new_zeros(x.size(0))
     score = self.score_layer(graph, x).squeeze()
     perm = topk(score, self.ratio, batch)
     x = x[perm] * self.non_linearity(score[perm]).view(-1, 1)
     batch = batch[perm]
     edge_index, edge_attr = filter_adj(graph.edge_index,
                                        graph.edge_weight,
                                        perm,
                                        num_nodes=score.size(0))
     return x, edge_index, edge_attr, batch, perm
Beispiel #9
0
 def forward(self, input, edge_index, edge_attr=None, batch=None):
     if batch is None:
         batch = edge_index.new_zeros(input.size(0))
     score = self.score_layer(input, edge_index).squeeze()
     perm = topk(score, self.ratio, batch)
     input = input[perm] * self.non_linearity(score[perm]).view(-1, 1)
     batch = batch[perm]
     edge_index, edge_attr = filter_adj(edge_index,
                                        edge_attr,
                                        perm,
                                        num_nodes=score.size(0))
     return input, edge_index, edge_attr, batch, perm
Beispiel #10
0
    def forward(self, x, edge_index, edge_attr=None, batch=None):
        """"""
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        x = x.unsqueeze(-1) if x.dim() == 1 else x

        score = torch.tanh(self.gnn(x, edge_index).view(-1))
        perm = topk(score, self.ratio, batch)
        x = x[perm] * score[perm].view(-1, 1)
        batch = batch[perm]
        edge_index, edge_attr = filter_adj(
            edge_index, edge_attr, perm, num_nodes=score.size(0))

        return x, edge_index, edge_attr, batch, perm
Beispiel #11
0
    def forward(self, x, edge_index=None, edge_attr=None, batch=None):
        if batch is None:
            batch = self.A.new_zeros(x.size(0))
        #x = x.unsqueeze(-1) if x.dim() == 1 else x
        score = self.score_layer(x, self.A).squeeze()

        perm = topk(score, self.ratio, batch)
        x = x[perm] * self.non_linearity(score[perm]).view(-1, 1)
        batch = batch[perm]
        edge_index, edge_attr = filter_adj(edge_index,
                                           edge_attr,
                                           perm,
                                           num_nodes=score.size(0))

        return x, edge_index, edge_attr, batch, perm
Beispiel #12
0
    def forward(self, x, edge_index, edge_attr, batch):
        score = self.score_layer(x, edge_index).squeeze()

        perm = topk(score, self.ratio, batch)

        x = x[perm] * self.non_linearity(score[perm]).view(-1, 1)

        batch = batch[perm]

        edge_index, edge_attr = filter_adj(
            edge_index, edge_attr, perm, num_nodes=score.size(0))

        a = gmp(x, batch)
        m = gap(x, batch)

        return torch.cat([m, a], dim=1)
Beispiel #13
0
    def forward(self, x, edge_index, attention, batch=None, direction=1):
        e_batch = edge_index[0]
        degree = torch.bincount(e_batch)
        node_scores = direction * g_pooling(attention, e_batch).view(-1)
        node_scores = node_scores.mul(degree)

        perm = topk(node_scores, self.rate, batch)

        edge_index, _ = self.augment_adj(edge_index, None, x.size(0))
        edge_index, _ = filter_adj(edge_index,
                                   None,
                                   perm,
                                   num_nodes=node_scores.size(0))
        x = x[perm]
        batch = batch[perm]

        return x, edge_index, batch, perm.view((1, -1))
Beispiel #14
0
    def forward(self, x, edge_index, edge_attr=None, batch=None):
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))
        num_node = x.size(0)

        k = F.relu(self.lin_2(x))

        A = SparseTensor.from_edge_index(edge_index=edge_index,
                                         edge_attr=edge_attr,
                                         sparse_sizes=(num_node, num_node))
        I = SparseTensor.eye(num_node, device=self.args.device)
        A_wave = fill_diag(A, 1)

        s = A_wave @ k

        score = s.squeeze()
        perm = topk(score, self.ratio, batch)

        A = self.norm(A)

        K_neighbor = A * k.T
        x_neighbor = K_neighbor @ x

        # ----modified
        deg = sum(A, dim=1)
        deg_inv = deg.pow_(-1)
        deg_inv.masked_fill_(deg_inv == float('inf'), 0.)
        x_neighbor = x_neighbor * deg_inv.view(1, -1).T
        # ----
        x_self = x * k

        x = x_neighbor * (
            1 - self.args.combine_ratio) + x_self * self.args.combine_ratio

        x = x[perm]
        batch = batch[perm]
        edge_index, edge_attr = filter_adj(edge_index,
                                           edge_attr,
                                           perm,
                                           num_nodes=s.size(0))

        return x, edge_index, edge_attr, batch, perm
Beispiel #15
0
    def forward(self, x, edge_index, edge_attr=None, batch=None):
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        #iterative fusion
        for i in range(3):
            score_ = self.score_layer(x, edge_index).squeeze()
            if i > 0:
                score = score * score_ + score
            else:
                score = score_

        perm = topk(score, self.ratio, batch)
        x = x[perm] * self.non_linearity(score[perm]).view(-1, 1)
        batch = batch[perm]
        edge_index, edge_attr = filter_adj(edge_index,
                                           edge_attr,
                                           perm,
                                           num_nodes=score.size(0))

        return x, edge_index, edge_attr, batch, perm
Beispiel #16
0
    def forward(self,
                x,
                edge_index,
                closeness,
                degree,
                edge_attr=None,
                batch=None):
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        x = x.unsqueeze(-1) if x.dim() == 1 else x

        if self.layer == 1:
            score = torch.relu(self.gnn(x, edge_index).view(-1))
        elif self.layer == 2:
            score = torch.relu(self.gnn1(x, edge_index))
            score = torch.relu(self.gnn2(score, edge_index).view(-1))
        elif self.layer == 3:
            score = torch.relu(self.gnn1(x, edge_index))
            score = torch.relu(self.gnn2(score, edge_index))
            score = torch.relu(self.gnn3(score, edge_index).view(-1))
        '''centrality adjust'''
        closeness = closeness * self.weight_closeness
        degree = degree * self.weight_degree
        centrality = closeness + degree
        if self.bias is not None:
            centrality += self.bias

        score = score * self.weight_score
        score = score + centrality
        score = F.relu(score)

        perm = topk(score, self.ratio, batch)
        tmp1 = x[perm]
        tmp2 = score[perm]
        x = tmp1 * tmp2.view(-1, 1)
        batch = batch[perm]

        return x, perm, batch
Beispiel #17
0
    def forward(self, x, x_score, edge_index, edge_attr, batch=None):
        n=x.shape[0]

        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        # Graph Pooling
        perm = topk(x_score.view(-1), self.ratio, batch)
        
        induced_edge_index, induced_edge_attr = filter_adj(edge_index, edge_attr, perm, num_nodes=x.shape[0])
        # isolate_mask=(perm.view(-1,1)==induced_edge_index.view(-1).unique()).sum(dim=1)>0
        # perm = perm[isolate_mask]
        # induced_edge_index, induced_edge_attr = filter_adj(edge_index, edge_attr, perm, num_nodes=x.shape[0])

        if edge_index.shape[1]>0:
            row,col=edge_index
            S=torch.exp(torch.norm(x[row]-x[col],dim=1))
            th=torch.sort(S,descending=True).values[int(self.edge_ratio*(len(S)-1))]
            select=(S>th)
            edge_index=edge_index[:,select]
        
        x = x[perm]
        batch = batch[perm]
        return x, induced_edge_index, induced_edge_attr, batch


# ############ add structure learning
# class LookHopsPool(torch.nn.Module):
#     def __init__(self, k, out_channels,ratio=0.8,edge_ratio=0.8):
#         super(LookHopsPool, self).__init__()
#         self.k=k
#         self.ratio = ratio
#         self.edge_ratio=edge_ratio
#         self.idx=1
#         self.node_att = nn.Linear(k*out_channels,1)
#         self.edge_att = nn.Linear(k*out_channels*2,1)
#         self.alpha = nn.Parameter(torch.tensor(1.0))

#     def forward(self, x, neighbor_info, edge_index, edge_dis, batch):
#         n=x.shape[0]

#         if batch is None:
#             batch = edge_index.new_zeros(x.size(0))
        
#         x_score = self.node_att(neighbor_info)
#         x=x*x_score

#         # Graph Pooling
#         perm = topk(x_score.view(-1), self.ratio, batch)
#         induced_edge_index, induced_edge_dis = filter_adj(edge_index, edge_dis, perm, num_nodes=x.shape[0])
#         x = x[perm]
#         batch = batch[perm]
#         # neighbor_info = neighbor_info[perm]
#         # row,col=induced_edge_index
#         induced_edge_weight = torch.exp(-self.alpha*induced_edge_dis)
#         if torch.isnan(induced_edge_weight).any():
#             print('NO')

#         # if edge_index.shape[1]>0:
#         #     row,col=edge_index
#         #     S=torch.exp(torch.norm(x[row]-x[col],dim=1))
#         #     th=torch.sort(S,descending=True).values[int(self.edge_ratio*(len(S)-1))]
#         #     select=(S>th)
#         #     edge_index=edge_index[:,select]
        
#         return x, induced_edge_index, induced_edge_weight,induced_edge_dis, batch
Beispiel #18
0
    def forward(self, x, edge_index, edge_attr=None, batch=None):
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        x_information_score = self.calc_information_score(x, edge_index)
        score = torch.sum(torch.abs(x_information_score), dim=1)

        # Graph Pooling
        original_x = x
        perm = topk(score, self.ratio, batch)
        x = x[perm]
        batch = batch[perm]
        induced_edge_index, induced_edge_attr = filter_adj(
            edge_index, edge_attr, perm, num_nodes=score.size(0))

        # Discard structure learning layer, directly return
        if self.sl is False:
            return x, induced_edge_index, induced_edge_attr, batch, perm

        # Structure Learning
        if self.sample:
            # A fast mode for large graphs.
            # In large graphs, learning the possible edge weights between each pair of nodes is time consuming.
            # To accelerate this process, we sample it's K-Hop neighbors for each node and then learn the
            # edge weights between them.
            k_hop = 3
            if edge_attr is None:
                edge_attr = torch.ones((edge_index.size(1), ),
                                       dtype=torch.float,
                                       device=edge_index.device)

            hop_data = Data(x=original_x,
                            edge_index=edge_index,
                            edge_attr=edge_attr)
            for _ in range(k_hop - 1):
                hop_data = self.neighbor_augment(hop_data)
            hop_edge_index = hop_data.edge_index
            hop_edge_attr = hop_data.edge_attr
            new_edge_index, new_edge_attr = filter_adj(hop_edge_index,
                                                       hop_edge_attr,
                                                       perm,
                                                       num_nodes=score.size(0))

            row, col = new_edge_index

            if self.att.fast is not None:
                weights = (torch.cat([x[row], x[col]], dim=1) *
                           self.att.fast).sum(dim=-1)
            else:
                tmps = torch.cat([x[row], x[col]], dim=1)
                # assert (tmps.shape[0]==self.att.)
                weights = (tmps * self.att).sum(dim=-1)

            weights = F.leaky_relu(
                weights, self.negative_slop) + new_edge_attr * self.lamb
            adj = torch.zeros((x.size(0), x.size(0)),
                              dtype=torch.float,
                              device=x.device)
            adj[row, col] = weights
            new_edge_index, weights = dense_to_sparse(adj)
            row, col = new_edge_index
            if self.sparse:
                new_edge_attr = self.sparse_attention(weights, row)
            else:
                new_edge_attr = softmax(weights, row, x.size(0))
            # filter out zero weight edges
            adj[row, col] = new_edge_attr
            new_edge_index, new_edge_attr = dense_to_sparse(adj)
            # release gpu memory
            del adj
            torch.cuda.empty_cache()
        else:
            # Learning the possible edge weights between each pair of nodes in the pooled subgraph, relative slower.
            if edge_attr is None:
                induced_edge_attr = torch.ones(
                    (induced_edge_index.size(1), ),
                    dtype=x.dtype,
                    device=induced_edge_index.device)
            num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
            shift_cum_num_nodes = torch.cat(
                [num_nodes.new_zeros(1),
                 num_nodes.cumsum(dim=0)[:-1]], dim=0)
            cum_num_nodes = num_nodes.cumsum(dim=0)
            adj = torch.zeros((x.size(0), x.size(0)),
                              dtype=torch.float,
                              device=x.device)
            # Construct batch fully connected graph in block diagonal matirx format
            for idx_i, idx_j in zip(shift_cum_num_nodes, cum_num_nodes):
                adj[idx_i:idx_j, idx_i:idx_j] = 1.0
            new_edge_index, _ = dense_to_sparse(adj)
            row, col = new_edge_index
            if self.att.fast is not None:
                weights = (torch.cat([x[row], x[col]], dim=1) *
                           self.att.fast).sum(dim=-1)
            else:
                weights = (torch.cat([x[row], x[col]], dim=1) *
                           self.att).sum(dim=-1)
            weights = F.leaky_relu(weights, self.negative_slop)
            adj[row, col] = weights
            induced_row, induced_col = induced_edge_index

            adj[induced_row, induced_col] += induced_edge_attr * self.lamb
            weights = adj[row, col]
            if self.sparse:
                new_edge_attr = self.sparse_attention(weights, row)
            else:
                new_edge_attr = softmax(weights, row, x.size(0))
            # filter out zero weight edges
            adj[row, col] = new_edge_attr
            new_edge_index, new_edge_attr = dense_to_sparse(adj)
            # release gpu memory
            del adj
            torch.cuda.empty_cache()

        return x, new_edge_index, new_edge_attr, batch, perm
Beispiel #19
0
    def forward(self, x, edge_index, edge_attr, batch, h, neg_num, samp_bias1, samp_bias2):
        """

        :param x: node feature after convolution
        :param edge_index:
        :param edge_attr:
        :param batch:
        :param h: node feature before convolution
        :param neg_num:
        :param samp_bias1:
        :param samp_bias2:
        :return:
        """

        # I(h_i; x_i)
        res_mi_pos, res_mi_neg = self.disc1(x, h, process.negative_sampling_tg(batch, neg_num), samp_bias1, samp_bias2)
        mi_jsd_score = process.sp_func(res_mi_pos) + process.sp_func(torch.mean(res_mi_neg, dim=1))

        # Graph Pooling
        original_x = x
        perm = topk(mi_jsd_score, self.ratio, batch)
        x = x[perm]
        batch = batch[perm]
        induced_edge_index, induced_edge_attr = filter_adj(edge_index, edge_attr, perm, num_nodes=mi_jsd_score.size(0))

        # Discard structure learning layer, directly return
        if self.sl is False:
            return x, induced_edge_index, induced_edge_attr, batch

        # Structure Learning
        if self.sample:
            # A fast mode for large graphs.
            # In large graphs, learning the possible edge weights between each pair of nodes is time consuming.
            # To accelerate this process, we sample it's K-Hop neighbors for each node and then learn the
            # edge weights between them.
            k_hop = 3
            if edge_attr is None:
                edge_attr = torch.ones((edge_index.size(1),), dtype=torch.float, device=edge_index.device)

            hop_data = Data(x=original_x, edge_index=edge_index, edge_attr=edge_attr)
            for _ in range(k_hop - 1):
                hop_data = self.neighbor_augment(hop_data)
            hop_edge_index = hop_data.edge_index
            hop_edge_attr = hop_data.edge_attr
            new_edge_index, new_edge_attr = filter_adj(hop_edge_index, hop_edge_attr, perm, num_nodes=mi_jsd_score.size(0))

            new_edge_index, new_edge_attr = add_remaining_self_loops(new_edge_index, new_edge_attr, 0, x.size(0))
            row, col = new_edge_index
            weights = (torch.cat([x[row], x[col]], dim=1) * self.att).sum(dim=-1)
            weights = F.leaky_relu(weights, self.negative_slop) + new_edge_attr * self.lamb
            adj = torch.zeros((x.size(0), x.size(0)), dtype=torch.float, device=x.device)
            adj[row, col] = weights
            new_edge_index, weights = dense_to_sparse(adj)
            row, col = new_edge_index
            if self.sparse:
                new_edge_attr = self.sparse_attention(weights, row)
            else:
                new_edge_attr = softmax(weights, row, x.size(0))
            # filter out zero weight edges
            adj[row, col] = new_edge_attr
            new_edge_index, new_edge_attr = dense_to_sparse(adj)
            # release gpu memory
            del adj
            torch.cuda.empty_cache()
        else:
            # Learning the possible edge weights between each pair of nodes in the pooled subgraph, relative slower.
            if edge_attr is None:
                induced_edge_attr = torch.ones((induced_edge_index.size(1),), dtype=x.dtype,
                                               device=induced_edge_index.device)
            num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
            shift_cum_num_nodes = torch.cat([num_nodes.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]], dim=0)
            cum_num_nodes = num_nodes.cumsum(dim=0)
            adj = torch.zeros((x.size(0), x.size(0)), dtype=torch.float, device=x.device)
            # Construct batch fully connected graph in block diagonal matirx format
            for idx_i, idx_j in zip(shift_cum_num_nodes, cum_num_nodes):
                adj[idx_i:idx_j, idx_i:idx_j] = 1.0
            new_edge_index, _ = dense_to_sparse(adj)
            row, col = new_edge_index

            weights = (torch.cat([x[row], x[col]], dim=1) * self.att).sum(dim=-1)
            weights = F.leaky_relu(weights, self.negative_slop)
            adj[row, col] = weights
            induced_row, induced_col = induced_edge_index

            adj[induced_row, induced_col] += induced_edge_attr * self.lamb
            weights = adj[row, col]
            if self.sparse:
                new_edge_attr = self.sparse_attention(weights, row)
            else:
                new_edge_attr = softmax(weights, row, x.size(0))
            # filter out zero weight edges
            adj[row, col] = new_edge_attr
            new_edge_index, new_edge_attr = dense_to_sparse(adj)
            # release gpu memory
            del adj
            torch.cuda.empty_cache()

        return x, new_edge_index, new_edge_attr, batch