def forward(self, g, node_feats, edge_feats): with g.local_scope(): # Node and edge feature dimension need to match. g.ndata['h'] = node_feats g.edata['h'] = self.edge_encoder(edge_feats) g.apply_edges(fn.u_add_e('h', 'h', 'm')) if self.aggr == 'softmax': g.edata['m'] = F.relu(g.edata['m']) + self.eps g.edata['a'] = edge_softmax(g, g.edata['m'] * self.beta) g.update_all( lambda edge: {'x': edge.data['m'] * edge.data['a']}, fn.sum('x', 'm')) elif self.aggr == 'power': minv, maxv = 1e-7, 1e1 torch.clamp_(g.edata['m'], minv, maxv) g.update_all( lambda edge: {'x': torch.pow(edge.data['m'], self.p)}, fn.mean('x', 'm')) torch.clamp_(g.ndata['m'], minv, maxv) g.ndata['m'] = torch.pow(g.ndata['m'], self.p) else: raise NotImplementedError( f'Aggregator {self.aggr} is not supported.') if self.msg_norm is not None: g.ndata['m'] = self.msg_norm(node_feats, g.ndata['m']) feats = node_feats + g.ndata['m'] return self.mlp(feats)
def forward(self, g, split_list, node_feat, edge_feat): graph = g.local_var() graph.ndata['h_n'] = node_feat graph.edata['h_e'] = edge_feat graph.update_all(fn.u_add_e('h_n', 'h_e', 'm'), self._reducer('m', 'neigh')) rst = (1 + self.eps) * node_feat + graph.ndata['neigh'] if self.apply_func is not None: rst = self.apply_func(g, rst) return rst
def forward(self, graph, node_feat, edge_feat): graph = graph.local_var() graph.ndata['h_n'] = node_feat graph.edata['h_e'] = edge_feat ### u, v, e represent source nodes, destination nodes and edges among them graph.update_all(fn.u_add_e('h_n', 'h_e', 'm'), fn.sum('m', 'neigh')) rst = (1 + self.eps) * node_feat + graph.ndata['neigh'] rst = self.mlp(rst) return rst
def forward(self, g, node_feats, categorical_edge_feats): """Update node representations. Parameters ---------- g : DGLGraph DGLGraph for a batch of graphs node_feats : FloatTensor of shape (N, emb_dim) * Input node features * N is the total number of nodes in the batch of graphs * emb_dim is the input node feature size, which must match emb_dim in initialization categorical_edge_feats : list of LongTensor of shape (E) * Input categorical edge features * len(categorical_edge_feats) should be the same as len(self.edge_embeddings) * E is the total number of edges in the batch of graphs Returns ------- node_feats : float32 tensor of shape (N, emb_dim) Output node representations """ edge_embeds = [] for i, feats in enumerate(categorical_edge_feats): edge_embeds.append(self.edge_embeddings[i](feats)) edge_embeds = torch.stack(edge_embeds, dim=0).sum(0) g = g.local_var() g.ndata['feat'] = node_feats g.edata['feat'] = edge_embeds g.update_all(fn.u_add_e('feat', 'feat', 'm'), fn.sum('m', 'feat')) node_feats = self.mlp(g.ndata.pop('feat')) if self.bn is not None: node_feats = self.bn(node_feats) if self.activation is not None: node_feats = self.activation(node_feats) return node_feats
def forward(self, agg_graph: dgl.DGLGraph, prop_graph: dgl.DGLGraph, traversal_order, new_node_ids) -> torch.Tensor: tg = agg_graph.local_var() pg = prop_graph.local_var() nfeat = tg.ndata["nfeat"] # h_self = nfeat h_self = self.encode_time(nfeat, tg.ndata["timestamp"]) tg.ndata["nfeat"] = h_self tg.edata["efeat"] = self.fc_edge(tg.edata["efeat"]) # efeat = tg.edata["efeat"] # tg.apply_edges(lambda edges: { # "efeat": # torch.cat((edges.src["nfeat"], edges.data["efeat"]), dim=1) # }) # tg.edata["efeat"] = self.encode_time(tg.edata["efeat"], tg.edata["timestamp"]) degs = tg.ndata["degree"] # agg_graph aggregation if self._agg_type == "pool": tg.edata["efeat"] = F.relu(self.fc_pool(tg.edata["efeat"])) tg.update_all(fn.u_add_e("nfeat", "efeat", "m"), fn.max("m", "neigh")) h_neigh = tg.ndata["neigh"] elif self._agg_type in ["mean", "gcn", "lstm"]: tg.update_all(fn.u_add_e("nfeat", "efeat", "m"), fn.sum("m", "neigh")) h_neigh = tg.ndata["neigh"] else: raise KeyError("Aggregator type {} not recognized.".format( self._agg_type)) pg.ndata["neigh"] = h_neigh # prop_graph propagation if False: if self._agg_type == "mean": pg.prop_nodes(traversal_order, message_func=fn.copy_src("neigh", "tmp"), reduce_func=fn.sum("tmp", "acc")) h_neigh = h_neigh + pg.ndata["acc"] h_neigh = h_neigh / degs.unsqueeze(-1) elif self._agg_type == "gcn": pg.prop_nodes(traversal_order, message_func=fn.copy_src("neigh", "tmp"), reduce_func=fn.sum("tmp", "acc")) h_neigh = h_neigh + pg.ndata["acc"] h_neigh = (h_self + h_neigh) / (degs.unsqueeze(-1) + 1) elif self._agg_type == "pool": pg.prop_nodes(traversal_order, message_func=fn.copy_src("neigh", "tmp"), reduce_func=fn.max("tmp", "acc")) h_neigh = torch.max(h_neigh, pg.ndata["acc"]) elif self._agg_type == "lstm": h_neighs = [ self._lstm_reducer(h_neigh[ids]) for ids in new_node_ids ] h_neighs = torch.cat(h_neighs, dim=0) ridx = torch.arange(h_neighs.shape[0]) ridx[np.concatenate(new_node_ids)] = torch.arange( h_neighs.shape[0]) h_neigh = h_neighs[ridx] else: if self._agg_type == "mean": h_neighs = [ torch.cumsum(h_neigh[ids], dim=0) for ids in new_node_ids ] h_neighs = torch.cat(h_neighs, dim=0) ridx = torch.arange(h_neighs.shape[0]) ridx[np.concatenate(new_node_ids)] = torch.arange( h_neighs.shape[0]) h_neigh = h_neighs[ridx] h_neigh = h_neigh / degs.unsqueeze(-1) elif self._agg_type == "gcn": h_neighs = [ torch.cumsum(h_neigh[ids], dim=0) for ids in new_node_ids ] h_neighs = torch.cat(h_neighs, dim=0) ridx = torch.arange(h_neighs.shape[0]) ridx[np.concatenate(new_node_ids)] = torch.arange( h_neighs.shape[0]) h_neigh = h_neighs[ridx] h_neigh = (h_self + h_neigh) / (degs.unsqueeze(-1) + 1) elif self._agg_type == "pool": h_neighs = [ torch.cummax(h_neigh[ids], dim=0) for ids in new_node_ids ] h_neighs = torch.cat(h_neighs, dim=0) ridx = torch.arange(h_neighs.shape[0]) ridx[np.concatenate(new_node_ids)] = torch.arange( h_neighs.shape[0]) h_neigh = h_neighs[ridx] elif self._agg_type == "lstm": h_neighs = [ self._lstm_reducer(h_neigh[ids]) for ids in new_node_ids ] h_neighs = torch.cat(h_neighs, dim=0) ridx = torch.arange(h_neighs.shape[0]) ridx[np.concatenate(new_node_ids)] = torch.arange( h_neighs.shape[0]) h_neigh = h_neighs[ridx] if self._agg_type == "gcn": rst = self.fc_neigh(h_neigh) else: rst = self.fc_self(h_self) + self.fc_neigh(h_neigh) return rst
import dgl import dgl.function as fn import torch as th import torch.nn as nn import torch.nn.functional as F from dgl import DGLGraph import Processtest from dgl.nn import GATConv, SAGEConv, GINConv # 消息函数和聚合函数 gcn_msg = fn.u_add_e('h', 'w', 'm') #gcn_msg = fn.copy_src(src='h', out='m') gcn_reduce = fn.sum(msg='m', out='h') class GCNLayer(nn.Module): def __init__(self, in_feats, out_feats): super(GCNLayer, self).__init__() self.linear = nn.Linear(in_feats, out_feats) def forward(self, g, feature): with g.local_scope(): g.ndata['h'] = feature g.update_all(gcn_msg, gcn_reduce) h = g.ndata['h'] return self.linear(h) class Net(nn.Module): def __init__(self):