コード例 #1
0
ファイル: test_misc.py プロジェクト: jigold/hail
 def test_maximal_independent_set_types(self):
     ht = hl.utils.range_table(10)
     ht = ht.annotate(i=hl.struct(a='1', b=hl.rand_norm(0, 1)),
                      j=hl.struct(a='2', b=hl.rand_norm(0, 1)))
     ht = ht.annotate(ii=hl.struct(id=ht.i, rank=hl.rand_norm(0, 1)),
                      jj=hl.struct(id=ht.j, rank=hl.rand_norm(0, 1)))
     hl.maximal_independent_set(ht.ii, ht.jj).count()
コード例 #2
0
ファイル: test_table.py プロジェクト: lfrancioli/hail
 def test_explode_on_set(self):
     t = hl.utils.range_table(1)
     t = t.annotate(a=hl.set(['a', 'b', 'c']))
     t = t.explode('a')
     self.assertEqual(set(t.collect()),
                      hl.eval(hl.set([hl.struct(idx=0, a='a'),
                                      hl.struct(idx=0, a='b'),
                                      hl.struct(idx=0, a='c')])))
コード例 #3
0
ファイル: test_table.py プロジェクト: lfrancioli/hail
    def test_from_pandas_works(self):
        d = {'a': [1, 2], 'b': ['foo', 'bar']}
        df = pd.DataFrame(data=d)
        t = hl.Table.from_pandas(df, key='a')

        d2 = [hl.struct(a=hl.int64(1), b='foo'), hl.struct(a=hl.int64(2), b='bar')]
        t2 = hl.Table.parallelize(d2, key='a')

        self.assertTrue(t._same(t2))
コード例 #4
0
ファイル: test_table.py プロジェクト: danking/hail
 def test_multi_way_zip_join_globals(self):
     t1 = hl.utils.range_table(1).annotate_globals(x=hl.null(hl.tint32))
     t2 = hl.utils.range_table(1).annotate_globals(x=5)
     t3 = hl.utils.range_table(1).annotate_globals(x=0)
     expected = hl.struct(__globals=hl.array([
         hl.struct(x=hl.null(hl.tint32)),
         hl.struct(x=5),
         hl.struct(x=0)]))
     joined = hl.Table._multi_way_zip_join([t1, t2, t3], '__data', '__globals')
     self.assertEqual(hl.eval(joined.globals), hl.eval(expected))
コード例 #5
0
ファイル: test_matrix_table.py プロジェクト: tpoterba/hail
 def test_agg_explode(self):
     t = hl.Table.parallelize([
         hl.struct(a=[1, 2]),
         hl.struct(a=hl.empty_array(hl.tint32)),
         hl.struct(a=hl.null(hl.tarray(hl.tint32))),
         hl.struct(a=[3]),
         hl.struct(a=[hl.null(hl.tint32)])
     ])
     self.assertCountEqual(t.aggregate(hl.agg.explode(lambda elt: hl.agg.collect(elt), t.a)),
                           [1, 2, None, 3])
コード例 #6
0
ファイル: test_reference_genome.py プロジェクト: bcajes/hail
    def test_liftover_strand(self):
        grch37 = hl.get_reference('GRCh37')
        grch37.add_liftover(resource('grch37_to_grch38_chr20.over.chain.gz'), 'GRCh38')

        self.assertEqual(hl.eval(hl.liftover(hl.locus('20', 60001, 'GRCh37'), 'GRCh38', include_strand=True)),
                         hl.eval(hl.struct(result=hl.locus('chr20', 79360, 'GRCh38'), is_negative_strand=False)))

        self.assertEqual(hl.eval(hl.liftover(hl.locus_interval('20', 37007582, 37007586, True, True, 'GRCh37'),
                                             'GRCh38', include_strand=True)),
                         hl.eval(hl.struct(result=hl.locus_interval('chr12', 32563117, 32563121, True, True, 'GRCh38'),
                                           is_negative_strand=True)))

        grch37.remove_liftover("GRCh38")
コード例 #7
0
ファイル: vcf_combiner.py プロジェクト: jigold/hail
def summarize(mt):
    """Computes summary statistics

    Calls :func:`.quick_summary`. Calling both this and :func:`.quick_summary`, will lead
    to :func:`.quick_summary` being executed twice.

    Note
    ----
    You will not be able to run :func:`.combine_gvcfs` with the output of this
    function.
    """
    mt = quick_summary(mt)
    mt = hl.experimental.densify(mt)
    return mt.annotate_rows(info=hl.rbind(
        hl.agg.call_stats(lgt_to_gt(mt.LGT, mt.LA), mt.alleles),
        lambda gs: hl.struct(
            # here, we alphabetize the INFO fields by GATK convention
            AC=gs.AC[1:],  # The VCF spec indicates that AC and AF have Number=A, so we need
            AF=gs.AF[1:],  # to drop the first element from each of these.
            AN=gs.AN,
            BaseQRankSum=hl.median(hl.agg.collect(mt.entry.gvcf_info.BaseQRankSum)),
            ClippingRankSum=hl.median(hl.agg.collect(mt.entry.gvcf_info.ClippingRankSum)),
            DP=hl.agg.sum(mt.entry.DP),
            MQ=hl.median(hl.agg.collect(mt.entry.gvcf_info.MQ)),
            MQRankSum=hl.median(hl.agg.collect(mt.entry.gvcf_info.MQRankSum)),
            MQ_DP=mt.info.MQ_DP,
            QUALapprox=mt.info.QUALapprox,
            RAW_MQ=mt.info.RAW_MQ,
            ReadPosRankSum=hl.median(hl.agg.collect(mt.entry.gvcf_info.ReadPosRankSum)),
            SB_TABLE=mt.info.SB_TABLE,
            VarDP=mt.info.VarDP,
        )))
コード例 #8
0
ファイル: vcf_combiner.py プロジェクト: jigold/hail
 def merge_alleles(alleles):
     from hail.expr.functions import _num_allele_type, _allele_ints
     return hl.rbind(
         alleles.map(lambda a: hl.or_else(a[0], ''))
                .fold(lambda s, t: hl.cond(hl.len(s) > hl.len(t), s, t), ''),
         lambda ref:
         hl.rbind(
             alleles.map(
                 lambda al: hl.rbind(
                     al[0],
                     lambda r:
                     hl.array([ref]).extend(
                         al[1:].map(
                             lambda a:
                             hl.rbind(
                                 _num_allele_type(r, a),
                                 lambda at:
                                 hl.cond(
                                     (_allele_ints['SNP'] == at) |
                                     (_allele_ints['Insertion'] == at) |
                                     (_allele_ints['Deletion'] == at) |
                                     (_allele_ints['MNP'] == at) |
                                     (_allele_ints['Complex'] == at),
                                     a + ref[hl.len(r):],
                                     a)))))),
             lambda lal:
             hl.struct(
                 globl=hl.array([ref]).extend(hl.array(hl.set(hl.flatten(lal)).remove(ref))),
                 local=lal)))
コード例 #9
0
ファイル: vcf_combiner.py プロジェクト: bcajes/hail
def combine(ts):
    # pylint: disable=protected-access
    tmp = ts.annotate(
        alleles=merge_alleles(ts.data.map(lambda d: d.alleles)),
        rsid=hl.find(hl.is_defined, ts.data.map(lambda d: d.rsid)),
        filters=hl.set(hl.flatten(ts.data.map(lambda d: hl.array(d.filters)))),
        info=hl.struct(
            DP=hl.sum(ts.data.map(lambda d: d.info.DP)),
            MQ_DP=hl.sum(ts.data.map(lambda d: d.info.MQ_DP)),
            QUALapprox=hl.sum(ts.data.map(lambda d: d.info.QUALapprox)),
            RAW_MQ=hl.sum(ts.data.map(lambda d: d.info.RAW_MQ)),
            VarDP=hl.sum(ts.data.map(lambda d: d.info.VarDP)),
            SB=hl.array([
                hl.sum(ts.data.map(lambda d: d.info.SB[0])),
                hl.sum(ts.data.map(lambda d: d.info.SB[1])),
                hl.sum(ts.data.map(lambda d: d.info.SB[2])),
                hl.sum(ts.data.map(lambda d: d.info.SB[3]))
            ])))
    tmp = tmp.annotate(
        __entries=hl.bind(
            lambda combined_allele_index:
            hl.range(0, hl.len(tmp.data)).flatmap(
                lambda i:
                hl.cond(hl.is_missing(tmp.data[i].__entries),
                        hl.range(0, hl.len(tmp.g[i].__cols))
                          .map(lambda _: hl.null(tmp.data[i].__entries.dtype.element_type)),
                        hl.bind(
                            lambda old_to_new: tmp.data[i].__entries.map(lambda e: renumber_entry(e, old_to_new)),
                            hl.range(0, hl.len(tmp.data[i].alleles)).map(
                                lambda j: combined_allele_index[tmp.data[i].alleles[j]])))),
            hl.dict(hl.range(0, hl.len(tmp.alleles)).map(
                lambda j: hl.tuple([tmp.alleles[j], j])))))
    tmp = tmp.annotate_globals(__cols=hl.flatten(tmp.g.map(lambda g: g.__cols)))

    return tmp.drop('data', 'g')
コード例 #10
0
ファイル: test_matrix_table.py プロジェクト: tpoterba/hail
    def test_aggregate_ir(self):
        ds = (hl.utils.range_matrix_table(5, 5)
              .annotate_globals(g1=5)
              .annotate_entries(e1=3))

        x = [("col_idx", lambda e: ds.aggregate_cols(e)),
             ("row_idx", lambda e: ds.aggregate_rows(e))]

        for name, f in x:
            r = f(hl.struct(x=agg.sum(ds[name]) + ds.g1,
                            y=agg.filter(ds[name] % 2 != 0, agg.sum(ds[name] + 2)) + ds.g1,
                            z=agg.sum(ds.g1 + ds[name]) + ds.g1,
                            mean=agg.mean(ds[name])))
            self.assertEqual(convert_struct_to_dict(r), {u'x': 15, u'y': 13, u'z': 40, u'mean': 2.0})

            r = f(5)
            self.assertEqual(r, 5)

            r = f(hl.null(hl.tint32))
            self.assertEqual(r, None)

            r = f(agg.filter(ds[name] % 2 != 0, agg.sum(ds[name] + 2)) + ds.g1)
            self.assertEqual(r, 13)

        r = ds.aggregate_entries(agg.filter((ds.row_idx % 2 != 0) & (ds.col_idx % 2 != 0),
                                            agg.sum(ds.e1 + ds.g1 + ds.row_idx + ds.col_idx)) + ds.g1)
        self.assertTrue(r, 48)
コード例 #11
0
ファイル: test_misc.py プロジェクト: danking/hail
    def test_filter_intervals_compound_key(self):
        ds = hl.import_vcf(resource('sample.vcf'), min_partitions=20)
        ds = (ds.annotate_rows(variant=hl.struct(locus=ds.locus, alleles=ds.alleles))
              .key_rows_by('locus', 'alleles'))

        intervals = [hl.Interval(hl.Struct(locus=hl.Locus('20', 10639222), alleles=['A', 'T']),
                                 hl.Struct(locus=hl.Locus('20', 10644700), alleles=['A', 'T']))]
        self.assertEqual(hl.filter_intervals(ds, intervals).count_rows(), 3)
コード例 #12
0
ファイル: test_matrix_table.py プロジェクト: tpoterba/hail
    def test_select_entries(self):
        mt = hl.utils.range_matrix_table(10, 10, n_partitions=4)
        mt = mt.annotate_entries(a=hl.struct(b=mt.row_idx, c=mt.col_idx), foo=mt.row_idx * 10 + mt.col_idx)
        mt = mt.select_entries(mt.a.b, mt.a.c, mt.foo)
        mt = mt.annotate_entries(bc=mt.b * 10 + mt.c)
        mt_entries = mt.entries()

        assert (mt_entries.all(mt_entries.bc == mt_entries.foo))
コード例 #13
0
ファイル: test_impex.py プロジェクト: lfrancioli/hail
    def test_call_fields(self):
        expected = hl.Table.parallelize(
            [hl.struct(locus = hl.locus("X", 16050036), s = "C1046::HG02024",
                       GT = hl.call(0, 0), GTA = hl.null(hl.tcall), GTZ = hl.call(0, 1)),
             hl.struct(locus = hl.locus("X", 16050036), s = "C1046::HG02025",
                       GT = hl.call(1), GTA = hl.null(hl.tcall), GTZ = hl.call(0)),
             hl.struct(locus = hl.locus("X", 16061250), s = "C1046::HG02024",
                       GT = hl.call(2, 2), GTA = hl.call(2, 1), GTZ = hl.call(1, 1)),
             hl.struct(locus = hl.locus("X", 16061250), s = "C1046::HG02025",
                       GT = hl.call(2), GTA = hl.null(hl.tcall), GTZ = hl.call(1))],
            key=['locus', 's'])

        mt = hl.import_vcf(resource('generic.vcf'), call_fields=['GT', 'GTA', 'GTZ'])
        entries = mt.entries()
        entries = entries.key_by('locus', 's')
        entries = entries.select('GT', 'GTA', 'GTZ')
        self.assertTrue(entries._same(expected))
コード例 #14
0
ファイル: test_impex.py プロジェクト: lfrancioli/hail
    def test_haploid(self):
        expected = hl.Table.parallelize(
            [hl.struct(locus = hl.locus("X", 16050036), s = "C1046::HG02024",
                       GT = hl.call(0, 0), AD = [10, 0], GQ = 44),
             hl.struct(locus = hl.locus("X", 16050036), s = "C1046::HG02025",
                       GT = hl.call(1), AD = [0, 6], GQ = 70),
             hl.struct(locus = hl.locus("X", 16061250), s = "C1046::HG02024",
                       GT = hl.call(2, 2), AD = [0, 0, 11], GQ = 33),
             hl.struct(locus = hl.locus("X", 16061250), s = "C1046::HG02025",
                       GT = hl.call(2), AD = [0, 0, 9], GQ = 24)],
            key=['locus', 's'])

        mt = hl.import_vcf(resource('haploid.vcf'))
        entries = mt.entries()
        entries = entries.key_by('locus', 's')
        entries = entries.select('GT', 'AD', 'GQ')
        self.assertTrue(entries._same(expected))
コード例 #15
0
ファイル: qc.py プロジェクト: tpoterba/hail
 def flatten_struct(*struct_exprs):
     flat = {}
     for struct in struct_exprs:
         for k, v in struct.items():
             flat[k] = v
     return hl.struct(
         **flat,
         **exprs,
     )
コード例 #16
0
ファイル: test_table.py プロジェクト: lfrancioli/hail
 def test_localize_entries(self):
     ref_schema = hl.tstruct(row_idx=hl.tint32,
                             __entries=hl.tarray(hl.tstruct(v=hl.tint32)))
     ref_data = [{'row_idx': i, '__entries': [{'v': i+j} for j in range(6)]}
                 for i in range(8)]
     ref_tab = hl.Table.parallelize(ref_data, ref_schema).key_by('row_idx')
     ref_tab = ref_tab.select_globals(__cols=[hl.struct(col_idx=i) for i in range(6)])
     mt = hl.utils.range_matrix_table(8, 6)
     mt = mt.annotate_entries(v=mt.row_idx+mt.col_idx)
     t = mt._localize_entries('__entries', '__cols')
     self.assertTrue(t._same(ref_tab))
コード例 #17
0
ファイル: sparse_split_multi.py プロジェクト: jigold/hail
 def struct_from_min_rep(i):
     return hl.bind(lambda mr:
                    (hl.case()
                     .when(ds.locus == mr.locus,
                           hl.struct(
                               locus=ds.locus,
                               alleles=[mr.alleles[0], mr.alleles[1]],
                               a_index=i,
                               was_split=True))
                     .or_error("Found non-left-aligned variant in sparse_split_multi")),
                    hl.min_rep(ds.locus, [ds.alleles[0], ds.alleles[i]]))
コード例 #18
0
ファイル: test_matrix_table.py プロジェクト: tpoterba/hail
    def test_explode_rows(self):
        mt = hl.utils.range_matrix_table(4, 4)
        mt = mt.annotate_entries(e=mt.row_idx * 10 + mt.col_idx)

        self.assertTrue(mt.annotate_rows(x=[1]).explode_rows('x').drop('x')._same(mt))

        self.assertEqual(mt.annotate_rows(x=hl.empty_array('int')).explode_rows('x').count_rows(), 0)
        self.assertEqual(mt.annotate_rows(x=hl.null('array<int>')).explode_rows('x').count_rows(), 0)
        self.assertEqual(mt.annotate_rows(x=hl.range(0, mt.row_idx)).explode_rows('x').count_rows(), 6)
        mt = mt.annotate_rows(x=hl.struct(y=hl.range(0, mt.row_idx)))
        self.assertEqual(mt.explode_rows(mt.x.y).count_rows(), 6)
コード例 #19
0
ファイル: helpers.py プロジェクト: jigold/hail
def create_all_values():
    return hl.struct(
        f32=hl.float32(3.14),
        i64=hl.int64(-9),
        m=hl.null(hl.tfloat64),
        astruct=hl.struct(a=hl.null(hl.tint32), b=5.5),
        mstruct=hl.null(hl.tstruct(x=hl.tint32, y=hl.tstr)),
        aset=hl.set(['foo', 'bar', 'baz']),
        mset=hl.null(hl.tset(hl.tfloat64)),
        d=hl.dict({hl.array(['a', 'b']): 0.5, hl.array(['x', hl.null(hl.tstr), 'z']): 0.3}),
        md=hl.null(hl.tdict(hl.tint32, hl.tstr)),
        h38=hl.locus('chr22', 33878978, 'GRCh38'),
        ml=hl.null(hl.tlocus('GRCh37')),
        i=hl.interval(
            hl.locus('1', 999),
            hl.locus('1', 1001)),
        c=hl.call(0, 1),
        mc=hl.null(hl.tcall),
        t=hl.tuple([hl.call(1, 2, phased=True), 'foo', hl.null(hl.tstr)]),
        mt=hl.null(hl.ttuple(hl.tlocus('GRCh37'), hl.tbool))
    )
コード例 #20
0
ファイル: family_methods.py プロジェクト: bcajes/hail
 def solve(p_de_novo):
     return (
         hl.case()
             .when(kid.GQ < min_gq, failure)
             .when((kid.DP / (parent.DP) < min_dp_ratio) |
                   (kid_ad_ratio < min_child_ab), failure)
             .when((hl.sum(parent.AD) == 0), failure)
             .when(parent.AD[1] / hl.sum(parent.AD) > max_parent_ab, failure)
             .when(p_de_novo < min_p, failure)
             .when(~is_snp, hl.case()
                   .when((p_de_novo > 0.99) & (kid_ad_ratio > 0.3) & (n_alt_alleles == 1),
                         hl.struct(p_de_novo=p_de_novo, confidence='HIGH'))
                   .when((p_de_novo > 0.5) & (kid_ad_ratio > 0.3) & (n_alt_alleles <= 5),
                         hl.struct(p_de_novo=p_de_novo, confidence='MEDIUM'))
                   .when((p_de_novo > 0.05) & (kid_ad_ratio > 0.3),
                         hl.struct(p_de_novo=p_de_novo, confidence='LOW'))
                   .or_missing())
             .default(hl.case()
                      .when(((p_de_novo > 0.99) & (kid_ad_ratio > 0.3) & (dp_ratio > 0.2)) |
                            ((p_de_novo > 0.99) & (kid_ad_ratio > 0.3) & (n_alt_alleles == 1)) |
                            ((p_de_novo > 0.5) & (kid_ad_ratio > 0.3) & (n_alt_alleles < 10) & (kid.DP > 10)),
                            hl.struct(p_de_novo=p_de_novo, confidence='HIGH'))
                      .when((p_de_novo > 0.5) & ((kid_ad_ratio > 0.3) | (n_alt_alleles == 1)),
                            hl.struct(p_de_novo=p_de_novo, confidence='MEDIUM'))
                      .when((p_de_novo > 0.05) & (kid_ad_ratio > 0.2),
                            hl.struct(p_de_novo=p_de_novo, confidence='LOW'))
                      .or_missing()
                      )
     )
