Esempio n. 1
0
def shard_envs(input_path, output_path, num_threads=8, subsample=True):
    input_sharded = sh.Sharded.load(input_path)
    keys = input_sharded.get_keys()
    if keys != ['ensemble']:
        raise RuntimeError('Can only apply to sharded by ensemble.')
    output_sharded = sh.Sharded(output_path, keys)
    input_num_shards = input_sharded.get_num_shards()

    tmp_path = output_sharded.get_prefix() + f'_tmp@{input_num_shards:}'
    tmp_sharded = sh.Sharded(tmp_path, keys)

    not_written = []
    for i in range(input_num_shards):
        shard = output_sharded._get_shard(i)
        if not os.path.exists(shard):
            not_written.append(i)

    print(f'Using {num_threads:} threads')

    inputs = [(input_sharded, tmp_sharded, shard_num, subsample)
              for shard_num in range(8)]

    # with multiprocessing.Pool(processes=num_threads) as pool:
    #     pool.starmap(_shard_envs, inputs)
    par.submit_jobs(_shard_envs, inputs, num_threads)

    sho.reshard(tmp_sharded, output_sharded)
    tmp_sharded.delete_files()
Esempio n. 2
0
def split(input_sharded, output_root):
    """Split temporally and shuffle examples."""
    prefix = sh.get_prefix(output_root)
    num_shards = sh.get_num_shards(output_root)

    # Re-key to ensemble, subunit.  To allow for shuffling across targets.
    tmp_sharded = sh.Sharded(
        f'{prefix:}_rekeyed@{num_shards:}', ['ensemble', 'subunit'])
    logger.info("Rekeying")
    sho.rekey(input_sharded, tmp_sharded)

    keys = tmp_sharded.get_keys()
    train_sharded = sh.Sharded(f'{prefix:}_train@{num_shards:}', keys)
    val_sharded = sh.Sharded(f'{prefix:}_val@{num_shards:}', keys)
    test_sharded = sh.Sharded(f'{prefix:}_test@{num_shards:}', keys)

    logger.info("Splitting")
    train_filter_fn = filters.form_filter_against_list(TRAIN, 'ensemble')
    val_filter_fn = filters.form_filter_against_list(VAL, 'ensemble')
    test_filter_fn = filters.form_filter_against_list(TEST, 'ensemble')

    sho.filter_sharded(
        tmp_sharded, train_sharded, train_filter_fn, num_shards)
    sho.filter_sharded(
        tmp_sharded, val_sharded, val_filter_fn, num_shards)
    sho.filter_sharded(
        tmp_sharded, test_sharded, test_filter_fn, num_shards)

    tmp_sharded.delete_files()
Esempio n. 3
0
def split(input_sharded, output_root, info_csv, shuffle_buffer):
    """Split by protein."""
    if input_sharded.get_keys() != ['ensemble']:
        raise RuntimeError('Can only apply to sharded by ensemble.')

    info = pd.read_csv(info_csv)
    info['ensemble'] = info.apply(
        lambda x: x['ligand'] + '__' + x['active_struc'].split('_')[
            2] + '__' + x['inactive_struc'].split('_')[2],
        axis=1)
    info = info.set_index('ensemble')
    # Remove duplicate ensembles.
    info = info[~info.index.duplicated()]

    ensembles = input_sharded.get_names()['ensemble']
    in_use = info.loc[ensembles]
    active = in_use[in_use['label'] == 'A']
    inactive = in_use[in_use['label'] == 'I']

    # Split by protein.
    proteins = info['protein'].unique()
    i_test, i_val, i_train = splits.random_split(len(proteins), 0.6, 0.2, 0.2)
    p_train = proteins[i_train]
    p_val = proteins[i_val]
    p_test = proteins[i_test]
    logger.info(f'Train proteins: {p_train:}')
    logger.info(f'Val proteins: {p_val:}')
    logger.info(f'Test proteins: {p_test:}')

    train = info[info['protein'].isin(p_train)].index.tolist()
    val = info[info['protein'].isin(p_val)].index.tolist()
    test = info[info['protein'].isin(p_test)].index.tolist()

    logger.info(f'{len(train):} train examples, {len(val):} val examples, '
                f'{len(test):} test examples.')

    keys = input_sharded.get_keys()
    prefix = sh.get_prefix(output_root)
    num_shards = sh.get_num_shards(output_root)
    train_sharded = sh.Sharded(f'{prefix:}_train@{num_shards:}', keys)
    val_sharded = sh.Sharded(f'{prefix:}_val@{num_shards:}', keys)
    test_sharded = sh.Sharded(f'{prefix:}_test@{num_shards:}', keys)

    train_filter_fn = filters.form_filter_against_list(train, 'ensemble')
    val_filter_fn = filters.form_filter_against_list(val, 'ensemble')
    test_filter_fn = filters.form_filter_against_list(test, 'ensemble')

    sho.filter_sharded(input_sharded, train_sharded, train_filter_fn,
                       shuffle_buffer)
    sho.filter_sharded(input_sharded, val_sharded, val_filter_fn,
                       shuffle_buffer)
    sho.filter_sharded(input_sharded, test_sharded, test_filter_fn,
                       shuffle_buffer)
