def injest_vcf(input_vcfs, output_folder='db/Tables'): # sanitize chr recode = {f"chr{i}": f"{i}" for i in (list(range(1, 23)) + ['X', 'Y'])} if type(input_vcfs) == str: input_vcfs = [input_vcfs] for input_vcf in tqdm(input_vcfs): vcf_name = (input_vcf.split('/')[-1]).split('.')[0] try: # load vcf mt = hl.methods.import_vcf(input_vcf, contig_recoding=recode, force_bgz=True, reference_genome=None) # clean information mt = (mt.select_entries(mt.GT, mt.DP, mt.GQ)) mt = mt.select_rows() mt = mt.select_cols() valid = hl.is_valid_locus(mt.locus.contig, mt.locus.position, reference_genome='GRCh37') mt_wrong = mt.filter_rows(~valid) mt_correct = mt.filter_rows(valid) mt_correct = mt_correct.annotate_entries(GT=mt_correct.GT[0] + mt_correct.GT[1]) mt_correct = mt_correct.annotate_entries( GT=hl.coalesce(mt_correct.GT, -1)) mt_correct = mt_correct.annotate_entries( DP=hl.coalesce(mt_correct.DP, 0)) mt_correct = mt_correct.annotate_entries( GQ=hl.coalesce(mt_correct.GQ, 0)) # store correct variants in MatrixTables if mt_correct.rows().count() > 0: mt_correct.entries().write(f'{output_folder}/{vcf_name}.ht', overwrite=True) # store incorrect variants in tsv table if mt_wrong.rows().count() > 0: mt_wrong.rows().export(f'db/errors/{vcf_name}.tsv') except Exception as e: with open('invalid_vcf.txt', 'a') as f: f.write(f'{input_vcf}\n') with open('error_log.txt', 'a') as log: log.write(f'{input_vcf}\n{e}\n') return 'done'
def coalesce_join(ref, var): call_field = 'GT' if 'GT' in var else 'LGT' assert call_field in var, var.dtype merged_fields = {} merged_fields[call_field] = hl.coalesce(var[call_field], hl.call(0, 0)) for field in ref.dtype: if field in var: merged_fields[field] = hl.coalesce(var[field], ref[field]) return hl.struct(**merged_fields).annotate(**{f: var[f] for f in var if f not in merged_fields})
def _get(self, var_df, samples, field): # array with samples in HAIL if type(samples) is str: samples = [samples] # create table with vars in HAIL ht = hl.Table.from_pandas(var_df) # create table with samples df = pd.DataFrame({'s': samples}) ht_samples = hl.Table.from_pandas(df) ht = ht.join(ht_samples) ht = ht.annotate(pos=hl.int32(ht.pos)) ht = ht.add_index() ht = ht.key_by(locus=hl.struct(contig=ht.chrom, position=ht.pos), alleles=hl.array([ht.ref, ht.alt]), s=ht.s) # all variants per sample res_table = None ht_paths = self._get_Tables_paths(samples) # iterate through ht_vcfs with samples for ht_path in ht_paths: ht_vcf = hl.read_table(ht_path) ht_n = ht.join(ht_vcf, how='left') if res_table is None: res_table = ht_n res_table = res_table.checkpoint('db/checkpoint/ht1.ht', overwrite=True) else: res_table = res_table.union(ht_n) # all variants per sample res_table = res_table.annotate(GT=hl.coalesce(res_table.GT, 0)) res_table = res_table.annotate(DP=hl.coalesce(res_table.DP, 0)) res_table = res_table.annotate(GQ=hl.coalesce(res_table.GQ, 0)) res_table = res_table.order_by(res_table.idx) res_table = res_table.checkpoint('db/checkpoint/ht2.ht', overwrite=True) return np.column_stack([ np.array(res_table.filter( res_table.s == sample)[field].collect()).reshape(-1, 1) for sample in samples ])
def compute_n_cases(mt, data_type): if data_type == 'icd': extra_fields = dict( n_cases=hl.agg.count_where(mt.primary_codes), n_controls=hl.agg.count_where(~mt.primary_codes), n_cases_secondary=hl.agg.count_where(mt.secondary_codes), n_cases_cause_of_death=hl.agg.count_where(mt.cause_of_death_codes), n_cases_all=hl.agg.count_where(mt.primary_codes | mt.secondary_codes | mt.external_codes | mt.cause_of_death_codes), n_controls_all=hl.agg.count_where( ~(mt.primary_codes | mt.secondary_codes | mt.external_codes | mt.cause_of_death_codes))) else: extra_fields = { 'Field': hl.coalesce(mt.both_sexes_pheno.Field, mt.females_pheno.Field, mt.males_pheno.Field) } if data_type == 'categorical': extra_fields.update({ 'n_cases': hl.agg.count_where(mt.both_sexes), 'n_controls': hl.agg.count_where(~mt.both_sexes), 'meaning': hl.coalesce(mt.both_sexes_pheno.meaning, mt.females_pheno.meaning, mt.males_pheno.meaning) }) else: extra_fields.update({ 'n_defined': hl.agg.count_where(hl.is_defined(mt.both_sexes)), 'n_defined_females': hl.agg.count_where(hl.is_defined(mt.females)), 'n_defined_males': hl.agg.count_where(hl.is_defined(mt.males)), }) return extra_fields
def multi_way_union_mts(mts: list, tmp_dir: str, chunk_size: int) -> hl.MatrixTable: """Joins MatrixTables in the provided list :param list mts: list of MatrixTables to join together :param str tmp_dir: path to temporary directory for intermediate results :param int chunk_size: number of MatrixTables to join per chunk :return: joined MatrixTable :rtype: MatrixTable """ staging = [mt.localize_entries("__entries", "__cols") for mt in mts] stage = 0 while len(staging) > 1: n_jobs = int(math.ceil(len(staging) / chunk_size)) info(f"multi_way_union_mts: stage {stage}: {n_jobs} total jobs") next_stage = [] for i in range(n_jobs): to_merge = staging[chunk_size * i:chunk_size * (i + 1)] info( f"multi_way_union_mts: stage {stage} / job {i}: merging {len(to_merge)} inputs" ) merged = hl.Table.multi_way_zip_join(to_merge, "__entries", "__cols") merged = merged.annotate(__entries=hl.flatten( hl.range(hl.len(merged.__entries)).map(lambda i: hl.coalesce( merged.__entries[i].__entries, hl.range(hl.len(merged.__cols[i].__cols)).map( lambda j: hl.null(merged.__entries.__entries.dtype. element_type.element_type)), )))) merged = merged.annotate_globals( __cols=hl.flatten(merged.__cols.map(lambda x: x.__cols))) next_stage.append( merged.checkpoint(os.path.join(tmp_dir, f"stage_{stage}_job_{i}.ht"), overwrite=True)) info(f"done stage {stage}") stage += 1 staging.clear() staging.extend(next_stage) return (staging[0]._unlocalize_entries( "__entries", "__cols", list(mts[0].col_key)).unfilter_entries())
def merge_arrays(r_array, v_array): def rewrite_ref(r): ref_block_selector = {} for k, t in merged_schema.items(): if k == 'LA': ref_block_selector[k] = hl.literal([0]) elif k in ('LGT', 'GT'): ref_block_selector[k] = hl.call(0, 0) else: ref_block_selector[k] = r[k] if k in r else hl.missing(t) return r.select(**ref_block_selector) def rewrite_var(v): return v.select(**{ k: v[k] if k in v else hl.missing(t) for k, t in merged_schema.items() }) return hl.case() \ .when(hl.is_missing(r_array), v_array.map(rewrite_var)) \ .when(hl.is_missing(v_array), r_array.map(rewrite_ref)) \ .default(hl.zip(r_array, v_array).map(lambda t: hl.coalesce(rewrite_var(t[1]), rewrite_ref(t[0]))))
def compare_doubletons_to_related(tranche_data: Tuple[str, int] = TRANCHE_DATA, temp_path: str = TEMP_PATH) -> None: """ Get sample pairs that share doubletons and compare these pairs to samples in 455k relatedness Table. :param tranche_data: UKB tranche data (data source and data freeze number). Default is TRANCHE_DATA. :param temp_path: Path to bucket to store Table and other temporary data. Default is TEMP_PATH. :return: None; function prints information to stdout. """ ht = get_doubleton_samples() rel_ht = hl.read_table(relatedness_ht_path(*tranche_data)) rel_ht = rel_ht.key_by(i=rel_ht.i.s, j=rel_ht.j.s) logger.info( "Annotating the doubleton sample pairs with relatedness information..." ) ht = ht.annotate( rel_def=(hl.is_defined(rel_ht[ht.s1, ht.s2]) | hl.is_defined(rel_ht[ht.s2, ht.s1])), kin=hl.coalesce( rel_ht[ht.s1, ht.s2].kin, rel_ht[ht.s2, ht.s1].kin, ), relationship=hl.coalesce( rel_ht[ht.s1, ht.s2].relationship, rel_ht[ht.s2, ht.s1].relationship, ), ) ht = ht.checkpoint(f"{temp_path}/doubletons_uniq_rel.ht", overwrite=True) ht.show() def _get_agg_struct(ht: hl.Table) -> hl.expr.StructExpression: """ Aggregate input Table and return StructExpression describing doubleton pairs. Return count of pairs present in relatedness HT, kinship distribution stats, and dictionary counting relationship types. Assumes Table is annotated with: - `rel_def`: Boolean for whether pair was present in relatedness Table. - `kin`: Kinship value for sample pair. - `relationship`: Relationship of sample pair (if found in relatedness Table). :param hl.Table ht: Input Table. :return: StructExpression describing doubleton pairs. """ return ht.aggregate( hl.struct( pair_in_relatedness_ht=hl.agg.count_where(ht.rel_def), kin_stats=hl.agg.stats(ht.kin), rel_counter=hl.agg.counter(ht.relationship), total_pairs=hl.agg.count(), )) logger.info( "Results from HT aggregate before removing 'unrelated' relationships: %s", _get_agg_struct(ht), ) ht = ht.filter(ht.relationship != "unrelated") logger.info( "Results from HT aggregate after removing 'unrelated' and undefined relationships: %s", _get_agg_struct(ht), )
def get_idx(struct): return hl.cond(hl.is_missing(struct), 0, hl.coalesce(2 + struct.GT.n_alt_alleles(), 1))
def main(args): hl.init(log='/assign_phecodes.log') # Read in the Phecode (v1.2b1) <-> ICD 9/10 codes mapping with hadoop_open( 'gs://ukb-diverse-pops/phecode/UKB_Phecode_v1.2b1_ICD_Mapping.txt', 'r') as f: df = pd.read_csv(f, delimiter='\t', dtype=str) list_of_icd_codes_to_include = [ row.icd_codes.split(',') for _, row in df.iterrows() ] list_of_phecodes_to_exclude = [ row.exclude_phecodes.split(',') for _, row in df.iterrows() ] df['icd_codes'] = list_of_icd_codes_to_include df['exclude_phecodes'] = list_of_phecodes_to_exclude # Convert it to HailTable phecode_ht = hl.Table.from_pandas(df) phecode_ht = phecode_ht.key_by('icd_codes') phecode_ht = phecode_ht.checkpoint( 'gs://ukb-diverse-pops/phecode/UKB_Phecode_v1.2b1_ICD_Mapping.ht', overwrite=args.overwrite) # Retreive UKB ICD MatrixTable and combine codes based on Phecode definitions icd_all = hl.read_matrix_table(get_ukb_pheno_mt_path('icd_all')) mt = combine_phenotypes(icd_all, icd_all.icd_code, icd_all.any_codes, list_of_icd_codes_to_include, new_col_name='icd_codes', new_entry_name='include_to_cases') mt = mt.annotate_cols( phecode=phecode_ht[mt.icd_codes].phecode, phecode_sex=phecode_ht[mt.icd_codes].sex, phecode_description=phecode_ht[mt.icd_codes].description, phecode_group=phecode_ht[mt.icd_codes].group, exclude_phecodes=phecode_ht[mt.icd_codes].exclude_phecodes) # Annotate sex for sex-specific phenotypes ukb_pheno_ht = hl.read_table(get_ukb_pheno_ht_path()) mt = mt.annotate_rows(isFemale=ukb_pheno_ht[mt.userId].sex == 0) mt = checkpoint_tmp(mt) # Compute phecode excluded from controls mt = mt.key_cols_by() exclude_mt = combine_phenotypes(mt, mt.phecode, mt.include_to_cases, list_of_phecodes_to_exclude, new_entry_name='exclude_from_controls') exclude_mt = checkpoint_tmp(exclude_mt) # Annotate exclusion mt = mt.key_cols_by('exclude_phecodes') mt = mt.annotate_entries( exclude_sex=(hl.switch(mt.phecode_sex).when("males", mt.isFemale).when( "females", ~mt.isFemale).default(False)), exclude_from_controls=hl.coalesce( exclude_mt[mt.userId, mt.exclude_phecodes].exclude_from_controls, False)) # Compute final case/control status # `case_control` becomes missing (NA) if a sample 1) is excluded because of sex, 2) is not cases and excluded from controls. mt = mt.annotate_entries(case_control=hl.if_else( mt.exclude_sex | (~mt.include_to_cases & mt.exclude_from_controls), hl.null(hl.tbool), mt.include_to_cases)) mt = mt.key_cols_by('phecode') mt.describe() mt.write(get_ukb_pheno_mt_path('phecode'), overwrite=args.overwrite)
sample_file=ukb_sf, index_file_map=file_map, _row_fields=['rsid']) # Extracting SNPs of interest mt_f = hl.filter_intervals(mt, ploci) mt_f = hl.variant_qc(mt_f) chromdat['chrompos'] = chromdat['chrom'] + ':' + chromdat[ 'hg19_pos'].astype(str) chromdat_hl = hl.Table.from_pandas(chromdat) chromdat_hl = chromdat_hl.annotate( locus=hl.parse_locus(chromdat_hl.chrompos, reference_genome='GRCh37')) chromdat_hl = chromdat_hl.key_by('locus') mt_f = mt_f.annotate_rows(**chromdat_hl[mt_f.locus]) flip = hl.case().when(mt_f.ea == mt_f.alleles[0], True).when(mt_f.ea == mt_f.alleles[1], False).or_missing() mt_f = mt_f.annotate_rows(flip=flip) mt_f = mt_f.annotate_rows( prior=2 * hl.if_else(mt_f.flip, mt_f.variant_qc.AF[0], mt_f.variant_qc.AF[1])) mt_f = mt_f.select_entries(G=hl.coalesce( hl.if_else(mt_f.flip, 2 - mt_f.GT.n_alt_alleles(), mt_f.GT.n_alt_alleles()), mt_f.prior)) ## Exporting result output = '/ludc/Home/daniel_c/dva/files/ukbgeno/chrom{}.vcf.bgz'.format(ch) hl.export_vcf(mt_f, output) ## Removing log file logfile = glob.glob('*.log') os.remove(logfile[0])
def to_merged_sparse_mt(vds: 'VariantDataset') -> 'MatrixTable': """Creates a single, merged sparse :class:'.MatrixTable' from the split :class:`.VariantDataset` representation. Parameters ---------- vds : :class:`.VariantDataset` Dataset in VariantDataset representation. Returns ------- :class:`.MatrixTable` Dataset in the merged sparse MatrixTable representation. """ rht = vds.reference_data.localize_entries('_ref_entries', '_ref_cols') vht = vds.variant_data.localize_entries('_var_entries', '_var_cols') # drop 'alleles' key for join vht = vht.key_by('locus') merged_schema = {} for e in vds.variant_data.entry: merged_schema[e] = vds.variant_data[e].dtype for e in vds.reference_data.entry: if e in merged_schema: if not merged_schema[e] == vds.reference_data[e].dtype: raise TypeError(f"cannot unify field {e!r}: {merged_schema[e]}, {vds.reference_data[e].dtype}") else: merged_schema[e] = vds.reference_data[e].dtype ht = rht.join(vht, how='outer').drop('_ref_cols') def merge_arrays(r_array, v_array): def rewrite_ref(r): ref_block_selector = {} for k, t in merged_schema.items(): if k == 'LA': ref_block_selector[k] = hl.literal([0]) elif k in ('LGT', 'GT'): ref_block_selector[k] = hl.call(0, 0) else: ref_block_selector[k] = r[k] if k in r else hl.missing(t) return r.select(**ref_block_selector) def rewrite_var(v): return v.select(**{ k: v[k] if k in v else hl.missing(t) for k, t in merged_schema.items() }) return hl.case() \ .when(hl.is_missing(r_array), v_array.map(rewrite_var)) \ .when(hl.is_missing(v_array), r_array.map(rewrite_ref)) \ .default(hl.zip(r_array, v_array).map(lambda t: hl.coalesce(rewrite_var(t[1]), rewrite_ref(t[0])))) ht = ht.select( alleles=hl.coalesce(ht['alleles'], hl.array([ht['ref_allele']])), # handle cases where vmt is not keyed by alleles **{k: ht[k] for k in vds.variant_data.row_value if k != 'alleles'}, _entries=merge_arrays(ht['_ref_entries'], ht['_var_entries']) ) ht = ht._key_by_assert_sorted('locus', 'alleles') return ht._unlocalize_entries('_entries', '_var_cols', list(vds.variant_data.col_key))
def combine_pheno_files_multi_sex(pheno_file_dict: dict, cov_ht: hl.Table, truncated_codes_only: bool = True, custom_data_categorical: bool = True): full_mt: hl.MatrixTable = None sexes = ('both_sexes', 'females', 'males') for data_type, mt in pheno_file_dict.items(): mt = mt.select_rows(**cov_ht[mt.row_key]) print(data_type) if data_type == 'phecode': mt = mt.key_cols_by(trait_type=data_type, phenocode=mt.phecode, pheno_sex=mt.phecode_sex, coding=NULL_STR_KEY, modifier=NULL_STR_KEY) mt = mt.select_cols(**compute_cases_binary(mt.case_control, mt.sex), description=mt.phecode_description, description_more=NULL_STR, coding_description=NULL_STR, category=mt.phecode_group) mt = mt.select_entries(**format_entries(mt.case_control, mt.sex)) elif data_type == 'prescriptions': def format_prescription_name(pheno): return pheno.replace(',', '|').replace('/', '_') mt = mt.select_entries( value=hl.or_else(hl.len(mt.values) > 0, False)) mt2 = mt.group_cols_by( trait_type=data_type, phenocode=format_prescription_name( mt.Drug_Category_and_Indication), pheno_sex='both_sexes', coding=NULL_STR_KEY, modifier=NULL_STR_KEY, ).aggregate(value=hl.agg.any(mt.value)).select_cols( category=NULL_STR) mt = mt.key_cols_by(trait_type=data_type, phenocode=format_prescription_name( mt.Generic_Name), pheno_sex='both_sexes', coding=NULL_STR_KEY, modifier=NULL_STR_KEY) mt = mt.select_cols(category=mt.Drug_Category_and_Indication) mt = mt.union_cols(mt2) mt = mt.select_cols(**compute_cases_binary(mt.value, mt.sex), description=NULL_STR, description_more=NULL_STR, coding_description=NULL_STR, category=mt.category) mt = mt.select_entries(**format_entries(mt.value, mt.sex)) elif data_type == 'custom': mt = mt.select_entries(**format_entries(mt.value, mt.sex)) mt = mt.select_cols(**{ f'n_cases_{sex}': hl.agg.count_where( hl.cond(mt.trait_type == 'categorical', mt[sex] == 1.0, hl.is_defined(mt[sex]))) for sex in sexes }, description=NULL_STR, description_more=NULL_STR, coding_description=NULL_STR, category=mt.category) elif data_type == 'additional': mt = mt.key_cols_by(trait_type='continuous', phenocode=mt.pheno, pheno_sex='both_sexes', coding=NULL_STR_KEY, modifier=mt.coding) mt = mt.select_cols( **{ f'n_cases_{sex}': hl.agg.count_where(hl.is_defined(mt[sex])) for sex in sexes }, description=hl.coalesce( *[mt[f'{sex}_pheno'].meaning for sex in sexes]), description_more=hl.coalesce( *[mt[f'{sex}_pheno'].description for sex in sexes]), coding_description=NULL_STR, category=NULL_STR) elif data_type in ('categorical', 'continuous'): mt = mt.key_cols_by(trait_type=data_type, phenocode=hl.str(mt.pheno), pheno_sex='both_sexes', coding=mt.coding if data_type == 'categorical' else NULL_STR_KEY, modifier=NULL_STR_KEY if data_type == 'categorical' else mt.coding) def check_func(x): return x if data_type == 'categorical' else hl.is_defined(x) mt = mt.select_cols( **{ f'n_cases_{sex}': hl.agg.count_where(check_func(mt[sex])) for sex in sexes }, description=hl.coalesce( *[mt[f'{sex}_pheno'].Field for sex in sexes]), description_more=hl.coalesce( *[mt[f'{sex}_pheno'].Notes for sex in sexes]), coding_description=hl.coalesce( *[mt[f'{sex}_pheno'].meaning for sex in sexes]) if data_type == 'categorical' else NULL_STR, category=hl.coalesce( *[mt[f'{sex}_pheno'].Path for sex in sexes])) mt = mt.select_entries( **{sex: hl.float64(mt[sex]) for sex in sexes}) elif 'icd_code' in list(mt.col_key): icd_version = mt.icd_version if 'icd_version' in list( mt.col) else '' mt = mt.key_cols_by(trait_type=icd_version, phenocode=mt.icd_code, pheno_sex='both_sexes', coding=NULL_STR_KEY, modifier=NULL_STR_KEY) if truncated_codes_only: mt = mt.filter_cols(hl.len(mt.icd_code) == 3) mt = mt.collect_cols_by_key() mt = mt.annotate_cols(keep=hl.if_else( hl.len(mt.truncated) == 1, 0, hl.zip_with_index(mt.truncated).filter(lambda x: x[1]).map( lambda x: x[0])[0])) mt = mt.select_entries(**{x: mt[x][mt.keep] for x in mt.entry}) mt = mt.select_cols( **{x: mt[x][mt.keep] for x in mt.col_value if x != 'keep'}) mt = mt.select_cols(**compute_cases_binary(mt.any_codes, mt.sex), description=mt.short_meaning, description_more="truncated: " + hl.str(mt.truncated) if 'truncated' in list( mt.col_value) else NULL_STR, coding_description=NULL_STR, category=mt.meaning) mt = mt.select_entries(**format_entries(mt.any_codes, mt.sex)) elif data_type == 'icd_first_occurrence': mt = mt.select_entries(**format_entries(mt.value, mt.sex)) mt = mt.select_cols(**compute_cases_binary( hl.is_defined(mt.both_sexes), mt.sex), description=mt.Field, description_more=mt.Notes, coding_description=NULL_STR, category=mt.Path) else: # 'biomarkers', 'activity_monitor' mt = mt.key_cols_by(trait_type=mt.trait_type if 'trait_type' in list(mt.col) else data_type, phenocode=hl.str(mt.pheno), pheno_sex='both_sexes', coding=NULL_STR_KEY, modifier=NULL_STR_KEY) mt = mt.select_entries(**format_entries(mt.value, mt.sex)) mt = mt.select_cols(**{ f'n_cases_{sex}': hl.agg.count_where(hl.is_defined(mt[sex])) for sex in sexes }, description=mt.Field, description_more=NULL_STR, coding_description=NULL_STR, category=mt.Path) # else: # raise ValueError('pheno or icd_code not in column key. New data type?') mt = mt.checkpoint( tempfile.mktemp(prefix=f'/tmp/{data_type}_', suffix='.mt')) if full_mt is None: full_mt = mt else: full_mt = full_mt.union_cols(mt, row_join_type='outer' if data_type == 'prescriptions' else 'inner') full_mt = full_mt.unfilter_entries() # Here because prescription data was smaller than the others (so need to set the missing samples to 0) return full_mt.select_entries( **{ sex: hl.cond(full_mt.trait_type == 'prescriptions', hl.or_else(full_mt[sex], hl.float64(0.0)), full_mt[sex]) for sex in sexes })
, "feature_truncation" , "intergenic_variant" ] mt4 = kt2.annotate(transcript_canonicals =kt2.vep.transcript_consequences.filter(lambda tc: tc.canonical == 1)) mt4 = mt4.annotate( all_transcript_terms = hl.set(hl.flatten(mt4.transcript_canonicals.map(lambda x: x.consequence_terms))) ) mt4 = mt4.annotate( Consequence = hl.coalesce( # coalesce means take first non-missing hl.literal(consequence_in_severity_order).filter(lambda cnsq: mt4.all_transcript_terms.contains(cnsq) ).head(), mt4.vep.most_severe_consequence ) ) mt4 = mt4.annotate( Gene = hl.if_else(mt4.transcript_canonicals.any(lambda tc: hl.set(tc.consequence_terms).contains(mt4.Consequence)), \ mt4.transcript_canonicals.find(lambda tc: (tc.canonical == 1) & (hl.set(tc.consequence_terms).contains(mt4.Consequence))).gene_symbol, \ mt4.vep.transcript_consequences.find(lambda tc: hl.set(tc.consequence_terms).contains(mt4.Consequence)).gene_symbol), \ hgvsp = hl.if_else(mt4.transcript_canonicals.any(lambda tc: hl.set(tc.consequence_terms).contains(mt4.Consequence)), \ mt4.transcript_canonicals.find(lambda tc: (tc.canonical == 1) & (hl.set(tc.consequence_terms).contains(mt4.Consequence))).hgvsp, \ mt4.vep.transcript_consequences.find(lambda tc: hl.set(tc.consequence_terms).contains(mt4.Consequence)).hgvsp), hgvsc = hl.if_else(mt4.transcript_canonicals.any(lambda tc: hl.set(tc.consequence_terms).contains(mt4.Consequence)), \ mt4.transcript_canonicals.find(lambda tc: (tc.canonical == 1) & (hl.set(tc.consequence_terms).contains(mt4.Consequence))).hgvsc, \
def get_non_missing_field(mt, field_name): return hl.coalesce(*[mt[f'{sex}_pheno'][field_name] for sex in sexes])
def segment_reference_blocks(ref: 'MatrixTable', intervals: 'Table') -> 'MatrixTable': """Returns a matrix table of reference blocks segmented according to intervals. Loci outside the given intervals are discarded. Reference blocks that start before but span an interval will appear at the interval start locus. Note ---- Assumes disjoint intervals which do not span contigs. Requires start-inclusive intervals. Parameters ---------- ref : :class:`.MatrixTable` MatrixTable of reference blocks. intervals : :class:`.Table` Table of intervals at which to segment reference blocks. Returns ------- :class:`.MatrixTable` """ interval_field = list(intervals.key)[0] if not intervals[interval_field].dtype == hl.tinterval(ref.locus.dtype): raise ValueError( f"expect intervals to be keyed by intervals of loci matching the VariantDataset:" f" found {intervals[interval_field].dtype} / {ref.locus.dtype}") intervals = intervals.select(_interval_dup=intervals[interval_field]) if not intervals.aggregate( hl.agg.all(intervals[interval_field].includes_start & (intervals[interval_field].start.contig == intervals[interval_field].end.contig))): raise ValueError("expect intervals to be start-inclusive") starts = intervals.key_by(_start_locus=intervals[interval_field].start) starts = starts.annotate(_include_locus=True) refl = ref.localize_entries('_ref_entries', '_ref_cols') joined = refl.join(starts, how='outer') rg = ref.locus.dtype.reference_genome contigs = rg.contigs contig_idx_map = hl.literal({contigs[i]: i for i in range(len(contigs))}, 'dict<str, int32>') joined = joined.annotate(__contig_idx=contig_idx_map[joined.locus.contig]) joined = joined.annotate(_ref_entries=joined._ref_entries.map( lambda e: e.annotate(__contig_idx=joined.__contig_idx))) dense = joined.annotate(dense_ref=hl.or_missing( joined._include_locus, hl.rbind( joined.locus.position, lambda pos: hl.enumerate( hl.scan._densify(hl.len(joined._ref_cols), joined._ref_entries) ).map(lambda idx_and_e: hl.rbind( idx_and_e[0], idx_and_e[1], lambda idx, e: hl.coalesce( joined._ref_entries[idx], hl.or_missing((e.__contig_idx == joined.__contig_idx) & (e.END >= pos), e))).drop('__contig_idx'))))) dense = dense.filter(dense._include_locus).drop('_interval_dup', '_include_locus', '__contig_idx') # at this point, 'dense' is a table with dense rows of reference blocks, keyed by locus refl_filtered = refl.annotate( **{interval_field: intervals[refl.locus]._interval_dup}) # remove rows that are not contained in an interval, and rows that are the start of an # interval (interval starts come from the 'dense' table) refl_filtered = refl_filtered.filter( hl.is_defined(refl_filtered[interval_field]) & (refl_filtered.locus != refl_filtered[interval_field].start)) # union dense interval starts with filtered table refl_filtered = refl_filtered.union( dense.transmute(_ref_entries=dense.dense_ref)) # rewrite reference blocks to end at the first of (interval end, reference block end) refl_filtered = refl_filtered.annotate( interval_end=refl_filtered[interval_field].end.position - ~refl_filtered[interval_field].includes_end) refl_filtered = refl_filtered.annotate( _ref_entries=refl_filtered._ref_entries.map( lambda entry: entry.annotate(END=hl.min(entry.END, refl_filtered. interval_end)))) return refl_filtered._unlocalize_entries('_ref_entries', '_ref_cols', list(ref.col_key))