Example #1
0
def remove_salts_solvents(mol, hac=3):
    """
    Remove solvents and ions have max 'hac' heavy atoms.
    This function was obtained from the mol2vec package,
    available at:
        https://github.com/samoturk/mol2vec/blob/master/mol2vec/features.py
    """
    # split molecule into fragments
    fragments = list(rdmolops.GetMolFrags(mol, asMols=True))
    ## keep heaviest only
    ## fragments.sort(reverse=True, key=lambda m: m.GetNumAtoms())
    # remove fragments with < 'hac' heavy atoms
    fragments = [fragment for fragment in fragments if \
                 fragment.GetNumAtoms() > hac]
    #
    if len(fragments) > 1:
        warnings.warn("molecule contains >1 fragment with >" + str(hac) + \
                      " heavy atoms")
        return None
    elif len(fragments) == 0:
        warnings.warn("molecule contains no fragments with >" + str(hac) + \
                      " heavy atoms")
        return None
    else:
        return fragments[0]
def filter_extract_mol(row, headers_dict):
    relation_idx = headers_dict['RELATION']
    ro5_idx = headers_dict['NUM_RO5_VIOLATIONS']
    pchembl_idx = headers_dict['PCHEMBL_VALUE']
    try:
        ro5_violations = int(row[ro5_idx])
    except:
        return None
    try:
        float(row[pchembl_idx])
    except:
        return None
    if (row[relation_idx] != '=' or ro5_violations > 0):
        return None
    mol = Chem.MolFromSmiles(row[smiles_idx])
    if (not mol):
        return None
    mols = list(rdmolops.GetMolFrags(mol, asMols=True))
    if (not mols):
        return None
    mols.sort(reverse=True, key=lambda m: m.GetNumAtoms())
    mol = mols[0]
    molWt = Descriptors.MolWt(mol)
    if (molWt < 300 or molWt > 400):
        return None
    return mol
Example #3
0
def keep_largest_fragment(mol: Chem.rdchem.Mol) -> Optional[Chem.rdchem.Mol]:
    """Only keep largest fragment of each molecule."""
    return max(
        rdmolops.GetMolFrags(mol, asMols=True),
        default=mol,
        key=lambda m: m.GetNumAtoms(),
    )
 def _split_products(self, smi):
     mol = AllChem.MolFromSmiles(smi)
     splitMols = rdmolops.GetMolFrags(mol, asMols=True)
     split_list = []
     for mol in splitMols:
         p_smile = AllChem.MolToSmiles(mol)
         split_list.append(p_smile)
     return split_list
Example #5
0
def _getlargestFragment(mol):
    frags = rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=True)
    maxmol = None
    for mol in frags:
        if mol is None:
            continue
        if maxmol is None:
            maxmol = mol
        if maxmol.GetNumHeavyAtoms() < mol.GetNumHeavyAtoms():
            maxmol = mol
    return maxmol
Example #6
0
def n_disconnected(mol):
    """ The number of disconnected fragments in the mol.

        Args:
            mol (skchem.Mol):
                The molecule for which to calculate the descriptor.

        Returns:
            int
        """

    return len(rdmolops.GetMolFrags(mol))
Example #7
0
def all_fragment_on_bond(mol,
                         asMols=False,
                         max_num_action=float("Inf"),
                         break_aromatic=True):
    """Fragment all possible bond in a molecule and return the set of resulting fragments
    This is similar to `random_bond_cut`, but is not stochastic as it does not return a random fragment
    but all the fragments resulting from all potential bond break in the molecule.

    .. note::
        This will always be a subset of all_bond_remove, the main difference being that all_bond_remove, allow decreasing
        bond count, while this one will always break a molecule into two.

    Args:
        mol: <Chem.Mol>
            input molecule
        asMols: bool, optional
            Whether to return results as mols or smiles
        max_num_action: float, optional
            Maximum number of action to reduce complexity
        break_aromatic: bool, optional
            Whether to attempt to break even aromatic bonds
            (Default: True)

    Returns:
        set of fragments

    """
    mol.GetRingInfo().AtomRings()
    fragment_set = set([])
    bonds = list(mol.GetBonds())
    stop = False
    if bonds:
        if break_aromatic:
            Chem.Kekulize(mol, clearAromaticFlags=True)
        for bond in bonds:
            if stop:
                break
            if break_aromatic or not bond.GetIsAromatic():
                truncate = Chem.FragmentOnBonds(mol, [bond.GetIdx()],
                                                addDummies=False)
                truncate = dm.sanitize_mol(truncate)
                if truncate is not None:
                    for frag in rdmolops.GetMolFrags(truncate, asMols=True):
                        frag = dm.sanitize_mol(frag)
                        if frag:
                            if not asMols:
                                frag = dm.to_smiles(frag)
                            fragment_set.add(frag)
                        if len(fragment_set) > max_num_action:
                            stop = True
                            break
    return fragment_set
