示例#1
0
    def test_window_by_locus(self):
        mt = hl.utils.range_matrix_table(100, 2, n_partitions=10)
        mt = mt.annotate_rows(locus=hl.locus('1', mt.row_idx + 1))
        mt = mt.key_rows_by('locus')
        mt = mt.annotate_entries(e_row_idx=mt.row_idx, e_col_idx=mt.col_idx)
        mt = hl.window_by_locus(mt, 5).cache()

        self.assertEqual(mt.count_rows(), 100)

        rows = mt.rows()
        self.assertTrue(
            rows.all((rows.row_idx < 5) | (rows.prev_rows.length() == 5)))
        self.assertTrue(
            rows.all(
                hl.all(lambda x: (rows.row_idx - 1 - x[0]) == x[1].row_idx,
                       hl.zip_with_index(rows.prev_rows))))

        entries = mt.entries()
        self.assertTrue(
            entries.all(
                hl.all(lambda x: x.e_col_idx == entries.col_idx,
                       entries.prev_entries)))
        self.assertTrue(
            entries.all(
                hl.all(lambda x: entries.row_idx - 1 - x[0] == x[1].e_row_idx,
                       hl.zip_with_index(entries.prev_entries))))
示例#2
0
    def phase_haploid_proband_x_nonpar(
            proband_call: hl.expr.CallExpression,
            father_call: hl.expr.CallExpression,
            mother_call: hl.expr.CallExpression) -> hl.expr.ArrayExpression:
        """
        Returns phased genotype calls in the case of a haploid proband in the non-PAR region of X

        :param CallExpression proband_call: Input proband genotype call
        :param CallExpression father_call: Input father genotype call
        :param CallExpression mother_call: Input mother genotype call
        :return: Array containing: phased proband call, phased father call, phased mother call
        :rtype: ArrayExpression
        """

        transmitted_allele = hl.zip_with_index(
            hl.array([mother_call[0],
                      mother_call[1]])).find(lambda m: m[1] == proband_call[0])
        return hl.or_missing(
            hl.is_defined(transmitted_allele),
            hl.array([
                hl.call(proband_call[0], phased=True),
                hl.or_missing(father_call.is_haploid(),
                              hl.call(father_call[0], phased=True)),
                phase_parent_call(mother_call, transmitted_allele[0])
            ]))
示例#3
0
    def phase_diploid_proband(
            locus: hl.expr.LocusExpression,
            alleles: hl.expr.ArrayExpression,
            proband_call: hl.expr.CallExpression,
            father_call: hl.expr.CallExpression,
            mother_call: hl.expr.CallExpression
    ) -> hl.expr.ArrayExpression:
        """
        Returns phased genotype calls in the case of a diploid proband
        (autosomes, PAR regions of sex chromosomes or non-PAR regions of a female proband)

        :param LocusExpression locus: Locus in the trio MatrixTable
        :param ArrayExpression alleles: Alleles in the trio MatrixTable
        :param CallExpression proband_call: Input proband genotype call
        :param CallExpression father_call: Input father genotype call
        :param CallExpression mother_call: Input mother genotype call
        :return: Array containing: phased proband call, phased father call, phased mother call
        :rtype: ArrayExpression
        """

        proband_v = proband_call.one_hot_alleles(alleles)
        father_v = hl.cond(
            locus.in_x_nonpar() | locus.in_y_nonpar(),
            hl.or_missing(father_call.is_haploid(), hl.array([father_call.one_hot_alleles(alleles)])),
            call_to_one_hot_alleles_array(father_call, alleles)
        )
        mother_v = call_to_one_hot_alleles_array(mother_call, alleles)

        combinations = hl.flatmap(
            lambda f:
            hl.zip_with_index(mother_v)
                .filter(lambda m: m[1] + f[1] == proband_v)
                .map(lambda m: hl.struct(m=m[0], f=f[0])),
            hl.zip_with_index(father_v)
        )

        return (
            hl.or_missing(
                hl.is_defined(combinations) & (hl.len(combinations) == 1),
                hl.array([
                    hl.call(father_call[combinations[0].f], mother_call[combinations[0].m], phased=True),
                    hl.cond(father_call.is_haploid(), hl.call(father_call[0], phased=True), phase_parent_call(father_call, combinations[0].f)),
                    phase_parent_call(mother_call, combinations[0].m)
                ])
            )
        )
示例#4
0
    def phase_diploid_proband(
            locus: hl.expr.LocusExpression,
            alleles: hl.expr.ArrayExpression,
            proband_call: hl.expr.CallExpression,
            father_call: hl.expr.CallExpression,
            mother_call: hl.expr.CallExpression
    ) -> hl.expr.ArrayExpression:
        """
        Returns phased genotype calls in the case of a diploid proband
        (autosomes, PAR regions of sex chromosomes or non-PAR regions of a female proband)

        :param LocusExpression locus: Locus in the trio MatrixTable
        :param ArrayExpression alleles: Alleles in the trio MatrixTable
        :param CallExpression proband_call: Input proband genotype call
        :param CallExpression father_call: Input father genotype call
        :param CallExpression mother_call: Input mother genotype call
        :return: Array containing: phased proband call, phased father call, phased mother call
        :rtype: ArrayExpression
        """

        proband_v = proband_call.one_hot_alleles(alleles)
        father_v = hl.cond(
            locus.in_x_nonpar() | locus.in_y_nonpar(),
            hl.or_missing(father_call.is_haploid(), hl.array([father_call.one_hot_alleles(alleles)])),
            call_to_one_hot_alleles_array(father_call, alleles)
        )
        mother_v = call_to_one_hot_alleles_array(mother_call, alleles)

        combinations = hl.flatmap(
            lambda f:
            hl.zip_with_index(mother_v)
                .filter(lambda m: m[1] + f[1] == proband_v)
                .map(lambda m: hl.struct(m=m[0], f=f[0])),
            hl.zip_with_index(father_v)
        )

        return (
            hl.or_missing(
                hl.is_defined(combinations) & (hl.len(combinations) == 1),
                hl.array([
                    hl.call(father_call[combinations[0].f], mother_call[combinations[0].m], phased=True),
                    hl.cond(father_call.is_haploid(), hl.call(father_call[0], phased=True), phase_parent_call(father_call, combinations[0].f)),
                    phase_parent_call(mother_call, combinations[0].m)
                ])
            )
        )
示例#5
0
def separate_results_mt_by_pop(mt):
    mt = mt.annotate_cols(
        pheno_data=hl.zip_with_index(mt.pheno_data)).explode_cols('pheno_data')
    mt = mt.annotate_cols(pop_index=mt.pheno_data[0],
                          pheno_data=mt.pheno_data[1])
    mt = mt.annotate_entries(
        summary_stats=mt.summary_stats[mt.pop_index]).drop('pop_index')
    return mt
示例#6
0
文件: test_misc.py 项目: danking/hail
    def test_window_by_locus(self):
        mt = hl.utils.range_matrix_table(100, 2, n_partitions=10)
        mt = mt.annotate_rows(locus=hl.locus('1', mt.row_idx + 1))
        mt = mt.key_rows_by('locus')
        mt = mt.annotate_entries(e_row_idx=mt.row_idx, e_col_idx=mt.col_idx)
        mt = hl.window_by_locus(mt, 5).cache()

        self.assertEqual(mt.count_rows(), 100)

        rows = mt.rows()
        self.assertTrue(rows.all((rows.row_idx < 5) | (rows.prev_rows.length() == 5)))
        self.assertTrue(rows.all(hl.all(lambda x: (rows.row_idx - 1 - x[0]) == x[1].row_idx,
                                        hl.zip_with_index(rows.prev_rows))))

        entries = mt.entries()
        self.assertTrue(entries.all(hl.all(lambda x: x.e_col_idx == entries.col_idx, entries.prev_entries)))
        self.assertTrue(entries.all(hl.all(lambda x: entries.row_idx - 1 - x[0] == x[1].e_row_idx,
                                           hl.zip_with_index(entries.prev_entries))))
示例#7
0
def sum_mcnv_ac_or_af(alts, values):
    return hl.bind(
        lambda cn2_index: hl.bind(
            lambda values_to_sum: values_to_sum.fold(lambda acc, n: acc + n, 0
                                                     ),
            hl.if_else(hl.is_defined(cn2_index), values[0:cn2_index].extend(
                values[cn2_index + 1:]), values),
        ),
        hl.zip_with_index(alts).find(lambda t: t[1] == "<CN=2>")[0],
    )
