Ejemplo n.º 1
0
def process_files(input_dir):
    """
    Process all protein (pdb) and ligand (sdf) files in input directory.
    
    :param input dir: directory containing PDBBind data
    :type input_dir: str
    
    :return structure_dict: dictionary containing each structure, keyed by PDB code. Each PDB is a dict containing protein as Biopython object and ligand as RDKit Mol object
    :rtype structure_dict: dict
    """
    structure_dict = {}
    pdb_files = fi.find_files(input_dir, 'pdb')

    for f in tqdm(pdb_files, desc='pdb files'):
        pdb_id = fi.get_pdb_code(f)
        if pdb_id not in structure_dict:
            structure_dict[pdb_id] = {}
        if '_protein' in f:
            prot = ft.read_any(f)
            structure_dict[pdb_id]['protein'] = prot

    lig_files = fi.find_files(input_dir, 'sdf')
    for f in tqdm(lig_files, desc='ligand files'):
        pdb_id = fi.get_pdb_code(f)
        structure_dict[pdb_id]['ligand'] = get_ligand(f)

    return structure_dict
Ejemplo n.º 2
0
def prepare(input_file_path, output_root, split, train_txt, val_txt, test_txt):
    logging.basicConfig(stream=sys.stdout,
                        format='%(asctime)s %(levelname)s %(process)d: ' +
                        '%(message)s',
                        level=logging.INFO)

    # Assume PDB filetype.
    filetype = 'pdb'

    file_list = fi.find_files(input_file_path, fo.patterns[filetype])

    lmdb_path = os.path.join(output_root, 'all')
    logger.info(f'Creating lmdb dataset into {lmdb_path:}...')
    dataset = da.load_dataset(file_list, filetype, transform=pdb_id_transform)
    da.make_lmdb_dataset(dataset, lmdb_path)

    if not split:
        return

    logger.info(f'Splitting indices...')
    lmdb_ds = da.load_dataset(lmdb_path, 'lmdb')

    def _write_split_indices(split_txt, lmdb_ds, output_txt):
        with open(split_txt, 'r') as f:
            split_set = set([x.strip() for x in f.readlines()])
        # Check if the target in id is in the desired target split set
        split_ids = list(filter(lambda id: eval(id)[0] in split_set, lmdb_ds.ids()))
        # Convert ids into lmdb numerical indices and write into txt file
        split_indices = lmdb_ds.ids_to_indices(split_ids)
        with open(output_txt, 'w') as f:
            f.write(str('\n'.join([str(i) for i in split_indices])))

    _write_split_indices(train_txt, lmdb_ds, os.path.join(output_root, 'train_indices.txt'))
    _write_split_indices(val_txt, lmdb_ds, os.path.join(output_root, 'val_indices.txt'))
    _write_split_indices(test_txt, lmdb_ds, os.path.join(output_root, 'test_indices.txt'))
Ejemplo n.º 3
0
def main(data_dir, target_list, labels_dir, struct_format,
         num_cpus, overwrite, tmscore_exe):
    """ Compute rmsd, tm-score, gdt-ts, gdt-ha of decoy structures
    """
    logger = logging.getLogger(__name__)
    logger.info("Compute rmsd, tm-score, gdt-ts, gdt-ha of decoys in {:}".format(
        data_dir))

    os.makedirs(labels_dir, exist_ok=True)

    with open(target_list, 'r') as f:
        requested_filenames = \
            [os.path.join(labels_dir, '{:}.dat'.format(x.strip())) for x in f]
    logger.info("{:} requested keys".format(len(requested_filenames)))

    produced_filenames = []
    if not overwrite:
        produced_filenames = [f for f in fi.find_files(labels_dir, 'dat') \
                              if 'targets' not in f]
    logger.info("{:} produced keys".format(len(produced_filenames)))

    inputs = []
    for filename in requested_filenames:
        if filename in produced_filenames:
            continue
        target_name = util.get_target_name(filename)
        target_dir = os.path.join(data_dir, target_name)
        inputs.append((tmscore_exe, filename, target_name,
                       target_dir, struct_format))

    logger.info("{:} work keys".format(len(inputs)))
    par.submit_jobs(run_tmscore_per_target, inputs, num_cpus)
