Example #1
0
    def decode_test():
        wrong = 0
        for tot, s in enumerate(sys.stdin):
            s = s.split()[0]
            tree = MolTree(s)
            tree.recover()

            cur_mol = copy_edit_mol(tree.nodes[0].mol)
            global_amap = [{}] + [{} for node in tree.nodes]
            global_amap[1] = {
                atom.GetIdx(): atom.GetIdx()
                for atom in cur_mol.GetAtoms()
            }

            dfs_assemble(cur_mol, global_amap, [], tree.nodes[0], None)

            cur_mol = cur_mol.GetMol()
            cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol))
            set_atommap(cur_mol)
            dec_smiles = Chem.MolToSmiles(cur_mol)

            gold_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(s))
            if gold_smiles != dec_smiles:
                print(gold_smiles, dec_smiles)
                wrong += 1
            print(wrong, tot + 1)
Example #2
0
    def test_decode(self):
        """test_decode."""
        for smiles in self.__smiles:
            tree = MolTree(smiles)
            tree.recover()

            cur_mol = copy_edit_mol(tree.get_nodes()[0].get_mol())
            global_amap = [{}] + [{} for _ in tree.get_nodes()]
            global_amap[1] = {
                atom.GetIdx(): atom.GetIdx()
                for atom in cur_mol.GetAtoms()
            }

            dfs_assemble(cur_mol, global_amap, [], tree.get_nodes()[0], None)

            cur_mol = cur_mol.GetMol()
            cur_mol = rdkit.Chem.MolFromSmiles(rdkit.Chem.MolToSmiles(cur_mol))
            set_atommap(cur_mol)
            dec_smiles = rdkit.Chem.MolToSmiles(cur_mol)

            gold_smiles = rdkit.Chem.MolToSmiles(
                rdkit.Chem.MolFromSmiles(smiles))

            if gold_smiles != dec_smiles:
                print(gold_smiles, dec_smiles)

            self.assertEqual(gold_smiles, dec_smiles)
Example #3
0
 def tree_test():
     for s in sys.stdin:
         s = s.split()[0]
         tree = MolTree(s)
         print('-------------------------------------------')
         print(s)
         for node in tree.nodes:
             print(node.smiles, [x.smiles for x in node.neighbors])
Example #4
0
    def encode_latent_mean(self, smiles_list):
        mol_batch = [MolTree(s) for s in smiles_list]
        for mol_tree in mol_batch:
            mol_tree.recover()

        _, tree_vec, mol_vec = self.encode(mol_batch)
        tree_mean = self.T_mean(tree_vec)
        mol_mean = self.G_mean(mol_vec)
        return torch.cat([tree_mean, mol_mean], dim=1)
