def _parameterized_filter_intervals(vds: 'VariantDataset', intervals, keep: bool, mode: str) -> 'VariantDataset': intervals_table = None if isinstance(intervals, Table): expected = hl.tinterval(hl.tlocus(vds.reference_genome)) if len(intervals.key) != 1 or intervals.key[0].dtype != hl.tinterval( hl.tlocus(vds.reference_genome)): raise ValueError( f"'filter_intervals': expect a table with a single key of type {expected}; " f"found {list(intervals.key.dtype.values())}") intervals_table = intervals intervals = intervals.aggregate(hl.agg.collect(intervals.key[0])) if mode == 'variants_only': variant_data = hl.filter_intervals(vds.variant_data, intervals, keep) return VariantDataset(vds.reference_data, variant_data) if mode == 'split_at_boundaries': if not keep: raise ValueError( "filter_intervals mode 'split_at_boundaries' not implemented for keep=False" ) par_intervals = intervals_table or hl.Table.parallelize( intervals.map(lambda x: hl.struct(interval=x)), schema=hl.tstruct(interval=intervals.dtype.element_type), key='interval') ref = segment_reference_blocks(vds.reference_data, par_intervals).drop( 'interval_end', list(par_intervals.key)[0]) return VariantDataset( ref, hl.filter_intervals(vds.variant_data, intervals, keep)) return VariantDataset( hl.filter_intervals(vds.reference_data, intervals, keep), hl.filter_intervals(vds.variant_data, intervals, keep))
def test_reference_genome_liftover(self): grch37 = hl.get_reference('GRCh37') grch38 = hl.get_reference('GRCh38') self.assertTrue(not grch37.has_liftover('GRCh38') and not grch38.has_liftover('GRCh37')) grch37.add_liftover(resource('grch37_to_grch38_chr20.over.chain.gz'), 'GRCh38') grch38.add_liftover(resource('grch38_to_grch37_chr20.over.chain.gz'), 'GRCh37') assert grch37.has_liftover('GRCh38') assert grch38.has_liftover('GRCh37') ds = hl.import_vcf(resource('sample.vcf')) t = ds.annotate_rows(liftover=hl.liftover(hl.liftover(ds.locus, 'GRCh38'), 'GRCh37')).rows() assert t.all(t.locus == t.liftover) null_locus = hl.null(hl.tlocus('GRCh38')) rows = [ {'l37': hl.locus('20', 1, 'GRCh37'), 'l38': null_locus}, {'l37': hl.locus('20', 60000, 'GRCh37'), 'l38': null_locus}, {'l37': hl.locus('20', 60001, 'GRCh37'), 'l38': hl.locus('chr20', 79360, 'GRCh38')}, {'l37': hl.locus('20', 278686, 'GRCh37'), 'l38': hl.locus('chr20', 298045, 'GRCh38')}, {'l37': hl.locus('20', 278687, 'GRCh37'), 'l38': hl.locus('chr20', 298046, 'GRCh38')}, {'l37': hl.locus('20', 278688, 'GRCh37'), 'l38': null_locus}, {'l37': hl.locus('20', 278689, 'GRCh37'), 'l38': null_locus}, {'l37': hl.locus('20', 278690, 'GRCh37'), 'l38': null_locus}, {'l37': hl.locus('20', 278691, 'GRCh37'), 'l38': hl.locus('chr20', 298047, 'GRCh38')}, {'l37': hl.locus('20', 37007586, 'GRCh37'), 'l38': hl.locus('chr12', 32563117, 'GRCh38')}, {'l37': hl.locus('20', 62965520, 'GRCh37'), 'l38': hl.locus('chr20', 64334167, 'GRCh38')}, {'l37': hl.locus('20', 62965521, 'GRCh37'), 'l38': null_locus} ] schema = hl.tstruct(l37=hl.tlocus(grch37), l38=hl.tlocus(grch38)) t = hl.Table.parallelize(rows, schema) self.assertTrue(t.all(hl.cond(hl.is_defined(t.l38), hl.liftover(t.l37, 'GRCh38') == t.l38, hl.is_missing(hl.liftover(t.l37, 'GRCh38'))))) t = t.filter(hl.is_defined(t.l38)) self.assertTrue(t.count() == 6) t = t.key_by('l38') t.count() self.assertTrue(list(t.key) == ['l38']) null_locus_interval = hl.null(hl.tinterval(hl.tlocus('GRCh38'))) rows = [ {'i37': hl.locus_interval('20', 1, 60000, True, False, 'GRCh37'), 'i38': null_locus_interval}, {'i37': hl.locus_interval('20', 60001, 82456, True, True, 'GRCh37'), 'i38': hl.locus_interval('chr20', 79360, 101815, True, True, 'GRCh38')} ] schema = hl.tstruct(i37=hl.tinterval(hl.tlocus(grch37)), i38=hl.tinterval(hl.tlocus(grch38))) t = hl.Table.parallelize(rows, schema) self.assertTrue(t.all(hl.liftover(t.i37, 'GRCh38') == t.i38)) grch37.remove_liftover("GRCh38") grch38.remove_liftover("GRCh37")
def test_reference_genome_liftover(self): grch37 = hl.get_reference('GRCh37') grch38 = hl.get_reference('GRCh38') self.assertTrue(not grch37.has_liftover('GRCh38') and not grch38.has_liftover('GRCh37')) grch37.add_liftover(resource('grch37_to_grch38_chr20.over.chain.gz'), 'GRCh38') grch38.add_liftover(resource('grch38_to_grch37_chr20.over.chain.gz'), 'GRCh37') self.assertTrue(grch37.has_liftover('GRCh38') and grch38.has_liftover('GRCh37')) ds = hl.import_vcf(resource('sample.vcf')) t = ds.annotate_rows(liftover=hl.liftover(hl.liftover(ds.locus, 'GRCh38'), 'GRCh37')).rows() self.assertTrue(t.all(t.locus == t.liftover)) null_locus = hl.null(hl.tlocus('GRCh38')) rows = [ {'l37': hl.locus('20', 1, 'GRCh37'), 'l38': null_locus}, {'l37': hl.locus('20', 60000, 'GRCh37'), 'l38': null_locus}, {'l37': hl.locus('20', 60001, 'GRCh37'), 'l38': hl.locus('chr20', 79360, 'GRCh38')}, {'l37': hl.locus('20', 278686, 'GRCh37'), 'l38': hl.locus('chr20', 298045, 'GRCh38')}, {'l37': hl.locus('20', 278687, 'GRCh37'), 'l38': hl.locus('chr20', 298046, 'GRCh38')}, {'l37': hl.locus('20', 278688, 'GRCh37'), 'l38': null_locus}, {'l37': hl.locus('20', 278689, 'GRCh37'), 'l38': null_locus}, {'l37': hl.locus('20', 278690, 'GRCh37'), 'l38': null_locus}, {'l37': hl.locus('20', 278691, 'GRCh37'), 'l38': hl.locus('chr20', 298047, 'GRCh38')}, {'l37': hl.locus('20', 37007586, 'GRCh37'), 'l38': hl.locus('chr12', 32563117, 'GRCh38')}, {'l37': hl.locus('20', 62965520, 'GRCh37'), 'l38': hl.locus('chr20', 64334167, 'GRCh38')}, {'l37': hl.locus('20', 62965521, 'GRCh37'), 'l38': null_locus} ] schema = hl.tstruct(l37=hl.tlocus(grch37), l38=hl.tlocus(grch38)) t = hl.Table.parallelize(rows, schema) self.assertTrue(t.all(hl.cond(hl.is_defined(t.l38), hl.liftover(t.l37, 'GRCh38') == t.l38, hl.is_missing(hl.liftover(t.l37, 'GRCh38'))))) t = t.filter(hl.is_defined(t.l38)) self.assertTrue(t.count() == 6) t = t.key_by('l38') t.count() self.assertTrue(list(t.key) == ['l38']) null_locus_interval = hl.null(hl.tinterval(hl.tlocus('GRCh38'))) rows = [ {'i37': hl.locus_interval('20', 1, 60000, True, False, 'GRCh37'), 'i38': null_locus_interval}, {'i37': hl.locus_interval('20', 60001, 82456, True, True, 'GRCh37'), 'i38': hl.locus_interval('chr20', 79360, 101815, True, True, 'GRCh38')} ] schema = hl.tstruct(i37=hl.tinterval(hl.tlocus(grch37)), i38=hl.tinterval(hl.tlocus(grch38))) t = hl.Table.parallelize(rows, schema) self.assertTrue(t.all(hl.liftover(t.i37, 'GRCh38') == t.i38)) grch37.remove_liftover("GRCh38") grch38.remove_liftover("GRCh37")
def impute_sex_aggregator(call, aaf, aaf_threshold=0.0, include_par=False, female_threshold=0.4, male_threshold=0.8) -> hl.Table: """:func:`.impute_sex` as an aggregator.""" mt = call._indices.source rg = mt.locus.dtype.reference_genome x_contigs = hl.literal( hl.eval( hl.map(lambda x_contig: hl.parse_locus_interval(x_contig, rg), rg.x_contigs))) inbreeding = hl.agg.inbreeding(call, aaf) is_female = hl.if_else( inbreeding.f_stat < female_threshold, True, hl.if_else(inbreeding.f_stat > male_threshold, False, hl.is_missing('tbool'))) expression = hl.struct(is_female=is_female, **inbreeding) if not include_par: interval_type = hl.tarray(hl.tinterval(hl.tlocus(rg))) par_intervals = hl.literal(rg.par, interval_type) expression = hl.agg.filter( ~par_intervals.any( lambda par_interval: par_interval.contains(mt.locus)), expression) expression = hl.agg.filter( (aaf > aaf_threshold) & (aaf < (1 - aaf_threshold)), expression) expression = hl.agg.filter( x_contigs.any(lambda contig: contig.contains(mt.locus)), expression) return expression
def __init__(self, schema, paths, key, intervals): assert (key is None) == (intervals is None) self.schema = schema self.paths = paths self.key = key if intervals is not None: t = hl.expr.impute_type(intervals) if not isinstance(t, hl.tarray) and not isinstance( t.element_type, hl.tinterval): raise TypeError("'intervals' must be an array of tintervals") pt = t.element_type.point_type if isinstance(pt, hl.tstruct): self._interval_type = t else: self._interval_type = hl.tarray( hl.tinterval(hl.tstruct(__point=pt))) if intervals is not None and t != self._interval_type: self.intervals = [ hl.Interval(hl.Struct(__point=i.start), hl.Struct(__point=i.end), i.includes_start, i.includes_end) for i in intervals ] else: self.intervals = intervals
def test_constructors(self): rg = hl.ReferenceGenome("foo", ["1"], {"1": 100}) schema = hl.tstruct(a=hl.tfloat64, b=hl.tfloat64, c=hl.tint32, d=hl.tint32) rows = [{'a': 2.0, 'b': 4.0, 'c': 1, 'd': 5}] kt = hl.Table.parallelize(rows, schema) kt = kt.annotate(d=hl.int64(kt.d)) kt = kt.annotate(l1=hl.parse_locus("1:51"), l2=hl.locus("1", 51, reference_genome=rg), i1=hl.parse_locus_interval("1:51-56", reference_genome=rg), i2=hl.interval(hl.locus("1", 51, reference_genome=rg), hl.locus("1", 56, reference_genome=rg))) expected_schema = {'a': hl.tfloat64, 'b': hl.tfloat64, 'c': hl.tint32, 'd': hl.tint64, 'l1': hl.tlocus(), 'l2': hl.tlocus(rg), 'i1': hl.tinterval(hl.tlocus(rg)), 'i2': hl.tinterval(hl.tlocus(rg))} self.assertTrue(all([expected_schema[f] == t for f, t in kt.row.dtype.items()]))
def overlaps(self, interval): """True if the the supplied interval contains any value in common with this one. Parameters ---------- interval : :class:`.Interval` Interval object with the same point type. Returns ------- :obj:`bool` """ return hl.eval(hl.literal(self, hl.tinterval(self._point_type)).overlaps(interval))
def values(self): values = [(hl.tbool, True), (hl.tint32, 0), (hl.tint64, 0), (hl.tfloat32, 0.5), (hl.tfloat64, 0.5), (hl.tstr, "foo"), (hl.tstruct(x=hl.tint32), hl.Struct(x=0)), (hl.tarray(hl.tint32), [0, 1, 4]), (hl.tset(hl.tint32), {0, 1, 4}), (hl.tdict(hl.tstr, hl.tint32), { "a": 0, "b": 1, "c": 4 }), (hl.tinterval(hl.tint32), hl.Interval(0, 1, True, False)), (hl.tlocus(hl.default_reference()), hl.Locus("1", 1)), (hl.tcall, hl.Call([0, 1]))] return values
def overlaps(self, interval): """True if the the supplied interval contains any value in common with this one. Parameters ---------- interval : :class:`.Interval` Interval object with the same point type. Returns ------- :obj:`bool` """ return hl.eval( hl.literal(self, hl.tinterval(self._point_type)).overlaps(interval))
def values(self): values = [ (hl.tbool, True), (hl.tint32, 0), (hl.tint64, 0), (hl.tfloat32, 0.5), (hl.tfloat64, 0.5), (hl.tstr, "foo"), (hl.tstruct(x=hl.tint32), hl.Struct(x=0)), (hl.tarray(hl.tint32), [0, 1, 4]), (hl.tset(hl.tint32), {0, 1, 4}), (hl.tdict(hl.tstr, hl.tint32), {"a": 0, "b": 1, "c": 4}), (hl.tinterval(hl.tint32), hl.Interval(0, 1, True, False)), (hl.tlocus(hl.default_reference()), hl.Locus("1", 1)), (hl.tcall, hl.Call([0, 1])) ] return values
def _object_hook(obj): if 'name' not in obj: return obj name = obj['name'] if name == VariantDatasetCombiner.__name__: del obj['name'] obj['vdses'] = [VDSMetadata(*x) for x in obj['vdses']] rg = hl.get_reference(obj['reference_genome']) obj['reference_genome'] = rg intervals_type = hl.tarray(hl.tinterval(hl.tlocus(rg))) intervals = intervals_type._convert_from_json( obj['gvcf_import_intervals']) obj['gvcf_import_intervals'] = intervals return VariantDatasetCombiner(**obj) return obj
def main(): parser = argparse.ArgumentParser( description="Driver for hail's gVCF combiner") parser.add_argument( '--sample-map', help='path to the sample map (must be filesystem local). ' 'The sample map should be tab separated with two columns. ' 'The first column is the sample ID, and the second column ' 'is the gVCF path.\n' 'WARNING: the sample names in the gVCFs will be overwritten', required=True) parser.add_argument( '--tmp-path', help='path to folder for temp output (can be a cloud bucket)', default='/tmp') parser.add_argument('--out-file', '-o', help='path to final combiner output', required=True) parser.add_argument('--json', help='json to use for the import of the gVCFs' '(must be filesystem local)', required=True) parser.add_argument('--header', help='external header, must be cloud based', required=False) args = parser.parse_args() hl.init(default_reference=DEFAULT_REF, log='/hail-joint-caller-' + time.strftime('%Y%m%d-%H%M') + '.log') with open(args.json) as j: ty = hl.tarray( hl.tinterval( hl.tstruct(locus=hl.tlocus(reference_genome='GRCh38')))) intervals = ty._from_json(j.read()) with open(args.sample_map) as m: samples = [l.strip().split('\t') for l in m] run_combiner(samples, intervals, args.out_file, args.tmp_path, args.header, overwrite=True)
def __init__(self, path, intervals, filter_intervals): if intervals is not None: t = hl.expr.impute_type(intervals) if not isinstance(t, hl.tarray) and not isinstance(t.element_type, hl.tinterval): raise TypeError("'intervals' must be an array of tintervals") pt = t.element_type.point_type if isinstance(pt, hl.tstruct): self._interval_type = t else: self._interval_type = hl.tarray(hl.tinterval(hl.tstruct(__point=pt))) self.path = path self.filter_intervals = filter_intervals if intervals is not None and t != self._interval_type: self.intervals = [hl.Interval(hl.Struct(__point=i.start), hl.Struct(__point=i.end), i.includes_start, i.includes_end) for i in intervals] else: self.intervals = intervals
def to_dict(self) -> dict: intervals_typ = hl.tarray( hl.tinterval(hl.tlocus(self.reference_genome))) return { 'name': self.__class__.__name__, 'save_path': self.save_path, 'output_path': self.output_path, 'temp_path': self.temp_path, 'reference_genome': str(self.reference_genome), 'branch_factor': self.branch_factor, 'target_records': self.target_records, 'gvcf_batch_size': self.gvcf_batch_size, 'gvcf_external_header': self.gvcf_external_header, # put this here for humans 'contig_recoding': self.contig_recoding, 'gvcf_info_to_keep': None if self.gvcf_info_to_keep is None else list( self.gvcf_info_to_keep), 'gvcf_reference_entry_fields_to_keep': None if self.gvcf_reference_entry_fields_to_keep is None else list( self.gvcf_reference_entry_fields_to_keep), 'vdses': [ md for i in sorted(self.vdses, reverse=True) for md in self.vdses[i] ], 'gvcfs': self.gvcfs, 'gvcf_sample_names': self.gvcf_sample_names, 'gvcf_import_intervals': intervals_typ._convert_to_json(self.gvcf_import_intervals), }
def main(): parser = argparse.ArgumentParser( description="Driver for hail's gVCF combiner") parser.add_argument( '--sample-map', help='path to the sample map (must be filesystem local)', required=True) parser.add_argument('--sample-file', help='path to a file containing a line separated list' 'of samples to combine (must be filesystem local)') parser.add_argument( '--tmp-path', help='path to folder for temp output (can be a cloud bucket)', default='/tmp') parser.add_argument('--out-file', '-o', help='path to final combiner output', required=True) parser.add_argument( '--summarize', help='if defined, run summarize, placing the rows table ' 'of the output at the argument value') parser.add_argument('--json', help='json to use for the import of the gVCFs' '(must be filesystem local)', required=True) args = parser.parse_args() samples = build_sample_list(args.sample_map, args.sample_file) with open(args.json) as j: ty = hl.tarray( hl.tinterval( hl.tstruct(locus=hl.tlocus(reference_genome='GRCh38')))) intervals = ty._from_json(j.read()) hl.init(default_reference=DEFAULT_REF, log='/hail-joint-caller-' + time.strftime('%Y%m%d-%H%M') + '.log') run_combiner(samples, intervals, args.out_file, args.tmp_path, summary_path=args.summarize, overwrite=True)
def test_segment_intervals(self): intervals = hl.Table.parallelize( [ hl.struct(interval=hl.interval(0, 10)), hl.struct(interval=hl.interval(20, 50)), hl.struct(interval=hl.interval(52, 52)) ], schema=hl.tstruct(interval=hl.tinterval(hl.tint32)), key='interval') points1 = [-1, 5, 30, 40, 52, 53] segmented1 = hl.segment_intervals(intervals, points1) assert segmented1.aggregate( hl.agg.collect(segmented1.interval) == [ hl.interval(0, 5), hl.interval(5, 10), hl.interval(20, 30), hl.interval(30, 40), hl.interval(40, 50), hl.interval(52, 52) ])
def contains(self, value): """True if `value` is contained within the interval. Examples -------- >>> interval2.contains(5) True >>> interval2.contains(6) False Parameters ---------- value : Object with type :meth:`.point_type`. Returns ------- :obj:`bool` """ return hl.eval(hl.literal(self, hl.tinterval(self._point_type)).contains(value))
def contains(self, value): """True if `value` is contained within the interval. Examples -------- >>> interval2.contains(5) True >>> interval2.contains(6) False Parameters ---------- value : Object with type :meth:`.point_type`. Returns ------- :obj:`bool` """ return hl.eval( hl.literal(self, hl.tinterval(self._point_type)).contains(value))
def visit_interval(self, node, visited_children): tinterval, _, angle_bracket, point_t, angle_bracket = visited_children return hl.tinterval(point_t)
def segment_reference_blocks(ref: 'MatrixTable', intervals: 'Table') -> 'MatrixTable': """Returns a matrix table of reference blocks segmented according to intervals. Loci outside the given intervals are discarded. Reference blocks that start before but span an interval will appear at the interval start locus. Note ---- Assumes disjoint intervals which do not span contigs. Requires start-inclusive intervals. Parameters ---------- ref : :class:`.MatrixTable` MatrixTable of reference blocks. intervals : :class:`.Table` Table of intervals at which to segment reference blocks. Returns ------- :class:`.MatrixTable` """ interval_field = list(intervals.key)[0] if not intervals[interval_field].dtype == hl.tinterval(ref.locus.dtype): raise ValueError( f"expect intervals to be keyed by intervals of loci matching the VariantDataset:" f" found {intervals[interval_field].dtype} / {ref.locus.dtype}") intervals = intervals.select(_interval_dup=intervals[interval_field]) if not intervals.aggregate( hl.agg.all(intervals[interval_field].includes_start & (intervals[interval_field].start.contig == intervals[interval_field].end.contig))): raise ValueError("expect intervals to be start-inclusive") starts = intervals.key_by(_start_locus=intervals[interval_field].start) starts = starts.annotate(_include_locus=True) refl = ref.localize_entries('_ref_entries', '_ref_cols') joined = refl.join(starts, how='outer') rg = ref.locus.dtype.reference_genome contigs = rg.contigs contig_idx_map = hl.literal({contigs[i]: i for i in range(len(contigs))}, 'dict<str, int32>') joined = joined.annotate(__contig_idx=contig_idx_map[joined.locus.contig]) joined = joined.annotate(_ref_entries=joined._ref_entries.map( lambda e: e.annotate(__contig_idx=joined.__contig_idx))) dense = joined.annotate(dense_ref=hl.or_missing( joined._include_locus, hl.rbind( joined.locus.position, lambda pos: hl.enumerate( hl.scan._densify(hl.len(joined._ref_cols), joined._ref_entries) ).map(lambda idx_and_e: hl.rbind( idx_and_e[0], idx_and_e[1], lambda idx, e: hl.coalesce( joined._ref_entries[idx], hl.or_missing((e.__contig_idx == joined.__contig_idx) & (e.END >= pos), e))).drop('__contig_idx'))))) dense = dense.filter(dense._include_locus).drop('_interval_dup', '_include_locus', '__contig_idx') # at this point, 'dense' is a table with dense rows of reference blocks, keyed by locus refl_filtered = refl.annotate( **{interval_field: intervals[refl.locus]._interval_dup}) # remove rows that are not contained in an interval, and rows that are the start of an # interval (interval starts come from the 'dense' table) refl_filtered = refl_filtered.filter( hl.is_defined(refl_filtered[interval_field]) & (refl_filtered.locus != refl_filtered[interval_field].start)) # union dense interval starts with filtered table refl_filtered = refl_filtered.union( dense.transmute(_ref_entries=dense.dense_ref)) # rewrite reference blocks to end at the first of (interval end, reference block end) refl_filtered = refl_filtered.annotate( interval_end=refl_filtered[interval_field].end.position - ~refl_filtered[interval_field].includes_end) refl_filtered = refl_filtered.annotate( _ref_entries=refl_filtered._ref_entries.map( lambda entry: entry.annotate(END=hl.min(entry.END, refl_filtered. interval_end)))) return refl_filtered._unlocalize_entries('_ref_entries', '_ref_cols', list(ref.col_key))
def filter_intervals(ds, intervals, keep=True) -> Union[Table, MatrixTable]: """Filter rows with a list of intervals. Examples -------- Filter to loci falling within one interval: >>> ds_result = hl.filter_intervals(dataset, [hl.parse_locus_interval('17:38449840-38530994')]) Remove all loci within list of intervals: >>> intervals = [hl.parse_locus_interval(x) for x in ['1:50M-75M', '2:START-400000', '3-22']] >>> ds_result = hl.filter_intervals(dataset, intervals, keep=False) Notes ----- Based on the ``keep`` argument, this method will either restrict to points in the supplied interval ranges, or remove all rows in those ranges. When ``keep=True``, partitions that don't overlap any supplied interval will not be loaded at all. This enables :func:`.filter_intervals` to be used for reasonably low-latency queries of small ranges of the dataset, even on large datasets. Parameters ---------- ds : :class:`.MatrixTable` or :class:`.Table` Dataset to filter. intervals : :class:`.ArrayExpression` of type :py:data:`.tinterval` Intervals to filter on. The point type of the interval must be a prefix of the key or equal to the first field of the key. keep : :obj:`bool` If ``True``, keep only rows that fall within any interval in `intervals`. If ``False``, keep only rows that fall outside all intervals in `intervals`. Returns ------- :class:`.MatrixTable` or :class:`.Table` """ if isinstance(ds, MatrixTable): k_type = ds.row_key.dtype else: assert isinstance(ds, Table) k_type = ds.key.dtype point_type = intervals.dtype.element_type.point_type def is_struct_prefix(partial, full): if list(partial) != list(full)[:len(partial)]: return False for k, v in partial.items(): if full[k] != v: return False return True if point_type == k_type[0]: needs_wrapper = True point_type = hl.tstruct(foo=point_type) elif isinstance(point_type, tstruct) and is_struct_prefix( point_type, k_type): needs_wrapper = False else: raise TypeError( "The point type is incompatible with key type of the dataset ('{}', '{}')" .format(repr(point_type), repr(k_type))) def wrap_input(interval): if interval is None: raise TypeError( "'filter_intervals' does not allow missing values in 'intervals'." ) elif needs_wrapper: return Interval(Struct(foo=interval.start), Struct(foo=interval.end), interval.includes_start, interval.includes_end) else: return interval intervals_type = intervals.dtype intervals = hl.eval(intervals) intervals = hl.tarray(hl.tinterval(point_type))._convert_to_json( [wrap_input(i) for i in intervals]) if isinstance(ds, MatrixTable): config = { 'name': 'MatrixFilterIntervals', 'keyType': point_type._parsable_string(), 'intervals': intervals, 'keep': keep } return MatrixTable(MatrixToMatrixApply(ds._mir, config)) else: config = { 'name': 'TableFilterIntervals', 'keyType': point_type._parsable_string(), 'intervals': intervals, 'keep': keep } return Table(TableToTableApply(ds._tir, config))
import hail as hl gvcfs = ['gs://hail-common/test-resources/HG00096.g.vcf.gz', 'gs://hail-common/test-resources/HG00268.g.vcf.gz'] hl.init(default_reference='GRCh38') parts_json = [ {'start': {'locus': {'contig': 'chr20', 'position': 17821257}}, 'end': {'locus': {'contig': 'chr20', 'position': 18708366}}, 'includeStart': True, 'includeEnd': True}, {'start': {'locus': {'contig': 'chr20', 'position': 18708367}}, 'end': {'locus': {'contig': 'chr20', 'position': 19776611}}, 'includeStart': True, 'includeEnd': True}, {'start': {'locus': {'contig': 'chr20', 'position': 19776612}}, 'end': {'locus': {'contig': 'chr20', 'position': 21144633}}, 'includeStart': True, 'includeEnd': True}, ] parts = hl.tarray(hl.tinterval(hl.tstruct(locus=hl.tlocus('GRCh38'))))._convert_from_json(parts_json) for mt in hl.import_gvcfs(gvcfs, parts): mt._force_count_rows()
def filter_intervals(ds, intervals, keep=True) -> Union[Table, MatrixTable]: """Filter rows with a list of intervals. Examples -------- Filter to loci falling within one interval: >>> ds_result = hl.filter_intervals(dataset, [hl.parse_locus_interval('17:38449840-38530994')]) Remove all loci within list of intervals: >>> intervals = [hl.parse_locus_interval(x) for x in ['1:50M-75M', '2:START-400000', '3-22']] >>> ds_result = hl.filter_intervals(dataset, intervals, keep=False) Notes ----- Based on the ``keep`` argument, this method will either restrict to points in the supplied interval ranges, or remove all rows in those ranges. When ``keep=True``, partitions that don't overlap any supplied interval will not be loaded at all. This enables :func:`.filter_intervals` to be used for reasonably low-latency queries of small ranges of the dataset, even on large datasets. Parameters ---------- ds : :class:`.MatrixTable` or :class:`.Table` Dataset to filter. intervals : :class:`.ArrayExpression` of type :py:data:`.tinterval` Intervals to filter on. The point type of the interval must be a prefix of the key or equal to the first field of the key. keep : :obj:`bool` If ``True``, keep only rows that fall within any interval in `intervals`. If ``False``, keep only rows that fall outside all intervals in `intervals`. Returns ------- :class:`.MatrixTable` or :class:`.Table` """ if isinstance(ds, MatrixTable): k_type = ds.row_key.dtype else: assert isinstance(ds, Table) k_type = ds.key.dtype point_type = intervals.dtype.element_type.point_type def is_struct_prefix(partial, full): if list(partial) != list(full)[:len(partial)]: return False for k, v in partial.items(): if full[k] != v: return False return True if point_type == k_type[0]: needs_wrapper = True point_type = hl.tstruct(foo=point_type) elif isinstance(point_type, tstruct) and is_struct_prefix(point_type, k_type): needs_wrapper = False else: raise TypeError("The point type is incompatible with key type of the dataset ('{}', '{}')".format(repr(point_type), repr(k_type))) def wrap_input(interval): if interval is None: raise TypeError("'filter_intervals' does not allow missing values in 'intervals'.") elif needs_wrapper: return Interval(Struct(foo=interval.start), Struct(foo=interval.end), interval.includes_start, interval.includes_end) else: return interval intervals_type = intervals.dtype intervals = hl.eval(intervals) intervals = hl.tarray(hl.tinterval(point_type))._convert_to_json([wrap_input(i) for i in intervals]) if isinstance(ds, MatrixTable): config = { 'name': 'MatrixFilterIntervals', 'keyType': point_type._parsable_string(), 'intervals': intervals, 'keep': keep } return MatrixTable(MatrixToMatrixApply(ds._mir, config)) else: config = { 'name': 'TableFilterIntervals', 'keyType': point_type._parsable_string(), 'intervals': intervals, 'keep': keep } return Table(TableToTableApply(ds._tir, config))