Exemple #1
0
def test_issue_2484(idtype):
    import dgl.function as fn
    g = dgl.graph(([0, 1, 2], [1, 2, 3]), idtype=idtype, device=F.ctx())
    x = F.copy_to(F.randn((4, )), F.ctx())
    g.ndata['x'] = x
    g.pull([2, 1], fn.u_add_v('x', 'x', 'm'), fn.sum('m', 'x'))
    y1 = g.ndata['x']

    g.ndata['x'] = x
    g.pull([1, 2], fn.u_add_v('x', 'x', 'm'), fn.sum('m', 'x'))
    y2 = g.ndata['x']

    assert F.allclose(y1, y2)
Exemple #2
0
    def forward(self, g, node_feats, edge_feats, node_only=False):
        r"""Update node and edge representations.

        Parameters
        ----------
        g : DGLGraph
            DGLGraph for a batch of graphs
        node_feats : float32 tensor of shape (V, node_in_feats)
            Input node features. V for the number of nodes in the batch of graphs.
        edge_feats : float32 tensor of shape (E, edge_in_feats)
            Input edge features. E for the number of edges in the batch of graphs.
        node_only : bool
            Whether to update node representations only. If False, edge representations
            will be updated as well. Default to False.

        Returns
        -------
        new_node_feats : float32 tensor of shape (V, node_out_feats)
            Updated node representations.
        new_edge_feats : float32 tensor of shape (E, edge_out_feats)
            Updated edge representations.
        """
        g = g.local_var()

        # Update node features
        node_node_feats = self.activation(self.node_to_node(node_feats))
        g.edata['e2n'] = self.activation(self.edge_to_node(edge_feats))
        g.update_all(fn.copy_edge('e2n', 'm'), fn.sum('m', 'e2n'))
        edge_node_feats = g.ndata.pop('e2n')
        new_node_feats = self.activation(
            self.update_node(
                torch.cat([node_node_feats, edge_node_feats], dim=1)))

        if node_only:
            return new_node_feats

        # Update edge features
        g.ndata['left_hv'] = self.left_node_to_edge(node_feats)
        g.ndata['right_hv'] = self.right_node_to_edge(node_feats)
        g.apply_edges(fn.u_add_v('left_hv', 'right_hv', 'first'))
        g.apply_edges(fn.u_add_v('right_hv', 'left_hv', 'second'))
        first_edge_feats = self.activation(g.edata.pop('first'))
        second_edge_feats = self.activation(g.edata.pop('second'))
        third_edge_feats = self.activation(self.edge_to_edge(edge_feats))
        new_edge_feats = self.activation(
            self.update_edge(
                torch.cat(
                    [first_edge_feats, second_edge_feats, third_edge_feats],
                    dim=1)))

        return new_node_feats, new_edge_feats
    def forward(self, graph, feat):
        graph = graph.local_var()
        h = self.feat_drop(feat)
        feat = self.fc(h).view(-1, self._num_heads, self._out_feats)
        el = (feat * self.attn_l).sum(dim=-1).unsqueeze(-1)
        er = (feat * self.attn_r).sum(dim=-1).unsqueeze(-1)
        graph.ndata.update({"ft": feat, "el": el, "er": er})
        # compute edge attention
        graph.apply_edges(fn.u_add_v("el", "er", "e"))
        # apply leaky relu
        graph.apply_edges(self.relu_udf)

        # compute softmax/sparsemax
        if self.sparsemax:
            graph.apply_edges(self.sparsemax_udf)
        else:
            graph.edata["a"] = edge_softmax(graph, graph.edata.pop("e"))

        # attention dropout
        graph.apply_edges(self.attn_drop_udf)

        # message passing
        graph.update_all(fn.u_mul_e("ft", "a", "m"), fn.sum("m", "ft"))
        rst = graph.ndata["ft"]
        # residual
        if self.res_fc is not None:
            resval = self.res_fc(h).view(h.shape[0], -1, self._out_feats)
            rst = rst + resval
        # activation
        if self.activation:
            rst = self.activation(rst)
        return rst
Exemple #4
0
    def forward(self, graph, feat):
        with graph.local_scope():
            if not self._allow_zero_in_degree:
                if (graph.in_degrees() == 0).any():
                    assert False

            if isinstance(feat, tuple):
                h_src = self.feat_drop(feat[0])
                h_dst = self.feat_drop(feat[1])
                if not hasattr(self, "fc_src"):
                    self.fc_src, self.fc_dst = self.fc, self.fc
                feat_src, feat_dst = h_src, h_dst
                feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats)
                feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats)
            else:
                h_src = h_dst = self.feat_drop(feat)
                feat_src, feat_dst = h_src, h_dst
                feat_src = feat_dst = self.fc(h_src).view(-1, self._num_heads, self._out_feats)
                if graph.is_block:
                    feat_dst = feat_src[: graph.number_of_dst_nodes()]

            if self._norm == "both":
                degs = graph.out_degrees().float().clamp(min=1)
                norm = torch.pow(degs, -0.5)
                shp = norm.shape + (1,) * (feat_src.dim() - 1)
                norm = torch.reshape(norm, shp)
                feat_src = feat_src * norm
            
            # Implement GeniePath adaptive-breadth function only
            graph.srcdata.update({"ft": feat_src})
            graph.dstdata.update({"ft_dst": feat_dst})
            
            graph.apply_edges(fn.u_add_v("ft", "ft_dst", "e"))
            e = graph.edata.pop("e")
            e = self.attn * torch.tanh(e)
            
            #Genie Path paper doesn't use LeakyReLU
            #e = self.leaky_relu(graph.edata.pop("e"))
            # compute softmax
            graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))
            # message passing
            graph.srcdata.update({"ft": feat_src})
            graph.update_all(fn.u_mul_e("ft", "a", "m"), fn.sum("m", "ft"))
            rst = graph.dstdata["ft"]

            if self._norm == "both":
                degs = graph.in_degrees().float().clamp(min=1)
                norm = torch.pow(degs, 0.5)
                shp = norm.shape + (1,) * (feat_dst.dim() - 1)
                norm = torch.reshape(norm, shp)
                rst = rst * norm

            # residual
            if self.res_fc is not None:
                resval = self.res_fc(h_dst).view(h_dst.shape[0], -1, self._out_feats)
                rst = rst + resval
            # activation
            if self._activation is not None:
                rst = self._activation(rst)
            return rst
