Exemple #1
0
    def test_aggregate(self):
        vds = self.get_vds()

        vds = vds.annotate_globals(foo=5)
        vds = vds.annotate_rows(x1=agg.count())
        vds = vds.annotate_cols(y1=agg.count())
        vds = vds.annotate_entries(z1=vds.DP)

        qv = vds.aggregate_rows(agg.count())
        qs = vds.aggregate_cols(agg.count())
        qg = vds.aggregate_entries(agg.count())

        self.assertIsNotNone(vds.aggregate_entries(hl.agg.take(vds.s, 1)[0]))

        self.assertEqual(qv, 346)
        self.assertEqual(qs, 100)
        self.assertEqual(qg, qv * qs)

        qvs = vds.aggregate_rows(hl.Struct(x=agg.collect(vds.locus.contig),
                                           y=agg.collect(vds.x1)))

        qss = vds.aggregate_cols(hl.Struct(x=agg.collect(vds.s),
                                           y=agg.collect(vds.y1)))

        qgs = vds.aggregate_entries(hl.Struct(x=agg.filter(False, agg.collect(vds.y1)),
                                              y=agg.filter(hl.rand_bool(0.1), agg.collect(vds.GT))))
Exemple #2
0
    def test_weird_names(self):
        ds = self.get_vds()
        exprs = {'a': 5, '   a    ': 5, r'\%!^!@#&#&$%#$%': [5], '$': 5, 'ß': 5}

        ds.annotate_globals(**exprs)
        ds.select_globals(**exprs)

        ds.annotate_cols(**exprs)
        ds1 = ds.select_cols(**exprs)

        ds.annotate_rows(**exprs)
        ds2 = ds.select_rows(**exprs)

        ds.annotate_entries(**exprs)
        ds.select_entries(**exprs)

        ds1.explode_cols('\%!^!@#&#&$%#$%')
        ds1.explode_cols(ds1['\%!^!@#&#&$%#$%'])
        ds1.group_cols_by(ds1.a).aggregate(**{'*``81': agg.count()})

        ds1.drop('\%!^!@#&#&$%#$%')
        ds1.drop(ds1['\%!^!@#&#&$%#$%'])

        ds2.explode_rows('\%!^!@#&#&$%#$%')
        ds2.explode_rows(ds2['\%!^!@#&#&$%#$%'])
        ds2.group_rows_by(ds2.a).aggregate(**{'*``81': agg.count()})
Exemple #3
0
    def test_concordance(self):
        dataset = get_dataset()
        glob_conc, cols_conc, rows_conc = hl.concordance(dataset, dataset)

        self.assertEqual(sum([sum(glob_conc[i]) for i in range(5)]), dataset.count_rows() * dataset.count_cols())

        counts = dataset.aggregate_entries(hl.Struct(n_het=agg.filter(dataset.GT.is_het(), agg.count()),
                                                     n_hom_ref=agg.filter(dataset.GT.is_hom_ref(),
                                                                          agg.count()),
                                                     n_hom_var=agg.filter(dataset.GT.is_hom_var(),
                                                                          agg.count()),
                                                     nNoCall=agg.filter(hl.is_missing(dataset.GT),
                                                                        agg.count())))

        self.assertEqual(glob_conc[0][0], 0)
        self.assertEqual(glob_conc[1][1], counts.nNoCall)
        self.assertEqual(glob_conc[2][2], counts.n_hom_ref)
        self.assertEqual(glob_conc[3][3], counts.n_het)
        self.assertEqual(glob_conc[4][4], counts.n_hom_var)
        [self.assertEqual(glob_conc[i][j], 0) for i in range(5) for j in range(5) if i != j]

        self.assertTrue(cols_conc.all(hl.sum(hl.flatten(cols_conc.concordance)) == dataset.count_rows()))
        self.assertTrue(rows_conc.all(hl.sum(hl.flatten(rows_conc.concordance)) == dataset.count_cols()))

        cols_conc.write('/tmp/foo.kt', overwrite=True)
        rows_conc.write('/tmp/foo.kt', overwrite=True)
