Ejemplo n.º 1
0
def combine(ts):
    def merge_alleles(alleles):
        from hail.expr.functions import _num_allele_type, _allele_ints
        return hl.rbind(
            alleles.map(lambda a: hl.or_else(a[0], '')).fold(
                lambda s, t: hl.cond(hl.len(s) > hl.len(t), s, t), ''),
            lambda ref: hl.rbind(
                alleles.map(lambda al: hl.rbind(
                    al[0], lambda r: hl.array([ref]).
                    extend(al[1:].map(lambda a: hl.rbind(
                        _num_allele_type(r, a), lambda at: hl.cond(
                            (_allele_ints['SNP'] == at)
                            | (_allele_ints['Insertion'] == at)
                            | (_allele_ints['Deletion'] == at)
                            | (_allele_ints['MNP'] == at)
                            | (_allele_ints['Complex'] == at), a + ref[hl.len(
                                r):], a)))))), lambda lal: hl.
                struct(globl=hl.array([ref]).extend(
                    hl.array(hl.set(hl.flatten(lal)).remove(ref))),
                       local=lal)))

    def renumber_entry(entry, old_to_new) -> StructExpression:
        # global index of alternate (non-ref) alleles
        return entry.annotate(LA=entry.LA.map(lambda lak: old_to_new[lak]))

    if (ts.row.dtype, ts.globals.dtype) not in _merge_function_map:
        f = hl.experimental.define_function(
            lambda row, gbl: hl.rbind(
                merge_alleles(row.data.map(lambda d: d.alleles)), lambda
                alleles: hl.struct(
                    locus=row.locus,
                    alleles=alleles.globl,
                    rsid=hl.find(hl.is_defined, row.data.map(lambda d: d.rsid)
                                 ),
                    __entries=hl.bind(
                        lambda combined_allele_index: hl.
                        range(0, hl.len(row.data)).flatmap(lambda i: hl.cond(
                            hl.is_missing(row.data[i].__entries),
                            hl.range(0, hl.len(gbl.g[i].__cols)).map(
                                lambda _: hl.null(row.data[i].__entries.dtype.
                                                  element_type)),
                            hl.bind(
                                lambda old_to_new: row.data[i].__entries.map(
                                    lambda e: renumber_entry(e, old_to_new)),
                                hl.range(0, hl.len(alleles.local[i])).map(
                                    lambda j: combined_allele_index[
                                        alleles.local[i][j]])))),
                        hl.dict(
                            hl.range(0, hl.len(alleles.globl)).map(
                                lambda j: hl.tuple([alleles.globl[j], j])))))),
            ts.row.dtype, ts.globals.dtype)
        _merge_function_map[(ts.row.dtype, ts.globals.dtype)] = f
    merge_function = _merge_function_map[(ts.row.dtype, ts.globals.dtype)]
    ts = Table(
        TableMapRows(
            ts._tir,
            Apply(merge_function._name, merge_function._ret_type,
                  TopLevelReference('row'), TopLevelReference('global'))))
    return ts.transmute_globals(
        __cols=hl.flatten(ts.g.map(lambda g: g.__cols)))
Ejemplo n.º 2
0
def combine(ts):
    # pylint: disable=protected-access
    tmp = ts.annotate(
        alleles=merge_alleles(ts.data.map(lambda d: d.alleles)),
        rsid=hl.find(hl.is_defined, ts.data.map(lambda d: d.rsid)),
        filters=hl.set(hl.flatten(ts.data.map(lambda d: hl.array(d.filters)))),
        info=hl.struct(
            DP=hl.sum(ts.data.map(lambda d: d.info.DP)),
            MQ_DP=hl.sum(ts.data.map(lambda d: d.info.MQ_DP)),
            QUALapprox=hl.sum(ts.data.map(lambda d: d.info.QUALapprox)),
            RAW_MQ=hl.sum(ts.data.map(lambda d: d.info.RAW_MQ)),
            VarDP=hl.sum(ts.data.map(lambda d: d.info.VarDP)),
            SB=hl.array([
                hl.sum(ts.data.map(lambda d: d.info.SB[0])),
                hl.sum(ts.data.map(lambda d: d.info.SB[1])),
                hl.sum(ts.data.map(lambda d: d.info.SB[2])),
                hl.sum(ts.data.map(lambda d: d.info.SB[3]))
            ])))
    tmp = tmp.annotate(
        __entries=hl.bind(
            lambda combined_allele_index:
            hl.range(0, hl.len(tmp.data)).flatmap(
                lambda i:
                hl.cond(hl.is_missing(tmp.data[i].__entries),
                        hl.range(0, hl.len(tmp.g[i].__cols))
                          .map(lambda _: hl.null(tmp.data[i].__entries.dtype.element_type)),
                        hl.bind(
                            lambda old_to_new: tmp.data[i].__entries.map(lambda e: renumber_entry(e, old_to_new)),
                            hl.range(0, hl.len(tmp.data[i].alleles)).map(
                                lambda j: combined_allele_index[tmp.data[i].alleles[j]])))),
            hl.dict(hl.range(0, hl.len(tmp.alleles)).map(
                lambda j: hl.tuple([tmp.alleles[j], j])))))
    tmp = tmp.annotate_globals(__cols=hl.flatten(tmp.g.map(lambda g: g.__cols)))

    return tmp.drop('data', 'g')
Ejemplo n.º 3
0
def combine(ts):
    # pylint: disable=protected-access
    tmp = ts.annotate(
        alleles=merge_alleles(ts.data.map(lambda d: d.alleles)),
        rsid=hl.find(hl.is_defined, ts.data.map(lambda d: d.rsid)),
        info=hl.struct(
            MQ_DP=hl.sum(ts.data.map(lambda d: d.info.MQ_DP)),
            QUALapprox=hl.sum(ts.data.map(lambda d: d.info.QUALapprox)),
            RAW_MQ=hl.sum(ts.data.map(lambda d: d.info.RAW_MQ)),
            VarDP=hl.sum(ts.data.map(lambda d: d.info.VarDP)),
            SB_TABLE=hl.array([
                hl.sum(ts.data.map(lambda d: d.info.SB_TABLE[0])),
                hl.sum(ts.data.map(lambda d: d.info.SB_TABLE[1])),
                hl.sum(ts.data.map(lambda d: d.info.SB_TABLE[2])),
                hl.sum(ts.data.map(lambda d: d.info.SB_TABLE[3]))
            ])))
    tmp = tmp.annotate(
        __entries=hl.bind(
            lambda combined_allele_index:
            hl.range(0, hl.len(tmp.data)).flatmap(
                lambda i:
                hl.cond(hl.is_missing(tmp.data[i].__entries),
                        hl.range(0, hl.len(tmp.g[i].__cols))
                          .map(lambda _: hl.null(tmp.data[i].__entries.dtype.element_type)),
                        hl.bind(
                            lambda old_to_new: tmp.data[i].__entries.map(lambda e: renumber_entry(e, old_to_new)),
                            hl.array([0]).extend(
                                hl.range(0, hl.len(tmp.data[i].alleles)).map(
                                    lambda j: combined_allele_index[tmp.data[i].alleles[j]]))))),
            hl.dict(hl.range(1, hl.len(tmp.alleles) + 1).map(
                lambda j: hl.tuple([tmp.alleles[j - 1], j])))))
    tmp = tmp.annotate_globals(__cols=hl.flatten(tmp.g.map(lambda g: g.__cols)))

    return tmp.drop('data', 'g')
Ejemplo n.º 4
0
 def get_lm_prediction_expr(metric: str):
     lm_pred_expr = _sample_qc_ht.lms[metric].beta[0] + hl.sum(
         hl.range(n_pcs).map(lambda i: _sample_qc_ht.lms[metric].beta[i + 1]
                             * _sample_qc_ht.scores[i]))
     if use_pc_square:
         lm_pred_expr = lm_pred_expr + hl.sum(
             hl.range(n_pcs).map(
                 lambda i: _sample_qc_ht.lms[metric].beta[i + n_pcs + 1] *
                 _sample_qc_ht.scores[i] * _sample_qc_ht.scores[i]))
     return lm_pred_expr