示例#8
0
def explode_lambda_ht(ht, by='ac'):
    ac_ht = ht.annotate(sumstats_qc=ht.sumstats_qc.select(
        *[x for x in ht.sumstats_qc.keys() if f'_{by}' in x]))
    ac_ht = ac_ht.annotate(
        index_ac=hl.zip_with_index(ac_ht[f'{by}_cutoffs'])).explode('index_ac')
    ac_ht = ac_ht.transmute(
        **{by: ac_ht.index_ac[1]}, **{
            x: ac_ht.sumstats_qc[x][ac_ht.index_ac[0]]
            for x in ac_ht.sumstats_qc
        })
    return ac_ht
示例#9
0
def separate_results_mt_by_pop(mt,
                               col_field='pheno_data',
                               entry_field='summary_stats',
                               skip_drop: bool = False):
    mt = mt.annotate_cols(
        col_array=hl.zip_with_index(mt[col_field])).explode_cols('col_array')
    mt = mt.transmute_cols(pop_index=mt.col_array[0],
                           **{col_field: mt.col_array[1]})
    mt = mt.annotate_entries(**{entry_field: mt[entry_field][mt.pop_index]})
    if not skip_drop:
        mt = mt.drop('pop_index')
    return mt
示例#10
0
def load_final_sumstats_mt(filter_phenos: bool = True,
                           filter_variants: bool = True,
                           filter_sumstats: bool = True,
                           separate_columns_by_pop: bool = True,
                           annotate_with_nearest_gene: bool = True):
    mt = hl.read_matrix_table(get_variant_results_path('full', 'mt'))
    variant_qual_ht = hl.read_table(get_variant_results_qc_path())
    mt = mt.annotate_rows(**variant_qual_ht[mt.row_key])
    pheno_qual_ht = hl.read_table(
        get_analysis_data_path('lambda', 'lambdas', 'full', 'ht'))
    mt = mt.annotate_cols(**pheno_qual_ht[mt.col_key])

    if filter_phenos:
        keep_phenos = hl.zip_with_index(
            mt.pheno_data).filter(lambda x: filter_lambda_gc(x[1].lambda_gc))

        mt = mt.annotate_cols(pheno_indices=keep_phenos.map(lambda x: x[0]),
                              pheno_data=keep_phenos.map(lambda x: x[1]))
        mt = mt.annotate_entries(
            summary_stats=hl.zip_with_index(mt.summary_stats).filter(
                lambda x: mt.pheno_indices.contains(x[0])).map(lambda x: x[1]))
        mt = mt.filter_cols(hl.len(mt.pheno_data) > 0)

    if filter_sumstats:
        mt = mt.annotate_entries(summary_stats=mt.summary_stats.map(
            lambda x: hl.or_missing(~x.low_confidence, x)))
        mt = mt.filter_entries(
            ~mt.summary_stats.all(lambda x: hl.is_missing(x.Pvalue)))

    if filter_variants:
        mt = mt.filter_rows(mt.high_quality)

    if annotate_with_nearest_gene:
        mt = annotate_nearest_gene(mt)

    if separate_columns_by_pop:
        mt = separate_results_mt_by_pop(mt)

    return mt
def total_ac_or_af(variant, field):
    return hl.cond(
        variant.type == "MCNV",
        hl.bind(
            lambda cn2_index: hl.bind(
                lambda values_to_sum: values_to_sum.fold(lambda acc, n: acc + n, 0),
                hl.cond(
                    hl.is_defined(cn2_index),
                    field[0:cn2_index].extend(field[cn2_index + 1 :]),
                    field,
                ),
            ),
            hl.zip_with_index(variant.alts).find(lambda t: t[1] == "<CN=2>")[0],
        ),
        field[0],
    )
示例#12
0
    def phase_haploid_proband_x_nonpar(
            proband_call: hl.expr.CallExpression,
            father_call: hl.expr.CallExpression,
            mother_call: hl.expr.CallExpression
    ) -> hl.expr.ArrayExpression:
        """
        Returns phased genotype calls in the case of a haploid proband in the non-PAR region of X

        :param CallExpression proband_call: Input proband genotype call
        :param CallExpression father_call: Input father genotype call
        :param CallExpression mother_call: Input mother genotype call
        :return: Array containing: phased proband call, phased father call, phased mother call
        :rtype: ArrayExpression
        """

        transmitted_allele = hl.zip_with_index(hl.array([mother_call[0], mother_call[1]])).find(lambda m: m[1] == proband_call[0])
        return hl.or_missing(
            hl.is_defined(transmitted_allele),
            hl.array([
                hl.call(proband_call[0], phased=True),
                hl.or_missing(father_call.is_haploid(), hl.call(father_call[0], phased=True)),
                phase_parent_call(mother_call, transmitted_allele[0])
            ])
        )
示例#13
0
def explode_trio_matrix(tm: hl.MatrixTable, col_keys: List[str] = ['s'], keep_trio_cols: bool = True, keep_trio_entries: bool = False) -> hl.MatrixTable:
    """Splits a trio MatrixTable back into a sample MatrixTable.

    Example
    -------
    >>> # Create a trio matrix from a sample matrix
    >>> pedigree = hl.Pedigree.read('data/case_control_study.fam')
    >>> trio_dataset = hl.trio_matrix(dataset, pedigree, complete_trios=True)

    >>> # Explode trio matrix back into a sample matrix
    >>> exploded_trio_dataset = explode_trio_matrix(trio_dataset)

    Notes
    -----
    The resulting MatrixTable column schema is the same as the proband/father/mother schema,
    and the resulting entry schema is the same as the proband_entry/father_entry/mother_entry schema.
    If the `keep_trio_cols` option is set, then an additional `source_trio` column is added with the trio column data.
    If the `keep_trio_entries` option is set, then an additional `source_trio_entry` column is added with the trio entry data.

    Note
    ----
    This assumes that the input MatrixTable is a trio MatrixTable (similar to the result of :meth:`.methods.trio_matrix`)
    Its entry schema has to contain 'proband_entry`, `father_entry` and `mother_entry` all with the same type.
    Its column schema has to contain 'proband`, `father` and `mother` all with the same type.

    Parameters
    ----------
    tm : :class:`.MatrixTable`
        Trio MatrixTable (entries have to be a Struct with `proband_entry`, `mother_entry` and `father_entry` present)
    col_keys : :obj:`list` of str
        Column key(s) for the resulting sample MatrixTable
    keep_trio_cols: bool
        Whether to add a `source_trio` column with the trio column data (default `True`)
    keep_trio_entries: bool
        Whether to add a `source_trio_entries` column with the trio entry data (default `False`)

    Returns
    -------
    :class:`.MatrixTable`
        Sample MatrixTable"""

    select_entries_expr = {'__trio_entries': hl.array([tm.proband_entry, tm.father_entry, tm.mother_entry])}
    if keep_trio_entries:
        select_entries_expr['source_trio_entry'] = hl.struct(**tm.entry)
    tm = tm.select_entries(**select_entries_expr)

    tm = tm.key_cols_by()
    select_cols_expr = {'__trio_members': hl.zip_with_index(hl.array([tm.proband, tm.father, tm.mother]))}
    if keep_trio_cols:
        select_cols_expr['source_trio'] = hl.struct(**tm.col)
    tm = tm.select_cols(**select_cols_expr)

    mt = tm.explode_cols(tm.__trio_members)

    mt = mt.transmute_entries(
        **mt.__trio_entries[mt.__trio_members[0]]
    )

    mt = mt.key_cols_by()
    mt = mt.transmute_cols(**mt.__trio_members[1])

    if col_keys:
        mt = mt.key_cols_by(*col_keys)

    return mt
