Exemple #1
0
 def test_agg_explode(self):
     t = hl.Table.parallelize([
         hl.struct(a=[1, 2]),
         hl.struct(a=hl.empty_array(hl.tint32)),
         hl.struct(a=hl.null(hl.tarray(hl.tint32))),
         hl.struct(a=[3]),
         hl.struct(a=[hl.null(hl.tint32)])
     ])
     self.assertCountEqual(t.aggregate(hl.agg.explode(lambda elt: hl.agg.collect(elt), t.a)),
                           [1, 2, None, 3])
Exemple #2
0
 def test_multi_way_zip_join_globals(self):
     t1 = hl.utils.range_table(1).annotate_globals(x=hl.null(hl.tint32))
     t2 = hl.utils.range_table(1).annotate_globals(x=5)
     t3 = hl.utils.range_table(1).annotate_globals(x=0)
     expected = hl.struct(__globals=hl.array([
         hl.struct(x=hl.null(hl.tint32)),
         hl.struct(x=5),
         hl.struct(x=0)]))
     joined = hl.Table._multi_way_zip_join([t1, t2, t3], '__data', '__globals')
     self.assertEqual(hl.eval(joined.globals), hl.eval(expected))
Exemple #3
0
    def test_reference_genome_liftover(self):
        grch37 = hl.get_reference('GRCh37')
        grch38 = hl.get_reference('GRCh38')

        self.assertTrue(not grch37.has_liftover('GRCh38') and not grch38.has_liftover('GRCh37'))
        grch37.add_liftover(resource('grch37_to_grch38_chr20.over.chain.gz'), 'GRCh38')
        grch38.add_liftover(resource('grch38_to_grch37_chr20.over.chain.gz'), 'GRCh37')
        self.assertTrue(grch37.has_liftover('GRCh38') and grch38.has_liftover('GRCh37'))

        ds = hl.import_vcf(resource('sample.vcf'))
        t = ds.annotate_rows(liftover=hl.liftover(hl.liftover(ds.locus, 'GRCh38'), 'GRCh37')).rows()
        self.assertTrue(t.all(t.locus == t.liftover))

        null_locus = hl.null(hl.tlocus('GRCh38'))

        rows = [
            {'l37': hl.locus('20', 1, 'GRCh37'), 'l38': null_locus},
            {'l37': hl.locus('20', 60000, 'GRCh37'), 'l38': null_locus},
            {'l37': hl.locus('20', 60001, 'GRCh37'), 'l38': hl.locus('chr20', 79360, 'GRCh38')},
            {'l37': hl.locus('20', 278686, 'GRCh37'), 'l38': hl.locus('chr20', 298045, 'GRCh38')},
            {'l37': hl.locus('20', 278687, 'GRCh37'), 'l38': hl.locus('chr20', 298046, 'GRCh38')},
            {'l37': hl.locus('20', 278688, 'GRCh37'), 'l38': null_locus},
            {'l37': hl.locus('20', 278689, 'GRCh37'), 'l38': null_locus},
            {'l37': hl.locus('20', 278690, 'GRCh37'), 'l38': null_locus},
            {'l37': hl.locus('20', 278691, 'GRCh37'), 'l38': hl.locus('chr20', 298047, 'GRCh38')},
            {'l37': hl.locus('20', 37007586, 'GRCh37'), 'l38': hl.locus('chr12', 32563117, 'GRCh38')},
            {'l37': hl.locus('20', 62965520, 'GRCh37'), 'l38': hl.locus('chr20', 64334167, 'GRCh38')},
            {'l37': hl.locus('20', 62965521, 'GRCh37'), 'l38': null_locus}
        ]
        schema = hl.tstruct(l37=hl.tlocus(grch37), l38=hl.tlocus(grch38))
        t = hl.Table.parallelize(rows, schema)
        self.assertTrue(t.all(hl.cond(hl.is_defined(t.l38),
                                      hl.liftover(t.l37, 'GRCh38') == t.l38,
                                      hl.is_missing(hl.liftover(t.l37, 'GRCh38')))))

        t = t.filter(hl.is_defined(t.l38))
        self.assertTrue(t.count() == 6)

        t = t.key_by('l38')
        t.count()
        self.assertTrue(list(t.key) == ['l38'])

        null_locus_interval = hl.null(hl.tinterval(hl.tlocus('GRCh38')))
        rows = [
            {'i37': hl.locus_interval('20', 1, 60000, True, False, 'GRCh37'), 'i38': null_locus_interval},
            {'i37': hl.locus_interval('20', 60001, 82456, True, True, 'GRCh37'),
             'i38': hl.locus_interval('chr20', 79360, 101815, True, True, 'GRCh38')}
        ]
        schema = hl.tstruct(i37=hl.tinterval(hl.tlocus(grch37)), i38=hl.tinterval(hl.tlocus(grch38)))
        t = hl.Table.parallelize(rows, schema)
        self.assertTrue(t.all(hl.liftover(t.i37, 'GRCh38') == t.i38))

        grch37.remove_liftover("GRCh38")
        grch38.remove_liftover("GRCh37")
Exemple #4
0
 def test_refs_with_process_joins(self):
     mt = hl.utils.range_matrix_table(10, 10)
     mt = mt.annotate_entries(
         a_literal=hl.literal(['a']),
         a_col_join=hl.is_defined(mt.cols()[mt.col_key]),
         a_row_join=hl.is_defined(mt.rows()[mt.row_key]),
         an_entry_join=hl.is_defined(mt[mt.row_key, mt.col_key]),
         the_global_failure=hl.cond(True, mt.globals, hl.null(mt.globals.dtype)),
         the_row_failure=hl.cond(True, mt.row, hl.null(mt.row.dtype)),
         the_col_failure=hl.cond(True, mt.col, hl.null(mt.col.dtype)),
         the_entry_failure=hl.cond(True, mt.entry, hl.null(mt.entry.dtype)),
     )
     mt.count()
Exemple #5
0
    def test_export_import_plink_same(self):
        mt = get_dataset()
        mt = mt.select_rows(rsid=hl.delimit([mt.locus.contig, hl.str(mt.locus.position), mt.alleles[0], mt.alleles[1]], ':'),
                            cm_position=15.0)
        mt = mt.select_cols(fam_id=hl.null(hl.tstr), pat_id=hl.null(hl.tstr), mat_id=hl.null(hl.tstr),
                            is_female=hl.null(hl.tbool), is_case=hl.null(hl.tbool))
        mt = mt.select_entries('GT')

        bfile = '/tmp/test_import_export_plink'
        hl.export_plink(mt, bfile, ind_id=mt.s, cm_position=mt.cm_position)

        mt_imported = hl.import_plink(bfile + '.bed', bfile + '.bim', bfile + '.fam',
                                      a2_reference=True, reference_genome='GRCh37')
        self.assertTrue(mt._same(mt_imported))
        self.assertTrue(mt.aggregate_rows(hl.agg.all(mt.cm_position == 15.0)))
Exemple #6
0
    def test_from_entry_expr_options(self):
        def build_mt(a):
            data = [{'v': 0, 's': 0, 'x': a[0]},
                    {'v': 0, 's': 1, 'x': a[1]},
                    {'v': 0, 's': 2, 'x': a[2]}]
            ht = hl.Table.parallelize(data, hl.dtype('struct{v: int32, s: int32, x: float64}'))
            mt = ht.to_matrix_table(['v'], ['s'])
            ids = mt.key_cols_by()['s'].collect()
            return mt.choose_cols([ids.index(0), ids.index(1), ids.index(2)])

        def check(expr, mean_impute, center, normalize, expected):
            actual = np.squeeze(BlockMatrix.from_entry_expr(expr,
                                                            mean_impute=mean_impute,
                                                            center=center,
                                                            normalize=normalize).to_numpy())
            assert np.allclose(actual, expected)

        a = np.array([0.0, 1.0, 2.0])

        mt = build_mt(a)
        check(mt.x, False, False, False, a)
        check(mt.x, False, True, False, a - 1.0)
        check(mt.x, False, False, True, a / np.sqrt(5))
        check(mt.x, False, True, True, (a - 1.0) / np.sqrt(2))
        check(mt.x + 1 - 1, False, False, False, a)

        mt = build_mt([0.0, hl.null('float64'), 2.0])
        check(mt.x, True, False, False, a)
        check(mt.x, True, True, False, a - 1.0)
        check(mt.x, True, False, True, a / np.sqrt(5))
        check(mt.x, True, True, True, (a - 1.0) / np.sqrt(2))
        with self.assertRaises(Exception):
            BlockMatrix.from_entry_expr(mt.x)
Exemple #7
0
    def test_annotate_intervals(self):
        ds = get_dataset()

        bed1 = hl.import_bed(resource('example1.bed'), reference_genome='GRCh37')
        bed2 = hl.import_bed(resource('example2.bed'), reference_genome='GRCh37')
        bed3 = hl.import_bed(resource('example3.bed'), reference_genome='GRCh37')
        self.assertTrue(list(bed2.key.dtype) == ['interval'])
        self.assertTrue(list(bed2.row.dtype) == ['interval', 'target'])

        interval_list1 = hl.import_locus_intervals(resource('exampleAnnotation1.interval_list'))
        interval_list2 = hl.import_locus_intervals(resource('exampleAnnotation2.interval_list'))
        self.assertTrue(list(interval_list2.key.dtype) == ['interval'])
        self.assertTrue(list(interval_list2.row.dtype) == ['interval', 'target'])

        ann = ds.annotate_rows(in_interval=bed1[ds.locus]).rows()
        self.assertTrue(ann.all((ann.locus.position <= 14000000) |
                                (ann.locus.position >= 17000000) |
                                (hl.is_missing(ann.in_interval))))

        for bed in [bed2, bed3]:
            ann = ds.annotate_rows(target=bed[ds.locus].target).rows()
            expr = (hl.case()
                    .when(ann.locus.position <= 14000000, ann.target == 'gene1')
                    .when(ann.locus.position >= 17000000, ann.target == 'gene2')
                    .default(ann.target == hl.null(hl.tstr)))
            self.assertTrue(ann.all(expr))

        self.assertTrue(ds.annotate_rows(in_interval=interval_list1[ds.locus]).rows()
                        ._same(ds.annotate_rows(in_interval=bed1[ds.locus]).rows()))

        self.assertTrue(ds.annotate_rows(target=interval_list2[ds.locus].target).rows()
                        ._same(ds.annotate_rows(target=bed2[ds.locus].target).rows()))
Exemple #8
0
    def test_aggregate_ir(self):
        ds = (hl.utils.range_matrix_table(5, 5)
              .annotate_globals(g1=5)
              .annotate_entries(e1=3))

        x = [("col_idx", lambda e: ds.aggregate_cols(e)),
             ("row_idx", lambda e: ds.aggregate_rows(e))]

        for name, f in x:
            r = f(hl.struct(x=agg.sum(ds[name]) + ds.g1,
                            y=agg.filter(ds[name] % 2 != 0, agg.sum(ds[name] + 2)) + ds.g1,
                            z=agg.sum(ds.g1 + ds[name]) + ds.g1,
                            mean=agg.mean(ds[name])))
            self.assertEqual(convert_struct_to_dict(r), {u'x': 15, u'y': 13, u'z': 40, u'mean': 2.0})

            r = f(5)
            self.assertEqual(r, 5)

            r = f(hl.null(hl.tint32))
            self.assertEqual(r, None)

            r = f(agg.filter(ds[name] % 2 != 0, agg.sum(ds[name] + 2)) + ds.g1)
            self.assertEqual(r, 13)

        r = ds.aggregate_entries(agg.filter((ds.row_idx % 2 != 0) & (ds.col_idx % 2 != 0),
                                            agg.sum(ds.e1 + ds.g1 + ds.row_idx + ds.col_idx)) + ds.g1)
        self.assertTrue(r, 48)
Exemple #9
0
def downsample(x, y, label=None, n_divisions=500) -> ArrayExpression:
    """Downsample (x, y) coordinate datapoints.

    Parameters
    ---------
    x : :class:`.NumericExpression`
        X-values to be downsampled.
    y : :class:`.NumericExpression`
        Y-values to be downsampled.
    label : :class:`.StringExpression` or :class:`.ArrayExpression`
        Additional data for each (x, y) coordinate. Can pass in multiple fields in an :class:`.ArrayExpression`.
    n_divisions : :obj:`int`
        Factor by which to downsample (default value = 500). A lower input results in fewer output datapoints.

    Returns
    -------
    :class:`.ArrayExpression`
        Expression for downsampled coordinate points (x, y). The element type of the array is
        :py:data:`.ttuple` of :py:data:`.tfloat64`, :py:data:`.tfloat64`, and :py:data:`.tarray` of :py:data:`.tstring`
    """
    if label is None:
        label = hl.null(hl.tarray(hl.tstr))
    elif isinstance(label, StringExpression):
        label = hl.array([label])
    return _agg_func('downsample', [x, y, label], tarray(ttuple(tfloat64, tfloat64, tarray(tstr))),
                     constructor_args=[n_divisions])
