Example #1
0
def prepare(input_file_path, output_root, split, train_txt, val_txt, test_txt):
    logging.basicConfig(stream=sys.stdout,
                        format='%(asctime)s %(levelname)s %(process)d: ' +
                        '%(message)s',
                        level=logging.INFO)

    # Assume PDB filetype.
    filetype = 'pdb'

    file_list = fi.find_files(input_file_path, fo.patterns[filetype])

    lmdb_path = os.path.join(output_root, 'all')
    logger.info(f'Creating lmdb dataset into {lmdb_path:}...')
    dataset = da.load_dataset(file_list, filetype, transform=pdb_id_transform)
    da.make_lmdb_dataset(dataset, lmdb_path)

    if not split:
        return

    logger.info(f'Splitting indices...')
    lmdb_ds = da.load_dataset(lmdb_path, 'lmdb')

    def _write_split_indices(split_txt, lmdb_ds, output_txt):
        with open(split_txt, 'r') as f:
            split_set = set([x.strip() for x in f.readlines()])
        # Check if the target in id is in the desired target split set
        split_ids = list(filter(lambda id: eval(id)[0] in split_set, lmdb_ds.ids()))
        # Convert ids into lmdb numerical indices and write into txt file
        split_indices = lmdb_ds.ids_to_indices(split_ids)
        with open(output_txt, 'w') as f:
            f.write(str('\n'.join([str(i) for i in split_indices])))

    _write_split_indices(train_txt, lmdb_ds, os.path.join(output_root, 'train_indices.txt'))
    _write_split_indices(val_txt, lmdb_ds, os.path.join(output_root, 'val_indices.txt'))
    _write_split_indices(test_txt, lmdb_ds, os.path.join(output_root, 'test_indices.txt'))
Example #2
0
    def _load_datasets(self, input_file_path, pdbcodes):
        protein_list = []
        pocket_list = []
        ligand_list = []
        for pdbcode in pdbcodes:
            protein_path = os.path.join(input_file_path,
                                        f'{pdbcode:}/{pdbcode:}_protein.pdb')
            pocket_path = os.path.join(input_file_path,
                                       f'{pdbcode:}/{pdbcode:}_pocket.pdb')
            ligand_path = os.path.join(input_file_path,
                                       f'{pdbcode:}/{pdbcode:}_ligand.sdf')
            if os.path.exists(protein_path) and os.path.exists(pocket_path) and \
                    os.path.exists(ligand_path):
                protein_list.append(protein_path)
                pocket_list.append(pocket_path)
                ligand_list.append(ligand_path)
        assert len(protein_list) == len(pocket_list) == len(ligand_list)
        logger.info(f'Found {len(protein_list):} protein/ligand files...')

        self._protein_dataset = da.load_dataset(
            protein_list, 'pdb', transform=SequenceReader(input_file_path))
        self._pocket_dataset = da.load_dataset(pocket_list,
                                               'pdb',
                                               transform=None)
        self._ligand_dataset = da.load_dataset(ligand_list,
                                               'sdf',
                                               include_bonds=True,
                                               transform=SmilesReader())
Example #3
0
def split_lmdb_dataset(lmdb_path, train_txt, val_txt, test_txt, split_dir):
    logger.info(f'Splitting indices, load data from {lmdb_path:}...')
    lmdb_ds = da.load_dataset(lmdb_path, 'lmdb')

    def _write_split_indices(split_txt, lmdb_ds, output_txt):
        # Read list of desired <target, decoy>
        split_set = set(map(tuple, pd.read_csv(split_txt, header=None, dtype=str).values))

        # Check if the <target, decoy> id is in the desired split set
        split_ids = list(filter(lambda id: eval(id) in split_set, lmdb_ds.ids()))
        # Convert ids into lmdb numerical indices and write into txt file
        split_indices = lmdb_ds.ids_to_indices(split_ids)
        with open(output_txt, 'w') as f:
            f.write(str('\n'.join([str(i) for i in split_indices])))
        return split_indices

    logger.info(f'Write results to {split_dir:}...')
    os.makedirs(os.path.join(split_dir, 'indices'), exist_ok=True)
    os.makedirs(os.path.join(split_dir, 'data'), exist_ok=True)

    indices_train = _write_split_indices(
        train_txt, lmdb_ds, os.path.join(split_dir, 'indices/train_indices.txt'))
    indices_val = _write_split_indices(
        val_txt, lmdb_ds, os.path.join(split_dir, 'indices/val_indices.txt'))
    indices_test = _write_split_indices(
        test_txt, lmdb_ds, os.path.join(split_dir, 'indices/test_indices.txt'))

    train_dataset, val_dataset, test_dataset = spl.split(
        lmdb_ds, indices_train, indices_val, indices_test)
    da.make_lmdb_dataset(train_dataset, os.path.join(split_dir, 'data/train'))
    da.make_lmdb_dataset(val_dataset, os.path.join(split_dir, 'data/val'))
    da.make_lmdb_dataset(test_dataset, os.path.join(split_dir, 'data/test'))
