Example #1
0
    def test_agg_cols_explode(self):
        t = hl.utils.range_matrix_table(1, 10)

        tests = [(agg.explode(
            lambda elt: agg.collect(elt + 1).append(0),
            hl.cond(t.col_idx > 7, [t.col_idx, t.col_idx + 1],
                    hl.empty_array(hl.tint32))), [9, 10, 10, 11, 0]),
                 (agg.explode(
                     lambda elt: agg.explode(
                         lambda elt2: agg.collect(elt2 + 1).append(0),
                         [elt, elt + 1]),
                     hl.cond(t.col_idx > 7, [t.col_idx, t.col_idx + 1],
                             hl.empty_array(hl.tint32))),
                  [9, 10, 10, 11, 10, 11, 11, 12, 0]),
                 (agg.explode(
                     lambda elt: agg.filter(elt > 8,
                                            agg.collect(elt + 1).append(0)),
                     hl.cond(t.col_idx > 7, [t.col_idx, t.col_idx + 1],
                             hl.empty_array(hl.tint32))), [10, 10, 11, 0]),
                 (agg.explode(
                     lambda elt: agg.group_by(elt % 3,
                                              agg.collect(elt + 1).append(0)),
                     hl.cond(t.col_idx > 7, [t.col_idx, t.col_idx + 1],
                             hl.empty_array(hl.tint32))), {
                                 0: [10, 10, 0],
                                 1: [11, 0],
                                 2: [9, 0]
                             })]
        for aggregation, expected in tests:
            self.assertEqual(
                t.select_rows(result=aggregation).result.collect()[0],
                expected)
Example #2
0
    def test_joins(self):
        kt = hl.utils.range_table(1).key_by().drop('idx')
        kt = kt.annotate(a='foo')

        kt1 = hl.utils.range_table(1).key_by().drop('idx')
        kt1 = kt1.annotate(a='foo', b='bar').key_by('a')

        kt2 = hl.utils.range_table(1).key_by().drop('idx')
        kt2 = kt2.annotate(b='bar', c='baz').key_by('b')

        kt3 = hl.utils.range_table(1).key_by().drop('idx')
        kt3 = kt3.annotate(c='baz', d='qux').key_by('c')

        kt4 = hl.utils.range_table(1).key_by().drop('idx')
        kt4 = kt4.annotate(d='qux', e='quam').key_by('d')

        ktr = kt.annotate(e=kt4[kt3[kt2[kt1[kt.a].b].c].d].e)
        self.assertTrue(ktr.aggregate(agg.collect(ktr.e)) == ['quam'])

        ktr = kt.select(e=kt4[kt3[kt2[kt1[kt.a].b].c].d].e)
        self.assertTrue(ktr.aggregate(agg.collect(ktr.e)) == ['quam'])

        self.assertEqual(
            kt.filter(kt4[kt3[kt2[kt1[kt.a].b].c].d].e == 'quam').count(), 1)

        m = hl.import_vcf(resource('sample.vcf'))
        vkt = m.rows()
        vkt = vkt.select(vkt.qual)
        vkt = vkt.annotate(qual2=m.index_rows(vkt.key).qual)
        self.assertTrue(vkt.filter(vkt.qual != vkt.qual2).count() == 0)

        m2 = m.annotate_rows(qual2=vkt.index(m.row_key).qual)
        self.assertTrue(m2.filter_rows(m2.qual != m2.qual2).count_rows() == 0)

        m3 = m.annotate_rows(qual2=m.index_rows(m.row_key).qual)
        self.assertTrue(m3.filter_rows(m3.qual != m3.qual2).count_rows() == 0)

        kt5 = hl.utils.range_table(1).annotate(key='C1589').key_by('key')
        m4 = m.annotate_cols(foo=m.s[:5])
        m4 = m4.annotate_cols(idx=kt5[m4.foo].idx)
        n_C1589 = m.filter_cols(m.s[:5] == 'C1589').count_cols()
        self.assertTrue(n_C1589 > 1)
        self.assertEqual(
            m4.filter_cols(hl.is_defined(m4.idx)).count_cols(), n_C1589)

        kt = hl.utils.range_table(1)
        kt = kt.annotate_globals(foo=5)
        self.assertEqual(hl.eval(kt.foo), 5)

        kt2 = hl.utils.range_table(1)

        kt2 = kt2.annotate_globals(kt_foo=kt.index_globals().foo)
        self.assertEqual(hl.eval(kt2.globals.kt_foo), 5)
