예제 #1
0
def test_fasta_extractor_valid_intervals():
    extractor = FastaExtractor("tests/data/fasta_test.fa")
    intervals = [Interval("chr1", 0, 10), Interval("chr2", 0, 10)]
    expected_data = np.array(
        [
            [
                [1., 0., 0., 0.],
                [0., 1., 0., 0.],
                [0., 1., 0., 0.],
                [0., 0., 1., 0.],
                [0., 0., 0., 1.],
                [1., 0., 0., 0.],
                [0., 1., 0., 0.],
                [0., 1., 0., 0.],
                [0., 0., 1., 0.],
                [0., 0., 0., 1.],
            ],
            [
                [1., 0., 0., 0.],
                [0., 1., 0., 0.],
                [0., 0., 1., 0.],
                [0., 0., 0., 1.],
                [0.25, 0.25, 0.25, 0.25],
                [1., 0., 0., 0.],
                [0., 1., 0., 0.],
                [0., 0., 1., 0.],
                [0., 0., 0., 1.],
                [0.25, 0.25, 0.25, 0.25],
            ],
        ],
        dtype=np.float32,
    )
    data = extractor(intervals)
    assert (data == expected_data).all()
예제 #2
0
    def __getitem__(self, idx):
        if self.fasta_extractor is None:
            self.fasta_extractor = FastaExtractor(self.fasta_file)
        interval = self.bt[idx]

        if interval.stop - interval.start != self.SEQ_WIDTH:
            raise ValueError("Expected the interval to be {0} wide. Recieved stop - start = {1}".
                             format(self.SEQ_WIDTH, interval.stop - interval.start))

        if interval.name is not None:
            y = np.array([float(interval.name)])
        else:
            y = {}

        # Run the fasta extractor
        seq = np.squeeze(self.fasta_extractor([interval]))

        # Reformat so that it matches the Basset shape
        # seq = np.swapaxes(seq, 1, 0)[:,:,None]
        return {
            "inputs": {"data/genome_data_dir": seq},
            "targets": y,
            "metadata": {
                "ranges": GenomicRanges.from_interval(interval)
            }
        }
예제 #3
0
    def __getitem__(self, idx):
        if self.fasta_extractor is None:
            self.fasta_extractor = FastaExtractor(self.fasta_file)

        interval, labels = self.tsv[idx]

        if self.auto_resize_len:
            # automatically resize the sequence to cerat
            interval = resize_interval(interval, self.auto_resize_len)

        # Run the fasta extractor
        seq = np.squeeze(self.fasta_extractor([interval]))

        return {
            "inputs": {"seq": seq},
            "targets": labels,
            "metadata": {
                "ranges": GenomicRanges(chr=interval.chrom,
                                        start=interval.start,
                                        end=interval.stop,
                                        id=str(idx),
                                        strand=(interval.strand
                                                if interval.strand is not None
                                                else "*"),
                                        ),
                "interval_from_task": ''
            }
        }
예제 #4
0
 def __init__(self, intervals_file, fasta_file):
     # intervals
     # if use_linecache:
      #   self.bt = BedToolLinecache(intervals_file)
     # else:
     self.bt = BedTool(intervals_file)
     self.fasta_extractor = FastaExtractor(fasta_file)
예제 #5
0
    def __getitem__(self, idx):
        if self.fasta_extractor is None:
            self.fasta_extractor = FastaExtractor(self.fasta_file)
        interval = self.bt[idx]

        if interval.stop - interval.start != self.SEQ_WIDTH:
            raise ValueError(
                "Expected the interval to be {0} wide. Recieved stop - start = {1}"
                .format(self.SEQ_WIDTH, interval.stop - interval.start))

        if self.targets is not None:
            y = self.targets.iloc[idx].values
        else:
            y = {}

        # Run the fasta extractor
        seq = np.squeeze(self.fasta_extractor([interval]), axis=0)
        seq = np.expand_dims(np.swapaxes(seq, 1, 0), axis=1)
        return {
            "inputs": seq,
            "targets": y,
            "metadata": {
                "ranges": GenomicRanges.from_interval(interval)
            }
        }
예제 #6
0
    def __getitem__(self, idx):
        if self.fasta_extractor is None:
            self.fasta_extractor = FastaExtractor(self.fasta_file)
        # create interval correctly here
        interval = self.bt[idx]

        # Intervals need to be 1000bp wide
        assert interval.stop - interval.start == 1000

        # check targets is none, pass targets file
        if interval.name is not None:
            y = np.array([float(interval.name)])
        else:
            y = {}

        # Run the fasta extractor
        seq = np.squeeze(self.fasta_extractor([interval]))

        # Reformat so that it matches the Basset shape
        # seq = np.swapaxes(seq, 1, 0)[:,:,None]
        return {
            "inputs": {"data/genome_data_dir": seq},
            "targets": y,
            "metadata": {
                "ranges": GenomicRanges.from_interval(interval)
            }
        }
