コード例 #1
0
def identity_split(dataset,
                   cutoff,
                   val_split=0.1,
                   test_split=0.1,
                   min_fam_in_split=5,
                   blast_db=None,
                   random_seed=None):
    """
    Splits pdb dataset using pre-computed sequence identity clusters from PDB.

    Generates train, val, test sets.

    Args:
        cutoff (float):
            sequence identity cutoff (can be .3, .4, .5, .7, .9, .95, 1.0)
        val_split (float): fraction of data used for validation. Default: 0.1
        test_split (float): fraction of data used for testing. Default: 0.1
        min_fam_in_split (int): controls variety of val/test sets. Default: 5
        blast_db (str):
            location of pre-computed BLAST DB for dataset. If None, compute and
            save in 'blast_db'. Default: None
        random_seed (int):  specifies random seed for shuffling. Default: None

    Returns:
        train_set (str[]):  pdbs in the train set
        val_set (str[]):  pdbs in the validation set
        test_set (str[]): pdbs in the test set

    """
    all_chain_sequences = [
        seq.get_chain_sequences(x['atoms']) for x in dataset
    ]
    # Flatten.
    flat_chain_sequences = [
        x for sublist in all_chain_sequences for x in sublist
    ]

    if blast_db is None:
        seq.write_to_blast_db(flat_chain_sequences, 'blast_db')
        blast_db = 'blast_db'

    if random_seed is not None:
        np.random.seed(random_seed)

    n = len(dataset)
    test_size = n * test_split
    val_size = n * val_split

    to_use = set(range(len(all_chain_sequences)))
    logger.info('generating validation set...')
    val_indices, to_use = _create_identity_split(all_chain_sequences, cutoff,
                                                 to_use, val_size,
                                                 min_fam_in_split, blast_db)
    logger.info('generating test set...')
    test_indices, to_use = _create_identity_split(all_chain_sequences, cutoff,
                                                  to_use, test_size,
                                                  min_fam_in_split, blast_db)
    train_indices = to_use

    return splits.split(dataset, train_indices, val_indices, test_indices)
コード例 #2
0
ファイル: prepare_lmdb.py プロジェクト: maschka/atom3d
def split_lmdb_dataset(lmdb_path, train_txt, val_txt, test_txt, split_dir):
    logger.info(f'Splitting indices, load data from {lmdb_path:}...')
    lmdb_ds = da.load_dataset(lmdb_path, 'lmdb')

    def _write_split_indices(split_txt, lmdb_ds, output_txt):
        # Read list of desired <target, decoy>
        split_set = set(map(tuple, pd.read_csv(split_txt, header=None, dtype=str).values))

        # Check if the <target, decoy> id is in the desired split set
        split_ids = list(filter(lambda id: eval(id) in split_set, lmdb_ds.ids()))
        # Convert ids into lmdb numerical indices and write into txt file
        split_indices = lmdb_ds.ids_to_indices(split_ids)
        with open(output_txt, 'w') as f:
            f.write(str('\n'.join([str(i) for i in split_indices])))
        return split_indices

    logger.info(f'Write results to {split_dir:}...')
    os.makedirs(os.path.join(split_dir, 'indices'), exist_ok=True)
    os.makedirs(os.path.join(split_dir, 'data'), exist_ok=True)

    indices_train = _write_split_indices(
        train_txt, lmdb_ds, os.path.join(split_dir, 'indices/train_indices.txt'))
    indices_val = _write_split_indices(
        val_txt, lmdb_ds, os.path.join(split_dir, 'indices/val_indices.txt'))
    indices_test = _write_split_indices(
        test_txt, lmdb_ds, os.path.join(split_dir, 'indices/test_indices.txt'))

    train_dataset, val_dataset, test_dataset = spl.split(
        lmdb_ds, indices_train, indices_val, indices_test)
    da.make_lmdb_dataset(train_dataset, os.path.join(split_dir, 'data/train'))
    da.make_lmdb_dataset(val_dataset, os.path.join(split_dir, 'data/val'))
    da.make_lmdb_dataset(test_dataset, os.path.join(split_dir, 'data/test'))
