Exemple #1
0
def fix_valence_charge(mol: Chem.rdchem.Mol, inplace: bool = False) -> Optional[Chem.rdchem.Mol]:
    """Fix valence issues that are due to incorrect charges.

    Args:
        mol: Input molecule with incorrect valence for some atoms
        inplace: Whether to modify in place or make a copy.

    Returns:
        Fixed molecule via charge correction or original molecule if failed.
    """

    vm = rdMolStandardize.RDKitValidation()

    # Don't fix something that is not broken
    if len(vm.validate(mol)) > 0:

        if not inplace:
            mol = copy.copy(mol)

        mol.UpdatePropertyCache(False)
        for a in mol.GetAtoms():
            n_electron = (
                a.GetImplicitValence()
                + a.GetExplicitValence()
                - dm.PERIODIC_TABLE.GetDefaultValence(a.GetSymbol())
            )
            a.SetFormalCharge(n_electron)

    return mol
Exemple #2
0
def center_of_mass(
    mol: Chem.rdchem.Mol,
    use_atoms: bool = True,
    digits: int = None,
    conf_id: int = -1,
) -> np.ndarray:
    """Compute the center of mass of a conformer of a molecule.

    Args:
        mol: a molecule
        use_atoms: Whether to compute the true center of mass or the geometrical center.
        digits: Number of digits to round to.
        conf_id: the conformer id.

    Returns
        cm: Center of mass or geometrical center
    """
    coords = get_coords(mol)
    atom_weight = np.ones((coords.shape[0]))

    if use_atoms:
        atom_weight = np.array([atom.GetMass() for atom in mol.GetAtoms()])

    atom_weight = atom_weight[:, None]
    atom_weight /= atom_weight.sum()
    center = (coords * atom_weight).sum(axis=0)

    if digits is not None:
        center = center.round(digits)

    return center
def num_atom_in_ring(mol: Chem.rdchem.Mol):
    """
    Check the number of the atoms that are in the ring,, count vector.
    """
    count = 0
    for atom in mol.GetAtoms():
        if atom.IsInRing():
            count += 1

    return [count]
def is_zwitterion(mol: Chem.rdchem.Mol):
    """
    To identify whether the molecule is zwitterion or not 
    """
    zwitterion = 0
    for atom in mol.GetAtoms():
        if atom.GetFormalCharge() != 0:
            zwitterion = 1
            break

    return [zwitterion]
Exemple #5
0
def copy_edit_mol(mol: Chem.rdchem.Mol) -> Chem.rdchem.Mol:
    new_mol = Chem.RWMol(Chem.MolFromSmiles(''))
    for atom in mol.GetAtoms():
        new_atom = copy_atom(atom)
        new_mol.AddAtom(new_atom)
    for bond in mol.GetBonds():
        a1 = bond.GetBeginAtom().GetIdx()
        a2 = bond.GetEndAtom().GetIdx()
        bt = bond.GetBondType()
        new_mol.AddBond(a1, a2, bt)

    return new_mol
Exemple #6
0
def atom_indices_to_mol(mol: Chem.rdchem.Mol, copy: bool = False):
    """Add the `molAtomMapNumber` property to each atoms.

    Args:
        mol: a molecule
        copy: Whether to copy the molecule.
    """

    if copy is True:
        mol = copy_mol(mol)

    for atom in mol.GetAtoms():
        atom.SetProp("molAtomMapNumber", str(atom.GetIdx()))
    return mol
Exemple #7
0
def adjust_singleton(mol: Chem.rdchem.Mol) -> Optional[Chem.rdchem.Mol]:
    """Remove all atoms that are essentially disconnected singleton nodes in the molecular graph.
    For example, the chlorine atom and methane fragment will be removed in Cl.[N:1]1=CC(O)=CC2CCCCC12.CC.C",
    but not the ethane fragment.

    Args:
        mol: a molecule.
    """
    to_rem = []
    em = Chem.RWMol(mol)
    for atom in mol.GetAtoms():
        if atom.GetExplicitValence() == 0:
            to_rem.append(atom.GetIdx())
    to_rem.sort(reverse=True)
    for a_idx in to_rem:
        em.RemoveAtom(a_idx)
    return em.GetMol()
Exemple #8
0
def to_neutral(mol: Chem.rdchem.Mol) -> Optional[Chem.rdchem.Mol]:
    """Neutralize the charge of a molecule.

    Args:
        mol: a molecule.

    Returns:
        mol: a molecule.
    """
    if mol is None:
        return mol

    for a in mol.GetAtoms():
        if a.GetFormalCharge() < 0 or (
            a.GetExplicitValence() >= PERIODIC_TABLE.GetDefaultValence(a.GetSymbol())
            and a.GetFormalCharge() > 0
        ):
            a.SetFormalCharge(0)
            a.UpdatePropertyCache(False)
    return mol