示例#14
0
def explode_trio_matrix(tm: hl.MatrixTable, col_keys: List[str] = ['s']) -> hl.MatrixTable:
    """Splits a trio MatrixTable back into a sample MatrixTable.

    Example
    -------
    >>> # Create a trio matrix from a sample matrix
    >>> pedigree = hl.Pedigree.read('data/case_control_study.fam')
    >>> trio_dataset = hl.trio_matrix(dataset, pedigree, complete_trios=True)

    >>> # Explode trio matrix back into a sample matrix
    >>> exploded_trio_dataset = explode_trio_matrix(trio_dataset)

    Notes
    -----
    This assumes that the input MatrixTable is a trio MatrixTable (similar to the result of :meth:`.methods.trio_matrix`)
    In particular, it should have the following entry schema:
    - proband_entry
    - father_entry
    - mother_entry
    And the following column schema:
    - proband
    - father
    - mother

    Note
    ----
    The only entries kept are `proband_entry`, `father_entry` and `mother_entry` are dropped.
    The only columns kepy are `proband`, `father` and `mother`

    Parameters
    ----------
    tm : :class:`.MatrixTable`
        Trio MatrixTable (entries have to be a Struct with `proband_entry`, `mother_entry` and `father_entry` present)
    call_field : :obj:`list` of str
        Column key(s) for the resulting sample MatrixTable

    Returns
    -------
    :class:`.MatrixTable`
        Sample MatrixTable"""

    tm = tm.select_entries(
        __trio_entries=hl.array([tm.proband_entry, tm.father_entry, tm.mother_entry])
    )

    tm = tm.select_cols(
        __trio_members=hl.zip_with_index(hl.array([tm.proband, tm.father, tm.mother]))
    )
    mt = tm.explode_cols(tm.__trio_members)

    mt = mt.select_entries(
        **mt.__trio_entries[mt.__trio_members[0]]
    )

    mt = mt.key_cols_by()
    mt = mt.select_cols(**mt.__trio_members[1])

    if col_keys:
        mt = mt.key_cols_by(*col_keys)

    return mt
示例#15
0
def combine_pheno_files_multi_sex(pheno_file_dict: dict,
                                  cov_ht: hl.Table,
                                  truncated_codes_only: bool = True,
                                  custom_data_categorical: bool = True):
    full_mt: hl.MatrixTable = None
    sexes = ('both_sexes', 'females', 'males')

    for data_type, mt in pheno_file_dict.items():
        mt = mt.select_rows(**cov_ht[mt.row_key])
        print(data_type)
        if data_type == 'phecode':
            mt = mt.key_cols_by(trait_type=data_type,
                                phenocode=mt.phecode,
                                pheno_sex=mt.phecode_sex,
                                coding=NULL_STR_KEY,
                                modifier=NULL_STR_KEY)
            mt = mt.select_cols(**compute_cases_binary(mt.case_control,
                                                       mt.sex),
                                description=mt.phecode_description,
                                description_more=NULL_STR,
                                coding_description=NULL_STR,
                                category=mt.phecode_group)
            mt = mt.select_entries(**format_entries(mt.case_control, mt.sex))
        elif data_type == 'prescriptions':

            def format_prescription_name(pheno):
                return pheno.replace(',', '|').replace('/', '_')

            mt = mt.select_entries(
                value=hl.or_else(hl.len(mt.values) > 0, False))
            mt2 = mt.group_cols_by(
                trait_type=data_type,
                phenocode=format_prescription_name(
                    mt.Drug_Category_and_Indication),
                pheno_sex='both_sexes',
                coding=NULL_STR_KEY,
                modifier=NULL_STR_KEY,
            ).aggregate(value=hl.agg.any(mt.value)).select_cols(
                category=NULL_STR)
            mt = mt.key_cols_by(trait_type=data_type,
                                phenocode=format_prescription_name(
                                    mt.Generic_Name),
                                pheno_sex='both_sexes',
                                coding=NULL_STR_KEY,
                                modifier=NULL_STR_KEY)
            mt = mt.select_cols(category=mt.Drug_Category_and_Indication)
            mt = mt.union_cols(mt2)
            mt = mt.select_cols(**compute_cases_binary(mt.value, mt.sex),
                                description=NULL_STR,
                                description_more=NULL_STR,
                                coding_description=NULL_STR,
                                category=mt.category)
            mt = mt.select_entries(**format_entries(mt.value, mt.sex))
        elif data_type == 'custom':
            mt = mt.select_entries(**format_entries(mt.value, mt.sex))
            mt = mt.select_cols(**{
                f'n_cases_{sex}': hl.agg.count_where(
                    hl.cond(mt.trait_type == 'categorical', mt[sex] == 1.0,
                            hl.is_defined(mt[sex])))
                for sex in sexes
            },
                                description=NULL_STR,
                                description_more=NULL_STR,
                                coding_description=NULL_STR,
                                category=mt.category)
        elif data_type == 'additional':
            mt = mt.key_cols_by(trait_type='continuous',
                                phenocode=mt.pheno,
                                pheno_sex='both_sexes',
                                coding=NULL_STR_KEY,
                                modifier=mt.coding)
            mt = mt.select_cols(
                **{
                    f'n_cases_{sex}':
                    hl.agg.count_where(hl.is_defined(mt[sex]))
                    for sex in sexes
                },
                description=hl.coalesce(
                    *[mt[f'{sex}_pheno'].meaning for sex in sexes]),
                description_more=hl.coalesce(
                    *[mt[f'{sex}_pheno'].description for sex in sexes]),
                coding_description=NULL_STR,
                category=NULL_STR)
        elif data_type in ('categorical', 'continuous'):
            mt = mt.key_cols_by(trait_type=data_type,
                                phenocode=hl.str(mt.pheno),
                                pheno_sex='both_sexes',
                                coding=mt.coding if data_type == 'categorical'
                                else NULL_STR_KEY,
                                modifier=NULL_STR_KEY
                                if data_type == 'categorical' else mt.coding)

            def check_func(x):
                return x if data_type == 'categorical' else hl.is_defined(x)

            mt = mt.select_cols(
                **{
                    f'n_cases_{sex}': hl.agg.count_where(check_func(mt[sex]))
                    for sex in sexes
                },
                description=hl.coalesce(
                    *[mt[f'{sex}_pheno'].Field for sex in sexes]),
                description_more=hl.coalesce(
                    *[mt[f'{sex}_pheno'].Notes for sex in sexes]),
                coding_description=hl.coalesce(
                    *[mt[f'{sex}_pheno'].meaning for sex in sexes])
                if data_type == 'categorical' else NULL_STR,
                category=hl.coalesce(
                    *[mt[f'{sex}_pheno'].Path for sex in sexes]))
            mt = mt.select_entries(
                **{sex: hl.float64(mt[sex])
                   for sex in sexes})

        elif 'icd_code' in list(mt.col_key):
            icd_version = mt.icd_version if 'icd_version' in list(
                mt.col) else ''
            mt = mt.key_cols_by(trait_type=icd_version,
                                phenocode=mt.icd_code,
                                pheno_sex='both_sexes',
                                coding=NULL_STR_KEY,
                                modifier=NULL_STR_KEY)
            if truncated_codes_only:
                mt = mt.filter_cols(hl.len(mt.icd_code) == 3)
                mt = mt.collect_cols_by_key()
                mt = mt.annotate_cols(keep=hl.if_else(
                    hl.len(mt.truncated) == 1, 0,
                    hl.zip_with_index(mt.truncated).filter(lambda x: x[1]).map(
                        lambda x: x[0])[0]))
                mt = mt.select_entries(**{x: mt[x][mt.keep] for x in mt.entry})
                mt = mt.select_cols(
                    **{x: mt[x][mt.keep]
                       for x in mt.col_value if x != 'keep'})
            mt = mt.select_cols(**compute_cases_binary(mt.any_codes, mt.sex),
                                description=mt.short_meaning,
                                description_more="truncated: " +
                                hl.str(mt.truncated) if 'truncated' in list(
                                    mt.col_value) else NULL_STR,
                                coding_description=NULL_STR,
                                category=mt.meaning)
            mt = mt.select_entries(**format_entries(mt.any_codes, mt.sex))
        elif data_type == 'icd_first_occurrence':
            mt = mt.select_entries(**format_entries(mt.value, mt.sex))
            mt = mt.select_cols(**compute_cases_binary(
                hl.is_defined(mt.both_sexes), mt.sex),
                                description=mt.Field,
                                description_more=mt.Notes,
                                coding_description=NULL_STR,
                                category=mt.Path)
        else:  # 'biomarkers', 'activity_monitor'
            mt = mt.key_cols_by(trait_type=mt.trait_type
                                if 'trait_type' in list(mt.col) else data_type,
                                phenocode=hl.str(mt.pheno),
                                pheno_sex='both_sexes',
                                coding=NULL_STR_KEY,
                                modifier=NULL_STR_KEY)
            mt = mt.select_entries(**format_entries(mt.value, mt.sex))
            mt = mt.select_cols(**{
                f'n_cases_{sex}': hl.agg.count_where(hl.is_defined(mt[sex]))
                for sex in sexes
            },
                                description=mt.Field,
                                description_more=NULL_STR,
                                coding_description=NULL_STR,
                                category=mt.Path)
        # else:
        #     raise ValueError('pheno or icd_code not in column key. New data type?')
        mt = mt.checkpoint(
            tempfile.mktemp(prefix=f'/tmp/{data_type}_', suffix='.mt'))
        if full_mt is None:
            full_mt = mt
        else:
            full_mt = full_mt.union_cols(mt,
                                         row_join_type='outer' if data_type
                                         == 'prescriptions' else 'inner')
    full_mt = full_mt.unfilter_entries()

    # Here because prescription data was smaller than the others (so need to set the missing samples to 0)
    return full_mt.select_entries(
        **{
            sex: hl.cond(full_mt.trait_type == 'prescriptions',
                         hl.or_else(full_mt[sex], hl.float64(0.0)),
                         full_mt[sex])
            for sex in sexes
        })
