예제 #1
0
    def forward(self, x, edge_index, edge_attr=None, batch=None):
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))
        x = x.unsqueeze(-1) if x.dim() == 1 else x

        #SBTL
        score_s = self.sbtl_layer(x, edge_index).squeeze()
        #FBTL
        score_f = self.fbtl_layer(x).squeeze()
        #hyperparametr alpha
        score = score_s * self.alpha + score_f * (1 - self.alpha)

        score = score.unsqueeze(-1) if score.dim() == 0 else score

        if self.min_score is None:
            score = self.non_linearity(score)
        else:
            score = softmax(score, batch)
        perm = topk(score, self.ratio, batch)

        #fusion
        if (self.fusion_flag == 1):
            x = self.fusion(x, edge_index)

        x = x[perm] * score[perm].view(-1, 1)
        x = self.multiplier * x if self.multiplier != 1 else x

        batch = batch[perm]
        edge_index, edge_attr = filter_adj(edge_index,
                                           edge_attr,
                                           perm,
                                           num_nodes=score.size(0))

        return x, edge_index, edge_attr, batch, perm
예제 #2
0
    def forward(self, x, edge_index, edge_attr=None, batch=None, attn=None):
        """"""
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        attn = x if attn is None else attn
        attn = attn.unsqueeze(-1) if attn.dim() == 1 else attn
        score = self.gnn(attn, edge_index).view(-1)

        if self.min_score is None:
            score = self.nonlinearity(score)
        else:
            score = softmax(score, batch)

        perm = topk(score, self.ratio, batch, self.min_score)
        x = x[perm] * score[perm].view(-1, 1)
        x = self.multiplier * x if self.multiplier != 1 else x

        batch = batch[perm]
        edge_index, edge_attr = filter_adj(edge_index,
                                           edge_attr,
                                           perm,
                                           num_nodes=score.size(0))

        return x, edge_index, edge_attr, batch, perm, score[perm]
예제 #3
0
    def forward(self, x, edge_index, edge_attr=None, batch=None, attn=None):
        """"""
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        attn = x if attn is None else attn
        attn = attn.unsqueeze(-1) if attn.dim() == 1 else attn
        score = self.gnn(attn, edge_index).view(-1)

        #####  zero mean for each instance #########3
        score = score.view(batch.max() + 1, -1)
        score = score - score.mean(1, keepdim=True)  #
        score = score.view(-1)

        if self.min_score is None:
            score = self.nonlinearity(score)
        else:
            score = softmax(score, batch)

        perm = topk(score, self.ratio, batch, self.min_score)
        x = x[perm] * score[perm].view(-1, 1)
        x = self.multiplier * x if self.multiplier != 1 else x

        batch = batch[perm]
        edge_index, edge_attr = filter_adj(edge_index,
                                           edge_attr,
                                           perm,
                                           num_nodes=score.size(0))

        # we changed the last returm term --- score, which are the scores for all the nodes
        return x, edge_index, edge_attr, batch, perm, score.view(
            batch.max() + 1, -1)
def test_filter_adj():
    edge_index = torch.tensor([[0, 0, 1, 1, 2, 2, 3, 3],
                               [1, 3, 0, 2, 1, 3, 0, 2]])
    edge_attr = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
    perm = torch.tensor([2, 3])

    edge_index, edge_attr = filter_adj(edge_index, edge_attr, perm)
    assert edge_index.tolist() == [[0, 1], [1, 0]]
    assert edge_attr.tolist() == [6, 8]
예제 #5
0
 def forward(self, input, edge_index, edge_attr=None, batch=None):
     if batch is None:
         batch = edge_index.new_zeros(input.size(0))
     score = self.score_layer(input, edge_index).squeeze()
     perm = topk(score, self.ratio, batch)
     input = input[perm] * self.non_linearity(score[perm]).view(-1, 1)
     batch = batch[perm]
     edge_index, edge_attr = filter_adj(edge_index,
                                        edge_attr,
                                        perm,
                                        num_nodes=score.size(0))
     return input, edge_index, edge_attr, batch, perm
예제 #6
0
 def forward(self, graph, x, batch=None):
     if batch is None:
         batch = graph.edge_index.new_zeros(x.size(0))
     score = self.score_layer(graph, x).squeeze()
     perm = topk(score, self.ratio, batch)
     x = x[perm] * self.non_linearity(score[perm]).view(-1, 1)
     batch = batch[perm]
     edge_index, edge_attr = filter_adj(graph.edge_index,
                                        graph.edge_weight,
                                        perm,
                                        num_nodes=score.size(0))
     return x, edge_index, edge_attr, batch, perm
