def forward(self, graph, inputs): msg_passing_fns = {} for stype, etype, dtype in graph.canonical_etypes: Wh = self.linears[stype](inputs[stype]) graph.nodes[stype].data["Wh_%s" % etype] = Wh msg_fn = dfn.copy_u("Wh_%s" % etype, "m") reduce_fn = dfn.mean("m", "h") msg_passing_fns[etype] = (msg_fn, reduce_fn) graph.multi_update_all(msg_passing_fns, "sum") return graph.ndata["h"]
def call(self, graph, feat, edge_feat): graph = graph.local_var() graph.srcdata['h'] = feat graph.apply_edges(fn.copy_u('h', 's')) graph.edata['e'] = self.dense( tf.concat([graph.edata.pop('s'), edge_feat], axis=1)) graph.update_all(fn.copy_e('e', 'm'), fn.sum(msg='m', out='h')) rst = self.dense2(graph.dstdata['h']) return feat + rst
def forward(self, g, feature): g.ndata['h'] = feature # g.update_all(msg,reduce) # collect features from source nodes and aggregate them in destination nodes g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h_sum')) # multiply source node features with edge weights and aggregate them in destination nodes # g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.max('m', 'h_max')) g.apply_nodes(func=self.apply_mod) h = g.ndata.pop('h') # print(h.shape) return h
def forward(self, graph, feat_dict): funcs = {} for srctype, etype, dsttype in graph.canonical_etypes: Wh = self.weight[etype](feat_dict[srctype]) graph.nodes[srctype].data['Wh_%s' % etype] = Wh funcs[etype] = (fn.copy_u('Wh_%s' % etype, 'm'), fn.mean('m', 'h')) graph.multi_update_all(funcs, 'sum') return {ntype: graph.nodes[ntype].data['h'] for ntype in graph.ntypes}
def forward(self, graph, feat): with graph.local_scope(): if not self._allow_zero_in_degree: if (graph.in_degrees() == 0).any(): raise DGLError( 'There are 0-in-degree nodes in the graph, ' 'output for those nodes will be invalid. ' 'This is harmful for some applications, ' 'causing silent performance regression. ' 'Adding self-loop on the input graph by ' 'calling `g = dgl.add_self_loop(g)` will resolve ' 'the issue. Setting ``allow_zero_in_degree`` ' 'to be `True` when constructing this module will ' 'suppress the check and let the code run.') if self._cached_h is not None: feat_list = self._cached_h result = torch.zeros(feat_list[0].shape[0], self.out_feats).to(feat_list[0].device) for i, k_feat in enumerate(feat_list): result += self.fc(k_feat * (self._lambda.pow(i) / factorial(i))) else: feat_list = [] # compute normalization degs = graph.in_degrees().float().clamp(min=1) norm = th.pow(degs, -0.5) norm = norm.to(feat.device).unsqueeze(1) feat_list.append(feat.float()) for i in range(self._k): feat = feat * norm feat = feat.float() graph.ndata['h'] = feat graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) feat = graph.ndata.pop('h') feat = feat * norm feat_list.append(feat) result = torch.zeros(feat_list[0].shape[0], self.out_feats).to(feat_list[0].device) for i, k_feat in enumerate(feat_list): result += self.fc(k_feat * (self._lambda.pow(i) / factorial(i))) if self.norm is not None: result = self.norm(result) # cache feature if self._cached: self._cached_h = feat_list return result
def do_copy_reduce(): funcs=OrderedDict() for i, rating in enumerate(dataset.possible_rating_values): rating = str(rating) graph.nodes['movie'].data['h%d' % i] = allfeat[i] #funcs[rating] = (fn.copy_u('h%d' % i, 'm'), fn.sum('m', 'h')) funcs['rev-%s' % rating] = (fn.copy_u('h%d' % i, 'm'), fn.sum('m', 'h')) # message passing graph.multi_update_all(funcs, "stack") #graph.nodes['user'].data.pop('h') return graph.nodes['user'].data.pop('h').reshape(num_u, -1)
def forward(self, g, inputs): with g.local_scope(): if isinstance(inputs, tuple) or g.is_block: if isinstance(inputs, tuple): src_inputs, dst_inputs = inputs else: src_inputs = inputs # dead code dst_inputs will not be used dst_inputs = { k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items() } g.srcdata['h'] = src_inputs g.update_all(fn.copy_u('h', 'm'), self._lstm_reducer) h_neigh = g.dstdata['neigh'] else: g.srcdata['h'] = inputs g.update_all(fn.copy_u('h', 'm'), self._lstm_reducer) h_neigh = g.dstdata['neigh'] return h_neigh
def forward(self, g, h, logits, old_z, shared_tau=True, tau_1=None, tau_2=None): # operates on a node g = g.local_var() if self.dropout: h = self.dropout(h) g.ndata['h'] = h g.ndata['logits'] = logits g.update_all(message_func=fn.copy_u('logits', 'logits'), reduce_func=adaptive_reduce_func) f1 = g.ndata.pop('f1') f2 = g.ndata.pop('f2') norm_f1 = self.ln_1(f1) norm_f2 = self.ln_2(f2) if shared_tau: z = F.sigmoid((-1) * (norm_f1 - tau_1)) * F.sigmoid( (-1) * (norm_f2 - tau_2)) else: # tau for each layer z = F.sigmoid((-1) * (norm_f1 - self.tau_1)) * F.sigmoid( (-1) * (norm_f2 - self.tau_2)) gate = torch.min(old_z, z) g.update_all(message_func=fn.copy_u('h', 'feat'), reduce_func=fn.sum(msg='feat', out='agg')) agg = g.ndata.pop('agg') normagg = agg * g.ndata['norm'] # normalization by tgt degree if self.activation: normagg = self.activation(normagg) new_h = h + gate.unsqueeze(1) * normagg return new_h, z
def unnLaplacian(feat, D_invsqrt_left, D_invsqrt_right, graph): """ Operation Feat * D^-1/2 A D^-1/2 但是如果写成矩阵乘法:D^-1/2 A D^-1/2 Feat""" #tmp = torch.zeros((D_invsqrt.shape[0],D_invsqrt.shape[0])).to(graph.device) # sparse tensor没有broadcast机制,最后还依赖于srcnode在feat中从0开始连续排布 #print("adj : ",graph.adj(transpose=False,ctx = graph.device).shape) #graph.srcdata['h'] = (torch.mm((graph.adj(transpose=False,ctx = graph.device)),(feat * D_invsqrt)))*D_invsqrt[::graph.number_of_dst_nodes()] #graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'h')) #return graph.srcdata['h'] graph.srcdata[ 'h'] = feat * D_invsqrt_right # feat is srcfeat graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) return graph.dstdata.pop('h') * D_invsqrt_left
def forward(self, G, features): funcs = {} feat = self.activation(self.reduce(features)) for _ in range(N_TIMESEPS): for etype in G.etypes: Wh = self.weight[etype](features) G.nodes['object'].data['Wh_%s' % etype] = Wh funcs[etype] = (fn.copy_u('Wh_%s' % etype, 'm'), fn.mean('m', 'h')) G.multi_update_all(funcs, 'sum') feat = self.gru(G.nodes['object'].data['h'], feat) return self.activation(feat)
def forward(self, graph, feat, logits): r"""Compute APPNP layer. Parameters ---------- graph : DGLGraph The graph. feat : torch.Tensor The input feature of shape :math:`(N, *)` :math:`N` is the number of nodes, and :math:`*` could be of any shape. Returns ------- torch.Tensor The output feature of shape :math:`(N, *)` where :math:`*` should be the same as input shape. """ graph = graph.local_var() norm = torch.pow(graph.in_degrees().float().clamp(min=1), -0.5) shp = norm.shape + (1, ) * (feat.dim() - 1) norm = torch.reshape(norm, shp).to(feat.device) feat_0 = feat z = torch.FloatTensor([ 1.0, ]).cuda() for lidx in range(self._k): # normalization by src node old_z = z feat = feat * norm graph.ndata['h'] = feat old_feat = feat if lidx != 0: logits = self.weight_y(feat) graph.ndata['logits'] = logits graph.update_all(message_func=fn.copy_u('logits', 'logits'), reduce_func=adaptive_reduce_func) f1 = graph.ndata.pop('f1') f2 = graph.ndata.pop('f2') norm_f1 = self.ln_1(f1) norm_f2 = self.ln_2(f2) z = F.sigmoid((-1) * (norm_f1 - self.tau_1)) * F.sigmoid( (-1) * (norm_f2 - self.tau_2)) gate = torch.min(old_z, z) graph.edata['w'] = self.edge_drop( torch.ones(graph.number_of_edges(), 1).to(feat.device)) graph.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h')) feat = graph.ndata.pop('h') # normalization by dst node feat = feat * norm feat = z.unsqueeze(1) * feat + old_feat # raw features return feat
def propagate_feature(g): with g.local_scope(): g.multi_update_all( { 'write': (fn.copy_u('net_embed', 'm'), fn.mean('m', 'a_net_embed')), 'publish_rev': (fn.copy_u('net_embed', 'm'), fn.mean('m', 'v_net_embed')) }, 'sum') paper_feats = torch.stack([ g.nodes['paper'].data[k] for k in ('abstract_embed', 'title_embed', 'net_embed', 'v_net_embed', 'a_net_embed') ], dim=1) # (N_p, 5, d) ap = find_neighbors(g, 'write_rev', 3).view(1, -1) # (1, 3N_a) ap_abstract_embed = g.nodes['paper'].data['abstract_embed'][ap] \ .view(g.num_nodes('author'), 3, -1) # (N_a, 3, d) author_feats = torch.cat([ g.nodes['author'].data['net_embed'].unsqueeze(dim=1), ap_abstract_embed ], dim=1) # (N_a, 4, d) vp = find_neighbors(g, 'publish', 5).view(1, -1) # (1, 5N_v) vp_abstract_embed = g.nodes['paper'].data['abstract_embed'][vp] \ .view(g.num_nodes('venue'), 5, -1) # (N_v, 5, d) venue_feats = torch.cat([ g.nodes['venue'].data['net_embed'].unsqueeze(dim=1), vp_abstract_embed ], dim=1) # (N_v, 6, d) return { 'author': author_feats, 'paper': paper_feats, 'venue': venue_feats }
def forward(self, bg, node_feats, edge_feats): """Initialize input representations. Project the node/edge features and then concatenate the edge representations with the representations of their source nodes. """ node_feats = self.project_nodes(node_feats) edge_feats = self.project_edges(edge_feats) bg = bg.local_var() bg.ndata['hv'] = node_feats bg.apply_edges(fn.copy_u('hv', 'he')) return torch.cat([bg.edata['he'], edge_feats], dim=1)
def forward(self, graph, ufeat=None, ifeat=None): num_u = graph.number_of_nodes('user') num_i = graph.number_of_nodes('item') funcs = {} for i, rating in enumerate(self.rating_vals): rating = str(rating) # W_r * x x_u = self.conv_u[i](ufeat) x_i = self.conv_i[i](ifeat) # left norm and dropout x_u = x_u * self.dropout(graph.nodes['user'].data['sqrt_deg']) x_i = x_i * self.dropout(graph.nodes['item'].data['sqrt_deg']) graph.nodes['user'].data['h%d' % i] = x_u graph.nodes['item'].data['h%d' % i] = x_i funcs[rating] = (fn.copy_u('h%d' % i, 'm'), fn.sum('m', 'h')) funcs['rev-%s' % rating] = (fn.copy_u('h%d' % i, 'm'), fn.sum('m', 'h')) # message passing graph.multi_update_all(funcs, self.agg) ufeat = graph.nodes['user'].data.pop('h').reshape((num_u, -1)) ifeat = graph.nodes['item'].data.pop('h').reshape((num_i, -1)) # right norm ufeat = ufeat * graph.nodes['user'].data['sqrt_deg'] ifeat = ifeat * graph.nodes['item'].data['sqrt_deg'] # non-linear ufeat = F.relu(ufeat) ifeat = F.relu(ifeat) return ufeat, ifeat
def test_updates(): def msg_func(edges): return {'m': edges.src['h']} def reduce_func(nodes): return {'y': F.sum(nodes.mailbox['m'], 1)} def apply_func(nodes): return {'y': nodes.data['y'] * 2} g = create_test_heterograph() x = F.randn((3, 5)) g.nodes['user'].data['h'] = x for msg, red, apply in itertools.product([fn.copy_u('h', 'm'), msg_func], [fn.sum('m', 'y'), reduce_func], [None, apply_func]): multiplier = 1 if apply is None else 2 g['user', 'plays', 'game'].update_all(msg, red, apply) y = g.nodes['game'].data['y'] assert F.array_equal(y[0], (x[0] + x[1]) * multiplier) assert F.array_equal(y[1], (x[1] + x[2]) * multiplier) del g.nodes['game'].data['y'] g['user', 'plays', 'game'].send_and_recv(([0, 1, 2], [0, 1, 1]), msg, red, apply) y = g.nodes['game'].data['y'] assert F.array_equal(y[0], x[0] * multiplier) assert F.array_equal(y[1], (x[1] + x[2]) * multiplier) del g.nodes['game'].data['y'] plays_g = g['user', 'plays', 'game'] plays_g.send(([0, 1, 2], [0, 1, 1]), msg) plays_g.recv([0, 1], red, apply) y = g.nodes['game'].data['y'] assert F.array_equal(y[0], x[0] * multiplier) assert F.array_equal(y[1], (x[1] + x[2]) * multiplier) del g.nodes['game'].data['y'] # pulls from destination (game) node 0 g['user', 'plays', 'game'].pull(0, msg, red, apply) y = g.nodes['game'].data['y'] assert F.array_equal(y[0], (x[0] + x[1]) * multiplier) del g.nodes['game'].data['y'] # pushes from source (user) node 0 g['user', 'plays', 'game'].push(0, msg, red, apply) y = g.nodes['game'].data['y'] assert F.array_equal(y[0], x[0] * multiplier) del g.nodes['game'].data['y']
def test_khop_adj(): N = 20 feat = F.randn((N, 5)) g = dgl.DGLGraph(nx.erdos_renyi_graph(N, 0.3)) for k in range(3): adj = F.tensor(dgl.khop_adj(g, k)) # use original graph to do message passing for k times. g.ndata['h'] = feat for _ in range(k): g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) h_0 = g.ndata.pop('h') # use k-hop adj to do message passing for one time. h_1 = F.matmul(adj, feat) assert F.allclose(h_0, h_1, rtol=1e-3, atol=1e-3)
def forward(self, graph, n_feats, e_weights=None): graph.ndata['h'] = n_feats if e_weights == None: graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h')) else: graph.edata['ew'] = e_weights graph.update_all(fn.u_mul_e('h', 'ew', 'm'), fn.mean('m', 'h')) graph.ndata['h'] = self.layer( th.cat([graph.ndata['h'], n_feats], dim=-1)) output = graph.ndata['h'] return output
def calculate_customized_homophily(g, labels, K, multilabels=False): if (not multilabels) and labels.max() > 1: y = torch.zeros(size=(len(labels), labels.max() + 1)) y[labels] = 1 else: y = labels g.ndata['y'] = y.clone() for k in range(K): g.update_all(fn.copy_u('y', 'm'), fn.mean('m', 'y')) y_new = g.ndata.pop('y') y_new = F.normalize(y_new, dim=1, p=1) out = y_new[labels.long()].mean(0) return out.mean(0)
def forward(self, mg, feat): with mg.local_scope(): if self.batch_norm is not None: feat = self.batch_norm(feat) mg.ndata['ft'] = feat if mg.number_of_edges() > 0: mg.update_all(fn.copy_u('ft', 'm'), self.reducer) neigh = mg.ndata['neigh'] rst = self.fc_self(feat) + self.fc_neigh(neigh) else: rst = self.fc_self(feat) if self.activation is not None: rst = self.activation(rst) return rst
def forward(self, block, H, HBar=None): if self.training: with block.local_scope(): H_src, H_dst = H HBar_src, agg_HBar_dst = HBar block.dstdata['agg_hbar'] = agg_HBar_dst block.srcdata['hdelta'] = H_src - HBar_src block.update_all(fn.copy_u('hdelta', 'm'), fn.mean('m', 'hdelta_new')) h_neigh = block.dstdata['agg_hbar'] + block.dstdata['hdelta_new'] h = self.W(th.cat([H_dst, h_neigh], 1)) if self.activation is not None: h = self.activation(h) return h else: with block.local_scope(): H_src, H_dst = H block.srcdata['h'] = H_src block.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_new')) h_neigh = block.dstdata['h_new'] h = self.W(th.cat([H_dst, h_neigh], 1)) if self.activation is not None: h = self.activation(h) return h
def evaluate(g, field_ids, author_rank, true_relevance, field_paper): predict_rank = {} field_feat = g.nodes['field'].data['feat'] apg = g['paper', 'writes_rev', 'author'] for i, f in enumerate(field_ids): pid = field_paper[f] paper_score = torch.matmul(g.nodes['paper'].data['feat'][pid], field_feat[f]) sg = dgl.out_subgraph(apg, {'paper': pid}, relabel_nodes=True) sg.nodes['paper'].data['score'] = paper_score sg.update_all(fn.copy_u('score', 's'), fn.sum('s', 's')) predict_rank[f] = (sg.nodes['author'].data[dgl.NID], sg.nodes['author'].data['s']) return calc_metrics(field_ids, author_rank, true_relevance, predict_rank)
def gen_mail(self, args, emb, input_nodes, pair_graph, frontier, mode='train'): pair_graph.ndata['feat'] = emb pair_graph = dgl.add_reverse_edges(pair_graph, copy_edata=True) pair_graph.update_all(MSG.get_edge_msg, fn.mean('m','msg')) frontier.ndata['msg'] = torch.zeros((frontier.num_nodes(), self.nfeat_dim + 2)) frontier.ndata['msg'][pair_graph.ndata[dgl.NID]] = pair_graph.ndata['msg'].to('cpu') for _ in range(args.n_layer): frontier.update_all(fn.copy_u('msg','m'), fn.mean('m','msg')) mail = MSG.msg2mail(frontier.ndata['mail'][input_nodes], frontier.ndata['msg'][input_nodes]) return mail
def forward(self, graph, feat, get_attention=False): # Check in degree and generate error if (graph.in_degrees()==0).any(): raise DGLError('There are 0-in-degree nodes in the graph, ' 'output for those nodes will be invalid. ' 'This is harmful for some applications, ' 'causing silent performance regression. ' 'Adding self-loop on the input graph by ' 'calling `g = dgl.add_self_loop(g)` will resolve ' 'the issue. Setting ``allow_zero_in_degree`` ' 'to be `True` when constructing this module will ' 'suppress the check and let the code run.') # projection process to get importance vector y graph.ndata['y'] = torch.abs(torch.matmul(self.p,feat.T).view(-1))/torch.norm(self.p,p=2) # Use edge message passing function to get the weight from src node graph.apply_edges(fn.copy_u('y','y')) # Select Top k neighbors subgraph = select_topk(graph,self.k,'y') # Sigmoid as information threshold subgraph.ndata['y'] = torch.sigmoid(subgraph.ndata['y']) # Using vector matrix elementwise mul for acceleration feat = subgraph.ndata['y'].view(-1,1)*feat feat = self.feat_drop(feat) h = self.fc(feat).view(-1, self.num_heads, self.out_feats) el = (h * self.attn_l).sum(dim=-1).unsqueeze(-1) er = (h * self.attn_r).sum(dim=-1).unsqueeze(-1) # Assign the value on the subgraph subgraph.srcdata.update({'ft': h, 'el': el}) subgraph.dstdata.update({'er': er}) # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively. subgraph.apply_edges(fn.u_add_v('el', 'er', 'e')) e = self.leaky_relu(subgraph.edata.pop('e')) # compute softmax subgraph.edata['a'] = self.attn_drop(edge_softmax(subgraph, e)) # message passing subgraph.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'ft')) rst = subgraph.dstdata['ft'] # activation if self.activation: rst = self.activation(rst) # Residual if self.residual: rst = rst + self.residual_module(feat).view(feat.shape[0],-1,self.out_feats) if get_attention: return rst, subgraph.edata['a'] else: return rst
def forward(self, graph, feat, eweight=None): with graph.local_scope(): feat = self.linear(feat) graph.ndata['h'] = feat if eweight is None: graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) else: graph.edata['w'] = eweight graph.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h')) if self.pool: return self.pool(graph, graph.ndata['h']) else: return graph.ndata['h']
def dgl_kernel_fusion(g, use_gpu, warmup_steps, total_steps): if use_gpu: th.cuda.synchronize() accum_time = 0 for cnt in range(total_steps): tic = time.time() g.update_all(fn.copy_u('x', 'm'), fn.sum('m', 'y')) y = g.ndata['y'] if use_gpu: th.cuda.synchronize() toc = time.time() if cnt >= warmup_steps: accum_time += toc - tic print('dgl kernel fusion average speed {} ms'.format(accum_time / (total_steps - warmup_steps)))
def forward(self, block, h): # with g.local_scope(): with block.local_scope(): # g.ndata['h'] = h h_src = h h_dst = h[:block.number_of_dst_nodes()] block.srcdata['h'] = h_src block.dstdata['h'] = h_dst # g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh')) block.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh')) # return self.W(torch.cat([g.ndata['h'], g.ndata['h_neigh']], 1)) return self.activation(self.W(torch.cat( [block.dstdata['h'], block.dstdata['h_neigh']], 1)))
def forward(self, G, feat_dict): # The input is a dictionary of node features for each type funcs = {} for srctype, etype, dsttype in G.canonical_etypes: # Compute W_r * h if srctype in feat_dict: Wh = self.weight[etype](feat_dict[srctype]) # Save it in graph for message passing G.nodes[srctype].data['Wh_%s' % etype] = Wh # Specify per-relation message passing functions: (message_func, reduce_func). funcs[etype] = (fn.copy_u('Wh_%s' % etype, 'm'), fn.mean('m', 'h')) # Trigger message passing of multiple types. G.multi_update_all(funcs, 'sum') # return the updated node feature dictionary return {ntype: G.nodes[ntype].data['h'] for ntype in G.ntypes if 'h' in G.nodes[ntype].data}
def forward(self, graph: dgl.DGLGraph, feats: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Args: graph: the graph feats: node features with node type as key and the corresponding features as value. Each tensor is of shape (N, D) where N is the number of nodes of the corresponding node type, and D is the feature size. Returns: updated node features. Each tensor is of shape (N, D) where N is the number of nodes of the corresponding node type, and D is the feature size. """ graph = graph.local_var() # assign data for nt, ft in feats.items(): graph.nodes[nt].data.update({"ft": ft}) for et in self.etypes: # option 1 graph[et].update_all(fn.copy_u("ft", "m"), fn.mean("m", "mean"), etype=et) graph[et].update_all(fn.copy_u("ft", "m"), fn.max("m", "max"), etype=et) nt = et[2] graph.apply_nodes(self._concatenate_node_feat, ntype=nt) # copy update feature from new_ft to ft graph.nodes[nt].data.update({"ft": graph.nodes[nt].data["new_ft"]}) return {nt: graph.nodes[nt].data["ft"] for nt in feats}
def forward(self, g, feat): with g.local_scope(): if self.aggre_type == 'attention': h_src = self.feat_drop(feat[0]).view(-1, self.num_heads, self.in_size) h_dst = self.feat_drop(feat[1]).view(-1, self.num_heads, self.in_size) el = (h_src * self.attn_l).sum(dim=-1).unsqueeze(-1) # er = (h_dst * self.attn_r).sum(dim=-1).unsqueeze(-1) g.srcdata.update({'ft': h_src, 'el': el}) # g.srcdata.update({'ft': h_src, 'er': er}) g.apply_edges(fn.copy_u('el', 'e')) # g.apply_edges(fn.u_add_v('el', 'er', 'e')) e = self.leaky_relu(g.edata.pop('e')) g.edata['a'] = self.attn_drop(edge_softmax(g, e)) g.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'ft')) rst = g.dstdata['ft'].flatten(1) if self.residual: rst = rst + h_dst.flatten(1) if self.activation: rst = self.activation(rst) elif self.aggre_type == 'mean': h_src = self.feat_drop(feat[0]).view(-1, self.in_size*self.num_heads) h_dst = self.feat_drop(feat[1]).view(-1, self.in_size * self.num_heads) g.srcdata['ft'] = h_src g.update_all(fn.copy_u('ft', 'm'), fn.mean('m', 'ft')) rst = g.dstdata['ft'] # + h_dst elif self.aggre_type == 'pool': h_src = self.feat_drop(feat[0]).view(-1, self.in_size*self.num_heads) h_dst = self.feat_drop(feat[1]).view(-1, self.in_size * self.num_heads) g.srcdata['ft'] = F.relu(self.fc_pool(h_src)) g.update_all(fn.copy_u('ft', 'm'), fn.max('m', 'ft')) rst = g.dstdata['ft'] #+ h_dst return rst
def forward(self, block): input_nodes = block.srcdata[dgl.NID] output_nodes = block.dstdata[dgl.NID] batch_size = block.number_of_dst_nodes() node_embed = self.node_embeddings node_type_embed = [] with block.local_scope(): for i in range(self.edge_type_count): edge_type = self.edge_types[i] block.srcdata[edge_type] = self.node_type_embeddings[ input_nodes, i] block.dstdata[edge_type] = self.node_type_embeddings[ output_nodes, i] block.update_all(fn.copy_u(edge_type, "m"), fn.sum("m", edge_type), etype=edge_type) node_type_embed.append(block.dstdata[edge_type]) node_type_embed = torch.stack(node_type_embed, 1) tmp_node_type_embed = node_type_embed.unsqueeze(2).view( -1, 1, self.embedding_u_size) trans_w = (self.trans_weights.unsqueeze(0).repeat( batch_size, 1, 1, 1).view(-1, self.embedding_u_size, self.embedding_size)) trans_w_s1 = (self.trans_weights_s1.unsqueeze(0).repeat( batch_size, 1, 1, 1).view(-1, self.embedding_u_size, self.dim_a)) trans_w_s2 = (self.trans_weights_s2.unsqueeze(0).repeat( batch_size, 1, 1, 1).view(-1, self.dim_a, 1)) attention = (F.softmax( torch.matmul( torch.tanh(torch.matmul(tmp_node_type_embed, trans_w_s1)), trans_w_s2, ).squeeze(2).view(-1, self.edge_type_count), dim=1, ).unsqueeze(1).repeat(1, self.edge_type_count, 1)) node_type_embed = torch.matmul(attention, node_type_embed).view( -1, 1, self.embedding_u_size) node_embed = node_embed[output_nodes].unsqueeze(1).repeat( 1, self.edge_type_count, 1) + torch.matmul( node_type_embed, trans_w).view(-1, self.edge_type_count, self.embedding_size) last_node_embed = F.normalize(node_embed, dim=2) return last_node_embed # [batch_size, edge_type_count, embedding_size]