示例#16
0
def full_outer_join_mt(left: hl.MatrixTable,
                       right: hl.MatrixTable) -> hl.MatrixTable:
    """Performs a full outer join on `left` and `right`.

    Replaces row, column, and entry fields with the following:

     - `left_row` / `right_row`: structs of row fields from left and right.
     - `left_col` / `right_col`: structs of column fields from left and right.
     - `left_entry` / `right_entry`: structs of entry fields from left and right.

    Parameters
    ----------
    left : :class:`.MatrixTable`
    right : :class:`.MatrixTable`

    Returns
    -------
    :class:`.MatrixTable`
    """

    if [x.dtype for x in left.row_key.values()
        ] != [x.dtype for x in right.row_key.values()]:
        raise ValueError(f"row key types do not match:\n"
                         f"  left:  {list(left.row_key.values())}\n"
                         f"  right: {list(right.row_key.values())}")

    if [x.dtype for x in left.col_key.values()
        ] != [x.dtype for x in right.col_key.values()]:
        raise ValueError(f"column key types do not match:\n"
                         f"  left:  {list(left.col_key.values())}\n"
                         f"  right: {list(right.col_key.values())}")

    left = left.select_rows(left_row=left.row)
    left_t = left.localize_entries('left_entries', 'left_cols')
    right = right.select_rows(right_row=right.row)
    right_t = right.localize_entries('right_entries', 'right_cols')

    ht = left_t.join(right_t, how='outer')
    ht = ht.annotate_globals(left_keys=hl.group_by(
        lambda t: t[0],
        hl.zip_with_index(
            ht.left_cols.map(lambda x: hl.tuple([x[f] for f in left.col_key])),
            index_first=False)).map_values(
                lambda elts: elts.map(lambda t: t[1])),
                             right_keys=hl.group_by(
                                 lambda t: t[0],
                                 hl.zip_with_index(
                                     ht.right_cols.map(lambda x: hl.tuple(
                                         [x[f] for f in right.col_key])),
                                     index_first=False)).
                             map_values(lambda elts: elts.map(lambda t: t[1])))
    ht = ht.annotate_globals(key_indices=hl.array(ht.left_keys.key_set(
    ).union(ht.right_keys.key_set())).map(lambda k: hl.struct(
        k=k,
        left_indices=ht.left_keys.get(k),
        right_indices=ht.right_keys.get(k))).flatmap(lambda s: hl.case().when(
            hl.is_defined(s.left_indices) & hl.is_defined(s.right_indices),
            hl.range(0, s.left_indices.length()).flatmap(lambda i: hl.range(
                0, s.right_indices.length()).map(lambda j: hl.struct(
                    k=s.k,
                    left_index=s.left_indices[i],
                    right_index=s.right_indices[j])))
        ).when(
            hl.is_defined(s.left_indices),
            s.left_indices.map(lambda elt: hl.struct(
                k=s.k, left_index=elt, right_index=hl.null('int32')))).when(
                    hl.is_defined(s.right_indices),
                    s.right_indices.map(lambda elt: hl.struct(
                        k=s.k, left_index=hl.null('int32'), right_index=elt))
                ).or_error('assertion error')))
    ht = ht.annotate(__entries=ht.key_indices.map(
        lambda s: hl.struct(left_entry=ht.left_entries[s.left_index],
                            right_entry=ht.right_entries[s.right_index])))
    ht = ht.annotate_globals(__cols=ht.key_indices.map(
        lambda s: hl.struct(**{f: s.k[i]
                               for i, f in enumerate(left.col_key)},
                            left_col=ht.left_cols[s.left_index],
                            right_col=ht.right_cols[s.right_index])))
    ht = ht.drop('left_entries', 'left_cols', 'left_keys', 'right_entries',
                 'right_cols', 'right_keys', 'key_indices')
    return ht._unlocalize_entries('__entries', '__cols', list(left.col_key))
示例#17
0
def full_outer_join_mt(left: hl.MatrixTable, right: hl.MatrixTable) -> hl.MatrixTable:
    """Performs a full outer join on `left` and `right`.

    Replaces row, column, and entry fields with the following:

     - `left_row` / `right_row`: structs of row fields from left and right.
     - `left_col` / `right_col`: structs of column fields from left and right.
     - `left_entry` / `right_entry`: structs of entry fields from left and right.

    Parameters
    ----------
    left : :class:`.MatrixTable`
    right : :class:`.MatrixTable`

    Returns
    -------
    :class:`.MatrixTable`
    """

    if [x.dtype for x in left.row_key.values()] != [x.dtype for x in right.row_key.values()]:
        raise ValueError(f"row key types do not match:\n"
                         f"  left:  {list(left.row_key.values())}\n"
                         f"  right: {list(right.row_key.values())}")

    if [x.dtype for x in left.col_key.values()] != [x.dtype for x in right.col_key.values()]: 
        raise ValueError(f"column key types do not match:\n"
                         f"  left:  {list(left.col_key.values())}\n"
                         f"  right: {list(right.col_key.values())}")

    left = left.select_rows(left_row=left.row)
    left_t = left.localize_entries('left_entries', 'left_cols')
    right = right.select_rows(right_row=right.row)
    right_t = right.localize_entries('right_entries', 'right_cols')

    ht = left_t.join(right_t, how='outer')
    ht = ht.annotate_globals(
        left_keys=hl.group_by(
            lambda t: t[0],
            hl.zip_with_index(
                ht.left_cols.map(lambda x: hl.tuple([x[f] for f in left.col_key])), index_first=False)).map_values(
            lambda elts: elts.map(lambda t: t[1])),
        right_keys=hl.group_by(
            lambda t: t[0],
            hl.zip_with_index(
                ht.right_cols.map(lambda x: hl.tuple([x[f] for f in right.col_key])), index_first=False)).map_values(
            lambda elts: elts.map(lambda t: t[1])))
    ht = ht.annotate_globals(
        key_indices=hl.array(ht.left_keys.key_set().union(ht.right_keys.key_set()))
            .map(lambda k: hl.struct(k=k, left_indices=ht.left_keys.get(k), right_indices=ht.right_keys.get(k)))
            .flatmap(lambda s: hl.case()
                     .when(hl.is_defined(s.left_indices) & hl.is_defined(s.right_indices),
                           hl.range(0, s.left_indices.length()).flatmap(
                               lambda i: hl.range(0, s.right_indices.length()).map(
                                   lambda j: hl.struct(k=s.k, left_index=s.left_indices[i],
                                                       right_index=s.right_indices[j]))))
                     .when(hl.is_defined(s.left_indices),
                           s.left_indices.map(
                               lambda elt: hl.struct(k=s.k, left_index=elt, right_index=hl.null('int32'))))
                     .when(hl.is_defined(s.right_indices),
                           s.right_indices.map(
                               lambda elt: hl.struct(k=s.k, left_index=hl.null('int32'), right_index=elt)))
                     .or_error('assertion error')))
    ht = ht.annotate(__entries=ht.key_indices.map(lambda s: hl.struct(left_entry=ht.left_entries[s.left_index],
                                                                      right_entry=ht.right_entries[s.right_index])))
    ht = ht.annotate_globals(__cols=ht.key_indices.map(
        lambda s: hl.struct(**{f: s.k[i] for i, f in enumerate(left.col_key)},
                            left_col=ht.left_cols[s.left_index],
                            right_col=ht.right_cols[s.right_index])))
    ht = ht.drop('left_entries', 'left_cols', 'left_keys', 'right_entries', 'right_cols', 'right_keys', 'key_indices')
    return ht._unlocalize_entries('__entries', '__cols', list(left.col_key))