Exemple #4
0
    def test_annotate(self):
        vds = self.get_vds()
        vds = vds.annotate_globals(foo=5)

        self.assertEqual(vds.globals.dtype, hl.tstruct(foo=hl.tint32))

        vds = vds.annotate_rows(x1=agg.count(),
                                x2=agg.fraction(False),
                                x3=agg.count_where(True),
                                x4=vds.info.AC + vds.foo)

        vds = vds.annotate_cols(apple=6)
        vds = vds.annotate_cols(y1=agg.count(),
                                y2=agg.fraction(False),
                                y3=agg.count_where(True),
                                y4=vds.foo + vds.apple)

        expected_schema = hl.tstruct(s=hl.tstr, apple=hl.tint32, y1=hl.tint64, y2=hl.tfloat64, y3=hl.tint64,
                                     y4=hl.tint32)

        self.assertTrue(schema_eq(vds.col.dtype, expected_schema),
                        "expected: " + str(vds.col.dtype) + "\nactual: " + str(expected_schema))

        vds = vds.select_entries(z1=vds.x1 + vds.foo,
                                 z2=vds.x1 + vds.y1 + vds.foo)
        self.assertTrue(schema_eq(vds.entry.dtype, hl.tstruct(z1=hl.tint64, z2=hl.tint64)))
Exemple #5
0
    def test_filter(self):
        vds = self.get_vds()
        vds = vds.annotate_globals(foo=5)
        vds = vds.annotate_rows(x1=agg.count())
        vds = vds.annotate_cols(y1=agg.count())
        vds = vds.annotate_entries(z1=vds.DP)

        vds = vds.filter_rows((vds.x1 == 5) & (agg.count() == 3) & (vds.foo == 2))
        vds = vds.filter_cols((vds.y1 == 5) & (agg.count() == 3) & (vds.foo == 2))
        vds = vds.filter_entries((vds.z1 < 5) & (vds.y1 == 3) & (vds.x1 == 5) & (vds.foo == 2))
        vds.count_rows()
    def test_filter(self):
        vds = self.get_vds()
        vds = vds.annotate_globals(foo=5)
        vds = vds.annotate_rows(x1=agg.count())
        vds = vds.annotate_cols(y1=agg.count())
        vds = vds.annotate_entries(z1=vds.DP)

        vds = vds.filter_rows((vds.x1 == 5) & (agg.count() == 3)
                              & (vds.foo == 2))
        vds = vds.filter_cols((vds.y1 == 5) & (agg.count() == 3)
                              & (vds.foo == 2))
        vds = vds.filter_entries((vds.z1 < 5) & (vds.y1 == 3) & (vds.x1 == 5)
                                 & (vds.foo == 2))
        vds.count_rows()
Exemple #7
0
def get_cnt_matrix(mnv_table, region="ALL", dist=1, minimum_cnt=0, PASS=True, part_size=1000, hom=False):
    # mnv_table = hail table of mnvs
    # region = bed file, defining the regions of interest (e.g. enhancer region)
    # dist = distance between two SNPs
    # PASS=True: restrict to both pass variants
    # we don't care indels anymore
    # filter by region, if you give a bed file path as region
    if region != "ALL":
        bed = hl.import_bed(region)
        mnv_table = mnv_table.filter(hl.is_defined(bed[mnv_table.locus]))
    if PASS=="NO":#exclusively getting at least one non-pass ones
        mnv_table = mnv_table.filter((mnv_table.filters.length() > 0) | (mnv_table.prev_filters.length() > 0))
    elif PASS==True:
        mnv_table = mnv_table.filter((mnv_table.filters.length() == 0) & (mnv_table.prev_filters.length() == 0))
    if hom:
        mnv_table = mnv_table.filter(mnv_table.n_homhom>0)
    # count MNV occurance -- restricting to SNPs
    mnv = mnv_table.filter((mnv_table.alleles[0].length() == 1) &
                           (mnv_table.alleles[1].length() == 1) &
                           (mnv_table.prev_alleles[0].length() == 1) &
                           (mnv_table.prev_alleles[1].length() == 1) &
                           ((
                            mnv_table.locus.position - mnv_table.prev_locus.position) == dist))  # filter to that specific distance

    #repartition to proper size
    mnv = mnv.repartition(part_size)

    mnv_cnt = mnv.group_by("alleles", "prev_alleles").aggregate(cnt=agg.count())  # count occurance
    mnv_cnt = mnv_cnt.annotate(
        refs=mnv_cnt.prev_alleles[0] + "N" * (dist - 1) + mnv_cnt.alleles[0])  # annotate combined refs
    mnv_cnt = mnv_cnt.annotate(
        alts=mnv_cnt.prev_alleles[1] + "N" * (dist - 1) + mnv_cnt.alleles[1])  # annotate combined alts

    if minimum_cnt > 0: mnv_cnt = mnv_cnt.filter((mnv_cnt.cnt > minimum_cnt))  # remove trivial ones
    return (mnv_cnt.select("refs", "alts", "cnt"))