コード例 #3
0
def split_lmdb_dataset(lmdb_path, train_txt, val_txt, test_txt, split_dir):
    logger.info(f'Splitting indices, load data from {lmdb_path:}...')
    lmdb_ds = da.load_dataset(lmdb_path, 'lmdb')

    def _write_split_indices(split_txt, lmdb_ds, output_txt):
        with open(split_txt, 'r') as f:
            split_set = set([x.strip() for x in f.readlines()])

        # Check if the target in id is in the desired target split set
        split_ids = list(filter(lambda id: id in split_set, lmdb_ds.ids()))
        # Convert ids into lmdb numerical indices and write into txt file
        split_indices = lmdb_ds.ids_to_indices(split_ids)
        with open(output_txt, 'w') as f:
            f.write(str('\n'.join([str(i) for i in split_indices])))
        return split_indices

    logger.info(f'Write results to {split_dir:}...')
    os.makedirs(os.path.join(split_dir, 'indices'), exist_ok=True)
    os.makedirs(os.path.join(split_dir, 'data'), exist_ok=True)

    indices_train = _write_split_indices(
        train_txt, lmdb_ds, os.path.join(split_dir,
                                         'indices/train_indices.txt'))
    indices_val = _write_split_indices(
        val_txt, lmdb_ds, os.path.join(split_dir, 'indices/val_indices.txt'))
    indices_test = _write_split_indices(
        test_txt, lmdb_ds, os.path.join(split_dir, 'indices/test_indices.txt'))

    train_dataset, val_dataset, test_dataset = spl.split(
        lmdb_ds, indices_train, indices_val, indices_test)
    da.make_lmdb_dataset(train_dataset, os.path.join(split_dir, 'data/train'))
    da.make_lmdb_dataset(val_dataset, os.path.join(split_dir, 'data/val'))
    da.make_lmdb_dataset(test_dataset, os.path.join(split_dir, 'data/test'))
コード例 #4
0
def test_split():
    # Load LMDB dataset
    dataset = da.load_dataset('tests/test_data/lmdb', 'lmdb')
    # Split with defined indices
    indices_train, indices_val, indices_test = [3, 0], [2], [1]
    s = spl.split(dataset, indices_train, indices_val, indices_test)
    train_dataset, val_dataset, test_dataset = s
    # Check whether the frames are in the correct dataset
    assert dataset[0]['atoms'].equals(train_dataset[1]['atoms'])
    assert dataset[1]['atoms'].equals(test_dataset[0]['atoms'])
    assert dataset[2]['atoms'].equals(val_dataset[0]['atoms'])
    assert dataset[3]['atoms'].equals(train_dataset[0]['atoms'])
コード例 #5
0
def prepare(input_file_path, output_root, split, train_txt, val_txt, test_txt,
            score_path):
    logging.basicConfig(stream=sys.stdout,
                        format='%(asctime)s %(levelname)s %(process)d: ' +
                        '%(message)s',
                        level=logging.INFO)

    # Assume PDB filetype.
    filetype = 'pdb'

    scores = Scores(score_path) if score_path else None

    file_list = fi.find_files(input_file_path, fo.patterns[filetype])

    lmdb_path = os.path.join(output_root, 'all')
    logger.info(f'Creating lmdb dataset into {lmdb_path:}...')
    dataset = da.load_dataset(file_list, filetype, transform=scores)
    da.make_lmdb_dataset(dataset, lmdb_path)

    if not split:
        return

    logger.info(f'Splitting indices...')
    lmdb_ds = da.load_dataset(lmdb_path, 'lmdb')

    def _write_split_indices(split_txt, lmdb_ds, output_txt):
        with open(split_txt, 'r') as f:
            split_set = set([x.strip() for x in f.readlines()])
        # Check if the target in id is in the desired target split set
        split_ids = list(
            filter(lambda id: eval(id)[0] in split_set, lmdb_ds.ids()))
        # Convert ids into lmdb numerical indices and write into txt file
        split_indices = lmdb_ds.ids_to_indices(split_ids)
        with open(output_txt, 'w') as f:
            f.write(str('\n'.join([str(i) for i in split_indices])))
        return split_indices

    indices_train = _write_split_indices(
        train_txt, lmdb_ds, os.path.join(output_root, 'train_indices.txt'))
    indices_val = _write_split_indices(
        val_txt, lmdb_ds, os.path.join(output_root, 'val_indices.txt'))
    indices_test = _write_split_indices(
        test_txt, lmdb_ds, os.path.join(output_root, 'test_indices.txt'))

    train_dataset, val_dataset, test_dataset = spl.split(
        lmdb_ds, indices_train, indices_val, indices_test)
    da.make_lmdb_dataset(train_dataset, os.path.join(output_root, 'train'))
    da.make_lmdb_dataset(val_dataset, os.path.join(output_root, 'val'))
    da.make_lmdb_dataset(test_dataset, os.path.join(output_root, 'test'))