コード例 #21
0
ファイル: test_file_formats.py プロジェクト: bcajes/hail
def create_all_values_datasets():
    all_values = hl.struct(
        f32=hl.float32(3.14),
        i64=hl.int64(-9),
        m=hl.null(hl.tfloat64),
        astruct=hl.struct(a=hl.null(hl.tint32), b=5.5),
        mstruct=hl.null(hl.tstruct(x=hl.tint32, y=hl.tstr)),
        aset=hl.set(['foo', 'bar', 'baz']),
        mset=hl.null(hl.tset(hl.tfloat64)),
        d=hl.dict({hl.array(['a', 'b']): 0.5, hl.array(['x', hl.null(hl.tstr), 'z']): 0.3}),
        md=hl.null(hl.tdict(hl.tint32, hl.tstr)),
        h38=hl.locus('chr22', 33878978, 'GRCh38'),
        ml=hl.null(hl.tlocus('GRCh37')),
        i=hl.interval(
            hl.locus('1', 999),
            hl.locus('1', 1001)),
        c=hl.call(0, 1),
        mc=hl.null(hl.tcall),
        t=hl.tuple([hl.call(1, 2, phased=True), 'foo', hl.null(hl.tstr)]),
        mt=hl.null(hl.ttuple(hl.tlocus('GRCh37'), hl.tbool))
    )

    def prefix(s, p):
        return hl.struct(**{p + k: s[k] for k in s})

    all_values_table = (hl.utils.range_table(5, n_partitions=3)
                        .annotate_globals(**prefix(all_values, 'global_'))
                        .annotate(**all_values)
                        .cache())

    all_values_matrix_table = (hl.utils.range_matrix_table(3, 2, n_partitions=2)
                               .annotate_globals(**prefix(all_values, 'global_'))
                               .annotate_rows(**prefix(all_values, 'row_'))
                               .annotate_cols(**prefix(all_values, 'col_'))
                               .annotate_entries(**prefix(all_values, 'entry_'))
                               .cache())

    return all_values_table, all_values_matrix_table
コード例 #22
0
ファイル: vcf_combiner.py プロジェクト: jigold/hail
def quick_summary(mt):
    """compute aggregate INFO fields that do not require densify"""
    return mt.annotate_rows(
        info=hl.struct(
            MQ_DP=hl.agg.sum(mt.entry.gvcf_info.MQ_DP),
            QUALapprox=hl.agg.sum(mt.entry.gvcf_info.QUALapprox),
            RAW_MQ=hl.agg.sum(mt.entry.gvcf_info.RAW_MQ),
            VarDP=hl.agg.sum(mt.entry.gvcf_info.VarDP),
            SB_TABLE=hl.array([
                hl.agg.sum(mt.entry.SB[0]),
                hl.agg.sum(mt.entry.SB[1]),
                hl.agg.sum(mt.entry.SB[2]),
                hl.agg.sum(mt.entry.SB[3]),
            ])))
コード例 #23
0
ファイル: test_table.py プロジェクト: lfrancioli/hail
    def test_aggregate_ir(self):
        kt = hl.utils.range_table(10).annotate_globals(g1=5)
        r = kt.aggregate(hl.struct(x=agg.sum(kt.idx) + kt.g1,
                                   y=agg.filter(kt.idx % 2 != 0, agg.sum(kt.idx + 2)) + kt.g1,
                                   z=agg.sum(kt.g1 + kt.idx) + kt.g1))
        self.assertEqual(convert_struct_to_dict(r), {u'x': 50, u'y': 40, u'z': 100})

        r = kt.aggregate(5)
        self.assertEqual(r, 5)

        r = kt.aggregate(hl.null(hl.tint32))
        self.assertEqual(r, None)

        r = kt.aggregate(agg.filter(kt.idx % 2 != 0, agg.sum(kt.idx + 2)) + kt.g1)
        self.assertEqual(r, 40)
コード例 #24
0
ファイル: ldscsim.py プロジェクト: jigold/hail
def add_sim_description(mt,starttime,stoptime,runtime,h2=None,pi=1,is_annot_inf=False,
                        annot_coef_dict=None, annot_regex=None,h2_normalize=True,
                        is_popstrat=False,cov_coef_dict=None,cov_regex=None,path_to_save=None):
    '''Annotates mt with description of simulation'''
    sim_id = 0
    while (str(sim_id) in [x.strip('sim_desc') for x in list(mt.globals) if 'sim_desc' in x]):
        sim_id += 1
    sim_desc = hl.struct(h2=none_to_null(h2),pi=pi,starttime=str(starttime),
                         stoptime=str(stoptime),runtime=str(runtime),
                         is_annot_inf=is_annot_inf,annot_coef_dict=none_to_null(annot_coef_dict),
                         annot_regex=none_to_null(annot_regex),h2_normalize=h2_normalize, 
                         is_popstrat=is_popstrat,cov_coef_dict=none_to_null(cov_coef_dict),
                         cov_regex=none_to_null(cov_regex),path_to_save=none_to_null(path_to_save))
    mt = mt._annotate_all(global_exprs={f'sim_desc{sim_id}':sim_desc})
    return mt
コード例 #25
0
ファイル: test_impex.py プロジェクト: lfrancioli/hail
    def test_import_bgen_variant_filtering_from_exprs(self):
        bgen_file = resource('example.8bits.bgen')
        hl.index_bgen(bgen_file, contig_recoding={'01': '1'})

        everything = hl.import_bgen(bgen_file, ['GT'])
        self.assertEqual(everything.count(), (199, 500))

        desired_variants = hl.struct(locus=everything.locus, alleles=everything.alleles)

        actual = hl.import_bgen(bgen_file,
                                ['GT'],
                                n_partitions=10,
                                variants=desired_variants) # filtering with everything

        self.assertTrue(everything._same(actual))
コード例 #26
0
ファイル: test_linalg.py プロジェクト: danking/hail
    def test_block_matrix_entries(self):
        n_rows, n_cols = 5, 3
        rows = [{'i': i, 'j': j, 'entry': float(i + j)} for i in range(n_rows) for j in range(n_cols)]
        schema = hl.tstruct(i=hl.tint32, j=hl.tint32, entry=hl.tfloat64)
        table = hl.Table.parallelize([hl.struct(i=row['i'], j=row['j'], entry=row['entry']) for row in rows], schema)
        table = table.annotate(i=hl.int64(table.i),
                               j=hl.int64(table.j)).key_by('i', 'j')

        ndarray = np.reshape(list(map(lambda row: row['entry'], rows)), (n_rows, n_cols))

        for block_size in [1, 2, 1024]:
            block_matrix = BlockMatrix.from_numpy(ndarray, block_size)
            entries_table = block_matrix.entries()
            self.assertEqual(entries_table.count(), n_cols * n_rows)
            self.assertEqual(len(entries_table.row), 3)
            self.assertTrue(table._same(entries_table))
コード例 #27
0
    def phase_diploid_proband(
            locus: hl.expr.LocusExpression,
            alleles: hl.expr.ArrayExpression,
            proband_call: hl.expr.CallExpression,
            father_call: hl.expr.CallExpression,
            mother_call: hl.expr.CallExpression
    ) -> hl.expr.ArrayExpression:
        """
        Returns phased genotype calls in the case of a diploid proband
        (autosomes, PAR regions of sex chromosomes or non-PAR regions of a female proband)

        :param LocusExpression locus: Locus in the trio MatrixTable
        :param ArrayExpression alleles: Alleles in the trio MatrixTable
        :param CallExpression proband_call: Input proband genotype call
        :param CallExpression father_call: Input father genotype call
        :param CallExpression mother_call: Input mother genotype call
        :return: Array containing: phased proband call, phased father call, phased mother call
        :rtype: ArrayExpression
        """

        proband_v = proband_call.one_hot_alleles(alleles)
        father_v = hl.cond(
            locus.in_x_nonpar() | locus.in_y_nonpar(),
            hl.or_missing(father_call.is_haploid(), hl.array([father_call.one_hot_alleles(alleles)])),
            call_to_one_hot_alleles_array(father_call, alleles)
        )
        mother_v = call_to_one_hot_alleles_array(mother_call, alleles)

        combinations = hl.flatmap(
            lambda f:
            hl.zip_with_index(mother_v)
                .filter(lambda m: m[1] + f[1] == proband_v)
                .map(lambda m: hl.struct(m=m[0], f=f[0])),
            hl.zip_with_index(father_v)
        )

        return (
            hl.or_missing(
                hl.is_defined(combinations) & (hl.len(combinations) == 1),
                hl.array([
                    hl.call(father_call[combinations[0].f], mother_call[combinations[0].m], phased=True),
                    hl.cond(father_call.is_haploid(), hl.call(father_call[0], phased=True), phase_parent_call(father_call, combinations[0].f)),
                    phase_parent_call(mother_call, combinations[0].m)
                ])
            )
        )
コード例 #28
0
ファイル: test_impex.py プロジェクト: lfrancioli/hail
    def test_import_bgen_locus_filtering_from_exprs(self):
        bgen_file = resource('example.8bits.bgen')
        hl.index_bgen(bgen_file, contig_recoding={'01': '1'})

        everything = hl.import_bgen(bgen_file, ['GT'])
        self.assertEqual(everything.count(), (199, 500))

        actual_struct = hl.import_bgen(bgen_file,
                                ['GT'],
                                variants=hl.struct(locus=everything.locus))

        self.assertTrue(everything._same(actual_struct))

        actual_locus = hl.import_bgen(bgen_file,
                                ['GT'],
                                variants=everything.locus)

        self.assertTrue(everything._same(actual_locus))
コード例 #29
0
ファイル: test_matrix_table.py プロジェクト: tpoterba/hail
    def test_agg_call_stats(self):
        t = hl.Table.parallelize([
            hl.struct(c=hl.call(0, 0)),
            hl.struct(c=hl.call(0, 1)),
            hl.struct(c=hl.call(0, 2, phased=True)),
            hl.struct(c=hl.call(1)),
            hl.struct(c=hl.call(0)),
            hl.struct(c=hl.call())
        ])
        actual = t.aggregate(hl.agg.call_stats(t.c, ['A', 'T', 'G']))
        expected = hl.struct(AC=[5, 2, 1],
                             AF=[5.0 / 8.0, 2.0 / 8.0, 1.0 / 8.0],
                             AN=8,
                             homozygote_count=[1, 0, 0])

        self.assertTrue(hl.Table.parallelize([actual]),
                        hl.Table.parallelize([expected]))
コード例 #30
0
ファイル: test_matrix_table.py プロジェクト: tpoterba/hail
    def test_hardy_weinberg_test(self):
        mt = hl.import_vcf(resource('HWE_test.vcf'))
        mt = mt.select_rows(**hl.agg.hardy_weinberg_test(mt.GT))
        rt = mt.rows()
        expected = hl.Table.parallelize([
            hl.struct(
                locus=hl.locus('20', pos),
                alleles=alleles,
                het_freq_hwe=r,
                p_value=p)
            for (pos, alleles, r, p) in [
                (1, ['A', 'G'], 0.0, 0.5),
                (2, ['A', 'G'], 0.25, 0.5),
                (3, ['T', 'C'], 0.5357142857142857, 0.21428571428571427),
                (4, ['T', 'A'], 0.5714285714285714, 0.6571428571428573),
                (5, ['G', 'A'], 0.3333333333333333, 0.5)]],
            key=['locus', 'alleles'])
        self.assertTrue(rt.filter(rt.locus.position != 6)._same(expected))

        rt6 = rt.filter(rt.locus.position == 6).collect()[0]
        self.assertEqual(rt6['p_value'], 0.5)
        self.assertTrue(math.isnan(rt6['het_freq_hwe']))
コード例 #31
0
ファイル: vcf_combiner.py プロジェクト: zscu/hail
def transform_one(mt, vardp_outlier=100_000) -> Table:
    """transforms a gvcf into a form suitable for combining

    The input to this should be some result of either :func:`.import_vcf` or
    :func:`.import_vcfs` with `array_elements_required=False`.

    There is a strong assumption that this function will be called on a matrix
    table with one column.
    """
    mt = localize(mt)
    if mt.row.dtype not in _transform_rows_function_map:
        f = hl.experimental.define_function(
            lambda row: hl.rbind(
                hl.len(row.alleles),
                '<NON_REF>' == row.alleles[-1],
                lambda alleles_len, has_non_ref: hl.
                struct(locus=row.locus,
                       alleles=hl.cond(has_non_ref, row.alleles[:-1], row.
                                       alleles),
                       rsid=row.rsid,
                       __entries=row.__entries.map(lambda e: hl.struct(
                           DP=e.DP,
                           END=row.info.END,
                           GQ=e.GQ,
                           LA=hl.range(
                               0, alleles_len - hl.cond(has_non_ref, 1, 0)),
                           LAD=hl.cond(has_non_ref, e.AD[:-1], e.AD),
                           LGT=e.GT,
                           LPGT=e.PGT,
                           LPL=hl.cond(
                               has_non_ref,
                               hl.cond(alleles_len > 2, e.PL[:-alleles_len],
                                       hl.null(e.PL.dtype)),
                               hl.cond(alleles_len > 1, e.PL,
                                       hl.null(e.PL.dtype))),
                           MIN_DP=e.MIN_DP,
                           PID=e.PID,
                           RGQ=hl.cond(
                               has_non_ref, e.PL[hl.call(0, alleles_len - 1).
                                                 unphased_diploid_gt_index()],
                               hl.null(e.PL.dtype.element_type)),
                           SB=e.SB,
                           gvcf_info=hl.case().when(
                               hl.is_missing(row.info.END),
                               hl.struct(
                                   ClippingRankSum=row.info.ClippingRankSum,
                                   BaseQRankSum=row.info.BaseQRankSum,
                                   MQ=row.info.MQ,
                                   MQRankSum=row.info.MQRankSum,
                                   MQ_DP=row.info.MQ_DP,
                                   QUALapprox=row.info.QUALapprox,
                                   RAW_MQ=row.info.RAW_MQ,
                                   ReadPosRankSum=row.info.ReadPosRankSum,
                                   VarDP=hl.cond(
                                       row.info.VarDP > vardp_outlier, row.info
                                       .DP, row.info.VarDP))).or_missing()))),
            ), mt.row.dtype)
        _transform_rows_function_map[mt.row.dtype] = f
    transform_row = _transform_rows_function_map[mt.row.dtype]
    return Table(
        TableMapRows(mt._tir,
                     Apply(transform_row._name, TopLevelReference('row'))))
コード例 #32
0
def simulate_phenotypes(mt,
                        genotype,
                        h2,
                        pi=None,
                        rg=None,
                        annot=None,
                        popstrat=None,
                        popstrat_var=None,
                        exact_h2=False):
    r"""Simulate phenotypes for testing LD score regression.

    Simulates betas (SNP effects) under the infinitesimal, spike & slab, or 
    annotation-informed models, depending on parameters passed. Optionally adds
    population stratification.

    Parameters
    ----------
    mt : :class:`.MatrixTable`
        :class:`.MatrixTable` containing genotypes to be used. Also should contain 
        variant annotations as row fields if running the annotation-informed
        model or covariates as column fields if adding population stratification.
    genotype : :class:`.Expression` or :class:`.CallExpression`
        Entry field containing genotypes of individuals to be used for the
        simulation.
    h2 : :obj:`float` or :obj:`int` or :obj:`list` or :class:`numpy.ndarray`
        SNP-based heritability of simulated trait.
    pi : :obj:`float` or :obj:`int` or :obj:`list` or :class:`numpy.ndarray`, optional
        Probability of SNP being causal when simulating under the spike & slab 
        model.
    rg : :obj:`float` or :obj:`int` or :obj:`list` or :class:`numpy.ndarray`, optional
        Genetic correlation between traits.
    annot : :class:`.Expression`, optional
        Row field to use as our aggregated annotations.
    popstrat: :class:`.Expression`, optional
        Column field to use as our aggregated covariates for adding population
        stratification.
    exact_h2: :obj:`bool`, optional
        Whether to exactly simulate ratio of variance of genetic component of 
        phenotype to variance of phenotype to be h2. If `False`, ratio will be
        h2 in expectation. Observed h2 in the simulation will be close to 
        expected h2 for large-scale simulations.

    Returns
    -------
    :class:`.MatrixTable`
        :class:`.MatrixTable` with simulated betas and phenotypes, simulated according
        to specified model.
    """
    h2 = h2.tolist() if type(h2) is np.ndarray else (
        [h2] if type(h2) is not list else h2)
    pi = pi.tolist() if type(pi) is np.ndarray else pi
    uid = Env.get_uid(base=100)
    mt = annotate_all(
        mt=mt,
        row_exprs={} if annot is None else {'annot_' + uid: annot},
        col_exprs={} if popstrat is None else {'popstrat_' + uid: popstrat},
        entry_exprs={
            'gt_' + uid:
            genotype.n_alt_alleles()
            if genotype.dtype is dtype('call') else genotype
        })
    mt, pi, rg = make_betas(mt=mt,
                            h2=h2,
                            pi=pi,
                            annot=None if annot is None else mt['annot_' +
                                                                uid],
                            rg=rg)
    mt = calculate_phenotypes(
        mt=mt,
        genotype=mt['gt_' + uid],
        beta=mt['beta'],
        h2=h2,
        popstrat=None if popstrat is None else mt['popstrat_' + uid],
        popstrat_var=popstrat_var,
        exact_h2=exact_h2)
    mt = annotate_all(mt=mt,
                      global_exprs={
                          'ldscsim':
                          hl.struct(
                              **{
                                  'h2':
                                  h2[0] if len(h2) == 1 else h2,
                                  **({} if pi == [None] else {
                                         'pi': pi
                                     }),
                                  **({} if rg == [None] else {
                                         'rg': rg[0] if len(rg) == 1 else rg
                                     }),
                                  **({} if annot is None else {
                                         'is_annot_inf': True
                                     }),
                                  **({} if popstrat is None else {
                                         'is_popstrat_inf': True
                                     }),
                                  **({} if popstrat_var is None else {
                                         'popstrat_var': popstrat_var
                                     }), 'exact_h2':
                                  exact_h2
                              })
                      })
    mt = _clean_fields(mt, uid)
    return mt
