def process_sharded_single(info_df, sharded): data = [] for _, shard_df in sharded.iter_shards(): for ensemble_name, ensemble_df in shard_df.groupby(['ensemble']): active_name = __get_subunit_name(ensemble_df.subunit.unique(), mode='active') struct_df = ensemble_df[ensemble_df.subunit == active_name] protein_df = struct_df[struct_df.chain != 'L'] ligand_name = ensemble_name.split('_')[0] info = info_df.loc[( ligand_name, protein_df.structure.unique()[0].split('.')[0])] # Sequence chain_sequences = ps.get_all_chain_sequences_df(protein_df) assert len(chain_sequences) == 1 seq = [] for (_, s) in chain_sequences[0][1]: seq.append(s) seq = '\n'.join(seq) data.append({ 'ligand': ligand_name, 'protein': info.protein, 'label': (info.label == 'A'), 'smiles': info.SMILES, 'seq': seq, }) data_df = pd.DataFrame( data, columns=['ligand', 'protein', 'label', 'smiles', 'seq']) return data_df
def filter_fn(df): to_keep = {} for structure_name, cs in seq.get_all_chain_sequences_df(df): hits = seq.find_similar(cs, blast_db_path, cutoff, 1) ensemble = structure_name[0] to_keep[ensemble] = (len(hits) == 0) to_keep = pd.Series(to_keep)[df['ensemble']] return df[to_keep.values]
def split(input_sharded, output_root, shuffle_buffer, cutoff=30): """Split by sequence identity.""" if input_sharded.get_keys() != ['ensemble']: raise RuntimeError('Can only apply to sharded by ensemble.') all_chain_sequences = [] logger.info('Loading chain sequences') for _, shard in input_sharded.iter_shards(): all_chain_sequences.extend(seq.get_all_chain_sequences_df(shard)) logger.info('Splitting by cluster') train, val, test = atom3d.splits.sequence.cluster_split( all_chain_sequences, cutoff) # Will just look up ensembles. train = [x[0] for x in train] val = [x[0] for x in val] test = [x[0] for x in test] keys = input_sharded.get_keys() if keys != ['ensemble']: raise RuntimeError('Can only apply to sharded by ensemble.') prefix = sh.get_prefix(output_root) num_shards = sh.get_num_shards(output_root) train_sharded = sh.Sharded(f'{prefix:}_train@{num_shards:}', keys) val_sharded = sh.Sharded(f'{prefix:}_val@{num_shards:}', keys) test_sharded = sh.Sharded(f'{prefix:}_test@{num_shards:}', keys) logger.info('Writing sets') train_filter_fn = filters.form_filter_against_list(train, 'ensemble') val_filter_fn = filters.form_filter_against_list(val, 'ensemble') test_filter_fn = filters.form_filter_against_list(test, 'ensemble') sho.filter_sharded(input_sharded, train_sharded, train_filter_fn, shuffle_buffer) sho.filter_sharded(input_sharded, val_sharded, val_filter_fn, shuffle_buffer) sho.filter_sharded(input_sharded, test_sharded, test_filter_fn, shuffle_buffer) # write splits to text files np.savetxt(output_root.split('@')[0] + '_train.txt', train, fmt='%s') np.savetxt(output_root.split('@')[0] + '_val.txt', val, fmt='%s') np.savetxt(output_root.split('@')[0] + '_test.txt', test, fmt='%s')
def form_seq_filter_against(sharded, cutoff): """ Remove structures with too much sequence identity to a chain in sharded. We consider each chain in each structure separately, and remove the structure if any of them matches any chain in sharded. """ blast_db_path = f'{sharded.path:}.db' all_chain_sequences = [] for _, shard in sharded.iter_shards(): all_chain_sequences.extend(seq.get_all_chain_sequences_df(shard)) seq.write_to_blast_db(all_chain_sequences, blast_db_path) def filter_fn(df): to_keep = {} for structure_name, cs in seq.get_all_chain_sequences_df(df): hits = seq.find_similar(cs, blast_db_path, cutoff, 1) ensemble = structure_name[0] to_keep[ensemble] = (len(hits) == 0) to_keep = pd.Series(to_keep)[df['ensemble']] return df[to_keep.values] return filter_fn