def __getitem__(self, idx): if self.fasta_extractors is None: self.fasta_extractors = FastaStringExtractor( self.fasta_file, use_strand=False, # self.use_strand, force_upper=self.force_upper) interval, labels = self.bed[idx] if self.auto_resize_len: # automatically resize the sequence to cerat interval = resize_interval(interval, self.auto_resize_len, anchor='center') # QUESTION: @kromme - why to we need max_seq_len? # if self.max_seq_len is not None: # assert interval.stop - interval.start <= self.max_seq_len # Run the fasta extractor and transform if necessary seq = self.fasta_extractors.extract(interval) return { "inputs": np.array(seq), "targets": labels, "metadata": { "ranges": GenomicRanges(interval.chrom, interval.start, interval.stop, str(idx)) } }
def __init__(self, fasta_file: str, bed_file: str = None, variant_type='snv', alphabet='DNA'): if variant_type not in {'all', 'snv', 'in', 'del'}: raise ValueError("variant_type should be one of " "{'all', 'snv', 'in', 'del'}") self.bed_file = bed_file self.fasta = fasta_file self.fasta = FastaStringExtractor(fasta_file, force_upper=True) self.variant_type = variant_type self.alphabet = alphabets[alphabet]
class MyDataset(Dataset): """Example re-implementation of kipoiseq.dataloaders.SeqIntervalDl Args: intervals_file: bed3 file containing intervals fasta_file: file path; Genome sequence """ def __init__(self, intervals_file, fasta_file, ignore_targets=True): self.bt = BedDataset(intervals_file, bed_columns=3, ignore_targets=ignore_targets) self.fasta_file = fasta_file self.fasta_extractor = None self.transform = OneHot() # one-hot encode DNA sequence def __len__(self): return len(self.bt) def __getitem__(self, idx): self.fasta_extractor = FastaStringExtractor(self.fasta_file) # get the intervals interval, targets = self.bt[idx] # resize to 500bp interval = resize_interval(interval, 500, anchor='center') # extract the sequence seq = self.fasta_extractor.extract(interval) # one-hot encode the sequence seq_onehot = self.transform(seq) ranges = GenomicRanges.from_interval(interval) return {"inputs": [seq_onehot], "metadata": [ranges]}
def __init__(self, fasta_file: str = None, ref_seq_extractor: BaseExtractor = None): """ Sequence extractor which allows to obtain the alternative sequence, given some interval and variants inside this interval. Args: fasta_file: path to the fasta file (can be gzipped) ref_seq_extractor: extractor returning the reference sequence given some interval """ if fasta_file is not None: if ref_seq_extractor is not None: raise ValueError("either fasta_file or ref_seq_extractor have to be specified") self._ref_seq_extractor = FastaStringExtractor(fasta_file, use_strand=True) else: if ref_seq_extractor is None: raise ValueError("either fasta_file or ref_seq_extractor have to be specified") self._ref_seq_extractor = ref_seq_extractor
def __getitem__(self, idx): self.fasta_extractor = FastaStringExtractor(self.fasta_file) # get the intervals interval, targets = self.bt[idx] # resize to 500bp interval = resize_interval(interval, 500, anchor='center') # extract the sequence seq = self.fasta_extractor.extract(interval) # one-hot encode the sequence seq_onehot = self.transform(seq) ranges = GenomicRanges.from_interval(interval) return {"inputs": [seq_onehot], "metadata": [ranges]}
def __getitem__(self, idx): if self.input_data_extractors is None: self.input_data_extractors = { "seq": FastaStringExtractor(self.fasta_file), "dist_polya_st": DistToClosestLandmarkExtractor(gtf_file=self.gtf, landmarks=["polya"]) } interval = self.bt[idx] out = {} out['inputs'] = { "seq": one_hot_dna( FastaStringExtractor(self.fasta_file).extract(interval)), "dist_polya_st": np.squeeze(DistToClosestLandmarkExtractor(gtf_file=self.gtf, landmarks=["polya" ])([interval ]), axis=0) } # use trained spline transformation to transform it out["inputs"]["dist_polya_st"] = np.squeeze(self.transformer.transform( out["inputs"]["dist_polya_st"][np.newaxis], warn=False), axis=0) if self.target_dataset is not None: out["targets"] = np.array([self.target_dataset[idx]]) # get metadata out['metadata'] = {} out['metadata']['ranges'] = {} out['metadata']['ranges']['chr'] = interval.chrom out['metadata']['ranges']['start'] = interval.start out['metadata']['ranges']['end'] = interval.stop out['metadata']['ranges']['id'] = interval.name out['metadata']['ranges']['strand'] = interval.strand return out
def __getitem__(self, idx): if self.fasta_extractor is None: self.fasta_extractor = FastaStringExtractor(self.fasta_file, use_strand=True, force_upper=True) feature = self.start_codons[idx] interval = get_upstream(feature, self.n_upstream) seq = self.fasta_extractor.extract(interval) seq_one_hot_encoded = self.input_transform(seq) return { "inputs": seq_one_hot_encoded, "metadata": { "ranges": GenomicRanges.from_interval(interval), "gene_id": feature.attributes.get('gene_id', [""])[0], "transcript_id": feature.attributes.get('transcript_id', [""])[0], "gene_biotype": feature.attributes.get('gene_biotype', [""])[0] } }
def test_fastareader(use_strand, force_upper): fp = "tests/data/sample.fasta" with open(fp, "r") as ifh: for i, s in enumerate(ifh): if i == 1: fasta_str = s.lstrip() fr = FastaStringExtractor(fp, use_strand, force_upper) intervals = Interval("chr1", 0, 2, strand="-"), Interval("chr1", 3, 4) for interval in intervals: seq = fr.extract(interval) ref_seq = fasta_str[interval.start:interval.end] if use_strand and interval.strand == "-": ref_seq = list(ref_seq)[::-1] ref_seq = "".join([comp[el] for el in ref_seq]) if force_upper: assert seq == ref_seq.upper() else: assert seq == ref_seq
def __init__( self, fasta_file, gtf_file, ): genome_annotation = pr.read_gtf(gtf_file, as_df=True) roi = get_roi_from_genome_annotation(genome_annotation) roi = pr.PyRanges(roi) super().__init__( regions_of_interest=roi, reference_sequence=FastaStringExtractor(fasta_file), )
class SeqDataset(Dataset): """ Args: intervals_file: bed3 file containing intervals fasta_file: file path; Genome sequence target_file: file path; path to the targets in the csv format """ def __init__(self, intervals_file, fasta_file, target_file=None, use_linecache=False): # intervals if use_linecache: self.bt = BedToolLinecache(intervals_file) else: self.bt = BedTool(intervals_file) self.fasta_file = fasta_file self.fasta_extractor = None # Targets if target_file is not None: self.targets = pd.read_csv(target_file) else: self.targets = None def __len__(self): return len(self.bt) def __getitem__(self, idx): if self.fasta_extractor is None: self.fasta_extractor = FastaStringExtractor(self.fasta_file) interval = self.bt[idx] # Intervals need to be 1000bp wide assert interval.stop - interval.start == 1000 if self.targets is not None: y = self.targets.iloc[idx].values else: y = {} # Run the fasta extractor seq = one_hot_dna(self.fasta_extractor.extract(interval), dtype=np.float32) # TODO: Remove additional dtype after kipoiseq gets a new release return { "inputs": seq, "targets": y, "metadata": { "ranges": GenomicRanges.from_interval(interval) } }
def __getitem__(self, idx): if self.fasta_extractor is None: self.fasta_extractor = FastaStringExtractor(self.fasta_file) interval = self.bt[idx] # Intervals need to be 1000bp wide assert interval.stop - interval.start == 1000 if self.targets is not None: y = self.targets.iloc[idx].values else: y = {} # Run the fasta extractor seq = one_hot_dna(self.fasta_extractor.extract(interval), dtype=np.float32) # TODO: Remove additional dtype after kipoiseq gets a new release return { "inputs": seq, "targets": y, "metadata": { "ranges": GenomicRanges.from_interval(interval) } }
def __init__(self, gtf_file, fasta_file, num_upstream, num_downstream, gtf_filter='gene_type == "protein_coding"', anchor='tss', transform=one_hot_dna, interval_attrs=["gene_id", "Strand"], use_strand=True): # Read and filter gtf gtf = pr.read_gtf(gtf_file).df if gtf_filter: if isinstance(gtf_filter, str): gtf = gtf.query(gtf_filter) else: gtf = gtf_filter(gtf) # Extract anchor if isinstance(anchor, str): anchor = anchor.lower() if anchor in self._function_mapping: anchor = self._function_mapping[anchor] else: raise Exception("No valid anchorpoint was chosen") self._gtf_anchor = anchor(gtf) # Other parameters self._use_strand = use_strand self._fa = FastaStringExtractor(fasta_file, use_strand=self._use_strand) self._transform = transform if self._transform is None: self._transform = lambda x: x self._num_upstream = num_upstream self._num_downstream = num_downstream self._interval_attrs = interval_attrs
def __init__( self, fasta_file, gtf_file, vcf_file, vcf_file_tbi=None, vcf_lazy=True, ): genome_annotation = pr.read_gtf(gtf_file, as_df=True) roi = get_roi_from_genome_annotation(genome_annotation) roi = pr.PyRanges(roi) from kipoiseq.extractors import MultiSampleVCF super().__init__(regions_of_interest=roi, reference_sequence=FastaStringExtractor(fasta_file), variants=MultiSampleVCF(vcf_file, lazy=vcf_lazy))
def __init__(self, gtf_file, fasta_file, vcf_file, feature_type, infer_from_cds=False, on_error_warn=True, vcf_file_tbi=None, **kwargs): self.gtf_file = gtf_file self.fasta_file = fasta_file self.vcf_file = vcf_file self.feature_type = feature_type self.infer_from_cds = infer_from_cds self.on_error_warn = on_error_warn self.interval_fetcher = UTRFetcher(gtf_file=gtf_file, feature_type=feature_type, infer_from_cds=infer_from_cds, on_error_warn=on_error_warn) self.multi_sample_VCF = MultiSampleVCF(vcf_file) self.reference_seq_extractor = FastaStringExtractor(fasta_file) df = self.interval_fetcher.df import pyranges # match variant with transcript_id self.variant_matcher = SingleVariantMatcher(self.vcf_file, pranges=pyranges.PyRanges( df.reset_index())) self.extractor = GenericSingleVariantMultiIntervalVCFSeqExtractor( interval_fetcher=self.interval_fetcher, reference_seq_extractor=self.reference_seq_extractor, variant_matcher=self.variant_matcher, multi_sample_VCF=self.multi_sample_VCF, ) # # only needed metadata # self.metadatas = ( # ( # df.loc[~df.index.duplicated(keep='first')] # ).drop(columns=['Start', 'End']) # ) # generator for all sequences with variants self.sequences = self._extractor()
class FixedSeq5UtrDl(Dataset): n_upstream = 50 def __init__(self, gtf_file, fasta_file, disable_infer_transcripts=True, disable_infer_genes=True): self.gtf_file = gtf_file self.fasta_file = fasta_file self.fasta_extractor = None self.db = gffutils.create_db( gtf_file, ":memory:", disable_infer_transcripts=disable_infer_transcripts, disable_infer_genes=disable_infer_genes) self.start_codons = list(self.db.features_of_type("start_codon")) self.input_transform = OneHot() def __len__(self): return len(self.start_codons) def __getitem__(self, idx): if self.fasta_extractor is None: self.fasta_extractor = FastaStringExtractor(self.fasta_file, use_strand=True, force_upper=True) feature = self.start_codons[idx] interval = get_upstream(feature, self.n_upstream) seq = self.fasta_extractor.extract(interval) seq_one_hot_encoded = self.input_transform(seq) return { "inputs": seq_one_hot_encoded, "metadata": { "ranges": GenomicRanges.from_interval(interval), "gene_id": feature.attributes.get('gene_id', [""])[0], "transcript_id": feature.attributes.get('transcript_id', [""])[0], "gene_biotype": feature.attributes.get('gene_biotype', [""])[0] } }
def __init__(self, fasta_file: str = None, reference_sequence: BaseExtractor = None, use_strand=True): """ Sequence extractor which allows to obtain the alternative sequence, given some interval and variants inside this interval. Args: fasta_file: path to the fasta file (can be gzipped) reference_sequence: extractor returning the reference sequence given some interval use_strand (bool): if True, the extracted sequence is reverse complemented in case interval.strand == "-" """ self._use_strand = use_strand if fasta_file is not None: if reference_sequence is not None: raise ValueError( "either fasta_file or ref_seq_extractor have to be specified") self._ref_seq_extractor = FastaStringExtractor( fasta_file, use_strand=False) else: if reference_sequence is None: raise ValueError( "either fasta_file or ref_seq_extractor have to be specified") self._ref_seq_extractor = reference_sequence
def test_extract(variant_seq_extractor): variants = [Variant.from_cyvcf(v) for v in VCF(vcf_file)] interval = Interval('chr1', 2, 9) seq = variant_seq_extractor.extract(interval, variants, anchor=5) assert len(seq) == interval.end - interval.start assert seq == 'CGAACGT' interval = Interval('chr1', 2, 9, strand='-') seq = variant_seq_extractor.extract(interval, variants, anchor=5) assert len(seq) == interval.end - interval.start assert seq == 'ACGTTCG' interval = Interval('chr1', 4, 14) seq = variant_seq_extractor.extract(interval, variants, anchor=7) assert len(seq) == interval.end - interval.start assert seq == 'AACGTAACGT' interval = Interval('chr1', 4, 14) seq = variant_seq_extractor.extract(interval, variants, anchor=4) assert len(seq) == interval.end - interval.start assert seq == 'GAACGTAACG' interval = Interval('chr1', 2, 5) seq = variant_seq_extractor.extract(interval, variants, anchor=3) assert len(seq) == interval.end - interval.start assert seq == 'GCG' interval = Interval('chr1', 24, 34) seq = variant_seq_extractor.extract(interval, variants, anchor=27) assert len(seq) == interval.end - interval.start assert seq == 'TGATAACGTA' interval = Interval('chr1', 25, 35) seq = variant_seq_extractor.extract(interval, variants, anchor=34) assert len(seq) == interval.end - interval.start assert seq == 'TGATAACGTA' interval = Interval('chr1', 34, 44) seq = variant_seq_extractor.extract(interval, variants, anchor=37) assert len(seq) == interval.end - interval.start assert seq == 'AACGTAACGT' interval = Interval('chr1', 34, 44) seq = variant_seq_extractor.extract(interval, variants, anchor=100) assert len(seq) == interval.end - interval.start assert seq == 'AACGTAACGT' interval = Interval('chr1', 5, 11, strand='+') seq = variant_seq_extractor.extract(interval, variants, anchor=10, fixed_len=False) assert seq == 'ACGTAA' interval = Interval('chr1', 0, 3, strand='+') seq = variant_seq_extractor.extract(interval, variants, anchor=10, fixed_len=False) assert seq == 'ACG' interval = Interval('chr1', 0, 3, strand='+') ref_seq_extractor = FastaStringExtractor(fasta_file, use_strand=True) seq = VariantSeqExtractor(reference_sequence=ref_seq_extractor).extract( interval, variants, anchor=10, fixed_len=False) assert seq == 'ACG'
class StringSeqIntervalDl(Dataset): """ info: doc: > Dataloader for a combination of fasta and tab-delimited input files such as bed files. The dataloader extracts regions from the fasta file as defined in the tab-delimited `intervals_file`. Returned sequences are of the type np.array([str]). args: intervals_file: doc: bed3+<columns> file path containing intervals + (optionally) labels example: url: https://raw.githubusercontent.com/kipoi/kipoiseq/master/tests/data/intervals_51bp.tsv md5: a76e47b3df87fd514860cf27fdc10eb4 fasta_file: doc: Reference genome FASTA file path. example: url: https://raw.githubusercontent.com/kipoi/kipoiseq/master/tests/data/hg38_chr22_32000000_32300000.fa md5: 01320157a250a3d2eea63e89ecf79eba num_chr_fasta: doc: True, the the dataloader will make sure that the chromosomes don't start with chr. label_dtype: doc: None, datatype of the task labels taken from the intervals_file. Example - str, int, float, np.float32 auto_resize_len: doc: None, required sequence length. # max_seq_len: # doc: maximum allowed sequence length # use_strand: # doc: reverse-complement fasta sequence if bed file defines negative strand force_upper: doc: Force uppercase output of sequences ignore_targets: doc: if True, don't return any target variables output_schema: inputs: name: seq shape: () doc: DNA sequence as string special_type: DNAStringSeq associated_metadata: ranges targets: shape: (None,) doc: (optional) values following the bed-entry - chr start end target1 target2 .... metadata: ranges: type: GenomicRanges doc: Ranges describing inputs.seq postprocessing: variant_effects: bed_input: - intervals_file """ def __init__( self, intervals_file, fasta_file, num_chr_fasta=False, label_dtype=None, auto_resize_len=None, # max_seq_len=None, # use_strand=False, force_upper=True, ignore_targets=False): self.num_chr_fasta = num_chr_fasta self.intervals_file = intervals_file self.fasta_file = fasta_file self.auto_resize_len = auto_resize_len # self.use_strand = use_strand self.force_upper = force_upper # self.max_seq_len = max_seq_len # if use_strand: # # require a 6-column bed-file if strand is used # bed_columns = 6 # else: # bed_columns = 3 self.bed = BedDataset(self.intervals_file, num_chr=self.num_chr_fasta, bed_columns=3, label_dtype=parse_dtype(label_dtype), ignore_targets=ignore_targets) self.fasta_extractors = None def __len__(self): return len(self.bed) def __getitem__(self, idx): if self.fasta_extractors is None: self.fasta_extractors = FastaStringExtractor( self.fasta_file, use_strand=False, # self.use_strand, force_upper=self.force_upper) interval, labels = self.bed[idx] if self.auto_resize_len: # automatically resize the sequence to cerat interval = resize_interval(interval, self.auto_resize_len, anchor='center') # QUESTION: @kromme - why to we need max_seq_len? # if self.max_seq_len is not None: # assert interval.stop - interval.start <= self.max_seq_len # Run the fasta extractor and transform if necessary seq = self.fasta_extractors.extract(interval) return { "inputs": np.array(seq), "targets": labels, "metadata": { "ranges": GenomicRanges(interval.chrom, interval.start, interval.stop, str(idx)) } } @classmethod def get_output_schema(cls): output_schema = deepcopy(cls.output_schema) kwargs = default_kwargs(cls) ignore_targets = kwargs['ignore_targets'] if ignore_targets: output_schema.targets = None return output_schema
class StrandedSequenceVariantDataloader(Dataset): """ This Dataloader requires the following input files: 1. bed3+ where a specific user-specified column (>3, 1-based) of the bed denotes the strand and a specific user-specified column (>3, 1-based) of the bed denotes the transcript id (or some other id that explains which exons in the bed belong together to form one sequence). All columns of the bed, except the first three, the id and the strand, are ignored. 2. fasta file that provides the reference genome 3. bgzip compressed (single sample) vcf that provides the variants 4. A chromosome order file (such as a fai file) that specifies the order of chromosomes (must be valid for all files) The bed and vcf must both be sorted (by position) and a tabix index must be present. (must lie in the same directory and have the same name + .tbi) The num_chr flag indicates whether chromosomes are listed numerically or with a chr prefix. This must be consistent across all input files! The dataloader finds all intervals in the bed which contain at least one variant in the vcf. It then joins intervals belonging to the same transcript, as specified by the id, to a single sequence. For these sequences, it extracts the reference sequence from the fasta file, injects the applicable variants and reverse complements according to the strand information. This means that if a vcf mixes variants from more than one patient, the results will not be meaningful. Split the vcf by patient and run the predictions seperately in this case! Returns the reference sequence and variant sequence as np.array([reference_sequence, variant_sequence]). Region metadata is additionally provided """ def __init__(self, intervals_file, fasta_file, vcf_file, chr_order_file, vcf_file_tbi=None, strand_column=6, id_column=4, num_chr=True): # workaround for test if vcf_file_tbi is not None and vcf_file_tbi.endswith("vcf_file_tbi"): os.rename(vcf_file_tbi, vcf_file_tbi.replace("vcf_file_tbi", "vcf_file.tbi")) self.num_chr_fasta = num_chr self.intervals_file = intervals_file self.fasta_file = fasta_file self.vcf_file = vcf_file self.chr_order_file = chr_order_file self.strand_column = strand_column - 1 self.id_column = id_column - 1 self.force_upper = True # "Parse" bed file self.bed = BedDataset(self.intervals_file, num_chr=self.num_chr_fasta, bed_columns=3, label_dtype=str, ignore_targets=False) # Intersect bed and vcf using bedtools # bedtools c flag: for each bed interval, counts number of vcf entries it overlaps bed_tool = pybedtools.BedTool(self.intervals_file) intersect_counts = list( bed_tool.intersect(self.vcf_file, c=True, sorted=True, g=self.chr_order_file)) intersect_counts = np.array( [isect.count for isect in intersect_counts]) # Retain only those transcripts that intersect a variant utr5_bed = self.bed.df id_col = utr5_bed.iloc[:, self.id_column] retain_transcripts = utr5_bed[ intersect_counts > 0].iloc[:, self.id_column] utr5_bed = utr5_bed[utr5_bed.iloc[:, self.id_column].isin( retain_transcripts)] # Aggregate 5utr positions per transcript tuples = list(zip(utr5_bed.iloc[:, 1], utr5_bed.iloc[:, 2])) pos = [[x] for x in tuples] id_chr_strand = list( zip(utr5_bed.iloc[:, self.id_column], utr5_bed.iloc[:, 0], utr5_bed.iloc[:, self.strand_column])) utr5_bed_posaggreg = pd.DataFrame({ "pos": pos, "id_chr_strand": id_chr_strand }) utr5_bed_posaggreg = utr5_bed_posaggreg.groupby("id_chr_strand").agg( {'pos': 'sum'}) # Rebuild "bed" utr5_bed_posaggreg["id"] = [x[0] for x in utr5_bed_posaggreg.index] utr5_bed_posaggreg["chr"] = [x[1] for x in utr5_bed_posaggreg.index] utr5_bed_posaggreg["strand"] = [x[2] for x in utr5_bed_posaggreg.index] self.bed = utr5_bed_posaggreg.reset_index()[[ "id", "chr", "pos", "strand" ]] self.fasta_extractor = None self.vcf = None self.vcf_extractor = None def __len__(self): return len(self.bed) def __getitem__(self, idx): if self.fasta_extractor is None: self.fasta_extractor = FastaStringExtractor( self.fasta_file, use_strand=True, force_upper=self.force_upper) if self.vcf is None: self.vcf = MultiSampleVCF(self.vcf_file) if self.vcf_extractor is None: self.vcf_extractor = VariantSeqExtractor(self.fasta_file) entry = self.bed.iloc[idx] entry_id = entry["id"] entry_chr = entry["chr"] entry_pos = entry["pos"] entry_strand = entry["strand"] ref_exons = [] var_exons = [] exon_pos_strings = [] exon_var_strings = [] for exon in entry_pos: # We get the interval interval = pybedtools.Interval(to_scalar(entry_chr), to_scalar(exon[0]), to_scalar(exon[1]), strand=to_scalar(entry_strand)) exon_pos_strings.append("%s-%s" % (str(exon[0]), str(exon[1]))) # We get the reference sequence ref_seq = self.fasta_extractor.extract(interval) # We get the variants, insert them and also save them as metadata variants = list(self.vcf.fetch_variants(interval)) if len(variants) == 0: ref_exons.append(ref_seq) var_exons.append(ref_seq) else: var_seq = self.vcf_extractor.extract(interval, variants=variants, anchor=0, fixed_len=False) var_string = ";".join([str(var) for var in variants]) ref_exons.append(ref_seq) var_exons.append(var_seq) exon_var_strings.append(var_string) # Combine if entry_strand == "-": ref_exons.reverse() var_exons.reverse() ref_seq = "".join(ref_exons) var_seq = "".join(var_exons) pos_string = ";".join(exon_pos_strings) var_string = ";".join(exon_var_strings) return { "inputs": { "ref_seq": ref_seq, "alt_seq": var_seq, }, "metadata": { "id": entry_id, "chr": entry_chr, "exon_positions": pos_string, "strand": entry_strand, "variants": var_string } }
def __getitem__(self, idx): if self.fasta_extractor is None: self.fasta_extractor = FastaStringExtractor( self.fasta_file, use_strand=True, force_upper=self.force_upper) if self.vcf is None: self.vcf = MultiSampleVCF(self.vcf_file) if self.vcf_extractor is None: self.vcf_extractor = VariantSeqExtractor(self.fasta_file) entry = self.bed.iloc[idx] entry_id = entry["id"] entry_chr = entry["chr"] entry_pos = entry["pos"] entry_strand = entry["strand"] ref_exons = [] var_exons = [] exon_pos_strings = [] exon_var_strings = [] for exon in entry_pos: # We get the interval interval = pybedtools.Interval(to_scalar(entry_chr), to_scalar(exon[0]), to_scalar(exon[1]), strand=to_scalar(entry_strand)) exon_pos_strings.append("%s-%s" % (str(exon[0]), str(exon[1]))) # We get the reference sequence ref_seq = self.fasta_extractor.extract(interval) # We get the variants, insert them and also save them as metadata variants = list(self.vcf.fetch_variants(interval)) if len(variants) == 0: ref_exons.append(ref_seq) var_exons.append(ref_seq) else: var_seq = self.vcf_extractor.extract(interval, variants=variants, anchor=0, fixed_len=False) var_string = ";".join([str(var) for var in variants]) ref_exons.append(ref_seq) var_exons.append(var_seq) exon_var_strings.append(var_string) # Combine if entry_strand == "-": ref_exons.reverse() var_exons.reverse() ref_seq = "".join(ref_exons) var_seq = "".join(var_exons) pos_string = ";".join(exon_pos_strings) var_string = ";".join(exon_var_strings) return { "inputs": { "ref_seq": ref_seq, "alt_seq": var_seq, }, "metadata": { "id": entry_id, "chr": entry_chr, "exon_positions": pos_string, "strand": entry_strand, "variants": var_string } }
class VariantCombinator: def __init__(self, fasta_file: str, bed_file: str = None, variant_type='snv', alphabet='DNA'): if variant_type not in {'all', 'snv', 'in', 'del'}: raise ValueError("variant_type should be one of " "{'all', 'snv', 'in', 'del'}") self.bed_file = bed_file self.fasta = fasta_file self.fasta = FastaStringExtractor(fasta_file, force_upper=True) self.variant_type = variant_type self.alphabet = alphabets[alphabet] def combination_variants_snv(self, interval: Interval) -> Iterable[Variant]: """Returns all the possible variants in the regions. interval: interval of variants """ seq = self.fasta.extract(interval) for pos, ref in zip(range(interval.start, interval.end), seq): pos = pos + 1 # 0 to 1 base for alt in self.alphabet: if ref != alt: yield Variant(interval.chrom, pos, ref, alt) def combination_variants_insertion(self, interval, length=2) -> Iterable[Variant]: """Returns all the possible variants in the regions. interval: interval of variants length: insertions up to length """ if length < 2: raise ValueError('length argument should be larger than 1') seq = self.fasta.extract(interval) for pos, ref in zip(range(interval.start, interval.end), seq): pos = pos + 1 # 0 to 1 base for l in range(2, length + 1): for alt in product(self.alphabet, repeat=l): yield Variant(interval.chrom, pos, ref, ''.join(alt)) def combination_variants_deletion(self, interval, length=1) -> Iterable[Variant]: """Returns all the possible variants in the regions. interval: interval of variants length: deletions up to length """ if length < 1 and length <= interval.width: raise ValueError('length argument should be larger than 0' ' and smaller than interval witdh') seq = self.fasta.extract(interval) for i, pos in enumerate(range(interval.start, interval.end)): pos = pos + 1 # 0 to 1 base for j in range(1, length + 1): if i + j <= len(seq): yield Variant(interval.chrom, pos, seq[i:i + j], '') def combination_variants(self, interval, variant_type='snv', in_length=2, del_length=2) -> Iterable[Variant]: if variant_type in {'snv', 'all'}: yield from self.combination_variants_snv(interval) if variant_type in {'indel', 'in', 'all'}: yield from self.combination_variants_insertion(interval, length=in_length) if variant_type in {'indel', 'del', 'all'}: yield from self.combination_variants_deletion(interval, length=del_length) def __iter__(self) -> Iterable[Variant]: import pyranges as pr gr = pr.read_bed(self.bed_file) gr = gr.merge(strand=False).sort() for interval in pyranges_to_intervals(gr): yield from self.combination_variants(interval, self.variant_type) def to_vcf(self, path): from cyvcf2 import Writer header = '''##fileformat=VCFv4.2 #CHROM POS ID REF ALT QUAL FILTER INFO ''' writer = Writer.from_string(path, header) for v in self: variant = writer.variant_from_string('\t'.join( [v.chrom, str(v.pos), '.', v.ref, v.alt, '.', '.', '.'])) writer.write_record(variant)
class AnchoredGTFDl(Dataset): """ info: doc: > Dataloader for a combination of fasta and gtf files. The dataloader extracts fixed length regions around anchor points. Anchor points are extracted from the gtf based on the anchor parameter. The sequences corresponding to the region are then extracted from the fasta file and optionally trnasformed using a function given by the transform parameter. args: gtf_file: doc: Path to a gtf file (str) example: url: https://zenodo.org/record/1466102/files/example_files-gencode.v24.annotation_chr22.gtf md5: c0d1bf7738f6a307b425e4890621e7d9 fasta_file: doc: Reference genome FASTA file path (str) example: url: https://zenodo.org/record/1466102/files/example_files-hg38_chr22.fa md5: b0f5cdd4f75186f8a4d2e23378c57b5b num_upstream: doc: Number of nt by which interval is extended upstream of the anchor point num_downstream: doc: Number of nt by which interval is extended downstream of the anchor point gtf_filter: doc: > Allows to filter the gtf before extracting the anchor points. Can be str, callable or None. If str, it is interpreted as argument to pandas .query(). If callable, it is interpreted as function that filters a pandas dataframe and returns the filtered df. anchor: doc: > Defines the anchor points. Can be str or callable. If it is a callable, it is treated as function that takes a pandas dataframe and returns a modified version of the dataframe where each row represents one anchor point, the position of which is stored in the column called anchor_pos. If it is a string, a predefined function is loaded. Currently available are tss (anchor is the start of a gene), start_codon (anchor is the start of the start_codon), stop_codon (anchor is the position right after the stop_codon), polya (anchor is the position right after the end of a gene). transform: doc: Callable (or None) to transform the extracted sequence (e.g. one-hot) interval_attrs: doc: Metadata to extract from the gtf, e.g. ["gene_id", "Strand"] use_strand: doc: True or False output_schema: inputs: name: seq shape: (None, 4) special_type: DNAStringSeq doc: exon sequence with flanking intronic sequence associated_metadata: ranges metadata: gene_id: type: str doc: gene id Strand: type: str doc: Strand ranges: type: GenomicRanges doc: ranges that the sequences were extracted """ _function_mapping = { "tss": lambda x: AnchoredGTFDl.anchor_to_feature_start( x, "gene", use_strand=True), "start_codon": lambda x: AnchoredGTFDl.anchor_to_feature_start( x, "start_codon", use_strand=True), "stop_codon": lambda x: AnchoredGTFDl.anchor_to_feature_end( x, "stop_codon", use_strand=True), "polya": lambda x: AnchoredGTFDl.anchor_to_feature_end( x, "gene", use_strand=True) } def __init__(self, gtf_file, fasta_file, num_upstream, num_downstream, gtf_filter='gene_type == "protein_coding"', anchor='tss', transform=one_hot_dna, interval_attrs=["gene_id", "Strand"], use_strand=True): # Read and filter gtf gtf = pr.read_gtf(gtf_file).df if gtf_filter: if isinstance(gtf_filter, str): gtf = gtf.query(gtf_filter) else: gtf = gtf_filter(gtf) # Extract anchor if isinstance(anchor, str): anchor = anchor.lower() if anchor in self._function_mapping: anchor = self._function_mapping[anchor] else: raise Exception("No valid anchorpoint was chosen") self._gtf_anchor = anchor(gtf) # Other parameters self._use_strand = use_strand self._fa = FastaStringExtractor(fasta_file, use_strand=self._use_strand) self._transform = transform if self._transform is None: self._transform = lambda x: x self._num_upstream = num_upstream self._num_downstream = num_downstream self._interval_attrs = interval_attrs def _create_anchored_interval(self, row, num_upstream, num_downstream): if self._use_strand == True and row.Strand == "-": # negative strand start = row.anchor_pos - num_downstream end = row.anchor_pos + num_upstream else: # positive strand start = row.anchor_pos - num_upstream end = row.anchor_pos + num_downstream interval = Interval(row.Chromosome, start, end, strand=row.Strand) return interval def __len__(self): return len(self._gtf_anchor) def __getitem__(self, idx): row = self._gtf_anchor.iloc[idx] interval = self._create_anchored_interval( row, num_upstream=self._num_upstream, num_downstream=self._num_downstream) sequence = self._fa.extract(interval) sequence = self._transform(sequence) metadata_dict = {k: row.get(k, '') for k in self._interval_attrs} metadata_dict["ranges"] = GenomicRanges(interval.chrom, interval.start, interval.stop, str(idx)) return {"inputs": np.array(sequence), "metadata": metadata_dict} @staticmethod def anchor_to_feature_start(gtf, feature, use_strand): gtf = gtf.query('Feature == @feature') if use_strand: gtf["anchor_pos"] = ((gtf.Start * (gtf.Strand == "+")) + (gtf.End * (gtf.Strand == "-"))) else: gtf["anchor_pos"] = gtf.Start return gtf @staticmethod def anchor_to_feature_end(gtf, feature, use_strand): gtf = gtf.query('Feature == @feature') if use_strand: gtf["anchor_pos"] = ((gtf.End * (gtf.Strand == "+")) + (gtf.Start * (gtf.Strand == "-"))) else: gtf["anchor_pos"] = gtf.End return gtf
def __init__(self, fasta_file): """ Args: fasta_file: path to the fasta file (can be gzipped) """ self.fasta = FastaStringExtractor(fasta_file, use_strand=True)
class VariantSeqExtractor(BaseExtractor): def __init__(self, fasta_file: str = None, ref_seq_extractor: BaseExtractor = None): """ Sequence extractor which allows to obtain the alternative sequence, given some interval and variants inside this interval. Args: fasta_file: path to the fasta file (can be gzipped) ref_seq_extractor: extractor returning the reference sequence given some interval """ if fasta_file is not None: if ref_seq_extractor is not None: raise ValueError("either fasta_file or ref_seq_extractor have to be specified") self._ref_seq_extractor = FastaStringExtractor(fasta_file, use_strand=True) else: if ref_seq_extractor is None: raise ValueError("either fasta_file or ref_seq_extractor have to be specified") self._ref_seq_extractor = ref_seq_extractor @property def ref_seq_extractor(self): return self._ref_seq_extractor def extract(self, interval, variants, anchor, fixed_len=True, **kwargs): """ Args: interval: pybedtools.Interval Region of interest from which to query the sequence. 0-based variants: List[cyvcf2.Variant]: variants overlapping the `interval`. can also be indels. 1-based anchor: absolution position w.r.t. the interval start. (0-based). E.g. for an interval of `chr1:10-20` the anchor of 10 denotes the point chr1:10 in the 0-based coordinate system. fixed_len: if True, the return sequence will have the same length as the `interval` (e.g. `interval.end - interval.start`) Returns: A single sequence (`str`) with all the variants applied. """ # Preprocessing anchor = max(min(anchor, interval.end), interval.start) variant_pairs = self._variant_to_sequence(variants) # 1. Split variants overlapping with anchor # and interval start end if not fixed_len variant_pairs = self._split_overlapping(variant_pairs, anchor) if not fixed_len: variant_pairs = self._split_overlapping( variant_pairs, interval.start, which='right') variant_pairs = self._split_overlapping( variant_pairs, interval.end, which='left') variant_pairs = list(variant_pairs) # 2. split the variants into upstream and downstream # and sort the variants in each interval upstream_variants = sorted( filter(lambda x: x[0].start >= anchor, variant_pairs), key=lambda x: x[0].start ) downstream_variants = sorted( filter(lambda x: x[0].start < anchor, variant_pairs), key=lambda x: x[0].start, reverse=True ) # 3. Extend start and end position for deletions if fixed_len: istart, iend = self._updated_interval( interval, upstream_variants, downstream_variants) else: istart, iend = interval.start, interval.end # 4. Iterate from the anchor point outwards. At each # register the interval from which to take the reference sequence # as well as the interval for the variant down_sb = self._downstream_builder( downstream_variants, interval, anchor, istart) up_sb = self._upstream_builder( upstream_variants, interval, anchor, iend) # 5. fetch the sequence and restore intervals in builder seq = self._fetch(interval, istart, iend) up_sb.restore(seq) down_sb.restore(seq) # 6. Concate sequences from the upstream and downstream splits. Concat # upstream and downstream sequence. Cut to fix the length. down_str = down_sb.concat() up_str = up_sb.concat() if fixed_len: down_str, up_str = self._cut_to_fix_len( down_str, up_str, interval, anchor) seq = down_str + up_str if interval.strand == '-': seq = complement(seq)[::-1] return seq @staticmethod def _variant_to_sequence(variants): """ Convert `cyvcf2.Variant` objects to `pyfaidx.Seqeunce` objects for reference and variants. """ for v in variants: ref = Sequence(name=v.chrom, seq=v.ref, start=v.start, end=v.start + len(v.ref)) alt = Sequence(name=v.chrom, seq=v.alt, start=v.start, end=v.start + len(v.alt)) yield ref, alt @staticmethod def _split_overlapping(variant_pairs, anchor, which='both'): """ Split the variants hitting the anchor into two """ for ref, alt in variant_pairs: if ref.start < anchor < ref.end: mid = anchor - ref.start if which == 'left' or which == 'both': yield ref[:mid], alt[:mid] if which == 'right' or which == 'both': yield ref[mid:], alt[mid:] else: yield ref, alt @staticmethod def _updated_interval(interval, up_variants, down_variants): istart = interval.start iend = interval.end for ref, alt in up_variants: diff_len = len(alt) - len(ref) if diff_len < 0: iend -= diff_len for ref, alt in down_variants: diff_len = len(alt) - len(ref) if diff_len < 0: istart += diff_len return istart, iend @staticmethod def _downstream_builder(down_variants, interval, anchor, istart): down_sb = IntervalSeqBuilder() prev = anchor for ref, alt in down_variants: if ref.end <= istart: break down_sb.append(Interval(interval.chrom, ref.end, prev)) down_sb.append(alt) prev = ref.start down_sb.append(Interval(interval.chrom, istart, prev)) down_sb.reverse() return down_sb @staticmethod def _upstream_builder(up_variants, interval, anchor, iend): up_sb = IntervalSeqBuilder() prev = anchor for ref, alt in up_variants: if ref.start >= iend: break up_sb.append(Interval(interval.chrom, prev, ref.start)) up_sb.append(alt) prev = ref.end up_sb.append(Interval(interval.chrom, prev, iend)) return up_sb def _fetch(self, interval, istart, iend): seq = self._ref_seq_extractor.extract(Interval(interval.chrom, istart, iend)) seq = Sequence(name=interval.chrom, seq=seq, start=istart, end=iend) return seq @staticmethod def _cut_to_fix_len(down_str, up_str, interval, anchor): down_len = anchor - interval.start up_len = interval.end - anchor down_str = down_str[-down_len:] if down_len else '' up_str = up_str[: up_len] if up_len else '' return down_str, up_str