Exemple #1
0
def _parameterized_filter_intervals(vds: 'VariantDataset', intervals,
                                    keep: bool, mode: str) -> 'VariantDataset':
    intervals_table = None
    if isinstance(intervals, Table):
        expected = hl.tinterval(hl.tlocus(vds.reference_genome))
        if len(intervals.key) != 1 or intervals.key[0].dtype != hl.tinterval(
                hl.tlocus(vds.reference_genome)):
            raise ValueError(
                f"'filter_intervals': expect a table with a single key of type {expected}; "
                f"found {list(intervals.key.dtype.values())}")
        intervals_table = intervals
        intervals = intervals.aggregate(hl.agg.collect(intervals.key[0]))

    if mode == 'variants_only':
        variant_data = hl.filter_intervals(vds.variant_data, intervals, keep)
        return VariantDataset(vds.reference_data, variant_data)
    if mode == 'split_at_boundaries':
        if not keep:
            raise ValueError(
                "filter_intervals mode 'split_at_boundaries' not implemented for keep=False"
            )
        par_intervals = intervals_table or hl.Table.parallelize(
            intervals.map(lambda x: hl.struct(interval=x)),
            schema=hl.tstruct(interval=intervals.dtype.element_type),
            key='interval')
        ref = segment_reference_blocks(vds.reference_data, par_intervals).drop(
            'interval_end',
            list(par_intervals.key)[0])
        return VariantDataset(
            ref, hl.filter_intervals(vds.variant_data, intervals, keep))

    return VariantDataset(
        hl.filter_intervals(vds.reference_data, intervals, keep),
        hl.filter_intervals(vds.variant_data, intervals, keep))
    def test_reference_genome_liftover(self):
        grch37 = hl.get_reference('GRCh37')
        grch38 = hl.get_reference('GRCh38')

        self.assertTrue(not grch37.has_liftover('GRCh38') and not grch38.has_liftover('GRCh37'))
        grch37.add_liftover(resource('grch37_to_grch38_chr20.over.chain.gz'), 'GRCh38')
        grch38.add_liftover(resource('grch38_to_grch37_chr20.over.chain.gz'), 'GRCh37')
        assert grch37.has_liftover('GRCh38')
        assert grch38.has_liftover('GRCh37')

        ds = hl.import_vcf(resource('sample.vcf'))
        t = ds.annotate_rows(liftover=hl.liftover(hl.liftover(ds.locus, 'GRCh38'), 'GRCh37')).rows()
        assert t.all(t.locus == t.liftover)

        null_locus = hl.null(hl.tlocus('GRCh38'))

        rows = [
            {'l37': hl.locus('20', 1, 'GRCh37'), 'l38': null_locus},
            {'l37': hl.locus('20', 60000, 'GRCh37'), 'l38': null_locus},
            {'l37': hl.locus('20', 60001, 'GRCh37'), 'l38': hl.locus('chr20', 79360, 'GRCh38')},
            {'l37': hl.locus('20', 278686, 'GRCh37'), 'l38': hl.locus('chr20', 298045, 'GRCh38')},
            {'l37': hl.locus('20', 278687, 'GRCh37'), 'l38': hl.locus('chr20', 298046, 'GRCh38')},
            {'l37': hl.locus('20', 278688, 'GRCh37'), 'l38': null_locus},
            {'l37': hl.locus('20', 278689, 'GRCh37'), 'l38': null_locus},
            {'l37': hl.locus('20', 278690, 'GRCh37'), 'l38': null_locus},
            {'l37': hl.locus('20', 278691, 'GRCh37'), 'l38': hl.locus('chr20', 298047, 'GRCh38')},
            {'l37': hl.locus('20', 37007586, 'GRCh37'), 'l38': hl.locus('chr12', 32563117, 'GRCh38')},
            {'l37': hl.locus('20', 62965520, 'GRCh37'), 'l38': hl.locus('chr20', 64334167, 'GRCh38')},
            {'l37': hl.locus('20', 62965521, 'GRCh37'), 'l38': null_locus}
        ]
        schema = hl.tstruct(l37=hl.tlocus(grch37), l38=hl.tlocus(grch38))
        t = hl.Table.parallelize(rows, schema)
        self.assertTrue(t.all(hl.cond(hl.is_defined(t.l38),
                                      hl.liftover(t.l37, 'GRCh38') == t.l38,
                                      hl.is_missing(hl.liftover(t.l37, 'GRCh38')))))

        t = t.filter(hl.is_defined(t.l38))
        self.assertTrue(t.count() == 6)

        t = t.key_by('l38')
        t.count()
        self.assertTrue(list(t.key) == ['l38'])

        null_locus_interval = hl.null(hl.tinterval(hl.tlocus('GRCh38')))
        rows = [
            {'i37': hl.locus_interval('20', 1, 60000, True, False, 'GRCh37'), 'i38': null_locus_interval},
            {'i37': hl.locus_interval('20', 60001, 82456, True, True, 'GRCh37'),
             'i38': hl.locus_interval('chr20', 79360, 101815, True, True, 'GRCh38')}
        ]
        schema = hl.tstruct(i37=hl.tinterval(hl.tlocus(grch37)), i38=hl.tinterval(hl.tlocus(grch38)))
        t = hl.Table.parallelize(rows, schema)
        self.assertTrue(t.all(hl.liftover(t.i37, 'GRCh38') == t.i38))

        grch37.remove_liftover("GRCh38")
        grch38.remove_liftover("GRCh37")