Exemple #5
0
    def forward(self, graph, node_feat, edge_feat):
        with graph.local_scope():
            h_src = h_dst = node_feat
            feat_src = feat_dst = self.fc(h_src).view(-1, self._edata_channels, self._out_feats)
            e_feat = self.edge_fc(edge_feat).view(-1, self._edata_channels, 1)
            graph.edata.update({'feat': e_feat})
            el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)
            er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)
            graph.srcdata.update({'feat': feat_src, 'el': el})
            graph.dstdata.update({'er': er})

            # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
            graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
            e = graph.edata.pop('e') * e_feat
            e = self.leaky_relu(e)

            # compute softmax
            graph.edata['a'] = edge_softmax(graph, e)

            # message passing
            def message_func(edges):
                feat_with_e = th.cat([edges.src['feat'], edges.data['feat']], 2)
                # apply a fc layer to adjust the dim of node feat that concatenate E_p to the out_feat_dim
                feat_with_e = self.nfeat_with_e_fc(feat_with_e)
                return {'m': edges.data['a'] * feat_with_e}

            graph.update_all(message_func,
                             fn.sum('m', 'ft'))
            rst = graph.dstdata['ft']
            rst = th.sigmoid(rst)
            return rst
Exemple #6
0
    def forward(self, graph, feat):
        '''

        :param graph: DGLGraph
        :param feat: <N, b, F>
        :return:
        '''
        with graph.local_scope():
            N, b, _ = feat.size()
            graph = graph.local_var()
            graph = graph.to(feat.device)
            feat = torch.cat([self.fc1(feat[:get_Parameter('taxi_size')]), self.fc2(feat[get_Parameter('taxi_size'):])], dim=0)
            feat_src = feat_dst = feat.view(N, b, self._num_heads, self._out_feats)
            #feat_src = feat_dst = self.fc(feat).view(N, b, self._num_heads, self._out_feats)
            el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)
            er = (feat_dst * self.attn_l).sum(dim=-1).unsqueeze(-1)
            graph.srcdata.update({'ft': feat_src, 'el': el})
            graph.dstdata.update({'er': er})

            graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
            #graph.apply_edges(fn.u_mul_e('e', 'w', 'e'))
            e = self.leaky_relu(graph.edata.pop('e'))
            graph.edata['a'] = self.attn_drop(edge_softmax(graph, e))
            #print(graph.edata['a'].size())
            graph.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'ft'))
            rst = graph.dstdata['ft']
            rst = rst.reshape(N, -1, self._num_heads*self._out_feats)
            return rst, graph.edata['a']
    def forward(self, g, h, e):

        h_in = h  # for residual connection

        g.ndata['h'] = h
        g.ndata['Ah'] = self.A(h)
        g.ndata['Bh'] = self.B(h)
        g.ndata['Dh'] = self.D(h)
        g.ndata['Eh'] = self.E(h)
        #g.update_all(self.message_func,self.reduce_func)
        g.apply_edges(fn.u_add_v('Dh', 'Eh', 'e'))
        g.edata['sigma'] = torch.sigmoid(g.edata['e'])
        g.update_all(fn.u_mul_e('Bh', 'sigma', 'm'),
                     fn.sum('m', 'sum_sigma_h'))
        g.update_all(fn.copy_e('sigma', 'm'), fn.sum('m', 'sum_sigma'))
        g.ndata['h'] = g.ndata['Ah'] + g.ndata['sum_sigma_h'] / (
            g.ndata['sum_sigma'] + 1e-6)
        h = g.ndata['h']  # result of graph convolution

        if self.batch_norm:
            h = self.bn_node_h(h)  # batch normalization

        h = F.relu(h)  # non-linear activation

        if self.residual:
            h = h_in + h  # residual connection

        h = F.dropout(h, self.dropout, training=self.training)

        return h, e
Exemple #8
0
    def forward(self, g, feat_src, feat_dst):
        """
        :param g: DGLGraph 邻居-目标顶点二分图
        :param feat_src: tensor(N_src, d) 邻居顶点输入特征
        :param feat_dst: tensor(N_dst, d) 目标顶点输入特征
        :return: tensor(N_dst, d) 目标顶点输出特征
        """
        with g.local_scope():
            # HeCo作者代码中使用attn_drop的方式与原始GAT不同,这样是不对的,却能顶点聚类提升性能……
            attn_l = self.attn_drop(self.attn_l)
            attn_r = self.attn_drop(self.attn_r)
            el = (feat_src * attn_l).sum(dim=-1).unsqueeze(dim=-1)  # (N_src, 1)
            er = (feat_dst * attn_r).sum(dim=-1).unsqueeze(dim=-1)  # (N_dst, 1)
            g.srcdata.update({'ft': feat_src, 'el': el})
            g.dstdata['er'] = er
            g.apply_edges(fn.u_add_v('el', 'er', 'e'))
            e = self.leaky_relu(g.edata.pop('e'))
            g.edata['a'] = edge_softmax(g, e)  # (E, 1)

            # 消息传递
            g.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'ft'))
            ret = g.dstdata['ft']
            if self.activation:
                ret = self.activation(ret)
            return ret
Exemple #9
0
    def forward(self, g, feat):
        """
        :param g: DGLGraph 二分图(只包含一种关系)
        :param feat: tensor(N_src, d_in) or (tensor(N_src, d_in), tensor(N_dst, d_in)) 输入特征
        :return: tensor(N_dst, K*d_out) 该关系关于目标顶点的表示
        """
        with g.local_scope():
            feat_src, feat_dst = expand_as_pair(feat, g)
            feat_src = self.fc_src(self.feat_drop(feat_src)).view(-1, self.num_heads, self.out_dim)
            feat_dst = self.fc_dst(self.feat_drop(feat_dst)).view(-1, self.num_heads, self.out_dim)

            # a^T (z_u || z_v) = (a_l^T || a_r^T) (z_u || z_v) = a_l^T z_u + a_r^T z_v = el + er
            el = (feat_src * self.attn_src[:, :self.out_dim]).sum(dim=-1, keepdim=True)  # (N_src, K, 1)
            er = (feat_dst * self.attn_src[:, self.out_dim:]).sum(dim=-1, keepdim=True)  # (N_dst, K, 1)
            g.srcdata.update({'ft': feat_src, 'el': el})
            g.dstdata['er'] = er
            g.apply_edges(fn.u_add_v('el', 'er', 'e'))
            e = self.leaky_relu(g.edata.pop('e'))
            g.edata['a'] = edge_softmax(g, e)  # (E, K, 1)

            # 消息传递
            g.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'ft'))
            ret = g.dstdata['ft'].view(-1, self.num_heads * self.out_dim)
            if self.activation:
                ret = self.activation(ret)
            return ret