Exemple #9
0
def atom_graph(mol: Chem.rdchem.Mol):
    """
    Generates the atom graph from an RDKit Mol object.

    Function taken from https://github.com/maxhodak/keras-molecules/pull/32/files.
    """
    if mol:
        G = nx.Graph()
        for atom in mol.GetAtoms():
            G.add_node(
                atom.GetIdx(),
                atomic_num=atom.GetAtomicNum(
                ),  # this should be instantiated once, and later reused for defining the feature vector
                formal_charge=atom.GetFormalCharge(),
                chiral_tag=atom.GetChiralTag(),
                hybridization=atom.GetHybridization(),
                num_explicit_hs=atom.GetNumExplicitHs(),
                is_aromatic=atom.GetIsAromatic(),
                mass=atom.GetMass(),
                implicit_valence=atom.GetImplicitValence(),
                total_hydrogens=atom.GetTotalNumHs(),
                features=np.array([
                    atom.GetAtomicNum(),
                    atom.GetFormalCharge(),
                    atom.GetChiralTag(),
                    atom.GetHybridization(),
                    atom.GetNumExplicitHs(),
                    atom.GetIsAromatic(),
                    atom.GetMass(),
                    atom.GetImplicitValence(),
                    atom.GetTotalNumHs(),
                ]),
            )
        for bond in mol.GetBonds():
            G.add_edge(
                bond.GetBeginAtomIdx(),
                bond.GetEndAtomIdx(),
                bond_type=bond.GetBondType(),
            )
        return G
Exemple #10
0
def bond_graph(mol: Chem.rdchem.Mol):
    """
    Generates the bond graph from an RDKit Mol object.

    Here, unlike the atom gaph, bonds are nodes, and are
    connected to each other by atoms.

    :returns: a NetworkX graph.
    """
    if mol:
        G = nx.Graph()
        for bond in mol.GetBonds():
            G.add_node(
                (bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()),
                bond_type=bond.GetBondTypeAsDouble(),
                aromatic=bond.GetIsAromatic(),
                stereo=bond.GetStereo(),
                in_ring=bond.IsInRing(),
                is_conjugated=bond.GetIsConjugated(),
                features=[
                    bond.GetBondTypeAsDouble(),
                    int(bond.GetIsAromatic()),
                    # bond.GetStereo(),
                    int(bond.IsInRing()),
                    int(bond.GetIsConjugated()),
                ],
            )

        for atom in mol.GetAtoms():
            bonds = atom.GetBonds()
            if len(bonds) >= 2:
                for b1, b2 in combinations(bonds, 2):
                    n1 = (b1.GetBeginAtomIdx(), b1.GetEndAtomIdx())
                    n2 = (b2.GetBeginAtomIdx(), b2.GetEndAtomIdx())
                    joining_node = list(set(n1).intersection(n2))[0]
                    G.add_edge(n1, n2, atom=joining_node)
                    G.add_edge(n2, n1)
        return G
Exemple #11
0
def sasa(
    mol: Chem.rdchem.Mol,
    conf_id: Union[int, List[int]] = None,
    n_jobs: int = 1,
) -> np.ndarray:
    """Compute Solvent Accessible Surface Area of all the conformers
    using FreeSASA (https://freesasa.github.io/). Values are returned
    as an array and also stored within each conformer as a property
    called `rdkit_free_sasa`.

    Example:

    ```python
    smiles = "O=C(C)Oc1ccccc1C(=O)O"
    mol = dm.to_mol(smiles)
    mol = dm.conformers.generate(mol)

    # Compute SASA for all the conformers without parallelization
    sasa_values = dm.conformers.sasa(mol, conf_id=None, n_jobs=1)

    # If minimization has been enabled (default to True)
    # you can access the computed energy.
    conf = mol.GetConformer(0)
    props = conf.GetPropsAsDict()
    print(props)
    # {'rdkit_uff_energy': 1.7649408317784008}
    ```

    Args:
        mol: a molecule
        conf_id: Id of the conformers to compute. If None, compute all.
        n_jobs: Number of jobs for parallelization. Set to 1 to disable
            and -1 to use all cores.

    Returns:
        mol: the molecule with the conformers.
    """
    from rdkit.Chem import rdFreeSASA

    if mol.GetNumConformers() == 0:
        raise ValueError(
            "The molecule has 0 conformers. You can generate conformers with `dm.conformers.generate(mol)`."
        )

    # Get Van der Waals radii (angstrom)
    radii = [
        dm.PERIODIC_TABLE.GetRvdw(atom.GetAtomicNum())
        for atom in mol.GetAtoms()
    ]

    # Which conformers to compute
    conf_ids = []
    if conf_id is None:
        # If None compute for all the conformers
        conf_ids = list(range(mol.GetNumConformers()))  # type: ignore
    elif isinstance(conf_id, int):
        conf_ids = [conf_id]
    else:
        conf_ids = conf_id

    # Compute solvent accessible surface area
    def _get_sasa(i):
        conf = mol.GetConformer(i)
        sasa = rdFreeSASA.CalcSASA(mol, radii, confIdx=conf.GetId())
        conf.SetDoubleProp("rdkit_free_sasa", sasa)
        return sasa

    runner = dm.JobRunner(n_jobs=n_jobs)
    sasa_values = runner(_get_sasa, conf_ids)
    return np.array(sasa_values)