Exemple #10
0
def combine(ts):
    # pylint: disable=protected-access
    tmp = ts.annotate(
        alleles=merge_alleles(ts.data.map(lambda d: d.alleles)),
        rsid=hl.find(hl.is_defined, ts.data.map(lambda d: d.rsid)),
        filters=hl.set(hl.flatten(ts.data.map(lambda d: hl.array(d.filters)))),
        info=hl.struct(
            DP=hl.sum(ts.data.map(lambda d: d.info.DP)),
            MQ_DP=hl.sum(ts.data.map(lambda d: d.info.MQ_DP)),
            QUALapprox=hl.sum(ts.data.map(lambda d: d.info.QUALapprox)),
            RAW_MQ=hl.sum(ts.data.map(lambda d: d.info.RAW_MQ)),
            VarDP=hl.sum(ts.data.map(lambda d: d.info.VarDP)),
            SB=hl.array([
                hl.sum(ts.data.map(lambda d: d.info.SB[0])),
                hl.sum(ts.data.map(lambda d: d.info.SB[1])),
                hl.sum(ts.data.map(lambda d: d.info.SB[2])),
                hl.sum(ts.data.map(lambda d: d.info.SB[3]))
            ])))
    tmp = tmp.annotate(
        __entries=hl.bind(
            lambda combined_allele_index:
            hl.range(0, hl.len(tmp.data)).flatmap(
                lambda i:
                hl.cond(hl.is_missing(tmp.data[i].__entries),
                        hl.range(0, hl.len(tmp.g[i].__cols))
                          .map(lambda _: hl.null(tmp.data[i].__entries.dtype.element_type)),
                        hl.bind(
                            lambda old_to_new: tmp.data[i].__entries.map(lambda e: renumber_entry(e, old_to_new)),
                            hl.range(0, hl.len(tmp.data[i].alleles)).map(
                                lambda j: combined_allele_index[tmp.data[i].alleles[j]])))),
            hl.dict(hl.range(0, hl.len(tmp.alleles)).map(
                lambda j: hl.tuple([tmp.alleles[j], j])))))
    tmp = tmp.annotate_globals(__cols=hl.flatten(tmp.g.map(lambda g: g.__cols)))

    return tmp.drop('data', 'g')
Exemple #11
0
    def test_aggregate2(self):
        schema = hl.tstruct(status=hl.tint32, GT=hl.tcall, qPheno=hl.tint32)

        rows = [{'status': 0, 'GT': hl.Call([0, 0]), 'qPheno': 3},
                {'status': 0, 'GT': hl.Call([0, 1]), 'qPheno': 13}]

        kt = hl.Table.parallelize(rows, schema)

        result = convert_struct_to_dict(
            kt.group_by(status=kt.status)
                .aggregate(
                x1=agg.collect(kt.qPheno * 2),
                x2=agg.explode(lambda elt: agg.collect(elt), [kt.qPheno, kt.qPheno + 1]),
                x3=agg.min(kt.qPheno),
                x4=agg.max(kt.qPheno),
                x5=agg.sum(kt.qPheno),
                x6=agg.product(hl.int64(kt.qPheno)),
                x7=agg.count(),
                x8=agg.count_where(kt.qPheno == 3),
                x9=agg.fraction(kt.qPheno == 1),
                x10=agg.stats(hl.float64(kt.qPheno)),
                x11=agg.hardy_weinberg_test(kt.GT),
                x13=agg.inbreeding(kt.GT, 0.1),
                x14=agg.call_stats(kt.GT, ["A", "T"]),
                x15=agg.collect(hl.Struct(a=5, b="foo", c=hl.Struct(banana='apple')))[0],
                x16=agg.collect(hl.Struct(a=5, b="foo", c=hl.Struct(banana='apple')).c.banana)[0],
                x17=agg.explode(lambda elt: agg.collect(elt), hl.null(hl.tarray(hl.tint32))),
                x18=agg.explode(lambda elt: agg.collect(elt), hl.null(hl.tset(hl.tint32))),
                x19=agg.take(kt.GT, 1, ordering=-kt.qPheno)
            ).take(1)[0])

        expected = {u'status': 0,
                    u'x13': {u'n_called': 2, u'expected_homs': 1.64, u'f_stat': -1.777777777777777,
                             u'observed_homs': 1},
                    u'x14': {u'AC': [3, 1], u'AF': [0.75, 0.25], u'AN': 4, u'homozygote_count': [1, 0]},
                    u'x15': {u'a': 5, u'c': {u'banana': u'apple'}, u'b': u'foo'},
                    u'x10': {u'min': 3.0, u'max': 13.0, u'sum': 16.0, u'stdev': 5.0, u'n': 2, u'mean': 8.0},
                    u'x8': 1, u'x9': 0.0, u'x16': u'apple',
                    u'x11': {u'het_freq_hwe': 0.5, u'p_value': 0.5},
                    u'x2': [3, 4, 13, 14], u'x3': 3, u'x1': [6, 26], u'x6': 39, u'x7': 2, u'x4': 13, u'x5': 16,
                    u'x17': [],
                    u'x18': [],
                    u'x19': [hl.Call([0, 1])]}

        self.maxDiff = None

        self.assertDictEqual(result, expected)
Exemple #12
0
def combine(ts):
    def merge_alleles(alleles):
        from hail.expr.functions import _num_allele_type, _allele_ints
        return hl.rbind(
            alleles.map(lambda a: hl.or_else(a[0], '')).fold(
                lambda s, t: hl.cond(hl.len(s) > hl.len(t), s, t), ''),
            lambda ref: hl.rbind(
                alleles.map(lambda al: hl.rbind(
                    al[0], lambda r: hl.array([ref]).
                    extend(al[1:].map(lambda a: hl.rbind(
                        _num_allele_type(r, a), lambda at: hl.cond(
                            (_allele_ints['SNP'] == at) |
                            (_allele_ints['Insertion'] == at) |
                            (_allele_ints['Deletion'] == at) |
                            (_allele_ints['MNP'] == at) | (_allele_ints[
                                'Complex'] == at), a + ref[hl.len(r):], a)
                    ))))), lambda lal: hl.struct(globl=hl.array([ref]).extend(
                        hl.array(hl.set(hl.flatten(lal)).remove(ref))),
                                                 local=lal)))

    def renumber_entry(entry, old_to_new) -> StructExpression:
        # global index of alternate (non-ref) alleles
        return entry.annotate(LA=entry.LA.map(lambda lak: old_to_new[lak]))

    if (ts.row.dtype, ts.globals.dtype) not in _merge_function_map:
        f = hl.experimental.define_function(
            lambda row, gbl: hl.rbind(
                merge_alleles(row.data.map(lambda d: d.alleles)), lambda
                alleles: hl.struct(
                    locus=row.locus,
                    alleles=alleles.globl,
                    rsid=hl.find(hl.is_defined, row.data.map(lambda d: d.rsid)
                                 ),
                    __entries=hl.bind(
                        lambda combined_allele_index: hl.
                        range(0, hl.len(row.data)).flatmap(lambda i: hl.cond(
                            hl.is_missing(row.data[i].__entries),
                            hl.range(0, hl.len(gbl.g[i].__cols)).map(
                                lambda _: hl.null(row.data[i].__entries.dtype.
                                                  element_type)),
                            hl.bind(
                                lambda old_to_new: row.data[i].__entries.map(
                                    lambda e: renumber_entry(e, old_to_new)),
                                hl.range(0, hl.len(alleles.local[i])).map(
                                    lambda j: combined_allele_index[
                                        alleles.local[i][j]])))),
                        hl.dict(
                            hl.range(0, hl.len(alleles.globl)).map(
                                lambda j: hl.tuple([alleles.globl[j], j])))))),
            ts.row.dtype, ts.globals.dtype)
        _merge_function_map[(ts.row.dtype, ts.globals.dtype)] = f
    merge_function = _merge_function_map[(ts.row.dtype, ts.globals.dtype)]
    ts = Table(
        TableMapRows(
            ts._tir,
            Apply(merge_function._name, merge_function._ret_type,
                  TopLevelReference('row'), TopLevelReference('global'))))
    return ts.transmute_globals(
        __cols=hl.flatten(ts.g.map(lambda g: g.__cols)))
Exemple #13
0
def unphase_mt(mt: hl.MatrixTable) -> hl.MatrixTable:
    """
    Generate unphased version of MatrixTable (assumes call is in mt.GT and is diploid or haploid only)
    """
    return mt.annotate_entries(GT=hl.case().when(
        mt.GT.is_diploid(), hl.call(mt.GT[0], mt.GT[1], phased=False)).when(
            mt.GT.is_haploid(), hl.call(mt.GT[0], phased=False)).default(
                hl.null(hl.tcall)))
Exemple #14
0
def unify_saige_ht_variant_schema(ht):
    shared = ('markerID', 'AC', 'AF', 'N', 'BETA', 'SE', 'Tstat', 'varT',
              'varTstar')
    new_floats = ('AF.Cases', 'AF.Controls')
    new_ints = ('N.Cases', 'N.Controls')
    shared_end = ('Pvalue', 'gene', 'annotation')
    if 'AF.Cases' not in list(ht.row):
        ht = ht.select(*shared,
                       **{field: hl.null(hl.tfloat64)
                          for field in new_floats},
                       **{field: hl.null(hl.tint32)
                          for field in new_ints},
                       **{field: ht[field]
                          for field in shared_end})
    else:
        ht = ht.select(*shared, *new_floats, *new_ints, *shared_end)
    return ht.annotate(SE=hl.float64(ht.SE), AC=hl.int32(ht.AC))
Exemple #15
0
    def test_call_fields(self):
        expected = hl.Table.parallelize(
            [hl.struct(locus = hl.locus("X", 16050036), s = "C1046::HG02024",
                       GT = hl.call(0, 0), GTA = hl.null(hl.tcall), GTZ = hl.call(0, 1)),
             hl.struct(locus = hl.locus("X", 16050036), s = "C1046::HG02025",
                       GT = hl.call(1), GTA = hl.null(hl.tcall), GTZ = hl.call(0)),
             hl.struct(locus = hl.locus("X", 16061250), s = "C1046::HG02024",
                       GT = hl.call(2, 2), GTA = hl.call(2, 1), GTZ = hl.call(1, 1)),
             hl.struct(locus = hl.locus("X", 16061250), s = "C1046::HG02025",
                       GT = hl.call(2), GTA = hl.null(hl.tcall), GTZ = hl.call(1))],
            key=['locus', 's'])

        mt = hl.import_vcf(resource('generic.vcf'), call_fields=['GT', 'GTA', 'GTZ'])
        entries = mt.entries()
        entries = entries.key_by('locus', 's')
        entries = entries.select('GT', 'GTA', 'GTZ')
        self.assertTrue(entries._same(expected))
def split_position_end(position):
    return hl.or_missing(
        hl.is_defined(position),
        hl.bind(
            lambda start: hl.cond(start == "?", hl.null(hl.tint), hl.int(start)
                                  ),
            position.split("-")[-1]),
    )
Exemple #17
0
def load_gene_data(directory: str,
                   pheno_key_dict,
                   gene_ht_map_path: str,
                   n_cases: int = -1,
                   n_controls: int = -1,
                   heritability: float = -1.0,
                   saige_version: str = 'NA',
                   inv_normalized: str = 'NA',
                   overwrite: bool = False):
    output_ht_path = f'{directory}/gene_results.ht'
    print(f'Loading: {directory}/*.gene.txt ...')
    types = {f'Nmarker_MACCate_{i}': hl.tint32 for i in range(1, 9)}
    types.update({
        x: hl.tfloat64
        for x in ('Pvalue', 'Pvalue_Burden', 'Pvalue_SKAT', 'Pvalue_skato_NA',
                  'Pvalue_burden_NA', 'Pvalue_skat_NA')
    })
    ht = hl.import_table(f'{directory}/*.gene.txt',
                         delimiter=' ',
                         impute=True,
                         types=types)
    if n_cases == -1: n_cases = hl.null(hl.tint)
    if n_controls == -1: n_controls = hl.null(hl.tint)
    if heritability == -1.0: heritability = hl.null(hl.tfloat)
    if saige_version == 'NA': saige_version = hl.null(hl.tstr)
    if inv_normalized == 'NA': inv_normalized = hl.null(hl.tstr)

    fields = ht.Gene.split('_')
    gene_ht = hl.read_table(gene_ht_map_path).select('interval').distinct()
    ht = ht.key_by(
        gene_id=fields[0],
        gene_symbol=fields[1],
        annotation=fields[2],
        **pheno_key_dict).drop('Gene').naive_coalesce(10).annotate_globals(
            n_cases=n_cases,
            n_controls=n_controls,
            heritability=heritability,
            saige_version=saige_version,
            inv_normalized=inv_normalized)
    ht = ht.annotate(total_variants=hl.sum(
        [v for k, v in list(ht.row_value.items()) if 'Nmarker' in k]),
                     interval=gene_ht.key_by('gene_id')[ht.gene_id].interval)
    ht = ht.checkpoint(output_ht_path,
                       overwrite=overwrite,
                       _read_if_exists=not overwrite).drop(
                           'n_cases', 'n_controls')