示例#18
0
def import_structural_variants(vcf_path):
    ds = hl.import_vcf(vcf_path, force_bgz=True, min_partitions=32).rows()

    ds = ds.annotate(
        **{field.lower(): ds.info[field]
           for field in TOP_LEVEL_INFO_FIELDS})

    ds = ds.annotate(
        variant_id=ds.rsid.replace("^gnomAD-SV_v2.1_", ""),
        reference_genome="GRCh37",
        # Start
        chrom=ds.locus.contig,
        pos=ds.locus.position,
        xpos=x_position(ds.locus.contig, ds.locus.position),
        # End
        end=ds.info.END,
        xend=x_position(ds.locus.contig, ds.info.END),
        # Start 2
        chrom2=ds.info.CHR2,
        pos2=ds.info.POS2,
        xpos2=x_position(ds.info.CHR2, ds.info.POS2),
        # End 2
        end2=ds.info.END2,
        xend2=x_position(ds.info.CHR2, ds.info.END2),
        # Other
        length=ds.info.SVLEN,
        type=ds.info.SVTYPE,
        alts=ds.alleles[1:],
    )

    # MULTIALLELIC should not be used as a quality filter in the browser
    ds = ds.annotate(filters=ds.filters.difference(hl.set(["MULTIALLELIC"])))

    # Group gene lists for all consequences in one field
    ds = ds.annotate(consequences=hl.array([
        hl.struct(
            consequence=csq.lower(),
            genes=hl.or_else(ds.info[f"PROTEIN_CODING__{csq}"],
                             hl.empty_array(hl.tstr)),
        ) for csq in RANKED_CONSEQUENCES
        if csq not in ("INTERGENIC", "NEAREST_TSS")
    ]).filter(lambda csq: hl.len(csq.genes) > 0))
    ds = ds.annotate(intergenic=ds.info.PROTEIN_CODING__INTERGENIC)

    ds = ds.annotate(major_consequence=hl.rbind(
        ds.consequences.find(lambda csq: hl.len(csq.genes) > 0),
        lambda csq: hl.or_else(csq.consequence,
                               hl.or_missing(ds.intergenic, "intergenic")),
    ))

    # Collect set of all genes for which a variant has a consequence
    ds = ds.annotate(genes=hl.set(ds.consequences.flatmap(lambda c: c.genes)))

    # Group per-population frequency values
    ds = ds.annotate(freq=hl.struct(
        **{field.lower(): ds.info[field]
           for field in FREQ_FIELDS},
        populations=[
            hl.struct(id=pop,
                      **{
                          field.lower(): ds.info[f"{pop}_{field}"]
                          for field in FREQ_FIELDS
                      }) for pop in DIVISIONS
        ],
    ))

    # For MCNVs, store per-copy number allele counts
    ds = ds.annotate(freq=ds.freq.annotate(copy_numbers=hl.or_missing(
        ds.type == "MCNV",
        hl.zip_with_index(ds.alts).map(lambda pair: hl.rbind(
            pair[0],
            pair[1],
            lambda index, alt: hl.struct(
                # Extract copy number. Example, get 2 from "CN=<2>"
                copy_number=hl.int(alt[4:-1]),
                ac=ds.freq.ac[index],
            ),
        )),
    )))

    # For MCNVs, sum AC/AF for all alt alleles except CN=2
    ds = ds.annotate(freq=ds.freq.annotate(
        ac=hl.if_else(ds.type == "MCNV", sum_mcnv_ac_or_af(
            ds.alts, ds.freq.ac), ds.freq.ac[0]),
        af=hl.if_else(ds.type == "MCNV", sum_mcnv_ac_or_af(
            ds.alts, ds.freq.af), ds.freq.af[0]),
        populations=hl.if_else(
            ds.type == "MCNV",
            ds.freq.populations.map(lambda pop: pop.annotate(
                ac=sum_mcnv_ac_or_af(ds.alts, pop.ac),
                af=sum_mcnv_ac_or_af(ds.alts, pop.af),
            )),
            ds.freq.populations.map(
                lambda pop: pop.annotate(ac=pop.ac[0], af=pop.af[0])),
        ),
    ))

    # Add hemizygous frequencies
    ds = ds.annotate(hemizygote_count=hl.dict(
        [(
            pop_id,
            hl.if_else(((ds.chrom == "X") | (ds.chrom == "Y"))
                       & ~ds.par, ds.info[f"{pop_id}_MALE_N_HEMIALT"], 0),
        ) for pop_id in POPULATIONS] +
        [(f"{pop_id}_FEMALE", 0) for pop_id in POPULATIONS] + [(
            f"{pop_id}_MALE",
            hl.if_else(((ds.chrom == "X") | (ds.chrom == "Y"))
                       & ~ds.par, ds.info[f"{pop_id}_MALE_N_HEMIALT"], 0),
        ) for pop_id in POPULATIONS] + [("FEMALE", 0)] +
        [("MALE",
          hl.if_else(((ds.chrom == "X") | (ds.chrom == "Y"))
                     & ~ds.par, ds.info.MALE_N_HEMIALT, 0))]))

    ds = ds.annotate(freq=ds.freq.annotate(
        hemizygote_count=hl.or_missing(
            ds.type != "MCNV",
            hl.if_else(((ds.chrom == "X") | (ds.chrom == "Y"))
                       & ~ds.par, ds.info.MALE_N_HEMIALT, 0),
        ),
        populations=hl.if_else(
            ds.type != "MCNV",
            ds.freq.populations.map(lambda pop: pop.annotate(
                hemizygote_count=ds.hemizygote_count[pop.id])),
            ds.freq.populations.map(
                lambda pop: pop.annotate(hemizygote_count=hl.null(hl.tint))),
        ),
    ))

    ds = ds.drop("hemizygote_count")

    # Rename n_homalt
    ds = ds.annotate(freq=ds.freq.annotate(
        homozygote_count=ds.freq.n_homalt,
        populations=ds.freq.populations.map(lambda pop: pop.annotate(
            homozygote_count=pop.n_homalt).drop("n_homalt")),
    ).drop("n_homalt"))

    # Re-key
    ds = ds.key_by("variant_id")

    ds = ds.drop("locus", "alleles", "info", "rsid")

    return ds
示例#19
0
def combine_pheno_files_multi_sex(pheno_file_dict: dict, cov_ht: hl.Table, truncated_codes_only: bool = True):
    full_mt: hl.MatrixTable = None
    sexes = ('both_sexes', 'females', 'males')

    def counting_func(value, trait_type):
        return value if trait_type == 'categorical' else hl.is_defined(value)

    for data_type, mt in pheno_file_dict.items():
        mt = mt.select_rows(**cov_ht[mt.row_key])
        print(data_type)
        if data_type == 'custom':
            mt = mt.select_entries(**format_entries(mt.value, mt.sex))
            mt = mt.select_cols(**{f'n_cases_{sex}': hl.agg.count_where(
                hl.cond(mt.trait_type == 'categorical', mt[sex] == 1.0, hl.is_defined(mt[sex]))
            ) for sex in sexes}, **{extra_col: mt[extra_col] if extra_col in list(mt.col) else NULL_STR
                                    for extra_col in PHENO_DESCRIPTION_FIELDS})
        elif data_type in ('categorical', 'continuous'):

            def get_non_missing_field(mt, field_name):
                return hl.coalesce(*[mt[f'{sex}_pheno'][field_name] for sex in sexes])

            mt = mt.select_cols(**compute_cases_binary(counting_func(mt.both_sexes, data_type), mt.sex),
                                # **{f'n_cases_{sex}': hl.agg.count_where(counting_func(mt[sex], mt.trait_type)) for sex in sexes},
                                description=get_non_missing_field(mt, 'Field'),
                                description_more=get_non_missing_field(mt, 'Notes'),
                                coding_description=get_non_missing_field(mt, 'meaning') if
                                data_type == 'categorical' else NULL_STR,
                                category=get_non_missing_field(mt, 'Path'))
            mt = mt.select_entries(**{sex: hl.float64(mt[sex]) for sex in sexes})

        # TODO: got here - move some of this to ICD load (get icd_version as icd10 and move truncation steps there)
        elif 'icd_code' in list(mt.col_key):
            icd_version = mt.icd_version if 'icd_version' in list(mt.col) else 'icd10'
            mt = mt.key_cols_by(trait_type=icd_version, phenocode=mt.icd_code, pheno_sex='both_sexes',
                                coding=NULL_STR_KEY, modifier=NULL_STR_KEY)
            if truncated_codes_only:
                mt = mt.filter_cols(hl.len(mt.icd_code) == 3)
                mt = mt.collect_cols_by_key()
                mt = mt.annotate_cols(keep=hl.if_else(
                    hl.len(mt.truncated) == 1, 0,
                    hl.zip_with_index(mt.truncated).filter(lambda x: x[1]).map(lambda x: x[0])[0]))
                mt = mt.select_entries(**{x: mt[x][mt.keep] for x in mt.entry})
                mt = mt.select_cols(**{x: mt[x][mt.keep] for x in mt.col_value if x != 'keep'})
            mt = mt.select_cols(**compute_cases_binary(mt.any_codes, mt.sex),
                                description=mt.short_meaning,
                                description_more="truncated: " + hl.str(mt.truncated) if 'truncated' in list(mt.col_value) else NULL_STR,
                                coding_description=NULL_STR,
                                category=mt.meaning)
            mt = mt.select_entries(**format_entries(mt.any_codes, mt.sex))
        elif data_type == 'icd_first_occurrence':
            mt = mt.select_entries(**format_entries(hl.is_defined(mt.value), mt.sex))
            mt = mt.select_cols(**compute_cases_binary(mt.both_sexes == 1.0, mt.sex),
                                description=mt.Field, description_more=mt.Notes,
                                coding_description=NULL_STR, category=mt.Path)
        else: # 'biomarkers', 'activity_monitor'
            mt = mt.key_cols_by(trait_type=mt.trait_type if 'trait_type' in list(mt.col) else data_type,
                                phenocode=hl.str(mt.pheno), pheno_sex='both_sexes', coding=NULL_STR_KEY, modifier=NULL_STR_KEY)
            mt = mt.select_entries(**format_entries(mt.value, mt.sex))
            mt = mt.select_cols(**{f'n_cases_{sex}': hl.agg.count_where(hl.is_defined(mt[sex])) for sex in sexes},
                                description=mt.Field, description_more=NULL_STR, coding_description=NULL_STR, category=mt.Path)
        # else:
        #     raise ValueError('pheno or icd_code not in column key. New data type?')
        mt = mt.checkpoint(tempfile.mktemp(prefix=f'/tmp/{data_type}_', suffix='.mt'))
        if full_mt is None:
            full_mt = mt
        else:
            full_mt = full_mt.union_cols(mt, row_join_type='outer')
    full_mt = full_mt.unfilter_entries()

    # Here because prescription data was smaller than the others (so need to set the missing samples to 0)
    return full_mt.select_entries(**{sex: hl.cond(
        full_mt.trait_type == 'prescriptions',
        hl.or_else(full_mt[sex], hl.float64(0.0)),
        full_mt[sex]) for sex in sexes})
