Exemplo n.º 1
0
    def infer(self, g, nids_eq_pos, eids_eq_pos, nids_eq_pos_leaf, g_inter,
              readout_ids):
        # Part I: self-attention
        h = g.nodes[nids_eq_pos].data['h']
        if self.rel_pos:
            g.edges[eids_eq_pos].data['ak'] = self.embed_ak(
                g.edges[eids_eq_pos].data['etype'])

        g.nodes[nids_eq_pos].data['q'] = self.proj_q[0](h).view(
            -1, self.h, self.d_k)
        g.nodes[nids_eq_pos].data['k'] = self.proj_k[0](h).view(
            -1, self.h, self.d_k)
        g.nodes[nids_eq_pos].data['v'] = self.proj_v[0](h).view(
            -1, self.h, self.d_k)

        g.apply_edges(
            lambda edges:
            {'e': (edges.src['k'] * edges.dst['q']).sum(dim=-1, keepdim=True)},
            eids_eq_pos)
        e = g.edges[eids_eq_pos].data['e']
        # relative positional encoding
        if self.rel_pos:
            g.apply_edges(
                lambda edges: {
                    'e_rel': (edges.data['ak'].unsqueeze(1) * edges.dst['q']).
                    sum(dim=-1, keepdim=True)
                }, eids_eq_pos)
            e = e + g.edges[eids_eq_pos].data['e_rel']
        # softmax
        g.edges[eids_eq_pos].data['a'] = self.drop_att[0](edge_softmax(
            g, e / np.sqrt(self.d_k), eids_eq_pos))
        # spmm
        g.send_and_recv(eids_eq_pos, fn.u_mul_e('v', 'a', 'm'),
                        fn.sum('m', 'o'))
        o = g.nodes[nids_eq_pos].data['o'].view(-1, self.d_k * self.h)
        o = self.drop_h[0](self.proj_o[0](o))
        g.nodes[nids_eq_pos].data['h'] = self.norm_in[0](h + o)

        # Part II: attend to memory
        h = g.nodes[nids_eq_pos_leaf].data['h']
        q = self.proj_q[1](h).view(-1, self.h, self.d_k)
        g_inter.nodes[readout_ids].data['q'] = q
        g_inter.apply_edges(
            lambda edges:
            {'e': (edges.src['k'] * edges.dst['q']).sum(dim=-1, keepdim=True)})
        # softmax
        g_inter.edata['a'] = self.drop_att[1](edge_softmax(
            g_inter, g_inter.edata['e'] / np.sqrt(self.d_k)))
        # spmm
        g_inter.update_all(fn.u_mul_e('v', 'a', 'm'), fn.sum('m', 'o'))
        o = g_inter.nodes[readout_ids].data['o'].view(-1, self.d_k * self.h)
        o = self.drop_h[1](self.proj_o[1](o))
        g.nodes[nids_eq_pos_leaf].data['h'] = h + o
        h = self.norm_in[1](g.nodes[nids_eq_pos].data['h'])

        # FFN
        h = self.norm_inter(h + self.ffn(h))
        g.nodes[nids_eq_pos].data['h'] = h
Exemplo n.º 2
0
 def forward(self, graph, feat):
     graph = graph.local_var()
     if isinstance(feat, tuple):
         h_src = self.feat_drop(feat[0])
         h_dst = self.feat_drop(feat[1])
         feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats)
         feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats)
     else:
         h_src = h_dst = self.feat_drop(feat)
         feat_src = feat_dst = self.fc(h_src).view(
             -1, self._num_heads, self._out_feats)
     el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)
     er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)
     graph.srcdata.update({'ft': feat_src, 'el': el})
     graph.dstdata.update({'er': er})
     # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
     graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
     e = self.leaky_relu(graph.edata.pop('e'))
     # compute softmax
     graph.edata['a'] = self.attn_drop(edge_softmax(graph, e))
     # message passing
     graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
                      fn.sum('m', 'ft'))
     rst = graph.dstdata['ft']
     # residual
     if self.res_fc is not None:
         resval = self.res_fc(h_dst).view(h_dst.shape[0], -1, self._out_feats)
         rst = rst + resval
     # activation
     if self.activation:
         rst = self.activation(rst)
     return rst