Example #5
0
def test_vae():
    vocab = [x.strip('\r\n ') for x in open('data/vocab.txt')]
    vocab = Vocab(vocab)
    mol_batch = [MolTree(smiles) for smiles in smiles_batch]
    for mol_tree in mol_batch:
        mol_tree.recover()
        mol_tree.assemble()
    set_batch_nodeID(mol_batch, vocab)
    nx_mol_batch = [DGLMolTree(smiles) for smiles in smiles_batch]
    for nx_mol_tree in nx_mol_batch:
        nx_mol_tree.recover()
        nx_mol_tree.assemble()
    dgl_set_batch_nodeID(nx_mol_batch, vocab)

    vae = JTNNVAE(vocab, 50, 50, 3)
    dglvae = DGLJTNNVAE(vocab, 50, 50, 3)
    e1 = torch.randn(len(smiles_batch), 25)
    e2 = torch.randn(len(smiles_batch), 25)

    dglvae.embedding = vae.embedding
    dgljtnn, dgljtmpn, dglmpn, dgldecoder = dglvae.jtnn, dglvae.jtmpn, dglvae.mpn, dglvae.decoder
    jtnn, mpn, decoder, jtmpn = vae.jtnn, vae.mpn, vae.decoder, vae.jtmpn
    dgljtnn.enc_tree_update.W_z = jtnn.W_z
    dgljtnn.enc_tree_update.W_h = jtnn.W_h
    dgljtnn.enc_tree_update.W_r = jtnn.W_r
    dgljtnn.enc_tree_update.U_r = jtnn.U_r
    dgljtnn.enc_tree_gather_update.W = jtnn.W
    dgljtnn.embedding = jtnn.embedding
    dgljtmpn.W_i = jtmpn.W_i
    dgljtmpn.gather_updater.W_o = jtmpn.W_o
    dgljtmpn.loopy_bp_updater.W_h = jtmpn.W_h
    mpn.W_i = dglmpn.W_i
    mpn.W_o = dglmpn.gather_updater.W_o
    mpn.W_h = dglmpn.loopy_bp_updater.W_h
    decoder.W = dgldecoder.W
    decoder.U = dgldecoder.U
    decoder.W_o = dgldecoder.W_o
    decoder.U_s = dgldecoder.U_s
    decoder.W_z = dgldecoder.dec_tree_edge_update.W_z
    decoder.W_r = dgldecoder.dec_tree_edge_update.W_r
    decoder.U_r = dgldecoder.dec_tree_edge_update.U_r
    decoder.W_h = dgldecoder.dec_tree_edge_update.W_h
    decoder.embedding = dgldecoder.embedding
    dglvae.T_mean = vae.T_mean
    dglvae.G_mean = vae.G_mean
    dglvae.T_var = vae.T_var
    dglvae.G_var = vae.G_var

    loss, kl_loss, wacc, tacc, aacc, sacc = vae(mol_batch, e1=e1, e2=e2)
    loss_dgl, kl_loss_dgl, wacc_dgl, tacc_dgl, aacc_dgl, sacc_dgl = dglvae(nx_mol_batch, e1=e1, e2=e2)

    assert torch.allclose(loss, loss_dgl)
    assert torch.allclose(kl_loss, kl_loss_dgl)
    assert torch.allclose(wacc, wacc_dgl)
    assert torch.allclose(tacc, tacc_dgl)
    assert np.allclose(aacc, aacc_dgl)
    assert np.allclose(sacc, sacc_dgl)
Example #6
0
 def count():
     cnt, n = 0, 0
     for s in sys.stdin:
         s = s.split()[0]
         tree = MolTree(s)
         tree.recover()
         tree.assemble()
         for node in tree.nodes:
             cnt += len(node.cands)
         n += len(tree.nodes)
Example #7
0
 def enum_test():
     for s in sys.stdin:
         s = s.split()[0]
         tree = MolTree(s)
         tree.recover()
         tree.assemble()
         for node in tree.nodes:
             if node.label not in node.cands:
                 print(tree.smiles)
                 print(node.smiles, [x.smiles for x in node.neighbors])
                 print(node.label, len(node.cands))
Example #8
0
def to_vocab(line):
    vocab = set()
    cid, smi, label = line.strip().split('\t')
    try:
        mol = MolTree(smi, skip_stereo=True)
    except SmilesFailure:
        print('SmilesFailure with CID', cid)
        return
    for c in mol.nodes:
        vocab.add(c.smiles)
    return vocab
Example #9
0
 def test_enum(self):
     """test_enum."""
     for smiles in self.__smiles:
         tree = MolTree(smiles)
         tree.recover()
         tree.assemble()
         for node in tree.get_nodes():
             if node.get_label() not in node.get_candidates():
                 print(tree.get_smiles())
                 print(node.get_smiles(),
                       [x.get_smiles() for x in node.get_neighbors()])
                 print(node.get_label(), len(node.get_candidates()))
Example #10
0
    def test_tree(self):
        """test_tree."""
        for smiles in self.__smiles:
            tree = MolTree(smiles)

            self.assertTrue(tree.get_nodes())

            for node in tree.get_nodes():
                self.assertTrue(node.get_smiles())
                self.assertTrue(
                    all([
                        neighbour.get_smiles()
                        for neighbour in node.get_neighbors()
                    ]))