예제 #7
0
파일: BPNet.py 프로젝트: dabai2/bpnet
    def get_seq(self,
                regions,
                variants=None,
                use_strand=False,
                fasta_file=None):
        """Get the one-hot-encoded sequence used to make model predictions and
        optionally augment it with the variants
        """
        if fasta_file is None:
            fasta_file = self.fasta_file

        if variants is not None:
            if use_strand:
                raise NotImplementedError(
                    "use_strand=True not implemented for variants")
            # Augment the regions using a variant
            if not isinstance(variants, list):
                variants = [variants] * len(regions)
            else:
                assert len(variants) == len(regions)
            seq = np.stack([
                extract_seq(interval, variant, fasta_file, one_hot=True)
                for variant, interval in zip(variants, regions)
            ])
        else:
            variants = [None] * len(regions)
            seq = FastaExtractor(fasta_file, use_strand=use_strand)(regions)
        return seq
예제 #8
0
    def __getitem__(self, idx):
        if self.fasta_extractor is None:
            self.fasta_extractor = FastaExtractor(self.fasta_file)
        interval = self.bt[idx]

        if interval.stop - interval.start != self.SEQ_WIDTH:
            center = (interval.start + interval.stop) // 2
            interval.start = center - self.SEQ_WIDTH // 2
            interval.end = center + self.SEQ_WIDTH // 2 + self.SEQ_WIDTH % 2

        if self.targets is not None:
            y = self.targets.iloc[idx].values
        else:
            y = {}

        # Run the fasta extractor
        seq = np.squeeze(self.fasta_extractor([interval]), axis=0)
        # Reformat so that it matches the DeepSEA shape
        seq = np.swapaxes(seq, 1, 0)[:, None, :]
        return {
            "inputs": seq,
            "targets": y,
            "metadata": {
                "ranges": GenomicRanges.from_interval(interval)
            }
        }
예제 #9
0
    def __init__(self,
                 intervals_file,
                 fasta_file,
                 gtf_file,
                 preproc_transformer,
                 target_file=None):
        gtf = pd.read_pickle(gtf_file)
        self.gtf = gtf[gtf["info"].str.contains('gene_type "protein_coding"')]
        self.gtf = self.gtf.rename(columns={"seqnames":
                                            "seqname"})  # concise>=0.6.5

        # distance transformer
        with open(preproc_transformer, "rb") as f:
            self.transformer = pickle.load(f)

        # intervals
        self.bt = pybedtools.BedTool(intervals_file)

        # extractors
        self.input_data_extractors = {
            "seq":
            FastaExtractor(fasta_file),
            "dist_polya_st":
            DistToClosestLandmarkExtractor(gtf_file=self.gtf,
                                           landmarks=["polya"])
        }

        # target
        if target_file:
            self.target_dataset = TxtDataset(target_file)
            assert len(self.target_dataset) == len(self.bt)
        else:
            self.target_dataset = None
예제 #10
0
    def __init__(self,
                 intervals_file,
                 fasta_file,
                 dnase_file,
                 mappability_file=None,
                 use_linecache=True):

        # intervals
        if use_linecache:
            linecache.clearcache()
            BT = BedToolLinecache
        else:
            BT = BedTool

        self.bt = BT(intervals_file)

        # Fasta
        self.fasta_extractor = FastaExtractor(fasta_file)

        # DNase
        self.dnase_extractor = BigwigExtractor(dnase_file)
        # mappability
        if mappability_file is None:
            # download the mappability file if not existing
            mappability_file = os.path.join(
                this_dir, "../../template/dataloader_files",
                "wgEncodeDukeMapabilityUniqueness35bp.bigWig")
            if not os.path.exists(mappability_file):
                print("Downloading the mappability file")
                urlretrieve(
                    "http://hgdownload.cse.ucsc.edu/goldenPath/hg19/encodeDCC/wgEncodeMapability/wgEncodeDukeMapabilityUniqueness35bp.bigWig",
                    mappability_file)
                print("Download complete")

        self.mappability_extractor = BigwigExtractor(mappability_file)
예제 #11
0
    def __getitem__(self, idx):
        if self.seq_extractor is None:
            self.seq_extractor = FastaExtractor(self.fasta_file)
            self.dist_extractor = DistToClosestLandmarkExtractor(gtf_file=self.gtf,
                                                                 landmarks=ALL_LANDMARKS)

        interval = self.bt[idx]

        if interval.stop - interval.start != self.SEQ_WIDTH:
            raise ValueError("Expected the interval to be {0} wide. Recieved stop - start = {1}".
                             format(self.SEQ_WIDTH, interval.stop - interval.start))
        out = {}
        out['inputs'] = {}
        # input - sequence
        out['inputs']['seq'] = np.squeeze(self.seq_extractor([interval]), axis=0)

        # input - distance
        dist_dict = self.dist_transformer.transform(self.dist_extractor([interval]))
        dist_dict = {k: np.squeeze(v, axis=0) for k, v in dist_dict.items()}  # squeeze the batch axis
        out['inputs'] = {**out['inputs'], **dist_dict}

        # targets
        if self.target_dataset is not None:
            out["targets"] = np.array([self.target_dataset[idx]])

        # metadata
        out['metadata'] = {}
        out['metadata']['ranges'] = GenomicRanges.from_interval(interval)

        return out