Exemple #10
0
 def forward(self, graph, feat):
     graph = graph.local_var()
     if isinstance(feat, tuple):
         h_src = self.feat_drop(feat[0])
         h_dst = self.feat_drop(feat[1])
         feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats)
         feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats)
     else:
         h_src = h_dst = self.feat_drop(feat)
         feat_src = feat_dst = self.fc(h_src).view(
             -1, self._num_heads, self._out_feats)
     el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)
     er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)
     graph.srcdata.update({'ft': feat_src, 'el': el})
     graph.dstdata.update({'er': er})
     # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
     graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
     e = self.leaky_relu(graph.edata.pop('e'))
     # compute softmax
     graph.edata['a'] = self.attn_drop(edge_softmax(graph, e))
     # message passing
     graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
                      fn.sum('m', 'ft'))
     rst = graph.dstdata['ft']
     # residual
     if self.res_fc is not None:
         resval = self.res_fc(h_dst).view(h_dst.shape[0], -1, self._out_feats)
         rst = rst + resval
     # activation
     if self.activation:
         rst = self.activation(rst)
     return rst
Exemple #11
0
    def forward(self,
                g,
                h,
                logits,
                old_z,
                attn_l,
                attn_r,
                *,
                shared_tau=True,
                tau1=None,
                tau2=None):
        with g.local_scope():
            h = self.dropout(h)

            if self.fc is not None:
                feat = self.fc(h).view(-1, self._num_heads, self._out_feats)
            else:
                feat = h
            g.ndata["h"] = feat  # (n_node, n_feat)
            g.ndata["logits"] = logits

            degs = g.in_degrees().float().clamp(min=1)
            norm = torch.pow(degs, -0.5).to(feat.device).unsqueeze(1)
            g.ndata["degree"] = degs

            el = (feat * attn_l).sum(dim=-1).unsqueeze(-1)
            er = (feat * attn_r).sum(dim=-1).unsqueeze(-1)
            g.ndata.update({"ft": feat, "el": el, "er": er})
            # compute edge attention
            g.apply_edges(fn.u_add_v("el", "er", "e"))
            e = self.leaky_relu(g.edata.pop("e"))
            # compute softmax
            g.edata["a"] = self.dropout(edge_softmax(g, e))

            g.update_all(
                message_func=adaptive_attn_message_func,
                reduce_func=adaptive_attn_reduce_func,
            )
            f1 = g.ndata.pop("f1")
            f2 = g.ndata.pop("f2")
            norm_f1 = self.ln1(f1)
            norm_f2 = self.ln2(f2)
            if shared_tau:
                z = torch.sigmoid((-1) * (norm_f1 - tau1)) * torch.sigmoid(
                    (-1) * (norm_f2 - tau2))
            else:
                # tau for each layer
                z = torch.sigmoid(
                    (-1) * (norm_f1 - self.tau1)) * torch.sigmoid(
                        (-1) * (norm_f2 - self.tau2))

            gate = torch.min(old_z, z)

            agg = g.ndata.pop("agg")
            normagg = agg * norm.unsqueeze(1)  # normalization by tgt degree

            if self.activation:
                normagg = self.activation(normagg)
            new_h = feat + gate.unsqueeze(2) * normagg
            return new_h, z
Exemple #12
0
    def update_all_p_norm(self, graph):

        """ 
        Attempt at robust p-norm á:
        def robust_norm(x, p):
                a = np.abs(x).max()
                return a * norm1(x / a, p)

            def norm1(x, p):
                "First-pass implementation of p-norm."
                return (np.abs(x)**p).sum() ** (1./p) """

        p = torch.clamp(self.P,1,100)
        
        graph.apply_edges(fn.u_add_v('Dh', 'Eh', 'DEh'))
        graph.edata['e'] = graph.edata['DEh'] + graph.edata['Ce']
        graph.edata['sigma'] = torch.sigmoid(graph.edata['e']) # n_{ij}

        alpha = torch.max(torch.abs(torch.cat((graph.ndata['Bh'],graph.edata['sigma']), dim=0)))

        graph.ndata['Bh_pow'] = (torch.abs(graph.ndata['Bh'])/alpha).pow(p)
        graph.edata['sig_pow'] = (torch.abs(graph.edata['sigma'])/alpha).pow(p)
        graph.update_all(fn.u_mul_e('Bh_pow', 'sig_pow', 'm'), fn.sum('m', 'sum_sigma_h')) # u_mul_e = elementwise mul. Output "m" = n_{ij}***Vh. Then sum! 
                                                                                 # Update_all - send messages through all edges and update all nodes.
        
        graph.update_all(fn.copy_e('sig_pow', 'm'), fn.sum('m', 'sum_sigma')) # copy_e - eqv to 'm': graph.edata['sigma']. Output "m". Then sum. 
                                                                        # Again, send messages and update all nodes. Why do this step?????
        
        graph.ndata['h'] = graph.ndata['Ah'] + ((graph.ndata['sum_sigma_h'] / (graph.ndata['sum_sigma'] + 1e-6))*alpha).pow(torch.div(1,p)) # Uh + sum()

        #graph.update_all(self.message_func,self.reduce_func) 
        h = graph.ndata['h'] # result of graph convolution
        e = graph.edata['e'] # result of graph convolution
        # Call update function outside of update_all
        return h, e