Example #11
0
    def reconstruct(self, smiles, prob_decode=False):
        mol_tree = MolTree(smiles)
        mol_tree.recover()
        _, tree_vec, mol_vec = self.encode([mol_tree])

        tree_mean = self.T_mean(tree_vec)
        tree_log_var = -torch.abs(
            self.T_var(tree_vec))  #Following Mueller et al.
        mol_mean = self.G_mean(mol_vec)
        mol_log_var = -torch.abs(
            self.G_var(mol_vec))  #Following Mueller et al.

        epsilon = create_var(torch.randn(1, int(self.latent_size / 2)), False)
        tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon
        epsilon = create_var(torch.randn(1, int(self.latent_size / 2)), False)
        mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon
        return self.decode(tree_vec, mol_vec, prob_decode)
Example #12
0
    def recon_eval(self, smiles):
        mol_tree = MolTree(smiles)
        mol_tree.recover()
        _, tree_vec, mol_vec = self.encode([mol_tree])

        tree_mean = self.T_mean(tree_vec)
        tree_log_var = -torch.abs(
            self.T_var(tree_vec))  #Following Mueller et al.
        mol_mean = self.G_mean(mol_vec)
        mol_log_var = -torch.abs(
            self.G_var(mol_vec))  #Following Mueller et al.

        all_smiles = []
        for i in range(10):
            epsilon = create_var(torch.randn(1, int(self.latent_size / 2)),
                                 False)
            tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon
            epsilon = create_var(torch.randn(1, self.latent_size / 2), False)
            mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon
            for j in range(10):
                new_smiles = self.decode(tree_vec, mol_vec, prob_decode=True)
                all_smiles.append(new_smiles)
        return all_smiles
Example #13
0
def test_treedec():
    mol_batch = [MolTree(smiles) for smiles in smiles_batch]
    for mol_tree in mol_batch:
        mol_tree.recover()
        mol_tree.assemble()

    tree_vec = torch.randn(len(mol_batch), 5)
    vocab = [x.strip('\r\n ') for x in open('data/vocab.txt')]
    vocab = Vocab(vocab)

    set_batch_nodeID(mol_batch, vocab)

    nx_mol_batch = [DGLMolTree(smiles) for smiles in smiles_batch]
    for nx_mol_tree in nx_mol_batch:
        nx_mol_tree.recover()
        nx_mol_tree.assemble()
    dgl_set_batch_nodeID(nx_mol_batch, vocab)

    emb = nn.Embedding(vocab.size(), 5)
    dgljtnn = DGLJTNNDecoder(vocab, 5, 5, emb)
    dgl_q_loss, dgl_p_loss, dgl_q_acc, dgl_p_acc = dgljtnn(nx_mol_batch, tree_vec)

    jtnn = JTNNDecoder(vocab, 5, 5, emb)
    jtnn.W = dgljtnn.W
    jtnn.U = dgljtnn.U
    jtnn.W_o = dgljtnn.W_o
    jtnn.U_s = dgljtnn.U_s
    jtnn.W_z = dgljtnn.dec_tree_edge_update.W_z
    jtnn.W_r = dgljtnn.dec_tree_edge_update.W_r
    jtnn.U_r = dgljtnn.dec_tree_edge_update.U_r
    jtnn.W_h = dgljtnn.dec_tree_edge_update.W_h
    q_loss, p_loss, q_acc, p_acc = jtnn(mol_batch, tree_vec)

    assert isclose(p_loss, dgl_p_loss)
    assert isclose(q_loss, dgl_q_loss)
    assert isclose(p_acc, dgl_p_acc)
    assert isclose(q_acc, dgl_q_acc)