Exemplo n.º 3
0
    def forward(self, graph: DGLGraph, features, drop_edge_ids=None):
        ###Attention computation: pre-normalization structure
        graph = graph.local_var()
        h = self.graph_norm(features)
        # feat_head = self.fc_head(self.feat_drop(h)).view(-1, self._num_heads, self._att_dim)
        # feat_tail = self.fc_tail(self.feat_drop(h)).view(-1, self._num_heads, self._att_dim)
        feat_head = torch.tanh(self.fc_head(self.feat_drop(h))).view(
            -1, self._num_heads, self._att_dim)
        feat_tail = torch.tanh(self.fc_tail(self.feat_drop(h))).view(
            -1, self._num_heads, self._att_dim)
        # feat_head = F.relu(self.fc_head(self.feat_drop(h))).view(-1, self._num_heads, self._att_dim)
        # feat_tail = F.relu(self.fc_tail(self.feat_drop(h))).view(-1, self._num_heads, self._att_dim)
        feat = self.fc(self.feat_drop(h)).view(-1, self._num_heads,
                                               self._att_dim)
        eh = (feat_head * self.attn_h).sum(dim=-1).unsqueeze(-1)
        et = (feat_tail * self.attn_t).sum(dim=-1).unsqueeze(-1)
        graph.ndata.update({'ft': feat, 'eh': eh, 'et': et})
        graph.apply_edges(fn.u_add_v('eh', 'et', 'e'))
        attations = graph.edata.pop('e')
        attations = self.leaky_relu(attations)
        if drop_edge_ids is not None:
            attations[drop_edge_ids] = self.attention_mask_value

        if self.top_k <= 0:
            graph.edata['a'] = edge_softmax(graph, attations)
        else:
            if self.topk_type == 'local':
                graph.edata['e'] = attations
                attations = self.topk_attention(graph)
                graph.edata['a'] = edge_softmax(
                    graph, attations)  ##return attention scores
            else:
                graph.edata['e'] = edge_softmax(graph, attations)
                graph.edata['a'] = self.topk_attention_softmax(graph)

        rst = self.ppr_estimation(graph=graph)
        rst = rst.flatten(1)
        rst = self.fc_out(rst)
        resval = self.res_fc(features)
        rst = resval + self.feat_drop(rst)

        rst_ff = self.feed_forward(self.ff_norm(rst))
        rst = rst + self.feat_drop(rst_ff)
        # +++++++
        attations = graph.edata.pop('a')
        # +++++++
        return rst, attations
Exemplo n.º 4
0
 def compute_attention(self, g):
     ## compute attention weight and store it on edges
     g = g.local_var()
     for i in range(self._n_relations):
         e_idxs = g.filter_edges(lambda edges: edges.data['type'] == i)
         self.W_r = self.W_R[i]
         g.apply_edges(self._att_score, e_idxs)
     w = edge_softmax(g, g.edata.pop('att_w'))
     return w