Exemple #3
0
    def test_reference_genome_liftover(self):
        grch37 = hl.get_reference('GRCh37')
        grch38 = hl.get_reference('GRCh38')

        self.assertTrue(not grch37.has_liftover('GRCh38') and not grch38.has_liftover('GRCh37'))
        grch37.add_liftover(resource('grch37_to_grch38_chr20.over.chain.gz'), 'GRCh38')
        grch38.add_liftover(resource('grch38_to_grch37_chr20.over.chain.gz'), 'GRCh37')
        self.assertTrue(grch37.has_liftover('GRCh38') and grch38.has_liftover('GRCh37'))

        ds = hl.import_vcf(resource('sample.vcf'))
        t = ds.annotate_rows(liftover=hl.liftover(hl.liftover(ds.locus, 'GRCh38'), 'GRCh37')).rows()
        self.assertTrue(t.all(t.locus == t.liftover))

        null_locus = hl.null(hl.tlocus('GRCh38'))

        rows = [
            {'l37': hl.locus('20', 1, 'GRCh37'), 'l38': null_locus},
            {'l37': hl.locus('20', 60000, 'GRCh37'), 'l38': null_locus},
            {'l37': hl.locus('20', 60001, 'GRCh37'), 'l38': hl.locus('chr20', 79360, 'GRCh38')},
            {'l37': hl.locus('20', 278686, 'GRCh37'), 'l38': hl.locus('chr20', 298045, 'GRCh38')},
            {'l37': hl.locus('20', 278687, 'GRCh37'), 'l38': hl.locus('chr20', 298046, 'GRCh38')},
            {'l37': hl.locus('20', 278688, 'GRCh37'), 'l38': null_locus},
            {'l37': hl.locus('20', 278689, 'GRCh37'), 'l38': null_locus},
            {'l37': hl.locus('20', 278690, 'GRCh37'), 'l38': null_locus},
            {'l37': hl.locus('20', 278691, 'GRCh37'), 'l38': hl.locus('chr20', 298047, 'GRCh38')},
            {'l37': hl.locus('20', 37007586, 'GRCh37'), 'l38': hl.locus('chr12', 32563117, 'GRCh38')},
            {'l37': hl.locus('20', 62965520, 'GRCh37'), 'l38': hl.locus('chr20', 64334167, 'GRCh38')},
            {'l37': hl.locus('20', 62965521, 'GRCh37'), 'l38': null_locus}
        ]
        schema = hl.tstruct(l37=hl.tlocus(grch37), l38=hl.tlocus(grch38))
        t = hl.Table.parallelize(rows, schema)
        self.assertTrue(t.all(hl.cond(hl.is_defined(t.l38),
                                      hl.liftover(t.l37, 'GRCh38') == t.l38,
                                      hl.is_missing(hl.liftover(t.l37, 'GRCh38')))))

        t = t.filter(hl.is_defined(t.l38))
        self.assertTrue(t.count() == 6)

        t = t.key_by('l38')
        t.count()
        self.assertTrue(list(t.key) == ['l38'])

        null_locus_interval = hl.null(hl.tinterval(hl.tlocus('GRCh38')))
        rows = [
            {'i37': hl.locus_interval('20', 1, 60000, True, False, 'GRCh37'), 'i38': null_locus_interval},
            {'i37': hl.locus_interval('20', 60001, 82456, True, True, 'GRCh37'),
             'i38': hl.locus_interval('chr20', 79360, 101815, True, True, 'GRCh38')}
        ]
        schema = hl.tstruct(i37=hl.tinterval(hl.tlocus(grch37)), i38=hl.tinterval(hl.tlocus(grch38)))
        t = hl.Table.parallelize(rows, schema)
        self.assertTrue(t.all(hl.liftover(t.i37, 'GRCh38') == t.i38))

        grch37.remove_liftover("GRCh38")
        grch38.remove_liftover("GRCh37")