Exemple #13
0
def check_apply_edges(create_node_flow):
    num_layers = 2
    for i in range(num_layers):
        g = generate_rand_graph(100)
        g.ndata["f"] = F.randn((100, 10))
        nf = create_node_flow(g, num_layers)
        nf.copy_from_parent()
        new_feats = F.randn((nf.block_size(i), 5))

        def update_func(edges):
            return {'h2': new_feats, "f2": edges.src["f"] + edges.dst["f"]}

        nf.apply_block(i, update_func)
        assert_array_equal(F.asnumpy(nf.blocks[i].data['h2']),
                           F.asnumpy(new_feats))

        # should also work for negative block ids
        nf.apply_block(-num_layers + i, update_func)
        assert_array_equal(F.asnumpy(nf.blocks[i].data['h2']),
                           F.asnumpy(new_feats))

        eids = nf.block_parent_eid(i)
        srcs, dsts = g.find_edges(eids)
        expected_f_sum = g.nodes[srcs].data["f"] + g.nodes[dsts].data["f"]
        assert_array_equal(F.asnumpy(nf.blocks[i].data['f2']),
                           F.asnumpy(expected_f_sum))

        # test built-in
        nf.apply_block(i, fn.u_add_v('f', 'f', 'f2'))
        eids = nf.block_parent_eid(i)
        srcs, dsts = g.find_edges(eids)
        expected_f_sum = g.nodes[srcs].data["f"] + g.nodes[dsts].data["f"]
        assert_array_equal(F.asnumpy(nf.blocks[i].data['f2']),
                           F.asnumpy(expected_f_sum))
Exemple #14
0
    def forward(self, g, h):
        g = g.local_var()
        if not self.use_pp or not self.training:
            norm = self.get_norm(g)

            # g.ndata['h'] = h
            # g.update_all(fn.copy_src(src='h', out='m'),
            #              fn.sum(msg='m', out='h'))
            # ah = g.ndata.pop('h')

            if self._aggre_type == 'mean':
                g.ndata['h'] = h
                g.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'h'))
                ah = g.ndata.pop('h')
            elif self._aggre_type == 'gcn':
                g.ndata['h'] = h
                g.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'h'))
                # divide in_degrees
                # degs = graph.in_degrees().float()
                # degs = degs.to(feat.device)
                # h_neigh = (graph.ndata['neigh'] + graph.ndata['h']) / (degs.unsqueeze(-1) + 1)
                ah = g.ndata.pop('h')
                ah = ah * norm
            elif self._aggre_type == 'pool':
                g.ndata['h'] = F.relu(self.fc_pool(h))
                g.update_all(fn.copy_src('h', 'm'), fn.max('m', 'h'))
                ah = g.ndata['h']
            elif self._aggre_type == 'lstm':
                g.ndata['h'] = h
                g.update_all(fn.copy_src('h', 'm'), self._lstm_reducer)
                ah = g.ndata['h']
            elif self._aggre_type == 'attn':
                feat = self.fc_attn(h).view(-1, self.num_heads, self._in_feats)
                el = (feat * self.attn_l).sum(dim=-1).unsqueeze(-1)
                er = (feat * self.attn_r).sum(dim=-1).unsqueeze(-1)
                g.ndata.update({'ft': feat, 'el': el, 'er': er})
                g.apply_edges(fn.u_add_v('el', 'er', 'e'))
                e = self.leaky_relu(g.edata.pop('e'))
                g.edata['a'] = edge_softmax(g, e)
                g.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'ft'))
                ah = g.ndata['ft']
                ah = ah.squeeze(1)
            else:
                raise KeyError('Aggregator type {} not recognized.'.format(
                    self._aggre_type))

            h = self.concat(h, ah, norm)
        if self.dropout:
            h = self.dropout(h)
        # GraphSAGE GCN does not require fc_self.
        # if self._aggre_type == 'gcn':
        #     rst = self.fc_neigh(ah)
        # else:
        #     rst = self.fc_self(h) + self.fc_neigh(ah)
        h = self.linear(h)
        h = self.lynorm(h)
        if self.activation:
            h = self.activation(h)
        return h
Exemple #15
0
 def forward(self, graph, n_feat, e_feat):
     graph = graph.local_var()
     graph.ndata['h'] = n_feat
     graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
     n_feat += graph.ndata['h']
     graph.apply_edges(fn.u_add_v('h', 'h', 'e'))
     e_feat += graph.edata['e']
     return n_feat, e_feat
 def forward(self, g, feat_src, feat_dst):
     el = (feat_src * self.attn_l).sum(dim=-1,
                                       keepdim=True)  # (N_src, K, 1)
     er = (feat_dst * self.attn_r).sum(dim=-1,
                                       keepdim=True)  # (N_dst, K, 1)
     g.srcdata['el'] = el
     g.dstdata['er'] = er
     g.apply_edges(fn.u_add_v('el', 'er', 'e'))
     return g.edata.pop('e')
Exemple #17
0
    def forward(self, g, h, e):

        ########## Message-passing sub-layer ##########

        h_in = h  # for residual connection
        e_in = e  # for residual connection

        if self.batch_norm == True:
            h = self.norm1_h(h)  # batch normalization
            e = self.norm1_e(e)  # batch normalization

        # Linear transformations of nodes and edges
        g.ndata['h'] = h
        g.edata['e'] = e
        g.ndata['Ah'] = self.A(h)  # node update, self-connection
        g.ndata['Bh'] = self.B(h)  # node update, neighbor projection
        g.ndata['Ch'] = self.C(h)  # edge update, source node projection
        g.ndata['Dh'] = self.D(h)  # edge update, destination node projection
        g.edata['Ee'] = self.E(e)  # edge update, edge projection

        # Graph convolution with dense attention mechanism
        g.apply_edges(fn.u_add_v('Ch', 'Dh', 'CDh'))
        g.edata['e'] = g.edata['CDh'] + g.edata['Ee']
        # Dense attention mechanism
        g.edata['sigma'] = torch.sigmoid(g.edata['e'])
        g.update_all(fn.u_mul_e('Bh', 'sigma', 'm'),
                     fn.sum('m', 'sum_sigma_h'))
        g.update_all(fn.copy_e('sigma', 'm'), fn.sum('m', 'sum_sigma'))
        # Gated-Mean aggregation
        g.ndata['h'] = g.ndata['Ah'] + g.ndata['sum_sigma_h'] / (
            g.ndata['sum_sigma'] + 1e-10)
        h = g.ndata['h']  # result of graph convolution
        e = g.edata['e']  # result of graph convolution

        if self.residual == True:
            h = h_in + h  # residual connection
            e = e_in + e  # residual connection

        ############ Feedforward sub-layer ############

        h_in = h  # for residual connection
        e_in = e  # for residual connection

        if self.batch_norm == True:
            h = self.norm2_h(h)  # batch normalization
            e = self.norm2_e(e)  # batch normalization

        # MLPs on updated node and edge features
        h = self.ff_h(h)
        e = self.ff_e(e)

        if self.residual == True:
            h = h_in + h  # residual connection
            e = e_in + e  # residual connection

        return h, e
