Exemple #1
0
    def __getitem__(self, idx):
        if self.fasta_extractors is None:
            self.fasta_extractors = FastaStringExtractor(
                self.fasta_file,
                use_strand=False,  # self.use_strand,
                force_upper=self.force_upper)

        interval, labels = self.bed[idx]

        if self.auto_resize_len:
            # automatically resize the sequence to cerat
            interval = resize_interval(interval,
                                       self.auto_resize_len,
                                       anchor='center')

        # QUESTION: @kromme - why to we need max_seq_len?
        # if self.max_seq_len is not None:
        #     assert interval.stop - interval.start <= self.max_seq_len

        # Run the fasta extractor and transform if necessary
        seq = self.fasta_extractors.extract(interval)

        return {
            "inputs": np.array(seq),
            "targets": labels,
            "metadata": {
                "ranges":
                GenomicRanges(interval.chrom, interval.start, interval.stop,
                              str(idx))
            }
        }
    def __init__(self,
                 fasta_file: str,
                 bed_file: str = None,
                 variant_type='snv',
                 alphabet='DNA'):
        if variant_type not in {'all', 'snv', 'in', 'del'}:
            raise ValueError("variant_type should be one of "
                             "{'all', 'snv', 'in', 'del'}")

        self.bed_file = bed_file
        self.fasta = fasta_file
        self.fasta = FastaStringExtractor(fasta_file, force_upper=True)
        self.variant_type = variant_type
        self.alphabet = alphabets[alphabet]
Exemple #3
0
class MyDataset(Dataset):
    """Example re-implementation of kipoiseq.dataloaders.SeqIntervalDl
    Args:
        intervals_file: bed3 file containing intervals
        fasta_file: file path; Genome sequence
    """
    def __init__(self, intervals_file, fasta_file, ignore_targets=True):
        self.bt = BedDataset(intervals_file,
                             bed_columns=3,
                             ignore_targets=ignore_targets)
        self.fasta_file = fasta_file
        self.fasta_extractor = None
        self.transform = OneHot()  # one-hot encode DNA sequence

    def __len__(self):
        return len(self.bt)

    def __getitem__(self, idx):
        self.fasta_extractor = FastaStringExtractor(self.fasta_file)

        # get the intervals
        interval, targets = self.bt[idx]

        # resize to 500bp
        interval = resize_interval(interval, 500, anchor='center')

        # extract the sequence
        seq = self.fasta_extractor.extract(interval)

        # one-hot encode the sequence
        seq_onehot = self.transform(seq)

        ranges = GenomicRanges.from_interval(interval)

        return {"inputs": [seq_onehot], "metadata": [ranges]}
Exemple #4
0
    def __init__(self, fasta_file: str = None, ref_seq_extractor: BaseExtractor = None):
        """
        Sequence extractor which allows to obtain the alternative sequence,
        given some interval and variants inside this interval.

        Args:
          fasta_file: path to the fasta file (can be gzipped)
          ref_seq_extractor: extractor returning the reference sequence given some interval
        """
        if fasta_file is not None:
            if ref_seq_extractor is not None:
                raise ValueError("either fasta_file or ref_seq_extractor have to be specified")
            self._ref_seq_extractor = FastaStringExtractor(fasta_file, use_strand=True)
        else:
            if ref_seq_extractor is None:
                raise ValueError("either fasta_file or ref_seq_extractor have to be specified")
            self._ref_seq_extractor = ref_seq_extractor
Exemple #5
0
    def __getitem__(self, idx):
        self.fasta_extractor = FastaStringExtractor(self.fasta_file)

        # get the intervals
        interval, targets = self.bt[idx]

        # resize to 500bp
        interval = resize_interval(interval, 500, anchor='center')

        # extract the sequence
        seq = self.fasta_extractor.extract(interval)

        # one-hot encode the sequence
        seq_onehot = self.transform(seq)

        ranges = GenomicRanges.from_interval(interval)

        return {"inputs": [seq_onehot], "metadata": [ranges]}
Exemple #6
0
    def __getitem__(self, idx):
        if self.input_data_extractors is None:
            self.input_data_extractors = {
                "seq":
                FastaStringExtractor(self.fasta_file),
                "dist_polya_st":
                DistToClosestLandmarkExtractor(gtf_file=self.gtf,
                                               landmarks=["polya"])
            }

        interval = self.bt[idx]

        out = {}

        out['inputs'] = {
            "seq":
            one_hot_dna(
                FastaStringExtractor(self.fasta_file).extract(interval)),
            "dist_polya_st":
            np.squeeze(DistToClosestLandmarkExtractor(gtf_file=self.gtf,
                                                      landmarks=["polya"
                                                                 ])([interval
                                                                     ]),
                       axis=0)
        }

        # use trained spline transformation to transform it
        out["inputs"]["dist_polya_st"] = np.squeeze(self.transformer.transform(
            out["inputs"]["dist_polya_st"][np.newaxis], warn=False),
                                                    axis=0)

        if self.target_dataset is not None:
            out["targets"] = np.array([self.target_dataset[idx]])

        # get metadata
        out['metadata'] = {}
        out['metadata']['ranges'] = {}
        out['metadata']['ranges']['chr'] = interval.chrom
        out['metadata']['ranges']['start'] = interval.start
        out['metadata']['ranges']['end'] = interval.stop
        out['metadata']['ranges']['id'] = interval.name
        out['metadata']['ranges']['strand'] = interval.strand

        return out
Exemple #7
0
 def __getitem__(self, idx):
     if self.fasta_extractor is None:
         self.fasta_extractor = FastaStringExtractor(self.fasta_file,
                                                     use_strand=True,
                                                     force_upper=True)
     feature = self.start_codons[idx]
     interval = get_upstream(feature, self.n_upstream)
     seq = self.fasta_extractor.extract(interval)
     seq_one_hot_encoded = self.input_transform(seq)
     return {
         "inputs": seq_one_hot_encoded,
         "metadata": {
             "ranges": GenomicRanges.from_interval(interval),
             "gene_id": feature.attributes.get('gene_id', [""])[0],
             "transcript_id": feature.attributes.get('transcript_id',
                                                     [""])[0],
             "gene_biotype": feature.attributes.get('gene_biotype', [""])[0]
         }
     }