Example #4
0
def split_lmdb_dataset(lmdb_path, train_txt, val_txt, test_txt, split_dir):
    logger.info(f'Splitting indices, load data from {lmdb_path:}...')
    lmdb_ds = da.load_dataset(lmdb_path, 'lmdb')

    def _write_split_indices(split_txt, lmdb_ds, output_txt):
        with open(split_txt, 'r') as f:
            split_set = set([x.strip() for x in f.readlines()])

        # Check if the target in id is in the desired target split set
        split_ids = list(filter(lambda id: id in split_set, lmdb_ds.ids()))
        # Convert ids into lmdb numerical indices and write into txt file
        split_indices = lmdb_ds.ids_to_indices(split_ids)
        with open(output_txt, 'w') as f:
            f.write(str('\n'.join([str(i) for i in split_indices])))
        return split_indices

    logger.info(f'Write results to {split_dir:}...')
    os.makedirs(os.path.join(split_dir, 'indices'), exist_ok=True)
    os.makedirs(os.path.join(split_dir, 'data'), exist_ok=True)

    indices_train = _write_split_indices(
        train_txt, lmdb_ds, os.path.join(split_dir,
                                         'indices/train_indices.txt'))
    indices_val = _write_split_indices(
        val_txt, lmdb_ds, os.path.join(split_dir, 'indices/val_indices.txt'))
    indices_test = _write_split_indices(
        test_txt, lmdb_ds, os.path.join(split_dir, 'indices/test_indices.txt'))

    train_dataset, val_dataset, test_dataset = spl.split(
        lmdb_ds, indices_train, indices_val, indices_test)
    da.make_lmdb_dataset(train_dataset, os.path.join(split_dir, 'data/train'))
    da.make_lmdb_dataset(val_dataset, os.path.join(split_dir, 'data/val'))
    da.make_lmdb_dataset(test_dataset, os.path.join(split_dir, 'data/test'))
Example #5
0
 def _process_chunk(file_list, filetype, lmdb_path, balance):
     logger.info(f'Creating lmdb dataset into {lmdb_path:}...')
     if not os.path.exists(lmdb_path):
         os.makedirs(lmdb_path)
     dataset = da.load_dataset(file_list,
                               filetype,
                               transform=ResTransform(balance=balance))
     da.make_lmdb_dataset(dataset, lmdb_path)
Example #6
0
    def _load_active_and_inactive_datasets(self, input_file_path, id_codes):
        A_list = [] #active conformations
        I_list = [] #inactive conformations
        for code in id_codes: 
            tokens = code.split('__')
            ligand = tokens[0]
            pdb1 = tokens[1]
            pdb2 = tokens[2]
            A_path = os.path.join(input_file_path, f'{ligand}_to_{pdb1}.pdb')
            I_path = os.path.join(input_file_path, f'{ligand}_to_{pdb2}.pdb')
            
            if os.path.exists(A_path) and os.path.exists(I_path):
                A_list.append(A_path)
                I_list.append(I_path)

        assert len(A_list) == len(I_list)
        logger.info(f'Found {len(A_list):} pairs of protein files...')
        self._active_dataset = da.load_dataset(A_list, 'pdb')
        self._inactive_dataset = da.load_dataset(I_list, 'pdb')
