Exemple #1
0
def copy_edit_mol(mol: Chem.rdchem.Mol) -> Chem.rdchem.Mol:
    new_mol = Chem.RWMol(Chem.MolFromSmiles(''))
    for atom in mol.GetAtoms():
        new_atom = copy_atom(atom)
        new_mol.AddAtom(new_atom)
    for bond in mol.GetBonds():
        a1 = bond.GetBeginAtom().GetIdx()
        a2 = bond.GetEndAtom().GetIdx()
        bt = bond.GetBondType()
        new_mol.AddBond(a1, a2, bt)

    return new_mol
Exemple #2
0
def atom_graph(mol: Chem.rdchem.Mol):
    """
    Generates the atom graph from an RDKit Mol object.

    Function taken from https://github.com/maxhodak/keras-molecules/pull/32/files.
    """
    if mol:
        G = nx.Graph()
        for atom in mol.GetAtoms():
            G.add_node(
                atom.GetIdx(),
                atomic_num=atom.GetAtomicNum(
                ),  # this should be instantiated once, and later reused for defining the feature vector
                formal_charge=atom.GetFormalCharge(),
                chiral_tag=atom.GetChiralTag(),
                hybridization=atom.GetHybridization(),
                num_explicit_hs=atom.GetNumExplicitHs(),
                is_aromatic=atom.GetIsAromatic(),
                mass=atom.GetMass(),
                implicit_valence=atom.GetImplicitValence(),
                total_hydrogens=atom.GetTotalNumHs(),
                features=np.array([
                    atom.GetAtomicNum(),
                    atom.GetFormalCharge(),
                    atom.GetChiralTag(),
                    atom.GetHybridization(),
                    atom.GetNumExplicitHs(),
                    atom.GetIsAromatic(),
                    atom.GetMass(),
                    atom.GetImplicitValence(),
                    atom.GetTotalNumHs(),
                ]),
            )
        for bond in mol.GetBonds():
            G.add_edge(
                bond.GetBeginAtomIdx(),
                bond.GetEndAtomIdx(),
                bond_type=bond.GetBondType(),
            )
        return G
Exemple #3
0
def bond_graph(mol: Chem.rdchem.Mol):
    """
    Generates the bond graph from an RDKit Mol object.

    Here, unlike the atom gaph, bonds are nodes, and are
    connected to each other by atoms.

    :returns: a NetworkX graph.
    """
    if mol:
        G = nx.Graph()
        for bond in mol.GetBonds():
            G.add_node(
                (bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()),
                bond_type=bond.GetBondTypeAsDouble(),
                aromatic=bond.GetIsAromatic(),
                stereo=bond.GetStereo(),
                in_ring=bond.IsInRing(),
                is_conjugated=bond.GetIsConjugated(),
                features=[
                    bond.GetBondTypeAsDouble(),
                    int(bond.GetIsAromatic()),
                    # bond.GetStereo(),
                    int(bond.IsInRing()),
                    int(bond.GetIsConjugated()),
                ],
            )

        for atom in mol.GetAtoms():
            bonds = atom.GetBonds()
            if len(bonds) >= 2:
                for b1, b2 in combinations(bonds, 2):
                    n1 = (b1.GetBeginAtomIdx(), b1.GetEndAtomIdx())
                    n2 = (b2.GetBeginAtomIdx(), b2.GetEndAtomIdx())
                    joining_node = list(set(n1).intersection(n2))[0]
                    G.add_edge(n1, n2, atom=joining_node)
                    G.add_edge(n2, n1)
        return G