예제 #7
0
    def forward(self, x, edge_index=None, edge_attr=None, batch=None):
        if batch is None:
            batch = self.A.new_zeros(x.size(0))
        #x = x.unsqueeze(-1) if x.dim() == 1 else x
        score = self.score_layer(x, self.A).squeeze()

        perm = topk(score, self.ratio, batch)
        x = x[perm] * self.non_linearity(score[perm]).view(-1, 1)
        batch = batch[perm]
        edge_index, edge_attr = filter_adj(edge_index,
                                           edge_attr,
                                           perm,
                                           num_nodes=score.size(0))

        return x, edge_index, edge_attr, batch, perm
예제 #8
0
    def forward(self, x, edge_index, edge_attr=None, batch=None):
        """"""
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        x = x.unsqueeze(-1) if x.dim() == 1 else x

        score = torch.tanh(self.gnn(x, edge_index).view(-1))
        perm = topk(score, self.ratio, batch)
        x = x[perm] * score[perm].view(-1, 1)
        batch = batch[perm]
        edge_index, edge_attr = filter_adj(
            edge_index, edge_attr, perm, num_nodes=score.size(0))

        return x, edge_index, edge_attr, batch, perm
예제 #9
0
    def forward(self, x, edge_index, edge_attr, batch):
        score = self.score_layer(x, edge_index).squeeze()

        perm = topk(score, self.ratio, batch)

        x = x[perm] * self.non_linearity(score[perm]).view(-1, 1)

        batch = batch[perm]

        edge_index, edge_attr = filter_adj(
            edge_index, edge_attr, perm, num_nodes=score.size(0))

        a = gmp(x, batch)
        m = gap(x, batch)

        return torch.cat([m, a], dim=1)
예제 #10
0
    def forward(self, x, edge_index, attention, batch=None, direction=1):
        e_batch = edge_index[0]
        degree = torch.bincount(e_batch)
        node_scores = direction * g_pooling(attention, e_batch).view(-1)
        node_scores = node_scores.mul(degree)

        perm = topk(node_scores, self.rate, batch)

        edge_index, _ = self.augment_adj(edge_index, None, x.size(0))
        edge_index, _ = filter_adj(edge_index,
                                   None,
                                   perm,
                                   num_nodes=node_scores.size(0))
        x = x[perm]
        batch = batch[perm]

        return x, edge_index, batch, perm.view((1, -1))
예제 #11
0
    def forward(self, x, edge_index, edge_attr=None, batch=None):
        r"""Forward computation which computes the raw edge score, normalizes
        it, and merges the edges.

        Args:
            x (Tensor): The node features.
            edge_index (LongTensor): The edge indices.
            batch (LongTensor): Batch vector
                :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
                each node to a specific example.

        Return types:
            * **x** *(Tensor)* - The pooled node features.
            * **edge_index** *(LongTensor)* - The coarsened edge indices.
            * **batch** *(LongTensor)* - The coarsened batch vector.
            * **unpool_info** *(unpool_description)* - Information that is
              consumed by :func:`EdgePooling.unpool` for unpooling.
        """
        # e = torch.cat([x[edge_index[0]], x[edge_index[1]]], dim=-1)
        # e = self.lin(e).view(-1)
        # e = F.dropout(e, p=self.dropout, training=self.training)
        # e = self.compute_edge_score(e)
        # e = e + self.add_to_edge_score

        # TODO: change linear to consider both node and edge features
        # n = self.lin(x).view(-1)

        n = self.score_func(x, edge_index).view(-1)

        n = F.dropout(n, p=self.dropout, training=self.training)
        n = self.compute_node_score(n)

        perm = self.__merge_stars_with_attr__(x, edge_index, n)

        batch = batch[perm]
        edge_index, edge_attr = filter_adj(edge_index,
                                           edge_attr,
                                           perm,
                                           num_nodes=n.size(0))
        x = x[perm]

        return x, edge_index, edge_attr, batch, perm
