def test_concordance(self): dataset = get_dataset() glob_conc, cols_conc, rows_conc = hl.concordance(dataset, dataset) self.assertEqual(sum([sum(glob_conc[i]) for i in range(5)]), dataset.count_rows() * dataset.count_cols()) counts = dataset.aggregate_entries(hl.Struct(n_het=agg.filter(dataset.GT.is_het(), agg.count()), n_hom_ref=agg.filter(dataset.GT.is_hom_ref(), agg.count()), n_hom_var=agg.filter(dataset.GT.is_hom_var(), agg.count()), nNoCall=agg.filter(hl.is_missing(dataset.GT), agg.count()))) self.assertEqual(glob_conc[0][0], 0) self.assertEqual(glob_conc[1][1], counts.nNoCall) self.assertEqual(glob_conc[2][2], counts.n_hom_ref) self.assertEqual(glob_conc[3][3], counts.n_het) self.assertEqual(glob_conc[4][4], counts.n_hom_var) [self.assertEqual(glob_conc[i][j], 0) for i in range(5) for j in range(5) if i != j] self.assertTrue(cols_conc.all(hl.sum(hl.flatten(cols_conc.concordance)) == dataset.count_rows())) self.assertTrue(rows_conc.all(hl.sum(hl.flatten(rows_conc.concordance)) == dataset.count_cols())) cols_conc.write('/tmp/foo.kt', overwrite=True) rows_conc.write('/tmp/foo.kt', overwrite=True)
def combine(ts): # pylint: disable=protected-access tmp = ts.annotate( alleles=merge_alleles(ts.data.map(lambda d: d.alleles)), rsid=hl.find(hl.is_defined, ts.data.map(lambda d: d.rsid)), filters=hl.set(hl.flatten(ts.data.map(lambda d: hl.array(d.filters)))), info=hl.struct( DP=hl.sum(ts.data.map(lambda d: d.info.DP)), MQ_DP=hl.sum(ts.data.map(lambda d: d.info.MQ_DP)), QUALapprox=hl.sum(ts.data.map(lambda d: d.info.QUALapprox)), RAW_MQ=hl.sum(ts.data.map(lambda d: d.info.RAW_MQ)), VarDP=hl.sum(ts.data.map(lambda d: d.info.VarDP)), SB=hl.array([ hl.sum(ts.data.map(lambda d: d.info.SB[0])), hl.sum(ts.data.map(lambda d: d.info.SB[1])), hl.sum(ts.data.map(lambda d: d.info.SB[2])), hl.sum(ts.data.map(lambda d: d.info.SB[3])) ]))) tmp = tmp.annotate( __entries=hl.bind( lambda combined_allele_index: hl.range(0, hl.len(tmp.data)).flatmap( lambda i: hl.cond(hl.is_missing(tmp.data[i].__entries), hl.range(0, hl.len(tmp.g[i].__cols)) .map(lambda _: hl.null(tmp.data[i].__entries.dtype.element_type)), hl.bind( lambda old_to_new: tmp.data[i].__entries.map(lambda e: renumber_entry(e, old_to_new)), hl.range(0, hl.len(tmp.data[i].alleles)).map( lambda j: combined_allele_index[tmp.data[i].alleles[j]])))), hl.dict(hl.range(0, hl.len(tmp.alleles)).map( lambda j: hl.tuple([tmp.alleles[j], j]))))) tmp = tmp.annotate_globals(__cols=hl.flatten(tmp.g.map(lambda g: g.__cols))) return tmp.drop('data', 'g')
def merge_alleles(alleles): from hail.expr.functions import _num_allele_type, _allele_ints return hl.rbind( alleles.map(lambda a: hl.or_else(a[0], '')) .fold(lambda s, t: hl.cond(hl.len(s) > hl.len(t), s, t), ''), lambda ref: hl.rbind( alleles.map( lambda al: hl.rbind( al[0], lambda r: hl.array([ref]).extend( al[1:].map( lambda a: hl.rbind( _num_allele_type(r, a), lambda at: hl.cond( (_allele_ints['SNP'] == at) | (_allele_ints['Insertion'] == at) | (_allele_ints['Deletion'] == at) | (_allele_ints['MNP'] == at) | (_allele_ints['Complex'] == at), a + ref[hl.len(r):], a)))))), lambda lal: hl.struct( globl=hl.array([ref]).extend(hl.array(hl.set(hl.flatten(lal)).remove(ref))), local=lal)))
def sor_from_sb( sb: Union[hl.expr.ArrayNumericExpression, hl.expr.ArrayExpression] ) -> hl.expr.Float64Expression: """ Computes `SOR` (Symmetric Odds Ratio test) annotation from the `SB` (strand balance table) field. .. note:: This function can either take - an array of length four containing the forward and reverse strands' counts of ref and alt alleles: [ref fwd, ref rev, alt fwd, alt rev] - a two dimensional array with arrays of length two, containing the counts: [[ref fwd, ref rev], [alt fwd, alt rev]] GATK code here: https://github.com/broadinstitute/gatk/blob/master/src/main/java/org/broadinstitute/hellbender/tools/walkers/annotator/StrandOddsRatio.java :param sb: Count of ref/alt reads on each strand :return: SOR value """ if not isinstance(sb, hl.expr.ArrayNumericExpression): sb = hl.bind(lambda x: hl.flatten(x), sb) sb = sb.map(lambda x: hl.float64(x) + 1) ref_fw = sb[0] ref_rv = sb[1] alt_fw = sb[2] alt_rv = sb[3] symmetrical_ratio = ((ref_fw * alt_rv) / (alt_fw * ref_rv)) + ( (alt_fw * ref_rv) / (ref_fw * alt_rv) ) ref_ratio = hl.min(ref_rv, ref_fw) / hl.max(ref_rv, ref_fw) alt_ratio = hl.min(alt_fw, alt_rv) / hl.max(alt_fw, alt_rv) sor = hl.log(symmetrical_ratio) + hl.log(ref_ratio) - hl.log(alt_ratio) return sor
def combine(ts): # pylint: disable=protected-access tmp = ts.annotate( alleles=merge_alleles(ts.data.map(lambda d: d.alleles)), rsid=hl.find(hl.is_defined, ts.data.map(lambda d: d.rsid)), info=hl.struct( MQ_DP=hl.sum(ts.data.map(lambda d: d.info.MQ_DP)), QUALapprox=hl.sum(ts.data.map(lambda d: d.info.QUALapprox)), RAW_MQ=hl.sum(ts.data.map(lambda d: d.info.RAW_MQ)), VarDP=hl.sum(ts.data.map(lambda d: d.info.VarDP)), SB_TABLE=hl.array([ hl.sum(ts.data.map(lambda d: d.info.SB_TABLE[0])), hl.sum(ts.data.map(lambda d: d.info.SB_TABLE[1])), hl.sum(ts.data.map(lambda d: d.info.SB_TABLE[2])), hl.sum(ts.data.map(lambda d: d.info.SB_TABLE[3])) ]))) tmp = tmp.annotate( __entries=hl.bind( lambda combined_allele_index: hl.range(0, hl.len(tmp.data)).flatmap( lambda i: hl.cond(hl.is_missing(tmp.data[i].__entries), hl.range(0, hl.len(tmp.g[i].__cols)) .map(lambda _: hl.null(tmp.data[i].__entries.dtype.element_type)), hl.bind( lambda old_to_new: tmp.data[i].__entries.map(lambda e: renumber_entry(e, old_to_new)), hl.array([0]).extend( hl.range(0, hl.len(tmp.data[i].alleles)).map( lambda j: combined_allele_index[tmp.data[i].alleles[j]]))))), hl.dict(hl.range(1, hl.len(tmp.alleles) + 1).map( lambda j: hl.tuple([tmp.alleles[j - 1], j]))))) tmp = tmp.annotate_globals(__cols=hl.flatten(tmp.g.map(lambda g: g.__cols))) return tmp.drop('data', 'g')
def combine(ts): def merge_alleles(alleles): from hail.expr.functions import _num_allele_type, _allele_ints return hl.rbind( alleles.map(lambda a: hl.or_else(a[0], '')).fold( lambda s, t: hl.cond(hl.len(s) > hl.len(t), s, t), ''), lambda ref: hl.rbind( alleles.map(lambda al: hl.rbind( al[0], lambda r: hl.array([ref]). extend(al[1:].map(lambda a: hl.rbind( _num_allele_type(r, a), lambda at: hl.cond( (_allele_ints['SNP'] == at) | (_allele_ints['Insertion'] == at) | (_allele_ints['Deletion'] == at) | (_allele_ints['MNP'] == at) | (_allele_ints['Complex'] == at), a + ref[hl.len( r):], a)))))), lambda lal: hl. struct(globl=hl.array([ref]).extend( hl.array(hl.set(hl.flatten(lal)).remove(ref))), local=lal))) def renumber_entry(entry, old_to_new) -> StructExpression: # global index of alternate (non-ref) alleles return entry.annotate(LA=entry.LA.map(lambda lak: old_to_new[lak])) if (ts.row.dtype, ts.globals.dtype) not in _merge_function_map: f = hl.experimental.define_function( lambda row, gbl: hl.rbind( merge_alleles(row.data.map(lambda d: d.alleles)), lambda alleles: hl.struct( locus=row.locus, alleles=alleles.globl, rsid=hl.find(hl.is_defined, row.data.map(lambda d: d.rsid) ), __entries=hl.bind( lambda combined_allele_index: hl. range(0, hl.len(row.data)).flatmap(lambda i: hl.cond( hl.is_missing(row.data[i].__entries), hl.range(0, hl.len(gbl.g[i].__cols)).map( lambda _: hl.null(row.data[i].__entries.dtype. element_type)), hl.bind( lambda old_to_new: row.data[i].__entries.map( lambda e: renumber_entry(e, old_to_new)), hl.range(0, hl.len(alleles.local[i])).map( lambda j: combined_allele_index[ alleles.local[i][j]])))), hl.dict( hl.range(0, hl.len(alleles.globl)).map( lambda j: hl.tuple([alleles.globl[j], j])))))), ts.row.dtype, ts.globals.dtype) _merge_function_map[(ts.row.dtype, ts.globals.dtype)] = f merge_function = _merge_function_map[(ts.row.dtype, ts.globals.dtype)] ts = Table( TableMapRows( ts._tir, Apply(merge_function._name, merge_function._ret_type, TopLevelReference('row'), TopLevelReference('global')))) return ts.transmute_globals( __cols=hl.flatten(ts.g.map(lambda g: g.__cols)))
def multi_way_union_mts(mts: list, tmp_dir: str, chunk_size: int) -> hl.MatrixTable: """Joins MatrixTables in the provided list :param list mts: list of MatrixTables to join together :param str tmp_dir: path to temporary directory for intermediate results :param int chunk_size: number of MatrixTables to join per chunk :return: joined MatrixTable :rtype: MatrixTable """ staging = [mt.localize_entries("__entries", "__cols") for mt in mts] stage = 0 while len(staging) > 1: n_jobs = int(math.ceil(len(staging) / chunk_size)) info(f"multi_way_union_mts: stage {stage}: {n_jobs} total jobs") next_stage = [] for i in range(n_jobs): to_merge = staging[chunk_size * i:chunk_size * (i + 1)] info( f"multi_way_union_mts: stage {stage} / job {i}: merging {len(to_merge)} inputs" ) merged = hl.Table.multi_way_zip_join(to_merge, "__entries", "__cols") merged = merged.annotate(__entries=hl.flatten( hl.range(hl.len(merged.__entries)).map(lambda i: hl.coalesce( merged.__entries[i].__entries, hl.range(hl.len(merged.__cols[i].__cols)).map( lambda j: hl.null(merged.__entries.__entries.dtype. element_type.element_type)), )))) merged = merged.annotate_globals( __cols=hl.flatten(merged.__cols.map(lambda x: x.__cols))) next_stage.append( merged.checkpoint(os.path.join(tmp_dir, f"stage_{stage}_job_{i}.ht"), overwrite=True)) info(f"done stage {stage}") stage += 1 staging.clear() staging.extend(next_stage) return (staging[0]._unlocalize_entries( "__entries", "__cols", list(mts[0].col_key)).unfilter_entries())
def combine_r(ts): if (ts.row.dtype, ts.globals.dtype) not in _merge_function_map: f = hl.experimental.define_function( lambda row, gbl: hl.struct( locus=row.locus, ref_allele=hl.find(hl.is_defined, row.data.map(lambda d: d.ref_allele)), __entries=hl.range(0, hl.len(row.data)).flatmap( lambda i: hl.if_else(hl.is_missing(row.data[i]), hl.range(0, hl.len(gbl.g[i].__cols)) .map(lambda _: hl.missing(row.data[i].__entries.dtype.element_type)), row.data[i].__entries))), ts.row.dtype, ts.globals.dtype) _merge_function_map[(ts.row.dtype, ts.globals.dtype)] = f merge_function = _merge_function_map[(ts.row.dtype, ts.globals.dtype)] ts = Table(TableMapRows(ts._tir, Apply(merge_function._name, merge_function._ret_type, TopLevelReference('row'), TopLevelReference('global')))) return ts.transmute_globals(__cols=hl.flatten(ts.g.map(lambda g: g.__cols)))
def combine(ts): def merge_alleles(alleles): from hail.expr.functions import _num_allele_type, _allele_ints return hl.rbind( alleles.map(lambda a: hl.or_else(a[0], '')) .fold(lambda s, t: hl.cond(hl.len(s) > hl.len(t), s, t), ''), lambda ref: hl.rbind( alleles.map( lambda al: hl.rbind( al[0], lambda r: hl.array([ref]).extend( al[1:].map( lambda a: hl.rbind( _num_allele_type(r, a), lambda at: hl.cond( (_allele_ints['SNP'] == at) | (_allele_ints['Insertion'] == at) | (_allele_ints['Deletion'] == at) | (_allele_ints['MNP'] == at) | (_allele_ints['Complex'] == at), a + ref[hl.len(r):], a)))))), lambda lal: hl.struct( globl=hl.array([ref]).extend(hl.array(hl.set(hl.flatten(lal)).remove(ref))), local=lal))) def renumber_entry(entry, old_to_new) -> StructExpression: # global index of alternate (non-ref) alleles return entry.annotate(LA=entry.LA.map(lambda lak: old_to_new[lak])) if (ts.row.dtype, ts.globals.dtype) not in _merge_function_map: f = hl.experimental.define_function( lambda row, gbl: hl.rbind( merge_alleles(row.data.map(lambda d: d.alleles)), lambda alleles: hl.struct( locus=row.locus, alleles=alleles.globl, rsid=hl.find(hl.is_defined, row.data.map(lambda d: d.rsid)), __entries=hl.bind( lambda combined_allele_index: hl.range(0, hl.len(row.data)).flatmap( lambda i: hl.cond(hl.is_missing(row.data[i].__entries), hl.range(0, hl.len(gbl.g[i].__cols)) .map(lambda _: hl.null(row.data[i].__entries.dtype.element_type)), hl.bind( lambda old_to_new: row.data[i].__entries.map( lambda e: renumber_entry(e, old_to_new)), hl.range(0, hl.len(alleles.local[i])).map( lambda j: combined_allele_index[alleles.local[i][j]])))), hl.dict(hl.range(0, hl.len(alleles.globl)).map( lambda j: hl.tuple([alleles.globl[j], j])))))), ts.row.dtype, ts.globals.dtype) _merge_function_map[(ts.row.dtype, ts.globals.dtype)] = f merge_function = _merge_function_map[(ts.row.dtype, ts.globals.dtype)] ts = Table(TableMapRows(ts._tir, Apply(merge_function._name, TopLevelReference('row'), TopLevelReference('global')))) return ts.transmute_globals(__cols=hl.flatten(ts.g.map(lambda g: g.__cols)))
'f7': 'phase', 'f8': 'attributes' }) ht_transcripts = ht_transcripts.filter( ht_transcripts.feature_type == 'transcript') ht_transcripts = ht_transcripts.annotate(interval=hl.interval( hl.locus(ht_transcripts.contig, ht_transcripts.start, 'GRCh37'), hl.locus(ht_transcripts.contig, ht_transcripts.end + 1, 'GRCh37'))) ht_transcripts = ht_transcripts.annotate(attributes=hl.dict( hl.map( lambda x: (x.split(' ')[0], x.split(' ')[1].replace('"', '').replace( ';$', '')), ht_transcripts.attributes.split('; ')))) attribute_cols = list( ht_transcripts.aggregate( hl.set(hl.flatten(hl.agg.collect(ht_transcripts.attributes.keys()))))) ht_transcripts = ht_transcripts.annotate( **{ x: hl.or_missing(ht_transcripts.attributes.contains(x), ht_transcripts.attributes[x]) for x in attribute_cols }) ht_transcripts = ht_transcripts.select(*([ 'transcript_id', 'transcript_name', 'transcript_type', 'strand', 'transcript_status', 'havana_transcript', 'ccdsid', 'ont', 'gene_name', 'interval', 'gene_type', 'annotation_source', 'havana_gene', 'gene_status', 'tag' ])) ht_transcripts = ht_transcripts.rename({ 'havana_transcript': 'havana_transcript_id', 'havana_gene': 'havana_gene_id'
def merge_alleles(alleles) -> ArrayExpression: # alleles is tarray(tarray(tstruct(ref=tstr, alt=tstr))) return hl.rbind(hl.array(hl.set(hl.flatten(alleles))), lambda arr: hl.filter(lambda a: a.alt != '<NON_REF>', arr) .extend(hl.filter(lambda a: a.alt == '<NON_REF>', arr)))
def compute_from_vp_mt(chr20: bool, overwrite: bool): meta = get_gnomad_meta('exomes') vp_mt = hl.read_matrix_table(full_mt_path('exomes')) vp_mt = vp_mt.filter_cols(meta[vp_mt.col_key].release) ann_ht = hl.read_table(vp_ann_ht_path('exomes')) phase_ht = hl.read_table(phased_vp_count_ht_path('exomes')) if chr20: vp_mt, ann_ht, phase_ht = filter_to_chr20([vp_mt, ann_ht, phase_ht]) vep1_expr = get_worst_gene_csq_code_expr(ann_ht.vep1) vep2_expr = get_worst_gene_csq_code_expr(ann_ht.vep2) ann_ht = ann_ht.select( 'snv1', 'snv2', is_singleton_vp=(ann_ht.freq1['all'].AC < 2) & (ann_ht.freq2['all'].AC < 2), pop_af=hl.dict( ann_ht.freq1.key_set().intersection(ann_ht.freq2.key_set()) .map( lambda pop: hl.tuple([pop, hl.max(ann_ht.freq1[pop].AF, ann_ht.freq2[pop].AF)]) ) ), popmax_af=hl.max(ann_ht.popmax1.AF, ann_ht.popmax2.AF, filter_missing=False), filtered=(hl.len(ann_ht.filters1) > 0) | (hl.len(ann_ht.filters2) > 0), vep=vep1_expr.keys().filter( lambda k: vep2_expr.contains(k) ).map( lambda k: vep1_expr[k].annotate( csq=hl.max(vep1_expr[k].csq, vep2_expr[k].csq) ) ) ) vp_mt = vp_mt.annotate_cols( pop=meta[vp_mt.col_key].pop ) vp_mt = vp_mt.annotate_rows( **ann_ht[vp_mt.row_key], phase_info=phase_ht[vp_mt.row_key].phase_info ) vp_mt = vp_mt.filter_rows( ~vp_mt.filtered ) vp_mt = vp_mt.filter_entries( vp_mt.GT1.is_het() & vp_mt.GT2.is_het() & vp_mt.adj1 & vp_mt.adj2 ) vp_mt = vp_mt.select_entries( x=True ) vp_mt = vp_mt.annotate_cols( pop=['all', vp_mt.pop] ) vp_mt = vp_mt.explode_cols('pop') vp_mt = vp_mt.explode_rows('vep') vp_mt = vp_mt.transmute_rows( **vp_mt.vep ) def get_grouped_phase_agg(): return hl.agg.group_by( hl.case() .when(~vp_mt.is_singleton_vp & (vp_mt.phase_info[vp_mt.pop].em.adj.p_chet > CHET_THRESHOLD), 1) .when(~vp_mt.is_singleton_vp & (vp_mt.phase_info[vp_mt.pop].em.adj.p_chet < SAME_HAP_THRESHOLD), 2) .default(3) , hl.agg.min(vp_mt.csq) ) vp_mt = vp_mt.group_rows_by( 'gene_id', 'gene_symbol' ).aggregate( all=hl.agg.filter( vp_mt.x & hl.if_else( vp_mt.pop == 'all', hl.is_defined(vp_mt.popmax_af) & (vp_mt.popmax_af <= MAX_FREQ), vp_mt.pop_af[vp_mt.pop] <= MAX_FREQ ), get_grouped_phase_agg() ), af_le_0_001=hl.agg.filter( hl.if_else( vp_mt.pop == 'all', hl.is_defined(vp_mt.popmax_af) & (vp_mt.popmax_af <= 0.001), vp_mt.pop_af[vp_mt.pop] <= 0.001 ) & vp_mt.x, get_grouped_phase_agg() ) ) vp_mt = vp_mt.checkpoint('gs://gnomad-tmp/compound_hets/chet_per_gene{}.2.mt'.format( '.chr20' if chr20 else '' ), overwrite=True) gene_ht = vp_mt.annotate_rows( row_counts=hl.flatten([ hl.array( hl.agg.group_by( vp_mt.pop, hl.struct( csq=csq, af=af, # TODO: Review this # These will only kept the worst csq -- now maybe it'd be better to keep either # - the worst csq for chet or # - the worst csq for both chet and same_hap n_worst_chet=hl.agg.count_where(vp_mt[af].get(1) == csq_i), n_chet=hl.agg.count_where((vp_mt[af].get(1) == csq_i) & (vp_mt[af].get(2, 9) >= csq_i) & (vp_mt[af].get(3, 9) >= csq_i)), n_same_hap=hl.agg.count_where((vp_mt[af].get(2) == csq_i) & (vp_mt[af].get(1, 9) > csq_i) & (vp_mt[af].get(3, 9) >= csq_i)), n_unphased=hl.agg.count_where((vp_mt[af].get(3) == csq_i) & (vp_mt[af].get(1, 9) > csq_i) & (vp_mt[af].get(2, 9) > csq_i)) ) ) ).filter( lambda x: (x[1].n_chet > 0) | (x[1].n_same_hap > 0) | (x[1].n_unphased > 0) ).map( lambda x: x[1].annotate( pop=x[0] ) ) for csq_i, csq in enumerate(CSQ_CODES) for af in ['all', 'af_le_0_001'] ]) ).rows() gene_ht = gene_ht.explode('row_counts') gene_ht = gene_ht.select( **gene_ht.row_counts ) gene_ht.describe() gene_ht = gene_ht.checkpoint( 'gs://gnomad-lfran/compound_hets/chet_per_gene{}.ht'.format( '.chr20' if chr20 else '' ), overwrite=overwrite ) gene_ht.flatten().export( 'gs://gnomad-lfran/compound_hets/chet_per_gene{}.tsv.gz'.format( '.chr20' if chr20 else '' ) )
def compute_from_full_mt(chr20: bool, overwrite: bool): mt = get_gnomad_data('exomes', adj=True, release_samples=True) freq_ht = hl.read_table(annotations_ht_path('exomes', 'frequencies')) vep_ht = hl.read_table(annotations_ht_path('exomes', 'vep')) rf_ht = hl.read_table(annotations_ht_path('exomes', 'rf')) if chr20: mt, freq_ht, vep_ht, rf_ht = filter_to_chr20([mt, freq_ht, vep_ht, rf_ht]) vep_ht = vep_ht.annotate( vep=get_worst_gene_csq_code_expr(vep_ht.vep).values() ) freq_ht = freq_ht.select( freq=freq_ht.freq[:10], popmax=freq_ht.popmax ) freq_meta = hl.eval(freq_ht.globals.freq_meta) freq_dict = {f['pop']: i for i, f in enumerate(freq_meta[:10]) if 'pop' in f} freq_dict['all'] = 0 freq_dict = hl.literal(freq_dict) mt = mt.annotate_rows( **freq_ht[mt.row_key], vep=vep_ht[mt.row_key].vep, filters=rf_ht[mt.row_key].filters ) mt = mt.filter_rows( (mt.freq[0].AF <= MAX_FREQ) & (hl.len(mt.vep) > 0) & (hl.len(mt.filters) == 0) ) mt = mt.filter_entries(mt.GT.is_non_ref()) mt = mt.select_entries( is_het=mt.GT.is_het() ) mt = mt.explode_rows(mt.vep) mt = mt.transmute_rows(**mt.vep) mt = mt.annotate_cols( pop=['all', mt.meta.pop] ) mt = mt.explode_cols(mt.pop) mt = mt.group_rows_by( 'gene_id' ).aggregate_rows( gene_symbol=hl.agg.take(mt.gene_symbol, 1)[0] ).aggregate( counts=hl.agg.filter( hl.if_else( mt.pop == 'all', hl.is_defined(mt.popmax) & (mt.popmax.AF <= MAX_FREQ), mt.freq[freq_dict[mt.pop]].AF <= MAX_FREQ ), hl.agg.group_by( hl.if_else( mt.pop == 'all', mt.popmax.AF > 0.001, mt.freq[freq_dict[mt.pop]].AF > 0.001 ), hl.struct( hom_csq=hl.agg.filter(~mt.is_het, hl.agg.min(mt.csq)), het_csq=hl.agg.filter(mt.is_het, hl.agg.min(mt.csq)), het_het_csq=hl.sorted( hl.array( hl.agg.filter(mt.is_het, hl.agg.counter(mt.csq)) ), key=lambda x: x[0] ).scan( lambda i, j: (j[0], i[1] + j[1]), (0, 0) ).find( lambda x: x[1] > 1 )[0] ) ) ) ) mt = mt.annotate_entries( counts=hl.struct( all=hl.struct( hom_csq=hl.min(mt.counts.get(True).hom_csq, mt.counts.get(False).hom_csq), het_csq=hl.min(mt.counts.get(True).het_csq, mt.counts.get(False).het_csq), het_het_csq=hl.min( mt.counts.get(True).het_het_csq, mt.counts.get(False).het_het_csq, hl.or_missing( hl.is_defined(mt.counts.get(True).het_csq) & hl.is_defined(mt.counts.get(False).het_csq), hl.max(mt.counts.get(True).het_csq, mt.counts.get(False).het_csq) ) ), ), af_le_0_001=mt.counts.get(False) ) ) mt = mt.checkpoint('gs://gnomad-tmp/compound_hets/het_and_hom_per_gene{}.1.mt'.format( '.chr20' if chr20 else '' ), overwrite=True) gene_ht = mt.annotate_rows( row_counts=hl.flatten([ hl.array( hl.agg.group_by( mt.pop, hl.struct( csq=csq, af=af, n_hom=hl.agg.count_where(mt.counts[af].hom_csq == csq_i), n_het=hl.agg.count_where(mt.counts[af].het_csq == csq_i), n_het_het=hl.agg.count_where(mt.counts[af].het_het_csq == csq_i) ) ) ).filter( lambda x: (x[1].n_het > 0) | (x[1].n_hom > 0) | (x[1].n_het_het > 0) ).map( lambda x: x[1].annotate( pop=x[0] ) ) for csq_i, csq in enumerate(CSQ_CODES) for af in ['all', 'af_le_0_001'] ]) ).rows() gene_ht = gene_ht.explode('row_counts') gene_ht = gene_ht.select( 'gene_symbol', **gene_ht.row_counts ) gene_ht.describe() gene_ht = gene_ht.checkpoint( 'gs://gnomad-lfran/compound_hets/het_and_hom_per_gene{}.ht'.format( '.chr20' if chr20 else '' ), overwrite=overwrite ) gene_ht.flatten().export('gs://gnomad-lfran/compound_hets/het_and_hom_per_gene{}.tsv.gz'.format( '.chr20' if chr20 else '' ))
def segment_intervals(ht, points): """Segment the interval keys of `ht` at a given set of points. Parameters ---------- ht : :class:`.Table` Table with interval keys. points : :class:`.Table` or :class:`.ArrayExpression` Points at which to segment the intervals, a table or an array. Returns ------- :class:`.Table` """ if len(ht.key) != 1 or not isinstance(ht.key[0].dtype, hl.tinterval): raise ValueError( "'segment_intervals' expects a table with interval keys") point_type = ht.key[0].dtype.point_type if isinstance(points, Table): if len(points.key) != 1 or points.key[0].dtype != point_type: raise ValueError( "'segment_intervals' expects points to be a table with a single" " key of the same type as the intervals in 'ht', or an array of those points:" f"\n expect {point_type}, found {list(points.key.dtype.values())}" ) points = hl.array(hl.set(points.collect(_localize=False))) if points.dtype.element_type != point_type: raise ValueError( f"'segment_intervals' expects points to be a table with a single" f" key of the same type as the intervals in 'ht', or an array of those points:" f"\n expect {point_type}, found {points.dtype.element_type}") points = hl._sort_by(points, lambda l, r: hl._compare(l, r) < 0) ht = ht.annotate_globals(__points=points) interval = ht.key[0] points = ht.__points lower = hl.expr.functions._lower_bound(points, interval.start) higher = hl.expr.functions._lower_bound(points, interval.end) n_points = hl.len(points) lower = hl.if_else((lower < n_points) & (points[lower] == interval.start), lower + 1, lower) higher = hl.if_else((higher < n_points) & (points[higher] == interval.end), higher - 1, higher) interval_results = hl.rbind( lower, higher, lambda lower, higher: hl.cond( lower >= higher, [interval], hl.flatten([ [ hl.interval(interval.start, points[lower], includes_start=interval.includes_start, includes_end=False) ], hl.range(lower, higher - 1).map(lambda x: hl.interval( points[x], points[x + 1], includes_start=True, includes_end=False)), [ hl.interval(points[higher - 1], interval.end, includes_start=True, includes_end=interval.includes_end) ], ]))) ht = ht.annotate(__new_intervals=interval_results, lower=lower, higher=higher).explode('__new_intervals') return ht.key_by(**{ list(ht.key)[0]: ht.__new_intervals }).drop('__new_intervals')
'SMRRNANM', 'SMVQCFL', 'SMTRSCPT', 'SMMPPDPR', 'SMCGLGTH', 'SMUNPDRD', 'SMMPPDUN', 'SME2ANTI', 'SMALTALG', 'SME2SNSE', 'SMMFLGTH', 'SMSPLTRD', 'SME1ANTI', 'SME1SNSE', 'SMNUM5CD'] ht_samples = ht_samples.annotate(**{x: hl.float(ht_samples[x]) for x in float_cols}) ht_samples = ht_samples.annotate(**{x: hl.int(ht_samples[x].replace('.0$', '')) for x in int_cols}) ht = ht.filter(ht.feature_type == 'gene') ht = ht.annotate(interval=hl.interval(hl.locus(ht['contig'], ht['start'], 'GRCh37'), hl.locus(ht['contig'], ht['end'] + 1, 'GRCh37'))) ht = ht.annotate(attributes=hl.dict(hl.map(lambda x: (x.split(' ')[0], x.split(' ')[1].replace('"', '').replace(';$', '')), ht['attributes'].split('; ')))) attribute_cols = list(ht.aggregate(hl.set(hl.flatten(hl.agg.collect(ht.attributes.keys()))))) ht = ht.annotate(**{x: hl.or_missing(ht_genes.attributes.contains(x), ht_genes.attributes[x]) for x in attribute_cols}) ht = ht.select(*(['gene_id', 'interval', 'gene_type', 'strand', 'annotation_source', 'havana_gene', 'gene_status', 'tag'])) ht = ht.rename({'havana_gene': 'havana_gene_id'}) ht = ht.key_by(ht_genes.gene_id) """
def import_gtf(path, key=None): """Import a GTF file. The GTF file format is identical to the GFF version 2 file format, and so this function can be used to import GFF version 2 files as well. See https://www.ensembl.org/info/website/upload/gff.html for more details on the GTF/GFF2 file format. The :class:`.Table` returned by this function will include the following row fields: .. code-block:: text 'seqname': str 'source': str 'feature': str 'start': int32 'end': int32 'score': float64 'strand': str 'frame': int32 There will also be corresponding fields for every tag found in the attribute field of the GTF file. .. note:: The "end" field in the table will be incremented by 1 in comparison to the value found in the GTF file, as the end coordinate in a GTF file is inclusive while the end coordinate in Hail is exclusive. Example ------- >>> ht = hl.experimental.import_gtf('data/test.gtf', key='gene_id') >>> ht.describe() .. code-block:: text ---------------------------------------- Global fields: None ---------------------------------------- Row fields: 'seqname': str 'source': str 'feature': str 'start': int32 'end': int32 'score': float64 'strand': str 'frame': int32 'havana_gene': str 'exon_id': str 'havana_transcript': str 'transcript_name': str 'gene_type': str 'tag': str 'transcript_status': str 'exon_number': str 'level': str 'transcript_id': str 'transcript_type': str 'gene_id': str 'gene_name': str 'gene_status': str ---------------------------------------- Key: ['gene_id'] ---------------------------------------- Parameters ---------- path : :obj:`str` File to import. key : :obj:`str` or :obj:`list` of :obj:`str` Key field(s). Can be tag name(s) found in the attribute field of the GTF file. Returns ------- :class:`.Table` """ ht = hl.import_table(path, comment='#', no_header=True, types={'f3': hl.tint, 'f4': hl.tint, 'f5': hl.tfloat, 'f7': hl.tint}, missing='.', delimiter='\t') ht = ht.rename({'f0': 'seqname', 'f1': 'source', 'f2': 'feature', 'f3': 'start', 'f4': 'end', 'f5': 'score', 'f6': 'strand', 'f7': 'frame', 'f8': 'attribute'}) ht = ht.annotate(end=ht['end'] + 1) ht = ht.annotate(attribute=hl.dict( hl.map(lambda x: (x.split(' ')[0], x.split(' ')[1].replace('"', '').replace(';$', '')), ht['attribute'].split('; ')))) attributes = list(ht.aggregate( hl.set(hl.flatten(hl.agg.collect(ht['attribute'].keys()))))) ht = ht.annotate(**{x: hl.or_missing(ht['attribute'].contains(x), ht['attribute'][x]) for x in attributes}) ht = ht.drop(ht['attribute']) if key: key = wrap_to_list(key) ht = ht.key_by(*key) return ht
def merge_alleles(alleles) -> ArrayExpression: return hl.array(hl.set(hl.flatten(alleles)))
def fs_from_sb( sb: Union[hl.expr.ArrayNumericExpression, hl.expr.ArrayExpression], normalize: bool = True, min_cell_count: int = 200, min_count: int = 4, min_p_value: float = 1e-320, ) -> hl.expr.Int64Expression: """ Computes `FS` (Fisher strand balance) annotation from the `SB` (strand balance table) field. `FS` is the phred-scaled value of the double-sided Fisher exact test on strand balance. Using default values will have the same behavior as the GATK implementation, that is: - If sum(counts) > 2*`min_cell_count` (default to GATK value of 200), they are normalized - If sum(counts) < `min_count` (default to GATK value of 4), returns missing - Any p-value < `min_p_value` (default to GATK value of 1e-320) is truncated to that value In addition to the default GATK behavior, setting `normalize` to `False` will perform a chi-squared test for large counts (> `min_cell_count`) instead of normalizing the cell values. .. note:: This function can either take - an array of length four containing the forward and reverse strands' counts of ref and alt alleles: [ref fwd, ref rev, alt fwd, alt rev] - a two dimensional array with arrays of length two, containing the counts: [[ref fwd, ref rev], [alt fwd, alt rev]] GATK code here: https://github.com/broadinstitute/gatk/blob/master/src/main/java/org/broadinstitute/hellbender/tools/walkers/annotator/FisherStrand.java :param sb: Count of ref/alt reads on each strand :param normalize: Whether to normalize counts is sum(counts) > min_cell_count (normalize=True), or use a chi sq instead of FET (normalize=False) :param min_cell_count: Maximum count for performing a FET :param min_count: Minimum total count to output FS (otherwise null it output) :return: FS value """ if not isinstance(sb, hl.expr.ArrayNumericExpression): sb = hl.bind(lambda x: hl.flatten(x), sb) sb_sum = hl.bind(lambda x: hl.sum(x), sb) # Normalize table if counts get too large if normalize: fs_expr = hl.bind( lambda sb, sb_sum: hl.cond( sb_sum <= 2 * min_cell_count, sb, sb.map(lambda x: hl.int(x / (sb_sum / min_cell_count))), ), sb, sb_sum, ) # FET fs_expr = to_phred( hl.max( hl.fisher_exact_test( fs_expr[0], fs_expr[1], fs_expr[2], fs_expr[3] ).p_value, min_p_value, ) ) else: fs_expr = to_phred( hl.max( hl.contingency_table_test( sb[0], sb[1], sb[2], sb[3], min_cell_count=min_cell_count ).p_value, min_p_value, ) ) # Return null if counts <= `min_count` return hl.or_missing( sb_sum > min_count, hl.max(0, fs_expr) # Needed to avoid -0.0 values )
def fields_to_array(ds, fields): return hl.flatten(hl.array([field_to_array(ds, f) for f in fields]))
, "TFBS_amplification" , "TF_binding_site_variant" , "regulatory_region_ablation" , "regulatory_region_amplification" , "feature_elongation" , "regulatory_region_variant" , "feature_truncation" , "intergenic_variant" ] mt4 = kt2.annotate(transcript_canonicals =kt2.vep.transcript_consequences.filter(lambda tc: tc.canonical == 1)) mt4 = mt4.annotate( all_transcript_terms = hl.set(hl.flatten(mt4.transcript_canonicals.map(lambda x: x.consequence_terms))) ) mt4 = mt4.annotate( Consequence = hl.coalesce( # coalesce means take first non-missing hl.literal(consequence_in_severity_order).filter(lambda cnsq: mt4.all_transcript_terms.contains(cnsq) ).head(), mt4.vep.most_severe_consequence ) ) mt4 = mt4.annotate( Gene = hl.if_else(mt4.transcript_canonicals.any(lambda tc: hl.set(tc.consequence_terms).contains(mt4.Consequence)), \ mt4.transcript_canonicals.find(lambda tc: (tc.canonical == 1) & (hl.set(tc.consequence_terms).contains(mt4.Consequence))).gene_symbol, \ mt4.vep.transcript_consequences.find(lambda tc: hl.set(tc.consequence_terms).contains(mt4.Consequence)).gene_symbol), \