Exemple #18
0
        def make_entry_struct(e, alleles_len, has_non_ref, row):
            handled_fields = dict()
            handled_names = {
                'LA', 'gvcf_info', 'END', 'LAD', 'AD', 'LGT', 'GT', 'LPL',
                'PL', 'LPGT', 'PGT'
            }

            if 'END' not in row.info:
                raise hl.utils.FatalError(
                    "the Hail GVCF combiner expects GVCFs to have an 'END' field in INFO."
                )
            if 'GT' not in e:
                raise hl.utils.FatalError(
                    "the Hail GVCF combiner expects GVCFs to have a 'GT' field in FORMAT."
                )

            handled_fields['LA'] = hl.range(
                0, alleles_len - hl.cond(has_non_ref, 1, 0))
            handled_fields['LGT'] = get_lgt(e, alleles_len, has_non_ref, row)
            if 'AD' in e:
                handled_fields['LAD'] = hl.cond(has_non_ref, e.AD[:-1], e.AD)
            if 'PGT' in e:
                handled_fields['LPGT'] = e.PGT
            if 'PL' in e:
                handled_fields['LPL'] = hl.cond(
                    has_non_ref,
                    hl.cond(alleles_len > 2, e.PL[:-alleles_len],
                            hl.null(e.PL.dtype)),
                    hl.cond(alleles_len > 1, e.PL, hl.null(e.PL.dtype)))
                handled_fields['RGQ'] = hl.cond(
                    has_non_ref,
                    e.PL[hl.call(0,
                                 alleles_len - 1).unphased_diploid_gt_index()],
                    hl.null(e.PL.dtype.element_type))

            handled_fields['END'] = row.info.END
            handled_fields['gvcf_info'] = (hl.case().when(
                hl.is_missing(row.info.END),
                hl.struct(**(parse_as_fields(row.info.select(
                    *info_to_keep), has_non_ref)))).or_missing())

            pass_through_fields = {
                k: v
                for k, v in e.items() if k not in handled_names
            }
            return hl.struct(**handled_fields, **pass_through_fields)
Exemple #19
0
 def get_lgt(e, n_alleles, has_non_ref, row):
     index = e.GT.unphased_diploid_gt_index()
     n_no_nonref = n_alleles - hl.int(has_non_ref)
     triangle_without_nonref = hl.triangle(n_no_nonref)
     return (hl.case().when(index < triangle_without_nonref, e.GT).when(
         index < hl.triangle(n_alleles),
         hl.null('call')).or_error('invalid GT ' + hl.str(e.GT) +
                                   ' at site ' + hl.str(row.locus)))
Exemple #20
0
def make_pheno_manifest(export=True):
    mt0 = load_final_sumstats_mt(filter_sumstats=False,
                                 filter_variants=False,
                                 separate_columns_by_pop=False,
                                 annotate_with_nearest_gene=False)

    ht = mt0.cols()
    annotate_dict = {}

    annotate_dict.update({
        'pops': hl.delimit(ht.pheno_data.pop),
        'num_pops': hl.len(ht.pheno_data.pop)
    })

    for field in ['n_cases', 'n_controls', 'heritability', 'lambda_gc']:
        for pop in ['AFR', 'AMR', 'CSA', 'EAS', 'EUR', 'MID']:
            new_field = field if field != 'heritability' else 'saige_heritability'  # new field name (only applicable to saige heritability)
            idx = ht.pheno_data.pop.index(pop)
            field_expr = ht.pheno_data[field]
            annotate_dict.update({
                f'{new_field}_{pop}':
                hl.if_else(hl.is_nan(idx), hl.null(field_expr[0].dtype),
                           field_expr[idx])
            })
    annotate_dict.update({'filename': get_pheno_id(tb=ht) + '.tsv.bgz'})
    ht = ht.annotate(**annotate_dict)

    dropbox_manifest = hl.import_table(
        f'{ldprune_dir}/UKBB_Pan_Populations-Manifest_20200615-manifest_info.tsv',
        impute=True,
        key='File')
    dropbox_manifest = dropbox_manifest.filter(
        dropbox_manifest['is_old_file'] != '1')
    bgz = dropbox_manifest.filter(~dropbox_manifest.File.contains('.tbi'))
    bgz = bgz.rename({'File': 'filename'})
    tbi = dropbox_manifest.filter(dropbox_manifest.File.contains('.tbi'))
    tbi = tbi.annotate(
        filename=tbi.File.replace('.tbi', '')).key_by('filename')

    dropbox_annotate_dict = {}

    rename_dict = {
        'dbox link': 'dropbox_link',
        'size (bytes)': 'size_in_bytes'
    }

    dropbox_annotate_dict.update({'filename_tabix': tbi[ht.filename].File})
    for field in ['dbox link', 'wget', 'size (bytes)', 'md5 hex']:
        for tb, suffix in [(bgz, ''), (tbi, '_tabix')]:
            dropbox_annotate_dict.update({
                (rename_dict[field] if field in rename_dict else field.replace(
                     ' ', '_')) + suffix:
                tb[ht.filename][field]
            })
    ht = ht.annotate(**dropbox_annotate_dict)
    ht = ht.drop('pheno_data')
    ht.describe()
    ht.show()
Exemple #21
0
def load_variant_data(directory: str,
                      pheno_key_dict,
                      ukb_vep_path: str,
                      extension: str = 'single.txt',
                      n_cases: int = -1,
                      n_controls: int = -1,
                      heritability: float = -1.0,
                      saige_version: str = 'NA',
                      inv_normalized: str = 'NA',
                      overwrite: bool = False,
                      legacy_annotations: bool = False,
                      num_partitions: int = 1000):
    output_ht_path = f'{directory}/variant_results.ht'
    ht = hl.import_table(f'{directory}/*.{extension}',
                         delimiter=' ',
                         impute=True)
    print(f'Loading: {directory}/*.{extension} ...')
    marker_id_col = 'markerID' if extension == 'single.txt' else 'SNPID'
    locus_alleles = ht[marker_id_col].split('_')
    if n_cases == -1: n_cases = hl.null(hl.tint)
    if n_controls == -1: n_controls = hl.null(hl.tint)
    if heritability == -1.0: heritability = hl.null(hl.tfloat)
    if saige_version == 'NA': saige_version = hl.null(hl.tstr)
    if inv_normalized == 'NA': inv_normalized = hl.null(hl.tstr)

    ht = ht.key_by(locus=hl.parse_locus(locus_alleles[0]),
                   alleles=locus_alleles[1].split('/'),
                   **pheno_key_dict).distinct().naive_coalesce(num_partitions)
    if marker_id_col == 'SNPID':
        ht = ht.drop('CHR', 'POS', 'rsid', 'Allele1', 'Allele2')
    ht = ht.transmute(Pvalue=ht['p.value']).annotate_globals(
        n_cases=n_cases,
        n_controls=n_controls,
        heritability=heritability,
        saige_version=saige_version,
        inv_normalized=inv_normalized)
    ht = ht.drop('varT', 'varTstar', 'N', 'Tstat')
    ht = ht.annotate(**get_vep_formatted_data(
        ukb_vep_path, legacy_annotations=legacy_annotations)[hl.struct(
            locus=ht.locus, alleles=ht.alleles
        )])  # TODO: fix this for variants that overlap multiple genes
    ht = ht.checkpoint(output_ht_path,
                       overwrite=overwrite,
                       _read_if_exists=not overwrite).drop(
                           'n_cases', 'n_controls', 'heritability')
Exemple #22
0
def load_prescription_data(prescription_data_tsv_path: str, prescription_mapping_tsv_path):
    ht = hl.import_table(prescription_data_tsv_path, types={'eid': hl.tint, 'data_provider': hl.tint}, key='eid')
    mapping_ht = hl.import_table(prescription_mapping_tsv_path, impute=True, key='Original_Prescription')
    ht = ht.annotate(issue_date=hl.cond(hl.len(ht.issue_date) == 0, hl.null(hl.tint64),
                                        hl.experimental.strptime(ht.issue_date + ' 00:00:00', '%d/%m/%Y %H:%M:%S', 'GMT')),
                     **mapping_ht[ht.drug_name])
    ht = ht.filter(ht.Generic_Name != '').key_by('eid', 'Generic_Name', 'Drug_Category_and_Indication').collect_by_key()
    ht = ht.annotate(values=hl.sorted(ht.values, key=lambda x: x.issue_date))
    return ht.to_matrix_table(row_key=['eid'], col_key=['Generic_Name'], col_fields=['Drug_Category_and_Indication'])
Exemple #23
0
    def test_explode_cols(self):
        mt = hl.utils.range_matrix_table(4, 4)
        mt = mt.annotate_entries(e=mt.row_idx * 10 + mt.col_idx)

        self.assertTrue(mt.annotate_cols(x=[1]).explode_cols('x').drop('x')._same(mt))

        self.assertEqual(mt.annotate_cols(x=hl.empty_array('int')).explode_cols('x').count_cols(), 0)
        self.assertEqual(mt.annotate_cols(x=hl.null('array<int>')).explode_cols('x').count_cols(), 0)
        self.assertEqual(mt.annotate_cols(x=hl.range(0, mt.col_idx)).explode_cols('x').count_cols(), 6)
Exemple #24
0
    def test_explode_cols(self):
        mt = hl.utils.range_matrix_table(4, 4)
        mt = mt.annotate_entries(e=mt.row_idx * 10 + mt.col_idx)

        self.assertTrue(mt.annotate_cols(x=[1]).explode_cols('x').drop('x')._same(mt))

        self.assertEqual(mt.annotate_cols(x=hl.empty_array('int')).explode_cols('x').count_cols(), 0)
        self.assertEqual(mt.annotate_cols(x=hl.null('array<int>')).explode_cols('x').count_cols(), 0)
        self.assertEqual(mt.annotate_cols(x=hl.range(0, mt.col_idx)).explode_cols('x').count_cols(), 6)
Exemple #25
0
def test_ndarray_eval():
    data_list = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
    mishapen_data_list1 = [[4], [1, 2, 3]]
    mishapen_data_list2 = [[[1], [2, 3]]]
    mishapen_data_list3 = [[4], [1, 2, 3], 5]

    nd_expr = hl.nd.array(data_list)
    evaled = hl.eval(nd_expr)
    np_equiv = np.array(data_list, dtype=np.int32)
    np_equiv_fortran_style = np.asfortranarray(np_equiv)
    np_equiv_extra_dimension = np_equiv.reshape((3, 1, 3))
    assert (np.array_equal(evaled, np_equiv))
    assert (evaled.strides == np_equiv.strides)

    assert hl.eval(hl.nd.array([[], []])).strides == (8, 8)
    assert np.array_equal(hl.eval(hl.nd.array([])), np.array([]))

    zero_array = np.zeros((10, 10), dtype=np.int64)
    evaled_zero_array = hl.eval(hl.literal(zero_array))

    assert np.array_equal(evaled_zero_array, zero_array)
    assert zero_array.dtype == evaled_zero_array.dtype

    # Testing correct interpretation of numpy strides
    assert np.array_equal(hl.eval(hl.literal(np_equiv_fortran_style)),
                          np_equiv_fortran_style)
    assert np.array_equal(hl.eval(hl.literal(np_equiv_extra_dimension)),
                          np_equiv_extra_dimension)

    # Testing from hail arrays
    assert np.array_equal(hl.eval(hl.nd.array(hl.range(6))), np.arange(6))
    assert np.array_equal(hl.eval(hl.nd.array(hl.int64(4))), np.array(4))

    # Testing from nested hail arrays
    assert np.array_equal(
        hl.eval(hl.nd.array(hl.array([hl.array(x) for x in data_list]))),
        np.arange(9).reshape((3, 3)) + 1)

    # Testing missing data
    assert hl.eval(hl.nd.array(hl.null(hl.tarray(hl.tint32)))) is None

    with pytest.raises(ValueError) as exc:
        hl.nd.array(mishapen_data_list1)
    assert "inner dimensions do not match" in str(exc.value)

    with pytest.raises(FatalError) as exc:
        hl.eval(hl.nd.array(hl.array(mishapen_data_list1)))
    assert "inner dimensions do not match" in str(exc.value)

    with pytest.raises(FatalError) as exc:
        hl.eval(hl.nd.array(hl.array(mishapen_data_list2)))
    assert "inner dimensions do not match" in str(exc.value)

    with pytest.raises(ValueError) as exc:
        hl.nd.array(mishapen_data_list3)
    assert "inner dimensions do not match" in str(exc.value)
