Exemple #1
0
def mwzj_hts_by_tree(all_hts,
                     temp_dir,
                     globals_for_col_key,
                     debug=False,
                     inner_mode='overwrite',
                     repartition_final: int = None):
    chunk_size = int(len(all_hts)**0.5) + 1
    outer_hts = []

    checkpoint_kwargs = {inner_mode: True}
    if repartition_final is not None:
        intervals = get_n_even_intervals(repartition_final)
        checkpoint_kwargs['_intervals'] = intervals

    if debug: print(f'Running chunk size {chunk_size}...')
    for i in range(chunk_size):
        if i * chunk_size >= len(all_hts): break
        hts = all_hts[i * chunk_size:(i + 1) * chunk_size]
        if debug:
            print(
                f'Going from {i * chunk_size} to {(i + 1) * chunk_size} ({len(hts)} HTs)...'
            )
        try:
            if isinstance(hts[0], str):
                hts = list(map(lambda x: hl.read_table(x), hts))
            ht = hl.Table.multi_way_zip_join(hts, 'row_field_name',
                                             'global_field_name')
        except:
            if debug:
                print(
                    f'problem in range {i * chunk_size}-{i * chunk_size + chunk_size}'
                )
                _ = [ht.describe() for ht in hts]
            raise
        outer_hts.append(
            ht.checkpoint(f'{temp_dir}/temp_output_{i}.ht',
                          **checkpoint_kwargs))
    ht = hl.Table.multi_way_zip_join(outer_hts, 'row_field_name_outer',
                                     'global_field_name_outer')
    ht = ht.transmute(inner_row=hl.flatmap(
        lambda i: hl.cond(
            hl.is_missing(ht.row_field_name_outer[i].row_field_name),
            hl.range(0, hl.len(ht.global_field_name_outer[i].global_field_name)
                     ).map(lambda _: hl.null(ht.row_field_name_outer[
                         i].row_field_name.dtype.element_type)), ht.
            row_field_name_outer[i].row_field_name),
        hl.range(hl.len(ht.global_field_name_outer))))
    ht = ht.transmute_globals(inner_global=hl.flatmap(
        lambda x: x.global_field_name, ht.global_field_name_outer))
    mt = ht._unlocalize_entries('inner_row', 'inner_global',
                                globals_for_col_key)
    return mt
def mwzj_hts_by_tree(all_hts, temp_dir, globals_for_col_key, 
                     debug=False, inner_mode = 'overwrite', repartition_final: int = None,
                     read_if_exists = False):
    r'''
    Adapted from ukb_common mwzj_hts_by_tree()
    Uses read_clump_ht() instead of read_table()
    '''
    chunk_size = int(len(all_hts) ** 0.5) + 1
    outer_hts = []
    
    if read_if_exists: print('\n\nWARNING: Intermediate tables will not be overwritten if they already exist\n\n')
    
    checkpoint_kwargs = {inner_mode: not read_if_exists,
                         '_read_if_exists': read_if_exists} #
    if repartition_final is not None:
        intervals = ukb_common.get_n_even_intervals(repartition_final)
        checkpoint_kwargs['_intervals'] = intervals
    
    if debug: print(f'Running chunk size {chunk_size}...')
    for i in range(chunk_size):
        if i * chunk_size >= len(all_hts): break
        hts = all_hts[i * chunk_size:(i + 1) * chunk_size]
        if debug: print(f'Going from {i * chunk_size} to {(i + 1) * chunk_size} ({len(hts)} HTs)...')
        try:
            if isinstance(hts[0], str):
                def read_clump_ht(f):
                    ht = hl.read_table(f)
                    ht = ht.drop('idx')
                    return ht
                hts = list(map(read_clump_ht, hts))
            ht = hl.Table.multi_way_zip_join(hts, 'row_field_name', 'global_field_name')
        except:
            if debug:
                print(f'problem in range {i * chunk_size}-{i * chunk_size + chunk_size}')
                _ = [ht.describe() for ht in hts]
            raise
        outer_hts.append(ht.checkpoint(f'{temp_dir}/temp_output_{i}.ht', **checkpoint_kwargs))
    ht = hl.Table.multi_way_zip_join(outer_hts, 'row_field_name_outer', 'global_field_name_outer')
    ht = ht.transmute(inner_row=hl.flatmap(lambda i:
                                           hl.cond(hl.is_missing(ht.row_field_name_outer[i].row_field_name),
                                                   hl.range(0, hl.len(ht.global_field_name_outer[i].global_field_name))
                                                   .map(lambda _: hl.null(ht.row_field_name_outer[i].row_field_name.dtype.element_type)),
                                                   ht.row_field_name_outer[i].row_field_name),
                                           hl.range(hl.len(ht.global_field_name_outer))))
    ht = ht.transmute_globals(inner_global=hl.flatmap(lambda x: x.global_field_name, ht.global_field_name_outer))
    mt = ht._unlocalize_entries('inner_row', 'inner_global', globals_for_col_key)
    return mt