Ejemplo n.º 5
0
        def with_local_a_index(local_a_index):
            new_pl = hl.or_missing(
                hl.is_defined(old_entry.LPL),
                hl.or_missing(
                    hl.is_defined(local_a_index),
                    hl.range(0, 3).map(lambda i: hl.min(
                        hl.range(0, hl.triangle(hl.len(old_entry.LA))).
                        filter(lambda j: hl.downcode(
                            hl.unphased_diploid_gt_index_call(j), local_a_index
                        ) == hl.unphased_diploid_gt_index_call(i)).map(
                            lambda idx: old_entry.LPL[idx])))))
            fields = set(old_entry.keys())

            def with_pl(pl):
                new_exprs = {}
                dropped_fields = ['LA']
                if 'LGT' in fields:
                    new_exprs['GT'] = hl.downcode(
                        old_entry.LGT,
                        hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LGT')
                if 'LPGT' in fields:
                    new_exprs['PGT'] = hl.downcode(
                        old_entry.LPGT,
                        hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LPGT')
                if 'LAD' in fields:
                    new_exprs['AD'] = hl.or_missing(
                        hl.is_defined(old_entry.LAD), [
                            old_entry.LAD[0],
                            hl.or_else(old_entry.LAD[local_a_index], 0)
                        ])  # second entry zeroed for lack of non-ref AD
                    dropped_fields.append('LAD')
                if 'LPL' in fields:
                    new_exprs['PL'] = pl
                    if 'GQ' in fields:
                        new_exprs['GQ'] = hl.or_else(hl.gq_from_pl(pl),
                                                     old_entry.GQ)

                    dropped_fields.append('LPL')

                return hl.cond(
                    hl.len(ds.alleles) == 1,
                    old_entry.annotate(
                        **{
                            f[1:]: old_entry[f]
                            for f in ['LGT', 'LPGT', 'LAD', 'LPL']
                            if f in fields
                        }).drop(*dropped_fields),
                    old_entry.annotate(**new_exprs).drop(*dropped_fields))

            if 'LPL' in fields:
                return hl.bind(with_pl, new_pl)
            else:
                return with_pl(None)
Ejemplo n.º 6
0
def get_lowqual_expr(
    alleles: hl.expr.ArrayExpression,
    qual_approx_expr: Union[hl.expr.ArrayNumericExpression, hl.expr.NumericExpression],
    snv_phred_threshold: int = 30,
    snv_phred_het_prior: int = 30,  # 1/1000
    indel_phred_threshold: int = 30,
    indel_phred_het_prior: int = 39,  # 1/8,000
) -> Union[hl.expr.BooleanExpression, hl.expr.ArrayExpression]:
    """
    Computes lowqual threshold expression for either split or unsplit alleles based on QUALapprox or AS_QUALapprox

    .. note::

        When running This lowqual annotation using QUALapprox, it differs from the GATK LowQual filter.
        This is because GATK computes this annotation at the site level, which uses the least stringent prior for mixed sites.
        When run using AS_QUALapprox, this implementation can thus be more stringent for certain alleles at mixed sites.

    :param alleles: Array of alleles
    :param qual_approx_expr: QUALapprox or AS_QUALapprox
    :param snv_phred_threshold: Phred-scaled SNV "emission" threshold (similar to GATK emission threshold)
    :param snv_phred_het_prior: Phred-scaled SNV heterozygosity prior (30 = 1/1000 bases, GATK default)
    :param indel_phred_threshold: Phred-scaled indel "emission" threshold (similar to GATK emission threshold)
    :param indel_phred_het_prior: Phred-scaled indel heterozygosity prior (30 = 1/1000 bases, GATK default)
    :return: lowqual expression (BooleanExpression if `qual_approx_expr`is Numeric, Array[BooleanExpression] if `qual_approx_expr` is ArrayNumeric)
    """

    min_snv_qual = snv_phred_threshold + snv_phred_het_prior
    min_indel_qual = indel_phred_threshold + indel_phred_het_prior
    min_mixed_qual = max(min_snv_qual, min_indel_qual)

    if isinstance(qual_approx_expr, hl.expr.ArrayNumericExpression):
        return hl.range(1, hl.len(alleles)).map(
            lambda ai: hl.cond(
                hl.is_snp(alleles[0], alleles[ai]),
                qual_approx_expr[ai - 1] < min_snv_qual,
                qual_approx_expr[ai - 1] < min_indel_qual,
            )
        )
    else:
        return (
            hl.case()
            .when(
                hl.range(1, hl.len(alleles)).all(
                    lambda ai: hl.is_snp(alleles[0], alleles[ai])
                ),
                qual_approx_expr < min_snv_qual,
            )
            .when(
                hl.range(1, hl.len(alleles)).all(
                    lambda ai: hl.is_indel(alleles[0], alleles[ai])
                ),
                qual_approx_expr < min_indel_qual,
            )
            .default(qual_approx_expr < min_mixed_qual)
        )
Ejemplo n.º 7
0
def filter_samples(vds: 'VariantDataset', samples_table: 'Table', *,
                   keep: bool = True,
                   remove_dead_alleles: bool = False) -> 'VariantDataset':
    """Filter samples in a :class:`.VariantDataset`.

    Parameters
    ----------
    vds : :class:`.VariantDataset`
        Dataset in VariantDataset representation.
    samples_table : :class:`.Table`
        Samples to filter on.
    keep : :obj:`bool`
        Whether to keep (default), or filter out the samples from `samples_table`.
    remove_dead_alleles : :obj:`bool`
        If true, remove alleles observed in no samples. Alleles with AC == 0 will be
        removed, and LA values recalculated.

    Returns
    -------
    :class:`.VariantDataset`
    """
    if not list(samples_table[x].dtype for x in samples_table.key) == [hl.tstr]:
        raise TypeError(f'invalid key: {samples_table.key.dtype}')
    samples_to_keep = samples_table.aggregate(hl.agg.collect_as_set(samples_table.key[0]), _localize=False)._persist()
    reference_data = vds.reference_data.filter_cols(samples_to_keep.contains(vds.reference_data.col_key[0]), keep=keep)
    reference_data = reference_data.filter_rows(hl.agg.count() > 0)
    variant_data = vds.variant_data.filter_cols(samples_to_keep.contains(vds.variant_data.col_key[0]), keep=keep)

    if remove_dead_alleles:
        vd = variant_data
        vd = vd.annotate_rows(__allele_counts=hl.agg.explode(lambda x: hl.agg.counter(x), vd.LA), __n=hl.agg.count())
        vd = vd.filter_rows(vd.__n > 0)

        vd = vd.annotate_rows(__kept_indices=hl.dict(
            hl.enumerate(
                hl.range(hl.len(vd.alleles)).filter(lambda idx: (idx == 0) | (vd.__allele_counts.get(idx, 0) > 0)),
                index_first=False)))

        vd = vd.annotate_rows(
            __old_to_new_LA=hl.range(hl.len(vd.alleles)).map(lambda idx: vd.__kept_indices.get(idx, -1)))

        def new_la_index(old_idx):
            raw_idx = vd.__old_to_new_LA[old_idx]
            return hl.case().when(raw_idx >= 0, raw_idx) \
                .or_error("'filter_samples': unexpected local allele: old index=" + hl.str(old_idx))

        vd = vd.annotate_entries(LA=vd.LA.map(lambda la: new_la_index(la)))
        vd = vd.key_rows_by('locus')
        vd = vd.annotate_rows(alleles=vd.__kept_indices.keys().map(lambda i: vd.alleles[i]))
        vd = vd._key_rows_by_assert_sorted('locus', 'alleles')
        vd = vd.drop('__allele_counts', '__kept_indices', '__old_to_new_LA')
        return VariantDataset(reference_data, vd)

    variant_data = variant_data.filter_rows(hl.agg.count() > 0)
    return VariantDataset(reference_data, variant_data)
Ejemplo n.º 8
0
    def test_explode_rows(self):
        mt = hl.utils.range_matrix_table(4, 4)
        mt = mt.annotate_entries(e=mt.row_idx * 10 + mt.col_idx)

        self.assertTrue(mt.annotate_rows(x=[1]).explode_rows('x').drop('x')._same(mt))

        self.assertEqual(mt.annotate_rows(x=hl.empty_array('int')).explode_rows('x').count_rows(), 0)
        self.assertEqual(mt.annotate_rows(x=hl.null('array<int>')).explode_rows('x').count_rows(), 0)
        self.assertEqual(mt.annotate_rows(x=hl.range(0, mt.row_idx)).explode_rows('x').count_rows(), 6)
        mt = mt.annotate_rows(x=hl.struct(y=hl.range(0, mt.row_idx)))
        self.assertEqual(mt.explode_rows(mt.x.y).count_rows(), 6)
Ejemplo n.º 9
0
    def test_explode_rows(self):
        mt = hl.utils.range_matrix_table(4, 4)
        mt = mt.annotate_entries(e=mt.row_idx * 10 + mt.col_idx)

        self.assertTrue(mt.annotate_rows(x=[1]).explode_rows('x').drop('x')._same(mt))

        self.assertEqual(mt.annotate_rows(x=hl.empty_array('int')).explode_rows('x').count_rows(), 0)
        self.assertEqual(mt.annotate_rows(x=hl.null('array<int>')).explode_rows('x').count_rows(), 0)
        self.assertEqual(mt.annotate_rows(x=hl.range(0, mt.row_idx)).explode_rows('x').count_rows(), 6)
        mt = mt.annotate_rows(x=hl.struct(y=hl.range(0, mt.row_idx)))
        self.assertEqual(mt.explode_rows(mt.x.y).count_rows(), 6)
Ejemplo n.º 10
0
def get_group_to_counts_expr(k: hl.expr.StructExpression, counts: hl.expr.DictExpression) -> hl.expr.ArrayExpression:
    return hl.range(1, k.snv - 1, step=-1).flatmap(
        lambda snv: hl.range(0, k.all + 1).flatmap(
            lambda af: hl.range(0, k.csq + 1).map(
                lambda csq: hl.struct(snv=hl.bool(snv), all=hl.bool(af), csq=csq)
            )
        )
    ).filter(
        lambda key: counts.contains(key)
    ).map(
        lambda key: counts[key]
    )
Ejemplo n.º 11
0
def mwzj_hts_by_tree(all_hts,
                     temp_dir,
                     globals_for_col_key,
                     debug=False,
                     inner_mode='overwrite',
                     repartition_final: int = None):
    chunk_size = int(len(all_hts)**0.5) + 1
    outer_hts = []

    checkpoint_kwargs = {inner_mode: True}
    if repartition_final is not None:
        intervals = get_n_even_intervals(repartition_final)
        checkpoint_kwargs['_intervals'] = intervals

    if debug: print(f'Running chunk size {chunk_size}...')
    for i in range(chunk_size):
        if i * chunk_size >= len(all_hts): break
        hts = all_hts[i * chunk_size:(i + 1) * chunk_size]
        if debug:
            print(
                f'Going from {i * chunk_size} to {(i + 1) * chunk_size} ({len(hts)} HTs)...'
            )
        try:
            if isinstance(hts[0], str):
                hts = list(map(lambda x: hl.read_table(x), hts))
            ht = hl.Table.multi_way_zip_join(hts, 'row_field_name',
                                             'global_field_name')
        except:
            if debug:
                print(
                    f'problem in range {i * chunk_size}-{i * chunk_size + chunk_size}'
                )
                _ = [ht.describe() for ht in hts]
            raise
        outer_hts.append(
            ht.checkpoint(f'{temp_dir}/temp_output_{i}.ht',
                          **checkpoint_kwargs))
    ht = hl.Table.multi_way_zip_join(outer_hts, 'row_field_name_outer',
                                     'global_field_name_outer')
    ht = ht.transmute(inner_row=hl.flatmap(
        lambda i: hl.cond(
            hl.is_missing(ht.row_field_name_outer[i].row_field_name),
            hl.range(0, hl.len(ht.global_field_name_outer[i].global_field_name)
                     ).map(lambda _: hl.null(ht.row_field_name_outer[
                         i].row_field_name.dtype.element_type)), ht.
            row_field_name_outer[i].row_field_name),
        hl.range(hl.len(ht.global_field_name_outer))))
    ht = ht.transmute_globals(inner_global=hl.flatmap(
        lambda x: x.global_field_name, ht.global_field_name_outer))
    mt = ht._unlocalize_entries('inner_row', 'inner_global',
                                globals_for_col_key)
    return mt
Ejemplo n.º 12
0
 def test_filter_alleles(self):
     # poor man's Gen
     paths = [resource('sample.vcf'),
              resource('multipleChromosomes.vcf'),
              resource('sample2.vcf')]
     for path in paths:
         ds = hl.import_vcf(path)
         self.assertEqual(
             hl.FilterAlleles(hl.range(0, ds.alleles.length() - 1).map(lambda i: False))
                 .filter()
                 .count_rows(), 0)
         self.assertEqual(
             hl.FilterAlleles(hl.range(0, ds.alleles.length() - 1).map(lambda i: True))
                 .filter()
                 .count_rows(), ds.count_rows())
Ejemplo n.º 13
0
def apply_mito_artifact_filter(
    mt: hl.MatrixTable,
    artifact_prone_sites_path: str,
) -> hl.MatrixTable:
    """Add back in artifact_prone_site filter

    :param hl.MatrixTable mt: MatrixTable to use an input
    :param str artifact_prone_sites_path: path to BED file of artifact_prone_sites to flag in the filters column

    :return: MatrixTable with artifact_prone_sites filter
    :rtype: hl.MatrixTable

    """

    # apply "artifact_prone_site" filter to any SNP or deletion that spans a known problematic site
    mt = mt.annotate_rows(
        position_range=hl.range(mt.locus.position, mt.locus.position +
                                hl.len(mt.alleles[0])))

    artifact_sites = []
    with hl.hadoop_open(artifact_prone_sites_path) as f:
        for line in f:
            pos = line.split()[2]
            artifact_sites.append(int(pos))
    sites = hl.literal(set(artifact_sites))

    mt = mt.annotate_rows(filters=hl.if_else(
        hl.len(hl.set(mt.position_range).intersection(sites)) > 0,
        {"artifact_prone_site"},
        {"PASS"},
    ))

    mt = mt.drop("position_range")

    return mt
Ejemplo n.º 14
0
def pop_max_expr(
    freq: hl.expr.ArrayExpression,
    freq_meta: hl.expr.ArrayExpression,
    pops_to_exclude: Optional[Set[str]] = None,
) -> hl.expr.StructExpression:
    """
    Creates an expression containing popmax: the frequency information about the population
    that has the highest AF from the populations provided in `freq_meta`,
    excluding those specified in `pops_to_exclude`.
    Only frequencies from adj populations are considered.

    This resulting struct contains the following fields:

        - AC: int32
        - AF: float64
        - AN: int32
        - homozygote_count: int32
        - pop: str

    :param freq: ArrayExpression of Structs with fields ['AC', 'AF', 'AN', 'homozygote_count']
    :param freq_meta: ArrayExpression of meta dictionaries corresponding to freq (as returned by annotate_freq)
    :param pops_to_exclude: Set of populations to skip for popmax calcluation

    :return: Popmax struct
    """
    _pops_to_exclude = hl.literal(pops_to_exclude)
    popmax_freq_indices = hl.range(0, hl.len(freq_meta)).filter(
        lambda i: (hl.set(freq_meta[i].keys()) == {"group", "pop"})
        & (freq_meta[i]["group"] == "adj")
        & (~_pops_to_exclude.contains(freq_meta[i]["pop"])))
    freq_filtered = popmax_freq_indices.map(lambda i: freq[i].annotate(
        pop=freq_meta[i]["pop"])).filter(lambda f: f.AC > 0)

    sorted_freqs = hl.sorted(freq_filtered, key=lambda x: x.AF, reverse=True)
    return hl.or_missing(hl.len(sorted_freqs) > 0, sorted_freqs[0])
Ejemplo n.º 15
0
def transform_one(mt: MatrixTable) -> MatrixTable:
    """transforms a gvcf into a form suitable for combining"""
    mt = mt.annotate_entries(
        # local (alt) allele index into global (alt) alleles
        LA=hl.range(0, hl.len(mt.alleles)),
        END=mt.info.END,
        BaseQRankSum=mt.info['BaseQRankSum'],
        ClippingRankSum=mt.info['ClippingRankSum'],
        MQ=mt.info['MQ'],
        MQRankSum=mt.info['MQRankSum'],
        ReadPosRankSum=mt.info['ReadPosRankSum'],
    )
    mt = mt.annotate_rows(
        info=mt.info.annotate(
            SB_TABLE=hl.array([
                hl.agg.sum(mt.entry.SB[0]),
                hl.agg.sum(mt.entry.SB[1]),
                hl.agg.sum(mt.entry.SB[2]),
                hl.agg.sum(mt.entry.SB[3]),
            ])
        ).select(
            "MQ_DP",
            "QUALapprox",
            "RAW_MQ",
            "VarDP",
            "SB_TABLE",
        ))
    mt = mt.transmute_entries(
        LGT=mt.GT,
        LAD=mt.AD[0:],  # requiredness issues :'(
        LPL=mt.PL[0:],
        LPGT=mt.PGT)
    mt = mt.drop('SB', 'qual', 'filters')

    return mt
Ejemplo n.º 16
0
def transform_one(mt: MatrixTable) -> MatrixTable:
    """transforms a gvcf into a form suitable for combining"""
    mt = mt.annotate_entries(
        # local (alt) allele index into global (alt) alleles
        LA=hl.range(0,
                    hl.len(mt.alleles) - 1),
        END=mt.info.END,
        PL=mt['PL'][0:],
        BaseQRankSum=mt.info['BaseQRankSum'],
        ClippingRankSum=mt.info['ClippingRankSum'],
        MQ=mt.info['MQ'],
        MQRankSum=mt.info['MQRankSum'],
        ReadPosRankSum=mt.info['ReadPosRankSum'],
    )
    mt = mt.annotate_rows(info=mt.info.annotate(
        DP=hl.agg.sum(mt.entry.DP),
        SB=hl.agg.array_sum(mt.entry.SB),
    ).select(
        "DP",
        "MQ_DP",
        "QUALapprox",
        "RAW_MQ",
        "VarDP",
        "SB",
    ))
    mt = mt.drop('SB', 'qual')

    return mt
Ejemplo n.º 17
0
def full(shape, value):
    if isinstance(shape, Int64Expression):
        shape_product = shape
    else:
        shape_product = reduce(lambda a, b: a * b, shape)
    return array(hl.range(
        hl.int32(shape_product)).map(lambda x: value)).reshape(shape)
Ejemplo n.º 18
0
 def compute_element(absolute):
     return hl.rbind(
         absolute % n_rows,
         absolute // n_rows,
         lambda row, col: hl.range(hl.int(n_inner)).map(
             lambda inner: multiply(left[row, inner], right[inner, col])
         ).fold(add, zero))
Ejemplo n.º 19
0
def set_female_y_metrics_to_na_expr(
        t: Union[hl.Table, hl.MatrixTable]) -> hl.expr.ArrayExpression:
    """
    Set Y-variant frequency callstats for female-specific metrics to missing structs.

    .. note:: Requires freq, freq_meta, and freq_index_dict annotations to be present in Table or MatrixTable

    :param t: Table or MatrixTable for which to adjust female metrics
    :return: Hail array expression to set female Y-variant metrics to missing values
    """
    female_idx = hl.map(
        lambda x: t.freq_index_dict[x],
        hl.filter(lambda x: x.contains("XX"), t.freq_index_dict.keys()),
    )
    freq_idx_range = hl.range(hl.len(t.freq_meta))

    new_freq_expr = hl.if_else(
        (t.locus.in_y_nonpar() | t.locus.in_y_par()),
        hl.map(
            lambda x: hl.if_else(female_idx.contains(x),
                                 missing_callstats_expr(), t.freq[x]),
            freq_idx_range,
        ),
        t.freq,
    )

    return new_freq_expr
Ejemplo n.º 20
0
def test_ndarray_eval():
    data_list = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
    nd_expr = hl._ndarray(data_list)
    evaled = hl.eval(nd_expr)
    np_equiv = np.array(data_list, dtype=np.int32)
    assert(np.array_equal(evaled, np_equiv))
    assert(evaled.strides == np_equiv.strides)

    assert hl.eval(hl._ndarray([[], []])).strides == (8, 8)
    assert np.array_equal(hl.eval(hl._ndarray([])), np.array([]))

    zero_array = np.zeros((10, 10), dtype=np.int64)
    evaled_zero_array = hl.eval(hl.literal(zero_array))

    assert np.array_equal(evaled_zero_array, zero_array)
    assert zero_array.dtype == evaled_zero_array.dtype

    # Testing from hail arrays
    assert np.array_equal(hl.eval(hl._ndarray(hl.range(6))), np.arange(6))
    assert np.array_equal(hl.eval(hl._ndarray(hl.int64(4))), np.array(4))

    # Testing missing data
    assert hl.eval(hl._ndarray(hl.null(hl.tarray(hl.tint32)))) is None

    with pytest.raises(ValueError) as exc:
        hl._ndarray([[4], [1, 2, 3], 5])
    assert "inner dimensions do not match" in str(exc.value)
Ejemplo n.º 21
0
def multi_way_union_mts(mts: list, tmp_dir: str,
                        chunk_size: int) -> hl.MatrixTable:
    """Joins MatrixTables in the provided list

    :param list mts: list of MatrixTables to join together
    :param str tmp_dir: path to temporary directory for intermediate results
    :param int chunk_size: number of MatrixTables to join per chunk

    :return: joined MatrixTable
    :rtype: MatrixTable
    """

    staging = [mt.localize_entries("__entries", "__cols") for mt in mts]
    stage = 0
    while len(staging) > 1:
        n_jobs = int(math.ceil(len(staging) / chunk_size))
        info(f"multi_way_union_mts: stage {stage}: {n_jobs} total jobs")
        next_stage = []
        for i in range(n_jobs):
            to_merge = staging[chunk_size * i:chunk_size * (i + 1)]
            info(
                f"multi_way_union_mts: stage {stage} / job {i}: merging {len(to_merge)} inputs"
            )

            merged = hl.Table.multi_way_zip_join(to_merge, "__entries",
                                                 "__cols")
            merged = merged.annotate(__entries=hl.flatten(
                hl.range(hl.len(merged.__entries)).map(lambda i: hl.coalesce(
                    merged.__entries[i].__entries,
                    hl.range(hl.len(merged.__cols[i].__cols)).map(
                        lambda j: hl.null(merged.__entries.__entries.dtype.
                                          element_type.element_type)),
                ))))
            merged = merged.annotate_globals(
                __cols=hl.flatten(merged.__cols.map(lambda x: x.__cols)))

            next_stage.append(
                merged.checkpoint(os.path.join(tmp_dir,
                                               f"stage_{stage}_job_{i}.ht"),
                                  overwrite=True))
        info(f"done stage {stage}")
        stage += 1
        staging.clear()
        staging.extend(next_stage)

    return (staging[0]._unlocalize_entries(
        "__entries", "__cols", list(mts[0].col_key)).unfilter_entries())
Ejemplo n.º 22
0
def mwzj_hts_by_tree(all_hts, temp_dir, globals_for_col_key, 
                     debug=False, inner_mode = 'overwrite', repartition_final: int = None,
                     read_if_exists = False):
    r'''
    Adapted from ukb_common mwzj_hts_by_tree()
    Uses read_clump_ht() instead of read_table()
    '''
    chunk_size = int(len(all_hts) ** 0.5) + 1
    outer_hts = []
    
    if read_if_exists: print('\n\nWARNING: Intermediate tables will not be overwritten if they already exist\n\n')
    
    checkpoint_kwargs = {inner_mode: not read_if_exists,
                         '_read_if_exists': read_if_exists} #
    if repartition_final is not None:
        intervals = ukb_common.get_n_even_intervals(repartition_final)
        checkpoint_kwargs['_intervals'] = intervals
    
    if debug: print(f'Running chunk size {chunk_size}...')
    for i in range(chunk_size):
        if i * chunk_size >= len(all_hts): break
        hts = all_hts[i * chunk_size:(i + 1) * chunk_size]
        if debug: print(f'Going from {i * chunk_size} to {(i + 1) * chunk_size} ({len(hts)} HTs)...')
        try:
            if isinstance(hts[0], str):
                def read_clump_ht(f):
                    ht = hl.read_table(f)
                    ht = ht.drop('idx')
                    return ht
                hts = list(map(read_clump_ht, hts))
            ht = hl.Table.multi_way_zip_join(hts, 'row_field_name', 'global_field_name')
        except:
            if debug:
                print(f'problem in range {i * chunk_size}-{i * chunk_size + chunk_size}')
                _ = [ht.describe() for ht in hts]
            raise
        outer_hts.append(ht.checkpoint(f'{temp_dir}/temp_output_{i}.ht', **checkpoint_kwargs))
    ht = hl.Table.multi_way_zip_join(outer_hts, 'row_field_name_outer', 'global_field_name_outer')
    ht = ht.transmute(inner_row=hl.flatmap(lambda i:
                                           hl.cond(hl.is_missing(ht.row_field_name_outer[i].row_field_name),
                                                   hl.range(0, hl.len(ht.global_field_name_outer[i].global_field_name))
                                                   .map(lambda _: hl.null(ht.row_field_name_outer[i].row_field_name.dtype.element_type)),
                                                   ht.row_field_name_outer[i].row_field_name),
                                           hl.range(hl.len(ht.global_field_name_outer))))
    ht = ht.transmute_globals(inner_global=hl.flatmap(lambda x: x.global_field_name, ht.global_field_name_outer))
    mt = ht._unlocalize_entries('inner_row', 'inner_global', globals_for_col_key)
    return mt
def pre_process_subset_freq(subset: str,
                            global_ht: hl.Table,
                            test: bool = False) -> hl.Table:
    """
    Prepare subset frequency Table by filling in missing frequency fields for loci present only in the global cohort.

    .. note::

        The resulting final `freq` array will be as long as the subset `freq_meta` global (i.e., one `freq` entry for each `freq_meta` entry)

    :param subset: subset ID
    :param global_ht: Hail Table containing all variants discovered in the overall release cohort
    :param test: If True, filter to small region on chr20
    :return: Table containing subset frequencies with missing freq structs filled in
    """

    # Read in subset HTs
    subset_ht_path = get_freq(subset=subset).path
    subset_chr20_ht_path = qc_temp_prefix() + f"chr20_test_freq.{subset}.ht"

    if test:
        if file_exists(subset_chr20_ht_path):
            logger.info(
                "Loading chr20 %s subset frequency data for testing: %s",
                subset,
                subset_chr20_ht_path,
            )
            subset_ht = hl.read_table(subset_chr20_ht_path)

        elif file_exists(subset_ht_path):
            logger.info(
                "Loading %s subset frequency data for testing: %s",
                subset,
                subset_ht_path,
            )
            subset_ht = hl.read_table(subset_ht_path)
            subset_ht = hl.filter_intervals(
                subset_ht, [hl.parse_locus_interval("chr20:1-1000000")])

    elif file_exists(subset_ht_path):
        logger.info("Loading %s subset frequency data: %s", subset,
                    subset_ht_path)
        subset_ht = hl.read_table(subset_ht_path)

    else:
        raise DataException(
            f"Hail Table containing {subset} subset frequencies not found. You may need to run the script generate_freq_data.py to generate frequency annotations first."
        )

    # Fill in missing freq structs
    ht = subset_ht.join(global_ht.select().select_globals(), how="right")
    ht = ht.annotate(freq=hl.if_else(
        hl.is_missing(ht.freq),
        hl.map(lambda x: missing_callstats_expr(),
               hl.range(hl.len(ht.freq_meta))),
        ht.freq,
    ))

    return ht
Ejemplo n.º 24
0
def default_compute_info(mt: hl.MatrixTable,
                         site_annotations: bool = False,
                         n_partitions: int = 5000) -> hl.Table:
    """
    Computes a HT with the typical GATK allele-specific (AS) info fields 
    as well as ACs and lowqual fields.
    Note that this table doesn't split multi-allelic sites.

    :param mt: Input MatrixTable. Note that this table should be filtered to nonref sites.
    :param site_annotations: Whether to also generate site level info fields. Default is False.
    :param n_partitions: Number of desired partitions for output Table. Default is 5000.
    :return: Table with info fields
    :rtype: Table
    """
    # Move gvcf info entries out from nested struct
    mt = mt.transmute_entries(**mt.gvcf_info)

    # Compute AS info expr
    info_expr = get_as_info_expr(mt)

    if site_annotations:
        info_expr = info_expr.annotate(**get_site_info_expr(mt))

    # Add AC and AC_raw:
    # First compute ACs for each non-ref allele, grouped by adj
    grp_ac_expr = hl.agg.array_agg(
        lambda ai: hl.agg.filter(
            mt.LA.contains(ai),
            hl.agg.group_by(
                get_adj_expr(mt.LGT, mt.GQ, mt.DP, mt.LAD),
                hl.agg.sum(
                    mt.LGT.one_hot_alleles(mt.LA.map(lambda x: hl.str(x)))[
                        mt.LA.index(ai)]),
            ),
        ),
        hl.range(1, hl.len(mt.alleles)),
    )

    # Then, for each non-ref allele, compute
    # AC as the adj group
    # AC_raw as the sum of adj and non-adj groups
    info_expr = info_expr.annotate(
        AC_raw=grp_ac_expr.map(
            lambda i: hl.int32(i.get(True, 0) + i.get(False, 0))),
        AC=grp_ac_expr.map(lambda i: hl.int32(i.get(True, 0))),
    )

    info_ht = mt.select_rows(info=info_expr).rows()

    # Add AS lowqual flag
    info_ht = info_ht.annotate(AS_lowqual=get_lowqual_expr(
        info_ht.alleles, info_ht.info.AS_QUALapprox))

    if site_annotations:
        # Add lowqual flag
        info_ht = info_ht.annotate(
            lowqual=get_lowqual_expr(info_ht.alleles, info_ht.info.QUALapprox))

    return info_ht.naive_coalesce(n_partitions)
Ejemplo n.º 25
0
 def block_product(left, right):
     product = left @ right
     n_rows, n_cols = product.shape
     return hl.struct(
         shape=product.shape,
         block=hl.range(hl.int(
             n_rows * n_cols)).map(lambda absolute: product[
                 absolute % n_rows, absolute // n_rows]))
Ejemplo n.º 26
0
def pc_relate_5k_5k(mt_path):
    mt = hl.read_matrix_table(mt_path)
    mt = mt.annotate_cols(scores=hl.range(2).map(lambda x: hl.rand_unif(0, 1)))
    rel = hl.pc_relate(mt.GT,
                       0.05,
                       scores_expr=mt.scores,
                       statistics='kin',
                       min_kinship=0.05)
    rel._force_count()
Ejemplo n.º 27
0
def project_max_expr(
    project_expr: hl.expr.StringExpression,
    gt_expr: hl.expr.CallExpression,
    alleles_expr: hl.expr.ArrayExpression,
    n_projects: int = 5,
) -> hl.expr.ArrayExpression:
    """
    Create an expression that computes allele frequency information by project for the `n_projects` with the largest AF at this row.

    Will return an array with one element per non-reference allele.

    Each of these elements is itself an array of structs with the following fields:

        - AC: int32
        - AF: float64
        - AN: int32
        - homozygote_count: int32
        - project: str

    .. note::

        Only projects with AF > 0 are returned.
        In case of ties, the project ordering is not guaranteed, and at most `n_projects` are returned.

    :param project_expr: column expression containing the project
    :param gt_expr: entry expression containing the genotype
    :param alleles_expr: row expression containing the alleles
    :param n_projects: Maximum number of projects to return for each row
    :return: projectmax expression
    """
    n_alleles = hl.len(alleles_expr)

    # compute call stats by  project
    project_cs = hl.array(
        hl.agg.group_by(project_expr, hl.agg.call_stats(gt_expr,
                                                        alleles_expr)))

    return hl.or_missing(
        n_alleles > 1,  # Exclude monomorphic sites
        hl.range(1, n_alleles).map(lambda ai: hl.sorted(
            project_cs.filter(
                # filter to projects with AF > 0
                lambda x: x[1].AF[ai] > 0),
            # order the callstats computed by AF in decreasing order
            lambda x: -x[1].AF[ai]
            # take the n_projects projects with largest AF
        )[:n_projects].map(
            # add the project in the callstats struct
            lambda x: x[1].annotate(
                AC=x[1].AC[ai],
                AF=x[1].AF[ai],
                AN=x[1].AN,
                homozygote_count=x[1].homozygote_count[ai],
                project=x[0],
            ))),
    )
Ejemplo n.º 28
0
def test_ndarray_eval():
    data_list = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
    mishapen_data_list1 = [[4], [1, 2, 3]]
    mishapen_data_list2 = [[[1], [2, 3]]]
    mishapen_data_list3 = [[4], [1, 2, 3], 5]

    nd_expr = hl.nd.array(data_list)
    evaled = hl.eval(nd_expr)
    np_equiv = np.array(data_list, dtype=np.int32)
    np_equiv_fortran_style = np.asfortranarray(np_equiv)
    np_equiv_extra_dimension = np_equiv.reshape((3, 1, 3))
    assert (np.array_equal(evaled, np_equiv))
    assert (evaled.strides == np_equiv.strides)

    assert hl.eval(hl.nd.array([[], []])).strides == (8, 8)
    assert np.array_equal(hl.eval(hl.nd.array([])), np.array([]))

    zero_array = np.zeros((10, 10), dtype=np.int64)
    evaled_zero_array = hl.eval(hl.literal(zero_array))

    assert np.array_equal(evaled_zero_array, zero_array)
    assert zero_array.dtype == evaled_zero_array.dtype

    # Testing correct interpretation of numpy strides
    assert np.array_equal(hl.eval(hl.literal(np_equiv_fortran_style)),
                          np_equiv_fortran_style)
    assert np.array_equal(hl.eval(hl.literal(np_equiv_extra_dimension)),
                          np_equiv_extra_dimension)

    # Testing from hail arrays
    assert np.array_equal(hl.eval(hl.nd.array(hl.range(6))), np.arange(6))
    assert np.array_equal(hl.eval(hl.nd.array(hl.int64(4))), np.array(4))

    # Testing from nested hail arrays
    assert np.array_equal(
        hl.eval(hl.nd.array(hl.array([hl.array(x) for x in data_list]))),
        np.arange(9).reshape((3, 3)) + 1)

    # Testing missing data
    assert hl.eval(hl.nd.array(hl.null(hl.tarray(hl.tint32)))) is None

    with pytest.raises(ValueError) as exc:
        hl.nd.array(mishapen_data_list1)
    assert "inner dimensions do not match" in str(exc.value)

    with pytest.raises(FatalError) as exc:
        hl.eval(hl.nd.array(hl.array(mishapen_data_list1)))
    assert "inner dimensions do not match" in str(exc.value)

    with pytest.raises(FatalError) as exc:
        hl.eval(hl.nd.array(hl.array(mishapen_data_list2)))
    assert "inner dimensions do not match" in str(exc.value)

    with pytest.raises(ValueError) as exc:
        hl.nd.array(mishapen_data_list3)
    assert "inner dimensions do not match" in str(exc.value)
Ejemplo n.º 29
0
def pc_relate_big():
    mt = hl.balding_nichols_model(3, 2 * 4096, 2 * 4096).checkpoint(
        hl.utils.new_temp_file(extension='mt'))
    mt = mt.annotate_cols(scores=hl.range(2).map(lambda x: hl.rand_unif(0, 1)))
    rel = hl.pc_relate(mt.GT,
                       0.05,
                       scores_expr=mt.scores,
                       statistics='kin',
                       min_kinship=0.05)
    rel._force_count()
Ejemplo n.º 30
0
def transform_one(mt, info_to_keep=[]) -> Table:
    """transforms a gvcf into a form suitable for combining

    The input to this should be some result of either :func:`.import_vcf` or
    :func:`.import_vcfs` with `array_elements_required=False`.

    There is a strong assumption that this function will be called on a matrix
    table with one column.
    """
    if not info_to_keep:
        info_to_keep = [name for name in mt.info if name not in ['END', 'DP']]
    mt = localize(mt)

    if mt.row.dtype not in _transform_rows_function_map:
        f = hl.experimental.define_function(
            lambda row: hl.rbind(
                hl.len(row.alleles), '<NON_REF>' == row.alleles[-1],
                lambda alleles_len, has_non_ref: hl.struct(
                    locus=row.locus,
                    alleles=hl.cond(has_non_ref, row.alleles[:-1], row.alleles),
                    rsid=row.rsid,
                    __entries=row.__entries.map(
                        lambda e:
                        hl.struct(
                            DP=e.DP,
                            END=row.info.END,
                            GQ=e.GQ,
                            LA=hl.range(0, alleles_len - hl.cond(has_non_ref, 1, 0)),
                            LAD=hl.cond(has_non_ref, e.AD[:-1], e.AD),
                            LGT=e.GT,
                            LPGT=e.PGT,
                            LPL=hl.cond(has_non_ref,
                                        hl.cond(alleles_len > 2,
                                                e.PL[:-alleles_len],
                                                hl.null(e.PL.dtype)),
                                        hl.cond(alleles_len > 1,
                                                e.PL,
                                                hl.null(e.PL.dtype))),
                            MIN_DP=e.MIN_DP,
                            PID=e.PID,
                            RGQ=hl.cond(
                                has_non_ref,
                                e.PL[hl.call(0, alleles_len - 1).unphased_diploid_gt_index()],
                                hl.null(e.PL.dtype.element_type)),
                            SB=e.SB,
                            gvcf_info=hl.case()
                                .when(hl.is_missing(row.info.END),
                                      hl.struct(**(row.info.select(*info_to_keep))))
                                .or_missing()
                        ))),
            ),
            mt.row.dtype)
        _transform_rows_function_map[mt.row.dtype] = f
    transform_row = _transform_rows_function_map[mt.row.dtype]
    return Table(TableMapRows(mt._tir, Apply(transform_row._name, transform_row._ret_type, TopLevelReference('row'))))
Ejemplo n.º 31
0
def gt_to_gp(mt, location: str = 'GP'):
    return mt.annotate_entries(
        **{
            location:
            hl.or_missing(
                hl.is_defined(mt.GT),
                hl.map(
                    lambda i: hl.cond(mt.GT.unphased_diploid_gt_index() == i,
                                      1.0, 0.0),
                    hl.range(0, hl.triangle(hl.len(mt.alleles)))))
        })
Ejemplo n.º 32
0
def resume_mwzj(temp_dir, globals_for_col_key):
    r'''
    For resuming multiway zip join if intermediate tables have already been written
    '''
    ls = hl.hadoop_ls(temp_dir)
    paths = [x['path'] for x in ls if 'temp_output' in x['path'] ]
    chunk_size = len(paths)
    outer_hts = []
    for i in range(chunk_size):
        outer_hts.append(hl.read_table(f'{temp_dir}/temp_output_{i}.ht'))
    ht = hl.Table.multi_way_zip_join(outer_hts, 'row_field_name_outer', 'global_field_name_outer')
    ht = ht.transmute(inner_row=hl.flatmap(lambda i:
                                           hl.cond(hl.is_missing(ht.row_field_name_outer[i].row_field_name),
                                                   hl.range(0, hl.len(ht.global_field_name_outer[i].global_field_name))
                                                   .map(lambda _: hl.null(ht.row_field_name_outer[i].row_field_name.dtype.element_type)),
                                                   ht.row_field_name_outer[i].row_field_name),
                                           hl.range(hl.len(ht.global_field_name_outer))))
    ht = ht.transmute_globals(inner_global=hl.flatmap(lambda x: x.global_field_name, ht.global_field_name_outer))
    mt = ht._unlocalize_entries('inner_row', 'inner_global', globals_for_col_key)
    return mt
Ejemplo n.º 33
0
        def with_local_a_index(local_a_index):
            new_pl = hl.or_missing(
                hl.is_defined(old_entry.LPL),
                hl.or_missing(
                    hl.is_defined(local_a_index),
                    hl.range(0, 3).map(lambda i: hl.min(
                        hl.range(0, hl.triangle(hl.len(old_entry.LA)))
                            .filter(lambda j: hl.downcode(hl.unphased_diploid_gt_index_call(j), local_a_index) == hl.unphased_diploid_gt_index_call(i))
                            .map(lambda idx: old_entry.LPL[idx])))))
            fields = set(old_entry.keys())

            def with_pl(pl):
                new_exprs = {}
                dropped_fields = ['LA']
                if 'LGT' in fields:
                    new_exprs['GT'] = hl.downcode(old_entry.LGT, hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LGT')
                if 'LPGT' in fields:
                    new_exprs['PGT'] = hl.downcode(old_entry.LPGT, hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LPGT')
                if 'LAD' in fields:
                    new_exprs['AD'] = hl.or_missing(
                        hl.is_defined(old_entry.LAD),
                        [old_entry.LAD[0], hl.or_else(old_entry.LAD[local_a_index], 0)]) # second entry zeroed for lack of non-ref AD
                    dropped_fields.append('LAD')
                if 'LPL' in fields:
                    new_exprs['PL'] = pl
                    if 'GQ' in fields:
                        new_exprs['GQ'] = hl.or_else(hl.gq_from_pl(pl), old_entry.GQ)

                    dropped_fields.append('LPL')

                return hl.cond(hl.len(ds.alleles) == 1,
                                   old_entry.annotate(**{f[1:]: old_entry[f] for f in ['LGT', 'LPGT', 'LAD', 'LPL'] if f in fields}).drop(*dropped_fields),
                                   old_entry.annotate(**new_exprs).drop(*dropped_fields))

            if 'LPL' in fields:
                return hl.bind(with_pl, new_pl)
            else:
                return with_pl(None)
Ejemplo n.º 34
0
def densify(sparse_mt):
    """Convert sparse MatrixTable to a dense one.

    Parameters
    ----------
    sparse_mt : :class:`.MatrixTable`
        Sparse MatrixTable to densify.  The first row key field must
        be named ``locus`` and have type ``locus``.  Must have an
        ``END`` entry field of type ``int32``.

    Returns
    -------
    :class:`.MatrixTable`
        The densified MatrixTable.  The ``END`` entry field is dropped.

    """
    if list(sparse_mt.row_key)[0] != 'locus' or not isinstance(sparse_mt.locus.dtype, hl.tlocus):
        raise ValueError("first row key field must be named 'locus' and have type 'locus'")
    if 'END' not in sparse_mt.entry or sparse_mt.END.dtype != hl.tint32:
        raise ValueError("'densify' requires 'END' entry field of type 'int32'")
    col_key_fields = list(sparse_mt.col_key)

    mt = sparse_mt
    mt = sparse_mt.annotate_entries(__contig = mt.locus.contig)
    t = mt._localize_entries('__entries', '__cols')
    t = t.annotate(
        __entries = hl.rbind(
            hl.scan.array_agg(
                lambda entry: hl.scan._prev_nonnull(hl.or_missing(hl.is_defined(entry.END), entry)),
                t.__entries),
            lambda prev_entries: hl.map(
                lambda i:
                hl.rbind(
                    prev_entries[i], t.__entries[i],
                    lambda prev_entry, entry:
                    hl.cond(
                        (~hl.is_defined(entry) &
                         (prev_entry.END >= t.locus.position) &
                         (prev_entry.__contig == t.locus.contig)),
                        prev_entry,
                        entry)),
                hl.range(0, hl.len(t.__entries)))))
    mt = t._unlocalize_entries('__entries', '__cols', col_key_fields)
    mt = mt.drop('__contig', 'END')
    return mt
Ejemplo n.º 35
0
def transform_one(mt: MatrixTable) -> MatrixTable:
    """transforms a gvcf into a form suitable for combining"""
    mt = mt.annotate_entries(
        # local (alt) allele index into global (alt) alleles
        LA=hl.range(0, hl.len(mt.alleles) - 1),
        END=mt.info.END,
        PL=mt['PL'][0:],
        BaseQRankSum=mt.info['BaseQRankSum'],
        ClippingRankSum=mt.info['ClippingRankSum'],
        MQ=mt.info['MQ'],
        MQRankSum=mt.info['MQRankSum'],
        ReadPosRankSum=mt.info['ReadPosRankSum'],
    )
    # This collects all fields with median combiners into arrays so we can calculate medians
    # when needed
    mt = mt.annotate_rows(
        # now minrep'ed (ref, alt) allele pairs
        alleles=hl.bind(lambda ref: mt.alleles[1:].map(lambda alt:
                                                       # minrep <NON_REF>
                                                       hl.struct(ref=hl.cond(alt == "<NON_REF>",
                                                                             ref[0:1],
                                                                             ref),
                                                                 alt=alt)),
                        mt.alleles[0]),
        info=mt.info.annotate(
            SB=hl.agg.array_sum(mt.entry.SB)
        ).select(
            "DP",
            "MQ_DP",
            "QUALapprox",
            "RAW_MQ",
            "VarDP",
            "SB",
        ))
    mt = mt.drop('SB', 'qual')

    return mt
Ejemplo n.º 36
0
def sparse_split_multi(sparse_mt):
    """Splits multiallelic variants on a sparse MatrixTable.

    Takes a dataset formatted like the output of :func:`.vcf_combiner`. The
    splitting will add `was_split` and `a_index` fields, as :func:`.split_multi`
    does. This function drops the `LA` (local alleles) field, as it re-computes
    entry fields based on the new, split globals alleles.

    Variants are split thus:

    - A row with only one (reference) or two (reference and alternate) alleles.

    - A row with multiple alternate alleles  will be split, with one row for
      each alternate allele, and each row will contain two alleles: ref and alt.
      The reference and alternate allele will be minrepped using
      :func:`.min_rep`.

    The split multi logic handles the following entry fields:

        .. code-block:: text

          struct {
            LGT: call
            LAD: array<int32>
            DP: int32
            GQ: int32
            LPL: array<int32>
            RGQ: int32
            LPGT: call
            LA: array<int32>
            END: int32
          }

    All fields except for `LA` are optional, and only handled if they exist.

    - `LA` is used to find the corresponding local allele index for the desired
      global `a_index`, and then dropped from the resulting dataset. If `LA`
      does not contain the global `a_index`, the index for the `<NON_REF>`
      allele is used to process the entry fields.

    - `LGT` and `LPGT` are downcoded using the corresponding local `a_index`.
      They are renamed to `GT` and `PGT` respectively, as the resulting call is
      no longer local.

    - `LAD` is used to create an `AD` field consisting of the allele depths
      corresponding to the reference, global `a_index` allele, and `<NON_REF>`
      allele.

    - `DP` is preserved unchanged.

    - `GQ` is recalculated from the updated `PL`, if it exists, but otherwise
      preserved unchanged.

    - `PL` array elements are calculated from the minimum `LPL` value for all
      allele pairs that downcode to the desired one. (This logic is identical to
      the `PL` logic in :func:`.split_mult_hts`; if a row has an alternate
      allele but it is not present in `LA`, the `PL` field is set to missing.
      The `PL` for `ref/<NON_REF>` in that case can be drawn from `RGQ`.

    - `RGQ` (the ref genotype quality) is preserved unchanged.

    - `END` is untouched.

    Notes
    -----
    This version of split-multi doesn't deal with either duplicate loci (in
    which case the explode could possibly result in out-of-order rows, although
    the actual split_multi function also doesn't handle that case).

    It also checks that min-repping will not change the locus and will error if
    it does. (I believe the VCF combiner checks that this holds true,
    currently.)

    Parameters
    ----------
    sparse_mt : :class:`.MatrixTable`
        Sparse MatrixTable to split.

    Returns
    -------
    :class:`.MatrixTable`
        The split MatrixTable in sparse format.

    """

    hl.methods.misc.require_row_key_variant(sparse_mt, "sparse_split_multi")

    entries = hl.utils.java.Env.get_uid()
    cols = hl.utils.java.Env.get_uid()
    ds = sparse_mt.localize_entries(entries, cols)
    new_id = hl.utils.java.Env.get_uid()

    def struct_from_min_rep(i):
        return hl.bind(lambda mr:
                       (hl.case()
                        .when(ds.locus == mr.locus,
                              hl.struct(
                                  locus=ds.locus,
                                  alleles=[mr.alleles[0], mr.alleles[1]],
                                  a_index=i,
                                  was_split=True))
                        .or_error("Found non-left-aligned variant in sparse_split_multi")),
                       hl.min_rep(ds.locus, [ds.alleles[0], ds.alleles[i]]))

    explode_structs = hl.cond(hl.len(ds.alleles) < 3,
                              [hl.struct(
                                  locus=ds.locus,
                                  alleles=ds.alleles,
                                  a_index=1,
                                  was_split=False)],
                              hl._sort_by(
                                  hl.range(1, hl.len(ds.alleles))
                                      .map(struct_from_min_rep),
                                  lambda l, r: hl._compare(l.alleles, r.alleles) < 0
                              ))

    ds = ds.annotate(**{new_id: explode_structs}).explode(new_id)

    def transform_entries(old_entry):
        def with_local_a_index(local_a_index):
            new_pl = hl.or_missing(
                hl.is_defined(old_entry.LPL),
                hl.or_missing(
                    hl.is_defined(local_a_index),
                    hl.range(0, 3).map(lambda i: hl.min(
                        hl.range(0, hl.triangle(hl.len(old_entry.LA)))
                            .filter(lambda j: hl.downcode(hl.unphased_diploid_gt_index_call(j), local_a_index) == hl.unphased_diploid_gt_index_call(i))
                            .map(lambda idx: old_entry.LPL[idx])))))
            fields = set(old_entry.keys())

            def with_pl(pl):
                new_exprs = {}
                dropped_fields = ['LA']
                if 'LGT' in fields:
                    new_exprs['GT'] = hl.downcode(old_entry.LGT, hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LGT')
                if 'LPGT' in fields:
                    new_exprs['PGT'] = hl.downcode(old_entry.LPGT, hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LPGT')
                if 'LAD' in fields:
                    new_exprs['AD'] = hl.or_missing(
                        hl.is_defined(old_entry.LAD),
                        [old_entry.LAD[0], hl.or_else(old_entry.LAD[local_a_index], 0)]) # second entry zeroed for lack of non-ref AD
                    dropped_fields.append('LAD')
                if 'LPL' in fields:
                    new_exprs['PL'] = pl
                    if 'GQ' in fields:
                        new_exprs['GQ'] = hl.or_else(hl.gq_from_pl(pl), old_entry.GQ)

                    dropped_fields.append('LPL')

                return hl.cond(hl.len(ds.alleles) == 1,
                                   old_entry.annotate(**{f[1:]: old_entry[f] for f in ['LGT', 'LPGT', 'LAD', 'LPL'] if f in fields}).drop(*dropped_fields),
                                   old_entry.annotate(**new_exprs).drop(*dropped_fields))

            if 'LPL' in fields:
                return hl.bind(with_pl, new_pl)
            else:
                return with_pl(None)

        lai = hl.fold(lambda accum, elt:
                        hl.cond(old_entry.LA[elt] == ds[new_id].a_index,
                                elt, accum),
                        hl.null(hl.tint32),
                        hl.range(0, hl.len(old_entry.LA)))
        return hl.bind(with_local_a_index, lai)

    new_row = ds.row.annotate(**{
        'locus': ds[new_id].locus,
        'alleles': ds[new_id].alleles,
        'a_index': ds[new_id].a_index,
        'was_split': ds[new_id].was_split,
        entries: ds[entries].map(transform_entries)
    }).drop(new_id)

    ds = hl.Table(
        hl.ir.TableKeyBy(
            hl.ir.TableMapRows(
                hl.ir.TableKeyBy(ds._tir, ['locus']),
                new_row._ir),
            ['locus', 'alleles'],
            is_sorted=True))
    return ds._unlocalize_entries(entries, cols, list(sparse_mt.col_key.keys()))
Ejemplo n.º 37
0
def ld_score_regression(weight_expr,
                        ld_score_expr,
                        chi_sq_exprs,
                        n_samples_exprs,
                        n_blocks=200,
                        two_step_threshold=30,
                        n_reference_panel_variants=None) -> Table:
    r"""Estimate SNP-heritability and level of confounding biases from
    GWAS summary statistics.

    Given a set or multiple sets of genome-wide association study (GWAS)
    summary statistics, :func:`.ld_score_regression` estimates the heritability
    of a trait or set of traits and the level of confounding biases present in
    the underlying studies by regressing chi-squared statistics on LD scores,
    leveraging the model:

    .. math::

        \mathrm{E}[\chi_j^2] = 1 + Na + \frac{Nh_g^2}{M}l_j

    *  :math:`\mathrm{E}[\chi_j^2]` is the expected chi-squared statistic
       for variant :math:`j` resulting from a test of association between
       variant :math:`j` and a trait.
    *  :math:`l_j = \sum_{k} r_{jk}^2` is the LD score of variant
       :math:`j`, calculated as the sum of squared correlation coefficients
       between variant :math:`j` and nearby variants. See :func:`ld_score`
       for further details.
    *  :math:`a` captures the contribution of confounding biases, such as
       cryptic relatedness and uncontrolled population structure, to the
       association test statistic.
    *  :math:`h_g^2` is the SNP-heritability, or the proportion of variation
       in the trait explained by the effects of variants included in the
       regression model above.
    *  :math:`M` is the number of variants used to estimate :math:`h_g^2`.
    *  :math:`N` is the number of samples in the underlying association study.

    For more details on the method implemented in this function, see:

    * `LD Score regression distinguishes confounding from polygenicity in genome-wide association studies (Bulik-Sullivan et al, 2015) <https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4495769/>`__

    Examples
    --------

    Run the method on a matrix table of summary statistics, where the rows
    are variants and the columns are different phenotypes:

    >>> mt_gwas = hl.read_matrix_table('data/ld_score_regression.sumstats.mt')
    >>> ht_results = hl.experimental.ld_score_regression(
    ...     weight_expr=mt_gwas['ld_score'],
    ...     ld_score_expr=mt_gwas['ld_score'],
    ...     chi_sq_exprs=mt_gwas['chi_squared'],
    ...     n_samples_exprs=mt_gwas['n'])


    Run the method on a table with summary statistics for a single
    phenotype:

    >>> ht_gwas = hl.read_table('data/ld_score_regression.sumstats.ht')
    >>> ht_results = hl.experimental.ld_score_regression(
    ...     weight_expr=ht_gwas['ld_score'],
    ...     ld_score_expr=ht_gwas['ld_score'],
    ...     chi_sq_exprs=ht_gwas['chi_squared_50_irnt'],
    ...     n_samples_exprs=ht_gwas['n_50_irnt'])

    Run the method on a table with summary statistics for multiple
    phenotypes:

    >>> ht_gwas = hl.read_table('data/ld_score_regression.sumstats.ht')
    >>> ht_results = hl.experimental.ld_score_regression(
    ...     weight_expr=ht_gwas['ld_score'],
    ...     ld_score_expr=ht_gwas['ld_score'],
    ...     chi_sq_exprs=[ht_gwas['chi_squared_50_irnt'],
    ...                        ht_gwas['chi_squared_20160']],
    ...     n_samples_exprs=[ht_gwas['n_50_irnt'],
    ...                      ht_gwas['n_20160']])

    Notes
    -----
    The ``exprs`` provided as arguments to :func:`.ld_score_regression`
    must all be from the same object, either a :class:`Table` or a
    :class:`MatrixTable`.

    **If the arguments originate from a table:**

    *  The table must be keyed by fields ``locus`` of type
       :class:`.tlocus` and ``alleles``, a :py:data:`.tarray` of
       :py:data:`.tstr` elements.
    *  ``weight_expr``, ``ld_score_expr``, ``chi_sq_exprs``, and
       ``n_samples_exprs`` are must be row-indexed fields.
    *  The number of expressions passed to ``n_samples_exprs`` must be
       equal to one or the number of expressions passed to
       ``chi_sq_exprs``. If just one expression is passed to
       ``n_samples_exprs``, that sample size expression is assumed to
       apply to all sets of statistics passed to ``chi_sq_exprs``.
       Otherwise, the expressions passed to ``chi_sq_exprs`` and
       ``n_samples_exprs`` are matched by index.
    *  The ``phenotype`` field that keys the table returned by
       :func:`.ld_score_regression` will have generic :obj:`int` values
       ``0``, ``1``, etc. corresponding to the ``0th``, ``1st``, etc.
       expressions passed to the ``chi_sq_exprs`` argument.

    **If the arguments originate from a matrix table:**

    *  The dimensions of the matrix table must be variants
       (rows) by phenotypes (columns).
    *  The rows of the matrix table must be keyed by fields
       ``locus`` of type :class:`.tlocus` and ``alleles``,
       a :py:data:`.tarray` of :py:data:`.tstr` elements.
    *  The columns of the matrix table must be keyed by a field
       of type :py:data:`.tstr` that uniquely identifies phenotypes
       represented in the matrix table. The column key must be a single
       expression; compound keys are not accepted.
    *  ``weight_expr`` and ``ld_score_expr`` must be row-indexed
       fields.
    *  ``chi_sq_exprs`` must be a single entry-indexed field
       (not a list of fields).
    *  ``n_samples_exprs`` must be a single entry-indexed field
       (not a list of fields).
    *  The ``phenotype`` field that keys the table returned by
       :func:`.ld_score_regression` will have values corresponding to the
       column keys of the input matrix table.

    This function returns a :class:`Table` with one row per set of summary
    statistics passed to the ``chi_sq_exprs`` argument. The following
    row-indexed fields are included in the table:

    *  **phenotype** (:py:data:`.tstr`) -- The name of the phenotype. The
       returned table is keyed by this field. See the notes below for
       details on the possible values of this field.
    *  **mean_chi_sq** (:py:data:`.tfloat64`) -- The mean chi-squared
       test statistic for the given phenotype.
    *  **intercept** (`Struct`) -- Contains fields:

       -  **estimate** (:py:data:`.tfloat64`) -- A point estimate of the
          intercept :math:`1 + Na`.
       -  **standard_error**  (:py:data:`.tfloat64`) -- An estimate of
          the standard error of this point estimate.

    *  **snp_heritability** (`Struct`) -- Contains fields:

       -  **estimate** (:py:data:`.tfloat64`) -- A point estimate of the
          SNP-heritability :math:`h_g^2`.
       -  **standard_error** (:py:data:`.tfloat64`) -- An estimate of
          the standard error of this point estimate.

    Warning
    -------
    :func:`.ld_score_regression` considers only the rows for which both row
    fields ``weight_expr`` and ``ld_score_expr`` are defined. Rows with missing
    values in either field are removed prior to fitting the LD score
    regression model.

    Parameters
    ----------
    weight_expr : :class:`.Float64Expression`
                  Row-indexed expression for the LD scores used to derive
                  variant weights in the model.
    ld_score_expr : :class:`.Float64Expression`
                    Row-indexed expression for the LD scores used as covariates
                    in the model.
    chi_sq_exprs : :class:`.Float64Expression` or :obj:`list` of
                        :class:`.Float64Expression`
                        One or more row-indexed (if table) or entry-indexed
                        (if matrix table) expressions for chi-squared
                        statistics resulting from genome-wide association
                        studies.
    n_samples_exprs: :class:`.NumericExpression` or :obj:`list` of
                     :class:`.NumericExpression`
                     One or more row-indexed (if table) or entry-indexed
                     (if matrix table) expressions indicating the number of
                     samples used in the studies that generated the test
                     statistics supplied to ``chi_sq_exprs``.
    n_blocks : :obj:`int`
               The number of blocks used in the jackknife approach to
               estimating standard errors.
    two_step_threshold : :obj:`int`
                         Variants with chi-squared statistics greater than this
                         value are excluded in the first step of the two-step
                         procedure used to fit the model.
    n_reference_panel_variants : :obj:`int`, optional
                                 Number of variants used to estimate the
                                 SNP-heritability :math:`h_g^2`.

    Returns
    -------
    :class:`.Table`
        Table keyed by ``phenotype`` with intercept and heritability estimates
        for each phenotype passed to the function."""

    chi_sq_exprs = wrap_to_list(chi_sq_exprs)
    n_samples_exprs = wrap_to_list(n_samples_exprs)

    assert ((len(chi_sq_exprs) == len(n_samples_exprs)) or
            (len(n_samples_exprs) == 1))
    __k = 2  # number of covariates, including intercept

    ds = chi_sq_exprs[0]._indices.source

    analyze('ld_score_regression/weight_expr',
            weight_expr,
            ds._row_indices)
    analyze('ld_score_regression/ld_score_expr',
            ld_score_expr,
            ds._row_indices)

    # format input dataset
    if isinstance(ds, MatrixTable):
        if len(chi_sq_exprs) != 1:
            raise ValueError("""Only one chi_sq_expr allowed if originating
                from a matrix table.""")
        if len(n_samples_exprs) != 1:
            raise ValueError("""Only one n_samples_expr allowed if
                originating from a matrix table.""")

        col_key = list(ds.col_key)
        if len(col_key) != 1:
            raise ValueError("""Matrix table must be keyed by a single
                phenotype field.""")

        analyze('ld_score_regression/chi_squared_expr',
                chi_sq_exprs[0],
                ds._entry_indices)
        analyze('ld_score_regression/n_samples_expr',
                n_samples_exprs[0],
                ds._entry_indices)

        ds = ds._select_all(row_exprs={'__locus': ds.locus,
                                       '__alleles': ds.alleles,
                                       '__w_initial': weight_expr,
                                       '__w_initial_floor': hl.max(weight_expr,
                                                                   1.0),
                                       '__x': ld_score_expr,
                                       '__x_floor': hl.max(ld_score_expr,
                                                           1.0)},
                            row_key=['__locus', '__alleles'],
                            col_exprs={'__y_name': ds[col_key[0]]},
                            col_key=['__y_name'],
                            entry_exprs={'__y': chi_sq_exprs[0],
                                         '__n': n_samples_exprs[0]})
        ds = ds.annotate_entries(**{'__w': ds.__w_initial})

        ds = ds.filter_rows(hl.is_defined(ds.__locus) &
                            hl.is_defined(ds.__alleles) &
                            hl.is_defined(ds.__w_initial) &
                            hl.is_defined(ds.__x))

    else:
        assert isinstance(ds, Table)
        for y in chi_sq_exprs:
            analyze('ld_score_regression/chi_squared_expr', y, ds._row_indices)
        for n in n_samples_exprs:
            analyze('ld_score_regression/n_samples_expr', n, ds._row_indices)

        ys = ['__y{:}'.format(i) for i, _ in enumerate(chi_sq_exprs)]
        ws = ['__w{:}'.format(i) for i, _ in enumerate(chi_sq_exprs)]
        ns = ['__n{:}'.format(i) for i, _ in enumerate(n_samples_exprs)]

        ds = ds.select(**dict(**{'__locus': ds.locus,
                                 '__alleles': ds.alleles,
                                 '__w_initial': weight_expr,
                                 '__x': ld_score_expr},
                              **{y: chi_sq_exprs[i]
                                 for i, y in enumerate(ys)},
                              **{w: weight_expr for w in ws},
                              **{n: n_samples_exprs[i]
                                 for i, n in enumerate(ns)}))
        ds = ds.key_by(ds.__locus, ds.__alleles)

        table_tmp_file = new_temp_file()
        ds.write(table_tmp_file)
        ds = hl.read_table(table_tmp_file)

        hts = [ds.select(**{'__w_initial': ds.__w_initial,
                            '__w_initial_floor': hl.max(ds.__w_initial,
                                                        1.0),
                            '__x': ds.__x,
                            '__x_floor': hl.max(ds.__x, 1.0),
                            '__y_name': i,
                            '__y': ds[ys[i]],
                            '__w': ds[ws[i]],
                            '__n': hl.int(ds[ns[i]])})
               for i, y in enumerate(ys)]

        mts = [ht.to_matrix_table(row_key=['__locus',
                                           '__alleles'],
                                  col_key=['__y_name'],
                                  row_fields=['__w_initial',
                                              '__w_initial_floor',
                                              '__x',
                                              '__x_floor'])
               for ht in hts]

        ds = mts[0]
        for i in range(1, len(ys)):
            ds = ds.union_cols(mts[i])

        ds = ds.filter_rows(hl.is_defined(ds.__locus) &
                            hl.is_defined(ds.__alleles) &
                            hl.is_defined(ds.__w_initial) &
                            hl.is_defined(ds.__x))

    mt_tmp_file1 = new_temp_file()
    ds.write(mt_tmp_file1)
    mt = hl.read_matrix_table(mt_tmp_file1)

    if not n_reference_panel_variants:
        M = mt.count_rows()
    else:
        M = n_reference_panel_variants

    # block variants for each phenotype
    n_phenotypes = mt.count_cols()

    mt = mt.annotate_entries(__in_step1=(hl.is_defined(mt.__y) &
                                         (mt.__y < two_step_threshold)),
                             __in_step2=hl.is_defined(mt.__y))

    mt = mt.annotate_cols(__col_idx=hl.int(hl.scan.count()),
                          __m_step1=hl.agg.count_where(mt.__in_step1),
                          __m_step2=hl.agg.count_where(mt.__in_step2))

    col_keys = list(mt.col_key)

    ht = mt.localize_entries(entries_array_field_name='__entries',
                             columns_array_field_name='__cols')

    ht = ht.annotate(__entries=hl.rbind(
        hl.scan.array_agg(
            lambda entry: hl.scan.count_where(entry.__in_step1),
            ht.__entries),
        lambda step1_indices: hl.map(
            lambda i: hl.rbind(
                hl.int(hl.or_else(step1_indices[i], 0)),
                ht.__cols[i].__m_step1,
                ht.__entries[i],
                lambda step1_idx, m_step1, entry: hl.rbind(
                    hl.map(
                        lambda j: hl.int(hl.floor(j * (m_step1 / n_blocks))),
                        hl.range(0, n_blocks + 1)),
                    lambda step1_separators: hl.rbind(
                        hl.set(step1_separators).contains(step1_idx),
                        hl.sum(
                            hl.map(
                                lambda s1: step1_idx >= s1,
                                step1_separators)) - 1,
                        lambda is_separator, step1_block: entry.annotate(
                            __step1_block=step1_block,
                            __step2_block=hl.cond(~entry.__in_step1 & is_separator,
                                                  step1_block - 1,
                                                  step1_block))))),
            hl.range(0, hl.len(ht.__entries)))))

    mt = ht._unlocalize_entries('__entries', '__cols', col_keys)

    mt_tmp_file2 = new_temp_file()
    mt.write(mt_tmp_file2)
    mt = hl.read_matrix_table(mt_tmp_file2)
    
    # initial coefficient estimates
    mt = mt.annotate_cols(__initial_betas=[
        1.0, (hl.agg.mean(mt.__y) - 1.0) / hl.agg.mean(mt.__x)])
    mt = mt.annotate_cols(__step1_betas=mt.__initial_betas,
                          __step2_betas=mt.__initial_betas)

    # step 1 iteratively reweighted least squares
    for i in range(3):
        mt = mt.annotate_entries(__w=hl.cond(
            mt.__in_step1,
            1.0/(mt.__w_initial_floor * 2.0 * (mt.__step1_betas[0] +
                                               mt.__step1_betas[1] *
                                               mt.__x_floor)**2),
            0.0))
        mt = mt.annotate_cols(__step1_betas=hl.agg.filter(
            mt.__in_step1,
            hl.agg.linreg(y=mt.__y,
                          x=[1.0, mt.__x],
                          weight=mt.__w).beta))
        mt = mt.annotate_cols(__step1_h2=hl.max(hl.min(
            mt.__step1_betas[1] * M / hl.agg.mean(mt.__n), 1.0), 0.0))
        mt = mt.annotate_cols(__step1_betas=[
            mt.__step1_betas[0],
            mt.__step1_h2 * hl.agg.mean(mt.__n) / M])

    # step 1 block jackknife
    mt = mt.annotate_cols(__step1_block_betas=[
        hl.agg.filter((mt.__step1_block != i) & mt.__in_step1,
                      hl.agg.linreg(y=mt.__y,
                                    x=[1.0, mt.__x],
                                    weight=mt.__w).beta)
        for i in range(n_blocks)])

    mt = mt.annotate_cols(__step1_block_betas_bias_corrected=hl.map(
        lambda x: n_blocks * mt.__step1_betas - (n_blocks - 1) * x,
        mt.__step1_block_betas))

    mt = mt.annotate_cols(
        __step1_jackknife_mean=hl.map(
            lambda i: hl.mean(
                hl.map(lambda x: x[i],
                       mt.__step1_block_betas_bias_corrected)),
            hl.range(0, __k)),
        __step1_jackknife_variance=hl.map(
            lambda i: (hl.sum(
                hl.map(lambda x: x[i]**2,
                       mt.__step1_block_betas_bias_corrected)) -
                       hl.sum(
                hl.map(lambda x: x[i],
                       mt.__step1_block_betas_bias_corrected))**2 /
                       n_blocks) /
            (n_blocks - 1) / n_blocks,
            hl.range(0, __k)))

    # step 2 iteratively reweighted least squares
    for i in range(3):
        mt = mt.annotate_entries(__w=hl.cond(
            mt.__in_step2,
            1.0/(mt.__w_initial_floor *
                 2.0 * (mt.__step2_betas[0] +
                        mt.__step2_betas[1] *
                        mt.__x_floor)**2),
            0.0))
        mt = mt.annotate_cols(__step2_betas=[
            mt.__step1_betas[0],
            hl.agg.filter(mt.__in_step2,
                          hl.agg.linreg(y=mt.__y - mt.__step1_betas[0],
                                        x=[mt.__x],
                                        weight=mt.__w).beta[0])])
        mt = mt.annotate_cols(__step2_h2=hl.max(hl.min(
            mt.__step2_betas[1] * M/hl.agg.mean(mt.__n), 1.0), 0.0))
        mt = mt.annotate_cols(__step2_betas=[
            mt.__step1_betas[0],
            mt.__step2_h2 * hl.agg.mean(mt.__n)/M])

    # step 2 block jackknife
    mt = mt.annotate_cols(__step2_block_betas=[
        hl.agg.filter((mt.__step2_block != i) & mt.__in_step2,
                      hl.agg.linreg(y=mt.__y - mt.__step1_betas[0],
                                    x=[mt.__x],
                                    weight=mt.__w).beta[0])
        for i in range(n_blocks)])

    mt = mt.annotate_cols(__step2_block_betas_bias_corrected=hl.map(
        lambda x: n_blocks * mt.__step2_betas[1] - (n_blocks - 1) * x,
        mt.__step2_block_betas))

    mt = mt.annotate_cols(
        __step2_jackknife_mean=hl.mean(
            mt.__step2_block_betas_bias_corrected),
        __step2_jackknife_variance=(
            hl.sum(mt.__step2_block_betas_bias_corrected**2) -
            hl.sum(mt.__step2_block_betas_bias_corrected)**2 /
            n_blocks) / (n_blocks - 1) / n_blocks)

    # combine step 1 and step 2 block jackknifes
    mt = mt.annotate_entries(
        __step2_initial_w=1.0/(mt.__w_initial_floor *
                               2.0 * (mt.__initial_betas[0] +
                                      mt.__initial_betas[1] *
                                      mt.__x_floor)**2))

    mt = mt.annotate_cols(
        __final_betas=[
            mt.__step1_betas[0],
            mt.__step2_betas[1]],
        __c=(hl.agg.sum(mt.__step2_initial_w * mt.__x) /
             hl.agg.sum(mt.__step2_initial_w * mt.__x**2)))

    mt = mt.annotate_cols(__final_block_betas=hl.map(
        lambda i: (mt.__step2_block_betas[i] - mt.__c *
                   (mt.__step1_block_betas[i][0] - mt.__final_betas[0])),
        hl.range(0, n_blocks)))

    mt = mt.annotate_cols(
        __final_block_betas_bias_corrected=(n_blocks * mt.__final_betas[1] -
                                            (n_blocks - 1) *
                                            mt.__final_block_betas))

    mt = mt.annotate_cols(
        __final_jackknife_mean=[
            mt.__step1_jackknife_mean[0],
            hl.mean(mt.__final_block_betas_bias_corrected)],
        __final_jackknife_variance=[
            mt.__step1_jackknife_variance[0],
            (hl.sum(mt.__final_block_betas_bias_corrected**2) -
             hl.sum(mt.__final_block_betas_bias_corrected)**2 /
             n_blocks) / (n_blocks - 1) / n_blocks])

    # convert coefficient to heritability estimate
    mt = mt.annotate_cols(
        phenotype=mt.__y_name,
        mean_chi_sq=hl.agg.mean(mt.__y),
        intercept=hl.struct(
            estimate=mt.__final_betas[0],
            standard_error=hl.sqrt(mt.__final_jackknife_variance[0])),
        snp_heritability=hl.struct(
            estimate=(M/hl.agg.mean(mt.__n)) * mt.__final_betas[1],
            standard_error=hl.sqrt((M/hl.agg.mean(mt.__n))**2 *
                                   mt.__final_jackknife_variance[1])))

    # format and return results
    ht = mt.cols()
    ht = ht.key_by(ht.phenotype)
    ht = ht.select(ht.mean_chi_sq,
                   ht.intercept,
                   ht.snp_heritability)

    ht_tmp_file = new_temp_file()
    ht.write(ht_tmp_file)
    ht = hl.read_table(ht_tmp_file)
    
    return ht
Ejemplo n.º 38
0
Archivo: qc.py Proyecto: tpoterba/hail
def summarize_variants(mt: MatrixTable, show=True):
    """Summarize the variants present in a dataset and print the results.

    Examples
    --------
    >>> hl.summarize_variants(dataset)  # doctest: +SKIP
    ==============================
    Number of variants: 346
    ==============================
    Alleles per variant
    -------------------
      2 alleles: 346 variants
    ==============================
    Variants per contig
    -------------------
      20: 346 variants
    ==============================
    Allele type distribution
    ------------------------
            SNP: 301 alleles
       Deletion: 27 alleles
      Insertion: 18 alleles
    ==============================

    Parameters
    ----------
    mt : :class:`.MatrixTable`
        Matrix table with a variant (locus / alleles) row key.
    show : :obj:`bool`
        If ``True``, print results instead of returning them.

    Notes
    -----
    The result returned if `show` is ``False`` is a  :class:`.Struct` with
    four fields:

    - `n_variants` (:obj:`int`): Number of variants present in the matrix table.
    - `allele_types` (:obj:`Dict[str, int]`): Number of alternate alleles in
      each allele allele category.
    - `contigs` (:obj:`Dict[str, int]`): Number of variants on each contig.
    - `allele_counts` (:obj:`Dict[int, int]`): Number of variants broken down
      by number of alleles (biallelic is 2, for example).

    Returns
    -------
    :obj:`None` or :class:`.Struct`
        Returns ``None`` if `show` is ``True``, or returns results as a struct.
    """
    require_row_key_variant(mt, 'summarize_variants')
    alleles_per_variant = hl.range(1, hl.len(mt.alleles)).map(lambda i: hl.allele_type(mt.alleles[0], mt.alleles[i]))
    allele_types, contigs, allele_counts, n_variants = mt.aggregate_rows(
        (hl.agg.explode(lambda elt: hl.agg.counter(elt), alleles_per_variant),
         hl.agg.counter(mt.locus.contig),
         hl.agg.counter(hl.len(mt.alleles)),
         hl.agg.count()))
    rg = mt.locus.dtype.reference_genome
    contig_idx = {contig: i for i, contig in enumerate(rg.contigs)}
    if show:
        max_contig_len = max(len(contig) for contig in contigs)
        contig_formatter = f'%{max_contig_len}s'

        max_allele_count_len = max(len(str(x)) for x in allele_counts)
        allele_count_formatter = f'%{max_allele_count_len}s'

        max_allele_type_len = max(len(x) for x in allele_types)
        allele_type_formatter = f'%{max_allele_type_len}s'

        line_break = '=============================='

        print(line_break)
        print(f'Number of variants: {n_variants}')
        print(line_break)
        print('Alleles per variant')
        print('-------------------')
        for n_alleles, count in sorted(allele_counts.items(), key=lambda x: x[0]):
            print(f'  {allele_count_formatter % n_alleles} alleles: {count} variants')
        print(line_break)
        print('Variants per contig')
        print('-------------------')
        for contig, count in sorted(contigs.items(), key=lambda x: contig_idx[x[0]]):
            print(f'  {contig_formatter % contig}: {count} variants')
        print(line_break)
        print('Allele type distribution')
        print('------------------------')
        for allele_type, count in Counter(allele_types).most_common():
            print(f'  {allele_type_formatter % allele_type}: {count} alternate alleles')
        print(line_break)
    else:
        return hl.Struct(allele_types=allele_types,
                         contigs=contigs,
                         allele_counts=allele_counts,
                         n_variants=n_variants)
Ejemplo n.º 39
0
Archivo: qc.py Proyecto: tpoterba/hail
def sample_qc(mt, name='sample_qc') -> MatrixTable:
    """Compute per-sample metrics useful for quality control.

    .. include:: ../_templates/req_tvariant.rst

    Examples
    --------

    Compute sample QC metrics and remove low-quality samples:

    >>> dataset = hl.sample_qc(dataset, name='sample_qc')
    >>> filtered_dataset = dataset.filter_cols((dataset.sample_qc.dp_stats.mean > 20) & (dataset.sample_qc.r_ti_tv > 1.5))

    Notes
    -----

    This method computes summary statistics per sample from a genetic matrix and stores
    the results as a new column-indexed struct field in the matrix, named based on the
    `name` parameter.

    If `mt` contains an entry field `DP` of type :py:data:`.tint32`, then the
    field `dp_stats` is computed. If `mt` contains an entry field `GQ` of type
    :py:data:`.tint32`, then the field `gq_stats` is computed. Both `dp_stats`
    and `gq_stats` are structs with with four fields:

    - `mean` (``float64``) -- Mean value.
    - `stdev` (``float64``) -- Standard deviation (zero degrees of freedom).
    - `min` (``int32``) -- Minimum value.
    - `max` (``int32``) -- Maximum value.

    If the dataset does not contain an entry field `GT` of type
    :py:data:`.tcall`, then an error is raised. The following fields are always
    computed from `GT`:

    - `call_rate` (``float64``) -- Fraction of calls non-missing.
    - `n_called` (``int64``) -- Number of non-missing calls.
    - `n_not_called` (``int64``) -- Number of missing calls.
    - `n_hom_ref` (``int64``) -- Number of homozygous reference calls.
    - `n_het` (``int64``) -- Number of heterozygous calls.
    - `n_hom_var` (``int64``) -- Number of homozygous alternate calls.
    - `n_non_ref` (``int64``) -- Sum of ``n_het`` and ``n_hom_var``.
    - `n_snp` (``int64``) -- Number of SNP alternate alleles.
    - `n_insertion` (``int64``) -- Number of insertion alternate alleles.
    - `n_deletion` (``int64``) -- Number of deletion alternate alleles.
    - `n_singleton` (``int64``) -- Number of private alleles.
    - `n_transition` (``int64``) -- Number of transition (A-G, C-T) alternate alleles.
    - `n_transversion` (``int64``) -- Number of transversion alternate alleles.
    - `n_star` (``int64``) -- Number of star (upstream deletion) alleles.
    - `r_ti_tv` (``float64``) -- Transition/Transversion ratio.
    - `r_het_hom_var` (``float64``) -- Het/HomVar call ratio.
    - `r_insertion_deletion` (``float64``) -- Insertion/Deletion allele ratio.

    Missing values ``NA`` may result from division by zero.

    Parameters
    ----------
    mt : :class:`.MatrixTable`
        Dataset.
    name : :obj:`str`
        Name for resulting field.

    Returns
    -------
    :class:`.MatrixTable`
        Dataset with a new column-indexed field `name`.
    """

    require_row_key_variant(mt, 'sample_qc')

    from hail.expr.functions import _num_allele_type , _allele_types

    allele_types = _allele_types[:]
    allele_types.extend(['Transition', 'Transversion'])
    allele_enum = {i: v for i, v in enumerate(allele_types)}
    allele_ints = {v: k for k, v in allele_enum.items()}

    def allele_type(ref, alt):
        return hl.bind(lambda at: hl.cond(at == allele_ints['SNP'],
                                          hl.cond(hl.is_transition(ref, alt),
                                                  allele_ints['Transition'],
                                                  allele_ints['Transversion']),
                                          at),
                       _num_allele_type(ref, alt))

    variant_ac = Env.get_uid()
    variant_atypes = Env.get_uid()
    mt = mt.annotate_rows(**{variant_ac: hl.agg.call_stats(mt.GT, mt.alleles).AC,
                             variant_atypes: mt.alleles[1:].map(lambda alt: allele_type(mt.alleles[0], alt))})

    exprs = {}

    def has_field_of_type(name, dtype):
        return name in mt.entry and mt[name].dtype == dtype

    if has_field_of_type('DP', hl.tint32):
        exprs['dp_stats'] = hl.agg.stats(mt.DP).select('mean', 'stdev', 'min', 'max')

    if has_field_of_type('GQ', hl.tint32):
        exprs['gq_stats'] = hl.agg.stats(mt.GQ).select('mean', 'stdev', 'min', 'max')

    if not has_field_of_type('GT',  hl.tcall):
        raise ValueError(f"'sample_qc': expect an entry field 'GT' of type 'call'")

    exprs['n_called'] = hl.agg.count_where(hl.is_defined(mt['GT']))
    exprs['n_not_called'] = hl.agg.count_where(hl.is_missing(mt['GT']))
    exprs['n_hom_ref'] = hl.agg.count_where(mt['GT'].is_hom_ref())
    exprs['n_het'] = hl.agg.count_where(mt['GT'].is_het())
    exprs['n_singleton'] = hl.agg.sum(hl.sum(hl.range(0, mt['GT'].ploidy).map(lambda i: mt[variant_ac][mt['GT'][i]] == 1)))

    def get_allele_type(allele_idx):
        return hl.cond(allele_idx > 0, mt[variant_atypes][allele_idx - 1], hl.null(hl.tint32))

    exprs['allele_type_counts'] = hl.agg.explode(
        lambda elt: hl.agg.counter(elt),
        hl.range(0, mt['GT'].ploidy).map(lambda i: get_allele_type(mt['GT'][i])))

    mt = mt.annotate_cols(**{name: hl.struct(**exprs)})

    zero = hl.int64(0)

    select_exprs = {}
    if 'dp_stats' in exprs:
        select_exprs['dp_stats'] = mt[name].dp_stats
    if 'gq_stats' in exprs:
        select_exprs['gq_stats'] = mt[name].gq_stats

    select_exprs = {
        **select_exprs,
        'call_rate': hl.float64(mt[name].n_called) / (mt[name].n_called + mt[name].n_not_called),
        'n_called': mt[name].n_called,
        'n_not_called': mt[name].n_not_called,
        'n_hom_ref': mt[name].n_hom_ref,
        'n_het': mt[name].n_het,
        'n_hom_var': mt[name].n_called - mt[name].n_hom_ref - mt[name].n_het,
        'n_non_ref': mt[name].n_called - mt[name].n_hom_ref,
        'n_singleton': mt[name].n_singleton,
        'n_snp': mt[name].allele_type_counts.get(allele_ints["Transition"], zero) + \
                 mt[name].allele_type_counts.get(allele_ints["Transversion"], zero),
        'n_insertion': mt[name].allele_type_counts.get(allele_ints["Insertion"], zero),
        'n_deletion': mt[name].allele_type_counts.get(allele_ints["Deletion"], zero),
        'n_transition': mt[name].allele_type_counts.get(allele_ints["Transition"], zero),
        'n_transversion': mt[name].allele_type_counts.get(allele_ints["Transversion"], zero),
        'n_star': mt[name].allele_type_counts.get(allele_ints["Star"], zero)
    }

    mt = mt.annotate_cols(**{name: mt[name].select(**select_exprs)})

    mt = mt.annotate_cols(**{name: mt[name].annotate(
        r_ti_tv=divide_null(hl.float64(mt[name].n_transition), mt[name].n_transversion),
        r_het_hom_var=divide_null(hl.float64(mt[name].n_het), mt[name].n_hom_var),
        r_insertion_deletion=divide_null(hl.float64(mt[name].n_insertion), mt[name].n_deletion)
    )})        

    mt = mt.drop(variant_ac, variant_atypes)

    return mt
Ejemplo n.º 40
0
def transform_one(mt, vardp_outlier=100_000) -> Table:
    """transforms a gvcf into a form suitable for combining

    The input to this should be some result of either :func:`.import_vcf` or
    :func:`.import_vcfs` with `array_elements_required=False`.

    There is a strong assumption that this function will be called on a matrix
    table with one column.
    """
    mt = localize(mt)
    if mt.row.dtype not in _transform_rows_function_map:
        f = hl.experimental.define_function(
            lambda row: hl.rbind(
                hl.len(row.alleles), '<NON_REF>' == row.alleles[-1],
                lambda alleles_len, has_non_ref: hl.struct(
                    locus=row.locus,
                    alleles=hl.cond(has_non_ref, row.alleles[:-1], row.alleles),
                    rsid=row.rsid,
                    __entries=row.__entries.map(
                        lambda e:
                        hl.struct(
                            DP=e.DP,
                            END=row.info.END,
                            GQ=e.GQ,
                            LA=hl.range(0, alleles_len - hl.cond(has_non_ref, 1, 0)),
                            LAD=hl.cond(has_non_ref, e.AD[:-1], e.AD),
                            LGT=e.GT,
                            LPGT=e.PGT,
                            LPL=hl.cond(has_non_ref,
                                        hl.cond(alleles_len > 2,
                                                e.PL[:-alleles_len],
                                                hl.null(e.PL.dtype)),
                                        hl.cond(alleles_len > 1,
                                                e.PL,
                                                hl.null(e.PL.dtype))),
                            MIN_DP=e.MIN_DP,
                            PID=e.PID,
                            RGQ=hl.cond(
                                has_non_ref,
                                e.PL[hl.call(0, alleles_len - 1).unphased_diploid_gt_index()],
                                hl.null(e.PL.dtype.element_type)),
                            SB=e.SB,
                            gvcf_info=hl.case()
                                .when(hl.is_missing(row.info.END),
                                      hl.struct(
                                          ClippingRankSum=row.info.ClippingRankSum,
                                          BaseQRankSum=row.info.BaseQRankSum,
                                          MQ=row.info.MQ,
                                          MQRankSum=row.info.MQRankSum,
                                          MQ_DP=row.info.MQ_DP,
                                          QUALapprox=row.info.QUALapprox,
                                          RAW_MQ=row.info.RAW_MQ,
                                          ReadPosRankSum=row.info.ReadPosRankSum,
                                          VarDP=hl.cond(row.info.VarDP > vardp_outlier,
                                                        row.info.DP, row.info.VarDP)))
                                .or_missing()
                        ))),
            ),
            mt.row.dtype)
        _transform_rows_function_map[mt.row.dtype] = f
    transform_row = _transform_rows_function_map[mt.row.dtype]
    return Table(TableMapRows(mt._tir, Apply(transform_row._name, TopLevelReference('row'))))
Ejemplo n.º 41
0
def combine(ts):
    def merge_alleles(alleles):
        from hail.expr.functions import _num_allele_type, _allele_ints
        return hl.rbind(
            alleles.map(lambda a: hl.or_else(a[0], ''))
                   .fold(lambda s, t: hl.cond(hl.len(s) > hl.len(t), s, t), ''),
            lambda ref:
            hl.rbind(
                alleles.map(
                    lambda al: hl.rbind(
                        al[0],
                        lambda r:
                        hl.array([ref]).extend(
                            al[1:].map(
                                lambda a:
                                hl.rbind(
                                    _num_allele_type(r, a),
                                    lambda at:
                                    hl.cond(
                                        (_allele_ints['SNP'] == at) |
                                        (_allele_ints['Insertion'] == at) |
                                        (_allele_ints['Deletion'] == at) |
                                        (_allele_ints['MNP'] == at) |
                                        (_allele_ints['Complex'] == at),
                                        a + ref[hl.len(r):],
                                        a)))))),
                lambda lal:
                hl.struct(
                    globl=hl.array([ref]).extend(hl.array(hl.set(hl.flatten(lal)).remove(ref))),
                    local=lal)))

    def renumber_entry(entry, old_to_new) -> StructExpression:
        # global index of alternate (non-ref) alleles
        return entry.annotate(LA=entry.LA.map(lambda lak: old_to_new[lak]))

    if (ts.row.dtype, ts.globals.dtype) not in _merge_function_map:
        f = hl.experimental.define_function(
            lambda row, gbl:
            hl.rbind(
                merge_alleles(row.data.map(lambda d: d.alleles)),
                lambda alleles:
                hl.struct(
                    locus=row.locus,
                    alleles=alleles.globl,
                    rsid=hl.find(hl.is_defined, row.data.map(lambda d: d.rsid)),
                    __entries=hl.bind(
                        lambda combined_allele_index:
                        hl.range(0, hl.len(row.data)).flatmap(
                            lambda i:
                            hl.cond(hl.is_missing(row.data[i].__entries),
                                    hl.range(0, hl.len(gbl.g[i].__cols))
                                      .map(lambda _: hl.null(row.data[i].__entries.dtype.element_type)),
                                    hl.bind(
                                        lambda old_to_new: row.data[i].__entries.map(
                                            lambda e: renumber_entry(e, old_to_new)),
                                        hl.range(0, hl.len(alleles.local[i])).map(
                                            lambda j: combined_allele_index[alleles.local[i][j]])))),
                        hl.dict(hl.range(0, hl.len(alleles.globl)).map(
                            lambda j: hl.tuple([alleles.globl[j], j])))))),
            ts.row.dtype, ts.globals.dtype)
        _merge_function_map[(ts.row.dtype, ts.globals.dtype)] = f
    merge_function = _merge_function_map[(ts.row.dtype, ts.globals.dtype)]
    ts = Table(TableMapRows(ts._tir, Apply(merge_function._name,
                                           TopLevelReference('row'),
                                           TopLevelReference('global'))))
    return ts.transmute_globals(__cols=hl.flatten(ts.g.map(lambda g: g.__cols)))
Ejemplo n.º 42
0
def full_outer_join_mt(left: hl.MatrixTable, right: hl.MatrixTable) -> hl.MatrixTable:
    """Performs a full outer join on `left` and `right`.

    Replaces row, column, and entry fields with the following:

     - `left_row` / `right_row`: structs of row fields from left and right.
     - `left_col` / `right_col`: structs of column fields from left and right.
     - `left_entry` / `right_entry`: structs of entry fields from left and right.

    Parameters
    ----------
    left : :class:`.MatrixTable`
    right : :class:`.MatrixTable`

    Returns
    -------
    :class:`.MatrixTable`
    """

    if [x.dtype for x in left.row_key.values()] != [x.dtype for x in right.row_key.values()]:
        raise ValueError(f"row key types do not match:\n"
                         f"  left:  {list(left.row_key.values())}\n"
                         f"  right: {list(right.row_key.values())}")

    if [x.dtype for x in left.col_key.values()] != [x.dtype for x in right.col_key.values()]: 
        raise ValueError(f"column key types do not match:\n"
                         f"  left:  {list(left.col_key.values())}\n"
                         f"  right: {list(right.col_key.values())}")

    left = left.select_rows(left_row=left.row)
    left_t = left.localize_entries('left_entries', 'left_cols')
    right = right.select_rows(right_row=right.row)
    right_t = right.localize_entries('right_entries', 'right_cols')

    ht = left_t.join(right_t, how='outer')
    ht = ht.annotate_globals(
        left_keys=hl.group_by(
            lambda t: t[0],
            hl.zip_with_index(
                ht.left_cols.map(lambda x: hl.tuple([x[f] for f in left.col_key])), index_first=False)).map_values(
            lambda elts: elts.map(lambda t: t[1])),
        right_keys=hl.group_by(
            lambda t: t[0],
            hl.zip_with_index(
                ht.right_cols.map(lambda x: hl.tuple([x[f] for f in right.col_key])), index_first=False)).map_values(
            lambda elts: elts.map(lambda t: t[1])))
    ht = ht.annotate_globals(
        key_indices=hl.array(ht.left_keys.key_set().union(ht.right_keys.key_set()))
            .map(lambda k: hl.struct(k=k, left_indices=ht.left_keys.get(k), right_indices=ht.right_keys.get(k)))
            .flatmap(lambda s: hl.case()
                     .when(hl.is_defined(s.left_indices) & hl.is_defined(s.right_indices),
                           hl.range(0, s.left_indices.length()).flatmap(
                               lambda i: hl.range(0, s.right_indices.length()).map(
                                   lambda j: hl.struct(k=s.k, left_index=s.left_indices[i],
                                                       right_index=s.right_indices[j]))))
                     .when(hl.is_defined(s.left_indices),
                           s.left_indices.map(
                               lambda elt: hl.struct(k=s.k, left_index=elt, right_index=hl.null('int32'))))
                     .when(hl.is_defined(s.right_indices),
                           s.right_indices.map(
                               lambda elt: hl.struct(k=s.k, left_index=hl.null('int32'), right_index=elt)))
                     .or_error('assertion error')))
    ht = ht.annotate(__entries=ht.key_indices.map(lambda s: hl.struct(left_entry=ht.left_entries[s.left_index],
                                                                      right_entry=ht.right_entries[s.right_index])))
    ht = ht.annotate_globals(__cols=ht.key_indices.map(
        lambda s: hl.struct(**{f: s.k[i] for i, f in enumerate(left.col_key)},
                            left_col=ht.left_cols[s.left_index],
                            right_col=ht.right_cols[s.right_index])))
    ht = ht.drop('left_entries', 'left_cols', 'left_keys', 'right_entries', 'right_cols', 'right_keys', 'key_indices')
    return ht._unlocalize_entries('__entries', '__cols', list(left.col_key))