Exemple #4
0
def impute_sex_aggregator(call,
                          aaf,
                          aaf_threshold=0.0,
                          include_par=False,
                          female_threshold=0.4,
                          male_threshold=0.8) -> hl.Table:
    """:func:`.impute_sex` as an aggregator."""
    mt = call._indices.source
    rg = mt.locus.dtype.reference_genome
    x_contigs = hl.literal(
        hl.eval(
            hl.map(lambda x_contig: hl.parse_locus_interval(x_contig, rg),
                   rg.x_contigs)))
    inbreeding = hl.agg.inbreeding(call, aaf)
    is_female = hl.if_else(
        inbreeding.f_stat < female_threshold, True,
        hl.if_else(inbreeding.f_stat > male_threshold, False,
                   hl.is_missing('tbool')))
    expression = hl.struct(is_female=is_female, **inbreeding)
    if not include_par:
        interval_type = hl.tarray(hl.tinterval(hl.tlocus(rg)))
        par_intervals = hl.literal(rg.par, interval_type)
        expression = hl.agg.filter(
            ~par_intervals.any(
                lambda par_interval: par_interval.contains(mt.locus)),
            expression)
    expression = hl.agg.filter(
        (aaf > aaf_threshold) & (aaf < (1 - aaf_threshold)), expression)
    expression = hl.agg.filter(
        x_contigs.any(lambda contig: contig.contains(mt.locus)), expression)

    return expression
Exemple #5
0
    def __init__(self, schema, paths, key, intervals):
        assert (key is None) == (intervals is None)
        self.schema = schema
        self.paths = paths
        self.key = key

        if intervals is not None:
            t = hl.expr.impute_type(intervals)
            if not isinstance(t, hl.tarray) and not isinstance(
                    t.element_type, hl.tinterval):
                raise TypeError("'intervals' must be an array of tintervals")
            pt = t.element_type.point_type
            if isinstance(pt, hl.tstruct):
                self._interval_type = t
            else:
                self._interval_type = hl.tarray(
                    hl.tinterval(hl.tstruct(__point=pt)))

        if intervals is not None and t != self._interval_type:
            self.intervals = [
                hl.Interval(hl.Struct(__point=i.start),
                            hl.Struct(__point=i.end), i.includes_start,
                            i.includes_end) for i in intervals
            ]
        else:
            self.intervals = intervals
Exemple #6
0
    def test_constructors(self):
        rg = hl.ReferenceGenome("foo", ["1"], {"1": 100})

        schema = hl.tstruct(a=hl.tfloat64, b=hl.tfloat64, c=hl.tint32, d=hl.tint32)
        rows = [{'a': 2.0, 'b': 4.0, 'c': 1, 'd': 5}]
        kt = hl.Table.parallelize(rows, schema)
        kt = kt.annotate(d=hl.int64(kt.d))

        kt = kt.annotate(l1=hl.parse_locus("1:51"),
                         l2=hl.locus("1", 51, reference_genome=rg),
                         i1=hl.parse_locus_interval("1:51-56", reference_genome=rg),
                         i2=hl.interval(hl.locus("1", 51, reference_genome=rg),
                                        hl.locus("1", 56, reference_genome=rg)))

        expected_schema = {'a': hl.tfloat64, 'b': hl.tfloat64, 'c': hl.tint32, 'd': hl.tint64,
                           'l1': hl.tlocus(), 'l2': hl.tlocus(rg),
                           'i1': hl.tinterval(hl.tlocus(rg)), 'i2': hl.tinterval(hl.tlocus(rg))}

        self.assertTrue(all([expected_schema[f] == t for f, t in kt.row.dtype.items()]))
