예제 #1
0
    def test_decode(self):
        """test_decode."""
        for smiles in self.__smiles:
            tree = MolTree(smiles)
            tree.recover()

            cur_mol = copy_edit_mol(tree.get_nodes()[0].get_mol())
            global_amap = [{}] + [{} for _ in tree.get_nodes()]
            global_amap[1] = {
                atom.GetIdx(): atom.GetIdx()
                for atom in cur_mol.GetAtoms()
            }

            dfs_assemble(cur_mol, global_amap, [], tree.get_nodes()[0], None)

            cur_mol = cur_mol.GetMol()
            cur_mol = rdkit.Chem.MolFromSmiles(rdkit.Chem.MolToSmiles(cur_mol))
            set_atommap(cur_mol)
            dec_smiles = rdkit.Chem.MolToSmiles(cur_mol)

            gold_smiles = rdkit.Chem.MolToSmiles(
                rdkit.Chem.MolFromSmiles(smiles))

            if gold_smiles != dec_smiles:
                print(gold_smiles, dec_smiles)

            self.assertEqual(gold_smiles, dec_smiles)
예제 #2
0
    def decode_test():
        wrong = 0
        for tot, s in enumerate(sys.stdin):
            s = s.split()[0]
            tree = MolTree(s)
            tree.recover()

            cur_mol = copy_edit_mol(tree.nodes[0].mol)
            global_amap = [{}] + [{} for node in tree.nodes]
            global_amap[1] = {
                atom.GetIdx(): atom.GetIdx()
                for atom in cur_mol.GetAtoms()
            }

            dfs_assemble(cur_mol, global_amap, [], tree.nodes[0], None)

            cur_mol = cur_mol.GetMol()
            cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol))
            set_atommap(cur_mol)
            dec_smiles = Chem.MolToSmiles(cur_mol)

            gold_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(s))
            if gold_smiles != dec_smiles:
                print(gold_smiles, dec_smiles)
                wrong += 1
            print(wrong, tot + 1)
예제 #3
0
 def count():
     cnt, n = 0, 0
     for s in sys.stdin:
         s = s.split()[0]
         tree = MolTree(s)
         tree.recover()
         tree.assemble()
         for node in tree.nodes:
             cnt += len(node.cands)
         n += len(tree.nodes)
예제 #4
0
 def enum_test():
     for s in sys.stdin:
         s = s.split()[0]
         tree = MolTree(s)
         tree.recover()
         tree.assemble()
         for node in tree.nodes:
             if node.label not in node.cands:
                 print(tree.smiles)
                 print(node.smiles, [x.smiles for x in node.neighbors])
                 print(node.label, len(node.cands))
예제 #5
0
 def test_enum(self):
     """test_enum."""
     for smiles in self.__smiles:
         tree = MolTree(smiles)
         tree.recover()
         tree.assemble()
         for node in tree.get_nodes():
             if node.get_label() not in node.get_candidates():
                 print(tree.get_smiles())
                 print(node.get_smiles(),
                       [x.get_smiles() for x in node.get_neighbors()])
                 print(node.get_label(), len(node.get_candidates()))
예제 #6
0
    def reconstruct(self, smiles, prob_decode=False):
        mol_tree = MolTree(smiles)
        mol_tree.recover()
        _, tree_vec, mol_vec = self.encode([mol_tree])

        tree_mean = self.T_mean(tree_vec)
        tree_log_var = -torch.abs(
            self.T_var(tree_vec))  #Following Mueller et al.
        mol_mean = self.G_mean(mol_vec)
        mol_log_var = -torch.abs(
            self.G_var(mol_vec))  #Following Mueller et al.

        epsilon = create_var(torch.randn(1, int(self.latent_size / 2)), False)
        tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon
        epsilon = create_var(torch.randn(1, int(self.latent_size / 2)), False)
        mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon
        return self.decode(tree_vec, mol_vec, prob_decode)