Exemple #18
0
    def forward(self, graph, feat):
        r"""Compute graph attention network layer.

        Parameters
        ----------
        graph : DGLGraph
            The graph.
        feat : torch.Tensor
            The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
            is size of input feature, :math:`N` is the number of nodes.

        Returns
        -------
        torch.Tensor
            The output feature of shape :math:`(N, H, D_{out})` where :math:`H`
            is the number of heads, and :math:`D_{out}` is size of output feature.
        """
        graph = graph.local_var()
        h = self.feat_drop(feat)
        feat = self.fc(h).view(-1, self._num_heads, self._out_feats)
        if self._num_heads == 8:
            feat_normal = feat[:, 3:, :]
            feat_eig = feat[:, :3, :]
            el_normal = (feat_normal * self.attn_l).sum(dim=-1).unsqueeze(-1)
            er_normal = (feat_normal * self.attn_r).sum(dim=-1).unsqueeze(-1)
            el_eig = (th.abs(graph.ndata['eig'][:, 1:3]).unsqueeze(1).expand(
                -1, 3, -1) * self.attn_l_eig).sum(dim=-1).unsqueeze(-1)
            er_eig = (th.abs(graph.ndata['eig'][:, 1:3]).unsqueeze(1).expand(
                -1, 3, -1) * self.attn_r_eig).sum(dim=-1).unsqueeze(-1)
            el = th.cat([el_normal, el_eig], dim=1)
            er = th.cat([er_normal, er_eig], dim=1)
        else:
            el = (th.abs(graph.ndata['eig'][:, 1:3]).unsqueeze(1).expand(
                -1, self._num_heads, -1) *
                  self.attn_l).sum(dim=-1).unsqueeze(-1)
            er = (th.abs(graph.ndata['eig'][:, 1:3]).unsqueeze(1).expand(
                -1, self._num_heads, -1) *
                  self.attn_r).sum(dim=-1).unsqueeze(-1)
        graph.ndata.update({'ft': feat, 'el': el, 'er': er})
        # compute edge attention
        graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
        e = self.leaky_relu(graph.edata.pop('e'))
        # compute softmax
        graph.edata['a'] = self.attn_drop(edge_softmax(graph, e))
        # message passing
        graph.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'ft'))
        rst = graph.ndata['ft']
        # residual
        if self.res_fc is not None:
            resval = self.res_fc(h).view(h.shape[0], -1, self._out_feats)
            rst = rst + resval
        # activation
        if self.activation:
            rst = self.activation(rst)
        return rst
Exemple #19
0
    def forward(self, graph, feat):
        with graph.local_scope():
            if not self._allow_zero_in_degree:
                if (graph.in_degrees() == 0).any():
                    raise DGLError(
                        'There are 0-in-degree nodes in the graph, '
                        'output for those nodes will be invalid. '
                        'This is harmful for some applications, '
                        'causing silent performance regression. '
                        'Adding self-loop on the input graph by '
                        'calling `g = dgl.add_self_loop(g)` will resolve '
                        'the issue. Setting ``allow_zero_in_degree`` '
                        'to be `True` when constructing this module will '
                        'suppress the check and let the code run.')

            if isinstance(feat, tuple):
                h_src = self.feat_drop(feat[0])
                h_dst = self.feat_drop(feat[1])
                if not hasattr(self, 'fc_src'):
                    self.fc_src, self.fc_dst = self.fc, self.fc
                feat_src = self.fc_src(h_src).view(*h_src.shape[:-1],
                                                   self._num_heads,
                                                   self._out_feats)
                feat_dst = self.fc_dst(h_dst).view(*h_dst.shape[:-1],
                                                   self._num_heads,
                                                   self._out_feats)
            else:
                h_src = h_dst = self.feat_drop(feat)
                feat_src = feat_dst = self.fc(h_src).view(
                    *h_src.shape[:-1], self._num_heads, self._out_feats)
                if graph.is_block:
                    feat_dst = feat_src[:graph.number_of_dst_nodes()]

            el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)
            er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)
            graph.srcdata.update({'ft': feat_src, 'el': el})
            graph.dstdata.update({'er': er})
            # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
            graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
            e = self.leaky_relu(graph.edata.pop('e'))
            # compute softmax
            graph.edata['a'] = self.attn_drop(edge_softmax(graph, e))
            # message passing
            graph.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'ft'))
            rst = graph.dstdata['ft']
            # residual
            if self.res_fc is not None:
                resval = self.res_fc(h_dst).view(h_dst.shape[0], -1,
                                                 self._out_feats)
                rst = rst + resval
            # activation
            if self.activation:
                rst = self.activation(rst)
            return rst
    def forward(self,
                g,
                h,
                logits,
                old_z,
                attn_l,
                attn_r,
                shared_tau=True,
                tau_1=None,
                tau_2=None):
        g = g.local_var()
        if self.feat_drop:
            h = self.feat_drop(h)

        if hasattr(self, 'fc'):
            feat = self.fc(h).view(-1, self._num_heads, self._out_feats)
        else:
            feat = h
        g.ndata['h'] = feat  # (n_node, n_feat)
        g.ndata['logits'] = logits

        el = (feat * attn_l).sum(dim=-1).unsqueeze(-1)
        er = (feat * attn_r).sum(dim=-1).unsqueeze(-1)
        g.ndata.update({'ft': feat, 'el': el, 'er': er})
        # compute edge attention
        g.apply_edges(fn.u_add_v('el', 'er', 'e'))
        e = self.leaky_relu(g.edata.pop('e'))
        # compute softmax
        g.edata['a'] = self.attn_drop(edge_softmax(g, e))

        g.update_all(message_func=adaptive_attn_message_func,
                     reduce_func=adaptive_attn_reduce_func)
        f1 = g.ndata.pop('f1')
        f2 = g.ndata.pop('f2')
        norm_f1 = self.ln_1(f1)
        norm_f2 = self.ln_2(f2)
        if shared_tau:
            z = F.sigmoid((-1) * (norm_f1 - tau_1)) * F.sigmoid(
                (-1) * (norm_f2 - tau_2))
        else:
            # tau for each layer
            z = F.sigmoid((-1) * (norm_f1 - self.tau_1)) * F.sigmoid(
                (-1) * (norm_f2 - self.tau_2))

        gate = torch.min(old_z, z)

        agg = g.ndata.pop('agg')
        normagg = agg * g.ndata['norm'].unsqueeze(
            1)  # normalization by tgt degree

        if self.activation:
            normagg = self.activation(normagg)
        new_h = feat + gate.unsqueeze(2) * normagg
        return new_h, z