Example #8
0
    def apply_retrorules(self, smile, rxns, explicit_hydrogens=False):
        '''Function takes a smile and dictionary of reactions, applys the reactions and
           returns a dictionary of rxn_names : products '''
        try:
            substrate_molecule = AllChem.MolFromSmiles(smile)
        except:
            return {}

        if explicit_hydrogens == True:
            substrate_molecule = rdmolops.AddHs(substrate_molecule)

        rxn_product_dict = {}
        for rxn_name, rxn in rxns.items():
            try:
                products = rxn.RunReactants((substrate_molecule, ))
            except:
                products = []
                print('Error running reactants for: ' + str(smile))

            smiles_products = []
            for product in products:
                sub_list = []
                for mol in product:
                    mols = [mol]

                    if explicit_hydrogens == True:
                        mol = rdmolops.RemoveHs(mol)

                    try:
                        mols = rdmolops.GetMolFrags(mol, asMols=True)
                    except:
                        pass

                    for mol in mols:
                        try:
                            p_smile = AllChem.MolToSmiles(mol)
                            p_smile = rdkit_smile(p_smile)
                            if self._check_valid_smile(
                                    p_smile, rxn_name=rxn_name) == True:
                                sub_list.append(p_smile)
                        except:
                            pass

                if (sub_list not in smiles_products) and (len(sub_list) != 0):
                    smiles_products.append(sub_list)

            if len(smiles_products) != 0:
                rxn_product_dict[rxn_name] = smiles_products

        return rxn_product_dict
Example #9
0
def _neutralise_sulphoxide(mol):
    smirks = '[S+1:1][O-1:2]>>[S+0:1]=[O-0:2]'
    rxn = rdChemReactions.ReactionFromSmarts(smirks)
    frags = rdmolops.GetMolFrags(mol, asMols=True)
    n_frags = list(
        filter(lambda x: x is not None,
               [_apply_rxn(frag, rxn) for frag in frags]))
    if len(n_frags) == 1:
        n_mol = n_frags[0]
    elif len(n_frags) == 2:
        n_mol = CombineMols(*n_frags)
        SanitizeMol(n_mol)
    else:
        n_mol = CombineMols(n_frags[0], n_frags[1])
        for i in range(2, len(n_frags)):
            n_mol = CombineMols(n_mol, n_frags[i])
        SanitizeMol(n_mol)
    return n_mol