Ejemplo n.º 4
0
def main(input_dir, output_lmdb, filetype, score_path, serialization_format):
    """Script wrapper to make_lmdb_dataset to create LMDB dataset."""
    logging.basicConfig(stream=sys.stdout,
                        format='%(asctime)s %(levelname)s %(process)d: ' +
                        '%(message)s',
                        level=logging.INFO)

    logger.info(f'filetype: {filetype}')
    if filetype == 'xyz-gdb':
        fileext = 'xyz'
    else:
        fileext = filetype
    file_list = fi.find_files(input_dir, fo.patterns[fileext])
    file_list.sort()
    logger.info(f'Found {len(file_list)} files.')

    if score_path:
        logger.info('Looking up scores...')
        scores = sc.Scores(score_path)
        new_file_list = scores.remove_missing(file_list)
        logger.info(f'Keeping {len(new_file_list)} / {len(file_list)}')
        file_list = new_file_list
    else:
        scores = None

    da.make_lmdb_dataset(file_list, output_lmdb, filetype, scores,
                         serialization_format)
Ejemplo n.º 5
0
def run_tmscore_per_target(tmscore_exe, output_filename, target_name,
                           target_dir, struct_format):
    '''
    Run TM-score to compare all decoy structures of a target with its
    native structure. Write the result into a tab-delimited file with
    the following headers:
        <target>  <decoy>  <rmsd>  <tm_score>  <gdt_ts>  <gdt_ha>
    '''
    native = os.path.join(target_dir, '{:}.{:}'.format(
        target_name, struct_format))
    decoys = fi.find_files(target_dir, struct_format)
    logging.info("Running tm-scores for {:} with {:} decoys".format(
        target_name, len(decoys)))
    rows = []
    for decoy in decoys:
        result = run_tmscore_per_structure(tmscore_exe, decoy, native)
        if result == None:
            logging.warning("Skip target {:} decoy {:} due to failure".format(
                target_name, decoy))
            continue
        rmsd, tm, gdt_ts, gdt_ha = result
        rows.append([util.get_target_name(decoy), util.get_decoy_name(decoy),
                     rmsd, gdt_ts, gdt_ha, tm])
    df = pd.DataFrame(
        rows,
        columns=['target', 'decoy', 'rmsd', 'gdt_ts', 'gdt_ha', 'tm'])
    df = df.sort_values(
            ['rmsd', 'gdt_ts', 'gdt_ha', 'tm', 'decoy'],
            ascending=[True, False, False, False, False]).reset_index(drop=True)
    # Write to file
    df.to_csv(output_filename, sep='\t', index=False)
Ejemplo n.º 6
0
    def __init__(self, data_path):
        self._scores = {}
        score_paths = fi.find_files(data_path, 'sc')
        if len(score_paths) == 0:
            raise RuntimeError('No score files found.')
        for silent_file in score_paths:
            key = self._key_from_silent_file(silent_file)
            self._scores[key] = parse_scores(silent_file)

        self._scores = pd.concat(self._scores)
Ejemplo n.º 7
0
def make_lmdb_dataset(input_file_path, filetype, cutoff, cutoff_type,
                      ensembler, output_root):
    lmdb_path = os.path.join(output_root, 'data')
    os.makedirs(lmdb_path, exist_ok=True)
    logger.info(f'Creating lmdb dataset into {lmdb_path:}...')

    file_list = fi.find_files(input_file_path, fo.patterns[filetype])
    logger.info(f'Found {len(file_list):} pdb files to process...')

    dataset = PPIDataset(file_list, cutoff, cutoff_type, ensembler)
    da.make_lmdb_dataset(dataset, lmdb_path)
    return lmdb_path
