def test_aggregate(self): vds = self.get_vds() vds = vds.annotate_globals(foo=5) vds = vds.annotate_rows(x1=agg.count()) vds = vds.annotate_cols(y1=agg.count()) vds = vds.annotate_entries(z1=vds.DP) qv = vds.aggregate_rows(agg.count()) qs = vds.aggregate_cols(agg.count()) qg = vds.aggregate_entries(agg.count()) self.assertIsNotNone(vds.aggregate_entries(hl.agg.take(vds.s, 1)[0])) self.assertEqual(qv, 346) self.assertEqual(qs, 100) self.assertEqual(qg, qv * qs) qvs = vds.aggregate_rows(hl.Struct(x=agg.collect(vds.locus.contig), y=agg.collect(vds.x1))) qss = vds.aggregate_cols(hl.Struct(x=agg.collect(vds.s), y=agg.collect(vds.y1))) qgs = vds.aggregate_entries(hl.Struct(x=agg.filter(False, agg.collect(vds.y1)), y=agg.filter(hl.rand_bool(0.1), agg.collect(vds.GT))))
def test_weird_names(self): ds = self.get_vds() exprs = {'a': 5, ' a ': 5, r'\%!^!@#&#&$%#$%': [5], '$': 5, 'ß': 5} ds.annotate_globals(**exprs) ds.select_globals(**exprs) ds.annotate_cols(**exprs) ds1 = ds.select_cols(**exprs) ds.annotate_rows(**exprs) ds2 = ds.select_rows(**exprs) ds.annotate_entries(**exprs) ds.select_entries(**exprs) ds1.explode_cols('\%!^!@#&#&$%#$%') ds1.explode_cols(ds1['\%!^!@#&#&$%#$%']) ds1.group_cols_by(ds1.a).aggregate(**{'*``81': agg.count()}) ds1.drop('\%!^!@#&#&$%#$%') ds1.drop(ds1['\%!^!@#&#&$%#$%']) ds2.explode_rows('\%!^!@#&#&$%#$%') ds2.explode_rows(ds2['\%!^!@#&#&$%#$%']) ds2.group_rows_by(ds2.a).aggregate(**{'*``81': agg.count()})
def test_concordance(self): dataset = get_dataset() glob_conc, cols_conc, rows_conc = hl.concordance(dataset, dataset) self.assertEqual(sum([sum(glob_conc[i]) for i in range(5)]), dataset.count_rows() * dataset.count_cols()) counts = dataset.aggregate_entries(hl.Struct(n_het=agg.filter(dataset.GT.is_het(), agg.count()), n_hom_ref=agg.filter(dataset.GT.is_hom_ref(), agg.count()), n_hom_var=agg.filter(dataset.GT.is_hom_var(), agg.count()), nNoCall=agg.filter(hl.is_missing(dataset.GT), agg.count()))) self.assertEqual(glob_conc[0][0], 0) self.assertEqual(glob_conc[1][1], counts.nNoCall) self.assertEqual(glob_conc[2][2], counts.n_hom_ref) self.assertEqual(glob_conc[3][3], counts.n_het) self.assertEqual(glob_conc[4][4], counts.n_hom_var) [self.assertEqual(glob_conc[i][j], 0) for i in range(5) for j in range(5) if i != j] self.assertTrue(cols_conc.all(hl.sum(hl.flatten(cols_conc.concordance)) == dataset.count_rows())) self.assertTrue(rows_conc.all(hl.sum(hl.flatten(rows_conc.concordance)) == dataset.count_cols())) cols_conc.write('/tmp/foo.kt', overwrite=True) rows_conc.write('/tmp/foo.kt', overwrite=True)
def test_annotate(self): vds = self.get_vds() vds = vds.annotate_globals(foo=5) self.assertEqual(vds.globals.dtype, hl.tstruct(foo=hl.tint32)) vds = vds.annotate_rows(x1=agg.count(), x2=agg.fraction(False), x3=agg.count_where(True), x4=vds.info.AC + vds.foo) vds = vds.annotate_cols(apple=6) vds = vds.annotate_cols(y1=agg.count(), y2=agg.fraction(False), y3=agg.count_where(True), y4=vds.foo + vds.apple) expected_schema = hl.tstruct(s=hl.tstr, apple=hl.tint32, y1=hl.tint64, y2=hl.tfloat64, y3=hl.tint64, y4=hl.tint32) self.assertTrue(schema_eq(vds.col.dtype, expected_schema), "expected: " + str(vds.col.dtype) + "\nactual: " + str(expected_schema)) vds = vds.select_entries(z1=vds.x1 + vds.foo, z2=vds.x1 + vds.y1 + vds.foo) self.assertTrue(schema_eq(vds.entry.dtype, hl.tstruct(z1=hl.tint64, z2=hl.tint64)))
def test_filter(self): vds = self.get_vds() vds = vds.annotate_globals(foo=5) vds = vds.annotate_rows(x1=agg.count()) vds = vds.annotate_cols(y1=agg.count()) vds = vds.annotate_entries(z1=vds.DP) vds = vds.filter_rows((vds.x1 == 5) & (agg.count() == 3) & (vds.foo == 2)) vds = vds.filter_cols((vds.y1 == 5) & (agg.count() == 3) & (vds.foo == 2)) vds = vds.filter_entries((vds.z1 < 5) & (vds.y1 == 3) & (vds.x1 == 5) & (vds.foo == 2)) vds.count_rows()
def test_filter(self): vds = self.get_vds() vds = vds.annotate_globals(foo=5) vds = vds.annotate_rows(x1=agg.count()) vds = vds.annotate_cols(y1=agg.count()) vds = vds.annotate_entries(z1=vds.DP) vds = vds.filter_rows((vds.x1 == 5) & (agg.count() == 3) & (vds.foo == 2)) vds = vds.filter_cols((vds.y1 == 5) & (agg.count() == 3) & (vds.foo == 2)) vds = vds.filter_entries((vds.z1 < 5) & (vds.y1 == 3) & (vds.x1 == 5) & (vds.foo == 2)) vds.count_rows()
def get_cnt_matrix(mnv_table, region="ALL", dist=1, minimum_cnt=0, PASS=True, part_size=1000, hom=False): # mnv_table = hail table of mnvs # region = bed file, defining the regions of interest (e.g. enhancer region) # dist = distance between two SNPs # PASS=True: restrict to both pass variants # we don't care indels anymore # filter by region, if you give a bed file path as region if region != "ALL": bed = hl.import_bed(region) mnv_table = mnv_table.filter(hl.is_defined(bed[mnv_table.locus])) if PASS=="NO":#exclusively getting at least one non-pass ones mnv_table = mnv_table.filter((mnv_table.filters.length() > 0) | (mnv_table.prev_filters.length() > 0)) elif PASS==True: mnv_table = mnv_table.filter((mnv_table.filters.length() == 0) & (mnv_table.prev_filters.length() == 0)) if hom: mnv_table = mnv_table.filter(mnv_table.n_homhom>0) # count MNV occurance -- restricting to SNPs mnv = mnv_table.filter((mnv_table.alleles[0].length() == 1) & (mnv_table.alleles[1].length() == 1) & (mnv_table.prev_alleles[0].length() == 1) & (mnv_table.prev_alleles[1].length() == 1) & (( mnv_table.locus.position - mnv_table.prev_locus.position) == dist)) # filter to that specific distance #repartition to proper size mnv = mnv.repartition(part_size) mnv_cnt = mnv.group_by("alleles", "prev_alleles").aggregate(cnt=agg.count()) # count occurance mnv_cnt = mnv_cnt.annotate( refs=mnv_cnt.prev_alleles[0] + "N" * (dist - 1) + mnv_cnt.alleles[0]) # annotate combined refs mnv_cnt = mnv_cnt.annotate( alts=mnv_cnt.prev_alleles[1] + "N" * (dist - 1) + mnv_cnt.alleles[1]) # annotate combined alts if minimum_cnt > 0: mnv_cnt = mnv_cnt.filter((mnv_cnt.cnt > minimum_cnt)) # remove trivial ones return (mnv_cnt.select("refs", "alts", "cnt"))
def get_cnt_matrix_alldist(mnv_table, region="ALL", dist_min=1, dist_max=10, minimum_cnt=0, PASS=True, part_size=1000): #give a distance range, instead of single distance if region != "ALL": bed = hl.import_bed(region, skip_invalid_intervals=True) mnv_table = mnv_table.filter(hl.is_defined(bed[mnv_table.locus])) if PASS: mnv_table = mnv_table.filter((mnv_table.filters.length() == 0) & (mnv_table.prev_filters.length() == 0)) # count MNV occurance -- restricting to SNPs mnv_table = mnv_table.filter((mnv_table.alleles[0].length() == 1) & (mnv_table.alleles[1].length() == 1) & (mnv_table.prev_alleles[0].length() == 1) & (mnv_table.prev_alleles[1].length() == 1) ) pdall = {} for dist in range(dist_min, (dist_max+1)): mnv = mnv_table.filter((mnv_table.locus.position - mnv_table.prev_locus.position) == dist) # filter to that specific distance #repartition to proper size mnv = mnv.repartition(part_size) mnv_cnt = mnv.group_by("alleles", "prev_alleles").aggregate(cnt=agg.count()) # count occurance mnv_cnt = mnv_cnt.annotate( refs=mnv_cnt.prev_alleles[0] + "N" * (dist - 1) + mnv_cnt.alleles[0]) # annotate combined refs mnv_cnt = mnv_cnt.annotate( alts=mnv_cnt.prev_alleles[1] + "N" * (dist - 1) + mnv_cnt.alleles[1]) # annotate combined alts if minimum_cnt > 0: mnv_cnt = mnv_cnt.filter((mnv_cnt.cnt > minimum_cnt)) # remove trivial ones pdall[dist] = ht_cnt_mat_to_pd(mnv_cnt.select("refs", "alts", "cnt")) #saving as pandas dataframe, in dictionary print ("done d={0}".format(dist)) print(tm.ctime()) return (pdall) #returning a dictionary of dataframe
def test_select_cols(self): mt = hl.utils.range_matrix_table(3, 5, n_partitions=4) mt = mt.annotate_entries(e=mt.col_idx * mt.row_idx) mt = mt.annotate_globals(g=1) mt = mt.annotate_cols(sum=agg.sum(mt.e + mt.col_idx + mt.row_idx + mt.g) + mt.col_idx + mt.g, count=agg.count_where(mt.e % 2 == 0), foo=agg.count()) result = convert_struct_to_dict(mt.cols().collect()[-2]) self.assertEqual(result, {'col_idx': 3, 'sum': 28, 'count': 2, 'foo': 3})
def test_select_cols(self): mt = hl.utils.range_matrix_table(3, 5, n_partitions=4) mt = mt.annotate_entries(e=mt.col_idx * mt.row_idx) mt = mt.annotate_globals(g=1) mt = mt.annotate_cols(sum=agg.sum(mt.e + mt.col_idx + mt.row_idx + mt.g) + mt.col_idx + mt.g, count=agg.count_where(mt.e % 2 == 0), foo=agg.count()) result = convert_struct_to_dict(mt.cols().collect()[-2]) self.assertEqual(result, {'col_idx': 3, 'sum': 28, 'count': 2, 'foo': 3})
def test_concordance(self): dataset = get_dataset() glob_conc, cols_conc, rows_conc = hl.concordance(dataset, dataset) self.assertEqual(sum([sum(glob_conc[i]) for i in range(5)]), dataset.count_rows() * dataset.count_cols()) counts = dataset.aggregate_entries( hl.Struct(n_het=agg.count( agg.filter(dataset.GT.is_het(), dataset.GT)), n_hom_ref=agg.count( agg.filter(dataset.GT.is_hom_ref(), dataset.GT)), n_hom_var=agg.count( agg.filter(dataset.GT.is_hom_var(), dataset.GT)), nNoCall=agg.count( agg.filter(hl.is_missing(dataset.GT), dataset.GT)))) self.assertEqual(glob_conc[0][0], 0) self.assertEqual(glob_conc[1][1], counts.nNoCall) self.assertEqual(glob_conc[2][2], counts.n_hom_ref) self.assertEqual(glob_conc[3][3], counts.n_het) self.assertEqual(glob_conc[4][4], counts.n_hom_var) [ self.assertEqual(glob_conc[i][j], 0) for i in range(5) for j in range(5) if i != j ] self.assertTrue( cols_conc.all( hl.sum(hl.flatten(cols_conc.concordance)) == dataset.count_rows())) self.assertTrue( rows_conc.all( hl.sum(hl.flatten(rows_conc.concordance)) == dataset.count_cols())) cols_conc.write('/tmp/foo.kt', overwrite=True) rows_conc.write('/tmp/foo.kt', overwrite=True)
def test_query(self): vds = self.get_vds() vds = vds.annotate_globals(foo=5) vds = vds.annotate_rows(x1=agg.count()) vds = vds.annotate_cols(y1=agg.count()) vds = vds.annotate_entries(z1=vds.DP) qv = vds.aggregate_rows(agg.count()) qs = vds.aggregate_cols(agg.count()) qg = vds.aggregate_entries(agg.count()) self.assertEqual(qv, 346) self.assertEqual(qs, 100) self.assertEqual(qg, qv * qs) qvs = vds.aggregate_rows(hl.Struct(x=agg.collect(vds.locus.contig), y=agg.collect(vds.x1))) qss = vds.aggregate_cols(hl.Struct(x=agg.collect(vds.s), y=agg.collect(vds.y1))) qgs = vds.aggregate_entries(hl.Struct(x=agg.collect(agg.filter(False, vds.y1)), y=agg.collect(agg.filter(hl.rand_bool(0.1), vds.GT))))
def test_aggregate2(self): schema = hl.tstruct(status=hl.tint32, GT=hl.tcall, qPheno=hl.tint32) rows = [{'status': 0, 'GT': hl.Call([0, 0]), 'qPheno': 3}, {'status': 0, 'GT': hl.Call([0, 1]), 'qPheno': 13}] kt = hl.Table.parallelize(rows, schema) result = convert_struct_to_dict( kt.group_by(status=kt.status) .aggregate( x1=agg.collect(kt.qPheno * 2), x2=agg.explode(lambda elt: agg.collect(elt), [kt.qPheno, kt.qPheno + 1]), x3=agg.min(kt.qPheno), x4=agg.max(kt.qPheno), x5=agg.sum(kt.qPheno), x6=agg.product(hl.int64(kt.qPheno)), x7=agg.count(), x8=agg.count_where(kt.qPheno == 3), x9=agg.fraction(kt.qPheno == 1), x10=agg.stats(hl.float64(kt.qPheno)), x11=agg.hardy_weinberg_test(kt.GT), x13=agg.inbreeding(kt.GT, 0.1), x14=agg.call_stats(kt.GT, ["A", "T"]), x15=agg.collect(hl.Struct(a=5, b="foo", c=hl.Struct(banana='apple')))[0], x16=agg.collect(hl.Struct(a=5, b="foo", c=hl.Struct(banana='apple')).c.banana)[0], x17=agg.explode(lambda elt: agg.collect(elt), hl.null(hl.tarray(hl.tint32))), x18=agg.explode(lambda elt: agg.collect(elt), hl.null(hl.tset(hl.tint32))), x19=agg.take(kt.GT, 1, ordering=-kt.qPheno) ).take(1)[0]) expected = {u'status': 0, u'x13': {u'n_called': 2, u'expected_homs': 1.64, u'f_stat': -1.777777777777777, u'observed_homs': 1}, u'x14': {u'AC': [3, 1], u'AF': [0.75, 0.25], u'AN': 4, u'homozygote_count': [1, 0]}, u'x15': {u'a': 5, u'c': {u'banana': u'apple'}, u'b': u'foo'}, u'x10': {u'min': 3.0, u'max': 13.0, u'sum': 16.0, u'stdev': 5.0, u'n': 2, u'mean': 8.0}, u'x8': 1, u'x9': 0.0, u'x16': u'apple', u'x11': {u'het_freq_hwe': 0.5, u'p_value': 0.5}, u'x2': [3, 4, 13, 14], u'x3': 3, u'x1': [6, 26], u'x6': 39, u'x7': 2, u'x4': 13, u'x5': 16, u'x17': [], u'x18': [], u'x19': [hl.Call([0, 1])]} self.maxDiff = None self.assertDictEqual(result, expected)
def test_aggregate2(self): schema = hl.tstruct(status=hl.tint32, GT=hl.tcall, qPheno=hl.tint32) rows = [{'status': 0, 'GT': hl.Call([0, 0]), 'qPheno': 3}, {'status': 0, 'GT': hl.Call([0, 1]), 'qPheno': 13}] kt = hl.Table.parallelize(rows, schema) result = convert_struct_to_dict( kt.group_by(status=kt.status) .aggregate( x1=agg.collect(kt.qPheno * 2), x2=agg.explode(lambda elt: agg.collect(elt), [kt.qPheno, kt.qPheno + 1]), x3=agg.min(kt.qPheno), x4=agg.max(kt.qPheno), x5=agg.sum(kt.qPheno), x6=agg.product(hl.int64(kt.qPheno)), x7=agg.count(), x8=agg.count_where(kt.qPheno == 3), x9=agg.fraction(kt.qPheno == 1), x10=agg.stats(hl.float64(kt.qPheno)), x11=agg.hardy_weinberg_test(kt.GT), x13=agg.inbreeding(kt.GT, 0.1), x14=agg.call_stats(kt.GT, ["A", "T"]), x15=agg.collect(hl.Struct(a=5, b="foo", c=hl.Struct(banana='apple')))[0], x16=agg.collect(hl.Struct(a=5, b="foo", c=hl.Struct(banana='apple')).c.banana)[0], x17=agg.explode(lambda elt: agg.collect(elt), hl.null(hl.tarray(hl.tint32))), x18=agg.explode(lambda elt: agg.collect(elt), hl.null(hl.tset(hl.tint32))), x19=agg.take(kt.GT, 1, ordering=-kt.qPheno) ).take(1)[0]) expected = {u'status': 0, u'x13': {u'n_called': 2, u'expected_homs': 1.64, u'f_stat': -1.777777777777777, u'observed_homs': 1}, u'x14': {u'AC': [3, 1], u'AF': [0.75, 0.25], u'AN': 4, u'homozygote_count': [1, 0]}, u'x15': {u'a': 5, u'c': {u'banana': u'apple'}, u'b': u'foo'}, u'x10': {u'min': 3.0, u'max': 13.0, u'sum': 16.0, u'stdev': 5.0, u'n': 2, u'mean': 8.0}, u'x8': 1, u'x9': 0.0, u'x16': u'apple', u'x11': {u'het_freq_hwe': 0.5, u'p_value': 0.5}, u'x2': [3, 4, 13, 14], u'x3': 3, u'x1': [6, 26], u'x6': 39, u'x7': 2, u'x4': 13, u'x5': 16, u'x17': [], u'x18': [], u'x19': [hl.Call([0, 1])]} self.maxDiff = None self.assertDictEqual(result, expected)
def test_weird_names(self): df = hl.utils.range_table(10) exprs = {'a': 5, ' a ': 5, r'\%!^!@#&#&$%#$%': [5], '$': 5, 'ß': 5} df.annotate_globals(**exprs) df.select_globals(**exprs) df.annotate(**exprs) df.select(**exprs) df = df.transmute(**exprs) df.explode('\%!^!@#&#&$%#$%') df.explode(df['\%!^!@#&#&$%#$%']) df.drop('\%!^!@#&#&$%#$%') df.drop(df['\%!^!@#&#&$%#$%']) df.group_by(**{'*``81': df.a}).aggregate(c=agg.count())
def test_query(self): schema = hl.tstruct(a=hl.tint32, b=hl.tint32, c=hl.tint32, d=hl.tint32, e=hl.tstr, f=hl.tarray(hl.tint32)) rows = [{'a': 4, 'b': 1, 'c': 3, 'd': 5, 'e': "hello", 'f': [1, 2, 3]}, {'a': 0, 'b': 5, 'c': 13, 'd': -1, 'e': "cat", 'f': []}, {'a': 4, 'b': 2, 'c': 20, 'd': 3, 'e': "dog", 'f': [5, 6, 7]}] kt = hl.Table.parallelize(rows, schema) results = kt.aggregate(hl.Struct(q1=agg.sum(kt.b), q2=agg.count(), q3=agg.collect(kt.e), q4=agg.collect(agg.filter((kt.d >= 5) | (kt.a == 0), kt.e)))) self.assertEqual(results.q1, 8) self.assertEqual(results.q2, 3) self.assertEqual(set(results.q3), {"hello", "cat", "dog"}) self.assertEqual(set(results.q4), {"hello", "cat"})
def test_weird_names(self): df = hl.utils.range_table(10) exprs = {'a': 5, ' a ': 5, r'\%!^!@#&#&$%#$%': [5]} df.annotate_globals(**exprs) df.select_globals(**exprs) df.annotate(**exprs) df.select(**exprs) df = df.transmute(**exprs) df.explode('\%!^!@#&#&$%#$%') df.explode(df['\%!^!@#&#&$%#$%']) df.drop('\%!^!@#&#&$%#$%') df.drop(df['\%!^!@#&#&$%#$%']) df.group_by(**{'*``81': df.a}).aggregate(c=agg.count())
def test_aggregate1(self): schema = hl.tstruct(a=hl.tint32, b=hl.tint32, c=hl.tint32, d=hl.tint32, e=hl.tstr, f=hl.tarray(hl.tint32)) rows = [{'a': 4, 'b': 1, 'c': 3, 'd': 5, 'e': "hello", 'f': [1, 2, 3]}, {'a': 0, 'b': 5, 'c': 13, 'd': -1, 'e': "cat", 'f': []}, {'a': 4, 'b': 2, 'c': 20, 'd': 3, 'e': "dog", 'f': [5, 6, 7]}] kt = hl.Table.parallelize(rows, schema) results = kt.aggregate(hl.Struct(q1=agg.sum(kt.b), q2=agg.count(), q3=agg.collect(kt.e), q4=agg.filter((kt.d >= 5) | (kt.a == 0), agg.collect(kt.e)), q5=agg.explode(lambda elt: agg.mean(elt), kt.f))) self.assertEqual(results.q1, 8) self.assertEqual(results.q2, 3) self.assertEqual(set(results.q3), {"hello", "cat", "dog"}) self.assertEqual(set(results.q4), {"hello", "cat"}) self.assertAlmostEqual(results.q5, 4)
def summary_ccr(ht_ccr: hl.Table, file_output: str, ccr_pct_start: int = 0, ccr_pct_end: int = 100, ccr_pct_bins: int = 10, cumulative_histogram: bool = False, ccr_pct_cutoffs=None) -> None: """ Summarize Coding Constrain Region information (as histogram) per gene. :param ht_ccr: CCR Hail table :param file_output: File output path :param ccr_pct_start: Start of histogram range. :param ccr_pct_end: End of histogram range :param ccr_pct_bins: Number of bins :param cumulative_histogram: Generate a cumulative histogram (rather than to use bins) :param ccr_pct_cutoffs: Cut-offs used to generate the cumulative histogram :return: None """ if ccr_pct_cutoffs is None: ccr_pct_cutoffs = [90, 95, 99] if cumulative_histogram: # generate cumulative counts histogram summary_tb = (ht_ccr .group_by('gene') .aggregate(**{'ccr_above_' + str(ccr_pct_cutoffs[k]): agg.filter(ht_ccr.ccr_pct >= ccr_pct_cutoffs[k], agg.count()) for k in range(0, len(ccr_pct_cutoffs))}) ) else: summary_tb = (ht_ccr .group_by('gene') .aggregate(ccr_bins=agg.hist(ht_ccr.ccr_pct, ccr_pct_start, ccr_pct_end, ccr_pct_bins)) ) # get bin edges as list (expected n_bins + 1) bin_edges = summary_tb.aggregate(agg.take(summary_tb.ccr_bins.bin_edges, 1))[0] # unpack array structure and annotate as individual fields summary_tb = (summary_tb .annotate(**{'ccr_bin_' + str(bin_edges[k]) + '_' + str(bin_edges[k + 1]): summary_tb.ccr_bins.bin_freq[k] for k in range(0, len(bin_edges) - 1)}) .flatten() ) # drop fields fields_to_drop = ['ccr_bins.bin_edges', 'ccr_bins.bin_freq'] summary_tb = (summary_tb .drop(*fields_to_drop) ) # Export summarized table (summary_tb .export(output=file_output) )
else: et_all = et_all.join(et_union, how="outer") print ("done chr {0}".format(chr)) et_all.select("s").write("gs://gnomad-qingbowang/MNV/tmp_exome_snp_pairs.ht", overwrite=True) final = hl.import_table("gs://gnomad-public/release/2.1/mnv/gnomad_mnv_coding.tsv", types={'n_indv_ex': hl.tint32}) final = final.filter(final.n_indv_ex>0) #et_all.show() et_all= hl.read_table("gs://gnomad-qingbowang/MNV/tmp_exome_snp_pairs.ht") final = final.key_by() final = final.key_by("snp1","snp2") et_all = et_all.key_by() et_all = et_all.key_by("snp1","snp2") et_all = et_all.annotate(categ = final[et_all.key].categ) cnt = et_all.group_by("s","categ").aggregate(n = agg.count()) sums = cnt.group_by("categ").aggregate(sums=hl.agg.sum(cnt.n)) sums.show() #output: """ | "Changed missense" | 5499378 | | "Gained PTV" | 8079 | | "Gained missense" | 202 | | "Lost missense" | 10434 | | "Partially changed missense" | 843375 | | "Rescued PTV" | 556142 | | "Rescued stop loss" | 1 | | "Unchanged" | 4953574 | | "gained_stop_loss" | 22762 | | NA | 57642684 |