Example #10
0
def fuzzy_scaffolding(
    mols: List[Chem.rdchem.Mol],
    enforce_subs: List[str] = None,
    n_atom_cuttoff: int = 8,
    additional_templates: List[Chem.rdchem.Mol] = None,
    ignore_non_ring: bool = False,
    mcs_params: Dict[Any, Any] = None,
):
    """Generate fuzzy scaffold with enforceable group that needs to appear
    in the core, forcing to keep the full side chain if required.

    NOTE(hadim): consider parallelize this (if possible).

    Args:
        mols: List of all molecules
        enforce_subs: List of substructure to enforce on the scaffold.
        n_atom_cuttoff: Minimum number of atom a core should have.
        additional_templates: Additional template to use to generate scaffolds.
        ignore_non_ring: Whether to ignore atom no in murcko ring system, even if they are in the framework.
        mcs_params: Arguments of MCS algorithm.

    Returns:
        scaffolds: set
            All found scaffolds in the molecules as valid smiles
        scaffold_infos: dict of dict
            Infos on the scaffold mapping, ignoring any side chain that had to be enforced.
            Key corresponds to generic scaffold smiles
            Values at ['smarts'] corresponds to smarts representation of the true scaffold (from MCS)
            Values at ['mols'] corresponds to list of molecules matching the scaffold
        scaffold_to_group: dict of list
            Map between each generic scaffold and the R-groups decomposition row
    """

    if enforce_subs is None:
        enforce_subs = []

    if additional_templates is None:
        additional_templates = []

    if mcs_params is None:
        mcs_params = {}

    rg_params = rdRGroupDecomposition.RGroupDecompositionParameters()
    rg_params.removeAllHydrogenRGroups = True
    rg_params.removeHydrogensPostMatch = True
    rg_params.alignment = rdRGroupDecomposition.RGroupCoreAlignment.MCS
    rg_params.matchingStrategy = rdRGroupDecomposition.RGroupMatching.Exhaustive
    rg_params.rgroupLabelling = rdRGroupDecomposition.RGroupLabelling.AtomMap
    rg_params.labels = rdRGroupDecomposition.RGroupLabels.AtomIndexLabels

    core_query_param = AdjustQueryParameters()
    core_query_param.makeDummiesQueries = True
    core_query_param.adjustDegree = False
    core_query_param.makeBondsGeneric = True

    # group molecules by they generic Murcko scaffold, allowing
    # side chain that contains cycle (might be a bad idea)
    scf2infos = collections.defaultdict(dict)
    scf2groups = {}
    all_scaffolds = set([])

    for m in mols:
        generic_m = MurckoScaffold.MakeScaffoldGeneric(m)
        scf = MurckoScaffold.GetScaffoldForMol(m)
        try:
            scf = MurckoScaffold.MakeScaffoldGeneric(scf)
        except:
            pass

        if ignore_non_ring:
            rw_scf = Chem.RWMol(scf)
            atms = [a.GetIdx() for a in rw_scf.GetAtoms() if not a.IsInRing()]
            atms.sort(reverse=True)
            for a in atms:
                rw_scf.RemoveAtom(a)
            scfs = list(rdmolops.GetMolFrags(rw_scf, asMols=False))
        else:
            scfs = [dm.to_smiles(scf)]

        # add templates mols if exists:
        for tmp in additional_templates:
            tmp = dm.to_mol(tmp)
            tmp_scf = MurckoScaffold.MakeScaffoldGeneric(tmp)
            if generic_m.HasSubstructMatch(tmp_scf):
                scfs.append(dm.to_smiles(tmp_scf))

        for scf in scfs:
            if scf2infos[scf].get("mols"):
                scf2infos[scf]["mols"].append(m)
            else:
                scf2infos[scf]["mols"] = [m]

    for scf in scf2infos:
        # cheat by adding murcko as last mol always
        popout = False
        mols = scf2infos[scf]["mols"]
        if len(mols) < 2:
            mols = mols + [MurckoScaffold.GetScaffoldForMol(mols[0])]
            popout = True

        # compute the MCS of the cluster
        mcs = rdFMCS.FindMCS(
            mols,
            atomCompare=rdFMCS.AtomCompare.CompareAny,
            bondCompare=rdFMCS.BondCompare.CompareAny,
            completeRingsOnly=True,
            **mcs_params,
        )

        mcsM = Chem.MolFromSmarts(mcs.smartsString)
        mcsM.UpdatePropertyCache(False)
        Chem.SetHybridization(mcsM)

        if mcsM.GetNumAtoms() < n_atom_cuttoff:
            continue

        scf2infos[scf]["smarts"] = dm.to_smarts(mcsM)
        if popout:
            mols = mols[:-1]

        core_groups = []
        # generate rgroups based on the mcs core
        success_mols = []
        try:
            rg = rdRGroupDecomposition.RGroupDecomposition(mcsM, rg_params)
            for i, analog in enumerate(mols):
                analog.RemoveAllConformers()
                res = rg.Add(analog)
                if not (res < 0):
                    success_mols.append(i)
            rg.Process()
            core_groups = rg.GetRGroupsAsRows()
        except Exception:
            pass

        mols = [mols[i] for i in success_mols]
        scf2groups[scf] = core_groups
        for mol, gp in zip(mols, core_groups):
            core = gp["Core"]
            acceptable_groups = [
                a.GetAtomMapNum() for a in core.GetAtoms()
                if (a.GetAtomMapNum() and not a.IsInRing())
            ]

            rgroups = [
                gp[f"R{k}"] for k in acceptable_groups if f"R{k}" in gp.keys()
            ]
            if enforce_subs:
                rgroups = [
                    rgp for rgp in rgroups if not any([
                        len(rgp.GetSubstructMatch(frag)) > 0
                        for frag in enforce_subs
                    ])
                ]
            try:
                scaff = trim_side_chain(
                    mol, AdjustQueryProperties(core, core_query_param),
                    rgroups)
            except:
                continue
            all_scaffolds.add(dm.to_smiles(scaff))

    return all_scaffolds, scf2infos, scf2groups