Exemple #7
0
    def overlaps(self, interval):
        """True if the the supplied interval contains any value in common with this one.

        Parameters
        ----------
        interval : :class:`.Interval`
            Interval object with the same point type.

        Returns
        -------
        :obj:`bool`
        """

        return hl.eval(hl.literal(self, hl.tinterval(self._point_type)).overlaps(interval))
Exemple #8
0
 def values(self):
     values = [(hl.tbool, True), (hl.tint32, 0), (hl.tint64, 0),
               (hl.tfloat32, 0.5), (hl.tfloat64, 0.5), (hl.tstr, "foo"),
               (hl.tstruct(x=hl.tint32), hl.Struct(x=0)),
               (hl.tarray(hl.tint32), [0, 1, 4]),
               (hl.tset(hl.tint32), {0, 1, 4}),
               (hl.tdict(hl.tstr, hl.tint32), {
                   "a": 0,
                   "b": 1,
                   "c": 4
               }), (hl.tinterval(hl.tint32), hl.Interval(0, 1, True,
                                                         False)),
               (hl.tlocus(hl.default_reference()), hl.Locus("1", 1)),
               (hl.tcall, hl.Call([0, 1]))]
     return values
Exemple #9
0
    def overlaps(self, interval):
        """True if the the supplied interval contains any value in common with this one.

        Parameters
        ----------
        interval : :class:`.Interval`
            Interval object with the same point type.

        Returns
        -------
        :obj:`bool`
        """

        return hl.eval(
            hl.literal(self,
                       hl.tinterval(self._point_type)).overlaps(interval))
Exemple #10
0
 def values(self):
     values = [
         (hl.tbool, True),
         (hl.tint32, 0),
         (hl.tint64, 0),
         (hl.tfloat32, 0.5),
         (hl.tfloat64, 0.5),
         (hl.tstr, "foo"),
         (hl.tstruct(x=hl.tint32), hl.Struct(x=0)),
         (hl.tarray(hl.tint32), [0, 1, 4]),
         (hl.tset(hl.tint32), {0, 1, 4}),
         (hl.tdict(hl.tstr, hl.tint32), {"a": 0, "b": 1, "c": 4}),
         (hl.tinterval(hl.tint32), hl.Interval(0, 1, True, False)),
         (hl.tlocus(hl.default_reference()), hl.Locus("1", 1)),
         (hl.tcall, hl.Call([0, 1]))
     ]
     return values
    def _object_hook(obj):
        if 'name' not in obj:
            return obj
        name = obj['name']
        if name == VariantDatasetCombiner.__name__:
            del obj['name']
            obj['vdses'] = [VDSMetadata(*x) for x in obj['vdses']]

            rg = hl.get_reference(obj['reference_genome'])
            obj['reference_genome'] = rg
            intervals_type = hl.tarray(hl.tinterval(hl.tlocus(rg)))
            intervals = intervals_type._convert_from_json(
                obj['gvcf_import_intervals'])
            obj['gvcf_import_intervals'] = intervals

            return VariantDatasetCombiner(**obj)
        return obj
Exemple #12
0
def main():
    parser = argparse.ArgumentParser(
        description="Driver for hail's gVCF combiner")
    parser.add_argument(
        '--sample-map',
        help='path to the sample map (must be filesystem local). '
        'The sample map should be tab separated with two columns. '
        'The first column is the sample ID, and the second column '
        'is the gVCF path.\n'
        'WARNING: the sample names in the gVCFs will be overwritten',
        required=True)
    parser.add_argument(
        '--tmp-path',
        help='path to folder for temp output (can be a cloud bucket)',
        default='/tmp')
    parser.add_argument('--out-file',
                        '-o',
                        help='path to final combiner output',
                        required=True)
    parser.add_argument('--json',
                        help='json to use for the import of the gVCFs'
                        '(must be filesystem local)',
                        required=True)
    parser.add_argument('--header',
                        help='external header, must be cloud based',
                        required=False)
    args = parser.parse_args()
    hl.init(default_reference=DEFAULT_REF,
            log='/hail-joint-caller-' + time.strftime('%Y%m%d-%H%M') + '.log')
    with open(args.json) as j:
        ty = hl.tarray(
            hl.tinterval(
                hl.tstruct(locus=hl.tlocus(reference_genome='GRCh38'))))
        intervals = ty._from_json(j.read())
    with open(args.sample_map) as m:
        samples = [l.strip().split('\t') for l in m]
    run_combiner(samples,
                 intervals,
                 args.out_file,
                 args.tmp_path,
                 args.header,
                 overwrite=True)