示例#20
0
def main(args):
    hl.init()

    # Read in all sumstats
    mt = load_final_sumstats_mt(filter_phenos=True,
                                filter_variants=False,
                                filter_sumstats=True,
                                separate_columns_by_pop=False,
                                annotate_with_nearest_gene=False)

    # Annotate per-entry sample size
    def get_n(pheno_data, i):
        return pheno_data[i].n_cases + hl.or_else(pheno_data[i].n_controls, 0)

    mt = mt.annotate_entries(summary_stats=hl.map(
        lambda x: x[1].annotate(N=hl.or_missing(hl.is_defined(x[1]),
                                                get_n(mt.pheno_data, x[0]))),
        hl.zip_with_index(mt.summary_stats)))

    # Exclude entries with low confidence flag.
    if not args.keep_low_confidence_variants:
        mt = mt.annotate_entries(summary_stats=hl.map(
            lambda x: hl.or_missing(~x.low_confidence, x), mt.summary_stats))

    # Run fixed-effect meta-analysis (all + leave-one-out)
    mt = mt.annotate_entries(unnorm_beta=mt.summary_stats.BETA /
                             (mt.summary_stats.SE**2),
                             inv_se2=1 / (mt.summary_stats.SE**2))
    mt = mt.annotate_entries(
        sum_unnorm_beta=all_and_leave_one_out(mt.unnorm_beta,
                                              mt.pheno_data.pop),
        sum_inv_se2=all_and_leave_one_out(mt.inv_se2, mt.pheno_data.pop))
    mt = mt.transmute_entries(META_BETA=mt.sum_unnorm_beta / mt.sum_inv_se2,
                              META_SE=hl.map(lambda x: hl.sqrt(1 / x),
                                             mt.sum_inv_se2))
    mt = mt.annotate_entries(
        META_Pvalue=hl.map(lambda x: 2 * hl.pnorm(x), -hl.abs(mt.META_BETA /
                                                              mt.META_SE)))

    # Run heterogeneity test (Cochran's Q)
    mt = mt.annotate_entries(META_Q=hl.map(
        lambda x: hl.sum((mt.summary_stats.BETA - x)**2 * mt.inv_se2),
        mt.META_BETA),
                             variant_exists=hl.map(lambda x: ~hl.is_missing(x),
                                                   mt.summary_stats.BETA))
    mt = mt.annotate_entries(META_N_pops=all_and_leave_one_out(
        mt.variant_exists, mt.pheno_data.pop))
    mt = mt.annotate_entries(META_Pvalue_het=hl.map(
        lambda i: hl.pchisqtail(mt.META_Q[i], mt.META_N_pops[i] - 1),
        hl.range(hl.len(mt.META_Q))))

    # Add other annotations
    mt = mt.annotate_entries(
        ac_cases=hl.map(lambda x: x["AF.Cases"] * x.N, mt.summary_stats),
        ac_controls=hl.map(lambda x: x["AF.Controls"] * x.N, mt.summary_stats),
        META_AC_Allele2=all_and_leave_one_out(
            mt.summary_stats.AF_Allele2 * mt.summary_stats.N,
            mt.pheno_data.pop),
        META_N=all_and_leave_one_out(mt.summary_stats.N, mt.pheno_data.pop))
    mt = mt.annotate_entries(
        META_AF_Allele2=mt.META_AC_Allele2 / mt.META_N,
        META_AF_Cases=all_and_leave_one_out(mt.ac_cases, mt.pheno_data.pop) /
        mt.META_N,
        META_AF_Controls=all_and_leave_one_out(mt.ac_controls,
                                               mt.pheno_data.pop) / mt.META_N)
    mt = mt.drop('unnorm_beta', 'inv_se2', 'variant_exists', 'ac_cases',
                 'ac_controls', 'summary_stats', 'META_AC_Allele2')

    # Format everything into array<struct>
    def is_finite_or_missing(x):
        return (hl.or_missing(hl.is_finite(x), x))

    meta_fields = [
        'BETA', 'SE', 'Pvalue', 'Q', 'Pvalue_het', 'N', 'N_pops', 'AF_Allele2',
        'AF_Cases', 'AF_Controls'
    ]
    mt = mt.transmute_entries(meta_analysis=hl.map(
        lambda i: hl.struct(
            **{
                field: is_finite_or_missing(mt[f'META_{field}'][i])
                for field in meta_fields
            }), hl.range(hl.len(mt.META_BETA))))

    col_fields = ['n_cases', 'n_controls']
    mt = mt.annotate_cols(
        **{
            field: all_and_leave_one_out(mt.pheno_data[field],
                                         mt.pheno_data.pop)
            for field in col_fields
        })
    col_fields += ['pop']
    mt = mt.annotate_cols(pop=all_and_leave_one_out(
        mt.pheno_data.pop,
        mt.pheno_data.pop,
        all_f=lambda x: x,
        loo_f=lambda i, x: hl.filter(lambda y: y != x[i], x),
    ))
    mt = mt.transmute_cols(meta_analysis_data=hl.map(
        lambda i: hl.struct(**{field: mt[field][i]
                               for field in col_fields}),
        hl.range(hl.len(mt.pop))))

    mt.describe()
    mt.write(get_meta_analysis_results_path(), overwrite=args.overwrite)

    hl.copy_log('gs://ukb-diverse-pops/combined_results/meta_analysis.log')
