def decode_test(): wrong = 0 for tot, s in enumerate(sys.stdin): s = s.split()[0] tree = MolTree(s) tree.recover() cur_mol = copy_edit_mol(tree.nodes[0].mol) global_amap = [{}] + [{} for node in tree.nodes] global_amap[1] = { atom.GetIdx(): atom.GetIdx() for atom in cur_mol.GetAtoms() } dfs_assemble(cur_mol, global_amap, [], tree.nodes[0], None) cur_mol = cur_mol.GetMol() cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol)) set_atommap(cur_mol) dec_smiles = Chem.MolToSmiles(cur_mol) gold_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(s)) if gold_smiles != dec_smiles: print(gold_smiles, dec_smiles) wrong += 1 print(wrong, tot + 1)
def test_decode(self): """test_decode.""" for smiles in self.__smiles: tree = MolTree(smiles) tree.recover() cur_mol = copy_edit_mol(tree.get_nodes()[0].get_mol()) global_amap = [{}] + [{} for _ in tree.get_nodes()] global_amap[1] = { atom.GetIdx(): atom.GetIdx() for atom in cur_mol.GetAtoms() } dfs_assemble(cur_mol, global_amap, [], tree.get_nodes()[0], None) cur_mol = cur_mol.GetMol() cur_mol = rdkit.Chem.MolFromSmiles(rdkit.Chem.MolToSmiles(cur_mol)) set_atommap(cur_mol) dec_smiles = rdkit.Chem.MolToSmiles(cur_mol) gold_smiles = rdkit.Chem.MolToSmiles( rdkit.Chem.MolFromSmiles(smiles)) if gold_smiles != dec_smiles: print(gold_smiles, dec_smiles) self.assertEqual(gold_smiles, dec_smiles)
def tree_test(): for s in sys.stdin: s = s.split()[0] tree = MolTree(s) print('-------------------------------------------') print(s) for node in tree.nodes: print(node.smiles, [x.smiles for x in node.neighbors])
def encode_latent_mean(self, smiles_list): mol_batch = [MolTree(s) for s in smiles_list] for mol_tree in mol_batch: mol_tree.recover() _, tree_vec, mol_vec = self.encode(mol_batch) tree_mean = self.T_mean(tree_vec) mol_mean = self.G_mean(mol_vec) return torch.cat([tree_mean, mol_mean], dim=1)
def test_vae(): vocab = [x.strip('\r\n ') for x in open('data/vocab.txt')] vocab = Vocab(vocab) mol_batch = [MolTree(smiles) for smiles in smiles_batch] for mol_tree in mol_batch: mol_tree.recover() mol_tree.assemble() set_batch_nodeID(mol_batch, vocab) nx_mol_batch = [DGLMolTree(smiles) for smiles in smiles_batch] for nx_mol_tree in nx_mol_batch: nx_mol_tree.recover() nx_mol_tree.assemble() dgl_set_batch_nodeID(nx_mol_batch, vocab) vae = JTNNVAE(vocab, 50, 50, 3) dglvae = DGLJTNNVAE(vocab, 50, 50, 3) e1 = torch.randn(len(smiles_batch), 25) e2 = torch.randn(len(smiles_batch), 25) dglvae.embedding = vae.embedding dgljtnn, dgljtmpn, dglmpn, dgldecoder = dglvae.jtnn, dglvae.jtmpn, dglvae.mpn, dglvae.decoder jtnn, mpn, decoder, jtmpn = vae.jtnn, vae.mpn, vae.decoder, vae.jtmpn dgljtnn.enc_tree_update.W_z = jtnn.W_z dgljtnn.enc_tree_update.W_h = jtnn.W_h dgljtnn.enc_tree_update.W_r = jtnn.W_r dgljtnn.enc_tree_update.U_r = jtnn.U_r dgljtnn.enc_tree_gather_update.W = jtnn.W dgljtnn.embedding = jtnn.embedding dgljtmpn.W_i = jtmpn.W_i dgljtmpn.gather_updater.W_o = jtmpn.W_o dgljtmpn.loopy_bp_updater.W_h = jtmpn.W_h mpn.W_i = dglmpn.W_i mpn.W_o = dglmpn.gather_updater.W_o mpn.W_h = dglmpn.loopy_bp_updater.W_h decoder.W = dgldecoder.W decoder.U = dgldecoder.U decoder.W_o = dgldecoder.W_o decoder.U_s = dgldecoder.U_s decoder.W_z = dgldecoder.dec_tree_edge_update.W_z decoder.W_r = dgldecoder.dec_tree_edge_update.W_r decoder.U_r = dgldecoder.dec_tree_edge_update.U_r decoder.W_h = dgldecoder.dec_tree_edge_update.W_h decoder.embedding = dgldecoder.embedding dglvae.T_mean = vae.T_mean dglvae.G_mean = vae.G_mean dglvae.T_var = vae.T_var dglvae.G_var = vae.G_var loss, kl_loss, wacc, tacc, aacc, sacc = vae(mol_batch, e1=e1, e2=e2) loss_dgl, kl_loss_dgl, wacc_dgl, tacc_dgl, aacc_dgl, sacc_dgl = dglvae(nx_mol_batch, e1=e1, e2=e2) assert torch.allclose(loss, loss_dgl) assert torch.allclose(kl_loss, kl_loss_dgl) assert torch.allclose(wacc, wacc_dgl) assert torch.allclose(tacc, tacc_dgl) assert np.allclose(aacc, aacc_dgl) assert np.allclose(sacc, sacc_dgl)
def count(): cnt, n = 0, 0 for s in sys.stdin: s = s.split()[0] tree = MolTree(s) tree.recover() tree.assemble() for node in tree.nodes: cnt += len(node.cands) n += len(tree.nodes)
def enum_test(): for s in sys.stdin: s = s.split()[0] tree = MolTree(s) tree.recover() tree.assemble() for node in tree.nodes: if node.label not in node.cands: print(tree.smiles) print(node.smiles, [x.smiles for x in node.neighbors]) print(node.label, len(node.cands))
def to_vocab(line): vocab = set() cid, smi, label = line.strip().split('\t') try: mol = MolTree(smi, skip_stereo=True) except SmilesFailure: print('SmilesFailure with CID', cid) return for c in mol.nodes: vocab.add(c.smiles) return vocab
def test_enum(self): """test_enum.""" for smiles in self.__smiles: tree = MolTree(smiles) tree.recover() tree.assemble() for node in tree.get_nodes(): if node.get_label() not in node.get_candidates(): print(tree.get_smiles()) print(node.get_smiles(), [x.get_smiles() for x in node.get_neighbors()]) print(node.get_label(), len(node.get_candidates()))
def test_tree(self): """test_tree.""" for smiles in self.__smiles: tree = MolTree(smiles) self.assertTrue(tree.get_nodes()) for node in tree.get_nodes(): self.assertTrue(node.get_smiles()) self.assertTrue( all([ neighbour.get_smiles() for neighbour in node.get_neighbors() ]))
def reconstruct(self, smiles, prob_decode=False): mol_tree = MolTree(smiles) mol_tree.recover() _, tree_vec, mol_vec = self.encode([mol_tree]) tree_mean = self.T_mean(tree_vec) tree_log_var = -torch.abs( self.T_var(tree_vec)) #Following Mueller et al. mol_mean = self.G_mean(mol_vec) mol_log_var = -torch.abs( self.G_var(mol_vec)) #Following Mueller et al. epsilon = create_var(torch.randn(1, int(self.latent_size / 2)), False) tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon epsilon = create_var(torch.randn(1, int(self.latent_size / 2)), False) mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon return self.decode(tree_vec, mol_vec, prob_decode)
def recon_eval(self, smiles): mol_tree = MolTree(smiles) mol_tree.recover() _, tree_vec, mol_vec = self.encode([mol_tree]) tree_mean = self.T_mean(tree_vec) tree_log_var = -torch.abs( self.T_var(tree_vec)) #Following Mueller et al. mol_mean = self.G_mean(mol_vec) mol_log_var = -torch.abs( self.G_var(mol_vec)) #Following Mueller et al. all_smiles = [] for i in range(10): epsilon = create_var(torch.randn(1, int(self.latent_size / 2)), False) tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon epsilon = create_var(torch.randn(1, self.latent_size / 2), False) mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon for j in range(10): new_smiles = self.decode(tree_vec, mol_vec, prob_decode=True) all_smiles.append(new_smiles) return all_smiles
def test_treedec(): mol_batch = [MolTree(smiles) for smiles in smiles_batch] for mol_tree in mol_batch: mol_tree.recover() mol_tree.assemble() tree_vec = torch.randn(len(mol_batch), 5) vocab = [x.strip('\r\n ') for x in open('data/vocab.txt')] vocab = Vocab(vocab) set_batch_nodeID(mol_batch, vocab) nx_mol_batch = [DGLMolTree(smiles) for smiles in smiles_batch] for nx_mol_tree in nx_mol_batch: nx_mol_tree.recover() nx_mol_tree.assemble() dgl_set_batch_nodeID(nx_mol_batch, vocab) emb = nn.Embedding(vocab.size(), 5) dgljtnn = DGLJTNNDecoder(vocab, 5, 5, emb) dgl_q_loss, dgl_p_loss, dgl_q_acc, dgl_p_acc = dgljtnn(nx_mol_batch, tree_vec) jtnn = JTNNDecoder(vocab, 5, 5, emb) jtnn.W = dgljtnn.W jtnn.U = dgljtnn.U jtnn.W_o = dgljtnn.W_o jtnn.U_s = dgljtnn.U_s jtnn.W_z = dgljtnn.dec_tree_edge_update.W_z jtnn.W_r = dgljtnn.dec_tree_edge_update.W_r jtnn.U_r = dgljtnn.dec_tree_edge_update.U_r jtnn.W_h = dgljtnn.dec_tree_edge_update.W_h q_loss, p_loss, q_acc, p_acc = jtnn(mol_batch, tree_vec) assert isclose(p_loss, dgl_p_loss) assert isclose(q_loss, dgl_q_loss) assert isclose(p_acc, dgl_p_acc) assert isclose(q_acc, dgl_q_acc)
import sys from jtnn.mol_tree import MolTree lg = rdkit.RDLogger.logger() lg.setLevel(rdkit.RDLogger.CRITICAL) smiles = [ "O=C1[C@@H]2C=C[C@@H](C=CC2)C1(c1ccccc1)c1ccccc1", "O=C([O-])CC[C@@]12CCCC[C@]1(O)OC(=O)CC2", "ON=C1C[C@H]2CC3(C[C@@H](C1)c1ccccc12)OCCO3", "C[C@H]1CC(=O)[C@H]2[C@@]3(O)C(=O)c4cccc(O)c4[C@@H]4O[C@@]43[C@@H](O)C[C@]2(O)C1", 'Cc1cc(NC(=O)CSc2nnc3c4ccccc4n(C)c3n2)ccc1Br', 'CC(C)(C)c1ccc(C(=O)N[C@H]2CCN3CCCc4cccc2c43)cc1', "O=c1c2ccc3c(=O)n(-c4nccs4)c(=O)c4ccc(c(=O)n1-c1nccs1)c2c34", "O=C(N1CCc2c(F)ccc(F)c2C1)C1(O)Cc2ccccc2C1" ] mol_tree = MolTree("C") assert len(mol_tree.nodes) > 0 def tree_test(): for s in sys.stdin: s = s.split()[0] tree = MolTree(s) print('-------------------------------------------') print(s) for node in tree.nodes: print(node.smiles, [x.smiles for x in node.neighbors]) def decode_test(): wrong = 0 for tot, s in enumerate(sys.stdin): s = s.split()[0]
def __getitem__(self, idx): smiles = self.data[idx] mol_tree = MolTree(smiles) mol_tree.recover() mol_tree.assemble() return mol_tree, self.prop_data[idx]
def test_treeenc(): mol_batch = [MolTree(smiles) for smiles in smiles_batch] for mol_tree in mol_batch: mol_tree.recover() mol_tree.assemble() vocab = [x.strip('\r\n ') for x in open('data/vocab.txt')] vocab = Vocab(vocab) set_batch_nodeID(mol_batch, vocab) emb = nn.Embedding(vocab.size(), 5) jtnn = JTNNEncoder(vocab, 5, emb) root_batch = [mol_tree.nodes[0] for mol_tree in mol_batch] tree_mess, tree_vec = jtnn(root_batch) nx_mol_batch = [DGLMolTree(smiles) for smiles in smiles_batch] for nx_mol_tree in nx_mol_batch: nx_mol_tree.recover() nx_mol_tree.assemble() dgl_set_batch_nodeID(nx_mol_batch, vocab) dgljtnn = DGLJTNNEncoder(vocab, 5, emb) dgljtnn.enc_tree_update.W_z = jtnn.W_z dgljtnn.enc_tree_update.W_h = jtnn.W_h dgljtnn.enc_tree_update.W_r = jtnn.W_r dgljtnn.enc_tree_update.U_r = jtnn.U_r dgljtnn.enc_tree_gather_update.W = jtnn.W mol_tree_batch, dgl_tree_vec = dgljtnn(nx_mol_batch) dgl_tree_mess = mol_tree_batch.get_e_repr()['m'] assert dgl_tree_mess.shape[0] == len(tree_mess) fail = False for u, v in tree_mess: eid = mol_tree_batch.get_edge_id(u, v) if not allclose(tree_mess[(u, v)], dgl_tree_mess[eid]): fail = True print(u, v, tree_mess[(u, v)], dgl_tree_mess[eid][0]) assert not fail assert allclose(dgl_tree_vec, tree_vec) # Graph decoder cands = [] dglcands = [] jtmpn = JTMPN(5, 4) dgljtmpn = DGLJTMPN(5, 4) dgljtmpn.W_i = jtmpn.W_i dgljtmpn.gather_updater.W_o = jtmpn.W_o dgljtmpn.loopy_bp_updater.W_h = jtmpn.W_h for i, mol_tree in enumerate(mol_batch): for node in mol_tree.nodes: if node.is_leaf or len(node.cands) == 1: continue cands.extend([(cand, mol_tree.nodes, node) for cand in node.cand_mols]) cand_vec = jtmpn(cands, tree_mess) for i, mol_tree in enumerate(nx_mol_batch): for node_id, node in mol_tree.nodes.items(): if node['is_leaf'] or len(node['cands']) == 1: continue dglcands.extend([ (cand, mol_tree, node_id) for cand in node['cand_mols'] ]) assert len(cands) == len(dglcands) for item, dglitem in zip(cands, dglcands): assert Chem.MolToSmiles(item[0]) == Chem.MolToSmiles(dglitem[0]) dgl_cand_vec = dgljtmpn(dglcands, mol_tree_batch) # TODO: add check. Seems that the original implementation has a bug assert allclose(cand_vec, dgl_cand_vec)
def optimize(self, smiles, sim_cutoff, lr=2.0, num_iter=20): mol_tree = MolTree(smiles) mol_tree.recover() _, tree_vec, mol_vec = self.encode([mol_tree]) mol = Chem.MolFromSmiles(smiles) fp1 = AllChem.GetMorganFingerprint(mol, 2) tree_mean = self.T_mean(tree_vec) tree_log_var = -torch.abs( self.T_var(tree_vec)) #Following Mueller et al. mol_mean = self.G_mean(mol_vec) mol_log_var = -torch.abs( self.G_var(mol_vec)) #Following Mueller et al. mean = torch.cat([tree_mean, mol_mean], dim=1) log_var = torch.cat([tree_log_var, mol_log_var], dim=1) cur_vec = create_var(mean.data, True) visited = [] for step in range(num_iter): prop_val = self.propNN(cur_vec).squeeze() grad = torch.autograd.grad(prop_val, cur_vec)[0] cur_vec = cur_vec.data + lr * grad.data cur_vec = create_var(cur_vec, True) visited.append(cur_vec) l, r = 0, num_iter - 1 while l < r - 1: mid = int((l + r) / 2) new_vec = visited[mid] tree_vec, mol_vec = torch.chunk(new_vec, 2, dim=1) new_smiles = self.decode(tree_vec, mol_vec, prob_decode=False) if new_smiles is None: r = mid - 1 continue new_mol = Chem.MolFromSmiles(new_smiles) fp2 = AllChem.GetMorganFingerprint(new_mol, 2) sim = DataStructs.TanimotoSimilarity(fp1, fp2) if sim < sim_cutoff: r = mid - 1 else: l = mid """ best_vec = visited[0] for new_vec in visited: tree_vec,mol_vec = torch.chunk(new_vec, 2, dim=1) new_smiles = self.decode(tree_vec, mol_vec, prob_decode=False) if new_smiles is None: continue new_mol = Chem.MolFromSmiles(new_smiles) fp2 = AllChem.GetMorganFingerprint(new_mol, 2) sim = DataStructs.TanimotoSimilarity(fp1, fp2) if sim >= sim_cutoff: best_vec = new_vec """ tree_vec, mol_vec = torch.chunk(visited[l], 2, dim=1) #tree_vec,mol_vec = torch.chunk(best_vec, 2, dim=1) new_smiles = self.decode(tree_vec, mol_vec, prob_decode=False) if new_smiles is None: return smiles, 1.0 new_mol = Chem.MolFromSmiles(new_smiles) fp2 = AllChem.GetMorganFingerprint(new_mol, 2) sim = DataStructs.TanimotoSimilarity(fp1, fp2) if sim >= sim_cutoff: return new_smiles, sim else: return smiles, 1.0