Exemple #21
0
    def forward(self, g, node_feats, edge_feats):
        """Update node representations.

        Parameters
        ----------
        g : DGLGraph
            DGLGraph for a batch of graphs
        node_feats : float32 tensor of shape (V, node_in_feats) or (V, n_head, node_in_feats)
            Input node features. V for the number of nodes in the batch of graphs.
        edge_feats : float32 tensor of shape (E, edge_in_feats)
            Input edge features. E for the number of edges in the batch of graphs.

        Returns
        -------
        float32 tensor of shape (V, node_out_feats) or (V, n_head, node_out_feats)
            Updated node features.
        """

        g = g.local_var()
        # In the paper node_src, node_dst, edge feats are concatenated
        # and multiplied with the matrix. We have optimized this step
        # by having three separate matrix multiplication.
        g.ndata['src'] = self.dropout(self.attn_src(node_feats))
        g.ndata['dst'] = self.dropout(self.attn_dst(node_feats))
        edg_atn = self.dropout(self.attn_edg(edge_feats)).unsqueeze(-2)
        g.apply_edges(fn.u_add_v('src', 'dst', 'e'))
        atn_scores = self.act(g.edata.pop('e') + edg_atn)

        atn_scores = self.attn_dot(atn_scores)
        atn_scores = self.dropout(edge_softmax(g, atn_scores))

        g.ndata['src'] = self.msg_src(node_feats)
        g.ndata['dst'] = self.msg_dst(node_feats)
        g.apply_edges(fn.u_add_v('src', 'dst', 'e'))
        atn_inp = g.edata.pop('e') + self.msg_edg(edge_feats).unsqueeze(-2)
        atn_inp = self.act(atn_inp)
        g.edata['msg'] = atn_scores * atn_inp
        g.update_all(fn.copy_e('msg', 'm'), fn.sum('m', 'feat'))
        out = g.ndata.pop('feat') + self.wgt_n(node_feats)
        return self.act(out)
Exemple #22
0
 def expected_output():
     g.srcdata.update({'ft': feat_src, 'el': el})
     g.dstdata.update({'er': er})
     g.apply_edges(fn.u_add_v('el', 'er', 'e'))
     e = leaky_relu(g.edata.pop('e'))
     g.edata['out'] = th.exp(e)
     g.update_all(fn.copy_e('out', 'm'), fn.sum('m', 'out_sum'))
     g.apply_edges(fn.e_div_v('out', 'out_sum', 'out1'))
     # Omit attn_drop for deterministic execution
     g.edata['a'] = g.edata['out1']
     # message passing
     g.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'ft'))
     rst = g.dstdata['ft']
     return rst
    def forward(self, graph: dgl.DGLHeteroGraph, feat: tuple, dst_node_transformation_weight: nn.Parameter,
                src_node_transformation_weight: nn.Parameter, src_nodes_attention_weight: nn.Parameter):
        r"""Compute graph attention network layer.
        Parameters
        ----------
        graph : specific relational DGLHeteroGraph
        feat : pair of torch.Tensor
            The pair contains two tensors of shape (N_{in}, D_{in_{src}})` and (N_{out}, D_{in_{dst}}).
        dst_node_transformation_weight: Parameter (input_dst_dim, n_heads * hidden_dim)
        src_node_transformation_weight: Parameter (input_src_dim, n_heads * hidden_dim)
        src_nodes_attention_weight: Parameter (n_heads, 2 * hidden_dim)
        Returns
        -------
        torch.Tensor, shape (N, H, D_out)` where H is the number of heads, and D_out is size of output feature.
        """
        graph = graph.local_var()
        # Tensor, (N_src, input_src_dim)
        feat_src = self.dropout(feat[0])
        # Tensor, (N_dst, input_dst_dim)
        feat_dst = self.dropout(feat[1])
        # Tensor, (N_src, n_heads, hidden_dim) -> (N_src, input_src_dim) * (input_src_dim, n_heads * hidden_dim)
        feat_src = torch.matmul(feat_src, src_node_transformation_weight).view(-1, self._num_heads, self._out_feats)
        # Tensor, (N_dst, n_heads, hidden_dim) -> (N_dst, input_dst_dim) * (input_dst_dim, n_heads * hidden_dim)
        feat_dst = torch.matmul(feat_dst, dst_node_transformation_weight).view(-1, self._num_heads, self._out_feats)

        # first decompose the weight vector into [a_l || a_r], then
        # a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j, This implementation is much efficient
        # Tensor, (N_dst, n_heads, 1),   (N_dst, n_heads, hidden_dim) * (n_heads, hidden_dim)
        e_dst = (feat_dst * src_nodes_attention_weight[:, :self._out_feats]).sum(dim=-1, keepdim=True)
        # Tensor, (N_src, n_heads, 1),   (N_src, n_heads, hidden_dim) * (n_heads, hidden_dim)
        e_src = (feat_src * src_nodes_attention_weight[:, self._out_feats:]).sum(dim=-1, keepdim=True)
        # (N_src, n_heads, hidden_dim), (N_src, n_heads, 1)
        graph.srcdata.update({'ft': feat_src, 'e_src': e_src})
        # (N_dst, n_heads, 1)
        graph.dstdata.update({'e_dst': e_dst})
        # compute edge attention, e_src and e_dst are a_src * Wh_src and a_dst * Wh_dst respectively.
        graph.apply_edges(fn.u_add_v('e_src', 'e_dst', 'e'))
        # shape (edges_num, heads, 1)
        e = self.leaky_relu(graph.edata.pop('e'))

        # compute softmax
        graph.edata['a'] = edge_softmax(graph, e)

        graph.update_all(fn.u_mul_e('ft', 'a', 'msg'), fn.sum('msg', 'ft'))
        # (N_dst, n_heads * hidden_dim),   (N_dst, n_heads, hidden_dim) reshape
        dst_features = graph.dstdata.pop('ft').reshape(-1, self._num_heads * self._out_feats)

        dst_features = F.relu(dst_features)

        return dst_features