Exemple #8
0
def get_cnt_matrix_alldist(mnv_table, region="ALL", dist_min=1, dist_max=10, minimum_cnt=0, PASS=True, part_size=1000):
    #give a distance range, instead of single distance
    if region != "ALL":
        bed = hl.import_bed(region, skip_invalid_intervals=True)
        mnv_table = mnv_table.filter(hl.is_defined(bed[mnv_table.locus]))
    if PASS:
        mnv_table = mnv_table.filter((mnv_table.filters.length() == 0) & (mnv_table.prev_filters.length() == 0))

    # count MNV occurance -- restricting to SNPs
    mnv_table = mnv_table.filter((mnv_table.alleles[0].length() == 1) &
                           (mnv_table.alleles[1].length() == 1) &
                           (mnv_table.prev_alleles[0].length() == 1) &
                           (mnv_table.prev_alleles[1].length() == 1) )
    pdall = {}
    for dist in range(dist_min, (dist_max+1)):
        mnv = mnv_table.filter((mnv_table.locus.position - mnv_table.prev_locus.position) == dist)  # filter to that specific distance

        #repartition to proper size
        mnv = mnv.repartition(part_size)

        mnv_cnt = mnv.group_by("alleles", "prev_alleles").aggregate(cnt=agg.count())  # count occurance
        mnv_cnt = mnv_cnt.annotate(
            refs=mnv_cnt.prev_alleles[0] + "N" * (dist - 1) + mnv_cnt.alleles[0])  # annotate combined refs
        mnv_cnt = mnv_cnt.annotate(
            alts=mnv_cnt.prev_alleles[1] + "N" * (dist - 1) + mnv_cnt.alleles[1])  # annotate combined alts

        if minimum_cnt > 0: mnv_cnt = mnv_cnt.filter((mnv_cnt.cnt > minimum_cnt))  # remove trivial ones
        pdall[dist] = ht_cnt_mat_to_pd(mnv_cnt.select("refs", "alts", "cnt")) #saving as pandas dataframe, in dictionary
        print ("done d={0}".format(dist))
        print(tm.ctime())
    return (pdall) #returning a dictionary of dataframe
Exemple #9
0
    def test_select_cols(self):
        mt = hl.utils.range_matrix_table(3, 5, n_partitions=4)
        mt = mt.annotate_entries(e=mt.col_idx * mt.row_idx)
        mt = mt.annotate_globals(g=1)
        mt = mt.annotate_cols(sum=agg.sum(mt.e + mt.col_idx + mt.row_idx + mt.g) + mt.col_idx + mt.g,
                              count=agg.count_where(mt.e % 2 == 0),
                              foo=agg.count())

        result = convert_struct_to_dict(mt.cols().collect()[-2])
        self.assertEqual(result, {'col_idx': 3, 'sum': 28, 'count': 2, 'foo': 3})
Exemple #10
0
    def test_select_cols(self):
        mt = hl.utils.range_matrix_table(3, 5, n_partitions=4)
        mt = mt.annotate_entries(e=mt.col_idx * mt.row_idx)
        mt = mt.annotate_globals(g=1)
        mt = mt.annotate_cols(sum=agg.sum(mt.e + mt.col_idx + mt.row_idx + mt.g) + mt.col_idx + mt.g,
                              count=agg.count_where(mt.e % 2 == 0),
                              foo=agg.count())

        result = convert_struct_to_dict(mt.cols().collect()[-2])
        self.assertEqual(result, {'col_idx': 3, 'sum': 28, 'count': 2, 'foo': 3})
