def shard_envs(input_path, output_path, num_threads=8, subsample=True): input_sharded = sh.Sharded.load(input_path) keys = input_sharded.get_keys() if keys != ['ensemble']: raise RuntimeError('Can only apply to sharded by ensemble.') output_sharded = sh.Sharded(output_path, keys) input_num_shards = input_sharded.get_num_shards() tmp_path = output_sharded.get_prefix() + f'_tmp@{input_num_shards:}' tmp_sharded = sh.Sharded(tmp_path, keys) not_written = [] for i in range(input_num_shards): shard = output_sharded._get_shard(i) if not os.path.exists(shard): not_written.append(i) print(f'Using {num_threads:} threads') inputs = [(input_sharded, tmp_sharded, shard_num, subsample) for shard_num in range(8)] # with multiprocessing.Pool(processes=num_threads) as pool: # pool.starmap(_shard_envs, inputs) par.submit_jobs(_shard_envs, inputs, num_threads) sho.reshard(tmp_sharded, output_sharded) tmp_sharded.delete_files()
def split(input_sharded, output_root): """Split temporally and shuffle examples.""" prefix = sh.get_prefix(output_root) num_shards = sh.get_num_shards(output_root) # Re-key to ensemble, subunit. To allow for shuffling across targets. tmp_sharded = sh.Sharded( f'{prefix:}_rekeyed@{num_shards:}', ['ensemble', 'subunit']) logger.info("Rekeying") sho.rekey(input_sharded, tmp_sharded) keys = tmp_sharded.get_keys() train_sharded = sh.Sharded(f'{prefix:}_train@{num_shards:}', keys) val_sharded = sh.Sharded(f'{prefix:}_val@{num_shards:}', keys) test_sharded = sh.Sharded(f'{prefix:}_test@{num_shards:}', keys) logger.info("Splitting") train_filter_fn = filters.form_filter_against_list(TRAIN, 'ensemble') val_filter_fn = filters.form_filter_against_list(VAL, 'ensemble') test_filter_fn = filters.form_filter_against_list(TEST, 'ensemble') sho.filter_sharded( tmp_sharded, train_sharded, train_filter_fn, num_shards) sho.filter_sharded( tmp_sharded, val_sharded, val_filter_fn, num_shards) sho.filter_sharded( tmp_sharded, test_sharded, test_filter_fn, num_shards) tmp_sharded.delete_files()
def split(input_sharded, output_root, info_csv, shuffle_buffer): """Split by protein.""" if input_sharded.get_keys() != ['ensemble']: raise RuntimeError('Can only apply to sharded by ensemble.') info = pd.read_csv(info_csv) info['ensemble'] = info.apply( lambda x: x['ligand'] + '__' + x['active_struc'].split('_')[ 2] + '__' + x['inactive_struc'].split('_')[2], axis=1) info = info.set_index('ensemble') # Remove duplicate ensembles. info = info[~info.index.duplicated()] ensembles = input_sharded.get_names()['ensemble'] in_use = info.loc[ensembles] active = in_use[in_use['label'] == 'A'] inactive = in_use[in_use['label'] == 'I'] # Split by protein. proteins = info['protein'].unique() i_test, i_val, i_train = splits.random_split(len(proteins), 0.6, 0.2, 0.2) p_train = proteins[i_train] p_val = proteins[i_val] p_test = proteins[i_test] logger.info(f'Train proteins: {p_train:}') logger.info(f'Val proteins: {p_val:}') logger.info(f'Test proteins: {p_test:}') train = info[info['protein'].isin(p_train)].index.tolist() val = info[info['protein'].isin(p_val)].index.tolist() test = info[info['protein'].isin(p_test)].index.tolist() logger.info(f'{len(train):} train examples, {len(val):} val examples, ' f'{len(test):} test examples.') keys = input_sharded.get_keys() prefix = sh.get_prefix(output_root) num_shards = sh.get_num_shards(output_root) train_sharded = sh.Sharded(f'{prefix:}_train@{num_shards:}', keys) val_sharded = sh.Sharded(f'{prefix:}_val@{num_shards:}', keys) test_sharded = sh.Sharded(f'{prefix:}_test@{num_shards:}', keys) train_filter_fn = filters.form_filter_against_list(train, 'ensemble') val_filter_fn = filters.form_filter_against_list(val, 'ensemble') test_filter_fn = filters.form_filter_against_list(test, 'ensemble') sho.filter_sharded(input_sharded, train_sharded, train_filter_fn, shuffle_buffer) sho.filter_sharded(input_sharded, val_sharded, val_filter_fn, shuffle_buffer) sho.filter_sharded(input_sharded, test_sharded, test_filter_fn, shuffle_buffer)
def filter_pairs(input_sharded_path, output_root, bsa, against_path, shuffle_buffer): input_sharded = sh.Sharded.load(input_sharded_path) keys = input_sharded.get_keys() if keys != ['ensemble']: raise RuntimeError('Can only apply to sharded by ensemble.') output_sharded = sh.Sharded(output_root, keys) # We form the combined filter by starting with the identity filter and # composing with further filters. filter_fn = filters.identity_filter filter_fn = filters.compose( atom3d.filters.pdb.form_molecule_type_filter(allowed=['prot']), filter_fn) filter_fn = filters.compose(filters.form_size_filter(min_size=50), filter_fn) filter_fn = filters.compose(atom3d.filters.pdb.form_resolution_filter(3.5), filter_fn) filter_fn = filters.compose( atom3d.filters.pdb.form_source_filter(allowed=['diffraction', 'EM']), filter_fn) if bsa is not None: filter_fn = filters.compose(form_bsa_filter(bsa, 500), filter_fn) if against_path is not None: against = sh.Sharded.load(against_path) filter_fn = filters.compose( atom3d.filters.sequence.form_seq_filter_against(against, 0.3), filter_fn) filter_fn = filters.compose( form_scop_pair_filter_against(against, 'superfamily'), filter_fn) sho.filter_sharded(input_sharded, output_sharded, filter_fn) split(output_sharded, output_root, shuffle_buffer)
def filter_sharded(input_sharded, output_sharded, filter_fn, shuffle_buffer=0): """Filter sharded dataset to new sharded dataset, using provided filter.""" logging.basicConfig(format='%(asctime)s %(levelname)s %(process)d: ' + '%(message)s', level=logging.INFO) if not os.path.exists(os.path.dirname(output_sharded.path)): os.makedirs(os.path.dirname(output_sharded.path)) input_num_shards = input_sharded.get_num_shards() # We will just map to tmp, then reshard. tmp_path = output_sharded.get_prefix() + f'_tmp@{input_num_shards:}' tmp_sharded = sh.Sharded(tmp_path, input_sharded.get_keys()) logging.info(f'Filtering {input_sharded.path:} to {output_sharded.path:}') # Apply filter. for shard_num in tqdm.trange(input_num_shards): df = input_sharded.read_shard(shard_num) if len(df) > 0: df = filter_fn(df) tmp_sharded._write_shard(shard_num, df) num_input_structures = input_sharded.get_num_keyed() num_output_structures = tmp_sharded.get_num_keyed() logging.info(f'After filtering, have {num_output_structures:} / ' f'{num_input_structures:} left.') reshard(tmp_sharded, output_sharded, shuffle_buffer) tmp_sharded.delete_files()
def split(input_sharded, output_root, scaffold_data, shuffle_buffer): """Split by scaffold.""" if input_sharded.get_keys() != ['ensemble']: raise RuntimeError('Can only apply to sharded by ensemble.') logger.info('Process scaffold and ensemble data') scaffold_list = scaffold_data['Scaffold'].tolist() ensemble_list = [] for i in range(len(scaffold_data)): ensemble = '' ensemble += scaffold_data.iloc[i]['ligand'] ensemble += '__'+scaffold_data.iloc[i]['active_struc'].split('_')[-1] ensemble += '__'+scaffold_data.iloc[i]['inactive_struc'].split('_')[-1] ensemble_list.append(ensemble) ensemble_list = np.array(ensemble_list) logger.info('Splitting by scaffold') train_idx, val_idx, test_idx = splits.scaffold_split(scaffold_list) train = ensemble_list[train_idx].tolist() val = ensemble_list[val_idx].tolist() test = ensemble_list[test_idx].tolist() keys = input_sharded.get_keys() if keys != ['ensemble']: raise RuntimeError('Can only apply to sharded by ensemble.') prefix = sh.get_prefix(output_root) num_shards = sh.get_num_shards(output_root) train_sharded = sh.Sharded(f'{prefix:}_train@{num_shards:}', keys) val_sharded = sh.Sharded(f'{prefix:}_val@{num_shards:}', keys) test_sharded = sh.Sharded(f'{prefix:}_test@{num_shards:}', keys) logger.info('Writing sets') train_filter_fn = filters.form_filter_against_list(train, 'ensemble') val_filter_fn = filters.form_filter_against_list(val, 'ensemble') test_filter_fn = filters.form_filter_against_list(test, 'ensemble') sho.filter_sharded( input_sharded, train_sharded, train_filter_fn, shuffle_buffer) sho.filter_sharded( input_sharded, val_sharded, val_filter_fn, shuffle_buffer) sho.filter_sharded( input_sharded, test_sharded, test_filter_fn, shuffle_buffer) # write splits to text files np.savetxt(output_root.split('@')[0]+'_train.txt', train, fmt='%s') np.savetxt(output_root.split('@')[0]+'_val.txt', val, fmt='%s') np.savetxt(output_root.split('@')[0]+'_test.txt', test, fmt='%s')
def split(input_sharded, output_root, shuffle_buffer, cutoff=30): """Split by sequence identity.""" if input_sharded.get_keys() != ['ensemble']: raise RuntimeError('Can only apply to sharded by ensemble.') all_chain_sequences = [] logger.info('Loading chain sequences') for _, shard in input_sharded.iter_shards(): all_chain_sequences.extend(seq.get_all_chain_sequences_df(shard)) logger.info('Splitting by cluster') train, val, test = atom3d.splits.sequence.cluster_split( all_chain_sequences, cutoff) # Will just look up ensembles. train = [x[0] for x in train] val = [x[0] for x in val] test = [x[0] for x in test] keys = input_sharded.get_keys() if keys != ['ensemble']: raise RuntimeError('Can only apply to sharded by ensemble.') prefix = sh.get_prefix(output_root) num_shards = sh.get_num_shards(output_root) train_sharded = sh.Sharded(f'{prefix:}_train@{num_shards:}', keys) val_sharded = sh.Sharded(f'{prefix:}_val@{num_shards:}', keys) test_sharded = sh.Sharded(f'{prefix:}_test@{num_shards:}', keys) logger.info('Writing sets') train_filter_fn = filters.form_filter_against_list(train, 'ensemble') val_filter_fn = filters.form_filter_against_list(val, 'ensemble') test_filter_fn = filters.form_filter_against_list(test, 'ensemble') sho.filter_sharded(input_sharded, train_sharded, train_filter_fn, shuffle_buffer) sho.filter_sharded(input_sharded, val_sharded, val_filter_fn, shuffle_buffer) sho.filter_sharded(input_sharded, test_sharded, test_filter_fn, shuffle_buffer) # write splits to text files np.savetxt(output_root.split('@')[0] + '_train.txt', train, fmt='%s') np.savetxt(output_root.split('@')[0] + '_val.txt', val, fmt='%s') np.savetxt(output_root.split('@')[0] + '_test.txt', test, fmt='%s')
def shard_pairs(input_path, output_path, cutoff, cutoff_type, num_threads): input_sharded = sh.Sharded.load(input_path) keys = input_sharded.get_keys() if keys != ['ensemble']: raise RuntimeError('Can only apply to sharded by ensemble.') output_sharded = sh.Sharded(output_path, keys) input_num_shards = input_sharded.get_num_shards() tmp_path = output_sharded.get_prefix() + f'_tmp@{input_num_shards:}' tmp_sharded = sh.Sharded(tmp_path, keys) logger.info(f'Using {num_threads:} threads') inputs = [(input_sharded, tmp_sharded, shard_num, cutoff, cutoff_type) for shard_num in range(input_num_shards)] par.submit_jobs(_shard_pairs, inputs, num_threads) sho.reshard(tmp_sharded, output_sharded) tmp_sharded.delete_files()
def prepare(input_sharded_path, output_root, score_dir): input_sharded = sh.Sharded.load(input_sharded_path) if score_dir is not None: prefix = sh.get_prefix(output_root) num_shards = sh.get_num_shards(output_root) keys = input_sharded.get_keys() filter_sharded = sh.Sharded(f'{prefix:}_filtered@{num_shards:}', keys) filter_fn = sc.form_score_filter(score_dir) logger.info('Filtering against score file.') sho.filter_sharded(input_sharded, filter_sharded, filter_fn) split(filter_sharded, output_root) filter_sharded.delete_files() else: split(input_sharded, output_root)
def split(input_sharded, output_root, scaffold_data, shuffle_buffer): """Split by sequence identity.""" if input_sharded.get_keys() != ['ensemble']: raise RuntimeError('Can only apply to sharded by ensemble.') logger.info('Splitting by scaffold') scaffold_list = scaffold_data['Scaffold'].tolist() train_idx, val_idx, test_idx = splits.scaffold_split(scaffold_list) train = scaffold_data['pdb'][train_idx].tolist() val = scaffold_data['pdb'][val_idx].tolist() test = scaffold_data['pdb'][test_idx].tolist() keys = input_sharded.get_keys() if keys != ['ensemble']: raise RuntimeError('Can only apply to sharded by ensemble.') prefix = sh.get_prefix(output_root) num_shards = sh.get_num_shards(output_root) train_sharded = sh.Sharded(f'{prefix:}_train@{num_shards:}', keys) val_sharded = sh.Sharded(f'{prefix:}_val@{num_shards:}', keys) test_sharded = sh.Sharded(f'{prefix:}_test@{num_shards:}', keys) logger.info('Writing sets') train_filter_fn = filters.form_filter_against_list(train, 'ensemble') val_filter_fn = filters.form_filter_against_list(val, 'ensemble') test_filter_fn = filters.form_filter_against_list(test, 'ensemble') sho.filter_sharded(input_sharded, train_sharded, train_filter_fn, shuffle_buffer) sho.filter_sharded(input_sharded, val_sharded, val_filter_fn, shuffle_buffer) sho.filter_sharded(input_sharded, test_sharded, test_filter_fn, shuffle_buffer) # write splits to text files np.savetxt(output_root.split('@')[0] + '_train.txt', train, fmt='%s') np.savetxt(output_root.split('@')[0] + '_val.txt', val, fmt='%s') np.savetxt(output_root.split('@')[0] + '_test.txt', test, fmt='%s')
def rekey(input_sharded, output_sharded, shuffle_buffer=0): """Rekey dataset.""" dirname = os.path.dirname(output_sharded.path) if not os.path.exists(dirname) and dirname != '': os.makedirs(dirname, exist_ok=True) input_num_shards = input_sharded.get_num_shards() # We will just map to tmp, then reshard. tmp_path = output_sharded.get_prefix() + f'_tmp@{input_num_shards:}' tmp_sharded = sh.Sharded(tmp_path, output_sharded.get_keys()) for shard_num in tqdm.trange(input_num_shards): df = input_sharded.read_shard(shard_num) tmp_sharded._write_shard(shard_num, df) num_input_structures = input_sharded.get_num_keyed() num_output_structures = tmp_sharded.get_num_keyed() logging.info(f'After rekey-ing, have {num_output_structures:} keyed, ' f'from {num_input_structures:} originally.') reshard(tmp_sharded, output_sharded, shuffle_buffer) tmp_sharded.delete_files()
def split_dataset(dirname_all, dirname_split, input_ds_name, input_splits, input_csvfile, input_exclude): # Read the sharded dataset input_sharded = sh.Sharded.load(input_ds_name) input_shard = input_sharded.read_shard(0) input_label = input_sharded.read_shard(0, 'labels') # Create output directories if not os.path.exists(dirname_all) and dirname_all != '': os.makedirs(dirname_all, exist_ok=True) if not os.path.exists(dirname_split) and dirname_split != '': os.makedirs(dirname_split, exist_ok=True) # Correct for ensemble = None input_shard['ensemble'] = input_shard['model'] # Save the full (corrected) dataset sharded_all = sh.Sharded(dirname_all + '/qm9_all@1', input_sharded.get_keys()) sharded_all._write_shard(0, input_shard) sharded_all.add_to_shard(0, input_label, 'labels') # Read raw and split data label_data = pd.read_csv(input_csvfile) indices_ex = np.loadtxt(input_exclude, dtype=int) indices_tr = np.loadtxt(input_splits + '/indices_train.dat', dtype=int) indices_va = np.loadtxt(input_splits + '/indices_valid.dat', dtype=int) indices_te = np.loadtxt(input_splits + '/indices_test.dat', dtype=int) # Create lists of molecule IDs for exclusion and splits mol_ids = label_data['mol_id'].tolist() mol_ids_ex = label_data.loc[indices_ex]['mol_id'].tolist() mol_ids_te = label_data.loc[indices_te]['mol_id'].tolist() mol_ids_va = label_data.loc[indices_va]['mol_id'].tolist() mol_ids_tr = label_data.loc[indices_tr]['mol_id'].tolist() # Write lists of mol_ids to files with open(dirname_split + '/mol_ids_excluded.txt', 'w') as f: for mol_id in mol_ids_ex: f.write("%s\n" % mol_id) with open(dirname_split + '/mol_ids_training.txt', 'w') as f: for mol_id in mol_ids_tr: f.write("%s\n" % mol_id) with open(dirname_split + '/mol_ids_validation.txt', 'w') as f: for mol_id in mol_ids_va: f.write("%s\n" % mol_id) with open(dirname_split + '/mol_ids_test.txt', 'w') as f: for mol_id in mol_ids_te: f.write("%s\n" % mol_id) # Split the labels labels_te = input_label.loc[label_data['mol_id'].isin( mol_ids_te)].reset_index(drop=True) labels_va = input_label.loc[label_data['mol_id'].isin( mol_ids_va)].reset_index(drop=True) labels_tr = input_label.loc[label_data['mol_id'].isin( mol_ids_tr)].reset_index(drop=True) # Filter and write out training set filter_tr = filters.form_filter_against_list(mol_ids_tr, 'subunit') sharded_tr = sh.Sharded(dirname_split + '/train@1', sharded_all.get_keys()) sho.filter_sharded(sharded_all, sharded_tr, filter_tr) sharded_tr.add_to_shard(0, labels_tr, 'labels') # Filter and write out validation set filter_va = filters.form_filter_against_list(mol_ids_va, 'structure') sharded_va = sh.Sharded(dirname_split + '/val@1', sharded_all.get_keys()) sho.filter_sharded(sharded_all, sharded_va, filter_va) sharded_va.add_to_shard(0, labels_va, 'labels') # Filter and write out test set filter_te = filters.form_filter_against_list(mol_ids_te, 'structure') sharded_te = sh.Sharded(dirname_split + '/test@1', input_sharded.get_keys()) sho.filter_sharded(sharded_all, sharded_te, filter_te) sharded_te.add_to_shard(0, labels_te, 'labels')