예제 #12
0
    def forward(self, x, edge_index, edge_attr=None, batch=None):
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))
        num_node = x.size(0)

        k = F.relu(self.lin_2(x))

        A = SparseTensor.from_edge_index(edge_index=edge_index,
                                         edge_attr=edge_attr,
                                         sparse_sizes=(num_node, num_node))
        I = SparseTensor.eye(num_node, device=self.args.device)
        A_wave = fill_diag(A, 1)

        s = A_wave @ k

        score = s.squeeze()
        perm = topk(score, self.ratio, batch)

        A = self.norm(A)

        K_neighbor = A * k.T
        x_neighbor = K_neighbor @ x

        # ----modified
        deg = sum(A, dim=1)
        deg_inv = deg.pow_(-1)
        deg_inv.masked_fill_(deg_inv == float('inf'), 0.)
        x_neighbor = x_neighbor * deg_inv.view(1, -1).T
        # ----
        x_self = x * k

        x = x_neighbor * (
            1 - self.args.combine_ratio) + x_self * self.args.combine_ratio

        x = x[perm]
        batch = batch[perm]
        edge_index, edge_attr = filter_adj(edge_index,
                                           edge_attr,
                                           perm,
                                           num_nodes=s.size(0))

        return x, edge_index, edge_attr, batch, perm
예제 #13
0
    def forward(self, x, edge_index, edge_attr=None, batch=None):
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        #iterative fusion
        for i in range(3):
            score_ = self.score_layer(x, edge_index).squeeze()
            if i > 0:
                score = score * score_ + score
            else:
                score = score_

        perm = topk(score, self.ratio, batch)
        x = x[perm] * self.non_linearity(score[perm]).view(-1, 1)
        batch = batch[perm]
        edge_index, edge_attr = filter_adj(edge_index,
                                           edge_attr,
                                           perm,
                                           num_nodes=score.size(0))

        return x, edge_index, edge_attr, batch, perm
예제 #14
0
    def forward(self, x, edge_index, edge_attr=None, batch=None):
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        x_information_score = self.calc_information_score(x, edge_index)
        score = torch.sum(torch.abs(x_information_score), dim=1)

        # Graph Pooling
        original_x = x
        perm = topk(score, self.ratio, batch)
        x = x[perm]
        batch = batch[perm]
        induced_edge_index, induced_edge_attr = filter_adj(
            edge_index, edge_attr, perm, num_nodes=score.size(0))

        # Discard structure learning layer, directly return
        if self.sl is False:
            return x, induced_edge_index, induced_edge_attr, batch, perm

        # Structure Learning
        if self.sample:
            # A fast mode for large graphs.
            # In large graphs, learning the possible edge weights between each pair of nodes is time consuming.
            # To accelerate this process, we sample it's K-Hop neighbors for each node and then learn the
            # edge weights between them.
            k_hop = 3
            if edge_attr is None:
                edge_attr = torch.ones((edge_index.size(1), ),
                                       dtype=torch.float,
                                       device=edge_index.device)

            hop_data = Data(x=original_x,
                            edge_index=edge_index,
                            edge_attr=edge_attr)
            for _ in range(k_hop - 1):
                hop_data = self.neighbor_augment(hop_data)
            hop_edge_index = hop_data.edge_index
            hop_edge_attr = hop_data.edge_attr
            new_edge_index, new_edge_attr = filter_adj(hop_edge_index,
                                                       hop_edge_attr,
                                                       perm,
                                                       num_nodes=score.size(0))

            row, col = new_edge_index

            if self.att.fast is not None:
                weights = (torch.cat([x[row], x[col]], dim=1) *
                           self.att.fast).sum(dim=-1)
            else:
                tmps = torch.cat([x[row], x[col]], dim=1)
                # assert (tmps.shape[0]==self.att.)
                weights = (tmps * self.att).sum(dim=-1)

            weights = F.leaky_relu(
                weights, self.negative_slop) + new_edge_attr * self.lamb
            adj = torch.zeros((x.size(0), x.size(0)),
                              dtype=torch.float,
                              device=x.device)
            adj[row, col] = weights
            new_edge_index, weights = dense_to_sparse(adj)
            row, col = new_edge_index
            if self.sparse:
                new_edge_attr = self.sparse_attention(weights, row)
            else:
                new_edge_attr = softmax(weights, row, x.size(0))
            # filter out zero weight edges
            adj[row, col] = new_edge_attr
            new_edge_index, new_edge_attr = dense_to_sparse(adj)
            # release gpu memory
            del adj
            torch.cuda.empty_cache()
        else:
            # Learning the possible edge weights between each pair of nodes in the pooled subgraph, relative slower.
            if edge_attr is None:
                induced_edge_attr = torch.ones(
                    (induced_edge_index.size(1), ),
                    dtype=x.dtype,
                    device=induced_edge_index.device)
            num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
            shift_cum_num_nodes = torch.cat(
                [num_nodes.new_zeros(1),
                 num_nodes.cumsum(dim=0)[:-1]], dim=0)
            cum_num_nodes = num_nodes.cumsum(dim=0)
            adj = torch.zeros((x.size(0), x.size(0)),
                              dtype=torch.float,
                              device=x.device)
            # Construct batch fully connected graph in block diagonal matirx format
            for idx_i, idx_j in zip(shift_cum_num_nodes, cum_num_nodes):
                adj[idx_i:idx_j, idx_i:idx_j] = 1.0
            new_edge_index, _ = dense_to_sparse(adj)
            row, col = new_edge_index
            if self.att.fast is not None:
                weights = (torch.cat([x[row], x[col]], dim=1) *
                           self.att.fast).sum(dim=-1)
            else:
                weights = (torch.cat([x[row], x[col]], dim=1) *
                           self.att).sum(dim=-1)
            weights = F.leaky_relu(weights, self.negative_slop)
            adj[row, col] = weights
            induced_row, induced_col = induced_edge_index

            adj[induced_row, induced_col] += induced_edge_attr * self.lamb
            weights = adj[row, col]
            if self.sparse:
                new_edge_attr = self.sparse_attention(weights, row)
            else:
                new_edge_attr = softmax(weights, row, x.size(0))
            # filter out zero weight edges
            adj[row, col] = new_edge_attr
            new_edge_index, new_edge_attr = dense_to_sparse(adj)
            # release gpu memory
            del adj
            torch.cuda.empty_cache()

        return x, new_edge_index, new_edge_attr, batch, perm
