def forward(self, cand_batch, mol_tree_batch): cand_graphs, tree_mess_src_edges, tree_mess_tgt_edges, tree_mess_tgt_nodes = \ mol2dgl(cand_batch, mol_tree_batch) n_samples = len(cand_graphs) cand_graphs = batch(cand_graphs) cand_line_graph = line_graph(cand_graphs, no_backtracking=True) n_nodes = len(cand_graphs.nodes) n_edges = len(cand_graphs.edges) cand_graphs = self.run(cand_graphs, cand_line_graph, tree_mess_src_edges, tree_mess_tgt_edges, tree_mess_tgt_nodes, mol_tree_batch) cand_graphs = unbatch(cand_graphs) g_repr = torch.stack( [g.get_n_repr()['h'].mean(0) for g in cand_graphs], 0) self.n_samples_total += n_samples self.n_nodes_total += n_nodes self.n_edges_total += n_edges self.n_passes += 1 return g_repr
def _load_graph(self): num_graphs = self.label.shape[0] graphs = [] line_graphs = [] for idx in trange(num_graphs): n_atoms = self.N[idx] # get all the atomic coordinates of the idx-th molecular graph R = self.R[self.N_cumsum[idx]:self.N_cumsum[idx + 1]] # calculate the distance between all atoms dist = np.linalg.norm(R[:, None, :] - R[None, :, :], axis=-1) # keep all edges that don't exceed the cutoff and delete self-loops adj = sp.csr_matrix(dist <= self.cutoff) - sp.eye(n_atoms, dtype=np.bool) adj = adj.tocoo() u, v = torch.tensor(adj.row), torch.tensor(adj.col) g = dgl_graph((u, v)) g.ndata['R'] = torch.tensor(R, dtype=torch.float32) g.ndata['Z'] = torch.tensor( self.Z[self.N_cumsum[idx]:self.N_cumsum[idx + 1]], dtype=torch.long) # add user-defined features if self.edge_funcs is not None: for func in self.edge_funcs: g.apply_edges(func) graphs.append(g) l_g = dgl.line_graph(g, backtracking=False) line_graphs.append(l_g) return graphs, line_graphs
def forward(self, mol_trees): mol_tree_batch = batch(mol_trees) # Build line graph to prepare for belief propagation mol_tree_batch_lg = line_graph(mol_tree_batch, no_backtracking=True) return self.run(mol_tree_batch, mol_tree_batch_lg)
def forward(self, mol_graph): mol_graph = mol_graph.local_var() line_mol_graph = dgl.line_graph(mol_graph, backtracking=False) line_input = self.W_i(mol_graph.edata['x']) line_mol_graph.ndata['msg_input'] = line_input line_mol_graph.ndata['msg'] = F.relu(line_input) # Message passing over the line graph for _ in range(self.depth - 1): line_mol_graph.update_all(message_func=fn.copy_u('msg', 'msg'), reduce_func=fn.sum('msg', 'nei_msg')) nei_msg = self.W_h(line_mol_graph.ndata['nei_msg']) line_mol_graph.ndata['msg'] = F.relu(line_input + nei_msg) # Message passing over the raw graph mol_graph.edata['msg'] = line_mol_graph.ndata['msg'] mol_graph.update_all(message_func=fn.copy_e('msg', 'msg'), reduce_func=fn.sum('msg', 'nei_msg')) raw_input = torch.cat([mol_graph.ndata['x'], mol_graph.ndata['nei_msg']], dim=1) mol_graph.ndata['atom_hiddens'] = self.W_o(raw_input) # Readout mol_vecs = dgl.mean_nodes(mol_graph, 'atom_hiddens') return mol_vecs
def forward(self, mol_tree_batch): # Build line graph to prepare for belief propagation mol_tree_batch_lg = dgl.line_graph(mol_tree_batch, backtracking=False, shared=True) mol_tree_batch_lg._node_frames = mol_tree_batch._edge_frames return self.run(mol_tree_batch, mol_tree_batch_lg)
def forward(self, mol_trees, tree_vec): ''' The training procedure which computes the prediction loss given the ground truth tree ''' mol_tree_batch = batch(mol_trees) mol_tree_batch_lg = line_graph(mol_tree_batch, no_backtracking=True) n_trees = len(mol_trees) return self.run(mol_tree_batch, mol_tree_batch_lg, n_trees, tree_vec)
def forward(self, mol_graph): mol_line_graph = dgl.line_graph(mol_graph, backtracking=False, shared=True) mol_line_graph._node_frames = mol_graph._edge_frames mol_graph = self.run(mol_graph, mol_line_graph) # TODO: replace with unbatch or readout g_repr = mean_nodes(mol_graph, 'h') return g_repr
def forward(self, mol_trees): mol_tree_batch = batch(mol_trees) if torch.cuda.is_available() and not os.getenv('NOCUDA', None): mol_tree_batch = mol_tree_batch.to('cuda:0') # Build line graph to prepare for belief propagation mol_tree_batch_lg = dgl.line_graph(mol_tree_batch, backtracking=False, shared=True) return self.run(mol_tree_batch, mol_tree_batch_lg)
def forward(self, mol_trees, tree_vec): ''' The training procedure which computes the prediction loss given the ground truth tree ''' mol_tree_batch = batch(mol_trees) mol_tree_batch_lg = dgl.line_graph(mol_tree_batch, backtracking=False, shared=True) mol_tree_batch_lg._node_frames = mol_tree_batch._edge_frames n_trees = len(mol_trees) return self.run(mol_tree_batch, mol_tree_batch_lg, n_trees, tree_vec)
def forward(self, cand_batch, mol_tree_batch): cand_graphs, tree_mess_src_edges, tree_mess_tgt_edges, tree_mess_tgt_nodes = cand_batch cand_line_graph = dgl.line_graph(cand_graphs, backtracking=False, shared=True) cand_line_graph._node_frames = cand_graphs._edge_frames cand_graphs = self.run( cand_graphs, cand_line_graph, tree_mess_src_edges, tree_mess_tgt_edges, tree_mess_tgt_nodes, mol_tree_batch) g_repr = mean_nodes(cand_graphs, 'h') return g_repr
def forward(self, mol_graph): n_samples = mol_graph.batch_size mol_line_graph = line_graph(mol_graph, backtracking=False, shared=True) n_nodes = mol_graph.number_of_nodes() n_edges = mol_graph.number_of_edges() mol_graph = self.run(mol_graph, mol_line_graph) # TODO: replace with unbatch or readout g_repr = mean_nodes(mol_graph, 'h') self.n_samples_total += n_samples self.n_nodes_total += n_nodes self.n_edges_total += n_edges self.n_passes += 1 return g_repr
def forward(self, tree_graphs): device = tree_graphs.device if 'x' not in tree_graphs.ndata: tree_graphs.ndata['x'] = self.embedding(tree_graphs.ndata['wid']) tree_graphs.apply_edges(fn.copy_u('x', 'src_x')) tree_graphs = tree_graphs.local_var() line_tree_graphs = dgl.line_graph(tree_graphs, backtracking=False) line_tree_graphs.ndata.update({ 'src_x': tree_graphs.edata['src_x'], 'src_x_r': self.W_r(tree_graphs.edata['src_x']), # Exploit the fact that the reduce function is a sum of incoming messages, # and uncomputed messages are zero vectors. 'h': torch.zeros(line_tree_graphs.num_nodes(), self.hidden_size).to(device), 'sum_h': torch.zeros(line_tree_graphs.num_nodes(), self.hidden_size).to(device), 'sum_gated_h': torch.zeros(line_tree_graphs.num_nodes(), self.hidden_size).to(device) }) # Get the ID of the root nodes, the first node of all trees root_ids = get_root_ids(tree_graphs) for eid in level_order(tree_graphs, root_ids.to(dtype=tree_graphs.idtype)): line_tree_graphs.pull(v=eid, message_func=fn.copy_u('h', 'h_nei'), reduce_func=fn.sum('h_nei', 'sum_h')) line_tree_graphs.pull(v=eid, message_func=self.gru_message, reduce_func=fn.sum('m', 'sum_gated_h')) line_tree_graphs.apply_nodes(self.gru_update, v=eid) # Readout tree_graphs.ndata['h'] = torch.zeros(tree_graphs.num_nodes(), self.hidden_size).to(device) tree_graphs.edata['h'] = line_tree_graphs.ndata['h'] root_ids = root_ids.to(device) tree_graphs.pull(v=root_ids.to(dtype=tree_graphs.idtype), message_func=fn.copy_e('h', 'm'), reduce_func=fn.sum('m', 'h')) root_vec = torch.cat([ tree_graphs.ndata['x'][root_ids], tree_graphs.ndata['h'][root_ids] ], dim=1) root_vec = self.W(root_vec) return tree_graphs.edata['h'], root_vec
def forward(self, cand_batch, mol_tree_batch): cand_graphs, tree_mess_src_edges, tree_mess_tgt_edges, tree_mess_tgt_nodes = cand_batch n_samples = len(cand_graphs) cand_line_graph = line_graph(cand_graphs, backtracking=False, shared=True) n_nodes = cand_graphs.number_of_nodes() n_edges = cand_graphs.number_of_edges() cand_graphs = self.run(cand_graphs, cand_line_graph, tree_mess_src_edges, tree_mess_tgt_edges, tree_mess_tgt_nodes, mol_tree_batch) g_repr = mean_nodes(cand_graphs, 'h') self.n_samples_total += n_samples self.n_nodes_total += n_nodes self.n_edges_total += n_edges self.n_passes += 1 return g_repr
def decode(self, mol_vec): assert mol_vec.shape[0] == 1 mol_tree = DGLMolTree(None) init_hidden = cuda(torch.zeros(1, self.hidden_size)) root_hidden = torch.cat([init_hidden, mol_vec], 1) root_hidden = F.relu(self.W(root_hidden)) root_score = self.W_o(root_hidden) _, root_wid = torch.max(root_score, 1) root_wid = root_wid.view(1) mol_tree.add_nodes(1) # root mol_tree.nodes[0].data['wid'] = root_wid mol_tree.nodes[0].data['x'] = self.embedding(root_wid) mol_tree.nodes[0].data['h'] = init_hidden mol_tree.nodes[0].data['fail'] = cuda(torch.tensor([0])) mol_tree.nodes_dict[0] = root_node_dict = create_node_dict( self.vocab.get_smiles(root_wid)) stack, trace = [], [] stack.append((0, self.vocab.get_slots(root_wid))) all_nodes = {0: root_node_dict} first = True new_node_id = 0 new_edge_id = 0 for step in range(MAX_DECODE_LEN): u, u_slots = stack[-1] udata = mol_tree.nodes[u].data x = udata['x'] h = udata['h'] # Predict stop p_input = torch.cat([x, h, mol_vec], 1) p_score = torch.sigmoid(self.U_s(torch.relu(self.U(p_input)))) backtrack = (p_score.item() < 0.5) if not backtrack: # Predict next clique. Note that the prediction may fail due # to lack of assemblable components mol_tree.add_nodes(1) new_node_id += 1 v = new_node_id mol_tree.add_edges(u, v) uv = new_edge_id new_edge_id += 1 if first: mol_tree.edata.update({ 's': cuda(torch.zeros(1, self.hidden_size)), 'm': cuda(torch.zeros(1, self.hidden_size)), 'r': cuda(torch.zeros(1, self.hidden_size)), 'z': cuda(torch.zeros(1, self.hidden_size)), 'src_x': cuda(torch.zeros(1, self.hidden_size)), 'dst_x': cuda(torch.zeros(1, self.hidden_size)), 'rm': cuda(torch.zeros(1, self.hidden_size)), 'accum_rm': cuda(torch.zeros(1, self.hidden_size)), }) first = False mol_tree.edges[uv].data['src_x'] = mol_tree.nodes[u].data['x'] # keeping dst_x 0 is fine as h on new edge doesn't depend on that. # DGL doesn't dynamically maintain a line graph. mol_tree_lg = dgl.line_graph(mol_tree, backtracking=False, shared=True) mol_tree_lg.pull( uv, dec_tree_edge_msg, dec_tree_edge_reduce, self.dec_tree_edge_update.update_zm, ) mol_tree.pull( v, dec_tree_node_msg, dec_tree_node_reduce, ) vdata = mol_tree.nodes[v].data h_v = vdata['h'] q_input = torch.cat([h_v, mol_vec], 1) q_score = torch.softmax(self.W_o(torch.relu(self.W(q_input))), -1) _, sort_wid = torch.sort(q_score, 1, descending=True) sort_wid = sort_wid.squeeze() next_wid = None for wid in sort_wid.tolist()[:5]: slots = self.vocab.get_slots(wid) cand_node_dict = create_node_dict( self.vocab.get_smiles(wid)) if (have_slots(u_slots, slots) and can_assemble(mol_tree, u, cand_node_dict)): next_wid = wid next_slots = slots next_node_dict = cand_node_dict break if next_wid is None: # Failed adding an actual children; v is a spurious node # and we mark it. vdata['fail'] = cuda(torch.tensor([1])) backtrack = True else: next_wid = cuda(torch.tensor([next_wid])) vdata['wid'] = next_wid vdata['x'] = self.embedding(next_wid) mol_tree.nodes_dict[v] = next_node_dict all_nodes[v] = next_node_dict stack.append((v, next_slots)) mol_tree.add_edge(v, u) vu = new_edge_id new_edge_id += 1 mol_tree.edges[uv].data['dst_x'] = mol_tree.nodes[v].data[ 'x'] mol_tree.edges[vu].data['src_x'] = mol_tree.nodes[v].data[ 'x'] mol_tree.edges[vu].data['dst_x'] = mol_tree.nodes[u].data[ 'x'] # DGL doesn't dynamically maintain a line graph. mol_tree_lg = dgl.line_graph(mol_tree, backtracking=False, shared=True) mol_tree_lg.apply_nodes(self.dec_tree_edge_update.update_r, uv) if backtrack: if len(stack) == 1: break # At root, terminate pu, _ = stack[-2] u_pu = mol_tree.edge_id(u, pu) mol_tree_lg.pull( u_pu, dec_tree_edge_msg, dec_tree_edge_reduce, self.dec_tree_edge_update, ) mol_tree.pull( pu, dec_tree_node_msg, dec_tree_node_reduce, ) stack.pop() effective_nodes = mol_tree.filter_nodes( lambda nodes: nodes.data['fail'] != 1) effective_nodes, _ = torch.sort(effective_nodes) return mol_tree, all_nodes, effective_nodes
def line_graph(g, backtracking=True, shared=False): #g2 = tocpu(g) g2 = dgl.line_graph(g, backtracking, shared) #g2 = g2.to(g.device) g2.ndata.update(g.edata) return g2
def forward(self, tree_graphs, tree_vec): device = tree_vec.device batch_size = tree_graphs.batch_size root_ids = get_root_ids(tree_graphs) if 'x' not in tree_graphs.ndata: tree_graphs.ndata['x'] = self.embedding(tree_graphs.ndata['wid']) if 'src_x' not in tree_graphs.edata: tree_graphs.apply_edges(fn.copy_u('x', 'src_x')) tree_graphs = tree_graphs.local_var() tree_graphs.apply_edges(func=lambda edges: {'dst_wid': edges.dst['wid']}) line_tree_graphs = dgl.line_graph(tree_graphs, backtracking=False, shared=True) line_num_nodes = line_tree_graphs.num_nodes() line_tree_graphs.ndata.update({ 'src_x_r': self.W_r(line_tree_graphs.ndata['src_x']), # Exploit the fact that the reduce function is a sum of incoming messages, # and uncomputed messages are zero vectors. 'h': torch.zeros(line_num_nodes, self.hidden_size).to(device), 'vec': dgl.broadcast_edges(tree_graphs, tree_vec), 'sum_h': torch.zeros(line_num_nodes, self.hidden_size).to(device), 'sum_gated_h': torch.zeros(line_num_nodes, self.hidden_size).to(device) }) # input tensors for stop prediction (p) and label prediction (q) pred_hiddens, pred_mol_vecs, pred_targets = [], [], [] stop_hiddens, stop_targets = [], [] # Predict root pred_hiddens.append(torch.zeros(batch_size, self.hidden_size).to(device)) pred_targets.append(tree_graphs.ndata['wid'][root_ids.to(device)]) pred_mol_vecs.append(tree_vec) # Traverse the tree and predict on children for eid, p in dfs_order(tree_graphs, root_ids.to(dtype=tree_graphs.idtype)): eid = eid.to(device) p = p.to(device=device, dtype=tree_graphs.idtype) # Message passing excluding the target line_tree_graphs.pull(v=eid, message_func=fn.copy_u('h', 'h_nei'), reduce_func=fn.sum('h_nei', 'sum_h')) line_tree_graphs.pull(v=eid, message_func=self.gru_message, reduce_func=fn.sum('m', 'sum_gated_h')) line_tree_graphs.apply_nodes(self.gru_update, v=eid) # Node aggregation including the target # By construction, the edges of the raw graph follow the order of # (i1, j1), (j1, i1), (i2, j2), (j2, i2), ... The order of the nodes # in the line graph corresponds to the order of the edges in the raw graph. eid = eid.long() reverse_eid = torch.bitwise_xor(eid, torch.tensor(1).to(device)) cur_o = line_tree_graphs.ndata['sum_h'][eid] + \ line_tree_graphs.ndata['h'][reverse_eid] # Gather targets mask = (p == torch.tensor(0).to(device)) pred_list = eid[mask] stop_target = torch.tensor(1).to(device) - p # Hidden states for stop prediction stop_hidden = torch.cat([line_tree_graphs.ndata['src_x'][eid], cur_o, line_tree_graphs.ndata['vec'][eid]], dim=1) stop_hiddens.append(stop_hidden) stop_targets.extend(stop_target) #Hidden states for clique prediction if len(pred_list) > 0: pred_mol_vecs.append(line_tree_graphs.ndata['vec'][pred_list]) pred_hiddens.append(line_tree_graphs.ndata['h'][pred_list]) pred_targets.append(line_tree_graphs.ndata['dst_wid'][pred_list]) #Last stop at root root_ids = root_ids.to(device) cur_x = tree_graphs.ndata['x'][root_ids] tree_graphs.edata['h'] = line_tree_graphs.ndata['h'] tree_graphs.pull(v=root_ids.to(dtype=tree_graphs.idtype), message_func=fn.copy_e('h', 'm'), reduce_func=fn.sum('m', 'cur_o')) stop_hidden = torch.cat([cur_x, tree_graphs.ndata['cur_o'][root_ids], tree_vec], dim=1) stop_hiddens.append(stop_hidden) stop_targets.extend(torch.zeros(batch_size).to(device)) # Predict next clique pred_hiddens = torch.cat(pred_hiddens, dim=0) pred_mol_vecs = torch.cat(pred_mol_vecs, dim=0) pred_vecs = torch.cat([pred_hiddens, pred_mol_vecs], dim=1) pred_vecs = F.relu(self.W(pred_vecs)) pred_scores = self.W_o(pred_vecs) pred_targets = torch.cat(pred_targets, dim=0) pred_loss = self.pred_loss(pred_scores, pred_targets) / batch_size _, preds = torch.max(pred_scores, dim=1) pred_acc = torch.eq(preds, pred_targets).float() pred_acc = torch.sum(pred_acc) / pred_targets.nelement() # Predict stop stop_hiddens = torch.cat(stop_hiddens, dim=0) stop_vecs = F.relu(self.U(stop_hiddens)) stop_scores = self.U_s(stop_vecs).squeeze() stop_targets = torch.Tensor(stop_targets).to(device) stop_loss = self.stop_loss(stop_scores, stop_targets) / batch_size stops = torch.ge(stop_scores, 0).float() stop_acc = torch.eq(stops, stop_targets).float() stop_acc = torch.sum(stop_acc) / stop_targets.nelement() return pred_loss, stop_loss, pred_acc.item(), stop_acc.item()