Exemple #8
0
def test_fastareader(use_strand, force_upper):
    fp = "tests/data/sample.fasta"
    with open(fp, "r") as ifh:
        for i, s in enumerate(ifh):
            if i == 1:
                fasta_str = s.lstrip()
    fr = FastaStringExtractor(fp, use_strand, force_upper)
    intervals = Interval("chr1", 0, 2, strand="-"), Interval("chr1", 3, 4)

    for interval in intervals:
        seq = fr.extract(interval)
        ref_seq = fasta_str[interval.start:interval.end]
        if use_strand and interval.strand == "-":
            ref_seq = list(ref_seq)[::-1]
            ref_seq = "".join([comp[el] for el in ref_seq])
        if force_upper:
            assert seq == ref_seq.upper()
        else:
            assert seq == ref_seq
Exemple #9
0
    def __init__(
        self,
        fasta_file,
        gtf_file,
    ):
        genome_annotation = pr.read_gtf(gtf_file, as_df=True)
        roi = get_roi_from_genome_annotation(genome_annotation)
        roi = pr.PyRanges(roi)

        super().__init__(
            regions_of_interest=roi,
            reference_sequence=FastaStringExtractor(fasta_file),
        )
Exemple #10
0
class SeqDataset(Dataset):
    """
    Args:
        intervals_file: bed3 file containing intervals
        fasta_file: file path; Genome sequence
        target_file: file path; path to the targets in the csv format
    """

    def __init__(self, intervals_file, fasta_file, target_file=None, use_linecache=False):

        # intervals
        if use_linecache:
            self.bt = BedToolLinecache(intervals_file)
        else:
            self.bt = BedTool(intervals_file)
        self.fasta_file = fasta_file
        self.fasta_extractor = None

        # Targets
        if target_file is not None:
            self.targets = pd.read_csv(target_file)
        else:
            self.targets = None

    def __len__(self):
        return len(self.bt)

    def __getitem__(self, idx):
        if self.fasta_extractor is None:
            self.fasta_extractor = FastaStringExtractor(self.fasta_file)

        interval = self.bt[idx]

        # Intervals need to be 1000bp wide
        assert interval.stop - interval.start == 1000

        if self.targets is not None:
            y = self.targets.iloc[idx].values
        else:
            y = {}

        # Run the fasta extractor
        seq = one_hot_dna(self.fasta_extractor.extract(interval), dtype=np.float32) # TODO: Remove additional dtype after kipoiseq gets a new release
        return {
            "inputs": seq,
            "targets": y,
            "metadata": {
                "ranges": GenomicRanges.from_interval(interval)
            }
        }
Exemple #11
0
    def __getitem__(self, idx):
        if self.fasta_extractor is None:
            self.fasta_extractor = FastaStringExtractor(self.fasta_file)

        interval = self.bt[idx]

        # Intervals need to be 1000bp wide
        assert interval.stop - interval.start == 1000

        if self.targets is not None:
            y = self.targets.iloc[idx].values
        else:
            y = {}

        # Run the fasta extractor
        seq = one_hot_dna(self.fasta_extractor.extract(interval), dtype=np.float32) # TODO: Remove additional dtype after kipoiseq gets a new release
        return {
            "inputs": seq,
            "targets": y,
            "metadata": {
                "ranges": GenomicRanges.from_interval(interval)
            }
        }
Exemple #12
0
    def __init__(self,
                 gtf_file,
                 fasta_file,
                 num_upstream,
                 num_downstream,
                 gtf_filter='gene_type == "protein_coding"',
                 anchor='tss',
                 transform=one_hot_dna,
                 interval_attrs=["gene_id", "Strand"],
                 use_strand=True):

        # Read and filter gtf
        gtf = pr.read_gtf(gtf_file).df
        if gtf_filter:
            if isinstance(gtf_filter, str):
                gtf = gtf.query(gtf_filter)
            else:
                gtf = gtf_filter(gtf)
        # Extract anchor
        if isinstance(anchor, str):
            anchor = anchor.lower()
            if anchor in self._function_mapping:
                anchor = self._function_mapping[anchor]
            else:
                raise Exception("No valid anchorpoint was chosen")
        self._gtf_anchor = anchor(gtf)

        # Other parameters
        self._use_strand = use_strand
        self._fa = FastaStringExtractor(fasta_file,
                                        use_strand=self._use_strand)
        self._transform = transform
        if self._transform is None:
            self._transform = lambda x: x
        self._num_upstream = num_upstream
        self._num_downstream = num_downstream
        self._interval_attrs = interval_attrs
Exemple #13
0
    def __init__(
        self,
        fasta_file,
        gtf_file,
        vcf_file,
        vcf_file_tbi=None,
        vcf_lazy=True,
    ):
        genome_annotation = pr.read_gtf(gtf_file, as_df=True)
        roi = get_roi_from_genome_annotation(genome_annotation)
        roi = pr.PyRanges(roi)

        from kipoiseq.extractors import MultiSampleVCF
        super().__init__(regions_of_interest=roi,
                         reference_sequence=FastaStringExtractor(fasta_file),
                         variants=MultiSampleVCF(vcf_file, lazy=vcf_lazy))