예제 #15
0
    def forward(self, x, edge_index, edge_attr, batch, h, neg_num, samp_bias1, samp_bias2):
        """

        :param x: node feature after convolution
        :param edge_index:
        :param edge_attr:
        :param batch:
        :param h: node feature before convolution
        :param neg_num:
        :param samp_bias1:
        :param samp_bias2:
        :return:
        """

        # I(h_i; x_i)
        res_mi_pos, res_mi_neg = self.disc1(x, h, process.negative_sampling_tg(batch, neg_num), samp_bias1, samp_bias2)
        mi_jsd_score = process.sp_func(res_mi_pos) + process.sp_func(torch.mean(res_mi_neg, dim=1))

        # Graph Pooling
        original_x = x
        perm = topk(mi_jsd_score, self.ratio, batch)
        x = x[perm]
        batch = batch[perm]
        induced_edge_index, induced_edge_attr = filter_adj(edge_index, edge_attr, perm, num_nodes=mi_jsd_score.size(0))

        # Discard structure learning layer, directly return
        if self.sl is False:
            return x, induced_edge_index, induced_edge_attr, batch

        # Structure Learning
        if self.sample:
            # A fast mode for large graphs.
            # In large graphs, learning the possible edge weights between each pair of nodes is time consuming.
            # To accelerate this process, we sample it's K-Hop neighbors for each node and then learn the
            # edge weights between them.
            k_hop = 3
            if edge_attr is None:
                edge_attr = torch.ones((edge_index.size(1),), dtype=torch.float, device=edge_index.device)

            hop_data = Data(x=original_x, edge_index=edge_index, edge_attr=edge_attr)
            for _ in range(k_hop - 1):
                hop_data = self.neighbor_augment(hop_data)
            hop_edge_index = hop_data.edge_index
            hop_edge_attr = hop_data.edge_attr
            new_edge_index, new_edge_attr = filter_adj(hop_edge_index, hop_edge_attr, perm, num_nodes=mi_jsd_score.size(0))

            new_edge_index, new_edge_attr = add_remaining_self_loops(new_edge_index, new_edge_attr, 0, x.size(0))
            row, col = new_edge_index
            weights = (torch.cat([x[row], x[col]], dim=1) * self.att).sum(dim=-1)
            weights = F.leaky_relu(weights, self.negative_slop) + new_edge_attr * self.lamb
            adj = torch.zeros((x.size(0), x.size(0)), dtype=torch.float, device=x.device)
            adj[row, col] = weights
            new_edge_index, weights = dense_to_sparse(adj)
            row, col = new_edge_index
            if self.sparse:
                new_edge_attr = self.sparse_attention(weights, row)
            else:
                new_edge_attr = softmax(weights, row, x.size(0))
            # filter out zero weight edges
            adj[row, col] = new_edge_attr
            new_edge_index, new_edge_attr = dense_to_sparse(adj)
            # release gpu memory
            del adj
            torch.cuda.empty_cache()
        else:
            # Learning the possible edge weights between each pair of nodes in the pooled subgraph, relative slower.
            if edge_attr is None:
                induced_edge_attr = torch.ones((induced_edge_index.size(1),), dtype=x.dtype,
                                               device=induced_edge_index.device)
            num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
            shift_cum_num_nodes = torch.cat([num_nodes.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]], dim=0)
            cum_num_nodes = num_nodes.cumsum(dim=0)
            adj = torch.zeros((x.size(0), x.size(0)), dtype=torch.float, device=x.device)
            # Construct batch fully connected graph in block diagonal matirx format
            for idx_i, idx_j in zip(shift_cum_num_nodes, cum_num_nodes):
                adj[idx_i:idx_j, idx_i:idx_j] = 1.0
            new_edge_index, _ = dense_to_sparse(adj)
            row, col = new_edge_index

            weights = (torch.cat([x[row], x[col]], dim=1) * self.att).sum(dim=-1)
            weights = F.leaky_relu(weights, self.negative_slop)
            adj[row, col] = weights
            induced_row, induced_col = induced_edge_index

            adj[induced_row, induced_col] += induced_edge_attr * self.lamb
            weights = adj[row, col]
            if self.sparse:
                new_edge_attr = self.sparse_attention(weights, row)
            else:
                new_edge_attr = softmax(weights, row, x.size(0))
            # filter out zero weight edges
            adj[row, col] = new_edge_attr
            new_edge_index, new_edge_attr = dense_to_sparse(adj)
            # release gpu memory
            del adj
            torch.cuda.empty_cache()

        return x, new_edge_index, new_edge_attr, batch