Example #11
0
def identify_functional_groups(smi):
    ## We decided to start from a SMILES and add explicit hydrogens inside the function
    mol = Chem.MolFromSmiles(smi)
    mol = rdmolops.AddHs(mol)
    try:
        marked = set()
        ## Since heteroatoms are included in PATT_TUPLE, we remove the first part of the original function
        for patt in PATT_TUPLE:
            for path in mol.GetSubstructMatches(patt):
                for atomindex in path:
                    marked.add(atomindex)

    #merge all connected marked atoms to a single FG
        groups = []
        while marked:
            grp = set([marked.pop()])
            merge(mol, marked, grp)
            groups.append(grp)
        groups = [list(x) for x in groups]

        ## It seems that the initial filtering of heteroatoms was not enough, so we add this to remove groups with only aromatic atoms
        for g in groups:
            group_aromaticity = set(
                [mol.GetAtomWithIdx(idx).GetIsAromatic() for idx in g])
            if group_aromaticity == {True}:
                groups.remove(g)

    ## Identify bonds to break and hydrogens to keep for every FG
        bonds = []
        labels = []
        for g in groups:
            group_bonds = []
            group_labels = []
            for idx in g:
                atom = mol.GetAtomWithIdx(idx)

                ## Carbon atoms
                if atom.GetAtomicNum() == 6:
                    for nbr in atom.GetNeighbors():
                        ## Carbonyl groups to disciminate between aldehydes and ketones
                        if nbr.GetAtomicNum() == 8 and str(
                                mol.GetBondBetweenAtoms(
                                    idx,
                                    nbr.GetIdx()).GetBondType()) == "DOUBLE":
                            PreserveH = True
                            break
                        else:
                            PreserveH = False
                    if PreserveH == True:
                        for nbr in atom.GetNeighbors():
                            jdx = nbr.GetIdx()
                            if jdx not in g and nbr.GetAtomicNum() != 1:
                                group_bonds.append(
                                    mol.GetBondBetweenAtoms(idx, jdx).GetIdx())
                                group_labels.append((0, 0))
                    else:
                        for nbr in atom.GetNeighbors():
                            jdx = nbr.GetIdx()
                            if jdx not in g:
                                group_bonds.append(
                                    mol.GetBondBetweenAtoms(idx, jdx).GetIdx())
                                group_labels.append((0, 0))
                ## Nitrogen atoms
                elif atom.GetAtomicNum() == 7:
                    ## To discriminate between anilines and amines (primary, secondary, etc)
                    if len(g) == 1:
                        neigh_atn = [
                            x.GetAtomicNum() for x in atom.GetNeighbors()
                            if x.GetAtomicNum() != 1
                        ]
                        if neigh_atn.count(6) == 1:
                            for nbr in atom.GetNeighbors():
                                jdx = nbr.GetIdx()
                                if jdx not in g and nbr.GetAtomicNum() != 1:
                                    group_bonds.append(
                                        mol.GetBondBetweenAtoms(idx,
                                                                jdx).GetIdx())
                                    if nbr.GetIsAromatic() == True:
                                        group_labels.append((1, 1))
                                    else:
                                        group_labels.append((0, 0))
                        else:
                            for nbr in atom.GetNeighbors():
                                jdx = nbr.GetIdx()
                                if jdx not in g and nbr.GetAtomicNum() != 1:
                                    group_bonds.append(
                                        mol.GetBondBetweenAtoms(idx,
                                                                jdx).GetIdx())
                                    group_labels.append((0, 0))
                    else:
                        for nbr in atom.GetNeighbors():
                            jdx = nbr.GetIdx()
                            if jdx not in g:
                                group_bonds.append(
                                    mol.GetBondBetweenAtoms(idx, jdx).GetIdx())
                                group_labels.append((0, 0))

                ## Oxygen atoms
                elif atom.GetAtomicNum() == 8:
                    ## To discriminate between alcohols from phenols and esthers from carboxylic acids
                    if len(g) == 1:
                        neigh_atn = [
                            x.GetAtomicNum() for x in atom.GetNeighbors()
                            if x.GetAtomicNum() != 1
                        ]
                        if len(neigh_atn) == 1 and neigh_atn.count(6) == 1:
                            for nbr in atom.GetNeighbors():
                                jdx = nbr.GetIdx()
                                if jdx not in g and (nbr.GetAtomicNum() != 1):
                                    group_bonds.append(
                                        mol.GetBondBetweenAtoms(idx,
                                                                jdx).GetIdx())
                                    if nbr.GetIsAromatic() == True:
                                        group_labels.append((1, 1))
                                    else:
                                        group_labels.append((0, 0))
                        else:
                            for nbr in atom.GetNeighbors():
                                jdx = nbr.GetIdx()
                                if jdx not in g and nbr.GetAtomicNum() != 1:
                                    group_bonds.append(
                                        mol.GetBondBetweenAtoms(idx,
                                                                jdx).GetIdx())
                                    group_labels.append((0, 0))
                    else:
                        for nbr in atom.GetNeighbors():
                            jdx = nbr.GetIdx()
                            if jdx not in g and nbr.GetAtomicNum() != 1:
                                group_bonds.append(
                                    mol.GetBondBetweenAtoms(idx, jdx).GetIdx())
                                group_labels.append((0, 0))

                ## Sulfur atoms
                elif atom.GetAtomicNum() == 16:
                    if len(g) == 1:
                        for nbr in atom.GetNeighbors():
                            jdx = nbr.GetIdx()
                            if jdx not in g and nbr.GetAtomicNum() != 1:
                                group_bonds.append(
                                    mol.GetBondBetweenAtoms(idx, jdx).GetIdx())
                                group_labels.append((0, 0))
                    else:
                        for nbr in atom.GetNeighbors():
                            jdx = nbr.GetIdx()
                            if jdx not in g:
                                group_bonds.append(
                                    mol.GetBondBetweenAtoms(idx, jdx).GetIdx())
                                group_labels.append((0, 0))

                else:
                    for nbr in atom.GetNeighbors():
                        jdx = nbr.GetIdx()
                        if jdx not in g:
                            group_bonds.append(
                                mol.GetBondBetweenAtoms(idx, jdx).GetIdx())
                            group_labels.append((0, 0))
            labels.append(group_labels)
            bonds.append(group_bonds)

    ## Build final fragments
        FGS_ENVS = []
        for i in range(len(groups)):
            Frag = Chem.FragmentOnBonds(mol, bonds[i], dummyLabels=labels[i])
            Frags = rdmolops.GetMolFrags(Frag)
            for j in Frags:
                if groups[i][0] in j:
                    FGS_ENVS.append(
                        Chem.MolFragmentToSmiles(Frag,
                                                 j,
                                                 canonical=True,
                                                 allHsExplicit=True))
        FGS_ENVS = list(set(FGS_ENVS))
        for i in FGS_ENVS:
            if Chem.MolFromSmiles(i) == None:
                FG = Chem.MolFromSmarts(i)
            else:
                FG = Chem.MolFromSmiles(i)
            if set([
                    atom.GetIsAromatic() for atom in FG.GetAtoms()
                    if atom.GetSymbol() not in ["*", "H"]
            ]) == {True}:
                FGS_ENVS.remove(i)
        return FGS_ENVS

    except:
        ## When the molecules is as small as a single FG
        FGS_ENVS = [Chem.MolToSmiles(mol, canonical=True, allHsExplicit=True)]
        return FGS_ENVS