Exemple #14
0
    def __init__(self,
                 gtf_file,
                 fasta_file,
                 vcf_file,
                 feature_type,
                 infer_from_cds=False,
                 on_error_warn=True,
                 vcf_file_tbi=None,
                 **kwargs):
        self.gtf_file = gtf_file
        self.fasta_file = fasta_file
        self.vcf_file = vcf_file
        self.feature_type = feature_type
        self.infer_from_cds = infer_from_cds
        self.on_error_warn = on_error_warn

        self.interval_fetcher = UTRFetcher(gtf_file=gtf_file,
                                           feature_type=feature_type,
                                           infer_from_cds=infer_from_cds,
                                           on_error_warn=on_error_warn)
        self.multi_sample_VCF = MultiSampleVCF(vcf_file)
        self.reference_seq_extractor = FastaStringExtractor(fasta_file)

        df = self.interval_fetcher.df
        import pyranges
        # match variant with transcript_id
        self.variant_matcher = SingleVariantMatcher(self.vcf_file,
                                                    pranges=pyranges.PyRanges(
                                                        df.reset_index()))

        self.extractor = GenericSingleVariantMultiIntervalVCFSeqExtractor(
            interval_fetcher=self.interval_fetcher,
            reference_seq_extractor=self.reference_seq_extractor,
            variant_matcher=self.variant_matcher,
            multi_sample_VCF=self.multi_sample_VCF,
        )

        # # only needed metadata
        # self.metadatas = (
        #     (
        #         df.loc[~df.index.duplicated(keep='first')]
        #     ).drop(columns=['Start', 'End'])
        # )
        # generator for all sequences with variants
        self.sequences = self._extractor()
Exemple #15
0
class FixedSeq5UtrDl(Dataset):

    n_upstream = 50

    def __init__(self,
                 gtf_file,
                 fasta_file,
                 disable_infer_transcripts=True,
                 disable_infer_genes=True):
        self.gtf_file = gtf_file
        self.fasta_file = fasta_file

        self.fasta_extractor = None

        self.db = gffutils.create_db(
            gtf_file,
            ":memory:",
            disable_infer_transcripts=disable_infer_transcripts,
            disable_infer_genes=disable_infer_genes)
        self.start_codons = list(self.db.features_of_type("start_codon"))
        self.input_transform = OneHot()

    def __len__(self):
        return len(self.start_codons)

    def __getitem__(self, idx):
        if self.fasta_extractor is None:
            self.fasta_extractor = FastaStringExtractor(self.fasta_file,
                                                        use_strand=True,
                                                        force_upper=True)
        feature = self.start_codons[idx]
        interval = get_upstream(feature, self.n_upstream)
        seq = self.fasta_extractor.extract(interval)
        seq_one_hot_encoded = self.input_transform(seq)
        return {
            "inputs": seq_one_hot_encoded,
            "metadata": {
                "ranges": GenomicRanges.from_interval(interval),
                "gene_id": feature.attributes.get('gene_id', [""])[0],
                "transcript_id": feature.attributes.get('transcript_id',
                                                        [""])[0],
                "gene_biotype": feature.attributes.get('gene_biotype', [""])[0]
            }
        }
Exemple #16
0
    def __init__(self, fasta_file: str = None, reference_sequence: BaseExtractor = None, use_strand=True):
        """
        Sequence extractor which allows to obtain the alternative sequence,
        given some interval and variants inside this interval.

        Args:
            fasta_file: path to the fasta file (can be gzipped)
            reference_sequence: extractor returning the reference sequence given some interval
            use_strand (bool): if True, the extracted sequence
                is reverse complemented in case interval.strand == "-"
        """
        self._use_strand = use_strand

        if fasta_file is not None:
            if reference_sequence is not None:
                raise ValueError(
                    "either fasta_file or ref_seq_extractor have to be specified")
            self._ref_seq_extractor = FastaStringExtractor(
                fasta_file, use_strand=False)
        else:
            if reference_sequence is None:
                raise ValueError(
                    "either fasta_file or ref_seq_extractor have to be specified")
            self._ref_seq_extractor = reference_sequence
def test_extract(variant_seq_extractor):
    variants = [Variant.from_cyvcf(v) for v in VCF(vcf_file)]

    interval = Interval('chr1', 2, 9)

    seq = variant_seq_extractor.extract(interval, variants, anchor=5)
    assert len(seq) == interval.end - interval.start
    assert seq == 'CGAACGT'

    interval = Interval('chr1', 2, 9, strand='-')
    seq = variant_seq_extractor.extract(interval, variants, anchor=5)
    assert len(seq) == interval.end - interval.start
    assert seq == 'ACGTTCG'

    interval = Interval('chr1', 4, 14)
    seq = variant_seq_extractor.extract(interval, variants, anchor=7)
    assert len(seq) == interval.end - interval.start
    assert seq == 'AACGTAACGT'

    interval = Interval('chr1', 4, 14)
    seq = variant_seq_extractor.extract(interval, variants, anchor=4)
    assert len(seq) == interval.end - interval.start
    assert seq == 'GAACGTAACG'

    interval = Interval('chr1', 2, 5)
    seq = variant_seq_extractor.extract(interval, variants, anchor=3)
    assert len(seq) == interval.end - interval.start
    assert seq == 'GCG'

    interval = Interval('chr1', 24, 34)
    seq = variant_seq_extractor.extract(interval, variants, anchor=27)
    assert len(seq) == interval.end - interval.start
    assert seq == 'TGATAACGTA'

    interval = Interval('chr1', 25, 35)
    seq = variant_seq_extractor.extract(interval, variants, anchor=34)
    assert len(seq) == interval.end - interval.start
    assert seq == 'TGATAACGTA'

    interval = Interval('chr1', 34, 44)
    seq = variant_seq_extractor.extract(interval, variants, anchor=37)
    assert len(seq) == interval.end - interval.start
    assert seq == 'AACGTAACGT'

    interval = Interval('chr1', 34, 44)
    seq = variant_seq_extractor.extract(interval, variants, anchor=100)
    assert len(seq) == interval.end - interval.start
    assert seq == 'AACGTAACGT'

    interval = Interval('chr1', 5, 11, strand='+')
    seq = variant_seq_extractor.extract(interval,
                                        variants,
                                        anchor=10,
                                        fixed_len=False)
    assert seq == 'ACGTAA'

    interval = Interval('chr1', 0, 3, strand='+')
    seq = variant_seq_extractor.extract(interval,
                                        variants,
                                        anchor=10,
                                        fixed_len=False)
    assert seq == 'ACG'

    interval = Interval('chr1', 0, 3, strand='+')
    ref_seq_extractor = FastaStringExtractor(fasta_file, use_strand=True)
    seq = VariantSeqExtractor(reference_sequence=ref_seq_extractor).extract(
        interval, variants, anchor=10, fixed_len=False)
    assert seq == 'ACG'