示例#21
0
def full_outer_join_mt(left: hl.MatrixTable,
                       right: hl.MatrixTable) -> hl.MatrixTable:
    """Performs a full outer join on `left` and `right`.

    Replaces row, column, and entry fields with the following:

     - `left_row` / `right_row`: structs of row fields from left and right.
     - `left_col` / `right_col`: structs of column fields from left and right.
     - `left_entry` / `right_entry`: structs of entry fields from left and right.

    Examples
    --------

    The following creates and joins two random datasets with disjoint sample ids
    but non-disjoint variant sets. We use :func:`.or_else` to attempt to find a
    non-missing genotype. If neither genotype is non-missing, then the genotype
    is set to missing. In particular, note that Samples `2` and `3` have missing
    genotypes for loci 1:1 and 1:2 because those loci are not present in `mt2`
    and these samples are not present in `mt1`

    >>> hl.set_global_seed(0)
    >>> mt1 = hl.balding_nichols_model(1, 2, 3)
    >>> mt2 = hl.balding_nichols_model(1, 2, 3)
    >>> mt2 = mt2.key_rows_by(locus=hl.locus(mt2.locus.contig,
    ...                                      mt2.locus.position+2),
    ...                       alleles=mt2.alleles)
    >>> mt2 = mt2.key_cols_by(sample_idx=mt2.sample_idx+2)
    >>> mt1.show()
    +---------------+------------+------+------+
    | locus         | alleles    | 0.GT | 1.GT |
    +---------------+------------+------+------+
    | locus<GRCh37> | array<str> | call | call |
    +---------------+------------+------+------+
    | 1:1           | ["A","C"]  | 0/1  | 0/1  |
    | 1:2           | ["A","C"]  | 1/1  | 1/1  |
    | 1:3           | ["A","C"]  | 0/0  | 0/0  |
    +---------------+------------+------+------+
    <BLANKLINE>
    >>> mt2.show()  # doctest: +SKIP_OUTPUT_CHECK
    +---------------+------------+------+------+
    | locus         | alleles    | 0.GT | 1.GT |
    +---------------+------------+------+------+
    | locus<GRCh37> | array<str> | call | call |
    +---------------+------------+------+------+
    | 1:3           | ["A","C"]  | 0/1  | 1/1  |
    | 1:4           | ["A","C"]  | 0/1  | 0/1  |
    | 1:5           | ["A","C"]  | 1/1  | 0/0  |
    +---------------+------------+------+------+
    <BLANKLINE>
    >>> mt3 = hl.experimental.full_outer_join_mt(mt1, mt2)
    >>> mt3 = mt3.select_entries(GT=hl.or_else(mt3.left_entry.GT, mt3.right_entry.GT))
    >>> mt3.show()
    +---------------+------------+------+------+------+------+
    | locus         | alleles    | 0.GT | 1.GT | 2.GT | 3.GT |
    +---------------+------------+------+------+------+------+
    | locus<GRCh37> | array<str> | call | call | call | call |
    +---------------+------------+------+------+------+------+
    | 1:1           | ["A","C"]  | 0/1  | 0/1  | NA   | NA   |
    | 1:2           | ["A","C"]  | 1/1  | 1/1  | NA   | NA   |
    | 1:3           | ["A","C"]  | 0/0  | 0/0  | 0/1  | 1/1  |
    | 1:4           | ["A","C"]  | NA   | NA   | 0/1  | 0/1  |
    | 1:5           | ["A","C"]  | NA   | NA   | 1/1  | 0/0  |
    +---------------+------------+------+------+------+------+
    <BLANKLINE>

    Parameters
    ----------
    left : :class:`.MatrixTable`
    right : :class:`.MatrixTable`

    Returns
    -------
    :class:`.MatrixTable`
    """

    if [x.dtype for x in left.row_key.values()
        ] != [x.dtype for x in right.row_key.values()]:
        raise ValueError(f"row key types do not match:\n"
                         f"  left:  {list(left.row_key.values())}\n"
                         f"  right: {list(right.row_key.values())}")

    if [x.dtype for x in left.col_key.values()
        ] != [x.dtype for x in right.col_key.values()]:
        raise ValueError(f"column key types do not match:\n"
                         f"  left:  {list(left.col_key.values())}\n"
                         f"  right: {list(right.col_key.values())}")

    left = left.select_rows(left_row=left.row)
    left_t = left.localize_entries('left_entries', 'left_cols')
    right = right.select_rows(right_row=right.row)
    right_t = right.localize_entries('right_entries', 'right_cols')

    ht = left_t.join(right_t, how='outer')
    ht = ht.annotate_globals(left_keys=hl.group_by(
        lambda t: t[0],
        hl.zip_with_index(
            ht.left_cols.map(lambda x: hl.tuple([x[f] for f in left.col_key])),
            index_first=False)).map_values(
                lambda elts: elts.map(lambda t: t[1])),
                             right_keys=hl.group_by(
                                 lambda t: t[0],
                                 hl.zip_with_index(
                                     ht.right_cols.map(lambda x: hl.tuple(
                                         [x[f] for f in right.col_key])),
                                     index_first=False)).
                             map_values(lambda elts: elts.map(lambda t: t[1])))
    ht = ht.annotate_globals(key_indices=hl.array(ht.left_keys.key_set(
    ).union(ht.right_keys.key_set())).map(lambda k: hl.struct(
        k=k,
        left_indices=ht.left_keys.get(k),
        right_indices=ht.right_keys.get(k))).flatmap(lambda s: hl.case().when(
            hl.is_defined(s.left_indices) & hl.is_defined(s.right_indices),
            hl.range(0, s.left_indices.length()).flatmap(lambda i: hl.range(
                0, s.right_indices.length()).map(lambda j: hl.struct(
                    k=s.k,
                    left_index=s.left_indices[i],
                    right_index=s.right_indices[j])))
        ).when(
            hl.is_defined(s.left_indices),
            s.left_indices.map(lambda elt: hl.struct(
                k=s.k, left_index=elt, right_index=hl.null('int32')))).when(
                    hl.is_defined(s.right_indices),
                    s.right_indices.map(lambda elt: hl.struct(
                        k=s.k, left_index=hl.null('int32'), right_index=elt))
                ).or_error('assertion error')))
    ht = ht.annotate(__entries=ht.key_indices.map(
        lambda s: hl.struct(left_entry=ht.left_entries[s.left_index],
                            right_entry=ht.right_entries[s.right_index])))
    ht = ht.annotate_globals(__cols=ht.key_indices.map(
        lambda s: hl.struct(**{f: s.k[i]
                               for i, f in enumerate(left.col_key)},
                            left_col=ht.left_cols[s.left_index],
                            right_col=ht.right_cols[s.right_index])))
    ht = ht.drop('left_entries', 'left_cols', 'left_keys', 'right_entries',
                 'right_cols', 'right_keys', 'key_indices')
    return ht._unlocalize_entries('__entries', '__cols', list(left.col_key))
示例#22
0
def sorted_transcript_consequences_v2(vep_root):
    """Sort transcripts by 3 properties:

        1. coding > non-coding
        2. transcript consequence severity
        3. canonical > non-canonical

    so that the 1st array entry will be for the coding, most-severe, canonical transcript (assuming
    one exists).

    Also, for each transcript in the array, computes these additional fields:
        domains: converts structs with db/name fields to string db:name
        hgvs: hgvsp (formatted for synonymous variants) if it exists, otherwise hgvsc
        major_consequence: set to most severe consequence for that transcript (
            VEP sometimes provides multiple consequences for a single transcript)
        major_consequence_rank: major_consequence rank based on VEP SO ontology (most severe = 1)
            (see http://www.ensembl.org/info/genome/variation/predicted_data.html)
        category: set to one of: "lof", "missense", "synonymous", "other" based on the value of major_consequence.

    Args:
        vep_root (StructExpression): root path of the VEP struct in the MT
    """

    consequences = (vep_root.transcript_consequences.map(
        lambda c: c.annotate(consequence_terms=c.consequence_terms.filter(
            lambda t: ~OMIT_CONSEQUENCE_TERMS.contains(t)))
    ).filter(lambda c: c.consequence_terms.size() > 0).map(
        lambda c: c.annotate(major_consequence=hl.sorted(
            c.consequence_terms, key=consequence_term_rank)[0])
    ).map(lambda c: c.annotate(
        category=(hl.case().when(
            consequence_term_rank(c.major_consequence) <=
            consequence_term_rank("frameshift_variant"), "lof").when(
                consequence_term_rank(c.major_consequence) <=
                consequence_term_rank("missense_variant"),
                "missense",
            ).when(
                consequence_term_rank(c.major_consequence) <=
                consequence_term_rank("synonymous_variant"),
                "synonymous",
            ).default("other")),
        domains=c.domains.map(lambda domain: domain.db + ":" + domain.name),
        hgvs=hl.cond(
            hl.is_missing(c.hgvsp) | SPLICE_CONSEQUENCES.contains(
                c.major_consequence),
            c.hgvsc.split(":")[-1],
            hgvsp_from_consequence_amino_acids(c),
        ),
        major_consequence_rank=consequence_term_rank(c.major_consequence),
    )))

    consequences = hl.sorted(
        consequences,
        lambda c: (hl.bind(
            lambda is_coding, is_most_severe, is_canonical: (hl.cond(
                is_coding,
                hl.cond(is_most_severe, hl.cond(is_canonical, 1, 2),
                        hl.cond(is_canonical, 3, 4)),
                hl.cond(is_most_severe, hl.cond(is_canonical, 5, 6),
                        hl.cond(is_canonical, 7, 8)),
            )),
            hl.or_else(c.biotype, "") == "protein_coding",
            hl.set(c.consequence_terms).contains(vep_root.
                                                 most_severe_consequence),
            hl.or_else(c.canonical, 0) == 1,
        )),
    )

    consequences = hl.zip_with_index(consequences).map(
        lambda csq_with_index: csq_with_index[1].annotate(transcript_rank=
                                                          csq_with_index[0]))

    # TODO: Discard most of lof_info field
    # Keep whether lof_info contains DONOR_DISRUPTION, ACCEPTOR_DISRUPTION, or DE_NOVO_DONOR
    consequences = consequences.map(lambda c: c.select(
        "amino_acids",
        "biotype",
        "canonical",
        "category",
        "cdna_end",
        "cdna_start",
        "codons",
        "consequence_terms",
        "domains",
        "gene_id",
        "gene_symbol",
        "hgvs",
        "hgvsc",
        "hgvsp",
        "lof_filter",
        "lof_flags",
        "lof_info",
        "lof",
        "major_consequence",
        "major_consequence_rank",
        "polyphen_prediction",
        "protein_id",
        "protein_start",
        "sift_prediction",
        "transcript_id",
        "transcript_rank",
    ))

    return consequences