Example #3
0
    def test_joins(self):
        kt = hl.utils.range_table(1).key_by().drop('idx')
        kt = kt.annotate(a='foo')

        kt1 = hl.utils.range_table(1).key_by().drop('idx')
        kt1 = kt1.annotate(a='foo', b='bar').key_by('a')

        kt2 = hl.utils.range_table(1).key_by().drop('idx')
        kt2 = kt2.annotate(b='bar', c='baz').key_by('b')

        kt3 = hl.utils.range_table(1).key_by().drop('idx')
        kt3 = kt3.annotate(c='baz', d='qux').key_by('c')

        kt4 = hl.utils.range_table(1).key_by().drop('idx')
        kt4 = kt4.annotate(d='qux', e='quam').key_by('d')

        ktr = kt.annotate(e=kt4[kt3[kt2[kt1[kt.a].b].c].d].e)
        self.assertTrue(ktr.aggregate(agg.collect(ktr.e)) == ['quam'])

        ktr = kt.select(e=kt4[kt3[kt2[kt1[kt.a].b].c].d].e)
        self.assertTrue(ktr.aggregate(agg.collect(ktr.e)) == ['quam'])

        self.assertEqual(kt.filter(kt4[kt3[kt2[kt1[kt.a].b].c].d].e == 'quam').count(), 1)

        m = hl.import_vcf(resource('sample.vcf'))
        vkt = m.rows()
        vkt = vkt.select(vkt.qual)
        vkt = vkt.annotate(qual2=m.index_rows(vkt.key).qual)
        self.assertTrue(vkt.filter(vkt.qual != vkt.qual2).count() == 0)

        m2 = m.annotate_rows(qual2=vkt.index(m.row_key).qual)
        self.assertTrue(m2.filter_rows(m2.qual != m2.qual2).count_rows() == 0)

        m3 = m.annotate_rows(qual2=m.index_rows(m.row_key).qual)
        self.assertTrue(m3.filter_rows(m3.qual != m3.qual2).count_rows() == 0)

        kt5 = hl.utils.range_table(1).annotate(key='C1589').key_by('key')
        m4 = m.annotate_cols(foo=m.s[:5])
        m4 = m4.annotate_cols(idx=kt5[m4.foo].idx)
        n_C1589 = m.filter_cols(m.s[:5] == 'C1589').count_cols()
        self.assertTrue(n_C1589 > 1)
        self.assertEqual(m4.filter_cols(hl.is_defined(m4.idx)).count_cols(), n_C1589)

        kt = hl.utils.range_table(1)
        kt = kt.annotate_globals(foo=5)
        self.assertEqual(hl.eval(kt.foo), 5)

        kt2 = hl.utils.range_table(1)

        kt2 = kt2.annotate_globals(kt_foo=kt.index_globals().foo)
        self.assertEqual(hl.eval(kt2.globals.kt_foo), 5)
Example #4
0
 def test_agg_cols_filter(self):
     t = hl.utils.range_matrix_table(1, 10)
     tests = [(agg.filter(t.col_idx > 7,
                          agg.collect(t.col_idx + 1).append(0)),
               [9, 10, 0]),
              (agg.filter(t.col_idx > 7,
                          agg.explode(lambda elt: agg.collect(elt + 1).append(0),
                                      [t.col_idx, t.col_idx + 1])),
               [9, 10, 10, 11, 0]),
              (agg.filter(t.col_idx > 7,
                          agg.group_by(t.col_idx % 3,
                                       hl.array(agg.collect_as_set(t.col_idx + 1)).append(0))),
               {0: [10, 0], 2: [9, 0]})
              ]
     for aggregation, expected in tests:
         self.assertEqual(t.select_rows(result = aggregation).result.collect()[0], expected)