コード例 #33
0
def merge_stats_counters_expr(
    stats: hl.expr.ArrayExpression, ) -> hl.expr.StructExpression:
    """
    Merges multiple stats counters, assuming that they were computed on non-overlapping data.

    Examples:

    - Merge stats computed on indel and snv separately
    - Merge stats computed on bi-allelic and multi-allelic variants separately
    - Merge stats computed on autosomes and sex chromosomes separately

    :param stats: An array of stats counters to merge
    :return: Merged stats Struct
    """
    def add_stats(i: hl.expr.StructExpression,
                  j: hl.expr.StructExpression) -> hl.expr.StructExpression:
        """
        This merges two stast counters together. It assumes that all stats counter fields are present in the struct.

        :param i: accumulator: struct with mean, n and variance
        :param j: new element: stats_struct -- needs to contain mean, n and variance
        :return: Accumulation over all elements: struct with mean, n and variance
        """
        delta = j.mean - i.mean
        n_tot = i.n + j.n
        return hl.struct(
            min=hl.min(i.min, j.min),
            max=hl.max(i.max, j.max),
            mean=(i.mean * i.n + j.mean * j.n) / n_tot,
            variance=i.variance + j.variance +
            (delta * delta * i.n * j.n) / n_tot,
            n=n_tot,
            sum=i.sum + j.sum,
        )

    # Gather all metrics present in all stats counters
    metrics = set(stats[0])
    dropped_metrics = set()
    for stat_expr in stats[1:]:
        stat_expr_metrics = set(stat_expr)
        dropped_metrics = dropped_metrics.union(
            stat_expr_metrics.difference(metrics))
        metrics = metrics.intersection(stat_expr_metrics)
    if dropped_metrics:
        logger.warning(
            f"The following metrics will be dropped during stats counter merging as they do not appear in all counters: {', '.join(dropped_metrics)}"
        )

    # Because merging standard deviation requires having the mean and n,
    # check that they are also present if `stdev` is. Otherwise remove stdev
    if "stdev" in metrics:
        missing_fields = [x for x in ["n", "mean"] if x not in metrics]
        if missing_fields:
            logger.warning(
                f'Cannot merge `stdev` from given stats counters since they are missing the following fields: {",".join(missing_fields)}'
            )
            metrics.remove("stdev")

    # Create a struct with all possible stats for merging.
    # This step helps when folding because we can rely on the struct schema
    # Note that for intermediate merging, we compute the variance rather than the stdev
    all_stats = hl.array(stats).map(lambda x: hl.struct(
        min=x.min if "min" in metrics else hl.null(hl.tfloat64),
        max=x.max if "max" in metrics else hl.null(hl.tfloat64),
        mean=x.mean if "mean" in metrics else hl.null(hl.tfloat64),
        variance=x.stdev * x.stdev
        if "stdev" in metrics else hl.null(hl.tfloat64),
        n=x.n if "n" in metrics else hl.null(hl.tfloat64),
        sum=x.sum if "sum" in metrics else hl.null(hl.tfloat64),
    ))

    # Merge the stats
    agg_stats = all_stats[1:].fold(add_stats, all_stats[0])

    # Return only the metrics that were present in all independent stats counters
    # If `stdev` is present, then compute it from the variance
    return agg_stats.select(
        **{
            metric: agg_stats[metric] if metric != "stdev" else hl.
            sqrt(agg_stats.variance)
            for metric in metrics
        })
コード例 #34
0
 def f3stats(ht):
     return ht.aggregate(
         hl.struct(
             n=hl.agg.count_where(hl.is_defined(ht["feature3"])),
             med=hl.median(hl.agg.collect(ht["feature3"])),
         ))
コード例 #35
0
ファイル: import_gtf.py プロジェクト: joonan30/hail
def import_gtf(path,
               reference_genome=None,
               skip_invalid_contigs=False,
               min_partitions=None) -> hl.Table:
    """Import a GTF file.

       The GTF file format is identical to the GFF version 2 file format,
       and so this function can be used to import GFF version 2 files as
       well.

       See https://www.ensembl.org/info/website/upload/gff.html for more
       details on the GTF/GFF2 file format.

       The :class:`.Table` returned by this function will be keyed by the
       ``interval`` row field and will include the following row fields:

       .. code-block:: text

           'source': str
           'feature': str
           'score': float64
           'strand': str
           'frame': int32
           'interval': interval<>

       There will also be corresponding fields for every tag found in the
       attribute field of the GTF file.

       Note
       ----

       This function will return an ``interval`` field of type :class:`.tinterval`
       constructed from the ``seqname``, ``start``, and ``end`` fields in the
       GTF file. This interval is inclusive of both the start and end positions
       in the GTF file. 

       If the ``reference_genome`` parameter is specified, the start and end
       points of the ``interval`` field will be of type :class:`.tlocus`.
       Otherwise, the start and end points of the ``interval`` field will be of
       type :class:`.tstruct` with fields ``seqname`` (type :class:`str`) and
       ``position`` (type :class:`.tint32`).

       Furthermore, if the ``reference_genome`` parameter is specified and
       ``skip_invalid_contigs`` is ``True``, this import function will skip
       lines in the GTF where ``seqname`` is not consistent with the reference
       genome specified.

       Example
       -------

       >>> ht = hl.experimental.import_gtf('data/test.gtf', 
       ...                                 reference_genome='GRCh37',
       ...                                 skip_invalid_contigs=True)

       >>> ht.describe()  # doctest: +NOTEST
       ----------------------------------------
       Global fields:
       None
       ----------------------------------------
       Row fields:
           'source': str
           'feature': str
           'score': float64
           'strand': str
           'frame': int32
           'gene_type': str
           'exon_id': str
           'havana_transcript': str
           'level': str
           'transcript_name': str
           'gene_status': str
           'gene_id': str
           'transcript_type': str
           'tag': str
           'transcript_status': str
           'gene_name': str
           'transcript_id': str
           'exon_number': str
           'havana_gene': str
           'interval': interval<locus<GRCh37>>
       ----------------------------------------
       Key: ['interval']
       ----------------------------------------

       Parameters
       ----------

       path : :obj:`str`
           File to import.
       reference_genome : :obj:`str` or :class:`.ReferenceGenome`, optional
           Reference genome to use.
       skip_invalid_contigs : :obj:`bool`
           If ``True`` and `reference_genome` is not ``None``, skip lines where
           ``seqname`` is not consistent with the reference genome.
       min_partitions : :obj:`int` or :obj:`None`
           Minimum number of partitions (passed to import_table).

       Returns
       -------
       :class:`.Table`
       """

    ht = hl.import_table(path,
                         min_partitions=min_partitions,
                         comment='#',
                         no_header=True,
                         types={
                             'f3': hl.tint,
                             'f4': hl.tint,
                             'f5': hl.tfloat,
                             'f7': hl.tint
                         },
                         missing='.',
                         delimiter='\t')

    ht = ht.rename({
        'f0': 'seqname',
        'f1': 'source',
        'f2': 'feature',
        'f3': 'start',
        'f4': 'end',
        'f5': 'score',
        'f6': 'strand',
        'f7': 'frame',
        'f8': 'attribute'
    })

    ht = ht.annotate(attribute=hl.dict(
        hl.map(
            lambda x: (x.split(' ')[0], x.split(' ')[1].replace('"', '').
                       replace(';$', '')), ht['attribute'].split('; '))))

    attributes = ht.aggregate(
        hl.agg.explode(lambda x: hl.agg.collect_as_set(x),
                       ht['attribute'].keys()))

    ht = ht.transmute(
        **{
            x: hl.or_missing(ht['attribute'].contains(x), ht['attribute'][x])
            for x in attributes if x
        })

    if reference_genome:
        if reference_genome == 'GRCh37':
            ht = ht.annotate(seqname=ht['seqname'].replace('^chr', ''))
        else:
            ht = ht.annotate(seqname=hl.case().when(
                ht['seqname'].startswith('HLA'), ht['seqname']).when(
                    ht['seqname'].startswith('chrHLA'), ht['seqname'].replace(
                        '^chr', '')).when(ht['seqname'].startswith(
                            'chr'), ht['seqname']).default('chr' +
                                                           ht['seqname']))
        if skip_invalid_contigs:
            valid_contigs = hl.literal(
                set(hl.get_reference(reference_genome).contigs))
            ht = ht.filter(valid_contigs.contains(ht['seqname']))
        ht = ht.transmute(
            interval=hl.locus_interval(ht['seqname'],
                                       ht['start'],
                                       ht['end'],
                                       includes_start=True,
                                       includes_end=True,
                                       reference_genome=reference_genome))
    else:
        ht = ht.transmute(interval=hl.interval(
            hl.struct(seqname=ht['seqname'], position=ht['start']),
            hl.struct(seqname=ht['seqname'], position=ht['end']),
            includes_start=True,
            includes_end=True))

    ht = ht.key_by('interval')

    return ht
コード例 #36
0
ファイル: vcf_combiner.py プロジェクト: vedasha/hail
def reannotate(mt, gatk_ht, summ_ht):
    """Re-annotate a sparse MT with annotations from certain GATK tools

    `gatk_ht` should be a table from the rows of a VCF, with `info` having at least
    the following fields.  Be aware that fields not present in this list will
    be dropped.
    ```
        struct {
            AC: array<int32>,
            AF: array<float64>,
            AN: int32,
            BaseQRankSum: float64,
            ClippingRankSum: float64,
            DP: int32,
            FS: float64,
            MQ: float64,
            MQRankSum: float64,
            MQ_DP: int32,
            NEGATIVE_TRAIN_SITE: bool,
            POSITIVE_TRAIN_SITE: bool,
            QD: float64,
            QUALapprox: int32,
            RAW_MQ: float64,
            ReadPosRankSum: float64,
            SB_TABLE: array<int32>,
            SOR: float64,
            VQSLOD: float64,
            VarDP: int32,
            culprit: str
        }
    ```
    `summarize_ht` should be the output of :func:`.summarize` as a rows table.

    Note
    ----
    You will not be able to run :func:`.combine_gvcfs` with the output of this
    function.
    """
    def check(ht):
        keys = list(ht.key)
        if keys[0] != 'locus':
            raise TypeError(
                f'table inputs must have first key "locus", found {keys}')
        if keys != ['locus']:
            return hl.Table(TableKeyBy(ht._tir, ['locus'], is_sorted=True))
        return ht

    gatk_ht, summ_ht = [check(ht) for ht in (gatk_ht, summ_ht)]
    return mt.annotate_rows(
        info=hl.rbind(
            gatk_ht[mt.locus].info, summ_ht[mt.locus].info,
            lambda ginfo, hinfo: hl.struct(
                AC=hl.or_else(hinfo.AC, ginfo.AC),
                AF=hl.or_else(hinfo.AF, ginfo.AF),
                AN=hl.or_else(hinfo.AN, ginfo.AN),
                BaseQRankSum=hl.or_else(hinfo.BaseQRankSum, ginfo.BaseQRankSum
                                        ),
                ClippingRankSum=hl.or_else(hinfo.ClippingRankSum, ginfo.
                                           ClippingRankSum),
                DP=hl.or_else(hinfo.DP, ginfo.DP),
                FS=ginfo.FS,
                MQ=hl.or_else(hinfo.MQ, ginfo.MQ),
                MQRankSum=hl.or_else(hinfo.MQRankSum, ginfo.MQRankSum),
                MQ_DP=hl.or_else(hinfo.MQ_DP, ginfo.MQ_DP),
                NEGATIVE_TRAIN_SITE=ginfo.NEGATIVE_TRAIN_SITE,
                POSITIVE_TRAIN_SITE=ginfo.POSITIVE_TRAIN_SITE,
                QD=ginfo.QD,
                QUALapprox=hl.or_else(hinfo.QUALapprox, ginfo.QUALapprox),
                RAW_MQ=hl.or_else(hinfo.RAW_MQ, ginfo.RAW_MQ),
                ReadPosRankSum=hl.or_else(hinfo.ReadPosRankSum, ginfo.
                                          ReadPosRankSum),
                SB_TABLE=hl.or_else(hinfo.SB_TABLE, ginfo.SB_TABLE),
                SOR=ginfo.SOR,
                VQSLOD=ginfo.VQSLOD,
                VarDP=hl.or_else(hinfo.VarDP, ginfo.VarDP),
                culprit=ginfo.culprit,
            )),
        qual=gatk_ht[mt.locus].qual,
        filters=gatk_ht[mt.locus].filters,
    )
コード例 #37
0
ファイル: vcf_combiner.py プロジェクト: vedasha/hail
def transform_one(mt) -> Table:
    """transforms a gvcf into a form suitable for combining

    The input to this should be some result of either :func:`.import_vcf` or
    :func:`.import_vcfs` with `array_elements_required=False`.
    """
    mt = localize(mt)
    if mt.row.dtype not in _transform_rows_function_map:
        f = hl.experimental.define_function(
            lambda row: hl.rbind(
                hl.len(row.alleles),
                '<NON_REF>' == row.alleles[-1],
                lambda alleles_len, has_non_ref: hl.
                struct(locus=row.locus,
                       alleles=hl.cond(has_non_ref, row.alleles[:-1], row.
                                       alleles),
                       rsid=row.rsid,
                       info=row.info.annotate(SB_TABLE=hl.array([
                           hl.sum(row.__entries.map(lambda d: d.SB[0])),
                           hl.sum(row.__entries.map(lambda d: d.SB[1])),
                           hl.sum(row.__entries.map(lambda d: d.SB[2])),
                           hl.sum(row.__entries.map(lambda d: d.SB[3])),
                       ])).select(
                           "MQ_DP",
                           "QUALapprox",
                           "RAW_MQ",
                           "SB_TABLE",
                           "VarDP",
                       ),
                       __entries=row.__entries.map(lambda e: hl.struct(
                           BaseQRankSum=row.info['BaseQRankSum'],
                           ClippingRankSum=row.info['ClippingRankSum'],
                           DP=e.DP,
                           END=row.info.END,
                           GQ=e.GQ,
                           LA=hl.range(
                               0, alleles_len - hl.cond(has_non_ref, 1, 0)),
                           LAD=hl.cond(has_non_ref, e.AD[:-1], e.AD),
                           LGT=e.GT,
                           LPGT=e.PGT,
                           LPL=hl.cond(
                               has_non_ref,
                               hl.cond(alleles_len > 2, e.PL[:-alleles_len],
                                       hl.null(e.PL.dtype)),
                               hl.cond(alleles_len > 1, e.PL,
                                       hl.null(e.PL.dtype))),
                           MIN_DP=e.MIN_DP,
                           MQ=row.info['MQ'],
                           MQRankSum=row.info['MQRankSum'],
                           PID=e.PID,
                           RGQ=hl.cond(
                               has_non_ref, e.PL[hl.call(0, alleles_len - 1).
                                                 unphased_diploid_gt_index()],
                               hl.null(e.PL.dtype.element_type)),
                           ReadPosRankSum=row.info['ReadPosRankSum'],
                       ))),
            ), mt.row.dtype)
        _transform_rows_function_map[mt.row.dtype] = f
    transform_row = _transform_rows_function_map[mt.row.dtype]
    return Table(
        TableMapRows(mt._tir,
                     Apply(transform_row._name, TopLevelReference('row'))))
コード例 #38
0
ファイル: doctest_write_data.py プロジェクト: zscu/hail
ds = hl.import_vcf('data/sample.vcf.bgz')
ds = ds.sample_rows(0.03)
ds = ds.annotate_rows(use_as_marker=hl.rand_bool(0.5),
                      panel_maf=0.1,
                      anno1=5,
                      anno2=0,
                      consequence="LOF",
                      gene="A",
                      score=5.0)
ds = ds.annotate_rows(a_index=1)
ds = hl.sample_qc(hl.variant_qc(ds))
ds = ds.annotate_cols(is_case=True,
                      pheno=hl.struct(is_case=hl.rand_bool(0.5),
                                      is_female=hl.rand_bool(0.5),
                                      age=hl.rand_norm(65, 10),
                                      height=hl.rand_norm(70, 10),
                                      blood_pressure=hl.rand_norm(120, 20),
                                      cohort_name="cohort1"),
                      cov=hl.struct(PC1=hl.rand_norm(0, 1)),
                      cov1=hl.rand_norm(0, 1),
                      cov2=hl.rand_norm(0, 1),
                      cohort="SIGMA")
ds = ds.annotate_globals(
    global_field_1=5,
    global_field_2=10,
    pli={
        'SCN1A': 0.999,
        'SONIC': 0.014
    },
    populations=['AFR', 'EAS', 'EUR', 'SAS', 'AMR', 'HIS'])
ds = ds.annotate_rows(gene=['TTN'])
コード例 #39
0
def compute_stratified_metrics_filter(
    ht: hl.Table,
    qc_metrics: Dict[str, hl.expr.NumericExpression],
    strata: Optional[Dict[str, hl.expr.Expression]] = None,
    lower_threshold: float = 4.0,
    upper_threshold: float = 4.0,
    metric_threshold: Optional[Dict[str, Tuple[float, float]]] = None,
    filter_name: str = "qc_metrics_filters",
) -> hl.Table:
    """
    Compute median, MAD, and upper and lower thresholds for each metric used in outlier filtering.

    :param ht: HT containing relevant sample QC metric annotations
    :param qc_metrics: list of metrics (name and expr) for which to compute the critical values for filtering outliers
    :param strata: List of annotations used for stratification. These metrics should be discrete types!
    :param lower_threshold: Lower MAD threshold
    :param upper_threshold: Upper MAD threshold
    :param metric_threshold: Can be used to specify different (lower, upper) thresholds for one or more metrics
    :param filter_name: Name of resulting filters annotation
    :return: Table grouped by strata, with upper and lower threshold values computed for each sample QC metric
    """
    _metric_threshold = {
        metric: (lower_threshold, upper_threshold)
        for metric in qc_metrics
    }
    if metric_threshold is not None:
        _metric_threshold.update(metric_threshold)

    def make_filters_expr(ht: hl.Table,
                          qc_metrics: Iterable[str]) -> hl.expr.SetExpression:
        return hl.set(
            hl.filter(
                lambda x: hl.is_defined(x),
                [
                    hl.or_missing(ht[f"fail_{metric}"], metric)
                    for metric in qc_metrics
                ],
            ))

    if strata is None:
        strata = {}

    ht = ht.select(**qc_metrics, **strata).key_by("s").persist()

    agg_expr = hl.struct(
        **{
            metric: hl.bind(
                lambda x: x.annotate(
                    lower=x.median - _metric_threshold[metric][0] * x.mad,
                    upper=x.median + _metric_threshold[metric][1] * x.mad,
                ),
                get_median_and_mad_expr(ht[metric]),
            )
            for metric in qc_metrics
        })

    if strata:
        ht = ht.annotate_globals(qc_metrics_stats=ht.aggregate(
            hl.agg.group_by(hl.tuple([ht[x] for x in strata]), agg_expr),
            _localize=False,
        ))
        metrics_stats_expr = ht.qc_metrics_stats[hl.tuple(
            [ht[x] for x in strata])]
    else:
        ht = ht.annotate_globals(
            qc_metrics_stats=ht.aggregate(agg_expr, _localize=False))
        metrics_stats_expr = ht.qc_metrics_stats

    fail_exprs = {
        f"fail_{metric}": (ht[metric] <= metrics_stats_expr[metric].lower)
        | (ht[metric] >= metrics_stats_expr[metric].upper)
        for metric in qc_metrics
    }
    ht = ht.transmute(**fail_exprs)
    stratified_filters = make_filters_expr(ht, qc_metrics)
    return ht.annotate(**{filter_name: stratified_filters})