Exemple #13
0
    def __init__(self, path, intervals, filter_intervals):
        if intervals is not None:
            t = hl.expr.impute_type(intervals)
            if not isinstance(t, hl.tarray) and not isinstance(t.element_type, hl.tinterval):
                raise TypeError("'intervals' must be an array of tintervals")
            pt = t.element_type.point_type
            if isinstance(pt, hl.tstruct):
                self._interval_type = t
            else:
                self._interval_type = hl.tarray(hl.tinterval(hl.tstruct(__point=pt)))

        self.path = path
        self.filter_intervals = filter_intervals
        if intervals is not None and t != self._interval_type:
            self.intervals = [hl.Interval(hl.Struct(__point=i.start),
                                          hl.Struct(__point=i.end),
                                          i.includes_start,
                                          i.includes_end) for i in intervals]
        else:
            self.intervals = intervals
 def to_dict(self) -> dict:
     intervals_typ = hl.tarray(
         hl.tinterval(hl.tlocus(self.reference_genome)))
     return {
         'name':
         self.__class__.__name__,
         'save_path':
         self.save_path,
         'output_path':
         self.output_path,
         'temp_path':
         self.temp_path,
         'reference_genome':
         str(self.reference_genome),
         'branch_factor':
         self.branch_factor,
         'target_records':
         self.target_records,
         'gvcf_batch_size':
         self.gvcf_batch_size,
         'gvcf_external_header':
         self.gvcf_external_header,  # put this here for humans
         'contig_recoding':
         self.contig_recoding,
         'gvcf_info_to_keep':
         None if self.gvcf_info_to_keep is None else list(
             self.gvcf_info_to_keep),
         'gvcf_reference_entry_fields_to_keep':
         None if self.gvcf_reference_entry_fields_to_keep is None else list(
             self.gvcf_reference_entry_fields_to_keep),
         'vdses': [
             md for i in sorted(self.vdses, reverse=True)
             for md in self.vdses[i]
         ],
         'gvcfs':
         self.gvcfs,
         'gvcf_sample_names':
         self.gvcf_sample_names,
         'gvcf_import_intervals':
         intervals_typ._convert_to_json(self.gvcf_import_intervals),
     }
Exemple #15
0
def main():
    parser = argparse.ArgumentParser(
        description="Driver for hail's gVCF combiner")
    parser.add_argument(
        '--sample-map',
        help='path to the sample map (must be filesystem local)',
        required=True)
    parser.add_argument('--sample-file',
                        help='path to a file containing a line separated list'
                        'of samples to combine (must be filesystem local)')
    parser.add_argument(
        '--tmp-path',
        help='path to folder for temp output (can be a cloud bucket)',
        default='/tmp')
    parser.add_argument('--out-file',
                        '-o',
                        help='path to final combiner output',
                        required=True)
    parser.add_argument(
        '--summarize',
        help='if defined, run summarize, placing the rows table '
        'of the output at the argument value')
    parser.add_argument('--json',
                        help='json to use for the import of the gVCFs'
                        '(must be filesystem local)',
                        required=True)
    args = parser.parse_args()
    samples = build_sample_list(args.sample_map, args.sample_file)
    with open(args.json) as j:
        ty = hl.tarray(
            hl.tinterval(
                hl.tstruct(locus=hl.tlocus(reference_genome='GRCh38'))))
        intervals = ty._from_json(j.read())
    hl.init(default_reference=DEFAULT_REF,
            log='/hail-joint-caller-' + time.strftime('%Y%m%d-%H%M') + '.log')
    run_combiner(samples,
                 intervals,
                 args.out_file,
                 args.tmp_path,
                 summary_path=args.summarize,
                 overwrite=True)
Exemple #16
0
    def test_segment_intervals(self):
        intervals = hl.Table.parallelize(
            [
                hl.struct(interval=hl.interval(0, 10)),
                hl.struct(interval=hl.interval(20, 50)),
                hl.struct(interval=hl.interval(52, 52))
            ],
            schema=hl.tstruct(interval=hl.tinterval(hl.tint32)),
            key='interval')

        points1 = [-1, 5, 30, 40, 52, 53]

        segmented1 = hl.segment_intervals(intervals, points1)

        assert segmented1.aggregate(
            hl.agg.collect(segmented1.interval) == [
                hl.interval(0, 5),
                hl.interval(5, 10),
                hl.interval(20, 30),
                hl.interval(30, 40),
                hl.interval(40, 50),
                hl.interval(52, 52)
            ])