Exemple #18
0
class StringSeqIntervalDl(Dataset):
    """
    info:
        doc: >
           Dataloader for a combination of fasta and tab-delimited input files such as bed files. The dataloader extracts
           regions from the fasta file as defined in the tab-delimited `intervals_file`. Returned sequences are of the type
           np.array([str]).
    args:
        intervals_file:
            doc: bed3+<columns> file path containing intervals + (optionally) labels
            example:
              url: https://raw.githubusercontent.com/kipoi/kipoiseq/master/tests/data/intervals_51bp.tsv
              md5: a76e47b3df87fd514860cf27fdc10eb4
        fasta_file:
            doc: Reference genome FASTA file path.
            example:
              url: https://raw.githubusercontent.com/kipoi/kipoiseq/master/tests/data/hg38_chr22_32000000_32300000.fa
              md5: 01320157a250a3d2eea63e89ecf79eba
        num_chr_fasta:
            doc: True, the the dataloader will make sure that the chromosomes don't start with chr.
        label_dtype:
            doc: None, datatype of the task labels taken from the intervals_file. Example - str, int, float, np.float32
        auto_resize_len:
            doc: None, required sequence length.
        # max_seq_len:
        #     doc: maximum allowed sequence length
        # use_strand:
        #     doc: reverse-complement fasta sequence if bed file defines negative strand
        force_upper:
            doc: Force uppercase output of sequences
        ignore_targets:
            doc: if True, don't return any target variables
    output_schema:
        inputs:
            name: seq
            shape: ()
            doc: DNA sequence as string
            special_type: DNAStringSeq
            associated_metadata: ranges
        targets:
            shape: (None,)
            doc: (optional) values following the bed-entry - chr  start  end  target1   target2 ....
        metadata:
            ranges:
                type: GenomicRanges
                doc: Ranges describing inputs.seq
    postprocessing:
        variant_effects:
          bed_input:
            - intervals_file
    """
    def __init__(
            self,
            intervals_file,
            fasta_file,
            num_chr_fasta=False,
            label_dtype=None,
            auto_resize_len=None,
            # max_seq_len=None,
            # use_strand=False,
            force_upper=True,
            ignore_targets=False):

        self.num_chr_fasta = num_chr_fasta
        self.intervals_file = intervals_file
        self.fasta_file = fasta_file
        self.auto_resize_len = auto_resize_len
        # self.use_strand = use_strand
        self.force_upper = force_upper
        # self.max_seq_len = max_seq_len

        # if use_strand:
        #     # require a 6-column bed-file if strand is used
        #     bed_columns = 6
        # else:
        #     bed_columns = 3

        self.bed = BedDataset(self.intervals_file,
                              num_chr=self.num_chr_fasta,
                              bed_columns=3,
                              label_dtype=parse_dtype(label_dtype),
                              ignore_targets=ignore_targets)
        self.fasta_extractors = None

    def __len__(self):
        return len(self.bed)

    def __getitem__(self, idx):
        if self.fasta_extractors is None:
            self.fasta_extractors = FastaStringExtractor(
                self.fasta_file,
                use_strand=False,  # self.use_strand,
                force_upper=self.force_upper)

        interval, labels = self.bed[idx]

        if self.auto_resize_len:
            # automatically resize the sequence to cerat
            interval = resize_interval(interval,
                                       self.auto_resize_len,
                                       anchor='center')

        # QUESTION: @kromme - why to we need max_seq_len?
        # if self.max_seq_len is not None:
        #     assert interval.stop - interval.start <= self.max_seq_len

        # Run the fasta extractor and transform if necessary
        seq = self.fasta_extractors.extract(interval)

        return {
            "inputs": np.array(seq),
            "targets": labels,
            "metadata": {
                "ranges":
                GenomicRanges(interval.chrom, interval.start, interval.stop,
                              str(idx))
            }
        }

    @classmethod
    def get_output_schema(cls):
        output_schema = deepcopy(cls.output_schema)
        kwargs = default_kwargs(cls)
        ignore_targets = kwargs['ignore_targets']
        if ignore_targets:
            output_schema.targets = None
        return output_schema
