예제 #1
0
def split(input_sharded, output_root):
    """Split temporally and shuffle examples."""
    prefix = sh.get_prefix(output_root)
    num_shards = sh.get_num_shards(output_root)

    # Re-key to ensemble, subunit.  To allow for shuffling across targets.
    tmp_sharded = sh.Sharded(
        f'{prefix:}_rekeyed@{num_shards:}', ['ensemble', 'subunit'])
    logger.info("Rekeying")
    sho.rekey(input_sharded, tmp_sharded)

    keys = tmp_sharded.get_keys()
    train_sharded = sh.Sharded(f'{prefix:}_train@{num_shards:}', keys)
    val_sharded = sh.Sharded(f'{prefix:}_val@{num_shards:}', keys)
    test_sharded = sh.Sharded(f'{prefix:}_test@{num_shards:}', keys)

    logger.info("Splitting")
    train_filter_fn = filters.form_filter_against_list(TRAIN, 'ensemble')
    val_filter_fn = filters.form_filter_against_list(VAL, 'ensemble')
    test_filter_fn = filters.form_filter_against_list(TEST, 'ensemble')

    sho.filter_sharded(
        tmp_sharded, train_sharded, train_filter_fn, num_shards)
    sho.filter_sharded(
        tmp_sharded, val_sharded, val_filter_fn, num_shards)
    sho.filter_sharded(
        tmp_sharded, test_sharded, test_filter_fn, num_shards)

    tmp_sharded.delete_files()
예제 #2
0
def split(input_sharded, output_root, info_csv, shuffle_buffer):
    """Split by protein."""
    if input_sharded.get_keys() != ['ensemble']:
        raise RuntimeError('Can only apply to sharded by ensemble.')

    info = pd.read_csv(info_csv)
    info['ensemble'] = info.apply(
        lambda x: x['ligand'] + '__' + x['active_struc'].split('_')[
            2] + '__' + x['inactive_struc'].split('_')[2],
        axis=1)
    info = info.set_index('ensemble')
    # Remove duplicate ensembles.
    info = info[~info.index.duplicated()]

    ensembles = input_sharded.get_names()['ensemble']
    in_use = info.loc[ensembles]
    active = in_use[in_use['label'] == 'A']
    inactive = in_use[in_use['label'] == 'I']

    # Split by protein.
    proteins = info['protein'].unique()
    i_test, i_val, i_train = splits.random_split(len(proteins), 0.6, 0.2, 0.2)
    p_train = proteins[i_train]
    p_val = proteins[i_val]
    p_test = proteins[i_test]
    logger.info(f'Train proteins: {p_train:}')
    logger.info(f'Val proteins: {p_val:}')
    logger.info(f'Test proteins: {p_test:}')

    train = info[info['protein'].isin(p_train)].index.tolist()
    val = info[info['protein'].isin(p_val)].index.tolist()
    test = info[info['protein'].isin(p_test)].index.tolist()

    logger.info(f'{len(train):} train examples, {len(val):} val examples, '
                f'{len(test):} test examples.')

    keys = input_sharded.get_keys()
    prefix = sh.get_prefix(output_root)
    num_shards = sh.get_num_shards(output_root)
    train_sharded = sh.Sharded(f'{prefix:}_train@{num_shards:}', keys)
    val_sharded = sh.Sharded(f'{prefix:}_val@{num_shards:}', keys)
    test_sharded = sh.Sharded(f'{prefix:}_test@{num_shards:}', keys)

    train_filter_fn = filters.form_filter_against_list(train, 'ensemble')
    val_filter_fn = filters.form_filter_against_list(val, 'ensemble')
    test_filter_fn = filters.form_filter_against_list(test, 'ensemble')

    sho.filter_sharded(input_sharded, train_sharded, train_filter_fn,
                       shuffle_buffer)
    sho.filter_sharded(input_sharded, val_sharded, val_filter_fn,
                       shuffle_buffer)
    sho.filter_sharded(input_sharded, test_sharded, test_filter_fn,
                       shuffle_buffer)
