Ejemplo n.º 1
0
    def forward(self, g, feat):
        """
        :param g: DGLGraph 二分图(只包含一种关系)
        :param feat: tensor(N_src, d_in) or (tensor(N_src, d_in), tensor(N_dst, d_in)) 输入特征
        :return: tensor(N_dst, d_out) 目标顶点该关于关系的表示
        """
        with g.local_scope():
            feat_src, feat_dst = expand_as_pair(feat, g)
            # (N_src, d_in) -> (N_src, d_out) -> (N_src, K, d_out/K)
            k = self.k_linear(feat_src).view(-1, self.num_heads, self.d_k)
            v = self.v_linear(feat_src).view(-1, self.num_heads, self.d_k)
            q = self.q_linear(feat_dst).view(-1, self.num_heads, self.d_k)

            # k[:, h] @= w_att[h] => k[n, h, j] = ∑(i) k[n, h, i] * w_att[h, i, j]
            k = torch.einsum('nhi,hij->nhj', k, self.w_att)
            v = torch.einsum('nhi,hij->nhj', v, self.w_msg)

            g.srcdata.update({'k': k, 'v': v})
            g.dstdata['q'] = q
            g.apply_edges(fn.v_dot_u('q', 'k', 't'))  # g.edata['t']: (E, K, 1)
            attn = g.edata.pop('t').squeeze(dim=-1) * self.mu / math.sqrt(
                self.d_k)
            attn = edge_softmax(g, attn)  # (E, K)
            self.attn = attn.detach()
            g.edata['t'] = attn.unsqueeze(dim=-1)  # (E, K, 1)

            g.update_all(fn.u_mul_e('v', 't', 'm'), fn.sum('m', 'h'))
            out = g.dstdata['h'].view(-1, self.out_dim)  # (N_dst, d_out)
            return out
Ejemplo n.º 2
0
    def forward(self, G, h):
        with G.local_scope():
            node_dict, edge_dict = self.node_dict, self.edge_dict
            for srctype, etype, dsttype in G.canonical_etypes:
                sub_graph = G[srctype, etype, dsttype]

                k_linear = self.k_linears[node_dict[srctype]]
                v_linear = self.v_linears[node_dict[srctype]]
                q_linear = self.q_linears[node_dict[dsttype]]

                k = k_linear(h[srctype]).view(-1, self.n_heads, self.d_k)
                v = v_linear(h[srctype]).view(-1, self.n_heads, self.d_k)
                q = q_linear(h[dsttype]).view(-1, self.n_heads, self.d_k)

                e_id = self.edge_dict[etype]

                relation_att = self.relation_att[e_id]
                relation_pri = self.relation_pri[e_id]
                relation_msg = self.relation_msg[e_id]

                k = torch.einsum("bij,ijk->bik", k, relation_att)
                v = torch.einsum("bij,ijk->bik", v, relation_msg)

                sub_graph.srcdata['k'] = k
                sub_graph.dstdata['q'] = q
                sub_graph.srcdata['v_%d' % e_id] = v

                sub_graph.apply_edges(fn.v_dot_u('q', 'k', 't'))
                attn_score = sub_graph.edata.pop('t').sum(
                    -1) * relation_pri / self.sqrt_dk
                attn_score = edge_softmax(sub_graph, attn_score, norm_by='dst')

                sub_graph.edata['t'] = attn_score.unsqueeze(-1)

            G.multi_update_all({etype : (fn.u_mul_e('v_%d' % e_id, 't', 'm'), fn.sum('m', 't')) \
                                for etype, e_id in edge_dict.items()}, cross_reducer = 'mean')

            new_h = {}
            for ntype in G.ntypes:
                '''
                    Step 3: Target-specific Aggregation
                    x = norm( W[node_type] * gelu( Agg(x) ) + x )
                '''
                n_id = node_dict[ntype]
                alpha = torch.sigmoid(self.skip[n_id])
                t = G.nodes[ntype].data['t'].view(-1, self.out_dim)
                trans_out = self.drop(self.a_linears[n_id](t))
                trans_out = trans_out * alpha + h[ntype] * (1 - alpha)
                if self.use_norm:
                    new_h[ntype] = self.norms[n_id](trans_out)
                else:
                    new_h[ntype] = trans_out
            return new_h
Ejemplo n.º 3
0
    def forward(self, g, feats):
        """
        :param g: DGLGraph 异构图
        :param feats: Dict[str, tensor(N_i, d_in)] 顶点类型到输入顶点特征的映射
        :return: Dict[str, tensor(N_i, d_out)] 顶点类型到输出特征的映射
        """
        with g.local_scope():
            for stype, etype, dtype in g.canonical_etypes:
                sg = g[stype, etype, dtype]
                feat_src, feat_dst = feats[stype], feats[dtype]

                # (N_i, d_in) -> (N_i, d_out) -> (N_i, K, d_out/K)
                k = self.k_linears[stype](feat_src).view(
                    -1, self.num_heads, self.d_k)
                v = self.v_linears[stype](feat_src).view(
                    -1, self.num_heads, self.d_k)
                q = self.q_linears[dtype](feat_dst).view(
                    -1, self.num_heads, self.d_k)

                # k[:, h] @= w_att[h] => k[n, h, j] = ∑(i) k[n, h, i] * w_att[h, i, j]
                k = torch.einsum('nhi,hij->nhj', k, self.w_att[etype])
                v = torch.einsum('nhi,hij->nhj', v, self.w_msg[etype])

                sg.srcdata.update({'k': k, f'v_{etype}': v})
                sg.dstdata['q'] = q

                # 第1步:异构互注意力
                sg.apply_edges(fn.v_dot_u('q', 'k',
                                          't'))  # sg.edata['t']: (E, K, 1)
                attn = sg.edata.pop('t').squeeze(
                    dim=-1) * self.mu[etype] / math.sqrt(self.d_k)
                attn = edge_softmax(sg, attn)  # (E, K)
                sg.edata['t'] = attn.unsqueeze(dim=-1)

            # 第2步:异构消息传递+目标相关的聚集
            g.multi_update_all(
                {
                    etype:
                    (fn.u_mul_e(f'v_{etype}', 't', 'm'), fn.sum('m', 'h'))
                    for etype in g.etypes
                }, 'mean')

            # 第3步:残差连接
            out_feats = {}
            for ntype in g.ntypes:
                alpha = torch.sigmoid(self.skip[ntype])
                h = g.nodes[ntype].data['h'].view(-1, self.out_dim)
                trans_out = self.drop(self.a_linears[ntype](h))
                out = alpha * trans_out + (1 - alpha) * feats[ntype]
                out_feats[ntype] = self.norms[ntype](
                    out) if self.use_norm else out
            return out_feats