Exemple #19
0
class StrandedSequenceVariantDataloader(Dataset):
    """ This Dataloader requires the following input files:
        1. bed3+ where a specific user-specified column (>3, 1-based) of the bed denotes the strand
        and a specific user-specified column (>3, 1-based) of the bed denotes the transcript id
        (or some other id that explains which exons in the bed belong together to form one sequence).
        All columns of the bed, except the first three, the id and the strand, are ignored. 
        2. fasta file that provides the reference genome
        3. bgzip compressed (single sample) vcf that provides the variants
        4. A chromosome order file (such as a fai file) that specifies the order of chromosomes 
        (must be valid for all files)
        The bed and vcf must both be sorted (by position) and a tabix index must be present.
        (must lie in the same directory and have the same name + .tbi)
        The num_chr flag indicates whether chromosomes are listed numerically or with a chr prefix.
        This must be consistent across all input files!
        The dataloader finds all intervals in the bed which contain at least one variant in the vcf.
        It then joins intervals belonging to the same transcript, as specified by the id, to a single sequence.
        For these sequences, it extracts the reference sequence from the fasta file, 
        injects the applicable variants and reverse complements according to the strand information.
        This means that if a vcf mixes variants from more than one patient, the results will not be
        meaningful. Split the vcf by patient and run the predictions seperately in this case!
        Returns the reference sequence and variant sequence as 
        np.array([reference_sequence, variant_sequence]). 
        Region metadata is additionally provided
"""
    def __init__(self,
                 intervals_file,
                 fasta_file,
                 vcf_file,
                 chr_order_file,
                 vcf_file_tbi=None,
                 strand_column=6,
                 id_column=4,
                 num_chr=True):

        # workaround for test
        if vcf_file_tbi is not None and vcf_file_tbi.endswith("vcf_file_tbi"):
            os.rename(vcf_file_tbi,
                      vcf_file_tbi.replace("vcf_file_tbi", "vcf_file.tbi"))

        self.num_chr_fasta = num_chr
        self.intervals_file = intervals_file
        self.fasta_file = fasta_file
        self.vcf_file = vcf_file
        self.chr_order_file = chr_order_file

        self.strand_column = strand_column - 1
        self.id_column = id_column - 1

        self.force_upper = True

        # "Parse" bed file
        self.bed = BedDataset(self.intervals_file,
                              num_chr=self.num_chr_fasta,
                              bed_columns=3,
                              label_dtype=str,
                              ignore_targets=False)

        # Intersect bed and vcf using bedtools
        # bedtools c flag: for each bed interval, counts number of vcf entries it overlaps
        bed_tool = pybedtools.BedTool(self.intervals_file)
        intersect_counts = list(
            bed_tool.intersect(self.vcf_file,
                               c=True,
                               sorted=True,
                               g=self.chr_order_file))
        intersect_counts = np.array(
            [isect.count for isect in intersect_counts])

        # Retain only those transcripts that intersect a variant
        utr5_bed = self.bed.df
        id_col = utr5_bed.iloc[:, self.id_column]
        retain_transcripts = utr5_bed[
            intersect_counts > 0].iloc[:, self.id_column]
        utr5_bed = utr5_bed[utr5_bed.iloc[:, self.id_column].isin(
            retain_transcripts)]

        # Aggregate 5utr positions per transcript
        tuples = list(zip(utr5_bed.iloc[:, 1], utr5_bed.iloc[:, 2]))
        pos = [[x] for x in tuples]
        id_chr_strand = list(
            zip(utr5_bed.iloc[:, self.id_column], utr5_bed.iloc[:, 0],
                utr5_bed.iloc[:, self.strand_column]))
        utr5_bed_posaggreg = pd.DataFrame({
            "pos": pos,
            "id_chr_strand": id_chr_strand
        })
        utr5_bed_posaggreg = utr5_bed_posaggreg.groupby("id_chr_strand").agg(
            {'pos': 'sum'})

        # Rebuild "bed"
        utr5_bed_posaggreg["id"] = [x[0] for x in utr5_bed_posaggreg.index]
        utr5_bed_posaggreg["chr"] = [x[1] for x in utr5_bed_posaggreg.index]
        utr5_bed_posaggreg["strand"] = [x[2] for x in utr5_bed_posaggreg.index]
        self.bed = utr5_bed_posaggreg.reset_index()[[
            "id", "chr", "pos", "strand"
        ]]

        self.fasta_extractor = None
        self.vcf = None
        self.vcf_extractor = None

    def __len__(self):
        return len(self.bed)

    def __getitem__(self, idx):
        if self.fasta_extractor is None:
            self.fasta_extractor = FastaStringExtractor(
                self.fasta_file, use_strand=True, force_upper=self.force_upper)
        if self.vcf is None:
            self.vcf = MultiSampleVCF(self.vcf_file)
        if self.vcf_extractor is None:
            self.vcf_extractor = VariantSeqExtractor(self.fasta_file)

        entry = self.bed.iloc[idx]
        entry_id = entry["id"]
        entry_chr = entry["chr"]
        entry_pos = entry["pos"]
        entry_strand = entry["strand"]

        ref_exons = []
        var_exons = []
        exon_pos_strings = []
        exon_var_strings = []
        for exon in entry_pos:
            # We get the interval
            interval = pybedtools.Interval(to_scalar(entry_chr),
                                           to_scalar(exon[0]),
                                           to_scalar(exon[1]),
                                           strand=to_scalar(entry_strand))
            exon_pos_strings.append("%s-%s" % (str(exon[0]), str(exon[1])))

            # We get the reference sequence
            ref_seq = self.fasta_extractor.extract(interval)

            # We get the variants, insert them and also save them as metadata
            variants = list(self.vcf.fetch_variants(interval))
            if len(variants) == 0:
                ref_exons.append(ref_seq)
                var_exons.append(ref_seq)
            else:
                var_seq = self.vcf_extractor.extract(interval,
                                                     variants=variants,
                                                     anchor=0,
                                                     fixed_len=False)
                var_string = ";".join([str(var) for var in variants])

                ref_exons.append(ref_seq)
                var_exons.append(var_seq)
                exon_var_strings.append(var_string)

        # Combine
        if entry_strand == "-":
            ref_exons.reverse()
            var_exons.reverse()
        ref_seq = "".join(ref_exons)
        var_seq = "".join(var_exons)
        pos_string = ";".join(exon_pos_strings)
        var_string = ";".join(exon_var_strings)

        return {
            "inputs": {
                "ref_seq": ref_seq,
                "alt_seq": var_seq,
            },
            "metadata": {
                "id": entry_id,
                "chr": entry_chr,
                "exon_positions": pos_string,
                "strand": entry_strand,
                "variants": var_string
            }
        }