Example #5
0
 def test_agg_cols_group_by(self):
     t = hl.utils.range_matrix_table(1, 10)
     tests = [
         (agg.group_by(
             t.col_idx % 2,
             hl.array(agg.collect_as_set(t.col_idx + 1)).append(0)), {
                 0: [1, 3, 5, 7, 9, 0],
                 1: [2, 4, 6, 8, 10, 0]
             }),
         (agg.group_by(
             t.col_idx % 3,
             agg.filter(
                 t.col_idx > 7,
                 hl.array(agg.collect_as_set(t.col_idx + 1)).append(0))), {
                     0: [10, 0],
                     1: [0],
                     2: [9, 0]
                 }),
         (agg.group_by(
             t.col_idx % 3,
             agg.explode(
                 lambda elt: agg.collect(elt + 1).append(0),
                 hl.cond(t.col_idx > 7, [t.col_idx, t.col_idx + 1],
                         hl.empty_array(hl.tint32)))), {
                             0: [10, 11, 0],
                             1: [0],
                             2: [9, 10, 0]
                         }),
     ]
     for aggregation, expected in tests:
         self.assertEqual(
             t.select_rows(result=aggregation).result.collect()[0],
             expected)
Example #6
0
 def test_agg_cols_filter(self):
     t = hl.utils.range_matrix_table(1, 10)
     tests = [(agg.filter(t.col_idx > 7,
                          agg.collect(t.col_idx + 1).append(0)),
               [9, 10, 0]),
              (agg.filter(t.col_idx > 7,
                          agg.explode(lambda elt: agg.collect(elt + 1).append(0),
                                      [t.col_idx, t.col_idx + 1])),
               [9, 10, 10, 11, 0]),
              (agg.filter(t.col_idx > 7,
                          agg.group_by(t.col_idx % 3,
                                       hl.array(agg.collect_as_set(t.col_idx + 1)).append(0))),
               {0: [10, 0], 2: [9, 0]})
              ]
     for aggregation, expected in tests:
         self.assertEqual(t.select_rows(result = aggregation).result.collect()[0], expected)
Example #7
0
    def test_query(self):
        schema = hl.tstruct(a=hl.tint32, b=hl.tint32, c=hl.tint32, d=hl.tint32, e=hl.tstr, f=hl.tarray(hl.tint32))

        rows = [{'a': 4, 'b': 1, 'c': 3, 'd': 5, 'e': "hello", 'f': [1, 2, 3]},
                {'a': 0, 'b': 5, 'c': 13, 'd': -1, 'e': "cat", 'f': []},
                {'a': 4, 'b': 2, 'c': 20, 'd': 3, 'e': "dog", 'f': [5, 6, 7]}]

        kt = hl.Table.parallelize(rows, schema)
        results = kt.aggregate(hl.Struct(q1=agg.sum(kt.b),
                                         q2=agg.count(),
                                         q3=agg.collect(kt.e),
                                         q4=agg.collect(agg.filter((kt.d >= 5) | (kt.a == 0), kt.e))))

        self.assertEqual(results.q1, 8)
        self.assertEqual(results.q2, 3)
        self.assertEqual(set(results.q3), {"hello", "cat", "dog"})
        self.assertEqual(set(results.q4), {"hello", "cat"})