def resume_mwzj(temp_dir, globals_for_col_key):
    r'''
    For resuming multiway zip join if intermediate tables have already been written
    '''
    ls = hl.hadoop_ls(temp_dir)
    paths = [x['path'] for x in ls if 'temp_output' in x['path'] ]
    chunk_size = len(paths)
    outer_hts = []
    for i in range(chunk_size):
        outer_hts.append(hl.read_table(f'{temp_dir}/temp_output_{i}.ht'))
    ht = hl.Table.multi_way_zip_join(outer_hts, 'row_field_name_outer', 'global_field_name_outer')
    ht = ht.transmute(inner_row=hl.flatmap(lambda i:
                                           hl.cond(hl.is_missing(ht.row_field_name_outer[i].row_field_name),
                                                   hl.range(0, hl.len(ht.global_field_name_outer[i].global_field_name))
                                                   .map(lambda _: hl.null(ht.row_field_name_outer[i].row_field_name.dtype.element_type)),
                                                   ht.row_field_name_outer[i].row_field_name),
                                           hl.range(hl.len(ht.global_field_name_outer))))
    ht = ht.transmute_globals(inner_global=hl.flatmap(lambda x: x.global_field_name, ht.global_field_name_outer))
    mt = ht._unlocalize_entries('inner_row', 'inner_global', globals_for_col_key)
    return mt
    def phase_diploid_proband(
            locus: hl.expr.LocusExpression,
            alleles: hl.expr.ArrayExpression,
            proband_call: hl.expr.CallExpression,
            father_call: hl.expr.CallExpression,
            mother_call: hl.expr.CallExpression
    ) -> hl.expr.ArrayExpression:
        """
        Returns phased genotype calls in the case of a diploid proband
        (autosomes, PAR regions of sex chromosomes or non-PAR regions of a female proband)

        :param LocusExpression locus: Locus in the trio MatrixTable
        :param ArrayExpression alleles: Alleles in the trio MatrixTable
        :param CallExpression proband_call: Input proband genotype call
        :param CallExpression father_call: Input father genotype call
        :param CallExpression mother_call: Input mother genotype call
        :return: Array containing: phased proband call, phased father call, phased mother call
        :rtype: ArrayExpression
        """

        proband_v = proband_call.one_hot_alleles(alleles)
        father_v = hl.cond(
            locus.in_x_nonpar() | locus.in_y_nonpar(),
            hl.or_missing(father_call.is_haploid(), hl.array([father_call.one_hot_alleles(alleles)])),
            call_to_one_hot_alleles_array(father_call, alleles)
        )
        mother_v = call_to_one_hot_alleles_array(mother_call, alleles)

        combinations = hl.flatmap(
            lambda f:
            hl.zip_with_index(mother_v)
                .filter(lambda m: m[1] + f[1] == proband_v)
                .map(lambda m: hl.struct(m=m[0], f=f[0])),
            hl.zip_with_index(father_v)
        )

        return (
            hl.or_missing(
                hl.is_defined(combinations) & (hl.len(combinations) == 1),
                hl.array([
                    hl.call(father_call[combinations[0].f], mother_call[combinations[0].m], phased=True),
                    hl.cond(father_call.is_haploid(), hl.call(father_call[0], phased=True), phase_parent_call(father_call, combinations[0].f)),
                    phase_parent_call(mother_call, combinations[0].m)
                ])
            )
        )
    def phase_diploid_proband(
            locus: hl.expr.LocusExpression,
            alleles: hl.expr.ArrayExpression,
            proband_call: hl.expr.CallExpression,
            father_call: hl.expr.CallExpression,
            mother_call: hl.expr.CallExpression
    ) -> hl.expr.ArrayExpression:
        """
        Returns phased genotype calls in the case of a diploid proband
        (autosomes, PAR regions of sex chromosomes or non-PAR regions of a female proband)

        :param LocusExpression locus: Locus in the trio MatrixTable
        :param ArrayExpression alleles: Alleles in the trio MatrixTable
        :param CallExpression proband_call: Input proband genotype call
        :param CallExpression father_call: Input father genotype call
        :param CallExpression mother_call: Input mother genotype call
        :return: Array containing: phased proband call, phased father call, phased mother call
        :rtype: ArrayExpression
        """

        proband_v = proband_call.one_hot_alleles(alleles)
        father_v = hl.cond(
            locus.in_x_nonpar() | locus.in_y_nonpar(),
            hl.or_missing(father_call.is_haploid(), hl.array([father_call.one_hot_alleles(alleles)])),
            call_to_one_hot_alleles_array(father_call, alleles)
        )
        mother_v = call_to_one_hot_alleles_array(mother_call, alleles)

        combinations = hl.flatmap(
            lambda f:
            hl.zip_with_index(mother_v)
                .filter(lambda m: m[1] + f[1] == proband_v)
                .map(lambda m: hl.struct(m=m[0], f=f[0])),
            hl.zip_with_index(father_v)
        )

        return (
            hl.or_missing(
                hl.is_defined(combinations) & (hl.len(combinations) == 1),
                hl.array([
                    hl.call(father_call[combinations[0].f], mother_call[combinations[0].m], phased=True),
                    hl.cond(father_call.is_haploid(), hl.call(father_call[0], phased=True), phase_parent_call(father_call, combinations[0].f)),
                    phase_parent_call(mother_call, combinations[0].m)
                ])
            )
        )