예제 #3
0
def split(input_sharded, output_root, scaffold_data, shuffle_buffer):
    """Split by scaffold."""
    if input_sharded.get_keys() != ['ensemble']:
        raise RuntimeError('Can only apply to sharded by ensemble.')

    logger.info('Process scaffold and ensemble data')
    scaffold_list = scaffold_data['Scaffold'].tolist()
    ensemble_list = []
    for i in range(len(scaffold_data)):
        ensemble = ''
        ensemble += scaffold_data.iloc[i]['ligand']
        ensemble += '__'+scaffold_data.iloc[i]['active_struc'].split('_')[-1]
        ensemble += '__'+scaffold_data.iloc[i]['inactive_struc'].split('_')[-1]
        ensemble_list.append(ensemble)
    ensemble_list = np.array(ensemble_list)
 
    logger.info('Splitting by scaffold')
    train_idx, val_idx, test_idx = splits.scaffold_split(scaffold_list)
    train = ensemble_list[train_idx].tolist()
    val = ensemble_list[val_idx].tolist()
    test = ensemble_list[test_idx].tolist()

    keys = input_sharded.get_keys()
    if keys != ['ensemble']:
        raise RuntimeError('Can only apply to sharded by ensemble.')
    prefix = sh.get_prefix(output_root)
    num_shards = sh.get_num_shards(output_root)
    train_sharded = sh.Sharded(f'{prefix:}_train@{num_shards:}', keys)
    val_sharded = sh.Sharded(f'{prefix:}_val@{num_shards:}', keys)
    test_sharded = sh.Sharded(f'{prefix:}_test@{num_shards:}', keys)

    logger.info('Writing sets')
    train_filter_fn = filters.form_filter_against_list(train, 'ensemble')
    val_filter_fn = filters.form_filter_against_list(val, 'ensemble')
    test_filter_fn = filters.form_filter_against_list(test, 'ensemble')

    sho.filter_sharded(
        input_sharded, train_sharded, train_filter_fn, shuffle_buffer)
    sho.filter_sharded(
        input_sharded, val_sharded, val_filter_fn, shuffle_buffer)
    sho.filter_sharded(
        input_sharded, test_sharded, test_filter_fn, shuffle_buffer)

    # write splits to text files
    np.savetxt(output_root.split('@')[0]+'_train.txt', train, fmt='%s')
    np.savetxt(output_root.split('@')[0]+'_val.txt', val, fmt='%s')
    np.savetxt(output_root.split('@')[0]+'_test.txt', test, fmt='%s')
예제 #4
0
def split(input_sharded, output_root, shuffle_buffer, cutoff=30):
    """Split by sequence identity."""
    if input_sharded.get_keys() != ['ensemble']:
        raise RuntimeError('Can only apply to sharded by ensemble.')

    all_chain_sequences = []
    logger.info('Loading chain sequences')
    for _, shard in input_sharded.iter_shards():
        all_chain_sequences.extend(seq.get_all_chain_sequences_df(shard))

    logger.info('Splitting by cluster')
    train, val, test = atom3d.splits.sequence.cluster_split(
        all_chain_sequences, cutoff)

    # Will just look up ensembles.
    train = [x[0] for x in train]
    val = [x[0] for x in val]
    test = [x[0] for x in test]

    keys = input_sharded.get_keys()
    if keys != ['ensemble']:
        raise RuntimeError('Can only apply to sharded by ensemble.')
    prefix = sh.get_prefix(output_root)
    num_shards = sh.get_num_shards(output_root)
    train_sharded = sh.Sharded(f'{prefix:}_train@{num_shards:}', keys)
    val_sharded = sh.Sharded(f'{prefix:}_val@{num_shards:}', keys)
    test_sharded = sh.Sharded(f'{prefix:}_test@{num_shards:}', keys)

    logger.info('Writing sets')
    train_filter_fn = filters.form_filter_against_list(train, 'ensemble')
    val_filter_fn = filters.form_filter_against_list(val, 'ensemble')
    test_filter_fn = filters.form_filter_against_list(test, 'ensemble')

    sho.filter_sharded(input_sharded, train_sharded, train_filter_fn,
                       shuffle_buffer)
    sho.filter_sharded(input_sharded, val_sharded, val_filter_fn,
                       shuffle_buffer)
    sho.filter_sharded(input_sharded, test_sharded, test_filter_fn,
                       shuffle_buffer)

    # write splits to text files
    np.savetxt(output_root.split('@')[0] + '_train.txt', train, fmt='%s')
    np.savetxt(output_root.split('@')[0] + '_val.txt', val, fmt='%s')
    np.savetxt(output_root.split('@')[0] + '_test.txt', test, fmt='%s')
예제 #5
0
def split(input_sharded, output_root, scaffold_data, shuffle_buffer):
    """Split by sequence identity."""
    if input_sharded.get_keys() != ['ensemble']:
        raise RuntimeError('Can only apply to sharded by ensemble.')

    logger.info('Splitting by scaffold')
    scaffold_list = scaffold_data['Scaffold'].tolist()
    train_idx, val_idx, test_idx = splits.scaffold_split(scaffold_list)
    train = scaffold_data['pdb'][train_idx].tolist()
    val = scaffold_data['pdb'][val_idx].tolist()
    test = scaffold_data['pdb'][test_idx].tolist()

    keys = input_sharded.get_keys()
    if keys != ['ensemble']:
        raise RuntimeError('Can only apply to sharded by ensemble.')
    prefix = sh.get_prefix(output_root)
    num_shards = sh.get_num_shards(output_root)
    train_sharded = sh.Sharded(f'{prefix:}_train@{num_shards:}', keys)
    val_sharded = sh.Sharded(f'{prefix:}_val@{num_shards:}', keys)
    test_sharded = sh.Sharded(f'{prefix:}_test@{num_shards:}', keys)

    logger.info('Writing sets')
    train_filter_fn = filters.form_filter_against_list(train, 'ensemble')
    val_filter_fn = filters.form_filter_against_list(val, 'ensemble')
    test_filter_fn = filters.form_filter_against_list(test, 'ensemble')

    sho.filter_sharded(input_sharded, train_sharded, train_filter_fn,
                       shuffle_buffer)
    sho.filter_sharded(input_sharded, val_sharded, val_filter_fn,
                       shuffle_buffer)
    sho.filter_sharded(input_sharded, test_sharded, test_filter_fn,
                       shuffle_buffer)

    # write splits to text files
    np.savetxt(output_root.split('@')[0] + '_train.txt', train, fmt='%s')
    np.savetxt(output_root.split('@')[0] + '_val.txt', val, fmt='%s')
    np.savetxt(output_root.split('@')[0] + '_test.txt', test, fmt='%s')