Example #8
0
    def test_joins(self):
        kt = hl.utils.range_table(1).drop('idx')
        kt = kt.annotate(a='foo')

        kt1 = hl.utils.range_table(1).drop('idx')
        kt1 = kt1.annotate(a='foo', b='bar').key_by('a')

        kt2 = hl.utils.range_table(1).drop('idx')
        kt2 = kt2.annotate(b='bar', c='baz').key_by('b')

        kt3 = hl.utils.range_table(1).drop('idx')
        kt3 = kt3.annotate(c='baz', d='qux').key_by('c')

        kt4 = hl.utils.range_table(1).drop('idx')
        kt4 = kt4.annotate(d='qux', e='quam').key_by('d')

        ktr = kt.annotate(e=kt4[kt3[kt2[kt1[kt.a].b].c].d].e)
        self.assertTrue(ktr.aggregate(agg.collect(ktr.e)) == ['quam'])

        ktr = kt.select(e=kt4[kt3[kt2[kt1[kt.a].b].c].d].e)
        self.assertTrue(ktr.aggregate(agg.collect(ktr.e)) == ['quam'])

        self.assertEqual(kt.filter(kt4[kt3[kt2[kt1[kt.a].b].c].d].e == 'quam').count(), 1)

        m = hl.import_vcf(resource('sample.vcf'))
        vkt = m.rows()
        vkt = vkt.select(vkt.locus, vkt.alleles, vkt.qual)
        vkt = vkt.annotate(qual2=m[(vkt.locus, vkt.alleles), :].qual)
        self.assertTrue(vkt.filter(vkt.qual != vkt.qual2).count() == 0)

        m2 = m.annotate_rows(qual2=vkt[m.locus, m.alleles].qual)
        self.assertTrue(m2.filter_rows(m2.qual != m2.qual2).count_rows() == 0)

        m3 = m.annotate_rows(qual2=m[(m.locus, m.alleles), :].qual)
        self.assertTrue(m3.filter_rows(m3.qual != m3.qual2).count_rows() == 0)

        kt = hl.utils.range_table(1)
        kt = kt.annotate_globals(foo=5)
        self.assertEqual(kt.foo.value, 5)

        kt2 = hl.utils.range_table(1)

        kt2 = kt2.annotate_globals(kt_foo=kt[:].foo)
        self.assertEqual(kt2.globals.kt_foo.value, 5)
Example #9
0
    def test_aggregate1(self):
        schema = hl.tstruct(a=hl.tint32, b=hl.tint32, c=hl.tint32, d=hl.tint32, e=hl.tstr, f=hl.tarray(hl.tint32))

        rows = [{'a': 4, 'b': 1, 'c': 3, 'd': 5, 'e': "hello", 'f': [1, 2, 3]},
                {'a': 0, 'b': 5, 'c': 13, 'd': -1, 'e': "cat", 'f': []},
                {'a': 4, 'b': 2, 'c': 20, 'd': 3, 'e': "dog", 'f': [5, 6, 7]}]

        kt = hl.Table.parallelize(rows, schema)
        results = kt.aggregate(hl.Struct(q1=agg.sum(kt.b),
                                         q2=agg.count(),
                                         q3=agg.collect(kt.e),
                                         q4=agg.filter((kt.d >= 5) | (kt.a == 0), agg.collect(kt.e)),
                                         q5=agg.explode(lambda elt: agg.mean(elt), kt.f)))

        self.assertEqual(results.q1, 8)
        self.assertEqual(results.q2, 3)
        self.assertEqual(set(results.q3), {"hello", "cat", "dog"})
        self.assertEqual(set(results.q4), {"hello", "cat"})
        self.assertAlmostEqual(results.q5, 4)
