def test_agg_explode(self): t = hl.Table.parallelize([ hl.struct(a=[1, 2]), hl.struct(a=hl.empty_array(hl.tint32)), hl.struct(a=hl.null(hl.tarray(hl.tint32))), hl.struct(a=[3]), hl.struct(a=[hl.null(hl.tint32)]) ]) self.assertCountEqual(t.aggregate(hl.agg.explode(lambda elt: hl.agg.collect(elt), t.a)), [1, 2, None, 3])
def test_multi_way_zip_join_globals(self): t1 = hl.utils.range_table(1).annotate_globals(x=hl.null(hl.tint32)) t2 = hl.utils.range_table(1).annotate_globals(x=5) t3 = hl.utils.range_table(1).annotate_globals(x=0) expected = hl.struct(__globals=hl.array([ hl.struct(x=hl.null(hl.tint32)), hl.struct(x=5), hl.struct(x=0)])) joined = hl.Table._multi_way_zip_join([t1, t2, t3], '__data', '__globals') self.assertEqual(hl.eval(joined.globals), hl.eval(expected))
def test_reference_genome_liftover(self): grch37 = hl.get_reference('GRCh37') grch38 = hl.get_reference('GRCh38') self.assertTrue(not grch37.has_liftover('GRCh38') and not grch38.has_liftover('GRCh37')) grch37.add_liftover(resource('grch37_to_grch38_chr20.over.chain.gz'), 'GRCh38') grch38.add_liftover(resource('grch38_to_grch37_chr20.over.chain.gz'), 'GRCh37') self.assertTrue(grch37.has_liftover('GRCh38') and grch38.has_liftover('GRCh37')) ds = hl.import_vcf(resource('sample.vcf')) t = ds.annotate_rows(liftover=hl.liftover(hl.liftover(ds.locus, 'GRCh38'), 'GRCh37')).rows() self.assertTrue(t.all(t.locus == t.liftover)) null_locus = hl.null(hl.tlocus('GRCh38')) rows = [ {'l37': hl.locus('20', 1, 'GRCh37'), 'l38': null_locus}, {'l37': hl.locus('20', 60000, 'GRCh37'), 'l38': null_locus}, {'l37': hl.locus('20', 60001, 'GRCh37'), 'l38': hl.locus('chr20', 79360, 'GRCh38')}, {'l37': hl.locus('20', 278686, 'GRCh37'), 'l38': hl.locus('chr20', 298045, 'GRCh38')}, {'l37': hl.locus('20', 278687, 'GRCh37'), 'l38': hl.locus('chr20', 298046, 'GRCh38')}, {'l37': hl.locus('20', 278688, 'GRCh37'), 'l38': null_locus}, {'l37': hl.locus('20', 278689, 'GRCh37'), 'l38': null_locus}, {'l37': hl.locus('20', 278690, 'GRCh37'), 'l38': null_locus}, {'l37': hl.locus('20', 278691, 'GRCh37'), 'l38': hl.locus('chr20', 298047, 'GRCh38')}, {'l37': hl.locus('20', 37007586, 'GRCh37'), 'l38': hl.locus('chr12', 32563117, 'GRCh38')}, {'l37': hl.locus('20', 62965520, 'GRCh37'), 'l38': hl.locus('chr20', 64334167, 'GRCh38')}, {'l37': hl.locus('20', 62965521, 'GRCh37'), 'l38': null_locus} ] schema = hl.tstruct(l37=hl.tlocus(grch37), l38=hl.tlocus(grch38)) t = hl.Table.parallelize(rows, schema) self.assertTrue(t.all(hl.cond(hl.is_defined(t.l38), hl.liftover(t.l37, 'GRCh38') == t.l38, hl.is_missing(hl.liftover(t.l37, 'GRCh38'))))) t = t.filter(hl.is_defined(t.l38)) self.assertTrue(t.count() == 6) t = t.key_by('l38') t.count() self.assertTrue(list(t.key) == ['l38']) null_locus_interval = hl.null(hl.tinterval(hl.tlocus('GRCh38'))) rows = [ {'i37': hl.locus_interval('20', 1, 60000, True, False, 'GRCh37'), 'i38': null_locus_interval}, {'i37': hl.locus_interval('20', 60001, 82456, True, True, 'GRCh37'), 'i38': hl.locus_interval('chr20', 79360, 101815, True, True, 'GRCh38')} ] schema = hl.tstruct(i37=hl.tinterval(hl.tlocus(grch37)), i38=hl.tinterval(hl.tlocus(grch38))) t = hl.Table.parallelize(rows, schema) self.assertTrue(t.all(hl.liftover(t.i37, 'GRCh38') == t.i38)) grch37.remove_liftover("GRCh38") grch38.remove_liftover("GRCh37")
def test_refs_with_process_joins(self): mt = hl.utils.range_matrix_table(10, 10) mt = mt.annotate_entries( a_literal=hl.literal(['a']), a_col_join=hl.is_defined(mt.cols()[mt.col_key]), a_row_join=hl.is_defined(mt.rows()[mt.row_key]), an_entry_join=hl.is_defined(mt[mt.row_key, mt.col_key]), the_global_failure=hl.cond(True, mt.globals, hl.null(mt.globals.dtype)), the_row_failure=hl.cond(True, mt.row, hl.null(mt.row.dtype)), the_col_failure=hl.cond(True, mt.col, hl.null(mt.col.dtype)), the_entry_failure=hl.cond(True, mt.entry, hl.null(mt.entry.dtype)), ) mt.count()
def test_export_import_plink_same(self): mt = get_dataset() mt = mt.select_rows(rsid=hl.delimit([mt.locus.contig, hl.str(mt.locus.position), mt.alleles[0], mt.alleles[1]], ':'), cm_position=15.0) mt = mt.select_cols(fam_id=hl.null(hl.tstr), pat_id=hl.null(hl.tstr), mat_id=hl.null(hl.tstr), is_female=hl.null(hl.tbool), is_case=hl.null(hl.tbool)) mt = mt.select_entries('GT') bfile = '/tmp/test_import_export_plink' hl.export_plink(mt, bfile, ind_id=mt.s, cm_position=mt.cm_position) mt_imported = hl.import_plink(bfile + '.bed', bfile + '.bim', bfile + '.fam', a2_reference=True, reference_genome='GRCh37') self.assertTrue(mt._same(mt_imported)) self.assertTrue(mt.aggregate_rows(hl.agg.all(mt.cm_position == 15.0)))
def test_from_entry_expr_options(self): def build_mt(a): data = [{'v': 0, 's': 0, 'x': a[0]}, {'v': 0, 's': 1, 'x': a[1]}, {'v': 0, 's': 2, 'x': a[2]}] ht = hl.Table.parallelize(data, hl.dtype('struct{v: int32, s: int32, x: float64}')) mt = ht.to_matrix_table(['v'], ['s']) ids = mt.key_cols_by()['s'].collect() return mt.choose_cols([ids.index(0), ids.index(1), ids.index(2)]) def check(expr, mean_impute, center, normalize, expected): actual = np.squeeze(BlockMatrix.from_entry_expr(expr, mean_impute=mean_impute, center=center, normalize=normalize).to_numpy()) assert np.allclose(actual, expected) a = np.array([0.0, 1.0, 2.0]) mt = build_mt(a) check(mt.x, False, False, False, a) check(mt.x, False, True, False, a - 1.0) check(mt.x, False, False, True, a / np.sqrt(5)) check(mt.x, False, True, True, (a - 1.0) / np.sqrt(2)) check(mt.x + 1 - 1, False, False, False, a) mt = build_mt([0.0, hl.null('float64'), 2.0]) check(mt.x, True, False, False, a) check(mt.x, True, True, False, a - 1.0) check(mt.x, True, False, True, a / np.sqrt(5)) check(mt.x, True, True, True, (a - 1.0) / np.sqrt(2)) with self.assertRaises(Exception): BlockMatrix.from_entry_expr(mt.x)
def test_annotate_intervals(self): ds = get_dataset() bed1 = hl.import_bed(resource('example1.bed'), reference_genome='GRCh37') bed2 = hl.import_bed(resource('example2.bed'), reference_genome='GRCh37') bed3 = hl.import_bed(resource('example3.bed'), reference_genome='GRCh37') self.assertTrue(list(bed2.key.dtype) == ['interval']) self.assertTrue(list(bed2.row.dtype) == ['interval', 'target']) interval_list1 = hl.import_locus_intervals(resource('exampleAnnotation1.interval_list')) interval_list2 = hl.import_locus_intervals(resource('exampleAnnotation2.interval_list')) self.assertTrue(list(interval_list2.key.dtype) == ['interval']) self.assertTrue(list(interval_list2.row.dtype) == ['interval', 'target']) ann = ds.annotate_rows(in_interval=bed1[ds.locus]).rows() self.assertTrue(ann.all((ann.locus.position <= 14000000) | (ann.locus.position >= 17000000) | (hl.is_missing(ann.in_interval)))) for bed in [bed2, bed3]: ann = ds.annotate_rows(target=bed[ds.locus].target).rows() expr = (hl.case() .when(ann.locus.position <= 14000000, ann.target == 'gene1') .when(ann.locus.position >= 17000000, ann.target == 'gene2') .default(ann.target == hl.null(hl.tstr))) self.assertTrue(ann.all(expr)) self.assertTrue(ds.annotate_rows(in_interval=interval_list1[ds.locus]).rows() ._same(ds.annotate_rows(in_interval=bed1[ds.locus]).rows())) self.assertTrue(ds.annotate_rows(target=interval_list2[ds.locus].target).rows() ._same(ds.annotate_rows(target=bed2[ds.locus].target).rows()))
def test_aggregate_ir(self): ds = (hl.utils.range_matrix_table(5, 5) .annotate_globals(g1=5) .annotate_entries(e1=3)) x = [("col_idx", lambda e: ds.aggregate_cols(e)), ("row_idx", lambda e: ds.aggregate_rows(e))] for name, f in x: r = f(hl.struct(x=agg.sum(ds[name]) + ds.g1, y=agg.filter(ds[name] % 2 != 0, agg.sum(ds[name] + 2)) + ds.g1, z=agg.sum(ds.g1 + ds[name]) + ds.g1, mean=agg.mean(ds[name]))) self.assertEqual(convert_struct_to_dict(r), {u'x': 15, u'y': 13, u'z': 40, u'mean': 2.0}) r = f(5) self.assertEqual(r, 5) r = f(hl.null(hl.tint32)) self.assertEqual(r, None) r = f(agg.filter(ds[name] % 2 != 0, agg.sum(ds[name] + 2)) + ds.g1) self.assertEqual(r, 13) r = ds.aggregate_entries(agg.filter((ds.row_idx % 2 != 0) & (ds.col_idx % 2 != 0), agg.sum(ds.e1 + ds.g1 + ds.row_idx + ds.col_idx)) + ds.g1) self.assertTrue(r, 48)
def downsample(x, y, label=None, n_divisions=500) -> ArrayExpression: """Downsample (x, y) coordinate datapoints. Parameters --------- x : :class:`.NumericExpression` X-values to be downsampled. y : :class:`.NumericExpression` Y-values to be downsampled. label : :class:`.StringExpression` or :class:`.ArrayExpression` Additional data for each (x, y) coordinate. Can pass in multiple fields in an :class:`.ArrayExpression`. n_divisions : :obj:`int` Factor by which to downsample (default value = 500). A lower input results in fewer output datapoints. Returns ------- :class:`.ArrayExpression` Expression for downsampled coordinate points (x, y). The element type of the array is :py:data:`.ttuple` of :py:data:`.tfloat64`, :py:data:`.tfloat64`, and :py:data:`.tarray` of :py:data:`.tstring` """ if label is None: label = hl.null(hl.tarray(hl.tstr)) elif isinstance(label, StringExpression): label = hl.array([label]) return _agg_func('downsample', [x, y, label], tarray(ttuple(tfloat64, tfloat64, tarray(tstr))), constructor_args=[n_divisions])
def combine(ts): # pylint: disable=protected-access tmp = ts.annotate( alleles=merge_alleles(ts.data.map(lambda d: d.alleles)), rsid=hl.find(hl.is_defined, ts.data.map(lambda d: d.rsid)), filters=hl.set(hl.flatten(ts.data.map(lambda d: hl.array(d.filters)))), info=hl.struct( DP=hl.sum(ts.data.map(lambda d: d.info.DP)), MQ_DP=hl.sum(ts.data.map(lambda d: d.info.MQ_DP)), QUALapprox=hl.sum(ts.data.map(lambda d: d.info.QUALapprox)), RAW_MQ=hl.sum(ts.data.map(lambda d: d.info.RAW_MQ)), VarDP=hl.sum(ts.data.map(lambda d: d.info.VarDP)), SB=hl.array([ hl.sum(ts.data.map(lambda d: d.info.SB[0])), hl.sum(ts.data.map(lambda d: d.info.SB[1])), hl.sum(ts.data.map(lambda d: d.info.SB[2])), hl.sum(ts.data.map(lambda d: d.info.SB[3])) ]))) tmp = tmp.annotate( __entries=hl.bind( lambda combined_allele_index: hl.range(0, hl.len(tmp.data)).flatmap( lambda i: hl.cond(hl.is_missing(tmp.data[i].__entries), hl.range(0, hl.len(tmp.g[i].__cols)) .map(lambda _: hl.null(tmp.data[i].__entries.dtype.element_type)), hl.bind( lambda old_to_new: tmp.data[i].__entries.map(lambda e: renumber_entry(e, old_to_new)), hl.range(0, hl.len(tmp.data[i].alleles)).map( lambda j: combined_allele_index[tmp.data[i].alleles[j]])))), hl.dict(hl.range(0, hl.len(tmp.alleles)).map( lambda j: hl.tuple([tmp.alleles[j], j]))))) tmp = tmp.annotate_globals(__cols=hl.flatten(tmp.g.map(lambda g: g.__cols))) return tmp.drop('data', 'g')
def test_aggregate2(self): schema = hl.tstruct(status=hl.tint32, GT=hl.tcall, qPheno=hl.tint32) rows = [{'status': 0, 'GT': hl.Call([0, 0]), 'qPheno': 3}, {'status': 0, 'GT': hl.Call([0, 1]), 'qPheno': 13}] kt = hl.Table.parallelize(rows, schema) result = convert_struct_to_dict( kt.group_by(status=kt.status) .aggregate( x1=agg.collect(kt.qPheno * 2), x2=agg.explode(lambda elt: agg.collect(elt), [kt.qPheno, kt.qPheno + 1]), x3=agg.min(kt.qPheno), x4=agg.max(kt.qPheno), x5=agg.sum(kt.qPheno), x6=agg.product(hl.int64(kt.qPheno)), x7=agg.count(), x8=agg.count_where(kt.qPheno == 3), x9=agg.fraction(kt.qPheno == 1), x10=agg.stats(hl.float64(kt.qPheno)), x11=agg.hardy_weinberg_test(kt.GT), x13=agg.inbreeding(kt.GT, 0.1), x14=agg.call_stats(kt.GT, ["A", "T"]), x15=agg.collect(hl.Struct(a=5, b="foo", c=hl.Struct(banana='apple')))[0], x16=agg.collect(hl.Struct(a=5, b="foo", c=hl.Struct(banana='apple')).c.banana)[0], x17=agg.explode(lambda elt: agg.collect(elt), hl.null(hl.tarray(hl.tint32))), x18=agg.explode(lambda elt: agg.collect(elt), hl.null(hl.tset(hl.tint32))), x19=agg.take(kt.GT, 1, ordering=-kt.qPheno) ).take(1)[0]) expected = {u'status': 0, u'x13': {u'n_called': 2, u'expected_homs': 1.64, u'f_stat': -1.777777777777777, u'observed_homs': 1}, u'x14': {u'AC': [3, 1], u'AF': [0.75, 0.25], u'AN': 4, u'homozygote_count': [1, 0]}, u'x15': {u'a': 5, u'c': {u'banana': u'apple'}, u'b': u'foo'}, u'x10': {u'min': 3.0, u'max': 13.0, u'sum': 16.0, u'stdev': 5.0, u'n': 2, u'mean': 8.0}, u'x8': 1, u'x9': 0.0, u'x16': u'apple', u'x11': {u'het_freq_hwe': 0.5, u'p_value': 0.5}, u'x2': [3, 4, 13, 14], u'x3': 3, u'x1': [6, 26], u'x6': 39, u'x7': 2, u'x4': 13, u'x5': 16, u'x17': [], u'x18': [], u'x19': [hl.Call([0, 1])]} self.maxDiff = None self.assertDictEqual(result, expected)
def combine(ts): def merge_alleles(alleles): from hail.expr.functions import _num_allele_type, _allele_ints return hl.rbind( alleles.map(lambda a: hl.or_else(a[0], '')).fold( lambda s, t: hl.cond(hl.len(s) > hl.len(t), s, t), ''), lambda ref: hl.rbind( alleles.map(lambda al: hl.rbind( al[0], lambda r: hl.array([ref]). extend(al[1:].map(lambda a: hl.rbind( _num_allele_type(r, a), lambda at: hl.cond( (_allele_ints['SNP'] == at) | (_allele_ints['Insertion'] == at) | (_allele_ints['Deletion'] == at) | (_allele_ints['MNP'] == at) | (_allele_ints[ 'Complex'] == at), a + ref[hl.len(r):], a) ))))), lambda lal: hl.struct(globl=hl.array([ref]).extend( hl.array(hl.set(hl.flatten(lal)).remove(ref))), local=lal))) def renumber_entry(entry, old_to_new) -> StructExpression: # global index of alternate (non-ref) alleles return entry.annotate(LA=entry.LA.map(lambda lak: old_to_new[lak])) if (ts.row.dtype, ts.globals.dtype) not in _merge_function_map: f = hl.experimental.define_function( lambda row, gbl: hl.rbind( merge_alleles(row.data.map(lambda d: d.alleles)), lambda alleles: hl.struct( locus=row.locus, alleles=alleles.globl, rsid=hl.find(hl.is_defined, row.data.map(lambda d: d.rsid) ), __entries=hl.bind( lambda combined_allele_index: hl. range(0, hl.len(row.data)).flatmap(lambda i: hl.cond( hl.is_missing(row.data[i].__entries), hl.range(0, hl.len(gbl.g[i].__cols)).map( lambda _: hl.null(row.data[i].__entries.dtype. element_type)), hl.bind( lambda old_to_new: row.data[i].__entries.map( lambda e: renumber_entry(e, old_to_new)), hl.range(0, hl.len(alleles.local[i])).map( lambda j: combined_allele_index[ alleles.local[i][j]])))), hl.dict( hl.range(0, hl.len(alleles.globl)).map( lambda j: hl.tuple([alleles.globl[j], j])))))), ts.row.dtype, ts.globals.dtype) _merge_function_map[(ts.row.dtype, ts.globals.dtype)] = f merge_function = _merge_function_map[(ts.row.dtype, ts.globals.dtype)] ts = Table( TableMapRows( ts._tir, Apply(merge_function._name, merge_function._ret_type, TopLevelReference('row'), TopLevelReference('global')))) return ts.transmute_globals( __cols=hl.flatten(ts.g.map(lambda g: g.__cols)))
def unphase_mt(mt: hl.MatrixTable) -> hl.MatrixTable: """ Generate unphased version of MatrixTable (assumes call is in mt.GT and is diploid or haploid only) """ return mt.annotate_entries(GT=hl.case().when( mt.GT.is_diploid(), hl.call(mt.GT[0], mt.GT[1], phased=False)).when( mt.GT.is_haploid(), hl.call(mt.GT[0], phased=False)).default( hl.null(hl.tcall)))
def unify_saige_ht_variant_schema(ht): shared = ('markerID', 'AC', 'AF', 'N', 'BETA', 'SE', 'Tstat', 'varT', 'varTstar') new_floats = ('AF.Cases', 'AF.Controls') new_ints = ('N.Cases', 'N.Controls') shared_end = ('Pvalue', 'gene', 'annotation') if 'AF.Cases' not in list(ht.row): ht = ht.select(*shared, **{field: hl.null(hl.tfloat64) for field in new_floats}, **{field: hl.null(hl.tint32) for field in new_ints}, **{field: ht[field] for field in shared_end}) else: ht = ht.select(*shared, *new_floats, *new_ints, *shared_end) return ht.annotate(SE=hl.float64(ht.SE), AC=hl.int32(ht.AC))
def test_call_fields(self): expected = hl.Table.parallelize( [hl.struct(locus = hl.locus("X", 16050036), s = "C1046::HG02024", GT = hl.call(0, 0), GTA = hl.null(hl.tcall), GTZ = hl.call(0, 1)), hl.struct(locus = hl.locus("X", 16050036), s = "C1046::HG02025", GT = hl.call(1), GTA = hl.null(hl.tcall), GTZ = hl.call(0)), hl.struct(locus = hl.locus("X", 16061250), s = "C1046::HG02024", GT = hl.call(2, 2), GTA = hl.call(2, 1), GTZ = hl.call(1, 1)), hl.struct(locus = hl.locus("X", 16061250), s = "C1046::HG02025", GT = hl.call(2), GTA = hl.null(hl.tcall), GTZ = hl.call(1))], key=['locus', 's']) mt = hl.import_vcf(resource('generic.vcf'), call_fields=['GT', 'GTA', 'GTZ']) entries = mt.entries() entries = entries.key_by('locus', 's') entries = entries.select('GT', 'GTA', 'GTZ') self.assertTrue(entries._same(expected))
def split_position_end(position): return hl.or_missing( hl.is_defined(position), hl.bind( lambda start: hl.cond(start == "?", hl.null(hl.tint), hl.int(start) ), position.split("-")[-1]), )
def load_gene_data(directory: str, pheno_key_dict, gene_ht_map_path: str, n_cases: int = -1, n_controls: int = -1, heritability: float = -1.0, saige_version: str = 'NA', inv_normalized: str = 'NA', overwrite: bool = False): output_ht_path = f'{directory}/gene_results.ht' print(f'Loading: {directory}/*.gene.txt ...') types = {f'Nmarker_MACCate_{i}': hl.tint32 for i in range(1, 9)} types.update({ x: hl.tfloat64 for x in ('Pvalue', 'Pvalue_Burden', 'Pvalue_SKAT', 'Pvalue_skato_NA', 'Pvalue_burden_NA', 'Pvalue_skat_NA') }) ht = hl.import_table(f'{directory}/*.gene.txt', delimiter=' ', impute=True, types=types) if n_cases == -1: n_cases = hl.null(hl.tint) if n_controls == -1: n_controls = hl.null(hl.tint) if heritability == -1.0: heritability = hl.null(hl.tfloat) if saige_version == 'NA': saige_version = hl.null(hl.tstr) if inv_normalized == 'NA': inv_normalized = hl.null(hl.tstr) fields = ht.Gene.split('_') gene_ht = hl.read_table(gene_ht_map_path).select('interval').distinct() ht = ht.key_by( gene_id=fields[0], gene_symbol=fields[1], annotation=fields[2], **pheno_key_dict).drop('Gene').naive_coalesce(10).annotate_globals( n_cases=n_cases, n_controls=n_controls, heritability=heritability, saige_version=saige_version, inv_normalized=inv_normalized) ht = ht.annotate(total_variants=hl.sum( [v for k, v in list(ht.row_value.items()) if 'Nmarker' in k]), interval=gene_ht.key_by('gene_id')[ht.gene_id].interval) ht = ht.checkpoint(output_ht_path, overwrite=overwrite, _read_if_exists=not overwrite).drop( 'n_cases', 'n_controls')
def make_entry_struct(e, alleles_len, has_non_ref, row): handled_fields = dict() handled_names = { 'LA', 'gvcf_info', 'END', 'LAD', 'AD', 'LGT', 'GT', 'LPL', 'PL', 'LPGT', 'PGT' } if 'END' not in row.info: raise hl.utils.FatalError( "the Hail GVCF combiner expects GVCFs to have an 'END' field in INFO." ) if 'GT' not in e: raise hl.utils.FatalError( "the Hail GVCF combiner expects GVCFs to have a 'GT' field in FORMAT." ) handled_fields['LA'] = hl.range( 0, alleles_len - hl.cond(has_non_ref, 1, 0)) handled_fields['LGT'] = get_lgt(e, alleles_len, has_non_ref, row) if 'AD' in e: handled_fields['LAD'] = hl.cond(has_non_ref, e.AD[:-1], e.AD) if 'PGT' in e: handled_fields['LPGT'] = e.PGT if 'PL' in e: handled_fields['LPL'] = hl.cond( has_non_ref, hl.cond(alleles_len > 2, e.PL[:-alleles_len], hl.null(e.PL.dtype)), hl.cond(alleles_len > 1, e.PL, hl.null(e.PL.dtype))) handled_fields['RGQ'] = hl.cond( has_non_ref, e.PL[hl.call(0, alleles_len - 1).unphased_diploid_gt_index()], hl.null(e.PL.dtype.element_type)) handled_fields['END'] = row.info.END handled_fields['gvcf_info'] = (hl.case().when( hl.is_missing(row.info.END), hl.struct(**(parse_as_fields(row.info.select( *info_to_keep), has_non_ref)))).or_missing()) pass_through_fields = { k: v for k, v in e.items() if k not in handled_names } return hl.struct(**handled_fields, **pass_through_fields)
def get_lgt(e, n_alleles, has_non_ref, row): index = e.GT.unphased_diploid_gt_index() n_no_nonref = n_alleles - hl.int(has_non_ref) triangle_without_nonref = hl.triangle(n_no_nonref) return (hl.case().when(index < triangle_without_nonref, e.GT).when( index < hl.triangle(n_alleles), hl.null('call')).or_error('invalid GT ' + hl.str(e.GT) + ' at site ' + hl.str(row.locus)))
def make_pheno_manifest(export=True): mt0 = load_final_sumstats_mt(filter_sumstats=False, filter_variants=False, separate_columns_by_pop=False, annotate_with_nearest_gene=False) ht = mt0.cols() annotate_dict = {} annotate_dict.update({ 'pops': hl.delimit(ht.pheno_data.pop), 'num_pops': hl.len(ht.pheno_data.pop) }) for field in ['n_cases', 'n_controls', 'heritability', 'lambda_gc']: for pop in ['AFR', 'AMR', 'CSA', 'EAS', 'EUR', 'MID']: new_field = field if field != 'heritability' else 'saige_heritability' # new field name (only applicable to saige heritability) idx = ht.pheno_data.pop.index(pop) field_expr = ht.pheno_data[field] annotate_dict.update({ f'{new_field}_{pop}': hl.if_else(hl.is_nan(idx), hl.null(field_expr[0].dtype), field_expr[idx]) }) annotate_dict.update({'filename': get_pheno_id(tb=ht) + '.tsv.bgz'}) ht = ht.annotate(**annotate_dict) dropbox_manifest = hl.import_table( f'{ldprune_dir}/UKBB_Pan_Populations-Manifest_20200615-manifest_info.tsv', impute=True, key='File') dropbox_manifest = dropbox_manifest.filter( dropbox_manifest['is_old_file'] != '1') bgz = dropbox_manifest.filter(~dropbox_manifest.File.contains('.tbi')) bgz = bgz.rename({'File': 'filename'}) tbi = dropbox_manifest.filter(dropbox_manifest.File.contains('.tbi')) tbi = tbi.annotate( filename=tbi.File.replace('.tbi', '')).key_by('filename') dropbox_annotate_dict = {} rename_dict = { 'dbox link': 'dropbox_link', 'size (bytes)': 'size_in_bytes' } dropbox_annotate_dict.update({'filename_tabix': tbi[ht.filename].File}) for field in ['dbox link', 'wget', 'size (bytes)', 'md5 hex']: for tb, suffix in [(bgz, ''), (tbi, '_tabix')]: dropbox_annotate_dict.update({ (rename_dict[field] if field in rename_dict else field.replace( ' ', '_')) + suffix: tb[ht.filename][field] }) ht = ht.annotate(**dropbox_annotate_dict) ht = ht.drop('pheno_data') ht.describe() ht.show()
def load_variant_data(directory: str, pheno_key_dict, ukb_vep_path: str, extension: str = 'single.txt', n_cases: int = -1, n_controls: int = -1, heritability: float = -1.0, saige_version: str = 'NA', inv_normalized: str = 'NA', overwrite: bool = False, legacy_annotations: bool = False, num_partitions: int = 1000): output_ht_path = f'{directory}/variant_results.ht' ht = hl.import_table(f'{directory}/*.{extension}', delimiter=' ', impute=True) print(f'Loading: {directory}/*.{extension} ...') marker_id_col = 'markerID' if extension == 'single.txt' else 'SNPID' locus_alleles = ht[marker_id_col].split('_') if n_cases == -1: n_cases = hl.null(hl.tint) if n_controls == -1: n_controls = hl.null(hl.tint) if heritability == -1.0: heritability = hl.null(hl.tfloat) if saige_version == 'NA': saige_version = hl.null(hl.tstr) if inv_normalized == 'NA': inv_normalized = hl.null(hl.tstr) ht = ht.key_by(locus=hl.parse_locus(locus_alleles[0]), alleles=locus_alleles[1].split('/'), **pheno_key_dict).distinct().naive_coalesce(num_partitions) if marker_id_col == 'SNPID': ht = ht.drop('CHR', 'POS', 'rsid', 'Allele1', 'Allele2') ht = ht.transmute(Pvalue=ht['p.value']).annotate_globals( n_cases=n_cases, n_controls=n_controls, heritability=heritability, saige_version=saige_version, inv_normalized=inv_normalized) ht = ht.drop('varT', 'varTstar', 'N', 'Tstat') ht = ht.annotate(**get_vep_formatted_data( ukb_vep_path, legacy_annotations=legacy_annotations)[hl.struct( locus=ht.locus, alleles=ht.alleles )]) # TODO: fix this for variants that overlap multiple genes ht = ht.checkpoint(output_ht_path, overwrite=overwrite, _read_if_exists=not overwrite).drop( 'n_cases', 'n_controls', 'heritability')
def load_prescription_data(prescription_data_tsv_path: str, prescription_mapping_tsv_path): ht = hl.import_table(prescription_data_tsv_path, types={'eid': hl.tint, 'data_provider': hl.tint}, key='eid') mapping_ht = hl.import_table(prescription_mapping_tsv_path, impute=True, key='Original_Prescription') ht = ht.annotate(issue_date=hl.cond(hl.len(ht.issue_date) == 0, hl.null(hl.tint64), hl.experimental.strptime(ht.issue_date + ' 00:00:00', '%d/%m/%Y %H:%M:%S', 'GMT')), **mapping_ht[ht.drug_name]) ht = ht.filter(ht.Generic_Name != '').key_by('eid', 'Generic_Name', 'Drug_Category_and_Indication').collect_by_key() ht = ht.annotate(values=hl.sorted(ht.values, key=lambda x: x.issue_date)) return ht.to_matrix_table(row_key=['eid'], col_key=['Generic_Name'], col_fields=['Drug_Category_and_Indication'])
def test_explode_cols(self): mt = hl.utils.range_matrix_table(4, 4) mt = mt.annotate_entries(e=mt.row_idx * 10 + mt.col_idx) self.assertTrue(mt.annotate_cols(x=[1]).explode_cols('x').drop('x')._same(mt)) self.assertEqual(mt.annotate_cols(x=hl.empty_array('int')).explode_cols('x').count_cols(), 0) self.assertEqual(mt.annotate_cols(x=hl.null('array<int>')).explode_cols('x').count_cols(), 0) self.assertEqual(mt.annotate_cols(x=hl.range(0, mt.col_idx)).explode_cols('x').count_cols(), 6)
def test_ndarray_eval(): data_list = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] mishapen_data_list1 = [[4], [1, 2, 3]] mishapen_data_list2 = [[[1], [2, 3]]] mishapen_data_list3 = [[4], [1, 2, 3], 5] nd_expr = hl.nd.array(data_list) evaled = hl.eval(nd_expr) np_equiv = np.array(data_list, dtype=np.int32) np_equiv_fortran_style = np.asfortranarray(np_equiv) np_equiv_extra_dimension = np_equiv.reshape((3, 1, 3)) assert (np.array_equal(evaled, np_equiv)) assert (evaled.strides == np_equiv.strides) assert hl.eval(hl.nd.array([[], []])).strides == (8, 8) assert np.array_equal(hl.eval(hl.nd.array([])), np.array([])) zero_array = np.zeros((10, 10), dtype=np.int64) evaled_zero_array = hl.eval(hl.literal(zero_array)) assert np.array_equal(evaled_zero_array, zero_array) assert zero_array.dtype == evaled_zero_array.dtype # Testing correct interpretation of numpy strides assert np.array_equal(hl.eval(hl.literal(np_equiv_fortran_style)), np_equiv_fortran_style) assert np.array_equal(hl.eval(hl.literal(np_equiv_extra_dimension)), np_equiv_extra_dimension) # Testing from hail arrays assert np.array_equal(hl.eval(hl.nd.array(hl.range(6))), np.arange(6)) assert np.array_equal(hl.eval(hl.nd.array(hl.int64(4))), np.array(4)) # Testing from nested hail arrays assert np.array_equal( hl.eval(hl.nd.array(hl.array([hl.array(x) for x in data_list]))), np.arange(9).reshape((3, 3)) + 1) # Testing missing data assert hl.eval(hl.nd.array(hl.null(hl.tarray(hl.tint32)))) is None with pytest.raises(ValueError) as exc: hl.nd.array(mishapen_data_list1) assert "inner dimensions do not match" in str(exc.value) with pytest.raises(FatalError) as exc: hl.eval(hl.nd.array(hl.array(mishapen_data_list1))) assert "inner dimensions do not match" in str(exc.value) with pytest.raises(FatalError) as exc: hl.eval(hl.nd.array(hl.array(mishapen_data_list2))) assert "inner dimensions do not match" in str(exc.value) with pytest.raises(ValueError) as exc: hl.nd.array(mishapen_data_list3) assert "inner dimensions do not match" in str(exc.value)
def get_expr_for_worst_transcript_consequence_annotations_struct( vep_sorted_transcript_consequences_root, include_coding_annotations=True): """Retrieves the top-ranked transcript annotation based on the ranking computed by get_expr_for_vep_sorted_transcript_consequences_array(..) Args: vep_sorted_transcript_consequences_root (ArrayExpression): include_coding_annotations (bool): """ transcript_consequences = { "biotype": hl.tstr, "canonical": hl.tint, "category": hl.tstr, "cdna_start": hl.tint, "cdna_end": hl.tint, "codons": hl.tstr, "gene_id": hl.tstr, "gene_symbol": hl.tstr, "hgvs": hl.tstr, "hgvsc": hl.tstr, "major_consequence": hl.tstr, "major_consequence_rank": hl.tint, "transcript_id": hl.tstr, } if include_coding_annotations: transcript_consequences.update({ "amino_acids": hl.tstr, "domains": hl.tstr, "hgvsp": hl.tstr, "lof": hl.tstr, "lof_flags": hl.tstr, "lof_filter": hl.tstr, "lof_info": hl.tstr, "polyphen_prediction": hl.tstr, "protein_id": hl.tstr, "sift_prediction": hl.tstr, }) return hl.cond( vep_sorted_transcript_consequences_root.size() == 0, hl.struct( **{ field: hl.null(field_type) for field, field_type in transcript_consequences.items() }), hl.bind( lambda worst_transcript_consequence: (worst_transcript_consequence.annotate(domains=hl.delimit( hl.set(worst_transcript_consequence.domains))).select( *transcript_consequences.keys())), vep_sorted_transcript_consequences_root[0], ), )
def test_ndarray_map(): a = hl.nd.array([[2, 3, 4], [5, 6, 7]]) b = hl.map(lambda x: -x, a) c = hl.map(lambda x: True, a) assert_ndarrays_eq((b, [[-2, -3, -4], [-5, -6, -7]]), (c, [[True, True, True], [True, True, True]])) assert hl.eval(hl.null(hl.tndarray(hl.tfloat, 1)).map(lambda x: x * 2)) is None
def create_frequency_bins_expr( AC: hl.expr.NumericExpression, AF: hl.expr.NumericExpression) -> hl.expr.StringExpression: """ Create bins for frequencies in preparation for aggregating QUAL by frequency bin. Bins: - singleton - doubleton - 0.00005 - 0.0001 - 0.0002 - 0.0005 - 0.001, - 0.002 - 0.005 - 0.01 - 0.02 - 0.05 - 0.1 - 0.2 - 0.5 - 1 NOTE: Frequencies should be frequencies from raw data. Used when creating site quality distribution json files. :param AC: Field in input that contains the allele count information :param AF: Field in input that contains the allele frequency information :return: Expression containing bin name :rtype: hl.expr.StringExpression """ bin_expr = (hl.case().when(AC == 1, "binned_singleton").when( AC == 2, "binned_doubleton").when( (AC > 2) & (AF < 0.00005), "binned_0.00005").when( (AF >= 0.00005) & (AF < 0.0001), "binned_0.0001").when( (AF >= 0.0001) & (AF < 0.0002), "binned_0.0002").when( (AF >= 0.0002) & (AF < 0.0005), "binned_0.0005").when( (AF >= 0.0005) & (AF < 0.001), "binned_0.001").when( (AF >= 0.001) & (AF < 0.002), "binned_0.002").when( (AF >= 0.002) & (AF < 0.005), "binned_0.005").when( (AF >= 0.005) & (AF < 0.01), "binned_0.01").when( (AF >= 0.01) & (AF < 0.02), "binned_0.02"). when((AF >= 0.02) & (AF < 0.05), "binned_0.05").when( (AF >= 0.05) & (AF < 0.1), "binned_0.1").when( (AF >= 0.1) & (AF < 0.2), "binned_0.2").when( (AF >= 0.2) & (AF < 0.5), "binned_0.5").when( (AF >= 0.5) & (AF <= 1), "binned_1").default(hl.null(hl.tstr))) return bin_expr
def _genotype_fields(self): # Convert the mt genotype entries into num_alt, gq, ab, dp, and sample_id. is_called = hl.is_defined(self.mt.GT) return { 'num_alt': hl.cond(is_called, self.mt.GT.n_alt_alleles(), -1), 'gq': hl.cond(is_called, self.mt.GQ, hl.null(hl.tint)), 'ab': hl.bind( lambda total: hl.cond( (is_called) & (total != 0) & (hl.len(self.mt.AD) > 1), hl.float(self.mt.AD[1] / total), hl.null(hl.tfloat)), hl.sum(self.mt.AD)), 'dp': hl.cond(is_called, hl.int(hl.min(self.mt.DP, 32000)), hl.null(hl.tfloat)), 'sample_id': self.mt.s }
def parse_first_occurrence(x): return (hl.case(missing_false=True) .when(hl.is_defined(hl.parse_float(x)), hl.float64(x)) # Source of the first code ... .when(hl.literal(pseudo_dates).contains(hl.str(x)), hl.null(hl.tfloat64)) # Setting past and future dates to missing .when(hl.str(x) == '1902-02-02', 0.0) # Matches DOB .when(hl.str(x) == '1903-03-03', # Within year of birth (taking midpoint between month of birth and EOY) (hl.experimental.strptime('1970-12-31 00:00:00', '%Y-%m-%d %H:%M:%S', 'GMT') - hl.experimental.strptime('1970-' + month + '-15 00:00:00', '%Y-%m-%d %H:%M:%S', 'GMT')) / 2) .default(hl.experimental.strptime(hl.str(x) + ' 00:00:00', '%Y-%m-%d %H:%M:%S', 'GMT') - dob ))
def test_ndarray_slice(): np_rect_prism = np.arange(24).reshape((2, 3, 4)) rect_prism = hl.nd.array(np_rect_prism) np_mat = np.arange(8).reshape((2, 4)) mat = hl.nd.array(np_mat) np_flat = np.arange(20) flat = hl.nd.array(np_flat) assert_ndarrays_eq( (rect_prism[:, :, :], np_rect_prism[:, :, :]), (rect_prism[:, :, 1], np_rect_prism[:, :, 1]), (rect_prism[0:1, 1:3, 0:2], np_rect_prism[0:1, 1:3, 0:2]), (rect_prism[:, :, 1:4:2], np_rect_prism[:, :, 1:4:2]), (rect_prism[:, 2, 1:4:2], np_rect_prism[:, 2, 1:4:2]), (rect_prism[0, 2, 1:4:2], np_rect_prism[0, 2, 1:4:2]), (rect_prism[0, :, 1:4:2] + rect_prism[:, :1, 1:4:2], np_rect_prism[0, :, 1:4:2] + np_rect_prism[:, :1, 1:4:2]), (rect_prism[0:, :, 1:4:2] + rect_prism[:, :1, 1:4:2], np_rect_prism[0:, :, 1:4:2] + np_rect_prism[:, :1, 1:4:2]), (mat[0, 1:4:2] + mat[:, 1:4:2], np_mat[0, 1:4:2] + np_mat[:, 1:4:2]), (rect_prism[0, 0, -3:-1], np_rect_prism[0, 0, -3:-1]), (flat[15:5:-1], np_flat[15:5:-1]), (flat[::-1], np_flat[::-1]), (flat[::22], np_flat[::22]), (flat[::-22], np_flat[::-22]), (flat[15:5], np_flat[15:5]), (flat[3:12:-1], np_flat[3:12:-1]), (flat[12:3:1], np_flat[12:3:1]), (mat[::-1, :], np_mat[::-1, :]), (flat[4:1:-2], np_flat[4:1:-2]), (flat[0:0:1], np_flat[0:0:1]), (flat[-4:-1:2], np_flat[-4:-1:2]) ) assert hl.eval(flat[hl.null(hl.tint32):4:1]) is None assert hl.eval(flat[4:hl.null(hl.tint32)]) is None assert hl.eval(flat[4:10:hl.null(hl.tint32)]) is None assert hl.eval(rect_prism[:, :, 0:hl.null(hl.tint32):1]) is None assert hl.eval(rect_prism[hl.null(hl.tint32), :, :]) is None with pytest.raises(FatalError) as exc: hl.eval(flat[::0]) assert "Slice step cannot be zero" in str(exc)
def get_codings(): """ Read codings data from Duncan's repo and load into hail Table :return: Hail table with codings :rtype: Table """ root = f'{tempfile.gettempdir()}/PHESANT' if subprocess.check_call(['git', 'clone', 'https://github.com/astheeggeggs/PHESANT.git', root]): raise Exception('Could not clone repo') hts = [] coding_dir = f'{root}/WAS/codings' for coding_file in os.listdir(f'{coding_dir}'): hl.hadoop_copy(f'file://{coding_dir}/{coding_file}', f'{coding_dir}/{coding_file}') ht = hl.import_table(f'{coding_dir}/{coding_file}') if 'node_id' not in ht.row: ht = ht.annotate(node_id=hl.null(hl.tstr), parent_id=hl.null(hl.tstr), selectable=hl.null(hl.tstr)) ht = ht.annotate(coding_id=hl.int(coding_file.split('.')[0].replace('coding', ''))) hts.append(ht) full_ht = hts[0].union(*hts[1:]).key_by('coding_id', 'coding') return full_ht.repartition(10)
def create_all_values(): return hl.struct( f32=hl.float32(3.14), i64=hl.int64(-9), m=hl.null(hl.tfloat64), astruct=hl.struct(a=hl.null(hl.tint32), b=5.5), mstruct=hl.null(hl.tstruct(x=hl.tint32, y=hl.tstr)), aset=hl.set(['foo', 'bar', 'baz']), mset=hl.null(hl.tset(hl.tfloat64)), d=hl.dict({hl.array(['a', 'b']): 0.5, hl.array(['x', hl.null(hl.tstr), 'z']): 0.3}), md=hl.null(hl.tdict(hl.tint32, hl.tstr)), h38=hl.locus('chr22', 33878978, 'GRCh38'), ml=hl.null(hl.tlocus('GRCh37')), i=hl.interval( hl.locus('1', 999), hl.locus('1', 1001)), c=hl.call(0, 1), mc=hl.null(hl.tcall), t=hl.tuple([hl.call(1, 2, phased=True), 'foo', hl.null(hl.tstr)]), mt=hl.null(hl.ttuple(hl.tlocus('GRCh37'), hl.tbool)) )
def finalize_annotated_table_for_seqr_variants( mt: hl.MatrixTable) -> hl.MatrixTable: """Given a messily-but-completely annotated Hail MatrixTable of variants, return a new MatrixTable with appropriate formatting to export to Elasticsearch and consume by Seqr. TO-EXTREMELY-DO: Create a app/common Python 3 module with code for SeqrAnnotatedVariant, with methods to im/export to/from Hail/Elasticsearch. :param vep_mt: A VCF loaded into hail 0.2, VEP has been run, and reference/computed fields have been added. :type vep_mt: hl.MatrixTable :return: A hail matrix table of variants and VEP annotations with proper formatting to be consumed by Seqr. :rtype: hl.MatrixTable """ mt = mt.annotate_rows( sortedTranscriptConsequences= get_expr_for_vep_sorted_transcript_consequences_array(vep_root=mt.vep)) mt = mt.annotate_rows( mainTranscript=hl.cond( hl.len(mt.sortedTranscriptConsequences) > 0, mt.sortedTranscriptConsequences[0], hl.null( "struct {biotype: str,canonical: int32,cdna_start: int32,cdna_end: int32,codons: str,gene_id: str,gene_symbol: str,hgvsc: str,hgvsp: str,transcript_id: str,amino_acids: str,lof: str,lof_filter: str,lof_flags: str,lof_info: str,polyphen_prediction: str,protein_id: str,protein_start: int32,sift_prediction: str,consequence_terms: array<str>,domains: array<str>,major_consequence: str,category: str,hgvs: str,major_consequence_rank: int32,transcript_rank: int32}" )), #allele_id=clinvar_mt.index_rows(mt.row_key).vep.id, alt=get_expr_for_alt_allele(mt), chrom=get_expr_for_contig(mt.locus), #clinvar_clinical_significance=clinvar_mt.index_rows(mt.row_key).clinical_significance, domains=get_expr_for_vep_protein_domains_set( vep_transcript_consequences_root=mt.vep.transcript_consequences), geneIds=hl.set( mt.vep.transcript_consequences.map(lambda c: c.gene_id)), # gene_id_to_consequence_json=get_expr_for_vep_gene_id_to_consequence_map( # vep_sorted_transcript_consequences_root=mt.sortedTranscriptConsequences, # gene_ids=clinvar_mt.gene_ids # ), #gold_stars= clinvar_mt.index_entries(mt.row_key,mt.col_key).gold_stars, pos=get_expr_for_start_pos(mt), ref=get_expr_for_ref_allele(mt), #review_status=clinvar_mt.index_rows(mt.locus,mt.alleles).review_status, transcript_consequence_terms=get_expr_for_vep_consequence_terms_set( vep_transcript_consequences_root=mt.sortedTranscriptConsequences), transcript_ids=get_expr_for_vep_transcript_ids_set( vep_transcript_consequences_root=mt.sortedTranscriptConsequences), transcript_id_to_consequence_json= get_expr_for_vep_transcript_id_to_consequence_map( vep_transcript_consequences_root=mt.sortedTranscriptConsequences), variant_id=get_expr_for_variant_id(mt), xpos=get_expr_for_xpos(mt.locus)) return mt
def test_trio_matrix_null_keys(self): ped = hl.Pedigree.read(resource('triomatrix.fam')) ht = hl.import_fam(resource('triomatrix.fam')) mt = hl.import_vcf(resource('triomatrix.vcf')) mt = mt.annotate_cols(fam=ht[mt.s].fam_id) # Make keys all null mt = mt.key_cols_by(s=hl.null(hl.tstr)) tt = hl.trio_matrix(mt, ped, complete_trios=True) self.assertEqual(tt.count_cols(), 0)
def mwzj_hts_by_tree(all_hts, temp_dir, globals_for_col_key, debug=False, inner_mode='overwrite', repartition_final: int = None): chunk_size = int(len(all_hts)**0.5) + 1 outer_hts = [] checkpoint_kwargs = {inner_mode: True} if repartition_final is not None: intervals = get_n_even_intervals(repartition_final) checkpoint_kwargs['_intervals'] = intervals if debug: print(f'Running chunk size {chunk_size}...') for i in range(chunk_size): if i * chunk_size >= len(all_hts): break hts = all_hts[i * chunk_size:(i + 1) * chunk_size] if debug: print( f'Going from {i * chunk_size} to {(i + 1) * chunk_size} ({len(hts)} HTs)...' ) try: if isinstance(hts[0], str): hts = list(map(lambda x: hl.read_table(x), hts)) ht = hl.Table.multi_way_zip_join(hts, 'row_field_name', 'global_field_name') except: if debug: print( f'problem in range {i * chunk_size}-{i * chunk_size + chunk_size}' ) _ = [ht.describe() for ht in hts] raise outer_hts.append( ht.checkpoint(f'{temp_dir}/temp_output_{i}.ht', **checkpoint_kwargs)) ht = hl.Table.multi_way_zip_join(outer_hts, 'row_field_name_outer', 'global_field_name_outer') ht = ht.transmute(inner_row=hl.flatmap( lambda i: hl.cond( hl.is_missing(ht.row_field_name_outer[i].row_field_name), hl.range(0, hl.len(ht.global_field_name_outer[i].global_field_name) ).map(lambda _: hl.null(ht.row_field_name_outer[ i].row_field_name.dtype.element_type)), ht. row_field_name_outer[i].row_field_name), hl.range(hl.len(ht.global_field_name_outer)))) ht = ht.transmute_globals(inner_global=hl.flatmap( lambda x: x.global_field_name, ht.global_field_name_outer)) mt = ht._unlocalize_entries('inner_row', 'inner_global', globals_for_col_key) return mt
def test_nulls_in_distinct_joins(self): # MatrixAnnotateRowsTable uses left distinct join mr = hl.utils.range_matrix_table(7, 3, 4) matrix1 = mr.key_rows_by(new_key=hl.cond((mr.row_idx == 3) | ( mr.row_idx == 5), hl.null(hl.tint32), mr.row_idx)) matrix2 = mr.key_rows_by(new_key=hl.cond((mr.row_idx == 4) | ( mr.row_idx == 6), hl.null(hl.tint32), mr.row_idx)) joined = matrix1.select_rows( idx1=matrix1.row_idx, idx2=matrix2.rows()[matrix1.new_key].row_idx) def row(new_key, idx1, idx2): return hl.Struct(new_key=new_key, idx1=idx1, idx2=idx2) expected = [ row(0, 0, 0), row(1, 1, 1), row(2, 2, 2), row(4, 4, None), row(6, 6, None), row(None, 3, None), row(None, 5, None) ] self.assertEqual(joined.rows().collect(), expected) # union_cols uses inner distinct join matrix1 = matrix1.annotate_entries(ridx=matrix1.row_idx, cidx=matrix1.col_idx) matrix2 = matrix2.annotate_entries(ridx=matrix2.row_idx, cidx=matrix2.col_idx) matrix2 = matrix2.key_cols_by(col_idx=matrix2.col_idx + 3) expected = hl.utils.range_matrix_table(3, 6, 1) expected = expected.key_rows_by(new_key=expected.row_idx) expected = expected.annotate_entries(ridx=expected.row_idx, cidx=expected.col_idx % 3) self.assertTrue(matrix1.union_cols(matrix2)._same(expected))
def test_null_joins_2(self): tr = hl.utils.range_table(7, 1) table1 = tr.key_by(new_key=hl.cond((tr.idx == 3) | (tr.idx == 5), hl.null(hl.tint32), tr.idx), key2=tr.idx) table1 = table1.select(idx1=table1.idx) table2 = tr.key_by(new_key=hl.cond((tr.idx == 4) | (tr.idx == 6), hl.null(hl.tint32), tr.idx), key2=tr.idx) table2 = table2.select(idx2=table2.idx) left_join = table1.join(table2, 'left') right_join = table1.join(table2, 'right') inner_join = table1.join(table2, 'inner') outer_join = table1.join(table2, 'outer') def row(new_key, key2, idx1, idx2): return hl.Struct(new_key=new_key, key2=key2, idx1=idx1, idx2=idx2) left_join_expected = [row(0, 0, 0, 0), row(1, 1, 1, 1), row(2, 2, 2, 2), row(4, 4, 4, None), row(6, 6, 6, None), row(None, 3, 3, None), row(None, 5, 5, None)] right_join_expected = [row(0, 0, 0, 0), row(1, 1, 1, 1), row(2, 2, 2, 2), row(3, 3, None, 3), row(5, 5, None, 5), row(None, 4, None, 4), row(None, 6, None, 6)] inner_join_expected = [row(0, 0, 0, 0), row(1, 1, 1, 1), row(2, 2, 2, 2)] outer_join_expected = [row(0, 0, 0, 0), row(1, 1, 1, 1), row(2, 2, 2, 2), row(3, 3, None, 3), row(4, 4, 4, None), row(5, 5, None, 5), row(6, 6, 6, None), row(None, 3, 3, None), row(None, 4, None, 4), row(None, 5, 5, None), row(None, 6, None, 6)] self.assertEqual(left_join.collect(), left_join_expected) self.assertEqual(right_join.collect(), right_join_expected) self.assertEqual(inner_join.collect(), inner_join_expected) self.assertEqual(outer_join.collect(), outer_join_expected)
def unphase_call_expr(call_expr: hl.expr.CallExpression) -> hl.expr.CallExpression: """ Generate unphased version of a call expression (which can be phased or not) :param call_expr: Input call expression :return: unphased call expression """ return ( hl.case() .when(call_expr.is_diploid(), hl.call(call_expr[0], call_expr[1], phased=False)) .when(call_expr.is_haploid(), hl.call(call_expr[0], phased=False)) .default(hl.null(hl.tcall)) )
def add_default_plink_fields(mt): """Add fields to PLINK""" return mt\ .annotate_rows(rsid=hl.null(hl.tstr))\ .annotate_cols( fam_id=hl.null(hl.tstr), pat_id=hl.null(hl.tstr), mat_id=hl.null(hl.tstr), is_female=hl.null(hl.tbool), is_case=hl.null(hl.tbool) )
def make_pheno_manifest(): mt0 = load_final_sumstats_mt(filter_sumstats=False, filter_variants=False, separate_columns_by_pop=False, annotate_with_nearest_gene=False) ht = mt0.cols() annotate_dict = {} annotate_dict.update({ 'pops': hl.delimit(ht.pheno_data.pop), 'num_pops': hl.len(ht.pheno_data.pop) }) for field in ['n_cases', 'n_controls', 'heritability', 'lambda_gc']: for pop in ['AFR', 'AMR', 'CSA', 'EAS', 'EUR', 'MID']: new_field = field if field != 'heritability' else 'saige_heritability' # new field name (only applicable to saige heritability) idx = ht.pheno_data.pop.index(pop) field_expr = ht.pheno_data[field] annotate_dict.update({ f'{new_field}_{pop}': hl.if_else(hl.is_nan(idx), hl.null(field_expr[0].dtype), field_expr[idx]) }) annotate_dict.update({ 'filename': (ht.trait_type + '-' + ht.phenocode + '-' + ht.pheno_sex + hl.if_else(hl.len(ht.coding) > 0, '-' + ht.coding, '') + hl.if_else(hl.len(ht.modifier) > 0, '-' + ht.modifier, '')).replace( ' ', '_').replace('/', '_') + '.tsv.bgz' }) ht = ht.annotate(**annotate_dict) aws_bucket = 'https://pan-ukb-us-east-1.s3.amazonaws.com/sumstats_release' ht = ht.annotate(aws_link=aws_bucket + '/' + ht.filename, aws_link_tabix=aws_bucket + '_tabix/' + ht.filename + '.tbi') other_fields_ht = hl.import_table( f'{ldprune_dir}/release/md5_hex_and_file_size.tsv.bgz', force_bgz=True, key=PHENO_KEY_FIELDS) other_fields = [ 'size_in_bytes', 'size_in_bytes_tabix', 'md5_hex', 'md5_hex_tabix' ] ht = ht.annotate(wget='wget ' + ht.aws_link, wget_tabix='wget ' + ht.aws_link_tabix, **{f: other_fields_ht[ht.key][f] for f in other_fields}) ht = ht.drop('pheno_data', 'pheno_indices') ht.export(f'{bucket}/combined_results/phenotype_manifest.tsv.bgz')
def add_default_plink_fields(mt): return mt.annotate_rows(rsid=hl.null(hl.tstr)).annotate_cols( fam_id=hl.null(hl.tstr), pat_id=hl.null(hl.tstr), mat_id=hl.null(hl.tstr), is_female=hl.null(hl.tbool), is_case=hl.null(hl.tbool), )
def create_all_values_datasets(): all_values = hl.struct( f32=hl.float32(3.14), i64=hl.int64(-9), m=hl.null(hl.tfloat64), astruct=hl.struct(a=hl.null(hl.tint32), b=5.5), mstruct=hl.null(hl.tstruct(x=hl.tint32, y=hl.tstr)), aset=hl.set(['foo', 'bar', 'baz']), mset=hl.null(hl.tset(hl.tfloat64)), d=hl.dict({ hl.array(['a', 'b']): 0.5, hl.array(['x', hl.null(hl.tstr), 'z']): 0.3 }), md=hl.null(hl.tdict(hl.tint32, hl.tstr)), h38=hl.locus('chr22', 33878978, 'GRCh38'), ml=hl.null(hl.tlocus('GRCh37')), i=hl.interval(hl.locus('1', 999), hl.locus('1', 1001)), c=hl.call(0, 1), mc=hl.null(hl.tcall), t=hl.tuple([hl.call(1, 2, phased=True), 'foo', hl.null(hl.tstr)]), mt=hl.null(hl.ttuple(hl.tlocus('GRCh37'), hl.tbool))) def prefix(s, p): return hl.struct(**{p + k: s[k] for k in s}) all_values_table = (hl.utils.range_table( 5, n_partitions=3).annotate_globals( **prefix(all_values, 'global_')).annotate(**all_values).cache()) all_values_matrix_table = (hl.utils.range_matrix_table( 3, 2, n_partitions=2).annotate_globals( **prefix(all_values, 'global_')).annotate_rows( **prefix(all_values, 'row_')).annotate_cols( **prefix(all_values, 'col_')).annotate_entries( **prefix(all_values, 'entry_')).cache()) return all_values_table, all_values_matrix_table
def create_all_values_datasets(): all_values = hl.struct( f32=hl.float32(3.14), i64=hl.int64(-9), m=hl.null(hl.tfloat64), astruct=hl.struct(a=hl.null(hl.tint32), b=5.5), mstruct=hl.null(hl.tstruct(x=hl.tint32, y=hl.tstr)), aset=hl.set(['foo', 'bar', 'baz']), mset=hl.null(hl.tset(hl.tfloat64)), d=hl.dict({hl.array(['a', 'b']): 0.5, hl.array(['x', hl.null(hl.tstr), 'z']): 0.3}), md=hl.null(hl.tdict(hl.tint32, hl.tstr)), h38=hl.locus('chr22', 33878978, 'GRCh38'), ml=hl.null(hl.tlocus('GRCh37')), i=hl.interval( hl.locus('1', 999), hl.locus('1', 1001)), c=hl.call(0, 1), mc=hl.null(hl.tcall), t=hl.tuple([hl.call(1, 2, phased=True), 'foo', hl.null(hl.tstr)]), mt=hl.null(hl.ttuple(hl.tlocus('GRCh37'), hl.tbool)) ) def prefix(s, p): return hl.struct(**{p + k: s[k] for k in s}) all_values_table = (hl.utils.range_table(5, n_partitions=3) .annotate_globals(**prefix(all_values, 'global_')) .annotate(**all_values) .cache()) all_values_matrix_table = (hl.utils.range_matrix_table(3, 2, n_partitions=2) .annotate_globals(**prefix(all_values, 'global_')) .annotate_rows(**prefix(all_values, 'row_')) .annotate_cols(**prefix(all_values, 'col_')) .annotate_entries(**prefix(all_values, 'entry_')) .cache()) return all_values_table, all_values_matrix_table
def test_nulls_in_distinct_joins(self): # MatrixAnnotateRowsTable uses left distinct join mr = hl.utils.range_matrix_table(7, 3, 4) matrix1 = mr.key_rows_by(new_key=hl.cond((mr.row_idx == 3) | (mr.row_idx == 5), hl.null(hl.tint32), mr.row_idx)) matrix2 = mr.key_rows_by(new_key=hl.cond((mr.row_idx == 4) | (mr.row_idx == 6), hl.null(hl.tint32), mr.row_idx)) joined = matrix1.select_rows(idx1=matrix1.row_idx, idx2=matrix2.rows()[matrix1.new_key].row_idx) def row(new_key, idx1, idx2): return hl.Struct(new_key=new_key, idx1=idx1, idx2=idx2) expected = [row(0, 0, 0), row(1, 1, 1), row(2, 2, 2), row(4, 4, None), row(6, 6, None), row(None, 3, None), row(None, 5, None)] self.assertEqual(joined.rows().collect(), expected) # union_cols uses inner distinct join matrix1 = matrix1.annotate_entries(ridx=matrix1.row_idx, cidx=matrix1.col_idx) matrix2 = matrix2.annotate_entries(ridx=matrix2.row_idx, cidx=matrix2.col_idx) matrix2 = matrix2.key_cols_by(col_idx=matrix2.col_idx + 3) expected = hl.utils.range_matrix_table(3, 6, 1) expected = expected.key_rows_by(new_key=expected.row_idx) expected = expected.annotate_entries(ridx=expected.row_idx, cidx=expected.col_idx % 3) self.assertTrue(matrix1.union_cols(matrix2)._same(expected))
def test_aggregate_ir(self): kt = hl.utils.range_table(10).annotate_globals(g1=5) r = kt.aggregate(hl.struct(x=agg.sum(kt.idx) + kt.g1, y=agg.filter(kt.idx % 2 != 0, agg.sum(kt.idx + 2)) + kt.g1, z=agg.sum(kt.g1 + kt.idx) + kt.g1)) self.assertEqual(convert_struct_to_dict(r), {u'x': 50, u'y': 40, u'z': 100}) r = kt.aggregate(5) self.assertEqual(r, 5) r = kt.aggregate(hl.null(hl.tint32)) self.assertEqual(r, None) r = kt.aggregate(agg.filter(kt.idx % 2 != 0, agg.sum(kt.idx + 2)) + kt.g1) self.assertEqual(r, 40)
def transform_entries(old_entry): def with_local_a_index(local_a_index): new_pl = hl.or_missing( hl.is_defined(old_entry.LPL), hl.or_missing( hl.is_defined(local_a_index), hl.range(0, 3).map(lambda i: hl.min( hl.range(0, hl.triangle(hl.len(old_entry.LA))) .filter(lambda j: hl.downcode(hl.unphased_diploid_gt_index_call(j), local_a_index) == hl.unphased_diploid_gt_index_call(i)) .map(lambda idx: old_entry.LPL[idx]))))) fields = set(old_entry.keys()) def with_pl(pl): new_exprs = {} dropped_fields = ['LA'] if 'LGT' in fields: new_exprs['GT'] = hl.downcode(old_entry.LGT, hl.or_else(local_a_index, hl.len(old_entry.LA))) dropped_fields.append('LGT') if 'LPGT' in fields: new_exprs['PGT'] = hl.downcode(old_entry.LPGT, hl.or_else(local_a_index, hl.len(old_entry.LA))) dropped_fields.append('LPGT') if 'LAD' in fields: new_exprs['AD'] = hl.or_missing( hl.is_defined(old_entry.LAD), [old_entry.LAD[0], hl.or_else(old_entry.LAD[local_a_index], 0)]) # second entry zeroed for lack of non-ref AD dropped_fields.append('LAD') if 'LPL' in fields: new_exprs['PL'] = pl if 'GQ' in fields: new_exprs['GQ'] = hl.or_else(hl.gq_from_pl(pl), old_entry.GQ) dropped_fields.append('LPL') return hl.cond(hl.len(ds.alleles) == 1, old_entry.annotate(**{f[1:]: old_entry[f] for f in ['LGT', 'LPGT', 'LAD', 'LPL'] if f in fields}).drop(*dropped_fields), old_entry.annotate(**new_exprs).drop(*dropped_fields)) if 'LPL' in fields: return hl.bind(with_pl, new_pl) else: return with_pl(None) lai = hl.fold(lambda accum, elt: hl.cond(old_entry.LA[elt] == ds[new_id].a_index, elt, accum), hl.null(hl.tint32), hl.range(0, hl.len(old_entry.LA))) return hl.bind(with_local_a_index, lai)
def phase_y_nonpar( proband_call: hl.expr.CallExpression, father_call: hl.expr.CallExpression, ) -> hl.expr.ArrayExpression: """ Returns phased genotype calls in the non-PAR region of Y (requires both father and proband to be haploid to return phase) :param CallExpression proband_call: Input proband genotype call :param CallExpression father_call: Input father genotype call :return: Array containing: phased proband call, phased father call, phased mother call :rtype: ArrayExpression """ return hl.or_missing( proband_call.is_haploid() & father_call.is_haploid() & (father_call[0] == proband_call[0]), hl.array([ hl.call(proband_call[0], phased=True), hl.call(father_call[0], phased=True), hl.null(hl.tcall) ]) )
def test_filter_na(self): mt = hl.utils.range_matrix_table(1, 1) self.assertEqual(mt.filter_rows(hl.null(hl.tbool)).count_rows(), 0) self.assertEqual(mt.filter_cols(hl.null(hl.tbool)).count_cols(), 0) self.assertEqual(mt.filter_entries(hl.null(hl.tbool)).entries().count(), 0)
def variant_qc(mt, name='variant_qc') -> MatrixTable: """Compute common variant statistics (quality control metrics). .. include:: ../_templates/req_tvariant.rst Examples -------- >>> dataset_result = hl.variant_qc(dataset) Notes ----- This method computes variant statistics from the genotype data, returning a new struct field `name` with the following metrics based on the fields present in the entry schema. If `mt` contains an entry field `DP` of type :py:data:`.tint32`, then the field `dp_stats` is computed. If `mt` contains an entry field `GQ` of type :py:data:`.tint32`, then the field `gq_stats` is computed. Both `dp_stats` and `gq_stats` are structs with with four fields: - `mean` (``float64``) -- Mean value. - `stdev` (``float64``) -- Standard deviation (zero degrees of freedom). - `min` (``int32``) -- Minimum value. - `max` (``int32``) -- Maximum value. If the dataset does not contain an entry field `GT` of type :py:data:`.tcall`, then an error is raised. The following fields are always computed from `GT`: - `AF` (``array<float64>``) -- Calculated allele frequency, one element per allele, including the reference. Sums to one. Equivalent to `AC` / `AN`. - `AC` (``array<int32>``) -- Calculated allele count, one element per allele, including the reference. Sums to `AN`. - `AN` (``int32``) -- Total number of called alleles. - `homozygote_count` (``array<int32>``) -- Number of homozygotes per allele. One element per allele, including the reference. - `n_called` (``int64``) -- Number of samples with a defined `GT`. - `n_not_called` (``int64``) -- Number of samples with a missing `GT`. - `call_rate` (``float32``) -- Fraction of samples with a defined `GT`. Equivalent to `n_called` / :meth:`.count_cols`. - `n_het` (``int64``) -- Number of heterozygous samples. - `n_non_ref` (``int64``) -- Number of samples with at least one called non-reference allele. - `het_freq_hwe` (``float64``) -- Expected frequency of heterozygous samples under Hardy-Weinberg equilibrium. See :func:`.functions.hardy_weinberg_test` for details. - `p_value_hwe` (``float64``) -- p-value from test of Hardy-Weinberg equilibrium. See :func:`.functions.hardy_weinberg_test` for details. Warning ------- `het_freq_hwe` and `p_value_hwe` are calculated as in :func:`.functions.hardy_weinberg_test`, with non-diploid calls (``ploidy != 2``) ignored in the counts. As this test is only statistically rigorous in the biallelic setting, :func:`.variant_qc` sets both fields to missing for multiallelic variants. Consider using :func:`~hail.methods.split_multi` to split multi-allelic variants beforehand. Parameters ---------- mt : :class:`.MatrixTable` Dataset. name : :obj:`str` Name for resulting field. Returns ------- :class:`.MatrixTable` """ require_row_key_variant(mt, 'variant_qc') exprs = {} struct_exprs = [] def has_field_of_type(name, dtype): return name in mt.entry and mt[name].dtype == dtype n_samples = mt.count_cols() if has_field_of_type('DP', hl.tint32): exprs['dp_stats'] = hl.agg.stats(mt.DP).select('mean', 'stdev', 'min', 'max') if has_field_of_type('GQ', hl.tint32): exprs['gq_stats'] = hl.agg.stats(mt.GQ).select('mean', 'stdev', 'min', 'max') if not has_field_of_type('GT', hl.tcall): raise ValueError(f"'variant_qc': expect an entry field 'GT' of type 'call'") exprs['n_called'] = hl.agg.count_where(hl.is_defined(mt['GT'])) struct_exprs.append(hl.agg.call_stats(mt.GT, mt.alleles)) # the structure of this function makes it easy to add new nested computations def flatten_struct(*struct_exprs): flat = {} for struct in struct_exprs: for k, v in struct.items(): flat[k] = v return hl.struct( **flat, **exprs, ) mt = mt.annotate_rows(**{name: hl.bind(flatten_struct, *struct_exprs)}) hwe = hl.hardy_weinberg_test(mt[name].homozygote_count[0], mt[name].AC[1] - 2 * mt[name].homozygote_count[1], mt[name].homozygote_count[1]) hwe = hwe.select(het_freq_hwe=hwe.het_freq_hwe, p_value_hwe=hwe.p_value) mt = mt.annotate_rows(**{name: mt[name].annotate(n_not_called=n_samples - mt[name].n_called, call_rate=mt[name].n_called / n_samples, n_het=mt[name].n_called - hl.sum(mt[name].homozygote_count), n_non_ref=mt[name].n_called - mt[name].homozygote_count[0], **hl.cond(hl.len(mt.alleles) == 2, hwe, hl.null(hwe.dtype)))}) return mt
def get_allele_type(allele_idx): return hl.cond(allele_idx > 0, mt[variant_atypes][allele_idx - 1], hl.null(hl.tint32))
def test_locus_windows(self): def assert_eq(a, b): self.assertTrue(np.array_equal(a, np.array(b))) centimorgans = hl.literal([0.1, 1.0, 1.0, 1.5, 1.9]) mt = hl.balding_nichols_model(1, 5, 5).add_row_index() mt = mt.annotate_rows(cm=centimorgans[hl.int32(mt.row_idx)]).cache() starts, stops = hl.linalg.utils.locus_windows(mt.locus, 2) assert_eq(starts, [0, 0, 0, 1, 2]) assert_eq(stops, [3, 4, 5, 5, 5]) starts, stops = hl.linalg.utils.locus_windows(mt.locus, 0.5, coord_expr=mt.cm) assert_eq(starts, [0, 1, 1, 1, 3]) assert_eq(stops, [1, 4, 4, 5, 5]) starts, stops = hl.linalg.utils.locus_windows(mt.locus, 1.0, coord_expr=2 * centimorgans[hl.int32(mt.row_idx)]) assert_eq(starts, [0, 1, 1, 1, 3]) assert_eq(stops, [1, 4, 4, 5, 5]) rows = [{'locus': hl.Locus('1', 1), 'cm': 1.0}, {'locus': hl.Locus('1', 2), 'cm': 3.0}, {'locus': hl.Locus('1', 4), 'cm': 4.0}, {'locus': hl.Locus('2', 1), 'cm': 2.0}, {'locus': hl.Locus('2', 1), 'cm': 2.0}, {'locus': hl.Locus('3', 3), 'cm': 5.0}] ht = hl.Table.parallelize(rows, hl.tstruct(locus=hl.tlocus('GRCh37'), cm=hl.tfloat64), key=['locus']) starts, stops = hl.linalg.utils.locus_windows(ht.locus, 1) assert_eq(starts, [0, 0, 2, 3, 3, 5]) assert_eq(stops, [2, 2, 3, 5, 5, 6]) starts, stops = hl.linalg.utils.locus_windows(ht.locus, 1.0, coord_expr=ht.cm) assert_eq(starts, [0, 1, 1, 3, 3, 5]) assert_eq(stops, [1, 3, 3, 5, 5, 6]) with self.assertRaises(ValueError) as cm: hl.linalg.utils.locus_windows(ht.order_by(ht.cm).locus, 1.0) self.assertTrue('ascending order' in str(cm.exception)) with self.assertRaises(ExpressionException) as cm: hl.linalg.utils.locus_windows(ht.locus, 1.0, coord_expr=hl.utils.range_table(1).idx) self.assertTrue('different source' in str(cm.exception)) with self.assertRaises(ExpressionException) as cm: hl.linalg.utils.locus_windows(hl.locus('1', 1), 1.0) self.assertTrue("no source" in str(cm.exception)) with self.assertRaises(ExpressionException) as cm: hl.linalg.utils.locus_windows(ht.locus, 1.0, coord_expr=0.0) self.assertTrue("no source" in str(cm.exception)) ht = ht.annotate_globals(x = hl.locus('1', 1), y = 1.0) with self.assertRaises(ExpressionException) as cm: hl.linalg.utils.locus_windows(ht.x, 1.0) self.assertTrue("row-indexed" in str(cm.exception)) with self.assertRaises(ExpressionException) as cm: hl.linalg.utils.locus_windows(ht.locus, 1.0, ht.y) self.assertTrue("row-indexed" in str(cm.exception)) ht = hl.Table.parallelize([{'locus': hl.null(hl.tlocus()), 'cm': 1.0}], hl.tstruct(locus=hl.tlocus('GRCh37'), cm=hl.tfloat64), key=['locus']) with self.assertRaises(ValueError) as cm: hl.linalg.utils.locus_windows(ht.locus, 1.0) self.assertTrue("missing value for 'locus_expr'" in str(cm.exception)) with self.assertRaises(ValueError) as cm: hl.linalg.utils.locus_windows(ht.locus, 1.0, coord_expr=ht.cm) self.assertTrue("missing value for 'locus_expr'" in str(cm.exception)) ht = hl.Table.parallelize([{'locus': hl.Locus('1', 1), 'cm': hl.null(hl.tfloat64)}], hl.tstruct(locus=hl.tlocus('GRCh37'), cm=hl.tfloat64), key=['locus']) with self.assertRaises(ValueError) as cm: hl.linalg.utils.locus_windows(ht.locus, 1.0, coord_expr=ht.cm) self.assertTrue("missing value for 'coord_expr'" in str(cm.exception))
def full_outer_join_mt(left: hl.MatrixTable, right: hl.MatrixTable) -> hl.MatrixTable: """Performs a full outer join on `left` and `right`. Replaces row, column, and entry fields with the following: - `left_row` / `right_row`: structs of row fields from left and right. - `left_col` / `right_col`: structs of column fields from left and right. - `left_entry` / `right_entry`: structs of entry fields from left and right. Parameters ---------- left : :class:`.MatrixTable` right : :class:`.MatrixTable` Returns ------- :class:`.MatrixTable` """ if [x.dtype for x in left.row_key.values()] != [x.dtype for x in right.row_key.values()]: raise ValueError(f"row key types do not match:\n" f" left: {list(left.row_key.values())}\n" f" right: {list(right.row_key.values())}") if [x.dtype for x in left.col_key.values()] != [x.dtype for x in right.col_key.values()]: raise ValueError(f"column key types do not match:\n" f" left: {list(left.col_key.values())}\n" f" right: {list(right.col_key.values())}") left = left.select_rows(left_row=left.row) left_t = left.localize_entries('left_entries', 'left_cols') right = right.select_rows(right_row=right.row) right_t = right.localize_entries('right_entries', 'right_cols') ht = left_t.join(right_t, how='outer') ht = ht.annotate_globals( left_keys=hl.group_by( lambda t: t[0], hl.zip_with_index( ht.left_cols.map(lambda x: hl.tuple([x[f] for f in left.col_key])), index_first=False)).map_values( lambda elts: elts.map(lambda t: t[1])), right_keys=hl.group_by( lambda t: t[0], hl.zip_with_index( ht.right_cols.map(lambda x: hl.tuple([x[f] for f in right.col_key])), index_first=False)).map_values( lambda elts: elts.map(lambda t: t[1]))) ht = ht.annotate_globals( key_indices=hl.array(ht.left_keys.key_set().union(ht.right_keys.key_set())) .map(lambda k: hl.struct(k=k, left_indices=ht.left_keys.get(k), right_indices=ht.right_keys.get(k))) .flatmap(lambda s: hl.case() .when(hl.is_defined(s.left_indices) & hl.is_defined(s.right_indices), hl.range(0, s.left_indices.length()).flatmap( lambda i: hl.range(0, s.right_indices.length()).map( lambda j: hl.struct(k=s.k, left_index=s.left_indices[i], right_index=s.right_indices[j])))) .when(hl.is_defined(s.left_indices), s.left_indices.map( lambda elt: hl.struct(k=s.k, left_index=elt, right_index=hl.null('int32')))) .when(hl.is_defined(s.right_indices), s.right_indices.map( lambda elt: hl.struct(k=s.k, left_index=hl.null('int32'), right_index=elt))) .or_error('assertion error'))) ht = ht.annotate(__entries=ht.key_indices.map(lambda s: hl.struct(left_entry=ht.left_entries[s.left_index], right_entry=ht.right_entries[s.right_index]))) ht = ht.annotate_globals(__cols=ht.key_indices.map( lambda s: hl.struct(**{f: s.k[i] for i, f in enumerate(left.col_key)}, left_col=ht.left_cols[s.left_index], right_col=ht.right_cols[s.right_index]))) ht = ht.drop('left_entries', 'left_cols', 'left_keys', 'right_entries', 'right_cols', 'right_keys', 'key_indices') return ht._unlocalize_entries('__entries', '__cols', list(left.col_key))
def combine(ts): def merge_alleles(alleles): from hail.expr.functions import _num_allele_type, _allele_ints return hl.rbind( alleles.map(lambda a: hl.or_else(a[0], '')) .fold(lambda s, t: hl.cond(hl.len(s) > hl.len(t), s, t), ''), lambda ref: hl.rbind( alleles.map( lambda al: hl.rbind( al[0], lambda r: hl.array([ref]).extend( al[1:].map( lambda a: hl.rbind( _num_allele_type(r, a), lambda at: hl.cond( (_allele_ints['SNP'] == at) | (_allele_ints['Insertion'] == at) | (_allele_ints['Deletion'] == at) | (_allele_ints['MNP'] == at) | (_allele_ints['Complex'] == at), a + ref[hl.len(r):], a)))))), lambda lal: hl.struct( globl=hl.array([ref]).extend(hl.array(hl.set(hl.flatten(lal)).remove(ref))), local=lal))) def renumber_entry(entry, old_to_new) -> StructExpression: # global index of alternate (non-ref) alleles return entry.annotate(LA=entry.LA.map(lambda lak: old_to_new[lak])) if (ts.row.dtype, ts.globals.dtype) not in _merge_function_map: f = hl.experimental.define_function( lambda row, gbl: hl.rbind( merge_alleles(row.data.map(lambda d: d.alleles)), lambda alleles: hl.struct( locus=row.locus, alleles=alleles.globl, rsid=hl.find(hl.is_defined, row.data.map(lambda d: d.rsid)), __entries=hl.bind( lambda combined_allele_index: hl.range(0, hl.len(row.data)).flatmap( lambda i: hl.cond(hl.is_missing(row.data[i].__entries), hl.range(0, hl.len(gbl.g[i].__cols)) .map(lambda _: hl.null(row.data[i].__entries.dtype.element_type)), hl.bind( lambda old_to_new: row.data[i].__entries.map( lambda e: renumber_entry(e, old_to_new)), hl.range(0, hl.len(alleles.local[i])).map( lambda j: combined_allele_index[alleles.local[i][j]])))), hl.dict(hl.range(0, hl.len(alleles.globl)).map( lambda j: hl.tuple([alleles.globl[j], j])))))), ts.row.dtype, ts.globals.dtype) _merge_function_map[(ts.row.dtype, ts.globals.dtype)] = f merge_function = _merge_function_map[(ts.row.dtype, ts.globals.dtype)] ts = Table(TableMapRows(ts._tir, Apply(merge_function._name, TopLevelReference('row'), TopLevelReference('global')))) return ts.transmute_globals(__cols=hl.flatten(ts.g.map(lambda g: g.__cols)))
def test_annotate(self): schema = hl.tstruct(a=hl.tint32, b=hl.tint32, c=hl.tint32, d=hl.tint32, e=hl.tstr, f=hl.tarray(hl.tint32)) rows = [{'a': 4, 'b': 1, 'c': 3, 'd': 5, 'e': "hello", 'f': [1, 2, 3]}, {'a': 0, 'b': 5, 'c': 13, 'd': -1, 'e': "cat", 'f': []}, {'a': 4, 'b': 2, 'c': 20, 'd': 3, 'e': "dog", 'f': [5, 6, 7]}] kt = hl.Table.parallelize(rows, schema) self.assertTrue(kt.annotate()._same(kt)) result1 = convert_struct_to_dict(kt.annotate(foo=kt.a + 1, foo2=kt.a).take(1)[0]) self.assertDictEqual(result1, {'a': 4, 'b': 1, 'c': 3, 'd': 5, 'e': "hello", 'f': [1, 2, 3], 'foo': 5, 'foo2': 4}) result3 = convert_struct_to_dict(kt.annotate( x1=kt.f.map(lambda x: x * 2), x2=kt.f.map(lambda x: [x, x + 1]).flatmap(lambda x: x), x3=hl.min(kt.f), x4=hl.max(kt.f), x5=hl.sum(kt.f), x6=hl.product(kt.f), x7=kt.f.length(), x8=kt.f.filter(lambda x: x == 3), x9=kt.f[1:], x10=kt.f[:], x11=kt.f[1:2], x12=kt.f.map(lambda x: [x, x + 1]), x13=kt.f.map(lambda x: [[x, x + 1], [x + 2]]).flatmap(lambda x: x), x14=hl.cond(kt.a < kt.b, kt.c, hl.null(hl.tint32)), x15={1, 2, 3} ).take(1)[0]) self.assertDictEqual(result3, {'a': 4, 'b': 1, 'c': 3, 'd': 5, 'e': "hello", 'f': [1, 2, 3], 'x1': [2, 4, 6], 'x2': [1, 2, 2, 3, 3, 4], 'x3': 1, 'x4': 3, 'x5': 6, 'x6': 6, 'x7': 3, 'x8': [3], 'x9': [2, 3], 'x10': [1, 2, 3], 'x11': [2], 'x12': [[1, 2], [2, 3], [3, 4]], 'x13': [[1, 2], [3], [2, 3], [4], [3, 4], [5]], 'x14': None, 'x15': set([1, 2, 3])}) kt.annotate( x1=kt.a + 5, x2=5 + kt.a, x3=kt.a + kt.b, x4=kt.a - 5, x5=5 - kt.a, x6=kt.a - kt.b, x7=kt.a * 5, x8=5 * kt.a, x9=kt.a * kt.b, x10=kt.a / 5, x11=5 / kt.a, x12=kt.a / kt.b, x13=-kt.a, x14=+kt.a, x15=kt.a == kt.b, x16=kt.a == 5, x17=5 == kt.a, x18=kt.a != kt.b, x19=kt.a != 5, x20=5 != kt.a, x21=kt.a > kt.b, x22=kt.a > 5, x23=5 > kt.a, x24=kt.a >= kt.b, x25=kt.a >= 5, x26=5 >= kt.a, x27=kt.a < kt.b, x28=kt.a < 5, x29=5 < kt.a, x30=kt.a <= kt.b, x31=kt.a <= 5, x32=5 <= kt.a, x33=(kt.a == 0) & (kt.b == 5), x34=(kt.a == 0) | (kt.b == 5), x35=False, x36=True )
def test_filter_missing(self): ht = hl.utils.range_table(1, 1) self.assertEqual(ht.filter(hl.null(hl.tbool)).count(), 0)
def transform_one(mt, vardp_outlier=100_000) -> Table: """transforms a gvcf into a form suitable for combining The input to this should be some result of either :func:`.import_vcf` or :func:`.import_vcfs` with `array_elements_required=False`. There is a strong assumption that this function will be called on a matrix table with one column. """ mt = localize(mt) if mt.row.dtype not in _transform_rows_function_map: f = hl.experimental.define_function( lambda row: hl.rbind( hl.len(row.alleles), '<NON_REF>' == row.alleles[-1], lambda alleles_len, has_non_ref: hl.struct( locus=row.locus, alleles=hl.cond(has_non_ref, row.alleles[:-1], row.alleles), rsid=row.rsid, __entries=row.__entries.map( lambda e: hl.struct( DP=e.DP, END=row.info.END, GQ=e.GQ, LA=hl.range(0, alleles_len - hl.cond(has_non_ref, 1, 0)), LAD=hl.cond(has_non_ref, e.AD[:-1], e.AD), LGT=e.GT, LPGT=e.PGT, LPL=hl.cond(has_non_ref, hl.cond(alleles_len > 2, e.PL[:-alleles_len], hl.null(e.PL.dtype)), hl.cond(alleles_len > 1, e.PL, hl.null(e.PL.dtype))), MIN_DP=e.MIN_DP, PID=e.PID, RGQ=hl.cond( has_non_ref, e.PL[hl.call(0, alleles_len - 1).unphased_diploid_gt_index()], hl.null(e.PL.dtype.element_type)), SB=e.SB, gvcf_info=hl.case() .when(hl.is_missing(row.info.END), hl.struct( ClippingRankSum=row.info.ClippingRankSum, BaseQRankSum=row.info.BaseQRankSum, MQ=row.info.MQ, MQRankSum=row.info.MQRankSum, MQ_DP=row.info.MQ_DP, QUALapprox=row.info.QUALapprox, RAW_MQ=row.info.RAW_MQ, ReadPosRankSum=row.info.ReadPosRankSum, VarDP=hl.cond(row.info.VarDP > vardp_outlier, row.info.DP, row.info.VarDP))) .or_missing() ))), ), mt.row.dtype) _transform_rows_function_map[mt.row.dtype] = f transform_row = _transform_rows_function_map[mt.row.dtype] return Table(TableMapRows(mt._tir, Apply(transform_row._name, TopLevelReference('row'))))