Exemplo n.º 5
0
    def forward(self, graph, feat):
        r"""Compute graph attention network layer.

        Parameters
        ----------
        graph : DGLGraph
            The graph.
        feat : torch.Tensor
            The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
            is size of input feature, :math:`N` is the number of nodes.

        Returns
        -------
        torch.Tensor
            The output feature of shape :math:`(N, H, D_{out})` where :math:`H`
            is the number of heads, and :math:`D_{out}` is size of output feature.
        """
        graph = graph.local_var()
        h = self.feat_drop(feat)
        feat = self.fc(h).view(-1, self._num_heads, self._out_feats)
        if self._num_heads == 8:
            feat_normal = feat[:, 3:, :]
            feat_eig = feat[:, :3, :]
            el_normal = (feat_normal * self.attn_l).sum(dim=-1).unsqueeze(-1)
            er_normal = (feat_normal * self.attn_r).sum(dim=-1).unsqueeze(-1)
            el_eig = (th.abs(graph.ndata['eig'][:, 1:3]).unsqueeze(1).expand(
                -1, 3, -1) * self.attn_l_eig).sum(dim=-1).unsqueeze(-1)
            er_eig = (th.abs(graph.ndata['eig'][:, 1:3]).unsqueeze(1).expand(
                -1, 3, -1) * self.attn_r_eig).sum(dim=-1).unsqueeze(-1)
            el = th.cat([el_normal, el_eig], dim=1)
            er = th.cat([er_normal, er_eig], dim=1)
        else:
            el = (th.abs(graph.ndata['eig'][:, 1:3]).unsqueeze(1).expand(
                -1, self._num_heads, -1) *
                  self.attn_l).sum(dim=-1).unsqueeze(-1)
            er = (th.abs(graph.ndata['eig'][:, 1:3]).unsqueeze(1).expand(
                -1, self._num_heads, -1) *
                  self.attn_r).sum(dim=-1).unsqueeze(-1)
        graph.ndata.update({'ft': feat, 'el': el, 'er': er})
        # compute edge attention
        graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
        e = self.leaky_relu(graph.edata.pop('e'))
        # compute softmax
        graph.edata['a'] = self.attn_drop(edge_softmax(graph, e))
        # message passing
        graph.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'ft'))
        rst = graph.ndata['ft']
        # residual
        if self.res_fc is not None:
            resval = self.res_fc(h).view(h.shape[0], -1, self._out_feats)
            rst = rst + resval
        # activation
        if self.activation:
            rst = self.activation(rst)
        return rst
Exemplo n.º 6
0
 def forward(self, g, feature):
     with g.local_scope():
         z = self.fc(feature)
         g.ndata['z'] = z
         # Equation (2)
         g.apply_edges(self.edge_attention)  # calculate e_{ij}
         # Calculate softmax on source code -> on the reversed graph
         rg = g.reverse(copy_ndata=False, copy_edata=True)
         g.edata['alpha'] = edge_softmax(rg, rg.edata['e'])
         # Equation (3)
         g.update_all(self.message_func, self.reduce_func)
         # output            
         h = g.ndata['h']
         return h
Exemplo n.º 7
0
    def forward(self, graph: dgl.DGLHeteroGraph, feat: tuple, dst_node_transformation_weight: nn.Parameter,
                src_node_transformation_weight: nn.Parameter, src_nodes_attention_weight: nn.Parameter):
        r"""Compute graph attention network layer.
        Parameters
        ----------
        graph : specific relational DGLHeteroGraph
        feat : pair of torch.Tensor
            The pair contains two tensors of shape (N_{in}, D_{in_{src}})` and (N_{out}, D_{in_{dst}}).
        dst_node_transformation_weight: Parameter (input_dst_dim, n_heads * hidden_dim)
        src_node_transformation_weight: Parameter (input_src_dim, n_heads * hidden_dim)
        src_nodes_attention_weight: Parameter (n_heads, 2 * hidden_dim)
        Returns
        -------
        torch.Tensor, shape (N, H, D_out)` where H is the number of heads, and D_out is size of output feature.
        """
        graph = graph.local_var()
        # Tensor, (N_src, input_src_dim)
        feat_src = self.dropout(feat[0])
        # Tensor, (N_dst, input_dst_dim)
        feat_dst = self.dropout(feat[1])
        # Tensor, (N_src, n_heads, hidden_dim) -> (N_src, input_src_dim) * (input_src_dim, n_heads * hidden_dim)
        feat_src = torch.matmul(feat_src, src_node_transformation_weight).view(-1, self._num_heads, self._out_feats)
        # Tensor, (N_dst, n_heads, hidden_dim) -> (N_dst, input_dst_dim) * (input_dst_dim, n_heads * hidden_dim)
        feat_dst = torch.matmul(feat_dst, dst_node_transformation_weight).view(-1, self._num_heads, self._out_feats)

        # first decompose the weight vector into [a_l || a_r], then
        # a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j, This implementation is much efficient
        # Tensor, (N_dst, n_heads, 1),   (N_dst, n_heads, hidden_dim) * (n_heads, hidden_dim)
        e_dst = (feat_dst * src_nodes_attention_weight[:, :self._out_feats]).sum(dim=-1, keepdim=True)
        # Tensor, (N_src, n_heads, 1),   (N_src, n_heads, hidden_dim) * (n_heads, hidden_dim)
        e_src = (feat_src * src_nodes_attention_weight[:, self._out_feats:]).sum(dim=-1, keepdim=True)
        # (N_src, n_heads, hidden_dim), (N_src, n_heads, 1)
        graph.srcdata.update({'ft': feat_src, 'e_src': e_src})
        # (N_dst, n_heads, 1)
        graph.dstdata.update({'e_dst': e_dst})
        # compute edge attention, e_src and e_dst are a_src * Wh_src and a_dst * Wh_dst respectively.
        graph.apply_edges(fn.u_add_v('e_src', 'e_dst', 'e'))
        # shape (edges_num, heads, 1)
        e = self.leaky_relu(graph.edata.pop('e'))

        # compute softmax
        graph.edata['a'] = edge_softmax(graph, e)

        graph.update_all(fn.u_mul_e('ft', 'a', 'msg'), fn.sum('msg', 'ft'))
        # (N_dst, n_heads * hidden_dim),   (N_dst, n_heads, hidden_dim) reshape
        dst_features = graph.dstdata.pop('ft').reshape(-1, self._num_heads * self._out_feats)

        dst_features = F.relu(dst_features)

        return dst_features
