Example #1
0
 def from_mdir(cls, model_dir):
     from basepair.seqmodel import SeqModel
     # TODO - figure out also the fasta_file if present (from dataspec)
     from basepair.cli.schemas import DataSpec
     ds_path = os.path.join(model_dir, "dataspec.yaml")
     if os.path.exists(ds_path):
         ds = DataSpec.load(ds_path)
         fasta_file = ds.fasta_file
     else:
         fasta_file = None
     return cls(SeqModel.from_mdir(model_dir), fasta_file=fasta_file)
Example #2
0
def load_data(model_dir, cache_data=False,
              dataspec=None, hparams=None, data=None, preprocessor=None):
    if dataspec is not None:
        ds = DataSpec.load(dataspec)
    else:
        ds = DataSpec.load(os.path.join(model_dir, "dataspec.yaml"))
    if hparams is not None:
        hp = HParams.load(hparams)
    else:
        hp = HParams.load(os.path.join(model_dir, "hparams.yaml"))
    if data is not None:
        data_path = data
    else:
        data_path = os.path.join(model_dir, 'data.pkl')
    if os.path.exists(data_path):
        train, valid, test = read_pkl(data_path)
    else:
        train, valid, test = chip_exo_nexus(ds,
                                            peak_width=hp.data.peak_width,
                                            shuffle=hp.data.shuffle,
                                            valid_chr=hp.data.valid_chr,
                                            test_chr=hp.data.test_chr)

        # Pre-process the data
        logger.info("Pre-processing the data")
        if preprocessor is not None:
            preproc_path = os.path.join(model_dir, "preprocessor.pkl")
        else:
            preproc_path = preprocessor
        preproc = read_pkl(preproc_path)
        train[1] = preproc.transform(train[1])
        valid[1] = preproc.transform(valid[1])
        try:
            test[1] = preproc.transform(test[1])
        except Exception:
            logger.warn("Test set couldn't be processed")
            test = None
    return train, valid, test
Example #3
0
def get_StrandedProfile_datasets(dataspec,
                                 peak_width=200,
                                 seq_width=None,
                                 shuffle=True,
                                 target_transformer=AppendCounts(),
                                 valid_chr=['chr2', 'chr3', 'chr4'],
                                 test_chr=['chr1', 'chr8', 'chr9'],
                                 all_chr=all_chr,
                                 exclude_chr=[],
                                 vmtouch=True):
    # test and valid shouldn't be in the valid or test sets
    for vc in valid_chr:
        assert vc not in exclude_chr
    for vc in test_chr:
        assert vc not in exclude_chr

    if isinstance(dataspec, str):
        dataspec = DataSpec.load(dataspec)

    if vmtouch:
        # use vmtouch to load all file to memory
        dataspec.touch_all_files()

    return (
        StrandedProfile(
            dataspec,
            peak_width,
            seq_width=seq_width,
            # Only include chromosomes from `all_chr`
            incl_chromosomes=[
                c for c in all_chr
                if c not in valid_chr + test_chr + exclude_chr
            ],
            excl_chromosomes=valid_chr + test_chr + exclude_chr,
            shuffle=shuffle,
            target_transformer=target_transformer),
        StrandedProfile(dataspec,
                        peak_width,
                        seq_width=seq_width,
                        incl_chromosomes=valid_chr,
                        shuffle=shuffle,
                        target_transformer=target_transformer),
        StrandedProfile(dataspec,
                        peak_width,
                        seq_width=seq_width,
                        incl_chromosomes=test_chr,
                        shuffle=shuffle,
                        target_transformer=target_transformer))
def dataspec():
    return DataSpec(bigwigs={
        "task1":
        TaskSpec(
            task='task1',
            pos_counts=f"{ddir}/pos.bw",
            neg_counts=f"{ddir}/neg.bw",
        ),
        "task2":
        TaskSpec(
            task='task2',
            pos_counts=f"{ddir}/pos.bw",
            neg_counts=f"{ddir}/neg.bw",
        )
    },
                    fasta_file=f"{ddir}/ref.fa",
                    peaks=f"{ddir}/peaks.bed")
