def test_simple_readout(): g1 = dgl.DGLGraph() g1.add_nodes(3) g2 = dgl.DGLGraph() g2.add_nodes(4) # no edges g1.add_edges([0, 1, 2], [2, 0, 1]) n1 = F.randn((3, 5)) n2 = F.randn((4, 5)) e1 = F.randn((3, 5)) s1 = F.sum(n1, 0) # node sums s2 = F.sum(n2, 0) se1 = F.sum(e1, 0) # edge sums m1 = F.mean(n1, 0) # node means m2 = F.mean(n2, 0) me1 = F.mean(e1, 0) # edge means w1 = F.randn((3, )) w2 = F.randn((4, )) max1 = F.max(n1, 0) max2 = F.max(n2, 0) maxe1 = F.max(e1, 0) ws1 = F.sum(n1 * F.unsqueeze(w1, 1), 0) ws2 = F.sum(n2 * F.unsqueeze(w2, 1), 0) wm1 = F.sum(n1 * F.unsqueeze(w1, 1), 0) / F.sum(F.unsqueeze(w1, 1), 0) wm2 = F.sum(n2 * F.unsqueeze(w2, 1), 0) / F.sum(F.unsqueeze(w2, 1), 0) g1.ndata['x'] = n1 g2.ndata['x'] = n2 g1.ndata['w'] = w1 g2.ndata['w'] = w2 g1.edata['x'] = e1 assert F.allclose(dgl.sum_nodes(g1, 'x'), s1) assert F.allclose(dgl.sum_nodes(g1, 'x', 'w'), ws1) assert F.allclose(dgl.sum_edges(g1, 'x'), se1) assert F.allclose(dgl.mean_nodes(g1, 'x'), m1) assert F.allclose(dgl.mean_nodes(g1, 'x', 'w'), wm1) assert F.allclose(dgl.mean_edges(g1, 'x'), me1) assert F.allclose(dgl.max_nodes(g1, 'x'), max1) assert F.allclose(dgl.max_edges(g1, 'x'), maxe1) g = dgl.batch([g1, g2]) s = dgl.sum_nodes(g, 'x') m = dgl.mean_nodes(g, 'x') max_bg = dgl.max_nodes(g, 'x') assert F.allclose(s, F.stack([s1, s2], 0)) assert F.allclose(m, F.stack([m1, m2], 0)) assert F.allclose(max_bg, F.stack([max1, max2], 0)) ws = dgl.sum_nodes(g, 'x', 'w') wm = dgl.mean_nodes(g, 'x', 'w') assert F.allclose(ws, F.stack([ws1, ws2], 0)) assert F.allclose(wm, F.stack([wm1, wm2], 0)) s = dgl.sum_edges(g, 'x') m = dgl.mean_edges(g, 'x') max_bg_e = dgl.max_edges(g, 'x') assert F.allclose(s, F.stack([se1, F.zeros(5)], 0)) assert F.allclose(m, F.stack([me1, F.zeros(5)], 0)) assert F.allclose(max_bg_e, F.stack([maxe1, F.zeros(5)], 0))
def test_simple_readout(): g1 = dgl.DGLGraph() g1.add_nodes(3) g2 = dgl.DGLGraph() g2.add_nodes(4) # no edges g1.add_edges([0, 1, 2], [2, 0, 1]) n1 = th.randn(3, 5) n2 = th.randn(4, 5) e1 = th.randn(3, 5) s1 = n1.sum(0) # node sums s2 = n2.sum(0) se1 = e1.sum(0) # edge sums m1 = n1.mean(0) # node means m2 = n2.mean(0) me1 = e1.mean(0) # edge means w1 = th.randn(3) w2 = th.randn(4) ws1 = (n1 * w1[:, None]).sum(0) # weighted node sums ws2 = (n2 * w2[:, None]).sum(0) wm1 = (n1 * w1[:, None]).sum(0) / w1[:, None].sum(0) # weighted node means wm2 = (n2 * w2[:, None]).sum(0) / w2[:, None].sum(0) g1.ndata['x'] = n1 g2.ndata['x'] = n2 g1.ndata['w'] = w1 g2.ndata['w'] = w2 g1.edata['x'] = e1 assert U.allclose(dgl.sum_nodes(g1, 'x'), s1) assert U.allclose(dgl.sum_nodes(g1, 'x', 'w'), ws1) assert U.allclose(dgl.sum_edges(g1, 'x'), se1) assert U.allclose(dgl.mean_nodes(g1, 'x'), m1) assert U.allclose(dgl.mean_nodes(g1, 'x', 'w'), wm1) assert U.allclose(dgl.mean_edges(g1, 'x'), me1) g = dgl.batch([g1, g2]) s = dgl.sum_nodes(g, 'x') m = dgl.mean_nodes(g, 'x') assert U.allclose(s, th.stack([s1, s2], 0)) assert U.allclose(m, th.stack([m1, m2], 0)) ws = dgl.sum_nodes(g, 'x', 'w') wm = dgl.mean_nodes(g, 'x', 'w') assert U.allclose(ws, th.stack([ws1, ws2], 0)) assert U.allclose(wm, th.stack([wm1, wm2], 0)) s = dgl.sum_edges(g, 'x') m = dgl.mean_edges(g, 'x') assert U.allclose(s, th.stack([se1, th.zeros(5)], 0)) assert U.allclose(m, th.stack([me1, th.zeros(5)], 0))
def forward(self, graph, edge_feat, node_feat, g_repr, edge_hidden, node_hidden, graph_hidden): graph.edata['edge_feat'] = edge_feat graph.ndata['node_feat'] = node_feat graph.edata['hidden1'] = edge_hidden[0][0] graph.ndata['hidden1'] = node_hidden[0][0] graph.edata['hidden2'] = edge_hidden[1][0] graph.ndata['hidden2'] = node_hidden[1][0] node_trf_func = lambda x : self.compute_node_repr(nodes=x, graph=graph, g_repr=g_repr) edge_trf_func = lambda x: self.compute_edge_repr(edges=x, graph=graph, g_repr=g_repr) graph.apply_edges(edge_trf_func) graph.update_all(self.graph_message_func, self.graph_reduce_func, node_trf_func) e_comb = dgl.sum_edges(graph, 'edge_feat') n_comb = dgl.sum_nodes(graph, 'node_feat') u_out, u_hidden = self.compute_u_repr(n_comb, e_comb, g_repr, graph_hidden) e_feat = graph.edata['edge_feat'] n_feat = graph.ndata['node_feat'] h_e = (torch.unsqueeze(graph.edata['hidden1'],0),torch.unsqueeze(graph.edata['hidden2'],0)) h_n = (torch.unsqueeze(graph.ndata['hidden1'],0),torch.unsqueeze(graph.ndata['hidden2'],0)) e_keys = list(graph.edata.keys()) n_keys = list(graph.ndata.keys()) for key in e_keys: graph.edata.pop(key) for key in n_keys: graph.ndata.pop(key) return e_feat, h_e, n_feat, h_n, u_out, u_hidden
def forward(self, graph, edge_feat, node_feat, g_repr): node_trf_func = lambda x: self.compute_node_repr( nodes=x, graph=graph, g_repr=g_repr) graph.edata['edge_feat'] = edge_feat graph.ndata['node_feat'] = node_feat edge_trf_func = lambda x: self.compute_edge_repr( edges=x, graph=graph, g_repr=g_repr) graph.apply_edges(edge_trf_func) graph.update_all(self.graph_message_func, self.graph_reduce_func, node_trf_func) e_comb = dgl.sum_edges(graph, 'edge_feat') n_comb = dgl.sum_nodes(graph, 'node_feat') e_out = graph.edata['edge_feat'] n_out = graph.ndata['node_feat'] e_keys = list(graph.edata.keys()) n_keys = list(graph.ndata.keys()) for key in e_keys: graph.edata.pop(key) for key in n_keys: graph.ndata.pop(key) return e_out, n_out, self.compute_u_repr(n_comb, e_comb, g_repr)
def forward(self, g, x, e, snorm_n, snorm_e): # h = self.embedding_h(h) # h = self.in_feat_dropout(h) h = torch.zeros([g.number_of_edges(),self.h_dim]).float().to(self.device) src, dst = g.all_edges() for mpnn in self.layers: if self.edge_f: if self.dst_f: h = mpnn(g, src_feat = x[src], dst_feat = x[dst], e_feat = e, h_feat = h, snorm_e = snorm_e) else: h = mpnn(g, src_feat=x[src], e_feat=e, h_feat=h, snorm_e=snorm_e) else: if self.dst_f: h = mpnn(g, src_feat=x[src], dst_feat=x[dst], h_feat=h, snorm_e=snorm_e) else: h = mpnn(g, src_feat=x[src], h_feat=h, snorm_e=snorm_e) g.edata['h'] = h if self.readout == "sum": hg = dgl.sum_edges(g, 'h') elif self.readout == "max": hg = dgl.max_edges(g, 'h') elif self.readout == "mean": hg = dgl.mean_edges(g, 'h') else: hg = dgl.mean_edges(g, 'h') # default readout is mean nodes return self.MLP_layer(hg)
def forward(self, g, x, e, snorm_n, snorm_e): # snorm_n batch中用到的 # h = self.embedding_h(h) # h = self.in_feat_dropout(h) h_node = torch.zeros([g.number_of_nodes(),self.node_in_dim]).float().to(self.device) h_edge = torch.zeros([g.number_of_edges(),self.h_dim]).float().to(self.device) src, dst = g.all_edges() for edge_layer, node_layer in zip(self.edge_layers, self.node_layers): if self.edge_f: if self.dst_f: h_edge = edge_layer(g, src_feat = x[src], dst_feat = x[dst], e_feat = e, h_feat = h_edge, snorm_e = snorm_e) h_node = node_layer(g, src_feat=x[src], dst_feat=x[dst], e_feat=e, h_feat=h_node, snorm_e=snorm_e, n_feat = x) else: h_edge = edge_layer(g, src_feat=x[src], e_feat=e, h_feat=h_edge, snorm_e=snorm_e) h_node = node_layer(g, src_feat=x[src], e_feat=e, h_feat=h_node, snorm_e=snorm_e, n_feat = x) else: if self.dst_f: h_edge = edge_layer(g, src_feat=x[src], dst_feat=x[dst], h_feat=h_edge, snorm_e=snorm_e) h_node = node_layer(g, src_feat=x[src], dst_feat=x[dst], h_feat=h_node, snorm_e=snorm_e, n_feat = x) else: h_edge = edge_layer(g, src_feat=x[src], h_feat=h_edge, snorm_e=snorm_e) h_node = node_layer(g, src_feat=x[src], h_feat=h_node, snorm_e=snorm_e, n_feat = x) g.edata['h'] = h_edge if self.node_update: g.ndata['h'] = h_node # print("g.data:", g.ndata['h'][0].shape) if self.readout == "sum": he = dgl.sum_edges(g, 'h') hn = dgl.sum_nodes(g, 'h') elif self.readout == "max": he = dgl.max_edges(g, 'h') hn = dgl.max_nodes(g, 'h') elif self.readout == "mean": he = dgl.mean_edges(g, 'h') hn = dgl.mean_nodes(g, 'h') else: he = dgl.mean_edges(g, 'h') # default readout is mean nodes hn = dgl.mean_nodes(g, 'h') # print(torch.cat([he, hn], dim=1).shape) # used to global task out = self.Global_MLP_layer(torch.cat([he, hn], dim=1)) # used to transition task edge_out = self.edge_MLPReadout(h_edge) # return self.MLP_layer(he) return out
def forward(self, graph, edge_feats_u, node_feats_u, edge_feat_reflected_u, mode="train", node_probability=None, joint_acts=None): graph.edata['edge_feat_u'] = edge_feats_u graph.edata['edge_feat_reflected_u'] = edge_feat_reflected_u graph.ndata['node_feat_u'] = node_feats_u n_weights = torch.zeros([node_feats_u.shape[0], 1]) zero_indexes, offset = [0], 0 num_nodes = graph.batch_num_nodes # Mark all 0-th index nodes for a in num_nodes[:-1]: offset += a zero_indexes.append(offset) n_weights[zero_indexes] = 1 graph.ndata['weights'] = n_weights graph.ndata['mod_weights'] = 1 - n_weights graph.apply_nodes(self.compute_node_data) graph.apply_edges(self.compute_edge_data) self.utils_storage["indiv"].append( graph.ndata["indiv_util"].detach().numpy()) self.utils_storage["pairs"].append( graph.edata["util_vals"].detach().numpy()) self.utils_storage["batch_num_nodes"].append(graph.batch_num_nodes) self.utils_storage["batch_num_edges"].append(graph.batch_num_edges) if "inference" in mode: graph.ndata["probs"] = node_probability src, dst = graph.edges() src_list, dst_list = src.tolist(), dst.tolist() # Mark edges not connected to zero e_nc_zero_weight = torch.zeros([edge_feats_u.shape[0], 1]) all_nc_edges = [ idx for idx, (src, dst) in enumerate(zip(src_list, dst_list)) if (not src in zero_indexes) and (not dst in zero_indexes) ] e_nc_zero_weight[all_nc_edges] = 0.5 graph.edata["nc_zero_weight"] = e_nc_zero_weight graph.apply_edges(self.graph_pair_inference_func) graph.update_all(message_func=self.graph_dst_inference_func, reduce_func=self.graph_reduce_func, apply_node_func=self.graph_node_inference_func) total_connected = dgl.sum_nodes(graph, 'util_dst', 'weights') total_n_connected = dgl.sum_edges(graph, 'edge_all_sum_prob', 'nc_zero_weight') total_expected_others_util = dgl.sum_nodes( graph, "expected_indiv_util", "mod_weights").view(-1, 1) total_indiv_util_zero = dgl.sum_nodes(graph, "indiv_util", "weights") returned_values = (total_connected + total_n_connected) + \ (total_expected_others_util + total_indiv_util_zero) e_keys = list(graph.edata.keys()) n_keys = list(graph.ndata.keys()) for key in e_keys: graph.edata.pop(key) for key in n_keys: graph.ndata.pop(key) return returned_values m_func = lambda x: self.graph_u_sum(graph, x, joint_acts) graph.update_all(message_func=m_func, reduce_func=self.graph_sum_all) indiv_u_zeros = graph.ndata['indiv_util'] u_msg_sum_zeros = 0.5 * graph.ndata['u_msg_sum'] graph.ndata['utils_sum_all'] = ( indiv_u_zeros + u_msg_sum_zeros).gather( -1, torch.Tensor(joint_acts)[:, None].long()) q_values = dgl.sum_nodes(graph, 'utils_sum_all') e_keys = list(graph.edata.keys()) n_keys = list(graph.ndata.keys()) for key in e_keys: graph.edata.pop(key) for key in n_keys: graph.ndata.pop(key) return q_values
def forward(self, graph, feat): with graph.local_scope(): graph.edata['e'] = feat readout = dgl.sum_edges(graph, 'e') return readout
def sum_readout(g): return dgl.sum_edges(g, from_field)