def process_files(input_dir): """ Process all protein (pdb) and ligand (sdf) files in input directory. :param input dir: directory containing PDBBind data :type input_dir: str :return structure_dict: dictionary containing each structure, keyed by PDB code. Each PDB is a dict containing protein as Biopython object and ligand as RDKit Mol object :rtype structure_dict: dict """ structure_dict = {} pdb_files = fi.find_files(input_dir, 'pdb') for f in tqdm(pdb_files, desc='pdb files'): pdb_id = fi.get_pdb_code(f) if pdb_id not in structure_dict: structure_dict[pdb_id] = {} if '_protein' in f: prot = ft.read_any(f) structure_dict[pdb_id]['protein'] = prot lig_files = fi.find_files(input_dir, 'sdf') for f in tqdm(lig_files, desc='ligand files'): pdb_id = fi.get_pdb_code(f) structure_dict[pdb_id]['ligand'] = get_ligand(f) return structure_dict
def _create_cluster_split(all_chain_sequences, clusterings, to_use, split_size, min_fam_in_split): """ Helper function for :func:`create_cluster_split`. Creates a single split of ``split_size`` elements while retaining diversity specified by ``min_fam_in_split``. Takes in ``all_chain_sequences`` and reference ``clusterings`` from PDB, as well as a list of valid indices to sample from, specified by ``to_use``. Returns indices of new split and indices remaining in dataset after removing those used in split. """ dataset_size = len(all_chain_sequences) code_to_idx = { fi.get_pdb_code(y[0][0]): i for (i, x) in enumerate(all_chain_sequences) for y in x } all_indices = set(range(dataset_size)) split, used = set(), all_indices.difference(to_use) while len(split) < split_size: i = random.sample(to_use, 1)[0] pdb_code = fi.get_pdb_code(all_chain_sequences[i][0][0][0]) found = seq.find_cluster_members(pdb_code, clusterings) # Map back to source. found = set([code_to_idx[x] for x in found]) found = found.difference(used) # ensure that at least min_fam_in_split families in each split max_fam_size = int(math.ceil(split_size / min_fam_in_split)) split = split.union(list(found)[:max_fam_size]) to_use = to_use.difference(found) used = used.union(found) return split, to_use
def _create_cluster_split(all_chain_sequences, clusterings, to_use, split_size, min_fam_in_split): """ Create a split while retaining diversity specified by min_fam_in_split. Returns split and removes any pdbs in this split from the remaining dataset """ dataset_size = len(all_chain_sequences) code_to_idx = { fi.get_pdb_code(y[0][0]): i for (i, x) in enumerate(all_chain_sequences) for y in x } all_indices = set(range(dataset_size)) split, used = set(), all_indices.difference(to_use) while len(split) < split_size: i = random.sample(to_use, 1)[0] pdb_code = fi.get_pdb_code(all_chain_sequences[i][0][0][0]) found = seq.find_cluster_members(pdb_code, clusterings) # Map back to source. found = set([code_to_idx[x] for x in found]) found = found.difference(used) # ensure that at least min_fam_in_split families in each split max_fam_size = int(math.ceil(split_size / min_fam_in_split)) split = split.union(list(found)[:max_fam_size]) to_use = to_use.difference(found) used = used.union(found) return split, to_use
def filter_fn(df): to_keep = {} for e, ensemble in df.groupby(['ensemble']): names, (bdf0, bdf1, udf0, udf1) = nb.get_subunits(ensemble) chains0 = bdf0[['structure', 'chain']].drop_duplicates() chains1 = bdf1[['structure', 'chain']].drop_duplicates() chains0['pdb_code'] = chains0['structure'].apply( lambda x: fi.get_pdb_code(x).lower()) chains1['pdb_code'] = chains1['structure'].apply( lambda x: fi.get_pdb_code(x).lower()) scop0, scop1 = [], [] for (pc, c) in chains0[['pdb_code', 'chain']].to_numpy(): if (pc, c) in scop_index: scop0.append(scop_index.loc[(pc, c)].values) for (pc, c) in chains1[['pdb_code', 'chain']].to_numpy(): if (pc, c) in scop_index: scop1.append(scop_index.loc[(pc, c)].values) scop0 = list(np.unique(np.concatenate(scop0))) \ if len(scop0) > 0 else [] scop1 = list(np.unique(np.concatenate(scop1))) \ if len(scop1) > 0 else [] pairs = [tuple(sorted((a, b))) for a in scop0 for b in scop1] to_keep[e] = True for p in pairs: if p in scop_pairs: to_keep[e] = False to_keep = pd.Series(to_keep)[df['ensemble']] return df[to_keep.values]
def get_idx_mapping(self): pdb_idx_dict = {} i = 0 for file in self.raw_file_names: if '_pocket' in file: pdb_code = fi.get_pdb_code(file) pdb_idx_dict[pdb_code] = i i += 1 return pdb_idx_dict
def form_scop_against(): result = [] for x in dataset: for (e, su, st), structure in x['atoms'].groupby( ['ensemble', 'subunit', 'structure']): pc = fi.get_pdb_code(st).lower() for (m, c), _ in structure.groupby(['model', 'chain']): if (pc, c) in scop_index: result.append(scop_index.loc[(pc, c)].values) return np.unique(np.concatenate(result))
def __call__(self, x): x['id'] = fi.get_pdb_code(x['id']) df = x['atoms'] subunits = [] # df = df.set_index(['chain', 'residue', 'resname'], drop=False) df = df.dropna(subset=['x', 'y', 'z']) #remove Hets and non-allowable atoms df = df[df['element'].isin(allowed_atoms)] df = df[df['hetero'].str.strip() == ''] for chain_res, res_df in df.groupby(['chain', 'residue', 'resname']): # chain_res = res_df.index.values[0] # names.append('_'.join([str(x) for x in name])) chain, res, res_name = chain_res # only train on canonical residues if res_name not in res_label_dict: continue # sample each residue based on its frequency in train data if self.balance: if not np.random.random() < res_wt_dict[res_name]: continue if not np.all([b in res_df['name'].to_list() for b in bb_atoms]): # print('residue missing atoms... skipping') continue CA_pos = res_df[res_df['name'] == 'CA'][['x', 'y', 'z']].astype( np.float32).to_numpy()[0] CB_pos = CA_pos + (np.ones_like(CA_pos) * gly_CB_mu) # remove current residue from structure subunit_df = df[(df.chain != chain) | (df.residue != res)] # add backbone atoms back in res_bb = res_df[res_df['name'].isin(bb_atoms)] subunit_df = pd.concat([subunit_df, res_bb]).reset_index(drop=True) # environment = all atoms within 10*sqrt(3) angstroms (to enable a 20A cube) kd_tree = scipy.spatial.KDTree(subunit_df[['x', 'y', 'z']].to_numpy()) subunit_pt_idx = kd_tree.query_ball_point(CB_pos, r=10.0 * np.sqrt(3), p=2.0) sub_df = subunit_df.loc[subunit_pt_idx] tmp = sub_df.copy() tmp['subunit'] = '_'.join([str(x) for x in chain_res]) subunits.append(tmp) if len(subunits) == 0: subunits = pd.DataFrame(columns=df.columns) else: subunits = pd.concat(subunits).reset_index(drop=True) x['atoms'] = subunits return x
def filter_fn(df): to_keep = {} for (e, su, st), structure in df.groupby( ['ensemble', 'subunit', 'structure']): pc = fi.get_pdb_code(st).lower() for (m, c), _ in structure.groupby(['model', 'chain']): if (pc, c) in scop_index: scop_found = scop_index.loc[(pc, c)].values if np.isin(scop_found, scop_against).any(): to_keep[(st, m, c)] = False else: to_keep[(st, m, c)] = True else: to_keep[(st, m, c)] = not conservative to_keep = \ pd.Series(to_keep)[pd.Index(df[['structure', 'model', 'chain']])] return df[to_keep.values]
def __getitem__(self, index: int): if not 0 <= index < self._num_examples: raise IndexError(index) protein = self._protein_dataset[index] pocket = self._pocket_dataset[index] ligand = self._ligand_dataset[index] pdbcode = fi.get_pdb_code(protein['id']) item = { 'atoms_protein': protein['atoms'], 'atoms_pocket': pocket['atoms'], 'atoms_ligand': ligand['atoms'], 'id': pdbcode, 'seq': protein['seq'], 'smiles': ligand['smiles'], } if self._transform: item = self._transform(item) return item
def process(self): label_file = os.path.join(self.root, 'pdbbind_refined_set_labels.csv') label_df = pd.read_csv(label_file) i = 0 for raw_path in self.raw_paths: pdb_code = fi.get_pdb_code(raw_path) y = torch.FloatTensor([get_label(pdb_code, label_df)]) if '_ligand' in raw_path: mol_graph = graph.mol_to_graph( dt.read_sdf_to_mol(raw_path, add_hs=True)[0]) elif '_pocket' in raw_path: prot_graph = graph.prot_df_to_graph( dt.bp_to_df(dt.read_any(raw_path, name=pdb_code))) node_feats, edge_index, edge_feats, pos = graph.combine_graphs( prot_graph, mol_graph, edges_between=True) data = Data(node_feats, edge_index, edge_feats, y=y, pos=pos) data.pdb = pdb_code torch.save( data, os.path.join(self.processed_dir, 'data_{}.pt'.format(i))) i += 1 else: continue
def main(datapath, out_path): valid_pdbs = [fi.get_pdb_code(f) for f in fi.find_files(out_path, 'sdf')] dat = [] with open(os.path.join(datapath, 'index/INDEX_refined_data.2019')) as f: for line in f: if line.startswith('#'): continue l = line.strip().split() if l[0] not in valid_pdbs: continue dat.append(l[:5] + l[6:]) refined_set = pd.DataFrame(dat, columns=[ 'pdb', 'res', 'year', 'neglog_aff', 'affinity', 'ref', 'ligand' ]) refined_set[['measurement', 'affinity']] = refined_set['affinity'].str.split('=', expand=True) refined_set['ligand'] = refined_set['ligand'].str.strip('()') # Remove peptide ligands # - refined set size now 4,598 # refined_set = refined_set[["-mer" not in l for l in refined_set.ligand]] refined_set.to_csv(os.path.join(out_path, 'pdbbind_refined_set_cleaned.csv'), index=False) labels = refined_set[['pdb', 'neglog_aff' ]].rename(columns={'neglog_aff': 'label'}) labels.to_csv(os.path.join(out_path, 'pdbbind_refined_set_labels.csv'), index=False)
def convert_to_hdf5(input_dir, label_file, hdf_file): cif_files = fi.find_files(input_dir, 'cif') proteins = [] pockets = [] pdb_codes = [] for f in tqdm(cif_files, desc='reading structures'): pdb_code = fi.get_pdb_code(f) if '_protein' in f: pdb_codes.append(pdb_code) df = dt.bp_to_df(dt.read_any(f)) proteins.append(df) elif '_pocket' in f: df = dt.bp_to_df(dt.read_any(f)) pockets.append(df) print('converting proteins...') protein_df = pd.concat(proteins) pocket_df = pd.concat(pockets) pdb_codes = pd.DataFrame({'pdb': pdb_codes}) protein_df.to_hdf(hdf_file, 'proteins') pocket_df.to_hdf(hdf_file, 'pockets') pdb_codes.to_hdf(hdf_file, 'pdb_codes') print('converting ligands...') sdf_files = fi.find_files(input_dir, 'sdf') big_sdf = os.path.join(input_dir, 'all_ligands.sdf') dt.combine_sdfs(sdf_files, big_sdf) lig_df = PandasTools.LoadSDF(big_sdf, molColName='Mol') lig_df.index = pdb_codes lig_df.to_hdf(hdf_file, 'ligands') print('converting labels...') label_df = pd.read_csv(label_file) label_df = label_df.set_index('pdb').reindex(pdb_codes) label_df.to_hdf(hdf_file, 'labels')
def cluster_split(dataset, cutoff, val_split=0.1, test_split=0.1, min_fam_in_split=5, random_seed=None): """ Splits pdb dataset using pre-computed sequence identity clusters from PDB. Generates train, val, test sets. We assume there is one PDB code per entry in dataset. Args: cutoff (float): sequence identity cutoff (can be .3, .4, .5, .7, .9, .95, 1.0) val_split (float): fraction of data used for validation. Default: 0.1 test_split (float): fraction of data used for testing. Default: 0.1 min_fam_in_split (int): controls variety of val/test sets. Default: 5 random_seed (int): specifies random seed for shuffling. Default: None Returns: train_set (str[]): pdbs in the train set val_set (str[]): pdbs in the validation set test_set (str[]): pdbs in the test set """ if random_seed is not None: np.random.seed(random_seed) logger.info('Loading chain sequences') all_chain_sequences = [ seq.get_chain_sequences(x['atoms']) for x in dataset ] pdb_codes = np.array( [fi.get_pdb_code(x[0][0][0]) for x in all_chain_sequences]) n_orig = len(np.unique(pdb_codes)) clusterings = seq.get_pdb_clusters(cutoff, np.unique(pdb_codes)) # If code not present in clustering, we don't use. to_use = [i for (i, x) in enumerate(pdb_codes) if x in clusterings[0]] n = len(np.unique(pdb_codes[to_use])) to_use = set(to_use) logger.info(f'Removing {n_orig - n:} / {n_orig:} ' f'sequences due to not finding in clustering.') test_size = n * test_split val_size = n * val_split logger.info('generating validation set...') val_indices, to_use = _create_cluster_split(all_chain_sequences, clusterings, to_use, val_size, min_fam_in_split) logger.info('generating test set...') test_indices, to_use = _create_cluster_split(all_chain_sequences, clusterings, to_use, test_size, min_fam_in_split) train_indices = to_use return splits.split(dataset, train_indices, val_indices, test_indices)
def cluster_split(dataset, cutoff, val_split=0.1, test_split=0.1, min_fam_in_split=5, random_seed=None): """Splits pdb dataset using pre-computed sequence identity clusters from PDB, ensuring that no cluster spans multiple splits. Clusters are selected randomly into validation and test sets, but to ensure that there is some diversity in each set (i.e. a split does not consist of a single sequence cluster), a minimum number of clusters in each split is enforced. . Some data examples may be removed in order to satisfy this constraint. This function assumes that the PDB code or PDB filename exists in the ``ensemble`` field of the ``atoms`` dataframe in the dataset. :param dataset: Dataset to perform the split on. :type dataset: ATOM3D Dataset :param cutoff: Sequence identity cutoff. Possible values: 0.3, 0.4, 0.5, 0.7, 0.9, 0.95, 1.0 :type cutoff: float :param val_split: Fraction of data used in validation set, defaults to 0.1 :type val_split: float, optional :param test_split: Fraction of data used in test set, defaults to 0.1 :type test_split: float, optional :param min_fam_in_split: Minimum number of sequence clusters to be included in validation and test sets, defaults to 5 :type min_fam_in_split: int, optional :param random_seed: Random seed for sampling clusters, defaults to None :type random_seed: int, optional :return: Tuple containing training, validation, and test sets, each as ATOM3D Dataset objects. :rtype: Tuple[Dataset] """ if random_seed is not None: np.random.seed(random_seed) logger.info('Loading chain sequences') all_chain_sequences = [ seq.get_chain_sequences(x['atoms']) for x in dataset ] pdb_codes = np.array( [fi.get_pdb_code(x[0][0][0]) for x in all_chain_sequences]) n_orig = len(np.unique(pdb_codes)) clusterings = seq.get_pdb_clusters(cutoff, np.unique(pdb_codes)) # If code not present in clustering, we don't use. to_use = [i for (i, x) in enumerate(pdb_codes) if x in clusterings[0]] n = len(np.unique(pdb_codes[to_use])) to_use = set(to_use) logger.info(f'Removing {n_orig - n:} / {n_orig:} ' f'sequences due to not finding in clustering.') test_size = n * test_split val_size = n * val_split logger.info('generating validation set...') val_indices, to_use = _create_cluster_split(all_chain_sequences, clusterings, to_use, val_size, min_fam_in_split) logger.info('generating test set...') test_indices, to_use = _create_cluster_split(all_chain_sequences, clusterings, to_use, test_size, min_fam_in_split) train_indices = to_use return splits.split(dataset, train_indices, val_indices, test_indices)
def pdb_id_transform(x): x['id'] = fi.get_pdb_code(x['id']) return x
def get_all_chain_sequences(pdb_dataset): """Return list of tuples of (pdb_code, chain_sequences) for PDB dataset.""" return [((fi.get_pdb_code(p), ), get_chain_sequences(p)) for p in tqdm.tqdm(pdb_dataset)]
def test_get_pdb_code(): codes = [] for path in test_file_list: codes.append(fi.get_pdb_code(path)) assert codes == ['103l','117e','11as','2olx']
def form_scop_pair_filter_against(sharded, level): """Remove pairs that have matching scop classes in both subunits.""" scop_index = scop.get_scop_index()[level] scop_pairs = [] for _, shard in sh.iter_shards(sharded): for e, ensemble in shard.groupby(['ensemble']): names, (bdf0, bdf1, udf0, udf1) = nb.get_subunits(ensemble) chains0 = bdf0[['structure', 'chain']].drop_duplicates() chains1 = bdf1[['structure', 'chain']].drop_duplicates() chains0['pdb_code'] = chains0['structure'].apply( lambda x: fi.get_pdb_code(x).lower()) chains1['pdb_code'] = chains1['structure'].apply( lambda x: fi.get_pdb_code(x).lower()) scop0, scop1 = [], [] for (pc, c) in chains0[['pdb_code', 'chain']].to_numpy(): if (pc, c) in scop_index: scop0.append(scop_index.loc[(pc, c)].values) for (pc, c) in chains1[['pdb_code', 'chain']].to_numpy(): if (pc, c) in scop_index: scop1.append(scop_index.loc[(pc, c)].values) scop0 = list(np.unique(np.concatenate(scop0))) \ if len(scop0) > 0 else [] scop1 = list(np.unique(np.concatenate(scop1))) \ if len(scop1) > 0 else [] pairs = [tuple(sorted((a, b))) for a in scop0 for b in scop1] scop_pairs.extend(pairs) scop_pairs = set(scop_pairs) def filter_fn(df): to_keep = {} for e, ensemble in df.groupby(['ensemble']): names, (bdf0, bdf1, udf0, udf1) = nb.get_subunits(ensemble) chains0 = bdf0[['structure', 'chain']].drop_duplicates() chains1 = bdf1[['structure', 'chain']].drop_duplicates() chains0['pdb_code'] = chains0['structure'].apply( lambda x: fi.get_pdb_code(x).lower()) chains1['pdb_code'] = chains1['structure'].apply( lambda x: fi.get_pdb_code(x).lower()) scop0, scop1 = [], [] for (pc, c) in chains0[['pdb_code', 'chain']].to_numpy(): if (pc, c) in scop_index: scop0.append(scop_index.loc[(pc, c)].values) for (pc, c) in chains1[['pdb_code', 'chain']].to_numpy(): if (pc, c) in scop_index: scop1.append(scop_index.loc[(pc, c)].values) scop0 = list(np.unique(np.concatenate(scop0))) \ if len(scop0) > 0 else [] scop1 = list(np.unique(np.concatenate(scop1))) \ if len(scop1) > 0 else [] pairs = [tuple(sorted((a, b))) for a in scop0 for b in scop1] to_keep[e] = True for p in pairs: if p in scop_pairs: to_keep[e] = False to_keep = pd.Series(to_keep)[df['ensemble']] return df[to_keep.values] return filter_fn