Example #5
0
 def from_mdir(cls, model_dir):
     """
     Args:
       model_dir (str): Path to the model directory
     """
     import os
     from basepair.cli.schemas import DataSpec
     from keras.models import load_model
     from basepair.utils import read_pkl
     ds = DataSpec.load(os.path.join(model_dir, "dataspec.yaml"))
     model = load_model(os.path.join(model_dir, "model.h5"))
     preproc_file = os.path.join(model_dir, "preprocessor.pkl")
     if os.path.exists(preproc_file):
         preproc = read_pkl(preproc_file)
     else:
         preproc = None
     return cls(model=model, fasta_file=ds.fasta_file, tasks=list(ds.task_specs), preproc=preproc)
Example #6
0
    def __init__(self,
                 ds,
                 peak_width=200,
                 seq_width=None,
                 incl_chromosomes=None,
                 excl_chromosomes=None,
                 intervals_file=None,
                 bcolz=False,
                 in_memory=False,
                 include_metadata=True,
                 taskname_first=False,
                 tasks=None,
                 include_classes=False,
                 only_classes=False,
                 shuffle=True,
                 interval_transformer=None,
                 target_transformer=None,
                 profile_bias_pool_size=None):
        """Dataset for loading the bigwigs and fastas

        Args:
          ds (basepair.src.schemas.DataSpec): data specification containing the
            fasta file, bed files and bigWig file paths
          chromosomes (list of str): a list of chor
          peak_width: resize the bed file to a certain width
          intervals_file: if specified, use these regions to train the model.
            If not specified, the regions are inferred from the dataspec.
          only_classes: if True, load only classes
          bcolz: If True, the bigwig/fasta files are in the genomelake bcolz format
          in_memory: If True, load the whole bcolz into memory. Only applicable when bcolz=True
          shuffle: True
          preprocessor: trained preprocessor object containing the .transform methods
        """
        if isinstance(ds, str):
            self.ds = DataSpec.load(ds)
        else:
            self.ds = ds
        self.peak_width = peak_width
        if seq_width is None:
            self.seq_width = peak_width
        else:
            self.seq_width = seq_width
        self.shuffle = shuffle
        self.intervals_file = intervals_file
        self.incl_chromosomes = incl_chromosomes
        self.excl_chromosomes = excl_chromosomes
        self.target_transformer = target_transformer
        self.include_classes = include_classes
        self.only_classes = only_classes
        self.taskname_first = taskname_first
        if self.only_classes:
            assert self.include_classes
        self.profile_bias_pool_size = profile_bias_pool_size
        # not specified yet
        self.fasta_extractor = None
        self.bw_extractors = None
        self.bias_bw_extractors = None
        self.include_metadata = include_metadata
        self.interval_transformer = interval_transformer
        self.bcolz = bcolz
        self.in_memory = in_memory
        if not self.bcolz and self.in_memory:
            raise ValueError(
                "in_memory option only applicable when bcolz=True")

        # Load chromosome lengths
        if self.bcolz:
            p = json.loads(
                (Path(self.ds.fasta_file) / "metadata.json").read_text())
            self.chrom_lens = {c: v[0] for c, v in p['file_shapes'].items()}
        else:
            fa = FastaFile(self.ds.fasta_file)
            self.chrom_lens = {
                name: l
                for name, l in zip(fa.references, fa.lengths)
            }
            if len(self.chrom_lens) == 0:
                raise ValueError(
                    f"no chromosomes found in fasta file: {self.ds.fasta_file}. "
                    "Make sure the file path is correct and that the fasta index file {self.ds.fasta_file}.fai is up to date"
                )
            del fa

        if self.intervals_file is None:
            self.dfm = load_beds(bed_files={
                task: task_spec.peaks
                for task, task_spec in self.ds.task_specs.items()
                if task_spec.peaks is not None
            },
                                 chromosome_lens=self.chrom_lens,
                                 excl_chromosomes=self.excl_chromosomes,
                                 incl_chromosomes=self.incl_chromosomes,
                                 resize_width=max(self.peak_width,
                                                  self.seq_width))
            assert list(
                self.dfm.columns)[:4] == ["chrom", "start", "end", "task"]
            if self.shuffle:
                self.dfm = self.dfm.sample(frac=1)
            self.tsv = None
            self.dfm_tasks = None
        else:
            self.tsv = TsvReader(self.intervals_file,
                                 num_chr=False,
                                 label_dtype=int,
                                 mask_ambigous=-1,
                                 incl_chromosomes=incl_chromosomes,
                                 excl_chromosomes=excl_chromosomes,
                                 chromosome_lens=self.chrom_lens,
                                 resize_width=max(self.peak_width,
                                                  self.seq_width))
            if self.shuffle:
                self.tsv.shuffle_inplace()
            self.dfm = self.tsv.df  # use the data-frame from tsv
            self.dfm_tasks = self.tsv.get_target_names()

        # remember the tasks
        if tasks is None:
            self.tasks = list(self.ds.task_specs)
        else:
            self.tasks = tasks

        if self.bcolz and self.in_memory:
            self.fasta_extractor = ArrayExtractor(self.ds.fasta_file,
                                                  in_memory=True)
            self.bw_extractors = {
                task: [
                    ArrayExtractor(task_spec.pos_counts, in_memory=True),
                    ArrayExtractor(task_spec.neg_counts, in_memory=True)
                ]
                for task, task_spec in self.ds.task_specs.items()
                if task in self.tasks
            }
            self.bias_bw_extractors = {
                task: [
                    ArrayExtractor(task_spec.pos_counts, in_memory=True),
                    ArrayExtractor(task_spec.neg_counts, in_memory=True)
                ]
                for task, task_spec in self.ds.bias_specs.items()
                if task in self.tasks
            }

        if self.include_classes:
            assert self.dfm_tasks is not None

        if self.dfm_tasks is not None:
            assert set(self.tasks).issubset(self.dfm_tasks)

        # setup bias maps per task
        self.task_bias_tracks = {
            task: [
                bias for bias, spec in self.ds.bias_specs.items()
                if task in spec.tasks
            ]
            for task in self.tasks
        }
