def test_window_by_locus(self): mt = hl.utils.range_matrix_table(100, 2, n_partitions=10) mt = mt.annotate_rows(locus=hl.locus('1', mt.row_idx + 1)) mt = mt.key_rows_by('locus') mt = mt.annotate_entries(e_row_idx=mt.row_idx, e_col_idx=mt.col_idx) mt = hl.window_by_locus(mt, 5).cache() self.assertEqual(mt.count_rows(), 100) rows = mt.rows() self.assertTrue( rows.all((rows.row_idx < 5) | (rows.prev_rows.length() == 5))) self.assertTrue( rows.all( hl.all(lambda x: (rows.row_idx - 1 - x[0]) == x[1].row_idx, hl.zip_with_index(rows.prev_rows)))) entries = mt.entries() self.assertTrue( entries.all( hl.all(lambda x: x.e_col_idx == entries.col_idx, entries.prev_entries))) self.assertTrue( entries.all( hl.all(lambda x: entries.row_idx - 1 - x[0] == x[1].e_row_idx, hl.zip_with_index(entries.prev_entries))))
def phase_haploid_proband_x_nonpar( proband_call: hl.expr.CallExpression, father_call: hl.expr.CallExpression, mother_call: hl.expr.CallExpression) -> hl.expr.ArrayExpression: """ Returns phased genotype calls in the case of a haploid proband in the non-PAR region of X :param CallExpression proband_call: Input proband genotype call :param CallExpression father_call: Input father genotype call :param CallExpression mother_call: Input mother genotype call :return: Array containing: phased proband call, phased father call, phased mother call :rtype: ArrayExpression """ transmitted_allele = hl.zip_with_index( hl.array([mother_call[0], mother_call[1]])).find(lambda m: m[1] == proband_call[0]) return hl.or_missing( hl.is_defined(transmitted_allele), hl.array([ hl.call(proband_call[0], phased=True), hl.or_missing(father_call.is_haploid(), hl.call(father_call[0], phased=True)), phase_parent_call(mother_call, transmitted_allele[0]) ]))
def phase_diploid_proband( locus: hl.expr.LocusExpression, alleles: hl.expr.ArrayExpression, proband_call: hl.expr.CallExpression, father_call: hl.expr.CallExpression, mother_call: hl.expr.CallExpression ) -> hl.expr.ArrayExpression: """ Returns phased genotype calls in the case of a diploid proband (autosomes, PAR regions of sex chromosomes or non-PAR regions of a female proband) :param LocusExpression locus: Locus in the trio MatrixTable :param ArrayExpression alleles: Alleles in the trio MatrixTable :param CallExpression proband_call: Input proband genotype call :param CallExpression father_call: Input father genotype call :param CallExpression mother_call: Input mother genotype call :return: Array containing: phased proband call, phased father call, phased mother call :rtype: ArrayExpression """ proband_v = proband_call.one_hot_alleles(alleles) father_v = hl.cond( locus.in_x_nonpar() | locus.in_y_nonpar(), hl.or_missing(father_call.is_haploid(), hl.array([father_call.one_hot_alleles(alleles)])), call_to_one_hot_alleles_array(father_call, alleles) ) mother_v = call_to_one_hot_alleles_array(mother_call, alleles) combinations = hl.flatmap( lambda f: hl.zip_with_index(mother_v) .filter(lambda m: m[1] + f[1] == proband_v) .map(lambda m: hl.struct(m=m[0], f=f[0])), hl.zip_with_index(father_v) ) return ( hl.or_missing( hl.is_defined(combinations) & (hl.len(combinations) == 1), hl.array([ hl.call(father_call[combinations[0].f], mother_call[combinations[0].m], phased=True), hl.cond(father_call.is_haploid(), hl.call(father_call[0], phased=True), phase_parent_call(father_call, combinations[0].f)), phase_parent_call(mother_call, combinations[0].m) ]) ) )
def separate_results_mt_by_pop(mt): mt = mt.annotate_cols( pheno_data=hl.zip_with_index(mt.pheno_data)).explode_cols('pheno_data') mt = mt.annotate_cols(pop_index=mt.pheno_data[0], pheno_data=mt.pheno_data[1]) mt = mt.annotate_entries( summary_stats=mt.summary_stats[mt.pop_index]).drop('pop_index') return mt
def test_window_by_locus(self): mt = hl.utils.range_matrix_table(100, 2, n_partitions=10) mt = mt.annotate_rows(locus=hl.locus('1', mt.row_idx + 1)) mt = mt.key_rows_by('locus') mt = mt.annotate_entries(e_row_idx=mt.row_idx, e_col_idx=mt.col_idx) mt = hl.window_by_locus(mt, 5).cache() self.assertEqual(mt.count_rows(), 100) rows = mt.rows() self.assertTrue(rows.all((rows.row_idx < 5) | (rows.prev_rows.length() == 5))) self.assertTrue(rows.all(hl.all(lambda x: (rows.row_idx - 1 - x[0]) == x[1].row_idx, hl.zip_with_index(rows.prev_rows)))) entries = mt.entries() self.assertTrue(entries.all(hl.all(lambda x: x.e_col_idx == entries.col_idx, entries.prev_entries))) self.assertTrue(entries.all(hl.all(lambda x: entries.row_idx - 1 - x[0] == x[1].e_row_idx, hl.zip_with_index(entries.prev_entries))))
def sum_mcnv_ac_or_af(alts, values): return hl.bind( lambda cn2_index: hl.bind( lambda values_to_sum: values_to_sum.fold(lambda acc, n: acc + n, 0 ), hl.if_else(hl.is_defined(cn2_index), values[0:cn2_index].extend( values[cn2_index + 1:]), values), ), hl.zip_with_index(alts).find(lambda t: t[1] == "<CN=2>")[0], )
def explode_lambda_ht(ht, by='ac'): ac_ht = ht.annotate(sumstats_qc=ht.sumstats_qc.select( *[x for x in ht.sumstats_qc.keys() if f'_{by}' in x])) ac_ht = ac_ht.annotate( index_ac=hl.zip_with_index(ac_ht[f'{by}_cutoffs'])).explode('index_ac') ac_ht = ac_ht.transmute( **{by: ac_ht.index_ac[1]}, **{ x: ac_ht.sumstats_qc[x][ac_ht.index_ac[0]] for x in ac_ht.sumstats_qc }) return ac_ht
def separate_results_mt_by_pop(mt, col_field='pheno_data', entry_field='summary_stats', skip_drop: bool = False): mt = mt.annotate_cols( col_array=hl.zip_with_index(mt[col_field])).explode_cols('col_array') mt = mt.transmute_cols(pop_index=mt.col_array[0], **{col_field: mt.col_array[1]}) mt = mt.annotate_entries(**{entry_field: mt[entry_field][mt.pop_index]}) if not skip_drop: mt = mt.drop('pop_index') return mt
def load_final_sumstats_mt(filter_phenos: bool = True, filter_variants: bool = True, filter_sumstats: bool = True, separate_columns_by_pop: bool = True, annotate_with_nearest_gene: bool = True): mt = hl.read_matrix_table(get_variant_results_path('full', 'mt')) variant_qual_ht = hl.read_table(get_variant_results_qc_path()) mt = mt.annotate_rows(**variant_qual_ht[mt.row_key]) pheno_qual_ht = hl.read_table( get_analysis_data_path('lambda', 'lambdas', 'full', 'ht')) mt = mt.annotate_cols(**pheno_qual_ht[mt.col_key]) if filter_phenos: keep_phenos = hl.zip_with_index( mt.pheno_data).filter(lambda x: filter_lambda_gc(x[1].lambda_gc)) mt = mt.annotate_cols(pheno_indices=keep_phenos.map(lambda x: x[0]), pheno_data=keep_phenos.map(lambda x: x[1])) mt = mt.annotate_entries( summary_stats=hl.zip_with_index(mt.summary_stats).filter( lambda x: mt.pheno_indices.contains(x[0])).map(lambda x: x[1])) mt = mt.filter_cols(hl.len(mt.pheno_data) > 0) if filter_sumstats: mt = mt.annotate_entries(summary_stats=mt.summary_stats.map( lambda x: hl.or_missing(~x.low_confidence, x))) mt = mt.filter_entries( ~mt.summary_stats.all(lambda x: hl.is_missing(x.Pvalue))) if filter_variants: mt = mt.filter_rows(mt.high_quality) if annotate_with_nearest_gene: mt = annotate_nearest_gene(mt) if separate_columns_by_pop: mt = separate_results_mt_by_pop(mt) return mt
def total_ac_or_af(variant, field): return hl.cond( variant.type == "MCNV", hl.bind( lambda cn2_index: hl.bind( lambda values_to_sum: values_to_sum.fold(lambda acc, n: acc + n, 0), hl.cond( hl.is_defined(cn2_index), field[0:cn2_index].extend(field[cn2_index + 1 :]), field, ), ), hl.zip_with_index(variant.alts).find(lambda t: t[1] == "<CN=2>")[0], ), field[0], )
def phase_haploid_proband_x_nonpar( proband_call: hl.expr.CallExpression, father_call: hl.expr.CallExpression, mother_call: hl.expr.CallExpression ) -> hl.expr.ArrayExpression: """ Returns phased genotype calls in the case of a haploid proband in the non-PAR region of X :param CallExpression proband_call: Input proband genotype call :param CallExpression father_call: Input father genotype call :param CallExpression mother_call: Input mother genotype call :return: Array containing: phased proband call, phased father call, phased mother call :rtype: ArrayExpression """ transmitted_allele = hl.zip_with_index(hl.array([mother_call[0], mother_call[1]])).find(lambda m: m[1] == proband_call[0]) return hl.or_missing( hl.is_defined(transmitted_allele), hl.array([ hl.call(proband_call[0], phased=True), hl.or_missing(father_call.is_haploid(), hl.call(father_call[0], phased=True)), phase_parent_call(mother_call, transmitted_allele[0]) ]) )
def explode_trio_matrix(tm: hl.MatrixTable, col_keys: List[str] = ['s'], keep_trio_cols: bool = True, keep_trio_entries: bool = False) -> hl.MatrixTable: """Splits a trio MatrixTable back into a sample MatrixTable. Example ------- >>> # Create a trio matrix from a sample matrix >>> pedigree = hl.Pedigree.read('data/case_control_study.fam') >>> trio_dataset = hl.trio_matrix(dataset, pedigree, complete_trios=True) >>> # Explode trio matrix back into a sample matrix >>> exploded_trio_dataset = explode_trio_matrix(trio_dataset) Notes ----- The resulting MatrixTable column schema is the same as the proband/father/mother schema, and the resulting entry schema is the same as the proband_entry/father_entry/mother_entry schema. If the `keep_trio_cols` option is set, then an additional `source_trio` column is added with the trio column data. If the `keep_trio_entries` option is set, then an additional `source_trio_entry` column is added with the trio entry data. Note ---- This assumes that the input MatrixTable is a trio MatrixTable (similar to the result of :meth:`.methods.trio_matrix`) Its entry schema has to contain 'proband_entry`, `father_entry` and `mother_entry` all with the same type. Its column schema has to contain 'proband`, `father` and `mother` all with the same type. Parameters ---------- tm : :class:`.MatrixTable` Trio MatrixTable (entries have to be a Struct with `proband_entry`, `mother_entry` and `father_entry` present) col_keys : :obj:`list` of str Column key(s) for the resulting sample MatrixTable keep_trio_cols: bool Whether to add a `source_trio` column with the trio column data (default `True`) keep_trio_entries: bool Whether to add a `source_trio_entries` column with the trio entry data (default `False`) Returns ------- :class:`.MatrixTable` Sample MatrixTable""" select_entries_expr = {'__trio_entries': hl.array([tm.proband_entry, tm.father_entry, tm.mother_entry])} if keep_trio_entries: select_entries_expr['source_trio_entry'] = hl.struct(**tm.entry) tm = tm.select_entries(**select_entries_expr) tm = tm.key_cols_by() select_cols_expr = {'__trio_members': hl.zip_with_index(hl.array([tm.proband, tm.father, tm.mother]))} if keep_trio_cols: select_cols_expr['source_trio'] = hl.struct(**tm.col) tm = tm.select_cols(**select_cols_expr) mt = tm.explode_cols(tm.__trio_members) mt = mt.transmute_entries( **mt.__trio_entries[mt.__trio_members[0]] ) mt = mt.key_cols_by() mt = mt.transmute_cols(**mt.__trio_members[1]) if col_keys: mt = mt.key_cols_by(*col_keys) return mt
def explode_trio_matrix(tm: hl.MatrixTable, col_keys: List[str] = ['s']) -> hl.MatrixTable: """Splits a trio MatrixTable back into a sample MatrixTable. Example ------- >>> # Create a trio matrix from a sample matrix >>> pedigree = hl.Pedigree.read('data/case_control_study.fam') >>> trio_dataset = hl.trio_matrix(dataset, pedigree, complete_trios=True) >>> # Explode trio matrix back into a sample matrix >>> exploded_trio_dataset = explode_trio_matrix(trio_dataset) Notes ----- This assumes that the input MatrixTable is a trio MatrixTable (similar to the result of :meth:`.methods.trio_matrix`) In particular, it should have the following entry schema: - proband_entry - father_entry - mother_entry And the following column schema: - proband - father - mother Note ---- The only entries kept are `proband_entry`, `father_entry` and `mother_entry` are dropped. The only columns kepy are `proband`, `father` and `mother` Parameters ---------- tm : :class:`.MatrixTable` Trio MatrixTable (entries have to be a Struct with `proband_entry`, `mother_entry` and `father_entry` present) call_field : :obj:`list` of str Column key(s) for the resulting sample MatrixTable Returns ------- :class:`.MatrixTable` Sample MatrixTable""" tm = tm.select_entries( __trio_entries=hl.array([tm.proband_entry, tm.father_entry, tm.mother_entry]) ) tm = tm.select_cols( __trio_members=hl.zip_with_index(hl.array([tm.proband, tm.father, tm.mother])) ) mt = tm.explode_cols(tm.__trio_members) mt = mt.select_entries( **mt.__trio_entries[mt.__trio_members[0]] ) mt = mt.key_cols_by() mt = mt.select_cols(**mt.__trio_members[1]) if col_keys: mt = mt.key_cols_by(*col_keys) return mt
def combine_pheno_files_multi_sex(pheno_file_dict: dict, cov_ht: hl.Table, truncated_codes_only: bool = True, custom_data_categorical: bool = True): full_mt: hl.MatrixTable = None sexes = ('both_sexes', 'females', 'males') for data_type, mt in pheno_file_dict.items(): mt = mt.select_rows(**cov_ht[mt.row_key]) print(data_type) if data_type == 'phecode': mt = mt.key_cols_by(trait_type=data_type, phenocode=mt.phecode, pheno_sex=mt.phecode_sex, coding=NULL_STR_KEY, modifier=NULL_STR_KEY) mt = mt.select_cols(**compute_cases_binary(mt.case_control, mt.sex), description=mt.phecode_description, description_more=NULL_STR, coding_description=NULL_STR, category=mt.phecode_group) mt = mt.select_entries(**format_entries(mt.case_control, mt.sex)) elif data_type == 'prescriptions': def format_prescription_name(pheno): return pheno.replace(',', '|').replace('/', '_') mt = mt.select_entries( value=hl.or_else(hl.len(mt.values) > 0, False)) mt2 = mt.group_cols_by( trait_type=data_type, phenocode=format_prescription_name( mt.Drug_Category_and_Indication), pheno_sex='both_sexes', coding=NULL_STR_KEY, modifier=NULL_STR_KEY, ).aggregate(value=hl.agg.any(mt.value)).select_cols( category=NULL_STR) mt = mt.key_cols_by(trait_type=data_type, phenocode=format_prescription_name( mt.Generic_Name), pheno_sex='both_sexes', coding=NULL_STR_KEY, modifier=NULL_STR_KEY) mt = mt.select_cols(category=mt.Drug_Category_and_Indication) mt = mt.union_cols(mt2) mt = mt.select_cols(**compute_cases_binary(mt.value, mt.sex), description=NULL_STR, description_more=NULL_STR, coding_description=NULL_STR, category=mt.category) mt = mt.select_entries(**format_entries(mt.value, mt.sex)) elif data_type == 'custom': mt = mt.select_entries(**format_entries(mt.value, mt.sex)) mt = mt.select_cols(**{ f'n_cases_{sex}': hl.agg.count_where( hl.cond(mt.trait_type == 'categorical', mt[sex] == 1.0, hl.is_defined(mt[sex]))) for sex in sexes }, description=NULL_STR, description_more=NULL_STR, coding_description=NULL_STR, category=mt.category) elif data_type == 'additional': mt = mt.key_cols_by(trait_type='continuous', phenocode=mt.pheno, pheno_sex='both_sexes', coding=NULL_STR_KEY, modifier=mt.coding) mt = mt.select_cols( **{ f'n_cases_{sex}': hl.agg.count_where(hl.is_defined(mt[sex])) for sex in sexes }, description=hl.coalesce( *[mt[f'{sex}_pheno'].meaning for sex in sexes]), description_more=hl.coalesce( *[mt[f'{sex}_pheno'].description for sex in sexes]), coding_description=NULL_STR, category=NULL_STR) elif data_type in ('categorical', 'continuous'): mt = mt.key_cols_by(trait_type=data_type, phenocode=hl.str(mt.pheno), pheno_sex='both_sexes', coding=mt.coding if data_type == 'categorical' else NULL_STR_KEY, modifier=NULL_STR_KEY if data_type == 'categorical' else mt.coding) def check_func(x): return x if data_type == 'categorical' else hl.is_defined(x) mt = mt.select_cols( **{ f'n_cases_{sex}': hl.agg.count_where(check_func(mt[sex])) for sex in sexes }, description=hl.coalesce( *[mt[f'{sex}_pheno'].Field for sex in sexes]), description_more=hl.coalesce( *[mt[f'{sex}_pheno'].Notes for sex in sexes]), coding_description=hl.coalesce( *[mt[f'{sex}_pheno'].meaning for sex in sexes]) if data_type == 'categorical' else NULL_STR, category=hl.coalesce( *[mt[f'{sex}_pheno'].Path for sex in sexes])) mt = mt.select_entries( **{sex: hl.float64(mt[sex]) for sex in sexes}) elif 'icd_code' in list(mt.col_key): icd_version = mt.icd_version if 'icd_version' in list( mt.col) else '' mt = mt.key_cols_by(trait_type=icd_version, phenocode=mt.icd_code, pheno_sex='both_sexes', coding=NULL_STR_KEY, modifier=NULL_STR_KEY) if truncated_codes_only: mt = mt.filter_cols(hl.len(mt.icd_code) == 3) mt = mt.collect_cols_by_key() mt = mt.annotate_cols(keep=hl.if_else( hl.len(mt.truncated) == 1, 0, hl.zip_with_index(mt.truncated).filter(lambda x: x[1]).map( lambda x: x[0])[0])) mt = mt.select_entries(**{x: mt[x][mt.keep] for x in mt.entry}) mt = mt.select_cols( **{x: mt[x][mt.keep] for x in mt.col_value if x != 'keep'}) mt = mt.select_cols(**compute_cases_binary(mt.any_codes, mt.sex), description=mt.short_meaning, description_more="truncated: " + hl.str(mt.truncated) if 'truncated' in list( mt.col_value) else NULL_STR, coding_description=NULL_STR, category=mt.meaning) mt = mt.select_entries(**format_entries(mt.any_codes, mt.sex)) elif data_type == 'icd_first_occurrence': mt = mt.select_entries(**format_entries(mt.value, mt.sex)) mt = mt.select_cols(**compute_cases_binary( hl.is_defined(mt.both_sexes), mt.sex), description=mt.Field, description_more=mt.Notes, coding_description=NULL_STR, category=mt.Path) else: # 'biomarkers', 'activity_monitor' mt = mt.key_cols_by(trait_type=mt.trait_type if 'trait_type' in list(mt.col) else data_type, phenocode=hl.str(mt.pheno), pheno_sex='both_sexes', coding=NULL_STR_KEY, modifier=NULL_STR_KEY) mt = mt.select_entries(**format_entries(mt.value, mt.sex)) mt = mt.select_cols(**{ f'n_cases_{sex}': hl.agg.count_where(hl.is_defined(mt[sex])) for sex in sexes }, description=mt.Field, description_more=NULL_STR, coding_description=NULL_STR, category=mt.Path) # else: # raise ValueError('pheno or icd_code not in column key. New data type?') mt = mt.checkpoint( tempfile.mktemp(prefix=f'/tmp/{data_type}_', suffix='.mt')) if full_mt is None: full_mt = mt else: full_mt = full_mt.union_cols(mt, row_join_type='outer' if data_type == 'prescriptions' else 'inner') full_mt = full_mt.unfilter_entries() # Here because prescription data was smaller than the others (so need to set the missing samples to 0) return full_mt.select_entries( **{ sex: hl.cond(full_mt.trait_type == 'prescriptions', hl.or_else(full_mt[sex], hl.float64(0.0)), full_mt[sex]) for sex in sexes })
def full_outer_join_mt(left: hl.MatrixTable, right: hl.MatrixTable) -> hl.MatrixTable: """Performs a full outer join on `left` and `right`. Replaces row, column, and entry fields with the following: - `left_row` / `right_row`: structs of row fields from left and right. - `left_col` / `right_col`: structs of column fields from left and right. - `left_entry` / `right_entry`: structs of entry fields from left and right. Parameters ---------- left : :class:`.MatrixTable` right : :class:`.MatrixTable` Returns ------- :class:`.MatrixTable` """ if [x.dtype for x in left.row_key.values() ] != [x.dtype for x in right.row_key.values()]: raise ValueError(f"row key types do not match:\n" f" left: {list(left.row_key.values())}\n" f" right: {list(right.row_key.values())}") if [x.dtype for x in left.col_key.values() ] != [x.dtype for x in right.col_key.values()]: raise ValueError(f"column key types do not match:\n" f" left: {list(left.col_key.values())}\n" f" right: {list(right.col_key.values())}") left = left.select_rows(left_row=left.row) left_t = left.localize_entries('left_entries', 'left_cols') right = right.select_rows(right_row=right.row) right_t = right.localize_entries('right_entries', 'right_cols') ht = left_t.join(right_t, how='outer') ht = ht.annotate_globals(left_keys=hl.group_by( lambda t: t[0], hl.zip_with_index( ht.left_cols.map(lambda x: hl.tuple([x[f] for f in left.col_key])), index_first=False)).map_values( lambda elts: elts.map(lambda t: t[1])), right_keys=hl.group_by( lambda t: t[0], hl.zip_with_index( ht.right_cols.map(lambda x: hl.tuple( [x[f] for f in right.col_key])), index_first=False)). map_values(lambda elts: elts.map(lambda t: t[1]))) ht = ht.annotate_globals(key_indices=hl.array(ht.left_keys.key_set( ).union(ht.right_keys.key_set())).map(lambda k: hl.struct( k=k, left_indices=ht.left_keys.get(k), right_indices=ht.right_keys.get(k))).flatmap(lambda s: hl.case().when( hl.is_defined(s.left_indices) & hl.is_defined(s.right_indices), hl.range(0, s.left_indices.length()).flatmap(lambda i: hl.range( 0, s.right_indices.length()).map(lambda j: hl.struct( k=s.k, left_index=s.left_indices[i], right_index=s.right_indices[j]))) ).when( hl.is_defined(s.left_indices), s.left_indices.map(lambda elt: hl.struct( k=s.k, left_index=elt, right_index=hl.null('int32')))).when( hl.is_defined(s.right_indices), s.right_indices.map(lambda elt: hl.struct( k=s.k, left_index=hl.null('int32'), right_index=elt)) ).or_error('assertion error'))) ht = ht.annotate(__entries=ht.key_indices.map( lambda s: hl.struct(left_entry=ht.left_entries[s.left_index], right_entry=ht.right_entries[s.right_index]))) ht = ht.annotate_globals(__cols=ht.key_indices.map( lambda s: hl.struct(**{f: s.k[i] for i, f in enumerate(left.col_key)}, left_col=ht.left_cols[s.left_index], right_col=ht.right_cols[s.right_index]))) ht = ht.drop('left_entries', 'left_cols', 'left_keys', 'right_entries', 'right_cols', 'right_keys', 'key_indices') return ht._unlocalize_entries('__entries', '__cols', list(left.col_key))
def full_outer_join_mt(left: hl.MatrixTable, right: hl.MatrixTable) -> hl.MatrixTable: """Performs a full outer join on `left` and `right`. Replaces row, column, and entry fields with the following: - `left_row` / `right_row`: structs of row fields from left and right. - `left_col` / `right_col`: structs of column fields from left and right. - `left_entry` / `right_entry`: structs of entry fields from left and right. Parameters ---------- left : :class:`.MatrixTable` right : :class:`.MatrixTable` Returns ------- :class:`.MatrixTable` """ if [x.dtype for x in left.row_key.values()] != [x.dtype for x in right.row_key.values()]: raise ValueError(f"row key types do not match:\n" f" left: {list(left.row_key.values())}\n" f" right: {list(right.row_key.values())}") if [x.dtype for x in left.col_key.values()] != [x.dtype for x in right.col_key.values()]: raise ValueError(f"column key types do not match:\n" f" left: {list(left.col_key.values())}\n" f" right: {list(right.col_key.values())}") left = left.select_rows(left_row=left.row) left_t = left.localize_entries('left_entries', 'left_cols') right = right.select_rows(right_row=right.row) right_t = right.localize_entries('right_entries', 'right_cols') ht = left_t.join(right_t, how='outer') ht = ht.annotate_globals( left_keys=hl.group_by( lambda t: t[0], hl.zip_with_index( ht.left_cols.map(lambda x: hl.tuple([x[f] for f in left.col_key])), index_first=False)).map_values( lambda elts: elts.map(lambda t: t[1])), right_keys=hl.group_by( lambda t: t[0], hl.zip_with_index( ht.right_cols.map(lambda x: hl.tuple([x[f] for f in right.col_key])), index_first=False)).map_values( lambda elts: elts.map(lambda t: t[1]))) ht = ht.annotate_globals( key_indices=hl.array(ht.left_keys.key_set().union(ht.right_keys.key_set())) .map(lambda k: hl.struct(k=k, left_indices=ht.left_keys.get(k), right_indices=ht.right_keys.get(k))) .flatmap(lambda s: hl.case() .when(hl.is_defined(s.left_indices) & hl.is_defined(s.right_indices), hl.range(0, s.left_indices.length()).flatmap( lambda i: hl.range(0, s.right_indices.length()).map( lambda j: hl.struct(k=s.k, left_index=s.left_indices[i], right_index=s.right_indices[j])))) .when(hl.is_defined(s.left_indices), s.left_indices.map( lambda elt: hl.struct(k=s.k, left_index=elt, right_index=hl.null('int32')))) .when(hl.is_defined(s.right_indices), s.right_indices.map( lambda elt: hl.struct(k=s.k, left_index=hl.null('int32'), right_index=elt))) .or_error('assertion error'))) ht = ht.annotate(__entries=ht.key_indices.map(lambda s: hl.struct(left_entry=ht.left_entries[s.left_index], right_entry=ht.right_entries[s.right_index]))) ht = ht.annotate_globals(__cols=ht.key_indices.map( lambda s: hl.struct(**{f: s.k[i] for i, f in enumerate(left.col_key)}, left_col=ht.left_cols[s.left_index], right_col=ht.right_cols[s.right_index]))) ht = ht.drop('left_entries', 'left_cols', 'left_keys', 'right_entries', 'right_cols', 'right_keys', 'key_indices') return ht._unlocalize_entries('__entries', '__cols', list(left.col_key))
def import_structural_variants(vcf_path): ds = hl.import_vcf(vcf_path, force_bgz=True, min_partitions=32).rows() ds = ds.annotate( **{field.lower(): ds.info[field] for field in TOP_LEVEL_INFO_FIELDS}) ds = ds.annotate( variant_id=ds.rsid.replace("^gnomAD-SV_v2.1_", ""), reference_genome="GRCh37", # Start chrom=ds.locus.contig, pos=ds.locus.position, xpos=x_position(ds.locus.contig, ds.locus.position), # End end=ds.info.END, xend=x_position(ds.locus.contig, ds.info.END), # Start 2 chrom2=ds.info.CHR2, pos2=ds.info.POS2, xpos2=x_position(ds.info.CHR2, ds.info.POS2), # End 2 end2=ds.info.END2, xend2=x_position(ds.info.CHR2, ds.info.END2), # Other length=ds.info.SVLEN, type=ds.info.SVTYPE, alts=ds.alleles[1:], ) # MULTIALLELIC should not be used as a quality filter in the browser ds = ds.annotate(filters=ds.filters.difference(hl.set(["MULTIALLELIC"]))) # Group gene lists for all consequences in one field ds = ds.annotate(consequences=hl.array([ hl.struct( consequence=csq.lower(), genes=hl.or_else(ds.info[f"PROTEIN_CODING__{csq}"], hl.empty_array(hl.tstr)), ) for csq in RANKED_CONSEQUENCES if csq not in ("INTERGENIC", "NEAREST_TSS") ]).filter(lambda csq: hl.len(csq.genes) > 0)) ds = ds.annotate(intergenic=ds.info.PROTEIN_CODING__INTERGENIC) ds = ds.annotate(major_consequence=hl.rbind( ds.consequences.find(lambda csq: hl.len(csq.genes) > 0), lambda csq: hl.or_else(csq.consequence, hl.or_missing(ds.intergenic, "intergenic")), )) # Collect set of all genes for which a variant has a consequence ds = ds.annotate(genes=hl.set(ds.consequences.flatmap(lambda c: c.genes))) # Group per-population frequency values ds = ds.annotate(freq=hl.struct( **{field.lower(): ds.info[field] for field in FREQ_FIELDS}, populations=[ hl.struct(id=pop, **{ field.lower(): ds.info[f"{pop}_{field}"] for field in FREQ_FIELDS }) for pop in DIVISIONS ], )) # For MCNVs, store per-copy number allele counts ds = ds.annotate(freq=ds.freq.annotate(copy_numbers=hl.or_missing( ds.type == "MCNV", hl.zip_with_index(ds.alts).map(lambda pair: hl.rbind( pair[0], pair[1], lambda index, alt: hl.struct( # Extract copy number. Example, get 2 from "CN=<2>" copy_number=hl.int(alt[4:-1]), ac=ds.freq.ac[index], ), )), ))) # For MCNVs, sum AC/AF for all alt alleles except CN=2 ds = ds.annotate(freq=ds.freq.annotate( ac=hl.if_else(ds.type == "MCNV", sum_mcnv_ac_or_af( ds.alts, ds.freq.ac), ds.freq.ac[0]), af=hl.if_else(ds.type == "MCNV", sum_mcnv_ac_or_af( ds.alts, ds.freq.af), ds.freq.af[0]), populations=hl.if_else( ds.type == "MCNV", ds.freq.populations.map(lambda pop: pop.annotate( ac=sum_mcnv_ac_or_af(ds.alts, pop.ac), af=sum_mcnv_ac_or_af(ds.alts, pop.af), )), ds.freq.populations.map( lambda pop: pop.annotate(ac=pop.ac[0], af=pop.af[0])), ), )) # Add hemizygous frequencies ds = ds.annotate(hemizygote_count=hl.dict( [( pop_id, hl.if_else(((ds.chrom == "X") | (ds.chrom == "Y")) & ~ds.par, ds.info[f"{pop_id}_MALE_N_HEMIALT"], 0), ) for pop_id in POPULATIONS] + [(f"{pop_id}_FEMALE", 0) for pop_id in POPULATIONS] + [( f"{pop_id}_MALE", hl.if_else(((ds.chrom == "X") | (ds.chrom == "Y")) & ~ds.par, ds.info[f"{pop_id}_MALE_N_HEMIALT"], 0), ) for pop_id in POPULATIONS] + [("FEMALE", 0)] + [("MALE", hl.if_else(((ds.chrom == "X") | (ds.chrom == "Y")) & ~ds.par, ds.info.MALE_N_HEMIALT, 0))])) ds = ds.annotate(freq=ds.freq.annotate( hemizygote_count=hl.or_missing( ds.type != "MCNV", hl.if_else(((ds.chrom == "X") | (ds.chrom == "Y")) & ~ds.par, ds.info.MALE_N_HEMIALT, 0), ), populations=hl.if_else( ds.type != "MCNV", ds.freq.populations.map(lambda pop: pop.annotate( hemizygote_count=ds.hemizygote_count[pop.id])), ds.freq.populations.map( lambda pop: pop.annotate(hemizygote_count=hl.null(hl.tint))), ), )) ds = ds.drop("hemizygote_count") # Rename n_homalt ds = ds.annotate(freq=ds.freq.annotate( homozygote_count=ds.freq.n_homalt, populations=ds.freq.populations.map(lambda pop: pop.annotate( homozygote_count=pop.n_homalt).drop("n_homalt")), ).drop("n_homalt")) # Re-key ds = ds.key_by("variant_id") ds = ds.drop("locus", "alleles", "info", "rsid") return ds
def combine_pheno_files_multi_sex(pheno_file_dict: dict, cov_ht: hl.Table, truncated_codes_only: bool = True): full_mt: hl.MatrixTable = None sexes = ('both_sexes', 'females', 'males') def counting_func(value, trait_type): return value if trait_type == 'categorical' else hl.is_defined(value) for data_type, mt in pheno_file_dict.items(): mt = mt.select_rows(**cov_ht[mt.row_key]) print(data_type) if data_type == 'custom': mt = mt.select_entries(**format_entries(mt.value, mt.sex)) mt = mt.select_cols(**{f'n_cases_{sex}': hl.agg.count_where( hl.cond(mt.trait_type == 'categorical', mt[sex] == 1.0, hl.is_defined(mt[sex])) ) for sex in sexes}, **{extra_col: mt[extra_col] if extra_col in list(mt.col) else NULL_STR for extra_col in PHENO_DESCRIPTION_FIELDS}) elif data_type in ('categorical', 'continuous'): def get_non_missing_field(mt, field_name): return hl.coalesce(*[mt[f'{sex}_pheno'][field_name] for sex in sexes]) mt = mt.select_cols(**compute_cases_binary(counting_func(mt.both_sexes, data_type), mt.sex), # **{f'n_cases_{sex}': hl.agg.count_where(counting_func(mt[sex], mt.trait_type)) for sex in sexes}, description=get_non_missing_field(mt, 'Field'), description_more=get_non_missing_field(mt, 'Notes'), coding_description=get_non_missing_field(mt, 'meaning') if data_type == 'categorical' else NULL_STR, category=get_non_missing_field(mt, 'Path')) mt = mt.select_entries(**{sex: hl.float64(mt[sex]) for sex in sexes}) # TODO: got here - move some of this to ICD load (get icd_version as icd10 and move truncation steps there) elif 'icd_code' in list(mt.col_key): icd_version = mt.icd_version if 'icd_version' in list(mt.col) else 'icd10' mt = mt.key_cols_by(trait_type=icd_version, phenocode=mt.icd_code, pheno_sex='both_sexes', coding=NULL_STR_KEY, modifier=NULL_STR_KEY) if truncated_codes_only: mt = mt.filter_cols(hl.len(mt.icd_code) == 3) mt = mt.collect_cols_by_key() mt = mt.annotate_cols(keep=hl.if_else( hl.len(mt.truncated) == 1, 0, hl.zip_with_index(mt.truncated).filter(lambda x: x[1]).map(lambda x: x[0])[0])) mt = mt.select_entries(**{x: mt[x][mt.keep] for x in mt.entry}) mt = mt.select_cols(**{x: mt[x][mt.keep] for x in mt.col_value if x != 'keep'}) mt = mt.select_cols(**compute_cases_binary(mt.any_codes, mt.sex), description=mt.short_meaning, description_more="truncated: " + hl.str(mt.truncated) if 'truncated' in list(mt.col_value) else NULL_STR, coding_description=NULL_STR, category=mt.meaning) mt = mt.select_entries(**format_entries(mt.any_codes, mt.sex)) elif data_type == 'icd_first_occurrence': mt = mt.select_entries(**format_entries(hl.is_defined(mt.value), mt.sex)) mt = mt.select_cols(**compute_cases_binary(mt.both_sexes == 1.0, mt.sex), description=mt.Field, description_more=mt.Notes, coding_description=NULL_STR, category=mt.Path) else: # 'biomarkers', 'activity_monitor' mt = mt.key_cols_by(trait_type=mt.trait_type if 'trait_type' in list(mt.col) else data_type, phenocode=hl.str(mt.pheno), pheno_sex='both_sexes', coding=NULL_STR_KEY, modifier=NULL_STR_KEY) mt = mt.select_entries(**format_entries(mt.value, mt.sex)) mt = mt.select_cols(**{f'n_cases_{sex}': hl.agg.count_where(hl.is_defined(mt[sex])) for sex in sexes}, description=mt.Field, description_more=NULL_STR, coding_description=NULL_STR, category=mt.Path) # else: # raise ValueError('pheno or icd_code not in column key. New data type?') mt = mt.checkpoint(tempfile.mktemp(prefix=f'/tmp/{data_type}_', suffix='.mt')) if full_mt is None: full_mt = mt else: full_mt = full_mt.union_cols(mt, row_join_type='outer') full_mt = full_mt.unfilter_entries() # Here because prescription data was smaller than the others (so need to set the missing samples to 0) return full_mt.select_entries(**{sex: hl.cond( full_mt.trait_type == 'prescriptions', hl.or_else(full_mt[sex], hl.float64(0.0)), full_mt[sex]) for sex in sexes})
def main(args): hl.init() # Read in all sumstats mt = load_final_sumstats_mt(filter_phenos=True, filter_variants=False, filter_sumstats=True, separate_columns_by_pop=False, annotate_with_nearest_gene=False) # Annotate per-entry sample size def get_n(pheno_data, i): return pheno_data[i].n_cases + hl.or_else(pheno_data[i].n_controls, 0) mt = mt.annotate_entries(summary_stats=hl.map( lambda x: x[1].annotate(N=hl.or_missing(hl.is_defined(x[1]), get_n(mt.pheno_data, x[0]))), hl.zip_with_index(mt.summary_stats))) # Exclude entries with low confidence flag. if not args.keep_low_confidence_variants: mt = mt.annotate_entries(summary_stats=hl.map( lambda x: hl.or_missing(~x.low_confidence, x), mt.summary_stats)) # Run fixed-effect meta-analysis (all + leave-one-out) mt = mt.annotate_entries(unnorm_beta=mt.summary_stats.BETA / (mt.summary_stats.SE**2), inv_se2=1 / (mt.summary_stats.SE**2)) mt = mt.annotate_entries( sum_unnorm_beta=all_and_leave_one_out(mt.unnorm_beta, mt.pheno_data.pop), sum_inv_se2=all_and_leave_one_out(mt.inv_se2, mt.pheno_data.pop)) mt = mt.transmute_entries(META_BETA=mt.sum_unnorm_beta / mt.sum_inv_se2, META_SE=hl.map(lambda x: hl.sqrt(1 / x), mt.sum_inv_se2)) mt = mt.annotate_entries( META_Pvalue=hl.map(lambda x: 2 * hl.pnorm(x), -hl.abs(mt.META_BETA / mt.META_SE))) # Run heterogeneity test (Cochran's Q) mt = mt.annotate_entries(META_Q=hl.map( lambda x: hl.sum((mt.summary_stats.BETA - x)**2 * mt.inv_se2), mt.META_BETA), variant_exists=hl.map(lambda x: ~hl.is_missing(x), mt.summary_stats.BETA)) mt = mt.annotate_entries(META_N_pops=all_and_leave_one_out( mt.variant_exists, mt.pheno_data.pop)) mt = mt.annotate_entries(META_Pvalue_het=hl.map( lambda i: hl.pchisqtail(mt.META_Q[i], mt.META_N_pops[i] - 1), hl.range(hl.len(mt.META_Q)))) # Add other annotations mt = mt.annotate_entries( ac_cases=hl.map(lambda x: x["AF.Cases"] * x.N, mt.summary_stats), ac_controls=hl.map(lambda x: x["AF.Controls"] * x.N, mt.summary_stats), META_AC_Allele2=all_and_leave_one_out( mt.summary_stats.AF_Allele2 * mt.summary_stats.N, mt.pheno_data.pop), META_N=all_and_leave_one_out(mt.summary_stats.N, mt.pheno_data.pop)) mt = mt.annotate_entries( META_AF_Allele2=mt.META_AC_Allele2 / mt.META_N, META_AF_Cases=all_and_leave_one_out(mt.ac_cases, mt.pheno_data.pop) / mt.META_N, META_AF_Controls=all_and_leave_one_out(mt.ac_controls, mt.pheno_data.pop) / mt.META_N) mt = mt.drop('unnorm_beta', 'inv_se2', 'variant_exists', 'ac_cases', 'ac_controls', 'summary_stats', 'META_AC_Allele2') # Format everything into array<struct> def is_finite_or_missing(x): return (hl.or_missing(hl.is_finite(x), x)) meta_fields = [ 'BETA', 'SE', 'Pvalue', 'Q', 'Pvalue_het', 'N', 'N_pops', 'AF_Allele2', 'AF_Cases', 'AF_Controls' ] mt = mt.transmute_entries(meta_analysis=hl.map( lambda i: hl.struct( **{ field: is_finite_or_missing(mt[f'META_{field}'][i]) for field in meta_fields }), hl.range(hl.len(mt.META_BETA)))) col_fields = ['n_cases', 'n_controls'] mt = mt.annotate_cols( **{ field: all_and_leave_one_out(mt.pheno_data[field], mt.pheno_data.pop) for field in col_fields }) col_fields += ['pop'] mt = mt.annotate_cols(pop=all_and_leave_one_out( mt.pheno_data.pop, mt.pheno_data.pop, all_f=lambda x: x, loo_f=lambda i, x: hl.filter(lambda y: y != x[i], x), )) mt = mt.transmute_cols(meta_analysis_data=hl.map( lambda i: hl.struct(**{field: mt[field][i] for field in col_fields}), hl.range(hl.len(mt.pop)))) mt.describe() mt.write(get_meta_analysis_results_path(), overwrite=args.overwrite) hl.copy_log('gs://ukb-diverse-pops/combined_results/meta_analysis.log')
def full_outer_join_mt(left: hl.MatrixTable, right: hl.MatrixTable) -> hl.MatrixTable: """Performs a full outer join on `left` and `right`. Replaces row, column, and entry fields with the following: - `left_row` / `right_row`: structs of row fields from left and right. - `left_col` / `right_col`: structs of column fields from left and right. - `left_entry` / `right_entry`: structs of entry fields from left and right. Examples -------- The following creates and joins two random datasets with disjoint sample ids but non-disjoint variant sets. We use :func:`.or_else` to attempt to find a non-missing genotype. If neither genotype is non-missing, then the genotype is set to missing. In particular, note that Samples `2` and `3` have missing genotypes for loci 1:1 and 1:2 because those loci are not present in `mt2` and these samples are not present in `mt1` >>> hl.set_global_seed(0) >>> mt1 = hl.balding_nichols_model(1, 2, 3) >>> mt2 = hl.balding_nichols_model(1, 2, 3) >>> mt2 = mt2.key_rows_by(locus=hl.locus(mt2.locus.contig, ... mt2.locus.position+2), ... alleles=mt2.alleles) >>> mt2 = mt2.key_cols_by(sample_idx=mt2.sample_idx+2) >>> mt1.show() +---------------+------------+------+------+ | locus | alleles | 0.GT | 1.GT | +---------------+------------+------+------+ | locus<GRCh37> | array<str> | call | call | +---------------+------------+------+------+ | 1:1 | ["A","C"] | 0/1 | 0/1 | | 1:2 | ["A","C"] | 1/1 | 1/1 | | 1:3 | ["A","C"] | 0/0 | 0/0 | +---------------+------------+------+------+ <BLANKLINE> >>> mt2.show() # doctest: +SKIP_OUTPUT_CHECK +---------------+------------+------+------+ | locus | alleles | 0.GT | 1.GT | +---------------+------------+------+------+ | locus<GRCh37> | array<str> | call | call | +---------------+------------+------+------+ | 1:3 | ["A","C"] | 0/1 | 1/1 | | 1:4 | ["A","C"] | 0/1 | 0/1 | | 1:5 | ["A","C"] | 1/1 | 0/0 | +---------------+------------+------+------+ <BLANKLINE> >>> mt3 = hl.experimental.full_outer_join_mt(mt1, mt2) >>> mt3 = mt3.select_entries(GT=hl.or_else(mt3.left_entry.GT, mt3.right_entry.GT)) >>> mt3.show() +---------------+------------+------+------+------+------+ | locus | alleles | 0.GT | 1.GT | 2.GT | 3.GT | +---------------+------------+------+------+------+------+ | locus<GRCh37> | array<str> | call | call | call | call | +---------------+------------+------+------+------+------+ | 1:1 | ["A","C"] | 0/1 | 0/1 | NA | NA | | 1:2 | ["A","C"] | 1/1 | 1/1 | NA | NA | | 1:3 | ["A","C"] | 0/0 | 0/0 | 0/1 | 1/1 | | 1:4 | ["A","C"] | NA | NA | 0/1 | 0/1 | | 1:5 | ["A","C"] | NA | NA | 1/1 | 0/0 | +---------------+------------+------+------+------+------+ <BLANKLINE> Parameters ---------- left : :class:`.MatrixTable` right : :class:`.MatrixTable` Returns ------- :class:`.MatrixTable` """ if [x.dtype for x in left.row_key.values() ] != [x.dtype for x in right.row_key.values()]: raise ValueError(f"row key types do not match:\n" f" left: {list(left.row_key.values())}\n" f" right: {list(right.row_key.values())}") if [x.dtype for x in left.col_key.values() ] != [x.dtype for x in right.col_key.values()]: raise ValueError(f"column key types do not match:\n" f" left: {list(left.col_key.values())}\n" f" right: {list(right.col_key.values())}") left = left.select_rows(left_row=left.row) left_t = left.localize_entries('left_entries', 'left_cols') right = right.select_rows(right_row=right.row) right_t = right.localize_entries('right_entries', 'right_cols') ht = left_t.join(right_t, how='outer') ht = ht.annotate_globals(left_keys=hl.group_by( lambda t: t[0], hl.zip_with_index( ht.left_cols.map(lambda x: hl.tuple([x[f] for f in left.col_key])), index_first=False)).map_values( lambda elts: elts.map(lambda t: t[1])), right_keys=hl.group_by( lambda t: t[0], hl.zip_with_index( ht.right_cols.map(lambda x: hl.tuple( [x[f] for f in right.col_key])), index_first=False)). map_values(lambda elts: elts.map(lambda t: t[1]))) ht = ht.annotate_globals(key_indices=hl.array(ht.left_keys.key_set( ).union(ht.right_keys.key_set())).map(lambda k: hl.struct( k=k, left_indices=ht.left_keys.get(k), right_indices=ht.right_keys.get(k))).flatmap(lambda s: hl.case().when( hl.is_defined(s.left_indices) & hl.is_defined(s.right_indices), hl.range(0, s.left_indices.length()).flatmap(lambda i: hl.range( 0, s.right_indices.length()).map(lambda j: hl.struct( k=s.k, left_index=s.left_indices[i], right_index=s.right_indices[j]))) ).when( hl.is_defined(s.left_indices), s.left_indices.map(lambda elt: hl.struct( k=s.k, left_index=elt, right_index=hl.null('int32')))).when( hl.is_defined(s.right_indices), s.right_indices.map(lambda elt: hl.struct( k=s.k, left_index=hl.null('int32'), right_index=elt)) ).or_error('assertion error'))) ht = ht.annotate(__entries=ht.key_indices.map( lambda s: hl.struct(left_entry=ht.left_entries[s.left_index], right_entry=ht.right_entries[s.right_index]))) ht = ht.annotate_globals(__cols=ht.key_indices.map( lambda s: hl.struct(**{f: s.k[i] for i, f in enumerate(left.col_key)}, left_col=ht.left_cols[s.left_index], right_col=ht.right_cols[s.right_index]))) ht = ht.drop('left_entries', 'left_cols', 'left_keys', 'right_entries', 'right_cols', 'right_keys', 'key_indices') return ht._unlocalize_entries('__entries', '__cols', list(left.col_key))
def sorted_transcript_consequences_v2(vep_root): """Sort transcripts by 3 properties: 1. coding > non-coding 2. transcript consequence severity 3. canonical > non-canonical so that the 1st array entry will be for the coding, most-severe, canonical transcript (assuming one exists). Also, for each transcript in the array, computes these additional fields: domains: converts structs with db/name fields to string db:name hgvs: hgvsp (formatted for synonymous variants) if it exists, otherwise hgvsc major_consequence: set to most severe consequence for that transcript ( VEP sometimes provides multiple consequences for a single transcript) major_consequence_rank: major_consequence rank based on VEP SO ontology (most severe = 1) (see http://www.ensembl.org/info/genome/variation/predicted_data.html) category: set to one of: "lof", "missense", "synonymous", "other" based on the value of major_consequence. Args: vep_root (StructExpression): root path of the VEP struct in the MT """ consequences = (vep_root.transcript_consequences.map( lambda c: c.annotate(consequence_terms=c.consequence_terms.filter( lambda t: ~OMIT_CONSEQUENCE_TERMS.contains(t))) ).filter(lambda c: c.consequence_terms.size() > 0).map( lambda c: c.annotate(major_consequence=hl.sorted( c.consequence_terms, key=consequence_term_rank)[0]) ).map(lambda c: c.annotate( category=(hl.case().when( consequence_term_rank(c.major_consequence) <= consequence_term_rank("frameshift_variant"), "lof").when( consequence_term_rank(c.major_consequence) <= consequence_term_rank("missense_variant"), "missense", ).when( consequence_term_rank(c.major_consequence) <= consequence_term_rank("synonymous_variant"), "synonymous", ).default("other")), domains=c.domains.map(lambda domain: domain.db + ":" + domain.name), hgvs=hl.cond( hl.is_missing(c.hgvsp) | SPLICE_CONSEQUENCES.contains( c.major_consequence), c.hgvsc.split(":")[-1], hgvsp_from_consequence_amino_acids(c), ), major_consequence_rank=consequence_term_rank(c.major_consequence), ))) consequences = hl.sorted( consequences, lambda c: (hl.bind( lambda is_coding, is_most_severe, is_canonical: (hl.cond( is_coding, hl.cond(is_most_severe, hl.cond(is_canonical, 1, 2), hl.cond(is_canonical, 3, 4)), hl.cond(is_most_severe, hl.cond(is_canonical, 5, 6), hl.cond(is_canonical, 7, 8)), )), hl.or_else(c.biotype, "") == "protein_coding", hl.set(c.consequence_terms).contains(vep_root. most_severe_consequence), hl.or_else(c.canonical, 0) == 1, )), ) consequences = hl.zip_with_index(consequences).map( lambda csq_with_index: csq_with_index[1].annotate(transcript_rank= csq_with_index[0])) # TODO: Discard most of lof_info field # Keep whether lof_info contains DONOR_DISRUPTION, ACCEPTOR_DISRUPTION, or DE_NOVO_DONOR consequences = consequences.map(lambda c: c.select( "amino_acids", "biotype", "canonical", "category", "cdna_end", "cdna_start", "codons", "consequence_terms", "domains", "gene_id", "gene_symbol", "hgvs", "hgvsc", "hgvsp", "lof_filter", "lof_flags", "lof_info", "lof", "major_consequence", "major_consequence_rank", "polyphen_prediction", "protein_id", "protein_start", "sift_prediction", "transcript_id", "transcript_rank", )) return consequences
def get_expr_for_vep_sorted_transcript_consequences_array( vep_root, include_coding_annotations=True, omit_consequences=OMIT_CONSEQUENCE_TERMS): """Sort transcripts by 3 properties: 1. coding > non-coding 2. transcript consequence severity 3. canonical > non-canonical so that the 1st array entry will be for the coding, most-severe, canonical transcript (assuming one exists). Also, for each transcript in the array, computes these additional fields: domains: converts Array[Struct] to string of comma-separated domain names hgvs: set to hgvsp is it exists, or else hgvsc. formats hgvsp for synonymous variants. major_consequence: set to most severe consequence for that transcript ( VEP sometimes provides multiple consequences for a single transcript) major_consequence_rank: major_consequence rank based on VEP SO ontology (most severe = 1) (see http://www.ensembl.org/info/genome/variation/predicted_data.html) category: set to one of: "lof", "missense", "synonymous", "other" based on the value of major_consequence. Args: vep_root (StructExpression): root path of the VEP struct in the MT include_coding_annotations (bool): if True, fields relevant to protein-coding variants will be included """ selected_annotations = [ "biotype", "canonical", "cdna_start", "cdna_end", "codons", "gene_id", "gene_symbol", "hgvsc", "hgvsp", "transcript_id", ] if include_coding_annotations: selected_annotations.extend([ "amino_acids", "lof", "lof_filter", "lof_flags", "lof_info", "polyphen_prediction", "protein_id", "protein_start", "sift_prediction", ]) omit_consequence_terms = hl.set( omit_consequences) if omit_consequences else hl.empty_set(hl.tstr) result = hl.sorted( vep_root.transcript_consequences.map(lambda c: c.select( *selected_annotations, consequence_terms=c.consequence_terms.filter( lambda t: ~omit_consequence_terms.contains(t)), domains=c.domains.map(lambda domain: domain.db + ":" + domain.name ), major_consequence=hl.cond( c.consequence_terms.size() > 0, hl.sorted(c.consequence_terms, key=lambda t: CONSEQUENCE_TERM_RANK_LOOKUP.get(t))[0 ], hl.null(hl.tstr), ))).filter(lambda c: c.consequence_terms.size() > 0). map(lambda c: c.annotate( category=(hl.case().when( CONSEQUENCE_TERM_RANK_LOOKUP.get(c.major_consequence) <= CONSEQUENCE_TERM_RANK_LOOKUP.get("frameshift_variant"), "lof", ).when( CONSEQUENCE_TERM_RANK_LOOKUP.get(c.major_consequence) <= CONSEQUENCE_TERM_RANK_LOOKUP.get("missense_variant"), "missense", ).when( CONSEQUENCE_TERM_RANK_LOOKUP.get(c.major_consequence) <= CONSEQUENCE_TERM_RANK_LOOKUP.get("synonymous_variant"), "synonymous", ).default("other")), hgvs=get_expr_for_formatted_hgvs(c), major_consequence_rank=CONSEQUENCE_TERM_RANK_LOOKUP.get( c.major_consequence), )), lambda c: (hl.bind( lambda is_coding, is_most_severe, is_canonical: (hl.cond( is_coding, hl.cond(is_most_severe, hl.cond(is_canonical, 1, 2), hl.cond(is_canonical, 3, 4)), hl.cond(is_most_severe, hl.cond(is_canonical, 5, 6), hl.cond(is_canonical, 7, 8)), )), hl.or_else(c.biotype, "") == "protein_coding", hl.set(c.consequence_terms).contains(vep_root. most_severe_consequence), hl.or_else(c.canonical, 0) == 1, )), ) if not include_coding_annotations: # for non-coding variants, drop fields here that are hard to exclude in the above code result = result.map(lambda c: c.drop("domains", "hgvsp")) return hl.zip_with_index(result).map(lambda csq_with_index: csq_with_index[ 1].annotate(transcript_rank=csq_with_index[0]))