def forward(self, input_hidden, graphs: dgl.DGLGraph, batch_num_nodes=None): if batch_num_nodes is None: b_num_nodes = graphs.batch_num_nodes else: b_num_nodes = batch_num_nodes h_t = self.input_proj(input_hidden) # when there are no edges in the graph, there is nothing to do if graphs.number_of_edges() > 0: #give all the nodes an edges information about the current querry hidden state broadcasted_hn = dgl.broadcast_nodes(graphs, h_t) graphs.ndata['h_t'] = broadcasted_hn broadcasted_he = dgl.broadcast_edges(graphs, h_t) graphs.edata['h_t'] = broadcasted_he # create a copy of the node and edge states which will be updated for K iterations graphs.ndata['F_n_t'] = graphs.ndata['F_n'] graphs.edata['F_e_t'] = graphs.edata['F_e'] for _ in range(self.k_update_steps): graphs.ndata['s_n'] = self.object_score(torch.cat([graphs.ndata['h_t'], graphs.ndata['F_n_t']], dim=-1)) graphs.send(message_func=self.io_attention_send) graphs.recv(reduce_func=self.io_attention_reduce) graphs.ndata['F_n_t'] = graphs.ndata['F_i_tplus1'] if self.update_relations: graphs.edata['F_e_t'] = graphs.edata['F_e_tplus1'] io = torch.split(graphs.ndata['F_n_t'], split_size_or_sections=b_num_nodes) else: io = torch.split(graphs.ndata['F_n'], split_size_or_sections=b_num_nodes) io = pad_sequence(io, batch_first=True) io_mask = io.sum(dim=-1) != 0 return io, io_mask
def forward(self, graph, global_attr=None, out_edge_key='h_e'): def send_func(edges): edges_to_collect = [] num_edges = edges.data[self.edge_key].shape[0] if self._use_edges: edges_to_collect.append(edges.data[self.edge_key]) if self._use_sender_nodes: edges_to_collect.append(edges.src[self.node_key]) if self._use_receiver_nodes: edges_to_collect.append(edges.dst[self.node_key]) if self._use_globals and global_attr is not None: # self._global_attr = global_attr.unsqueeze(0) # make global_attr.shape = (1, DIM) # expanded_global_attr = self._global_attr.expand(num_edges, self._global_attr.shape[1]) expanded_global_attr = edges.data['expanded_global_attr'] edges_to_collect.append(expanded_global_attr) collected_edges = torch.cat(edges_to_collect, dim=-1) if self.recurrent: return { out_edge_key: self.net(collected_edges, edges.data[out_edge_key]) } else: return {out_edge_key: self.net(collected_edges)} graph.edata['expanded_global_attr'] = dgl.broadcast_edges( graph, global_attr) graph.apply_edges(send_func) return graph
def test_broadcast_edges(): # test#1: basic g0 = dgl.DGLGraph(nx.path_graph(10)) feat0 = F.randn((40, )) ground_truth = F.stack([feat0] * g0.number_of_edges(), 0) assert F.allclose(dgl.broadcast_edges(g0, feat0), ground_truth) # test#2: batched graph g1 = dgl.DGLGraph(nx.path_graph(3)) g2 = dgl.DGLGraph() g3 = dgl.DGLGraph(nx.path_graph(12)) bg = dgl.batch([g0, g1, g2, g3]) feat1 = F.randn((40, )) feat2 = F.randn((40, )) feat3 = F.randn((40, )) ground_truth = F.stack( [feat0] * g0.number_of_edges() +\ [feat1] * g1.number_of_edges() +\ [feat2] * g2.number_of_edges() +\ [feat3] * g3.number_of_edges(), 0 ) assert F.allclose( dgl.broadcast_edges(bg, F.stack([feat0, feat1, feat2, feat3], 0)), ground_truth)
def forward(self, g, h): for l in range(self.num_layers - 1): h, _ = self.gat[l](g, h, merge='flatten') h = F.elu(h) h, e = self.gat[-1](g, h, merge='mean') # Graph level prediction g.ndata['h'] = h h_readout = dgl.mean_nodes(g, 'h') h_pred = self.linear_h(h_readout) # Edge prediction eh = dgl.broadcast_edges(g, h_readout) e_fused = torch.cat((eh, e), dim=1) e_pred = self.linear_e(e_fused) return h_pred, e_pred
def test_broadcast(idtype, g): g = g.astype(idtype).to(F.ctx()) gfeat = F.randn((g.batch_size, 3)) # Test.0: broadcast_nodes g.ndata['h'] = dgl.broadcast_nodes(g, gfeat) subg = dgl.unbatch(g) for i, sg in enumerate(subg): assert F.allclose( sg.ndata['h'], F.repeat(F.reshape(gfeat[i], (1, 3)), sg.number_of_nodes(), dim=0)) # Test.1: broadcast_edges g.edata['h'] = dgl.broadcast_edges(g, gfeat) subg = dgl.unbatch(g) for i, sg in enumerate(subg): assert F.allclose( sg.edata['h'], F.repeat(F.reshape(gfeat[i], (1, 3)), sg.number_of_edges(), dim=0))
def forward(self, tree_graphs, tree_vec): device = tree_vec.device batch_size = tree_graphs.batch_size root_ids = get_root_ids(tree_graphs) if 'x' not in tree_graphs.ndata: tree_graphs.ndata['x'] = self.embedding(tree_graphs.ndata['wid']) if 'src_x' not in tree_graphs.edata: tree_graphs.apply_edges(fn.copy_u('x', 'src_x')) tree_graphs = tree_graphs.local_var() tree_graphs.apply_edges(func=lambda edges: {'dst_wid': edges.dst['wid']}) line_tree_graphs = dgl.line_graph(tree_graphs, backtracking=False, shared=True) line_num_nodes = line_tree_graphs.num_nodes() line_tree_graphs.ndata.update({ 'src_x_r': self.W_r(line_tree_graphs.ndata['src_x']), # Exploit the fact that the reduce function is a sum of incoming messages, # and uncomputed messages are zero vectors. 'h': torch.zeros(line_num_nodes, self.hidden_size).to(device), 'vec': dgl.broadcast_edges(tree_graphs, tree_vec), 'sum_h': torch.zeros(line_num_nodes, self.hidden_size).to(device), 'sum_gated_h': torch.zeros(line_num_nodes, self.hidden_size).to(device) }) # input tensors for stop prediction (p) and label prediction (q) pred_hiddens, pred_mol_vecs, pred_targets = [], [], [] stop_hiddens, stop_targets = [], [] # Predict root pred_hiddens.append(torch.zeros(batch_size, self.hidden_size).to(device)) pred_targets.append(tree_graphs.ndata['wid'][root_ids.to(device)]) pred_mol_vecs.append(tree_vec) # Traverse the tree and predict on children for eid, p in dfs_order(tree_graphs, root_ids.to(dtype=tree_graphs.idtype)): eid = eid.to(device) p = p.to(device=device, dtype=tree_graphs.idtype) # Message passing excluding the target line_tree_graphs.pull(v=eid, message_func=fn.copy_u('h', 'h_nei'), reduce_func=fn.sum('h_nei', 'sum_h')) line_tree_graphs.pull(v=eid, message_func=self.gru_message, reduce_func=fn.sum('m', 'sum_gated_h')) line_tree_graphs.apply_nodes(self.gru_update, v=eid) # Node aggregation including the target # By construction, the edges of the raw graph follow the order of # (i1, j1), (j1, i1), (i2, j2), (j2, i2), ... The order of the nodes # in the line graph corresponds to the order of the edges in the raw graph. eid = eid.long() reverse_eid = torch.bitwise_xor(eid, torch.tensor(1).to(device)) cur_o = line_tree_graphs.ndata['sum_h'][eid] + \ line_tree_graphs.ndata['h'][reverse_eid] # Gather targets mask = (p == torch.tensor(0).to(device)) pred_list = eid[mask] stop_target = torch.tensor(1).to(device) - p # Hidden states for stop prediction stop_hidden = torch.cat([line_tree_graphs.ndata['src_x'][eid], cur_o, line_tree_graphs.ndata['vec'][eid]], dim=1) stop_hiddens.append(stop_hidden) stop_targets.extend(stop_target) #Hidden states for clique prediction if len(pred_list) > 0: pred_mol_vecs.append(line_tree_graphs.ndata['vec'][pred_list]) pred_hiddens.append(line_tree_graphs.ndata['h'][pred_list]) pred_targets.append(line_tree_graphs.ndata['dst_wid'][pred_list]) #Last stop at root root_ids = root_ids.to(device) cur_x = tree_graphs.ndata['x'][root_ids] tree_graphs.edata['h'] = line_tree_graphs.ndata['h'] tree_graphs.pull(v=root_ids.to(dtype=tree_graphs.idtype), message_func=fn.copy_e('h', 'm'), reduce_func=fn.sum('m', 'cur_o')) stop_hidden = torch.cat([cur_x, tree_graphs.ndata['cur_o'][root_ids], tree_vec], dim=1) stop_hiddens.append(stop_hidden) stop_targets.extend(torch.zeros(batch_size).to(device)) # Predict next clique pred_hiddens = torch.cat(pred_hiddens, dim=0) pred_mol_vecs = torch.cat(pred_mol_vecs, dim=0) pred_vecs = torch.cat([pred_hiddens, pred_mol_vecs], dim=1) pred_vecs = F.relu(self.W(pred_vecs)) pred_scores = self.W_o(pred_vecs) pred_targets = torch.cat(pred_targets, dim=0) pred_loss = self.pred_loss(pred_scores, pred_targets) / batch_size _, preds = torch.max(pred_scores, dim=1) pred_acc = torch.eq(preds, pred_targets).float() pred_acc = torch.sum(pred_acc) / pred_targets.nelement() # Predict stop stop_hiddens = torch.cat(stop_hiddens, dim=0) stop_vecs = F.relu(self.U(stop_hiddens)) stop_scores = self.U_s(stop_vecs).squeeze() stop_targets = torch.Tensor(stop_targets).to(device) stop_loss = self.stop_loss(stop_scores, stop_targets) / batch_size stops = torch.ge(stop_scores, 0).float() stop_acc = torch.eq(stops, stop_targets).float() stop_acc = torch.sum(stop_acc) / stop_targets.nelement() return pred_loss, stop_loss, pred_acc.item(), stop_acc.item()