def split(input_sharded, output_root): """Split temporally and shuffle examples.""" prefix = sh.get_prefix(output_root) num_shards = sh.get_num_shards(output_root) # Re-key to ensemble, subunit. To allow for shuffling across targets. tmp_sharded = sh.Sharded( f'{prefix:}_rekeyed@{num_shards:}', ['ensemble', 'subunit']) logger.info("Rekeying") sho.rekey(input_sharded, tmp_sharded) keys = tmp_sharded.get_keys() train_sharded = sh.Sharded(f'{prefix:}_train@{num_shards:}', keys) val_sharded = sh.Sharded(f'{prefix:}_val@{num_shards:}', keys) test_sharded = sh.Sharded(f'{prefix:}_test@{num_shards:}', keys) logger.info("Splitting") train_filter_fn = filters.form_filter_against_list(TRAIN, 'ensemble') val_filter_fn = filters.form_filter_against_list(VAL, 'ensemble') test_filter_fn = filters.form_filter_against_list(TEST, 'ensemble') sho.filter_sharded( tmp_sharded, train_sharded, train_filter_fn, num_shards) sho.filter_sharded( tmp_sharded, val_sharded, val_filter_fn, num_shards) sho.filter_sharded( tmp_sharded, test_sharded, test_filter_fn, num_shards) tmp_sharded.delete_files()
def split(input_sharded, output_root, info_csv, shuffle_buffer): """Split by protein.""" if input_sharded.get_keys() != ['ensemble']: raise RuntimeError('Can only apply to sharded by ensemble.') info = pd.read_csv(info_csv) info['ensemble'] = info.apply( lambda x: x['ligand'] + '__' + x['active_struc'].split('_')[ 2] + '__' + x['inactive_struc'].split('_')[2], axis=1) info = info.set_index('ensemble') # Remove duplicate ensembles. info = info[~info.index.duplicated()] ensembles = input_sharded.get_names()['ensemble'] in_use = info.loc[ensembles] active = in_use[in_use['label'] == 'A'] inactive = in_use[in_use['label'] == 'I'] # Split by protein. proteins = info['protein'].unique() i_test, i_val, i_train = splits.random_split(len(proteins), 0.6, 0.2, 0.2) p_train = proteins[i_train] p_val = proteins[i_val] p_test = proteins[i_test] logger.info(f'Train proteins: {p_train:}') logger.info(f'Val proteins: {p_val:}') logger.info(f'Test proteins: {p_test:}') train = info[info['protein'].isin(p_train)].index.tolist() val = info[info['protein'].isin(p_val)].index.tolist() test = info[info['protein'].isin(p_test)].index.tolist() logger.info(f'{len(train):} train examples, {len(val):} val examples, ' f'{len(test):} test examples.') keys = input_sharded.get_keys() prefix = sh.get_prefix(output_root) num_shards = sh.get_num_shards(output_root) train_sharded = sh.Sharded(f'{prefix:}_train@{num_shards:}', keys) val_sharded = sh.Sharded(f'{prefix:}_val@{num_shards:}', keys) test_sharded = sh.Sharded(f'{prefix:}_test@{num_shards:}', keys) train_filter_fn = filters.form_filter_against_list(train, 'ensemble') val_filter_fn = filters.form_filter_against_list(val, 'ensemble') test_filter_fn = filters.form_filter_against_list(test, 'ensemble') sho.filter_sharded(input_sharded, train_sharded, train_filter_fn, shuffle_buffer) sho.filter_sharded(input_sharded, val_sharded, val_filter_fn, shuffle_buffer) sho.filter_sharded(input_sharded, test_sharded, test_filter_fn, shuffle_buffer)
def prepare(input_sharded_path, output_root, score_dir): input_sharded = sh.Sharded.load(input_sharded_path) if score_dir is not None: prefix = sh.get_prefix(output_root) num_shards = sh.get_num_shards(output_root) keys = input_sharded.get_keys() filter_sharded = sh.Sharded(f'{prefix:}_filtered@{num_shards:}', keys) filter_fn = sc.form_score_filter(score_dir) logger.info('Filtering against score file.') sho.filter_sharded(input_sharded, filter_sharded, filter_fn) split(filter_sharded, output_root) filter_sharded.delete_files() else: split(input_sharded, output_root)
def split(input_sharded, output_root, scaffold_data, shuffle_buffer): """Split by scaffold.""" if input_sharded.get_keys() != ['ensemble']: raise RuntimeError('Can only apply to sharded by ensemble.') logger.info('Process scaffold and ensemble data') scaffold_list = scaffold_data['Scaffold'].tolist() ensemble_list = [] for i in range(len(scaffold_data)): ensemble = '' ensemble += scaffold_data.iloc[i]['ligand'] ensemble += '__'+scaffold_data.iloc[i]['active_struc'].split('_')[-1] ensemble += '__'+scaffold_data.iloc[i]['inactive_struc'].split('_')[-1] ensemble_list.append(ensemble) ensemble_list = np.array(ensemble_list) logger.info('Splitting by scaffold') train_idx, val_idx, test_idx = splits.scaffold_split(scaffold_list) train = ensemble_list[train_idx].tolist() val = ensemble_list[val_idx].tolist() test = ensemble_list[test_idx].tolist() keys = input_sharded.get_keys() if keys != ['ensemble']: raise RuntimeError('Can only apply to sharded by ensemble.') prefix = sh.get_prefix(output_root) num_shards = sh.get_num_shards(output_root) train_sharded = sh.Sharded(f'{prefix:}_train@{num_shards:}', keys) val_sharded = sh.Sharded(f'{prefix:}_val@{num_shards:}', keys) test_sharded = sh.Sharded(f'{prefix:}_test@{num_shards:}', keys) logger.info('Writing sets') train_filter_fn = filters.form_filter_against_list(train, 'ensemble') val_filter_fn = filters.form_filter_against_list(val, 'ensemble') test_filter_fn = filters.form_filter_against_list(test, 'ensemble') sho.filter_sharded( input_sharded, train_sharded, train_filter_fn, shuffle_buffer) sho.filter_sharded( input_sharded, val_sharded, val_filter_fn, shuffle_buffer) sho.filter_sharded( input_sharded, test_sharded, test_filter_fn, shuffle_buffer) # write splits to text files np.savetxt(output_root.split('@')[0]+'_train.txt', train, fmt='%s') np.savetxt(output_root.split('@')[0]+'_val.txt', val, fmt='%s') np.savetxt(output_root.split('@')[0]+'_test.txt', test, fmt='%s')
def split(input_sharded, output_root, shuffle_buffer, cutoff=30): """Split by sequence identity.""" if input_sharded.get_keys() != ['ensemble']: raise RuntimeError('Can only apply to sharded by ensemble.') all_chain_sequences = [] logger.info('Loading chain sequences') for _, shard in input_sharded.iter_shards(): all_chain_sequences.extend(seq.get_all_chain_sequences_df(shard)) logger.info('Splitting by cluster') train, val, test = atom3d.splits.sequence.cluster_split( all_chain_sequences, cutoff) # Will just look up ensembles. train = [x[0] for x in train] val = [x[0] for x in val] test = [x[0] for x in test] keys = input_sharded.get_keys() if keys != ['ensemble']: raise RuntimeError('Can only apply to sharded by ensemble.') prefix = sh.get_prefix(output_root) num_shards = sh.get_num_shards(output_root) train_sharded = sh.Sharded(f'{prefix:}_train@{num_shards:}', keys) val_sharded = sh.Sharded(f'{prefix:}_val@{num_shards:}', keys) test_sharded = sh.Sharded(f'{prefix:}_test@{num_shards:}', keys) logger.info('Writing sets') train_filter_fn = filters.form_filter_against_list(train, 'ensemble') val_filter_fn = filters.form_filter_against_list(val, 'ensemble') test_filter_fn = filters.form_filter_against_list(test, 'ensemble') sho.filter_sharded(input_sharded, train_sharded, train_filter_fn, shuffle_buffer) sho.filter_sharded(input_sharded, val_sharded, val_filter_fn, shuffle_buffer) sho.filter_sharded(input_sharded, test_sharded, test_filter_fn, shuffle_buffer) # write splits to text files np.savetxt(output_root.split('@')[0] + '_train.txt', train, fmt='%s') np.savetxt(output_root.split('@')[0] + '_val.txt', val, fmt='%s') np.savetxt(output_root.split('@')[0] + '_test.txt', test, fmt='%s')
def split(input_sharded, output_root, scaffold_data, shuffle_buffer): """Split by sequence identity.""" if input_sharded.get_keys() != ['ensemble']: raise RuntimeError('Can only apply to sharded by ensemble.') logger.info('Splitting by scaffold') scaffold_list = scaffold_data['Scaffold'].tolist() train_idx, val_idx, test_idx = splits.scaffold_split(scaffold_list) train = scaffold_data['pdb'][train_idx].tolist() val = scaffold_data['pdb'][val_idx].tolist() test = scaffold_data['pdb'][test_idx].tolist() keys = input_sharded.get_keys() if keys != ['ensemble']: raise RuntimeError('Can only apply to sharded by ensemble.') prefix = sh.get_prefix(output_root) num_shards = sh.get_num_shards(output_root) train_sharded = sh.Sharded(f'{prefix:}_train@{num_shards:}', keys) val_sharded = sh.Sharded(f'{prefix:}_val@{num_shards:}', keys) test_sharded = sh.Sharded(f'{prefix:}_test@{num_shards:}', keys) logger.info('Writing sets') train_filter_fn = filters.form_filter_against_list(train, 'ensemble') val_filter_fn = filters.form_filter_against_list(val, 'ensemble') test_filter_fn = filters.form_filter_against_list(test, 'ensemble') sho.filter_sharded(input_sharded, train_sharded, train_filter_fn, shuffle_buffer) sho.filter_sharded(input_sharded, val_sharded, val_filter_fn, shuffle_buffer) sho.filter_sharded(input_sharded, test_sharded, test_filter_fn, shuffle_buffer) # write splits to text files np.savetxt(output_root.split('@')[0] + '_train.txt', train, fmt='%s') np.savetxt(output_root.split('@')[0] + '_val.txt', val, fmt='%s') np.savetxt(output_root.split('@')[0] + '_test.txt', test, fmt='%s')