def remove_salts_solvents(mol, hac=3): """ Remove solvents and ions have max 'hac' heavy atoms. This function was obtained from the mol2vec package, available at: https://github.com/samoturk/mol2vec/blob/master/mol2vec/features.py """ # split molecule into fragments fragments = list(rdmolops.GetMolFrags(mol, asMols=True)) ## keep heaviest only ## fragments.sort(reverse=True, key=lambda m: m.GetNumAtoms()) # remove fragments with < 'hac' heavy atoms fragments = [fragment for fragment in fragments if \ fragment.GetNumAtoms() > hac] # if len(fragments) > 1: warnings.warn("molecule contains >1 fragment with >" + str(hac) + \ " heavy atoms") return None elif len(fragments) == 0: warnings.warn("molecule contains no fragments with >" + str(hac) + \ " heavy atoms") return None else: return fragments[0]
def filter_extract_mol(row, headers_dict): relation_idx = headers_dict['RELATION'] ro5_idx = headers_dict['NUM_RO5_VIOLATIONS'] pchembl_idx = headers_dict['PCHEMBL_VALUE'] try: ro5_violations = int(row[ro5_idx]) except: return None try: float(row[pchembl_idx]) except: return None if (row[relation_idx] != '=' or ro5_violations > 0): return None mol = Chem.MolFromSmiles(row[smiles_idx]) if (not mol): return None mols = list(rdmolops.GetMolFrags(mol, asMols=True)) if (not mols): return None mols.sort(reverse=True, key=lambda m: m.GetNumAtoms()) mol = mols[0] molWt = Descriptors.MolWt(mol) if (molWt < 300 or molWt > 400): return None return mol
def keep_largest_fragment(mol: Chem.rdchem.Mol) -> Optional[Chem.rdchem.Mol]: """Only keep largest fragment of each molecule.""" return max( rdmolops.GetMolFrags(mol, asMols=True), default=mol, key=lambda m: m.GetNumAtoms(), )
def _split_products(self, smi): mol = AllChem.MolFromSmiles(smi) splitMols = rdmolops.GetMolFrags(mol, asMols=True) split_list = [] for mol in splitMols: p_smile = AllChem.MolToSmiles(mol) split_list.append(p_smile) return split_list
def _getlargestFragment(mol): frags = rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=True) maxmol = None for mol in frags: if mol is None: continue if maxmol is None: maxmol = mol if maxmol.GetNumHeavyAtoms() < mol.GetNumHeavyAtoms(): maxmol = mol return maxmol
def n_disconnected(mol): """ The number of disconnected fragments in the mol. Args: mol (skchem.Mol): The molecule for which to calculate the descriptor. Returns: int """ return len(rdmolops.GetMolFrags(mol))
def all_fragment_on_bond(mol, asMols=False, max_num_action=float("Inf"), break_aromatic=True): """Fragment all possible bond in a molecule and return the set of resulting fragments This is similar to `random_bond_cut`, but is not stochastic as it does not return a random fragment but all the fragments resulting from all potential bond break in the molecule. .. note:: This will always be a subset of all_bond_remove, the main difference being that all_bond_remove, allow decreasing bond count, while this one will always break a molecule into two. Args: mol: <Chem.Mol> input molecule asMols: bool, optional Whether to return results as mols or smiles max_num_action: float, optional Maximum number of action to reduce complexity break_aromatic: bool, optional Whether to attempt to break even aromatic bonds (Default: True) Returns: set of fragments """ mol.GetRingInfo().AtomRings() fragment_set = set([]) bonds = list(mol.GetBonds()) stop = False if bonds: if break_aromatic: Chem.Kekulize(mol, clearAromaticFlags=True) for bond in bonds: if stop: break if break_aromatic or not bond.GetIsAromatic(): truncate = Chem.FragmentOnBonds(mol, [bond.GetIdx()], addDummies=False) truncate = dm.sanitize_mol(truncate) if truncate is not None: for frag in rdmolops.GetMolFrags(truncate, asMols=True): frag = dm.sanitize_mol(frag) if frag: if not asMols: frag = dm.to_smiles(frag) fragment_set.add(frag) if len(fragment_set) > max_num_action: stop = True break return fragment_set
def apply_retrorules(self, smile, rxns, explicit_hydrogens=False): '''Function takes a smile and dictionary of reactions, applys the reactions and returns a dictionary of rxn_names : products ''' try: substrate_molecule = AllChem.MolFromSmiles(smile) except: return {} if explicit_hydrogens == True: substrate_molecule = rdmolops.AddHs(substrate_molecule) rxn_product_dict = {} for rxn_name, rxn in rxns.items(): try: products = rxn.RunReactants((substrate_molecule, )) except: products = [] print('Error running reactants for: ' + str(smile)) smiles_products = [] for product in products: sub_list = [] for mol in product: mols = [mol] if explicit_hydrogens == True: mol = rdmolops.RemoveHs(mol) try: mols = rdmolops.GetMolFrags(mol, asMols=True) except: pass for mol in mols: try: p_smile = AllChem.MolToSmiles(mol) p_smile = rdkit_smile(p_smile) if self._check_valid_smile( p_smile, rxn_name=rxn_name) == True: sub_list.append(p_smile) except: pass if (sub_list not in smiles_products) and (len(sub_list) != 0): smiles_products.append(sub_list) if len(smiles_products) != 0: rxn_product_dict[rxn_name] = smiles_products return rxn_product_dict
def _neutralise_sulphoxide(mol): smirks = '[S+1:1][O-1:2]>>[S+0:1]=[O-0:2]' rxn = rdChemReactions.ReactionFromSmarts(smirks) frags = rdmolops.GetMolFrags(mol, asMols=True) n_frags = list( filter(lambda x: x is not None, [_apply_rxn(frag, rxn) for frag in frags])) if len(n_frags) == 1: n_mol = n_frags[0] elif len(n_frags) == 2: n_mol = CombineMols(*n_frags) SanitizeMol(n_mol) else: n_mol = CombineMols(n_frags[0], n_frags[1]) for i in range(2, len(n_frags)): n_mol = CombineMols(n_mol, n_frags[i]) SanitizeMol(n_mol) return n_mol
def fuzzy_scaffolding( mols: List[Chem.rdchem.Mol], enforce_subs: List[str] = None, n_atom_cuttoff: int = 8, additional_templates: List[Chem.rdchem.Mol] = None, ignore_non_ring: bool = False, mcs_params: Dict[Any, Any] = None, ): """Generate fuzzy scaffold with enforceable group that needs to appear in the core, forcing to keep the full side chain if required. NOTE(hadim): consider parallelize this (if possible). Args: mols: List of all molecules enforce_subs: List of substructure to enforce on the scaffold. n_atom_cuttoff: Minimum number of atom a core should have. additional_templates: Additional template to use to generate scaffolds. ignore_non_ring: Whether to ignore atom no in murcko ring system, even if they are in the framework. mcs_params: Arguments of MCS algorithm. Returns: scaffolds: set All found scaffolds in the molecules as valid smiles scaffold_infos: dict of dict Infos on the scaffold mapping, ignoring any side chain that had to be enforced. Key corresponds to generic scaffold smiles Values at ['smarts'] corresponds to smarts representation of the true scaffold (from MCS) Values at ['mols'] corresponds to list of molecules matching the scaffold scaffold_to_group: dict of list Map between each generic scaffold and the R-groups decomposition row """ if enforce_subs is None: enforce_subs = [] if additional_templates is None: additional_templates = [] if mcs_params is None: mcs_params = {} rg_params = rdRGroupDecomposition.RGroupDecompositionParameters() rg_params.removeAllHydrogenRGroups = True rg_params.removeHydrogensPostMatch = True rg_params.alignment = rdRGroupDecomposition.RGroupCoreAlignment.MCS rg_params.matchingStrategy = rdRGroupDecomposition.RGroupMatching.Exhaustive rg_params.rgroupLabelling = rdRGroupDecomposition.RGroupLabelling.AtomMap rg_params.labels = rdRGroupDecomposition.RGroupLabels.AtomIndexLabels core_query_param = AdjustQueryParameters() core_query_param.makeDummiesQueries = True core_query_param.adjustDegree = False core_query_param.makeBondsGeneric = True # group molecules by they generic Murcko scaffold, allowing # side chain that contains cycle (might be a bad idea) scf2infos = collections.defaultdict(dict) scf2groups = {} all_scaffolds = set([]) for m in mols: generic_m = MurckoScaffold.MakeScaffoldGeneric(m) scf = MurckoScaffold.GetScaffoldForMol(m) try: scf = MurckoScaffold.MakeScaffoldGeneric(scf) except: pass if ignore_non_ring: rw_scf = Chem.RWMol(scf) atms = [a.GetIdx() for a in rw_scf.GetAtoms() if not a.IsInRing()] atms.sort(reverse=True) for a in atms: rw_scf.RemoveAtom(a) scfs = list(rdmolops.GetMolFrags(rw_scf, asMols=False)) else: scfs = [dm.to_smiles(scf)] # add templates mols if exists: for tmp in additional_templates: tmp = dm.to_mol(tmp) tmp_scf = MurckoScaffold.MakeScaffoldGeneric(tmp) if generic_m.HasSubstructMatch(tmp_scf): scfs.append(dm.to_smiles(tmp_scf)) for scf in scfs: if scf2infos[scf].get("mols"): scf2infos[scf]["mols"].append(m) else: scf2infos[scf]["mols"] = [m] for scf in scf2infos: # cheat by adding murcko as last mol always popout = False mols = scf2infos[scf]["mols"] if len(mols) < 2: mols = mols + [MurckoScaffold.GetScaffoldForMol(mols[0])] popout = True # compute the MCS of the cluster mcs = rdFMCS.FindMCS( mols, atomCompare=rdFMCS.AtomCompare.CompareAny, bondCompare=rdFMCS.BondCompare.CompareAny, completeRingsOnly=True, **mcs_params, ) mcsM = Chem.MolFromSmarts(mcs.smartsString) mcsM.UpdatePropertyCache(False) Chem.SetHybridization(mcsM) if mcsM.GetNumAtoms() < n_atom_cuttoff: continue scf2infos[scf]["smarts"] = dm.to_smarts(mcsM) if popout: mols = mols[:-1] core_groups = [] # generate rgroups based on the mcs core success_mols = [] try: rg = rdRGroupDecomposition.RGroupDecomposition(mcsM, rg_params) for i, analog in enumerate(mols): analog.RemoveAllConformers() res = rg.Add(analog) if not (res < 0): success_mols.append(i) rg.Process() core_groups = rg.GetRGroupsAsRows() except Exception: pass mols = [mols[i] for i in success_mols] scf2groups[scf] = core_groups for mol, gp in zip(mols, core_groups): core = gp["Core"] acceptable_groups = [ a.GetAtomMapNum() for a in core.GetAtoms() if (a.GetAtomMapNum() and not a.IsInRing()) ] rgroups = [ gp[f"R{k}"] for k in acceptable_groups if f"R{k}" in gp.keys() ] if enforce_subs: rgroups = [ rgp for rgp in rgroups if not any([ len(rgp.GetSubstructMatch(frag)) > 0 for frag in enforce_subs ]) ] try: scaff = trim_side_chain( mol, AdjustQueryProperties(core, core_query_param), rgroups) except: continue all_scaffolds.add(dm.to_smiles(scaff)) return all_scaffolds, scf2infos, scf2groups
def identify_functional_groups(smi): ## We decided to start from a SMILES and add explicit hydrogens inside the function mol = Chem.MolFromSmiles(smi) mol = rdmolops.AddHs(mol) try: marked = set() ## Since heteroatoms are included in PATT_TUPLE, we remove the first part of the original function for patt in PATT_TUPLE: for path in mol.GetSubstructMatches(patt): for atomindex in path: marked.add(atomindex) #merge all connected marked atoms to a single FG groups = [] while marked: grp = set([marked.pop()]) merge(mol, marked, grp) groups.append(grp) groups = [list(x) for x in groups] ## It seems that the initial filtering of heteroatoms was not enough, so we add this to remove groups with only aromatic atoms for g in groups: group_aromaticity = set( [mol.GetAtomWithIdx(idx).GetIsAromatic() for idx in g]) if group_aromaticity == {True}: groups.remove(g) ## Identify bonds to break and hydrogens to keep for every FG bonds = [] labels = [] for g in groups: group_bonds = [] group_labels = [] for idx in g: atom = mol.GetAtomWithIdx(idx) ## Carbon atoms if atom.GetAtomicNum() == 6: for nbr in atom.GetNeighbors(): ## Carbonyl groups to disciminate between aldehydes and ketones if nbr.GetAtomicNum() == 8 and str( mol.GetBondBetweenAtoms( idx, nbr.GetIdx()).GetBondType()) == "DOUBLE": PreserveH = True break else: PreserveH = False if PreserveH == True: for nbr in atom.GetNeighbors(): jdx = nbr.GetIdx() if jdx not in g and nbr.GetAtomicNum() != 1: group_bonds.append( mol.GetBondBetweenAtoms(idx, jdx).GetIdx()) group_labels.append((0, 0)) else: for nbr in atom.GetNeighbors(): jdx = nbr.GetIdx() if jdx not in g: group_bonds.append( mol.GetBondBetweenAtoms(idx, jdx).GetIdx()) group_labels.append((0, 0)) ## Nitrogen atoms elif atom.GetAtomicNum() == 7: ## To discriminate between anilines and amines (primary, secondary, etc) if len(g) == 1: neigh_atn = [ x.GetAtomicNum() for x in atom.GetNeighbors() if x.GetAtomicNum() != 1 ] if neigh_atn.count(6) == 1: for nbr in atom.GetNeighbors(): jdx = nbr.GetIdx() if jdx not in g and nbr.GetAtomicNum() != 1: group_bonds.append( mol.GetBondBetweenAtoms(idx, jdx).GetIdx()) if nbr.GetIsAromatic() == True: group_labels.append((1, 1)) else: group_labels.append((0, 0)) else: for nbr in atom.GetNeighbors(): jdx = nbr.GetIdx() if jdx not in g and nbr.GetAtomicNum() != 1: group_bonds.append( mol.GetBondBetweenAtoms(idx, jdx).GetIdx()) group_labels.append((0, 0)) else: for nbr in atom.GetNeighbors(): jdx = nbr.GetIdx() if jdx not in g: group_bonds.append( mol.GetBondBetweenAtoms(idx, jdx).GetIdx()) group_labels.append((0, 0)) ## Oxygen atoms elif atom.GetAtomicNum() == 8: ## To discriminate between alcohols from phenols and esthers from carboxylic acids if len(g) == 1: neigh_atn = [ x.GetAtomicNum() for x in atom.GetNeighbors() if x.GetAtomicNum() != 1 ] if len(neigh_atn) == 1 and neigh_atn.count(6) == 1: for nbr in atom.GetNeighbors(): jdx = nbr.GetIdx() if jdx not in g and (nbr.GetAtomicNum() != 1): group_bonds.append( mol.GetBondBetweenAtoms(idx, jdx).GetIdx()) if nbr.GetIsAromatic() == True: group_labels.append((1, 1)) else: group_labels.append((0, 0)) else: for nbr in atom.GetNeighbors(): jdx = nbr.GetIdx() if jdx not in g and nbr.GetAtomicNum() != 1: group_bonds.append( mol.GetBondBetweenAtoms(idx, jdx).GetIdx()) group_labels.append((0, 0)) else: for nbr in atom.GetNeighbors(): jdx = nbr.GetIdx() if jdx not in g and nbr.GetAtomicNum() != 1: group_bonds.append( mol.GetBondBetweenAtoms(idx, jdx).GetIdx()) group_labels.append((0, 0)) ## Sulfur atoms elif atom.GetAtomicNum() == 16: if len(g) == 1: for nbr in atom.GetNeighbors(): jdx = nbr.GetIdx() if jdx not in g and nbr.GetAtomicNum() != 1: group_bonds.append( mol.GetBondBetweenAtoms(idx, jdx).GetIdx()) group_labels.append((0, 0)) else: for nbr in atom.GetNeighbors(): jdx = nbr.GetIdx() if jdx not in g: group_bonds.append( mol.GetBondBetweenAtoms(idx, jdx).GetIdx()) group_labels.append((0, 0)) else: for nbr in atom.GetNeighbors(): jdx = nbr.GetIdx() if jdx not in g: group_bonds.append( mol.GetBondBetweenAtoms(idx, jdx).GetIdx()) group_labels.append((0, 0)) labels.append(group_labels) bonds.append(group_bonds) ## Build final fragments FGS_ENVS = [] for i in range(len(groups)): Frag = Chem.FragmentOnBonds(mol, bonds[i], dummyLabels=labels[i]) Frags = rdmolops.GetMolFrags(Frag) for j in Frags: if groups[i][0] in j: FGS_ENVS.append( Chem.MolFragmentToSmiles(Frag, j, canonical=True, allHsExplicit=True)) FGS_ENVS = list(set(FGS_ENVS)) for i in FGS_ENVS: if Chem.MolFromSmiles(i) == None: FG = Chem.MolFromSmarts(i) else: FG = Chem.MolFromSmiles(i) if set([ atom.GetIsAromatic() for atom in FG.GetAtoms() if atom.GetSymbol() not in ["*", "H"] ]) == {True}: FGS_ENVS.remove(i) return FGS_ENVS except: ## When the molecules is as small as a single FG FGS_ENVS = [Chem.MolToSmiles(mol, canonical=True, allHsExplicit=True)] return FGS_ENVS
def perform_lrp(model, hyper, trial=0, sample=None, epsilon=0.1, gamma=0.1): tf.config.experimental.set_memory_growth( tf.config.experimental.list_physical_devices('GPU')[0], True) # Make folder fig_path = "../analysis/{}".format(model) if not os.path.isdir(fig_path): os.mkdir(fig_path) fig_path = "../analysis/{}/{}".format(model, hyper) if not os.path.isdir(fig_path): os.mkdir(fig_path) fig_path = "../analysis/{}/{}/heatmap".format(model, hyper) if not os.path.isdir(fig_path): os.mkdir(fig_path) # Load results base_path = "../result/{}/{}/".format(model, hyper) path = base_path + 'trial_{:02d}/'.format(trial) # Load hyper with open(path + 'hyper.csv', newline='') as csvfile: reader = csv.DictReader(csvfile) for row in reader: hyper = dict(row) # Load model custom_objects = { 'NodeEmbedding': NodeEmbedding, 'GraphConvolution': GraphConvolution, 'Normalize': Normalize, 'GlobalPooling': GlobalPooling } model = load_model(path + 'best_model.h5', custom_objects=custom_objects) print([l.name for l in model.layers]) # Load data data = np.load(path + 'data_split.npz') dataset = Dataset('refined', 5) if sample is not None: dataset.split_by_idx(32, data['train'], data['valid'], data['test'][sample]) else: dataset.split_by_idx(32, data['train'], data['valid'], data['test']) data.close() # Predict true_y = dataset.test_y outputs = {} for layer_name in [ 'node_embedding', 'node_embedding_1', 'normalize', 'normalize_1', 'activation', 'add', 'activation_1', 'add_1', 'global_pooling', 'activation_2', 'activation_3', 'activation_4', 'atom_feature_input' ]: sub_model = tf.keras.models.Model( inputs=model.input, outputs=model.get_layer(layer_name).output) outputs[layer_name] = sub_model.predict(dataset.test, steps=dataset.test_step, verbose=0)[:len(true_y)] # Output layer: LRP-0 # print('Calculating Dense_2...') relevance = lrp_dense(outputs['activation_3'], outputs['activation_4'], model.get_layer('dense_2').get_weights()[0], model.get_layer('dense_2').get_weights()[1], epsilon=0) # Dense layer: LRP-e # print('Calculating Dense_1...') relevance = lrp_dense(outputs['activation_2'], relevance, model.get_layer('dense_1').get_weights()[0], model.get_layer('dense_1').get_weights()[1], epsilon=epsilon) # Dense layer: LRP-e # print('Calculating Dense_0...') relevance = lrp_dense(outputs['global_pooling'], relevance, model.get_layer('dense').get_weights()[0], model.get_layer('dense').get_weights()[1], epsilon=epsilon) # Pooling layer # print('Calculating Pooling...') relevance = lrp_pooling(outputs['activation_1'], relevance) # Add layer # print('Calculating Add_1...') relevance_1, relevance_2 = lrp_add( [outputs['add'], outputs['activation_1']], relevance) # GCN layer: LRP-g # print('Calculating GCN_1...') relevance = lrp_gcn_gamma( outputs['add'], relevance_2, outputs['normalize_1'], model.get_layer('graph_convolution_1').get_weights()[0], gamma=gamma) + relevance_1 # Add layer # print('Calculating Add_0...') relevance_1, relevance_2 = lrp_add( [outputs['graph_embedding_1'], outputs['activation']], relevance) # GCN layer: LRP-g # print('Calculating GCN_0...') relevance = lrp_gcn_gamma( outputs['graph_embedding_1'], relevance_2, outputs['normalize'], model.get_layer('graph_convolution').get_weights()[0], gamma=gamma) + relevance_1 # Embedding layer : LRP-e # print('Calculating Embedding_1...') relevance = lrp_dense( outputs['graph_embedding'], relevance, model.get_layer('graph_embedding_1').get_weights()[0], model.get_layer('graph_embedding_1').get_weights()[1], epsilon=epsilon) # Embedding layer : LRP-e # print('Calculating Embedding_0...') relevance = lrp_dense(outputs['atom_feature_input'], relevance, model.get_layer('graph_embedding').get_weights()[0], model.get_layer('graph_embedding').get_weights()[1], epsilon=epsilon) relevance = tf.math.reduce_sum(relevance, axis=-1).numpy() relevance = np.divide(relevance, np.expand_dims(true_y, -1)) # Preset DrawingOptions.bondLineWidth = 1.5 DrawingOptions.elemDict = {} DrawingOptions.dotsPerAngstrom = 20 DrawingOptions.atomLabelFontSize = 4 DrawingOptions.atomLabelMinFontSize = 4 DrawingOptions.dblBondOffset = 0.3 DrawingOptions.wedgeDashedBonds = False # Load data dataframe = pd.read_pickle('../data/5A.pkl') if sample is not None: test_set = np.load(path + 'data_split.npz')['test'][sample] else: test_set = np.load(path + 'data_split.npz')['test'] # Draw images for test molecules colormap = cm.get_cmap('seismic') for idx, test_idx in enumerate(test_set): print('Drawing figure for {}/{}'.format(idx, len(test_set))) pdb_code = dataframe.iloc[test_idx]['code'] error = np.absolute(dataframe.iloc[test_idx]['output'] - outputs['activation_4'][idx])[0] if error > 0.2: continue for mol_ligand, mol_pocket in zip( Chem.SDMolSupplier( '../data/refined-set/{}/{}_ligand.sdf'.format( pdb_code, pdb_code)), Chem.SDMolSupplier( '../data/refined-set/{}/{}_pocket.sdf'.format( pdb_code, pdb_code))): # Crop atoms mol = Chem.CombineMols(mol_ligand, mol_pocket) distance = np.array(rdmolops.Get3DDistanceMatrix(mol)) cropped_idx = np.argwhere( np.min(distance[:, :mol_ligand.GetNumAtoms()], axis=1) <= 5 ).flatten() unpadded_relevance = np.zeros((mol.GetNumAtoms(), )) np.put(unpadded_relevance, cropped_idx, relevance[idx]) scale = max(max(unpadded_relevance), math.fabs(min(unpadded_relevance))) * 3 # Separate fragments in Combined Mol idxs_frag = rdmolops.GetMolFrags(mol) mols_frag = rdmolops.GetMolFrags(mol, asMols=True) # Draw fragment and interaction for i, (mol_frag, idx_frag) in enumerate(zip(mols_frag[1:], idxs_frag[1:])): # Ignore water if mol_frag.GetNumAtoms() == 1: continue # Generate 2D image mol_combined = Chem.CombineMols(mols_frag[0], mol_frag) AllChem.Compute2DCoords(mol_combined) fig = Draw.MolToMPL(mol_combined, coordScale=1) fig.axes[0].set_axis_off() # Draw line between close atoms (5A) flag = False for j in range(mol_ligand.GetNumAtoms()): for k in idx_frag: if distance[j, k] <= 5: # Draw connection coord_li = mol_combined._atomPs[j] coord_po = mol_combined._atomPs[ idx_frag.index(k) + mols_frag[0].GetNumAtoms()] x, y = np.array([[coord_li[0], coord_po[0]], [coord_li[1], coord_po[1]]]) line = Line2D(x, y, color='b', linewidth=1, alpha=0.3) fig.axes[0].add_line(line) flag = True # Draw heatmap for atoms for j in range(mol_combined.GetNumAtoms()): relevance_li = unpadded_relevance[j] relevance_li = relevance_li / scale + 0.5 highlight = plt.Circle( (mol_combined._atomPs[j][0], mol_combined._atomPs[j][1]), 0.035 * math.fabs(unpadded_relevance[j] / scale) + 0.008, color=colormap(relevance_li), alpha=0.8, zorder=0) fig.axes[0].add_artist(highlight) # Save if flag: fig_name = fig_path + '/{}_lrp_{}_{}_{}.png'.format( trial, test_idx, pdb_code, i) fig.savefig(fig_name, bbox_inches='tight') plt.close(fig)
def all_bond_remove( mol: Chem.rdchem.Mol, as_mol: bool = True, allow_bond_decrease: bool = True, allow_atom_trim: bool = True, max_num_action=float("Inf"), ): """Remove bonds from a molecule Warning: This can be computationally expensive. Args: mol: Input molecule allow_bond_decrease: Allow decreasing bond type in addition to bond cut max_num_action: Maximum number of action to reduce complexity allow_atom_trim: Allow bond removal even when it results in dm.SINGLE_BOND Returns: All possible molecules from removing bonds """ new_mols = [] try: Chem.Kekulize(mol, clearAromaticFlags=True) except: pass for bond in mol.GetBonds(): if len(new_mols) > max_num_action: break original_bond_type = bond.GetBondType() emol = Chem.RWMol(mol) emol.RemoveBond(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()) new_mol = dm.sanitize_mol(emol.GetMol()) if not new_mol: continue frag_list = list(rdmolops.GetMolFrags(new_mol, asMols=True)) has_single_atom = any([x.GetNumAtoms() < 2 for x in frag_list]) if not has_single_atom or allow_atom_trim: new_mols.extend(frag_list) if allow_bond_decrease: if original_bond_type in [dm.DOUBLE_BOND, dm.TRIPLE_BOND]: new_mol = update_bond(mol, bond, dm.SINGLE_BOND) if new_mol is not None: new_mols.extend( list(rdmolops.GetMolFrags(new_mol, asMols=True))) if original_bond_type == dm.TRIPLE_BOND: new_mol = update_bond(mol, bond, dm.DOUBLE_BOND) if new_mol is not None: new_mols.extend( list(rdmolops.GetMolFrags(new_mol, asMols=True))) new_mols = [mol for mol in new_mols if mol is not None] if not as_mol: return [dm.to_smiles(x) for x in new_mols if x] return new_mols
def get_largest_fragment(mol): return max(rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=False), default=mol, key=lambda m: m.GetNumAtoms())