Exemplo n.º 8
0
    def forward(self, v, k: Dict = None, q: Dict = None, G=None, **kwargs):
        """Forward pass of the linear layer

        Args:
            G: minibatch of (h**o)graphs
            v: dict of value edge-features
            k: dict of key edge-features
            q: dict of query node-features
        Returns: 
            tensor with new features [B, n_points, n_features_out]
        """
        with G.local_scope():
            # Add node features to local graph scope
            ## We use the stacked tensor representation for attention
            for m, d in self.f_value.structure:
                G.edata[f'v{d}'] = v[f'{d}'].view(-1, self.n_heads,
                                                  m // self.n_heads, 2 * d + 1)
            G.edata['k'] = fiber2head(k,
                                      self.n_heads,
                                      self.f_key,
                                      squeeze=True)
            G.ndata['q'] = fiber2head(q,
                                      self.n_heads,
                                      self.f_key,
                                      squeeze=True)

            # Compute attention weights
            ## Inner product between (key) neighborhood and (query) center
            G.apply_edges(fn.e_dot_v('k', 'q', 'e'))

            ## Apply softmax
            e = G.edata.pop('e')
            if self.new_dgl:
                # in dgl 5.3, e has an extra dimension compared to dgl 4.3
                # the following, we get rid of this be reshaping
                n_edges = G.edata['k'].shape[0]
                e = e.view([n_edges, self.n_heads])
            e = e / np.sqrt(self.f_key.n_features)
            G.edata['a'] = edge_softmax(G, e)

            # Perform attention-weighted message-passing
            for d in self.f_value.degrees:
                G.update_all(self.udf_u_mul_e(d), fn.sum('m', f'out{d}'))

            output = {}
            for m, d in self.f_value.structure:
                output[f'{d}'] = G.ndata[f'out{d}'].view(-1, m, 2 * d + 1)

            return output
Exemplo n.º 9
0
    def forward(self, g, feature):
        g = g.local_var()
        g.ndata['v'] = self.V(feature).view(-1, self._num_heads,
                                            self._out_feats)
        g.ndata['q'] = self.Q(feature).view(-1, self._num_heads,
                                            self._out_feats)
        g.ndata['k'] = self.K(feature).view(-1, self._num_heads,
                                            self._out_feats)

        g.apply_edges(fn.u_mul_v('q', 'k', 'u'))
        #e*h*1
        u = g.edata['u'].sum(-1, keepdim=True) * (self._out_feats)**(-0.5)
        a = edge_softmax(g, u)
        g.edata['a'] = a
        g.update_all(fn.u_mul_e('v', 'a', 'm'), fn.sum('m', 'ft'))
        #n*(h*in_feats)
        rst = self.scale(g.ndata['ft'].view(-1,
                                            self._out_feats * self._num_heads))
        if self.res_fc is not None:
            rst = rst.view(-1, self._num_heads,
                           self._in_feats).sum(dim=1) + feature
        return rst
Exemplo n.º 10
0
    def graph_nn(self, g, h, ctx, c, graph_membership):

        g = g.local_var()
        c_broadcast = F.embedding(graph_membership, c)
        fuse = self.W4(self.read_drop(h)) * self.W5(self.read_drop(ctx))
        cat = th.cat([h, ctx, fuse], dim=1)

        src_ctx = self.W7(cat) * self.W8(c_broadcast)
        dst_ctx = self.W6(cat)

        g.srcdata.update({"s_e": src_ctx})
        g.dstdata.update({"d_e": dst_ctx})
        g.apply_edges(fn.u_dot_v("s_e", "d_e", "e"))
        e = g.edata.pop('e')

        g.edata['a'] = edge_softmax(g, e)
        g.ndata['ft'] = self.W9(cat) * self.W10(c_broadcast)

        g.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 's'))
        ctx = self.W11(ctx) + self.W11b(g.ndata['s'])

        rst = ctx

        return rst