Exemple #11
0
    def test_concordance(self):
        dataset = get_dataset()
        glob_conc, cols_conc, rows_conc = hl.concordance(dataset, dataset)

        self.assertEqual(sum([sum(glob_conc[i]) for i in range(5)]),
                         dataset.count_rows() * dataset.count_cols())

        counts = dataset.aggregate_entries(
            hl.Struct(n_het=agg.count(
                agg.filter(dataset.GT.is_het(), dataset.GT)),
                      n_hom_ref=agg.count(
                          agg.filter(dataset.GT.is_hom_ref(), dataset.GT)),
                      n_hom_var=agg.count(
                          agg.filter(dataset.GT.is_hom_var(), dataset.GT)),
                      nNoCall=agg.count(
                          agg.filter(hl.is_missing(dataset.GT), dataset.GT))))

        self.assertEqual(glob_conc[0][0], 0)
        self.assertEqual(glob_conc[1][1], counts.nNoCall)
        self.assertEqual(glob_conc[2][2], counts.n_hom_ref)
        self.assertEqual(glob_conc[3][3], counts.n_het)
        self.assertEqual(glob_conc[4][4], counts.n_hom_var)
        [
            self.assertEqual(glob_conc[i][j], 0) for i in range(5)
            for j in range(5) if i != j
        ]

        self.assertTrue(
            cols_conc.all(
                hl.sum(hl.flatten(cols_conc.concordance)) ==
                dataset.count_rows()))
        self.assertTrue(
            rows_conc.all(
                hl.sum(hl.flatten(rows_conc.concordance)) ==
                dataset.count_cols()))

        cols_conc.write('/tmp/foo.kt', overwrite=True)
        rows_conc.write('/tmp/foo.kt', overwrite=True)
Exemple #12
0
    def test_query(self):
        vds = self.get_vds()

        vds = vds.annotate_globals(foo=5)
        vds = vds.annotate_rows(x1=agg.count())
        vds = vds.annotate_cols(y1=agg.count())
        vds = vds.annotate_entries(z1=vds.DP)

        qv = vds.aggregate_rows(agg.count())
        qs = vds.aggregate_cols(agg.count())
        qg = vds.aggregate_entries(agg.count())

        self.assertEqual(qv, 346)
        self.assertEqual(qs, 100)
        self.assertEqual(qg, qv * qs)

        qvs = vds.aggregate_rows(hl.Struct(x=agg.collect(vds.locus.contig),
                                           y=agg.collect(vds.x1)))

        qss = vds.aggregate_cols(hl.Struct(x=agg.collect(vds.s),
                                           y=agg.collect(vds.y1)))

        qgs = vds.aggregate_entries(hl.Struct(x=agg.collect(agg.filter(False, vds.y1)),
                                              y=agg.collect(agg.filter(hl.rand_bool(0.1), vds.GT))))