예제 #6
0
def split_dataset(dirname_all, dirname_split, input_ds_name, input_splits,
                  input_csvfile, input_exclude):

    # Read the sharded dataset
    input_sharded = sh.Sharded.load(input_ds_name)
    input_shard = input_sharded.read_shard(0)
    input_label = input_sharded.read_shard(0, 'labels')

    # Create output directories
    if not os.path.exists(dirname_all) and dirname_all != '':
        os.makedirs(dirname_all, exist_ok=True)
    if not os.path.exists(dirname_split) and dirname_split != '':
        os.makedirs(dirname_split, exist_ok=True)

    # Correct for ensemble = None
    input_shard['ensemble'] = input_shard['model']

    # Save the full (corrected) dataset
    sharded_all = sh.Sharded(dirname_all + '/qm9_all@1',
                             input_sharded.get_keys())
    sharded_all._write_shard(0, input_shard)
    sharded_all.add_to_shard(0, input_label, 'labels')

    # Read raw and split data
    label_data = pd.read_csv(input_csvfile)
    indices_ex = np.loadtxt(input_exclude, dtype=int)
    indices_tr = np.loadtxt(input_splits + '/indices_train.dat', dtype=int)
    indices_va = np.loadtxt(input_splits + '/indices_valid.dat', dtype=int)
    indices_te = np.loadtxt(input_splits + '/indices_test.dat', dtype=int)

    # Create lists of molecule IDs for exclusion and splits
    mol_ids = label_data['mol_id'].tolist()
    mol_ids_ex = label_data.loc[indices_ex]['mol_id'].tolist()
    mol_ids_te = label_data.loc[indices_te]['mol_id'].tolist()
    mol_ids_va = label_data.loc[indices_va]['mol_id'].tolist()
    mol_ids_tr = label_data.loc[indices_tr]['mol_id'].tolist()

    # Write lists of mol_ids to files
    with open(dirname_split + '/mol_ids_excluded.txt', 'w') as f:
        for mol_id in mol_ids_ex:
            f.write("%s\n" % mol_id)
    with open(dirname_split + '/mol_ids_training.txt', 'w') as f:
        for mol_id in mol_ids_tr:
            f.write("%s\n" % mol_id)
    with open(dirname_split + '/mol_ids_validation.txt', 'w') as f:
        for mol_id in mol_ids_va:
            f.write("%s\n" % mol_id)
    with open(dirname_split + '/mol_ids_test.txt', 'w') as f:
        for mol_id in mol_ids_te:
            f.write("%s\n" % mol_id)

    # Split the labels
    labels_te = input_label.loc[label_data['mol_id'].isin(
        mol_ids_te)].reset_index(drop=True)
    labels_va = input_label.loc[label_data['mol_id'].isin(
        mol_ids_va)].reset_index(drop=True)
    labels_tr = input_label.loc[label_data['mol_id'].isin(
        mol_ids_tr)].reset_index(drop=True)

    # Filter and write out training set
    filter_tr = filters.form_filter_against_list(mol_ids_tr, 'subunit')
    sharded_tr = sh.Sharded(dirname_split + '/train@1', sharded_all.get_keys())
    sho.filter_sharded(sharded_all, sharded_tr, filter_tr)
    sharded_tr.add_to_shard(0, labels_tr, 'labels')

    # Filter and write out validation set
    filter_va = filters.form_filter_against_list(mol_ids_va, 'structure')
    sharded_va = sh.Sharded(dirname_split + '/val@1', sharded_all.get_keys())
    sho.filter_sharded(sharded_all, sharded_va, filter_va)
    sharded_va.add_to_shard(0, labels_va, 'labels')

    # Filter and write out test set
    filter_te = filters.form_filter_against_list(mol_ids_te, 'structure')
    sharded_te = sh.Sharded(dirname_split + '/test@1',
                            input_sharded.get_keys())
    sho.filter_sharded(sharded_all, sharded_te, filter_te)
    sharded_te.add_to_shard(0, labels_te, 'labels')