Exemple #20
0
    def __getitem__(self, idx):
        if self.fasta_extractor is None:
            self.fasta_extractor = FastaStringExtractor(
                self.fasta_file, use_strand=True, force_upper=self.force_upper)
        if self.vcf is None:
            self.vcf = MultiSampleVCF(self.vcf_file)
        if self.vcf_extractor is None:
            self.vcf_extractor = VariantSeqExtractor(self.fasta_file)

        entry = self.bed.iloc[idx]
        entry_id = entry["id"]
        entry_chr = entry["chr"]
        entry_pos = entry["pos"]
        entry_strand = entry["strand"]

        ref_exons = []
        var_exons = []
        exon_pos_strings = []
        exon_var_strings = []
        for exon in entry_pos:
            # We get the interval
            interval = pybedtools.Interval(to_scalar(entry_chr),
                                           to_scalar(exon[0]),
                                           to_scalar(exon[1]),
                                           strand=to_scalar(entry_strand))
            exon_pos_strings.append("%s-%s" % (str(exon[0]), str(exon[1])))

            # We get the reference sequence
            ref_seq = self.fasta_extractor.extract(interval)

            # We get the variants, insert them and also save them as metadata
            variants = list(self.vcf.fetch_variants(interval))
            if len(variants) == 0:
                ref_exons.append(ref_seq)
                var_exons.append(ref_seq)
            else:
                var_seq = self.vcf_extractor.extract(interval,
                                                     variants=variants,
                                                     anchor=0,
                                                     fixed_len=False)
                var_string = ";".join([str(var) for var in variants])

                ref_exons.append(ref_seq)
                var_exons.append(var_seq)
                exon_var_strings.append(var_string)

        # Combine
        if entry_strand == "-":
            ref_exons.reverse()
            var_exons.reverse()
        ref_seq = "".join(ref_exons)
        var_seq = "".join(var_exons)
        pos_string = ";".join(exon_pos_strings)
        var_string = ";".join(exon_var_strings)

        return {
            "inputs": {
                "ref_seq": ref_seq,
                "alt_seq": var_seq,
            },
            "metadata": {
                "id": entry_id,
                "chr": entry_chr,
                "exon_positions": pos_string,
                "strand": entry_strand,
                "variants": var_string
            }
        }
class VariantCombinator:
    def __init__(self,
                 fasta_file: str,
                 bed_file: str = None,
                 variant_type='snv',
                 alphabet='DNA'):
        if variant_type not in {'all', 'snv', 'in', 'del'}:
            raise ValueError("variant_type should be one of "
                             "{'all', 'snv', 'in', 'del'}")

        self.bed_file = bed_file
        self.fasta = fasta_file
        self.fasta = FastaStringExtractor(fasta_file, force_upper=True)
        self.variant_type = variant_type
        self.alphabet = alphabets[alphabet]

    def combination_variants_snv(self,
                                 interval: Interval) -> Iterable[Variant]:
        """Returns all the possible variants in the regions.

          interval: interval of variants
        """
        seq = self.fasta.extract(interval)
        for pos, ref in zip(range(interval.start, interval.end), seq):
            pos = pos + 1  # 0 to 1 base
            for alt in self.alphabet:
                if ref != alt:
                    yield Variant(interval.chrom, pos, ref, alt)

    def combination_variants_insertion(self,
                                       interval,
                                       length=2) -> Iterable[Variant]:
        """Returns all the possible variants in the regions.

          interval: interval of variants
          length: insertions up to length
        """
        if length < 2:
            raise ValueError('length argument should be larger than 1')

        seq = self.fasta.extract(interval)
        for pos, ref in zip(range(interval.start, interval.end), seq):
            pos = pos + 1  # 0 to 1 base
            for l in range(2, length + 1):
                for alt in product(self.alphabet, repeat=l):
                    yield Variant(interval.chrom, pos, ref, ''.join(alt))

    def combination_variants_deletion(self,
                                      interval,
                                      length=1) -> Iterable[Variant]:
        """Returns all the possible variants in the regions.
          interval: interval of variants
          length: deletions up to length
        """
        if length < 1 and length <= interval.width:
            raise ValueError('length argument should be larger than 0'
                             ' and smaller than interval witdh')

        seq = self.fasta.extract(interval)
        for i, pos in enumerate(range(interval.start, interval.end)):
            pos = pos + 1  # 0 to 1 base
            for j in range(1, length + 1):
                if i + j <= len(seq):
                    yield Variant(interval.chrom, pos, seq[i:i + j], '')

    def combination_variants(self,
                             interval,
                             variant_type='snv',
                             in_length=2,
                             del_length=2) -> Iterable[Variant]:
        if variant_type in {'snv', 'all'}:
            yield from self.combination_variants_snv(interval)
        if variant_type in {'indel', 'in', 'all'}:
            yield from self.combination_variants_insertion(interval,
                                                           length=in_length)
        if variant_type in {'indel', 'del', 'all'}:
            yield from self.combination_variants_deletion(interval,
                                                          length=del_length)

    def __iter__(self) -> Iterable[Variant]:
        import pyranges as pr

        gr = pr.read_bed(self.bed_file)
        gr = gr.merge(strand=False).sort()

        for interval in pyranges_to_intervals(gr):
            yield from self.combination_variants(interval, self.variant_type)

    def to_vcf(self, path):
        from cyvcf2 import Writer
        header = '''##fileformat=VCFv4.2
#CHROM	POS	ID	REF	ALT	QUAL	FILTER	INFO
'''
        writer = Writer.from_string(path, header)

        for v in self:
            variant = writer.variant_from_string('\t'.join(
                [v.chrom,
                 str(v.pos), '.', v.ref, v.alt, '.', '.', '.']))
            writer.write_record(variant)