Example #7
0
def get_gw_StrandedProfile_datasets(dataspec,
                                    intervals_file=None,
                                    peak_width=200,
                                    seq_width=None,
                                    shuffle=True,
                                    target_transformer=AppendCounts(),
                                    include_metadata=False,
                                    taskname_first=False,
                                    include_classes=False,
                                    only_classes=False,
                                    tasks=None,
                                    valid_chr=['chr2', 'chr3', 'chr4'],
                                    test_chr=['chr1', 'chr8', 'chr9'],
                                    exclude_chr=[],
                                    vmtouch=True,
                                    profile_bias_pool_size=None):
    # NOTE = only chromosomes from chr1-22 and chrX and chrY are considered here
    # (e.g. all other chromosomes like ChrUn... are omitted)
    from basepair.metrics import BPNetMetric, PeakPredictionProfileMetric, pearson_spearman
    # test and valid shouldn't be in the valid or test sets
    for vc in valid_chr:
        assert vc not in exclude_chr
    for vc in test_chr:
        assert vc not in exclude_chr

    dataspec = DataSpec.load(dataspec)
    if vmtouch:
        # use vmtouch to load all file to memory
        dataspec.touch_all_files()

    if tasks is None:
        tasks = list(dataspec.task_specs)

    train = StrandedProfile(
        dataspec,
        peak_width,
        seq_width=seq_width,
        intervals_file=intervals_file,
        include_metadata=include_metadata,
        taskname_first=taskname_first,
        include_classes=include_classes,
        only_classes=only_classes,
        tasks=tasks,
        incl_chromosomes=[
            c for c in all_chr if c not in valid_chr + test_chr + exclude_chr
        ],
        excl_chromosomes=valid_chr + test_chr + exclude_chr,
        shuffle=shuffle,
        target_transformer=target_transformer,
        profile_bias_pool_size=profile_bias_pool_size)

    valid = [('train-valid-genome-wide',
              StrandedProfile(dataspec,
                              peak_width,
                              seq_width=seq_width,
                              intervals_file=intervals_file,
                              include_metadata=include_metadata,
                              include_classes=include_classes,
                              only_classes=only_classes,
                              taskname_first=taskname_first,
                              tasks=tasks,
                              incl_chromosomes=valid_chr,
                              shuffle=shuffle,
                              target_transformer=target_transformer,
                              profile_bias_pool_size=profile_bias_pool_size))]
    if include_classes:
        # Only use binary classification for genome-wide evaluation
        valid = valid + [('valid-genome-wide',
                          StrandedProfile(
                              dataspec,
                              peak_width,
                              seq_width=seq_width,
                              intervals_file=intervals_file,
                              include_metadata=include_metadata,
                              include_classes=True,
                              only_classes=True,
                              taskname_first=taskname_first,
                              tasks=tasks,
                              incl_chromosomes=valid_chr,
                              shuffle=shuffle,
                              target_transformer=target_transformer,
                              profile_bias_pool_size=profile_bias_pool_size))]

    if not only_classes:
        # Add also the peak regions
        valid = valid + [
            (
                'valid-peaks',
                StrandedProfile(
                    dataspec,
                    peak_width,
                    seq_width=seq_width,
                    intervals_file=None,
                    include_metadata=include_metadata,
                    taskname_first=taskname_first,
                    tasks=tasks,
                    include_classes=False,  # dataspec doesn't contain labels
                    only_classes=only_classes,
                    incl_chromosomes=valid_chr,
                    shuffle=shuffle,
                    target_transformer=target_transformer,
                    profile_bias_pool_size=profile_bias_pool_size)),
            (
                'train-peaks',
                StrandedProfile(
                    dataspec,
                    peak_width,
                    seq_width=seq_width,
                    intervals_file=None,
                    include_metadata=include_metadata,
                    taskname_first=taskname_first,
                    tasks=tasks,
                    include_classes=False,  # dataspec doesn't contain labels
                    only_classes=only_classes,
                    incl_chromosomes=[
                        c for c in all_chr
                        if c not in valid_chr + test_chr + exclude_chr
                    ],
                    excl_chromosomes=valid_chr + test_chr + exclude_chr,
                    shuffle=shuffle,
                    target_transformer=target_transformer,
                    profile_bias_pool_size=profile_bias_pool_size)),
            # use the default metric for the peak sets
        ]
    return train, valid