Exemple #24
0
    def forward(self, graph, feat, get_attention=False):
            # Check in degree and generate error
            if (graph.in_degrees()==0).any():
                raise DGLError('There are 0-in-degree nodes in the graph, '
                                   'output for those nodes will be invalid. '
                                   'This is harmful for some applications, '
                                   'causing silent performance regression. '
                                   'Adding self-loop on the input graph by '
                                   'calling `g = dgl.add_self_loop(g)` will resolve '
                                   'the issue. Setting ``allow_zero_in_degree`` '
                                   'to be `True` when constructing this module will '
                                   'suppress the check and let the code run.')
            # projection process to get importance vector y
            graph.ndata['y'] = torch.abs(torch.matmul(self.p,feat.T).view(-1))/torch.norm(self.p,p=2)
            # Use edge message passing function to get the weight from src node
            graph.apply_edges(fn.copy_u('y','y'))
            # Select Top k neighbors
            subgraph = select_topk(graph,self.k,'y')
            # Sigmoid as information threshold
            subgraph.ndata['y'] = torch.sigmoid(subgraph.ndata['y'])
            # Using vector matrix elementwise mul for acceleration
            feat = subgraph.ndata['y'].view(-1,1)*feat
            feat = self.feat_drop(feat)
            h = self.fc(feat).view(-1, self.num_heads, self.out_feats)
            el = (h * self.attn_l).sum(dim=-1).unsqueeze(-1)
            er = (h * self.attn_r).sum(dim=-1).unsqueeze(-1)
            # Assign the value on the subgraph
            subgraph.srcdata.update({'ft': h, 'el': el})
            subgraph.dstdata.update({'er': er})
            # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
            subgraph.apply_edges(fn.u_add_v('el', 'er', 'e'))
            e = self.leaky_relu(subgraph.edata.pop('e'))
            # compute softmax
            subgraph.edata['a'] = self.attn_drop(edge_softmax(subgraph, e))
            # message passing
            subgraph.update_all(fn.u_mul_e('ft', 'a', 'm'),
                             fn.sum('m', 'ft'))
            rst = subgraph.dstdata['ft']
            # activation
            if self.activation:
                rst = self.activation(rst)
            # Residual
            if self.residual:
                rst = rst + self.residual_module(feat).view(feat.shape[0],-1,self.out_feats)

            if get_attention:
                return rst, subgraph.edata['a']
            else:
                return rst
Exemple #25
0
 def call(self, graph, feat):
     with graph.local_scope():
         if isinstance(feat, tuple):
             h_src = self.feat_drop(feat[0])
             h_dst = self.feat_drop(feat[1])
             if not hasattr(self, 'fc_src'):
                 self.fc_src, self.fc_dst = self.fc, self.fc
             feat_src = tf.reshape(self.fc_src(h_src),
                                   (-1, self._num_heads, self._out_feats))
             feat_dst = tf.reshape(self.fc_dst(h_dst),
                                   (-1, self._num_heads, self._out_feats))
         else:
             h_src = h_dst = self.feat_drop(feat)
             feat_src = feat_dst = tf.reshape(
                 self.fc(h_src), (-1, self._num_heads, self._out_feats))
             if graph.is_block:
                 feat_dst = feat_src[:graph.number_of_dst_nodes()]
         # NOTE: GAT paper uses "first concatenation then linear projection"
         # to compute attention scores, while ours is "first projection then
         # addition", the two approaches are mathematically equivalent:
         # We decompose the weight vector a mentioned in the paper into
         # [a_l || a_r], then
         # a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j
         # Our implementation is much efficient because we do not need to
         # save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,
         # addition could be optimized with DGL's built-in function u_add_v,
         # which further speeds up computation and saves memory footprint.
         el = tf.reduce_sum(feat_src * self.attn_l, axis=-1, keepdims=True)
         er = tf.reduce_sum(feat_dst * self.attn_r, axis=-1, keepdims=True)
         graph.srcdata.update({'ft': feat_src, 'el': el})
         graph.dstdata.update({'er': er})
         # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
         graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
         e = self.leaky_relu(graph.edata.pop('e'))
         # compute softmax
         graph.edata['a'] = self.attn_drop(edge_softmax(graph, e))
         # message passing
         graph.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'ft'))
         rst = graph.dstdata['ft']
         # residual
         if self.res_fc is not None:
             resval = tf.reshape(self.res_fc(h_dst),
                                 (h_dst.shape[0], -1, self._out_feats))
             rst = rst + resval
         # activation
         if self.activation:
             rst = self.activation(rst)
         return rst
Exemple #26
0
 def forward(self, graph, feat, soft_label):
     with graph.local_scope():
         if not self._allow_zero_in_degree:
             if (graph.in_degrees() == 0).any():
                 raise DGLError('There are 0-in-degree nodes in the graph, '
                                'output for those nodes will be invalid. '
                                'This is harmful for some applications, '
                                'causing silent performance regression. '
                                'Adding self-loop on the input graph by '
                                'calling `g = dgl.add_self_loop(g)` will resolve '
                                'the issue. Setting ``allow_zero_in_degree`` '
                                'to be `True` when constructing this module will '
                                'suppress the check and let the code run.')
         if self.ptype == 'ind':
             feat_src = h_dst = self.feat_drop(feat)
             el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)
             er = th.zeros(graph.num_nodes(), device=graph.device)
         elif self.ptype == 'tra':
             feat_src = self.feat_drop(self.fc_emb)
             feat_dst = h_dst = th.zeros(graph.num_nodes(), device=graph.device)
             el = feat_src
             er = feat_dst
         cog_label = soft_label
         graph.srcdata.update({'ft': cog_label, 'el': el})
         graph.dstdata.update({'er': er})
         # # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
         graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
         # graph.edata['e'] = th.ones(graph.num_edges(), device=graph.device)  # non-parameterized PLP
         e = graph.edata.pop('e')
         # compute softmax
         graph.edata['a'] = self.attn_drop(edge_softmax(graph, e))
         att = graph.edata['a'].squeeze()
         # message passing
         graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
                          fn.sum('m', 'ft'))
         if self.mlp_layers > 0:
             rst = th.sigmoid(self.lr_alpha) * graph.dstdata['ft'] + \
                   th.sigmoid(-self.lr_alpha) * self.mlp(feat)
         else:
             rst = graph.dstdata['ft']
         # residual
         if self.res_fc is not None:
             resval = self.res_fc(h_dst)
             rst = rst + resval
         # activation
         if self.activation:
             rst = self.activation(rst)
         return rst, att, th.sigmoid(self.lr_alpha).squeeze(), el.squeeze(), er.squeeze()