Exemple #6
0
def annotate_fields(mt, gencode_release, gencode_path):
    genotypes = hl.agg.collect(
        hl.struct(sample_id=mt.s,
                  gq=mt.GQ,
                  cn=mt.RD_CN,
                  num_alt=hl.if_else(hl.is_defined(mt.GT),
                                     mt.GT.n_alt_alleles(), -1)))
    rows = mt.annotate_rows(genotypes=genotypes).rows()

    rows = rows.annotate(**{k: v(rows) for k, v in CORE_FIELDS.items()})

    gene_id_mapping = hl.literal(
        load_gencode(gencode_release, download_path=gencode_path))

    rows = rows.annotate(
        sortedTranscriptConsequences=hl.flatmap(
            lambda x: x,
            hl.filter(lambda x: hl.is_defined(x), [
                rows.info[col].map(lambda gene: hl.struct(
                    gene_symbol=gene,
                    gene_id=gene_id_mapping[gene],
                    predicted_consequence=col.split('__')[-1])) for col in [
                        gene_col for gene_col in rows.info
                        if gene_col.startswith('PROTEIN_CODING__')
                        and rows.info[gene_col].dtype == hl.dtype('array<str>')
                    ]
            ])),
        sv_type=rows.alleles[1].replace('[<>]', '').split(':', 2),
    )

    DERIVED_FIELDS.update({
        'filters':
        lambda rows: hl.if_else(
            hl.len(rows.filters) > 0, rows.filters,
            hl.missing(hl.dtype('array<str>')))
    })
    rows = rows.annotate(**{k: v(rows) for k, v in DERIVED_FIELDS.items()})

    rows = rows.rename({'rsid': 'variantId'})

    return rows.key_by().select(*FIELDS)