예제 #12
0
    def __getitem__(self, idx):
        if self.fasta_extractor is None:
            # Fasta
            self.fasta_extractor = FastaExtractor(self.fasta_file)
            # DNase
            self.dnase_extractor = BigwigExtractor(self.dnase_file)
            self.mappability_extractor = BigwigExtractor(self.mappability_file)

        # Get the interval
        interval = self.bt[idx]
        if interval.stop - interval.start != self.SEQ_WIDTH:
            center = (interval.start + interval.stop) // 2
            interval.start = center - self.SEQ_WIDTH // 2
            interval.end = center + self.SEQ_WIDTH // 2 + self.SEQ_WIDTH % 2
        # Get the gencode features
        gencode_counts = np.array([v[idx].count for k, v in self.overlap_beds],
                                  dtype=bool)

        # Run the fasta extractor
        seq = np.squeeze(self.fasta_extractor([interval]), axis=0)
        seq_rc = seq[::-1, ::-1]

        # Dnase
        dnase = np.squeeze(self.dnase_extractor([interval],
                                                axis=0))[:, np.newaxis]
        dnase[np.isnan(dnase)] = 0  # NA fill
        dnase_rc = dnase[::-1]

        bigwig_list = [seq]
        bigwig_rc_list = [seq_rc]
        mappability = np.squeeze(self.mappability_extractor(
            [interval], axis=0))[:, np.newaxis]
        mappability[np.isnan(mappability)] = 0  # NA fill
        mappability_rc = mappability[::-1]
        bigwig_list.append(mappability)
        bigwig_rc_list.append(mappability_rc)
        bigwig_list.append(dnase)
        bigwig_rc_list.append(dnase_rc)

        ranges = GenomicRanges.from_interval(interval)
        ranges_rc = GenomicRanges.from_interval(interval)
        ranges_rc.strand = "-"

        return {
            "inputs": [
                np.concatenate(bigwig_list,
                               axis=-1),  # stack along the last axis
                np.concatenate(bigwig_rc_list, axis=-1),  # RC version
                np.append(self.meta_feat, gencode_counts)
            ],
            "targets": {},  # No Targets
            "metadata": {
                "ranges": ranges,
                "ranges_rc": ranges_rc
            }
        }
예제 #13
0
    def __init__(self, intervals_file, fasta_file, target_file=None):

        self.bt = BedTool(intervals_file)
        self.fasta_extractor = FastaExtractor(fasta_file)

        # Targets
        if target_file is not None:
            self.targets = pd.read_csv(target_file)
        else:
            self.targets = None
예제 #14
0
def dataspec_stats(dataspec, regions=None, sample=None, peak_width=1000):
    """Compute the stats about the tracks
    """
    import random
    from pybedtools import BedTool
    from bpnet.preproc import resize_interval
    from genomelake.extractors import FastaExtractor

    ds = DataSpec.load(dataspec)

    if regions is not None:
        regions = list(BedTool(regions))
    else:
        regions = ds.get_all_regions()

    if sample is not None and sample < len(regions):
        logger.info(
            f"Using {sample} randomly sampled regions instead of {len(regions)}"
        )
        regions = random.sample(regions, k=sample)

    # resize the regions
    regions = [
        resize_interval(interval, peak_width, ignore_strand=True)
        for interval in regions
    ]

    base_freq = FastaExtractor(ds.fasta_file)(regions).mean(axis=(0, 1))

    count_stats = _track_stats(ds.load_counts(regions, progbar=True))
    bias_count_stats = _track_stats(ds.load_bias_counts(regions, progbar=True))

    print("")
    print("Base frequency")
    for i, base in enumerate(['A', 'C', 'G', 'T']):
        print(f"- {base}: {base_freq[i]}")
    print("")
    print("Count stats")
    for task, stats in count_stats.items():
        print(f"- {task}")
        for stat_key, stat_value in stats.items():
            print(f"  {stat_key}: {stat_value}")
    print("")
    print("Bias stats")
    for task, stats in bias_count_stats.items():
        print(f"- {task}")
        for stat_key, stat_value in stats.items():
            print(f"  {stat_key}: {stat_value}")

    lamb = np.mean([v["total median"] for v in count_stats.values()]) / 10
    print("")
    print(
        f"We recommend to set lambda=total_count_median / 10 = {lamb:.2f} (default=10) in `bpnet train --override=` "
        "to put 5x more weight on profile prediction than on total count prediction."
    )