Exemplo n.º 11
0
    def forward(self, graph, feat):
        r"""Compute graph attention network layer.

        Parameters
        ----------
        graph : DGLGraph
            The graph.
        feat : torch.Tensor or pair of torch.Tensor
            If a torch.Tensor is given, the input feature of shape :math:`(N, D_{in})` where
            :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
            If a pair of torch.Tensor is given, the pair must contain two tensors of shape
            :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.

        Returns
        -------
        torch.Tensor
            The output feature of shape :math:`(N, H, D_{out})` where :math:`H`
            is the number of heads, and :math:`D_{out}` is size of output feature.
        """
        with graph.local_scope():
            if isinstance(feat, tuple):
                assert False
                h_src = self.feat_drop(feat[0])
                h_dst = self.feat_drop(feat[1])
                feat_src = self.fc_src(h_src).view(-1, self._num_heads,
                                                   self._out_feats)
                feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads,
                                                   self._out_feats)
            else:
                h_src = h_dst = self.feat_drop(feat)
                feat_src = feat_dst = self.fc(h_src).view(
                    -1, self._num_heads, self._out_feats)

            # NOTE: GAT paper uses "first concatenation then linear projection"
            # to compute attention scores, while ours is "first projection then
            # addition", the two approaches are mathematically equivalent:
            # We decompose the weight vector a mentioned in the paper into
            # [a_l || a_r], then
            # a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j
            # Our implementation is much efficient because we do not need to
            # save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,
            # addition could be optimized with DGL's built-in function u_add_v,
            # which further speeds up computation and saves memory footprint.

            tmp = feat_src

            feat_src = feat_dst = self.apply_allgather(feat_src)

            el = (feat_src * self.attn_l).sum(dim=-1, keepdim=True)
            #el = tmp * self.attn_l
            #el = self.apply_allgather(tmp)
            #el = el.sum(dim=-1, keepdim=True)
            # el = self.apply_allgather(el)

            er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)
            # er = self.pad_remote(er)
            # er = self.apply_allgather(er)

            graph.srcdata.update({'ft': feat_src, 'el': el})
            graph.dstdata.update({'er': er})
            # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
            graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
            e = self.leaky_relu(graph.edata.pop('e'))
            # compute softmax
            graph.edata['a'] = self.attn_drop(edge_softmax(graph, e))
            # message passing
            graph.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'ft'))
            rst = graph.dstdata['ft']
            rst = rst[:self._local_n_nodes, ...]
            # residual
            if self.res_fc is not None:
                if not self._apply_gather:
                    h_dst = h_dst[:self._local_n_nodes, ...]
                resval = self.res_fc(h_dst).view(h_dst.shape[0], -1,
                                                 self._out_feats)
                rst += resval
            # activation
            if self.activation:
                rst = self.activation(rst)
            return rst
Exemplo n.º 12
0
 def node_update(self, graph, e_feat, n_feat):
     graph.edata['a'] = edge_softmax(graph, e_feat)
     graph.update_all(fn.u_mul_e('W_3h', 'a', 'm'), fn.sum('m', 'n_tmp'))
     n_ = self.relu(
         self.node_batch_norm(graph.ndata['n_tmp'] + graph.ndata['W_2h']))
     return n_ + n_feat