コード例 #6
0
def prepare(input_file_path, output_root, split, train_txt, val_txt, test_txt,
            score_path):
    logging.basicConfig(stream=sys.stdout,
                        format='%(asctime)s %(levelname)s %(process)d: ' +
                        '%(message)s',
                        level=logging.INFO)

    scores = Scores(score_path) if score_path else None

    # Assume subdirectories containing the protein/pocket/ligand files are
    # structured as <input_file_path>/<pdbcode>
    pdbcodes = os.listdir(input_file_path)

    lmdb_path = os.path.join(output_root, 'all')
    logger.info(f'Creating lmdb dataset into {lmdb_path:}...')

    dataset = LBADataset(input_file_path, pdbcodes, transform=scores)
    da.make_lmdb_dataset(dataset, lmdb_path)

    if not split:
        return

    logger.info(f'Splitting indices...')
    lmdb_ds = da.load_dataset(lmdb_path, 'lmdb')

    def _write_split_indices(split_txt, lmdb_ds, output_txt):
        with open(split_txt, 'r') as f:
            split_set = set([x.strip() for x in f.readlines()])
        # Check if the pdbcode in id is in the desired pdbcode split set
        split_ids = list(filter(lambda id: id in split_set, lmdb_ds.ids()))
        # Convert ids into lmdb numerical indices and write into txt file
        split_indices = lmdb_ds.ids_to_indices(split_ids)
        with open(output_txt, 'w') as f:
            f.write(str('\n'.join([str(i) for i in split_indices])))
        return split_indices

    indices_train = _write_split_indices(
        train_txt, lmdb_ds, os.path.join(output_root, 'train_indices.txt'))
    indices_val = _write_split_indices(
        val_txt, lmdb_ds, os.path.join(output_root, 'val_indices.txt'))
    indices_test = _write_split_indices(
        test_txt, lmdb_ds, os.path.join(output_root, 'test_indices.txt'))

    train_dataset, val_dataset, test_dataset = spl.split(
        lmdb_ds, indices_train, indices_val, indices_test)
    da.make_lmdb_dataset(train_dataset, os.path.join(output_root, 'train'))
    da.make_lmdb_dataset(val_dataset, os.path.join(output_root, 'val'))
    da.make_lmdb_dataset(test_dataset, os.path.join(output_root, 'test'))
コード例 #7
0
ファイル: prepare_lmdb.py プロジェクト: drorlab/atom3d
def prepare(input_file_path, output_root, split, train_txt, val_txt, test_txt):
    # Logger
    logging.basicConfig(stream=sys.stdout,
                        format='%(asctime)s %(levelname)s %(process)d: ' +
                        '%(message)s',
                        level=logging.INFO)
    # Assume GDB-specific version of XYZ format.
    filetype = 'xyz-gdb'
    # Compile a list of the input files
    file_list = fi.find_files(input_file_path, fo.patterns[filetype])
    # Write the LMDB dataset
    lmdb_path = os.path.join(output_root, 'all')
    logger.info(f'Creating lmdb dataset into {lmdb_path:}...')
    dataset = da.load_dataset(
        file_list,
        filetype,
        transform=_add_data_with_subtracted_thermochem_energy)
    da.make_lmdb_dataset(dataset, lmdb_path, filter_fn=bond_filter)
    # Only continue if we want to write split datasets
    if not split:
        return
    logger.info(f'Splitting indices...\n')
    # Load the dataset that has just been created
    lmdb_ds = da.load_dataset(lmdb_path, 'lmdb')
    # Determine and write out the split indices
    indices_train = _write_split_indices(
        train_txt, lmdb_ds, os.path.join(output_root, 'train_indices.txt'))
    indices_val = _write_split_indices(
        val_txt, lmdb_ds, os.path.join(output_root, 'val_indices.txt'))
    indices_test = _write_split_indices(
        test_txt, lmdb_ds, os.path.join(output_root, 'test_indices.txt'))
    # Write the split datasets
    train_dataset, val_dataset, test_dataset = spl.split(
        lmdb_ds, indices_train, indices_val, indices_test)
    da.make_lmdb_dataset(train_dataset,
                         os.path.join(output_root, 'train'),
                         filter_fn=bond_filter)
    da.make_lmdb_dataset(val_dataset,
                         os.path.join(output_root, 'val'),
                         filter_fn=bond_filter)
    da.make_lmdb_dataset(test_dataset,
                         os.path.join(output_root, 'test'),
                         filter_fn=bond_filter)