Exemple #12
0
def tree_decomp(
        mol: Chem.rdchem.Mol) -> Tuple[List[List[int]], List[Tuple[int, int]]]:
    n_atoms = mol.GetNumAtoms()
    cliques = []
    for atom in mol.GetAtoms():
        if atom.GetDegree() == 0:
            cliques.append([atom.GetIdx()])

    for bond in mol.GetBonds():
        a1 = bond.GetBeginAtom().GetIdx()
        a2 = bond.GetEndAtom().GetIdx()
        if not bond.IsInRing():
            cliques.append([a1, a2])

    ssr = [list(x) for x in Chem.GetSymmSSSR(mol)]
    cliques.extend(ssr)

    nei_list = [[] for i in range(n_atoms)]
    for i in range(len(cliques)):
        for atom in cliques[i]:
            nei_list[atom].append(i)

    # Merge Rings with intersection > 2 atoms
    for i in range(len(cliques)):
        if len(cliques[i]) <= 2: continue
        for atom in cliques[i]:
            for j in nei_list[atom]:
                if i >= j or len(cliques[j]) <= 2: continue
                inter = set(cliques[i]) & set(cliques[j])
                if len(inter) > 2:
                    cliques[i].extend(cliques[j])
                    cliques[i] = list(set(cliques[i]))
                    cliques[j] = []

    cliques = [c for c in cliques if len(c) > 0]
    nei_list = [[] for i in range(n_atoms)]
    for i in range(len(cliques)):
        for atom in cliques[i]:
            nei_list[atom].append(i)

    # Build edges and add singleton cliques
    edges = defaultdict(int)
    for atom in range(n_atoms):
        if len(nei_list[atom]) <= 1:
            continue
        cnei = nei_list[atom]
        bonds = [c for c in cnei if len(cliques[c]) == 2]
        rings = [c for c in cnei if len(cliques[c]) > 4]
        if len(bonds) > 2 or (
                len(bonds) == 2 and len(cnei) > 2
        ):  # In general, if len(cnei) >= 3, a singleton should be added, but 1 bond + 2 ring is currently not dealt with.
            cliques.append([atom])
            c2 = len(cliques) - 1
            for c1 in cnei:
                edges[(c1, c2)] = 1
        elif len(rings) > 2:  # Multiple (n>2) complex rings
            cliques.append([atom])
            c2 = len(cliques) - 1
            for c1 in cnei:
                edges[(c1, c2)] = MST_MAX_WEIGHT - 1
        else:
            for i in range(len(cnei)):
                for j in range(i + 1, len(cnei)):
                    c1, c2 = cnei[i], cnei[j]
                    inter = set(cliques[c1]) & set(cliques[c2])
                    if edges[(c1, c2)] < len(inter):
                        edges[(c1, c2)] = len(
                            inter)  # cnei[i] < cnei[j] by construction

    edges = [u + (MST_MAX_WEIGHT - v, ) for u, v in edges.items()]
    if len(edges) == 0:
        return cliques, edges

    # Compute Maximum Spanning Tree
    row, col, data = zip(*edges)
    n_clique = len(cliques)
    clique_graph = csr_matrix((data, (row, col)), shape=(n_clique, n_clique))
    junc_tree = minimum_spanning_tree(clique_graph)
    row, col = junc_tree.nonzero()
    edges = [(row[i], col[i]) for i in range(len(row))]

    return cliques, edges
Exemple #13
0
def has_charge(mol: Chem.rdchem.Mol) -> bool:
    for atom in mol.GetAtoms():
        if atom.GetFormalCharge() != 0:
            return True
    return False