예제 #15
0
    def __init__(self, intervals_file, fasta_file,
                 use_linecache=False):

        # intervals
        if use_linecache:
            self.bt = BedToolLinecache(intervals_file)
        else:
            self.bt = BedTool(intervals_file)
        self.fasta_extractor = FastaExtractor(fasta_file)

        if len(self.bt) % 2 == 1:
            raise ValueError("Basenji strictly requires batch_size=2," +
                             " hence the bed file should have an od length")
예제 #16
0
    def __init__(self,
                 intervals_file,
                 fasta_file,
                 gtf_file,
                 filter_protein_coding=True,
                 target_file=None,
                 use_linecache=False):
        if sys.version_info[0] != 3:
            warnings.warn(
                "Only Python 3 is supported. You are using Python {0}".format(
                    sys.version_info[0]))
        self.gtf = read_gtf(gtf_file)

        self.filter_protein_coding = filter_protein_coding

        if self.filter_protein_coding:
            if "gene_type" in self.gtf:
                self.gtf = self.gtf[self.gtf["gene_type"] == "protein_coding"]
            elif "gene_biotype" in self.gtf:
                self.gtf = self.gtf[self.gtf["gene_biotype"] ==
                                    "protein_coding"]
            else:
                warnings.warn(
                    "Gtf doesn't have the field 'gene_type' or 'gene_biotype'. Considering genomic landmarks"
                    + "of all genes not just protein_coding.")

        if not np.any(self.gtf.seqname.str.contains("chr")):
            self.gtf["seqname"] = "chr" + self.gtf["seqname"]

        # intervals
        if use_linecache:
            self.bt = BedToolLinecache(intervals_file)
        else:
            self.bt = BedTool(intervals_file)

        # extractors
        self.seq_extractor = FastaExtractor(fasta_file)
        self.dist_extractor = DistToClosestLandmarkExtractor(
            gtf_file=self.gtf, landmarks=ALL_LANDMARKS)

        # here the DATALOADER_DIR contains the path to the current directory
        self.dist_transformer = DistanceTransformer(
            ALL_LANDMARKS,
            DATALOADER_DIR + "/dataloader_files/position_transformer.pkl")

        # target
        if target_file:
            self.target_dataset = TxtDataset(target_file)
            assert len(self.target_dataset) == len(self.bt)
        else:
            self.target_dataset = None
예제 #17
0
    def __init__(self,
                 intervals_file,
                 fasta_file,
                 dnase_file,
                 cell_line=None,
                 RNAseq_PC_file=None,
                 mappability_file=None,
                 use_linecache=True):

        # intervals
        if use_linecache:
            linecache.clearcache()
            BT = BedToolLinecache
        else:
            BT = BedTool

        self.bt = BT(intervals_file)

        # Fasta
        self.fasta_extractor = FastaExtractor(fasta_file)

        # DNase
        self.dnase_extractor = BigwigExtractor(dnase_file)
        # mappability
        if mappability_file is None:
            # download the mappability file if not existing
            mappability_file = os.path.join(
                this_dir, "../../template/dataloader_files",
                "wgEncodeDukeMapabilityUniqueness35bp.bigWig")
            if not os.path.exists(mappability_file):
                print("Downloading the mappability file")
                urlretrieve(
                    "http://hgdownload.cse.ucsc.edu/goldenPath/hg19/encodeDCC/wgEncodeMapability/wgEncodeDukeMapabilityUniqueness35bp.bigWig",
                    mappability_file)
                print("Download complete")

        self.mappability_extractor = BigwigExtractor(mappability_file)
        # Get the metadata features
        if cell_line is None:
            if RNAseq_PC_file is None:
                raise ValueError(
                    "RNAseq_PC_file has to be specified when cell_line=None")
            assert os.path.exists(RNAseq_PC_file)
        else:
            # Using the pre-defined cell-line
            rp = os.path.join(this_dir, "dataloader_files/RNAseq_features/")
            RNAseq_PC_file = os.path.join(rp, cell_line, "meta.txt")
        self.meta_feat = pd.read_csv(RNAseq_PC_file, sep="\t",
                                     header=None)[0].values
예제 #18
0
    def __getitem__(self, idx):
        if self.fasta_extractor is None:
            # Fasta
            self.fasta_extractor = FastaExtractor(self.fasta_file)
            # DNase
            self.dnase_extractor = BigwigExtractor(self.dnase_file)

        # Get the interval
        interval = self.bt[idx]
        if interval.stop - interval.start != self.SEQ_WIDTH:
            center = (interval.start + interval.stop) // 2
            interval.start = center - self.SEQ_WIDTH // 2
            interval.end = center + self.SEQ_WIDTH // 2 + self.SEQ_WIDTH % 2

        # Run the fasta extractor
        seq = np.squeeze(self.fasta_extractor([interval]), axis=0)
        seq_rc = seq[::-1, ::-1]

        # Dnase
        dnase = np.squeeze(self.dnase_extractor([interval],
                                                axis=0))[:, np.newaxis]
        dnase[np.isnan(dnase)] = 0  # NA fill
        dnase_rc = dnase[::-1]

        bigwig_list = [seq]
        bigwig_rc_list = [seq_rc]
        bigwig_list.append(dnase)
        bigwig_rc_list.append(dnase_rc)

        ranges = GenomicRanges.from_interval(interval)
        ranges_rc = GenomicRanges.from_interval(interval)
        ranges_rc.strand = "-"

        return {
            "inputs": [
                np.concatenate(bigwig_list,
                               axis=-1),  # stack along the last axis
                np.concatenate(bigwig_rc_list, axis=-1),  # RC version
            ],
            "targets": {},  # No Targets
            "metadata": {
                "ranges": ranges,
                "ranges_rc": ranges_rc
            }
        }
