def test_simple_readout(): g1 = dgl.DGLGraph() g1.add_nodes(3) g2 = dgl.DGLGraph() g2.add_nodes(4) # no edges g1.add_edges([0, 1, 2], [2, 0, 1]) n1 = F.randn((3, 5)) n2 = F.randn((4, 5)) e1 = F.randn((3, 5)) s1 = F.sum(n1, 0) # node sums s2 = F.sum(n2, 0) se1 = F.sum(e1, 0) # edge sums m1 = F.mean(n1, 0) # node means m2 = F.mean(n2, 0) me1 = F.mean(e1, 0) # edge means w1 = F.randn((3, )) w2 = F.randn((4, )) max1 = F.max(n1, 0) max2 = F.max(n2, 0) maxe1 = F.max(e1, 0) ws1 = F.sum(n1 * F.unsqueeze(w1, 1), 0) ws2 = F.sum(n2 * F.unsqueeze(w2, 1), 0) wm1 = F.sum(n1 * F.unsqueeze(w1, 1), 0) / F.sum(F.unsqueeze(w1, 1), 0) wm2 = F.sum(n2 * F.unsqueeze(w2, 1), 0) / F.sum(F.unsqueeze(w2, 1), 0) g1.ndata['x'] = n1 g2.ndata['x'] = n2 g1.ndata['w'] = w1 g2.ndata['w'] = w2 g1.edata['x'] = e1 assert F.allclose(dgl.sum_nodes(g1, 'x'), s1) assert F.allclose(dgl.sum_nodes(g1, 'x', 'w'), ws1) assert F.allclose(dgl.sum_edges(g1, 'x'), se1) assert F.allclose(dgl.mean_nodes(g1, 'x'), m1) assert F.allclose(dgl.mean_nodes(g1, 'x', 'w'), wm1) assert F.allclose(dgl.mean_edges(g1, 'x'), me1) assert F.allclose(dgl.max_nodes(g1, 'x'), max1) assert F.allclose(dgl.max_edges(g1, 'x'), maxe1) g = dgl.batch([g1, g2]) s = dgl.sum_nodes(g, 'x') m = dgl.mean_nodes(g, 'x') max_bg = dgl.max_nodes(g, 'x') assert F.allclose(s, F.stack([s1, s2], 0)) assert F.allclose(m, F.stack([m1, m2], 0)) assert F.allclose(max_bg, F.stack([max1, max2], 0)) ws = dgl.sum_nodes(g, 'x', 'w') wm = dgl.mean_nodes(g, 'x', 'w') assert F.allclose(ws, F.stack([ws1, ws2], 0)) assert F.allclose(wm, F.stack([wm1, wm2], 0)) s = dgl.sum_edges(g, 'x') m = dgl.mean_edges(g, 'x') max_bg_e = dgl.max_edges(g, 'x') assert F.allclose(s, F.stack([se1, F.zeros(5)], 0)) assert F.allclose(m, F.stack([me1, F.zeros(5)], 0)) assert F.allclose(max_bg_e, F.stack([maxe1, F.zeros(5)], 0))
def forward(self, graph): # graph: a dgl graph # use node degree as the initial node feature h = graph.in_degrees() h1 = h.view(-1, 1).float() h2 = (h - 3) > 0 h2 = h2.view(-1, 1).float() h3 = 3 / h1 h4 = (h - 4) > 0 h4 = h4.view(-1, 1).float() h_ = th.cat((h1, h2, h3, h4), 1) h_at = self.gat1(h_, graph) h_at = self.gat2(h_at, graph) h_at = self.gat3(h_at, graph) graph.ndata['h'] = h_at # calculate graph representation graph_emb = dgl.max_nodes(graph, 'h') h_at2 = self.drop_layer1(graph_emb) pred = F.sigmoid(self.classify(h_at2)) return pred, graph_emb, graph.ndata['h']
def forward(self, graph): # graph: a dgl graph # use node degree as the initial node feature # and binary variable if node has fractional value h = graph.in_degrees() h1 = h.view(-1, 1).float() h2 = (h - 3) > 0 h2 = h2.view(-1, 1).float() h3 = 3 / h1 h4 = (h - 4) > 0 h4 = h4.view(-1, 1).float() h_ = th.cat((h1, h2, h3, h4), 1) # perform graph convolution and activation function (relu) h_co = F.relu(self.conv1(graph, h_)) h_co = F.relu(self.conv2(graph, h_co)) graph.ndata['h'] = h_co # calculate graph representation graph_emb = dgl.max_nodes(graph, 'h') h_emb = F.relu(self.layer1(graph_emb)) h_emb = self.drop_layer1(h_emb) pred = F.sigmoid(self.layer2(h_emb)) return pred
def forward(self, g, h, e, snorm_n, snorm_e): h = self.embedding_h(h) h = self.in_feat_dropout(h) if self.edge_feat: e = self.embedding_e(e) # Loop all layers for i, conv in enumerate(self.layers): # Graph conv layers h_t = conv(g, h, e, snorm_n) h = h_t # Virtual node layer if self.virtual_node_layers is not None: if i == 0: vn_h = 0 if i < len(self.virtual_node_layers): vn_h, h = self.virtual_node_layers[i].forward(g, h, vn_h) g.ndata['h'] = h # Readout layer if self.readout == "sum": hg = dgl.sum_nodes(g, 'h') elif self.readout == "max": hg = dgl.max_nodes(g, 'h') elif self.readout == "mean": hg = dgl.mean_nodes(g, 'h') else: hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes return self.MLP_layer(hg)
def forward(self, g, h, e, snorm_n, snorm_e): h = self.embedding_h(h) h = self.in_feat_dropout(h) h_init = h '''for conv in self.layers: h = conv(g, h, snorm_n) h = self.joining_layer(h_init + h)''' for i in range(self.layer_count): conv = self.layers[i] joint = self.joining_layers[i] h = conv(g, h, snorm_n) h = joint(h_init + h) g.ndata['h'] = h if self.readout == "sum": hg = dgl.sum_nodes(g, 'h') elif self.readout == "max": hg = dgl.max_nodes(g, 'h') elif self.readout == "mean": hg = dgl.mean_nodes(g, 'h') else: hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes return self.MLP_layer(hg)
def forward(self, g, h, e, snorm_n, snorm_e): h = self.embedding_h(h) h = self.in_feat_dropout(h) if self.pos_enc_dim > 0: h_pos_enc = self.embedding_pos_enc(g.ndata['pos_enc'].to( self.device)) h = h + h_pos_enc if self.edge_feat: e = self.embedding_e(e) for i, conv in enumerate(self.layers): h_t = conv(g, h, e, snorm_n) h = h_t g.ndata['h'] = h if self.readout == "sum": hg = dgl.sum_nodes(g, 'h') elif self.readout == "max": hg = dgl.max_nodes(g, 'h') elif self.readout == "mean": hg = dgl.mean_nodes(g, 'h') else: hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes return self.MLP_layer(hg)
def forward(self, g, h, e, pos_enc=None): # input embedding if self.pos_enc: h = self.embedding_pos_enc(pos_enc) else: h = self.embedding_h(h) # computing the 'pseudo' named tensor which depends on node degrees g.ndata['deg'] = g.in_degrees() g.apply_edges(self.compute_pseudo) pseudo = g.edata['pseudo'].to(self.device).float() for i in range(len(self.layers)): h = self.layers[i](g, h, self.pseudo_proj[i](pseudo)) g.ndata['h'] = h if self.readout == "sum": hg = dgl.sum_nodes(g, 'h') elif self.readout == "max": hg = dgl.max_nodes(g, 'h') elif self.readout == "mean": hg = dgl.mean_nodes(g, 'h') else: hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes return self.MLP_layer(hg)
def forward(self, g, h, e, snorm_n, snorm_e): h = self.embedding_h(h) #h = self.in_feat_dropout(h) if self.JK == 'sum': h_list = [h] for i, conv in enumerate(self.layers): h_t = conv(g, h, e, snorm_n) if self.gru_enable and i != len(self.layers) - 1: h_t = self.gru(h, h_t) h = h_t if self.JK == 'sum': h_list.append(h) g.ndata['h'] = h if self.JK == 'last': g.ndata['h'] = h elif self.JK == 'sum': h = 0 for layer in h_list: h += layer g.ndata['h'] = h if self.readout == "sum": hg = dgl.sum_nodes(g, 'h') elif self.readout == "max": hg = dgl.max_nodes(g, 'h') elif self.readout == "mean": hg = dgl.mean_nodes(g, 'h') else: hg = None return self.MLP_layer(hg)
def forward(self, feats, bg): """Multi-task prediction for a batch of molecules Parameters ---------- feats : FloatTensor of shape (N, M0) Initial features for all atoms in the batch of molecules bg : BatchedDGLGraph B Batched DGLGraphs for processing multiple molecules in parallel Returns ------- FloatTensor of shape (B, n_tasks) Soft prediction for all tasks on the batch of molecules """ # Update atom features for gcn in self.gcn_layers: feats = gcn(feats, bg) # Compute molecule features from atom features bg.ndata[self.atom_data_field] = feats bg.ndata[self.atom_weight_field] = self.atom_weighting(feats) h_g_sum = dgl.sum_nodes(bg, self.atom_data_field, self.atom_weight_field) h_g_max = dgl.max_nodes(bg, self.atom_data_field) h_g = torch.cat([h_g_sum, h_g_max], dim=1) # Multi-task prediction return self.soft_classifier(h_g)
def forward(self, g, h, e): h = self.embedding_h(h) h = self.in_feat_dropout(h) for conv in self.layers: # For reduced graphs h = conv(g, h, e) # For original graphs # h = conv(g, h) g.ndata['h'] = h if self.readout == "sum": # For reduced graphs hg = dgl.sum_nodes(g, feat='h', weight='weight') # For original graphs # hg = dgl.sum_nodes(g, feat= 'h') elif self.readout == "max": # For reduced graphs hg = dgl.max_nodes(g, feat='h', weight='weight') # For original graphs # hg = dgl.max_nodes(g, feat= 'h') elif self.readout == "mean": # For reduced graphs hg = dgl.mean_nodes(g, feat='h', weight='weight') # For original graphs # hg = dgl.mean_nodes(g, feat= 'h') else: # For reduced graphs hg = dgl.mean_nodes( g, feat='h', weight='weight') # default readout is mean nodes # For original graphs # hg = dgl.mean_nodes(g, feat= 'h') return self.MLP_layer(hg)
def forward(self, g, h, e, snorm_n, snorm_e, mlp=True, head=False, return_graph=False): h = self.embedding_h(h) h = self.in_feat_dropout(h) for conv in self.layers: h = conv(g, h, snorm_n) g.ndata['h'] = h if return_graph: return g if self.readout == "sum": hg = dgl.sum_nodes(g, 'h') elif self.readout == "max": hg = dgl.max_nodes(g, 'h') elif self.readout == "mean": hg = dgl.mean_nodes(g, 'h') else: hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes if mlp: return self.MLP_layer(hg) else: if head: return self.projection_head(hg) else: return hg
def forward(self, g, h, e, snorm_n, snorm_e): h = self.embedding_h(h) h = self.in_feat_dropout(h) if self.pos_enc_dim > 0: h_pos_enc = self.embedding_pos_enc(g.ndata['pos_enc'].to( self.device)) h = h + h_pos_enc if self.JK == 'sum': h_list = [h] if self.edge_feat: e = self.embedding_e(e) for i, conv in enumerate(self.layers): h_t = conv(g, h, e, snorm_n) if self.gru_enable and i != len(self.layers) - 1: h_t = self.gru(h, h_t) h = h_t if self.JK == 'sum': h_list.append(h) g.ndata['h'] = h if self.JK == 'last': g.ndata['h'] = h elif self.JK == 'sum': h = 0 for layer in h_list: h += layer g.ndata['h'] = h if self.readout == "sum": hg = dgl.sum_nodes(g, 'h') elif self.readout == "max": hg = dgl.max_nodes(g, 'h') elif self.readout == "mean": hg = dgl.mean_nodes(g, 'h') elif self.readout == "directional_abs": g.ndata['dir'] = h * torch.abs(g.ndata['eig'][:, 1:2].to( self.device)) / torch.sum(torch.abs(g.ndata['eig'][:, 1:2].to( self.device)), dim=1, keepdim=True) hg = torch.cat([dgl.mean_nodes(g, 'dir'), dgl.mean_nodes(g, 'h')], dim=1) elif self.readout == "directional": g.ndata['dir'] = h * g.ndata['eig'][:, 1:2].to( self.device) / torch.sum(torch.abs(g.ndata['eig'][:, 1:2].to( self.device)), dim=1, keepdim=True) hg = torch.cat( [torch.abs(dgl.mean_nodes(g, 'dir')), dgl.mean_nodes(g, 'h')], dim=1) else: hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes return self.MLP_layer(hg)
def forward(self, g, h, e, snorm_n, snorm_e): # modified dtype for new dataset h = h.float() h = self.embedding_lin(h.cuda()) h_in = h # for residual connection # list of hidden representation at each layer (including input) hidden_rep = [h] for i in range(self.n_layers): h = self.ginlayers[i](g, h, snorm_n) # Residual Connection if self.residual: if self.residual == "gated": z = torch.sigmoid(self.W_g(torch.cat([h, h_in], dim=1))) h = z * h + (torch.ones_like(z) - z) * h_in else: h += h_in g.ndata['h'] = self.linear_ro(h) if self.readout == "sum": hg = dgl.sum_nodes(g, 'h') elif self.readout == "max": hg = dgl.max_nodes(g, 'h') elif self.readout == "mean": hg = dgl.mean_nodes(g, 'h') else: hg = dgl.sum_nodes(g, 'h') # default readout is summation score = self.linear_prediction(hg) return score
def forward(self, g, h, e, snorm_n, snorm_e): # modified dtype for new dataset h = h.float() h = self.embedding_lin(h) h = self.in_feat_dropout(h) for conv in self.layers: h_in = h h = conv(g, h, snorm_n) if self.residual: if self.residual == "gated": z = torch.sigmoid(self.W_g(torch.cat([h, h_in], dim=1))) h = z * h + (torch.ones_like(z) - z) * h_in else: h += h_in g.ndata['h'] = self.linear_ro(h) if self.readout == "sum": hg = dgl.sum_nodes(g, 'h') elif self.readout == "max": hg = dgl.max_nodes(g, 'h') elif self.readout == "mean": hg = dgl.mean_nodes(g, 'h') else: hg = dgl.sum_nodes(g, 'h') # default readout is summation return self.linear_predict(hg)
def forward(self, g, h, e, h_lap_pos_enc=None, h_wl_pos_enc=None): # input embedding h = self.embedding_h(h) h = self.in_feat_dropout(h) if self.lap_pos_enc: h_lap_pos_enc = self.embedding_lap_pos_enc(h_lap_pos_enc.float()) h = h + h_lap_pos_enc if self.wl_pos_enc: h_wl_pos_enc = self.embedding_wl_pos_enc(h_wl_pos_enc) h = h + h_wl_pos_enc if not self.edge_feat: # edge feature set to 1 e = torch.ones(e.size(0), 1).to(self.device) e = self.embedding_e(e) # convnets for conv in self.layers: h, e = conv(g, h, e) g.ndata['h'] = h if self.readout == "sum": hg = dgl.sum_nodes(g, 'h') elif self.readout == "max": hg = dgl.max_nodes(g, 'h') elif self.readout == "mean": hg = dgl.mean_nodes(g, 'h') else: hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes return self.MLP_layer(hg)
def forward(self, g, h, e, snorm_n, snorm_e): h = self.embedding_h(h) # computing the 'pseudo' named tensor which depends on node degrees us, vs = g.edges() # to avoid zero division in case in_degree is 0, we add constant '1' in all node degrees denoting self-loop pseudo = [[ 1 / np.sqrt(g.in_degree(us[i]) + 1), 1 / np.sqrt(g.in_degree(vs[i]) + 1) ] for i in range(g.number_of_edges())] pseudo = torch.Tensor(pseudo).to(self.device) for i in range(len(self.layers)): h = self.layers[i](g, h, self.pseudo_proj[i](pseudo), snorm_n) g.ndata['h'] = h if self.readout == "sum": hg = dgl.sum_nodes(g, 'h') elif self.readout == "max": hg = dgl.max_nodes(g, 'h') elif self.readout == "mean": hg = dgl.mean_nodes(g, 'h') else: hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes return self.MLP_layer(hg)
def forward(self, bg, feats): """Readout Parameters ---------- bg : DGLGraph DGLGraph for a batch of graphs. feats : FloatTensor of shape (N, M1) * N is the total number of nodes in the batch of graphs * M1 is the input node feature size, which must match in_feats in initialization Returns ------- h_g : FloatTensor of shape (B, 2 * M1) * B is the number of graphs in the batch * M1 is the input node feature size, which must match in_feats in initialization """ h_g_sum = self.weight_and_sum(bg, feats) with bg.local_scope(): bg.ndata['h'] = feats h_g_max = dgl.max_nodes(bg, 'h') h_g = torch.cat([h_g_sum, h_g_max], dim=1) return h_g
def forward(self, g): inputs = g.ndata['feat'].view(-1, 73).float() h = self.conv1(g, inputs) h = self.relu(h) h = self.dropout(h) h = self.conv2(g, h) h = self.relu(h) h = self.dropout(h) h = self.conv3(g, h) h = self.relu(h) g.ndata['h'] = h h = dgl.max_nodes(g, 'h') # full connect h = h.reshape((-1, 73 * 2)) h = self.fc1(h) h = self.sig(h) h = self.dropout(h) h = self.fc2(h) out = self.sig(h) return out
def forward(self, g, graph_pooling): """ Forward pass on the graph. :param g: The graph :param graph_pooling: Binary value indicating if the GAT embedding must be pooled (with max) in order to have an output on the global graph. Otherwise, output are node-dependant :return: prediction of the GAT network """ for l, layer in enumerate(self.embedding_layer[:-1]): g = layer(g) g.ndata["n_feat"] = torch.relu(g.ndata["n_feat"]) g.edata["e_feat"] = torch.relu(g.edata["e_feat"]) last_layer = self.embedding_layer[-1] g = last_layer(g) g.ndata["n_feat"] = torch.relu(g.ndata["n_feat"]) g.edata["e_feat"] = torch.relu(g.edata["e_feat"]) if graph_pooling: out = dgl.max_nodes(g, "n_feat") for l, layer in enumerate(self.fc_layer): out = torch.relu(layer(out)) out = self.fc_out(out) return out else: for l, layer in enumerate(self.fc_layer): g.ndata["n_feat"] = torch.relu(layer(g.ndata["n_feat"])) g.ndata["n_feat"] = self.fc_out(g.ndata["n_feat"]) return g
def forward(self, g): h = g.ndata['attr'] h = h.to(self.device) # list of hidden representation at each layer (including input) hidden_rep = [h] for layer in range(self.num_layers - 1): h = self.ginlayers[layer](g, h) hidden_rep.append(h) score_over_layer = 0 # perform pooling over all nodes in each graph in every layer for layer, h in enumerate(hidden_rep): g.ndata['h'] = h if self.graph_pooling_type == 'sum': pooled_h = dgl.sum_nodes(g, 'h') elif self.graph_pooling_type == 'mean': pooled_h = dgl.mean_nodes(g, 'h') elif self.graph_pooling_type == 'max': pooled_h = dgl.max_nodes(g, 'h') else: raise NotImplementedError() score_over_layer += F.dropout( self.linears_prediction[layer](pooled_h), self.final_dropout, training=self.training) return score_over_layer
def forward(self, g, h): h = self._forward(g, h) if self._data_type == 'nc': h = self.classifier(h) return h elif self._data_type in ['gc', 'rg']: g.ndata['h'] = h if self._readout == "sum": hg = dgl.sum_nodes(g, 'h') elif self._readout == "max": hg = dgl.max_nodes(g, 'h') elif self._readout == "mean": hg = dgl.mean_nodes(g, 'h') else: hg = dgl.mean_nodes(g, 'h') hg = self.classifier(hg) return hg elif self._data_type in ['ec']: def _edge_feat(edges): e = torch.cat([edges.src['h'], edges.dst['h']], dim=1) e = self.classifier(e) return {'e': e} g.ndata['h'] = h g.apply_edges(_edge_feat) return g.edata['e']
def forward(self, g, node_feats): """Computes graph representations out of node features. Parameters ---------- g : DGLGraph DGLGraph for a batch of graphs. node_feats : float32 tensor of shape (V, node_feats) Input node features, V for the number of nodes. Returns ------- graph_feats : float32 tensor of shape (G, graph_feats) Graph representations computed. G for the number of graphs. """ node_feats = self.in_project(node_feats) if self.activation is not None: node_feats = self.activation(node_feats) node_feats = self.out_project(node_feats) with g.local_scope(): g.ndata['h'] = node_feats if self.mode == 'max': graph_feats = dgl.max_nodes(g, 'h') elif self.mode == 'mean': graph_feats = dgl.mean_nodes(g, 'h') elif self.mode == 'sum': graph_feats = dgl.sum_nodes(g, 'h') return graph_feats
def forward(self, bg, feats): """Multi-task prediction for a batch of molecules Parameters ---------- bg : BatchedDGLGraph B Batched DGLGraphs for processing multiple molecules in parallel feats : FloatTensor of shape (N, M0) Initial features for all atoms in the batch of molecules Returns ------- FloatTensor of shape (B, n_tasks) Soft prediction for all tasks on the batch of molecules """ # Update atom features with GNNs for gnn in self.gnn_layers: feats = gnn(bg, feats) # Compute molecule features from atom features h_g_sum = self.weighted_sum_readout(bg, feats) with bg.local_scope(): bg.ndata['h'] = feats h_g_max = dgl.max_nodes(bg, 'h') if not isinstance(bg, BatchedDGLGraph): h_g_sum = h_g_sum.unsqueeze(0) h_g_max = h_g_max.unsqueeze(0) h_g = torch.cat([h_g_sum, h_g_max], dim=1) # Multi-task prediction return self.soft_classifier(h_g)
def forward(self, bg, feats): h_g_sum = self.weight_and_sum(bg, feats) with bg.local_scope(): bg.ndata['h'] = feats h_g_max = dgl.max_nodes(bg, 'h') h_g = torch.cat([h_g_sum, h_g_max], dim=1) return h_g
def forward(self, dgl_data): if self.getnode and self.getedge: dgl_feat = torch.cat([ dgl.mean_nodes(dgl_data, 'h'), dgl.max_nodes(dgl_data, 'h'), dgl.mean_edges(dgl_data, 'h'), dgl.max_edges(dgl_data, 'h'), ], -1) elif self.getnode: dgl_feat = torch.cat( [dgl.mean_nodes(dgl_data, 'h'), dgl.max_nodes(dgl_data, 'h')], -1) else: dgl_feat = torch.cat( [dgl.mean_edges(dgl_data, 'h'), dgl.max_edges(dgl_data, 'h')], -1) dgl_predict = self.activate(self.weight_node(dgl_feat)) return dgl_predict
def forward(self, dgl_data): dgl_feat, _ = torch.max( torch.stack([ dgl.mean_nodes(dgl_data, 'h'), dgl.max_nodes(dgl_data, 'h'), dgl.mean_edges(dgl_data, 'h'), dgl.max_edges(dgl_data, 'h'), ], 2), -1) return dgl_feat
def graph_pooling(self, g): h = 0 if self.graph_pooling_type == 'max': hg = dgl.max_nodes(g, 'h') elif self.graph_pooling_type == 'mean': hg = dgl.mean_nodes(g, 'h') elif self.graph_pooling_type == 'sum': hg = dgl.sum_nodes(g, 'h') return hg
def forward(self, g, x, e, snorm_n, snorm_e): # snorm_n batch中用到的 # h = self.embedding_h(h) # h = self.in_feat_dropout(h) h_node = torch.zeros([g.number_of_nodes(),self.node_in_dim]).float().to(self.device) h_edge = torch.zeros([g.number_of_edges(),self.h_dim]).float().to(self.device) src, dst = g.all_edges() for edge_layer, node_layer in zip(self.edge_layers, self.node_layers): if self.edge_f: if self.dst_f: h_edge = edge_layer(g, src_feat = x[src], dst_feat = x[dst], e_feat = e, h_feat = h_edge, snorm_e = snorm_e) h_node = node_layer(g, src_feat=x[src], dst_feat=x[dst], e_feat=e, h_feat=h_node, snorm_e=snorm_e, n_feat = x) else: h_edge = edge_layer(g, src_feat=x[src], e_feat=e, h_feat=h_edge, snorm_e=snorm_e) h_node = node_layer(g, src_feat=x[src], e_feat=e, h_feat=h_node, snorm_e=snorm_e, n_feat = x) else: if self.dst_f: h_edge = edge_layer(g, src_feat=x[src], dst_feat=x[dst], h_feat=h_edge, snorm_e=snorm_e) h_node = node_layer(g, src_feat=x[src], dst_feat=x[dst], h_feat=h_node, snorm_e=snorm_e, n_feat = x) else: h_edge = edge_layer(g, src_feat=x[src], h_feat=h_edge, snorm_e=snorm_e) h_node = node_layer(g, src_feat=x[src], h_feat=h_node, snorm_e=snorm_e, n_feat = x) g.edata['h'] = h_edge if self.node_update: g.ndata['h'] = h_node # print("g.data:", g.ndata['h'][0].shape) if self.readout == "sum": he = dgl.sum_edges(g, 'h') hn = dgl.sum_nodes(g, 'h') elif self.readout == "max": he = dgl.max_edges(g, 'h') hn = dgl.max_nodes(g, 'h') elif self.readout == "mean": he = dgl.mean_edges(g, 'h') hn = dgl.mean_nodes(g, 'h') else: he = dgl.mean_edges(g, 'h') # default readout is mean nodes hn = dgl.mean_nodes(g, 'h') # print(torch.cat([he, hn], dim=1).shape) # used to global task out = self.Global_MLP_layer(torch.cat([he, hn], dim=1)) # used to transition task edge_out = self.edge_MLPReadout(h_edge) # return self.MLP_layer(he) return out
def forward(self, graph: dgl.DGLGraph): graph.apply_nodes(self.input_node_func) for mp_layer in self.mp_layers: mp_layer(graph) mean_nodes = dgl.mean_nodes(graph, 'feat') max_nodes = dgl.max_nodes(graph, 'feat') mean_max = torch.cat([mean_nodes, max_nodes], dim=-1) return self.output(mean_max)
def readout_fn(readout, graphs, h): if readout == "sum": hg = dgl.sum_nodes(graphs, h) elif readout == "max": hg = dgl.max_nodes(graphs, h) elif readout == "mean": hg = dgl.mean_nodes(graphs, h) else: hg = dgl.mean_nodes(graphs, h) # default readout is mean nodes return hg