コード例 #8
0
def cluster_split(dataset,
                  cutoff,
                  val_split=0.1,
                  test_split=0.1,
                  min_fam_in_split=5,
                  random_seed=None):
    """
    Splits pdb dataset using pre-computed sequence identity clusters from PDB.

    Generates train, val, test sets.

    We assume there is one PDB code per entry in dataset.

    Args:
        cutoff (float):
            sequence identity cutoff (can be .3, .4, .5, .7, .9, .95, 1.0)
        val_split (float): fraction of data used for validation. Default: 0.1
        test_split (float): fraction of data used for testing. Default: 0.1
        min_fam_in_split (int): controls variety of val/test sets. Default: 5
        random_seed (int):  specifies random seed for shuffling. Default: None

    Returns:
        train_set (str[]):  pdbs in the train set
        val_set (str[]):  pdbs in the validation set
        test_set (str[]): pdbs in the test set

    """
    if random_seed is not None:
        np.random.seed(random_seed)

    logger.info('Loading chain sequences')
    all_chain_sequences = [
        seq.get_chain_sequences(x['atoms']) for x in dataset
    ]

    pdb_codes = np.array(
        [fi.get_pdb_code(x[0][0][0]) for x in all_chain_sequences])
    n_orig = len(np.unique(pdb_codes))
    clusterings = seq.get_pdb_clusters(cutoff, np.unique(pdb_codes))

    # If code not present in clustering, we don't use.
    to_use = [i for (i, x) in enumerate(pdb_codes) if x in clusterings[0]]
    n = len(np.unique(pdb_codes[to_use]))
    to_use = set(to_use)

    logger.info(f'Removing {n_orig - n:} / {n_orig:} '
                f'sequences due to not finding in clustering.')

    test_size = n * test_split
    val_size = n * val_split

    logger.info('generating validation set...')
    val_indices, to_use = _create_cluster_split(all_chain_sequences,
                                                clusterings, to_use, val_size,
                                                min_fam_in_split)
    logger.info('generating test set...')
    test_indices, to_use = _create_cluster_split(all_chain_sequences,
                                                 clusterings, to_use,
                                                 test_size, min_fam_in_split)
    train_indices = to_use

    return splits.split(dataset, train_indices, val_indices, test_indices)
コード例 #9
0
def cluster_split(dataset,
                  cutoff,
                  val_split=0.1,
                  test_split=0.1,
                  min_fam_in_split=5,
                  random_seed=None):
    """Splits pdb dataset using pre-computed sequence identity clusters from PDB, ensuring that no cluster spans multiple splits. 

    Clusters are selected randomly into validation and test sets, but to ensure that there is some diversity in each set (i.e. a split does not consist of a single sequence cluster), a minimum number of clusters in each split is enforced. . Some data examples may be removed in order to satisfy this constraint.
    
    This function assumes that the PDB code or PDB filename exists in the ``ensemble`` field of the ``atoms`` dataframe in the dataset.

    :param dataset: Dataset to perform the split on.
    :type dataset: ATOM3D Dataset
    :param cutoff: Sequence identity cutoff. Possible values: 0.3, 0.4, 0.5, 0.7, 0.9, 0.95, 1.0
    :type cutoff: float
    :param val_split: Fraction of data used in validation set, defaults to 0.1
    :type val_split: float, optional
    :param test_split: Fraction of data used in test set, defaults to 0.1
    :type test_split: float, optional
    :param min_fam_in_split: Minimum number of sequence clusters to be included in validation and test sets, defaults to 5
    :type min_fam_in_split: int, optional
    :param random_seed: Random seed for sampling clusters, defaults to None
    :type random_seed: int, optional

    :return: Tuple containing training, validation, and test sets, each as ATOM3D Dataset objects.
    :rtype: Tuple[Dataset]
    """

    if random_seed is not None:
        np.random.seed(random_seed)

    logger.info('Loading chain sequences')
    all_chain_sequences = [
        seq.get_chain_sequences(x['atoms']) for x in dataset
    ]

    pdb_codes = np.array(
        [fi.get_pdb_code(x[0][0][0]) for x in all_chain_sequences])
    n_orig = len(np.unique(pdb_codes))
    clusterings = seq.get_pdb_clusters(cutoff, np.unique(pdb_codes))

    # If code not present in clustering, we don't use.
    to_use = [i for (i, x) in enumerate(pdb_codes) if x in clusterings[0]]
    n = len(np.unique(pdb_codes[to_use]))
    to_use = set(to_use)

    logger.info(f'Removing {n_orig - n:} / {n_orig:} '
                f'sequences due to not finding in clustering.')

    test_size = n * test_split
    val_size = n * val_split

    logger.info('generating validation set...')
    val_indices, to_use = _create_cluster_split(all_chain_sequences,
                                                clusterings, to_use, val_size,
                                                min_fam_in_split)
    logger.info('generating test set...')
    test_indices, to_use = _create_cluster_split(all_chain_sequences,
                                                 clusterings, to_use,
                                                 test_size, min_fam_in_split)
    train_indices = to_use

    return splits.split(dataset, train_indices, val_indices, test_indices)