Exemple #13
0
    def test_aggregate2(self):
        schema = hl.tstruct(status=hl.tint32, GT=hl.tcall, qPheno=hl.tint32)

        rows = [{'status': 0, 'GT': hl.Call([0, 0]), 'qPheno': 3},
                {'status': 0, 'GT': hl.Call([0, 1]), 'qPheno': 13}]

        kt = hl.Table.parallelize(rows, schema)

        result = convert_struct_to_dict(
            kt.group_by(status=kt.status)
                .aggregate(
                x1=agg.collect(kt.qPheno * 2),
                x2=agg.explode(lambda elt: agg.collect(elt), [kt.qPheno, kt.qPheno + 1]),
                x3=agg.min(kt.qPheno),
                x4=agg.max(kt.qPheno),
                x5=agg.sum(kt.qPheno),
                x6=agg.product(hl.int64(kt.qPheno)),
                x7=agg.count(),
                x8=agg.count_where(kt.qPheno == 3),
                x9=agg.fraction(kt.qPheno == 1),
                x10=agg.stats(hl.float64(kt.qPheno)),
                x11=agg.hardy_weinberg_test(kt.GT),
                x13=agg.inbreeding(kt.GT, 0.1),
                x14=agg.call_stats(kt.GT, ["A", "T"]),
                x15=agg.collect(hl.Struct(a=5, b="foo", c=hl.Struct(banana='apple')))[0],
                x16=agg.collect(hl.Struct(a=5, b="foo", c=hl.Struct(banana='apple')).c.banana)[0],
                x17=agg.explode(lambda elt: agg.collect(elt), hl.null(hl.tarray(hl.tint32))),
                x18=agg.explode(lambda elt: agg.collect(elt), hl.null(hl.tset(hl.tint32))),
                x19=agg.take(kt.GT, 1, ordering=-kt.qPheno)
            ).take(1)[0])

        expected = {u'status': 0,
                    u'x13': {u'n_called': 2, u'expected_homs': 1.64, u'f_stat': -1.777777777777777,
                             u'observed_homs': 1},
                    u'x14': {u'AC': [3, 1], u'AF': [0.75, 0.25], u'AN': 4, u'homozygote_count': [1, 0]},
                    u'x15': {u'a': 5, u'c': {u'banana': u'apple'}, u'b': u'foo'},
                    u'x10': {u'min': 3.0, u'max': 13.0, u'sum': 16.0, u'stdev': 5.0, u'n': 2, u'mean': 8.0},
                    u'x8': 1, u'x9': 0.0, u'x16': u'apple',
                    u'x11': {u'het_freq_hwe': 0.5, u'p_value': 0.5},
                    u'x2': [3, 4, 13, 14], u'x3': 3, u'x1': [6, 26], u'x6': 39, u'x7': 2, u'x4': 13, u'x5': 16,
                    u'x17': [],
                    u'x18': [],
                    u'x19': [hl.Call([0, 1])]}

        self.maxDiff = None

        self.assertDictEqual(result, expected)
Exemple #14
0
    def test_aggregate2(self):
        schema = hl.tstruct(status=hl.tint32, GT=hl.tcall, qPheno=hl.tint32)

        rows = [{'status': 0, 'GT': hl.Call([0, 0]), 'qPheno': 3},
                {'status': 0, 'GT': hl.Call([0, 1]), 'qPheno': 13}]

        kt = hl.Table.parallelize(rows, schema)

        result = convert_struct_to_dict(
            kt.group_by(status=kt.status)
                .aggregate(
                x1=agg.collect(kt.qPheno * 2),
                x2=agg.explode(lambda elt: agg.collect(elt), [kt.qPheno, kt.qPheno + 1]),
                x3=agg.min(kt.qPheno),
                x4=agg.max(kt.qPheno),
                x5=agg.sum(kt.qPheno),
                x6=agg.product(hl.int64(kt.qPheno)),
                x7=agg.count(),
                x8=agg.count_where(kt.qPheno == 3),
                x9=agg.fraction(kt.qPheno == 1),
                x10=agg.stats(hl.float64(kt.qPheno)),
                x11=agg.hardy_weinberg_test(kt.GT),
                x13=agg.inbreeding(kt.GT, 0.1),
                x14=agg.call_stats(kt.GT, ["A", "T"]),
                x15=agg.collect(hl.Struct(a=5, b="foo", c=hl.Struct(banana='apple')))[0],
                x16=agg.collect(hl.Struct(a=5, b="foo", c=hl.Struct(banana='apple')).c.banana)[0],
                x17=agg.explode(lambda elt: agg.collect(elt), hl.null(hl.tarray(hl.tint32))),
                x18=agg.explode(lambda elt: agg.collect(elt), hl.null(hl.tset(hl.tint32))),
                x19=agg.take(kt.GT, 1, ordering=-kt.qPheno)
            ).take(1)[0])

        expected = {u'status': 0,
                    u'x13': {u'n_called': 2, u'expected_homs': 1.64, u'f_stat': -1.777777777777777,
                             u'observed_homs': 1},
                    u'x14': {u'AC': [3, 1], u'AF': [0.75, 0.25], u'AN': 4, u'homozygote_count': [1, 0]},
                    u'x15': {u'a': 5, u'c': {u'banana': u'apple'}, u'b': u'foo'},
                    u'x10': {u'min': 3.0, u'max': 13.0, u'sum': 16.0, u'stdev': 5.0, u'n': 2, u'mean': 8.0},
                    u'x8': 1, u'x9': 0.0, u'x16': u'apple',
                    u'x11': {u'het_freq_hwe': 0.5, u'p_value': 0.5},
                    u'x2': [3, 4, 13, 14], u'x3': 3, u'x1': [6, 26], u'x6': 39, u'x7': 2, u'x4': 13, u'x5': 16,
                    u'x17': [],
                    u'x18': [],
                    u'x19': [hl.Call([0, 1])]}

        self.maxDiff = None

        self.assertDictEqual(result, expected)