Esempio n. 4
0
def filter_pairs(input_sharded_path, output_root, bsa, against_path,
                 shuffle_buffer):
    input_sharded = sh.Sharded.load(input_sharded_path)
    keys = input_sharded.get_keys()
    if keys != ['ensemble']:
        raise RuntimeError('Can only apply to sharded by ensemble.')
    output_sharded = sh.Sharded(output_root, keys)
    # We form the combined filter by starting with the identity filter and
    # composing with further filters.
    filter_fn = filters.identity_filter

    filter_fn = filters.compose(
        atom3d.filters.pdb.form_molecule_type_filter(allowed=['prot']),
        filter_fn)
    filter_fn = filters.compose(filters.form_size_filter(min_size=50),
                                filter_fn)
    filter_fn = filters.compose(atom3d.filters.pdb.form_resolution_filter(3.5),
                                filter_fn)
    filter_fn = filters.compose(
        atom3d.filters.pdb.form_source_filter(allowed=['diffraction', 'EM']),
        filter_fn)
    if bsa is not None:
        filter_fn = filters.compose(form_bsa_filter(bsa, 500), filter_fn)
    if against_path is not None:
        against = sh.Sharded.load(against_path)
        filter_fn = filters.compose(
            atom3d.filters.sequence.form_seq_filter_against(against, 0.3),
            filter_fn)
        filter_fn = filters.compose(
            form_scop_pair_filter_against(against, 'superfamily'), filter_fn)

    sho.filter_sharded(input_sharded, output_sharded, filter_fn)
    split(output_sharded, output_root, shuffle_buffer)
Esempio n. 5
0
def filter_sharded(input_sharded, output_sharded, filter_fn, shuffle_buffer=0):
    """Filter sharded dataset to new sharded dataset, using provided filter."""
    logging.basicConfig(format='%(asctime)s %(levelname)s %(process)d: ' +
                        '%(message)s',
                        level=logging.INFO)

    if not os.path.exists(os.path.dirname(output_sharded.path)):
        os.makedirs(os.path.dirname(output_sharded.path))

    input_num_shards = input_sharded.get_num_shards()

    # We will just map to tmp, then reshard.
    tmp_path = output_sharded.get_prefix() + f'_tmp@{input_num_shards:}'
    tmp_sharded = sh.Sharded(tmp_path, input_sharded.get_keys())

    logging.info(f'Filtering {input_sharded.path:} to {output_sharded.path:}')
    # Apply filter.
    for shard_num in tqdm.trange(input_num_shards):
        df = input_sharded.read_shard(shard_num)
        if len(df) > 0:
            df = filter_fn(df)
        tmp_sharded._write_shard(shard_num, df)

    num_input_structures = input_sharded.get_num_keyed()
    num_output_structures = tmp_sharded.get_num_keyed()
    logging.info(f'After filtering, have {num_output_structures:} / '
                 f'{num_input_structures:} left.')
    reshard(tmp_sharded, output_sharded, shuffle_buffer)
    tmp_sharded.delete_files()