コード例 #40
0
def variant_qc(mt, name='variant_qc') -> MatrixTable:
    """Compute common variant statistics (quality control metrics).

    .. include:: ../_templates/req_tvariant.rst

    Examples
    --------

    >>> dataset_result = hl.variant_qc(dataset)

    Notes
    -----
    This method computes variant statistics from the genotype data, returning
    a new struct field `name` with the following metrics based on the fields
    present in the entry schema.

    If `mt` contains an entry field `DP` of type :py:data:`.tint32`, then the
    field `dp_stats` is computed. If `mt` contains an entry field `GQ` of type
    :py:data:`.tint32`, then the field `gq_stats` is computed. Both `dp_stats`
    and `gq_stats` are structs with with four fields:

    - `mean` (``float64``) -- Mean value.
    - `stdev` (``float64``) -- Standard deviation (zero degrees of freedom).
    - `min` (``int32``) -- Minimum value.
    - `max` (``int32``) -- Maximum value.

    If the dataset does not contain an entry field `GT` of type
    :py:data:`.tcall`, then an error is raised. The following fields are always
    computed from `GT`:

    - `AF` (``array<float64>``) -- Calculated allele frequency, one element
      per allele, including the reference. Sums to one. Equivalent to
      `AC` / `AN`.
    - `AC` (``array<int32>``) -- Calculated allele count, one element per
      allele, including the reference. Sums to `AN`.
    - `AN` (``int32``) -- Total number of called alleles.
    - `homozygote_count` (``array<int32>``) -- Number of homozygotes per
      allele. One element per allele, including the reference.
    - `call_rate` (``float64``) -- Fraction of calls neither missing nor filtered.
      Equivalent to `n_called` / :meth:`.count_cols`.
    - `n_called` (``int64``) -- Number of samples with a defined `GT`.
    - `n_not_called` (``int64``) -- Number of samples with a missing `GT`.
    - `n_filtered` (``int64``) -- Number of filtered entries.
    - `n_het` (``int64``) -- Number of heterozygous samples.
    - `n_non_ref` (``int64``) -- Number of samples with at least one called
      non-reference allele.
    - `het_freq_hwe` (``float64``) -- Expected frequency of heterozygous
      samples under Hardy-Weinberg equilibrium. See
      :func:`.functions.hardy_weinberg_test` for details.
    - `p_value_hwe` (``float64``) -- p-value from test of Hardy-Weinberg equilibrium.
      See :func:`.functions.hardy_weinberg_test` for details.

    Warning
    -------
    `het_freq_hwe` and `p_value_hwe` are calculated as in
    :func:`.functions.hardy_weinberg_test`, with non-diploid calls
    (``ploidy != 2``) ignored in the counts. As this test is only
    statistically rigorous in the biallelic setting, :func:`.variant_qc`
    sets both fields to missing for multiallelic variants. Consider using
    :func:`~hail.methods.split_multi` to split multi-allelic variants beforehand.

    Parameters
    ----------
    mt : :class:`.MatrixTable`
        Dataset.
    name : :obj:`str`
        Name for resulting field.

    Returns
    -------
    :class:`.MatrixTable`
    """
    require_row_key_variant(mt, 'variant_qc')

    bound_exprs = {}
    gq_dp_exprs = {}

    def has_field_of_type(name, dtype):
        return name in mt.entry and mt[name].dtype == dtype

    if has_field_of_type('DP', hl.tint32):
        gq_dp_exprs['dp_stats'] = hl.agg.stats(mt.DP).select(
            'mean', 'stdev', 'min', 'max')

    if has_field_of_type('GQ', hl.tint32):
        gq_dp_exprs['gq_stats'] = hl.agg.stats(mt.GQ).select(
            'mean', 'stdev', 'min', 'max')

    if not has_field_of_type('GT', hl.tcall):
        raise ValueError(
            "'variant_qc': expect an entry field 'GT' of type 'call'")

    bound_exprs['n_called'] = hl.agg.count_where(hl.is_defined(mt['GT']))
    bound_exprs['n_not_called'] = hl.agg.count_where(hl.is_missing(mt['GT']))
    n_cols_ref = hl.expr.construct_expr(
        hl.ir.Ref('n_cols'), hl.tint32, mt._row_indices,
        hl.utils.LinkedList(hl.expr.expressions.Aggregation))
    bound_exprs['n_filtered'] = hl.int64(n_cols_ref) - hl.agg.count()
    bound_exprs['call_stats'] = hl.agg.call_stats(mt.GT, mt.alleles)

    result = hl.rbind(
        hl.struct(**bound_exprs), lambda e1: hl.rbind(
            hl.case().when(
                hl.len(mt.alleles) == 2,
                hl.hardy_weinberg_test(
                    e1.call_stats.homozygote_count[0], e1.call_stats.AC[
                        1] - 2 * e1.call_stats.homozygote_count[1], e1.
                    call_stats.homozygote_count[1])).or_missing(), lambda hwe:
            hl.struct(
                **{
                    **gq_dp_exprs,
                    **e1.call_stats, 'call_rate':
                    hl.float(e1.n_called) /
                    (e1.n_called + e1.n_not_called + e1.n_filtered),
                    'n_called':
                    e1.n_called,
                    'n_not_called':
                    e1.n_not_called,
                    'n_filtered':
                    e1.n_filtered,
                    'n_het':
                    e1.n_called - hl.sum(e1.call_stats.homozygote_count),
                    'n_non_ref':
                    e1.n_called - e1.call_stats.homozygote_count[0],
                    'het_freq_hwe':
                    hwe.het_freq_hwe,
                    'p_value_hwe':
                    hwe.p_value
                })))

    return mt.annotate_rows(**{name: result})
コード例 #41
0
def sample_qc(mt, name='sample_qc') -> MatrixTable:
    """Compute per-sample metrics useful for quality control.

    .. include:: ../_templates/req_tvariant.rst

    Examples
    --------

    Compute sample QC metrics and remove low-quality samples:

    >>> dataset = hl.sample_qc(dataset, name='sample_qc')
    >>> filtered_dataset = dataset.filter_cols((dataset.sample_qc.dp_stats.mean > 20) & (dataset.sample_qc.r_ti_tv > 1.5))

    Notes
    -----

    This method computes summary statistics per sample from a genetic matrix and stores
    the results as a new column-indexed struct field in the matrix, named based on the
    `name` parameter.

    If `mt` contains an entry field `DP` of type :py:data:`.tint32`, then the
    field `dp_stats` is computed. If `mt` contains an entry field `GQ` of type
    :py:data:`.tint32`, then the field `gq_stats` is computed. Both `dp_stats`
    and `gq_stats` are structs with with four fields:

    - `mean` (``float64``) -- Mean value.
    - `stdev` (``float64``) -- Standard deviation (zero degrees of freedom).
    - `min` (``int32``) -- Minimum value.
    - `max` (``int32``) -- Maximum value.

    If the dataset does not contain an entry field `GT` of type
    :py:data:`.tcall`, then an error is raised. The following fields are always
    computed from `GT`:

    - `call_rate` (``float64``) -- Fraction of calls not missing or filtered.
      Equivalent to `n_called` divided by :meth:`.count_rows`.
    - `n_called` (``int64``) -- Number of non-missing calls.
    - `n_not_called` (``int64``) -- Number of missing calls.
    - `n_filtered` (``int64``) -- Number of filtered entries.
    - `n_hom_ref` (``int64``) -- Number of homozygous reference calls.
    - `n_het` (``int64``) -- Number of heterozygous calls.
    - `n_hom_var` (``int64``) -- Number of homozygous alternate calls.
    - `n_non_ref` (``int64``) -- Sum of `n_het` and `n_hom_var`.
    - `n_snp` (``int64``) -- Number of SNP alternate alleles.
    - `n_insertion` (``int64``) -- Number of insertion alternate alleles.
    - `n_deletion` (``int64``) -- Number of deletion alternate alleles.
    - `n_singleton` (``int64``) -- Number of private alleles.
    - `n_transition` (``int64``) -- Number of transition (A-G, C-T) alternate alleles.
    - `n_transversion` (``int64``) -- Number of transversion alternate alleles.
    - `n_star` (``int64``) -- Number of star (upstream deletion) alleles.
    - `r_ti_tv` (``float64``) -- Transition/Transversion ratio.
    - `r_het_hom_var` (``float64``) -- Het/HomVar call ratio.
    - `r_insertion_deletion` (``float64``) -- Insertion/Deletion allele ratio.

    Missing values ``NA`` may result from division by zero.

    Parameters
    ----------
    mt : :class:`.MatrixTable`
        Dataset.
    name : :obj:`str`
        Name for resulting field.

    Returns
    -------
    :class:`.MatrixTable`
        Dataset with a new column-indexed field `name`.
    """

    require_row_key_variant(mt, 'sample_qc')

    from hail.expr.functions import _num_allele_type, _allele_types

    allele_types = _allele_types[:]
    allele_types.extend(['Transition', 'Transversion'])
    allele_enum = {i: v for i, v in enumerate(allele_types)}
    allele_ints = {v: k for k, v in allele_enum.items()}

    def allele_type(ref, alt):
        return hl.bind(
            lambda at: hl.cond(
                at == allele_ints['SNP'],
                hl.cond(hl.is_transition(ref, alt), allele_ints['Transition'],
                        allele_ints['Transversion']), at),
            _num_allele_type(ref, alt))

    variant_ac = Env.get_uid()
    variant_atypes = Env.get_uid()
    mt = mt.annotate_rows(
        **{
            variant_ac:
            hl.agg.call_stats(mt.GT, mt.alleles).AC,
            variant_atypes:
            mt.alleles[1:].map(lambda alt: allele_type(mt.alleles[0], alt))
        })

    bound_exprs = {}
    gq_dp_exprs = {}

    def has_field_of_type(name, dtype):
        return name in mt.entry and mt[name].dtype == dtype

    if has_field_of_type('DP', hl.tint32):
        gq_dp_exprs['dp_stats'] = hl.agg.stats(mt.DP).select(
            'mean', 'stdev', 'min', 'max')

    if has_field_of_type('GQ', hl.tint32):
        gq_dp_exprs['gq_stats'] = hl.agg.stats(mt.GQ).select(
            'mean', 'stdev', 'min', 'max')

    if not has_field_of_type('GT', hl.tcall):
        raise ValueError(
            "'sample_qc': expect an entry field 'GT' of type 'call'")

    bound_exprs['n_called'] = hl.agg.count_where(hl.is_defined(mt['GT']))
    bound_exprs['n_not_called'] = hl.agg.count_where(hl.is_missing(mt['GT']))

    n_rows_ref = hl.expr.construct_expr(
        hl.ir.Ref('n_rows'), hl.tint64, mt._col_indices,
        hl.utils.LinkedList(hl.expr.expressions.Aggregation))
    bound_exprs['n_filtered'] = n_rows_ref - hl.agg.count()
    bound_exprs['n_hom_ref'] = hl.agg.count_where(mt['GT'].is_hom_ref())
    bound_exprs['n_het'] = hl.agg.count_where(mt['GT'].is_het())
    bound_exprs['n_singleton'] = hl.agg.sum(
        hl.sum(
            hl.range(0, mt['GT'].ploidy).map(
                lambda i: mt[variant_ac][mt['GT'][i]] == 1)))

    def get_allele_type(allele_idx):
        return hl.cond(allele_idx > 0, mt[variant_atypes][allele_idx - 1],
                       hl.null(hl.tint32))

    bound_exprs['allele_type_counts'] = hl.agg.explode(
        lambda elt: hl.agg.counter(elt),
        hl.range(0,
                 mt['GT'].ploidy).map(lambda i: get_allele_type(mt['GT'][i])))

    zero = hl.int64(0)

    result_struct = hl.rbind(
        hl.struct(**bound_exprs), lambda x: hl.rbind(
            hl.struct(
                **{
                    **gq_dp_exprs, 'call_rate':
                    hl.float64(x.n_called) /
                    (x.n_called + x.n_not_called + x.n_filtered),
                    'n_called':
                    x.n_called,
                    'n_not_called':
                    x.n_not_called,
                    'n_filtered':
                    x.n_filtered,
                    'n_hom_ref':
                    x.n_hom_ref,
                    'n_het':
                    x.n_het,
                    'n_hom_var':
                    x.n_called - x.n_hom_ref - x.n_het,
                    'n_non_ref':
                    x.n_called - x.n_hom_ref,
                    'n_singleton':
                    x.n_singleton,
                    'n_snp': (x.allele_type_counts.get(
                        allele_ints["Transition"], zero) + x.allele_type_counts
                              .get(allele_ints["Transversion"], zero)),
                    'n_insertion':
                    x.allele_type_counts.get(allele_ints["Insertion"], zero),
                    'n_deletion':
                    x.allele_type_counts.get(allele_ints["Deletion"], zero),
                    'n_transition':
                    x.allele_type_counts.get(allele_ints["Transition"], zero),
                    'n_transversion':
                    x.allele_type_counts.get(allele_ints["Transversion"], zero
                                             ),
                    'n_star':
                    x.allele_type_counts.get(allele_ints["Star"], zero)
                }), lambda s: s.annotate(r_ti_tv=divide_null(
                    hl.float64(s.n_transition), s.n_transversion),
                                         r_het_hom_var=divide_null(
                                             hl.float64(s.n_het), s.n_hom_var),
                                         r_insertion_deletion=divide_null(
                                             hl.float64(s.n_insertion), s.
                                             n_deletion))))

    mt = mt.annotate_cols(**{name: result_struct})
    mt = mt.drop(variant_ac, variant_atypes)

    return mt
コード例 #42
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--input-url",
        help="URL of ExAC sites VCF",
        default=
        "gs://gnomad-public/legacy/exac_browser/ExAC.r1.sites.vep.vcf.gz")
    parser.add_argument("--output-url",
                        help="URL to write Hail table to",
                        required=True)
    parser.add_argument("--subset",
                        help="Filter variants to this chrom:start-end range")
    args = parser.parse_args()

    hl.init(log="/tmp/hail.log")

    print("\n=== Importing VCF ===")

    ds = hl.import_vcf(args.input_url,
                       force_bgz=True,
                       min_partitions=2000,
                       skip_invalid_loci=True).rows()

    if args.subset:
        print(f"\n=== Filtering to interval {args.subset} ===")
        subset_interval = hl.parse_locus_interval(args.subset)
        ds = ds.filter(subset_interval.contains(ds.locus))

    print("\n=== Splitting multiallelic variants ===")

    ds = hl.split_multi(ds)

    ds = ds.repartition(2000, shuffle=True)

    # Get value corresponding to the split variant
    ds = ds.annotate(info=ds.info.annotate(
        **{
            field: hl.or_missing(hl.is_defined(ds.info[field]), ds.info[field][
                ds.a_index - 1])
            for field in PER_ALLELE_FIELDS
        }))

    # For DP_HIST and GQ_HIST, the first value in the array contains the histogram for all individuals,
    # which is the same in each alt allele's variant.
    ds = ds.annotate(info=ds.info.annotate(
        DP_HIST=hl.struct(all=ds.info.DP_HIST[0],
                          alt=ds.info.DP_HIST[ds.a_index]),
        GQ_HIST=hl.struct(all=ds.info.GQ_HIST[0],
                          alt=ds.info.GQ_HIST[ds.a_index]),
    ))

    ds = ds.cache()

    print("\n=== Munging data ===")

    # Convert "NA" and empty strings into null values
    # Convert fields in chunks to avoid "Method code too large" errors
    for i in range(0, len(SELECT_INFO_FIELDS), 10):
        ds = ds.annotate(info=ds.info.annotate(
            **{
                field: hl.or_missing(
                    hl.is_defined(ds.info[field]),
                    hl.bind(
                        lambda value: hl.cond(
                            (value == "") | (value == "NA"),
                            hl.null(ds.info[field].dtype), ds.info[field]),
                        hl.str(ds.info[field]),
                    ),
                )
                for field in SELECT_INFO_FIELDS[i:i + 10]
            }))

    # Convert field types
    ds = ds.annotate(info=ds.info.annotate(
        **{
            field: hl.cond(ds.info[field] == "", hl.null(hl.tint),
                           hl.int(ds.info[field]))
            for field in CONVERT_TO_INT_FIELDS
        }))
    ds = ds.annotate(info=ds.info.annotate(
        **{
            field: hl.cond(ds.info[field] == "", hl.null(hl.tfloat),
                           hl.float(ds.info[field]))
            for field in CONVERT_TO_FLOAT_FIELDS
        }))

    # Format VEP annotations to mimic the output of hail.vep
    ds = ds.annotate(info=ds.info.annotate(CSQ=ds.info.CSQ.map(
        lambda s: s.replace("%3A", ":").replace("%3B", ";").replace(
            "%3D", "=").replace("%25", "%").replace("%2C", ","))))
    ds = ds.annotate(vep=hl.struct(
        transcript_consequences=ds.info.CSQ.map(lambda csq_str: hl.bind(
            lambda csq_values: hl.struct(
                **{
                    field: hl.cond(csq_values[index] == "", hl.null(hl.tstr),
                                   csq_values[index])
                    for index, field in enumerate(VEP_FIELDS)
                }),
            csq_str.split("\\|"),
        )).filter(lambda annotation: annotation.Feature.startswith("ENST")).
        filter(lambda annotation: hl.int(annotation.ALLELE_NUM) == ds.a_index).
        map(lambda annotation: annotation.select(
            amino_acids=annotation.Amino_acids,
            biotype=annotation.BIOTYPE,
            canonical=annotation.CANONICAL == "YES",
            # cDNA_position may contain either "start-end" or, when start == end, "start"
            cdna_start=split_position_start(annotation.cDNA_position),
            cdna_end=split_position_end(annotation.cDNA_position),
            codons=annotation.Codons,
            consequence_terms=annotation.Consequence.split("&"),
            distance=hl.int(annotation.DISTANCE),
            domains=hl.or_missing(
                hl.is_defined(annotation.DOMAINS),
                annotation.DOMAINS.split("&").map(lambda d: hl.struct(
                    db=d.split(":")[0], name=d.split(":")[1])),
            ),
            exon=annotation.EXON,
            gene_id=annotation.Gene,
            gene_symbol=annotation.SYMBOL,
            gene_symbol_source=annotation.SYMBOL_SOURCE,
            hgnc_id=annotation.HGNC_ID,
            hgvsc=annotation.HGVSc,
            hgvsp=annotation.HGVSp,
            lof=annotation.LoF,
            lof_filter=annotation.LoF_filter,
            lof_flags=annotation.LoF_flags,
            lof_info=annotation.LoF_info,
            # PolyPhen field contains "polyphen_prediction(polyphen_score)"
            polyphen_prediction=hl.or_missing(
                hl.is_defined(annotation.PolyPhen),
                annotation.PolyPhen.split("\\(")[0]),
            protein_id=annotation.ENSP,
            # Protein_position may contain either "start-end" or, when start == end, "start"
            protein_start=split_position_start(annotation.Protein_position),
            protein_end=split_position_end(annotation.Protein_position),
            # SIFT field contains "sift_prediction(sift_score)"
            sift_prediction=hl.or_missing(hl.is_defined(annotation.SIFT),
                                          annotation.SIFT.split("\\(")[0]),
            transcript_id=annotation.Feature,
        ))))

    ds = ds.annotate(vep=ds.vep.annotate(most_severe_consequence=hl.bind(
        lambda all_consequence_terms: hl.or_missing(
            all_consequence_terms.size() != 0,
            hl.sorted(all_consequence_terms, key=consequence_term_rank)[0]),
        ds.vep.transcript_consequences.flatmap(lambda c: c.consequence_terms),
    )))

    ds = ds.cache()

    print("\n=== Adding derived fields ===")

    ds = ds.annotate(
        sorted_transcript_consequences=sorted_transcript_consequences_v3(
            ds.vep))

    ds = ds.select(
        "filters",
        "qual",
        "rsid",
        "sorted_transcript_consequences",
        AC=ds.info.AC,
        AC_Adj=ds.info.AC_Adj,
        AC_Hemi=ds.info.AC_Hemi,
        AC_Hom=ds.info.AC_Hom,
        AF=ds.info.AF,
        AN=ds.info.AN,
        AN_Adj=ds.info.AN_Adj,
        BaseQRankSum=ds.info.BaseQRankSum,
        CCC=ds.info.CCC,
        ClippingRankSum=ds.info.ClippingRankSum,
        DB=ds.info.DB,
        DP=ds.info.DP,
        DS=ds.info.DS,
        END=ds.info.END,
        FS=ds.info.FS,
        GQ_MEAN=ds.info.GQ_MEAN,
        GQ_STDDEV=ds.info.GQ_STDDEV,
        HWP=ds.info.HWP,
        HaplotypeScore=ds.info.HaplotypeScore,
        InbreedingCoeff=ds.info.InbreedingCoeff,
        MLEAC=ds.info.MLEAC,
        MLEAF=ds.info.MLEAF,
        MQ=ds.info.MQ,
        MQ0=ds.info.MQ0,
        MQRankSum=ds.info.MQRankSum,
        NCC=ds.info.NCC,
        NEGATIVE_TRAIN_SITE=ds.info.NEGATIVE_TRAIN_SITE,
        POSITIVE_TRAIN_SITE=ds.info.POSITIVE_TRAIN_SITE,
        QD=ds.info.QD,
        ReadPosRankSum=ds.info.ReadPosRankSum,
        VQSLOD=ds.info.VQSLOD,
        culprit=ds.info.culprit,
        DP_HIST=ds.info.DP_HIST,
        GQ_HIST=ds.info.GQ_HIST,
        DOUBLETON_DIST=ds.info.DOUBLETON_DIST,
        AC_CONSANGUINEOUS=ds.info.AC_CONSANGUINEOUS,
        AN_CONSANGUINEOUS=ds.info.AN_CONSANGUINEOUS,
        Hom_CONSANGUINEOUS=ds.info.Hom_CONSANGUINEOUS,
        AGE_HISTOGRAM_HET=ds.info.AGE_HISTOGRAM_HET,
        AGE_HISTOGRAM_HOM=ds.info.AGE_HISTOGRAM_HOM,
        AC_POPMAX=ds.info.AC_POPMAX,
        AN_POPMAX=ds.info.AN_POPMAX,
        POPMAX=ds.info.POPMAX,
        K1_RUN=ds.info.K1_RUN,
        K2_RUN=ds.info.K2_RUN,
        K3_RUN=ds.info.K3_RUN,
        ESP_AF_POPMAX=ds.info.ESP_AF_POPMAX,
        ESP_AF_GLOBAL=ds.info.ESP_AF_GLOBAL,
        ESP_AC=ds.info.ESP_AC,
        KG_AF_POPMAX=ds.info.KG_AF_POPMAX,
        KG_AF_GLOBAL=ds.info.KG_AF_GLOBAL,
        KG_AC=ds.info.KG_AC,
        AC_FEMALE=ds.info.AC_FEMALE,
        AN_FEMALE=ds.info.AN_FEMALE,
        AC_MALE=ds.info.AC_MALE,
        AN_MALE=ds.info.AN_MALE,
        populations=hl.struct(
            **{
                pop_id: hl.struct(
                    AC=ds.info[f"AC_{pop_id}"],
                    AN=ds.info[f"AN_{pop_id}"],
                    hemi=ds.info[f"Hemi_{pop_id}"],
                    hom=ds.info[f"Hom_{pop_id}"],
                )
                for pop_id in
                ["AFR", "AMR", "EAS", "FIN", "NFE", "OTH", "SAS"]
            }),
        colocated_variants=hl.bind(
            lambda this_variant_id: variant_ids(ds.old_locus, ds.old_alleles).
            filter(lambda v_id: v_id != this_variant_id),
            variant_id(ds.locus, ds.alleles),
        ),
        variant_id=variant_id(ds.locus, ds.alleles),
        xpos=x_position(ds.locus),
    )

    print("\n=== Writing table ===")

    ds.write(args.output_url)