Ejemplo n.º 8
0
def prepare(input_file_path, output_root, split, balance, train_txt, val_txt, test_txt, num_threads, start):
    logging.basicConfig(stream=sys.stdout,
                        format='%(asctime)s %(levelname)s %(process)d: ' +
                        '%(message)s',
                       level=logging.INFO)
    
    def _process_chunk(file_list, filetype, lmdb_path, balance):
        logger.info(f'Creating lmdb dataset into {lmdb_path:}...')
        if not os.path.exists(lmdb_path):
            os.makedirs(lmdb_path)
        dataset = da.load_dataset(file_list, filetype, transform=ResTransform(balance=balance))
        da.make_lmdb_dataset(dataset, lmdb_path)

    # Assume PDB filetype.
    filetype = 'pdb'

    file_list = fi.find_files(input_file_path, fo.patterns[filetype])
    
    lmdb_path = os.path.join(output_root, 'all')
    if not os.path.exists(lmdb_path):
        os.makedirs(lmdb_path)
        
    # dataset = da.load_dataset(file_list, filetype, transform=ResTransform(balance=balance))
    # da.make_lmdb_dataset(dataset, lmdb_path)
    
    chunk_size = (len(file_list) // num_threads) + 1
    chunks = [file_list[i:i + chunk_size] for i in range(0, len(file_list), chunk_size)]
    assert len(chunks) == num_threads
    
    for i in range(start,num_threads):
        logger.info(f'Processing chunk {i:}...')
        _process_chunk(chunks[i], 'pdb', f'{lmdb_path}_tmp_{i}', balance)
        

    if not split:
        return

    logger.info(f'Splitting indices...')
    lmdb_ds = da.load_dataset(lmdb_path, 'lmdb')

    def _write_split_indices(split_txt, lmdb_ds, output_txt):
        with open(split_txt, 'r') as f:
            split_set = set([x.strip() for x in f.readlines()])
        # Check if the target in id is in the desired target split set
        split_ids = list(filter(lambda id: id in split_set, lmdb_ds.ids()))
        # Convert ids into lmdb numerical indices and write into txt file
        split_indices = lmdb_ds.ids_to_indices(split_ids)
        with open(output_txt, 'w') as f:
            f.write(str('\n'.join([str(i) for i in split_indices])))

    _write_split_indices(train_txt, lmdb_ds, os.path.join(output_root, 'train_indices.txt'))
    _write_split_indices(val_txt, lmdb_ds, os.path.join(output_root, 'val_indices.txt'))
    _write_split_indices(test_txt, lmdb_ds, os.path.join(output_root, 'test_indices.txt'))
Ejemplo n.º 9
0
def shard_dataset(input_dir, sharded_path, filetype, ensembler):
    """Shard whole input dataset."""
    logging.basicConfig(format='%(asctime)s %(levelname)s %(process)d: ' +
                        '%(message)s',
                        level=logging.INFO)

    dirname = os.path.dirname(sharded_path)
    if not os.path.exists(dirname) and dirname != '':
        os.makedirs(dirname, exist_ok=True)

    files = fi.find_files(input_dir, dt.patterns[filetype])
    ensemble_map = en.ensemblers[ensembler](files)
    Sharded.create_from_ensemble_map(ensemble_map, sharded_path)
Ejemplo n.º 10
0
def gen_splits(target_list, input_dir, output_sharded_train, output_sharded_val,
               output_sharded_test, splitby, test_years, train_years, val_years,
               train_size, val_size, test_size,
               train_decoy_size, val_decoy_size, test_decoy_size,
               exclude_natives, shuffle, random_seed):
    """ Generate train/val/test sets from the input dataset. """
    targets_df = pd.read_csv(
        target_list, delimiter='\s*', engine='python').dropna()

    files = fi.find_files(input_dir, dt.patterns['pdb'])
    structures_df = pd.DataFrame(
        [[util.get_target_name(f), util.get_decoy_name(f), f] for f in files],
        columns = ['target', 'decoy', 'path'])
    # Remove duplicates
    structures_df = structures_df.drop_duplicates(
        subset=['target', 'decoy'], keep='first').reset_index(drop=True)
    structures_df = pd.merge(structures_df, targets_df, on='target')

    # Keep only (target, year) that also appear in structure_df
    targets_df = structures_df[['target', 'year']].drop_duplicates(
        keep='first').reset_index(drop=True)

    if splitby == 'random':
        targets_train, targets_val, targets_test = split_targets_random(
            targets_df, train_size, val_size, test_size, shuffle, random_seed)
    elif splitby == 'year':
        targets_train, targets_val, targets_test = split_targets_by_year(
            targets_df, test_years, train_years, val_years, val_size,
            shuffle, random_seed)
    else:
        assert 'Unrecognized splitby option %s' % splitby

    print('Generating dataset: train ({:} targets), val ({:} targets), '
          'test ({:} targets)'.format(len(targets_train), len(targets_val),
                                      len(targets_test)))

    train_set, val_set, test_set = generate_train_val_targets_tests(
        structures_df, targets_train, targets_val, targets_test,
        train_decoy_size, val_decoy_size, test_decoy_size,
        exclude_natives, random_seed)

    print('Finished generating dataset: train ({:} decoys), val ({:} decoys), '
          'test ({:} decoys)'.format(len(train_set), len(val_set), len(test_set)))

    for (output_sharded, dataset) in [(output_sharded_train, train_set),
                                      (output_sharded_val, val_set),
                                      (output_sharded_test, test_set)]:
        print('\nWriting out dataset to {:}'.format(output_sharded))
        files = dataset.path.unique()
        create_sharded_dataset(files, output_sharded)
Ejemplo n.º 11
0
def load_scores(score_dir):
    """Create target_name -> (subunit_name -> RMS)."""
    score_files = fi.find_files(score_dir, 'sc')
    scores = {
        get_target_name(f): pd.read_csv(f,
                                        delimiter='\s*',
                                        index_col='description',
                                        engine='python')
        for f in score_files
    }
    # If duplicate structures present, remove all but first.
    for x, y in scores.items():
        scores[x] = y.loc[~y.index.duplicated(keep='first')]
    return scores
Ejemplo n.º 12
0
def convert_to_hdf5(input_dir, label_file, hdf_file):
    cif_files = fi.find_files(input_dir, 'cif')
    proteins = []
    pockets = []
    pdb_codes = []
    for f in tqdm(cif_files, desc='reading structures'):
        pdb_code = fi.get_pdb_code(f)
        if '_protein' in f:
            pdb_codes.append(pdb_code)
            df = dt.bp_to_df(dt.read_any(f))
            proteins.append(df)
        elif '_pocket' in f:
            df = dt.bp_to_df(dt.read_any(f))
            pockets.append(df)

    print('converting proteins...')
    protein_df = pd.concat(proteins)
    pocket_df = pd.concat(pockets)
    pdb_codes = pd.DataFrame({'pdb': pdb_codes})

    protein_df.to_hdf(hdf_file, 'proteins')
    pocket_df.to_hdf(hdf_file, 'pockets')
    pdb_codes.to_hdf(hdf_file, 'pdb_codes')

    print('converting ligands...')
    sdf_files = fi.find_files(input_dir, 'sdf')
    big_sdf = os.path.join(input_dir, 'all_ligands.sdf')
    dt.combine_sdfs(sdf_files, big_sdf)
    lig_df = PandasTools.LoadSDF(big_sdf, molColName='Mol')
    lig_df.index = pdb_codes
    lig_df.to_hdf(hdf_file, 'ligands')

    print('converting labels...')
    label_df = pd.read_csv(label_file)
    label_df = label_df.set_index('pdb').reindex(pdb_codes)
    label_df.to_hdf(hdf_file, 'labels')
Ejemplo n.º 13
0
def make_lmdb_dataset(input_file_path, score_path, output_root):
    # Assume PDB filetype.
    filetype = 'pdb'

    scores = Scores(score_path) if score_path else None

    file_list = fi.find_files(input_file_path, fo.patterns[filetype])

    lmdb_path = os.path.join(output_root, 'data')
    os.makedirs(lmdb_path, exist_ok=True)

    logger.info(f'Creating lmdb dataset into {lmdb_path:}...')
    dataset = da.load_dataset(file_list, filetype, transform=scores)
    da.make_lmdb_dataset(dataset, lmdb_path)
    return lmdb_path
Ejemplo n.º 14
0
def read_labels(labels_dir, ext='dat'):
    '''
    Read all label files with extension <ext> in <label_dir> into
    a panda DataFrame.
    '''
    files = fi.find_files(labels_dir, ext)
    frames = []
    for filename in files:
        target_name = get_target_name(filename)
        df = pd.read_csv(filename,
                         delimiter='\s*',
                         engine='python',
                         index_col=[0, 1]).dropna()
        frames.append(df)
    all_df = pd.concat(frames, sort=False)
    return all_df
Ejemplo n.º 15
0
def get_file_list(input_path, filetype):
    if filetype == 'lmdb':
        file_list = [input_path]
    elif os.path.isfile(input_path):
        with open(input_path) as f:
            all_paths = f.readlines()
        input_dir = os.path.dirname(input_path)
        file_list = []
        for x in all_paths:
            x = x.strip()
            if not fo.is_type(x, filetype):
                continue
            x = os.path.join(input_dir, x)
            file_list.append(x)
    else:
        file_list = fi.find_files(input_path,
                                  fo.patterns.get(filetype, filetype + r'$'))
    return sorted(file_list)
Ejemplo n.º 16
0
def prepare(input_file_path, output_root, split, train_txt, val_txt, test_txt):
    # Logger
    logging.basicConfig(stream=sys.stdout,
                        format='%(asctime)s %(levelname)s %(process)d: ' +
                        '%(message)s',
                        level=logging.INFO)
    # Assume GDB-specific version of XYZ format.
    filetype = 'xyz-gdb'
    # Compile a list of the input files
    file_list = fi.find_files(input_file_path, fo.patterns[filetype])
    # Write the LMDB dataset
    lmdb_path = os.path.join(output_root, 'all')
    logger.info(f'Creating lmdb dataset into {lmdb_path:}...')
    dataset = da.load_dataset(
        file_list,
        filetype,
        transform=_add_data_with_subtracted_thermochem_energy)
    da.make_lmdb_dataset(dataset, lmdb_path, filter_fn=bond_filter)
    # Only continue if we want to write split datasets
    if not split:
        return
    logger.info(f'Splitting indices...\n')
    # Load the dataset that has just been created
    lmdb_ds = da.load_dataset(lmdb_path, 'lmdb')
    # Determine and write out the split indices
    indices_train = _write_split_indices(
        train_txt, lmdb_ds, os.path.join(output_root, 'train_indices.txt'))
    indices_val = _write_split_indices(
        val_txt, lmdb_ds, os.path.join(output_root, 'val_indices.txt'))
    indices_test = _write_split_indices(
        test_txt, lmdb_ds, os.path.join(output_root, 'test_indices.txt'))
    # Write the split datasets
    train_dataset, val_dataset, test_dataset = spl.split(
        lmdb_ds, indices_train, indices_val, indices_test)
    da.make_lmdb_dataset(train_dataset,
                         os.path.join(output_root, 'train'),
                         filter_fn=bond_filter)
    da.make_lmdb_dataset(val_dataset,
                         os.path.join(output_root, 'val'),
                         filter_fn=bond_filter)
    da.make_lmdb_dataset(test_dataset,
                         os.path.join(output_root, 'test'),
                         filter_fn=bond_filter)
Ejemplo n.º 17
0
def main(input_dir, output_lmdb, filetype, score_path, serialization_format):
    """Script wrapper to make_lmdb_dataset to create LMDB dataset."""
    logging.basicConfig(stream=sys.stdout,
                        format='%(asctime)s %(levelname)s %(process)d: ' +
                        '%(message)s',
                        level=logging.INFO)

    logger.info(f'filetype: {filetype}')
    if filetype == 'xyz-gdb':
        fileext = 'xyz'
    else:
        fileext = filetype
    file_list = fi.find_files(input_dir, fo.patterns[fileext])
    file_list.sort()
    logger.info(f'Found {len(file_list)} files.')

    dataset = da.load_dataset(file_list, filetype)
    da.make_lmdb_dataset(dataset,
                         output_lmdb,
                         filetype,
                         serialization_format=serialization_format)
Ejemplo n.º 18
0
def main(datapath, out_path):
    valid_pdbs = [fi.get_pdb_code(f) for f in fi.find_files(out_path, 'sdf')]
    dat = []
    with open(os.path.join(datapath, 'index/INDEX_refined_data.2019')) as f:
        for line in f:
            if line.startswith('#'):
                continue
            l = line.strip().split()
            if l[0] not in valid_pdbs:
                continue
            dat.append(l[:5] + l[6:])
    refined_set = pd.DataFrame(dat,
                               columns=[
                                   'pdb', 'res', 'year', 'neglog_aff',
                                   'affinity', 'ref', 'ligand'
                               ])

    refined_set[['measurement',
                 'affinity']] = refined_set['affinity'].str.split('=',
                                                                  expand=True)

    refined_set['ligand'] = refined_set['ligand'].str.strip('()')

    # Remove peptide ligands
    # - refined set size now 4,598

    #     refined_set = refined_set[["-mer" not in l for l in refined_set.ligand]]

    refined_set.to_csv(os.path.join(out_path,
                                    'pdbbind_refined_set_cleaned.csv'),
                       index=False)

    labels = refined_set[['pdb', 'neglog_aff'
                          ]].rename(columns={'neglog_aff': 'label'})

    labels.to_csv(os.path.join(out_path, 'pdbbind_refined_set_labels.csv'),
                  index=False)
Ejemplo n.º 19
0
def prepare(input_file_path, output_root, score_path, structures_per_rna):
    logging.basicConfig(stream=sys.stdout,
                        format='%(asctime)s %(levelname)s %(process)d: ' +
                        '%(message)s',
                        level=logging.INFO)

    # Assume PDB filetype.
    filetype = 'pdb'

    scores = ar.Scores(score_path) if score_path else None

    logger.info(f'Splitting indices')
    file_list = fi.find_files(input_file_path, fo.patterns[filetype])
    random.shuffle(file_list)
    target_indices = col.defaultdict(list)
    for i, f in enumerate(file_list):
        target = get_target(f)
        if len(target_indices[target]) >= structures_per_rna:
            continue
        target_indices[target].append(i)

    dataset = da.load_dataset(file_list, filetype, transform=scores)

    logger.info(f'Writing train')
    train_indices = [f for target in TRAIN for f in target_indices[target]]
    train_dataset = torch.utils.data.Subset(dataset, train_indices)
    da.make_lmdb_dataset(train_dataset, os.path.join(output_root, 'train'))

    logger.info(f'Writing val')
    val_indices = [f for target in VAL for f in target_indices[target]]
    val_dataset = torch.utils.data.Subset(dataset, val_indices)
    da.make_lmdb_dataset(val_dataset, os.path.join(output_root, 'val'))

    logger.info(f'Writing test')
    test_indices = [f for target in TEST for f in target_indices[target]]
    test_dataset = torch.utils.data.Subset(dataset, test_indices)
    da.make_lmdb_dataset(test_dataset, os.path.join(output_root, 'test'))
Ejemplo n.º 20
0
def get_file_list(input_path, filetype):
    if filetype == 'lmdb':
        file_list = [input_path]
    else:
        file_list = fi.find_files(input_path, fo.patterns[filetype])
    return file_list
Ejemplo n.º 21
0
def test_find_files():
    file_list = fi.find_files(pdb_path, 'pdb', relative=None)
    assert file_list == [Path(x) for x in test_file_list]
Ejemplo n.º 22
0
def prepare(input_file_path, output_root, split, train_txt, val_txt, test_txt):
    logging.basicConfig(stream=sys.stdout,
                        format='%(asctime)s %(levelname)s %(process)d: ' +
                        '%(message)s',
                        level=logging.INFO)

    # Assume PDB filetype.
    filetype = 'pdb'

    file_list = fi.find_files(os.path.join(input_file_path, 'mutated'),
                              fo.patterns[filetype])
    transform = MSPTransform(base_file_dir=input_file_path)

    lmdb_path = os.path.join(output_root, 'raw', 'MSP', 'data')
    if not os.path.exists(lmdb_path):
        os.makedirs(lmdb_path)

    logger.info(f'Creating lmdb dataset into {lmdb_path:}...')
    if not os.path.exists(lmdb_path):
        os.makedirs(lmdb_path)
    #dataset = da.load_dataset(file_list, filetype, transform=transform)
    #da.make_lmdb_dataset(dataset, lmdb_path)

    if not split:
        return

    logger.info(f'Splitting indices...')
    lmdb_ds = da.load_dataset(lmdb_path, 'lmdb')

    split_data_path = os.path.join(output_root, 'splits', 'split-by-seqid30',
                                   'data')
    split_idx_path = os.path.join(output_root, 'splits', 'split-by-seqid30',
                                  'indices')
    if not os.path.exists(split_data_path):
        os.makedirs(split_data_path)
    if not os.path.exists(split_idx_path):
        os.makedirs(split_idx_path)

    def _write_split_indices(split_txt, lmdb_ds, output_txt):
        with open(split_txt, 'r') as f:
            split_set = set([x.strip() for x in f.readlines()])
        # Check if the target in id is in the desired target split set
        split_ids = list(filter(lambda id: id in split_set, lmdb_ds.ids()))
        # Convert ids into lmdb numerical indices and write into txt file
        split_indices = lmdb_ds.ids_to_indices(split_ids)
        str_indices = [str(i) for i in split_indices]
        with open(output_txt, 'w') as f:
            f.write(str('\n'.join(str_indices)))
        return split_indices

    logger.info(f'Writing train')
    train_indices = _write_split_indices(
        train_txt, lmdb_ds, os.path.join(split_idx_path, 'train_indices.txt'))
    print(train_indices)
    train_dataset = torch.utils.data.Subset(lmdb_ds, train_indices)
    da.make_lmdb_dataset(train_dataset, os.path.join(split_data_path, 'train'))

    logger.info(f'Writing val')
    val_indices = _write_split_indices(
        val_txt, lmdb_ds, os.path.join(split_idx_path, 'val_indices.txt'))
    val_dataset = torch.utils.data.Subset(lmdb_ds, val_indices)
    da.make_lmdb_dataset(val_dataset, os.path.join(split_data_path, 'val'))

    logger.info(f'Writing test')
    test_indices = _write_split_indices(
        test_txt, lmdb_ds, os.path.join(split_idx_path, 'test_indices.txt'))
    test_dataset = torch.utils.data.Subset(lmdb_ds, test_indices)
    da.make_lmdb_dataset(test_dataset, os.path.join(split_data_path, 'test'))