Example #8
0
def get_StrandedProfile_datasets2(dataspec,
                                  peak_width=200,
                                  intervals_file=None,
                                  seq_width=None,
                                  shuffle=True,
                                  target_transformer=AppendCounts(),
                                  include_metadata=False,
                                  valid_chr=['chr2', 'chr3', 'chr4'],
                                  test_chr=['chr1', 'chr8', 'chr9'],
                                  tasks=None,
                                  taskname_first=False,
                                  exclude_chr=[],
                                  augment_interval=False,
                                  interval_augmentation_shift=200,
                                  vmtouch=True,
                                  profile_bias_pool_size=None):
    from basepair.metrics import BPNetMetric, PeakPredictionProfileMetric, pearson_spearman
    # test and valid shouldn't be in the valid or test sets
    for vc in valid_chr:
        assert vc not in exclude_chr
    for vc in test_chr:
        assert vc not in exclude_chr

    dataspec = DataSpec.load(dataspec)
    if vmtouch:
        # use vmtouch to load all file to memory
        dataspec.touch_all_files()

    if tasks is None:
        tasks = list(dataspec.task_specs)

    if augment_interval:
        interval_transformer = IntervalAugmentor(
            max_shift=interval_augmentation_shift, flip_strand=True)
    else:
        interval_transformer = None

    return (StrandedProfile(
        dataspec,
        peak_width,
        intervals_file=intervals_file,
        seq_width=seq_width,
        include_metadata=include_metadata,
        incl_chromosomes=[
            c for c in all_chr if c not in valid_chr + test_chr + exclude_chr
        ],
        excl_chromosomes=valid_chr + test_chr + exclude_chr,
        tasks=tasks,
        taskname_first=taskname_first,
        shuffle=shuffle,
        target_transformer=target_transformer,
        interval_transformer=interval_transformer,
        profile_bias_pool_size=profile_bias_pool_size), [
            ('valid-peaks',
             StrandedProfile(dataspec,
                             peak_width,
                             intervals_file=intervals_file,
                             seq_width=seq_width,
                             include_metadata=include_metadata,
                             incl_chromosomes=valid_chr,
                             tasks=tasks,
                             taskname_first=taskname_first,
                             interval_transformer=interval_transformer,
                             shuffle=shuffle,
                             target_transformer=target_transformer,
                             profile_bias_pool_size=profile_bias_pool_size)),
            ('train-peaks',
             StrandedProfile(dataspec,
                             peak_width,
                             intervals_file=intervals_file,
                             seq_width=seq_width,
                             include_metadata=include_metadata,
                             incl_chromosomes=[
                                 c for c in all_chr
                                 if c not in valid_chr + test_chr + exclude_chr
                             ],
                             excl_chromosomes=valid_chr + test_chr +
                             exclude_chr,
                             tasks=tasks,
                             taskname_first=taskname_first,
                             interval_transformer=interval_transformer,
                             shuffle=shuffle,
                             target_transformer=target_transformer,
                             profile_bias_pool_size=profile_bias_pool_size)),
        ])