Example #7
0
def prepare(input_file_path, output_root, split, train_txt, val_txt, test_txt):
    # Logger
    logging.basicConfig(stream=sys.stdout,
                        format='%(asctime)s %(levelname)s %(process)d: ' +
                        '%(message)s',
                        level=logging.INFO)
    # Assume GDB-specific version of XYZ format.
    filetype = 'xyz-gdb'
    # Compile a list of the input files
    file_list = fi.find_files(input_file_path, fo.patterns[filetype])
    # Write the LMDB dataset
    lmdb_path = os.path.join(output_root, 'all')
    logger.info(f'Creating lmdb dataset into {lmdb_path:}...')
    dataset = da.load_dataset(
        file_list,
        filetype,
        transform=_add_data_with_subtracted_thermochem_energy)
    da.make_lmdb_dataset(dataset, lmdb_path, filter_fn=bond_filter)
    # Only continue if we want to write split datasets
    if not split:
        return
    logger.info(f'Splitting indices...\n')
    # Load the dataset that has just been created
    lmdb_ds = da.load_dataset(lmdb_path, 'lmdb')
    # Determine and write out the split indices
    indices_train = _write_split_indices(
        train_txt, lmdb_ds, os.path.join(output_root, 'train_indices.txt'))
    indices_val = _write_split_indices(
        val_txt, lmdb_ds, os.path.join(output_root, 'val_indices.txt'))
    indices_test = _write_split_indices(
        test_txt, lmdb_ds, os.path.join(output_root, 'test_indices.txt'))
    # Write the split datasets
    train_dataset, val_dataset, test_dataset = spl.split(
        lmdb_ds, indices_train, indices_val, indices_test)
    da.make_lmdb_dataset(train_dataset,
                         os.path.join(output_root, 'train'),
                         filter_fn=bond_filter)
    da.make_lmdb_dataset(val_dataset,
                         os.path.join(output_root, 'val'),
                         filter_fn=bond_filter)
    da.make_lmdb_dataset(test_dataset,
                         os.path.join(output_root, 'test'),
                         filter_fn=bond_filter)
