Exemple #1
0
def split(input_sharded, output_root):
    """Split temporally and shuffle examples."""
    prefix = sh.get_prefix(output_root)
    num_shards = sh.get_num_shards(output_root)

    # Re-key to ensemble, subunit.  To allow for shuffling across targets.
    tmp_sharded = sh.Sharded(
        f'{prefix:}_rekeyed@{num_shards:}', ['ensemble', 'subunit'])
    logger.info("Rekeying")
    sho.rekey(input_sharded, tmp_sharded)

    keys = tmp_sharded.get_keys()
    train_sharded = sh.Sharded(f'{prefix:}_train@{num_shards:}', keys)
    val_sharded = sh.Sharded(f'{prefix:}_val@{num_shards:}', keys)
    test_sharded = sh.Sharded(f'{prefix:}_test@{num_shards:}', keys)

    logger.info("Splitting")
    train_filter_fn = filters.form_filter_against_list(TRAIN, 'ensemble')
    val_filter_fn = filters.form_filter_against_list(VAL, 'ensemble')
    test_filter_fn = filters.form_filter_against_list(TEST, 'ensemble')

    sho.filter_sharded(
        tmp_sharded, train_sharded, train_filter_fn, num_shards)
    sho.filter_sharded(
        tmp_sharded, val_sharded, val_filter_fn, num_shards)
    sho.filter_sharded(
        tmp_sharded, test_sharded, test_filter_fn, num_shards)

    tmp_sharded.delete_files()
Exemple #2
0
def split(input_sharded, output_root, info_csv, shuffle_buffer):
    """Split by protein."""
    if input_sharded.get_keys() != ['ensemble']:
        raise RuntimeError('Can only apply to sharded by ensemble.')

    info = pd.read_csv(info_csv)
    info['ensemble'] = info.apply(
        lambda x: x['ligand'] + '__' + x['active_struc'].split('_')[
            2] + '__' + x['inactive_struc'].split('_')[2],
        axis=1)
    info = info.set_index('ensemble')
    # Remove duplicate ensembles.
    info = info[~info.index.duplicated()]

    ensembles = input_sharded.get_names()['ensemble']
    in_use = info.loc[ensembles]
    active = in_use[in_use['label'] == 'A']
    inactive = in_use[in_use['label'] == 'I']

    # Split by protein.
    proteins = info['protein'].unique()
    i_test, i_val, i_train = splits.random_split(len(proteins), 0.6, 0.2, 0.2)
    p_train = proteins[i_train]
    p_val = proteins[i_val]
    p_test = proteins[i_test]
    logger.info(f'Train proteins: {p_train:}')
    logger.info(f'Val proteins: {p_val:}')
    logger.info(f'Test proteins: {p_test:}')

    train = info[info['protein'].isin(p_train)].index.tolist()
    val = info[info['protein'].isin(p_val)].index.tolist()
    test = info[info['protein'].isin(p_test)].index.tolist()

    logger.info(f'{len(train):} train examples, {len(val):} val examples, '
                f'{len(test):} test examples.')

    keys = input_sharded.get_keys()
    prefix = sh.get_prefix(output_root)
    num_shards = sh.get_num_shards(output_root)
    train_sharded = sh.Sharded(f'{prefix:}_train@{num_shards:}', keys)
    val_sharded = sh.Sharded(f'{prefix:}_val@{num_shards:}', keys)
    test_sharded = sh.Sharded(f'{prefix:}_test@{num_shards:}', keys)

    train_filter_fn = filters.form_filter_against_list(train, 'ensemble')
    val_filter_fn = filters.form_filter_against_list(val, 'ensemble')
    test_filter_fn = filters.form_filter_against_list(test, 'ensemble')

    sho.filter_sharded(input_sharded, train_sharded, train_filter_fn,
                       shuffle_buffer)
    sho.filter_sharded(input_sharded, val_sharded, val_filter_fn,
                       shuffle_buffer)
    sho.filter_sharded(input_sharded, test_sharded, test_filter_fn,
                       shuffle_buffer)
Exemple #3
0
def prepare(input_sharded_path, output_root, score_dir):
    input_sharded = sh.Sharded.load(input_sharded_path)
    if score_dir is not None:
        prefix = sh.get_prefix(output_root)
        num_shards = sh.get_num_shards(output_root)
        keys = input_sharded.get_keys()
        filter_sharded = sh.Sharded(f'{prefix:}_filtered@{num_shards:}', keys)
        filter_fn = sc.form_score_filter(score_dir)
        logger.info('Filtering against score file.')
        sho.filter_sharded(input_sharded, filter_sharded, filter_fn)
        split(filter_sharded, output_root)
        filter_sharded.delete_files()
    else:
        split(input_sharded, output_root)
