def pass_messages(self, g): g.apply_edges(GF.u_mul_v('norm', 'norm', 'coef')) g.apply_edges(GF.u_mul_v('x', 'x', 'm2')) g.apply_edges(GF.copy_u('x', 'm1')) g.apply_edges(self.edge_sum) g.update_all(GF.copy_e('m1', 'm1'), GF.sum('m1', 'f1')) g.update_all(GF.copy_e('m2', 'm2'), GF.sum('m2', 'f2'))
def check_flow_compute2(create_node_flow): num_layers = 2 g = generate_rand_graph(100) g.edata['h'] = F.ones((g.number_of_edges(), 10)) nf = create_node_flow(g, num_layers) nf.copy_from_parent() g.ndata['h'] = g.ndata['h1'] nf.layers[0].data['h'] = nf.layers[0].data['h1'] for i in range(num_layers): nf.block_compute(i, SrcMulEdgeMessageFunction( 'h', 'h', 't'), fn.sum('t', 'h1')) nf.block_compute(i, fn.src_mul_edge('h', 'h', 'h'), fn.sum('h', 'h')) g.update_all(fn.src_mul_edge('h', 'h', 'h'), fn.sum('h', 'h')) assert_allclose(F.asnumpy(nf.layers[i + 1].data['h1']), F.asnumpy(nf.layers[i + 1].data['h']), rtol=1e-4, atol=1e-4) assert_allclose(F.asnumpy(nf.layers[i + 1].data['h']), F.asnumpy( g.nodes[nf.layer_parent_nid(i + 1)].data['h']), rtol=1e-4, atol=1e-4) nf = create_node_flow(g, num_layers) g.ndata['h'] = g.ndata['h1'] nf.copy_from_parent() for i in range(nf.num_layers): nf.layers[i].data['h'] = nf.layers[i].data['h1'] for i in range(num_layers): nf.block_compute(i, fn.u_mul_v('h', 'h', 't'), fn.sum('t', 's')) g.update_all(fn.u_mul_v('h', 'h', 't'), fn.sum('t', 's')) assert_allclose(F.asnumpy(nf.layers[i + 1].data['s']), F.asnumpy( g.nodes[nf.layer_parent_nid(i + 1)].data['s']), rtol=1e-4, atol=1e-4)
def GRANDConv(graph, feats, order): ''' Parameters ----------- graph: dgl.Graph The input graph feats: Tensor (n_nodes * feat_dim) Node features order: int Propagation Steps ''' with graph.local_scope(): ''' Calculate Symmetric normalized adjacency matrix \hat{A} ''' degs = graph.in_degrees().float().clamp(min=1) norm = th.pow(degs, -0.5).to(feats.device).unsqueeze(1) graph.ndata['norm'] = norm graph.apply_edges(fn.u_mul_v('norm', 'norm', 'weight')) ''' Graph Conv ''' x = feats y = 0 + feats for i in range(order): graph.ndata['h'] = x graph.update_all(fn.u_mul_e('h', 'weight', 'm'), fn.sum('m', 'h')) x = graph.ndata.pop('h') y.add_(x) return y / (order + 1)
def forward(self, graph): node_num = graph.ndata['h'].size(0) Q = self.query(graph.ndata['h']) K = self.key(graph.ndata['h']) V = self.value(graph.ndata['h']) Q = self.transpose_for_scores(Q) K = self.transpose_for_scores(K) V = self.transpose_for_scores(V) graph.ndata['Q'] = Q graph.ndata['K'] = K graph.ndata['V'] = V graph.apply_edges(fn.u_mul_v('K', 'Q', 'attn_probs')) graph.edata['attn_probs'] = graph.edata['attn_probs'].sum(-1, keepdim=True) graph.edata['attn_probs'] = edge_softmax(graph, graph.edata['attn_probs']) graph.edata['attn_probs'] = self.dropout(graph.edata['attn_probs']) graph.apply_edges(fn.u_mul_e('V', 'attn_probs', 'attn_values')) graph.register_message_func(fn.copy_e('attn_values', 'm')) graph.register_reduce_func(fn.sum('m', 'h')) graph.update_all() graph.ndata['h'] = graph.ndata['h'].view([node_num, -1]) return graph
def forward(self, g, node_mask): # collect features from source nodes and aggregate them in destination nodes g.update_all(fn.copy_src('nodes', 'message'), fn.sum('message', 'message_sum')) msg = g.ndata.pop('message_sum') nodes = self.update_GRU(msg, g.ndata['nodes']) g.apply_edges(fn.u_mul_v('nodes', 'nodes', 'edge_message')) edges = g.edata.pop('edge_spans') * g.edata.pop('edge_message').unsqueeze(-1) return nodes, edges
def calc_weight(g): """计算行归一化的D^(-1/2)AD(-1/2)""" with g.local_scope(): g.ndata['in_degree'] = g.in_degrees().float().pow(-0.5) g.ndata['out_degree'] = g.out_degrees().float().pow(-0.5) g.apply_edges(fn.u_mul_v('out_degree', 'in_degree', 'weight')) g.update_all(fn.copy_e('weight', 'msg'), fn.sum('msg', 'norm')) g.apply_edges(fn.e_div_v('weight', 'norm', 'weight')) return g.edata['weight']
def propagate_attention(self, g): ''' copied from gqp ''' g.apply_edges(fn.u_mul_v('q', 'k', 'e')) e = (g.edata['e'].sum(dim=-1, keepdim=True)) / (self.dk**0.5) g.edata['e'] = self.attn_drop(edge_softmax(g, e)) g.update_all(fn.u_mul_e('v', 'e', 'e'), fn.sum('e', 'v'))
def calc_weight(g): """ Compute row_normalized(D^(-1/2)AD^(-1/2)) """ with g.local_scope(): # compute D^(-0.5)*D(-1/2), assuming A is Identity g.ndata["in_deg"] = g.in_degrees().float().pow(-0.5) g.ndata["out_deg"] = g.out_degrees().float().pow(-0.5) g.apply_edges(fn.u_mul_v("out_deg", "in_deg", "weight")) # row-normalize weight g.update_all(fn.copy_e("weight", "msg"), fn.sum("msg", "norm")) g.apply_edges(fn.e_div_v("weight", "norm", "weight")) return g.edata["weight"]
def forward(self, graph, feat): graph = graph.local_var() if isinstance(feat, tuple): feat_src, feat_dst = feat else: feat_src = feat_dst = feat h_self = feat_dst # DIN attention: 两个向量、两个向量的差、两个向量的积,分别mlp到n_hidden,再相加,再mlp到1 ## 计算两个向量的差和积 graph.srcdata.update({'e_src': feat_src}) graph.dstdata.update({'e_dst': feat_dst}) graph.apply_edges(fn.u_sub_v('e_src', 'e_dst', 'e_sub')) graph.apply_edges(fn.u_mul_v('e_src', 'e_dst', 'e_mul')) ## 分别mlp graph.srcdata["e_src"] = self.atten_src(feat_src) graph.dstdata["e_dst"] = self.atten_dst(feat_dst) graph.edata["e_sub"] = self.atten_sub(graph.edata["e_sub"]) graph.edata["e_mul"] = self.atten_mul(graph.edata["e_mul"]) ## “mlp后相加”代替“concat后mlp” graph.edata["e"] = graph.edata.pop("e_sub") + graph.edata.pop("e_mul") graph.apply_edges(fn.e_add_u('e', 'e_src', 'e')) graph.apply_edges(fn.e_add_v('e', 'e_dst', 'e')) graph.srcdata.pop("e_src") graph.dstdata.pop("e_dst") ## 第一层激活函数 graph.edata["e"] = F.gelu(graph.edata["e"]) ## 第二层mlp变换到1 graph.edata["e"] = self.leaky_relu(self.atten_out(graph.edata["e"])) # max pool graph.srcdata['h'] = F.gelu(self.fc_pool(feat_src)) graph.apply_edges(fn.e_mul_u('e', 'h', 'h')) graph.update_all(fn.copy_e('h', 'm'), fn.max('m', 'neigh')) h_neigh = graph.dstdata['neigh'] # mean pool graph.srcdata['h'] = F.gelu(self.fc_pool2(feat_src)) graph.apply_edges(fn.e_mul_u('e', 'h', 'h')) graph.update_all(fn.copy_e('h', 'm'), fn.mean('m', 'neigh')) h_neigh2 = graph.dstdata['neigh'] # concat rst = self.fc_self(h_self) + self.fc_neigh(h_neigh) + self.fc_neigh2(h_neigh2) # mlps if len(self.out_mlp) > 0: for layer in self.out_mlp: o = layer(F.gelu(rst)) rst = rst + o return rst
def forward(self, g, x, edge_attr): with g.local_scope(): x = self.linear(x) edge_embedding = self.bond_encoder(edge_attr) # Molecular graphs are undirected # g.out_degrees() is the same as g.in_degrees() degs = (g.out_degrees().float() + 1).to(x.device) norm = torch.pow(degs, -0.5).unsqueeze(-1) # (N, 1) g.ndata['norm'] = norm g.apply_edges(fn.u_mul_v('norm', 'norm', 'norm')) g.ndata['x'] = x g.apply_edges(fn.copy_u('x', 'm')) g.edata['m'] = g.edata['norm'] * F.relu(g.edata['m'] + edge_embedding) g.update_all(fn.copy_e('m', 'm'), fn.sum('m', 'new_x')) out = g.ndata['new_x'] + F.relu(x + self.root_emb.weight) * 1. / degs.view(-1, 1) return out
def forward(self, g, feature): g = g.local_var() g.ndata['v'] = self.V(feature).view(-1, self._num_heads, self._out_feats) g.ndata['q'] = self.Q(feature).view(-1, self._num_heads, self._out_feats) g.ndata['k'] = self.K(feature).view(-1, self._num_heads, self._out_feats) g.apply_edges(fn.u_mul_v('q', 'k', 'u')) #e*h*1 u = g.edata['u'].sum(-1, keepdim=True) * (self._out_feats)**(-0.5) a = edge_softmax(g, u) g.edata['a'] = a g.update_all(fn.u_mul_e('v', 'a', 'm'), fn.sum('m', 'ft')) #n*(h*in_feats) rst = self.scale(g.ndata['ft'].view(-1, self._out_feats * self._num_heads)) if self.res_fc is not None: rst = rst.view(-1, self._num_heads, self._in_feats).sum(dim=1) + feature return rst
def forward(self, g, feat_dict): funcs = {} for srctype, etype, dsttype in g.canonical_etypes: g.nodes[dsttype].data['h'] = feat_dict[ dsttype] #nodes' original feature g.nodes[srctype].data['h'] = feat_dict[srctype] g.nodes[srctype].data['t_h'] = self.W_T[etype]( feat_dict[srctype]) #src nodes' transformed feature #compute the attention numerator (exp) g.apply_edges(fn.u_mul_v('t_h', 'h', 'x'), etype=etype) g.edges[etype].data['x'] = torch.exp(self.W_A[etype]( g.edges[etype].data['x'])) #first update to compute the attention denominator (\sum exp) funcs[etype] = (fn.copy_e('x', 'm'), fn.sum('m', 'att')) g.multi_update_all(funcs, 'sum') funcs = {} for srctype, etype, dsttype in g.canonical_etypes: g.apply_edges(fn.e_div_v('x', 'att', 'att'), etype=etype ) #compute attention weights (numerator/denominator) funcs[etype] = (fn.u_mul_e('h', 'att', 'm'), fn.sum('m', 'h')) #\sum(h0*att) -> h1 #second update to obtain h1 g.multi_update_all(funcs, 'sum') #apply activation, layernorm, and dropout feat_dict = {} for ntype in g.ntypes: feat_dict[ntype] = self.dropout( self.layernorm(F.relu_(g.nodes[ntype].data['h'])) ) #apply activation, layernorm, and dropout return feat_dict
def forward(self, g, node_feats, edge_feats): """Update node representations. Parameters ---------- g : DGLGraph DGLGraph for a batch of graphs node_feats : LongTensor of shape (N, 1) Input categorical node features. N for the number of nodes. edge_feats : FloatTensor of shape (E, in_edge_feats) Input edge features. E for the number of edges. Returns ------- FloatTensor of shape (N, hidden_feats) Output node representations """ if self.gnn_type == 'gcn': degs = (g.in_degrees().float() + 1).to(node_feats.device) norm = torch.pow(degs, -0.5).unsqueeze(-1) # (N, 1) g.ndata['norm'] = norm g.apply_edges(fn.u_mul_v('norm', 'norm', 'norm')) norm = g.edata.pop('norm') if self.virtual_node: virtual_node_feats = self.virtual_node_emb( torch.zeros(g.batch_size).to(node_feats.dtype).to( node_feats.device)) h_list = [self.node_encoder(node_feats)] for l in range(len(self.layers)): if self.virtual_node: virtual_feats_broadcast = dgl.broadcast_nodes( g, virtual_node_feats) h_list[l] = h_list[l] + virtual_feats_broadcast if self.gnn_type == 'gcn': h = self.layers[l](g, h_list[l], edge_feats, degs, norm) else: h = self.layers[l](g, h_list[l], edge_feats) if self.batchnorms is not None: h = self.batchnorms[l](h) if self.activation is not None and l != self.n_layers - 1: h = self.activation(h) h = self.dropout(h) h_list.append(h) if l < self.n_layers - 1 and self.virtual_node: ### Update virtual node representation from real node representations virtual_node_feats_tmp = self.virtual_readout( g, h_list[l]) + virtual_node_feats if self.residual: virtual_node_feats = virtual_node_feats + self.dropout( self.mlp_virtual_project[l](virtual_node_feats_tmp)) else: virtual_node_feats = self.dropout( self.mlp_virtual_project[l](virtual_node_feats_tmp)) if self.jk: return torch.stack(h_list, dim=0).sum(0) else: return h_list[-1]
def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running): # define model and optimizer model = gen_model(in_feats, n_classes, args) model = model.to(device) if not args.standard_loss: loss_fcn = loge_cross_entropy else: loss_fcn = cross_entropy optimizer = optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=args.wd) # training loop total_time = 0 best_val_acc, best_test_acc, best_val_loss = 0, 0, float("inf") accs, train_accs, val_accs, test_accs = [], [], [], [] losses, train_losses, val_losses, test_losses = [], [], [], [] ### do nomalization for only one time deg_sqrt, deg_isqrt = compute_norm(graph) graph.srcdata.update({"src_norm": deg_isqrt}) graph.dstdata.update({"dst_norm": deg_isqrt}) graph.apply_edges(fn.u_mul_v("src_norm", "dst_norm", "gcn_norm")) graph.srcdata.update({"src_norm": deg_isqrt}) graph.dstdata.update({"dst_norm": deg_sqrt}) graph.apply_edges(fn.u_mul_v("src_norm", "dst_norm", "gcn_norm_adjust")) checkpoint_path = args.checkpoint_path if args.mode == "student": teacher_output = torch.load(os.path.join(checkpoint_path, f'best_pred_run{n_running}.pt')).cpu().cuda() else: teacher_output = None for epoch in range(1, args.n_epochs + 1): tic = time.time() if args.adjust_lr: adjust_learning_rate(optimizer, args.lr, epoch) acc, loss = train(args, model, graph, labels, train_idx, val_idx, test_idx, optimizer, teacher_output, loss_fcn, evaluator, epoch=epoch) train_acc, val_acc, test_acc, train_loss, val_loss, test_loss, pred = evaluate( args, model, graph, labels, train_idx, val_idx, test_idx, args.use_labels, loss_fcn, evaluator ) toc = time.time() total_time += toc - tic if val_loss < best_val_loss: best_val_loss = val_loss best_val_acc = val_acc best_test_acc = test_acc final_pred = pred if args.mode == "teacher": os.makedirs(checkpoint_path, exist_ok=True) save_checkpoint(final_pred, n_running, checkpoint_path) if epoch % args.log_every == 0: print(f"Run: {n_running}/{args.n_runs}, Epoch: {epoch}/{args.n_epochs}", ) print(f"Time: {(total_time / epoch):.4f}, Loss: {loss.item():.4f}, Acc: {acc:.4f}") print(f"Train/Val/Test loss: {train_loss:.4f}/{val_loss:.4f}/{test_loss:.4f}") print(f"Train/Val/Test/Best val/Best test acc: {train_acc:.4f}/{val_acc:.4f}/{test_acc:.4f}/{best_val_acc:.4f}/{best_test_acc:.4f}") for l, e in zip( [accs, train_accs, val_accs, test_accs, losses, train_losses, val_losses, test_losses], [acc, train_acc, val_acc, test_acc, loss.item(), train_loss, val_loss, test_loss], ): l.append(e) print("*" * 50) print(f"Average epoch time: {total_time / args.n_epochs}, Test acc: {best_test_acc}") if args.plot_curves: plot(accs, train_accs, val_accs, test_accs, losses, train_losses, val_losses, test_losses, n_running, args.n_epochs) if args.save_pred: os.makedirs(args.output_path, exist_ok=True) torch.save(F.softmax(final_pred, dim=1), os.path.join(args.output_path, f"{n_running - 1}.pt")) return best_val_acc, best_test_acc