Exemple #27
0
 def forward(self, sg, feat):
     with sg.local_scope():
         if self.batch_norm is not None:
             feat = self.batch_norm(feat)
         q = self.fc_q(feat)
         k = self.fc_k(feat)
         v = self.fc_v(feat)
         sg.ndata.update({'q': q, 'k': k, 'v': v})
         sg.apply_edges(fn.u_add_v('q', 'k', 'e'))
         e = self.attn_e(th.sigmoid(sg.edata['e']))
         sg.edata['a'] = edge_softmax(sg, e)
         sg.update_all(fn.u_mul_e('v', 'a', 'm'), fn.sum('m', 'ft'))
         rst = sg.ndata['ft']
         if self.activation is not None:
             rst = self.activation(rst)
         return rst
 def forward(self, feat):
     # equation (1)
     g = self.graph.local_var()
     g.ndata['h'] = feat.mm(getattr(self, 'W'))
     g.ndata['el'] = feat.mm(getattr(self, 'al'))
     g.ndata['er'] = feat.mm(getattr(self, 'ar'))
     g.apply_edges(fn.u_add_v('el', 'er', 'e'))
     # message passing
     g.update_all(fn.src_mul_edge('h', 'w', 'm'), fn.sum('m', 'h'))
     e = F.leaky_relu(g.edata['e'])
     # compute softmax
     g.edata['w'] = F.softmax(e)
     rst = g.ndata['h']
     #rst = self.linear(rst)
     #rst = self.activation(rst)
     return rst
Exemple #29
0
    def forward(self, graph: DGLGraph, features, drop_edge_ids=None):
        ###Attention computation: pre-normalization structure
        graph = graph.local_var()
        h = self.graph_norm(features)
        # feat_head = self.fc_head(self.feat_drop(h)).view(-1, self._num_heads, self._att_dim)
        # feat_tail = self.fc_tail(self.feat_drop(h)).view(-1, self._num_heads, self._att_dim)
        feat_head = torch.tanh(self.fc_head(self.feat_drop(h))).view(
            -1, self._num_heads, self._att_dim)
        feat_tail = torch.tanh(self.fc_tail(self.feat_drop(h))).view(
            -1, self._num_heads, self._att_dim)
        # feat_head = F.relu(self.fc_head(self.feat_drop(h))).view(-1, self._num_heads, self._att_dim)
        # feat_tail = F.relu(self.fc_tail(self.feat_drop(h))).view(-1, self._num_heads, self._att_dim)
        feat = self.fc(self.feat_drop(h)).view(-1, self._num_heads,
                                               self._att_dim)
        eh = (feat_head * self.attn_h).sum(dim=-1).unsqueeze(-1)
        et = (feat_tail * self.attn_t).sum(dim=-1).unsqueeze(-1)
        graph.ndata.update({'ft': feat, 'eh': eh, 'et': et})
        graph.apply_edges(fn.u_add_v('eh', 'et', 'e'))
        attations = graph.edata.pop('e')
        attations = self.leaky_relu(attations)
        if drop_edge_ids is not None:
            attations[drop_edge_ids] = self.attention_mask_value

        if self.top_k <= 0:
            graph.edata['a'] = edge_softmax(graph, attations)
        else:
            if self.topk_type == 'local':
                graph.edata['e'] = attations
                attations = self.topk_attention(graph)
                graph.edata['a'] = edge_softmax(
                    graph, attations)  ##return attention scores
            else:
                graph.edata['e'] = edge_softmax(graph, attations)
                graph.edata['a'] = self.topk_attention_softmax(graph)

        rst = self.ppr_estimation(graph=graph)
        rst = rst.flatten(1)
        rst = self.fc_out(rst)
        resval = self.res_fc(features)
        rst = resval + self.feat_drop(rst)

        rst_ff = self.feed_forward(self.ff_norm(rst))
        rst = rst + self.feat_drop(rst_ff)
        # +++++++
        attations = graph.edata.pop('a')
        # +++++++
        return rst, attations
    def forward(self, g, h, e):
        
        h_in = h # for residual connection
        e_in = e # for residual connection
        
        g.ndata['h']  = h 
        g.ndata['Ah'] = self.A(h) 
        g.ndata['Bh'] = self.B(h) 
        g.ndata['Dh'] = self.D(h)
        g.ndata['Eh'] = self.E(h) 
        g.edata['e']  = e 
        g.edata['Ce'] = self.C(e) 

        g.apply_edges(fn.u_add_v('Dh', 'Eh', 'DEh'))
        g.edata['e'] = g.edata['DEh'] + g.edata['Ce']
        g.edata['sigma'] = torch.sigmoid(g.edata['e'])
        g.update_all(fn.copy_e('sigma', 'm'), fn.sum('m', 'sum_sigma')) 

        g.ndata['eee'] = g.ndata['Bh'] / (g.ndata['sum_sigma'] + 1e-6) ### bring here
        
        g.update_all(fn.u_mul_e('eee', 'sigma', 'm'), self._reducer) 

        g.ndata['h'] = g.ndata['Ah'] + g.ndata['sum_sigma_h']  # dis here 

        #h, e = self.update_all_p_norm(g)
        h = g.ndata['h'] # result of graph convolution
        e = g.edata['e'] # result of graph convolution

        if self.batch_norm:
            h = self.bn_node_h(h) # batch normalization  
            e = self.bn_node_e(e) # batch normalization  
        
        h = F.relu(h) # non-linear activation
        e = F.relu(e) # non-linear activation

        if self.residual:
            h = h_in + h # residual connection
            e = e_in + e # residual connection
        
        h = F.dropout(h, self.dropout, training=self.training)
        e = F.dropout(e, self.dropout, training=self.training)
       
        return h, e