コード例 #10
0
def identity_split(dataset,
                   cutoff,
                   val_split=0.1,
                   test_split=0.1,
                   min_fam_in_split=5,
                   blast_db=None,
                   random_seed=None):
    """Splits a dataset of proteins by sequence identity at specified cutoff value. Proteins are randomly selected to be placed in validation and test splits, along with all proteins within ``cutoff`` sequence identity (calculated by BLAST).

        To ensure that there is some diversity in each set (i.e. a split does not consist of a single sequence cluster), a minimum number of clusters in each split is enforced. Some data examples may be removed in order to satisfy this constraint.

        Note that the construction of this function means that it is effectively a cluster split. All examples within ``cutoff`` of the sampled query protein are added to validation set, meaning that some examples near the edge of the cluster may in fact share less than ``cutoff`` sequence identity with other proteins in the dataset.
        Therefore, this does not satisfy the constraints for a strict sequence identity cutoff: that 
        (1) no protein in validation split shares greater than ``cutoff`` sequence identity with any protein in the train set, and 
        (2) no protein in the test split shares greater than ``cutoff`` sequence identity with any protein in either train or validation sets.
        A function that satisfies this more strict definition is currently under development.



    :param dataset: Dataset to perform the split on.
    :type dataset: ATOM3D Dataset
    :param cutoff: Sequence identity cutoff, between 0 and 1
    :type cutoff: float
    :param val_split: Fraction of data used in validation set, defaults to 0.1
    :type val_split: float, optional
    :param test_split: Fraction of data used in test set, defaults to 0.1
    :type test_split: float, optional
    :param min_fam_in_split: Minimum number of sequence clusters to be included in validation and test sets, defaults to 5
    :type min_fam_in_split: int, optional
    :param random_seed: Random seed for sampling clusters, defaults to None
    :type random_seed: int, optional

    :return: Tuple containing training, validation, and test sets, each as ATOM3D Dataset objects.
    :rtype: Tuple[Dataset]
    """

    all_chain_sequences = [
        seq.get_chain_sequences(x['atoms']) for x in dataset
    ]
    # Flatten.
    flat_chain_sequences = [
        x for sublist in all_chain_sequences for x in sublist
    ]

    # write all sequences to BLAST-formatted database
    if blast_db is None:
        seq.write_to_blast_db(flat_chain_sequences, 'blast_db')
        blast_db = 'blast_db'

    if random_seed is not None:
        np.random.seed(random_seed)

    n = len(dataset)
    test_size = n * test_split
    val_size = n * val_split

    to_use = set(range(len(all_chain_sequences)))
    logger.info('generating validation set...')
    val_indices, to_use = _create_identity_split(all_chain_sequences, cutoff,
                                                 to_use, val_size,
                                                 min_fam_in_split, blast_db)
    logger.info('generating test set...')
    test_indices, to_use = _create_identity_split(all_chain_sequences, cutoff,
                                                  to_use, test_size,
                                                  min_fam_in_split, blast_db)
    train_indices = to_use

    return splits.split(dataset, train_indices, val_indices, test_indices)