def get_gtex_summary(gtex_rsem_path,
                     gtex_tx_summary_out_path,
                     get_medians=True):
    """
    Get GTEx RSEM table with ENSTs and ENSGs as rows and GTEx samples as columns (e.g. Muscle-Skeletal.12,
    Adipose.27 etc.) and write out a table with same rows, and tissues as columns (Muscle-Skeletal, Adipose etc.)
    with cells representing summary expression of transcripts across tissues (ie. mean or median).

    :param str gtex_rsem_path: Output of RSEM quantifications from GTEx
    Example: "gs://gnomad-berylc/reheadered.GTEx_Analysis_2016-09-07_RSEMv1.2.22_transcript_tpm.txt.bgz"
    :param str gtex_tx_summary_out_path: Path to write out.
    Example: "gs://gnomad-berylc/tx-annotation/hail2/GTEx.V7.tx_medians.030818.mt"
    :param bool get_medians: Default True. If False, returns mean transcript expression per tissue
    :return: Writes out summarized GTEx transcript expression as Table.
    :rtype: None
    """

    gtex = hl.import_matrix_table(gtex_rsem_path,
                                  row_key='transcript_id',
                                  row_fields={
                                      'transcript_id': hl.tstr,
                                      'gene_id': hl.tstr
                                  },
                                  entry_type=hl.tfloat64)

    gtex = gtex.annotate_cols(tissue=gtex.col_id.split("\\.")[0])

    if get_medians:
        gtex = gtex.group_cols_by(gtex.tissue).aggregate(
            median_tx_expr=hl.median(agg.collect(gtex.x)))
    else:
        gtex = gtex.group_cols_by(
            gtex.tissue).aggregate(mean_tx_expr=hl.mean(agg.collect(gtex.x)))

    # Make a new column as an array of the values across tissues (per transcript)
    gtex = gtex.annotate_rows(agg_expression=agg.collect(gtex.median_tx_expr))

    # Modify the gtex table to remove version numbers
    gtex = gtex.annotate_rows(transcript_id=gtex.transcript_id.split("\\.")[0])
    gtex = gtex.annotate_rows(gene_id=gtex.gene_id.split("\\.")[0])

    gtex.write(gtex_tx_summary_out_path, overwrite=True)
Example #11
0
    def test_aggregate(self):
        vds = self.get_vds()

        vds = vds.annotate_globals(foo=5)
        vds = vds.annotate_rows(x1=agg.count())
        vds = vds.annotate_cols(y1=agg.count())
        vds = vds.annotate_entries(z1=vds.DP)

        qv = vds.aggregate_rows(agg.count())
        qs = vds.aggregate_cols(agg.count())
        qg = vds.aggregate_entries(agg.count())

        self.assertIsNotNone(vds.aggregate_entries(hl.agg.take(vds.s, 1)[0]))

        self.assertEqual(qv, 346)
        self.assertEqual(qs, 100)
        self.assertEqual(qg, qv * qs)

        qvs = vds.aggregate_rows(
            hl.Struct(x=agg.collect(vds.locus.contig), y=agg.collect(vds.x1)))

        qss = vds.aggregate_cols(
            hl.Struct(x=agg.collect(vds.s), y=agg.collect(vds.y1)))

        qgs = vds.aggregate_entries(
            hl.Struct(x=agg.collect(agg.filter(False, vds.y1)),
                      y=agg.collect(agg.filter(hl.rand_bool(0.1), vds.GT))))
Example #12
0
    def test_aggregate(self):
        vds = self.get_vds()

        vds = vds.annotate_globals(foo=5)
        vds = vds.annotate_rows(x1=agg.count())
        vds = vds.annotate_cols(y1=agg.count())
        vds = vds.annotate_entries(z1=vds.DP)

        qv = vds.aggregate_rows(agg.count())
        qs = vds.aggregate_cols(agg.count())
        qg = vds.aggregate_entries(agg.count())

        self.assertIsNotNone(vds.aggregate_entries(hl.agg.take(vds.s, 1)[0]))

        self.assertEqual(qv, 346)
        self.assertEqual(qs, 100)
        self.assertEqual(qg, qv * qs)

        qvs = vds.aggregate_rows(hl.Struct(x=agg.collect(vds.locus.contig),
                                           y=agg.collect(vds.x1)))

        qss = vds.aggregate_cols(hl.Struct(x=agg.collect(vds.s),
                                           y=agg.collect(vds.y1)))

        qgs = vds.aggregate_entries(hl.Struct(x=agg.filter(False, agg.collect(vds.y1)),
                                              y=agg.filter(hl.rand_bool(0.1), agg.collect(vds.GT))))