Example #14
0
    import sys
    from jtnn.mol_tree import MolTree
    lg = rdkit.RDLogger.logger()
    lg.setLevel(rdkit.RDLogger.CRITICAL)

    smiles = [
        "O=C1[C@@H]2C=C[C@@H](C=CC2)C1(c1ccccc1)c1ccccc1",
        "O=C([O-])CC[C@@]12CCCC[C@]1(O)OC(=O)CC2",
        "ON=C1C[C@H]2CC3(C[C@@H](C1)c1ccccc12)OCCO3",
        "C[C@H]1CC(=O)[C@H]2[C@@]3(O)C(=O)c4cccc(O)c4[C@@H]4O[C@@]43[C@@H](O)C[C@]2(O)C1",
        'Cc1cc(NC(=O)CSc2nnc3c4ccccc4n(C)c3n2)ccc1Br',
        'CC(C)(C)c1ccc(C(=O)N[C@H]2CCN3CCCc4cccc2c43)cc1',
        "O=c1c2ccc3c(=O)n(-c4nccs4)c(=O)c4ccc(c(=O)n1-c1nccs1)c2c34",
        "O=C(N1CCc2c(F)ccc(F)c2C1)C1(O)Cc2ccccc2C1"
    ]
    mol_tree = MolTree("C")
    assert len(mol_tree.nodes) > 0

    def tree_test():
        for s in sys.stdin:
            s = s.split()[0]
            tree = MolTree(s)
            print('-------------------------------------------')
            print(s)
            for node in tree.nodes:
                print(node.smiles, [x.smiles for x in node.neighbors])

    def decode_test():
        wrong = 0
        for tot, s in enumerate(sys.stdin):
            s = s.split()[0]
Example #15
0
 def __getitem__(self, idx):
     smiles = self.data[idx]
     mol_tree = MolTree(smiles)
     mol_tree.recover()
     mol_tree.assemble()
     return mol_tree, self.prop_data[idx]
Example #16
0
def test_treeenc():
    mol_batch = [MolTree(smiles) for smiles in smiles_batch]
    for mol_tree in mol_batch:
        mol_tree.recover()
        mol_tree.assemble()

    vocab = [x.strip('\r\n ') for x in open('data/vocab.txt')]
    vocab = Vocab(vocab)

    set_batch_nodeID(mol_batch, vocab)

    emb = nn.Embedding(vocab.size(), 5)
    jtnn = JTNNEncoder(vocab, 5, emb)

    root_batch = [mol_tree.nodes[0] for mol_tree in mol_batch]
    tree_mess, tree_vec = jtnn(root_batch)

    nx_mol_batch = [DGLMolTree(smiles) for smiles in smiles_batch]
    for nx_mol_tree in nx_mol_batch:
        nx_mol_tree.recover()
        nx_mol_tree.assemble()
    dgl_set_batch_nodeID(nx_mol_batch, vocab)

    dgljtnn = DGLJTNNEncoder(vocab, 5, emb)
    dgljtnn.enc_tree_update.W_z = jtnn.W_z
    dgljtnn.enc_tree_update.W_h = jtnn.W_h
    dgljtnn.enc_tree_update.W_r = jtnn.W_r
    dgljtnn.enc_tree_update.U_r = jtnn.U_r
    dgljtnn.enc_tree_gather_update.W = jtnn.W

    mol_tree_batch, dgl_tree_vec = dgljtnn(nx_mol_batch)
    dgl_tree_mess = mol_tree_batch.get_e_repr()['m']

    assert dgl_tree_mess.shape[0] == len(tree_mess)
    fail = False
    for u, v in tree_mess:
        eid = mol_tree_batch.get_edge_id(u, v)
        if not allclose(tree_mess[(u, v)], dgl_tree_mess[eid]):
            fail = True
            print(u, v, tree_mess[(u, v)], dgl_tree_mess[eid][0])
    assert not fail

    assert allclose(dgl_tree_vec, tree_vec)

    # Graph decoder
    cands = []
    dglcands = []
    jtmpn = JTMPN(5, 4)
    dgljtmpn = DGLJTMPN(5, 4)
    dgljtmpn.W_i = jtmpn.W_i
    dgljtmpn.gather_updater.W_o = jtmpn.W_o
    dgljtmpn.loopy_bp_updater.W_h = jtmpn.W_h

    for i, mol_tree in enumerate(mol_batch):
        for node in mol_tree.nodes:
            if node.is_leaf or len(node.cands) == 1:
                continue
            cands.extend([(cand, mol_tree.nodes, node) for cand in node.cand_mols])

    cand_vec = jtmpn(cands, tree_mess)

    for i, mol_tree in enumerate(nx_mol_batch):
        for node_id, node in mol_tree.nodes.items():
            if node['is_leaf'] or len(node['cands']) == 1:
                continue
            dglcands.extend([
                (cand, mol_tree, node_id) for cand in node['cand_mols']
            ])

    assert len(cands) == len(dglcands)
    for item, dglitem in zip(cands, dglcands):
        assert Chem.MolToSmiles(item[0]) == Chem.MolToSmiles(dglitem[0])

    dgl_cand_vec = dgljtmpn(dglcands, mol_tree_batch)

    # TODO: add check.  Seems that the original implementation has a bug
    assert allclose(cand_vec, dgl_cand_vec)
