def prepare(input_file_path, output_root, split, train_txt, val_txt, test_txt): logging.basicConfig(stream=sys.stdout, format='%(asctime)s %(levelname)s %(process)d: ' + '%(message)s', level=logging.INFO) # Assume PDB filetype. filetype = 'pdb' file_list = fi.find_files(input_file_path, fo.patterns[filetype]) lmdb_path = os.path.join(output_root, 'all') logger.info(f'Creating lmdb dataset into {lmdb_path:}...') dataset = da.load_dataset(file_list, filetype, transform=pdb_id_transform) da.make_lmdb_dataset(dataset, lmdb_path) if not split: return logger.info(f'Splitting indices...') lmdb_ds = da.load_dataset(lmdb_path, 'lmdb') def _write_split_indices(split_txt, lmdb_ds, output_txt): with open(split_txt, 'r') as f: split_set = set([x.strip() for x in f.readlines()]) # Check if the target in id is in the desired target split set split_ids = list(filter(lambda id: eval(id)[0] in split_set, lmdb_ds.ids())) # Convert ids into lmdb numerical indices and write into txt file split_indices = lmdb_ds.ids_to_indices(split_ids) with open(output_txt, 'w') as f: f.write(str('\n'.join([str(i) for i in split_indices]))) _write_split_indices(train_txt, lmdb_ds, os.path.join(output_root, 'train_indices.txt')) _write_split_indices(val_txt, lmdb_ds, os.path.join(output_root, 'val_indices.txt')) _write_split_indices(test_txt, lmdb_ds, os.path.join(output_root, 'test_indices.txt'))
def _load_datasets(self, input_file_path, pdbcodes): protein_list = [] pocket_list = [] ligand_list = [] for pdbcode in pdbcodes: protein_path = os.path.join(input_file_path, f'{pdbcode:}/{pdbcode:}_protein.pdb') pocket_path = os.path.join(input_file_path, f'{pdbcode:}/{pdbcode:}_pocket.pdb') ligand_path = os.path.join(input_file_path, f'{pdbcode:}/{pdbcode:}_ligand.sdf') if os.path.exists(protein_path) and os.path.exists(pocket_path) and \ os.path.exists(ligand_path): protein_list.append(protein_path) pocket_list.append(pocket_path) ligand_list.append(ligand_path) assert len(protein_list) == len(pocket_list) == len(ligand_list) logger.info(f'Found {len(protein_list):} protein/ligand files...') self._protein_dataset = da.load_dataset( protein_list, 'pdb', transform=SequenceReader(input_file_path)) self._pocket_dataset = da.load_dataset(pocket_list, 'pdb', transform=None) self._ligand_dataset = da.load_dataset(ligand_list, 'sdf', include_bonds=True, transform=SmilesReader())
def split_lmdb_dataset(lmdb_path, train_txt, val_txt, test_txt, split_dir): logger.info(f'Splitting indices, load data from {lmdb_path:}...') lmdb_ds = da.load_dataset(lmdb_path, 'lmdb') def _write_split_indices(split_txt, lmdb_ds, output_txt): # Read list of desired <target, decoy> split_set = set(map(tuple, pd.read_csv(split_txt, header=None, dtype=str).values)) # Check if the <target, decoy> id is in the desired split set split_ids = list(filter(lambda id: eval(id) in split_set, lmdb_ds.ids())) # Convert ids into lmdb numerical indices and write into txt file split_indices = lmdb_ds.ids_to_indices(split_ids) with open(output_txt, 'w') as f: f.write(str('\n'.join([str(i) for i in split_indices]))) return split_indices logger.info(f'Write results to {split_dir:}...') os.makedirs(os.path.join(split_dir, 'indices'), exist_ok=True) os.makedirs(os.path.join(split_dir, 'data'), exist_ok=True) indices_train = _write_split_indices( train_txt, lmdb_ds, os.path.join(split_dir, 'indices/train_indices.txt')) indices_val = _write_split_indices( val_txt, lmdb_ds, os.path.join(split_dir, 'indices/val_indices.txt')) indices_test = _write_split_indices( test_txt, lmdb_ds, os.path.join(split_dir, 'indices/test_indices.txt')) train_dataset, val_dataset, test_dataset = spl.split( lmdb_ds, indices_train, indices_val, indices_test) da.make_lmdb_dataset(train_dataset, os.path.join(split_dir, 'data/train')) da.make_lmdb_dataset(val_dataset, os.path.join(split_dir, 'data/val')) da.make_lmdb_dataset(test_dataset, os.path.join(split_dir, 'data/test'))
def split_lmdb_dataset(lmdb_path, train_txt, val_txt, test_txt, split_dir): logger.info(f'Splitting indices, load data from {lmdb_path:}...') lmdb_ds = da.load_dataset(lmdb_path, 'lmdb') def _write_split_indices(split_txt, lmdb_ds, output_txt): with open(split_txt, 'r') as f: split_set = set([x.strip() for x in f.readlines()]) # Check if the target in id is in the desired target split set split_ids = list(filter(lambda id: id in split_set, lmdb_ds.ids())) # Convert ids into lmdb numerical indices and write into txt file split_indices = lmdb_ds.ids_to_indices(split_ids) with open(output_txt, 'w') as f: f.write(str('\n'.join([str(i) for i in split_indices]))) return split_indices logger.info(f'Write results to {split_dir:}...') os.makedirs(os.path.join(split_dir, 'indices'), exist_ok=True) os.makedirs(os.path.join(split_dir, 'data'), exist_ok=True) indices_train = _write_split_indices( train_txt, lmdb_ds, os.path.join(split_dir, 'indices/train_indices.txt')) indices_val = _write_split_indices( val_txt, lmdb_ds, os.path.join(split_dir, 'indices/val_indices.txt')) indices_test = _write_split_indices( test_txt, lmdb_ds, os.path.join(split_dir, 'indices/test_indices.txt')) train_dataset, val_dataset, test_dataset = spl.split( lmdb_ds, indices_train, indices_val, indices_test) da.make_lmdb_dataset(train_dataset, os.path.join(split_dir, 'data/train')) da.make_lmdb_dataset(val_dataset, os.path.join(split_dir, 'data/val')) da.make_lmdb_dataset(test_dataset, os.path.join(split_dir, 'data/test'))
def _process_chunk(file_list, filetype, lmdb_path, balance): logger.info(f'Creating lmdb dataset into {lmdb_path:}...') if not os.path.exists(lmdb_path): os.makedirs(lmdb_path) dataset = da.load_dataset(file_list, filetype, transform=ResTransform(balance=balance)) da.make_lmdb_dataset(dataset, lmdb_path)
def _load_active_and_inactive_datasets(self, input_file_path, id_codes): A_list = [] #active conformations I_list = [] #inactive conformations for code in id_codes: tokens = code.split('__') ligand = tokens[0] pdb1 = tokens[1] pdb2 = tokens[2] A_path = os.path.join(input_file_path, f'{ligand}_to_{pdb1}.pdb') I_path = os.path.join(input_file_path, f'{ligand}_to_{pdb2}.pdb') if os.path.exists(A_path) and os.path.exists(I_path): A_list.append(A_path) I_list.append(I_path) assert len(A_list) == len(I_list) logger.info(f'Found {len(A_list):} pairs of protein files...') self._active_dataset = da.load_dataset(A_list, 'pdb') self._inactive_dataset = da.load_dataset(I_list, 'pdb')
def prepare(input_file_path, output_root, split, train_txt, val_txt, test_txt): # Logger logging.basicConfig(stream=sys.stdout, format='%(asctime)s %(levelname)s %(process)d: ' + '%(message)s', level=logging.INFO) # Assume GDB-specific version of XYZ format. filetype = 'xyz-gdb' # Compile a list of the input files file_list = fi.find_files(input_file_path, fo.patterns[filetype]) # Write the LMDB dataset lmdb_path = os.path.join(output_root, 'all') logger.info(f'Creating lmdb dataset into {lmdb_path:}...') dataset = da.load_dataset( file_list, filetype, transform=_add_data_with_subtracted_thermochem_energy) da.make_lmdb_dataset(dataset, lmdb_path, filter_fn=bond_filter) # Only continue if we want to write split datasets if not split: return logger.info(f'Splitting indices...\n') # Load the dataset that has just been created lmdb_ds = da.load_dataset(lmdb_path, 'lmdb') # Determine and write out the split indices indices_train = _write_split_indices( train_txt, lmdb_ds, os.path.join(output_root, 'train_indices.txt')) indices_val = _write_split_indices( val_txt, lmdb_ds, os.path.join(output_root, 'val_indices.txt')) indices_test = _write_split_indices( test_txt, lmdb_ds, os.path.join(output_root, 'test_indices.txt')) # Write the split datasets train_dataset, val_dataset, test_dataset = spl.split( lmdb_ds, indices_train, indices_val, indices_test) da.make_lmdb_dataset(train_dataset, os.path.join(output_root, 'train'), filter_fn=bond_filter) da.make_lmdb_dataset(val_dataset, os.path.join(output_root, 'val'), filter_fn=bond_filter) da.make_lmdb_dataset(test_dataset, os.path.join(output_root, 'test'), filter_fn=bond_filter)
def prepare(input_file_path, output_root, split, balance, train_txt, val_txt, test_txt, num_threads, start): logging.basicConfig(stream=sys.stdout, format='%(asctime)s %(levelname)s %(process)d: ' + '%(message)s', level=logging.INFO) def _process_chunk(file_list, filetype, lmdb_path, balance): logger.info(f'Creating lmdb dataset into {lmdb_path:}...') if not os.path.exists(lmdb_path): os.makedirs(lmdb_path) dataset = da.load_dataset(file_list, filetype, transform=ResTransform(balance=balance)) da.make_lmdb_dataset(dataset, lmdb_path) # Assume PDB filetype. filetype = 'pdb' file_list = fi.find_files(input_file_path, fo.patterns[filetype]) lmdb_path = os.path.join(output_root, 'all') if not os.path.exists(lmdb_path): os.makedirs(lmdb_path) # dataset = da.load_dataset(file_list, filetype, transform=ResTransform(balance=balance)) # da.make_lmdb_dataset(dataset, lmdb_path) chunk_size = (len(file_list) // num_threads) + 1 chunks = [file_list[i:i + chunk_size] for i in range(0, len(file_list), chunk_size)] assert len(chunks) == num_threads for i in range(start,num_threads): logger.info(f'Processing chunk {i:}...') _process_chunk(chunks[i], 'pdb', f'{lmdb_path}_tmp_{i}', balance) if not split: return logger.info(f'Splitting indices...') lmdb_ds = da.load_dataset(lmdb_path, 'lmdb') def _write_split_indices(split_txt, lmdb_ds, output_txt): with open(split_txt, 'r') as f: split_set = set([x.strip() for x in f.readlines()]) # Check if the target in id is in the desired target split set split_ids = list(filter(lambda id: id in split_set, lmdb_ds.ids())) # Convert ids into lmdb numerical indices and write into txt file split_indices = lmdb_ds.ids_to_indices(split_ids) with open(output_txt, 'w') as f: f.write(str('\n'.join([str(i) for i in split_indices]))) _write_split_indices(train_txt, lmdb_ds, os.path.join(output_root, 'train_indices.txt')) _write_split_indices(val_txt, lmdb_ds, os.path.join(output_root, 'val_indices.txt')) _write_split_indices(test_txt, lmdb_ds, os.path.join(output_root, 'test_indices.txt'))
def make_lmdb_dataset(input_file_path, score_path, output_root): # Assume PDB filetype. filetype = 'pdb' scores = Scores(score_path) if score_path else None file_list = fi.find_files(input_file_path, fo.patterns[filetype]) lmdb_path = os.path.join(output_root, 'data') os.makedirs(lmdb_path, exist_ok=True) logger.info(f'Creating lmdb dataset into {lmdb_path:}...') dataset = da.load_dataset(file_list, filetype, transform=scores) da.make_lmdb_dataset(dataset, lmdb_path) return lmdb_path
def prepare(input_file_path, output_root, split, train_txt, val_txt, test_txt, score_path): logging.basicConfig(stream=sys.stdout, format='%(asctime)s %(levelname)s %(process)d: ' + '%(message)s', level=logging.INFO) scores = Scores(score_path) if score_path else None # Assume subdirectories containing the protein/pocket/ligand files are # structured as <input_file_path>/<pdbcode> pdbcodes = os.listdir(input_file_path) lmdb_path = os.path.join(output_root, 'all') logger.info(f'Creating lmdb dataset into {lmdb_path:}...') dataset = LBADataset(input_file_path, pdbcodes, transform=scores) da.make_lmdb_dataset(dataset, lmdb_path) if not split: return logger.info(f'Splitting indices...') lmdb_ds = da.load_dataset(lmdb_path, 'lmdb') def _write_split_indices(split_txt, lmdb_ds, output_txt): with open(split_txt, 'r') as f: split_set = set([x.strip() for x in f.readlines()]) # Check if the pdbcode in id is in the desired pdbcode split set split_ids = list(filter(lambda id: id in split_set, lmdb_ds.ids())) # Convert ids into lmdb numerical indices and write into txt file split_indices = lmdb_ds.ids_to_indices(split_ids) with open(output_txt, 'w') as f: f.write(str('\n'.join([str(i) for i in split_indices]))) return split_indices indices_train = _write_split_indices( train_txt, lmdb_ds, os.path.join(output_root, 'train_indices.txt')) indices_val = _write_split_indices( val_txt, lmdb_ds, os.path.join(output_root, 'val_indices.txt')) indices_test = _write_split_indices( test_txt, lmdb_ds, os.path.join(output_root, 'test_indices.txt')) train_dataset, val_dataset, test_dataset = spl.split( lmdb_ds, indices_train, indices_val, indices_test) da.make_lmdb_dataset(train_dataset, os.path.join(output_root, 'train')) da.make_lmdb_dataset(val_dataset, os.path.join(output_root, 'val')) da.make_lmdb_dataset(test_dataset, os.path.join(output_root, 'test'))
def split(in_path, output_root, train_txt, val_txt, test_txt): dataset = da.load_dataset(in_path, 'lmdb') logger.info(f'Writing train') train_indices = read_split_file(train_txt) train_dataset = torch.utils.data.Subset(dataset, train_indices) da.make_lmdb_dataset(train_dataset, os.path.join(output_root, 'train')) logger.info(f'Writing val') val_indices = read_split_file(val_txt) val_dataset = torch.utils.data.Subset(dataset, val_indices) da.make_lmdb_dataset(val_dataset, os.path.join(output_root, 'val')) logger.info(f'Writing test') test_indices = read_split_file(test_txt) test_dataset = torch.utils.data.Subset(dataset, test_indices) da.make_lmdb_dataset(test_dataset, os.path.join(output_root, 'test'))
def main(input_dir, output_lmdb, filetype, score_path, serialization_format): """Script wrapper to make_lmdb_dataset to create LMDB dataset.""" logging.basicConfig(stream=sys.stdout, format='%(asctime)s %(levelname)s %(process)d: ' + '%(message)s', level=logging.INFO) logger.info(f'filetype: {filetype}') if filetype == 'xyz-gdb': fileext = 'xyz' else: fileext = filetype file_list = da.get_file_list(input_dir, fileext) logger.info(f'Found {len(file_list)} files.') dataset = da.load_dataset(file_list, filetype) da.make_lmdb_dataset( dataset, output_lmdb, serialization_format=serialization_format)
def prepare(input_file_path, output_root, score_path, structures_per_rna): logging.basicConfig(stream=sys.stdout, format='%(asctime)s %(levelname)s %(process)d: ' + '%(message)s', level=logging.INFO) # Assume PDB filetype. filetype = 'pdb' scores = ar.Scores(score_path) if score_path else None logger.info(f'Splitting indices') file_list = fi.find_files(input_file_path, fo.patterns[filetype]) random.shuffle(file_list) target_indices = col.defaultdict(list) for i, f in enumerate(file_list): target = get_target(f) if len(target_indices[target]) >= structures_per_rna: continue target_indices[target].append(i) dataset = da.load_dataset(file_list, filetype, transform=scores) logger.info(f'Writing train') train_indices = [f for target in TRAIN for f in target_indices[target]] train_dataset = torch.utils.data.Subset(dataset, train_indices) da.make_lmdb_dataset(train_dataset, os.path.join(output_root, 'train')) logger.info(f'Writing val') val_indices = [f for target in VAL for f in target_indices[target]] val_dataset = torch.utils.data.Subset(dataset, val_indices) da.make_lmdb_dataset(val_dataset, os.path.join(output_root, 'val')) logger.info(f'Writing test') test_indices = [f for target in TEST for f in target_indices[target]] test_dataset = torch.utils.data.Subset(dataset, test_indices) da.make_lmdb_dataset(test_dataset, os.path.join(output_root, 'test'))
def prepare(input_file_path, output_root, split, train_txt, val_txt, test_txt): logging.basicConfig(stream=sys.stdout, format='%(asctime)s %(levelname)s %(process)d: ' + '%(message)s', level=logging.INFO) # Assume PDB filetype. filetype = 'pdb' file_list = fi.find_files(os.path.join(input_file_path, 'mutated'), fo.patterns[filetype]) transform = MSPTransform(base_file_dir=input_file_path) lmdb_path = os.path.join(output_root, 'raw', 'MSP', 'data') if not os.path.exists(lmdb_path): os.makedirs(lmdb_path) logger.info(f'Creating lmdb dataset into {lmdb_path:}...') if not os.path.exists(lmdb_path): os.makedirs(lmdb_path) #dataset = da.load_dataset(file_list, filetype, transform=transform) #da.make_lmdb_dataset(dataset, lmdb_path) if not split: return logger.info(f'Splitting indices...') lmdb_ds = da.load_dataset(lmdb_path, 'lmdb') split_data_path = os.path.join(output_root, 'splits', 'split-by-seqid30', 'data') split_idx_path = os.path.join(output_root, 'splits', 'split-by-seqid30', 'indices') if not os.path.exists(split_data_path): os.makedirs(split_data_path) if not os.path.exists(split_idx_path): os.makedirs(split_idx_path) def _write_split_indices(split_txt, lmdb_ds, output_txt): with open(split_txt, 'r') as f: split_set = set([x.strip() for x in f.readlines()]) # Check if the target in id is in the desired target split set split_ids = list(filter(lambda id: id in split_set, lmdb_ds.ids())) # Convert ids into lmdb numerical indices and write into txt file split_indices = lmdb_ds.ids_to_indices(split_ids) str_indices = [str(i) for i in split_indices] with open(output_txt, 'w') as f: f.write(str('\n'.join(str_indices))) return split_indices logger.info(f'Writing train') train_indices = _write_split_indices( train_txt, lmdb_ds, os.path.join(split_idx_path, 'train_indices.txt')) print(train_indices) train_dataset = torch.utils.data.Subset(lmdb_ds, train_indices) da.make_lmdb_dataset(train_dataset, os.path.join(split_data_path, 'train')) logger.info(f'Writing val') val_indices = _write_split_indices( val_txt, lmdb_ds, os.path.join(split_idx_path, 'val_indices.txt')) val_dataset = torch.utils.data.Subset(lmdb_ds, val_indices) da.make_lmdb_dataset(val_dataset, os.path.join(split_data_path, 'val')) logger.info(f'Writing test') test_indices = _write_split_indices( test_txt, lmdb_ds, os.path.join(split_idx_path, 'test_indices.txt')) test_dataset = torch.utils.data.Subset(lmdb_ds, test_indices) da.make_lmdb_dataset(test_dataset, os.path.join(split_data_path, 'test'))