예제 #19
0
    def __getitem__(self, idx):
        if self.fasta_extractor is None:
            self.fasta_extractor = FastaExtractor(self.fasta_file)
        interval = self.bt[idx]

        if interval.stop - interval.start != self.SEQ_WIDTH:
            raise ValueError(
                "Expected the interval to be {0} wide. Recieved stop - start = {1}"
                .format(self.SEQ_WIDTH, interval.stop - interval.start))

        # Run the fasta extractor
        seq = np.squeeze(self.fasta_extractor([interval]), axis=0)
        return {
            "inputs": seq,
            "targets": {},  # No Targets
            "metadata": {
                "ranges": GenomicRanges.from_interval(interval)
            }
        }
예제 #20
0
    def __getitem__(self, idx):
        if self.fasta_extractor is None:
            self.fasta_extractor = FastaExtractor(self.fasta_file)

        interval, labels = self.tsv[idx]

        # Intervals need to be 1000bp wide
        assert interval.stop - interval.start == 1000

        # Run the fasta extractor
        seq = np.squeeze(self.fasta_extractor([interval]))

        return {
            "inputs": {"data/genome_data_dir": seq},
            "targets": labels,
            "metadata": {
                "ranges": GenomicRanges(interval.chrom, interval.start, interval.stop, str(idx))
            }
        }
예제 #21
0
    def __getitem__(self, idx):
        if self.fasta_extractor is None:
            self.fasta_extractor = FastaExtractor(self.fasta_file)

        interval = self.bt[idx]

        if self.targets is not None:
            y = self.targets.iloc[idx].values
        else:
            y = {}

        # Run the fasta extractor
        seq = np.squeeze(self.fasta_extractor([interval]), axis=0)
        return {
            "inputs": seq,
            "targets": y,
            "metadata": {
                "ranges": GenomicRanges.from_interval(interval)
            }
        }
예제 #22
0
    def __init__(self,
                 intervals_file,
                 fasta_file,
                 dnase_file,
                 use_linecache=True):

        # intervals
        if use_linecache:
            linecache.clearcache()
            BT = BedToolLinecache
        else:
            BT = BedTool

        self.bt = BT(intervals_file)

        # Fasta
        self.fasta_extractor = FastaExtractor(fasta_file)

        # DNase
        self.dnase_extractor = BigwigExtractor(dnase_file)
예제 #23
0
    def __getitem__(self, idx):
        if self.fasta_extractor is None:
            self.fasta_extractor = FastaExtractor(self.fasta_file)
        interval = self.bt[idx]

        # Intervals need to be 101bp wide
        assert interval.stop - interval.start == 101

        if self.targets is not None:
            y = self.targets.iloc[idx].values
        else:
            y = {}

        # Run the fasta extractor
        seq = self.fasta_extractor([interval]).squeeze()
        return {
            "inputs": seq,
            "targets": y,
            "metadata": {
                "ranges": GenomicRanges.from_interval(interval)
            }
        }
예제 #24
0
    def __getitem__(self, idx):
        if self.fasta_extractor is None:
            self.fasta_extractor = FastaExtractor(self.fasta_file)
            self.bigwig_extractors = {
                a: [BigwigExtractor(f) for f in self.bigwigs[a]]
                for a in self.bigwigs
            }

        interval, labels = self.tsv[idx]
        interval = resize_interval(interval, 1000)
        # Intervals need to be 1000bp wide
        assert interval.stop - interval.start == 1000

        # Run the fasta extractor
        seq = np.squeeze(self.fasta_extractor([interval]))

        interval_wide = resize_interval(deepcopy(interval), self.track_width)

        return {
            "inputs": {
                "seq": seq
            },
            "targets": {
                a:
                sum([e([interval_wide])[0]
                     for e in self.bigwig_extractors[a]]).sum()
                for a in self.bigwig_extractors
            },
            "metadata": {
                "ranges":
                GenomicRanges(interval.chrom, interval.start, interval.stop,
                              str(idx)),
                "ranges_wide":
                GenomicRanges.from_interval(interval_wide),
                "name":
                interval.name
            }
        }