Exemple #22
0
class AnchoredGTFDl(Dataset):
    """
    info:
        doc: >
            Dataloader for a combination of fasta and gtf files. The dataloader extracts fixed length regions
            around anchor points. Anchor points are extracted from the gtf based on the anchor parameter.
            The sequences corresponding to the region are then extracted from the fasta file and optionally 
            trnasformed using a function given by the transform parameter.
    args:
        gtf_file:
            doc: Path to a gtf file (str)
            example:
                url: https://zenodo.org/record/1466102/files/example_files-gencode.v24.annotation_chr22.gtf
                md5: c0d1bf7738f6a307b425e4890621e7d9
        fasta_file:
            doc: Reference genome FASTA file path (str)
            example:
                url: https://zenodo.org/record/1466102/files/example_files-hg38_chr22.fa
                md5: b0f5cdd4f75186f8a4d2e23378c57b5b
        num_upstream:
            doc: Number of nt by which interval is extended upstream of the anchor point
        num_downstream:
            doc: Number of nt by which interval is extended downstream of the anchor point
        gtf_filter:
            doc: >
                Allows to filter the gtf before extracting the anchor points. Can be str, callable
                or None. If str, it is interpreted as argument to pandas .query(). If callable,
                it is interpreted as function that filters a pandas dataframe and returns the 
                filtered df.
        anchor:
            doc: >
                Defines the anchor points. Can be str or callable. If it is a callable, it is 
                treated as function that takes a pandas dataframe and returns a modified version
                of the dataframe where each row represents one anchor point, the position of
                which is stored in the column called anchor_pos. If it is a string, a predefined function
                is loaded. Currently available are tss (anchor is the start of a gene), start_codon 
                (anchor is the start of the start_codon), stop_codon (anchor is the position right after
                the stop_codon), polya (anchor is the position right after the end of a gene).
        transform:
            doc: Callable (or None) to transform the extracted sequence (e.g. one-hot)
        interval_attrs:
            doc: Metadata to extract from the gtf, e.g. ["gene_id", "Strand"]
        use_strand:
            doc: True or False
    output_schema:
        inputs:
            name: seq
            shape: (None, 4)
            special_type: DNAStringSeq
            doc: exon sequence with flanking intronic sequence
            associated_metadata: ranges
        metadata:
            gene_id:
                type: str
                doc: gene id
            Strand: 
                type: str
                doc: Strand
            ranges:
                type: GenomicRanges
                doc: ranges that the sequences were extracted
    """
    _function_mapping = {
        "tss":
        lambda x: AnchoredGTFDl.anchor_to_feature_start(
            x, "gene", use_strand=True),
        "start_codon":
        lambda x: AnchoredGTFDl.anchor_to_feature_start(
            x, "start_codon", use_strand=True),
        "stop_codon":
        lambda x: AnchoredGTFDl.anchor_to_feature_end(
            x, "stop_codon", use_strand=True),
        "polya":
        lambda x: AnchoredGTFDl.anchor_to_feature_end(
            x, "gene", use_strand=True)
    }

    def __init__(self,
                 gtf_file,
                 fasta_file,
                 num_upstream,
                 num_downstream,
                 gtf_filter='gene_type == "protein_coding"',
                 anchor='tss',
                 transform=one_hot_dna,
                 interval_attrs=["gene_id", "Strand"],
                 use_strand=True):

        # Read and filter gtf
        gtf = pr.read_gtf(gtf_file).df
        if gtf_filter:
            if isinstance(gtf_filter, str):
                gtf = gtf.query(gtf_filter)
            else:
                gtf = gtf_filter(gtf)
        # Extract anchor
        if isinstance(anchor, str):
            anchor = anchor.lower()
            if anchor in self._function_mapping:
                anchor = self._function_mapping[anchor]
            else:
                raise Exception("No valid anchorpoint was chosen")
        self._gtf_anchor = anchor(gtf)

        # Other parameters
        self._use_strand = use_strand
        self._fa = FastaStringExtractor(fasta_file,
                                        use_strand=self._use_strand)
        self._transform = transform
        if self._transform is None:
            self._transform = lambda x: x
        self._num_upstream = num_upstream
        self._num_downstream = num_downstream
        self._interval_attrs = interval_attrs

    def _create_anchored_interval(self, row, num_upstream, num_downstream):

        if self._use_strand == True and row.Strand == "-":
            # negative strand
            start = row.anchor_pos - num_downstream
            end = row.anchor_pos + num_upstream
        else:
            # positive strand
            start = row.anchor_pos - num_upstream
            end = row.anchor_pos + num_downstream

        interval = Interval(row.Chromosome, start, end, strand=row.Strand)
        return interval

    def __len__(self):
        return len(self._gtf_anchor)

    def __getitem__(self, idx):
        row = self._gtf_anchor.iloc[idx]
        interval = self._create_anchored_interval(
            row,
            num_upstream=self._num_upstream,
            num_downstream=self._num_downstream)
        sequence = self._fa.extract(interval)
        sequence = self._transform(sequence)
        metadata_dict = {k: row.get(k, '') for k in self._interval_attrs}
        metadata_dict["ranges"] = GenomicRanges(interval.chrom, interval.start,
                                                interval.stop, str(idx))
        return {"inputs": np.array(sequence), "metadata": metadata_dict}

    @staticmethod
    def anchor_to_feature_start(gtf, feature, use_strand):
        gtf = gtf.query('Feature == @feature')
        if use_strand:
            gtf["anchor_pos"] = ((gtf.Start * (gtf.Strand == "+")) +
                                 (gtf.End * (gtf.Strand == "-")))
        else:
            gtf["anchor_pos"] = gtf.Start
        return gtf

    @staticmethod
    def anchor_to_feature_end(gtf, feature, use_strand):
        gtf = gtf.query('Feature == @feature')
        if use_strand:
            gtf["anchor_pos"] = ((gtf.End * (gtf.Strand == "+")) +
                                 (gtf.Start * (gtf.Strand == "-")))
        else:
            gtf["anchor_pos"] = gtf.End
        return gtf
Exemple #23
0
 def __init__(self, fasta_file):
     """
     Args:
       fasta_file: path to the fasta file (can be gzipped)
     """
     self.fasta = FastaStringExtractor(fasta_file, use_strand=True)