Exemple #26
0
def get_expr_for_worst_transcript_consequence_annotations_struct(
        vep_sorted_transcript_consequences_root,
        include_coding_annotations=True):
    """Retrieves the top-ranked transcript annotation based on the ranking computed by
    get_expr_for_vep_sorted_transcript_consequences_array(..)

    Args:
        vep_sorted_transcript_consequences_root (ArrayExpression):
        include_coding_annotations (bool):
    """

    transcript_consequences = {
        "biotype": hl.tstr,
        "canonical": hl.tint,
        "category": hl.tstr,
        "cdna_start": hl.tint,
        "cdna_end": hl.tint,
        "codons": hl.tstr,
        "gene_id": hl.tstr,
        "gene_symbol": hl.tstr,
        "hgvs": hl.tstr,
        "hgvsc": hl.tstr,
        "major_consequence": hl.tstr,
        "major_consequence_rank": hl.tint,
        "transcript_id": hl.tstr,
    }

    if include_coding_annotations:
        transcript_consequences.update({
            "amino_acids": hl.tstr,
            "domains": hl.tstr,
            "hgvsp": hl.tstr,
            "lof": hl.tstr,
            "lof_flags": hl.tstr,
            "lof_filter": hl.tstr,
            "lof_info": hl.tstr,
            "polyphen_prediction": hl.tstr,
            "protein_id": hl.tstr,
            "sift_prediction": hl.tstr,
        })

    return hl.cond(
        vep_sorted_transcript_consequences_root.size() == 0,
        hl.struct(
            **{
                field: hl.null(field_type)
                for field, field_type in transcript_consequences.items()
            }),
        hl.bind(
            lambda worst_transcript_consequence:
            (worst_transcript_consequence.annotate(domains=hl.delimit(
                hl.set(worst_transcript_consequence.domains))).select(
                    *transcript_consequences.keys())),
            vep_sorted_transcript_consequences_root[0],
        ),
    )
Exemple #27
0
def test_ndarray_map():
    a = hl.nd.array([[2, 3, 4], [5, 6, 7]])
    b = hl.map(lambda x: -x, a)
    c = hl.map(lambda x: True, a)

    assert_ndarrays_eq((b, [[-2, -3, -4], [-5, -6, -7]]),
                       (c, [[True, True, True], [True, True, True]]))

    assert hl.eval(hl.null(hl.tndarray(hl.tfloat,
                                       1)).map(lambda x: x * 2)) is None
Exemple #28
0
def create_frequency_bins_expr(
        AC: hl.expr.NumericExpression,
        AF: hl.expr.NumericExpression) -> hl.expr.StringExpression:
    """
    Create bins for frequencies in preparation for aggregating QUAL by frequency bin.

    Bins:
        - singleton
        - doubleton
        - 0.00005
        - 0.0001
        - 0.0002
        - 0.0005
        - 0.001,
        - 0.002
        - 0.005
        - 0.01
        - 0.02
        - 0.05
        - 0.1
        - 0.2
        - 0.5
        - 1

    NOTE: Frequencies should be frequencies from raw data.
    Used when creating site quality distribution json files.

    :param AC: Field in input that contains the allele count information
    :param AF: Field in input that contains the allele frequency information
    :return: Expression containing bin name
    :rtype: hl.expr.StringExpression
    """
    bin_expr = (hl.case().when(AC == 1, "binned_singleton").when(
        AC == 2, "binned_doubleton").when(
            (AC > 2) & (AF < 0.00005), "binned_0.00005").when(
                (AF >= 0.00005) & (AF < 0.0001), "binned_0.0001").when(
                    (AF >= 0.0001) & (AF < 0.0002), "binned_0.0002").when(
                        (AF >= 0.0002) & (AF < 0.0005), "binned_0.0005").when(
                            (AF >= 0.0005) & (AF < 0.001),
                            "binned_0.001").when(
                                (AF >= 0.001) & (AF < 0.002),
                                "binned_0.002").when(
                                    (AF >= 0.002) & (AF < 0.005),
                                    "binned_0.005").when(
                                        (AF >= 0.005) & (AF < 0.01),
                                        "binned_0.01").when(
                                            (AF >= 0.01) & (AF < 0.02),
                                            "binned_0.02").
                when((AF >= 0.02) & (AF < 0.05), "binned_0.05").when(
                    (AF >= 0.05) & (AF < 0.1), "binned_0.1").when(
                        (AF >= 0.1) & (AF < 0.2), "binned_0.2").when(
                            (AF >= 0.2) & (AF < 0.5), "binned_0.5").when(
                                (AF >= 0.5) & (AF <= 1),
                                "binned_1").default(hl.null(hl.tstr)))
    return bin_expr
 def _genotype_fields(self):
     # Convert the mt genotype entries into num_alt, gq, ab, dp, and sample_id.
     is_called = hl.is_defined(self.mt.GT)
     return {
         'num_alt':
         hl.cond(is_called, self.mt.GT.n_alt_alleles(), -1),
         'gq':
         hl.cond(is_called, self.mt.GQ, hl.null(hl.tint)),
         'ab':
         hl.bind(
             lambda total: hl.cond(
                 (is_called) & (total != 0) & (hl.len(self.mt.AD) > 1),
                 hl.float(self.mt.AD[1] / total), hl.null(hl.tfloat)),
             hl.sum(self.mt.AD)),
         'dp':
         hl.cond(is_called, hl.int(hl.min(self.mt.DP, 32000)),
                 hl.null(hl.tfloat)),
         'sample_id':
         self.mt.s
     }
Exemple #30
0
 def parse_first_occurrence(x):
     return (hl.case(missing_false=True)
         .when(hl.is_defined(hl.parse_float(x)), hl.float64(x))  # Source of the first code ...
         .when(hl.literal(pseudo_dates).contains(hl.str(x)), hl.null(hl.tfloat64))  # Setting past and future dates to missing
         .when(hl.str(x) == '1902-02-02', 0.0)  # Matches DOB
         .when(hl.str(x) == '1903-03-03',  # Within year of birth (taking midpoint between month of birth and EOY)
               (hl.experimental.strptime('1970-12-31 00:00:00', '%Y-%m-%d %H:%M:%S', 'GMT') -
                hl.experimental.strptime('1970-' + month + '-15 00:00:00', '%Y-%m-%d %H:%M:%S',
                                         'GMT')) / 2)
         .default(hl.experimental.strptime(hl.str(x) + ' 00:00:00', '%Y-%m-%d %H:%M:%S', 'GMT') - dob
     ))
Exemple #31
0
def test_ndarray_slice():
    np_rect_prism = np.arange(24).reshape((2, 3, 4))
    rect_prism = hl.nd.array(np_rect_prism)
    np_mat = np.arange(8).reshape((2, 4))
    mat = hl.nd.array(np_mat)
    np_flat = np.arange(20)
    flat = hl.nd.array(np_flat)

    assert_ndarrays_eq(
        (rect_prism[:, :, :], np_rect_prism[:, :, :]),
        (rect_prism[:, :, 1], np_rect_prism[:, :, 1]),
        (rect_prism[0:1, 1:3, 0:2], np_rect_prism[0:1, 1:3, 0:2]),
        (rect_prism[:, :, 1:4:2], np_rect_prism[:, :, 1:4:2]),
        (rect_prism[:, 2, 1:4:2], np_rect_prism[:, 2, 1:4:2]),
        (rect_prism[0, 2, 1:4:2], np_rect_prism[0, 2, 1:4:2]),
        (rect_prism[0, :, 1:4:2] + rect_prism[:, :1, 1:4:2], np_rect_prism[0, :, 1:4:2] + np_rect_prism[:, :1, 1:4:2]),
        (rect_prism[0:, :, 1:4:2] + rect_prism[:, :1, 1:4:2], np_rect_prism[0:, :, 1:4:2] + np_rect_prism[:, :1, 1:4:2]),
        (mat[0, 1:4:2] + mat[:, 1:4:2], np_mat[0, 1:4:2] + np_mat[:, 1:4:2]),
        (rect_prism[0, 0, -3:-1], np_rect_prism[0, 0, -3:-1]),
        (flat[15:5:-1], np_flat[15:5:-1]),
        (flat[::-1], np_flat[::-1]),
        (flat[::22], np_flat[::22]),
        (flat[::-22], np_flat[::-22]),
        (flat[15:5], np_flat[15:5]),
        (flat[3:12:-1], np_flat[3:12:-1]),
        (flat[12:3:1], np_flat[12:3:1]),
        (mat[::-1, :], np_mat[::-1, :]),
        (flat[4:1:-2], np_flat[4:1:-2]),
        (flat[0:0:1], np_flat[0:0:1]),
        (flat[-4:-1:2], np_flat[-4:-1:2])
    )

    assert hl.eval(flat[hl.null(hl.tint32):4:1]) is None
    assert hl.eval(flat[4:hl.null(hl.tint32)]) is None
    assert hl.eval(flat[4:10:hl.null(hl.tint32)]) is None
    assert hl.eval(rect_prism[:, :, 0:hl.null(hl.tint32):1]) is None
    assert hl.eval(rect_prism[hl.null(hl.tint32), :, :]) is None

    with pytest.raises(FatalError) as exc:
        hl.eval(flat[::0])
    assert "Slice step cannot be zero" in str(exc)
Exemple #32
0
def get_codings():
    """
    Read codings data from Duncan's repo and load into hail Table

    :return: Hail table with codings
    :rtype: Table
    """
    root = f'{tempfile.gettempdir()}/PHESANT'
    if subprocess.check_call(['git', 'clone', 'https://github.com/astheeggeggs/PHESANT.git', root]):
        raise Exception('Could not clone repo')
    hts = []
    coding_dir = f'{root}/WAS/codings'
    for coding_file in os.listdir(f'{coding_dir}'):
        hl.hadoop_copy(f'file://{coding_dir}/{coding_file}', f'{coding_dir}/{coding_file}')
        ht = hl.import_table(f'{coding_dir}/{coding_file}')
        if 'node_id' not in ht.row:
            ht = ht.annotate(node_id=hl.null(hl.tstr), parent_id=hl.null(hl.tstr), selectable=hl.null(hl.tstr))
        ht = ht.annotate(coding_id=hl.int(coding_file.split('.')[0].replace('coding', '')))
        hts.append(ht)
    full_ht = hts[0].union(*hts[1:]).key_by('coding_id', 'coding')
    return full_ht.repartition(10)
Exemple #33
0
def create_all_values():
    return hl.struct(
        f32=hl.float32(3.14),
        i64=hl.int64(-9),
        m=hl.null(hl.tfloat64),
        astruct=hl.struct(a=hl.null(hl.tint32), b=5.5),
        mstruct=hl.null(hl.tstruct(x=hl.tint32, y=hl.tstr)),
        aset=hl.set(['foo', 'bar', 'baz']),
        mset=hl.null(hl.tset(hl.tfloat64)),
        d=hl.dict({hl.array(['a', 'b']): 0.5, hl.array(['x', hl.null(hl.tstr), 'z']): 0.3}),
        md=hl.null(hl.tdict(hl.tint32, hl.tstr)),
        h38=hl.locus('chr22', 33878978, 'GRCh38'),
        ml=hl.null(hl.tlocus('GRCh37')),
        i=hl.interval(
            hl.locus('1', 999),
            hl.locus('1', 1001)),
        c=hl.call(0, 1),
        mc=hl.null(hl.tcall),
        t=hl.tuple([hl.call(1, 2, phased=True), 'foo', hl.null(hl.tstr)]),
        mt=hl.null(hl.ttuple(hl.tlocus('GRCh37'), hl.tbool))
    )