コード例 #43
0
def ld_score_regression(weight_expr,
                        ld_score_expr,
                        chi_sq_exprs,
                        n_samples_exprs,
                        n_blocks=200,
                        two_step_threshold=30,
                        n_reference_panel_variants=None) -> Table:
    r"""Estimate SNP-heritability and level of confounding biases from
    GWAS summary statistics.

    Given a set or multiple sets of genome-wide association study (GWAS)
    summary statistics, :func:`.ld_score_regression` estimates the heritability
    of a trait or set of traits and the level of confounding biases present in
    the underlying studies by regressing chi-squared statistics on LD scores,
    leveraging the model:

    .. math::

        \mathrm{E}[\chi_j^2] = 1 + Na + \frac{Nh_g^2}{M}l_j

    *  :math:`\mathrm{E}[\chi_j^2]` is the expected chi-squared statistic
       for variant :math:`j` resulting from a test of association between
       variant :math:`j` and a trait.
    *  :math:`l_j = \sum_{k} r_{jk}^2` is the LD score of variant
       :math:`j`, calculated as the sum of squared correlation coefficients
       between variant :math:`j` and nearby variants. See :func:`ld_score`
       for further details.
    *  :math:`a` captures the contribution of confounding biases, such as
       cryptic relatedness and uncontrolled population structure, to the
       association test statistic.
    *  :math:`h_g^2` is the SNP-heritability, or the proportion of variation
       in the trait explained by the effects of variants included in the
       regression model above.
    *  :math:`M` is the number of variants used to estimate :math:`h_g^2`.
    *  :math:`N` is the number of samples in the underlying association study.

    For more details on the method implemented in this function, see:

    * `LD Score regression distinguishes confounding from polygenicity in genome-wide association studies (Bulik-Sullivan et al, 2015) <https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4495769/>`__

    Examples
    --------

    Run the method on a matrix table of summary statistics, where the rows
    are variants and the columns are different phenotypes:

    >>> mt_gwas = ld_score_all_phenos_sumstats
    >>> ht_results = hl.experimental.ld_score_regression(
    ...     weight_expr=mt_gwas['ld_score'],
    ...     ld_score_expr=mt_gwas['ld_score'],
    ...     chi_sq_exprs=mt_gwas['chi_squared'],
    ...     n_samples_exprs=mt_gwas['n'])


    Run the method on a table with summary statistics for a single
    phenotype:

    >>> ht_gwas = ld_score_one_pheno_sumstats
    >>> ht_results = hl.experimental.ld_score_regression(
    ...     weight_expr=ht_gwas['ld_score'],
    ...     ld_score_expr=ht_gwas['ld_score'],
    ...     chi_sq_exprs=ht_gwas['chi_squared_50_irnt'],
    ...     n_samples_exprs=ht_gwas['n_50_irnt'])

    Run the method on a table with summary statistics for multiple
    phenotypes:

    >>> ht_gwas = ld_score_one_pheno_sumstats
    >>> ht_results = hl.experimental.ld_score_regression(
    ...     weight_expr=ht_gwas['ld_score'],
    ...     ld_score_expr=ht_gwas['ld_score'],
    ...     chi_sq_exprs=[ht_gwas['chi_squared_50_irnt'],
    ...                        ht_gwas['chi_squared_20160']],
    ...     n_samples_exprs=[ht_gwas['n_50_irnt'],
    ...                      ht_gwas['n_20160']])

    Notes
    -----
    The ``exprs`` provided as arguments to :func:`.ld_score_regression`
    must all be from the same object, either a :class:`Table` or a
    :class:`MatrixTable`.

    **If the arguments originate from a table:**

    *  The table must be keyed by fields ``locus`` of type
       :class:`.tlocus` and ``alleles``, a :py:data:`.tarray` of
       :py:data:`.tstr` elements.
    *  ``weight_expr``, ``ld_score_expr``, ``chi_sq_exprs``, and
       ``n_samples_exprs`` are must be row-indexed fields.
    *  The number of expressions passed to ``n_samples_exprs`` must be
       equal to one or the number of expressions passed to
       ``chi_sq_exprs``. If just one expression is passed to
       ``n_samples_exprs``, that sample size expression is assumed to
       apply to all sets of statistics passed to ``chi_sq_exprs``.
       Otherwise, the expressions passed to ``chi_sq_exprs`` and
       ``n_samples_exprs`` are matched by index.
    *  The ``phenotype`` field that keys the table returned by
       :func:`.ld_score_regression` will have generic :obj:`int` values
       ``0``, ``1``, etc. corresponding to the ``0th``, ``1st``, etc.
       expressions passed to the ``chi_sq_exprs`` argument.

    **If the arguments originate from a matrix table:**

    *  The dimensions of the matrix table must be variants
       (rows) by phenotypes (columns).
    *  The rows of the matrix table must be keyed by fields
       ``locus`` of type :class:`.tlocus` and ``alleles``,
       a :py:data:`.tarray` of :py:data:`.tstr` elements.
    *  The columns of the matrix table must be keyed by a field
       of type :py:data:`.tstr` that uniquely identifies phenotypes
       represented in the matrix table. The column key must be a single
       expression; compound keys are not accepted.
    *  ``weight_expr`` and ``ld_score_expr`` must be row-indexed
       fields.
    *  ``chi_sq_exprs`` must be a single entry-indexed field
       (not a list of fields).
    *  ``n_samples_exprs`` must be a single entry-indexed field
       (not a list of fields).
    *  The ``phenotype`` field that keys the table returned by
       :func:`.ld_score_regression` will have values corresponding to the
       column keys of the input matrix table.

    This function returns a :class:`Table` with one row per set of summary
    statistics passed to the ``chi_sq_exprs`` argument. The following
    row-indexed fields are included in the table:

    *  **phenotype** (:py:data:`.tstr`) -- The name of the phenotype. The
       returned table is keyed by this field. See the notes below for
       details on the possible values of this field.
    *  **mean_chi_sq** (:py:data:`.tfloat64`) -- The mean chi-squared
       test statistic for the given phenotype.
    *  **intercept** (`Struct`) -- Contains fields:

       -  **estimate** (:py:data:`.tfloat64`) -- A point estimate of the
          intercept :math:`1 + Na`.
       -  **standard_error**  (:py:data:`.tfloat64`) -- An estimate of
          the standard error of this point estimate.

    *  **snp_heritability** (`Struct`) -- Contains fields:

       -  **estimate** (:py:data:`.tfloat64`) -- A point estimate of the
          SNP-heritability :math:`h_g^2`.
       -  **standard_error** (:py:data:`.tfloat64`) -- An estimate of
          the standard error of this point estimate.

    Warning
    -------
    :func:`.ld_score_regression` considers only the rows for which both row
    fields ``weight_expr`` and ``ld_score_expr`` are defined. Rows with missing
    values in either field are removed prior to fitting the LD score
    regression model.

    Parameters
    ----------
    weight_expr : :class:`.Float64Expression`
                  Row-indexed expression for the LD scores used to derive
                  variant weights in the model.
    ld_score_expr : :class:`.Float64Expression`
                    Row-indexed expression for the LD scores used as covariates
                    in the model.
    chi_sq_exprs : :class:`.Float64Expression` or :obj:`list` of
                        :class:`.Float64Expression`
                        One or more row-indexed (if table) or entry-indexed
                        (if matrix table) expressions for chi-squared
                        statistics resulting from genome-wide association
                        studies.
    n_samples_exprs: :class:`.NumericExpression` or :obj:`list` of
                     :class:`.NumericExpression`
                     One or more row-indexed (if table) or entry-indexed
                     (if matrix table) expressions indicating the number of
                     samples used in the studies that generated the test
                     statistics supplied to ``chi_sq_exprs``.
    n_blocks : :obj:`int`
               The number of blocks used in the jackknife approach to
               estimating standard errors.
    two_step_threshold : :obj:`int`
                         Variants with chi-squared statistics greater than this
                         value are excluded in the first step of the two-step
                         procedure used to fit the model.
    n_reference_panel_variants : :obj:`int`, optional
                                 Number of variants used to estimate the
                                 SNP-heritability :math:`h_g^2`.

    Returns
    -------
    :class:`.Table`
        Table keyed by ``phenotype`` with intercept and heritability estimates
        for each phenotype passed to the function."""

    chi_sq_exprs = wrap_to_list(chi_sq_exprs)
    n_samples_exprs = wrap_to_list(n_samples_exprs)

    assert ((len(chi_sq_exprs) == len(n_samples_exprs))
            or (len(n_samples_exprs) == 1))
    __k = 2  # number of covariates, including intercept

    ds = chi_sq_exprs[0]._indices.source

    analyze('ld_score_regression/weight_expr', weight_expr, ds._row_indices)
    analyze('ld_score_regression/ld_score_expr', ld_score_expr,
            ds._row_indices)

    # format input dataset
    if isinstance(ds, MatrixTable):
        if len(chi_sq_exprs) != 1:
            raise ValueError("""Only one chi_sq_expr allowed if originating
                from a matrix table.""")
        if len(n_samples_exprs) != 1:
            raise ValueError("""Only one n_samples_expr allowed if
                originating from a matrix table.""")

        col_key = list(ds.col_key)
        if len(col_key) != 1:
            raise ValueError("""Matrix table must be keyed by a single
                phenotype field.""")

        analyze('ld_score_regression/chi_squared_expr', chi_sq_exprs[0],
                ds._entry_indices)
        analyze('ld_score_regression/n_samples_expr', n_samples_exprs[0],
                ds._entry_indices)

        ds = ds._select_all(row_exprs={
            '__locus': ds.locus,
            '__alleles': ds.alleles,
            '__w_initial': weight_expr,
            '__w_initial_floor': hl.max(weight_expr, 1.0),
            '__x': ld_score_expr,
            '__x_floor': hl.max(ld_score_expr, 1.0)
        },
                            row_key=['__locus', '__alleles'],
                            col_exprs={'__y_name': ds[col_key[0]]},
                            col_key=['__y_name'],
                            entry_exprs={
                                '__y': chi_sq_exprs[0],
                                '__n': n_samples_exprs[0]
                            })
        ds = ds.annotate_entries(**{'__w': ds.__w_initial})

        ds = ds.filter_rows(
            hl.is_defined(ds.__locus) & hl.is_defined(ds.__alleles)
            & hl.is_defined(ds.__w_initial) & hl.is_defined(ds.__x))

    else:
        assert isinstance(ds, Table)
        for y in chi_sq_exprs:
            analyze('ld_score_regression/chi_squared_expr', y, ds._row_indices)
        for n in n_samples_exprs:
            analyze('ld_score_regression/n_samples_expr', n, ds._row_indices)

        ys = ['__y{:}'.format(i) for i, _ in enumerate(chi_sq_exprs)]
        ws = ['__w{:}'.format(i) for i, _ in enumerate(chi_sq_exprs)]
        ns = ['__n{:}'.format(i) for i, _ in enumerate(n_samples_exprs)]

        ds = ds.select(**dict(
            **{
                '__locus': ds.locus,
                '__alleles': ds.alleles,
                '__w_initial': weight_expr,
                '__x': ld_score_expr
            }, **{y: chi_sq_exprs[i]
                  for i, y in enumerate(ys)}, **{w: weight_expr
                                                 for w in ws}, **
            {n: n_samples_exprs[i]
             for i, n in enumerate(ns)}))
        ds = ds.key_by(ds.__locus, ds.__alleles)

        table_tmp_file = new_temp_file()
        ds.write(table_tmp_file)
        ds = hl.read_table(table_tmp_file)

        hts = [
            ds.select(
                **{
                    '__w_initial': ds.__w_initial,
                    '__w_initial_floor': hl.max(ds.__w_initial, 1.0),
                    '__x': ds.__x,
                    '__x_floor': hl.max(ds.__x, 1.0),
                    '__y_name': i,
                    '__y': ds[ys[i]],
                    '__w': ds[ws[i]],
                    '__n': hl.int(ds[ns[i]])
                }) for i, y in enumerate(ys)
        ]

        mts = [
            ht.to_matrix_table(row_key=['__locus', '__alleles'],
                               col_key=['__y_name'],
                               row_fields=[
                                   '__w_initial', '__w_initial_floor', '__x',
                                   '__x_floor'
                               ]) for ht in hts
        ]

        ds = mts[0]
        for i in range(1, len(ys)):
            ds = ds.union_cols(mts[i])

        ds = ds.filter_rows(
            hl.is_defined(ds.__locus) & hl.is_defined(ds.__alleles)
            & hl.is_defined(ds.__w_initial) & hl.is_defined(ds.__x))

    mt_tmp_file1 = new_temp_file()
    ds.write(mt_tmp_file1)
    mt = hl.read_matrix_table(mt_tmp_file1)

    if not n_reference_panel_variants:
        M = mt.count_rows()
    else:
        M = n_reference_panel_variants

    # block variants for each phenotype
    n_phenotypes = mt.count_cols()

    mt = mt.annotate_entries(__in_step1=(hl.is_defined(mt.__y) &
                                         (mt.__y < two_step_threshold)),
                             __in_step2=hl.is_defined(mt.__y))

    mt = mt.annotate_cols(__col_idx=hl.int(hl.scan.count()),
                          __m_step1=hl.agg.count_where(mt.__in_step1),
                          __m_step2=hl.agg.count_where(mt.__in_step2))

    col_keys = list(mt.col_key)

    ht = mt.localize_entries(entries_array_field_name='__entries',
                             columns_array_field_name='__cols')

    ht = ht.annotate(__entries=hl.rbind(
        hl.scan.array_agg(lambda entry: hl.scan.count_where(entry.__in_step1),
                          ht.__entries),
        lambda step1_indices: hl.map(
            lambda i: hl.rbind(
                hl.int(hl.or_else(step1_indices[i], 0)), ht.__cols[
                    i].__m_step1, ht.__entries[i], lambda step1_idx, m_step1,
                entry: hl.rbind(
                    hl.map(
                        lambda j: hl.int(hl.floor(j * (m_step1 / n_blocks))),
                        hl.range(0, n_blocks + 1)), lambda step1_separators: hl
                    .rbind(
                        hl.set(step1_separators).contains(step1_idx),
                        hl.sum(
                            hl.map(lambda s1: step1_idx >= s1, step1_separators
                                   )) - 1, lambda is_separator, step1_block:
                        entry.annotate(__step1_block=step1_block,
                                       __step2_block=hl.cond(
                                           ~entry.__in_step1 & is_separator,
                                           step1_block - 1, step1_block))))),
            hl.range(0, hl.len(ht.__entries)))))

    mt = ht._unlocalize_entries('__entries', '__cols', col_keys)

    mt_tmp_file2 = new_temp_file()
    mt.write(mt_tmp_file2)
    mt = hl.read_matrix_table(mt_tmp_file2)

    # initial coefficient estimates
    mt = mt.annotate_cols(__initial_betas=[
        1.0, (hl.agg.mean(mt.__y) - 1.0) / hl.agg.mean(mt.__x)
    ])
    mt = mt.annotate_cols(__step1_betas=mt.__initial_betas,
                          __step2_betas=mt.__initial_betas)

    # step 1 iteratively reweighted least squares
    for i in range(3):
        mt = mt.annotate_entries(__w=hl.cond(
            mt.__in_step1, 1.0 /
            (mt.__w_initial_floor * 2.0 *
             (mt.__step1_betas[0] + mt.__step1_betas[1] * mt.__x_floor)**2),
            0.0))
        mt = mt.annotate_cols(__step1_betas=hl.agg.filter(
            mt.__in_step1,
            hl.agg.linreg(y=mt.__y, x=[1.0, mt.__x], weight=mt.__w).beta))
        mt = mt.annotate_cols(__step1_h2=hl.max(
            hl.min(mt.__step1_betas[1] * M / hl.agg.mean(mt.__n), 1.0), 0.0))
        mt = mt.annotate_cols(__step1_betas=[
            mt.__step1_betas[0], mt.__step1_h2 * hl.agg.mean(mt.__n) / M
        ])

    # step 1 block jackknife
    mt = mt.annotate_cols(__step1_block_betas=hl.agg.array_agg(
        lambda i: hl.agg.filter(
            (mt.__step1_block != i) & mt.__in_step1,
            hl.agg.linreg(y=mt.__y, x=[1.0, mt.__x], weight=mt.__w).beta),
        hl.range(n_blocks)))

    mt = mt.annotate_cols(__step1_block_betas_bias_corrected=hl.map(
        lambda x: n_blocks * mt.__step1_betas - (n_blocks - 1) * x,
        mt.__step1_block_betas))

    mt = mt.annotate_cols(
        __step1_jackknife_mean=hl.map(
            lambda i: hl.mean(
                hl.map(lambda x: x[i], mt.__step1_block_betas_bias_corrected)),
            hl.range(0, __k)),
        __step1_jackknife_variance=hl.map(
            lambda i: (hl.sum(
                hl.map(lambda x: x[i]**2, mt.__step1_block_betas_bias_corrected
                       )) - hl.sum(
                           hl.map(lambda x: x[i], mt.
                                  __step1_block_betas_bias_corrected))**
                       2 / n_blocks) / (n_blocks - 1) / n_blocks,
            hl.range(0, __k)))

    # step 2 iteratively reweighted least squares
    for i in range(3):
        mt = mt.annotate_entries(__w=hl.cond(
            mt.__in_step2, 1.0 /
            (mt.__w_initial_floor * 2.0 *
             (mt.__step2_betas[0] + mt.__step2_betas[1] * mt.__x_floor)**2),
            0.0))
        mt = mt.annotate_cols(__step2_betas=[
            mt.__step1_betas[0],
            hl.agg.filter(
                mt.__in_step2,
                hl.agg.linreg(
                    y=mt.__y -
                    mt.__step1_betas[0], x=[mt.__x], weight=mt.__w).beta[0])
        ])
        mt = mt.annotate_cols(__step2_h2=hl.max(
            hl.min(mt.__step2_betas[1] * M / hl.agg.mean(mt.__n), 1.0), 0.0))
        mt = mt.annotate_cols(__step2_betas=[
            mt.__step1_betas[0], mt.__step2_h2 * hl.agg.mean(mt.__n) / M
        ])

    # step 2 block jackknife
    mt = mt.annotate_cols(__step2_block_betas=hl.agg.array_agg(
        lambda i: hl.agg.filter((mt.__step2_block != i) & mt.__in_step2,
                                hl.agg.linreg(y=mt.__y - mt.__step1_betas[0],
                                              x=[mt.__x],
                                              weight=mt.__w).beta[0]),
        hl.range(n_blocks)))

    mt = mt.annotate_cols(__step2_block_betas_bias_corrected=hl.map(
        lambda x: n_blocks * mt.__step2_betas[1] - (n_blocks - 1) * x,
        mt.__step2_block_betas))

    mt = mt.annotate_cols(
        __step2_jackknife_mean=hl.mean(mt.__step2_block_betas_bias_corrected),
        __step2_jackknife_variance=(
            hl.sum(mt.__step2_block_betas_bias_corrected**2) -
            hl.sum(mt.__step2_block_betas_bias_corrected)**2 / n_blocks) /
        (n_blocks - 1) / n_blocks)

    # combine step 1 and step 2 block jackknifes
    mt = mt.annotate_entries(
        __step2_initial_w=1.0 /
        (mt.__w_initial_floor * 2.0 *
         (mt.__initial_betas[0] + mt.__initial_betas[1] * mt.__x_floor)**2))

    mt = mt.annotate_cols(
        __final_betas=[mt.__step1_betas[0], mt.__step2_betas[1]],
        __c=(hl.agg.sum(mt.__step2_initial_w * mt.__x) /
             hl.agg.sum(mt.__step2_initial_w * mt.__x**2)))

    mt = mt.annotate_cols(__final_block_betas=hl.map(
        lambda i: (mt.__step2_block_betas[i] - mt.__c *
                   (mt.__step1_block_betas[i][0] - mt.__final_betas[0])),
        hl.range(0, n_blocks)))

    mt = mt.annotate_cols(__final_block_betas_bias_corrected=(
        n_blocks * mt.__final_betas[1] -
        (n_blocks - 1) * mt.__final_block_betas))

    mt = mt.annotate_cols(
        __final_jackknife_mean=[
            mt.__step1_jackknife_mean[0],
            hl.mean(mt.__final_block_betas_bias_corrected)
        ],
        __final_jackknife_variance=[
            mt.__step1_jackknife_variance[0],
            (hl.sum(mt.__final_block_betas_bias_corrected**2) -
             hl.sum(mt.__final_block_betas_bias_corrected)**2 / n_blocks) /
            (n_blocks - 1) / n_blocks
        ])

    # convert coefficient to heritability estimate
    mt = mt.annotate_cols(
        phenotype=mt.__y_name,
        mean_chi_sq=hl.agg.mean(mt.__y),
        intercept=hl.struct(estimate=mt.__final_betas[0],
                            standard_error=hl.sqrt(
                                mt.__final_jackknife_variance[0])),
        snp_heritability=hl.struct(
            estimate=(M / hl.agg.mean(mt.__n)) * mt.__final_betas[1],
            standard_error=hl.sqrt((M / hl.agg.mean(mt.__n))**2 *
                                   mt.__final_jackknife_variance[1])))

    # format and return results
    ht = mt.cols()
    ht = ht.key_by(ht.phenotype)
    ht = ht.select(ht.mean_chi_sq, ht.intercept, ht.snp_heritability)

    ht_tmp_file = new_temp_file()
    ht.write(ht_tmp_file)
    ht = hl.read_table(ht_tmp_file)

    return ht
