예제 #1
0
    def recover(self, original_mol):
        clique = []
        clique.extend(self.clique)
        if not self.is_leaf:
            for cidx in self.clique:
                original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(self.nid)

        for nei_node in self.neighbors:
            clique.extend(nei_node.clique)
            if nei_node.is_leaf:  #Leaf node, no need to mark
                continue
            for cidx in nei_node.clique:
                #allow singleton node override the atom mapping
                if cidx not in self.clique or len(nei_node.clique) == 1:
                    atom = original_mol.GetAtomWithIdx(cidx)
                    atom.SetAtomMapNum(nei_node.nid)

        clique = list(set(clique))
        label_mol = get_clique_mol(original_mol, clique)
        self.label = Chem.MolToSmiles(Chem.MolFromSmiles(
            get_smiles(label_mol)))
        self.label_mol = get_mol(self.label)

        for cidx in clique:
            original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(0)

        return self.label
예제 #2
0
    def __init__(self, smiles):
        self.smiles = smiles
        self.mol = get_mol(smiles)

        #Stereo Generation
        mol = Chem.MolFromSmiles(smiles)
        self.smiles3D = Chem.MolToSmiles(mol, isomericSmiles=True)
        self.smiles2D = Chem.MolToSmiles(mol)
        self.stereo_cands = decode_stereo(self.smiles2D)

        cliques, edges = tree_decomp(self.mol)
        self.nodes = []
        root = 0
        for i, c in enumerate(cliques):
            cmol = get_clique_mol(self.mol, c)
            node = MolTreeNode(get_smiles(cmol), c)
            self.nodes.append(node)
            if min(c) == 0:
                root = i

        for x, y in edges:
            self.nodes[x].add_neighbor(self.nodes[y])
            self.nodes[y].add_neighbor(self.nodes[x])

        if root > 0:
            self.nodes[0], self.nodes[root] = self.nodes[root], self.nodes[0]

        for i, node in enumerate(self.nodes):
            node.nid = i + 1
            if len(node.neighbors) > 1:  #Leaf node mol is not marked
                set_atommap(node.mol, node.nid)
            node.is_leaf = (len(node.neighbors) == 1)
예제 #3
0
파일: dglmol.py 프로젝트: tbwxmu/SAMPN
    def _recover_node(self, i, original_mol):
        node = self.nodes_dict[i]

        clique = []
        clique.extend(node['clique'])
        if not node['is_leaf']:
            for cidx in node['clique']:
                original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(node['nid'])

        for j in self.successors(i).numpy():
            nei_node = self.nodes_dict[j]
            clique.extend(nei_node['clique'])
            if nei_node['is_leaf']:
                continue
            for cidx in nei_node['clique']:

                if cidx not in node['clique'] or len(nei_node['clique']) == 1:
                    atom = original_mol.GetAtomWithIdx(cidx)
                    atom.SetAtomMapNum(nei_node['nid'])

        clique = list(set(clique))
        label_mol = get_clique_mol(original_mol, clique)
        node['label'] = Chem.MolToSmiles(
            Chem.MolFromSmiles(get_smiles(label_mol)))
        node['label_mol'] = get_mol(node['label'])

        for cidx in clique:
            original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(0)

        return node['label']
예제 #4
0
파일: dglmol.py 프로젝트: tbwxmu/SAMPN
    def __init__(self, smiles):
        DGLGraph.__init__(self)
        self.nodes_dict = {}
        if smiles is None:
            return
        self.smiles = smiles
        self.mol = get_mol(smiles)

        mol = Chem.MolFromSmiles(smiles)
        self.smiles3D = Chem.MolToSmiles(mol, isomericSmiles=True)
        self.smiles2D = Chem.MolToSmiles(mol)
        self.stereo_cands = decode_stereo(self.smiles2D)

        cliques, edges = tree_decomp(self.mol)
        root = 0
        for i, c in enumerate(cliques):
            cmol = get_clique_mol(self.mol, c)
            csmiles = get_smiles(cmol)
            self.nodes_dict[i] = dict(
                smiles=csmiles,
                mol=get_mol(csmiles),
                clique=c,
            )
            if min(c) == 0:
                root = i
        self.add_nodes(len(cliques))

        if root > 0:
            for attr in self.nodes_dict[0]:
                self.nodes_dict[0][attr], self.nodes_dict[root][attr] = \
                        self.nodes_dict[root][attr], self.nodes_dict[0][attr]

        src = np.zeros((len(edges) * 2, ), dtype='int')
        dst = np.zeros((len(edges) * 2, ), dtype='int')
        for i, (_x, _y) in enumerate(edges):
            x = 0 if _x == root else root if _x == 0 else _x
            y = 0 if _y == root else root if _y == 0 else _y
            src[2 * i] = x
            dst[2 * i] = y
            src[2 * i + 1] = y
            dst[2 * i + 1] = x
        self.add_edges(src, dst)

        for i in self.nodes_dict:
            self.nodes_dict[i]['nid'] = i + 1
            if self.out_degree(i) > 1:
                set_atommap(self.nodes_dict[i]['mol'],
                            self.nodes_dict[i]['nid'])
            self.nodes_dict[i]['is_leaf'] = (self.out_degree(i) == 1)