Example #9
0
def evaluate(model_dir,
             output_dir=None,
             gpu=0,
             exclude_metrics=False,
             splits=['train', 'valid'],
             model_path=None,
             data=None,
             hparams=None,
             dataspec=None,
             preprocessor=None):
    """
    Args:
      model_dir: path to the model directory
      splits: For which data splits to compute the evaluation metrics
      model_metrics: if True, metrics computed using mode.evaluate(..)
    """
    if gpu is not None:
        create_tf_session(gpu)
    if dataspec is not None:
        ds = DataSpec.load(dataspec)
    else:
        ds = DataSpec.load(os.path.join(model_dir, "dataspec.yaml"))
    if hparams is not None:
        hp = HParams.load(hparams)
    else:
        hp = HParams.load(os.path.join(model_dir, "hparams.yaml"))
    if model_path is not None:
        model = load_model(model_path)
    else:
        model = load_model(os.path.join(model_dir, "model.h5"))
    if output_dir is None:
        output_dir = os.path.join(model_dir, "eval")
    train, valid, test = load_data(model_dir,
                                   dataspec=dataspec,
                                   hparams=hparams,
                                   data=data,
                                   preprocessor=preprocessor)
    data = dict(train=train, valid=valid, test=test)

    metrics = {}
    profile_metrics = []
    os.makedirs(os.path.join(output_dir, "plots"),
                exist_ok=True)
    for split in tqdm(splits):
        y_pred = model.predict(data[split][0])
        y_true = data[split][1]
        if not exclude_metrics:
            eval_metrics_values = model.evaluate(data[split][0],
                                                 data[split][1])
            eval_metrics = dict(zip(_listify(model.metrics_names),
                                    _listify(eval_metrics_values)))
            eval_metrics = {split + "/" + k.replace("_", "/"): v
                            for k, v in eval_metrics.items()}
            metrics = {**eval_metrics, **metrics}
        for task in ds.task_specs:
            # Counts

            yp = y_pred[ds.task2idx(task, "counts")].sum(axis=-1)
            yt = y_true["counts/" + task].sum(axis=-1)
            # compute the correlation
            rp = pearsonr(yt, yp)[0]
            rs = spearmanr(yt, yp)[0]
            metrics = {**metrics,
                       split + f"/counts/{task}/pearsonr": rp,
                       split + f"/counts/{task}/spearmanr": rs,
                       }

            fig = plt.figure(figsize=(5, 5))
            plt.scatter(yp, yt, alpha=0.5)
            plt.xlabel("Predicted")
            plt.ylabel("Observed")
            plt.title(f"R_pearson={rp:.2f}, R_spearman={rs:.2f}")
            plt.savefig(os.path.join(output_dir, f"plots/counts.{split}.{task}.png"))

            # Profile
            yp = softmax(y_pred[ds.task2idx(task, "profile")])
            yt = y_true["profile/" + task]
            df = eval_profile(yt, yp,
                              pos_min_threshold=hp.evaluate.pos_min_threshold,
                              neg_max_threshold=hp.evaluate.neg_max_threshold,
                              required_min_pos_counts=hp.evaluate.required_min_pos_counts,
                              binsizes=hp.evaluate.binsizes)
            df['task'] = task
            df['split'] = split
            # Evaluate for the smallest binsize
            auprc_min = df[df.binsize == min(hp.evaluate.binsizes)].iloc[0].auprc
            metrics[split + f'/profile/{task}/auprc'] = auprc_min
            profile_metrics.append(df)

    # Write the count metrics
    write_json(metrics, os.path.join(output_dir, "metrics.json"))

    # write the profile metrics
    dfm = pd.concat(profile_metrics)
    dfm.to_csv(os.path.join(output_dir, "profile_metrics.tsv"),
               sep='\t', index=False)
    return dfm, metrics