コード例 #44
0
mt.describe()

mt_split = hl.split_multi(mt)
mt_split = mt_split.select_entries(
    GT=hl.downcode(mt_split.GT, mt_split.a_index))
mt_split = mt_split.annotate_rows(info=hl.struct(
    END=mt_split.info.END,
    SVTYPE=mt_split.info.SVTYPE,
    AA=mt_split.info.AA,
    AC=mt_split.info.AC[mt_split.a_index - 1],
    AF=mt_split.info.AF[mt_split.a_index - 1],
    NS=mt_split.info.NS,
    AN=mt_split.info.AN,
    EAS_AF=mt_split.info.EAS_AF[mt_split.a_index - 1],
    EUR_AF=mt_split.info.EUR_AF[mt_split.a_index - 1],
    AFR_AF=mt_split.info.AFR_AF[mt_split.a_index - 1],
    AMR_AF=mt_split.info.AMR_AF[mt_split.a_index - 1],
    SAS_AF=mt_split.info.SAS_AF[mt_split.a_index - 1],
    VT=(hl.case().when((mt_split.alleles[0].length() == 1)
                       & (mt_split.alleles[1].length() == 1), 'SNP').when(
                           mt_split.alleles[0].matches('<CN*>')
                           | mt_split.alleles[1].matches('<CN*>'),
                           'SV').default('INDEL')),
    EX_TARGET=mt_split.info.EX_TARGET,
    MULTI_ALLELIC=mt_split.info.MULTI_ALLELIC,
    DP=mt_split.info.DP))
mt_split.describe()
mt_split = mt_split.drop('old_locus', 'old_alleles', 'a_index')