Exemple #17
0
    def contains(self, value):
        """True if `value` is contained within the interval.

        Examples
        --------

        >>> interval2.contains(5)
        True

        >>> interval2.contains(6)
        False

        Parameters
        ----------
        value :
            Object with type :meth:`.point_type`.

        Returns
        -------
        :obj:`bool`
        """

        return hl.eval(hl.literal(self, hl.tinterval(self._point_type)).contains(value))
Exemple #18
0
    def contains(self, value):
        """True if `value` is contained within the interval.

        Examples
        --------

        >>> interval2.contains(5)
        True

        >>> interval2.contains(6)
        False

        Parameters
        ----------
        value :
            Object with type :meth:`.point_type`.

        Returns
        -------
        :obj:`bool`
        """

        return hl.eval(
            hl.literal(self, hl.tinterval(self._point_type)).contains(value))
Exemple #19
0
 def visit_interval(self, node, visited_children):
     tinterval, _, angle_bracket, point_t, angle_bracket = visited_children
     return hl.tinterval(point_t)
Exemple #20
0
def segment_reference_blocks(ref: 'MatrixTable',
                             intervals: 'Table') -> 'MatrixTable':
    """Returns a matrix table of reference blocks segmented according to intervals.

    Loci outside the given intervals are discarded. Reference blocks that start before
    but span an interval will appear at the interval start locus.

    Note
    ----
        Assumes disjoint intervals which do not span contigs.

        Requires start-inclusive intervals.

    Parameters
    ----------
    ref : :class:`.MatrixTable`
        MatrixTable of reference blocks.
    intervals : :class:`.Table`
        Table of intervals at which to segment reference blocks.

    Returns
    -------
    :class:`.MatrixTable`
    """
    interval_field = list(intervals.key)[0]
    if not intervals[interval_field].dtype == hl.tinterval(ref.locus.dtype):
        raise ValueError(
            f"expect intervals to be keyed by intervals of loci matching the VariantDataset:"
            f" found {intervals[interval_field].dtype} / {ref.locus.dtype}")
    intervals = intervals.select(_interval_dup=intervals[interval_field])

    if not intervals.aggregate(
            hl.agg.all(intervals[interval_field].includes_start
                       & (intervals[interval_field].start.contig ==
                          intervals[interval_field].end.contig))):
        raise ValueError("expect intervals to be start-inclusive")

    starts = intervals.key_by(_start_locus=intervals[interval_field].start)
    starts = starts.annotate(_include_locus=True)
    refl = ref.localize_entries('_ref_entries', '_ref_cols')
    joined = refl.join(starts, how='outer')
    rg = ref.locus.dtype.reference_genome
    contigs = rg.contigs
    contig_idx_map = hl.literal({contigs[i]: i
                                 for i in range(len(contigs))},
                                'dict<str, int32>')
    joined = joined.annotate(__contig_idx=contig_idx_map[joined.locus.contig])
    joined = joined.annotate(_ref_entries=joined._ref_entries.map(
        lambda e: e.annotate(__contig_idx=joined.__contig_idx)))
    dense = joined.annotate(dense_ref=hl.or_missing(
        joined._include_locus,
        hl.rbind(
            joined.locus.position, lambda pos: hl.enumerate(
                hl.scan._densify(hl.len(joined._ref_cols), joined._ref_entries)
            ).map(lambda idx_and_e: hl.rbind(
                idx_and_e[0], idx_and_e[1], lambda idx, e: hl.coalesce(
                    joined._ref_entries[idx],
                    hl.or_missing((e.__contig_idx == joined.__contig_idx) &
                                  (e.END >= pos), e))).drop('__contig_idx')))))
    dense = dense.filter(dense._include_locus).drop('_interval_dup',
                                                    '_include_locus',
                                                    '__contig_idx')

    # at this point, 'dense' is a table with dense rows of reference blocks, keyed by locus

    refl_filtered = refl.annotate(
        **{interval_field: intervals[refl.locus]._interval_dup})

    # remove rows that are not contained in an interval, and rows that are the start of an
    # interval (interval starts come from the 'dense' table)
    refl_filtered = refl_filtered.filter(
        hl.is_defined(refl_filtered[interval_field])
        & (refl_filtered.locus != refl_filtered[interval_field].start))

    # union dense interval starts with filtered table
    refl_filtered = refl_filtered.union(
        dense.transmute(_ref_entries=dense.dense_ref))

    # rewrite reference blocks to end at the first of (interval end, reference block end)
    refl_filtered = refl_filtered.annotate(
        interval_end=refl_filtered[interval_field].end.position -
        ~refl_filtered[interval_field].includes_end)
    refl_filtered = refl_filtered.annotate(
        _ref_entries=refl_filtered._ref_entries.map(
            lambda entry: entry.annotate(END=hl.min(entry.END, refl_filtered.
                                                    interval_end))))

    return refl_filtered._unlocalize_entries('_ref_entries', '_ref_cols',
                                             list(ref.col_key))