예제 #25
0
파일: dataloader.py 프로젝트: yynst2/kipoi
    def __getitem__(self, idx):
        if self.input_data_extractors is None:
            self.input_data_extractors = {
                "seq":
                FastaExtractor(self.fasta_file),
                "dist_polya_st":
                DistToClosestLandmarkExtractor(gtf_file=self.gtf,
                                               landmarks=["polya"])
            }

        interval = self.bt[idx]

        out = {}

        out['inputs'] = {
            key: np.squeeze(extractor([interval]), axis=0)
            for key, extractor in self.input_data_extractors.items()
        }

        # use trained spline transformation to transform it
        out["inputs"]["dist_polya_st"] = np.squeeze(self.transformer.transform(
            out["inputs"]["dist_polya_st"][np.newaxis], warn=False),
                                                    axis=0)

        if self.target_dataset is not None:
            out["targets"] = np.array([self.target_dataset[idx]])

        # get metadata
        out['metadata'] = {}
        out['metadata']['ranges'] = {}
        out['metadata']['ranges']['chr'] = interval.chrom
        out['metadata']['ranges']['start'] = interval.start
        out['metadata']['ranges']['end'] = interval.stop
        out['metadata']['ranges']['id'] = interval.name
        out['metadata']['ranges']['strand'] = interval.strand

        return out
예제 #26
0
def chip_exo_nexus(dataspec,
                   peak_width=200,
                   shuffle=True,
                   preprocessor=AppendTotalCounts(),
                   interval_augm=lambda x: x,
                   valid_chr=valid_chr,
                   test_chr=test_chr):
    """
    General dataloading function for ChIP-exo or ChIP-nexus data

    Args:
      dataspec: basepair.schemas.DataSpec object containing information about
        the bigwigs, fasta_file and
      peak_width: final width of the interval to extract
      shuffle: if true, the order of the peaks will get shuffled
      preprocessor: preprocessor object - needs to implement .fit() and .predict() methods
      interval_augm: interval augmentor.
      valid_chr: list of chromosomes in the validation split
      test_chr: list of chromosomes in the test split

    Returns:
      (train, valid, test) tuple where train consists of:
        - x: one-hot encoded sequence, sample shape: (peak_width, 4)
        - y: dictionary containing fields:
          {task_id}/profile: sample shape - (peak_width, 2), count profile
          {task_id}/counts: sample shape - (2, ), total number of counts per strand
        - metadata: pandas dataframe storing the original intervals

    """
    for v in valid_chr:
        assert v not in test_chr

    def set_attrs_name(interval, name):
        """Add a name to the interval
        """
        interval.attrs['name'] = name
        return interval

    # Load intervals for all tasks.
    #   remember the task name in interval.name
    def get_bt(peaks):
        if peaks is None:
            return []
        else:
            return BedTool(peaks)

    # Resize and skip infervals outside of the genome
    from pysam import FastaFile
    fa = FastaFile(dataspec.fasta_file)
    #     intervals = len(get_bt(peaks))
    #     n_int = len(intervals)

    intervals = [
        set_attrs_name(resize_interval(interval_augm(interval), peak_width),
                       task) for task, ds in dataspec.task_specs.items()
        for i, interval in enumerate(get_bt(ds.peaks))
        if keep_interval(interval, peak_width, fa)
    ]
    #     if len(intervals) != n_int:
    #         logger.warn(f"Skipped {n_int - len(intervals)} intervals"
    #                     " outside of the genome size")

    if shuffle:
        Random(42).shuffle(intervals)

    # Setup metadata
    dfm = pd.DataFrame(
        dict(id=np.arange(len(intervals)),
             chr=[x.chrom for x in intervals],
             start=[x.start for x in intervals],
             end=[x.stop for x in intervals],
             task=[x.attrs['name'] for x in intervals]))

    logger.info("extract sequence")
    seq = FastaExtractor(dataspec.fasta_file)(intervals)

    logger.info("extract counts")
    cuts = {
        f"profile/{task}": spec.load_counts(intervals)
        for task, spec in tqdm(dataspec.task_specs.items())
    }
    # # sum across the sequence
    # for task in dataspec.task_specs:
    #     cuts[f"counts/{task}"] = cuts[f"profile/{task}"].sum(axis=1)
    assert len(seq) == len(dfm)
    assert len(seq) == len(cuts[list(cuts.keys())[0]])

    # Split by chromosomes
    is_test = dfm.chr.isin(test_chr)
    is_valid = dfm.chr.isin(valid_chr)
    is_train = (~is_test) & (~is_valid)

    train = [seq[is_train], get_dataset_item(cuts, is_train), dfm[is_train]]
    valid = [seq[is_valid], get_dataset_item(cuts, is_valid), dfm[is_valid]]
    test = [seq[is_test], get_dataset_item(cuts, is_test), dfm[is_test]]

    if preprocessor is not None:
        preprocessor.fit(train[1])
        train[1] = preprocessor.transform(train[1])
        valid[1] = preprocessor.transform(valid[1])
        test[1] = preprocessor.transform(test[1])

    train.append(preprocessor)
    return (train, valid, test)
