def forward(self, v, k: Dict = None, q: Dict = None, G=None, **kwargs): """Forward pass of the linear layer Args: G: minibatch of (h**o)graphs v: dict of value edge-features k: dict of key edge-features q: dict of query node-features Returns: tensor with new features [B, n_points, n_features_out] """ with G.local_scope(): # Add node features to local graph scope ## We use the stacked tensor representation for attention for m, d in self.f_value.structure: G.edata[f'v{d}'] = v[f'{d}'].view(-1, self.n_heads, m // self.n_heads, 2 * d + 1) G.edata['k'] = fiber2head(k, self.n_heads, self.f_key, squeeze=True) G.ndata['q'] = fiber2head(q, self.n_heads, self.f_key, squeeze=True) # Compute attention weights ## Inner product between (key) neighborhood and (query) center G.apply_edges(fn.e_dot_v('k', 'q', 'e')) ## Apply softmax e = G.edata.pop('e') if self.new_dgl: # in dgl 5.3, e has an extra dimension compared to dgl 4.3 # the following, we get rid of this be reshaping n_edges = G.edata['k'].shape[0] e = e.view([n_edges, self.n_heads]) e = e / np.sqrt(self.f_key.n_features) G.edata['a'] = edge_softmax(G, e) # Perform attention-weighted message-passing for d in self.f_value.degrees: G.update_all(self.udf_u_mul_e(d), fn.sum('m', f'out{d}')) output = {} for m, d in self.f_value.structure: output[f'{d}'] = G.ndata[f'out{d}'].view(-1, m, 2 * d + 1) return output
def forward(self, g, f_feat, b_feat, u_feat, v_feat): g.srcnodes['u'].data['h'] = u_feat g.srcnodes['v'].data['h'] = v_feat g.dstnodes['u'].data['h'] = u_feat[:g.number_of_dst_nodes(ntype='u')] g.dstnodes['v'].data['h'] = v_feat[:g.number_of_dst_nodes(ntype='v')] g.edges['forward'].data['h'] = f_feat g.edges['backward'].data['h'] = b_feat # formula 3 and 4 (optimized implementation to save memory) g.srcnodes["u"].data.update( {'he_u': self.u_linear(g.srcnodes['u'].data['h'])}) g.srcnodes["v"].data.update( {'he_v': self.v_linear(g.srcnodes['v'].data['h'])}) g.dstnodes["u"].data.update( {'he_u': self.u_linear(g.dstnodes['u'].data['h'])}) g.dstnodes["v"].data.update( {'he_v': self.v_linear(g.dstnodes['v'].data['h'])}) g.edges["forward"].data.update({'he_e': self.e_linear(f_feat)}) g.edges["backward"].data.update({'he_e': self.e_linear(b_feat)}) g.apply_edges( lambda edges: {'he': edges.data['he_e'] + edges.dst['he_u'] + edges.src['he_v']}, etype='backward') g.apply_edges( lambda edges: {'he': edges.data['he_e'] + edges.src['he_u'] + edges.dst['he_v']}, etype='forward') hf = g.edges["forward"].data['he'] hb = g.edges["backward"].data['he'] if self.activation is not None: hf = self.activation(hf) hb = self.activation(hb) # formula 6 g.apply_edges(lambda edges: {'h_ve': th.cat([edges.src['h'], edges.data['h']], -1)}, etype='backward') g.apply_edges(lambda edges: {'h_ue': th.cat([edges.src['h'], edges.data['h']], -1)}, etype='forward') # formula 7, self-attention g.srcnodes['u'].data['h_att_u'] = self.W_ATTN_u( g.srcnodes['u'].data['h']) g.srcnodes['v'].data['h_att_v'] = self.W_ATTN_v( g.srcnodes['v'].data['h']) g.dstnodes['u'].data['h_att_u'] = self.W_ATTN_u( g.dstnodes['u'].data['h']) g.dstnodes['v'].data['h_att_v'] = self.W_ATTN_v( g.dstnodes['v'].data['h']) # Step 1: dot product g.apply_edges(fn.e_dot_v('h_ve', 'h_att_u', 'edotv'), etype='backward') g.apply_edges(fn.e_dot_v('h_ue', 'h_att_v', 'edotv'), etype='forward') # Step 2. softmax g.edges['backward'].data['sfm'] = edge_softmax( g['backward'], g.edges['backward'].data['edotv']) g.edges['forward'].data['sfm'] = edge_softmax( g['forward'], g.edges['forward'].data['edotv']) # Step 3. Broadcast softmax value to each edge, and then attention is done g.apply_edges( lambda edges: {'attn': edges.data['h_ve'] * edges.data['sfm']}, etype='backward') g.apply_edges( lambda edges: {'attn': edges.data['h_ue'] * edges.data['sfm']}, etype='forward') # Step 4. Aggregate attention to dst,user nodes, so formula 7 is done g.update_all(fn.copy_e('attn', 'm'), fn.sum('m', 'agg_u'), etype='backward') g.update_all(fn.copy_e('attn', 'm'), fn.sum('m', 'agg_v'), etype='forward') # formula 5 h_nu = self.W_u(g.dstnodes['u'].data['agg_u']) h_nv = self.W_v(g.dstnodes['v'].data['agg_v']) if self.activation is not None: h_nu = self.activation(h_nu) h_nv = self.activation(h_nv) # Dropout hf = self.dropout(hf) hb = self.dropout(hb) h_nu = self.dropout(h_nu) h_nv = self.dropout(h_nv) # formula 8 hu = th.cat([self.Vu(g.dstnodes['u'].data['h']), h_nu], -1) hv = th.cat([self.Vv(g.dstnodes['v'].data['h']), h_nv], -1) return hf, hb, hu, hv