def stereo(self, mol_batch, mol_vec): stereo_cands = mol_batch['stereo_cand_graph_batch'] batch_idx = mol_batch['stereo_cand_batch_idx'] labels = mol_batch['stereo_cand_labels'] lengths = mol_batch['stereo_cand_lengths'] if len(labels) == 0: # Only one stereoisomer exists; do nothing return cuda(torch.tensor(0.)), 1. batch_idx = cuda(torch.LongTensor(batch_idx)) stereo_cands = self.mpn(stereo_cands) stereo_cands = self.G_mean(stereo_cands) stereo_labels = mol_vec[batch_idx] scores = F.cosine_similarity(stereo_cands, stereo_labels) st, acc = 0, 0 all_loss = [] for label, le in zip(labels, lengths): cur_scores = scores[st:st+le] if cur_scores.data[label].item() >= cur_scores.max().item(): acc += 1 label = cuda(torch.LongTensor([label])) all_loss.append( F.cross_entropy(cur_scores.view(1, -1), label, size_average=False)) st += le all_loss = sum(all_loss) / len(labels) return all_loss, acc / len(labels)
def forward(self, nodes): #print(nodes.data.keys() ) x = nodes.data['x'] try: m = nodes.data['m'] except: m = torch.cuda.FloatTensor(1, self.hidden_size).fill_(0) return { 'h': cuda(torch.relu(self.W(cuda(torch.cat([x, m], 1))))), }
def decode(self, tree_vec, mol_vec): mol_tree, nodes_dict, effective_nodes = self.decoder.decode(tree_vec) effective_nodes_list = effective_nodes.tolist() nodes_dict = [nodes_dict[v] for v in effective_nodes_list] for i, (node_id, node) in enumerate(zip(effective_nodes_list, nodes_dict)): node['idx'] = i node['nid'] = i + 1 node['is_leaf'] = True if mol_tree.in_degree(node_id) > 1: node['is_leaf'] = False set_atommap(node['mol'], node['nid']) mol_tree_sg = mol_tree.subgraph(effective_nodes) mol_tree_sg.copy_from_parent() mol_tree_msg, _ = self.jtnn([mol_tree_sg]) mol_tree_msg = unbatch(mol_tree_msg)[0] mol_tree_msg.nodes_dict = nodes_dict cur_mol = copy_edit_mol(nodes_dict[0]['mol']) global_amap = [{}] + [{} for node in nodes_dict] global_amap[1] = {atom.GetIdx(): atom.GetIdx() for atom in cur_mol.GetAtoms()} cur_mol = self.dfs_assemble(mol_tree_msg, mol_vec, cur_mol, global_amap, [], 0, None) if cur_mol is None: return None cur_mol = cur_mol.GetMol() set_atommap(cur_mol) cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol)) if cur_mol is None: return None smiles2D = Chem.MolToSmiles(cur_mol) stereo_cands = decode_stereo(smiles2D) if len(stereo_cands) == 1: return stereo_cands[0] stereo_graphs = [mol2dgl_enc(c) for c in stereo_cands] stereo_cand_graphs, atom_x, bond_x = \ zip(*stereo_graphs) stereo_cand_graphs = batch(stereo_cand_graphs) atom_x = cuda(torch.cat(atom_x)) bond_x = cuda(torch.cat(bond_x)) stereo_cand_graphs.ndata['x'] = atom_x stereo_cand_graphs.edata['x'] = bond_x stereo_cand_graphs.edata['src_x'] = atom_x.new( bond_x.shape[0], atom_x.shape[1]).zero_() stereo_vecs = self.mpn(stereo_cand_graphs) stereo_vecs = self.G_mean(stereo_vecs) scores = F.cosine_similarity(stereo_vecs, mol_vec) _, max_id = scores.max(0) return stereo_cands[max_id.item()]
def sample(self, tree_vec, mol_vec, e1=None, e2=None): tree_mean = cuda(self.T_mean(tree_vec)) tree_log_var = -torch.abs(self.T_var(tree_vec)) mol_mean = self.G_mean(mol_vec) mol_log_var = -torch.abs(self.G_var(mol_vec)) epsilon = cuda(torch.randn(*tree_mean.shape)) if e1 is None else e1 tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon epsilon = cuda(torch.randn(*mol_mean.shape)) if e2 is None else e2 mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon z_mean = torch.cat([tree_mean, mol_mean], 1) z_log_var = torch.cat([tree_log_var, mol_log_var], 1) return tree_vec, mol_vec, z_mean, z_log_var
def run(self, cand_graphs, cand_line_graph, tree_mess_src_edges, tree_mess_tgt_edges, tree_mess_tgt_nodes, mol_tree_batch): n_nodes = cand_graphs.number_of_nodes() cand_graphs.apply_edges(func=lambda edges: {'src_x': edges.src['x']}, ) get_bond_features = cand_line_graph.ndata['x'] source_features = cand_line_graph.ndata['src_x'] features = torch.cat([source_features, get_bond_features], 1) msg_input = self.W_i(features) cand_line_graph.ndata.update({ 'msg_input': msg_input, 'msg': torch.relu(msg_input), 'accum_msg': torch.zeros_like(msg_input), }) zero_node_state = get_bond_features.new(n_nodes, self.hidden_size).zero_() cand_graphs.ndata.update({ 'm': zero_node_state.clone(), 'h': zero_node_state.clone(), }) cand_graphs.edata['alpha'] = \ cuda(torch.zeros(cand_graphs.number_of_edges(), self.hidden_size)) cand_graphs.ndata['alpha'] = zero_node_state if tree_mess_src_edges.shape[0] > 0: if PAPER: src_u, src_v = tree_mess_src_edges.unbind(1) tgt_u, tgt_v = tree_mess_tgt_edges.unbind(1) alpha = mol_tree_batch.edges[src_u, src_v].data['m'] cand_graphs.edges[tgt_u, tgt_v].data['alpha'] = alpha else: src_u, src_v = tree_mess_src_edges.unbind(1) alpha = mol_tree_batch.edges[src_u, src_v].data['m'] node_idx = (tree_mess_tgt_nodes.to( device=zero_node_state.device)[:, None].expand_as(alpha)) node_alpha = zero_node_state.clone().scatter_add( 0, node_idx, alpha) cand_graphs.ndata['alpha'] = node_alpha cand_graphs.apply_edges( func=lambda edges: {'alpha': edges.src['alpha']}, ) for i in range(self.depth - 1): cand_line_graph.update_all( mpn_loopy_bp_msg, mpn_loopy_bp_reduce, self.loopy_bp_updater, ) cand_graphs.update_all( mpn_gather_msg, mpn_gather_reduce, self.gather_updater, ) return cand_graphs
def encode(self, mol_batch): mol_graphs = mol_batch['mol_graph_batch'] mol_vec = cuda(self.mpn(mol_graphs)) mol_tree_batch, tree_vec = self.jtnn(mol_batch['mol_trees']) self.n_nodes_total += mol_graphs.number_of_nodes() self.n_edges_total += mol_graphs.number_of_edges() self.n_tree_nodes_total += sum(t.number_of_nodes() for t in mol_batch['mol_trees']) self.n_passes += 1 return mol_tree_batch, tree_vec, mol_vec
def assm(self, mol_batch, mol_tree_batch, mol_vec): cands = [mol_batch['cand_graph_batch'], mol_batch['tree_mess_src_e'], mol_batch['tree_mess_tgt_e'], mol_batch['tree_mess_tgt_n']] cand_vec = self.jtmpn(cands, mol_tree_batch) cand_vec = self.G_mean(cand_vec) batch_idx = cuda(torch.LongTensor(mol_batch['cand_batch_idx'])) mol_vec = mol_vec[batch_idx] mol_vec = mol_vec.view(-1, 1, self.latent_size // 2) cand_vec = cand_vec.view(-1, self.latent_size // 2, 1) scores = (mol_vec @ cand_vec)[:, 0, 0] cnt, tot, acc = 0, 0, 0 all_loss = [] for i, mol_tree in enumerate(mol_batch['mol_trees']): comp_nodes = [node_id for node_id, node in mol_tree.nodes_dict.items() if len(node['cands']) > 1 and not node['is_leaf']] cnt += len(comp_nodes) # segmented accuracy and cross entropy for node_id in comp_nodes: node = mol_tree.nodes_dict[node_id] label = node['cands'].index(node['label']) ncand = len(node['cands']) cur_score = scores[tot:tot+ncand] tot += ncand if cur_score[label].item() >= cur_score.max().item(): acc += 1 label = cuda(torch.LongTensor([label])) all_loss.append( F.cross_entropy(cur_score.view(1, -1), label, size_average=False)) all_loss = sum(all_loss) / len(mol_batch['mol_trees']) return all_loss, acc / cnt
def run(self, mol_tree_batch, mol_tree_batch_lg): # Since tree roots are designated to 0. In the batched graph we can # simply find the corresponding node ID by looking at node_offset node_offset = np.cumsum([0] + mol_tree_batch.batch_num_nodes) root_ids = node_offset[:-1] n_nodes = mol_tree_batch.number_of_nodes() n_edges = mol_tree_batch.number_of_edges() # Assign structure embeddings to tree nodes x = cuda(self.embedding(cuda(mol_tree_batch.ndata['wid']))) h = torch.cuda.FloatTensor(n_nodes, self.hidden_size).fill_(0) mol_tree_batch.ndata.update({ 'x': x, 'h': h, }) # Initialize the intermediate variables according to Eq (4)-(8). # Also initialize the src_x and dst_x fields. # TODO: context? mol_tree_batch.edata.update({ 's': torch.cuda.FloatTensor(n_edges, self.hidden_size).fill_(0), 'm': torch.cuda.FloatTensor(n_edges, self.hidden_size).fill_(0), 'r': torch.cuda.FloatTensor(n_edges, self.hidden_size).fill_(0), 'z': torch.cuda.FloatTensor(n_edges, self.hidden_size).fill_(0), 'src_x': torch.cuda.FloatTensor(n_edges, self.hidden_size).fill_(0), 'dst_x': torch.cuda.FloatTensor(n_edges, self.hidden_size).fill_(0), 'rm': torch.cuda.FloatTensor(n_edges, self.hidden_size).fill_(0), 'accum_rm': torch.cuda.FloatTensor(n_edges, self.hidden_size).fill_(0), }) # Send the source/destination node features to edges mol_tree_batch.apply_edges(func=lambda edges: { 'src_x': edges.src['x'], 'dst_x': edges.dst['x'] }, ) # Message passing # I exploited the fact that the reduce function is a sum of incoming # messages, and the uncomputed messages are zero vectors. Essentially, # we can always compute s_ij as the sum of incoming m_ij, no matter # if m_ij is actually computed or not. for eid in level_order(mol_tree_batch, root_ids): #eid = mol_tree_batch.edge_ids(u, v) mol_tree_batch_lg.pull( eid, enc_tree_msg, enc_tree_reduce, self.enc_tree_update, ) # Readout mol_tree_batch.update_all( enc_tree_gather_msg, enc_tree_gather_reduce, self.enc_tree_gather_update, ) root_vecs = mol_tree_batch.nodes[root_ids].data['h'] return mol_tree_batch, root_vecs
def decode(self, mol_vec): assert mol_vec.shape[0] == 1 mol_tree = MolTree(None) init_hidden = torch.cuda.FloatTensor(1, self.hidden_size).fill_(0) root_hidden = torch.cat([init_hidden, mol_vec], 1) root_hidden = F.relu(self.W(root_hidden)) root_score = self.W_o(root_hidden) _, root_wid = torch.max(root_score, 1) root_wid = root_wid.view(1) mol_tree.add_nodes(1) # root mol_tree.nodes[0].data['wid'] = root_wid mol_tree.nodes[0].data['x'] = self.embedding(root_wid) mol_tree.nodes[0].data['h'] = init_hidden mol_tree.nodes[0].data['fail'] = cuda(torch.tensor([0])) mol_tree.nodes_dict[0] = root_node_dict = create_node_dict( self.vocab.get_smiles(root_wid)) stack, trace = [], [] stack.append((0, self.vocab.get_slots(root_wid))) all_nodes = {0: root_node_dict} h = {} first = True new_node_id = 0 new_edge_id = 0 for step in range(MAX_DECODE_LEN): u, u_slots = stack[-1] udata = mol_tree.nodes[u].data x = udata['x'] h = udata['h'] # Predict stop p_input = torch.cat([x, h, mol_vec], 1) p_score = torch.sigmoid(self.U_s(torch.relu(self.U(p_input)))) backtrack = (p_score.item() < 0.5) if not backtrack: # Predict next clique. Note that the prediction may fail due # to lack of assemblable components mol_tree.add_nodes(1) new_node_id += 1 v = new_node_id mol_tree.add_edges(u, v) uv = new_edge_id new_edge_id += 1 if first: mol_tree.edata.update({ 's': torch.cuda.FloatTensor(1, self.hidden_size).fill_(0), 'm': torch.cuda.FloatTensor(1, self.hidden_size).fill_(0), 'r': torch.cuda.FloatTensor(1, self.hidden_size).fill_(0), 'z': torch.cuda.FloatTensor(1, self.hidden_size).fill_(0), 'src_x': torch.cuda.FloatTensor(1, self.hidden_size).fill_(0), 'dst_x': torch.cuda.FloatTensor(1, self.hidden_size).fill_(0), 'rm': torch.cuda.FloatTensor(1, self.hidden_size).fill_(0), 'accum_rm': torch.cuda.FloatTensor(1, self.hidden_size).fill_(0), }) first = False mol_tree.edges[uv].data['src_x'] = mol_tree.nodes[u].data['x'] # keeping dst_x 0 is fine as h on new edge doesn't depend on that. # DGL doesn't dynamically maintain a line graph. mol_tree_lg = mol_tree.line_graph(backtracking=False, shared=True) mol_tree_lg.pull( uv, dec_tree_edge_msg, dec_tree_edge_reduce, self.dec_tree_edge_update.update_zm, ) mol_tree.pull( v, dec_tree_node_msg, dec_tree_node_reduce, ) vdata = mol_tree.nodes[v].data h_v = vdata['h'] q_input = torch.cat([h_v, mol_vec], 1) q_score = torch.softmax(self.W_o(torch.relu(self.W(q_input))), -1) _, sort_wid = torch.sort(q_score, 1, descending=True) sort_wid = sort_wid.squeeze() next_wid = None for wid in sort_wid.tolist()[:5]: slots = self.vocab.get_slots(wid) cand_node_dict = create_node_dict(self.vocab.get_smiles(wid)) if (have_slots(u_slots, slots) and can_assemble(mol_tree, u, cand_node_dict)): next_wid = wid next_slots = slots next_node_dict = cand_node_dict break if next_wid is None: # Failed adding an actual children; v is a spurious node # and we mark it. vdata['fail'] = cuda(torch.tensor([1])) backtrack = True else: next_wid = cuda(torch.tensor([next_wid])) vdata['wid'] = next_wid vdata['x'] = self.embedding(next_wid) mol_tree.nodes_dict[v] = next_node_dict all_nodes[v] = next_node_dict stack.append((v, next_slots)) mol_tree.add_edge(v, u) vu = new_edge_id new_edge_id += 1 mol_tree.edges[uv].data['dst_x'] = mol_tree.nodes[v].data['x'] mol_tree.edges[vu].data['src_x'] = mol_tree.nodes[v].data['x'] mol_tree.edges[vu].data['dst_x'] = mol_tree.nodes[u].data['x'] # DGL doesn't dynamically maintain a line graph. mol_tree_lg = mol_tree.line_graph(backtracking=False, shared=True) mol_tree_lg.apply_nodes( self.dec_tree_edge_update.update_r, uv ) if backtrack: if len(stack) == 1: break # At root, terminate pu, _ = stack[-2] u_pu = mol_tree.edge_id(u, pu) mol_tree_lg.pull( u_pu, dec_tree_edge_msg, dec_tree_edge_reduce, self.dec_tree_edge_update, ) mol_tree.pull( pu, dec_tree_node_msg, dec_tree_node_reduce, ) stack.pop() effective_nodes = mol_tree.filter_nodes(lambda nodes: nodes.data['fail'] != 1) effective_nodes, _ = torch.sort(effective_nodes) return mol_tree, all_nodes, effective_nodes
def run(self, mol_tree_batch, mol_tree_batch_lg, n_trees, tree_vec): times = [] times.append((116,time.time())) node_offset = np.cumsum([0] + mol_tree_batch.batch_num_nodes) root_ids = node_offset[:-1] n_nodes = mol_tree_batch.number_of_nodes() n_edges = mol_tree_batch.number_of_edges() times.append((122,time.time())) mol_tree_batch.ndata.update({ 'x': self.embedding(mol_tree_batch.ndata['wid']), 'h': torch.cuda.FloatTensor(n_nodes, self.hidden_size).fill_(0), 'new': torch.cuda.ByteTensor(n_nodes).fill_(1) # whether it's newly generated node }) times.append((129,time.time())) mol_tree_batch.edata.update({ 's': torch.cuda.FloatTensor(n_edges, self.hidden_size).fill_(0), 'm': torch.cuda.FloatTensor(n_edges, self.hidden_size).fill_(0), 'r': torch.cuda.FloatTensor(n_edges, self.hidden_size).fill_(0), 'z': torch.cuda.FloatTensor(n_edges, self.hidden_size).fill_(0), 'src_x': torch.cuda.FloatTensor(n_edges, self.hidden_size).fill_(0), 'dst_x': torch.cuda.FloatTensor(n_edges, self.hidden_size).fill_(0), 'rm': torch.cuda.FloatTensor(n_edges, self.hidden_size).fill_(0), 'accum_rm': torch.cuda.FloatTensor(n_edges, self.hidden_size).fill_(0) }) times.append((141,time.time())) mol_tree_batch.apply_edges( func=lambda edges: {'src_x': edges.src['x'], 'dst_x': edges.dst['x']}, ) # input tensors for stop prediction (p) and label prediction (q) p_inputs = [] p_targets = [] q_inputs = [] q_targets = [] times.append((152,time.time())) # Predict root mol_tree_batch.pull( root_ids, dec_tree_node_msg, dec_tree_node_reduce, dec_tree_node_update, ) times.append((161,time.time())) # Extract hidden states and store them for stop/label prediction h = mol_tree_batch.nodes[root_ids].data['h'] x = mol_tree_batch.nodes[root_ids].data['x'] p_inputs.append(torch.cat([x, h, tree_vec], 1)) # If the out degree is 0 we don't generate any edges at all root_out_degrees = mol_tree_batch.out_degrees(root_ids).cuda() q_inputs.append(torch.cat([h, tree_vec], 1)) q_targets.append(mol_tree_batch.nodes[root_ids].data['wid']) times.append((171,time.time())) # Traverse the tree and predict on children for eid, p in dfs_order(mol_tree_batch, root_ids): u, v = mol_tree_batch.find_edges(eid) p_target_list = torch.cuda.LongTensor(root_out_degrees.shape).fill_(0) p_target_list[root_out_degrees > 0] = 1 - p.cuda() p_target_list = p_target_list[root_out_degrees >= 0] p_targets.append(torch.tensor(p_target_list).cuda()) root_out_degrees -= (root_out_degrees == 0).long().cuda() root_out_degrees -= torch.tensor(np.isin(root_ids, v).astype('int64')).cuda() mol_tree_batch_lg.pull( eid, dec_tree_edge_msg, dec_tree_edge_reduce, self.dec_tree_edge_update, ) is_new = mol_tree_batch.nodes[v].data['new'] mol_tree_batch.pull( v, dec_tree_node_msg, dec_tree_node_reduce, dec_tree_node_update, ) # Extract n_repr = mol_tree_batch.nodes[v].data h = n_repr['h'] x = n_repr['x'] tree_vec_set = tree_vec[root_out_degrees >= 0] wid = n_repr['wid'] p_inputs.append(torch.cat([x, h, tree_vec_set], 1)) # Only newly generated nodes are needed for label prediction # NOTE: The following works since the uncomputed messages are zeros. q_input = torch.cat([h, tree_vec_set], 1)[is_new] q_target = wid[is_new] if q_input.shape[0] > 0: q_inputs.append(q_input) q_targets.append(q_target) p_targets.append(torch.zeros((root_out_degrees == 0).sum()).long().cuda()) times.append((214,time.time())) # Batch compute the stop/label prediction losses p_inputs = torch.cat(p_inputs, 0) p_targets = cuda(torch.cat(p_targets, 0)) q_inputs = torch.cat(q_inputs, 0) q_targets = torch.cat(q_targets, 0) times.append((221,time.time())) q = self.W_o(torch.relu(self.W(q_inputs))) p = self.U_s(torch.relu(self.U(p_inputs)))[:, 0] times.append((225,time.time())) p_loss = F.binary_cross_entropy_with_logits( p, p_targets.float(), size_average=False ) / n_trees q_loss = F.cross_entropy(q, q_targets, size_average=False) / n_trees p_acc = ((p > 0).long() == p_targets).sum().float() / p_targets.shape[0] q_acc = (q.max(1)[1] == q_targets).float().sum() / q_targets.shape[0] times.append((233,time.time())) self.q_inputs = q_inputs self.q_targets = q_targets self.q = q self.p_inputs = p_inputs self.p_targets = p_targets self.p = p #print("Dec Profile:") #for i in range(len(times)-1): # print("\t%d: %f" % (times[i][0], # times[i+1][1]-times[i][1])) return q_loss, p_loss, q_acc, p_acc
def dfs_assemble(self, mol_tree_msg, mol_vec, cur_mol, global_amap, fa_amap, cur_node_id, fa_node_id): nodes_dict = mol_tree_msg.nodes_dict fa_node = nodes_dict[fa_node_id] if fa_node_id is not None else None cur_node = nodes_dict[cur_node_id] fa_nid = fa_node['nid'] if fa_node is not None else -1 prev_nodes = [fa_node] if fa_node is not None else [] children_node_id = [v for v in mol_tree_msg.successors(cur_node_id).tolist() if nodes_dict[v]['nid'] != fa_nid] children = [nodes_dict[v] for v in children_node_id] neighbors = [nei for nei in children if nei['mol'].GetNumAtoms() > 1] neighbors = sorted(neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True) singletons = [nei for nei in children if nei['mol'].GetNumAtoms() == 1] neighbors = singletons + neighbors cur_amap = [(fa_nid, a2, a1) for nid, a1, a2 in fa_amap if nid == cur_node['nid']] cands = enum_assemble_nx(cur_node, neighbors, prev_nodes, cur_amap) if len(cands) == 0: return None cand_smiles, cand_mols, cand_amap = list(zip(*cands)) cands = [(candmol, mol_tree_msg, cur_node_id) for candmol in cand_mols] cand_graphs, atom_x, bond_x, tree_mess_src_edges, \ tree_mess_tgt_edges, tree_mess_tgt_nodes = mol2dgl_dec( cands) cand_graphs = batch(cand_graphs) atom_x = cuda(atom_x) bond_x = cuda(bond_x) cand_graphs.ndata['x'] = atom_x cand_graphs.edata['x'] = bond_x cand_graphs.edata['src_x'] = atom_x.new(bond_x.shape[0], atom_x.shape[1]).zero_() cand_vecs = self.jtmpn( (cand_graphs, tree_mess_src_edges, tree_mess_tgt_edges, tree_mess_tgt_nodes), mol_tree_msg, ) cand_vecs = self.G_mean(cand_vecs) mol_vec = mol_vec.squeeze() scores = cand_vecs @ mol_vec _, cand_idx = torch.sort(scores, descending=True) backup_mol = Chem.RWMol(cur_mol) for i in range(len(cand_idx)): cur_mol = Chem.RWMol(backup_mol) pred_amap = cand_amap[cand_idx[i].item()] new_global_amap = copy.deepcopy(global_amap) for nei_id, ctr_atom, nei_atom in pred_amap: if nei_id == fa_nid: continue new_global_amap[nei_id][nei_atom] = new_global_amap[cur_node['nid']][ctr_atom] cur_mol = attach_mols_nx(cur_mol, children, [], new_global_amap) new_mol = cur_mol.GetMol() new_mol = Chem.MolFromSmiles(Chem.MolToSmiles(new_mol)) if new_mol is None: continue result = True for nei_node_id, nei_node in zip(children_node_id, children): if nei_node['is_leaf']: continue cur_mol = self.dfs_assemble( mol_tree_msg, mol_vec, cur_mol, new_global_amap, pred_amap, nei_node_id, cur_node_id) if cur_mol is None: result = False break if result: return cur_mol return None