def forward(self, g, feat): """ :param g: DGLGraph 二分图(只包含一种关系) :param feat: tensor(N_src, d_in) or (tensor(N_src, d_in), tensor(N_dst, d_in)) 输入特征 :return: tensor(N_dst, d_out) 目标顶点该关于关系的表示 """ with g.local_scope(): feat_src, feat_dst = expand_as_pair(feat, g) # (N_src, d_in) -> (N_src, d_out) -> (N_src, K, d_out/K) k = self.k_linear(feat_src).view(-1, self.num_heads, self.d_k) v = self.v_linear(feat_src).view(-1, self.num_heads, self.d_k) q = self.q_linear(feat_dst).view(-1, self.num_heads, self.d_k) # k[:, h] @= w_att[h] => k[n, h, j] = ∑(i) k[n, h, i] * w_att[h, i, j] k = torch.einsum('nhi,hij->nhj', k, self.w_att) v = torch.einsum('nhi,hij->nhj', v, self.w_msg) g.srcdata.update({'k': k, 'v': v}) g.dstdata['q'] = q g.apply_edges(fn.v_dot_u('q', 'k', 't')) # g.edata['t']: (E, K, 1) attn = g.edata.pop('t').squeeze(dim=-1) * self.mu / math.sqrt( self.d_k) attn = edge_softmax(g, attn) # (E, K) self.attn = attn.detach() g.edata['t'] = attn.unsqueeze(dim=-1) # (E, K, 1) g.update_all(fn.u_mul_e('v', 't', 'm'), fn.sum('m', 'h')) out = g.dstdata['h'].view(-1, self.out_dim) # (N_dst, d_out) return out
def forward(self, G, h): with G.local_scope(): node_dict, edge_dict = self.node_dict, self.edge_dict for srctype, etype, dsttype in G.canonical_etypes: sub_graph = G[srctype, etype, dsttype] k_linear = self.k_linears[node_dict[srctype]] v_linear = self.v_linears[node_dict[srctype]] q_linear = self.q_linears[node_dict[dsttype]] k = k_linear(h[srctype]).view(-1, self.n_heads, self.d_k) v = v_linear(h[srctype]).view(-1, self.n_heads, self.d_k) q = q_linear(h[dsttype]).view(-1, self.n_heads, self.d_k) e_id = self.edge_dict[etype] relation_att = self.relation_att[e_id] relation_pri = self.relation_pri[e_id] relation_msg = self.relation_msg[e_id] k = torch.einsum("bij,ijk->bik", k, relation_att) v = torch.einsum("bij,ijk->bik", v, relation_msg) sub_graph.srcdata['k'] = k sub_graph.dstdata['q'] = q sub_graph.srcdata['v_%d' % e_id] = v sub_graph.apply_edges(fn.v_dot_u('q', 'k', 't')) attn_score = sub_graph.edata.pop('t').sum( -1) * relation_pri / self.sqrt_dk attn_score = edge_softmax(sub_graph, attn_score, norm_by='dst') sub_graph.edata['t'] = attn_score.unsqueeze(-1) G.multi_update_all({etype : (fn.u_mul_e('v_%d' % e_id, 't', 'm'), fn.sum('m', 't')) \ for etype, e_id in edge_dict.items()}, cross_reducer = 'mean') new_h = {} for ntype in G.ntypes: ''' Step 3: Target-specific Aggregation x = norm( W[node_type] * gelu( Agg(x) ) + x ) ''' n_id = node_dict[ntype] alpha = torch.sigmoid(self.skip[n_id]) t = G.nodes[ntype].data['t'].view(-1, self.out_dim) trans_out = self.drop(self.a_linears[n_id](t)) trans_out = trans_out * alpha + h[ntype] * (1 - alpha) if self.use_norm: new_h[ntype] = self.norms[n_id](trans_out) else: new_h[ntype] = trans_out return new_h
def forward(self, g, feats): """ :param g: DGLGraph 异构图 :param feats: Dict[str, tensor(N_i, d_in)] 顶点类型到输入顶点特征的映射 :return: Dict[str, tensor(N_i, d_out)] 顶点类型到输出特征的映射 """ with g.local_scope(): for stype, etype, dtype in g.canonical_etypes: sg = g[stype, etype, dtype] feat_src, feat_dst = feats[stype], feats[dtype] # (N_i, d_in) -> (N_i, d_out) -> (N_i, K, d_out/K) k = self.k_linears[stype](feat_src).view( -1, self.num_heads, self.d_k) v = self.v_linears[stype](feat_src).view( -1, self.num_heads, self.d_k) q = self.q_linears[dtype](feat_dst).view( -1, self.num_heads, self.d_k) # k[:, h] @= w_att[h] => k[n, h, j] = ∑(i) k[n, h, i] * w_att[h, i, j] k = torch.einsum('nhi,hij->nhj', k, self.w_att[etype]) v = torch.einsum('nhi,hij->nhj', v, self.w_msg[etype]) sg.srcdata.update({'k': k, f'v_{etype}': v}) sg.dstdata['q'] = q # 第1步:异构互注意力 sg.apply_edges(fn.v_dot_u('q', 'k', 't')) # sg.edata['t']: (E, K, 1) attn = sg.edata.pop('t').squeeze( dim=-1) * self.mu[etype] / math.sqrt(self.d_k) attn = edge_softmax(sg, attn) # (E, K) sg.edata['t'] = attn.unsqueeze(dim=-1) # 第2步:异构消息传递+目标相关的聚集 g.multi_update_all( { etype: (fn.u_mul_e(f'v_{etype}', 't', 'm'), fn.sum('m', 'h')) for etype in g.etypes }, 'mean') # 第3步:残差连接 out_feats = {} for ntype in g.ntypes: alpha = torch.sigmoid(self.skip[ntype]) h = g.nodes[ntype].data['h'].view(-1, self.out_dim) trans_out = self.drop(self.a_linears[ntype](h)) out = alpha * trans_out + (1 - alpha) * feats[ntype] out_feats[ntype] = self.norms[ntype]( out) if self.use_norm else out return out_feats