Beispiel #1
0
    def __init__(self, file_list, transform=None):
        """constructor

        """

        if not importlib.util.find_spec("rosetta") is not None:
            raise RuntimeError(
                'Need to install pyrosetta to process silent files.')

        with contextlib.redirect_stdout(None):
            self.pyrosetta = importlib.import_module('pyrosetta')
            self.pyrpose = importlib.import_module(
                'pyrosetta.rosetta.core.pose')
            self.pyrps = importlib.import_module(
                'pyrosetta.rosetta.core.import_pose.pose_stream')
            self.pyrosetta.init("-mute all")

        self._file_list = [Path(x).absolute() for x in file_list]
        self._scores = ar.Scores(self._file_list)
        self._transform = transform

        self._num_examples = len(self._scores)
Beispiel #2
0
def prepare(input_file_path, output_root, score_path, structures_per_rna):
    logging.basicConfig(stream=sys.stdout,
                        format='%(asctime)s %(levelname)s %(process)d: ' +
                        '%(message)s',
                        level=logging.INFO)

    # Assume PDB filetype.
    filetype = 'pdb'

    scores = ar.Scores(score_path) if score_path else None

    logger.info(f'Splitting indices')
    file_list = fi.find_files(input_file_path, fo.patterns[filetype])
    random.shuffle(file_list)
    target_indices = col.defaultdict(list)
    for i, f in enumerate(file_list):
        target = get_target(f)
        if len(target_indices[target]) >= structures_per_rna:
            continue
        target_indices[target].append(i)

    dataset = da.load_dataset(file_list, filetype, transform=scores)

    logger.info(f'Writing train')
    train_indices = [f for target in TRAIN for f in target_indices[target]]
    train_dataset = torch.utils.data.Subset(dataset, train_indices)
    da.make_lmdb_dataset(train_dataset, os.path.join(output_root, 'train'))

    logger.info(f'Writing val')
    val_indices = [f for target in VAL for f in target_indices[target]]
    val_dataset = torch.utils.data.Subset(dataset, val_indices)
    da.make_lmdb_dataset(val_dataset, os.path.join(output_root, 'val'))

    logger.info(f'Writing test')
    test_indices = [f for target in TEST for f in target_indices[target]]
    test_dataset = torch.utils.data.Subset(dataset, test_indices)
    da.make_lmdb_dataset(test_dataset, os.path.join(output_root, 'test'))