Exemple #34
0
def create_all_values():
    return hl.struct(
        f32=hl.float32(3.14),
        i64=hl.int64(-9),
        m=hl.null(hl.tfloat64),
        astruct=hl.struct(a=hl.null(hl.tint32), b=5.5),
        mstruct=hl.null(hl.tstruct(x=hl.tint32, y=hl.tstr)),
        aset=hl.set(['foo', 'bar', 'baz']),
        mset=hl.null(hl.tset(hl.tfloat64)),
        d=hl.dict({hl.array(['a', 'b']): 0.5, hl.array(['x', hl.null(hl.tstr), 'z']): 0.3}),
        md=hl.null(hl.tdict(hl.tint32, hl.tstr)),
        h38=hl.locus('chr22', 33878978, 'GRCh38'),
        ml=hl.null(hl.tlocus('GRCh37')),
        i=hl.interval(
            hl.locus('1', 999),
            hl.locus('1', 1001)),
        c=hl.call(0, 1),
        mc=hl.null(hl.tcall),
        t=hl.tuple([hl.call(1, 2, phased=True), 'foo', hl.null(hl.tstr)]),
        mt=hl.null(hl.ttuple(hl.tlocus('GRCh37'), hl.tbool))
    )
def finalize_annotated_table_for_seqr_variants(
        mt: hl.MatrixTable) -> hl.MatrixTable:
    """Given a messily-but-completely annotated Hail MatrixTable of variants,
    return a new MatrixTable with appropriate formatting to export to Elasticsearch
    and consume  by Seqr.

    TO-EXTREMELY-DO: Create a app/common Python 3 module with code for SeqrAnnotatedVariant,
    with methods to im/export to/from Hail/Elasticsearch.

    :param vep_mt: A VCF loaded into hail 0.2, VEP has been run,
    and reference/computed fields have been added.
    :type vep_mt: hl.MatrixTable
    :return: A hail matrix table of variants and VEP annotations with
    proper formatting to be consumed by Seqr.
    :rtype: hl.MatrixTable
    """
    mt = mt.annotate_rows(
        sortedTranscriptConsequences=
        get_expr_for_vep_sorted_transcript_consequences_array(vep_root=mt.vep))

    mt = mt.annotate_rows(
        mainTranscript=hl.cond(
            hl.len(mt.sortedTranscriptConsequences) > 0,
            mt.sortedTranscriptConsequences[0],
            hl.null(
                "struct {biotype: str,canonical: int32,cdna_start: int32,cdna_end: int32,codons: str,gene_id: str,gene_symbol: str,hgvsc: str,hgvsp: str,transcript_id: str,amino_acids: str,lof: str,lof_filter: str,lof_flags: str,lof_info: str,polyphen_prediction: str,protein_id: str,protein_start: int32,sift_prediction: str,consequence_terms: array<str>,domains: array<str>,major_consequence: str,category: str,hgvs: str,major_consequence_rank: int32,transcript_rank: int32}"
            )),
        #allele_id=clinvar_mt.index_rows(mt.row_key).vep.id,
        alt=get_expr_for_alt_allele(mt),
        chrom=get_expr_for_contig(mt.locus),
        #clinvar_clinical_significance=clinvar_mt.index_rows(mt.row_key).clinical_significance,
        domains=get_expr_for_vep_protein_domains_set(
            vep_transcript_consequences_root=mt.vep.transcript_consequences),
        geneIds=hl.set(
            mt.vep.transcript_consequences.map(lambda c: c.gene_id)),
        # gene_id_to_consequence_json=get_expr_for_vep_gene_id_to_consequence_map(
        #     vep_sorted_transcript_consequences_root=mt.sortedTranscriptConsequences,
        #     gene_ids=clinvar_mt.gene_ids
        # ),
        #gold_stars= clinvar_mt.index_entries(mt.row_key,mt.col_key).gold_stars,
        pos=get_expr_for_start_pos(mt),
        ref=get_expr_for_ref_allele(mt),
        #review_status=clinvar_mt.index_rows(mt.locus,mt.alleles).review_status,
        transcript_consequence_terms=get_expr_for_vep_consequence_terms_set(
            vep_transcript_consequences_root=mt.sortedTranscriptConsequences),
        transcript_ids=get_expr_for_vep_transcript_ids_set(
            vep_transcript_consequences_root=mt.sortedTranscriptConsequences),
        transcript_id_to_consequence_json=
        get_expr_for_vep_transcript_id_to_consequence_map(
            vep_transcript_consequences_root=mt.sortedTranscriptConsequences),
        variant_id=get_expr_for_variant_id(mt),
        xpos=get_expr_for_xpos(mt.locus))
    return mt
Exemple #36
0
    def test_trio_matrix_null_keys(self):
        ped = hl.Pedigree.read(resource('triomatrix.fam'))
        ht = hl.import_fam(resource('triomatrix.fam'))

        mt = hl.import_vcf(resource('triomatrix.vcf'))
        mt = mt.annotate_cols(fam=ht[mt.s].fam_id)

        # Make keys all null
        mt = mt.key_cols_by(s=hl.null(hl.tstr))

        tt = hl.trio_matrix(mt, ped, complete_trios=True)
        self.assertEqual(tt.count_cols(), 0)
Exemple #37
0
def mwzj_hts_by_tree(all_hts,
                     temp_dir,
                     globals_for_col_key,
                     debug=False,
                     inner_mode='overwrite',
                     repartition_final: int = None):
    chunk_size = int(len(all_hts)**0.5) + 1
    outer_hts = []

    checkpoint_kwargs = {inner_mode: True}
    if repartition_final is not None:
        intervals = get_n_even_intervals(repartition_final)
        checkpoint_kwargs['_intervals'] = intervals

    if debug: print(f'Running chunk size {chunk_size}...')
    for i in range(chunk_size):
        if i * chunk_size >= len(all_hts): break
        hts = all_hts[i * chunk_size:(i + 1) * chunk_size]
        if debug:
            print(
                f'Going from {i * chunk_size} to {(i + 1) * chunk_size} ({len(hts)} HTs)...'
            )
        try:
            if isinstance(hts[0], str):
                hts = list(map(lambda x: hl.read_table(x), hts))
            ht = hl.Table.multi_way_zip_join(hts, 'row_field_name',
                                             'global_field_name')
        except:
            if debug:
                print(
                    f'problem in range {i * chunk_size}-{i * chunk_size + chunk_size}'
                )
                _ = [ht.describe() for ht in hts]
            raise
        outer_hts.append(
            ht.checkpoint(f'{temp_dir}/temp_output_{i}.ht',
                          **checkpoint_kwargs))
    ht = hl.Table.multi_way_zip_join(outer_hts, 'row_field_name_outer',
                                     'global_field_name_outer')
    ht = ht.transmute(inner_row=hl.flatmap(
        lambda i: hl.cond(
            hl.is_missing(ht.row_field_name_outer[i].row_field_name),
            hl.range(0, hl.len(ht.global_field_name_outer[i].global_field_name)
                     ).map(lambda _: hl.null(ht.row_field_name_outer[
                         i].row_field_name.dtype.element_type)), ht.
            row_field_name_outer[i].row_field_name),
        hl.range(hl.len(ht.global_field_name_outer))))
    ht = ht.transmute_globals(inner_global=hl.flatmap(
        lambda x: x.global_field_name, ht.global_field_name_outer))
    mt = ht._unlocalize_entries('inner_row', 'inner_global',
                                globals_for_col_key)
    return mt
Exemple #38
0
    def test_nulls_in_distinct_joins(self):

        # MatrixAnnotateRowsTable uses left distinct join
        mr = hl.utils.range_matrix_table(7, 3, 4)
        matrix1 = mr.key_rows_by(new_key=hl.cond((mr.row_idx == 3) | (
            mr.row_idx == 5), hl.null(hl.tint32), mr.row_idx))
        matrix2 = mr.key_rows_by(new_key=hl.cond((mr.row_idx == 4) | (
            mr.row_idx == 6), hl.null(hl.tint32), mr.row_idx))

        joined = matrix1.select_rows(
            idx1=matrix1.row_idx, idx2=matrix2.rows()[matrix1.new_key].row_idx)

        def row(new_key, idx1, idx2):
            return hl.Struct(new_key=new_key, idx1=idx1, idx2=idx2)

        expected = [
            row(0, 0, 0),
            row(1, 1, 1),
            row(2, 2, 2),
            row(4, 4, None),
            row(6, 6, None),
            row(None, 3, None),
            row(None, 5, None)
        ]
        self.assertEqual(joined.rows().collect(), expected)

        # union_cols uses inner distinct join
        matrix1 = matrix1.annotate_entries(ridx=matrix1.row_idx,
                                           cidx=matrix1.col_idx)
        matrix2 = matrix2.annotate_entries(ridx=matrix2.row_idx,
                                           cidx=matrix2.col_idx)
        matrix2 = matrix2.key_cols_by(col_idx=matrix2.col_idx + 3)

        expected = hl.utils.range_matrix_table(3, 6, 1)
        expected = expected.key_rows_by(new_key=expected.row_idx)
        expected = expected.annotate_entries(ridx=expected.row_idx,
                                             cidx=expected.col_idx % 3)

        self.assertTrue(matrix1.union_cols(matrix2)._same(expected))
Exemple #39
0
    def test_null_joins_2(self):
        tr = hl.utils.range_table(7, 1)
        table1 = tr.key_by(new_key=hl.cond((tr.idx == 3) | (tr.idx == 5),
                                           hl.null(hl.tint32), tr.idx),
                           key2=tr.idx)
        table1 = table1.select(idx1=table1.idx)
        table2 = tr.key_by(new_key=hl.cond((tr.idx == 4) | (tr.idx == 6),
                                           hl.null(hl.tint32), tr.idx),
                           key2=tr.idx)
        table2 = table2.select(idx2=table2.idx)

        left_join = table1.join(table2, 'left')
        right_join = table1.join(table2, 'right')
        inner_join = table1.join(table2, 'inner')
        outer_join = table1.join(table2, 'outer')

        def row(new_key, key2, idx1, idx2):
            return hl.Struct(new_key=new_key, key2=key2, idx1=idx1, idx2=idx2)

        left_join_expected = [row(0, 0, 0, 0), row(1, 1, 1, 1), row(2, 2, 2, 2),
                              row(4, 4, 4, None), row(6, 6, 6, None),
                              row(None, 3, 3, None), row(None, 5, 5, None)]

        right_join_expected = [row(0, 0, 0, 0), row(1, 1, 1, 1), row(2, 2, 2, 2),
                               row(3, 3, None, 3), row(5, 5, None, 5),
                               row(None, 4, None, 4), row(None, 6, None, 6)]

        inner_join_expected = [row(0, 0, 0, 0), row(1, 1, 1, 1), row(2, 2, 2, 2)]

        outer_join_expected = [row(0, 0, 0, 0), row(1, 1, 1, 1), row(2, 2, 2, 2),
                               row(3, 3, None, 3), row(4, 4, 4, None),
                               row(5, 5, None, 5), row(6, 6, 6, None),
                               row(None, 3, 3, None), row(None, 4, None, 4),
                               row(None, 5, 5, None), row(None, 6, None, 6)]

        self.assertEqual(left_join.collect(), left_join_expected)
        self.assertEqual(right_join.collect(), right_join_expected)
        self.assertEqual(inner_join.collect(), inner_join_expected)
        self.assertEqual(outer_join.collect(), outer_join_expected)
Exemple #40
0
def unphase_call_expr(call_expr: hl.expr.CallExpression) -> hl.expr.CallExpression:
    """
    Generate unphased version of a call expression (which can be phased or not)

    :param call_expr: Input call expression
    :return: unphased call expression
    """
    return (
        hl.case()
        .when(call_expr.is_diploid(), hl.call(call_expr[0], call_expr[1], phased=False))
        .when(call_expr.is_haploid(), hl.call(call_expr[0], phased=False))
        .default(hl.null(hl.tcall))
    )
Exemple #41
0
    def test_null_joins_2(self):
        tr = hl.utils.range_table(7, 1)
        table1 = tr.key_by(new_key=hl.cond((tr.idx == 3) | (tr.idx == 5),
                                           hl.null(hl.tint32), tr.idx),
                           key2=tr.idx)
        table1 = table1.select(idx1=table1.idx)
        table2 = tr.key_by(new_key=hl.cond((tr.idx == 4) | (tr.idx == 6),
                                           hl.null(hl.tint32), tr.idx),
                           key2=tr.idx)
        table2 = table2.select(idx2=table2.idx)

        left_join = table1.join(table2, 'left')
        right_join = table1.join(table2, 'right')
        inner_join = table1.join(table2, 'inner')
        outer_join = table1.join(table2, 'outer')

        def row(new_key, key2, idx1, idx2):
            return hl.Struct(new_key=new_key, key2=key2, idx1=idx1, idx2=idx2)

        left_join_expected = [row(0, 0, 0, 0), row(1, 1, 1, 1), row(2, 2, 2, 2),
                              row(4, 4, 4, None), row(6, 6, 6, None),
                              row(None, 3, 3, None), row(None, 5, 5, None)]

        right_join_expected = [row(0, 0, 0, 0), row(1, 1, 1, 1), row(2, 2, 2, 2),
                               row(3, 3, None, 3), row(5, 5, None, 5),
                               row(None, 4, None, 4), row(None, 6, None, 6)]

        inner_join_expected = [row(0, 0, 0, 0), row(1, 1, 1, 1), row(2, 2, 2, 2)]

        outer_join_expected = [row(0, 0, 0, 0), row(1, 1, 1, 1), row(2, 2, 2, 2),
                               row(3, 3, None, 3), row(4, 4, 4, None),
                               row(5, 5, None, 5), row(6, 6, 6, None),
                               row(None, 3, 3, None), row(None, 4, None, 4),
                               row(None, 5, 5, None), row(None, 6, None, 6)]

        self.assertEqual(left_join.collect(), left_join_expected)
        self.assertEqual(right_join.collect(), right_join_expected)
        self.assertEqual(inner_join.collect(), inner_join_expected)
        self.assertEqual(outer_join.collect(), outer_join_expected)