예제 #27
0
    def __getitem__(self, idx):
        if self.fasta_extractor is None:
            # Use array extractors
            if self.bcolz:
                self.fasta_extractor = ArrayExtractor(self.ds.fasta_file,
                                                      in_memory=False)
                self.bw_extractors = {
                    task: [
                        ArrayExtractor(task_spec.pos_counts, in_memory=False),
                        ArrayExtractor(task_spec.neg_counts, in_memory=False)
                    ]
                    for task, task_spec in self.ds.task_specs.items()
                    if task in self.tasks
                }
                self.bias_bw_extractors = {
                    task: [
                        ArrayExtractor(task_spec.pos_counts, in_memory=False),
                        ArrayExtractor(task_spec.neg_counts, in_memory=False)
                    ]
                    for task, task_spec in self.ds.bias_specs.items()
                    if task in self.tasks
                }
            else:
                # Use normal fasta/bigwig extractors
                assert not self.bcolz
                # first call
                self.fasta_extractor = FastaExtractor(self.ds.fasta_file,
                                                      use_strand=True)
                self.bw_extractors = {
                    task: [
                        BigwigExtractor(task_spec.pos_counts),
                        BigwigExtractor(task_spec.neg_counts)
                    ]
                    for task, task_spec in self.ds.task_specs.items()
                    if task in self.tasks
                }
                self.bias_bw_extractors = {
                    task: [
                        BigwigExtractor(task_spec.pos_counts),
                        BigwigExtractor(task_spec.neg_counts)
                    ]
                    for task, task_spec in self.ds.bias_specs.items()
                }

        # Setup the intervals
        interval = Interval(
            self.dfm.iat[idx, 0],  # chrom
            self.dfm.iat[idx, 1],  # start
            self.dfm.iat[idx, 2])  # end

        # Transform the input interval (for say augmentation...)
        if self.interval_transformer is not None:
            interval = self.interval_transformer(interval)

        target_interval = resize_interval(deepcopy(interval), self.peak_width)
        seq_interval = resize_interval(deepcopy(interval), self.seq_width)

        # This only kicks in when we specify the taskname from dataspec
        # to the 3rd column. E.g. it doesn't apply when using intervals_file
        interval_from_task = self.dfm.iat[
            idx, 3] if self.intervals_file is None else ''

        # extract seq + tracks
        sequence = self.fasta_extractor([seq_interval])[0]

        if not self.only_classes:
            if self.taskname_first:
                cuts = {
                    f"{task}/profile":
                    run_extractors(self.bw_extractors[task], [target_interval],
                                   ignore_strand=spec.ignore_strand)[0]
                    for task, spec in self.ds.task_specs.items()
                    if task in self.tasks
                }
            else:
                cuts = {
                    f"profile/{task}":
                    run_extractors(self.bw_extractors[task], [target_interval],
                                   ignore_strand=spec.ignore_strand)[0]
                    for task, spec in self.ds.task_specs.items()
                    if task in self.tasks
                }

            # Add counts
            if self.target_transformer is not None:
                cuts = self.target_transformer.transform(cuts)

            # Add bias tracks
            if len(self.ds.bias_specs) > 0:

                biases = {
                    bias_task:
                    run_extractors(self.bias_bw_extractors[bias_task],
                                   [target_interval],
                                   ignore_strand=spec.ignore_strand)[0]
                    for bias_task, spec in self.ds.bias_specs.items()
                }

                task_biases = {
                    f"bias/{task}/profile": np.concatenate(
                        [biases[bt] for bt in self.task_bias_tracks[task]],
                        axis=-1)
                    for task in self.tasks
                }

                if self.target_transformer is not None:
                    for task in self.tasks:
                        task_biases[f'bias/{task}/counts'] = np.log(
                            1 + task_biases[f'bias/{task}/profile'].sum(0))
                    # total_count_bias = np.concatenate([np.log(1 + x[k].sum(0))
                    #                                    for k, x in biases.items()], axis=-1)
                    # task_biases['bias/total_counts'] = total_count_bias

                if self.profile_bias_pool_size is not None:
                    for task in self.tasks:
                        task_biases[f'bias/{task}/profile'] = np.concatenate(
                            [
                                moving_average(
                                    task_biases[f'bias/{task}/profile'],
                                    n=pool_size) for pool_size in to_list(
                                        self.profile_bias_pool_size)
                            ],
                            axis=-1)

                sequence = {"seq": sequence, **task_biases}
        else:
            cuts = dict()

        if self.include_classes:
            if self.taskname_first:
                # Get the classes from the tsv file
                classes = {
                    f"{task}/class": self.dfm.iat[idx, i + 3]
                    for i, task in enumerate(self.dfm_tasks)
                    if task in self.tasks
                }
            else:
                classes = {
                    f"class/{task}": self.dfm.iat[idx, i + 3]
                    for i, task in enumerate(self.dfm_tasks)
                    if task in self.tasks
                }
            cuts = {**cuts, **classes}

        out = {"inputs": sequence, "targets": cuts}

        if self.include_metadata:
            out['metadata'] = {
                "range":
                GenomicRanges(
                    chr=target_interval.chrom,
                    start=target_interval.start,
                    end=target_interval.stop,
                    id=idx,
                    strand=(target_interval.strand
                            if target_interval.strand is not None else "*"),
                ),
                "interval_from_task":
                interval_from_task
            }
        return out