Exemple #4
0
def split(input_sharded, output_root, scaffold_data, shuffle_buffer):
    """Split by scaffold."""
    if input_sharded.get_keys() != ['ensemble']:
        raise RuntimeError('Can only apply to sharded by ensemble.')

    logger.info('Process scaffold and ensemble data')
    scaffold_list = scaffold_data['Scaffold'].tolist()
    ensemble_list = []
    for i in range(len(scaffold_data)):
        ensemble = ''
        ensemble += scaffold_data.iloc[i]['ligand']
        ensemble += '__'+scaffold_data.iloc[i]['active_struc'].split('_')[-1]
        ensemble += '__'+scaffold_data.iloc[i]['inactive_struc'].split('_')[-1]
        ensemble_list.append(ensemble)
    ensemble_list = np.array(ensemble_list)
 
    logger.info('Splitting by scaffold')
    train_idx, val_idx, test_idx = splits.scaffold_split(scaffold_list)
    train = ensemble_list[train_idx].tolist()
    val = ensemble_list[val_idx].tolist()
    test = ensemble_list[test_idx].tolist()

    keys = input_sharded.get_keys()
    if keys != ['ensemble']:
        raise RuntimeError('Can only apply to sharded by ensemble.')
    prefix = sh.get_prefix(output_root)
    num_shards = sh.get_num_shards(output_root)
    train_sharded = sh.Sharded(f'{prefix:}_train@{num_shards:}', keys)
    val_sharded = sh.Sharded(f'{prefix:}_val@{num_shards:}', keys)
    test_sharded = sh.Sharded(f'{prefix:}_test@{num_shards:}', keys)

    logger.info('Writing sets')
    train_filter_fn = filters.form_filter_against_list(train, 'ensemble')
    val_filter_fn = filters.form_filter_against_list(val, 'ensemble')
    test_filter_fn = filters.form_filter_against_list(test, 'ensemble')

    sho.filter_sharded(
        input_sharded, train_sharded, train_filter_fn, shuffle_buffer)
    sho.filter_sharded(
        input_sharded, val_sharded, val_filter_fn, shuffle_buffer)
    sho.filter_sharded(
        input_sharded, test_sharded, test_filter_fn, shuffle_buffer)

    # write splits to text files
    np.savetxt(output_root.split('@')[0]+'_train.txt', train, fmt='%s')
    np.savetxt(output_root.split('@')[0]+'_val.txt', val, fmt='%s')
    np.savetxt(output_root.split('@')[0]+'_test.txt', test, fmt='%s')
Exemple #5
0
def split(input_sharded, output_root, shuffle_buffer, cutoff=30):
    """Split by sequence identity."""
    if input_sharded.get_keys() != ['ensemble']:
        raise RuntimeError('Can only apply to sharded by ensemble.')

    all_chain_sequences = []
    logger.info('Loading chain sequences')
    for _, shard in input_sharded.iter_shards():
        all_chain_sequences.extend(seq.get_all_chain_sequences_df(shard))

    logger.info('Splitting by cluster')
    train, val, test = atom3d.splits.sequence.cluster_split(
        all_chain_sequences, cutoff)

    # Will just look up ensembles.
    train = [x[0] for x in train]
    val = [x[0] for x in val]
    test = [x[0] for x in test]

    keys = input_sharded.get_keys()
    if keys != ['ensemble']:
        raise RuntimeError('Can only apply to sharded by ensemble.')
    prefix = sh.get_prefix(output_root)
    num_shards = sh.get_num_shards(output_root)
    train_sharded = sh.Sharded(f'{prefix:}_train@{num_shards:}', keys)
    val_sharded = sh.Sharded(f'{prefix:}_val@{num_shards:}', keys)
    test_sharded = sh.Sharded(f'{prefix:}_test@{num_shards:}', keys)

    logger.info('Writing sets')
    train_filter_fn = filters.form_filter_against_list(train, 'ensemble')
    val_filter_fn = filters.form_filter_against_list(val, 'ensemble')
    test_filter_fn = filters.form_filter_against_list(test, 'ensemble')

    sho.filter_sharded(input_sharded, train_sharded, train_filter_fn,
                       shuffle_buffer)
    sho.filter_sharded(input_sharded, val_sharded, val_filter_fn,
                       shuffle_buffer)
    sho.filter_sharded(input_sharded, test_sharded, test_filter_fn,
                       shuffle_buffer)

    # write splits to text files
    np.savetxt(output_root.split('@')[0] + '_train.txt', train, fmt='%s')
    np.savetxt(output_root.split('@')[0] + '_val.txt', val, fmt='%s')
    np.savetxt(output_root.split('@')[0] + '_test.txt', test, fmt='%s')
Exemple #6
0
def split(input_sharded, output_root, scaffold_data, shuffle_buffer):
    """Split by sequence identity."""
    if input_sharded.get_keys() != ['ensemble']:
        raise RuntimeError('Can only apply to sharded by ensemble.')

    logger.info('Splitting by scaffold')
    scaffold_list = scaffold_data['Scaffold'].tolist()
    train_idx, val_idx, test_idx = splits.scaffold_split(scaffold_list)
    train = scaffold_data['pdb'][train_idx].tolist()
    val = scaffold_data['pdb'][val_idx].tolist()
    test = scaffold_data['pdb'][test_idx].tolist()

    keys = input_sharded.get_keys()
    if keys != ['ensemble']:
        raise RuntimeError('Can only apply to sharded by ensemble.')
    prefix = sh.get_prefix(output_root)
    num_shards = sh.get_num_shards(output_root)
    train_sharded = sh.Sharded(f'{prefix:}_train@{num_shards:}', keys)
    val_sharded = sh.Sharded(f'{prefix:}_val@{num_shards:}', keys)
    test_sharded = sh.Sharded(f'{prefix:}_test@{num_shards:}', keys)

    logger.info('Writing sets')
    train_filter_fn = filters.form_filter_against_list(train, 'ensemble')
    val_filter_fn = filters.form_filter_against_list(val, 'ensemble')
    test_filter_fn = filters.form_filter_against_list(test, 'ensemble')

    sho.filter_sharded(input_sharded, train_sharded, train_filter_fn,
                       shuffle_buffer)
    sho.filter_sharded(input_sharded, val_sharded, val_filter_fn,
                       shuffle_buffer)
    sho.filter_sharded(input_sharded, test_sharded, test_filter_fn,
                       shuffle_buffer)

    # write splits to text files
    np.savetxt(output_root.split('@')[0] + '_train.txt', train, fmt='%s')
    np.savetxt(output_root.split('@')[0] + '_val.txt', val, fmt='%s')
    np.savetxt(output_root.split('@')[0] + '_test.txt', test, fmt='%s')