Exemple #15
0
    def test_weird_names(self):
        df = hl.utils.range_table(10)
        exprs = {'a': 5, '   a    ': 5, r'\%!^!@#&#&$%#$%': [5], '$': 5, 'ß': 5}

        df.annotate_globals(**exprs)
        df.select_globals(**exprs)

        df.annotate(**exprs)
        df.select(**exprs)
        df = df.transmute(**exprs)

        df.explode('\%!^!@#&#&$%#$%')
        df.explode(df['\%!^!@#&#&$%#$%'])

        df.drop('\%!^!@#&#&$%#$%')
        df.drop(df['\%!^!@#&#&$%#$%'])
        df.group_by(**{'*``81': df.a}).aggregate(c=agg.count())
Exemple #16
0
    def test_query(self):
        schema = hl.tstruct(a=hl.tint32, b=hl.tint32, c=hl.tint32, d=hl.tint32, e=hl.tstr, f=hl.tarray(hl.tint32))

        rows = [{'a': 4, 'b': 1, 'c': 3, 'd': 5, 'e': "hello", 'f': [1, 2, 3]},
                {'a': 0, 'b': 5, 'c': 13, 'd': -1, 'e': "cat", 'f': []},
                {'a': 4, 'b': 2, 'c': 20, 'd': 3, 'e': "dog", 'f': [5, 6, 7]}]

        kt = hl.Table.parallelize(rows, schema)
        results = kt.aggregate(hl.Struct(q1=agg.sum(kt.b),
                                         q2=agg.count(),
                                         q3=agg.collect(kt.e),
                                         q4=agg.collect(agg.filter((kt.d >= 5) | (kt.a == 0), kt.e))))

        self.assertEqual(results.q1, 8)
        self.assertEqual(results.q2, 3)
        self.assertEqual(set(results.q3), {"hello", "cat", "dog"})
        self.assertEqual(set(results.q4), {"hello", "cat"})
Exemple #17
0
    def test_weird_names(self):
        df = hl.utils.range_table(10)
        exprs = {'a': 5, '   a    ': 5, r'\%!^!@#&#&$%#$%': [5]}

        df.annotate_globals(**exprs)
        df.select_globals(**exprs)

        df.annotate(**exprs)
        df.select(**exprs)
        df = df.transmute(**exprs)

        df.explode('\%!^!@#&#&$%#$%')
        df.explode(df['\%!^!@#&#&$%#$%'])

        df.drop('\%!^!@#&#&$%#$%')
        df.drop(df['\%!^!@#&#&$%#$%'])
        df.group_by(**{'*``81': df.a}).aggregate(c=agg.count())
Exemple #18
0
    def test_aggregate1(self):
        schema = hl.tstruct(a=hl.tint32, b=hl.tint32, c=hl.tint32, d=hl.tint32, e=hl.tstr, f=hl.tarray(hl.tint32))

        rows = [{'a': 4, 'b': 1, 'c': 3, 'd': 5, 'e': "hello", 'f': [1, 2, 3]},
                {'a': 0, 'b': 5, 'c': 13, 'd': -1, 'e': "cat", 'f': []},
                {'a': 4, 'b': 2, 'c': 20, 'd': 3, 'e': "dog", 'f': [5, 6, 7]}]

        kt = hl.Table.parallelize(rows, schema)
        results = kt.aggregate(hl.Struct(q1=agg.sum(kt.b),
                                         q2=agg.count(),
                                         q3=agg.collect(kt.e),
                                         q4=agg.filter((kt.d >= 5) | (kt.a == 0), agg.collect(kt.e)),
                                         q5=agg.explode(lambda elt: agg.mean(elt), kt.f)))

        self.assertEqual(results.q1, 8)
        self.assertEqual(results.q2, 3)
        self.assertEqual(set(results.q3), {"hello", "cat", "dog"})
        self.assertEqual(set(results.q4), {"hello", "cat"})
        self.assertAlmostEqual(results.q5, 4)