예제 #28
0
def test_fasta_extractor_over_chr_end():
    extractor = FastaExtractor('tests/data/fasta_test.fa')
    intervals = [Interval('chr1', 0, 100), Interval('chr1', 1, 101)]
    with pytest.raises(ValueError):
        data = extractor(intervals)
예제 #29
0
 def __init__(self, intervals_file, fasta_file):
     self.bt = BedTool(intervals_file)
     self.fasta_extractor = FastaExtractor(fasta_file)
예제 #30
0
    def __getitem__(self, idx):
        from pybedtools import Interval

        if self.fasta_extractor is None:
            # first call
            # Use normal fasta/bigwig extractors
            self.fasta_extractor = FastaExtractor(self.ds.fasta_file, use_strand=True)

            self.bw_extractors = {task: [BigwigExtractor(track) for track in task_spec.tracks]
                                  for task, task_spec in self.ds.task_specs.items() if task in self.tasks}

            self.bias_bw_extractors = {task: [BigwigExtractor(track) for track in task_spec.tracks]
                                       for task, task_spec in self.ds.bias_specs.items()}

        # Get the genomic interval for that particular datapoint
        interval = Interval(self.dfm.iat[idx, 0],  # chrom
                            self.dfm.iat[idx, 1],  # start
                            self.dfm.iat[idx, 2])  # end

        # Transform the input interval (for say augmentation...)
        if self.interval_transformer is not None:
            interval = self.interval_transformer(interval)

        # resize the intervals to the desired widths
        target_interval = resize_interval(deepcopy(interval), self.peak_width)
        seq_interval = resize_interval(deepcopy(interval), self.seq_width)

        # This only kicks in when we specify the taskname from dataspec
        # to the 3rd column. E.g. it doesn't apply when using intervals_file
        interval_from_task = self.dfm.iat[idx, 3] if self.intervals_file is None else ''

        # extract DNA sequence + one-hot encode it
        sequence = self.fasta_extractor([seq_interval])[0]
        inputs = {"seq": sequence}

        # exctract the profile counts from the bigwigs
        cuts = {f"{task}/profile": _run_extractors(self.bw_extractors[task],
                                                   [target_interval],
                                                   sum_tracks=spec.sum_tracks)[0]
                for task, spec in self.ds.task_specs.items() if task in self.tasks}
        if self.track_transform is not None:
            for task in self.tasks:
                cuts[f'{task}/profile'] = self.track_transform(cuts[f'{task}/profile'])

        # Add total number of counts
        for task in self.tasks:
            cuts[f'{task}/counts'] = self.total_count_transform(cuts[f'{task}/profile'].sum(0))

        if len(self.ds.bias_specs) > 0:
            # Extract the bias tracks
            biases = {bias_task: _run_extractors(self.bias_bw_extractors[bias_task],
                                                 [target_interval],
                                                 sum_tracks=spec.sum_tracks)[0]
                      for bias_task, spec in self.ds.bias_specs.items()}

            task_biases = {f"bias/{task}/profile": np.concatenate([biases[bt]
                                                                   for bt in self.task_bias_tracks[task]],
                                                                  axis=-1)
                           for task in self.tasks}

            if self.track_transform is not None:
                for task in self.tasks:
                    task_biases[f'bias/{task}/profile'] = self.track_transform(task_biases[f'bias/{task}/profile'])

            # Add total number of bias counts
            for task in self.tasks:
                task_biases[f'bias/{task}/counts'] = self.total_count_transform(task_biases[f'bias/{task}/profile'].sum(0))

            inputs = {**inputs, **task_biases}

        if self.include_classes:
            # Optionally, add binary labels from the additional columns in the tsv intervals file
            classes = {f"{task}/class": self.dfm.iat[idx, i + 3]
                       for i, task in enumerate(self.dfm_tasks) if task in self.tasks}
            cuts = {**cuts, **classes}

        out = {"inputs": inputs,
               "targets": cuts}

        if self.include_metadata:
            # remember the metadata (what genomic interval was used)
            out['metadata'] = {"range": GenomicRanges(chr=target_interval.chrom,
                                                      start=target_interval.start,
                                                      end=target_interval.stop,
                                                      id=idx,
                                                      strand=(target_interval.strand
                                                              if target_interval.strand is not None
                                                              else "*"),
                                                      ),
                               "interval_from_task": interval_from_task}
        return out