Example #12
0
def perform_lrp(model, hyper, trial=0, sample=None, epsilon=0.1, gamma=0.1):
    tf.config.experimental.set_memory_growth(
        tf.config.experimental.list_physical_devices('GPU')[0], True)

    # Make folder
    fig_path = "../analysis/{}".format(model)
    if not os.path.isdir(fig_path):
        os.mkdir(fig_path)
    fig_path = "../analysis/{}/{}".format(model, hyper)
    if not os.path.isdir(fig_path):
        os.mkdir(fig_path)
    fig_path = "../analysis/{}/{}/heatmap".format(model, hyper)
    if not os.path.isdir(fig_path):
        os.mkdir(fig_path)

    # Load results
    base_path = "../result/{}/{}/".format(model, hyper)
    path = base_path + 'trial_{:02d}/'.format(trial)

    # Load hyper
    with open(path + 'hyper.csv', newline='') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            hyper = dict(row)

    # Load model
    custom_objects = {
        'NodeEmbedding': NodeEmbedding,
        'GraphConvolution': GraphConvolution,
        'Normalize': Normalize,
        'GlobalPooling': GlobalPooling
    }
    model = load_model(path + 'best_model.h5', custom_objects=custom_objects)
    print([l.name for l in model.layers])

    # Load data
    data = np.load(path + 'data_split.npz')
    dataset = Dataset('refined', 5)
    if sample is not None:
        dataset.split_by_idx(32, data['train'], data['valid'],
                             data['test'][sample])
    else:
        dataset.split_by_idx(32, data['train'], data['valid'], data['test'])
    data.close()

    # Predict
    true_y = dataset.test_y
    outputs = {}
    for layer_name in [
            'node_embedding', 'node_embedding_1', 'normalize', 'normalize_1',
            'activation', 'add', 'activation_1', 'add_1', 'global_pooling',
            'activation_2', 'activation_3', 'activation_4',
            'atom_feature_input'
    ]:
        sub_model = tf.keras.models.Model(
            inputs=model.input, outputs=model.get_layer(layer_name).output)
        outputs[layer_name] = sub_model.predict(dataset.test,
                                                steps=dataset.test_step,
                                                verbose=0)[:len(true_y)]

    # Output layer: LRP-0
    # print('Calculating Dense_2...')
    relevance = lrp_dense(outputs['activation_3'],
                          outputs['activation_4'],
                          model.get_layer('dense_2').get_weights()[0],
                          model.get_layer('dense_2').get_weights()[1],
                          epsilon=0)

    # Dense layer: LRP-e
    # print('Calculating Dense_1...')
    relevance = lrp_dense(outputs['activation_2'],
                          relevance,
                          model.get_layer('dense_1').get_weights()[0],
                          model.get_layer('dense_1').get_weights()[1],
                          epsilon=epsilon)

    # Dense layer: LRP-e
    # print('Calculating Dense_0...')
    relevance = lrp_dense(outputs['global_pooling'],
                          relevance,
                          model.get_layer('dense').get_weights()[0],
                          model.get_layer('dense').get_weights()[1],
                          epsilon=epsilon)

    # Pooling layer
    # print('Calculating Pooling...')
    relevance = lrp_pooling(outputs['activation_1'], relevance)

    # Add layer
    # print('Calculating Add_1...')
    relevance_1, relevance_2 = lrp_add(
        [outputs['add'], outputs['activation_1']], relevance)

    # GCN layer: LRP-g
    # print('Calculating GCN_1...')
    relevance = lrp_gcn_gamma(
        outputs['add'],
        relevance_2,
        outputs['normalize_1'],
        model.get_layer('graph_convolution_1').get_weights()[0],
        gamma=gamma) + relevance_1

    # Add layer
    # print('Calculating Add_0...')
    relevance_1, relevance_2 = lrp_add(
        [outputs['graph_embedding_1'], outputs['activation']], relevance)

    # GCN layer: LRP-g
    # print('Calculating GCN_0...')
    relevance = lrp_gcn_gamma(
        outputs['graph_embedding_1'],
        relevance_2,
        outputs['normalize'],
        model.get_layer('graph_convolution').get_weights()[0],
        gamma=gamma) + relevance_1

    # Embedding layer : LRP-e
    # print('Calculating Embedding_1...')
    relevance = lrp_dense(
        outputs['graph_embedding'],
        relevance,
        model.get_layer('graph_embedding_1').get_weights()[0],
        model.get_layer('graph_embedding_1').get_weights()[1],
        epsilon=epsilon)

    # Embedding layer : LRP-e
    # print('Calculating Embedding_0...')
    relevance = lrp_dense(outputs['atom_feature_input'],
                          relevance,
                          model.get_layer('graph_embedding').get_weights()[0],
                          model.get_layer('graph_embedding').get_weights()[1],
                          epsilon=epsilon)

    relevance = tf.math.reduce_sum(relevance, axis=-1).numpy()
    relevance = np.divide(relevance, np.expand_dims(true_y, -1))

    # Preset
    DrawingOptions.bondLineWidth = 1.5
    DrawingOptions.elemDict = {}
    DrawingOptions.dotsPerAngstrom = 20
    DrawingOptions.atomLabelFontSize = 4
    DrawingOptions.atomLabelMinFontSize = 4
    DrawingOptions.dblBondOffset = 0.3
    DrawingOptions.wedgeDashedBonds = False

    # Load data
    dataframe = pd.read_pickle('../data/5A.pkl')
    if sample is not None:
        test_set = np.load(path + 'data_split.npz')['test'][sample]
    else:
        test_set = np.load(path + 'data_split.npz')['test']

    # Draw images for test molecules
    colormap = cm.get_cmap('seismic')
    for idx, test_idx in enumerate(test_set):
        print('Drawing figure for {}/{}'.format(idx, len(test_set)))
        pdb_code = dataframe.iloc[test_idx]['code']
        error = np.absolute(dataframe.iloc[test_idx]['output'] -
                            outputs['activation_4'][idx])[0]
        if error > 0.2: continue

        for mol_ligand, mol_pocket in zip(
                Chem.SDMolSupplier(
                    '../data/refined-set/{}/{}_ligand.sdf'.format(
                        pdb_code, pdb_code)),
                Chem.SDMolSupplier(
                    '../data/refined-set/{}/{}_pocket.sdf'.format(
                        pdb_code, pdb_code))):

            # Crop atoms
            mol = Chem.CombineMols(mol_ligand, mol_pocket)
            distance = np.array(rdmolops.Get3DDistanceMatrix(mol))
            cropped_idx = np.argwhere(
                np.min(distance[:, :mol_ligand.GetNumAtoms()], axis=1) <= 5
            ).flatten()
            unpadded_relevance = np.zeros((mol.GetNumAtoms(), ))
            np.put(unpadded_relevance, cropped_idx, relevance[idx])
            scale = max(max(unpadded_relevance),
                        math.fabs(min(unpadded_relevance))) * 3

            # Separate fragments in Combined Mol
            idxs_frag = rdmolops.GetMolFrags(mol)
            mols_frag = rdmolops.GetMolFrags(mol, asMols=True)

            # Draw fragment and interaction
            for i, (mol_frag,
                    idx_frag) in enumerate(zip(mols_frag[1:], idxs_frag[1:])):
                # Ignore water
                if mol_frag.GetNumAtoms() == 1:
                    continue

                # Generate 2D image
                mol_combined = Chem.CombineMols(mols_frag[0], mol_frag)
                AllChem.Compute2DCoords(mol_combined)
                fig = Draw.MolToMPL(mol_combined, coordScale=1)
                fig.axes[0].set_axis_off()

                # Draw line between close atoms (5A)
                flag = False
                for j in range(mol_ligand.GetNumAtoms()):
                    for k in idx_frag:
                        if distance[j, k] <= 5:
                            # Draw connection
                            coord_li = mol_combined._atomPs[j]
                            coord_po = mol_combined._atomPs[
                                idx_frag.index(k) + mols_frag[0].GetNumAtoms()]
                            x, y = np.array([[coord_li[0], coord_po[0]],
                                             [coord_li[1], coord_po[1]]])
                            line = Line2D(x,
                                          y,
                                          color='b',
                                          linewidth=1,
                                          alpha=0.3)
                            fig.axes[0].add_line(line)
                            flag = True

                # Draw heatmap for atoms
                for j in range(mol_combined.GetNumAtoms()):
                    relevance_li = unpadded_relevance[j]
                    relevance_li = relevance_li / scale + 0.5
                    highlight = plt.Circle(
                        (mol_combined._atomPs[j][0],
                         mol_combined._atomPs[j][1]),
                        0.035 * math.fabs(unpadded_relevance[j] / scale) +
                        0.008,
                        color=colormap(relevance_li),
                        alpha=0.8,
                        zorder=0)
                    fig.axes[0].add_artist(highlight)

                # Save
                if flag:
                    fig_name = fig_path + '/{}_lrp_{}_{}_{}.png'.format(
                        trial, test_idx, pdb_code, i)
                    fig.savefig(fig_name, bbox_inches='tight')
                plt.close(fig)