Exemple #24
0
class VariantSeqExtractor(BaseExtractor):

    def __init__(self, fasta_file: str = None, ref_seq_extractor: BaseExtractor = None):
        """
        Sequence extractor which allows to obtain the alternative sequence,
        given some interval and variants inside this interval.

        Args:
          fasta_file: path to the fasta file (can be gzipped)
          ref_seq_extractor: extractor returning the reference sequence given some interval
        """
        if fasta_file is not None:
            if ref_seq_extractor is not None:
                raise ValueError("either fasta_file or ref_seq_extractor have to be specified")
            self._ref_seq_extractor = FastaStringExtractor(fasta_file, use_strand=True)
        else:
            if ref_seq_extractor is None:
                raise ValueError("either fasta_file or ref_seq_extractor have to be specified")
            self._ref_seq_extractor = ref_seq_extractor

    @property
    def ref_seq_extractor(self):
        return self._ref_seq_extractor

    def extract(self, interval, variants, anchor, fixed_len=True, **kwargs):
        """

        Args:
          interval: pybedtools.Interval Region of interest from
            which to query the sequence. 0-based
          variants: List[cyvcf2.Variant]: variants overlapping the `interval`.
            can also be indels. 1-based
          anchor: absolution position w.r.t. the interval start. (0-based).
            E.g. for an interval of `chr1:10-20` the anchor of 10 denotes
            the point chr1:10 in the 0-based coordinate system.
          fixed_len: if True, the return sequence will have the same length
            as the `interval` (e.g. `interval.end - interval.start`)

        Returns:
          A single sequence (`str`) with all the variants applied.
        """
        # Preprocessing
        anchor = max(min(anchor, interval.end), interval.start)
        variant_pairs = self._variant_to_sequence(variants)

        # 1. Split variants overlapping with anchor
        # and interval start end if not fixed_len
        variant_pairs = self._split_overlapping(variant_pairs, anchor)

        if not fixed_len:
            variant_pairs = self._split_overlapping(
                variant_pairs, interval.start, which='right')
            variant_pairs = self._split_overlapping(
                variant_pairs, interval.end, which='left')

        variant_pairs = list(variant_pairs)

        # 2. split the variants into upstream and downstream
        # and sort the variants in each interval
        upstream_variants = sorted(
            filter(lambda x: x[0].start >= anchor, variant_pairs),
            key=lambda x: x[0].start
        )

        downstream_variants = sorted(
            filter(lambda x: x[0].start < anchor, variant_pairs),
            key=lambda x: x[0].start,
            reverse=True
        )

        # 3. Extend start and end position for deletions
        if fixed_len:
            istart, iend = self._updated_interval(
                interval, upstream_variants, downstream_variants)
        else:
            istart, iend = interval.start, interval.end

        # 4. Iterate from the anchor point outwards. At each
        # register the interval from which to take the reference sequence
        # as well as the interval for the variant
        down_sb = self._downstream_builder(
            downstream_variants, interval, anchor, istart)

        up_sb = self._upstream_builder(
            upstream_variants, interval, anchor, iend)

        # 5. fetch the sequence and restore intervals in builder
        seq = self._fetch(interval, istart, iend)
        up_sb.restore(seq)
        down_sb.restore(seq)

        # 6. Concate sequences from the upstream and downstream splits. Concat
        # upstream and downstream sequence. Cut to fix the length.
        down_str = down_sb.concat()
        up_str = up_sb.concat()

        if fixed_len:
            down_str, up_str = self._cut_to_fix_len(
                down_str, up_str, interval, anchor)

        seq = down_str + up_str

        if interval.strand == '-':
            seq = complement(seq)[::-1]

        return seq

    @staticmethod
    def _variant_to_sequence(variants):
        """
        Convert `cyvcf2.Variant` objects to `pyfaidx.Seqeunce` objects
        for reference and variants.
        """
        for v in variants:
            ref = Sequence(name=v.chrom, seq=v.ref,
                           start=v.start, end=v.start + len(v.ref))
            alt = Sequence(name=v.chrom, seq=v.alt,
                           start=v.start, end=v.start + len(v.alt))
            yield ref, alt

    @staticmethod
    def _split_overlapping(variant_pairs, anchor, which='both'):
        """
        Split the variants hitting the anchor into two
        """
        for ref, alt in variant_pairs:
            if ref.start < anchor < ref.end:
                mid = anchor - ref.start
                if which == 'left' or which == 'both':
                    yield ref[:mid], alt[:mid]
                if which == 'right' or which == 'both':
                    yield ref[mid:], alt[mid:]
            else:
                yield ref, alt

    @staticmethod
    def _updated_interval(interval, up_variants, down_variants):
        istart = interval.start
        iend = interval.end

        for ref, alt in up_variants:
            diff_len = len(alt) - len(ref)
            if diff_len < 0:
                iend -= diff_len

        for ref, alt in down_variants:
            diff_len = len(alt) - len(ref)
            if diff_len < 0:
                istart += diff_len

        return istart, iend

    @staticmethod
    def _downstream_builder(down_variants, interval, anchor, istart):
        down_sb = IntervalSeqBuilder()

        prev = anchor
        for ref, alt in down_variants:
            if ref.end <= istart:
                break
            down_sb.append(Interval(interval.chrom, ref.end, prev))
            down_sb.append(alt)
            prev = ref.start
        down_sb.append(Interval(interval.chrom, istart, prev))
        down_sb.reverse()

        return down_sb

    @staticmethod
    def _upstream_builder(up_variants, interval, anchor, iend):
        up_sb = IntervalSeqBuilder()

        prev = anchor
        for ref, alt in up_variants:
            if ref.start >= iend:
                break
            up_sb.append(Interval(interval.chrom, prev, ref.start))
            up_sb.append(alt)
            prev = ref.end
        up_sb.append(Interval(interval.chrom, prev, iend))

        return up_sb

    def _fetch(self, interval, istart, iend):
        seq = self._ref_seq_extractor.extract(Interval(interval.chrom, istart, iend))
        seq = Sequence(name=interval.chrom, seq=seq, start=istart, end=iend)
        return seq

    @staticmethod
    def _cut_to_fix_len(down_str, up_str, interval, anchor):
        down_len = anchor - interval.start
        up_len = interval.end - anchor
        down_str = down_str[-down_len:] if down_len else ''
        up_str = up_str[: up_len] if up_len else ''
        return down_str, up_str