Example #13
0
    def test_aggregate(self):
        schema = hl.tstruct(status=hl.tint32, GT=hl.tcall, qPheno=hl.tint32)

        rows = [{'status': 0, 'GT': hl.Call([0, 0]), 'qPheno': 3},
                {'status': 0, 'GT': hl.Call([0, 1]), 'qPheno': 13}]

        kt = hl.Table.parallelize(rows, schema)

        result = convert_struct_to_dict(
            kt.group_by(status=kt.status)
                .aggregate(x1=agg.collect(kt.qPheno * 2),
                           x2=agg.collect(agg.explode([kt.qPheno, kt.qPheno + 1])),
                           x3=agg.min(kt.qPheno),
                           x4=agg.max(kt.qPheno),
                           x5=agg.sum(kt.qPheno),
                           x6=agg.product(hl.int64(kt.qPheno)),
                           x7=agg.count(),
                           x8=agg.count_where(kt.qPheno == 3),
                           x9=agg.fraction(kt.qPheno == 1),
                           x10=agg.stats(hl.float64(kt.qPheno)),
                           x11=agg.hardy_weinberg(kt.GT),
                           x13=agg.inbreeding(kt.GT, 0.1),
                           x14=agg.call_stats(kt.GT, ["A", "T"]),
                           x15=agg.collect(hl.Struct(a=5, b="foo", c=hl.Struct(banana='apple')))[0],
                           x16=agg.collect(hl.Struct(a=5, b="foo", c=hl.Struct(banana='apple')).c.banana)[0]
                           ).take(1)[0])

        expected = {u'status': 0,
                    u'x13': {u'n_called': 2, u'expected_homs': 1.64, u'f_stat': -1.777777777777777,
                             u'observed_homs': 1},
                    u'x14': {u'AC': [3, 1], u'AF': [0.75, 0.25], u'GC': [1, 1, 0], u'AN': 4},
                    u'x15': {u'a': 5, u'c': {u'banana': u'apple'}, u'b': u'foo'},
                    u'x10': {u'min': 3.0, u'max': 13.0, u'sum': 16.0, u'stdev': 5.0, u'n': 2, u'mean': 8.0},
                    u'x8': 1, u'x9': 0.0, u'x16': u'apple',
                    u'x11': {u'r_expected_het_freq': 0.5, u'p_hwe': 0.5},
                    u'x2': [3, 4, 13, 14], u'x3': 3, u'x1': [6, 26], u'x6': 39, u'x7': 2, u'x4': 13, u'x5': 16}

        self.assertDictEqual(result, expected)
Example #14
0
    def test_agg_cols_explode(self):
        t = hl.utils.range_matrix_table(1, 10)

        tests = [(agg.explode(lambda elt: agg.collect(elt + 1).append(0),
                              hl.cond(t.col_idx > 7, [t.col_idx, t.col_idx + 1], hl.empty_array(hl.tint32))),
                  [9, 10, 10, 11, 0]),
                 (agg.explode(lambda elt: agg.explode(lambda elt2: agg.collect(elt2 + 1).append(0),
                                                      [elt, elt + 1]),
                              hl.cond(t.col_idx > 7, [t.col_idx, t.col_idx + 1], hl.empty_array(hl.tint32))),
                  [9, 10, 10, 11, 10, 11, 11, 12, 0]),
                 (agg.explode(lambda elt: agg.filter(elt > 8,
                                                     agg.collect(elt + 1).append(0)),
                              hl.cond(t.col_idx > 7, [t.col_idx, t.col_idx + 1], hl.empty_array(hl.tint32))),
                  [10, 10, 11, 0]),
                 (agg.explode(lambda elt: agg.group_by(elt % 3,
                                                       agg.collect(elt + 1).append(0)),
                                           hl.cond(t.col_idx > 7,
                                                   [t.col_idx, t.col_idx + 1],
                                                   hl.empty_array(hl.tint32))),
                  {0: [10, 10, 0], 1: [11, 0], 2:[9, 0]})
                 ]
        for aggregation, expected in tests:
            self.assertEqual(t.select_rows(result = aggregation).result.collect()[0], expected)