Example #10
0
def imp_score(model_dir,
              output_file,
              method="grad",
              split='all',
              batch_size=512,
              num_workers=10,
              h5_chunk_size=512,
              max_batches=-1,
              shuffle_seq=False,
              memfrac=0.45,
              exclude_chr='',
              overwrite=False,
              gpu=None):
    """Run importance scores for a BPNet model

    Args:
      model_dir: path to the model directory
      output_file: output file path (HDF5 format)
      method: which importance scoring method to use ('grad', 'deeplift' or 'ism')
      split: for which dataset split to compute the importance scores
      h5_chunk_size: hdf5 chunk size.
      exclude_chr: comma-separated list of chromosomes to exclude
      overwrite: if True, overwrite the output directory
      gpu (int): which GPU to use locally. If None, GPU is not used
    """
    add_file_logging(os.path.dirname(output_file), logger, 'modisco-score')
    if gpu is not None:
        create_tf_session(gpu, per_process_gpu_memory_fraction=memfrac)
    else:
        # Don't use any GPU's
        os.environ['CUDA_VISIBLE_DEVICES'] = ''

    if os.path.exists(output_file):
        if overwrite:
            os.remove(output_file)
        else:
            raise ValueError(f"File exists {output_file}. Use overwrite=True to overwrite it")

    if exclude_chr:
        exclude_chr = exclude_chr.split(",")
    else:
        exclude_chr = []
    # load the config files
    logger.info("Loading the config files")
    model_dir = Path(model_dir)
    hp = HParams.load(model_dir / "hparams.yaml")
    ds = DataSpec.load(model_dir / "dataspec.yaml")
    tasks = list(ds.task_specs)
    # validate that the correct dataset was used
    if hp.data.name != 'get_StrandedProfile_datasets':
        logger.warn("hp.data.name != 'get_StrandedProfile_datasets'")

    if split == 'valid':
        assert len(exclude_chr) == 0
        incl_chromosomes = hp.data.kwargs['valid_chr']
        excl_chromosomes = None
    elif split == 'test':
        assert len(exclude_chr) == 0
        incl_chromosomes = hp.data.kwargs['test_chr']
        excl_chromosomes = None
    elif split == 'train':
        assert len(exclude_chr) == 0
        incl_chromosomes = None
        excl_chromosomes = hp.data.kwargs['valid_chr'] + hp.data.kwargs['test_chr'] + hp.data.kwargs.get('exclude_chr', [])
    elif split == 'all':
        incl_chromosomes = None
        excl_chromosomes = hp.data.kwargs.get('exclude_chr', []) + exclude_chr
        logger.info("Excluding chromosomes: {excl_chromosomes}")
    else:
        raise ValueError("split needs to be from {train,valid,test,all}")

    logger.info("Creating the dataset")
    from basepair.datasets import StrandedProfile
    seq_len = hp.data.kwargs['peak_width']
    dl_valid = StrandedProfile(ds,
                               incl_chromosomes=incl_chromosomes,
                               excl_chromosomes=excl_chromosomes,
                               peak_width=seq_len,
                               shuffle=False,
                               target_transformer=None)

    bpnet = BPNet.from_mdir(model_dir)

    writer = HDF5BatchWriter(output_file, chunk_size=h5_chunk_size)
    for i, batch in enumerate(tqdm(dl_valid.batch_iter(batch_size=batch_size, num_workers=num_workers))):
        if max_batches > 0:
            logging.info(f"max_batches: {max_batches} exceeded. Stopping the computation")
            if i > max_batches:
                break
        # append the bias model predictions
        # (batch['inputs'], batch['targets']) = bm((batch['inputs'], batch['targets']))

        # store the original batch containing 'inputs' and 'targets'
        wdict = batch

        if shuffle_seq:
            # Di-nucleotide shuffle the sequences
            if 'seq' in batch['inputs']:
                batch['inputs']['seq'] = onehot_dinucl_shuffle(batch['inputs']['seq'])
            else:
                batch['inputs'] = onehot_dinucl_shuffle(batch['inputs'])

        # loop through all tasks, pred_summary and strands
        for task_i, task in enumerate(tasks):
            for pred_summary in ['count', 'weighted']:
                # figure out the number of channels
                nstrands = batch['targets'][f'profile/{task}'].shape[-1]
                strand_hash = ["pos", "neg"]

                for strand_i in range(nstrands):
                    hyp_imp = bpnet.imp_score(batch['inputs'],
                                              task=task,
                                              strand=strand_hash[strand_i],
                                              method=method,
                                              pred_summary=pred_summary,
                                              batch_size=None)  # don't second-batch
                    # put importance scores to the dictionary
                    wdict[f"/hyp_imp/{task}/{pred_summary}/{strand_i}"] = hyp_imp
        writer.batch_write(wdict)
    writer.close()
