Esempio n. 1
0
    def decode(self, tree_vec, mol_vec, prob_decode):
        pred_root,pred_nodes = self.decoder.decode(tree_vec, prob_decode)

        #Mark nid & is_leaf & atommap
        for i,node in enumerate(pred_nodes):
            node.nid = i + 1
            node.is_leaf = (len(node.neighbors) == 1)
            if len(node.neighbors) > 1:
                set_atommap(node.mol, node.nid)

        tree_mess = self.jtnn([pred_root])[0]

        cur_mol = copy_edit_mol(pred_root.mol)
        global_amap = [{}] + [{} for node in pred_nodes]
        global_amap[1] = {atom.GetIdx():atom.GetIdx() for atom in cur_mol.GetAtoms()}

        cur_mol = self.dfs_assemble(tree_mess, mol_vec, pred_nodes, cur_mol, global_amap, [], pred_root, None, prob_decode)
        if cur_mol is None: 
            return None

        cur_mol = cur_mol.GetMol()
        set_atommap(cur_mol)
        cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol))
        if cur_mol is None: return None

        smiles2D = Chem.MolToSmiles(cur_mol)
        stereo_cands = decode_stereo(smiles2D)
        if len(stereo_cands) == 1: 
            return stereo_cands[0]
        stereo_vecs = self.mpn(mol2graph(stereo_cands))
        stereo_vecs = self.G_mean(stereo_vecs)
        scores = nn.CosineSimilarity()(stereo_vecs, mol_vec)
        _,max_id = scores.max(dim=0)
        return stereo_cands[max_id.data[0]]
Esempio n. 2
0
    def __init__(self, smiles):
        self.smiles = smiles
        self.mol = get_mol(smiles)

        #Stereo Generation
        mol = Chem.MolFromSmiles(smiles)
        self.smiles3D = Chem.MolToSmiles(mol, isomericSmiles=True)
        self.smiles2D = Chem.MolToSmiles(mol)
        self.stereo_cands = decode_stereo(self.smiles2D)

        cliques, edges = tree_decomp(self.mol)
        self.nodes = []
        root = 0
        for i, c in enumerate(cliques):
            cmol = get_clique_mol(self.mol, c)
            node = MolTreeNode(get_smiles(cmol), c)
            self.nodes.append(node)
            if min(c) == 0:
                root = i

        for x, y in edges:
            self.nodes[x].add_neighbor(self.nodes[y])
            self.nodes[y].add_neighbor(self.nodes[x])

        if root > 0:
            self.nodes[0], self.nodes[root] = self.nodes[root], self.nodes[0]

        for i, node in enumerate(self.nodes):
            node.nid = i + 1
            if len(node.neighbors) > 1:  #Leaf node mol is not marked
                set_atommap(node.mol, node.nid)
            node.is_leaf = (len(node.neighbors) == 1)
Esempio n. 3
0
 def test(self, x_batch, x_tree, reselect_num=1, prop="logp", sim_type="binary"):
     x_graphs, x_tensors, x_orders, x_scores = x_batch
     x_tensors = make_cuda(x_tensors)
     score1 = x_scores[0]
     
     _, x_tree_node_vecs, x_tree_atom_vecs = self.encode(x_tensors, x_orders)
     
     x_tree_vecs = x_tree_node_vecs.sum(dim=0)
     
     latent_vec = torch.autograd.Variable(torch.randn((2 * self.latent_size)), requires_grad=True).to(device)
     
     mol = Chem.MolFromSmiles(Chem.MolToSmiles(x_tree.mol))
     fp1 = AllChem.GetMorganFingerprint(mol, 2)
     
     new_mol, reselect = self.decoder.decode(x_tensors, latent_vec, x_tree_node_vecs, x_graphs, mol, reselect_num)
     
     if new_mol is None:
         return x_tree.smiles, 1.0, 0, score1, score1
     set_atommap(new_mol)
     try:
         new_smiles = Chem.MolToSmiles(new_mol, isomericSmiles=False)
         score2 = get_prop(new_smiles, prop=prop)
         sim = similarity(x_tree.smiles, new_smiles, sim_type)
     except Exception as e:
         print(e)
         return x_tree.smiles, 1.0, reselect, score1, score1
     if score1 == score2 and sim < 1:
         print("special case: %s and %s" % (x_tree.smiles, new_smiles))
     return new_smiles, sim, reselect, score1, score2
Esempio n. 4
0
    def decode(self, x_tree_vecs, x_mol_vecs):
        #currently do not support batch decoding
        assert x_tree_vecs.size(0) == 1 and x_mol_vecs.size(0) == 1

        pred_root,pred_nodes = self.decoder.decode(x_tree_vecs, x_mol_vecs)
        if len(pred_nodes) == 0: return None
        elif len(pred_nodes) == 1: return pred_root.smiles

        #Mark nid & is_leaf & atommap
        for i,node in enumerate(pred_nodes):
            node.nid = i + 1
            node.is_leaf = (len(node.neighbors) == 1)
            if len(node.neighbors) > 1:
                set_atommap(node.mol, node.nid)

        scope = [(0, len(pred_nodes))]
        jtenc_holder,mess_dict = JTNNEncoder.tensorize_nodes(pred_nodes, scope)
        _,tree_mess = self.jtnn(*jtenc_holder)
        tree_mess = (tree_mess, mess_dict) #Important: tree_mess is a matrix, mess_dict is a python dict

        x_mol_vec_pooled = x_mol_vecs.sum(dim=1) #average pooling?
        x_mol_vec_pooled = self.A_assm(x_mol_vec_pooled).squeeze() #bilinear

        cur_mol = copy_edit_mol(pred_root.mol)
        global_amap = [{}] + [{} for node in pred_nodes]
        global_amap[1] = {atom.GetIdx():atom.GetIdx() for atom in cur_mol.GetAtoms()}

        cur_mol = self.dfs_assemble(tree_mess, x_mol_vec_pooled, pred_nodes, cur_mol, global_amap, [], pred_root, None)
        if cur_mol is None: 
            return None

        cur_mol = cur_mol.GetMol()
        set_atommap(cur_mol)
        cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol))
        return Chem.MolToSmiles(cur_mol) if cur_mol is not None else None