示例#23
0
def explode_trio_matrix(tm: hl.MatrixTable, col_keys: List[str] = ['s'], keep_trio_cols: bool = True, keep_trio_entries: bool = False) -> hl.MatrixTable:
    """Splits a trio MatrixTable back into a sample MatrixTable.

    Example
    -------
    >>> # Create a trio matrix from a sample matrix
    >>> pedigree = hl.Pedigree.read('data/case_control_study.fam')
    >>> trio_dataset = hl.trio_matrix(dataset, pedigree, complete_trios=True)

    >>> # Explode trio matrix back into a sample matrix
    >>> exploded_trio_dataset = explode_trio_matrix(trio_dataset)

    Notes
    -----
    The resulting MatrixTable column schema is the same as the proband/father/mother schema,
    and the resulting entry schema is the same as the proband_entry/father_entry/mother_entry schema.
    If the `keep_trio_cols` option is set, then an additional `source_trio` column is added with the trio column data.
    If the `keep_trio_entries` option is set, then an additional `source_trio_entry` column is added with the trio entry data.

    Note
    ----
    This assumes that the input MatrixTable is a trio MatrixTable (similar to the result of :meth:`.methods.trio_matrix`)
    Its entry schema has to contain 'proband_entry`, `father_entry` and `mother_entry` all with the same type.
    Its column schema has to contain 'proband`, `father` and `mother` all with the same type.

    Parameters
    ----------
    tm : :class:`.MatrixTable`
        Trio MatrixTable (entries have to be a Struct with `proband_entry`, `mother_entry` and `father_entry` present)
    col_keys : :obj:`list` of str
        Column key(s) for the resulting sample MatrixTable
    keep_trio_cols: bool
        Whether to add a `source_trio` column with the trio column data (default `True`)
    keep_trio_entries: bool
        Whether to add a `source_trio_entries` column with the trio entry data (default `False`)

    Returns
    -------
    :class:`.MatrixTable`
        Sample MatrixTable"""

    select_entries_expr = {'__trio_entries': hl.array([tm.proband_entry, tm.father_entry, tm.mother_entry])}
    if keep_trio_entries:
        select_entries_expr['source_trio_entry'] = hl.struct(**tm.entry)
    tm = tm.select_entries(**select_entries_expr)

    tm = tm.key_cols_by()
    select_cols_expr = {'__trio_members': hl.zip_with_index(hl.array([tm.proband, tm.father, tm.mother]))}
    if keep_trio_cols:
        select_cols_expr['source_trio'] = hl.struct(**tm.col)
    tm = tm.select_cols(**select_cols_expr)

    mt = tm.explode_cols(tm.__trio_members)

    mt = mt.transmute_entries(
        **mt.__trio_entries[mt.__trio_members[0]]
    )

    mt = mt.key_cols_by()
    mt = mt.transmute_cols(**mt.__trio_members[1])

    if col_keys:
        mt = mt.key_cols_by(*col_keys)

    return mt
示例#24
0
def get_expr_for_vep_sorted_transcript_consequences_array(
        vep_root,
        include_coding_annotations=True,
        omit_consequences=OMIT_CONSEQUENCE_TERMS):
    """Sort transcripts by 3 properties:

        1. coding > non-coding
        2. transcript consequence severity
        3. canonical > non-canonical

    so that the 1st array entry will be for the coding, most-severe, canonical transcript (assuming
    one exists).

    Also, for each transcript in the array, computes these additional fields:
        domains: converts Array[Struct] to string of comma-separated domain names
        hgvs: set to hgvsp is it exists, or else hgvsc. formats hgvsp for synonymous variants.
        major_consequence: set to most severe consequence for that transcript (
            VEP sometimes provides multiple consequences for a single transcript)
        major_consequence_rank: major_consequence rank based on VEP SO ontology (most severe = 1)
            (see http://www.ensembl.org/info/genome/variation/predicted_data.html)
        category: set to one of: "lof", "missense", "synonymous", "other" based on the value of major_consequence.

    Args:
        vep_root (StructExpression): root path of the VEP struct in the MT
        include_coding_annotations (bool): if True, fields relevant to protein-coding variants will be included
    """

    selected_annotations = [
        "biotype",
        "canonical",
        "cdna_start",
        "cdna_end",
        "codons",
        "gene_id",
        "gene_symbol",
        "hgvsc",
        "hgvsp",
        "transcript_id",
    ]

    if include_coding_annotations:
        selected_annotations.extend([
            "amino_acids",
            "lof",
            "lof_filter",
            "lof_flags",
            "lof_info",
            "polyphen_prediction",
            "protein_id",
            "protein_start",
            "sift_prediction",
        ])

    omit_consequence_terms = hl.set(
        omit_consequences) if omit_consequences else hl.empty_set(hl.tstr)

    result = hl.sorted(
        vep_root.transcript_consequences.map(lambda c: c.select(
            *selected_annotations,
            consequence_terms=c.consequence_terms.filter(
                lambda t: ~omit_consequence_terms.contains(t)),
            domains=c.domains.map(lambda domain: domain.db + ":" + domain.name
                                  ),
            major_consequence=hl.cond(
                c.consequence_terms.size() > 0,
                hl.sorted(c.consequence_terms,
                          key=lambda t: CONSEQUENCE_TERM_RANK_LOOKUP.get(t))[0
                                                                             ],
                hl.null(hl.tstr),
            ))).filter(lambda c: c.consequence_terms.size() > 0).
        map(lambda c: c.annotate(
            category=(hl.case().when(
                CONSEQUENCE_TERM_RANK_LOOKUP.get(c.major_consequence) <=
                CONSEQUENCE_TERM_RANK_LOOKUP.get("frameshift_variant"),
                "lof",
            ).when(
                CONSEQUENCE_TERM_RANK_LOOKUP.get(c.major_consequence) <=
                CONSEQUENCE_TERM_RANK_LOOKUP.get("missense_variant"),
                "missense",
            ).when(
                CONSEQUENCE_TERM_RANK_LOOKUP.get(c.major_consequence) <=
                CONSEQUENCE_TERM_RANK_LOOKUP.get("synonymous_variant"),
                "synonymous",
            ).default("other")),
            hgvs=get_expr_for_formatted_hgvs(c),
            major_consequence_rank=CONSEQUENCE_TERM_RANK_LOOKUP.get(
                c.major_consequence),
        )),
        lambda c: (hl.bind(
            lambda is_coding, is_most_severe, is_canonical: (hl.cond(
                is_coding,
                hl.cond(is_most_severe, hl.cond(is_canonical, 1, 2),
                        hl.cond(is_canonical, 3, 4)),
                hl.cond(is_most_severe, hl.cond(is_canonical, 5, 6),
                        hl.cond(is_canonical, 7, 8)),
            )),
            hl.or_else(c.biotype, "") == "protein_coding",
            hl.set(c.consequence_terms).contains(vep_root.
                                                 most_severe_consequence),
            hl.or_else(c.canonical, 0) == 1,
        )),
    )

    if not include_coding_annotations:
        # for non-coding variants, drop fields here that are hard to exclude in the above code
        result = result.map(lambda c: c.drop("domains", "hgvsp"))

    return hl.zip_with_index(result).map(lambda csq_with_index: csq_with_index[
        1].annotate(transcript_rank=csq_with_index[0]))