Example #17
0
    def optimize(self, smiles, sim_cutoff, lr=2.0, num_iter=20):
        mol_tree = MolTree(smiles)
        mol_tree.recover()
        _, tree_vec, mol_vec = self.encode([mol_tree])

        mol = Chem.MolFromSmiles(smiles)
        fp1 = AllChem.GetMorganFingerprint(mol, 2)

        tree_mean = self.T_mean(tree_vec)
        tree_log_var = -torch.abs(
            self.T_var(tree_vec))  #Following Mueller et al.
        mol_mean = self.G_mean(mol_vec)
        mol_log_var = -torch.abs(
            self.G_var(mol_vec))  #Following Mueller et al.
        mean = torch.cat([tree_mean, mol_mean], dim=1)
        log_var = torch.cat([tree_log_var, mol_log_var], dim=1)
        cur_vec = create_var(mean.data, True)

        visited = []
        for step in range(num_iter):
            prop_val = self.propNN(cur_vec).squeeze()
            grad = torch.autograd.grad(prop_val, cur_vec)[0]
            cur_vec = cur_vec.data + lr * grad.data
            cur_vec = create_var(cur_vec, True)
            visited.append(cur_vec)

        l, r = 0, num_iter - 1
        while l < r - 1:
            mid = int((l + r) / 2)
            new_vec = visited[mid]
            tree_vec, mol_vec = torch.chunk(new_vec, 2, dim=1)
            new_smiles = self.decode(tree_vec, mol_vec, prob_decode=False)
            if new_smiles is None:
                r = mid - 1
                continue

            new_mol = Chem.MolFromSmiles(new_smiles)
            fp2 = AllChem.GetMorganFingerprint(new_mol, 2)
            sim = DataStructs.TanimotoSimilarity(fp1, fp2)
            if sim < sim_cutoff:
                r = mid - 1
            else:
                l = mid
        """
        best_vec = visited[0]
        for new_vec in visited:
            tree_vec,mol_vec = torch.chunk(new_vec, 2, dim=1)
            new_smiles = self.decode(tree_vec, mol_vec, prob_decode=False)
            if new_smiles is None: continue
            new_mol = Chem.MolFromSmiles(new_smiles)
            fp2 = AllChem.GetMorganFingerprint(new_mol, 2)
            sim = DataStructs.TanimotoSimilarity(fp1, fp2) 
            if sim >= sim_cutoff:
                best_vec = new_vec
        """
        tree_vec, mol_vec = torch.chunk(visited[l], 2, dim=1)
        #tree_vec,mol_vec = torch.chunk(best_vec, 2, dim=1)
        new_smiles = self.decode(tree_vec, mol_vec, prob_decode=False)
        if new_smiles is None:
            return smiles, 1.0
        new_mol = Chem.MolFromSmiles(new_smiles)
        fp2 = AllChem.GetMorganFingerprint(new_mol, 2)
        sim = DataStructs.TanimotoSimilarity(fp1, fp2)
        if sim >= sim_cutoff:
            return new_smiles, sim
        else:
            return smiles, 1.0