예제 #5
0
    def __init__(self, smiles):
        self.smiles = smiles
        self.mol = get_mol(smiles)

        # Stereo Generation
        mol = Chem.MolFromSmiles(smiles)
        self.smiles3D = Chem.MolToSmiles(mol, isomericSmiles=True)
        self.smiles2D = Chem.MolToSmiles(mol)
        self.stereo_cands = decode_stereo(self.smiles2D)

        self.node_pair2bond = {}

        cliques, edges = tree_decomp(self.mol)
        self.nodes = []
        root = 0
        for i, c in enumerate(cliques):
            cmol = get_clique_mol(self.mol, c)
            node = MolTreeNode(get_smiles(cmol), c)
            self.nodes.append(node)
            if min(c) == 0:
                root = i

        self.n_edges = 0
        self.n_virtual_edges = 0
        for x, y in edges:
            self.nodes[x].add_neighbor(self.nodes[y])
            self.nodes[y].add_neighbor(self.nodes[x])
            xy_bond = self.nodes[x].add_neighbor_bond(self.nodes[y], self.mol)
            yx_bond = self.nodes[y].add_neighbor_bond(self.nodes[x], self.mol)
            self.node_pair2bond[(x, y)] = xy_bond
            self.node_pair2bond[(y, x)] = yx_bond
            if isinstance(xy_bond, RDKitBond) or isinstance(
                    yx_bond, RDKitBond):
                self.n_virtual_edges += 1
            self.n_edges += 1

        # change
        if root > 0:
            self.nodes[0], self.nodes[root] = self.nodes[root], self.nodes[0]

        for i, node in enumerate(self.nodes):
            node.nid = i + 1
            if len(node.neighbors) > 1:  # Leaf node mol is not marked
                set_atommap(node.mol, node.nid)
            node.is_leaf = (len(node.neighbors) == 1)
예제 #6
0
    def build_mol_tree(self):
        cliques = self.cliques
        graph = nx.DiGraph()

        for i, clique in enumerate(cliques):
            cmol = get_clique_mol(self.mol, clique)
            graph.add_node(i)
            graph.nodes[i]['label'] = get_smiles(cmol)
            graph.nodes[i]['clq'] = clique

        for edge in self.edges:
            inter_atoms = list(set(cliques[edge[0]]) & set(cliques[edge[1]]))

            graph.add_edge(edge[0], edge[1])
            graph.add_edge(edge[1], edge[0])
            graph[edge[0]][edge[1]]['anchor'] = inter_atoms
            graph[edge[1]][edge[0]]['anchor'] = inter_atoms

            if len(inter_atoms) == 1:
                graph[edge[0]][edge[1]]['label'] = cliques[edge[0]].index(
                    inter_atoms[0])
                graph[edge[1]][edge[0]]['label'] = cliques[edge[1]].index(
                    inter_atoms[0])
            elif len(inter_atoms) == 2:
                index1 = cliques[edge[0]].index(inter_atoms[0])
                index2 = cliques[edge[0]].index(inter_atoms[1])
                if index2 == len(cliques[edge[0]]) - 1:
                    index2 = -1
                graph[edge[0]][edge[1]]['label'] = max(index1, index2)

                index1 = cliques[edge[1]].index(inter_atoms[0])
                index2 = cliques[edge[1]].index(inter_atoms[1])
                if index2 == len(cliques[edge[1]]) - 1:
                    index2 = -1
                graph[edge[1]][edge[0]]['label'] = max(index1, index2)

        return graph