Esempio n. 5
0
    def __init__(self, smiles):
        DGLGraph.__init__(self)
        self.nodes_dict = {}
        if smiles is None:
            return
        self.smiles = smiles
        self.mol = get_mol(smiles)

        mol = Chem.MolFromSmiles(smiles)
        self.smiles3D = Chem.MolToSmiles(mol, isomericSmiles=True)
        self.smiles2D = Chem.MolToSmiles(mol)
        self.stereo_cands = decode_stereo(self.smiles2D)

        cliques, edges = tree_decomp(self.mol)
        root = 0
        for i, c in enumerate(cliques):
            cmol = get_clique_mol(self.mol, c)
            csmiles = get_smiles(cmol)
            self.nodes_dict[i] = dict(
                smiles=csmiles,
                mol=get_mol(csmiles),
                clique=c,
            )
            if min(c) == 0:
                root = i
        self.add_nodes(len(cliques))

        if root > 0:
            for attr in self.nodes_dict[0]:
                self.nodes_dict[0][attr], self.nodes_dict[root][attr] = \
                        self.nodes_dict[root][attr], self.nodes_dict[0][attr]

        src = np.zeros((len(edges) * 2, ), dtype='int')
        dst = np.zeros((len(edges) * 2, ), dtype='int')
        for i, (_x, _y) in enumerate(edges):
            x = 0 if _x == root else root if _x == 0 else _x
            y = 0 if _y == root else root if _y == 0 else _y
            src[2 * i] = x
            dst[2 * i] = y
            src[2 * i + 1] = y
            dst[2 * i + 1] = x
        self.add_edges(src, dst)

        for i in self.nodes_dict:
            self.nodes_dict[i]['nid'] = i + 1
            if self.out_degree(i) > 1:
                set_atommap(self.nodes_dict[i]['mol'],
                            self.nodes_dict[i]['nid'])
            self.nodes_dict[i]['is_leaf'] = (self.out_degree(i) == 1)
Esempio n. 6
0
    def __init__(self, smiles):
        self.smiles = smiles
        self.mol = get_mol(smiles)

        # Stereo Generation
        mol = Chem.MolFromSmiles(smiles)
        self.smiles3D = Chem.MolToSmiles(mol, isomericSmiles=True)
        self.smiles2D = Chem.MolToSmiles(mol)
        self.stereo_cands = decode_stereo(self.smiles2D)

        self.node_pair2bond = {}

        cliques, edges = tree_decomp(self.mol)
        self.nodes = []
        root = 0
        for i, c in enumerate(cliques):
            cmol = get_clique_mol(self.mol, c)
            node = MolTreeNode(get_smiles(cmol), c)
            self.nodes.append(node)
            if min(c) == 0:
                root = i

        self.n_edges = 0
        self.n_virtual_edges = 0
        for x, y in edges:
            self.nodes[x].add_neighbor(self.nodes[y])
            self.nodes[y].add_neighbor(self.nodes[x])
            xy_bond = self.nodes[x].add_neighbor_bond(self.nodes[y], self.mol)
            yx_bond = self.nodes[y].add_neighbor_bond(self.nodes[x], self.mol)
            self.node_pair2bond[(x, y)] = xy_bond
            self.node_pair2bond[(y, x)] = yx_bond
            if isinstance(xy_bond, RDKitBond) or isinstance(
                    yx_bond, RDKitBond):
                self.n_virtual_edges += 1
            self.n_edges += 1

        # change
        if root > 0:
            self.nodes[0], self.nodes[root] = self.nodes[root], self.nodes[0]

        for i, node in enumerate(self.nodes):
            node.nid = i + 1
            if len(node.neighbors) > 1:  # Leaf node mol is not marked
                set_atommap(node.mol, node.nid)
            node.is_leaf = (len(node.neighbors) == 1)
Esempio n. 7
0
def can_assemble(node_x, node_y):
    node_x.nid = 1
    node_x.is_leaf = False
    set_atommap(node_x.mol, node_x.nid)

    neis = node_x.neighbors + [node_y]
    for i,nei in enumerate(neis):
        nei.nid = i + 2
        nei.is_leaf = (len(nei.neighbors) <= 1)
        if nei.is_leaf:
            set_atommap(nei.mol, 0)
        else:
            set_atommap(nei.mol, nei.nid)

    neighbors = [nei for nei in neis if nei.mol.GetNumAtoms() > 1]
    neighbors = sorted(neighbors, key=lambda x:x.mol.GetNumAtoms(), reverse=True)
    singletons = [nei for nei in neis if nei.mol.GetNumAtoms() == 1]
    neighbors = singletons + neighbors
    cands,aroma_scores = enum_assemble(node_x, neighbors)
    return len(cands) > 0# and sum(aroma_scores) >= 0