Example #13
0
def all_bond_remove(
        mol: Chem.rdchem.Mol,
        as_mol: bool = True,
        allow_bond_decrease: bool = True,
        allow_atom_trim: bool = True,
        max_num_action=float("Inf"),
):
    """Remove bonds from a molecule

    Warning:
        This can be computationally expensive.

    Args:
        mol: Input molecule
        allow_bond_decrease: Allow decreasing bond type in addition to bond cut
        max_num_action: Maximum number of action to reduce complexity
        allow_atom_trim: Allow bond removal even when it results in dm.SINGLE_BOND

    Returns:
        All possible molecules from removing bonds

    """
    new_mols = []

    try:
        Chem.Kekulize(mol, clearAromaticFlags=True)
    except:
        pass

    for bond in mol.GetBonds():
        if len(new_mols) > max_num_action:
            break

        original_bond_type = bond.GetBondType()
        emol = Chem.RWMol(mol)
        emol.RemoveBond(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx())
        new_mol = dm.sanitize_mol(emol.GetMol())

        if not new_mol:
            continue

        frag_list = list(rdmolops.GetMolFrags(new_mol, asMols=True))
        has_single_atom = any([x.GetNumAtoms() < 2 for x in frag_list])
        if not has_single_atom or allow_atom_trim:
            new_mols.extend(frag_list)
        if allow_bond_decrease:
            if original_bond_type in [dm.DOUBLE_BOND, dm.TRIPLE_BOND]:
                new_mol = update_bond(mol, bond, dm.SINGLE_BOND)
                if new_mol is not None:
                    new_mols.extend(
                        list(rdmolops.GetMolFrags(new_mol, asMols=True)))
            if original_bond_type == dm.TRIPLE_BOND:
                new_mol = update_bond(mol, bond, dm.DOUBLE_BOND)
                if new_mol is not None:
                    new_mols.extend(
                        list(rdmolops.GetMolFrags(new_mol, asMols=True)))

    new_mols = [mol for mol in new_mols if mol is not None]

    if not as_mol:
        return [dm.to_smiles(x) for x in new_mols if x]

    return new_mols
Example #14
0
def get_largest_fragment(mol):
    return max(rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=False),
               default=mol,
               key=lambda m: m.GetNumAtoms())