예제 #16
0
    def forward(self, x, x_score, edge_index, edge_attr, batch=None):
        n=x.shape[0]

        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        # Graph Pooling
        perm = topk(x_score.view(-1), self.ratio, batch)
        
        induced_edge_index, induced_edge_attr = filter_adj(edge_index, edge_attr, perm, num_nodes=x.shape[0])
        # isolate_mask=(perm.view(-1,1)==induced_edge_index.view(-1).unique()).sum(dim=1)>0
        # perm = perm[isolate_mask]
        # induced_edge_index, induced_edge_attr = filter_adj(edge_index, edge_attr, perm, num_nodes=x.shape[0])

        if edge_index.shape[1]>0:
            row,col=edge_index
            S=torch.exp(torch.norm(x[row]-x[col],dim=1))
            th=torch.sort(S,descending=True).values[int(self.edge_ratio*(len(S)-1))]
            select=(S>th)
            edge_index=edge_index[:,select]
        
        x = x[perm]
        batch = batch[perm]
        return x, induced_edge_index, induced_edge_attr, batch


# ############ add structure learning
# class LookHopsPool(torch.nn.Module):
#     def __init__(self, k, out_channels,ratio=0.8,edge_ratio=0.8):
#         super(LookHopsPool, self).__init__()
#         self.k=k
#         self.ratio = ratio
#         self.edge_ratio=edge_ratio
#         self.idx=1
#         self.node_att = nn.Linear(k*out_channels,1)
#         self.edge_att = nn.Linear(k*out_channels*2,1)
#         self.alpha = nn.Parameter(torch.tensor(1.0))

#     def forward(self, x, neighbor_info, edge_index, edge_dis, batch):
#         n=x.shape[0]

#         if batch is None:
#             batch = edge_index.new_zeros(x.size(0))
        
#         x_score = self.node_att(neighbor_info)
#         x=x*x_score

#         # Graph Pooling
#         perm = topk(x_score.view(-1), self.ratio, batch)
#         induced_edge_index, induced_edge_dis = filter_adj(edge_index, edge_dis, perm, num_nodes=x.shape[0])
#         x = x[perm]
#         batch = batch[perm]
#         # neighbor_info = neighbor_info[perm]
#         # row,col=induced_edge_index
#         induced_edge_weight = torch.exp(-self.alpha*induced_edge_dis)
#         if torch.isnan(induced_edge_weight).any():
#             print('NO')

#         # if edge_index.shape[1]>0:
#         #     row,col=edge_index
#         #     S=torch.exp(torch.norm(x[row]-x[col],dim=1))
#         #     th=torch.sort(S,descending=True).values[int(self.edge_ratio*(len(S)-1))]
#         #     select=(S>th)
#         #     edge_index=edge_index[:,select]
        
#         return x, induced_edge_index, induced_edge_weight,induced_edge_dis, batch