def mwzj_hts_by_tree(all_hts, temp_dir, globals_for_col_key, debug=False, inner_mode='overwrite', repartition_final: int = None): chunk_size = int(len(all_hts)**0.5) + 1 outer_hts = [] checkpoint_kwargs = {inner_mode: True} if repartition_final is not None: intervals = get_n_even_intervals(repartition_final) checkpoint_kwargs['_intervals'] = intervals if debug: print(f'Running chunk size {chunk_size}...') for i in range(chunk_size): if i * chunk_size >= len(all_hts): break hts = all_hts[i * chunk_size:(i + 1) * chunk_size] if debug: print( f'Going from {i * chunk_size} to {(i + 1) * chunk_size} ({len(hts)} HTs)...' ) try: if isinstance(hts[0], str): hts = list(map(lambda x: hl.read_table(x), hts)) ht = hl.Table.multi_way_zip_join(hts, 'row_field_name', 'global_field_name') except: if debug: print( f'problem in range {i * chunk_size}-{i * chunk_size + chunk_size}' ) _ = [ht.describe() for ht in hts] raise outer_hts.append( ht.checkpoint(f'{temp_dir}/temp_output_{i}.ht', **checkpoint_kwargs)) ht = hl.Table.multi_way_zip_join(outer_hts, 'row_field_name_outer', 'global_field_name_outer') ht = ht.transmute(inner_row=hl.flatmap( lambda i: hl.cond( hl.is_missing(ht.row_field_name_outer[i].row_field_name), hl.range(0, hl.len(ht.global_field_name_outer[i].global_field_name) ).map(lambda _: hl.null(ht.row_field_name_outer[ i].row_field_name.dtype.element_type)), ht. row_field_name_outer[i].row_field_name), hl.range(hl.len(ht.global_field_name_outer)))) ht = ht.transmute_globals(inner_global=hl.flatmap( lambda x: x.global_field_name, ht.global_field_name_outer)) mt = ht._unlocalize_entries('inner_row', 'inner_global', globals_for_col_key) return mt
def mwzj_hts_by_tree(all_hts, temp_dir, globals_for_col_key, debug=False, inner_mode = 'overwrite', repartition_final: int = None, read_if_exists = False): r''' Adapted from ukb_common mwzj_hts_by_tree() Uses read_clump_ht() instead of read_table() ''' chunk_size = int(len(all_hts) ** 0.5) + 1 outer_hts = [] if read_if_exists: print('\n\nWARNING: Intermediate tables will not be overwritten if they already exist\n\n') checkpoint_kwargs = {inner_mode: not read_if_exists, '_read_if_exists': read_if_exists} # if repartition_final is not None: intervals = ukb_common.get_n_even_intervals(repartition_final) checkpoint_kwargs['_intervals'] = intervals if debug: print(f'Running chunk size {chunk_size}...') for i in range(chunk_size): if i * chunk_size >= len(all_hts): break hts = all_hts[i * chunk_size:(i + 1) * chunk_size] if debug: print(f'Going from {i * chunk_size} to {(i + 1) * chunk_size} ({len(hts)} HTs)...') try: if isinstance(hts[0], str): def read_clump_ht(f): ht = hl.read_table(f) ht = ht.drop('idx') return ht hts = list(map(read_clump_ht, hts)) ht = hl.Table.multi_way_zip_join(hts, 'row_field_name', 'global_field_name') except: if debug: print(f'problem in range {i * chunk_size}-{i * chunk_size + chunk_size}') _ = [ht.describe() for ht in hts] raise outer_hts.append(ht.checkpoint(f'{temp_dir}/temp_output_{i}.ht', **checkpoint_kwargs)) ht = hl.Table.multi_way_zip_join(outer_hts, 'row_field_name_outer', 'global_field_name_outer') ht = ht.transmute(inner_row=hl.flatmap(lambda i: hl.cond(hl.is_missing(ht.row_field_name_outer[i].row_field_name), hl.range(0, hl.len(ht.global_field_name_outer[i].global_field_name)) .map(lambda _: hl.null(ht.row_field_name_outer[i].row_field_name.dtype.element_type)), ht.row_field_name_outer[i].row_field_name), hl.range(hl.len(ht.global_field_name_outer)))) ht = ht.transmute_globals(inner_global=hl.flatmap(lambda x: x.global_field_name, ht.global_field_name_outer)) mt = ht._unlocalize_entries('inner_row', 'inner_global', globals_for_col_key) return mt
def resume_mwzj(temp_dir, globals_for_col_key): r''' For resuming multiway zip join if intermediate tables have already been written ''' ls = hl.hadoop_ls(temp_dir) paths = [x['path'] for x in ls if 'temp_output' in x['path'] ] chunk_size = len(paths) outer_hts = [] for i in range(chunk_size): outer_hts.append(hl.read_table(f'{temp_dir}/temp_output_{i}.ht')) ht = hl.Table.multi_way_zip_join(outer_hts, 'row_field_name_outer', 'global_field_name_outer') ht = ht.transmute(inner_row=hl.flatmap(lambda i: hl.cond(hl.is_missing(ht.row_field_name_outer[i].row_field_name), hl.range(0, hl.len(ht.global_field_name_outer[i].global_field_name)) .map(lambda _: hl.null(ht.row_field_name_outer[i].row_field_name.dtype.element_type)), ht.row_field_name_outer[i].row_field_name), hl.range(hl.len(ht.global_field_name_outer)))) ht = ht.transmute_globals(inner_global=hl.flatmap(lambda x: x.global_field_name, ht.global_field_name_outer)) mt = ht._unlocalize_entries('inner_row', 'inner_global', globals_for_col_key) return mt
def phase_diploid_proband( locus: hl.expr.LocusExpression, alleles: hl.expr.ArrayExpression, proband_call: hl.expr.CallExpression, father_call: hl.expr.CallExpression, mother_call: hl.expr.CallExpression ) -> hl.expr.ArrayExpression: """ Returns phased genotype calls in the case of a diploid proband (autosomes, PAR regions of sex chromosomes or non-PAR regions of a female proband) :param LocusExpression locus: Locus in the trio MatrixTable :param ArrayExpression alleles: Alleles in the trio MatrixTable :param CallExpression proband_call: Input proband genotype call :param CallExpression father_call: Input father genotype call :param CallExpression mother_call: Input mother genotype call :return: Array containing: phased proband call, phased father call, phased mother call :rtype: ArrayExpression """ proband_v = proband_call.one_hot_alleles(alleles) father_v = hl.cond( locus.in_x_nonpar() | locus.in_y_nonpar(), hl.or_missing(father_call.is_haploid(), hl.array([father_call.one_hot_alleles(alleles)])), call_to_one_hot_alleles_array(father_call, alleles) ) mother_v = call_to_one_hot_alleles_array(mother_call, alleles) combinations = hl.flatmap( lambda f: hl.zip_with_index(mother_v) .filter(lambda m: m[1] + f[1] == proband_v) .map(lambda m: hl.struct(m=m[0], f=f[0])), hl.zip_with_index(father_v) ) return ( hl.or_missing( hl.is_defined(combinations) & (hl.len(combinations) == 1), hl.array([ hl.call(father_call[combinations[0].f], mother_call[combinations[0].m], phased=True), hl.cond(father_call.is_haploid(), hl.call(father_call[0], phased=True), phase_parent_call(father_call, combinations[0].f)), phase_parent_call(mother_call, combinations[0].m) ]) ) )
def annotate_fields(mt, gencode_release, gencode_path): genotypes = hl.agg.collect( hl.struct(sample_id=mt.s, gq=mt.GQ, cn=mt.RD_CN, num_alt=hl.if_else(hl.is_defined(mt.GT), mt.GT.n_alt_alleles(), -1))) rows = mt.annotate_rows(genotypes=genotypes).rows() rows = rows.annotate(**{k: v(rows) for k, v in CORE_FIELDS.items()}) gene_id_mapping = hl.literal( load_gencode(gencode_release, download_path=gencode_path)) rows = rows.annotate( sortedTranscriptConsequences=hl.flatmap( lambda x: x, hl.filter(lambda x: hl.is_defined(x), [ rows.info[col].map(lambda gene: hl.struct( gene_symbol=gene, gene_id=gene_id_mapping[gene], predicted_consequence=col.split('__')[-1])) for col in [ gene_col for gene_col in rows.info if gene_col.startswith('PROTEIN_CODING__') and rows.info[gene_col].dtype == hl.dtype('array<str>') ] ])), sv_type=rows.alleles[1].replace('[<>]', '').split(':', 2), ) DERIVED_FIELDS.update({ 'filters': lambda rows: hl.if_else( hl.len(rows.filters) > 0, rows.filters, hl.missing(hl.dtype('array<str>'))) }) rows = rows.annotate(**{k: v(rows) for k, v in DERIVED_FIELDS.items()}) rows = rows.rename({'rsid': 'variantId'}) return rows.key_by().select(*FIELDS)
def load_icd_data(pre_phesant_data_path, icd_codings_path, temp_directory, force_overwrite_intermediate: bool = False, include_dates: bool = False, icd9: bool = False): """ Load raw (pre-PHESANT) phenotype data and extract ICD codes into hail MatrixTable with booleans as entries :param str pre_phesant_data_path: Input phenotype file :param str icd_codings_path: Input coding metadata :param str temp_directory: Temp bucket/directory to write intermediate file :param bool force_overwrite_intermediate: Whether to overwrite intermediate loaded file :param bool include_dates: Whether to also load date data (not implemented yet) :param bool icd9: Whether to load ICD9 data :return: MatrixTable with ICD codes :rtype: MatrixTable """ if icd9: code_locations = {'primary_codes': '41203', 'secondary_codes': '41205'} else: code_locations = { 'primary_codes': '41202', 'secondary_codes': '41204', 'external_codes': '41201', 'cause_of_death_codes': '40001' } date_locations = {'primary_codes': '41262'} ht = hl.import_table(pre_phesant_data_path, impute=not icd9, min_partitions=100, missing='', key='userId', types={'userId': hl.tint32}) ht = ht.checkpoint(f'{temp_directory}/pre_phesant.ht', _read_if_exists=not force_overwrite_intermediate) all_phenos = list(ht.row_value) fields_to_select = { code: [ht[x] for x in all_phenos if x.startswith(f'x{loc}')] for code, loc in code_locations.items() } if include_dates: fields_to_select.update({ f'date_{code}': [ht[x] for x in all_phenos if x.startswith(f'x{loc}')] for code, loc in date_locations.items() }) ht = ht.select(**fields_to_select) ht = ht.annotate( **{ code: ht[code].filter(lambda x: hl.is_defined(x)) for code in code_locations }, # **{f'date_{code}': ht[code].filter(lambda x: hl.is_defined(x)) for code in date_locations} ) # ht = ht.annotate(primary_codes_with_date=hl.dict(hl.zip(ht.primary_codes, ht.date_primary_codes))) all_codes = hl.sorted( hl.array( hl.set( hl.flatmap( lambda x: hl.array(x), ht.aggregate([ hl.agg.explode(lambda c: hl.agg.collect_as_set(c), ht[code]) for code in code_locations ], _localize=True))))) ht = ht.select(bool_codes=all_codes.map(lambda x: hl.struct( **{code: ht[code].contains(x) for code in code_locations}))) ht = ht.annotate_globals( all_codes=all_codes.map(lambda x: hl.struct(icd_code=x))) mt = ht._unlocalize_entries('bool_codes', 'all_codes', ['icd_code']) mt = mt.annotate_entries( any_codes=hl.any(lambda x: x, list(mt.entry.values()))) # mt = mt.annotate_entries(date=hl.cond(mt.primary_codes, mt.primary_codes_with_date[mt.icd_code], hl.null(hl.tstr))) mt = mt.annotate_cols(truncated=False).annotate_globals( code_locations=code_locations) mt = mt.checkpoint(f'{temp_directory}/raw_icd.mt', _read_if_exists=not force_overwrite_intermediate) trunc_mt = mt.filter_cols((hl.len(mt.icd_code) == 3) | (hl.len(mt.icd_code) == 4)) trunc_mt = trunc_mt.key_cols_by(icd_code=trunc_mt.icd_code[:3]) trunc_mt = trunc_mt.group_cols_by('icd_code').aggregate_entries( **{ code: hl.agg.any(trunc_mt[code]) for code in list(code_locations.keys()) + ['any_codes'] }).aggregate_cols(n_phenos_truncated=hl.agg.count()).result() trunc_mt = trunc_mt.filter_cols(trunc_mt.n_phenos_truncated > 1) trunc_mt = trunc_mt.annotate_cols( **mt.cols().drop('truncated', 'code_locations')[trunc_mt.icd_code], truncated=True).drop('n_phenos_truncated') mt = mt.union_cols(trunc_mt) coding_ht = hl.read_table(icd_codings_path) return mt.annotate_cols(**coding_ht[mt.col_key])