예제 #7
0
    def tensorize(mol_batch, vocab, avocab, target=False, add_target=False):
        scores = []
        del_num = 0
        is_break = False
        for i in range(len(mol_batch)):
            mol = mol_batch[i - del_num]
            mol.set_anchor()
            for j, clique in enumerate(mol.cliques):
                cmol = get_clique_mol(mol.mol, clique)
            for u, v in mol.mol_tree.edges:
                if len(mol.mol_tree[u][v]['anchor']) > 2:
                    print(mol.smiles)
                    del mol_batch[i - del_num]
                    del_num += 1
                    is_break = True
                    break
            if not is_break:
                scores.append(penalized_logp(mol.smiles))
            else:
                is_break = False
        scores = torch.FloatTensor(scores)

        #mol_batch = [MolGraph(x) for x in mol_batch]
        tree_tensors, tree_batchG = MolTree.tensorize_graph(
            [x.mol_tree for x in mol_batch], vocab)
        graph_tensors, graph_batchG = MolTree.tensorize_graph(
            [x.mol_graph for x in mol_batch], avocab, tree=False)
        tree_scope = tree_tensors[-1]
        graph_scope = graph_tensors[-1]

        # Add anchor atom index
        cgraph = torch.zeros(len(tree_batchG.edges) + 1, 2).int()
        for u, v, attr in tree_batchG.edges(data=True):
            eid = attr['mess_idx']
            anchor = tree_batchG[u][v]['anchor']
            cgraph[eid, :len(anchor)] = torch.LongTensor(anchor)

        # Add all atom index
        max_cls_size = max([len(c) for x in mol_batch for c in x.cliques])
        dgraph = torch.zeros(len(tree_batchG) + 1, max_cls_size).long()
        for v, attr in tree_batchG.nodes(data=True):
            bid = attr['batch_id']
            offset = graph_scope[bid][0]
            tree_batchG.nodes[v]['clq'] = cls = [
                x + offset for x in attr['clq']
            ]
            tree_batchG.nodes[v]['bonds'] = [(x + offset, y + offset)
                                             for x, y in attr['bonds']]
            dgraph[v, :len(cls)] = torch.LongTensor(cls)

        # Add atom mess index
        egraph = torch.zeros(len(graph_batchG) + 1,
                             len(graph_batchG) + 1).long()
        for u, v, attr in graph_batchG.edges(data=True):
            eid = attr['mess_idx']
            egraph[u, v] = eid

        all_orders = []
        max_rev_size = max([len(x.order) for x in mol_batch])
        for i, hmol in enumerate(mol_batch):
            offset = tree_scope[i][0]
            order = [(x + offset, y + offset, tree_tensors[0][y + offset])
                     for x, y in hmol.order]
            if add_target:
                target_idx = [(None, x + offset) for x in hmol.mol_tree.nodes
                              if hmol.mol_tree.nodes[x]['target']]
                all_orders.append(target_idx + order)
            else:
                all_orders.append(order)

        tree_tensors = tree_tensors[:4] + (cgraph, dgraph, tree_scope)
        graph_tensors = graph_tensors[:4] + (egraph, graph_scope)

        if target:
            node_mask = torch.ones(len(tree_batchG) + 1, 1).int()
            edge_mask = torch.ones(len(tree_batchG.edges) + 1, 1).int()

            atom_mask = torch.zeros(len(graph_batchG) + 1, 1).int()
            bond_mask = torch.ones(len(graph_batchG.edges) + 1, 1).int()

            try:
                for v, attr in tree_batchG.nodes(data=True):
                    if attr['revise']:
                        node_mask[v] = 0
                    else:
                        atom_mask.scatter_(
                            0, dgraph[v, :len(attr['clq'])].unsqueeze(1), 1)
            except Exception as e:
                print(e)
                pdb.set_trace()

            for u, v in tree_batchG.edges:
                if tree_batchG.nodes[u]['revise'] or tree_batchG.nodes[v][
                        'revise']:
                    edge_mask[tree_batchG[u][v]['mess_idx']] = 0

            mask1 = torch.ones(len(graph_batchG) + 1, 1).int()
            mask2 = torch.zeros(len(graph_batchG) + 1, 1).int()
            masked_atoms = torch.where(atom_mask == 0, atom_mask, mask2)
            masked_atoms = torch.where(atom_mask > 0, masked_atoms, mask1)
            masked_atoms = masked_atoms.nonzero()[:, 0]

            mess_list = []
            for a1 in masked_atoms[1:]:
                a1 = a1.item()
                mess = torch.LongTensor([
                    graph_batchG[a1][edge[1]]['mess_idx']
                    for edge in graph_batchG.edges(a1)
                ])
                mess_list.append(mess)

            try:
                mess = torch.cat(mess_list, dim=0).unsqueeze(1)
            except:
                pdb.set_trace()
            bond_mask.scatter_(0, mess, 1)

            tree_tensors = tree_tensors[:-1] + (node_mask, edge_mask,
                                                tree_scope)
            graph_tensors = graph_tensors[:-1] + (atom_mask, bond_mask,
                                                  graph_scope)
            return (tree_batchG,
                    graph_batchG), (tree_tensors,
                                    graph_tensors), all_orders, scores
        else:
            return (tree_batchG,
                    graph_batchG), (tree_tensors,
                                    graph_tensors), all_orders, scores