Exemple #21
0
def filter_intervals(ds, intervals, keep=True) -> Union[Table, MatrixTable]:
    """Filter rows with a list of intervals.

    Examples
    --------

    Filter to loci falling within one interval:

    >>> ds_result = hl.filter_intervals(dataset, [hl.parse_locus_interval('17:38449840-38530994')])

    Remove all loci within list of intervals:

    >>> intervals = [hl.parse_locus_interval(x) for x in ['1:50M-75M', '2:START-400000', '3-22']]
    >>> ds_result = hl.filter_intervals(dataset, intervals, keep=False)

    Notes
    -----
    Based on the ``keep`` argument, this method will either restrict to points
    in the supplied interval ranges, or remove all rows in those ranges.

    When ``keep=True``, partitions that don't overlap any supplied interval
    will not be loaded at all.  This enables :func:`.filter_intervals` to be
    used for reasonably low-latency queries of small ranges of the dataset, even
    on large datasets.

    Parameters
    ----------
    ds : :class:`.MatrixTable` or :class:`.Table`
        Dataset to filter.
    intervals : :class:`.ArrayExpression` of type :py:data:`.tinterval`
        Intervals to filter on.  The point type of the interval must
        be a prefix of the key or equal to the first field of the key.
    keep : :obj:`bool`
        If ``True``, keep only rows that fall within any interval in `intervals`.
        If ``False``, keep only rows that fall outside all intervals in
        `intervals`.

    Returns
    -------
    :class:`.MatrixTable` or :class:`.Table`

    """

    if isinstance(ds, MatrixTable):
        k_type = ds.row_key.dtype
    else:
        assert isinstance(ds, Table)
        k_type = ds.key.dtype

    point_type = intervals.dtype.element_type.point_type

    def is_struct_prefix(partial, full):
        if list(partial) != list(full)[:len(partial)]:
            return False
        for k, v in partial.items():
            if full[k] != v:
                return False
        return True

    if point_type == k_type[0]:
        needs_wrapper = True
        point_type = hl.tstruct(foo=point_type)
    elif isinstance(point_type, tstruct) and is_struct_prefix(
            point_type, k_type):
        needs_wrapper = False
    else:
        raise TypeError(
            "The point type is incompatible with key type of the dataset ('{}', '{}')"
            .format(repr(point_type), repr(k_type)))

    def wrap_input(interval):
        if interval is None:
            raise TypeError(
                "'filter_intervals' does not allow missing values in 'intervals'."
            )
        elif needs_wrapper:
            return Interval(Struct(foo=interval.start),
                            Struct(foo=interval.end), interval.includes_start,
                            interval.includes_end)
        else:
            return interval

    intervals_type = intervals.dtype
    intervals = hl.eval(intervals)
    intervals = hl.tarray(hl.tinterval(point_type))._convert_to_json(
        [wrap_input(i) for i in intervals])

    if isinstance(ds, MatrixTable):
        config = {
            'name': 'MatrixFilterIntervals',
            'keyType': point_type._parsable_string(),
            'intervals': intervals,
            'keep': keep
        }
        return MatrixTable(MatrixToMatrixApply(ds._mir, config))
    else:
        config = {
            'name': 'TableFilterIntervals',
            'keyType': point_type._parsable_string(),
            'intervals': intervals,
            'keep': keep
        }
        return Table(TableToTableApply(ds._tir, config))
import hail as hl

gvcfs = ['gs://hail-common/test-resources/HG00096.g.vcf.gz',
         'gs://hail-common/test-resources/HG00268.g.vcf.gz']