Example #8
0
def prepare(input_file_path, output_root, split, balance, train_txt, val_txt, test_txt, num_threads, start):
    logging.basicConfig(stream=sys.stdout,
                        format='%(asctime)s %(levelname)s %(process)d: ' +
                        '%(message)s',
                       level=logging.INFO)
    
    def _process_chunk(file_list, filetype, lmdb_path, balance):
        logger.info(f'Creating lmdb dataset into {lmdb_path:}...')
        if not os.path.exists(lmdb_path):
            os.makedirs(lmdb_path)
        dataset = da.load_dataset(file_list, filetype, transform=ResTransform(balance=balance))
        da.make_lmdb_dataset(dataset, lmdb_path)

    # Assume PDB filetype.
    filetype = 'pdb'

    file_list = fi.find_files(input_file_path, fo.patterns[filetype])
    
    lmdb_path = os.path.join(output_root, 'all')
    if not os.path.exists(lmdb_path):
        os.makedirs(lmdb_path)
        
    # dataset = da.load_dataset(file_list, filetype, transform=ResTransform(balance=balance))
    # da.make_lmdb_dataset(dataset, lmdb_path)
    
    chunk_size = (len(file_list) // num_threads) + 1
    chunks = [file_list[i:i + chunk_size] for i in range(0, len(file_list), chunk_size)]
    assert len(chunks) == num_threads
    
    for i in range(start,num_threads):
        logger.info(f'Processing chunk {i:}...')
        _process_chunk(chunks[i], 'pdb', f'{lmdb_path}_tmp_{i}', balance)
        

    if not split:
        return

    logger.info(f'Splitting indices...')
    lmdb_ds = da.load_dataset(lmdb_path, 'lmdb')

    def _write_split_indices(split_txt, lmdb_ds, output_txt):
        with open(split_txt, 'r') as f:
            split_set = set([x.strip() for x in f.readlines()])
        # Check if the target in id is in the desired target split set
        split_ids = list(filter(lambda id: id in split_set, lmdb_ds.ids()))
        # Convert ids into lmdb numerical indices and write into txt file
        split_indices = lmdb_ds.ids_to_indices(split_ids)
        with open(output_txt, 'w') as f:
            f.write(str('\n'.join([str(i) for i in split_indices])))

    _write_split_indices(train_txt, lmdb_ds, os.path.join(output_root, 'train_indices.txt'))
    _write_split_indices(val_txt, lmdb_ds, os.path.join(output_root, 'val_indices.txt'))
    _write_split_indices(test_txt, lmdb_ds, os.path.join(output_root, 'test_indices.txt'))
Example #9
0
def make_lmdb_dataset(input_file_path, score_path, output_root):
    # Assume PDB filetype.
    filetype = 'pdb'

    scores = Scores(score_path) if score_path else None

    file_list = fi.find_files(input_file_path, fo.patterns[filetype])

    lmdb_path = os.path.join(output_root, 'data')
    os.makedirs(lmdb_path, exist_ok=True)

    logger.info(f'Creating lmdb dataset into {lmdb_path:}...')
    dataset = da.load_dataset(file_list, filetype, transform=scores)
    da.make_lmdb_dataset(dataset, lmdb_path)
    return lmdb_path
Example #10
0
def prepare(input_file_path, output_root, split, train_txt, val_txt, test_txt,
            score_path):
    logging.basicConfig(stream=sys.stdout,
                        format='%(asctime)s %(levelname)s %(process)d: ' +
                        '%(message)s',
                        level=logging.INFO)

    scores = Scores(score_path) if score_path else None

    # Assume subdirectories containing the protein/pocket/ligand files are
    # structured as <input_file_path>/<pdbcode>
    pdbcodes = os.listdir(input_file_path)

    lmdb_path = os.path.join(output_root, 'all')
    logger.info(f'Creating lmdb dataset into {lmdb_path:}...')

    dataset = LBADataset(input_file_path, pdbcodes, transform=scores)
    da.make_lmdb_dataset(dataset, lmdb_path)

    if not split:
        return

    logger.info(f'Splitting indices...')
    lmdb_ds = da.load_dataset(lmdb_path, 'lmdb')

    def _write_split_indices(split_txt, lmdb_ds, output_txt):
        with open(split_txt, 'r') as f:
            split_set = set([x.strip() for x in f.readlines()])
        # Check if the pdbcode in id is in the desired pdbcode split set
        split_ids = list(filter(lambda id: id in split_set, lmdb_ds.ids()))
        # Convert ids into lmdb numerical indices and write into txt file
        split_indices = lmdb_ds.ids_to_indices(split_ids)
        with open(output_txt, 'w') as f:
            f.write(str('\n'.join([str(i) for i in split_indices])))
        return split_indices

    indices_train = _write_split_indices(
        train_txt, lmdb_ds, os.path.join(output_root, 'train_indices.txt'))
    indices_val = _write_split_indices(
        val_txt, lmdb_ds, os.path.join(output_root, 'val_indices.txt'))
    indices_test = _write_split_indices(
        test_txt, lmdb_ds, os.path.join(output_root, 'test_indices.txt'))

    train_dataset, val_dataset, test_dataset = spl.split(
        lmdb_ds, indices_train, indices_val, indices_test)
    da.make_lmdb_dataset(train_dataset, os.path.join(output_root, 'train'))
    da.make_lmdb_dataset(val_dataset, os.path.join(output_root, 'val'))
    da.make_lmdb_dataset(test_dataset, os.path.join(output_root, 'test'))
Example #11
0
def split(in_path, output_root, train_txt, val_txt, test_txt):
    dataset = da.load_dataset(in_path, 'lmdb')

    logger.info(f'Writing train')
    train_indices = read_split_file(train_txt)
    train_dataset = torch.utils.data.Subset(dataset, train_indices)
    da.make_lmdb_dataset(train_dataset, os.path.join(output_root, 'train'))

    logger.info(f'Writing val')
    val_indices = read_split_file(val_txt)
    val_dataset = torch.utils.data.Subset(dataset, val_indices)
    da.make_lmdb_dataset(val_dataset, os.path.join(output_root, 'val'))

    logger.info(f'Writing test')
    test_indices = read_split_file(test_txt)
    test_dataset = torch.utils.data.Subset(dataset, test_indices)
    da.make_lmdb_dataset(test_dataset, os.path.join(output_root, 'test'))
Example #12
0
def main(input_dir, output_lmdb, filetype, score_path, serialization_format):
    """Script wrapper to make_lmdb_dataset to create LMDB dataset."""
    logging.basicConfig(stream=sys.stdout,
                        format='%(asctime)s %(levelname)s %(process)d: ' +
                        '%(message)s',
                        level=logging.INFO)

    logger.info(f'filetype: {filetype}')
    if filetype == 'xyz-gdb':
        fileext = 'xyz'
    else:
        fileext = filetype
    file_list = da.get_file_list(input_dir, fileext)
    logger.info(f'Found {len(file_list)} files.')

    dataset = da.load_dataset(file_list, filetype)
    da.make_lmdb_dataset(
        dataset, output_lmdb, serialization_format=serialization_format)
Example #13
0
def prepare(input_file_path, output_root, score_path, structures_per_rna):
    logging.basicConfig(stream=sys.stdout,
                        format='%(asctime)s %(levelname)s %(process)d: ' +
                        '%(message)s',
                        level=logging.INFO)

    # Assume PDB filetype.
    filetype = 'pdb'

    scores = ar.Scores(score_path) if score_path else None

    logger.info(f'Splitting indices')
    file_list = fi.find_files(input_file_path, fo.patterns[filetype])
    random.shuffle(file_list)
    target_indices = col.defaultdict(list)
    for i, f in enumerate(file_list):
        target = get_target(f)
        if len(target_indices[target]) >= structures_per_rna:
            continue
        target_indices[target].append(i)

    dataset = da.load_dataset(file_list, filetype, transform=scores)

    logger.info(f'Writing train')
    train_indices = [f for target in TRAIN for f in target_indices[target]]
    train_dataset = torch.utils.data.Subset(dataset, train_indices)
    da.make_lmdb_dataset(train_dataset, os.path.join(output_root, 'train'))

    logger.info(f'Writing val')
    val_indices = [f for target in VAL for f in target_indices[target]]
    val_dataset = torch.utils.data.Subset(dataset, val_indices)
    da.make_lmdb_dataset(val_dataset, os.path.join(output_root, 'val'))

    logger.info(f'Writing test')
    test_indices = [f for target in TEST for f in target_indices[target]]
    test_dataset = torch.utils.data.Subset(dataset, test_indices)
    da.make_lmdb_dataset(test_dataset, os.path.join(output_root, 'test'))
Example #14
0
def prepare(input_file_path, output_root, split, train_txt, val_txt, test_txt):
    logging.basicConfig(stream=sys.stdout,
                        format='%(asctime)s %(levelname)s %(process)d: ' +
                        '%(message)s',
                        level=logging.INFO)

    # Assume PDB filetype.
    filetype = 'pdb'

    file_list = fi.find_files(os.path.join(input_file_path, 'mutated'),
                              fo.patterns[filetype])
    transform = MSPTransform(base_file_dir=input_file_path)

    lmdb_path = os.path.join(output_root, 'raw', 'MSP', 'data')
    if not os.path.exists(lmdb_path):
        os.makedirs(lmdb_path)

    logger.info(f'Creating lmdb dataset into {lmdb_path:}...')
    if not os.path.exists(lmdb_path):
        os.makedirs(lmdb_path)
    #dataset = da.load_dataset(file_list, filetype, transform=transform)
    #da.make_lmdb_dataset(dataset, lmdb_path)

    if not split:
        return

    logger.info(f'Splitting indices...')
    lmdb_ds = da.load_dataset(lmdb_path, 'lmdb')

    split_data_path = os.path.join(output_root, 'splits', 'split-by-seqid30',
                                   'data')
    split_idx_path = os.path.join(output_root, 'splits', 'split-by-seqid30',
                                  'indices')
    if not os.path.exists(split_data_path):
        os.makedirs(split_data_path)
    if not os.path.exists(split_idx_path):
        os.makedirs(split_idx_path)

    def _write_split_indices(split_txt, lmdb_ds, output_txt):
        with open(split_txt, 'r') as f:
            split_set = set([x.strip() for x in f.readlines()])
        # Check if the target in id is in the desired target split set
        split_ids = list(filter(lambda id: id in split_set, lmdb_ds.ids()))
        # Convert ids into lmdb numerical indices and write into txt file
        split_indices = lmdb_ds.ids_to_indices(split_ids)
        str_indices = [str(i) for i in split_indices]
        with open(output_txt, 'w') as f:
            f.write(str('\n'.join(str_indices)))
        return split_indices

    logger.info(f'Writing train')
    train_indices = _write_split_indices(
        train_txt, lmdb_ds, os.path.join(split_idx_path, 'train_indices.txt'))
    print(train_indices)
    train_dataset = torch.utils.data.Subset(lmdb_ds, train_indices)
    da.make_lmdb_dataset(train_dataset, os.path.join(split_data_path, 'train'))

    logger.info(f'Writing val')
    val_indices = _write_split_indices(
        val_txt, lmdb_ds, os.path.join(split_idx_path, 'val_indices.txt'))
    val_dataset = torch.utils.data.Subset(lmdb_ds, val_indices)
    da.make_lmdb_dataset(val_dataset, os.path.join(split_data_path, 'val'))

    logger.info(f'Writing test')
    test_indices = _write_split_indices(
        test_txt, lmdb_ds, os.path.join(split_idx_path, 'test_indices.txt'))
    test_dataset = torch.utils.data.Subset(lmdb_ds, test_indices)
    da.make_lmdb_dataset(test_dataset, os.path.join(split_data_path, 'test'))