Exemple #42
0
def add_default_plink_fields(mt):
    """Add fields to PLINK"""
    return mt\
        .annotate_rows(rsid=hl.null(hl.tstr))\
        .annotate_cols(
            fam_id=hl.null(hl.tstr), pat_id=hl.null(hl.tstr), mat_id=hl.null(hl.tstr),
            is_female=hl.null(hl.tbool), is_case=hl.null(hl.tbool)
        )
Exemple #43
0
def make_pheno_manifest():
    mt0 = load_final_sumstats_mt(filter_sumstats=False,
                                 filter_variants=False,
                                 separate_columns_by_pop=False,
                                 annotate_with_nearest_gene=False)
    ht = mt0.cols()
    annotate_dict = {}

    annotate_dict.update({
        'pops': hl.delimit(ht.pheno_data.pop),
        'num_pops': hl.len(ht.pheno_data.pop)
    })

    for field in ['n_cases', 'n_controls', 'heritability', 'lambda_gc']:
        for pop in ['AFR', 'AMR', 'CSA', 'EAS', 'EUR', 'MID']:
            new_field = field if field != 'heritability' else 'saige_heritability'  # new field name (only applicable to saige heritability)
            idx = ht.pheno_data.pop.index(pop)
            field_expr = ht.pheno_data[field]
            annotate_dict.update({
                f'{new_field}_{pop}':
                hl.if_else(hl.is_nan(idx), hl.null(field_expr[0].dtype),
                           field_expr[idx])
            })
    annotate_dict.update({
        'filename':
        (ht.trait_type + '-' + ht.phenocode + '-' + ht.pheno_sex +
         hl.if_else(hl.len(ht.coding) > 0, '-' + ht.coding, '') +
         hl.if_else(hl.len(ht.modifier) > 0, '-' + ht.modifier, '')).replace(
             ' ', '_').replace('/', '_') + '.tsv.bgz'
    })
    ht = ht.annotate(**annotate_dict)
    aws_bucket = 'https://pan-ukb-us-east-1.s3.amazonaws.com/sumstats_release'
    ht = ht.annotate(aws_link=aws_bucket + '/' + ht.filename,
                     aws_link_tabix=aws_bucket + '_tabix/' + ht.filename +
                     '.tbi')

    other_fields_ht = hl.import_table(
        f'{ldprune_dir}/release/md5_hex_and_file_size.tsv.bgz',
        force_bgz=True,
        key=PHENO_KEY_FIELDS)
    other_fields = [
        'size_in_bytes', 'size_in_bytes_tabix', 'md5_hex', 'md5_hex_tabix'
    ]

    ht = ht.annotate(wget='wget ' + ht.aws_link,
                     wget_tabix='wget ' + ht.aws_link_tabix,
                     **{f: other_fields_ht[ht.key][f]
                        for f in other_fields})

    ht = ht.drop('pheno_data', 'pheno_indices')
    ht.export(f'{bucket}/combined_results/phenotype_manifest.tsv.bgz')
Exemple #44
0
def add_default_plink_fields(mt):
    return mt.annotate_rows(rsid=hl.null(hl.tstr)).annotate_cols(
        fam_id=hl.null(hl.tstr),
        pat_id=hl.null(hl.tstr),
        mat_id=hl.null(hl.tstr),
        is_female=hl.null(hl.tbool),
        is_case=hl.null(hl.tbool),
    )
Exemple #45
0
def create_all_values_datasets():
    all_values = hl.struct(
        f32=hl.float32(3.14),
        i64=hl.int64(-9),
        m=hl.null(hl.tfloat64),
        astruct=hl.struct(a=hl.null(hl.tint32), b=5.5),
        mstruct=hl.null(hl.tstruct(x=hl.tint32, y=hl.tstr)),
        aset=hl.set(['foo', 'bar', 'baz']),
        mset=hl.null(hl.tset(hl.tfloat64)),
        d=hl.dict({
            hl.array(['a', 'b']): 0.5,
            hl.array(['x', hl.null(hl.tstr), 'z']): 0.3
        }),
        md=hl.null(hl.tdict(hl.tint32, hl.tstr)),
        h38=hl.locus('chr22', 33878978, 'GRCh38'),
        ml=hl.null(hl.tlocus('GRCh37')),
        i=hl.interval(hl.locus('1', 999), hl.locus('1', 1001)),
        c=hl.call(0, 1),
        mc=hl.null(hl.tcall),
        t=hl.tuple([hl.call(1, 2, phased=True), 'foo',
                    hl.null(hl.tstr)]),
        mt=hl.null(hl.ttuple(hl.tlocus('GRCh37'), hl.tbool)))

    def prefix(s, p):
        return hl.struct(**{p + k: s[k] for k in s})

    all_values_table = (hl.utils.range_table(
        5, n_partitions=3).annotate_globals(
            **prefix(all_values, 'global_')).annotate(**all_values).cache())

    all_values_matrix_table = (hl.utils.range_matrix_table(
        3, 2, n_partitions=2).annotate_globals(
            **prefix(all_values, 'global_')).annotate_rows(
                **prefix(all_values, 'row_')).annotate_cols(
                    **prefix(all_values, 'col_')).annotate_entries(
                        **prefix(all_values, 'entry_')).cache())

    return all_values_table, all_values_matrix_table
Exemple #46
0
def create_all_values_datasets():
    all_values = hl.struct(
        f32=hl.float32(3.14),
        i64=hl.int64(-9),
        m=hl.null(hl.tfloat64),
        astruct=hl.struct(a=hl.null(hl.tint32), b=5.5),
        mstruct=hl.null(hl.tstruct(x=hl.tint32, y=hl.tstr)),
        aset=hl.set(['foo', 'bar', 'baz']),
        mset=hl.null(hl.tset(hl.tfloat64)),
        d=hl.dict({hl.array(['a', 'b']): 0.5, hl.array(['x', hl.null(hl.tstr), 'z']): 0.3}),
        md=hl.null(hl.tdict(hl.tint32, hl.tstr)),
        h38=hl.locus('chr22', 33878978, 'GRCh38'),
        ml=hl.null(hl.tlocus('GRCh37')),
        i=hl.interval(
            hl.locus('1', 999),
            hl.locus('1', 1001)),
        c=hl.call(0, 1),
        mc=hl.null(hl.tcall),
        t=hl.tuple([hl.call(1, 2, phased=True), 'foo', hl.null(hl.tstr)]),
        mt=hl.null(hl.ttuple(hl.tlocus('GRCh37'), hl.tbool))
    )

    def prefix(s, p):
        return hl.struct(**{p + k: s[k] for k in s})

    all_values_table = (hl.utils.range_table(5, n_partitions=3)
                        .annotate_globals(**prefix(all_values, 'global_'))
                        .annotate(**all_values)
                        .cache())

    all_values_matrix_table = (hl.utils.range_matrix_table(3, 2, n_partitions=2)
                               .annotate_globals(**prefix(all_values, 'global_'))
                               .annotate_rows(**prefix(all_values, 'row_'))
                               .annotate_cols(**prefix(all_values, 'col_'))
                               .annotate_entries(**prefix(all_values, 'entry_'))
                               .cache())

    return all_values_table, all_values_matrix_table
Exemple #47
0
    def test_nulls_in_distinct_joins(self):

        # MatrixAnnotateRowsTable uses left distinct join
        mr = hl.utils.range_matrix_table(7, 3, 4)
        matrix1 = mr.key_rows_by(new_key=hl.cond((mr.row_idx == 3) | (mr.row_idx == 5),
                                                hl.null(hl.tint32), mr.row_idx))
        matrix2 = mr.key_rows_by(new_key=hl.cond((mr.row_idx == 4) | (mr.row_idx == 6),
                                                hl.null(hl.tint32), mr.row_idx))

        joined = matrix1.select_rows(idx1=matrix1.row_idx,
                                     idx2=matrix2.rows()[matrix1.new_key].row_idx)

        def row(new_key, idx1, idx2):
            return hl.Struct(new_key=new_key, idx1=idx1, idx2=idx2)

        expected = [row(0, 0, 0),
                    row(1, 1, 1),
                    row(2, 2, 2),
                    row(4, 4, None),
                    row(6, 6, None),
                    row(None, 3, None),
                    row(None, 5, None)]
        self.assertEqual(joined.rows().collect(), expected)

        # union_cols uses inner distinct join
        matrix1 = matrix1.annotate_entries(ridx=matrix1.row_idx,
                                           cidx=matrix1.col_idx)
        matrix2 = matrix2.annotate_entries(ridx=matrix2.row_idx,
                                           cidx=matrix2.col_idx)
        matrix2 = matrix2.key_cols_by(col_idx=matrix2.col_idx + 3)

        expected = hl.utils.range_matrix_table(3, 6, 1)
        expected = expected.key_rows_by(new_key=expected.row_idx)
        expected = expected.annotate_entries(ridx=expected.row_idx,
                                             cidx=expected.col_idx % 3)

        self.assertTrue(matrix1.union_cols(matrix2)._same(expected))
Exemple #48
0
    def test_aggregate_ir(self):
        kt = hl.utils.range_table(10).annotate_globals(g1=5)
        r = kt.aggregate(hl.struct(x=agg.sum(kt.idx) + kt.g1,
                                   y=agg.filter(kt.idx % 2 != 0, agg.sum(kt.idx + 2)) + kt.g1,
                                   z=agg.sum(kt.g1 + kt.idx) + kt.g1))
        self.assertEqual(convert_struct_to_dict(r), {u'x': 50, u'y': 40, u'z': 100})

        r = kt.aggregate(5)
        self.assertEqual(r, 5)

        r = kt.aggregate(hl.null(hl.tint32))
        self.assertEqual(r, None)

        r = kt.aggregate(agg.filter(kt.idx % 2 != 0, agg.sum(kt.idx + 2)) + kt.g1)
        self.assertEqual(r, 40)