mt_split = mt_split.annotate_cols(
    sex=ht_samples[mt_split.s].gender,
コード例 #45
0
def default_generate_gene_lof_summary(
    mt: hl.MatrixTable,
    collapse_indels: bool = False,
    tx: bool = False,
    lof_csq_set: Set[str] = LOF_CSQ_SET,
    meta_root: str = "meta",
    pop_field: str = "pop",
    filter_loftee: bool = False,
) -> hl.Table:
    """
    Generate summary counts for loss-of-function (LoF), missense, and synonymous variants.

    Also calculates p, proportion of of haplotypes carrying a putative LoF (pLoF) variant,
    and observed/expected (OE) ratio of samples with homozygous pLoF variant calls.

    Summary counts are (all per gene):
        - Number of samples with no pLoF variants.
        - Number of samples with heterozygous pLoF variants.
        - Number of samples with homozygous pLoF variants.
        - Total number of sites with genotype calls.
        - All of the above stats grouped by population.

    Assumes MT was created using `default_generate_gene_lof_matrix`.

    .. note::
        Assumes LoF variants in MT were filtered (LOFTEE pass and no LoF flag only).
        If LoF variants have not been filtered and `filter_loftee` is True,
        expects MT has the row annotation `vep`.

    :param mt: Input MatrixTable.
    :param collapse_indels: Whether to collapse indels. Default is False.
    :param tx: Whether input MT has transcript expression data. Default is False.
    :param lof_csq_set: Set containing LoF transcript consequence strings. Default is LOF_CSQ_SET.
    :param meta_root: String indicating top level name for sample metadata. Default is 'meta'.
    :param pop_field: String indiciating field with sample population assignment information. Default is 'pop'.
    :param filter_loftee: Filters to LOFTEE pass variants (and no LoF flags) only. Default is False.
    :return: Table with het/hom summary counts.
    """
    if collapse_indels:
        grouping = ["gene_id", "gene", "most_severe_consequence"]
        if tx:
            grouping.append("expressed")
        else:
            grouping.extend(["transcript_id", "canonical"])
        mt = (
            mt.group_rows_by(*grouping)
            .aggregate_rows(
                n_sites=hl.agg.sum(mt.n_sites),
                n_sites_array=hl.agg.array_sum(mt.n_sites_array),
                classic_caf=hl.agg.sum(mt.classic_caf),
                max_af=hl.agg.max(mt.max_af),
                classic_caf_array=hl.agg.array_sum(mt.classic_caf_array),
            )
            .aggregate_entries(
                num_homs=hl.agg.sum(mt.num_homs),
                num_hets=hl.agg.sum(mt.num_hets),
                defined_sites=hl.agg.sum(mt.defined_sites),
            )
            .result()
        )

    if filter_loftee:
        lof_ht = get_most_severe_consequence_for_summary(mt.rows())
        mt = mt.filter_rows(
            hl.is_defined(lof_ht[mt.row_key].lof)
            & (lof_ht[mt.row_key].lof == "HC")
            & (lof_ht[mt.row_key].no_lof_flags)
        )

    ht = mt.annotate_rows(
        lof=hl.struct(
            **get_het_hom_summary_dict(
                csq_set=lof_csq_set,
                most_severe_csq_expr=mt.most_severe_consequence,
                defined_sites_expr=mt.defined_sites,
                num_homs_expr=mt.num_homs,
                num_hets_expr=mt.num_hets,
                pop_expr=mt[meta_root][pop_field],
            ),
        ),
        missense=hl.struct(
            **get_het_hom_summary_dict(
                csq_set={"missense_variant"},
                most_severe_csq_expr=mt.most_severe_consequence,
                defined_sites_expr=mt.defined_sites,
                num_homs_expr=mt.num_homs,
                num_hets_expr=mt.num_hets,
                pop_expr=mt[meta_root][pop_field],
            ),
        ),
        synonymous=hl.struct(
            **get_het_hom_summary_dict(
                csq_set={"synonymous_variant"},
                most_severe_csq_expr=mt.most_severe_consequence,
                defined_sites_expr=mt.defined_sites,
                num_homs_expr=mt.num_homs,
                num_hets_expr=mt.num_hets,
                pop_expr=mt[meta_root][pop_field],
            ),
        ),
    ).rows()
    ht = ht.annotate(
        p=(1 - hl.sqrt(hl.float64(ht.lof.no_alt_calls) / ht.lof.defined)),
        pop_p=hl.dict(
            hl.array(ht.lof.pop_defined).map(
                lambda x: (
                    x[0],
                    1 - hl.sqrt(hl.float64(ht.lof.pop_no_alt_calls.get(x[0])) / x[1]),
                )
            )
        ),
    )
    ht = ht.annotate(exp_hom_lof=ht.lof.defined * ht.p * ht.p)
    return ht.annotate(oe=ht.lof.obs_hom / ht.exp_hom_lof)
コード例 #46
0
def make_group_sum_expr_dict(
    t: Union[hl.MatrixTable, hl.Table],
    subset: str,
    label_groups: Dict[str, List[str]],
    sort_order: List[str] = SORT_ORDER,
    delimiter: str = "-",
    metric_first_field: bool = True,
    metrics: List[str] = ["AC", "AN", "nhomalt"],
) -> Dict[str, Dict[str, Union[hl.expr.Int64Expression,
                               hl.expr.StructExpression]]]:
    """
    Compute the sum of call stats annotations for a specified group of annotations, compare to the annotated version, and display the result in stdout.

    For example, if subset1 consists of pop1, pop2, and pop3, check that t.info.AC-subset1 == sum(t.info.AC-subset1-pop1, t.info.AC-subset1-pop2, t.info.AC-subset1-pop3).

    :param t: Input MatrixTable or Table containing call stats annotations to be summed.
    :param subset: String indicating sample subset.
    :param label_groups: Dictionary containing an entry for each label group, where key is the name of the grouping, e.g. "sex" or "pop", and value is a list of all possible values for that grouping (e.g. ["XY", "XX"] or ["afr", "nfe", "amr"]).
    :param sort_order: List containing order to sort label group combinations. Default is SORT_ORDER.
    :param delimiter: String to use as delimiter when making group label combinations. Default is "-".
    :param metric_first_field: If True, metric precedes subset in the Table's fields, e.g. AC-hgdp. If False, subset precedes metric, hgdp-AC. Default is True.
    :param metrics: List of metrics to sum and compare to annotationed versions. Default is ["AC", "AN", "nhomalt"].
    :return: Dictionary of sample sum field check expressions and display fields.
    """
    t = t.rows() if isinstance(t, hl.MatrixTable) else t

    # Check if subset string is provided to avoid adding a delimiter to empty string
    # (An empty string is passed to run this check on the entire callset)
    if subset:
        subset += delimiter

    label_combos = make_label_combos(label_groups, label_delimiter=delimiter)
    # Grab the first group for check and remove if from the label_group dictionary. In gnomAD, this is 'adj', as we do not retain the raw metric counts for all sample groups so we do not check raw sample sums.
    group = label_groups.pop("group")[0]
    # sum_group is a the type of high level annotation that you want to sum e.g. 'pop', 'pop-sex', 'sex'.
    sum_group = delimiter.join(
        sorted(label_groups.keys(), key=lambda x: sort_order.index(x)))
    info_fields = t.info.keys()

    # Loop through metrics and the label combos to build a dictionary
    # where the key is a string representing the sum_group annotations and the value is the sum of these annotations.
    # If metric_first_field is True, metric is AC, subset is tgp, group is adj, sum_group is pop, then the values below are:
    # sum_group_exprs = ["AC-tgp-pop1", "AC-tgp-pop2", "AC-tgp-pop3"]
    # annot_dict = {'sum-AC-tgp-adj-pop': hl.sum(["AC-tgp-adj-pop1", "AC-tgp-adj-pop2", "AC-tgp-adj-pop3"])
    annot_dict = {}
    for metric in metrics:
        if metric_first_field:
            field_prefix = f"{metric}{delimiter}{subset}"
        else:
            field_prefix = f"{subset}{metric}{delimiter}"

        sum_group_exprs = []
        for label in label_combos:
            field = f"{field_prefix}{label}"
            if field in info_fields:
                sum_group_exprs.append(t.info[field])
            else:
                logger.warning("%s is not in table's info field", field)

        annot_dict[
            f"sum{delimiter}{field_prefix}{group}{delimiter}{sum_group}"] = hl.sum(
                sum_group_exprs)

    # If metric_first_field is True, metric is AC, subset is tgp, sum_group is pop, and group is adj, then the values below are:
    # check_field_left = "AC-tgp-adj"
    # check_field_right = "sum-AC-tgp-adj-pop" to match the annotation dict key from above
    field_check_expr = {}
    for metric in metrics:
        if metric_first_field:
            check_field_left = f"{metric}{delimiter}{subset}{group}"
        else:
            check_field_left = f"{subset}{metric}{delimiter}{group}"
        check_field_right = f"sum{delimiter}{check_field_left}{delimiter}{sum_group}"
        field_check_expr[f"{check_field_left} = {check_field_right}"] = {
            "expr":
            hl.agg.count_where(
                t.info[check_field_left] != annot_dict[check_field_right]),
            "display_fields":
            hl.struct(
                **{
                    check_field_left: t.info[check_field_left],
                    check_field_right: annot_dict[check_field_right],
                }),
        }
    return field_check_expr
コード例 #47
0
def get_summary_counts(
    ht: hl.Table,
    freq_field: str = "freq",
    filter_field: str = "filters",
    filter_decoy: bool = False,
    index: int = 0,
) -> hl.Table:
    """
    Generate a struct with summary counts across variant categories.

    Summary counts:
        - Number of variants
        - Number of indels
        - Number of SNVs
        - Number of LoF variants
        - Number of LoF variants that pass LOFTEE (including with LoF flags)
        - Number of LoF variants that pass LOFTEE without LoF flags
        - Number of OS (other splice) variants annotated by LOFTEE
        - Number of LoF variants that fail LOFTEE filters

    Also annotates Table's globals with total variant counts.

    Before calculating summary counts, function:
        - Filters out low confidence regions
        - Filters to canonical transcripts
        - Uses the most severe consequence

    Assumes that:
        - Input HT is annotated with VEP.
        - Multiallelic variants have been split and/or input HT contains bi-allelic variants only.
        - freq_expr was calculated with `annotate_freq`.
        - (Frequency index 0 from `annotate_freq` is frequency for all pops calculated on adj genotypes only.)

    :param ht: Input Table.
    :param freq_field: Name of field in HT containing frequency annotation (array of structs). Default is "freq".
    :param filter_field: Name of field in HT containing variant filter information. Default is "filters".
    :param filter_decoy: Whether to filter decoy regions. Default is False.
    :param index: Which index of freq_expr to use for annotation. Default is 0.
    :return: Table grouped by frequency bin and aggregated across summary count categories.
    """
    logger.info("Checking if multi-allelic variants have been split...")
    max_alleles = ht.aggregate(hl.agg.max(hl.len(ht.alleles)))
    if max_alleles > 2:
        logger.info("Splitting multi-allelics and VEP transcript consequences...")
        ht = hl.split_multi_hts(ht)

    logger.info("Filtering to PASS variants in high confidence regions...")
    ht = ht.filter((hl.len(ht[filter_field]) == 0))
    ht = filter_low_conf_regions(ht, filter_decoy=filter_decoy)

    logger.info(
        "Filtering to canonical transcripts and getting VEP summary annotations..."
    )
    ht = filter_vep_to_canonical_transcripts(ht)
    ht = get_most_severe_consequence_for_summary(ht)

    logger.info("Annotating with frequency bin information...")
    ht = ht.annotate(freq_bin=freq_bin_expr(ht[freq_field], index))

    logger.info(
        "Annotating HT globals with total counts/total allele counts per variant category..."
    )
    summary_counts = ht.aggregate(
        hl.struct(
            **get_summary_counts_dict(
                ht.locus,
                ht.alleles,
                ht.lof,
                ht.no_lof_flags,
                ht.most_severe_csq,
                prefix_str="total_",
            )
        )
    )
    summary_ac_counts = ht.aggregate(
        hl.struct(
            **get_summary_ac_dict(
                ht[freq_field][index].AC,
                ht.lof,
                ht.no_lof_flags,
                ht.most_severe_csq,
            )
        )
    )
    ht = ht.annotate_globals(
        summary_counts=summary_counts.annotate(**summary_ac_counts)
    )
    return ht.group_by("freq_bin").aggregate(
        **get_summary_counts_dict(
            ht.locus,
            ht.alleles,
            ht.lof,
            ht.no_lof_flags,
            ht.most_severe_csq,
        )
    )
コード例 #48
0
def compare_subset_freqs(
    t: Union[hl.MatrixTable, hl.Table],
    subsets: List[str],
    verbose: bool,
    show_percent_sites: bool = True,
    delimiter: str = "-",
    metric_first_field: bool = True,
    metrics: List[str] = ["AC", "AN", "nhomalt"],
) -> None:
    """
    Perform validity checks on frequency data in input Table.

    Check:
        - Number of sites where callset frequency is equal to a subset frequency (raw and adj)
            - eg. t.info.AC-adj != t.info.AC-subset1-adj
        - Total number of sites where the raw allele count annotation is defined

    :param t: Input MatrixTable or Table.
    :param subsets: List of sample subsets.
    :param verbose: If True, show top values of annotations being checked, including checks that pass; if False, show only top values of annotations that fail checks.
    :param show_percent_sites: If True, show the percentage and count of overall sites that fail; if False, only show the number of sites that fail.
    :param delimiter: String to use as delimiter when making group label combinations. Default is "-".
    :param metric_first_field: If True, metric precedes subset, e.g. AC-non_v2-. If False, subset precedes metric, non_v2-AC-XY. Default is True.
    :param metrics: List of metrics to compare between subset and entire callset. Default is ["AC", "AN", "nhomalt"].
    :return: None
    """
    t = t.rows() if isinstance(t, hl.MatrixTable) else t

    field_check_expr = {}
    for subset in subsets:
        if subset:
            for metric in metrics:
                for group in ["adj", "raw"]:
                    logger.info(
                        "Comparing the %s subset's %s %s to entire callset's %s %s",
                        subset,
                        group,
                        metric,
                        group,
                        metric,
                    )
                    check_field_left = f"{metric}{delimiter}{group}"
                    if metric_first_field:
                        check_field_right = (
                            f"{metric}{delimiter}{subset}{delimiter}{group}")
                    else:
                        check_field_right = (
                            f"{subset}{delimiter}{metric}{delimiter}{group}")

                    field_check_expr[
                        f"{check_field_left} != {check_field_right}"] = {
                            "expr":
                            hl.agg.count_where(t.info[check_field_left] ==
                                               t.info[check_field_right]),
                            "display_fields":
                            hl.struct(
                                **{
                                    check_field_left: t.info[check_field_left],
                                    check_field_right:
                                    t.info[check_field_right],
                                }),
                        }

    generic_field_check_loop(
        t,
        field_check_expr,
        verbose,
        show_percent_sites=show_percent_sites,
    )

    # Spot check the raw AC counts
    total_defined_raw_ac = t.aggregate(
        hl.agg.count_where(hl.is_defined(t.info[f"AC{delimiter}raw"])))
    logger.info("Total defined raw AC count: %s", total_defined_raw_ac)
コード例 #49
0
ファイル: test_matrix_table.py プロジェクト: tianyunwang/hail
 def test_literals_rebuild(self):
     mt = hl.utils.range_matrix_table(1, 1)
     mt = mt.annotate_rows(x=hl.cond(
         hl.len(hl.literal([1, 2, 3])) < hl.rand_unif(10, 11), mt.globals,
         hl.struct()))
     mt._force_count_rows()
コード例 #50
0
def merge_sample_qc_expr(
    sample_qc_exprs: List[hl.expr.StructExpression],
) -> hl.expr.StructExpression:
    """
    Create an expression that merges results from non-overlapping strata of hail.sample_qc.

    E.g.:

    - Compute autosomes and sex chromosomes metrics separately, then merge results
    - Compute bi-allelic and multi-allelic metrics separately, then merge results

    Note regarding the merging of ``dp_stats`` and ``gq_stats``:
    Because ``n`` is needed to aggregate ``stdev``, ``n_called`` is used for this purpose.
    This should work very well on a standard GATK VCF and it essentially assumes that:

    - samples that are called have `DP` and `GQ` fields
    - samples that are not called do not have `DP` and `GQ` fields

    Even if these assumptions are broken for some genotypes, it shouldn't matter too much.

    :param sample_qc_exprs: List of sample QC struct expressions for each stratification
    :return: Combined sample QC results
    """
    # List of metrics that can be aggregated by summing
    additive_metrics = [
        "n_called",
        "n_not_called",
        "n_filtered",
        "n_hom_ref",
        "n_het",
        "n_hom_var",
        "n_snp",
        "n_insertion",
        "n_deletion",
        "n_singleton",
        "n_transition",
        "n_transversion",
        "n_star",
    ]

    # List of metrics that are ratio of summed metrics (name, nominator, denominator)
    ratio_metrics = [
        ("call_rate", "n_called", "n_not_called"),
        ("r_ti_tv", "n_transition", "n_transversion"),
        ("r_het_hom_var", "n_het", "n_hom_var"),
        ("r_insertion_deletion", "n_insertion", "n_deletion"),
    ]

    # List of metrics that are struct generated by a stats counter
    stats_metrics = ["gq_stats", "dp_stats"]

    # Gather metrics present in sample qc fields
    sample_qc_fields = set(sample_qc_exprs[0])
    for sample_qc_expr in sample_qc_exprs[1:]:
        sample_qc_fields = sample_qc_fields.union(set(sample_qc_expr))

    # Merge additive metrics in sample qc fields
    merged_exprs = {
        metric:
        hl.sum([sample_qc_expr[metric] for sample_qc_expr in sample_qc_exprs])
        for metric in additive_metrics if metric in sample_qc_fields
    }

    # Merge ratio metrics in sample qc fields
    merged_exprs.update({
        metric: hl.float64(divide_null(merged_exprs[nom], merged_exprs[denom]))
        for metric, nom, denom in ratio_metrics
        if nom in sample_qc_fields and denom in sample_qc_fields
    })

    # Merge stats counter metrics in sample qc fields
    # Use n_called as n for DP and GQ stats
    if "n_called" in sample_qc_fields:
        merged_exprs.update({
            metric: merge_stats_counters_expr([
                sample_qc_expr[metric].annotate(n=sample_qc_expr.n_called)
                for sample_qc_expr in sample_qc_exprs
            ]).drop("n")
            for metric in stats_metrics if metric in sample_qc_fields
        })

    return hl.struct(**merged_exprs)
コード例 #51
0
def shuffle_key_rows_by_mt(mt_path):
    mt = hl.read_matrix_table(mt_path)
    mt = mt.annotate_rows(reversed_position_locus=hl.struct(
        contig=mt.locus.contig, position=-mt.locus.position))
    mt = mt.key_rows_by(mt.reversed_position_locus)
    mt._force_count_rows()
コード例 #52
0
def compute_qc_metrics_residuals(
    ht: hl.Table,
    pc_scores: hl.expr.ArrayNumericExpression,
    qc_metrics: Dict[str, hl.expr.NumericExpression],
    use_pc_square: bool = True,
    n_pcs: Optional[int] = None,
    regression_sample_inclusion_expr: hl.expr.BooleanExpression = hl.bool(
        True),
) -> hl.Table:
    """
    Compute QC metrics residuals after regressing out PCs (and optionally PC^2).

    .. note::

        The `regression_sample_inclusion_expr` can be used to select a subset of the samples to include in the regression calculation.
        Residuals are always computed for all samples.

    :param ht: Input sample QC metrics HT
    :param pc_scores: The expression in the input HT that stores the PC scores
    :param qc_metrics: A dictionary with the name of each QC metric to compute residuals for and their corresponding expression in the input HT.
    :param use_pc_square: Whether to  use PC^2 in the regression or not
    :param n_pcs: Numer of PCs to use. If not set, then all PCs in `pc_scores` are used.
    :param regression_sample_inclusion_expr: An optional expression to select samples to include in the regression calculation.
    :return: Table with QC metrics residuals
    """
    # Annotate QC HT with fields necessary for computation
    _sample_qc_ht = ht.select(**qc_metrics,
                              scores=pc_scores,
                              _keep=regression_sample_inclusion_expr)

    # If n_pcs wasn't provided, use all PCs
    if n_pcs is None:
        n_pcs = _sample_qc_ht.aggregate(
            hl.agg.min(hl.len(_sample_qc_ht.scores)))

    logger.info(
        "Computing regressed QC metrics filters using %d PCs for metrics: %s",
        n_pcs,
        ", ".join(qc_metrics),
    )

    # Prepare regression variables, adding 1.0 first for the intercept
    # Adds square of variables if use_pc_square is true
    x_expr = [1.0] + [_sample_qc_ht.scores[i] for i in range(0, n_pcs)]
    if use_pc_square:
        x_expr.extend([
            _sample_qc_ht.scores[i] * _sample_qc_ht.scores[i]
            for i in range(0, n_pcs)
        ])

    # Compute linear regressions
    lms = _sample_qc_ht.aggregate(
        hl.struct(
            **{
                metric: hl.agg.filter(
                    _sample_qc_ht._keep,
                    hl.agg.linreg(y=_sample_qc_ht[metric], x=x_expr),
                )
                for metric in qc_metrics
            }))

    _sample_qc_ht = _sample_qc_ht.annotate_globals(lms=lms).persist()

    # Compute residuals
    def get_lm_prediction_expr(metric: str):
        lm_pred_expr = _sample_qc_ht.lms[metric].beta[0] + hl.sum(
            hl.range(n_pcs).map(lambda i: _sample_qc_ht.lms[metric].beta[i + 1]
                                * _sample_qc_ht.scores[i]))
        if use_pc_square:
            lm_pred_expr = lm_pred_expr + hl.sum(
                hl.range(n_pcs).map(
                    lambda i: _sample_qc_ht.lms[metric].beta[i + n_pcs + 1] *
                    _sample_qc_ht.scores[i] * _sample_qc_ht.scores[i]))
        return lm_pred_expr

    residuals_ht = _sample_qc_ht.select(
        **{
            f"{metric}_residual": _sample_qc_ht[metric] -
            get_lm_prediction_expr(metric)
            for metric in _sample_qc_ht.lms
        })

    return residuals_ht.persist()
コード例 #53
0
def median_impute_features(
        ht: hl.Table,
        strata: Optional[Dict[str, hl.expr.Expression]] = None) -> hl.Table:
    """
    Numerical features in the Table are median-imputed by Hail's `approx_median`.

    If a `strata` dict is given, imputation is done based on the median of of each stratification.

    The annotations that are added to the Table are
        - feature_imputed - A row annotation indicating if each numerical feature was imputed or not.
        - features_median - A global annotation containing the median of the numerical features. If `strata` is given,
          this struct will also be broken down by the given strata.
        - variants_by_strata - An additional global annotation with the variant counts by strata that will only be
          added if imputing by a given `strata`.

    :param ht: Table containing all samples and features for median imputation.
    :param strata: Whether to impute features median by specific strata (default False).
    :return: Feature Table imputed using approximate median values.
    """

    logger.info(
        "Computing feature medians for imputation of missing numeric values")
    numerical_features = [
        k for k, v in ht.row.dtype.items() if v == hl.tint or v == hl.tfloat
    ]

    median_agg_expr = hl.struct(
        **{
            feature: hl.agg.approx_median(ht[feature])
            for feature in numerical_features
        })

    if strata:
        ht = ht.annotate_globals(
            feature_medians=ht.aggregate(
                hl.agg.group_by(hl.tuple([ht[x] for x in strata]),
                                median_agg_expr),
                _localize=False,
            ),
            variants_by_strata=ht.aggregate(hl.agg.counter(
                hl.tuple([ht[x] for x in strata])),
                                            _localize=False),
        )
        feature_median_expr = ht.feature_medians[hl.tuple(
            [ht[x] for x in strata])]
        logger.info("Variant count by strata:\n{}".format("\n".join([
            "{}: {}".format(k, v)
            for k, v in hl.eval(ht.variants_by_strata).items()
        ])))

    else:
        ht = ht.annotate_globals(
            feature_medians=ht.aggregate(median_agg_expr, _localize=False))
        feature_median_expr = ht.feature_medians

    ht = ht.annotate(
        **{
            f: hl.or_else(ht[f], feature_median_expr[f])
            for f in numerical_features
        },
        feature_imputed=hl.struct(
            **{f: hl.is_missing(ht[f])
               for f in numerical_features}),
    )

    return ht
コード例 #54
0
ファイル: vcf_combiner.py プロジェクト: vedasha/hail
def combine(ts):
    def merge_alleles(alleles):
        from hail.expr.functions import _num_allele_type, _allele_ints
        return hl.rbind(
            alleles.map(lambda a: hl.or_else(a[0], '')).fold(
                lambda s, t: hl.cond(hl.len(s) > hl.len(t), s, t), ''),
            lambda ref: hl.rbind(
                alleles.map(lambda al: hl.rbind(
                    al[0], lambda r: hl.array([ref]).
                    extend(al[1:].map(lambda a: hl.rbind(
                        _num_allele_type(r, a), lambda at: hl.cond(
                            (_allele_ints['SNP'] == at) |
                            (_allele_ints['Insertion'] == at) |
                            (_allele_ints['Deletion'] == at) |
                            (_allele_ints['MNP'] == at) | (_allele_ints[
                                'Complex'] == at), a + ref[hl.len(r):], a)
                    ))))), lambda lal: hl.struct(globl=hl.array([ref]).extend(
                        hl.array(hl.set(hl.flatten(lal)).remove(ref))),
                                                 local=lal)))

    def renumber_entry(entry, old_to_new) -> StructExpression:
        # global index of alternate (non-ref) alleles
        return entry.annotate(LA=entry.LA.map(lambda lak: old_to_new[lak]))

    if (ts.row.dtype, ts.globals.dtype) not in _merge_function_map:
        f = hl.experimental.define_function(
            lambda row, gbl: hl.rbind(
                merge_alleles(row.data.map(lambda d: d.alleles)), lambda
                alleles: hl.struct(
                    locus=row.locus,
                    alleles=alleles.globl,
                    rsid=hl.find(hl.is_defined, row.data.map(lambda d: d.rsid)
                                 ),
                    info=hl.struct(
                        MQ_DP=hl.sum(row.data.map(lambda d: d.info.MQ_DP)),
                        QUALapprox=hl.sum(
                            row.data.map(lambda d: d.info.QUALapprox)),
                        RAW_MQ=hl.sum(row.data.map(lambda d: d.info.RAW_MQ)),
                        VarDP=hl.sum(row.data.map(lambda d: d.info.VarDP)),
                        SB_TABLE=hl.array([
                            hl.sum(row.data.map(lambda d: d.info.SB_TABLE[0])),
                            hl.sum(row.data.map(lambda d: d.info.SB_TABLE[1])),
                            hl.sum(row.data.map(lambda d: d.info.SB_TABLE[2])),
                            hl.sum(row.data.map(lambda d: d.info.SB_TABLE[3]))
                        ])),
                    __entries=hl.bind(
                        lambda combined_allele_index: hl.
                        range(0, hl.len(row.data)).flatmap(lambda i: hl.cond(
                            hl.is_missing(row.data[i].__entries),
                            hl.range(0, hl.len(gbl.g[i].__cols)).map(
                                lambda _: hl.null(row.data[i].__entries.dtype.
                                                  element_type)),
                            hl.bind(
                                lambda old_to_new: row.data[i].__entries.map(
                                    lambda e: renumber_entry(e, old_to_new)),
                                hl.range(0, hl.len(alleles.local[i])).map(
                                    lambda j: combined_allele_index[
                                        alleles.local[i][j]])))),
                        hl.dict(
                            hl.range(0, hl.len(alleles.globl)).map(
                                lambda j: hl.tuple([alleles.globl[j], j])))))),
            ts.row.dtype, ts.globals.dtype)
        _merge_function_map[(ts.row.dtype, ts.globals.dtype)] = f
    merge_function = _merge_function_map[(ts.row.dtype, ts.globals.dtype)]
    ts = Table(
        TableMapRows(
            ts._tir,
            Apply(merge_function._name, TopLevelReference('row'),
                  TopLevelReference('global'))))
    return ts.transmute_globals(
        __cols=hl.flatten(ts.g.map(lambda g: g.__cols)))
コード例 #55
0
def plot_score_distributions(data_type,
                             models: Union[Dict[str, str], List[str]],
                             snv: bool,
                             cut: int,
                             colors: Dict[str, str] = None) -> Tabs:
    """
    Generates plots of model scores distributions:
    One tab per model.
    Within each tab, there is 2x2 grid of plots:
    - One row showing the score distribution across the entire data
    - One row showing the score distribution across the release-samples, adj data only (release_sample_AC_ADJ > 0)
    - One column showing the histogram of the score
    - One column showing the normalized cumulative histogram of the score

    Cutoff is highlighted by a dashed red line

    :param str data_type: One of 'exomes' or 'genomes'
    :param list of str or dict of str -> str models: Which models to plot. Can either be a list of models or a dict with mapping from model id to model name for display.
    :param bool snv: Whether to plot SNVs or Indels
    :param int cut: Bin cut on the entire data to highlight
    :param dict of str -> str colors: Optional colors to use (model name -> desired color)
    :return: Plots of the score distributions
    :rtype: Tabs
    """

    if not isinstance(models, dict):
        models = {m: m for m in models}

    if colors is None:
        colors = {m_name: "#033649" for m_name in models.values()}

    tabs = []
    for model_id, model_name in models.items():
        if model_id in ['vqsr', 'cnn', 'rf_2.0.2', 'rf_2.0.2_beta']:
            ht = hl.read_table(
                score_ranking_path(data_type, model_id, binned=False))
        else:
            ht = hl.read_table(
                rf_path(data_type, 'rf_result', run_hash=model_id))

        ht = ht.filter(hl.is_snp(ht.alleles[0], ht.alleles[1]), keep=snv)
        binned_ht = hl.read_table(
            score_ranking_path(data_type, model_id, binned=True))
        binned_ht = binned_ht.filter(binned_ht.snv, keep=snv)

        cut_value = binned_ht.aggregate(
            hl.agg.filter(
                (binned_ht.bin == cut) & (binned_ht.rank_id == 'rank'),
                hl.agg.min(binned_ht.min_score)))

        min_score, max_score = (-20, 30) if model_id == 'vqsr' else (0.0, 1.0)
        agg_values = ht.aggregate(
            hl.struct(score_hist=[
                hl.agg.hist(ht.score, min_score, max_score, 100),
                hl.agg.filter(ht.ac > 0,
                              hl.agg.hist(ht.score, min_score, max_score, 100))
            ],
                      adj_counts=hl.agg.filter(
                          ht.ac > 0, hl.agg.counter(ht.score >= cut_value))))
        score_hist = agg_values.score_hist
        adj_cut = '{0:.2f}'.format(
            100 * agg_values.adj_counts[True] /
            (agg_values.adj_counts[True] + agg_values.adj_counts[False]))

        rows = []
        x_range = DataRange1d()
        y_range = [DataRange1d(), DataRange1d()]
        for adj in [False, True]:
            title = '{0}, {1} cut (score = {2:.2f})'.format(
                'Adj' if adj else 'All', adj_cut if adj else cut, cut_value)
            p = plot_hail_hist(score_hist[adj],
                               title=title + "\n",
                               fill_color=colors[model_name])
            p.add_layout(
                Span(location=cut_value,
                     dimension='height',
                     line_color='red',
                     line_dash='dashed'))
            p.x_range = x_range
            p.y_range = y_range[0]
            set_plots_defaults(p)

            p_cumul = plot_hail_hist_cumulative(score_hist[adj],
                                                title=title + ', cumulative',
                                                line_color=colors[model_name])
            p_cumul.add_layout(
                Span(location=cut_value,
                     dimension='height',
                     line_color='red',
                     line_dash='dashed'))
            p_cumul.x_range = x_range
            p_cumul.y_range = y_range[1]
            set_plots_defaults(p_cumul)

            rows.append([p, p_cumul])

        tabs.append(Panel(child=gridplot(rows), title=model_name))
    return Tabs(tabs=tabs)
コード例 #56
0
def get_site_info_expr(
    mt: hl.MatrixTable,
    sum_agg_fields: Union[
        List[str], Dict[str, hl.expr.NumericExpression]
    ] = INFO_SUM_AGG_FIELDS,
    int32_sum_agg_fields: Union[
        List[str], Dict[str, hl.expr.NumericExpression]
    ] = INFO_INT32_SUM_AGG_FIELDS,
    median_agg_fields: Union[
        List[str], Dict[str, hl.expr.NumericExpression]
    ] = INFO_MEDIAN_AGG_FIELDS,
    array_sum_agg_fields: Union[
        List[str], Dict[str, hl.expr.ArrayNumericExpression]
    ] = INFO_ARRAY_SUM_AGG_FIELDS,
) -> hl.expr.StructExpression:
    """
    Create a site-level annotation Struct aggregating typical VCF INFO fields from GVCF INFO fields stored in the MT entries.

    .. note::

        - If `RAW_MQandDP` is specified in array_sum_agg_fields, it will be used for the `MQ` calculation and then dropped according to GATK recommendation.
        - If `RAW_MQ` and `MQ_DP` are given, they will be used for the `MQ` calculation and then dropped according to GATK recommendation.
        - If the fields to be aggregate (`sum_agg_fields`, `int32_sum_agg_fields`, `median_agg_fields`) are passed as
          list of str, then they should correspond to entry fields in `mt` or in `mt.gvcf_info`.
        - Priority is given to entry fields in `mt` over those in `mt.gvcf_info` in case of a name clash.

    :param mt: Input Matrix Table
    :param sum_agg_fields: Fields to aggregate using sum.
    :param int32_sum_agg_fields: Fields to aggregate using sum using int32.
    :param median_agg_fields: Fields to aggregate using (approximate) median.
    :return: Expression containing the site-level info fields
    """
    if "DP" in list(sum_agg_fields) + list(int32_sum_agg_fields):
        logger.warning(
            "`DP` was included in site-level aggregation. This requires a densifying prior to running get_site_info_expr"
        )

    agg_expr = _get_info_agg_expr(
        mt=mt,
        sum_agg_fields=sum_agg_fields,
        int32_sum_agg_fields=int32_sum_agg_fields,
        median_agg_fields=median_agg_fields,
        array_sum_agg_fields=array_sum_agg_fields,
    )

    # Add FS and SOR if SB is present
    # This is done outside of _get_info_agg_expr as the behavior is different in site vs allele-specific versions
    if "SB" in agg_expr:
        agg_expr["FS"] = fs_from_sb(agg_expr["SB"])
        agg_expr["SOR"] = sor_from_sb(agg_expr["SB"])

    # Run aggregator on non-ref genotypes
    info = hl.agg.filter(
        mt.LGT.is_non_ref(),
        hl.struct(**{k: v for k, v in agg_expr.items() if k != "DP"}),
    )

    # Add DP, computed over both ref and non-ref genotypes, if present
    if "DP" in agg_expr:
        info = info.annotate(DP=agg_expr["DP"])

    return info
コード例 #57
0
def get_as_info_expr(
    mt: hl.MatrixTable,
    sum_agg_fields: Union[List[str], Dict[
        str, hl.expr.NumericExpression]] = INFO_SUM_AGG_FIELDS,
    int32_sum_agg_fields: Union[List[str], Dict[
        str, hl.expr.NumericExpression]] = INFO_INT32_SUM_AGG_FIELDS,
    median_agg_fields: Union[List[str], Dict[
        str, hl.expr.NumericExpression]] = INFO_MEDIAN_AGG_FIELDS,
    array_sum_agg_fields: Union[List[str], Dict[
        str, hl.expr.ArrayNumericExpression]] = INFO_ARRAY_SUM_AGG_FIELDS,
    alt_alleles_range_array_field: str = "alt_alleles_range_array",
) -> hl.expr.StructExpression:
    """
    Returns an allele-specific annotation Struct containing typical VCF INFO fields from GVCF INFO fields stored in the MT entries.

    Notes:

    1. If `SB` is specified in array_sum_agg_fields, it will be aggregated as `AS_SB_TABLE`, according to GATK standard nomenclature.
    2. If `RAW_MQandDP` is specified in array_sum_agg_fields, it will be used for the `MQ` calculation and then dropped according to GATK recommendation.
    3. If `RAW_MQ` and `MQ_DP` are given, they will be used for the `MQ` calculation and then dropped according to GATK recommendation.
    4. If the fields to be aggregate (`sum_agg_fields`, `int32_sum_agg_fields`, `median_agg_fields`) are passed as list of str,
       then they should correspond to entry fields in `mt` or in `mt.gvcf_info`.
       Priority is given to entry fields in `mt` over those in `mt.gvcf_info` in case of a name clash.

    :param mt: Input Matrix Table
    :param sum_agg_fields: Fields to aggregate using sum.
    :param int32_sum_agg_fields: Fields to aggregate using sum using int32.
    :param median_agg_fields: Fields to aggregate using (approximate) median.
    :param array_sum_agg_fields: Fields to aggregate using array sum.
    :param alt_alleles_range_array_field: Annotation containing an array of the range of alternate alleles e.g., `hl.range(1, hl.len(mt.alleles))`
    :return: Expression containing the AS info fields
    """
    if "DP" in list(sum_agg_fields) + list(int32_sum_agg_fields):
        logger.warning(
            "`DP` was included in allele-specific aggregation, "
            "however `DP` is typically not aggregated by allele; `VarDP` is."
            "Note that the resulting `AS_DP` field will NOT include reference genotypes."
        )

    agg_expr = _get_info_agg_expr(
        mt=mt,
        sum_agg_fields=sum_agg_fields,
        int32_sum_agg_fields=int32_sum_agg_fields,
        median_agg_fields=median_agg_fields,
        array_sum_agg_fields=array_sum_agg_fields,
        prefix="AS_",
    )

    # Rename AS_SB to AS_SB_TABLE if present
    if "AS_SB" in agg_expr:
        agg_expr["AS_SB_TABLE"] = agg_expr.pop("AS_SB")

    if alt_alleles_range_array_field not in mt.row or mt[
            alt_alleles_range_array_field].dtype != hl.dtype("array<int32>"):
        msg = f"'get_as_info_expr' expected a row field '{alt_alleles_range_array_field}' of type array<int32>"
        logger.error(msg)
        raise ValueError(msg)

    # Modify aggregations to aggregate per allele
    agg_expr = {
        f: hl.agg.array_agg(
            lambda ai: hl.agg.filter(mt.LA.contains(ai), expr),
            mt[alt_alleles_range_array_field],
        )
        for f, expr in agg_expr.items()
    }

    # Run aggregations
    info = hl.struct(**agg_expr)

    # Add SB Ax2 aggregation logic and FS if SB is present
    if "AS_SB_TABLE" in info:
        as_sb_table = hl.array([
            info.AS_SB_TABLE.filter(lambda x: hl.is_defined(x)).fold(
                lambda i, j: i[:2] + j[:2], [0, 0])  # ref
        ]).extend(info.AS_SB_TABLE.map(lambda x: x[2:])  # each alt
                  )
        info = info.annotate(
            AS_SB_TABLE=as_sb_table,
            AS_FS=hl.range(1, hl.len(mt.alleles)).map(
                lambda i: fs_from_sb(as_sb_table[0].extend(as_sb_table[i]))),
            AS_SOR=hl.range(1, hl.len(mt.alleles)).map(
                lambda i: sor_from_sb(as_sb_table[0].extend(as_sb_table[i]))),
        )

    return info
コード例 #58
0
parser.add_argument('-b', required=True, choices=['GRCh37', 'GRCh38'], help='Ensembl reference genome build.')
args = parser.parse_args()

name = 'Ensembl_homo_sapiens_low_complexity_regions'
version = args.v
build = args.b

ht = hl.import_table(f'{raw_data_root}/Ensembl_homo_sapiens_low_complexity_regions_release{version}_{build}.tsv.bgz')

if build == 'GRCh37':
    ht = ht.annotate(interval=hl.locus_interval(ht['chromosome'], hl.int(ht['start']), hl.int(ht['end']), reference_genome='GRCh37'))
else:
    ht = ht.annotate(interval=hl.locus_interval('chr' + ht['chromosome'].replace('MT', 'M'), hl.int(ht['start']), hl.int(ht['end']), reference_genome='GRCh38'))

ht = ht.key_by('interval')
ht = ht.select()

n_rows = ht.count()
n_partitions = ht.n_partitions()

ht = ht.annotate_globals(metadata=hl.struct(name=name,
                                            version=f'release_{version}',
                                            reference_genome=build,
                                            n_rows=n_rows,
                                            n_partitions=n_partitions))

path = f'{hail_data_root}/{name}.release_{version}.{build}.ht'
ht.write(path, overwrite=True)
ht = hl.read_table(path)
ht.describe()
コード例 #59
0
ファイル: vcf_combiner.py プロジェクト: amarantolaw/hail
def transform_gvcf(mt, info_to_keep=[]) -> Table:
    """Transforms a gvcf into a sparse matrix table

    The input to this should be some result of either :func:`.import_vcf` or
    :func:`.import_gvcfs` with ``array_elements_required=False``.

    There is an assumption that this function will be called on a matrix table
    with one column (or a localized table version of the same).

    Parameters
    ----------
    mt : :obj:`Union[Table, MatrixTable]`
        The gvcf being transformed, if it is a table, then it must be a localized matrix table with
        the entries array named ``__entries``
    info_to_keep : :obj:`List[str]`
        Any ``INFO`` fields in the gvcf that are to be kept and put in the ``gvcf_info`` entry
        field. By default, all ``INFO`` fields except ``END`` and ``DP`` are kept.

    Returns
    -------
    :obj:`.Table`
        A localized matrix table that can be used as part of the input to :func:`.combine_gvcfs`

    Notes
    -----
    This function will parse the following allele specific annotations from
    pipe delimited strings into proper values. ::

        AS_QUALapprox
        AS_RAW_MQ
        AS_RAW_MQRankSum
        AS_RAW_ReadPosRankSum
        AS_SB_TABLE
        AS_VarDP

    """
    if not info_to_keep:
        info_to_keep = [name for name in mt.info if name not in ['END', 'DP']]
    mt = localize(mt)

    if mt.row.dtype not in _transform_rows_function_map:
        f = hl.experimental.define_function(
            lambda row: hl.rbind(
                hl.len(row.alleles),
                '<NON_REF>' == row.alleles[-1],
                lambda alleles_len, has_non_ref: hl.
                struct(locus=row.locus,
                       alleles=hl.cond(has_non_ref, row.alleles[:-1], row.
                                       alleles),
                       rsid=row.rsid,
                       __entries=row.__entries.map(lambda e: hl.struct(
                           DP=e.DP,
                           END=row.info.END,
                           GQ=e.GQ,
                           LA=hl.range(
                               0, alleles_len - hl.cond(has_non_ref, 1, 0)),
                           LAD=hl.cond(has_non_ref, e.AD[:-1], e.AD),
                           LGT=e.GT,
                           LPGT=e.PGT,
                           LPL=hl.cond(
                               has_non_ref,
                               hl.cond(alleles_len > 2, e.PL[:-alleles_len],
                                       hl.null(e.PL.dtype)),
                               hl.cond(alleles_len > 1, e.PL,
                                       hl.null(e.PL.dtype))),
                           MIN_DP=e.MIN_DP,
                           PID=e.PID,
                           RGQ=hl.cond(
                               has_non_ref, e.PL[hl.call(0, alleles_len - 1).
                                                 unphased_diploid_gt_index()],
                               hl.null(e.PL.dtype.element_type)),
                           SB=e.SB,
                           gvcf_info=hl.case().when(
                               hl.is_missing(row.info.END),
                               hl.struct(**(parse_as_fields(
                                   row.info.select(*info_to_keep), has_non_ref)
                                            ))).or_missing()))),
            ), mt.row.dtype)
        _transform_rows_function_map[mt.row.dtype] = f
    transform_row = _transform_rows_function_map[mt.row.dtype]
    return Table(
        TableMapRows(
            mt._tir,
            Apply(transform_row._name, transform_row._ret_type,
                  TopLevelReference('row'))))
コード例 #60
0
ファイル: sparse_split_multi.py プロジェクト: tpoterba/hail
def sparse_split_multi(sparse_mt, *, filter_changed_loci=False):
    """Splits multiallelic variants on a sparse matrix table.

    Analogous to :func:`.split_multi_hts` (splits entry fields) for sparse
    representations.

    Takes a dataset formatted like the output of :func:`.vcf_combiner`. The
    splitting will add `was_split` and `a_index` fields, as :func:`.split_multi`
    does. This function drops the `LA` (local alleles) field, as it re-computes
    entry fields based on the new, split globals alleles.

    Variants are split thus:

    - A row with only one (reference) or two (reference and alternate) alleles
      is unchanged, as local and global alleles are the same.

    - A row with multiple alternate alleles  will be split, with one row for
      each alternate allele, and each row will contain two alleles: ref and alt.
      The reference and alternate allele will be minrepped using
      :func:`.min_rep`.

    The split multi logic handles the following entry fields:

        .. code-block:: text

          struct {
            LGT: call
            LAD: array<int32>
            DP: int32
            GQ: int32
            LPL: array<int32>
            RGQ: int32
            LPGT: call
            LA: array<int32>
            END: int32
          }

    All fields except for `LA` are optional, and only handled if they exist.

    - `LA` is used to find the corresponding local allele index for the desired
      global `a_index`, and then dropped from the resulting dataset. If `LA`
      does not contain the global `a_index`, calls will be downcoded to hom ref
      and `PL` will be set to missing.

    - `LGT` and `LPGT` are downcoded using the corresponding local `a_index`.
      They are renamed to `GT` and `PGT` respectively, as the resulting call is
      no longer local.

    - `LAD` is used to create an `AD` field consisting of the allele depths
      corresponding to the reference and global `a_index` alleles.

    - `DP` is preserved unchanged.

    - `GQ` is recalculated from the updated `PL`, if it exists, but otherwise
      preserved unchanged.

    - `PL` array elements are calculated from the minimum `LPL` value for all
      allele pairs that downcode to the desired one. (This logic is identical to
      the `PL` logic in :func:`.split_mult_hts`.) If a row has an alternate
      allele but it is not present in `LA`, the `PL` field is set to missing.
      The `PL` for `ref/<NON_REF>` in that case can be drawn from `RGQ`.

    - `RGQ` (the reference genotype quality) is preserved unchanged.

    - `END` is untouched.

    Notes
    -----
    This version of split-multi doesn't deal with either duplicate loci (in
    which case the explode could possibly result in out-of-order rows, although
    the actual split_multi function also doesn't handle that case).

    It also checks that min-repping will not change the locus and will error if
    it does.

    Unlike the normal split_multi function. Sparse split multi will not filter
    ``*`` alleles. This is because a row with a bi-allelic spanning deletion
    may contain reference blocks that start at this position for other samples.

    Parameters
    ----------
    sparse_mt : :class:`.MatrixTable`
        Sparse MatrixTable to split.
    filter_changed_loci : :obj:`.bool`
        Rather than erroring if any REF/ALT pair changes locus under :func:`.min_rep`
        filter that variant instead.

    Returns
    -------
    :class:`.MatrixTable`
        The split MatrixTable in sparse format.

    """

    hl.methods.misc.require_row_key_variant(sparse_mt, "sparse_split_multi")

    entries = hl.utils.java.Env.get_uid()
    cols = hl.utils.java.Env.get_uid()
    ds = sparse_mt.localize_entries(entries, cols)
    new_id = hl.utils.java.Env.get_uid()

    def struct_from_min_rep(i):
        return hl.bind(
            lambda mr:
            (hl.case().
             when(
                 ds.locus == mr.locus,
                 hl.struct(locus=ds.locus,
                           alleles=[mr.alleles[0], mr.alleles[1]],
                           a_index=i,
                           was_split=True)).when(
                               filter_changed_loci,
                               hl.null(
                                   hl.tstruct(locus=ds.locus.dtype,
                                              alleles=hl.tarray(hl.tstr),
                                              a_index=hl.tint,
                                              was_split=hl.tbool))).
             or_error("Found non-left-aligned variant in sparse_split_multi\n"
                      + "old locus: " + hl.str(ds.locus) + "\n" + "old ref  : "
                      + ds.alleles[0] + "\n" + "old alt  : " + ds.alleles[
                          i] + "\n" + "mr locus : " + hl.str(
                              mr.locus) + "\n" + "mr ref   : " + mr.alleles[
                                  0] + "\n" + "mr alt   : " + mr.alleles[1])),
            hl.min_rep(ds.locus, [ds.alleles[0], ds.alleles[i]]))

    explode_structs = hl.cond(
        hl.len(ds.alleles) < 3, [
            hl.struct(
                locus=ds.locus, alleles=ds.alleles, a_index=1, was_split=False)
        ],
        hl._sort_by(
            hl.cond(
                filter_changed_loci,
                hl.range(1,
                         hl.len(ds.alleles)).map(struct_from_min_rep).filter(
                             hl.is_defined),
                hl.range(1, hl.len(ds.alleles)).map(struct_from_min_rep)),
            lambda l, r: hl._compare(l.alleles, r.alleles) < 0))

    ds = ds.annotate(**{new_id: explode_structs}).explode(new_id)

    def transform_entries(old_entry):
        def with_local_a_index(local_a_index):
            fields = set(old_entry.keys())

            def with_pl(pl):
                new_exprs = {}
                dropped_fields = ['LA']
                if 'LGT' in fields:
                    new_exprs['GT'] = hl.downcode(
                        old_entry.LGT,
                        hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LGT')
                if 'LPGT' in fields:
                    new_exprs['PGT'] = hl.downcode(
                        old_entry.LPGT,
                        hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LPGT')
                if 'LAD' in fields:
                    non_ref_ad = hl.or_else(old_entry.LAD[local_a_index],
                                            0)  # zeroed if not in LAD
                    new_exprs['AD'] = hl.or_missing(
                        hl.is_defined(old_entry.LAD),
                        [hl.sum(old_entry.LAD) - non_ref_ad, non_ref_ad])
                    dropped_fields.append('LAD')
                if 'LPL' in fields:
                    new_exprs['PL'] = pl
                    if 'GQ' in fields:
                        new_exprs['GQ'] = hl.or_else(hl.gq_from_pl(pl),
                                                     old_entry.GQ)

                    dropped_fields.append('LPL')

                return (hl.case().when(
                    hl.len(ds.alleles) == 1,
                    old_entry.annotate(
                        **{
                            f[1:]: old_entry[f]
                            for f in ['LGT', 'LPGT', 'LAD', 'LPL']
                            if f in fields
                        }).drop(*dropped_fields)).when(
                            hl.or_else(old_entry.LGT.is_hom_ref(), False),
                            old_entry.annotate(
                                **{
                                    f: old_entry[f'L{f}'] if f in
                                    ['GT', 'PGT'] else e
                                    for f, e in new_exprs.items()
                                }).drop(*dropped_fields)).default(
                                    old_entry.annotate(**new_exprs).drop(
                                        *dropped_fields)))

            if 'LPL' in fields:
                new_pl = hl.or_missing(
                    hl.is_defined(old_entry.LPL),
                    hl.or_missing(
                        hl.is_defined(local_a_index),
                        hl.range(0, 3).map(lambda i: hl.min(
                            hl.range(0, hl.triangle(hl.len(old_entry.LA))).
                            filter(lambda j: hl.downcode(
                                hl.unphased_diploid_gt_index_call(j),
                                local_a_index) == hl.
                                   unphased_diploid_gt_index_call(i)).map(
                                       lambda idx: old_entry.LPL[idx])))))
                return hl.bind(with_pl, new_pl)
            else:
                return with_pl(None)

        lai = hl.fold(
            lambda accum, elt: hl.cond(old_entry.LA[elt] == ds[new_id].a_index,
                                       elt, accum), hl.null(hl.tint32),
            hl.range(0, hl.len(old_entry.LA)))
        return hl.bind(with_local_a_index, lai)

    new_row = ds.row.annotate(
        **{
            'locus': ds[new_id].locus,
            'alleles': ds[new_id].alleles,
            'a_index': ds[new_id].a_index,
            'was_split': ds[new_id].was_split,
            entries: ds[entries].map(transform_entries)
        }).drop(new_id)

    ds = hl.Table(
        hl.ir.TableKeyBy(hl.ir.TableMapRows(
            hl.ir.TableKeyBy(ds._tir, ['locus']), new_row._ir),
                         ['locus', 'alleles'],
                         is_sorted=True))
    return ds._unlocalize_entries(entries, cols,
                                  list(sparse_mt.col_key.keys()))