Exemple #19
0
def summary_ccr(ht_ccr: hl.Table,
                file_output: str,
                ccr_pct_start: int = 0,
                ccr_pct_end: int = 100,
                ccr_pct_bins: int = 10,
                cumulative_histogram: bool = False,
                ccr_pct_cutoffs=None) -> None:
    """
    Summarize Coding Constrain Region information (as histogram) per gene.

    :param ht_ccr: CCR Hail table
    :param file_output: File output path
    :param ccr_pct_start: Start of histogram range.
    :param ccr_pct_end: End of histogram range
    :param ccr_pct_bins: Number of bins
    :param cumulative_histogram: Generate a cumulative histogram (rather than to use bins)
    :param ccr_pct_cutoffs: Cut-offs used to generate the cumulative histogram
    :return: None
    """

    if ccr_pct_cutoffs is None:
        ccr_pct_cutoffs = [90, 95, 99]

    if cumulative_histogram:
        # generate cumulative counts histogram
        summary_tb = (ht_ccr
                      .group_by('gene')
                      .aggregate(**{'ccr_above_' + str(ccr_pct_cutoffs[k]): agg.filter(ht_ccr.ccr_pct >=
                                                                                       ccr_pct_cutoffs[k], agg.count())
                                    for k in range(0, len(ccr_pct_cutoffs))})
                      )
    else:
        summary_tb = (ht_ccr
                      .group_by('gene')
                      .aggregate(ccr_bins=agg.hist(ht_ccr.ccr_pct, ccr_pct_start, ccr_pct_end, ccr_pct_bins))
                      )

        # get bin edges as list (expected n_bins + 1)
        bin_edges = summary_tb.aggregate(agg.take(summary_tb.ccr_bins.bin_edges, 1))[0]

        # unpack array structure and annotate as individual fields
        summary_tb = (summary_tb
                      .annotate(**{'ccr_bin_' + str(bin_edges[k]) + '_' + str(bin_edges[k + 1]):
                                       summary_tb.ccr_bins.bin_freq[k] for k in range(0, len(bin_edges) - 1)})
                      .flatten()
                      )

        # drop fields
        fields_to_drop = ['ccr_bins.bin_edges', 'ccr_bins.bin_freq']
        summary_tb = (summary_tb
                      .drop(*fields_to_drop)
                      )

    # Export summarized table
    (summary_tb
     .export(output=file_output)
     )
    else:
        et_all = et_all.join(et_union, how="outer")
    print ("done chr {0}".format(chr))
et_all.select("s").write("gs://gnomad-qingbowang/MNV/tmp_exome_snp_pairs.ht", overwrite=True)
final = hl.import_table("gs://gnomad-public/release/2.1/mnv/gnomad_mnv_coding.tsv", types={'n_indv_ex': hl.tint32})
final = final.filter(final.n_indv_ex>0)

#et_all.show()
et_all= hl.read_table("gs://gnomad-qingbowang/MNV/tmp_exome_snp_pairs.ht")
final = final.key_by()
final = final.key_by("snp1","snp2")
et_all = et_all.key_by()
et_all = et_all.key_by("snp1","snp2")
et_all = et_all.annotate(categ = final[et_all.key].categ)

cnt = et_all.group_by("s","categ").aggregate(n = agg.count())
sums = cnt.group_by("categ").aggregate(sums=hl.agg.sum(cnt.n))

sums.show()
#output:
"""
| "Changed missense"           |  5499378 |
| "Gained PTV"                 |     8079 |
| "Gained missense"            |      202 |
| "Lost missense"              |    10434 |
| "Partially changed missense" |   843375 |
| "Rescued PTV"                |   556142 |
| "Rescued stop loss"          |        1 |
| "Unchanged"                  |  4953574 |
| "gained_stop_loss"           |    22762 |
| NA                           | 57642684 |