Exemple #1
0
    def __getitem__(self, idx):
        if self.fasta_extractor is None:
            self.fasta_extractor = FastaExtractor(self.fasta_file)

        interval, labels = self.tsv[idx]

        if self.auto_resize_len:
            # automatically resize the sequence to cerat
            interval = resize_interval(interval, self.auto_resize_len)

        # Run the fasta extractor
        seq = np.squeeze(self.fasta_extractor([interval]))

        return {
            "inputs": {"seq": seq},
            "targets": labels,
            "metadata": {
                "ranges": GenomicRanges(chr=interval.chrom,
                                        start=interval.start,
                                        end=interval.stop,
                                        id=str(idx),
                                        strand=(interval.strand
                                                if interval.strand is not None
                                                else "*"),
                                        ),
                "interval_from_task": ''
            }
        }
Exemple #2
0
def get_nonredundant_example_idx(ranges, width=200):
    """Get non - overlapping intervals(in the central region)

    Args:
      ranges: pandas.DataFrame returned by bpnet.cli.modisco.load_ranges
      width: central region considered that should not overlap between
         any interval
    """
    from pybedtools import BedTool
    from bpnet.preproc import resize_interval
    # 1. resize ranges
    ranges['example_idx'] = np.arange(len(ranges))  # make sure
    r = ranges[['chrom', 'start', 'end',
                'example_idx']]  # add also the strand information
    if width is not None:
        r = resize_interval(r, width, ignore_strand=True)

    bt = BedTool.from_dataframe(r)
    btm = bt.sort().merge()
    df = btm.to_dataframe()
    df = df[(df.end - df.start) < width * 2]

    r_overlaps = bt.intersect(BedTool.from_dataframe(df),
                              wb=True).to_dataframe()
    keep_idx = r_overlaps.drop_duplicates(['score', 'strand',
                                           'thickStart'])['name'].astype(int)

    return keep_idx
Exemple #3
0
def dataspec_stats(dataspec, regions=None, sample=None, peak_width=1000):
    """Compute the stats about the tracks
    """
    import random
    from pybedtools import BedTool
    from bpnet.preproc import resize_interval
    from genomelake.extractors import FastaExtractor

    ds = DataSpec.load(dataspec)

    if regions is not None:
        regions = list(BedTool(regions))
    else:
        regions = ds.get_all_regions()

    if sample is not None and sample < len(regions):
        logger.info(
            f"Using {sample} randomly sampled regions instead of {len(regions)}"
        )
        regions = random.sample(regions, k=sample)

    # resize the regions
    regions = [
        resize_interval(interval, peak_width, ignore_strand=True)
        for interval in regions
    ]

    base_freq = FastaExtractor(ds.fasta_file)(regions).mean(axis=(0, 1))

    count_stats = _track_stats(ds.load_counts(regions, progbar=True))
    bias_count_stats = _track_stats(ds.load_bias_counts(regions, progbar=True))

    print("")
    print("Base frequency")
    for i, base in enumerate(['A', 'C', 'G', 'T']):
        print(f"- {base}: {base_freq[i]}")
    print("")
    print("Count stats")
    for task, stats in count_stats.items():
        print(f"- {task}")
        for stat_key, stat_value in stats.items():
            print(f"  {stat_key}: {stat_value}")
    print("")
    print("Bias stats")
    for task, stats in bias_count_stats.items():
        print(f"- {task}")
        for stat_key, stat_value in stats.items():
            print(f"  {stat_key}: {stat_value}")

    lamb = np.mean([v["total median"] for v in count_stats.values()]) / 10
    print("")
    print(
        f"We recommend to set lambda=total_count_median / 10 = {lamb:.2f} (default=10) in `bpnet train --override=` "
        "to put 5x more weight on profile prediction than on total count prediction."
    )
Exemple #4
0
def bpnet_export_bw(
        model_dir,
        output_prefix,
        fasta_file=None,
        regions=None,
        contrib_method='grad',
        contrib_wildcard='*/profile/wn,*/counts/pre-act',  # specifies which contrib. scores to compute
        batch_size=256,
        scale_contribution=False,
        flip_negative_strand=False,
        gpu=0,
        memfrac_gpu=0.45):
    """Export model predictions and contribution scores to big-wig files
    """
    from pybedtools import BedTool
    from bpnet.modisco.core import Seqlet
    output_dir = os.path.dirname(output_prefix)
    add_file_logging(output_dir, logger, 'bpnet-export-bw')
    os.makedirs(output_dir, exist_ok=True)
    if gpu is not None:
        create_tf_session(gpu, per_process_gpu_memory_fraction=memfrac_gpu)

    logger.info("Load model")

    bp = BPNetSeqModel.from_mdir(model_dir)

    if regions is not None:
        logger.info(
            f"Computing predictions and contribution scores for provided regions: {regions}"
        )
        regions = list(BedTool(regions))
    else:
        logger.info("--regions not provided. Using regions from dataspec.yml")
        ds = DataSpec.load(os.path.join(model_dir, 'dataspec.yml'))
        regions = ds.get_all_regions()

    seqlen = bp.input_seqlen()
    logger.info(
        f"Resizing regions (fix=center) to model's input width of: {seqlen}")
    regions = [resize_interval(interval, seqlen) for interval in regions]
    logger.info("Sort the bed file")
    regions = list(BedTool(regions).sort())

    bp.export_bw(regions=regions,
                 output_prefix=output_prefix,
                 contrib_method=contrib_method,
                 fasta_file=fasta_file,
                 pred_summaries=contrib_wildcard.replace("*/", "").split(","),
                 batch_size=batch_size,
                 scale_contribution=scale_contribution,
                 flip_negative_strand=flip_negative_strand,
                 chromosomes=None)  # infer chromosomes from the fasta file
