Esempio n. 1
0
    def test_decode(self):
        """test_decode."""
        for smiles in self.__smiles:
            tree = MolTree(smiles)
            tree.recover()

            cur_mol = copy_edit_mol(tree.get_nodes()[0].get_mol())
            global_amap = [{}] + [{} for _ in tree.get_nodes()]
            global_amap[1] = {
                atom.GetIdx(): atom.GetIdx()
                for atom in cur_mol.GetAtoms()
            }

            dfs_assemble(cur_mol, global_amap, [], tree.get_nodes()[0], None)

            cur_mol = cur_mol.GetMol()
            cur_mol = rdkit.Chem.MolFromSmiles(rdkit.Chem.MolToSmiles(cur_mol))
            set_atommap(cur_mol)
            dec_smiles = rdkit.Chem.MolToSmiles(cur_mol)

            gold_smiles = rdkit.Chem.MolToSmiles(
                rdkit.Chem.MolFromSmiles(smiles))

            if gold_smiles != dec_smiles:
                print(gold_smiles, dec_smiles)

            self.assertEqual(gold_smiles, dec_smiles)
Esempio n. 2
0
    def __init__(self, smiles):
        self.smiles = smiles
        self.mol = get_mol(smiles)

        # Stereo Generation
        mol = Chem.MolFromSmiles(smiles)
        self.smiles3D = Chem.MolToSmiles(mol, isomericSmiles=True)
        self.smiles2D = Chem.MolToSmiles(mol)
        self.stereo_cands = decode_stereo(self.smiles2D)

        cliques, edges = tree_decomp(self.mol)
        self.nodes = []
        root = 0
        for i, c in enumerate(cliques):
            cmol = get_clique_mol(self.mol, c)
            node = MolTreeNode(get_smiles(cmol), c)
            self.nodes.append(node)
            if min(c) == 0:
                root = i

        for x, y in edges:
            self.nodes[x].add_neighbor(self.nodes[y])
            self.nodes[y].add_neighbor(self.nodes[x])

        if root > 0:
            self.nodes[0], self.nodes[root] = self.nodes[root], self.nodes[0]

        for i, node in enumerate(self.nodes):
            node.nid = i + 1
            if len(node.neighbors) > 1:  # Leaf node mol is not marked
                set_atommap(node.mol, node.nid)
            node.is_leaf = (len(node.neighbors) == 1)
Esempio n. 3
0
    def decode(self, tree_vec, mol_vec, prob_decode):
        pred_root, pred_nodes = self.decoder.decode(tree_vec, prob_decode)

        #Mark nid & is_leaf & atommap
        for i, node in enumerate(pred_nodes):
            node.nid = i + 1
            node.is_leaf = (len(node.neighbors) == 1)
            if len(node.neighbors) > 1:
                set_atommap(node.mol, node.nid)

        tree_mess = self.jtnn([pred_root])[0]

        cur_mol = copy_edit_mol(pred_root.mol)
        global_amap = [{}] + [{} for node in pred_nodes]
        global_amap[1] = {
            atom.GetIdx(): atom.GetIdx()
            for atom in cur_mol.GetAtoms()
        }

        cur_mol = self.dfs_assemble(tree_mess, mol_vec, pred_nodes, cur_mol,
                                    global_amap, [], pred_root, None,
                                    prob_decode)
        if cur_mol is None:
            return None

        cur_mol = cur_mol.GetMol()
        set_atommap(cur_mol)
        cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol))
        if cur_mol is None: return None

        smiles2D = Chem.MolToSmiles(cur_mol)
        stereo_cands = decode_stereo(smiles2D)
        if len(stereo_cands) == 1:
            return stereo_cands[0]
        stereo_vecs = self.mpn(mol2graph(stereo_cands))
        stereo_vecs = self.G_mean(stereo_vecs)
        scores = nn.CosineSimilarity()(stereo_vecs, mol_vec)
        _, max_id = scores.max(dim=0)
        return stereo_cands[max_id.item()]
Esempio n. 4
0
    def __init__(self, smiles):
        self.__smiles = smiles
        self.__mol = chemutils.get_mol(smiles)
        self.__nodes = []

        # Stereo generation:
        mol = Chem.MolFromSmiles(smiles)
        self.__smiles3d = Chem.MolToSmiles(mol, isomericSmiles=True)
        self.__stereo_cands = chemutils.decode_stereo(Chem.MolToSmiles(mol))

        # Calculate cliques and edges:
        cliques, edges = _tree_decomp(self.__mol)

        # Add nodes:
        root = 0

        for clq_idx, clique in enumerate(cliques):
            clq_mol = _get_clique_mol(self.__mol, clique)
            self.__nodes.append(
                MolTreeNode(chemutils.get_smiles(clq_mol), clique))

            if min(clique) == 0:
                root = clq_idx

        # Add edges:
        for edge_x, edge_y in edges:
            self.__nodes[edge_x].add_neighbor(self.__nodes[edge_y])
            self.__nodes[edge_y].add_neighbor(self.__nodes[edge_x])

        if root > 0:
            self.__nodes[0], self.__nodes[root] = \
                self.__nodes[root], self.__nodes[0]

        for i, node in enumerate(self.__nodes):
            node.set_node_id(i + 1)

            # Leaf node mol is not marked:
            if not node.is_leaf():
                chemutils.set_atommap(node.get_mol(), node.get_node_id())