Example #15
0
 def test_agg_cols_group_by(self):
     t = hl.utils.range_matrix_table(1, 10)
     tests = [(agg.group_by(t.col_idx % 2,
                            hl.array(agg.collect_as_set(t.col_idx + 1)).append(0)),
               {0: [1, 3, 5, 7, 9, 0], 1: [2, 4, 6, 8, 10, 0]}),
              (agg.group_by(t.col_idx % 3,
                            agg.filter(t.col_idx > 7,
                                       hl.array(agg.collect_as_set(t.col_idx + 1)).append(0))),
               {0: [10, 0], 1: [0], 2: [9, 0]}),
              (agg.group_by(t.col_idx % 3,
                            agg.explode(lambda elt: agg.collect(elt + 1).append(0),
                                        hl.cond(t.col_idx > 7,
                                                [t.col_idx, t.col_idx + 1],
                                                hl.empty_array(hl.tint32)))),
               {0: [10, 11, 0], 1: [0], 2:[9, 10, 0]}),
              ]
     for aggregation, expected in tests:
         self.assertEqual(t.select_rows(result = aggregation).result.collect()[0], expected)
Example #16
0
    def test_aggregate2(self):
        schema = hl.tstruct(status=hl.tint32, GT=hl.tcall, qPheno=hl.tint32)

        rows = [{'status': 0, 'GT': hl.Call([0, 0]), 'qPheno': 3},
                {'status': 0, 'GT': hl.Call([0, 1]), 'qPheno': 13}]

        kt = hl.Table.parallelize(rows, schema)

        result = convert_struct_to_dict(
            kt.group_by(status=kt.status)
                .aggregate(
                x1=agg.collect(kt.qPheno * 2),
                x2=agg.explode(lambda elt: agg.collect(elt), [kt.qPheno, kt.qPheno + 1]),
                x3=agg.min(kt.qPheno),
                x4=agg.max(kt.qPheno),
                x5=agg.sum(kt.qPheno),
                x6=agg.product(hl.int64(kt.qPheno)),
                x7=agg.count(),
                x8=agg.count_where(kt.qPheno == 3),
                x9=agg.fraction(kt.qPheno == 1),
                x10=agg.stats(hl.float64(kt.qPheno)),
                x11=agg.hardy_weinberg_test(kt.GT),
                x13=agg.inbreeding(kt.GT, 0.1),
                x14=agg.call_stats(kt.GT, ["A", "T"]),
                x15=agg.collect(hl.Struct(a=5, b="foo", c=hl.Struct(banana='apple')))[0],
                x16=agg.collect(hl.Struct(a=5, b="foo", c=hl.Struct(banana='apple')).c.banana)[0],
                x17=agg.explode(lambda elt: agg.collect(elt), hl.null(hl.tarray(hl.tint32))),
                x18=agg.explode(lambda elt: agg.collect(elt), hl.null(hl.tset(hl.tint32))),
                x19=agg.take(kt.GT, 1, ordering=-kt.qPheno)
            ).take(1)[0])

        expected = {u'status': 0,
                    u'x13': {u'n_called': 2, u'expected_homs': 1.64, u'f_stat': -1.777777777777777,
                             u'observed_homs': 1},
                    u'x14': {u'AC': [3, 1], u'AF': [0.75, 0.25], u'AN': 4, u'homozygote_count': [1, 0]},
                    u'x15': {u'a': 5, u'c': {u'banana': u'apple'}, u'b': u'foo'},
                    u'x10': {u'min': 3.0, u'max': 13.0, u'sum': 16.0, u'stdev': 5.0, u'n': 2, u'mean': 8.0},
                    u'x8': 1, u'x9': 0.0, u'x16': u'apple',
                    u'x11': {u'het_freq_hwe': 0.5, u'p_value': 0.5},
                    u'x2': [3, 4, 13, 14], u'x3': 3, u'x1': [6, 26], u'x6': 39, u'x7': 2, u'x4': 13, u'x5': 16,
                    u'x17': [],
                    u'x18': [],
                    u'x19': [hl.Call([0, 1])]}

        self.maxDiff = None

        self.assertDictEqual(result, expected)