예제 #7
0
    def recon_eval(self, smiles):
        mol_tree = MolTree(smiles)
        mol_tree.recover()
        _, tree_vec, mol_vec = self.encode([mol_tree])

        tree_mean = self.T_mean(tree_vec)
        tree_log_var = -torch.abs(
            self.T_var(tree_vec))  #Following Mueller et al.
        mol_mean = self.G_mean(mol_vec)
        mol_log_var = -torch.abs(
            self.G_var(mol_vec))  #Following Mueller et al.

        all_smiles = []
        for i in range(10):
            epsilon = create_var(torch.randn(1, int(self.latent_size / 2)),
                                 False)
            tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon
            epsilon = create_var(torch.randn(1, self.latent_size / 2), False)
            mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon
            for j in range(10):
                new_smiles = self.decode(tree_vec, mol_vec, prob_decode=True)
                all_smiles.append(new_smiles)
        return all_smiles
예제 #8
0
 def __getitem__(self, idx):
     smiles = self.data[idx]
     mol_tree = MolTree(smiles)
     mol_tree.recover()
     mol_tree.assemble()
     return mol_tree, self.prop_data[idx]
예제 #9
0
    def optimize(self, smiles, sim_cutoff, lr=2.0, num_iter=20):
        mol_tree = MolTree(smiles)
        mol_tree.recover()
        _, tree_vec, mol_vec = self.encode([mol_tree])

        mol = Chem.MolFromSmiles(smiles)
        fp1 = AllChem.GetMorganFingerprint(mol, 2)

        tree_mean = self.T_mean(tree_vec)
        tree_log_var = -torch.abs(
            self.T_var(tree_vec))  #Following Mueller et al.
        mol_mean = self.G_mean(mol_vec)
        mol_log_var = -torch.abs(
            self.G_var(mol_vec))  #Following Mueller et al.
        mean = torch.cat([tree_mean, mol_mean], dim=1)
        log_var = torch.cat([tree_log_var, mol_log_var], dim=1)
        cur_vec = create_var(mean.data, True)

        visited = []
        for step in range(num_iter):
            prop_val = self.propNN(cur_vec).squeeze()
            grad = torch.autograd.grad(prop_val, cur_vec)[0]
            cur_vec = cur_vec.data + lr * grad.data
            cur_vec = create_var(cur_vec, True)
            visited.append(cur_vec)

        l, r = 0, num_iter - 1
        while l < r - 1:
            mid = int((l + r) / 2)
            new_vec = visited[mid]
            tree_vec, mol_vec = torch.chunk(new_vec, 2, dim=1)
            new_smiles = self.decode(tree_vec, mol_vec, prob_decode=False)
            if new_smiles is None:
                r = mid - 1
                continue

            new_mol = Chem.MolFromSmiles(new_smiles)
            fp2 = AllChem.GetMorganFingerprint(new_mol, 2)
            sim = DataStructs.TanimotoSimilarity(fp1, fp2)
            if sim < sim_cutoff:
                r = mid - 1
            else:
                l = mid
        """
        best_vec = visited[0]
        for new_vec in visited:
            tree_vec,mol_vec = torch.chunk(new_vec, 2, dim=1)
            new_smiles = self.decode(tree_vec, mol_vec, prob_decode=False)
            if new_smiles is None: continue
            new_mol = Chem.MolFromSmiles(new_smiles)
            fp2 = AllChem.GetMorganFingerprint(new_mol, 2)
            sim = DataStructs.TanimotoSimilarity(fp1, fp2) 
            if sim >= sim_cutoff:
                best_vec = new_vec
        """
        tree_vec, mol_vec = torch.chunk(visited[l], 2, dim=1)
        #tree_vec,mol_vec = torch.chunk(best_vec, 2, dim=1)
        new_smiles = self.decode(tree_vec, mol_vec, prob_decode=False)
        if new_smiles is None:
            return smiles, 1.0
        new_mol = Chem.MolFromSmiles(new_smiles)
        fp2 = AllChem.GetMorganFingerprint(new_mol, 2)
        sim = DataStructs.TanimotoSimilarity(fp1, fp2)
        if sim >= sim_cutoff:
            return new_smiles, sim
        else:
            return smiles, 1.0