Esempio n. 6
0
def split(input_sharded, output_root, scaffold_data, shuffle_buffer):
    """Split by scaffold."""
    if input_sharded.get_keys() != ['ensemble']:
        raise RuntimeError('Can only apply to sharded by ensemble.')

    logger.info('Process scaffold and ensemble data')
    scaffold_list = scaffold_data['Scaffold'].tolist()
    ensemble_list = []
    for i in range(len(scaffold_data)):
        ensemble = ''
        ensemble += scaffold_data.iloc[i]['ligand']
        ensemble += '__'+scaffold_data.iloc[i]['active_struc'].split('_')[-1]
        ensemble += '__'+scaffold_data.iloc[i]['inactive_struc'].split('_')[-1]
        ensemble_list.append(ensemble)
    ensemble_list = np.array(ensemble_list)
 
    logger.info('Splitting by scaffold')
    train_idx, val_idx, test_idx = splits.scaffold_split(scaffold_list)
    train = ensemble_list[train_idx].tolist()
    val = ensemble_list[val_idx].tolist()
    test = ensemble_list[test_idx].tolist()

    keys = input_sharded.get_keys()
    if keys != ['ensemble']:
        raise RuntimeError('Can only apply to sharded by ensemble.')
    prefix = sh.get_prefix(output_root)
    num_shards = sh.get_num_shards(output_root)
    train_sharded = sh.Sharded(f'{prefix:}_train@{num_shards:}', keys)
    val_sharded = sh.Sharded(f'{prefix:}_val@{num_shards:}', keys)
    test_sharded = sh.Sharded(f'{prefix:}_test@{num_shards:}', keys)

    logger.info('Writing sets')
    train_filter_fn = filters.form_filter_against_list(train, 'ensemble')
    val_filter_fn = filters.form_filter_against_list(val, 'ensemble')
    test_filter_fn = filters.form_filter_against_list(test, 'ensemble')

    sho.filter_sharded(
        input_sharded, train_sharded, train_filter_fn, shuffle_buffer)
    sho.filter_sharded(
        input_sharded, val_sharded, val_filter_fn, shuffle_buffer)
    sho.filter_sharded(
        input_sharded, test_sharded, test_filter_fn, shuffle_buffer)

    # write splits to text files
    np.savetxt(output_root.split('@')[0]+'_train.txt', train, fmt='%s')
    np.savetxt(output_root.split('@')[0]+'_val.txt', val, fmt='%s')
    np.savetxt(output_root.split('@')[0]+'_test.txt', test, fmt='%s')
Esempio n. 7
0
def split(input_sharded, output_root, shuffle_buffer, cutoff=30):
    """Split by sequence identity."""
    if input_sharded.get_keys() != ['ensemble']:
        raise RuntimeError('Can only apply to sharded by ensemble.')

    all_chain_sequences = []
    logger.info('Loading chain sequences')
    for _, shard in input_sharded.iter_shards():
        all_chain_sequences.extend(seq.get_all_chain_sequences_df(shard))

    logger.info('Splitting by cluster')
    train, val, test = atom3d.splits.sequence.cluster_split(
        all_chain_sequences, cutoff)

    # Will just look up ensembles.
    train = [x[0] for x in train]
    val = [x[0] for x in val]
    test = [x[0] for x in test]

    keys = input_sharded.get_keys()
    if keys != ['ensemble']:
        raise RuntimeError('Can only apply to sharded by ensemble.')
    prefix = sh.get_prefix(output_root)
    num_shards = sh.get_num_shards(output_root)
    train_sharded = sh.Sharded(f'{prefix:}_train@{num_shards:}', keys)
    val_sharded = sh.Sharded(f'{prefix:}_val@{num_shards:}', keys)
    test_sharded = sh.Sharded(f'{prefix:}_test@{num_shards:}', keys)

    logger.info('Writing sets')
    train_filter_fn = filters.form_filter_against_list(train, 'ensemble')
    val_filter_fn = filters.form_filter_against_list(val, 'ensemble')
    test_filter_fn = filters.form_filter_against_list(test, 'ensemble')

    sho.filter_sharded(input_sharded, train_sharded, train_filter_fn,
                       shuffle_buffer)
    sho.filter_sharded(input_sharded, val_sharded, val_filter_fn,
                       shuffle_buffer)
    sho.filter_sharded(input_sharded, test_sharded, test_filter_fn,
                       shuffle_buffer)

    # write splits to text files
    np.savetxt(output_root.split('@')[0] + '_train.txt', train, fmt='%s')
    np.savetxt(output_root.split('@')[0] + '_val.txt', val, fmt='%s')
    np.savetxt(output_root.split('@')[0] + '_test.txt', test, fmt='%s')