Exemple #7
0
def load_icd_data(pre_phesant_data_path,
                  icd_codings_path,
                  temp_directory,
                  force_overwrite_intermediate: bool = False,
                  include_dates: bool = False,
                  icd9: bool = False):
    """
    Load raw (pre-PHESANT) phenotype data and extract ICD codes into hail MatrixTable with booleans as entries

    :param str pre_phesant_data_path: Input phenotype file
    :param str icd_codings_path: Input coding metadata
    :param str temp_directory: Temp bucket/directory to write intermediate file
    :param bool force_overwrite_intermediate: Whether to overwrite intermediate loaded file
    :param bool include_dates: Whether to also load date data (not implemented yet)
    :param bool icd9: Whether to load ICD9 data
    :return: MatrixTable with ICD codes
    :rtype: MatrixTable
    """
    if icd9:
        code_locations = {'primary_codes': '41203', 'secondary_codes': '41205'}
    else:
        code_locations = {
            'primary_codes': '41202',
            'secondary_codes': '41204',
            'external_codes': '41201',
            'cause_of_death_codes': '40001'
        }
    date_locations = {'primary_codes': '41262'}
    ht = hl.import_table(pre_phesant_data_path,
                         impute=not icd9,
                         min_partitions=100,
                         missing='',
                         key='userId',
                         types={'userId': hl.tint32})
    ht = ht.checkpoint(f'{temp_directory}/pre_phesant.ht',
                       _read_if_exists=not force_overwrite_intermediate)
    all_phenos = list(ht.row_value)
    fields_to_select = {
        code: [ht[x] for x in all_phenos if x.startswith(f'x{loc}')]
        for code, loc in code_locations.items()
    }
    if include_dates:
        fields_to_select.update({
            f'date_{code}':
            [ht[x] for x in all_phenos if x.startswith(f'x{loc}')]
            for code, loc in date_locations.items()
        })
    ht = ht.select(**fields_to_select)
    ht = ht.annotate(
        **{
            code: ht[code].filter(lambda x: hl.is_defined(x))
            for code in code_locations
        },
        # **{f'date_{code}': ht[code].filter(lambda x: hl.is_defined(x)) for code in date_locations}
    )
    # ht = ht.annotate(primary_codes_with_date=hl.dict(hl.zip(ht.primary_codes, ht.date_primary_codes)))
    all_codes = hl.sorted(
        hl.array(
            hl.set(
                hl.flatmap(
                    lambda x: hl.array(x),
                    ht.aggregate([
                        hl.agg.explode(lambda c: hl.agg.collect_as_set(c),
                                       ht[code]) for code in code_locations
                    ],
                                 _localize=True)))))
    ht = ht.select(bool_codes=all_codes.map(lambda x: hl.struct(
        **{code: ht[code].contains(x)
           for code in code_locations})))
    ht = ht.annotate_globals(
        all_codes=all_codes.map(lambda x: hl.struct(icd_code=x)))
    mt = ht._unlocalize_entries('bool_codes', 'all_codes', ['icd_code'])
    mt = mt.annotate_entries(
        any_codes=hl.any(lambda x: x, list(mt.entry.values())))
    # mt = mt.annotate_entries(date=hl.cond(mt.primary_codes, mt.primary_codes_with_date[mt.icd_code], hl.null(hl.tstr)))
    mt = mt.annotate_cols(truncated=False).annotate_globals(
        code_locations=code_locations)
    mt = mt.checkpoint(f'{temp_directory}/raw_icd.mt',
                       _read_if_exists=not force_overwrite_intermediate)
    trunc_mt = mt.filter_cols((hl.len(mt.icd_code) == 3)
                              | (hl.len(mt.icd_code) == 4))
    trunc_mt = trunc_mt.key_cols_by(icd_code=trunc_mt.icd_code[:3])
    trunc_mt = trunc_mt.group_cols_by('icd_code').aggregate_entries(
        **{
            code: hl.agg.any(trunc_mt[code])
            for code in list(code_locations.keys()) + ['any_codes']
        }).aggregate_cols(n_phenos_truncated=hl.agg.count()).result()
    trunc_mt = trunc_mt.filter_cols(trunc_mt.n_phenos_truncated > 1)
    trunc_mt = trunc_mt.annotate_cols(
        **mt.cols().drop('truncated', 'code_locations')[trunc_mt.icd_code],
        truncated=True).drop('n_phenos_truncated')
    mt = mt.union_cols(trunc_mt)
    coding_ht = hl.read_table(icd_codings_path)
    return mt.annotate_cols(**coding_ht[mt.col_key])