def test_simple_readout(): g1 = dgl.DGLGraph() g1.add_nodes(3) g2 = dgl.DGLGraph() g2.add_nodes(4) # no edges g1.add_edges([0, 1, 2], [2, 0, 1]) n1 = F.randn((3, 5)) n2 = F.randn((4, 5)) e1 = F.randn((3, 5)) s1 = F.sum(n1, 0) # node sums s2 = F.sum(n2, 0) se1 = F.sum(e1, 0) # edge sums m1 = F.mean(n1, 0) # node means m2 = F.mean(n2, 0) me1 = F.mean(e1, 0) # edge means w1 = F.randn((3, )) w2 = F.randn((4, )) max1 = F.max(n1, 0) max2 = F.max(n2, 0) maxe1 = F.max(e1, 0) ws1 = F.sum(n1 * F.unsqueeze(w1, 1), 0) ws2 = F.sum(n2 * F.unsqueeze(w2, 1), 0) wm1 = F.sum(n1 * F.unsqueeze(w1, 1), 0) / F.sum(F.unsqueeze(w1, 1), 0) wm2 = F.sum(n2 * F.unsqueeze(w2, 1), 0) / F.sum(F.unsqueeze(w2, 1), 0) g1.ndata['x'] = n1 g2.ndata['x'] = n2 g1.ndata['w'] = w1 g2.ndata['w'] = w2 g1.edata['x'] = e1 assert F.allclose(dgl.sum_nodes(g1, 'x'), s1) assert F.allclose(dgl.sum_nodes(g1, 'x', 'w'), ws1) assert F.allclose(dgl.sum_edges(g1, 'x'), se1) assert F.allclose(dgl.mean_nodes(g1, 'x'), m1) assert F.allclose(dgl.mean_nodes(g1, 'x', 'w'), wm1) assert F.allclose(dgl.mean_edges(g1, 'x'), me1) assert F.allclose(dgl.max_nodes(g1, 'x'), max1) assert F.allclose(dgl.max_edges(g1, 'x'), maxe1) g = dgl.batch([g1, g2]) s = dgl.sum_nodes(g, 'x') m = dgl.mean_nodes(g, 'x') max_bg = dgl.max_nodes(g, 'x') assert F.allclose(s, F.stack([s1, s2], 0)) assert F.allclose(m, F.stack([m1, m2], 0)) assert F.allclose(max_bg, F.stack([max1, max2], 0)) ws = dgl.sum_nodes(g, 'x', 'w') wm = dgl.mean_nodes(g, 'x', 'w') assert F.allclose(ws, F.stack([ws1, ws2], 0)) assert F.allclose(wm, F.stack([wm1, wm2], 0)) s = dgl.sum_edges(g, 'x') m = dgl.mean_edges(g, 'x') max_bg_e = dgl.max_edges(g, 'x') assert F.allclose(s, F.stack([se1, F.zeros(5)], 0)) assert F.allclose(m, F.stack([me1, F.zeros(5)], 0)) assert F.allclose(max_bg_e, F.stack([maxe1, F.zeros(5)], 0))
def forward(self, g, x, e, snorm_n, snorm_e): # h = self.embedding_h(h) # h = self.in_feat_dropout(h) h = torch.zeros([g.number_of_edges(),self.h_dim]).float().to(self.device) src, dst = g.all_edges() for mpnn in self.layers: if self.edge_f: if self.dst_f: h = mpnn(g, src_feat = x[src], dst_feat = x[dst], e_feat = e, h_feat = h, snorm_e = snorm_e) else: h = mpnn(g, src_feat=x[src], e_feat=e, h_feat=h, snorm_e=snorm_e) else: if self.dst_f: h = mpnn(g, src_feat=x[src], dst_feat=x[dst], h_feat=h, snorm_e=snorm_e) else: h = mpnn(g, src_feat=x[src], h_feat=h, snorm_e=snorm_e) g.edata['h'] = h if self.readout == "sum": hg = dgl.sum_edges(g, 'h') elif self.readout == "max": hg = dgl.max_edges(g, 'h') elif self.readout == "mean": hg = dgl.mean_edges(g, 'h') else: hg = dgl.mean_edges(g, 'h') # default readout is mean nodes return self.MLP_layer(hg)
def forward(self, dgl_data): dgl_feat, _ = torch.max( torch.stack([ dgl.mean_nodes(dgl_data, 'h'), dgl.max_nodes(dgl_data, 'h'), dgl.mean_edges(dgl_data, 'h'), dgl.max_edges(dgl_data, 'h'), ], 2), -1) return dgl_feat
def forward(self, dgl_data): if self.getnode and self.getedge: dgl_feat = torch.cat([ dgl.mean_nodes(dgl_data, 'h'), dgl.max_nodes(dgl_data, 'h'), dgl.mean_edges(dgl_data, 'h'), dgl.max_edges(dgl_data, 'h'), ], -1) elif self.getnode: dgl_feat = torch.cat( [dgl.mean_nodes(dgl_data, 'h'), dgl.max_nodes(dgl_data, 'h')], -1) else: dgl_feat = torch.cat( [dgl.mean_edges(dgl_data, 'h'), dgl.max_edges(dgl_data, 'h')], -1) dgl_predict = self.activate(self.weight_node(dgl_feat)) return dgl_predict
def forward(self, g, x, e, snorm_n, snorm_e): # snorm_n batch中用到的 # h = self.embedding_h(h) # h = self.in_feat_dropout(h) h_node = torch.zeros([g.number_of_nodes(),self.node_in_dim]).float().to(self.device) h_edge = torch.zeros([g.number_of_edges(),self.h_dim]).float().to(self.device) src, dst = g.all_edges() for edge_layer, node_layer in zip(self.edge_layers, self.node_layers): if self.edge_f: if self.dst_f: h_edge = edge_layer(g, src_feat = x[src], dst_feat = x[dst], e_feat = e, h_feat = h_edge, snorm_e = snorm_e) h_node = node_layer(g, src_feat=x[src], dst_feat=x[dst], e_feat=e, h_feat=h_node, snorm_e=snorm_e, n_feat = x) else: h_edge = edge_layer(g, src_feat=x[src], e_feat=e, h_feat=h_edge, snorm_e=snorm_e) h_node = node_layer(g, src_feat=x[src], e_feat=e, h_feat=h_node, snorm_e=snorm_e, n_feat = x) else: if self.dst_f: h_edge = edge_layer(g, src_feat=x[src], dst_feat=x[dst], h_feat=h_edge, snorm_e=snorm_e) h_node = node_layer(g, src_feat=x[src], dst_feat=x[dst], h_feat=h_node, snorm_e=snorm_e, n_feat = x) else: h_edge = edge_layer(g, src_feat=x[src], h_feat=h_edge, snorm_e=snorm_e) h_node = node_layer(g, src_feat=x[src], h_feat=h_node, snorm_e=snorm_e, n_feat = x) g.edata['h'] = h_edge if self.node_update: g.ndata['h'] = h_node # print("g.data:", g.ndata['h'][0].shape) if self.readout == "sum": he = dgl.sum_edges(g, 'h') hn = dgl.sum_nodes(g, 'h') elif self.readout == "max": he = dgl.max_edges(g, 'h') hn = dgl.max_nodes(g, 'h') elif self.readout == "mean": he = dgl.mean_edges(g, 'h') hn = dgl.mean_nodes(g, 'h') else: he = dgl.mean_edges(g, 'h') # default readout is mean nodes hn = dgl.mean_nodes(g, 'h') # print(torch.cat([he, hn], dim=1).shape) # used to global task out = self.Global_MLP_layer(torch.cat([he, hn], dim=1)) # used to transition task edge_out = self.edge_MLPReadout(h_edge) # return self.MLP_layer(he) return out
def forward(self, graph, feat): with graph.local_scope(): graph.edata['e'] = feat readout = dgl.max_edges(graph, 'e') return readout