Esempio n. 8
0
def shard_pairs(input_path, output_path, cutoff, cutoff_type, num_threads):
    input_sharded = sh.Sharded.load(input_path)
    keys = input_sharded.get_keys()
    if keys != ['ensemble']:
        raise RuntimeError('Can only apply to sharded by ensemble.')
    output_sharded = sh.Sharded(output_path, keys)
    input_num_shards = input_sharded.get_num_shards()

    tmp_path = output_sharded.get_prefix() + f'_tmp@{input_num_shards:}'
    tmp_sharded = sh.Sharded(tmp_path, keys)

    logger.info(f'Using {num_threads:} threads')

    inputs = [(input_sharded, tmp_sharded, shard_num, cutoff, cutoff_type)
              for shard_num in range(input_num_shards)]

    par.submit_jobs(_shard_pairs, inputs, num_threads)

    sho.reshard(tmp_sharded, output_sharded)
    tmp_sharded.delete_files()
Esempio n. 9
0
def prepare(input_sharded_path, output_root, score_dir):
    input_sharded = sh.Sharded.load(input_sharded_path)
    if score_dir is not None:
        prefix = sh.get_prefix(output_root)
        num_shards = sh.get_num_shards(output_root)
        keys = input_sharded.get_keys()
        filter_sharded = sh.Sharded(f'{prefix:}_filtered@{num_shards:}', keys)
        filter_fn = sc.form_score_filter(score_dir)
        logger.info('Filtering against score file.')
        sho.filter_sharded(input_sharded, filter_sharded, filter_fn)
        split(filter_sharded, output_root)
        filter_sharded.delete_files()
    else:
        split(input_sharded, output_root)
Esempio n. 10
0
def split(input_sharded, output_root, scaffold_data, shuffle_buffer):
    """Split by sequence identity."""
    if input_sharded.get_keys() != ['ensemble']:
        raise RuntimeError('Can only apply to sharded by ensemble.')

    logger.info('Splitting by scaffold')
    scaffold_list = scaffold_data['Scaffold'].tolist()
    train_idx, val_idx, test_idx = splits.scaffold_split(scaffold_list)
    train = scaffold_data['pdb'][train_idx].tolist()
    val = scaffold_data['pdb'][val_idx].tolist()
    test = scaffold_data['pdb'][test_idx].tolist()

    keys = input_sharded.get_keys()
    if keys != ['ensemble']:
        raise RuntimeError('Can only apply to sharded by ensemble.')
    prefix = sh.get_prefix(output_root)
    num_shards = sh.get_num_shards(output_root)
    train_sharded = sh.Sharded(f'{prefix:}_train@{num_shards:}', keys)
    val_sharded = sh.Sharded(f'{prefix:}_val@{num_shards:}', keys)
    test_sharded = sh.Sharded(f'{prefix:}_test@{num_shards:}', keys)

    logger.info('Writing sets')
    train_filter_fn = filters.form_filter_against_list(train, 'ensemble')
    val_filter_fn = filters.form_filter_against_list(val, 'ensemble')
    test_filter_fn = filters.form_filter_against_list(test, 'ensemble')

    sho.filter_sharded(input_sharded, train_sharded, train_filter_fn,
                       shuffle_buffer)
    sho.filter_sharded(input_sharded, val_sharded, val_filter_fn,
                       shuffle_buffer)
    sho.filter_sharded(input_sharded, test_sharded, test_filter_fn,
                       shuffle_buffer)

    # write splits to text files
    np.savetxt(output_root.split('@')[0] + '_train.txt', train, fmt='%s')
    np.savetxt(output_root.split('@')[0] + '_val.txt', val, fmt='%s')
    np.savetxt(output_root.split('@')[0] + '_test.txt', test, fmt='%s')
Esempio n. 11
0
def rekey(input_sharded, output_sharded, shuffle_buffer=0):
    """Rekey dataset."""
    dirname = os.path.dirname(output_sharded.path)
    if not os.path.exists(dirname) and dirname != '':
        os.makedirs(dirname, exist_ok=True)

    input_num_shards = input_sharded.get_num_shards()

    # We will just map to tmp, then reshard.
    tmp_path = output_sharded.get_prefix() + f'_tmp@{input_num_shards:}'
    tmp_sharded = sh.Sharded(tmp_path, output_sharded.get_keys())

    for shard_num in tqdm.trange(input_num_shards):
        df = input_sharded.read_shard(shard_num)
        tmp_sharded._write_shard(shard_num, df)

    num_input_structures = input_sharded.get_num_keyed()
    num_output_structures = tmp_sharded.get_num_keyed()
    logging.info(f'After rekey-ing, have {num_output_structures:} keyed, '
                 f'from {num_input_structures:} originally.')
    reshard(tmp_sharded, output_sharded, shuffle_buffer)
    tmp_sharded.delete_files()