Exemple #49
0
    def transform_entries(old_entry):
        def with_local_a_index(local_a_index):
            new_pl = hl.or_missing(
                hl.is_defined(old_entry.LPL),
                hl.or_missing(
                    hl.is_defined(local_a_index),
                    hl.range(0, 3).map(lambda i: hl.min(
                        hl.range(0, hl.triangle(hl.len(old_entry.LA)))
                            .filter(lambda j: hl.downcode(hl.unphased_diploid_gt_index_call(j), local_a_index) == hl.unphased_diploid_gt_index_call(i))
                            .map(lambda idx: old_entry.LPL[idx])))))
            fields = set(old_entry.keys())

            def with_pl(pl):
                new_exprs = {}
                dropped_fields = ['LA']
                if 'LGT' in fields:
                    new_exprs['GT'] = hl.downcode(old_entry.LGT, hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LGT')
                if 'LPGT' in fields:
                    new_exprs['PGT'] = hl.downcode(old_entry.LPGT, hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LPGT')
                if 'LAD' in fields:
                    new_exprs['AD'] = hl.or_missing(
                        hl.is_defined(old_entry.LAD),
                        [old_entry.LAD[0], hl.or_else(old_entry.LAD[local_a_index], 0)]) # second entry zeroed for lack of non-ref AD
                    dropped_fields.append('LAD')
                if 'LPL' in fields:
                    new_exprs['PL'] = pl
                    if 'GQ' in fields:
                        new_exprs['GQ'] = hl.or_else(hl.gq_from_pl(pl), old_entry.GQ)

                    dropped_fields.append('LPL')

                return hl.cond(hl.len(ds.alleles) == 1,
                                   old_entry.annotate(**{f[1:]: old_entry[f] for f in ['LGT', 'LPGT', 'LAD', 'LPL'] if f in fields}).drop(*dropped_fields),
                                   old_entry.annotate(**new_exprs).drop(*dropped_fields))

            if 'LPL' in fields:
                return hl.bind(with_pl, new_pl)
            else:
                return with_pl(None)

        lai = hl.fold(lambda accum, elt:
                        hl.cond(old_entry.LA[elt] == ds[new_id].a_index,
                                elt, accum),
                        hl.null(hl.tint32),
                        hl.range(0, hl.len(old_entry.LA)))
        return hl.bind(with_local_a_index, lai)
    def phase_y_nonpar(
            proband_call: hl.expr.CallExpression,
            father_call: hl.expr.CallExpression,
    ) -> hl.expr.ArrayExpression:
        """
        Returns phased genotype calls in the non-PAR region of Y (requires both father and proband to be haploid to return phase)

        :param CallExpression proband_call: Input proband genotype call
        :param CallExpression father_call: Input father genotype call
        :return: Array containing: phased proband call, phased father call, phased mother call
        :rtype: ArrayExpression
        """
        return hl.or_missing(
            proband_call.is_haploid() & father_call.is_haploid() & (father_call[0] == proband_call[0]),
            hl.array([
                hl.call(proband_call[0], phased=True),
                hl.call(father_call[0], phased=True),
                hl.null(hl.tcall)
            ])
        )
Exemple #51
0
    def test_filter_na(self):
        mt = hl.utils.range_matrix_table(1, 1)

        self.assertEqual(mt.filter_rows(hl.null(hl.tbool)).count_rows(), 0)
        self.assertEqual(mt.filter_cols(hl.null(hl.tbool)).count_cols(), 0)
        self.assertEqual(mt.filter_entries(hl.null(hl.tbool)).entries().count(), 0)
Exemple #52
0
def variant_qc(mt, name='variant_qc') -> MatrixTable:
    """Compute common variant statistics (quality control metrics).

    .. include:: ../_templates/req_tvariant.rst

    Examples
    --------

    >>> dataset_result = hl.variant_qc(dataset)

    Notes
    -----
    This method computes variant statistics from the genotype data, returning
    a new struct field `name` with the following metrics based on the fields
    present in the entry schema.

    If `mt` contains an entry field `DP` of type :py:data:`.tint32`, then the
    field `dp_stats` is computed. If `mt` contains an entry field `GQ` of type
    :py:data:`.tint32`, then the field `gq_stats` is computed. Both `dp_stats`
    and `gq_stats` are structs with with four fields:

    - `mean` (``float64``) -- Mean value.
    - `stdev` (``float64``) -- Standard deviation (zero degrees of freedom).
    - `min` (``int32``) -- Minimum value.
    - `max` (``int32``) -- Maximum value.

    If the dataset does not contain an entry field `GT` of type
    :py:data:`.tcall`, then an error is raised. The following fields are always
    computed from `GT`:

    - `AF` (``array<float64>``) -- Calculated allele frequency, one element
      per allele, including the reference. Sums to one. Equivalent to
      `AC` / `AN`.
    - `AC` (``array<int32>``) -- Calculated allele count, one element per
      allele, including the reference. Sums to `AN`.
    - `AN` (``int32``) -- Total number of called alleles.
    - `homozygote_count` (``array<int32>``) -- Number of homozygotes per
      allele. One element per allele, including the reference.
    - `n_called` (``int64``) -- Number of samples with a defined `GT`.
    - `n_not_called` (``int64``) -- Number of samples with a missing `GT`.
    - `call_rate` (``float32``) -- Fraction of samples with a defined `GT`.
      Equivalent to `n_called` / :meth:`.count_cols`.
    - `n_het` (``int64``) -- Number of heterozygous samples.
    - `n_non_ref` (``int64``) -- Number of samples with at least one called
      non-reference allele.
    - `het_freq_hwe` (``float64``) -- Expected frequency of heterozygous
      samples under Hardy-Weinberg equilibrium. See
      :func:`.functions.hardy_weinberg_test` for details.
    - `p_value_hwe` (``float64``) -- p-value from test of Hardy-Weinberg equilibrium.
      See :func:`.functions.hardy_weinberg_test` for details.

    Warning
    -------
    `het_freq_hwe` and `p_value_hwe` are calculated as in
    :func:`.functions.hardy_weinberg_test`, with non-diploid calls
    (``ploidy != 2``) ignored in the counts. As this test is only
    statistically rigorous in the biallelic setting, :func:`.variant_qc`
    sets both fields to missing for multiallelic variants. Consider using
    :func:`~hail.methods.split_multi` to split multi-allelic variants beforehand.

    Parameters
    ----------
    mt : :class:`.MatrixTable`
        Dataset.
    name : :obj:`str`
        Name for resulting field.

    Returns
    -------
    :class:`.MatrixTable`
    """
    require_row_key_variant(mt, 'variant_qc')

    exprs = {}
    struct_exprs = []

    def has_field_of_type(name, dtype):
        return name in mt.entry and mt[name].dtype == dtype

    n_samples = mt.count_cols()

    if has_field_of_type('DP', hl.tint32):
        exprs['dp_stats'] = hl.agg.stats(mt.DP).select('mean', 'stdev', 'min', 'max')

    if has_field_of_type('GQ', hl.tint32):
        exprs['gq_stats'] = hl.agg.stats(mt.GQ).select('mean', 'stdev', 'min', 'max')

    if not has_field_of_type('GT',  hl.tcall):
        raise ValueError(f"'variant_qc': expect an entry field 'GT' of type 'call'")
    exprs['n_called'] = hl.agg.count_where(hl.is_defined(mt['GT']))
    struct_exprs.append(hl.agg.call_stats(mt.GT, mt.alleles))


    # the structure of this function makes it easy to add new nested computations
    def flatten_struct(*struct_exprs):
        flat = {}
        for struct in struct_exprs:
            for k, v in struct.items():
                flat[k] = v
        return hl.struct(
            **flat,
            **exprs,
        )

    mt = mt.annotate_rows(**{name: hl.bind(flatten_struct, *struct_exprs)})

    hwe = hl.hardy_weinberg_test(mt[name].homozygote_count[0],
                                 mt[name].AC[1] - 2 * mt[name].homozygote_count[1],
                                 mt[name].homozygote_count[1])
    hwe = hwe.select(het_freq_hwe=hwe.het_freq_hwe, p_value_hwe=hwe.p_value)
    mt = mt.annotate_rows(**{name: mt[name].annotate(n_not_called=n_samples - mt[name].n_called,
                                                     call_rate=mt[name].n_called / n_samples,
                                                     n_het=mt[name].n_called - hl.sum(mt[name].homozygote_count),
                                                     n_non_ref=mt[name].n_called - mt[name].homozygote_count[0],
                                                     **hl.cond(hl.len(mt.alleles) == 2,
                                                               hwe,
                                                               hl.null(hwe.dtype)))})
    return mt
Exemple #53
0
 def get_allele_type(allele_idx):
     return hl.cond(allele_idx > 0, mt[variant_atypes][allele_idx - 1], hl.null(hl.tint32))
Exemple #54
0
    def test_locus_windows(self):
        def assert_eq(a, b):
            self.assertTrue(np.array_equal(a, np.array(b)))

        centimorgans = hl.literal([0.1, 1.0, 1.0, 1.5, 1.9])

        mt = hl.balding_nichols_model(1, 5, 5).add_row_index()
        mt = mt.annotate_rows(cm=centimorgans[hl.int32(mt.row_idx)]).cache()

        starts, stops = hl.linalg.utils.locus_windows(mt.locus, 2)
        assert_eq(starts, [0, 0, 0, 1, 2])
        assert_eq(stops, [3, 4, 5, 5, 5])

        starts, stops = hl.linalg.utils.locus_windows(mt.locus, 0.5, coord_expr=mt.cm)
        assert_eq(starts, [0, 1, 1, 1, 3])
        assert_eq(stops, [1, 4, 4, 5, 5])

        starts, stops = hl.linalg.utils.locus_windows(mt.locus, 1.0, coord_expr=2 * centimorgans[hl.int32(mt.row_idx)])
        assert_eq(starts, [0, 1, 1, 1, 3])
        assert_eq(stops, [1, 4, 4, 5, 5])

        rows = [{'locus': hl.Locus('1', 1), 'cm': 1.0},
                {'locus': hl.Locus('1', 2), 'cm': 3.0},
                {'locus': hl.Locus('1', 4), 'cm': 4.0},
                {'locus': hl.Locus('2', 1), 'cm': 2.0},
                {'locus': hl.Locus('2', 1), 'cm': 2.0},
                {'locus': hl.Locus('3', 3), 'cm': 5.0}]

        ht = hl.Table.parallelize(rows,
                                  hl.tstruct(locus=hl.tlocus('GRCh37'), cm=hl.tfloat64),
                                  key=['locus'])

        starts, stops = hl.linalg.utils.locus_windows(ht.locus, 1)
        assert_eq(starts, [0, 0, 2, 3, 3, 5])
        assert_eq(stops, [2, 2, 3, 5, 5, 6])

        starts, stops = hl.linalg.utils.locus_windows(ht.locus, 1.0, coord_expr=ht.cm)
        assert_eq(starts, [0, 1, 1, 3, 3, 5])
        assert_eq(stops, [1, 3, 3, 5, 5, 6])

        with self.assertRaises(ValueError) as cm:
            hl.linalg.utils.locus_windows(ht.order_by(ht.cm).locus, 1.0)
        self.assertTrue('ascending order' in str(cm.exception))

        with self.assertRaises(ExpressionException) as cm:
            hl.linalg.utils.locus_windows(ht.locus, 1.0, coord_expr=hl.utils.range_table(1).idx)
        self.assertTrue('different source' in str(cm.exception))

        with self.assertRaises(ExpressionException) as cm:
            hl.linalg.utils.locus_windows(hl.locus('1', 1), 1.0)
        self.assertTrue("no source" in str(cm.exception))

        with self.assertRaises(ExpressionException) as cm:
            hl.linalg.utils.locus_windows(ht.locus, 1.0, coord_expr=0.0)
        self.assertTrue("no source" in str(cm.exception))

        ht = ht.annotate_globals(x = hl.locus('1', 1), y = 1.0)
        with self.assertRaises(ExpressionException) as cm:
            hl.linalg.utils.locus_windows(ht.x, 1.0)
        self.assertTrue("row-indexed" in str(cm.exception))
        with self.assertRaises(ExpressionException) as cm:
            hl.linalg.utils.locus_windows(ht.locus, 1.0, ht.y)
        self.assertTrue("row-indexed" in str(cm.exception))

        ht = hl.Table.parallelize([{'locus': hl.null(hl.tlocus()), 'cm': 1.0}],
                                  hl.tstruct(locus=hl.tlocus('GRCh37'), cm=hl.tfloat64), key=['locus'])
        with self.assertRaises(ValueError) as cm:
            hl.linalg.utils.locus_windows(ht.locus, 1.0)
        self.assertTrue("missing value for 'locus_expr'" in str(cm.exception))
        with self.assertRaises(ValueError) as cm:
            hl.linalg.utils.locus_windows(ht.locus, 1.0, coord_expr=ht.cm)
        self.assertTrue("missing value for 'locus_expr'" in str(cm.exception))

        ht = hl.Table.parallelize([{'locus': hl.Locus('1', 1), 'cm': hl.null(hl.tfloat64)}],
                                  hl.tstruct(locus=hl.tlocus('GRCh37'), cm=hl.tfloat64), key=['locus'])
        with self.assertRaises(ValueError) as cm:
            hl.linalg.utils.locus_windows(ht.locus, 1.0, coord_expr=ht.cm)
        self.assertTrue("missing value for 'coord_expr'" in str(cm.exception))
Exemple #55
0
def full_outer_join_mt(left: hl.MatrixTable, right: hl.MatrixTable) -> hl.MatrixTable:
    """Performs a full outer join on `left` and `right`.

    Replaces row, column, and entry fields with the following:

     - `left_row` / `right_row`: structs of row fields from left and right.
     - `left_col` / `right_col`: structs of column fields from left and right.
     - `left_entry` / `right_entry`: structs of entry fields from left and right.

    Parameters
    ----------
    left : :class:`.MatrixTable`
    right : :class:`.MatrixTable`

    Returns
    -------
    :class:`.MatrixTable`
    """

    if [x.dtype for x in left.row_key.values()] != [x.dtype for x in right.row_key.values()]:
        raise ValueError(f"row key types do not match:\n"
                         f"  left:  {list(left.row_key.values())}\n"
                         f"  right: {list(right.row_key.values())}")

    if [x.dtype for x in left.col_key.values()] != [x.dtype for x in right.col_key.values()]: 
        raise ValueError(f"column key types do not match:\n"
                         f"  left:  {list(left.col_key.values())}\n"
                         f"  right: {list(right.col_key.values())}")

    left = left.select_rows(left_row=left.row)
    left_t = left.localize_entries('left_entries', 'left_cols')
    right = right.select_rows(right_row=right.row)
    right_t = right.localize_entries('right_entries', 'right_cols')

    ht = left_t.join(right_t, how='outer')
    ht = ht.annotate_globals(
        left_keys=hl.group_by(
            lambda t: t[0],
            hl.zip_with_index(
                ht.left_cols.map(lambda x: hl.tuple([x[f] for f in left.col_key])), index_first=False)).map_values(
            lambda elts: elts.map(lambda t: t[1])),
        right_keys=hl.group_by(
            lambda t: t[0],
            hl.zip_with_index(
                ht.right_cols.map(lambda x: hl.tuple([x[f] for f in right.col_key])), index_first=False)).map_values(
            lambda elts: elts.map(lambda t: t[1])))
    ht = ht.annotate_globals(
        key_indices=hl.array(ht.left_keys.key_set().union(ht.right_keys.key_set()))
            .map(lambda k: hl.struct(k=k, left_indices=ht.left_keys.get(k), right_indices=ht.right_keys.get(k)))
            .flatmap(lambda s: hl.case()
                     .when(hl.is_defined(s.left_indices) & hl.is_defined(s.right_indices),
                           hl.range(0, s.left_indices.length()).flatmap(
                               lambda i: hl.range(0, s.right_indices.length()).map(
                                   lambda j: hl.struct(k=s.k, left_index=s.left_indices[i],
                                                       right_index=s.right_indices[j]))))
                     .when(hl.is_defined(s.left_indices),
                           s.left_indices.map(
                               lambda elt: hl.struct(k=s.k, left_index=elt, right_index=hl.null('int32'))))
                     .when(hl.is_defined(s.right_indices),
                           s.right_indices.map(
                               lambda elt: hl.struct(k=s.k, left_index=hl.null('int32'), right_index=elt)))
                     .or_error('assertion error')))
    ht = ht.annotate(__entries=ht.key_indices.map(lambda s: hl.struct(left_entry=ht.left_entries[s.left_index],
                                                                      right_entry=ht.right_entries[s.right_index])))
    ht = ht.annotate_globals(__cols=ht.key_indices.map(
        lambda s: hl.struct(**{f: s.k[i] for i, f in enumerate(left.col_key)},
                            left_col=ht.left_cols[s.left_index],
                            right_col=ht.right_cols[s.right_index])))
    ht = ht.drop('left_entries', 'left_cols', 'left_keys', 'right_entries', 'right_cols', 'right_keys', 'key_indices')
    return ht._unlocalize_entries('__entries', '__cols', list(left.col_key))
Exemple #56
0
def combine(ts):
    def merge_alleles(alleles):
        from hail.expr.functions import _num_allele_type, _allele_ints
        return hl.rbind(
            alleles.map(lambda a: hl.or_else(a[0], ''))
                   .fold(lambda s, t: hl.cond(hl.len(s) > hl.len(t), s, t), ''),
            lambda ref:
            hl.rbind(
                alleles.map(
                    lambda al: hl.rbind(
                        al[0],
                        lambda r:
                        hl.array([ref]).extend(
                            al[1:].map(
                                lambda a:
                                hl.rbind(
                                    _num_allele_type(r, a),
                                    lambda at:
                                    hl.cond(
                                        (_allele_ints['SNP'] == at) |
                                        (_allele_ints['Insertion'] == at) |
                                        (_allele_ints['Deletion'] == at) |
                                        (_allele_ints['MNP'] == at) |
                                        (_allele_ints['Complex'] == at),
                                        a + ref[hl.len(r):],
                                        a)))))),
                lambda lal:
                hl.struct(
                    globl=hl.array([ref]).extend(hl.array(hl.set(hl.flatten(lal)).remove(ref))),
                    local=lal)))

    def renumber_entry(entry, old_to_new) -> StructExpression:
        # global index of alternate (non-ref) alleles
        return entry.annotate(LA=entry.LA.map(lambda lak: old_to_new[lak]))

    if (ts.row.dtype, ts.globals.dtype) not in _merge_function_map:
        f = hl.experimental.define_function(
            lambda row, gbl:
            hl.rbind(
                merge_alleles(row.data.map(lambda d: d.alleles)),
                lambda alleles:
                hl.struct(
                    locus=row.locus,
                    alleles=alleles.globl,
                    rsid=hl.find(hl.is_defined, row.data.map(lambda d: d.rsid)),
                    __entries=hl.bind(
                        lambda combined_allele_index:
                        hl.range(0, hl.len(row.data)).flatmap(
                            lambda i:
                            hl.cond(hl.is_missing(row.data[i].__entries),
                                    hl.range(0, hl.len(gbl.g[i].__cols))
                                      .map(lambda _: hl.null(row.data[i].__entries.dtype.element_type)),
                                    hl.bind(
                                        lambda old_to_new: row.data[i].__entries.map(
                                            lambda e: renumber_entry(e, old_to_new)),
                                        hl.range(0, hl.len(alleles.local[i])).map(
                                            lambda j: combined_allele_index[alleles.local[i][j]])))),
                        hl.dict(hl.range(0, hl.len(alleles.globl)).map(
                            lambda j: hl.tuple([alleles.globl[j], j])))))),
            ts.row.dtype, ts.globals.dtype)
        _merge_function_map[(ts.row.dtype, ts.globals.dtype)] = f
    merge_function = _merge_function_map[(ts.row.dtype, ts.globals.dtype)]
    ts = Table(TableMapRows(ts._tir, Apply(merge_function._name,
                                           TopLevelReference('row'),
                                           TopLevelReference('global'))))
    return ts.transmute_globals(__cols=hl.flatten(ts.g.map(lambda g: g.__cols)))
Exemple #57
0
    def test_annotate(self):
        schema = hl.tstruct(a=hl.tint32, b=hl.tint32, c=hl.tint32, d=hl.tint32, e=hl.tstr, f=hl.tarray(hl.tint32))

        rows = [{'a': 4, 'b': 1, 'c': 3, 'd': 5, 'e': "hello", 'f': [1, 2, 3]},
                {'a': 0, 'b': 5, 'c': 13, 'd': -1, 'e': "cat", 'f': []},
                {'a': 4, 'b': 2, 'c': 20, 'd': 3, 'e': "dog", 'f': [5, 6, 7]}]

        kt = hl.Table.parallelize(rows, schema)

        self.assertTrue(kt.annotate()._same(kt))

        result1 = convert_struct_to_dict(kt.annotate(foo=kt.a + 1,
                                                     foo2=kt.a).take(1)[0])

        self.assertDictEqual(result1, {'a': 4,
                                       'b': 1,
                                       'c': 3,
                                       'd': 5,
                                       'e': "hello",
                                       'f': [1, 2, 3],
                                       'foo': 5,
                                       'foo2': 4})

        result3 = convert_struct_to_dict(kt.annotate(
            x1=kt.f.map(lambda x: x * 2),
            x2=kt.f.map(lambda x: [x, x + 1]).flatmap(lambda x: x),
            x3=hl.min(kt.f),
            x4=hl.max(kt.f),
            x5=hl.sum(kt.f),
            x6=hl.product(kt.f),
            x7=kt.f.length(),
            x8=kt.f.filter(lambda x: x == 3),
            x9=kt.f[1:],
            x10=kt.f[:],
            x11=kt.f[1:2],
            x12=kt.f.map(lambda x: [x, x + 1]),
            x13=kt.f.map(lambda x: [[x, x + 1], [x + 2]]).flatmap(lambda x: x),
            x14=hl.cond(kt.a < kt.b, kt.c, hl.null(hl.tint32)),
            x15={1, 2, 3}
        ).take(1)[0])

        self.assertDictEqual(result3, {'a': 4,
                                       'b': 1,
                                       'c': 3,
                                       'd': 5,
                                       'e': "hello",
                                       'f': [1, 2, 3],
                                       'x1': [2, 4, 6], 'x2': [1, 2, 2, 3, 3, 4],
                                       'x3': 1, 'x4': 3, 'x5': 6, 'x6': 6, 'x7': 3, 'x8': [3],
                                       'x9': [2, 3], 'x10': [1, 2, 3], 'x11': [2],
                                       'x12': [[1, 2], [2, 3], [3, 4]],
                                       'x13': [[1, 2], [3], [2, 3], [4], [3, 4], [5]],
                                       'x14': None, 'x15': set([1, 2, 3])})
        kt.annotate(
            x1=kt.a + 5,
            x2=5 + kt.a,
            x3=kt.a + kt.b,
            x4=kt.a - 5,
            x5=5 - kt.a,
            x6=kt.a - kt.b,
            x7=kt.a * 5,
            x8=5 * kt.a,
            x9=kt.a * kt.b,
            x10=kt.a / 5,
            x11=5 / kt.a,
            x12=kt.a / kt.b,
            x13=-kt.a,
            x14=+kt.a,
            x15=kt.a == kt.b,
            x16=kt.a == 5,
            x17=5 == kt.a,
            x18=kt.a != kt.b,
            x19=kt.a != 5,
            x20=5 != kt.a,
            x21=kt.a > kt.b,
            x22=kt.a > 5,
            x23=5 > kt.a,
            x24=kt.a >= kt.b,
            x25=kt.a >= 5,
            x26=5 >= kt.a,
            x27=kt.a < kt.b,
            x28=kt.a < 5,
            x29=5 < kt.a,
            x30=kt.a <= kt.b,
            x31=kt.a <= 5,
            x32=5 <= kt.a,
            x33=(kt.a == 0) & (kt.b == 5),
            x34=(kt.a == 0) | (kt.b == 5),
            x35=False,
            x36=True
        )
Exemple #58
0
    def test_filter_missing(self):
        ht = hl.utils.range_table(1, 1)

        self.assertEqual(ht.filter(hl.null(hl.tbool)).count(), 0)
Exemple #59
0
def transform_one(mt, vardp_outlier=100_000) -> Table:
    """transforms a gvcf into a form suitable for combining

    The input to this should be some result of either :func:`.import_vcf` or
    :func:`.import_vcfs` with `array_elements_required=False`.

    There is a strong assumption that this function will be called on a matrix
    table with one column.
    """
    mt = localize(mt)
    if mt.row.dtype not in _transform_rows_function_map:
        f = hl.experimental.define_function(
            lambda row: hl.rbind(
                hl.len(row.alleles), '<NON_REF>' == row.alleles[-1],
                lambda alleles_len, has_non_ref: hl.struct(
                    locus=row.locus,
                    alleles=hl.cond(has_non_ref, row.alleles[:-1], row.alleles),
                    rsid=row.rsid,
                    __entries=row.__entries.map(
                        lambda e:
                        hl.struct(
                            DP=e.DP,
                            END=row.info.END,
                            GQ=e.GQ,
                            LA=hl.range(0, alleles_len - hl.cond(has_non_ref, 1, 0)),
                            LAD=hl.cond(has_non_ref, e.AD[:-1], e.AD),
                            LGT=e.GT,
                            LPGT=e.PGT,
                            LPL=hl.cond(has_non_ref,
                                        hl.cond(alleles_len > 2,
                                                e.PL[:-alleles_len],
                                                hl.null(e.PL.dtype)),
                                        hl.cond(alleles_len > 1,
                                                e.PL,
                                                hl.null(e.PL.dtype))),
                            MIN_DP=e.MIN_DP,
                            PID=e.PID,
                            RGQ=hl.cond(
                                has_non_ref,
                                e.PL[hl.call(0, alleles_len - 1).unphased_diploid_gt_index()],
                                hl.null(e.PL.dtype.element_type)),
                            SB=e.SB,
                            gvcf_info=hl.case()
                                .when(hl.is_missing(row.info.END),
                                      hl.struct(
                                          ClippingRankSum=row.info.ClippingRankSum,
                                          BaseQRankSum=row.info.BaseQRankSum,
                                          MQ=row.info.MQ,
                                          MQRankSum=row.info.MQRankSum,
                                          MQ_DP=row.info.MQ_DP,
                                          QUALapprox=row.info.QUALapprox,
                                          RAW_MQ=row.info.RAW_MQ,
                                          ReadPosRankSum=row.info.ReadPosRankSum,
                                          VarDP=hl.cond(row.info.VarDP > vardp_outlier,
                                                        row.info.DP, row.info.VarDP)))
                                .or_missing()
                        ))),
            ),
            mt.row.dtype)
        _transform_rows_function_map[mt.row.dtype] = f
    transform_row = _transform_rows_function_map[mt.row.dtype]
    return Table(TableMapRows(mt._tir, Apply(transform_row._name, TopLevelReference('row'))))