Exemple #5
0
    def __getitem__(self, idx):
        from pybedtools import Interval

        if self.fasta_extractor is None:
            # first call
            # Use normal fasta/bigwig extractors
            self.fasta_extractor = FastaExtractor(self.ds.fasta_file, use_strand=True)

            self.bw_extractors = {task: [BigwigExtractor(track) for track in task_spec.tracks]
                                  for task, task_spec in self.ds.task_specs.items() if task in self.tasks}

            self.bias_bw_extractors = {task: [BigwigExtractor(track) for track in task_spec.tracks]
                                       for task, task_spec in self.ds.bias_specs.items()}

        # Get the genomic interval for that particular datapoint
        interval = Interval(self.dfm.iat[idx, 0],  # chrom
                            self.dfm.iat[idx, 1],  # start
                            self.dfm.iat[idx, 2])  # end

        # Transform the input interval (for say augmentation...)
        if self.interval_transformer is not None:
            interval = self.interval_transformer(interval)

        # resize the intervals to the desired widths
        target_interval = resize_interval(deepcopy(interval), self.peak_width)
        seq_interval = resize_interval(deepcopy(interval), self.seq_width)

        # This only kicks in when we specify the taskname from dataspec
        # to the 3rd column. E.g. it doesn't apply when using intervals_file
        interval_from_task = self.dfm.iat[idx, 3] if self.intervals_file is None else ''

        # extract DNA sequence + one-hot encode it
        sequence = self.fasta_extractor([seq_interval])[0]
        inputs = {"seq": sequence}

        # exctract the profile counts from the bigwigs
        cuts = {f"{task}/profile": _run_extractors(self.bw_extractors[task],
                                                   [target_interval],
                                                   sum_tracks=spec.sum_tracks)[0]
                for task, spec in self.ds.task_specs.items() if task in self.tasks}
        if self.track_transform is not None:
            for task in self.tasks:
                cuts[f'{task}/profile'] = self.track_transform(cuts[f'{task}/profile'])

        # Add total number of counts
        for task in self.tasks:
            cuts[f'{task}/counts'] = self.total_count_transform(cuts[f'{task}/profile'].sum(0))

        if len(self.ds.bias_specs) > 0:
            # Extract the bias tracks
            biases = {bias_task: _run_extractors(self.bias_bw_extractors[bias_task],
                                                 [target_interval],
                                                 sum_tracks=spec.sum_tracks)[0]
                      for bias_task, spec in self.ds.bias_specs.items()}

            task_biases = {f"bias/{task}/profile": np.concatenate([biases[bt]
                                                                   for bt in self.task_bias_tracks[task]],
                                                                  axis=-1)
                           for task in self.tasks}

            if self.track_transform is not None:
                for task in self.tasks:
                    task_biases[f'bias/{task}/profile'] = self.track_transform(task_biases[f'bias/{task}/profile'])

            # Add total number of bias counts
            for task in self.tasks:
                task_biases[f'bias/{task}/counts'] = self.total_count_transform(task_biases[f'bias/{task}/profile'].sum(0))

            inputs = {**inputs, **task_biases}

        if self.include_classes:
            # Optionally, add binary labels from the additional columns in the tsv intervals file
            classes = {f"{task}/class": self.dfm.iat[idx, i + 3]
                       for i, task in enumerate(self.dfm_tasks) if task in self.tasks}
            cuts = {**cuts, **classes}

        out = {"inputs": inputs,
               "targets": cuts}

        if self.include_metadata:
            # remember the metadata (what genomic interval was used)
            out['metadata'] = {"range": GenomicRanges(chr=target_interval.chrom,
                                                      start=target_interval.start,
                                                      end=target_interval.stop,
                                                      id=idx,
                                                      strand=(target_interval.strand
                                                              if target_interval.strand is not None
                                                              else "*"),
                                                      ),
                               "interval_from_task": interval_from_task}
        return out