def topk_attention(self, graph: DGLGraph): graph = graph.local_var() # the graph should be added a self-loop edge def send_edge_message(edges): return {'m_e': edges.data['e']} def topk_attn_reduce_func(nodes): topk = self.top_k attentions = nodes.mailbox['m_e'] neighbor_num = attentions.shape[1] if topk > neighbor_num: topk = neighbor_num topk_atts, _ = torch.topk(attentions, k=topk, dim=1) kth_attn_value = topk_atts[:, topk - 1] return {'kth_e': kth_attn_value} graph.register_reduce_func(topk_attn_reduce_func) graph.register_message_func(send_edge_message) graph.update_all(message_func=send_edge_message, reduce_func=topk_attn_reduce_func) def edge_score_update(edges): scores, kth_score = edges.data['e'], edges.dst['kth_e'] scores[scores < kth_score] = self.attention_mask_value return {'e': scores} graph.apply_edges(edge_score_update) topk_attentions = graph.edata.pop('e') return topk_attentions
def forward(self, graph: dgl.DGLGraph): graph.apply_nodes(self.input_node_func) graph.apply_edges(self.input_edge_func, etype='bond') for mp_layer in self.mp_layers: mp_layer(graph) graph.apply_nodes(self.output_node_func)
def forward(self, graph: dgl.DGLGraph, h): # equation (1) z = self.fc(h) graph.ndata['z'] = z # equation (2) graph.apply_edges(self.edge_attention) # equation (3) & (4) graph.update_all(self.message_func, self.reduce_func) return graph.ndata.pop('h')
def topk_attention_softmax(self, graph: DGLGraph): graph = graph.local_var() def send_edge_message(edges): return {'m_e': edges.data['e'], 'm_e_id': edges.data['e_id']} def topk_attn_reduce_func(nodes): topk = self.top_k attentions = nodes.mailbox['m_e'] edge_ids = nodes.mailbox['m_e_id'] topk_edge_ids = torch.full(size=(edge_ids.shape[0], topk), fill_value=-1, dtype=torch.long) if torch.cuda.is_available(): topk_edge_ids = topk_edge_ids.cuda() attentions_sum = attentions.sum(dim=2) neighbor_num = attentions_sum.shape[1] if topk > neighbor_num: topk = neighbor_num topk_atts, top_k_neighbor_idx = torch.topk(attentions_sum, k=topk, dim=1) top_k_neighbor_idx = top_k_neighbor_idx.squeeze(dim=-1) row_idxes = torch.arange(0, top_k_neighbor_idx.shape[0]).view(-1, 1) top_k_attention = attentions[row_idxes, top_k_neighbor_idx] top_k_edge_ids = edge_ids[row_idxes, top_k_neighbor_idx] top_k_attention_norm = top_k_attention.sum(dim=1) topk_edge_ids[:, torch.arange(0, topk)] = top_k_edge_ids return { 'topk_eid': topk_edge_ids, 'topk_norm': top_k_attention_norm } graph.register_reduce_func(topk_attn_reduce_func) graph.register_message_func(send_edge_message) graph.update_all(message_func=send_edge_message, reduce_func=topk_attn_reduce_func) topk_edge_ids = graph.ndata['topk_eid'].flatten() topk_edge_ids = topk_edge_ids[topk_edge_ids >= 0] mask_edges = torch.zeros((graph.number_of_edges(), 1)) if torch.cuda.is_available(): mask_edges = mask_edges.cuda() mask_edges[topk_edge_ids] = 1 attentions = graph.edata['e'].squeeze(dim=-1) attentions = attentions * mask_edges graph.edata['e'] = attentions.unsqueeze(dim=-1) def edge_score_update(edges): scores = edges.data['e'] / edges.dst['topk_norm'] return {'e': scores} graph.apply_edges(edge_score_update) topk_attentions = graph.edata.pop('e') return topk_attentions
def forward(self, graph: dgl.DGLGraph): graph.apply_nodes(self.input_node_func) if self.fourier_encodings > 0: graph.edata['d'] = fourier_encode_dist( graph.edata['d'], num_encodings=self.fourier_encodings) graph.apply_edges(self.input_edge_func) for mp_layer in self.mp_layers: mp_layer(graph) if self.node_wise_output_layers > 0: graph.apply_nodes(self.output_node_func) return graph.ndata['feat']
def forward(self, graph: dgl.DGLGraph) -> torch.Tensor: with graph.local_scope(): graph.apply_edges(self._apply_edges, etype="forward") h_forward = graph.edges["forward"].data["h"] with graph.local_scope(): graph.apply_edges(self._apply_edges, etype="reverse") h_reverse = graph.edges["reverse"].data["h"] return self.layers(h_forward) + self.layers(h_reverse)
def forward(self, graph: DGLGraph, features, drop_edge_ids=None): ###Attention computation: pre-normalization structure graph = graph.local_var() h = self.graph_norm(features) # feat_head = self.fc_head(self.feat_drop(h)).view(-1, self._num_heads, self._att_dim) # feat_tail = self.fc_tail(self.feat_drop(h)).view(-1, self._num_heads, self._att_dim) feat_head = torch.tanh(self.fc_head(self.feat_drop(h))).view( -1, self._num_heads, self._att_dim) feat_tail = torch.tanh(self.fc_tail(self.feat_drop(h))).view( -1, self._num_heads, self._att_dim) # feat_head = F.relu(self.fc_head(self.feat_drop(h))).view(-1, self._num_heads, self._att_dim) # feat_tail = F.relu(self.fc_tail(self.feat_drop(h))).view(-1, self._num_heads, self._att_dim) feat = self.fc(self.feat_drop(h)).view(-1, self._num_heads, self._att_dim) eh = (feat_head * self.attn_h).sum(dim=-1).unsqueeze(-1) et = (feat_tail * self.attn_t).sum(dim=-1).unsqueeze(-1) graph.ndata.update({'ft': feat, 'eh': eh, 'et': et}) graph.apply_edges(fn.u_add_v('eh', 'et', 'e')) attations = graph.edata.pop('e') attations = self.leaky_relu(attations) if drop_edge_ids is not None: attations[drop_edge_ids] = self.attention_mask_value if self.top_k <= 0: graph.edata['a'] = edge_softmax(graph, attations) else: if self.topk_type == 'local': graph.edata['e'] = attations attations = self.topk_attention(graph) graph.edata['a'] = edge_softmax( graph, attations) ##return attention scores else: graph.edata['e'] = edge_softmax(graph, attations) graph.edata['a'] = self.topk_attention_softmax(graph) rst = self.ppr_estimation(graph=graph) rst = rst.flatten(1) rst = self.fc_out(rst) resval = self.res_fc(features) rst = resval + self.feat_drop(rst) rst_ff = self.feed_forward(self.ff_norm(rst)) rst = rst + self.feat_drop(rst_ff) # +++++++ attations = graph.edata.pop('a') # +++++++ return rst, attations
def forward(self, graph: dgl.DGLGraph): if self.fourier_encodings > 0: graph.edata['d'] = fourier_encode_dist( graph.edata['d'], num_encodings=self.fourier_encodings) graph.apply_edges(self.input_edge_func) graph.update_all(message_func=self.message_function, reduce_func=self.reduce_func(msg='m', out='m_sum')) if self.node_wise_output_layers > 0: graph.apply_nodes(self.output_node_func) readouts_to_cat = [ dgl.readout_nodes(graph, 'feat', op=aggr) for aggr in self.readout_aggregators ] readout = torch.cat(readouts_to_cat, dim=-1) return self.output(readout)
def forward(self, graph: dgl.DGLGraph): graph.ndata['feat'] = self.node_embedding[None, :].expand( graph.number_of_nodes(), -1) if self.fourier_encodings > 0: graph.edata['d'] = fourier_encode_dist( graph.edata['d'], num_encodings=self.fourier_encodings) graph.apply_edges(self.input_edge_func) for mp_layer in self.mp_layers: mp_layer(graph) if self.node_wise_output_layers > 0: graph.apply_nodes(self.output_node_func) readouts_to_cat = [ dgl.readout_nodes(graph, 'feat', op=aggr) for aggr in self.readout_aggregators ] readout = torch.cat(readouts_to_cat, dim=-1) return self.output(readout)
class D2GCN(nn.Module): def __init__(self, in_feat_dim, out_feat_dim): super(D2GCN, self).__init__() self.fedge = nn.Sequential( nn.Linear(in_feat_dim * 2, in_feat_dim // 64), nn.BatchNorm1d(in_feat_dim // 64), nn.Dropout(dropout), nn.LeakyReLU(), nn.Linear(in_feat_dim // 64, out_feat_dim), nn.BatchNorm1d(out_feat_dim), nn.Dropout(dropout), nn.ReLU()) if feature_drop: self.feat_drop = nn.Dropout(feature_drop) else: self.feat_drop = lambda x: x if att_drop: self.att_drop = nn.Dropout(att_drop) else: self.att_drop = lambda x: x self.attn_l = nn.Parameter(torch.Tensor(size=(1, out_feat_dim))) self.attn_r = nn.Parameter(torch.Tensor(size=(1, out_feat_dim))) self.relu = nn.LeakyReLU(alpha) self.softmax = edge_softmax self.fnode = nn.Sequential( nn.Linear(in_feat_dim + out_feat_dim, out_feat_dim // 64), nn.BatchNorm1d(out_feat_dim // 64), nn.Dropout(dropout), nn.LeakyReLU(), nn.Linear(out_feat_dim // 64, out_feat_dim), nn.BatchNorm1d(out_feat_dim), nn.Dropout(dropout), nn.ReLU()) nn.init.xavier_normal_(self.attn_l.data, gain=1.414) nn.init.xavier_normal_(self.attn_r.data, gain=1.414) def build_graph(self, num_nodes, device): self.g = DGLGraph() self.g.add_nodes(num_nodes) for i in range(0, num_nodes): for j in range(0, num_nodes): if i != j: self.g.add_edge(i, j) self.g.add_edge(j, i) self.g.to(device) self.g.register_message_func(self.send_source) self.g.register_reduce_func(self.simple_reduce) def send_source(self, edges): edge_feature = self.fedge.forward( torch.cat((edges.src["h"], edges.dst["h"]), dim=1)) msg = self.fnode.forward( torch.cat((edges.src["h"], edge_feature), dim=1)) m = torch.mul(msg, edges.data['a_drop']) return {"m": m} def simple_reduce(self, nodes): return {"h": torch.sum(nodes.mailbox['m'], dim=1) + nodes.data["h"]} def edge_attention(self, edges): a = self.relu(edges.src['a1'] + edges.dst['a2']) return {'a': a} def edge_softmax(self): att = self.softmax(self.g, self.g.edata.pop('a')) self.g.edata['a_drop'] = self.att_drop(att) def forward(self, n_feature): a1 = (n_feature * self.attn_l).sum(dim=-1).unsqueeze(-1) a2 = (n_feature * self.attn_r).sum(dim=-1).unsqueeze(-1) self.g.ndata.update({'h': n_feature, 'a1': a1, 'a2': a2}) self.g.apply_edges(self.edge_attention) self.edge_softmax() self.g.send(self.g.edges()) self.g.recv(self.g.nodes()) return self.g.ndata.pop('h')
def forward(self, dgl_data: dgl.DGLGraph): dgl_data.apply_edges(self.edge_init) return dgl_data
def forward(self, dgl_data: dgl.DGLGraph): # 注意这个函数只会更新node feature dgl_data.update_all(self.gcn_msg, self.gcn_reduce) dgl_data.apply_edges(self.edge_update) return dgl_data