예제 #1
0
def modisco_enrich_patterns(patterns_pkl_file,
                            modisco_dir,
                            output_file,
                            impsf=None):
    """Add stacked_seqlet_imp to pattern `attrs`

    Args:
      patterns_pkl: patterns.pkl file path
      modisco_dir: modisco directory containing
      output_file: output file path for patterns.pkl
    """
    from basepair.utils import read_pkl, write_pkl
    from basepair.cli.imp_score import ImpScoreFile
    from basepair.modisco.core import StackedSeqletImp

    logger.info("Loading patterns")
    modisco_dir = Path(modisco_dir)
    patterns = read_pkl(patterns_pkl_file)

    mr = ModiscoResult(modisco_dir / 'modisco.h5')
    mr.open()

    if impsf is None:
        imp_file = ImpScoreFile.from_modisco_dir(modisco_dir)
        logger.info("Loading ImpScoreFile into memory")
        imp_file.cache()
    else:
        logger.info("Using the provided ImpScoreFile")
        imp_file = impsf

    logger.info("Extracting profile and importance scores")
    extended_patterns = []
    for p in tqdm(patterns):
        p = p.copy()
        profile_width = p.len_profile()
        # get the shifted seqlets
        seqlets = [
            s.pattern_align(**p.attrs['align'])
            for s in mr._get_seqlets(p.name)
        ]

        # keep only valid seqlets
        valid_seqlets = [
            s for s in seqlets if s.valid_resize(profile_width,
                                                 imp_file.get_seqlen() + 1)
        ]
        # extract the importance scores
        p.attrs['stacked_seqlet_imp'] = imp_file.extract(
            valid_seqlets, profile_width=profile_width)

        p.attrs['n_seqlets'] = mr.n_seqlets(*p.name.split("/"))
        extended_patterns.append(p)

    write_pkl(extended_patterns, output_file)
예제 #2
0
def load_data(model_dir, cache_data=False,
              dataspec=None, hparams=None, data=None, preprocessor=None):
    if dataspec is not None:
        ds = DataSpec.load(dataspec)
    else:
        ds = DataSpec.load(os.path.join(model_dir, "dataspec.yaml"))
    if hparams is not None:
        hp = HParams.load(hparams)
    else:
        hp = HParams.load(os.path.join(model_dir, "hparams.yaml"))
    if data is not None:
        data_path = data
    else:
        data_path = os.path.join(model_dir, 'data.pkl')
    if os.path.exists(data_path):
        train, valid, test = read_pkl(data_path)
    else:
        train, valid, test = chip_exo_nexus(ds,
                                            peak_width=hp.data.peak_width,
                                            shuffle=hp.data.shuffle,
                                            valid_chr=hp.data.valid_chr,
                                            test_chr=hp.data.test_chr)

        # Pre-process the data
        logger.info("Pre-processing the data")
        if preprocessor is not None:
            preproc_path = os.path.join(model_dir, "preprocessor.pkl")
        else:
            preproc_path = preprocessor
        preproc = read_pkl(preproc_path)
        train[1] = preproc.transform(train[1])
        valid[1] = preproc.transform(valid[1])
        try:
            test[1] = preproc.transform(test[1])
        except Exception:
            logger.warn("Test set couldn't be processed")
            test = None
    return train, valid, test
예제 #3
0
 def from_mdir(cls, model_dir):
     """
     Args:
       model_dir (str): Path to the model directory
     """
     import os
     from basepair.cli.schemas import DataSpec
     from keras.models import load_model
     from basepair.utils import read_pkl
     ds = DataSpec.load(os.path.join(model_dir, "dataspec.yaml"))
     model = load_model(os.path.join(model_dir, "model.h5"))
     preproc_file = os.path.join(model_dir, "preprocessor.pkl")
     if os.path.exists(preproc_file):
         preproc = read_pkl(preproc_file)
     else:
         preproc = None
     return cls(model=model, fasta_file=ds.fasta_file, tasks=list(ds.task_specs), preproc=preproc)
예제 #4
0
 def load(cls, file_path):
     """Load model from a file
     """
     from basepair.utils import read_pkl
     # TODO - update to json
     return read_pkl(file_path)