Esempio n. 12
0
def split_dataset(dirname_all, dirname_split, input_ds_name, input_splits,
                  input_csvfile, input_exclude):

    # Read the sharded dataset
    input_sharded = sh.Sharded.load(input_ds_name)
    input_shard = input_sharded.read_shard(0)
    input_label = input_sharded.read_shard(0, 'labels')

    # Create output directories
    if not os.path.exists(dirname_all) and dirname_all != '':
        os.makedirs(dirname_all, exist_ok=True)
    if not os.path.exists(dirname_split) and dirname_split != '':
        os.makedirs(dirname_split, exist_ok=True)

    # Correct for ensemble = None
    input_shard['ensemble'] = input_shard['model']

    # Save the full (corrected) dataset
    sharded_all = sh.Sharded(dirname_all + '/qm9_all@1',
                             input_sharded.get_keys())
    sharded_all._write_shard(0, input_shard)
    sharded_all.add_to_shard(0, input_label, 'labels')

    # Read raw and split data
    label_data = pd.read_csv(input_csvfile)
    indices_ex = np.loadtxt(input_exclude, dtype=int)
    indices_tr = np.loadtxt(input_splits + '/indices_train.dat', dtype=int)
    indices_va = np.loadtxt(input_splits + '/indices_valid.dat', dtype=int)
    indices_te = np.loadtxt(input_splits + '/indices_test.dat', dtype=int)

    # Create lists of molecule IDs for exclusion and splits
    mol_ids = label_data['mol_id'].tolist()
    mol_ids_ex = label_data.loc[indices_ex]['mol_id'].tolist()
    mol_ids_te = label_data.loc[indices_te]['mol_id'].tolist()
    mol_ids_va = label_data.loc[indices_va]['mol_id'].tolist()
    mol_ids_tr = label_data.loc[indices_tr]['mol_id'].tolist()

    # Write lists of mol_ids to files
    with open(dirname_split + '/mol_ids_excluded.txt', 'w') as f:
        for mol_id in mol_ids_ex:
            f.write("%s\n" % mol_id)
    with open(dirname_split + '/mol_ids_training.txt', 'w') as f:
        for mol_id in mol_ids_tr:
            f.write("%s\n" % mol_id)
    with open(dirname_split + '/mol_ids_validation.txt', 'w') as f:
        for mol_id in mol_ids_va:
            f.write("%s\n" % mol_id)
    with open(dirname_split + '/mol_ids_test.txt', 'w') as f:
        for mol_id in mol_ids_te:
            f.write("%s\n" % mol_id)

    # Split the labels
    labels_te = input_label.loc[label_data['mol_id'].isin(
        mol_ids_te)].reset_index(drop=True)
    labels_va = input_label.loc[label_data['mol_id'].isin(
        mol_ids_va)].reset_index(drop=True)
    labels_tr = input_label.loc[label_data['mol_id'].isin(
        mol_ids_tr)].reset_index(drop=True)

    # Filter and write out training set
    filter_tr = filters.form_filter_against_list(mol_ids_tr, 'subunit')
    sharded_tr = sh.Sharded(dirname_split + '/train@1', sharded_all.get_keys())
    sho.filter_sharded(sharded_all, sharded_tr, filter_tr)
    sharded_tr.add_to_shard(0, labels_tr, 'labels')

    # Filter and write out validation set
    filter_va = filters.form_filter_against_list(mol_ids_va, 'structure')
    sharded_va = sh.Sharded(dirname_split + '/val@1', sharded_all.get_keys())
    sho.filter_sharded(sharded_all, sharded_va, filter_va)
    sharded_va.add_to_shard(0, labels_va, 'labels')

    # Filter and write out test set
    filter_te = filters.form_filter_against_list(mol_ids_te, 'structure')
    sharded_te = sh.Sharded(dirname_split + '/test@1',
                            input_sharded.get_keys())
    sho.filter_sharded(sharded_all, sharded_te, filter_te)
    sharded_te.add_to_shard(0, labels_te, 'labels')