Exemple #4
0
def tree_decomp(
        mol: Chem.rdchem.Mol) -> Tuple[List[List[int]], List[Tuple[int, int]]]:
    n_atoms = mol.GetNumAtoms()
    cliques = []
    for atom in mol.GetAtoms():
        if atom.GetDegree() == 0:
            cliques.append([atom.GetIdx()])

    for bond in mol.GetBonds():
        a1 = bond.GetBeginAtom().GetIdx()
        a2 = bond.GetEndAtom().GetIdx()
        if not bond.IsInRing():
            cliques.append([a1, a2])

    ssr = [list(x) for x in Chem.GetSymmSSSR(mol)]
    cliques.extend(ssr)

    nei_list = [[] for i in range(n_atoms)]
    for i in range(len(cliques)):
        for atom in cliques[i]:
            nei_list[atom].append(i)

    # Merge Rings with intersection > 2 atoms
    for i in range(len(cliques)):
        if len(cliques[i]) <= 2: continue
        for atom in cliques[i]:
            for j in nei_list[atom]:
                if i >= j or len(cliques[j]) <= 2: continue
                inter = set(cliques[i]) & set(cliques[j])
                if len(inter) > 2:
                    cliques[i].extend(cliques[j])
                    cliques[i] = list(set(cliques[i]))
                    cliques[j] = []

    cliques = [c for c in cliques if len(c) > 0]
    nei_list = [[] for i in range(n_atoms)]
    for i in range(len(cliques)):
        for atom in cliques[i]:
            nei_list[atom].append(i)

    # Build edges and add singleton cliques
    edges = defaultdict(int)
    for atom in range(n_atoms):
        if len(nei_list[atom]) <= 1:
            continue
        cnei = nei_list[atom]
        bonds = [c for c in cnei if len(cliques[c]) == 2]
        rings = [c for c in cnei if len(cliques[c]) > 4]
        if len(bonds) > 2 or (
                len(bonds) == 2 and len(cnei) > 2
        ):  # In general, if len(cnei) >= 3, a singleton should be added, but 1 bond + 2 ring is currently not dealt with.
            cliques.append([atom])
            c2 = len(cliques) - 1
            for c1 in cnei:
                edges[(c1, c2)] = 1
        elif len(rings) > 2:  # Multiple (n>2) complex rings
            cliques.append([atom])
            c2 = len(cliques) - 1
            for c1 in cnei:
                edges[(c1, c2)] = MST_MAX_WEIGHT - 1
        else:
            for i in range(len(cnei)):
                for j in range(i + 1, len(cnei)):
                    c1, c2 = cnei[i], cnei[j]
                    inter = set(cliques[c1]) & set(cliques[c2])
                    if edges[(c1, c2)] < len(inter):
                        edges[(c1, c2)] = len(
                            inter)  # cnei[i] < cnei[j] by construction

    edges = [u + (MST_MAX_WEIGHT - v, ) for u, v in edges.items()]
    if len(edges) == 0:
        return cliques, edges

    # Compute Maximum Spanning Tree
    row, col, data = zip(*edges)
    n_clique = len(cliques)
    clique_graph = csr_matrix((data, (row, col)), shape=(n_clique, n_clique))
    junc_tree = minimum_spanning_tree(clique_graph)
    row, col = junc_tree.nonzero()
    edges = [(row[i], col[i]) for i in range(len(row))]

    return cliques, edges
Exemple #5
0
def all_bond_remove(
        mol: Chem.rdchem.Mol,
        as_mol: bool = True,
        allow_bond_decrease: bool = True,
        allow_atom_trim: bool = True,
        max_num_action=float("Inf"),
):
    """Remove bonds from a molecule

    Warning:
        This can be computationally expensive.

    Args:
        mol: Input molecule
        allow_bond_decrease: Allow decreasing bond type in addition to bond cut
        max_num_action: Maximum number of action to reduce complexity
        allow_atom_trim: Allow bond removal even when it results in dm.SINGLE_BOND

    Returns:
        All possible molecules from removing bonds

    """
    new_mols = []

    try:
        Chem.Kekulize(mol, clearAromaticFlags=True)
    except:
        pass

    for bond in mol.GetBonds():
        if len(new_mols) > max_num_action:
            break

        original_bond_type = bond.GetBondType()
        emol = Chem.RWMol(mol)
        emol.RemoveBond(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx())
        new_mol = dm.sanitize_mol(emol.GetMol())

        if not new_mol:
            continue

        frag_list = list(rdmolops.GetMolFrags(new_mol, asMols=True))
        has_single_atom = any([x.GetNumAtoms() < 2 for x in frag_list])
        if not has_single_atom or allow_atom_trim:
            new_mols.extend(frag_list)
        if allow_bond_decrease:
            if original_bond_type in [dm.DOUBLE_BOND, dm.TRIPLE_BOND]:
                new_mol = update_bond(mol, bond, dm.SINGLE_BOND)
                if new_mol is not None:
                    new_mols.extend(
                        list(rdmolops.GetMolFrags(new_mol, asMols=True)))
            if original_bond_type == dm.TRIPLE_BOND:
                new_mol = update_bond(mol, bond, dm.DOUBLE_BOND)
                if new_mol is not None:
                    new_mols.extend(
                        list(rdmolops.GetMolFrags(new_mol, asMols=True)))

    new_mols = [mol for mol in new_mols if mol is not None]

    if not as_mol:
        return [dm.to_smiles(x) for x in new_mols if x]

    return new_mols