def tx_annotate_mt(mt,
                   gtex,
                   filter_to_genes=None,
                   gene_column_in_mt=None,
                   filter_to_csqs=None,
                   out_tx_annotation_tsv=None,
                   out_tx_annotation_kt=None,
                   filter_to_homs=False):
    """
    Annotate variants in the input MatrixTable with transcript-based expression values accross GTEx. Returns Table.

    :param MatrixTable mt:
    :param MatrixTable gtex: Input GTEx summary MatrixTable, must have transcript_id column to key by
    :param None or set filter_to_genes: Default None. If you'd like to filter the mt before annotating
    (decreases time) feed in a list or set of genes
    :param str gene_column_in_mt: Must be set if filter_to_genes != None.
    Column in matrix table that contains gene information within vep.transcript_consequences.
    often ["gene_id", "gene_symbol"]
    :param Nonr or list filter_to_csqs: Default None. If you'd like to filter the mt before annotating
    (decreases time) feed in a list or set of consequence terms.
    Example = ["stop_gained","splice_donor_variant", "splice_acceptor_variant","frameshift_variant"]
    :param None or str out_tx_annotation_tsv: Default None.
    If you'd like to write out the results table as a tsv, provide a tsv path
    :param None or str out_tx_annotation_kt: Default None.
    If you'd like to write out the results table as a Hail 0.2 table, provide a .kt path
    :param bool filter_to_homs: Default False
    If True, filter to variants with at least one homozygote in dataset
    :return: Table with columns: variant, worst_csq, ensg, LOFTEE LOF, LOFTEE LOF Flag, transcript-aware expression
    by GTEx Tissue
    :rtype: Table with variants annotated with transcript-aware tissue expression
    """

    # Create a Table copy of GTEx, key'd by transcript_id
    gtex_table = gtex.rows().key_by("transcript_id")

    # Explode the mt for the transcript consequences to be able to key by transcript ID
    mt = mt.explode_rows(mt.vep.transcript_consequences)

    # Add worst csq to the mt
    mt = simplify_worst_csq(mt)

    mt_kt = mt.rows()

    if filter_to_genes:
        print("Filtering to genes of interest")
        mt_kt = filter_table_to_gene_list(mt_kt, filter_to_genes,
                                          gene_column_in_mt)

    if filter_to_csqs:
        print("Filtering to csqs in %s" % (",".join(filter_to_csqs)))
        mt_kt = filter_table_to_csqs(mt_kt, filter_to_csqs)

    if filter_to_homs:
        print(
            "Filtering to variants with at least 1 homozygote sample in dataset"
        )
        mt_kt = mt_kt.filter(mt_kt.info.Hom[mt_kt.a_index - 1] > 0)

    # Annotate mt with the gtex values (ie. join them)
    mt_kt = mt_kt.annotate(
        tx_data=gtex_table[mt_kt.vep.transcript_consequences.transcript_id])

    # Group by gene, worst_csq and variant, and do a pairwise-sum
    grouped_table = (mt_kt.group_by(
        worst_csq=mt_kt.worst_csq,
        ensg=mt_kt.vep.transcript_consequences.gene_id,
        locus=mt_kt.locus,
        alleles=mt_kt.alleles,
        lof=mt_kt.vep.transcript_consequences.lof,
        lof_flag=mt_kt.vep.transcript_consequences.lof_flags).aggregate(
            tx_annotation=agg.array_sum(mt_kt.tx_data.agg_expression)))

    # Expand the columns from the arrays and add tissues as headers
    tissue_ids = gtex.aggregate_cols(agg.collect(gtex.tissue))
    d = {tiss: i for i, tiss in enumerate(tissue_ids)}

    # This is currently a hack but Tim will fix it at which point I can remove "replace"s
    tx_annotation_table = grouped_table.annotate(
        **{
            tissue_id.replace("-", "_").replace(" ", "_").replace("(", "_").
            replace(")", "_"): grouped_table.tx_annotation[d[tissue_id]]
            for tissue_id in tissue_ids
        })

    if out_tx_annotation_tsv:
        print("Writing tsv file to %s" % out_tx_annotation_tsv)
        tx_annotation_table.export(out_tx_annotation_tsv)

    if out_tx_annotation_kt:
        print("Writing Table to %s" % out_tx_annotation_kt)
        tx_annotation_table.write(out_tx_annotation_kt)

    return tx_annotation_table