hl.init(default_reference='GRCh38')
parts_json = [
    {'start': {'locus': {'contig': 'chr20', 'position': 17821257}},
     'end': {'locus': {'contig': 'chr20', 'position': 18708366}},
     'includeStart': True,
     'includeEnd': True},
    {'start': {'locus': {'contig': 'chr20', 'position': 18708367}},
     'end': {'locus': {'contig': 'chr20', 'position': 19776611}},
     'includeStart': True,
     'includeEnd': True},
    {'start': {'locus': {'contig': 'chr20', 'position': 19776612}},
     'end': {'locus': {'contig': 'chr20', 'position': 21144633}},
     'includeStart': True,
     'includeEnd': True},
]

parts = hl.tarray(hl.tinterval(hl.tstruct(locus=hl.tlocus('GRCh38'))))._convert_from_json(parts_json)
for mt in hl.import_gvcfs(gvcfs, parts):
    mt._force_count_rows()
Exemple #23
0
def filter_intervals(ds, intervals, keep=True) -> Union[Table, MatrixTable]:
    """Filter rows with a list of intervals.

    Examples
    --------

    Filter to loci falling within one interval:

    >>> ds_result = hl.filter_intervals(dataset, [hl.parse_locus_interval('17:38449840-38530994')])

    Remove all loci within list of intervals:

    >>> intervals = [hl.parse_locus_interval(x) for x in ['1:50M-75M', '2:START-400000', '3-22']]
    >>> ds_result = hl.filter_intervals(dataset, intervals, keep=False)

    Notes
    -----
    Based on the ``keep`` argument, this method will either restrict to points
    in the supplied interval ranges, or remove all rows in those ranges.

    When ``keep=True``, partitions that don't overlap any supplied interval
    will not be loaded at all.  This enables :func:`.filter_intervals` to be
    used for reasonably low-latency queries of small ranges of the dataset, even
    on large datasets.

    Parameters
    ----------
    ds : :class:`.MatrixTable` or :class:`.Table`
        Dataset to filter.
    intervals : :class:`.ArrayExpression` of type :py:data:`.tinterval`
        Intervals to filter on.  The point type of the interval must
        be a prefix of the key or equal to the first field of the key.
    keep : :obj:`bool`
        If ``True``, keep only rows that fall within any interval in `intervals`.
        If ``False``, keep only rows that fall outside all intervals in
        `intervals`.

    Returns
    -------
    :class:`.MatrixTable` or :class:`.Table`

    """

    if isinstance(ds, MatrixTable):
        k_type = ds.row_key.dtype
    else:
        assert isinstance(ds, Table)
        k_type = ds.key.dtype

    point_type = intervals.dtype.element_type.point_type

    def is_struct_prefix(partial, full):
        if list(partial) != list(full)[:len(partial)]:
            return False
        for k, v in partial.items():
            if full[k] != v:
                return False
        return True

    if point_type == k_type[0]:
        needs_wrapper = True
        point_type = hl.tstruct(foo=point_type)
    elif isinstance(point_type, tstruct) and is_struct_prefix(point_type, k_type):
        needs_wrapper = False
    else:
        raise TypeError("The point type is incompatible with key type of the dataset ('{}', '{}')".format(repr(point_type), repr(k_type)))

    def wrap_input(interval):
        if interval is None:
            raise TypeError("'filter_intervals' does not allow missing values in 'intervals'.")
        elif needs_wrapper:
            return Interval(Struct(foo=interval.start),
                            Struct(foo=interval.end),
                            interval.includes_start,
                            interval.includes_end)
        else:
            return interval

    intervals_type = intervals.dtype
    intervals = hl.eval(intervals)
    intervals = hl.tarray(hl.tinterval(point_type))._convert_to_json([wrap_input(i) for i in intervals])

    if isinstance(ds, MatrixTable):
        config = {
            'name': 'MatrixFilterIntervals',
            'keyType': point_type._parsable_string(),
            'intervals': intervals,
            'keep': keep
        }
        return MatrixTable(MatrixToMatrixApply(ds._mir, config))
    else:
        config = {
            'name': 'TableFilterIntervals',
            'keyType': point_type._parsable_string(),
            'intervals': intervals,
            'keep': keep
        }
        return Table(TableToTableApply(ds._tir, config))
Exemple #24
0
 def visit_interval(self, node, visited_children):
     tinterval, _, angle_bracket, point_t, angle_bracket = visited_children
     return hl.tinterval(point_t)