def forward(self, g, h, h_en): """Forward computation """ with g.local_scope(): h_src, h_dst = expand_as_pair(h) h_src_en, h_dst_en = expand_as_pair(h_en) g.srcdata['x'] = h_src g.dstdata['x'] = h_dst g.srcdata['en'] = h_src_en g.dstdata['en'] = h_dst_en if not self.batch_norm: #g.update_all(self.message, fn.mean('e', 'x')) g.apply_edges(self.message) g.update_all(fn.copy_e('e', 'e'), fn.max('e', 'x')) g.update_all(fn.copy_e('e_en', 'e_en'), fn.mean('e_en', 'en')) else: g.apply_edges(self.message) g.edata['e'] = self.bn(g.edata['e']) g.update_all(fn.copy_e('e', 'e'), fn.max('e', 'x')) g.update_all(fn.copy_e('e_en', 'e_en'), fn.mean('e_en', 'en')) return g.dstdata['x'], g.dstdata['en'] #+ h_en
def _pull_nodes(nodes): # compute ground truth g.pull(nodes, _mfunc_hxw1, _rfunc_m1, _afunc) o1 = g.ndata.pop('o1') g.pull(nodes, _mfunc_hxw2, _rfunc_m2, _afunc) o2 = g.ndata.pop('o2') g.pull(nodes, _mfunc_hxw1, _rfunc_m1max, _afunc) o3 = g.ndata.pop('o3') # v2v spmv g.pull(nodes, fn.src_mul_edge(src='h', edge='w1', out='m1'), fn.sum(msg='m1', out='o1'), _afunc) assert U.allclose(o1, g.ndata.pop('o1')) # v2v fallback to e2v g.pull(nodes, fn.src_mul_edge(src='h', edge='w2', out='m2'), fn.sum(msg='m2', out='o2'), _afunc) assert U.allclose(o2, g.ndata.pop('o2')) # v2v fallback to degree bucketing g.pull(nodes, fn.src_mul_edge(src='h', edge='w1', out='m1'), fn.max(msg='m1', out='o3'), _afunc) assert U.allclose(o3, g.ndata.pop('o3')) # multi builtins, both v2v spmv g.pull(nodes, [fn.src_mul_edge(src='h', edge='w1', out='m1'), fn.src_mul_edge(src='h', edge='w1', out='m2')], [fn.sum(msg='m1', out='o1'), fn.sum(msg='m2', out='o2')], _afunc) assert U.allclose(o1, g.ndata.pop('o1')) assert U.allclose(o1, g.ndata.pop('o2')) # multi builtins, one v2v spmv, one fallback to e2v g.pull(nodes, [fn.src_mul_edge(src='h', edge='w1', out='m1'), fn.src_mul_edge(src='h', edge='w2', out='m2')], [fn.sum(msg='m1', out='o1'), fn.sum(msg='m2', out='o2')], _afunc) assert U.allclose(o1, g.ndata.pop('o1')) assert U.allclose(o2, g.ndata.pop('o2')) # multi builtins, one v2v spmv, one fallback to e2v, one fallback to degree-bucketing g.pull(nodes, [fn.src_mul_edge(src='h', edge='w1', out='m1'), fn.src_mul_edge(src='h', edge='w2', out='m2'), fn.src_mul_edge(src='h', edge='w1', out='m3')], [fn.sum(msg='m1', out='o1'), fn.sum(msg='m2', out='o2'), fn.max(msg='m3', out='o3')], _afunc) assert U.allclose(o1, g.ndata.pop('o1')) assert U.allclose(o2, g.ndata.pop('o2')) assert U.allclose(o3, g.ndata.pop('o3'))
def forward(self, graph, feat, e_feat): r"""Compute GraphSAGE layer. Parameters ---------- graph : DGLGraph The graph. feat : torch.Tensor The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes. Returns ------- torch.Tensor The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` is size of output feature. """ graph = graph.local_var() feat = self.feat_drop(feat) h_self = feat graph.edata['e'] = e_feat if self._aggre_type == 'sum': graph.ndata['h'] = feat graph.update_all(fn.u_mul_e('h', 'e', 'm'), fn.sum('m', 'neigh')) h_neigh = graph.ndata['neigh'] elif self._aggre_type == 'mean': graph.ndata['h'] = feat graph.update_all(fn.u_mul_e('h', 'e', 'm'), fn.mean('m', 'neigh')) h_neigh = graph.ndata['neigh'] elif self._aggre_type == 'gcn': graph.ndata['h'] = feat graph.update_all(fn.u_mul_e('h', 'e', 'm'), fn.sum('m', 'neigh')) # divide in_degrees degs = graph.in_degrees().float() degs = degs.to(feat.device) h_neigh = (graph.ndata['neigh'] + graph.ndata['h']) / (degs.unsqueeze(-1) + 1) elif self._aggre_type == 'pool': graph.ndata['h'] = F.relu(self.fc_pool(feat)) graph.update_all(fn.u_mul_e('h', 'e', 'm'), fn.max('m', 'neigh')) h_neigh = graph.ndata['neigh'] elif self._aggre_type == 'lstm': graph.ndata['h'] = feat graph.update_all(fn.u_mul_e('h', 'e', 'm'), self._lstm_reducer) h_neigh = graph.ndata['neigh'] else: raise KeyError('Aggregator type {} not recognized.'.format( self._aggre_type)) # GraphSAGE GCN does not require fc_self. if self._aggre_type == 'gcn': rst = self.fc_neigh(h_neigh) else: rst = self.fc_self(h_self) + self.fc_neigh(h_neigh) # activation if self.activation is not None: rst = self.activation(rst) # normalization if self.norm is not None: rst = self.norm(rst) return rst
def forward(self, g, h, e): h_in = h # for residual connection if self.dgl_builtin == False: h = self.dropout(h) g.ndata['h'] = h #g.update_all(fn.copy_src(src='h', out='m'), # self.aggregator, # self.nodeapply) if self.aggregator_type == 'maxpool': g.ndata['h'] = self.aggregator.linear(g.ndata['h']) g.ndata['h'] = self.aggregator.activation(g.ndata['h']) g.update_all(fn.copy_src('h', 'm'), fn.max('m', 'c'), self.nodeapply) elif self.aggregator_type == 'lstm': g.update_all(fn.copy_src(src='h', out='m'), self.aggregator, self.nodeapply) else: g.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'c'), self.nodeapply) h = g.ndata['h'] else: # For original graphs # h = self.sageconv(g, h) # For reduced graphs h = self.sageconv(g, h, edge_weight=e) if self.batch_norm: h = self.batchnorm_h(h) if self.residual: h = h_in + h # residual connection return h
def sample_frontier(self, block_id, g, seed_nodes): fanout = self.fanouts[block_id] if self.fanouts is not None else None # List of neighbors to sample per edge type for each GNN layer, starting from the first layer. g = dgl.in_subgraph(g, seed_nodes) g.remove_edges(torch.where(g.edata['timestamp'] > self.ts)[0]) if self.args.valid_path: if block_id != self.args.n_layer - 1: g.dstdata['sample_time'] = self.frontiers[block_id + 1].srcdata['sample_time'] g.apply_edges(self.sample_prob) g.remove_edges(torch.where(g.edata['timespan'] < 0)[0]) g_re=dgl.reverse(g,copy_edata=True,copy_ndata=True) g_re.update_all(self.sample_time,fn.max('st','sample_time')) g=dgl.reverse(g_re,copy_edata=True,copy_ndata=True) if fanout is None: frontier = g else: if block_id == self.args.n_layer - 1: if self.args.bandit: frontier = dgl.sampling.sample_neighbors(g,seed_nodes,fanout,prob='q_ij') else: frontier = dgl.sampling.sample_neighbors(g, seed_nodes, fanout) else: frontier = dgl.sampling.sample_neighbors(g, seed_nodes, fanout) self.frontiers[block_id] = frontier return frontier
def agg(self, x, B): h = x x = self.dropout(x) for i in range(self.K): if i == 0: if self.aggregator == 'pool': x = torch.matmul(x, self.weight_pool_in) if self.bias: x = x + self.bias_in if self.aggregator == 'gcn': B[i].srcdata['h'] = torch.matmul(x, self.weight_gcn_in) else: B[i].srcdata['h'] = x B[i].dstdata['h'] = x[:B[i].number_of_dst_nodes()] else: if self.aggregator == 'pool': hh = torch.matmul(B[i - 1].dstdata['h'], self.weight_pool_hid) if self.bias: hh = hh + self.bias_hid else: hh = B[i - 1].dstdata['h'] if self.aggregator == 'gcn': B[i].srcdata['h'] = torch.matmul(hh, self.weight_gcn_hid) else: B[i].srcdata['h'] = hh B[i].dstdata['h'] = hh[:B[i].number_of_dst_nodes()] if self.aggregator == 'gcn': B[i].update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh')) elif self.aggregator == 'mean': B[i].update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh')) elif self.aggregator == 'lstm': B[i].update_all(fn.copy_src('h', 'm'), self.lstm_reducer_in if i == 0 else self.lstm_reducer_hid) else: B[i].update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh')) h_neigh = B[i].dstdata['neigh'] if i == 0: h = torch.matmul(B[i].dstdata['h'], self.weight_in[0, :, :]) \ + (torch.matmul(h_neigh, self.weight_in[1, :, :]) if self.aggregator != 'gcn' else 0) if self.bias: h = h + self.bias_in_k[0, :] + (self.bias_in_k[1, :] if self.aggregator != 'gcn' else 0) elif i == self.K - 1: h = torch.matmul(B[i].dstdata['h'], self.weight_out[0, :, :])\ + (torch.matmul(h_neigh, self.weight_out[1, :, :]) if self.aggregator != 'gcn' else 0) if self.bias: h = h + self.bias_out_k[0, :] + (self.bias_out_k[1, :] if self.aggregator != 'gcn' else 0) else: h = torch.matmul(B[i].dstdata['h'], self.weight_hid[i - 1, 0, :, :])\ + (torch.matmul(h_neigh, self.weight_hid[i - 1, 1, :, :]) if self.aggregator != 'gcn' else 0) if self.bias: h = h + self.bias_hid_k[0, :] + (self.bias_hid_k[1, :] if self.aggregator != 'gcn' else 0) if self.activation and i != self.K - 1: h = self.activation(h, inplace=False) if i != self.K - 1: h = self.dropout(h) if self.norm: norm = torch.norm(h, dim=1) norm = norm + (norm == 0).long() h = h / norm.unsqueeze(-1) B[i].dstdata['h'] = h return h
def fit(self, train_labels, train_mask): """Trains the model. Parameters ---------- train_labels: torch.LongTensor Tensor of target data of size n_train_nodes. train_mask: torch.ByteTensor Boolean mask of size n_nodes indicating the nodes used in training. """ # Add initial node labels if train_labels.is_cuda: init_labels = torch.cuda.FloatTensor(self.graph.number_of_nodes()).fill_(0) else: init_labels = torch.zeros(self.graph.number_of_nodes(), dtype=torch.float) init_labels[train_mask] = train_labels.float() self.graph.ndata["l"] = init_labels # Propagate self.graph.update_all( message_func=fn.copy_src(src="l", out="m"), reduce_func=fn.max(msg="m", out="l"), ) # Put back positive seed nodes self.graph.ndata["l"] = torch.max(self.graph.ndata["l"], init_labels) self.predictions = self.graph.ndata["l"]
def collate(self, items): ''' items: edge id in graph g. We sample iteratively k-times and batch them into one single subgraph. ''' current_ts = self.g.edata['timestamp'][ items[0]] #only sample edges before current timestamp self.graph_sampler.ts = current_ts # restore the current timestamp to the graph sampler. # if link prefiction, we use a negative_sampler to generate neg-graph for loss computing. if self.negative_sampler is None: neg_pair_graph = None input_nodes, pair_graph, blocks = self._collate(items) else: input_nodes, pair_graph, neg_pair_graph, blocks = self._collate_with_negative_sampling( items) # we sampling k-hop subgraph and batch them into one graph for i in range(self.n_layer - 1): self.graph_sampler.frontiers[0].add_edges( *self.graph_sampler.frontiers[i + 1].edges()) frontier = self.graph_sampler.frontiers[0] # computing node last-update timestamp frontier.update_all(fn.copy_e('timestamp', 'ts'), fn.max('ts', 'timestamp')) return input_nodes, pair_graph, neg_pair_graph, [frontier]
def track_time(graph_name, format, feat_size, msg_type, reduce_type): device = utils.get_bench_device() graph = utils.get_graph(graph_name, format) graph = graph.to(device) graph.ndata['h'] = torch.randn((graph.num_nodes(), feat_size), device=device) graph.edata['e'] = torch.randn((graph.num_edges(), 1), device=device) msg_builtin_dict = { 'copy_u': fn.copy_u('h', 'x'), 'u_mul_e': fn.u_mul_e('h', 'e', 'x'), } reduce_builtin_dict = { 'sum': fn.sum('x', 'h_new'), 'mean': fn.mean('x', 'h_new'), 'max': fn.max('x', 'h_new'), } # dry run graph.update_all(msg_builtin_dict[msg_type], reduce_builtin_dict[reduce_type]) # timing with utils.Timer() as t: for i in range(3): graph.update_all(msg_builtin_dict[msg_type], reduce_builtin_dict[reduce_type]) return t.elapsed_secs / 3
def forward(self, g, h): g = g.local_var() if not self.use_pp or not self.training: norm = self.get_norm(g) # g.ndata['h'] = h # g.update_all(fn.copy_src(src='h', out='m'), # fn.sum(msg='m', out='h')) # ah = g.ndata.pop('h') if self._aggre_type == 'mean': g.ndata['h'] = h g.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'h')) ah = g.ndata.pop('h') elif self._aggre_type == 'gcn': g.ndata['h'] = h g.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'h')) # divide in_degrees # degs = graph.in_degrees().float() # degs = degs.to(feat.device) # h_neigh = (graph.ndata['neigh'] + graph.ndata['h']) / (degs.unsqueeze(-1) + 1) ah = g.ndata.pop('h') ah = ah * norm elif self._aggre_type == 'pool': g.ndata['h'] = F.relu(self.fc_pool(h)) g.update_all(fn.copy_src('h', 'm'), fn.max('m', 'h')) ah = g.ndata['h'] elif self._aggre_type == 'lstm': g.ndata['h'] = h g.update_all(fn.copy_src('h', 'm'), self._lstm_reducer) ah = g.ndata['h'] elif self._aggre_type == 'attn': feat = self.fc_attn(h).view(-1, self.num_heads, self._in_feats) el = (feat * self.attn_l).sum(dim=-1).unsqueeze(-1) er = (feat * self.attn_r).sum(dim=-1).unsqueeze(-1) g.ndata.update({'ft': feat, 'el': el, 'er': er}) g.apply_edges(fn.u_add_v('el', 'er', 'e')) e = self.leaky_relu(g.edata.pop('e')) g.edata['a'] = edge_softmax(g, e) g.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'ft')) ah = g.ndata['ft'] ah = ah.squeeze(1) else: raise KeyError('Aggregator type {} not recognized.'.format( self._aggre_type)) h = self.concat(h, ah, norm) if self.dropout: h = self.dropout(h) # GraphSAGE GCN does not require fc_self. # if self._aggre_type == 'gcn': # rst = self.fc_neigh(ah) # else: # rst = self.fc_self(h) + self.fc_neigh(ah) h = self.linear(h) h = self.lynorm(h) if self.activation: h = self.activation(h) return h
def get_current_ts(pos_graph, neg_graph): with pos_graph.local_scope(): pos_graph_ = dgl.add_reverse_edges(pos_graph, copy_edata=True) pos_graph_.update_all(fn.copy_e('timestamp', 'times'), fn.max('times', 'ts')) current_ts = pos_ts = pos_graph_.ndata['ts'] num_pos_nodes = pos_graph_.num_nodes() with neg_graph.local_scope(): neg_graph_ = dgl.add_reverse_edges(neg_graph) neg_graph_.edata['timestamp'] = pos_graph_.edata['timestamp'] neg_graph_.update_all(fn.copy_e('timestamp', 'times'), fn.max('times', 'ts')) num_pos_nodes = torch.where(pos_graph_.ndata['ts'] > 0)[0].shape[0] pos_ts = pos_graph_.ndata['ts'][:num_pos_nodes] neg_ts = neg_graph_.ndata['ts'][num_pos_nodes:] current_ts = torch.cat([pos_ts, neg_ts]) return current_ts, pos_ts, num_pos_nodes
def __init__(self, pool_type): super(GraphPooling, self).__init__() self.pool_type = pool_type if pool_type == 'mean': self.reduce_func = fn.mean(msg='m', out='h') elif pool_type == 'max': self.reduce_func = fn.max(msg='m', out='h') elif pool_type == 'min': self.reduce_func = fn.min(msg='m', out='h')
def pool_agg(self, g): x = g.ndata['x'] x = self.dropout(x) h = torch.matmul(x, self.weight_pool_in) if self.bias: h = h + self.bias_in g.srcdata['h'] = h for i in range(self.K): if i == 0: g.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh')) h_neigh = g.dstdata['neigh'] h = torch.matmul(torch.cat([g.srcdata['h'], h_neigh], dim=1), self.weight_in) if self.activation: h = self.activation(h, inplace=False) norm = torch.norm(h, dim=1) h = h / (norm.unsqueeze(-1) + 0.05) g.srcdata['h'] = h elif i == self.K - 1: h = torch.matmul(g.srcdata['h'], self.weight_pool_hid) if self.bias: h = h + self.bias_hid g.srcdata['h'] = h g.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh')) h_neigh = g.dstdata['neigh'] h = torch.matmul(torch.cat([g.srcdata['h'], h_neigh], dim=1), self.weight_out) norm = torch.norm(h, dim=1) h = h / (norm.unsqueeze(-1) + 0.05) g.ndata['z'] = h else: h = torch.matmul(g.srcdata['h'], self.weight_pool_hid) if self.bias: h = h + self.bias_hid g.srcdata['h'] = h g.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh')) h_neigh = g.dstdata['neigh'] h = torch.matmul(torch.cat([g.srcdata['h'], h_neigh], dim=1), self.weight_hid[i-1, :, :]) if self.activation: h = self.activation(h, inplace=False) norm = torch.norm(h, dim=1) h = h / (norm.unsqueeze(-1) + 0.05) g.srcdata['h'] = h return g
def edge_softmax(self, g): # compute the max g.update_all(fn.copy_edge('a', 'a'), fn.max('a', 'a_max')) # minus the max and exp g.apply_edges(lambda edges: {'a': torch.exp(edges.data['a'] - edges.dst['a_max'])}) # compute dropout g.apply_edges( lambda edges: {'a_drop': self.attn_drop(edges.data['a'])}) # compute normalizer g.update_all(fn.copy_edge('a', 'a'), fn.sum('a', 'z'))
def _test(fld): def message_func(edges): return {'m': edges.src[fld]} def message_func_edge(edges): if len(edges.src[fld].shape) == 1: return {'m': edges.src[fld] * edges.data['e1']} else: return {'m': edges.src[fld] * edges.data['e2']} def reduce_func(nodes): return {fld: mx.nd.max(nodes.mailbox['m'], axis=1)} def apply_func(nodes): return {fld: 2 * nodes.data[fld]} g = simple_graph() # update all v1 = g.ndata[fld] g.update_all(fn.copy_src(src=fld, out='m'), fn.max(msg='m', out=fld), apply_func) v2 = g.ndata[fld] g.set_n_repr({fld: v1}) g.update_all(message_func, reduce_func, apply_func) v3 = g.ndata[fld] assert np.allclose(v2.asnumpy(), v3.asnumpy(), rtol=1e-05, atol=1e-05) # update all with edge weights v1 = g.ndata[fld] g.update_all(fn.src_mul_edge(src=fld, edge='e1', out='m'), fn.max(msg='m', out=fld), apply_func) v2 = g.ndata[fld] g.set_n_repr({fld: v1}) g.update_all(fn.src_mul_edge(src=fld, edge='e2', out='m'), fn.max(msg='m', out=fld), apply_func) v3 = g.ndata[fld].squeeze() g.set_n_repr({fld: v1}) g.update_all(message_func_edge, reduce_func, apply_func) v4 = g.ndata[fld] assert np.allclose(v2.asnumpy(), v3.asnumpy(), rtol=1e-05, atol=1e-05) assert np.allclose(v3.asnumpy(), v4.asnumpy(), rtol=1e-05, atol=1e-05)
def forward(self, graph, x): num_nan = 0 for item in self.fc_msg: if isinstance(item, nn.Linear): num_nan += item.weight.isnan().sum() for item in self.fc_udt: if isinstance(item, nn.Linear): num_nan += item.weight.isnan().sum() if num_nan > 0: print("nan is found in model parameters.") graph.ndata['in_feats'] = x graph.update_all(self.message, fn.max('m', 'r')) return self.fc_udt( torch.cat([graph.dstdata['in_feats'], graph.dstdata['r']], dim=1))
def forward(self, g, x): with g.local_scope(): g.ndata['x'] = x g.ndata['z'] = self.gate_m(x) g.update_all(fn.copy_u('x', 'x'), fn.mean('x', 'mean_z')) g.update_all(fn.copy_u('z', 'z'), fn.max('z', 'max_z')) nft = torch.cat([g.ndata['x'], g.ndata['max_z'], g.ndata['mean_z']], dim=1) gate = self.gate_fn(nft).sigmoid() attn_out = self.gatlayer(g, x) node_num = g.num_nodes() gated_out = ((gate.view(-1)*attn_out.view(-1, self.out_feats).T).T).view( node_num, self.num_heads, self.out_feats) gated_out = gated_out.mean(1) merge = self.merger_layer(torch.cat([x, gated_out], dim=1)) return merge
def forward(self, graph, feat): graph = graph.local_var() if isinstance(feat, tuple): feat_src, feat_dst = feat else: feat_src = feat_dst = feat h_self = feat_dst # DIN attention: 两个向量、两个向量的差、两个向量的积,分别mlp到n_hidden,再相加,再mlp到1 ## 计算两个向量的差和积 graph.srcdata.update({'e_src': feat_src}) graph.dstdata.update({'e_dst': feat_dst}) graph.apply_edges(fn.u_sub_v('e_src', 'e_dst', 'e_sub')) graph.apply_edges(fn.u_mul_v('e_src', 'e_dst', 'e_mul')) ## 分别mlp graph.srcdata["e_src"] = self.atten_src(feat_src) graph.dstdata["e_dst"] = self.atten_dst(feat_dst) graph.edata["e_sub"] = self.atten_sub(graph.edata["e_sub"]) graph.edata["e_mul"] = self.atten_mul(graph.edata["e_mul"]) ## “mlp后相加”代替“concat后mlp” graph.edata["e"] = graph.edata.pop("e_sub") + graph.edata.pop("e_mul") graph.apply_edges(fn.e_add_u('e', 'e_src', 'e')) graph.apply_edges(fn.e_add_v('e', 'e_dst', 'e')) graph.srcdata.pop("e_src") graph.dstdata.pop("e_dst") ## 第一层激活函数 graph.edata["e"] = F.gelu(graph.edata["e"]) ## 第二层mlp变换到1 graph.edata["e"] = self.leaky_relu(self.atten_out(graph.edata["e"])) # max pool graph.srcdata['h'] = F.gelu(self.fc_pool(feat_src)) graph.apply_edges(fn.e_mul_u('e', 'h', 'h')) graph.update_all(fn.copy_e('h', 'm'), fn.max('m', 'neigh')) h_neigh = graph.dstdata['neigh'] # mean pool graph.srcdata['h'] = F.gelu(self.fc_pool2(feat_src)) graph.apply_edges(fn.e_mul_u('e', 'h', 'h')) graph.update_all(fn.copy_e('h', 'm'), fn.mean('m', 'neigh')) h_neigh2 = graph.dstdata['neigh'] # concat rst = self.fc_self(h_self) + self.fc_neigh(h_neigh) + self.fc_neigh2(h_neigh2) # mlps if len(self.out_mlp) > 0: for layer in self.out_mlp: o = layer(F.gelu(rst)) rst = rst + o return rst
def forward(self, nf, logits): r"""Compute edge softmax. Parameters ---------- nf : NodeFlow logits : torch.Tensor The input edge feature Returns ------- Unnormalized scores : torch.Tensor This part gives :math:`\exp(z_{ij})`'s Normalizer : torch.Tensor This part gives :math:`\sum_{j\in\mathcal{N}(i)}\exp(z_{ij})` Notes ----- * Input shape: :math:`(N, *, 1)` where * means any number of additional dimensions, :math:`N` is the number of edges. * Unnormalized scores shape: :math:`(N, *, 1)` where all but the last dimension are the same shape as the input. * Normalizer shape: :math:`(M, *, 1)` where :math:`M` is the number of nodes and all but the first and the last dimensions are the same as the input. """ self._logits_name = get_edata_name(nf, self.index, self._logits_name) self._max_logits_name = get_ndata_name(nf, self.index + 1, self._max_logits_name) self._normalizer_name = get_ndata_name(nf, self.index + 1, self._normalizer_name) nf.blocks[self.index].data[self._logits_name] = logits # compute the softmax nf.block_compute(self.index, fn.copy_edge(self._logits_name, self._logits_name), fn.max(self._logits_name, self._max_logits_name)) # minus the max and exp nf.apply_block(self.index, lambda edges: { self._logits_name : torch.exp(edges.data[self._logits_name] - edges.dst[self._max_logits_name])}) # pop out temporary feature _max_logits, otherwise get_ndata_name could have huge overhead nf.layers[self.index + 1].data.pop(self._max_logits_name) # compute normalizer nf.block_compute(self.index, fn.copy_edge(self._logits_name, self._logits_name), fn.sum(self._logits_name, self._normalizer_name)) return nf.blocks[self.index].data.pop(self._logits_name), \ nf.layers[self.index + 1].data.pop(self._normalizer_name)
def forward(self, g, feat): with g.local_scope(): if self.aggre_type == 'attention': h_src = self.feat_drop(feat[0]).view(-1, self.num_heads, self.in_size) h_dst = self.feat_drop(feat[1]).view(-1, self.num_heads, self.in_size) el = (h_src * self.attn_l).sum(dim=-1).unsqueeze(-1) # er = (h_dst * self.attn_r).sum(dim=-1).unsqueeze(-1) g.srcdata.update({'ft': h_src, 'el': el}) # g.srcdata.update({'ft': h_src, 'er': er}) g.apply_edges(fn.copy_u('el', 'e')) # g.apply_edges(fn.u_add_v('el', 'er', 'e')) e = self.leaky_relu(g.edata.pop('e')) g.edata['a'] = self.attn_drop(edge_softmax(g, e)) g.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'ft')) rst = g.dstdata['ft'].flatten(1) if self.residual: rst = rst + h_dst.flatten(1) if self.activation: rst = self.activation(rst) elif self.aggre_type == 'mean': h_src = self.feat_drop(feat[0]).view(-1, self.in_size*self.num_heads) h_dst = self.feat_drop(feat[1]).view(-1, self.in_size * self.num_heads) g.srcdata['ft'] = h_src g.update_all(fn.copy_u('ft', 'm'), fn.mean('m', 'ft')) rst = g.dstdata['ft'] # + h_dst elif self.aggre_type == 'pool': h_src = self.feat_drop(feat[0]).view(-1, self.in_size*self.num_heads) h_dst = self.feat_drop(feat[1]).view(-1, self.in_size * self.num_heads) g.srcdata['ft'] = F.relu(self.fc_pool(h_src)) g.update_all(fn.copy_u('ft', 'm'), fn.max('m', 'ft')) rst = g.dstdata['ft'] #+ h_dst return rst
def forward(self, graph: dgl.DGLGraph, feats: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Args: graph: the graph feats: node features with node type as key and the corresponding features as value. Each tensor is of shape (N, D) where N is the number of nodes of the corresponding node type, and D is the feature size. Returns: updated node features. Each tensor is of shape (N, D) where N is the number of nodes of the corresponding node type, and D is the feature size. """ graph = graph.local_var() # assign data for nt, ft in feats.items(): graph.nodes[nt].data.update({"ft": ft}) for et in self.etypes: # option 1 graph[et].update_all(fn.copy_u("ft", "m"), fn.mean("m", "mean"), etype=et) graph[et].update_all(fn.copy_u("ft", "m"), fn.max("m", "max"), etype=et) nt = et[2] graph.apply_nodes(self._concatenate_node_feat, ntype=nt) # copy update feature from new_ft to ft graph.nodes[nt].data.update({"ft": graph.nodes[nt].data["new_ft"]}) return {nt: graph.nodes[nt].data["ft"] for nt in feats}
def forward(self, g, h): h_in = h # for residual connection if self.dgl_builtin == False: h = self.dropout(h) g.ndata['h'] = h #g.update_all(fn.copy_src(src='h', out='m'), # self.aggregator, # self.nodeapply) if self.aggregator_type == 'maxpool': g.ndata['h'] = self.aggregator.linear(g.ndata['h']) g.ndata['h'] = self.aggregator.activation(g.ndata['h']) g.update_all(fn.copy_src('h', 'm'), fn.max('m', 'c'), self.nodeapply) elif self.aggregator_type == 'lstm': g.update_all(fn.copy_src(src='h', out='m'), self.aggregator, self.nodeapply) elif self.aggregator_type == 'sumpool': P = torch.clamp(self.P, 1, 100) g.ndata['h_pow'] = torch.abs(g.ndata['h']).pow(P) g.update_all(fn.copy_src('h_pow', 'm'), fn.sum('m', 'c'), self.nodeapply) else: g.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'c'), self.nodeapply) h = g.ndata['h'] else: h = self.sageconv(g, h) if self.batch_norm: h = self.batchnorm_h(h) if self.residual: h = h_in + h # residual connection return h
def forward(self, nf): if self.preprocess: for i in range(nf.num_layers): h = nf.layers[i].data.pop('features') neigh = nf.layers[i].data.pop('neigh') if self.dropout: h = self.dropout(h) h = self.fc_self(h) + self.fc_neigh(neigh) skip_start = (0 == self.n_layers - 1) if skip_start: h = torch.cat((h, self.activation(h)), dim=1) else: h = self.activation(h) nf.layers[i].data['h'] = h else: for lid in range(nf.num_layers): nf.layers[lid].data['h'] = nf.layers[lid].data.pop('features') for lid, layer in enumerate(self.layers): for i in range(lid, nf.num_layers - 1): h = nf.layers[i].data.pop('h') h = self.dropout(h) nf.layers[i].data['h'] = h if self.aggregator_type == 'mean': nf.block_compute(i, fn.copy_src(src='h', out='m'), fn.mean('m', 'neigh'), layer) elif self.aggregator_type == 'gcn': nf.block_compute(i, fn.copy_src(src='h', out='m'), fn.sum('m', 'neigh'), layer) elif self.aggregator_type == 'pool': nf.block_compute(i, fn.copy_src(src='h', out='m'), fn.max('m', 'neigh'), layer) elif self.aggregator_type == 'lstm': reducer = self.reducer[i] def _reducer(self, nodes): m = nodes.mailbox['m'] # (B, L, D) batch_size = m.shape[0] h = (m.new_zeros((1, batch_size, self._in_feats)), m.new_zeros((1, batch_size, self._in_feats))) _, (rst, _) = reducer(m, h) return {'neigh': rst.squeeze(0)} nf.block_compute(i, fn.copy_src(src='h', out='m'), _reducer, layer) else: raise KeyError('Aggregator type {} not recognized.'.format(self.aggregator_type)) # set up new feat for i in range(lid + 1, nf.num_layers): h = nf.layers[i].data.pop('activation') nf.layers[i].data['h'] = h h = nf.layers[nf.num_layers - 1].data.pop('h') return h
def forward(self, graph, feat): r""" Description ----------- Compute GraphSAGE layer. Parameters ---------- graph : DGLGraph The graph. feat : torch.Tensor or pair of torch.Tensor If a torch.Tensor is given, it represents the input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes. If a pair of torch.Tensor is given, the pair must contain two tensors of shape :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`. Returns ------- torch.Tensor The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` is size of output feature. """ with graph.local_scope(): if isinstance(feat, tuple): feat_src = self.feat_drop(feat[0]) feat_dst = self.feat_drop(feat[1]) else: feat_src = feat_dst = self.feat_drop(feat) if graph.is_block: feat_dst = feat_src[:graph.number_of_dst_nodes()] h_self = feat_dst # Handle the case of graphs without edges if graph.number_of_edges() == 0: graph.dstdata['neigh'] = torch.zeros( feat_dst.shape[0], self._in_src_feats).to(feat_dst) if self._aggre_type == 'mean': graph.srcdata['h'] = feat_src graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh')) h_neigh = graph.dstdata['neigh'] elif self._aggre_type == 'gcn': check_eq_shape(feat) graph.srcdata['h'] = feat_src graph.dstdata['h'] = feat_dst # same as above if homogeneous graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh')) # divide in_degrees degs = graph.in_degrees().to(feat_dst) h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1) elif self._aggre_type == 'pool': graph.srcdata['h'] = F.relu(self.fc_pool(feat_src)) graph.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh')) h_neigh = graph.dstdata['neigh'] elif self._aggre_type == 'lstm': graph.srcdata['h'] = feat_src graph.update_all(fn.copy_src('h', 'm'), self._lstm_reducer) h_neigh = graph.dstdata['neigh'] elif self._aggre_type == 'ginmean': graph.srcdata['h'] = feat_src graph.update_all(fn.copy_src('h', 'm'), self._gin_reducer('m', 'neigh')) h_neigh = graph.dstdata['neigh'] elif self._aggre_type == 'cheb': def unnLaplacian(feat, D_invsqrt_left, D_invsqrt_right, graph): """ Operation Feat * D^-1/2 A D^-1/2 但是如果写成矩阵乘法:D^-1/2 A D^-1/2 Feat""" #tmp = torch.zeros((D_invsqrt.shape[0],D_invsqrt.shape[0])).to(graph.device) # sparse tensor没有broadcast机制,最后还依赖于srcnode在feat中从0开始连续排布 #print("adj : ",graph.adj(transpose=False,ctx = graph.device).shape) #graph.srcdata['h'] = (torch.mm((graph.adj(transpose=False,ctx = graph.device)),(feat * D_invsqrt)))*D_invsqrt[::graph.number_of_dst_nodes()] #graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'h')) #return graph.srcdata['h'] graph.srcdata[ 'h'] = feat * D_invsqrt_right # feat is srcfeat graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) return graph.dstdata.pop('h') * D_invsqrt_left D_invsqrt_right = torch.pow( graph.out_degrees().float().clamp(min=1), -0.5).unsqueeze(-1) D_invsqrt_left = torch.pow( graph.in_degrees().float().clamp(min=1), -0.5).unsqueeze(-1) #print("D_invsqrt shape: ",D_invsqrt.shape) #print(graph.__dict__) #print(dir(graph)) #graph.srcdata['h']=feat_src #graph.dstdata['h']=feat_dst #g = dgl.to_homogeneous(graph,ndata=['h']) #dgl._ffi.base.DGLError: Expect number of features to match number of nodes (len(u)). Got 70 and 76 instead. #print(g) # since the block is different every time so it's safe to call dgl's method every time instead of calculating the l_m ahead try: lambda_max = laplacian_lambda_max(graph) except BaseException: # if the largest eigenvalue is not found dgl_warning( "Largest eigonvalue not found, using default value 2 for lambda_max", RuntimeWarning) lambda_max = torch.tensor(2) # .to(feat.device) if isinstance(lambda_max, list): lambda_max = torch.tensor(lambda_max) # .to(feat.device) if lambda_max.dim() == 1: lambda_max = lambda_max.unsqueeze(-1) # (B,) to (B, 1) # broadcast from (B, 1) to (N, 1) # lambda_max = lambda_max * torch.ones((feat.shape[0],1)) #re_norm = (2 / lambda_max ) * torch.ones((graph.number_of_dst_nodes(),1)).to(graph.device) re_norm = (2 / lambda_max.to(graph.device)) * torch.ones( (graph.number_of_dst_nodes(), 1), device=graph.device) self._cheb_Xt = X_0 = feat_dst graph.srcdata[ 'h'] = feat_src * D_invsqrt_right # feat is srcfeat graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) X_1 = -re_norm * graph.dstdata['h'] * D_invsqrt_left + X_0 * ( re_norm - 1) self._cheb_Xt = torch.cat((self._cheb_Xt, X_1.float()), 1) else: raise KeyError('Aggregator type {} not recognized.'.format( self._aggre_type)) # GraphSAGE GCN does not require fc_self. if self._aggre_type == 'gcn': rst = self.fc_neigh(h_neigh) elif self._aggre_type == 'ginmean': rst = (1 + self.eps) * h_self + h_neigh rst = self.fc_gin(rst) if self.norm is not None: rst = self.norm(rst) return rst elif self._aggre_type == 'cheb': rst = self._cheb_linear(self._cheb_Xt) else: rst = self.fc_self(h_self) + self.fc_neigh(h_neigh) # activation if self.activation is not None: rst = self.activation(rst) # normalization if self.norm is not None: rst = self.norm(rst) return rst
# Graph Conv and Relational Graph Conv import dgl.function as fn import torch.nn as nn import torch.nn.functional as F from allennlp.common import FromParams gcn_msg = fn.copy_src(src='h', out='m') gcn_reduce_sum = fn.sum(msg='m', out='h') gcn_reduce_max = fn.max(msg='m', out='h') # gcn_reduce_u_mul_v = fn.u_mul_v('m', 'h') from depricated.archival_gnns import GraphEncoder class NodeApplyModule(nn.Module): def __init__(self, in_feats, out_feats, activation): super(NodeApplyModule, self).__init__() self.linear = nn.Linear(in_feats, out_feats) self.activation = activation def forward(self, node): h = self.linear(node.data['h']) h = self.activation(h) return {'h': h} class GCN(nn.Module): def __init__(self, in_feats, out_feats, activation): super(GCN, self).__init__() self.apply_mod = NodeApplyModule(in_feats, out_feats, activation) def forward(self, g, feature):
def forward(self, g, features): ''' Inputs: g: The graph features: H^{l}, i.e. Node features with shape [num_nodes, features_per_node] Returns: rst: H^{l+1}, i.e. Node embeddings of the l+1 layer (depth) with the shape [num_nodes, hidden_per_node] Variables: msg_func: Message function, i.e. What to be aggregated (e.g. Sending node embeddings) reduce_func: Reduce function, i.e. How to aggregate (e.g. Summing neighbor embeddings) Notice: 'h' means node feature/embedding itself, 'm' means node's mailbox ''' # create an independent instance of the graph to manipulate g = g.local_var() # H^{k-1}_{v} h_self = features # calculate H^{k}_{N(v)} in line 4 of the algorithm 1 # based on different aggregators if self._aggre_type == 'mean': g.ndata['h'] = features msg_func = fn.copy_src('h', 'm') reduce_func = fn.mean('m', 'neigh') g.update_all(msg_func, reduce_func) # h_neigh is H^{k}_{N(v)} h_neigh = g.ndata.pop('neigh') elif self._aggre_type == 'gcn': # part of equation (2) in the paper g.ndata['h'] = features msg_func = fn.copy_src('h', 'm') reduce_func = fn.sum('m', 'neigh') g.update_all(msg_func, reduce_func) h_neigh = g.ndata.pop('neigh') # H^{k-1}_{v} U H^{k-1}_{u} in equation (2) # g.ndata.pop('neigh') represents {H^{k-1}_{u} for u /belongs N(v)} # g.dstdata['h'] represents {H^{k-1}_{v}} h_neigh = h_neigh + g.ndata.pop('h') # divide in_degrees: MEAN() operation in equation (2) degs = g.in_degrees().to(features) # Notice: h_neigh is more than H^{k}_{N(u)} h_neigh = h_neigh / (degs.unsqueeze(-1) + 1) elif self._aggre_type == 'pool': g.ndata['h'] = F.relu(self.fc_pool(features)) msg_func = fn.copy_src('h', 'm') reduce_func = fn.max('m', 'neigh') g.update_all(msg_func, reduce_func) # h_neigh is H^{k}_{N(v)} h_neigh = g.ndata.pop('neigh') else: raise KeyError('Aggregator type {} not recognized.'.format( self._aggre_type)) # calculate H^{k}_{v} in line 5 of the algorithm 1 if self._aggre_type == 'gcn': rst = self.fc_neigh(h_neigh) else: rst = self.fc_self(h_self) + self.fc_neigh(h_neigh) # activation if self._activation_func is not None: rst = self._activation_func(rst) # normalization in line 7 of the algorithm 1 # l2_norm = torch.norm(rst, p=2, dim=1) # l2_norm = l2_norm.unsqueeze(1) # rst = torch.div(rst, l2_norm) return rst
def test_update_all_multi_fallback(): # create a graph with zero in degree nodes g = dgl.DGLGraph() g.add_nodes(10) for i in range(1, 9): g.add_edge(0, i) g.add_edge(i, 9) g.ndata['h'] = th.randn(10, D) g.edata['w1'] = th.randn(16,) g.edata['w2'] = th.randn(16, D) def _mfunc_hxw1(edges): return {'m1' : edges.src['h'] * th.unsqueeze(edges.data['w1'], 1)} def _mfunc_hxw2(edges): return {'m2' : edges.src['h'] * edges.data['w2']} def _rfunc_m1(nodes): return {'o1' : th.sum(nodes.mailbox['m1'], 1)} def _rfunc_m2(nodes): return {'o2' : th.sum(nodes.mailbox['m2'], 1)} def _rfunc_m1max(nodes): return {'o3' : th.max(nodes.mailbox['m1'], 1)[0]} def _afunc(nodes): ret = {} for k, v in nodes.data.items(): if k.startswith('o'): ret[k] = 2 * v return ret # compute ground truth g.update_all(_mfunc_hxw1, _rfunc_m1, _afunc) o1 = g.ndata.pop('o1') g.update_all(_mfunc_hxw2, _rfunc_m2, _afunc) o2 = g.ndata.pop('o2') g.update_all(_mfunc_hxw1, _rfunc_m1max, _afunc) o3 = g.ndata.pop('o3') # v2v spmv g.update_all(fn.src_mul_edge(src='h', edge='w1', out='m1'), fn.sum(msg='m1', out='o1'), _afunc) assert U.allclose(o1, g.ndata.pop('o1')) # v2v fallback to e2v g.update_all(fn.src_mul_edge(src='h', edge='w2', out='m2'), fn.sum(msg='m2', out='o2'), _afunc) assert U.allclose(o2, g.ndata.pop('o2')) # v2v fallback to degree bucketing g.update_all(fn.src_mul_edge(src='h', edge='w1', out='m1'), fn.max(msg='m1', out='o3'), _afunc) assert U.allclose(o3, g.ndata.pop('o3')) # multi builtins, both v2v spmv g.update_all([fn.src_mul_edge(src='h', edge='w1', out='m1'), fn.src_mul_edge(src='h', edge='w1', out='m2')], [fn.sum(msg='m1', out='o1'), fn.sum(msg='m2', out='o2')], _afunc) assert U.allclose(o1, g.ndata.pop('o1')) assert U.allclose(o1, g.ndata.pop('o2')) # multi builtins, one v2v spmv, one fallback to e2v g.update_all([fn.src_mul_edge(src='h', edge='w1', out='m1'), fn.src_mul_edge(src='h', edge='w2', out='m2')], [fn.sum(msg='m1', out='o1'), fn.sum(msg='m2', out='o2')], _afunc) assert U.allclose(o1, g.ndata.pop('o1')) assert U.allclose(o2, g.ndata.pop('o2')) # multi builtins, one v2v spmv, one fallback to e2v, one fallback to degree-bucketing g.update_all([fn.src_mul_edge(src='h', edge='w1', out='m1'), fn.src_mul_edge(src='h', edge='w2', out='m2'), fn.src_mul_edge(src='h', edge='w1', out='m3')], [fn.sum(msg='m1', out='o1'), fn.sum(msg='m2', out='o2'), fn.max(msg='m3', out='o3')], _afunc) assert U.allclose(o1, g.ndata.pop('o1')) assert U.allclose(o2, g.ndata.pop('o2')) assert U.allclose(o3, g.ndata.pop('o3'))
def forward(self, agg_graph: dgl.DGLGraph, prop_graph: dgl.DGLGraph, traversal_order, new_node_ids) -> torch.Tensor: tg = agg_graph.local_var() pg = prop_graph.local_var() nfeat = tg.ndata["nfeat"] # h_self = nfeat h_self = self.encode_time(nfeat, tg.ndata["timestamp"]) tg.ndata["nfeat"] = h_self tg.edata["efeat"] = self.fc_edge(tg.edata["efeat"]) # efeat = tg.edata["efeat"] # tg.apply_edges(lambda edges: { # "efeat": # torch.cat((edges.src["nfeat"], edges.data["efeat"]), dim=1) # }) # tg.edata["efeat"] = self.encode_time(tg.edata["efeat"], tg.edata["timestamp"]) degs = tg.ndata["degree"] # agg_graph aggregation if self._agg_type == "pool": tg.edata["efeat"] = F.relu(self.fc_pool(tg.edata["efeat"])) tg.update_all(fn.u_add_e("nfeat", "efeat", "m"), fn.max("m", "neigh")) h_neigh = tg.ndata["neigh"] elif self._agg_type in ["mean", "gcn", "lstm"]: tg.update_all(fn.u_add_e("nfeat", "efeat", "m"), fn.sum("m", "neigh")) h_neigh = tg.ndata["neigh"] else: raise KeyError("Aggregator type {} not recognized.".format( self._agg_type)) pg.ndata["neigh"] = h_neigh # prop_graph propagation if False: if self._agg_type == "mean": pg.prop_nodes(traversal_order, message_func=fn.copy_src("neigh", "tmp"), reduce_func=fn.sum("tmp", "acc")) h_neigh = h_neigh + pg.ndata["acc"] h_neigh = h_neigh / degs.unsqueeze(-1) elif self._agg_type == "gcn": pg.prop_nodes(traversal_order, message_func=fn.copy_src("neigh", "tmp"), reduce_func=fn.sum("tmp", "acc")) h_neigh = h_neigh + pg.ndata["acc"] h_neigh = (h_self + h_neigh) / (degs.unsqueeze(-1) + 1) elif self._agg_type == "pool": pg.prop_nodes(traversal_order, message_func=fn.copy_src("neigh", "tmp"), reduce_func=fn.max("tmp", "acc")) h_neigh = torch.max(h_neigh, pg.ndata["acc"]) elif self._agg_type == "lstm": h_neighs = [ self._lstm_reducer(h_neigh[ids]) for ids in new_node_ids ] h_neighs = torch.cat(h_neighs, dim=0) ridx = torch.arange(h_neighs.shape[0]) ridx[np.concatenate(new_node_ids)] = torch.arange( h_neighs.shape[0]) h_neigh = h_neighs[ridx] else: if self._agg_type == "mean": h_neighs = [ torch.cumsum(h_neigh[ids], dim=0) for ids in new_node_ids ] h_neighs = torch.cat(h_neighs, dim=0) ridx = torch.arange(h_neighs.shape[0]) ridx[np.concatenate(new_node_ids)] = torch.arange( h_neighs.shape[0]) h_neigh = h_neighs[ridx] h_neigh = h_neigh / degs.unsqueeze(-1) elif self._agg_type == "gcn": h_neighs = [ torch.cumsum(h_neigh[ids], dim=0) for ids in new_node_ids ] h_neighs = torch.cat(h_neighs, dim=0) ridx = torch.arange(h_neighs.shape[0]) ridx[np.concatenate(new_node_ids)] = torch.arange( h_neighs.shape[0]) h_neigh = h_neighs[ridx] h_neigh = (h_self + h_neigh) / (degs.unsqueeze(-1) + 1) elif self._agg_type == "pool": h_neighs = [ torch.cummax(h_neigh[ids], dim=0) for ids in new_node_ids ] h_neighs = torch.cat(h_neighs, dim=0) ridx = torch.arange(h_neighs.shape[0]) ridx[np.concatenate(new_node_ids)] = torch.arange( h_neighs.shape[0]) h_neigh = h_neighs[ridx] elif self._agg_type == "lstm": h_neighs = [ self._lstm_reducer(h_neigh[ids]) for ids in new_node_ids ] h_neighs = torch.cat(h_neighs, dim=0) ridx = torch.arange(h_neighs.shape[0]) ridx[np.concatenate(new_node_ids)] = torch.arange( h_neighs.shape[0]) h_neigh = h_neighs[ridx] if self._agg_type == "gcn": rst = self.fc_neigh(h_neigh) else: rst = self.fc_self(h_self) + self.fc_neigh(h_neigh) return rst
# Sends a message of node feature h # Equivalent to => return {'m': edges.src['h']} # msg = fn.copy_src(src='h', out='m') # # def reduce(nodes): # accum = torch.mean(nodes.mailbox['m'], 1) # return {'h': accum} # def msg_func(edges): # return {'m': torch.mul(edges.data['feat'], edges.src['h'])} msg_func = fn.u_mul_e('h', 'feat', 'm') reduce_mean = fn.mean('m', 'h') reduce_sum = fn.sum('m', 'h') reduce_max = fn.max('m', 'h') # def reduce(nodes): # accum = torch.sum(nodes.mailbox['m'], 1) # return {'h': accum} class NodeApplyModule(nn.Module): # Update node feature h_v with (Wh_v+b) def __init__(self, in_dim, out_dim): super().__init__() self.linear = nn.Linear(in_dim, out_dim) def forward(self, node): h = self.linear(node.data['h']) return {'h': h}
def forward(self, g, features): ''' Inputs: g: The graph features: H^{l}, BLOCK.SRC and BLOCK.DST features in tuple with shape [N_{src}, D_{in_{src}] and [N_{dst}, D_{in_{dst}] where 'D_{in}' is size of input feature Returns: rst: H^{l+1}, Node embeddings of the l+1 layer (depth) with the shape [N_{dst}, D_{out}] Variables: msg_func: Message function, i.e. What to be aggregated (e.g. Sending node embeddings) reduce_func: Reduce function, i.e. How to aggregate (e.g. Summing neighbor embeddings) Notice: 'h' means node feature/embedding itself, 'm' means node's mailbox ''' # create an independent instance of the graph to manipulate g = g.local_var() # split (feature_src, feature_dst) feat_src = features[0] feat_dst = features[1] # H^{k-1}_{u} h_self = feat_dst # calculate H^{k}_{N(u)} in line 11 of the algorithm 2 # different aggregators: aggregate neighbor (block.src) information # in this case, g.srcdata and g.dstdata will be more convenient, they # should be identical to g.ndata if self._aggre_type == 'mean': g.srcdata['h'] = feat_src msg_func = fn.copy_src('h', 'm') reduce_func = fn.mean('m', 'neigh') g.update_all(msg_func, reduce_func) # h_neigh is H^{k}_{N(u)} h_neigh = g.dstdata['neigh'] elif self._aggre_type == 'gcn': # check whether feat_src and feat_dst has the same shape # otherwise we can't sum later dgl.utils.check_eq_shape(features) # part of equation (2) in the paper g.srcdata['h'] = feat_src g.dstdata['h'] = feat_dst msg_func = fn.copy_src('h', 'm') reduce_func = fn.sum('m', 'neigh') g.update_all(msg_func, reduce_func) h_neigh = g.dstdata['neigh'] # H^{k-1}_{v} U H^{k-1}_{u} in equation (2) # g.dstdata['neigh'] represents BLOCK.DST with aggregation from SRC # g.dstdata['h'] represents original BLOCK.DST without aggregation h_neigh = h_neigh + g.dstdata['h'] # divide in_degrees: MEAN() operation in equation (2) degs = g.in_degrees().to(feat_dst) # Notice: h_neigh is more than H^{k}_{N(u)} h_neigh = h_neigh / (degs.unsqueeze(-1) + 1) elif self._aggre_type == 'pool': # equation (3) in the paper g.srcdata['h'] = self.relu(self.fc_pool(feat_src)) msg_func = fn.copy_src('h', 'm') reduce_func = fn.max('m', 'neigh') g.update_all(msg_func, reduce_func) # h_neigh is H^{k}_{N(u)} h_neigh = g.dstdata['neigh'] else: raise KeyError('Aggregator type {} not recognized.'.format( self._aggre_type)) # calculate H^{k}_{v} in line 11 of the algorithm 2 # Notice: GCN aggregator is different than in others, see equation (2) if self._aggre_type == 'gcn': rst = self.fc_neigh(h_neigh) else: # line 12 of the algorithm 2 rst = self.fc_self(h_self) + self.fc_neigh(h_neigh) # activation if self._activation_func is not None: rst = self._activation_func(rst) # normalization in line 13 of the algorithm 2 # l2_norm = torch.norm(rst, p=2, dim=1) # l2_norm = l2_norm.unsqueeze(1) # rst = torch.div(rst, l2_norm) return rst