Example #11
0
def imp_score_seqmodel(model_dir,
                       output_file,
                       dataspec=None,
                       peak_width=1000,
                       seq_width=None,
                       intp_pattern='*',  # specifies which imp. scores to compute
                       # skip_trim=False,  # skip trimming the output
                       method="deeplift",
                       batch_size=512,
                       max_batches=-1,
                       shuffle_seq=False,
                       memfrac=0.45,
                       num_workers=10,
                       h5_chunk_size=512,
                       exclude_chr='',
                       include_chr='',
                       overwrite=False,
                       skip_bias=False,
                       gpu=None):
    """Run importance scores for a BPNet model

    Args:
      model_dir: path to the model directory
      output_file: output file path (HDF5 format)
      method: which importance scoring method to use ('grad', 'deeplift' or 'ism')
      split: for which dataset split to compute the importance scores
      h5_chunk_size: hdf5 chunk size.
      exclude_chr: comma-separated list of chromosomes to exclude
      overwrite: if True, overwrite the output directory
      skip_bias: if True, don't store the bias tracks in teh output
      gpu (int): which GPU to use locally. If None, GPU is not used
    """
    add_file_logging(os.path.dirname(output_file), logger, 'modisco-score-seqmodel')
    if gpu is not None:
        create_tf_session(gpu, per_process_gpu_memory_fraction=memfrac)
    else:
        # Don't use any GPU's
        os.environ['CUDA_VISIBLE_DEVICES'] = ''

    if os.path.exists(output_file):
        if overwrite:
            os.remove(output_file)
        else:
            raise ValueError(f"File exists {output_file}. Use overwrite=True to overwrite it")

    if seq_width is None:
        logger.info("Using seq_width = peak_width")
        seq_width = peak_width

    # make sure these are int's
    seq_width = int(seq_width)
    peak_width = int(peak_width)

    # Split
    intp_patterns = intp_pattern.split(",")

    # Allow chr inclusion / exclusion
    if exclude_chr:
        exclude_chr = exclude_chr.split(",")
    else:
        exclude_chr = None
    if include_chr:
        include_chr = include_chr.split(",")
    else:
        include_chr = None

    logger.info("Loading the config files")
    model_dir = Path(model_dir)

    if dataspec is None:
        # Specify dataspec
        dataspec = model_dir / "dataspec.yaml"
    ds = DataSpec.load(dataspec)

    logger.info("Creating the dataset")
    from basepair.datasets import StrandedProfile
    dl_valid = StrandedProfile(ds,
                               incl_chromosomes=include_chr,
                               excl_chromosomes=exclude_chr,
                               peak_width=peak_width,
                               seq_width=seq_width,
                               shuffle=False,
                               taskname_first=True,  # Required to work nicely with imp-score
                               target_transformer=None)

    # Setup importance score trimming
    if seq_width > peak_width:
        # Trim
        # make sure we can nicely trim the peak
        logger.info("Trimming the output")
        assert (seq_width - peak_width) % 2 == 0
        trim_start = (seq_width - peak_width) // 2
        trim_end = seq_width - trim_start
        assert trim_end - trim_start == peak_width
    elif seq_width == peak_width:
        trim_start = 0
        trim_end = peak_width
    else:
        raise ValueError("seq_width < peak_width")

    seqmodel = SeqModel.from_mdir(model_dir)

    # get all possible interpretation names
    # make sure they match the specified glob
    intp_names = [name for name, _ in seqmodel.get_intp_tensors(preact_only=False)
                  if fnmatch_any(name, intp_patterns)]
    logger.info(f"Using the following interpretation targets:")
    for n in intp_names:
        print(n)

    writer = HDF5BatchWriter(output_file, chunk_size=h5_chunk_size)
    for i, batch in enumerate(tqdm(dl_valid.batch_iter(batch_size=batch_size, num_workers=num_workers))):
        # store the original batch containing 'inputs' and 'targets'
        wdict = batch
        if skip_bias:
            wdict['inputs'] = {'seq': wdict['inputs']['seq']}  # ignore all other inputs

        if max_batches > 0:
            logging.info(f"max_batches: {max_batches} exceeded. Stopping the computation")
            if i > max_batches:
                break

        if shuffle_seq:
            # Di-nucleotide shuffle the sequences
            batch['inputs']['seq'] = onehot_dinucl_shuffle(batch['inputs']['seq'])

        for name in intp_names:
            hyp_imp = seqmodel.imp_score(batch['inputs']['seq'],
                                         name=name,
                                         method=method,
                                         batch_size=None)  # don't second-batch

            # put importance scores to the dictionary
            # also trim the importance scores appropriately so that
            # the output will always be w.r.t. the peak center
            wdict[f"/hyp_imp/{name}"] = hyp_imp[:, trim_start:trim_end]

        # trim the sequence as well
        if isinstance(wdict['inputs'], dict):
            # Trim the sequence
            wdict['inputs']['seq'] = wdict['inputs']['seq'][:, trim_start:trim_end]
        else:
            wdict['inputs'] = wdict['inputs'][:, trim_start:trim_end]

        writer.batch_write(wdict)
    writer.close()
Example #12
0
        isf = ImpScoreFile(models_dir / exp / 'deeplift.imp_score.h5',
                           default_imp_score=imp_score)
        dfi_subset = pd.read_parquet(
            models_dir / exp / "deeplift/dfi_subset.parq",
            engine='fastparquet').assign(model=model_name).assign(exp=exp)
        mr = MultipleModiscoResult({
            t: models_dir / exp / f'deeplift/{t}/out/{imp_score}/modisco.h5'
            for t in get_tasks(exp)
        })
        return isf, dfi_subset, mr

    # load DataSpec
    # old config
    rdir = get_repo_root()
    dataspec_file = rdir / "src/chipnexus/train/seqmodel/ChIP-nexus.dataspec.yml"
    ds = DataSpec.load(dataspec_file)

    # Load all files into the page cache
    logger.info("Touch all files")
    ds.touch_all_files()

    # --------------------------------------------
    # ### nexus/profile
    model_name = 'nexus/profile'
    imp_score = 'profile/wn'
    exp = models[model_name]
    isf, dfi_subset, mr = load_data(model_name, imp_score, exp)
    ranges_profile = isf.get_ranges()
    profiles = isf.get_profiles()
    # TODO - fix seqlets
    dfi_list.append(