示例#1
0
 def merge_alleles(alleles):
     from hail.expr.functions import _num_allele_type, _allele_ints
     return hl.rbind(
         alleles.map(lambda a: hl.or_else(a[0], ''))
                .fold(lambda s, t: hl.cond(hl.len(s) > hl.len(t), s, t), ''),
         lambda ref:
         hl.rbind(
             alleles.map(
                 lambda al: hl.rbind(
                     al[0],
                     lambda r:
                     hl.array([ref]).extend(
                         al[1:].map(
                             lambda a:
                             hl.rbind(
                                 _num_allele_type(r, a),
                                 lambda at:
                                 hl.cond(
                                     (_allele_ints['SNP'] == at) |
                                     (_allele_ints['Insertion'] == at) |
                                     (_allele_ints['Deletion'] == at) |
                                     (_allele_ints['MNP'] == at) |
                                     (_allele_ints['Complex'] == at),
                                     a + ref[hl.len(r):],
                                     a)))))),
             lambda lal:
             hl.struct(
                 globl=hl.array([ref]).extend(hl.array(hl.set(hl.flatten(lal)).remove(ref))),
                 local=lal)))
示例#2
0
def summarize(mt):
    """Computes summary statistics

    Calls :func:`.quick_summary`. Calling both this and :func:`.quick_summary`, will lead
    to :func:`.quick_summary` being executed twice.

    Note
    ----
    You will not be able to run :func:`.combine_gvcfs` with the output of this
    function.
    """
    mt = quick_summary(mt)
    mt = hl.experimental.densify(mt)
    return mt.annotate_rows(info=hl.rbind(
        hl.agg.call_stats(lgt_to_gt(mt.LGT, mt.LA), mt.alleles),
        lambda gs: hl.struct(
            # here, we alphabetize the INFO fields by GATK convention
            AC=gs.AC[1:],  # The VCF spec indicates that AC and AF have Number=A, so we need
            AF=gs.AF[1:],  # to drop the first element from each of these.
            AN=gs.AN,
            BaseQRankSum=hl.median(hl.agg.collect(mt.entry.gvcf_info.BaseQRankSum)),
            ClippingRankSum=hl.median(hl.agg.collect(mt.entry.gvcf_info.ClippingRankSum)),
            DP=hl.agg.sum(mt.entry.DP),
            MQ=hl.median(hl.agg.collect(mt.entry.gvcf_info.MQ)),
            MQRankSum=hl.median(hl.agg.collect(mt.entry.gvcf_info.MQRankSum)),
            MQ_DP=mt.info.MQ_DP,
            QUALapprox=mt.info.QUALapprox,
            RAW_MQ=mt.info.RAW_MQ,
            ReadPosRankSum=hl.median(hl.agg.collect(mt.entry.gvcf_info.ReadPosRankSum)),
            SB_TABLE=mt.info.SB_TABLE,
            VarDP=mt.info.VarDP,
        )))
示例#3
0
文件: densify.py 项目: jigold/hail
def densify(sparse_mt):
    """Convert sparse MatrixTable to a dense one.

    Parameters
    ----------
    sparse_mt : :class:`.MatrixTable`
        Sparse MatrixTable to densify.  The first row key field must
        be named ``locus`` and have type ``locus``.  Must have an
        ``END`` entry field of type ``int32``.

    Returns
    -------
    :class:`.MatrixTable`
        The densified MatrixTable.  The ``END`` entry field is dropped.

    """
    if list(sparse_mt.row_key)[0] != 'locus' or not isinstance(sparse_mt.locus.dtype, hl.tlocus):
        raise ValueError("first row key field must be named 'locus' and have type 'locus'")
    if 'END' not in sparse_mt.entry or sparse_mt.END.dtype != hl.tint32:
        raise ValueError("'densify' requires 'END' entry field of type 'int32'")
    col_key_fields = list(sparse_mt.col_key)

    mt = sparse_mt
    mt = sparse_mt.annotate_entries(__contig = mt.locus.contig)
    t = mt._localize_entries('__entries', '__cols')
    t = t.annotate(
        __entries = hl.rbind(
            hl.scan.array_agg(
                lambda entry: hl.scan._prev_nonnull(hl.or_missing(hl.is_defined(entry.END), entry)),
                t.__entries),
            lambda prev_entries: hl.map(
                lambda i:
                hl.rbind(
                    prev_entries[i], t.__entries[i],
                    lambda prev_entry, entry:
                    hl.cond(
                        (~hl.is_defined(entry) &
                         (prev_entry.END >= t.locus.position) &
                         (prev_entry.__contig == t.locus.contig)),
                        prev_entry,
                        entry)),
                hl.range(0, hl.len(t.__entries)))))
    mt = t._unlocalize_entries('__entries', '__cols', col_key_fields)
    mt = mt.drop('__contig', 'END')
    return mt
示例#4
0
文件: misc.py 项目: chrisvittal/hail
def segment_intervals(ht, points):
    """Segment the interval keys of `ht` at a given set of points.

    Parameters
    ----------
    ht : :class:`.Table`
        Table with interval keys.
    points : :class:`.Table` or :class:`.ArrayExpression`
        Points at which to segment the intervals, a table or an array.

    Returns
    -------
    :class:`.Table`
    """
    if len(ht.key) != 1 or not isinstance(ht.key[0].dtype, hl.tinterval):
        raise ValueError(
            "'segment_intervals' expects a table with interval keys")
    point_type = ht.key[0].dtype.point_type
    if isinstance(points, Table):
        if len(points.key) != 1 or points.key[0].dtype != point_type:
            raise ValueError(
                "'segment_intervals' expects points to be a table with a single"
                " key of the same type as the intervals in 'ht', or an array of those points:"
                f"\n  expect {point_type}, found {list(points.key.dtype.values())}"
            )
        points = hl.array(hl.set(points.collect(_localize=False)))
    if points.dtype.element_type != point_type:
        raise ValueError(
            f"'segment_intervals' expects points to be a table with a single"
            f" key of the same type as the intervals in 'ht', or an array of those points:"
            f"\n  expect {point_type}, found {points.dtype.element_type}")

    points = hl._sort_by(points, lambda l, r: hl._compare(l, r) < 0)

    ht = ht.annotate_globals(__points=points)

    interval = ht.key[0]
    points = ht.__points
    lower = hl.expr.functions._lower_bound(points, interval.start)
    higher = hl.expr.functions._lower_bound(points, interval.end)
    n_points = hl.len(points)
    lower = hl.if_else((lower < n_points) & (points[lower] == interval.start),
                       lower + 1, lower)
    higher = hl.if_else((higher < n_points) & (points[higher] == interval.end),
                        higher - 1, higher)
    interval_results = hl.rbind(
        lower, higher, lambda lower, higher: hl.cond(
            lower >= higher, [interval],
            hl.flatten([
                [
                    hl.interval(interval.start,
                                points[lower],
                                includes_start=interval.includes_start,
                                includes_end=False)
                ],
                hl.range(lower, higher - 1).map(lambda x: hl.interval(
                    points[x],
                    points[x + 1],
                    includes_start=True,
                    includes_end=False)),
                [
                    hl.interval(points[higher - 1],
                                interval.end,
                                includes_start=True,
                                includes_end=interval.includes_end)
                ],
            ])))
    ht = ht.annotate(__new_intervals=interval_results,
                     lower=lower,
                     higher=higher).explode('__new_intervals')
    return ht.key_by(**{
        list(ht.key)[0]: ht.__new_intervals
    }).drop('__new_intervals')
示例#5
0
def make_variants_matrix_table(mt: MatrixTable,
                               info_to_keep: Optional[Collection[str]] = None
                               ) -> MatrixTable:
    if info_to_keep is None:
        info_to_keep = []
    if not info_to_keep:
        info_to_keep = [name for name in mt.info if name not in ['END', 'DP']]
    info_key = tuple(sorted(info_to_keep))  # hashable stable value
    mt = localize(mt)
    mt = mt.filter(hl.is_missing(mt.info.END))

    if (mt.row.dtype, info_key) not in _transform_variant_function_map:
        def get_lgt(e, n_alleles, has_non_ref, row):
            index = e.GT.unphased_diploid_gt_index()
            n_no_nonref = n_alleles - hl.int(has_non_ref)
            triangle_without_nonref = hl.triangle(n_no_nonref)
            return (hl.case()
                    .when(e.GT.is_haploid(),
                          hl.or_missing(e.GT[0] < n_no_nonref, e.GT))
                    .when(index < triangle_without_nonref, e.GT)
                    .when(index < hl.triangle(n_alleles), hl.missing('call'))
                    .or_error('invalid GT ' + hl.str(e.GT) + ' at site ' + hl.str(row.locus)))

        def make_entry_struct(e, alleles_len, has_non_ref, row):
            handled_fields = dict()
            handled_names = {'LA', 'gvcf_info',
                             'LAD', 'AD',
                             'LGT', 'GT',
                             'LPL', 'PL',
                             'LPGT', 'PGT'}

            if 'GT' not in e:
                raise hl.utils.FatalError("the Hail GVCF combiner expects GVCFs to have a 'GT' field in FORMAT.")

            handled_fields['LA'] = hl.range(0, alleles_len - hl.if_else(has_non_ref, 1, 0))
            handled_fields['LGT'] = get_lgt(e, alleles_len, has_non_ref, row)
            if 'AD' in e:
                handled_fields['LAD'] = hl.if_else(has_non_ref, e.AD[:-1], e.AD)
            if 'PGT' in e:
                handled_fields['LPGT'] = e.PGT
            if 'PL' in e:
                handled_fields['LPL'] = hl.if_else(has_non_ref,
                                                   hl.if_else(alleles_len > 2,
                                                              e.PL[:-alleles_len],
                                                              hl.missing(e.PL.dtype)),
                                                   hl.if_else(alleles_len > 1,
                                                              e.PL,
                                                              hl.missing(e.PL.dtype)))
                handled_fields['RGQ'] = hl.if_else(
                    has_non_ref,
                    hl.if_else(e.GT.is_haploid(),
                               e.PL[alleles_len - 1],
                               e.PL[hl.call(0, alleles_len - 1).unphased_diploid_gt_index()]),
                    hl.missing(e.PL.dtype.element_type))

            handled_fields['gvcf_info'] = (hl.case()
                                           .when(hl.is_missing(row.info.END),
                                                 hl.struct(**(
                                                     parse_as_fields(
                                                         row.info.select(*info_to_keep),
                                                         has_non_ref)
                                                 )))
                                           .or_missing())

            pass_through_fields = {k: v for k, v in e.items() if k not in handled_names}
            return hl.struct(**handled_fields, **pass_through_fields)

        f = hl.experimental.define_function(
            lambda row: hl.rbind(
                hl.len(row.alleles), '<NON_REF>' == row.alleles[-1],
                lambda alleles_len, has_non_ref: hl.struct(
                    locus=row.locus,
                    alleles=hl.if_else(has_non_ref, row.alleles[:-1], row.alleles),
                    rsid=row.rsid,
                    __entries=row.__entries.map(
                        lambda e: make_entry_struct(e, alleles_len, has_non_ref, row)))),
            mt.row.dtype)
        _transform_variant_function_map[mt.row.dtype, info_key] = f
    transform_row = _transform_variant_function_map[mt.row.dtype, info_key]
    return unlocalize(Table(TableMapRows(mt._tir, Apply(transform_row._name, transform_row._ret_type, TopLevelReference('row')))))
示例#6
0
def ld_score_regression(weight_expr,
                        ld_score_expr,
                        chi_sq_exprs,
                        n_samples_exprs,
                        n_blocks=200,
                        two_step_threshold=30,
                        n_reference_panel_variants=None) -> Table:
    r"""Estimate SNP-heritability and level of confounding biases from
    GWAS summary statistics.

    Given a set or multiple sets of genome-wide association study (GWAS)
    summary statistics, :func:`.ld_score_regression` estimates the heritability
    of a trait or set of traits and the level of confounding biases present in
    the underlying studies by regressing chi-squared statistics on LD scores,
    leveraging the model:

    .. math::

        \mathrm{E}[\chi_j^2] = 1 + Na + \frac{Nh_g^2}{M}l_j

    *  :math:`\mathrm{E}[\chi_j^2]` is the expected chi-squared statistic
       for variant :math:`j` resulting from a test of association between
       variant :math:`j` and a trait.
    *  :math:`l_j = \sum_{k} r_{jk}^2` is the LD score of variant
       :math:`j`, calculated as the sum of squared correlation coefficients
       between variant :math:`j` and nearby variants. See :func:`ld_score`
       for further details.
    *  :math:`a` captures the contribution of confounding biases, such as
       cryptic relatedness and uncontrolled population structure, to the
       association test statistic.
    *  :math:`h_g^2` is the SNP-heritability, or the proportion of variation
       in the trait explained by the effects of variants included in the
       regression model above.
    *  :math:`M` is the number of variants used to estimate :math:`h_g^2`.
    *  :math:`N` is the number of samples in the underlying association study.

    For more details on the method implemented in this function, see:

    * `LD Score regression distinguishes confounding from polygenicity in genome-wide association studies (Bulik-Sullivan et al, 2015) <https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4495769/>`__

    Examples
    --------

    Run the method on a matrix table of summary statistics, where the rows
    are variants and the columns are different phenotypes:

    >>> mt_gwas = hl.read_matrix_table('data/ld_score_regression.sumstats.mt')
    >>> ht_results = hl.experimental.ld_score_regression(
    ...     weight_expr=mt_gwas['ld_score'],
    ...     ld_score_expr=mt_gwas['ld_score'],
    ...     chi_sq_exprs=mt_gwas['chi_squared'],
    ...     n_samples_exprs=mt_gwas['n'])


    Run the method on a table with summary statistics for a single
    phenotype:

    >>> ht_gwas = hl.read_table('data/ld_score_regression.sumstats.ht')
    >>> ht_results = hl.experimental.ld_score_regression(
    ...     weight_expr=ht_gwas['ld_score'],
    ...     ld_score_expr=ht_gwas['ld_score'],
    ...     chi_sq_exprs=ht_gwas['chi_squared_50_irnt'],
    ...     n_samples_exprs=ht_gwas['n_50_irnt'])

    Run the method on a table with summary statistics for multiple
    phenotypes:

    >>> ht_gwas = hl.read_table('data/ld_score_regression.sumstats.ht')
    >>> ht_results = hl.experimental.ld_score_regression(
    ...     weight_expr=ht_gwas['ld_score'],
    ...     ld_score_expr=ht_gwas['ld_score'],
    ...     chi_sq_exprs=[ht_gwas['chi_squared_50_irnt'],
    ...                        ht_gwas['chi_squared_20160']],
    ...     n_samples_exprs=[ht_gwas['n_50_irnt'],
    ...                      ht_gwas['n_20160']])

    Notes
    -----
    The ``exprs`` provided as arguments to :func:`.ld_score_regression`
    must all be from the same object, either a :class:`Table` or a
    :class:`MatrixTable`.

    **If the arguments originate from a table:**

    *  The table must be keyed by fields ``locus`` of type
       :class:`.tlocus` and ``alleles``, a :py:data:`.tarray` of
       :py:data:`.tstr` elements.
    *  ``weight_expr``, ``ld_score_expr``, ``chi_sq_exprs``, and
       ``n_samples_exprs`` are must be row-indexed fields.
    *  The number of expressions passed to ``n_samples_exprs`` must be
       equal to one or the number of expressions passed to
       ``chi_sq_exprs``. If just one expression is passed to
       ``n_samples_exprs``, that sample size expression is assumed to
       apply to all sets of statistics passed to ``chi_sq_exprs``.
       Otherwise, the expressions passed to ``chi_sq_exprs`` and
       ``n_samples_exprs`` are matched by index.
    *  The ``phenotype`` field that keys the table returned by
       :func:`.ld_score_regression` will have generic :obj:`int` values
       ``0``, ``1``, etc. corresponding to the ``0th``, ``1st``, etc.
       expressions passed to the ``chi_sq_exprs`` argument.

    **If the arguments originate from a matrix table:**

    *  The dimensions of the matrix table must be variants
       (rows) by phenotypes (columns).
    *  The rows of the matrix table must be keyed by fields
       ``locus`` of type :class:`.tlocus` and ``alleles``,
       a :py:data:`.tarray` of :py:data:`.tstr` elements.
    *  The columns of the matrix table must be keyed by a field
       of type :py:data:`.tstr` that uniquely identifies phenotypes
       represented in the matrix table. The column key must be a single
       expression; compound keys are not accepted.
    *  ``weight_expr`` and ``ld_score_expr`` must be row-indexed
       fields.
    *  ``chi_sq_exprs`` must be a single entry-indexed field
       (not a list of fields).
    *  ``n_samples_exprs`` must be a single entry-indexed field
       (not a list of fields).
    *  The ``phenotype`` field that keys the table returned by
       :func:`.ld_score_regression` will have values corresponding to the
       column keys of the input matrix table.

    This function returns a :class:`Table` with one row per set of summary
    statistics passed to the ``chi_sq_exprs`` argument. The following
    row-indexed fields are included in the table:

    *  **phenotype** (:py:data:`.tstr`) -- The name of the phenotype. The
       returned table is keyed by this field. See the notes below for
       details on the possible values of this field.
    *  **mean_chi_sq** (:py:data:`.tfloat64`) -- The mean chi-squared
       test statistic for the given phenotype.
    *  **intercept** (`Struct`) -- Contains fields:

       -  **estimate** (:py:data:`.tfloat64`) -- A point estimate of the
          intercept :math:`1 + Na`.
       -  **standard_error**  (:py:data:`.tfloat64`) -- An estimate of
          the standard error of this point estimate.

    *  **snp_heritability** (`Struct`) -- Contains fields:

       -  **estimate** (:py:data:`.tfloat64`) -- A point estimate of the
          SNP-heritability :math:`h_g^2`.
       -  **standard_error** (:py:data:`.tfloat64`) -- An estimate of
          the standard error of this point estimate.

    Warning
    -------
    :func:`.ld_score_regression` considers only the rows for which both row
    fields ``weight_expr`` and ``ld_score_expr`` are defined. Rows with missing
    values in either field are removed prior to fitting the LD score
    regression model.

    Parameters
    ----------
    weight_expr : :class:`.Float64Expression`
                  Row-indexed expression for the LD scores used to derive
                  variant weights in the model.
    ld_score_expr : :class:`.Float64Expression`
                    Row-indexed expression for the LD scores used as covariates
                    in the model.
    chi_sq_exprs : :class:`.Float64Expression` or :obj:`list` of
                        :class:`.Float64Expression`
                        One or more row-indexed (if table) or entry-indexed
                        (if matrix table) expressions for chi-squared
                        statistics resulting from genome-wide association
                        studies.
    n_samples_exprs: :class:`.NumericExpression` or :obj:`list` of
                     :class:`.NumericExpression`
                     One or more row-indexed (if table) or entry-indexed
                     (if matrix table) expressions indicating the number of
                     samples used in the studies that generated the test
                     statistics supplied to ``chi_sq_exprs``.
    n_blocks : :obj:`int`
               The number of blocks used in the jackknife approach to
               estimating standard errors.
    two_step_threshold : :obj:`int`
                         Variants with chi-squared statistics greater than this
                         value are excluded in the first step of the two-step
                         procedure used to fit the model.
    n_reference_panel_variants : :obj:`int`, optional
                                 Number of variants used to estimate the
                                 SNP-heritability :math:`h_g^2`.

    Returns
    -------
    :class:`.Table`
        Table keyed by ``phenotype`` with intercept and heritability estimates
        for each phenotype passed to the function."""

    chi_sq_exprs = wrap_to_list(chi_sq_exprs)
    n_samples_exprs = wrap_to_list(n_samples_exprs)

    assert ((len(chi_sq_exprs) == len(n_samples_exprs)) or
            (len(n_samples_exprs) == 1))
    __k = 2  # number of covariates, including intercept

    ds = chi_sq_exprs[0]._indices.source

    analyze('ld_score_regression/weight_expr',
            weight_expr,
            ds._row_indices)
    analyze('ld_score_regression/ld_score_expr',
            ld_score_expr,
            ds._row_indices)

    # format input dataset
    if isinstance(ds, MatrixTable):
        if len(chi_sq_exprs) != 1:
            raise ValueError("""Only one chi_sq_expr allowed if originating
                from a matrix table.""")
        if len(n_samples_exprs) != 1:
            raise ValueError("""Only one n_samples_expr allowed if
                originating from a matrix table.""")

        col_key = list(ds.col_key)
        if len(col_key) != 1:
            raise ValueError("""Matrix table must be keyed by a single
                phenotype field.""")

        analyze('ld_score_regression/chi_squared_expr',
                chi_sq_exprs[0],
                ds._entry_indices)
        analyze('ld_score_regression/n_samples_expr',
                n_samples_exprs[0],
                ds._entry_indices)

        ds = ds._select_all(row_exprs={'__locus': ds.locus,
                                       '__alleles': ds.alleles,
                                       '__w_initial': weight_expr,
                                       '__w_initial_floor': hl.max(weight_expr,
                                                                   1.0),
                                       '__x': ld_score_expr,
                                       '__x_floor': hl.max(ld_score_expr,
                                                           1.0)},
                            row_key=['__locus', '__alleles'],
                            col_exprs={'__y_name': ds[col_key[0]]},
                            col_key=['__y_name'],
                            entry_exprs={'__y': chi_sq_exprs[0],
                                         '__n': n_samples_exprs[0]})
        ds = ds.annotate_entries(**{'__w': ds.__w_initial})

        ds = ds.filter_rows(hl.is_defined(ds.__locus) &
                            hl.is_defined(ds.__alleles) &
                            hl.is_defined(ds.__w_initial) &
                            hl.is_defined(ds.__x))

    else:
        assert isinstance(ds, Table)
        for y in chi_sq_exprs:
            analyze('ld_score_regression/chi_squared_expr', y, ds._row_indices)
        for n in n_samples_exprs:
            analyze('ld_score_regression/n_samples_expr', n, ds._row_indices)

        ys = ['__y{:}'.format(i) for i, _ in enumerate(chi_sq_exprs)]
        ws = ['__w{:}'.format(i) for i, _ in enumerate(chi_sq_exprs)]
        ns = ['__n{:}'.format(i) for i, _ in enumerate(n_samples_exprs)]

        ds = ds.select(**dict(**{'__locus': ds.locus,
                                 '__alleles': ds.alleles,
                                 '__w_initial': weight_expr,
                                 '__x': ld_score_expr},
                              **{y: chi_sq_exprs[i]
                                 for i, y in enumerate(ys)},
                              **{w: weight_expr for w in ws},
                              **{n: n_samples_exprs[i]
                                 for i, n in enumerate(ns)}))
        ds = ds.key_by(ds.__locus, ds.__alleles)

        table_tmp_file = new_temp_file()
        ds.write(table_tmp_file)
        ds = hl.read_table(table_tmp_file)

        hts = [ds.select(**{'__w_initial': ds.__w_initial,
                            '__w_initial_floor': hl.max(ds.__w_initial,
                                                        1.0),
                            '__x': ds.__x,
                            '__x_floor': hl.max(ds.__x, 1.0),
                            '__y_name': i,
                            '__y': ds[ys[i]],
                            '__w': ds[ws[i]],
                            '__n': hl.int(ds[ns[i]])})
               for i, y in enumerate(ys)]

        mts = [ht.to_matrix_table(row_key=['__locus',
                                           '__alleles'],
                                  col_key=['__y_name'],
                                  row_fields=['__w_initial',
                                              '__w_initial_floor',
                                              '__x',
                                              '__x_floor'])
               for ht in hts]

        ds = mts[0]
        for i in range(1, len(ys)):
            ds = ds.union_cols(mts[i])

        ds = ds.filter_rows(hl.is_defined(ds.__locus) &
                            hl.is_defined(ds.__alleles) &
                            hl.is_defined(ds.__w_initial) &
                            hl.is_defined(ds.__x))

    mt_tmp_file1 = new_temp_file()
    ds.write(mt_tmp_file1)
    mt = hl.read_matrix_table(mt_tmp_file1)

    if not n_reference_panel_variants:
        M = mt.count_rows()
    else:
        M = n_reference_panel_variants

    # block variants for each phenotype
    n_phenotypes = mt.count_cols()

    mt = mt.annotate_entries(__in_step1=(hl.is_defined(mt.__y) &
                                         (mt.__y < two_step_threshold)),
                             __in_step2=hl.is_defined(mt.__y))

    mt = mt.annotate_cols(__col_idx=hl.int(hl.scan.count()),
                          __m_step1=hl.agg.count_where(mt.__in_step1),
                          __m_step2=hl.agg.count_where(mt.__in_step2))

    col_keys = list(mt.col_key)

    ht = mt.localize_entries(entries_array_field_name='__entries',
                             columns_array_field_name='__cols')

    ht = ht.annotate(__entries=hl.rbind(
        hl.scan.array_agg(
            lambda entry: hl.scan.count_where(entry.__in_step1),
            ht.__entries),
        lambda step1_indices: hl.map(
            lambda i: hl.rbind(
                hl.int(hl.or_else(step1_indices[i], 0)),
                ht.__cols[i].__m_step1,
                ht.__entries[i],
                lambda step1_idx, m_step1, entry: hl.rbind(
                    hl.map(
                        lambda j: hl.int(hl.floor(j * (m_step1 / n_blocks))),
                        hl.range(0, n_blocks + 1)),
                    lambda step1_separators: hl.rbind(
                        hl.set(step1_separators).contains(step1_idx),
                        hl.sum(
                            hl.map(
                                lambda s1: step1_idx >= s1,
                                step1_separators)) - 1,
                        lambda is_separator, step1_block: entry.annotate(
                            __step1_block=step1_block,
                            __step2_block=hl.cond(~entry.__in_step1 & is_separator,
                                                  step1_block - 1,
                                                  step1_block))))),
            hl.range(0, hl.len(ht.__entries)))))

    mt = ht._unlocalize_entries('__entries', '__cols', col_keys)

    mt_tmp_file2 = new_temp_file()
    mt.write(mt_tmp_file2)
    mt = hl.read_matrix_table(mt_tmp_file2)
    
    # initial coefficient estimates
    mt = mt.annotate_cols(__initial_betas=[
        1.0, (hl.agg.mean(mt.__y) - 1.0) / hl.agg.mean(mt.__x)])
    mt = mt.annotate_cols(__step1_betas=mt.__initial_betas,
                          __step2_betas=mt.__initial_betas)

    # step 1 iteratively reweighted least squares
    for i in range(3):
        mt = mt.annotate_entries(__w=hl.cond(
            mt.__in_step1,
            1.0/(mt.__w_initial_floor * 2.0 * (mt.__step1_betas[0] +
                                               mt.__step1_betas[1] *
                                               mt.__x_floor)**2),
            0.0))
        mt = mt.annotate_cols(__step1_betas=hl.agg.filter(
            mt.__in_step1,
            hl.agg.linreg(y=mt.__y,
                          x=[1.0, mt.__x],
                          weight=mt.__w).beta))
        mt = mt.annotate_cols(__step1_h2=hl.max(hl.min(
            mt.__step1_betas[1] * M / hl.agg.mean(mt.__n), 1.0), 0.0))
        mt = mt.annotate_cols(__step1_betas=[
            mt.__step1_betas[0],
            mt.__step1_h2 * hl.agg.mean(mt.__n) / M])

    # step 1 block jackknife
    mt = mt.annotate_cols(__step1_block_betas=[
        hl.agg.filter((mt.__step1_block != i) & mt.__in_step1,
                      hl.agg.linreg(y=mt.__y,
                                    x=[1.0, mt.__x],
                                    weight=mt.__w).beta)
        for i in range(n_blocks)])

    mt = mt.annotate_cols(__step1_block_betas_bias_corrected=hl.map(
        lambda x: n_blocks * mt.__step1_betas - (n_blocks - 1) * x,
        mt.__step1_block_betas))

    mt = mt.annotate_cols(
        __step1_jackknife_mean=hl.map(
            lambda i: hl.mean(
                hl.map(lambda x: x[i],
                       mt.__step1_block_betas_bias_corrected)),
            hl.range(0, __k)),
        __step1_jackknife_variance=hl.map(
            lambda i: (hl.sum(
                hl.map(lambda x: x[i]**2,
                       mt.__step1_block_betas_bias_corrected)) -
                       hl.sum(
                hl.map(lambda x: x[i],
                       mt.__step1_block_betas_bias_corrected))**2 /
                       n_blocks) /
            (n_blocks - 1) / n_blocks,
            hl.range(0, __k)))

    # step 2 iteratively reweighted least squares
    for i in range(3):
        mt = mt.annotate_entries(__w=hl.cond(
            mt.__in_step2,
            1.0/(mt.__w_initial_floor *
                 2.0 * (mt.__step2_betas[0] +
                        mt.__step2_betas[1] *
                        mt.__x_floor)**2),
            0.0))
        mt = mt.annotate_cols(__step2_betas=[
            mt.__step1_betas[0],
            hl.agg.filter(mt.__in_step2,
                          hl.agg.linreg(y=mt.__y - mt.__step1_betas[0],
                                        x=[mt.__x],
                                        weight=mt.__w).beta[0])])
        mt = mt.annotate_cols(__step2_h2=hl.max(hl.min(
            mt.__step2_betas[1] * M/hl.agg.mean(mt.__n), 1.0), 0.0))
        mt = mt.annotate_cols(__step2_betas=[
            mt.__step1_betas[0],
            mt.__step2_h2 * hl.agg.mean(mt.__n)/M])

    # step 2 block jackknife
    mt = mt.annotate_cols(__step2_block_betas=[
        hl.agg.filter((mt.__step2_block != i) & mt.__in_step2,
                      hl.agg.linreg(y=mt.__y - mt.__step1_betas[0],
                                    x=[mt.__x],
                                    weight=mt.__w).beta[0])
        for i in range(n_blocks)])

    mt = mt.annotate_cols(__step2_block_betas_bias_corrected=hl.map(
        lambda x: n_blocks * mt.__step2_betas[1] - (n_blocks - 1) * x,
        mt.__step2_block_betas))

    mt = mt.annotate_cols(
        __step2_jackknife_mean=hl.mean(
            mt.__step2_block_betas_bias_corrected),
        __step2_jackknife_variance=(
            hl.sum(mt.__step2_block_betas_bias_corrected**2) -
            hl.sum(mt.__step2_block_betas_bias_corrected)**2 /
            n_blocks) / (n_blocks - 1) / n_blocks)

    # combine step 1 and step 2 block jackknifes
    mt = mt.annotate_entries(
        __step2_initial_w=1.0/(mt.__w_initial_floor *
                               2.0 * (mt.__initial_betas[0] +
                                      mt.__initial_betas[1] *
                                      mt.__x_floor)**2))

    mt = mt.annotate_cols(
        __final_betas=[
            mt.__step1_betas[0],
            mt.__step2_betas[1]],
        __c=(hl.agg.sum(mt.__step2_initial_w * mt.__x) /
             hl.agg.sum(mt.__step2_initial_w * mt.__x**2)))

    mt = mt.annotate_cols(__final_block_betas=hl.map(
        lambda i: (mt.__step2_block_betas[i] - mt.__c *
                   (mt.__step1_block_betas[i][0] - mt.__final_betas[0])),
        hl.range(0, n_blocks)))

    mt = mt.annotate_cols(
        __final_block_betas_bias_corrected=(n_blocks * mt.__final_betas[1] -
                                            (n_blocks - 1) *
                                            mt.__final_block_betas))

    mt = mt.annotate_cols(
        __final_jackknife_mean=[
            mt.__step1_jackknife_mean[0],
            hl.mean(mt.__final_block_betas_bias_corrected)],
        __final_jackknife_variance=[
            mt.__step1_jackknife_variance[0],
            (hl.sum(mt.__final_block_betas_bias_corrected**2) -
             hl.sum(mt.__final_block_betas_bias_corrected)**2 /
             n_blocks) / (n_blocks - 1) / n_blocks])

    # convert coefficient to heritability estimate
    mt = mt.annotate_cols(
        phenotype=mt.__y_name,
        mean_chi_sq=hl.agg.mean(mt.__y),
        intercept=hl.struct(
            estimate=mt.__final_betas[0],
            standard_error=hl.sqrt(mt.__final_jackknife_variance[0])),
        snp_heritability=hl.struct(
            estimate=(M/hl.agg.mean(mt.__n)) * mt.__final_betas[1],
            standard_error=hl.sqrt((M/hl.agg.mean(mt.__n))**2 *
                                   mt.__final_jackknife_variance[1])))

    # format and return results
    ht = mt.cols()
    ht = ht.key_by(ht.phenotype)
    ht = ht.select(ht.mean_chi_sq,
                   ht.intercept,
                   ht.snp_heritability)

    ht_tmp_file = new_temp_file()
    ht.write(ht_tmp_file)
    ht = hl.read_table(ht_tmp_file)
    
    return ht
示例#7
0
def segment_reference_blocks(ref: 'MatrixTable',
                             intervals: 'Table') -> 'MatrixTable':
    """Returns a matrix table of reference blocks segmented according to intervals.

    Loci outside the given intervals are discarded. Reference blocks that start before
    but span an interval will appear at the interval start locus.

    Note
    ----
        Assumes disjoint intervals which do not span contigs.

        Requires start-inclusive intervals.

    Parameters
    ----------
    ref : :class:`.MatrixTable`
        MatrixTable of reference blocks.
    intervals : :class:`.Table`
        Table of intervals at which to segment reference blocks.

    Returns
    -------
    :class:`.MatrixTable`
    """
    interval_field = list(intervals.key)[0]
    if not intervals[interval_field].dtype == hl.tinterval(ref.locus.dtype):
        raise ValueError(
            f"expect intervals to be keyed by intervals of loci matching the VariantDataset:"
            f" found {intervals[interval_field].dtype} / {ref.locus.dtype}")
    intervals = intervals.select(_interval_dup=intervals[interval_field])

    if not intervals.aggregate(
            hl.agg.all(intervals[interval_field].includes_start
                       & (intervals[interval_field].start.contig ==
                          intervals[interval_field].end.contig))):
        raise ValueError("expect intervals to be start-inclusive")

    starts = intervals.key_by(_start_locus=intervals[interval_field].start)
    starts = starts.annotate(_include_locus=True)
    refl = ref.localize_entries('_ref_entries', '_ref_cols')
    joined = refl.join(starts, how='outer')
    rg = ref.locus.dtype.reference_genome
    contigs = rg.contigs
    contig_idx_map = hl.literal({contigs[i]: i
                                 for i in range(len(contigs))},
                                'dict<str, int32>')
    joined = joined.annotate(__contig_idx=contig_idx_map[joined.locus.contig])
    joined = joined.annotate(_ref_entries=joined._ref_entries.map(
        lambda e: e.annotate(__contig_idx=joined.__contig_idx)))
    dense = joined.annotate(dense_ref=hl.or_missing(
        joined._include_locus,
        hl.rbind(
            joined.locus.position, lambda pos: hl.enumerate(
                hl.scan._densify(hl.len(joined._ref_cols), joined._ref_entries)
            ).map(lambda idx_and_e: hl.rbind(
                idx_and_e[0], idx_and_e[1], lambda idx, e: hl.coalesce(
                    joined._ref_entries[idx],
                    hl.or_missing((e.__contig_idx == joined.__contig_idx) &
                                  (e.END >= pos), e))).drop('__contig_idx')))))
    dense = dense.filter(dense._include_locus).drop('_interval_dup',
                                                    '_include_locus',
                                                    '__contig_idx')

    # at this point, 'dense' is a table with dense rows of reference blocks, keyed by locus

    refl_filtered = refl.annotate(
        **{interval_field: intervals[refl.locus]._interval_dup})

    # remove rows that are not contained in an interval, and rows that are the start of an
    # interval (interval starts come from the 'dense' table)
    refl_filtered = refl_filtered.filter(
        hl.is_defined(refl_filtered[interval_field])
        & (refl_filtered.locus != refl_filtered[interval_field].start))

    # union dense interval starts with filtered table
    refl_filtered = refl_filtered.union(
        dense.transmute(_ref_entries=dense.dense_ref))

    # rewrite reference blocks to end at the first of (interval end, reference block end)
    refl_filtered = refl_filtered.annotate(
        interval_end=refl_filtered[interval_field].end.position -
        ~refl_filtered[interval_field].includes_end)
    refl_filtered = refl_filtered.annotate(
        _ref_entries=refl_filtered._ref_entries.map(
            lambda entry: entry.annotate(END=hl.min(entry.END, refl_filtered.
                                                    interval_end))))

    return refl_filtered._unlocalize_entries('_ref_entries', '_ref_cols',
                                             list(ref.col_key))
示例#8
0
def combine(ts):
    def merge_alleles(alleles):
        from hail.expr.functions import _num_allele_type, _allele_ints
        return hl.rbind(
            alleles.map(lambda a: hl.or_else(a[0], ''))
            .fold(lambda s, t: hl.cond(hl.len(s) > hl.len(t), s, t), ''),
            lambda ref:
            hl.rbind(
                alleles.map(
                    lambda al: hl.rbind(
                        al[0],
                        lambda r:
                        hl.array([ref]).extend(
                            al[1:].map(
                                lambda a:
                                hl.rbind(
                                    _num_allele_type(r, a),
                                    lambda at:
                                    hl.cond(
                                        (_allele_ints['SNP'] == at)
                                        | (_allele_ints['Insertion'] == at)
                                        | (_allele_ints['Deletion'] == at)
                                        | (_allele_ints['MNP'] == at)
                                        | (_allele_ints['Complex'] == at),
                                        a + ref[hl.len(r):],
                                        a)))))),
                lambda lal:
                hl.struct(
                    globl=hl.array([ref]).extend(hl.array(hl.set(hl.flatten(lal)).remove(ref))),
                    local=lal)))

    def renumber_entry(entry, old_to_new) -> StructExpression:
        # global index of alternate (non-ref) alleles
        return entry.annotate(LA=entry.LA.map(lambda lak: old_to_new[lak]))

    if (ts.row.dtype, ts.globals.dtype) not in _merge_function_map:
        f = hl.experimental.define_function(
            lambda row, gbl:
            hl.rbind(
                merge_alleles(row.data.map(lambda d: d.alleles)),
                lambda alleles:
                hl.struct(
                    locus=row.locus,
                    alleles=alleles.globl,
                    rsid=hl.find(hl.is_defined, row.data.map(lambda d: d.rsid)),
                    __entries=hl.bind(
                        lambda combined_allele_index:
                        hl.range(0, hl.len(row.data)).flatmap(
                            lambda i:
                            hl.cond(hl.is_missing(row.data[i].__entries),
                                    hl.range(0, hl.len(gbl.g[i].__cols))
                                    .map(lambda _: hl.null(row.data[i].__entries.dtype.element_type)),
                                    hl.bind(
                                        lambda old_to_new: row.data[i].__entries.map(
                                            lambda e: renumber_entry(e, old_to_new)),
                                        hl.range(0, hl.len(alleles.local[i])).map(
                                            lambda j: combined_allele_index[alleles.local[i][j]])))),
                        hl.dict(hl.range(0, hl.len(alleles.globl)).map(
                            lambda j: hl.tuple([alleles.globl[j], j])))))),
            ts.row.dtype, ts.globals.dtype)
        _merge_function_map[(ts.row.dtype, ts.globals.dtype)] = f
    merge_function = _merge_function_map[(ts.row.dtype, ts.globals.dtype)]
    ts = Table(TableMapRows(ts._tir, Apply(merge_function._name,
                                           merge_function._ret_type,
                                           TopLevelReference('row'),
                                           TopLevelReference('global'))))
    return ts.transmute_globals(__cols=hl.flatten(ts.g.map(lambda g: g.__cols)))
示例#9
0
def combine(ts):
    def merge_alleles(alleles):
        from hail.expr.functions import _num_allele_type, _allele_ints
        return hl.rbind(
            alleles.map(lambda a: hl.or_else(a[0], ''))
                   .fold(lambda s, t: hl.cond(hl.len(s) > hl.len(t), s, t), ''),
            lambda ref:
            hl.rbind(
                alleles.map(
                    lambda al: hl.rbind(
                        al[0],
                        lambda r:
                        hl.array([ref]).extend(
                            al[1:].map(
                                lambda a:
                                hl.rbind(
                                    _num_allele_type(r, a),
                                    lambda at:
                                    hl.cond(
                                        (_allele_ints['SNP'] == at) |
                                        (_allele_ints['Insertion'] == at) |
                                        (_allele_ints['Deletion'] == at) |
                                        (_allele_ints['MNP'] == at) |
                                        (_allele_ints['Complex'] == at),
                                        a + ref[hl.len(r):],
                                        a)))))),
                lambda lal:
                hl.struct(
                    globl=hl.array([ref]).extend(hl.array(hl.set(hl.flatten(lal)).remove(ref))),
                    local=lal)))

    def renumber_entry(entry, old_to_new) -> StructExpression:
        # global index of alternate (non-ref) alleles
        return entry.annotate(LA=entry.LA.map(lambda lak: old_to_new[lak]))

    if (ts.row.dtype, ts.globals.dtype) not in _merge_function_map:
        f = hl.experimental.define_function(
            lambda row, gbl:
            hl.rbind(
                merge_alleles(row.data.map(lambda d: d.alleles)),
                lambda alleles:
                hl.struct(
                    locus=row.locus,
                    alleles=alleles.globl,
                    rsid=hl.find(hl.is_defined, row.data.map(lambda d: d.rsid)),
                    __entries=hl.bind(
                        lambda combined_allele_index:
                        hl.range(0, hl.len(row.data)).flatmap(
                            lambda i:
                            hl.cond(hl.is_missing(row.data[i].__entries),
                                    hl.range(0, hl.len(gbl.g[i].__cols))
                                      .map(lambda _: hl.null(row.data[i].__entries.dtype.element_type)),
                                    hl.bind(
                                        lambda old_to_new: row.data[i].__entries.map(
                                            lambda e: renumber_entry(e, old_to_new)),
                                        hl.range(0, hl.len(alleles.local[i])).map(
                                            lambda j: combined_allele_index[alleles.local[i][j]])))),
                        hl.dict(hl.range(0, hl.len(alleles.globl)).map(
                            lambda j: hl.tuple([alleles.globl[j], j])))))),
            ts.row.dtype, ts.globals.dtype)
        _merge_function_map[(ts.row.dtype, ts.globals.dtype)] = f
    merge_function = _merge_function_map[(ts.row.dtype, ts.globals.dtype)]
    ts = Table(TableMapRows(ts._tir, Apply(merge_function._name,
                                           TopLevelReference('row'),
                                           TopLevelReference('global'))))
    return ts.transmute_globals(__cols=hl.flatten(ts.g.map(lambda g: g.__cols)))
示例#10
0
文件: qc.py 项目: jigold/hail
def sample_qc(mt, name='sample_qc') -> MatrixTable:
    """Compute per-sample metrics useful for quality control.

    .. include:: ../_templates/req_tvariant.rst

    Examples
    --------

    Compute sample QC metrics and remove low-quality samples:

    >>> dataset = hl.sample_qc(dataset, name='sample_qc')
    >>> filtered_dataset = dataset.filter_cols((dataset.sample_qc.dp_stats.mean > 20) & (dataset.sample_qc.r_ti_tv > 1.5))

    Notes
    -----

    This method computes summary statistics per sample from a genetic matrix and stores
    the results as a new column-indexed struct field in the matrix, named based on the
    `name` parameter.

    If `mt` contains an entry field `DP` of type :py:data:`.tint32`, then the
    field `dp_stats` is computed. If `mt` contains an entry field `GQ` of type
    :py:data:`.tint32`, then the field `gq_stats` is computed. Both `dp_stats`
    and `gq_stats` are structs with with four fields:

    - `mean` (``float64``) -- Mean value.
    - `stdev` (``float64``) -- Standard deviation (zero degrees of freedom).
    - `min` (``int32``) -- Minimum value.
    - `max` (``int32``) -- Maximum value.

    If the dataset does not contain an entry field `GT` of type
    :py:data:`.tcall`, then an error is raised. The following fields are always
    computed from `GT`:

    - `call_rate` (``float64``) -- Fraction of calls not missing or filtered.
       Equivalent to `n_called` divided by :meth:`.count_rows`.
    - `n_called` (``int64``) -- Number of non-missing calls.
    - `n_not_called` (``int64``) -- Number of missing calls.
    - `n_filtered` (``int64``) -- Number of filtered entries.
    - `n_hom_ref` (``int64``) -- Number of homozygous reference calls.
    - `n_het` (``int64``) -- Number of heterozygous calls.
    - `n_hom_var` (``int64``) -- Number of homozygous alternate calls.
    - `n_non_ref` (``int64``) -- Sum of ``n_het`` and ``n_hom_var``.
    - `n_snp` (``int64``) -- Number of SNP alternate alleles.
    - `n_insertion` (``int64``) -- Number of insertion alternate alleles.
    - `n_deletion` (``int64``) -- Number of deletion alternate alleles.
    - `n_singleton` (``int64``) -- Number of private alleles.
    - `n_transition` (``int64``) -- Number of transition (A-G, C-T) alternate alleles.
    - `n_transversion` (``int64``) -- Number of transversion alternate alleles.
    - `n_star` (``int64``) -- Number of star (upstream deletion) alleles.
    - `r_ti_tv` (``float64``) -- Transition/Transversion ratio.
    - `r_het_hom_var` (``float64``) -- Het/HomVar call ratio.
    - `r_insertion_deletion` (``float64``) -- Insertion/Deletion allele ratio.

    Missing values ``NA`` may result from division by zero.

    Parameters
    ----------
    mt : :class:`.MatrixTable`
        Dataset.
    name : :obj:`str`
        Name for resulting field.

    Returns
    -------
    :class:`.MatrixTable`
        Dataset with a new column-indexed field `name`.
    """

    require_row_key_variant(mt, 'sample_qc')

    from hail.expr.functions import _num_allele_type , _allele_types

    allele_types = _allele_types[:]
    allele_types.extend(['Transition', 'Transversion'])
    allele_enum = {i: v for i, v in enumerate(allele_types)}
    allele_ints = {v: k for k, v in allele_enum.items()}

    def allele_type(ref, alt):
        return hl.bind(lambda at: hl.cond(at == allele_ints['SNP'],
                                          hl.cond(hl.is_transition(ref, alt),
                                                  allele_ints['Transition'],
                                                  allele_ints['Transversion']),
                                          at),
                       _num_allele_type(ref, alt))

    variant_ac = Env.get_uid()
    variant_atypes = Env.get_uid()
    mt = mt.annotate_rows(**{variant_ac: hl.agg.call_stats(mt.GT, mt.alleles).AC,
                             variant_atypes: mt.alleles[1:].map(lambda alt: allele_type(mt.alleles[0], alt))})

    bound_exprs = {}
    gq_dp_exprs = {}

    def has_field_of_type(name, dtype):
        return name in mt.entry and mt[name].dtype == dtype

    if has_field_of_type('DP', hl.tint32):
        gq_dp_exprs['dp_stats'] = hl.agg.stats(mt.DP).select('mean', 'stdev', 'min', 'max')

    if has_field_of_type('GQ', hl.tint32):
        gq_dp_exprs['gq_stats'] = hl.agg.stats(mt.GQ).select('mean', 'stdev', 'min', 'max')

    if not has_field_of_type('GT',  hl.tcall):
        raise ValueError(f"'sample_qc': expect an entry field 'GT' of type 'call'")

    bound_exprs['n_called'] = hl.agg.count_where(hl.is_defined(mt['GT']))
    bound_exprs['n_not_called'] = hl.agg.count_where(hl.is_missing(mt['GT']))
    bound_exprs['n_filtered'] = mt.count_rows(_localize=False) - hl.agg.count()
    bound_exprs['n_hom_ref'] = hl.agg.count_where(mt['GT'].is_hom_ref())
    bound_exprs['n_het'] = hl.agg.count_where(mt['GT'].is_het())
    bound_exprs['n_singleton'] = hl.agg.sum(hl.sum(hl.range(0, mt['GT'].ploidy).map(lambda i: mt[variant_ac][mt['GT'][i]] == 1)))

    def get_allele_type(allele_idx):
        return hl.cond(allele_idx > 0, mt[variant_atypes][allele_idx - 1], hl.null(hl.tint32))

    bound_exprs['allele_type_counts'] = hl.agg.explode(
        lambda elt: hl.agg.counter(elt),
        hl.range(0, mt['GT'].ploidy).map(lambda i: get_allele_type(mt['GT'][i])))

    zero = hl.int64(0)

    result_struct = hl.rbind(hl.struct(**bound_exprs),
        lambda x: hl.rbind(
            hl.struct(**{
                **gq_dp_exprs,
                'call_rate': hl.float64(x.n_called) / (x.n_called + x.n_not_called + x.n_filtered),
                'n_called': x.n_called,
                'n_not_called': x.n_not_called,
                'n_filtered': x.n_filtered,
                'n_hom_ref': x.n_hom_ref,
                'n_het': x.n_het,
                'n_hom_var': x.n_called - x.n_hom_ref - x.n_het,
                'n_non_ref': x.n_called - x.n_hom_ref,
                'n_singleton': x.n_singleton,
                'n_snp': x.allele_type_counts.get(allele_ints["Transition"], zero) + \
                         x.allele_type_counts.get(allele_ints["Transversion"], zero),
                'n_insertion': x.allele_type_counts.get(allele_ints["Insertion"], zero),
                'n_deletion': x.allele_type_counts.get(allele_ints["Deletion"], zero),
                'n_transition': x.allele_type_counts.get(allele_ints["Transition"], zero),
                'n_transversion': x.allele_type_counts.get(allele_ints["Transversion"], zero),
                'n_star': x.allele_type_counts.get(allele_ints["Star"], zero)
            }),
            lambda s: s.annotate(
                r_ti_tv=divide_null(hl.float64(s.n_transition), s.n_transversion),
                r_het_hom_var=divide_null(hl.float64(s.n_het), s.n_hom_var),
                r_insertion_deletion=divide_null(hl.float64(s.n_insertion), s.n_deletion)
            )))

    mt = mt.annotate_cols(**{name: result_struct})
    mt = mt.drop(variant_ac, variant_atypes)

    return mt
示例#11
0
def reannotate(mt, gatk_ht, summ_ht):
    """Re-annotate a sparse MT with annotations from certain GATK tools

    `gatk_ht` should be a table from the rows of a VCF, with `info` having at least
    the following fields.  Be aware that fields not present in this list will
    be dropped.
    ```
        struct {
            AC: array<int32>,
            AF: array<float64>,
            AN: int32,
            BaseQRankSum: float64,
            ClippingRankSum: float64,
            DP: int32,
            FS: float64,
            MQ: float64,
            MQRankSum: float64,
            MQ_DP: int32,
            NEGATIVE_TRAIN_SITE: bool,
            POSITIVE_TRAIN_SITE: bool,
            QD: float64,
            QUALapprox: int32,
            RAW_MQ: float64,
            ReadPosRankSum: float64,
            SB_TABLE: array<int32>,
            SOR: float64,
            VQSLOD: float64,
            VarDP: int32,
            culprit: str
        }
    ```
    `summarize_ht` should be the output of :func:`.summarize` as a rows table.

    Note
    ----
    You will not be able to run :func:`.combine_gvcfs` with the output of this
    function.
    """
    def check(ht):
        keys = list(ht.key)
        if keys[0] != 'locus':
            raise TypeError(
                f'table inputs must have first key "locus", found {keys}')
        if keys != ['locus']:
            return hl.Table(TableKeyBy(ht._tir, ['locus'], is_sorted=True))
        return ht

    gatk_ht, summ_ht = [check(ht) for ht in (gatk_ht, summ_ht)]
    return mt.annotate_rows(
        info=hl.rbind(
            gatk_ht[mt.locus].info, summ_ht[mt.locus].info,
            lambda ginfo, hinfo: hl.struct(
                AC=hl.or_else(hinfo.AC, ginfo.AC),
                AF=hl.or_else(hinfo.AF, ginfo.AF),
                AN=hl.or_else(hinfo.AN, ginfo.AN),
                BaseQRankSum=hl.or_else(hinfo.BaseQRankSum, ginfo.BaseQRankSum
                                        ),
                ClippingRankSum=hl.or_else(hinfo.ClippingRankSum, ginfo.
                                           ClippingRankSum),
                DP=hl.or_else(hinfo.DP, ginfo.DP),
                FS=ginfo.FS,
                MQ=hl.or_else(hinfo.MQ, ginfo.MQ),
                MQRankSum=hl.or_else(hinfo.MQRankSum, ginfo.MQRankSum),
                MQ_DP=hl.or_else(hinfo.MQ_DP, ginfo.MQ_DP),
                NEGATIVE_TRAIN_SITE=ginfo.NEGATIVE_TRAIN_SITE,
                POSITIVE_TRAIN_SITE=ginfo.POSITIVE_TRAIN_SITE,
                QD=ginfo.QD,
                QUALapprox=hl.or_else(hinfo.QUALapprox, ginfo.QUALapprox),
                RAW_MQ=hl.or_else(hinfo.RAW_MQ, ginfo.RAW_MQ),
                ReadPosRankSum=hl.or_else(hinfo.ReadPosRankSum, ginfo.
                                          ReadPosRankSum),
                SB_TABLE=hl.or_else(hinfo.SB_TABLE, ginfo.SB_TABLE),
                SOR=ginfo.SOR,
                VQSLOD=ginfo.VQSLOD,
                VarDP=hl.or_else(hinfo.VarDP, ginfo.VarDP),
                culprit=ginfo.culprit,
            )),
        qual=gatk_ht[mt.locus].qual,
        filters=gatk_ht[mt.locus].filters,
    )
示例#12
0
def reannotate(mt, gatk_ht, summ_ht):
    """Re-annotate a sparse MT with annotations from certain GATK tools

    `gatk_ht` should be a table from the rows of a VCF, with `info` having at least
    the following fields.  Be aware that fields not present in this list will
    be dropped.
    ```
        struct {
            AC: array<int32>,
            AF: array<float64>,
            AN: int32,
            BaseQRankSum: float64,
            ClippingRankSum: float64,
            DP: int32,
            FS: float64,
            MQ: float64,
            MQRankSum: float64,
            MQ_DP: int32,
            NEGATIVE_TRAIN_SITE: bool,
            POSITIVE_TRAIN_SITE: bool,
            QD: float64,
            QUALapprox: int32,
            RAW_MQ: float64,
            ReadPosRankSum: float64,
            SB_TABLE: array<int32>,
            SOR: float64,
            VQSLOD: float64,
            VarDP: int32,
            culprit: str
        }
    ```
    `summarize_ht` should be the output of :func:`.summarize` as a rows table.

    Note
    ----
    You will not be able to run :func:`.combine_gvcfs` with the output of this
    function.
    """
    def check(ht):
        keys = list(ht.key)
        if keys[0] != 'locus':
            raise TypeError(f'table inputs must have first key "locus", found {keys}')
        if keys != ['locus']:
            return hl.Table(TableKeyBy(ht._tir, ['locus'], is_sorted=True))
        return ht

    gatk_ht, summ_ht = [check(ht) for ht in (gatk_ht, summ_ht)]
    return mt.annotate_rows(
        info=hl.rbind(
            gatk_ht[mt.locus].info, summ_ht[mt.locus].info,
            lambda ginfo, hinfo: hl.struct(
                AC=hl.or_else(hinfo.AC, ginfo.AC),
                AF=hl.or_else(hinfo.AF, ginfo.AF),
                AN=hl.or_else(hinfo.AN, ginfo.AN),
                BaseQRankSum=hl.or_else(hinfo.BaseQRankSum, ginfo.BaseQRankSum),
                ClippingRankSum=hl.or_else(hinfo.ClippingRankSum, ginfo.ClippingRankSum),
                DP=hl.or_else(hinfo.DP, ginfo.DP),
                FS=ginfo.FS,
                MQ=hl.or_else(hinfo.MQ, ginfo.MQ),
                MQRankSum=hl.or_else(hinfo.MQRankSum, ginfo.MQRankSum),
                MQ_DP=hl.or_else(hinfo.MQ_DP, ginfo.MQ_DP),
                NEGATIVE_TRAIN_SITE=ginfo.NEGATIVE_TRAIN_SITE,
                POSITIVE_TRAIN_SITE=ginfo.POSITIVE_TRAIN_SITE,
                QD=ginfo.QD,
                QUALapprox=hl.or_else(hinfo.QUALapprox, ginfo.QUALapprox),
                RAW_MQ=hl.or_else(hinfo.RAW_MQ, ginfo.RAW_MQ),
                ReadPosRankSum=hl.or_else(hinfo.ReadPosRankSum, ginfo.ReadPosRankSum),
                SB_TABLE=hl.or_else(hinfo.SB_TABLE, ginfo.SB_TABLE),
                SOR=ginfo.SOR,
                VQSLOD=ginfo.VQSLOD,
                VarDP=hl.or_else(hinfo.VarDP, ginfo.VarDP),
                culprit=ginfo.culprit,
            )),
        qual=gatk_ht[mt.locus].qual,
        filters=gatk_ht[mt.locus].filters,
    )
示例#13
0
def sample_qc(mt, name='sample_qc') -> MatrixTable:
    """Compute per-sample metrics useful for quality control.

    .. include:: ../_templates/req_tvariant.rst

    Examples
    --------

    Compute sample QC metrics and remove low-quality samples:

    >>> dataset = hl.sample_qc(dataset, name='sample_qc')
    >>> filtered_dataset = dataset.filter_cols((dataset.sample_qc.dp_stats.mean > 20) & (dataset.sample_qc.r_ti_tv > 1.5))

    Notes
    -----

    This method computes summary statistics per sample from a genetic matrix and stores
    the results as a new column-indexed struct field in the matrix, named based on the
    `name` parameter.

    If `mt` contains an entry field `DP` of type :py:data:`.tint32`, then the
    field `dp_stats` is computed. If `mt` contains an entry field `GQ` of type
    :py:data:`.tint32`, then the field `gq_stats` is computed. Both `dp_stats`
    and `gq_stats` are structs with with four fields:

    - `mean` (``float64``) -- Mean value.
    - `stdev` (``float64``) -- Standard deviation (zero degrees of freedom).
    - `min` (``int32``) -- Minimum value.
    - `max` (``int32``) -- Maximum value.

    If the dataset does not contain an entry field `GT` of type
    :py:data:`.tcall`, then an error is raised. The following fields are always
    computed from `GT`:

    - `call_rate` (``float64``) -- Fraction of calls not missing or filtered.
      Equivalent to `n_called` divided by :meth:`.count_rows`.
    - `n_called` (``int64``) -- Number of non-missing calls.
    - `n_not_called` (``int64``) -- Number of missing calls.
    - `n_filtered` (``int64``) -- Number of filtered entries.
    - `n_hom_ref` (``int64``) -- Number of homozygous reference calls.
    - `n_het` (``int64``) -- Number of heterozygous calls.
    - `n_hom_var` (``int64``) -- Number of homozygous alternate calls.
    - `n_non_ref` (``int64``) -- Sum of `n_het` and `n_hom_var`.
    - `n_snp` (``int64``) -- Number of SNP alternate alleles.
    - `n_insertion` (``int64``) -- Number of insertion alternate alleles.
    - `n_deletion` (``int64``) -- Number of deletion alternate alleles.
    - `n_singleton` (``int64``) -- Number of private alleles.
    - `n_transition` (``int64``) -- Number of transition (A-G, C-T) alternate alleles.
    - `n_transversion` (``int64``) -- Number of transversion alternate alleles.
    - `n_star` (``int64``) -- Number of star (upstream deletion) alleles.
    - `r_ti_tv` (``float64``) -- Transition/Transversion ratio.
    - `r_het_hom_var` (``float64``) -- Het/HomVar call ratio.
    - `r_insertion_deletion` (``float64``) -- Insertion/Deletion allele ratio.

    Missing values ``NA`` may result from division by zero.

    Parameters
    ----------
    mt : :class:`.MatrixTable`
        Dataset.
    name : :obj:`str`
        Name for resulting field.

    Returns
    -------
    :class:`.MatrixTable`
        Dataset with a new column-indexed field `name`.
    """

    require_row_key_variant(mt, 'sample_qc')

    from hail.expr.functions import _num_allele_type, _allele_types

    allele_types = _allele_types[:]
    allele_types.extend(['Transition', 'Transversion'])
    allele_enum = {i: v for i, v in enumerate(allele_types)}
    allele_ints = {v: k for k, v in allele_enum.items()}

    def allele_type(ref, alt):
        return hl.bind(
            lambda at: hl.cond(
                at == allele_ints['SNP'],
                hl.cond(hl.is_transition(ref, alt), allele_ints['Transition'],
                        allele_ints['Transversion']), at),
            _num_allele_type(ref, alt))

    variant_ac = Env.get_uid()
    variant_atypes = Env.get_uid()
    mt = mt.annotate_rows(
        **{
            variant_ac:
            hl.agg.call_stats(mt.GT, mt.alleles).AC,
            variant_atypes:
            mt.alleles[1:].map(lambda alt: allele_type(mt.alleles[0], alt))
        })

    bound_exprs = {}
    gq_dp_exprs = {}

    def has_field_of_type(name, dtype):
        return name in mt.entry and mt[name].dtype == dtype

    if has_field_of_type('DP', hl.tint32):
        gq_dp_exprs['dp_stats'] = hl.agg.stats(mt.DP).select(
            'mean', 'stdev', 'min', 'max')

    if has_field_of_type('GQ', hl.tint32):
        gq_dp_exprs['gq_stats'] = hl.agg.stats(mt.GQ).select(
            'mean', 'stdev', 'min', 'max')

    if not has_field_of_type('GT', hl.tcall):
        raise ValueError(
            f"'sample_qc': expect an entry field 'GT' of type 'call'")

    bound_exprs['n_called'] = hl.agg.count_where(hl.is_defined(mt['GT']))
    bound_exprs['n_not_called'] = hl.agg.count_where(hl.is_missing(mt['GT']))

    n_rows_ref = hl.expr.construct_expr(
        hl.ir.Ref('n_rows'), hl.tint64, mt._col_indices,
        hl.utils.LinkedList(hl.expr.Aggregation))
    bound_exprs['n_filtered'] = n_rows_ref - hl.agg.count()
    bound_exprs['n_hom_ref'] = hl.agg.count_where(mt['GT'].is_hom_ref())
    bound_exprs['n_het'] = hl.agg.count_where(mt['GT'].is_het())
    bound_exprs['n_singleton'] = hl.agg.sum(
        hl.sum(
            hl.range(0, mt['GT'].ploidy).map(
                lambda i: mt[variant_ac][mt['GT'][i]] == 1)))

    def get_allele_type(allele_idx):
        return hl.cond(allele_idx > 0, mt[variant_atypes][allele_idx - 1],
                       hl.null(hl.tint32))

    bound_exprs['allele_type_counts'] = hl.agg.explode(
        lambda elt: hl.agg.counter(elt),
        hl.range(0,
                 mt['GT'].ploidy).map(lambda i: get_allele_type(mt['GT'][i])))

    zero = hl.int64(0)

    result_struct = hl.rbind(hl.struct(**bound_exprs),
        lambda x: hl.rbind(
            hl.struct(**{
                **gq_dp_exprs,
                'call_rate': hl.float64(x.n_called) / (x.n_called + x.n_not_called + x.n_filtered),
                'n_called': x.n_called,
                'n_not_called': x.n_not_called,
                'n_filtered': x.n_filtered,
                'n_hom_ref': x.n_hom_ref,
                'n_het': x.n_het,
                'n_hom_var': x.n_called - x.n_hom_ref - x.n_het,
                'n_non_ref': x.n_called - x.n_hom_ref,
                'n_singleton': x.n_singleton,
                'n_snp': x.allele_type_counts.get(allele_ints["Transition"], zero) + \
                         x.allele_type_counts.get(allele_ints["Transversion"], zero),
                'n_insertion': x.allele_type_counts.get(allele_ints["Insertion"], zero),
                'n_deletion': x.allele_type_counts.get(allele_ints["Deletion"], zero),
                'n_transition': x.allele_type_counts.get(allele_ints["Transition"], zero),
                'n_transversion': x.allele_type_counts.get(allele_ints["Transversion"], zero),
                'n_star': x.allele_type_counts.get(allele_ints["Star"], zero)
            }),
            lambda s: s.annotate(
                r_ti_tv=divide_null(hl.float64(s.n_transition), s.n_transversion),
                r_het_hom_var=divide_null(hl.float64(s.n_het), s.n_hom_var),
                r_insertion_deletion=divide_null(hl.float64(s.n_insertion), s.n_deletion)
            )))

    mt = mt.annotate_cols(**{name: result_struct})
    mt = mt.drop(variant_ac, variant_atypes)

    return mt
示例#14
0
def import_exac_vcf(path):
    ds = hl.import_vcf(path, force_bgz=True, skip_invalid_loci=True).rows()

    ds = hl.split_multi(ds)

    ds = ds.repartition(5000, shuffle=True)

    # Get value corresponding to the split variant
    ds = ds.annotate(info=ds.info.annotate(
        **{
            field: hl.or_missing(hl.is_defined(ds.info[field]), ds.info[field][
                ds.a_index - 1])
            for field in PER_ALLELE_FIELDS
        }))

    # For DP_HIST and GQ_HIST, the first value in the array contains the histogram for all individuals,
    # which is the same in each alt allele's variant.
    ds = ds.annotate(info=ds.info.annotate(
        DP_HIST=hl.struct(all=ds.info.DP_HIST[0],
                          alt=ds.info.DP_HIST[ds.a_index]),
        GQ_HIST=hl.struct(all=ds.info.GQ_HIST[0],
                          alt=ds.info.GQ_HIST[ds.a_index]),
    ))

    ds = ds.cache()

    # Convert "NA" and empty strings into null values
    # Convert fields in chunks to avoid "Method code too large" errors
    for i in range(0, len(SELECT_INFO_FIELDS), 10):
        ds = ds.annotate(info=ds.info.annotate(
            **{
                field: hl.or_missing(
                    hl.is_defined(ds.info[field]),
                    hl.if_else(
                        (hl.str(ds.info[field]) == "")
                        | (hl.str(ds.info[field]) == "NA"),
                        hl.null(ds.info[field].dtype),
                        ds.info[field],
                    ),
                )
                for field in SELECT_INFO_FIELDS[i:i + 10]
            }))

    # Convert field types
    ds = ds.annotate(info=ds.info.annotate(
        **{
            field: hl.if_else(ds.info[field] == "", hl.null(hl.tint),
                              hl.int(ds.info[field]))
            for field in CONVERT_TO_INT_FIELDS
        }))
    ds = ds.annotate(info=ds.info.annotate(
        **{
            field: hl.if_else(ds.info[field] == "", hl.null(hl.tfloat),
                              hl.float(ds.info[field]))
            for field in CONVERT_TO_FLOAT_FIELDS
        }))

    # Format VEP annotations to mimic the output of hail.vep
    ds = ds.annotate(info=ds.info.annotate(CSQ=ds.info.CSQ.map(
        lambda s: s.replace("%3A", ":").replace("%3B", ";").replace(
            "%3D", "=").replace("%25", "%").replace("%2C", ","))))
    ds = ds.annotate(vep=hl.struct(
        transcript_consequences=ds.info.CSQ.map(lambda csq_str: hl.bind(
            lambda csq_values: hl.struct(
                **{
                    field: hl.if_else(csq_values[index] == "", hl.null(
                        hl.tstr), csq_values[index])
                    for index, field in enumerate(VEP_FIELDS)
                }),
            csq_str.split(r"\|"),
        )).filter(lambda annotation: annotation.Feature.startswith("ENST")).
        filter(lambda annotation: hl.int(annotation.ALLELE_NUM) == ds.a_index).
        map(lambda annotation: annotation.select(
            amino_acids=annotation.Amino_acids,
            biotype=annotation.BIOTYPE,
            canonical=annotation.CANONICAL == "YES",
            # cDNA_position may contain either "start-end" or, when start == end, "start"
            cdna_start=split_position_start(annotation.cDNA_position),
            cdna_end=split_position_end(annotation.cDNA_position),
            codons=annotation.Codons,
            consequence_terms=annotation.Consequence.split("&"),
            distance=hl.int(annotation.DISTANCE),
            domains=hl.or_missing(
                hl.is_defined(annotation.DOMAINS),
                annotation.DOMAINS.split("&").map(lambda d: hl.struct(
                    db=d.split(":")[0], name=d.split(":")[1])),
            ),
            exon=annotation.EXON,
            gene_id=annotation.Gene,
            gene_symbol=annotation.SYMBOL,
            gene_symbol_source=annotation.SYMBOL_SOURCE,
            hgnc_id=annotation.HGNC_ID,
            hgvsc=annotation.HGVSc,
            hgvsp=annotation.HGVSp,
            lof=annotation.LoF,
            lof_filter=annotation.LoF_filter,
            lof_flags=annotation.LoF_flags,
            lof_info=annotation.LoF_info,
            # PolyPhen field contains "polyphen_prediction(polyphen_score)"
            polyphen_prediction=hl.or_missing(
                hl.is_defined(annotation.PolyPhen),
                annotation.PolyPhen.split(r"\(")[0]),
            protein_id=annotation.ENSP,
            # Protein_position may contain either "start-end" or, when start == end, "start"
            protein_start=split_position_start(annotation.Protein_position),
            protein_end=split_position_end(annotation.Protein_position),
            # SIFT field contains "sift_prediction(sift_score)"
            sift_prediction=hl.or_missing(hl.is_defined(annotation.SIFT),
                                          annotation.SIFT.split(r"\(")[0]),
            transcript_id=annotation.Feature,
        ))))

    ds = ds.annotate(vep=ds.vep.annotate(most_severe_consequence=hl.bind(
        lambda all_consequence_terms: hl.or_missing(
            all_consequence_terms.size() != 0,
            hl.sorted(all_consequence_terms, key=consequence_term_rank)[0]),
        ds.vep.transcript_consequences.flatmap(lambda c: c.consequence_terms),
    )))

    ds = ds.cache()

    QUALITY_METRIC_HISTOGRAM_BIN_EDGES = [i * 5 for i in range(21)]

    ds = ds.select(
        variant_id=variant_id(ds.locus, ds.alleles),
        reference_genome="GRCh37",
        chrom=normalized_contig(ds.locus.contig),
        pos=ds.locus.position,
        xpos=x_position(ds.locus),
        ref=ds.alleles[0],
        alt=ds.alleles[1],
        rsids=hl.or_missing(hl.is_defined(ds.rsid), hl.set([ds.rsid])),
        exome=hl.struct(
            ac=ds.info.AC_Adj,
            an=ds.info.AN_Adj,
            homozygote_count=ds.info.AC_Hom,
            hemizygote_count=hl.or_else(ds.info.AC_Hemi, 0),
            filters=hl.set(
                hl.if_else(ds.info.AC_Adj == 0, ds.filters.add("AC0"),
                           ds.filters)),
            populations=[
                hl.struct(
                    id=pop_id.lower(),
                    ac=ds.info[f"AC_{pop_id}"],
                    an=ds.info[f"AN_{pop_id}"],
                    hemizygote_count=hl.or_else(ds.info[f"Hemi_{pop_id}"], 0),
                    homozygote_count=ds.info[f"Hom_{pop_id}"],
                ) for pop_id in
                ["AFR", "AMR", "EAS", "FIN", "NFE", "OTH", "SAS"]
            ],
            age_distribution=hl.struct(
                het=hl.rbind(
                    hl.or_else(ds.info.AGE_HISTOGRAM_HET,
                               "0|0|0|0|0|0|0|0|0|0|0|0").split(r"\|").map(
                                   hl.float),
                    lambda bins: hl.struct(
                        bin_edges=[30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80],
                        bin_freq=bins[1:11],
                        n_smaller=bins[0],
                        n_larger=bins[11],
                    ),
                ),
                hom=hl.rbind(
                    hl.or_else(ds.info.AGE_HISTOGRAM_HOM,
                               "0|0|0|0|0|0|0|0|0|0|0|0").split(r"\|").map(
                                   hl.float),
                    lambda bins: hl.struct(
                        bin_edges=[30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80],
                        bin_freq=bins[1:11],
                        n_smaller=bins[0],
                        n_larger=bins[11],
                    ),
                ),
            ),
            quality_metrics=hl.struct(
                genotype_depth=hl.struct(
                    all_raw=hl.struct(
                        bin_edges=QUALITY_METRIC_HISTOGRAM_BIN_EDGES,
                        bin_freq=ds.info.DP_HIST.all.split(r"\|").map(hl.int),
                    ),
                    alt_raw=hl.struct(
                        bin_edges=QUALITY_METRIC_HISTOGRAM_BIN_EDGES,
                        bin_freq=ds.info.DP_HIST.alt.split(r"\|").map(hl.int),
                    ),
                ),
                genotype_quality=hl.struct(
                    all_raw=hl.struct(
                        bin_edges=QUALITY_METRIC_HISTOGRAM_BIN_EDGES,
                        bin_freq=ds.info.GQ_HIST.all.split(r"\|").map(hl.int),
                    ),
                    alt_raw=hl.struct(
                        bin_edges=QUALITY_METRIC_HISTOGRAM_BIN_EDGES,
                        bin_freq=ds.info.GQ_HIST.alt.split(r"\|").map(hl.int),
                    ),
                ),
                site_quality_metrics=[
                    hl.struct(metric="BaseQRankSum",
                              value=hl.float(ds.info.BaseQRankSum)),
                    hl.struct(metric="ClippingRankSum",
                              value=hl.float(ds.info.ClippingRankSum)),
                    hl.struct(metric="DP", value=hl.float(ds.info.DP)),
                    hl.struct(metric="FS", value=hl.float(ds.info.FS)),
                    hl.struct(metric="InbreedingCoeff",
                              value=hl.float(ds.info.InbreedingCoeff)),
                    hl.struct(metric="MQ", value=hl.float(ds.info.MQ)),
                    hl.struct(metric="MQRankSum",
                              value=hl.float(ds.info.MQRankSum)),
                    hl.struct(metric="QD", value=hl.float(ds.info.QD)),
                    hl.struct(metric="ReadPosRankSum",
                              value=hl.float(ds.info.ReadPosRankSum)),
                    hl.struct(metric="SiteQuality", value=hl.float(ds.qual)),
                    hl.struct(metric="VQSLOD", value=hl.float(ds.info.VQSLOD)),
                ],
            ),
        ),
        colocated_variants=hl.rbind(
            variant_id(ds.locus, ds.alleles),
            lambda this_variant_id: variant_ids(ds.old_locus, ds.old_alleles).
            filter(lambda v_id: v_id != this_variant_id),
        ),
        vep=ds.vep,
    )

    ds = ds.annotate(genome=hl.null(ds.exome.dtype))

    return ds
示例#15
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--gencode",
        action="append",
        default=[],
        metavar=("version", "gtf_path", "canonical_transcripts_path"),
        nargs=3,
        required=True,
    )
    parser.add_argument("--hgnc")
    parser.add_argument("--mane-select-transcripts")
    parser.add_argument("--min-partitions", type=int, default=32)
    parser.add_argument("--output", required=True)
    args = parser.parse_args()
    genes = None
    all_gencode_versions = [
        gencode_version for gencode_version, _, _ in args.gencode
    ]
    for gencode_version, gtf_path, canonical_transcripts_path in args.gencode:
        gencode_genes = load_gencode_gene_models(
            gtf_path, min_partitions=args.min_partitions)
        canonical_transcripts = hl.import_table(
            canonical_transcripts_path,
            key="gene_id",
            min_partitions=args.min_partitions)
        gencode_genes = gencode_genes.annotate(
            canonical_transcript_id=canonical_transcripts[
                gencode_genes.gene_id].transcript_id)
        gencode_genes = gencode_genes.select(
            **{f"v{gencode_version}": gencode_genes.row_value})
        if not genes:
            genes = gencode_genes
        else:
            genes = genes.join(gencode_genes, "outer")
    genes = genes.select(gencode=genes.row_value)
    hgnc = hl.import_table(args.hgnc, missing="")
    hgnc = hgnc.select(
        hgnc_id=hgnc["HGNC ID"],
        symbol=hgnc["Approved symbol"],
        name=hgnc["Approved name"],
        previous_symbols=hgnc["Previous symbols"],
        alias_symbols=hgnc["Alias symbols"],
        omim_id=hgnc["OMIM ID(supplied by OMIM)"],
        gene_id=hl.or_else(hgnc["Ensembl gene ID"],
                           hgnc["Ensembl ID(supplied by Ensembl)"]),
    )
    hgnc = hgnc.filter(hl.is_defined(hgnc.gene_id)).key_by("gene_id")
    hgnc = hgnc.annotate(
        previous_symbols=hl.cond(
            hgnc.previous_symbols == "",
            hl.empty_array(hl.tstr),
            hgnc.previous_symbols.split(",").map(lambda s: s.strip()),
        ),
        alias_symbols=hl.cond(
            hgnc.alias_symbols == "", hl.empty_array(hl.tstr),
            hgnc.alias_symbols.split(",").map(lambda s: s.strip())),
    )
    genes = genes.annotate(**hgnc[genes.gene_id])
    genes = genes.annotate(symbol_source=hl.cond(hl.is_defined(genes.symbol),
                                                 "hgnc", hl.null(hl.tstr)))
    for gencode_version in all_gencode_versions:
        genes = genes.annotate(
            symbol=hl.or_else(
                genes.symbol,
                genes.gencode[f"v{gencode_version}"].gene_symbol),
            symbol_source=hl.cond(
                hl.is_missing(genes.symbol) & hl.is_defined(
                    genes.gencode[f"v{gencode_version}"].gene_symbol),
                f"gencode (v{gencode_version})",
                genes.symbol_source,
            ),
        )  # Collect all fields that can be used to search by gene name
    genes = genes.annotate(
        symbol_upper_case=genes.symbol.upper(),
        search_terms=hl.empty_array(hl.tstr).append(genes.symbol).extend(
            genes.previous_symbols).extend(genes.alias_symbols),
    )
    for gencode_version in all_gencode_versions:
        genes = genes.annotate(search_terms=hl.rbind(
            genes.gencode[f"v{gencode_version}"].gene_symbol,
            lambda symbol_in_gencode: hl.cond(
                hl.is_defined(symbol_in_gencode),
                genes.search_terms.append(symbol_in_gencode), genes.
                search_terms),
        ))
    genes = genes.annotate(
        search_terms=hl.set(genes.search_terms.map(lambda s: s.upper())))
    if args.mane_select_transcripts:
        mane_select_transcripts = hl.import_table(args.mane_select_transcripts,
                                                  force=True)
        mane_select_transcripts = mane_select_transcripts.select(
            gene_id=mane_select_transcripts.Ensembl_Gene.split("\\.")[0],
            matched_gene_version=mane_select_transcripts.Ensembl_Gene.split(
                "\\.")[1],
            ensembl_id=mane_select_transcripts.Ensembl_nuc.split("\\.")[0],
            ensembl_version=mane_select_transcripts.Ensembl_nuc.split("\\.")
            [1],
            refseq_id=mane_select_transcripts.RefSeq_nuc.split("\\.")[0],
            refseq_version=mane_select_transcripts.RefSeq_nuc.split("\\.")[1],
        )
        mane_select_transcripts = mane_select_transcripts.key_by("gene_id")
        ensembl_to_refseq_map = {}
        for transcript in mane_select_transcripts.collect():
            ensembl_to_refseq_map[transcript.ensembl_id] = {
                transcript.ensembl_version:
                hl.Struct(refseq_id=transcript.refseq_id,
                          refseq_version=transcript.refseq_version)
            }

        ensembl_to_refseq_map = hl.literal(ensembl_to_refseq_map)
        for gencode_version in ["19", "29"]:
            if int(gencode_version) >= 20:
                transcript_annotation = lambda transcript: transcript.annotate(
                    **ensembl_to_refseq_map.get(
                        transcript.transcript_id,
                        hl.empty_dict(
                            hl.tstr,
                            hl.tstruct(refseq_id=hl.tstr,
                                       refseq_version=hl.tstr)),
                    ).get(
                        transcript.transcript_version,
                        hl.struct(refseq_id=hl.null(hl.tstr),
                                  refseq_version=hl.null(hl.tstr)),
                    ))
            else:
                transcript_annotation = lambda transcript: transcript.annotate(
                    refseq_id=hl.null(hl.tstr),
                    refseq_version=hl.null(hl.tstr))
            genes = genes.annotate(gencode=genes.gencode.annotate(
                **{
                    f"v{gencode_version}":
                    genes.gencode[f"v{gencode_version}"].annotate(
                        transcripts=genes.gencode[f"v{gencode_version}"].
                        transcripts.map(transcript_annotation))
                }))
        genes = genes.annotate(
            mane_select_transcript=mane_select_transcripts[genes.gene_id])
    genes.describe()
    genes.write(args.output, overwrite=True)
for i in range(1, 7):
    ht = load_results(sex, i, dilution)
    new_phenotypes = list(ht['phenotypes'].collect()[0])

    ht_results = ht_results.annotate_globals(
        columns=ht_results['columns'].extend(
            hl.array([hl.struct(phenotype=x) for x in new_phenotypes])))

    ht_results = ht_results.annotate(entries=ht_results['entries'].extend(
        hl.rbind(
            ht[ht_results.key], lambda new_results: hl.array([
                hl.struct(n=new_results['n'][i],
                          sum_x=new_results['sum_x'][i],
                          y_transpose_x=new_results['y_transpose_x'][i][0],
                          beta=new_results['beta'][i][0],
                          standard_error=new_results['standard_error'][i][0],
                          t_stat=new_results['t_stat'][i][0],
                          p_value=new_results['p_value'][i][0])
                for i in range(len(new_phenotypes))
            ]))))

mt = ht_results._unlocalize_entries('entries', 'columns', ['phenotype'])

codes = hl.literal({
    'albumin': '30600',
    'alkaline_phosphatase': '30610',
    'alanine_aminotransferase': '30620',
    'apoliprotein_A': '30630',
    'apoliprotein_B': '30640',
    'aspartate_aminotransferase': '30650',
def prepare_variant_results():
    annotations = None
    results = None

    for group in ("dn", "dbs", "swe"):
        group_annotations_path = pipeline_config.get("ASC", f"{group}_variant_annotations_path")
        group_results_path = pipeline_config.get("ASC", f"{group}_variant_results_path")

        group_annotations = hl.import_table(
            group_annotations_path,
            force=True,
            key="v",
            missing="NA",
            types={
                "v": hl.tstr,
                "in_analysis": hl.tbool,
                "gene_id": hl.tstr,
                "gene_name": hl.tstr,
                "transcript_id": hl.tstr,
                "hgvsc": hl.tstr,
                "hgvsp": hl.tstr,
                "csq_analysis": hl.tstr,
                "csq_worst": hl.tstr,
                "mpc": hl.tfloat,
                "polyphen": hl.tstr,
            },
        )

        group_annotations = group_annotations.repartition(100, shuffle=True)

        if annotations is None:
            annotations = group_annotations
        else:
            annotations = annotations.union(group_annotations)

        group_results = hl.import_table(
            group_results_path,
            force=True,
            min_partitions=100,
            key="v",
            missing="NA",
            types={
                "v": hl.tstr,
                "analysis_group": hl.tstr,
                "ac_case": hl.tint,
                "an_case": hl.tint,
                "af_case": hl.tstr,
                "ac_ctrl": hl.tint,
                "an_ctrl": hl.tint,
                "af_ctrl": hl.tstr,
            },
        )

        group_results = group_results.repartition(100, shuffle=True)

        group_results = group_results.drop("af_case", "af_ctrl")

        group_results = group_results.annotate(in_analysis=group_annotations[group_results.v].in_analysis)

        if results is None:
            results = group_results
        else:
            results = results.union(group_results)

    annotations = annotations.cache()
    results = results.cache()

    annotations = annotations.distinct()
    annotations = annotations.cache()

    annotations = annotations.select(
        "gene_id",
        consequence=hl.sorted(
            annotations.csq_analysis.split(","),
            lambda c: CONSEQUENCE_TERM_RANKS.get(c),  # pylint: disable=unnecessary-lambda
        )[0],
        hgvsc=annotations.hgvsc.split(":")[-1],
        hgvsp=annotations.hgvsp.split(":")[-1],
        info=hl.struct(mpc=annotations.mpc, polyphen=annotations.polyphen),
    )

    results = results.group_by("v").aggregate(group_results=hl.agg.collect(results.row_value))
    results = results.annotate(
        group_results=hl.dict(
            results.group_results.map(
                lambda group_result: (group_result.analysis_group, group_result.drop("analysis_group"))
            )
        )
    )

    variants = annotations.annotate(group_results=results[annotations.key].group_results)

    variants = variants.annotate(
        locus=hl.rbind(variants.v.split(":"), lambda p: hl.locus(p[0], hl.int(p[1]), reference_genome="GRCh37")),
        alleles=hl.rbind(variants.v.split(":"), lambda p: [p[2], p[3]]),
    )

    variants = variants.key_by("locus", "alleles")

    return variants
示例#18
0
def transform_one(mt, vardp_outlier=100_000) -> Table:
    """transforms a gvcf into a form suitable for combining

    The input to this should be some result of either :func:`.import_vcf` or
    :func:`.import_vcfs` with `array_elements_required=False`.

    There is a strong assumption that this function will be called on a matrix
    table with one column.
    """
    mt = localize(mt)
    if mt.row.dtype not in _transform_rows_function_map:
        f = hl.experimental.define_function(
            lambda row: hl.rbind(
                hl.len(row.alleles), '<NON_REF>' == row.alleles[-1],
                lambda alleles_len, has_non_ref: hl.struct(
                    locus=row.locus,
                    alleles=hl.cond(has_non_ref, row.alleles[:-1], row.alleles),
                    rsid=row.rsid,
                    __entries=row.__entries.map(
                        lambda e:
                        hl.struct(
                            DP=e.DP,
                            END=row.info.END,
                            GQ=e.GQ,
                            LA=hl.range(0, alleles_len - hl.cond(has_non_ref, 1, 0)),
                            LAD=hl.cond(has_non_ref, e.AD[:-1], e.AD),
                            LGT=e.GT,
                            LPGT=e.PGT,
                            LPL=hl.cond(has_non_ref,
                                        hl.cond(alleles_len > 2,
                                                e.PL[:-alleles_len],
                                                hl.null(e.PL.dtype)),
                                        hl.cond(alleles_len > 1,
                                                e.PL,
                                                hl.null(e.PL.dtype))),
                            MIN_DP=e.MIN_DP,
                            PID=e.PID,
                            RGQ=hl.cond(
                                has_non_ref,
                                e.PL[hl.call(0, alleles_len - 1).unphased_diploid_gt_index()],
                                hl.null(e.PL.dtype.element_type)),
                            SB=e.SB,
                            gvcf_info=hl.case()
                                .when(hl.is_missing(row.info.END),
                                      hl.struct(
                                          ClippingRankSum=row.info.ClippingRankSum,
                                          BaseQRankSum=row.info.BaseQRankSum,
                                          MQ=row.info.MQ,
                                          MQRankSum=row.info.MQRankSum,
                                          MQ_DP=row.info.MQ_DP,
                                          QUALapprox=row.info.QUALapprox,
                                          RAW_MQ=row.info.RAW_MQ,
                                          ReadPosRankSum=row.info.ReadPosRankSum,
                                          VarDP=hl.cond(row.info.VarDP > vardp_outlier,
                                                        row.info.DP, row.info.VarDP)))
                                .or_missing()
                        ))),
            ),
            mt.row.dtype)
        _transform_rows_function_map[mt.row.dtype] = f
    transform_row = _transform_rows_function_map[mt.row.dtype]
    return Table(TableMapRows(mt._tir, Apply(transform_row._name, TopLevelReference('row'))))
示例#19
0
def merge_alleles(alleles) -> ArrayExpression:
    # alleles is tarray(tarray(tstruct(ref=tstr, alt=tstr)))
    return hl.rbind(hl.array(hl.set(hl.flatten(alleles))),
                    lambda arr:
                    hl.filter(lambda a: a.alt != '<NON_REF>', arr)
                      .extend(hl.filter(lambda a: a.alt == '<NON_REF>', arr)))
示例#20
0
def transform_gvcf(mt, info_to_keep=[]) -> Table:
    """Transforms a gvcf into a sparse matrix table

    The input to this should be some result of either :func:`.import_vcf` or
    :func:`.import_gvcfs` with ``array_elements_required=False``.

    There is an assumption that this function will be called on a matrix table
    with one column (or a localized table version of the same).

    Parameters
    ----------
    mt : :obj:`Union[Table, MatrixTable]`
        The gvcf being transformed, if it is a table, then it must be a localized matrix table with
        the entries array named ``__entries``
    info_to_keep : :obj:`List[str]`
        Any ``INFO`` fields in the gvcf that are to be kept and put in the ``gvcf_info`` entry
        field. By default, all ``INFO`` fields except ``END`` and ``DP`` are kept.

    Returns
    -------
    :obj:`.Table`
        A localized matrix table that can be used as part of the input to :func:`.combine_gvcfs`

    Notes
    -----
    This function will parse the following allele specific annotations from
    pipe delimited strings into proper values. ::

        AS_QUALapprox
        AS_RAW_MQ
        AS_RAW_MQRankSum
        AS_RAW_ReadPosRankSum
        AS_SB_TABLE
        AS_VarDP

    """
    if not info_to_keep:
        info_to_keep = [name for name in mt.info if name not in ['END', 'DP']]
    mt = localize(mt)

    if mt.row.dtype not in _transform_rows_function_map:

        def get_lgt(e, n_alleles, has_non_ref, row):
            index = e.GT.unphased_diploid_gt_index()
            n_no_nonref = n_alleles - hl.int(has_non_ref)
            triangle_without_nonref = hl.triangle(n_no_nonref)
            return (hl.case().when(index < triangle_without_nonref, e.GT).when(
                index < hl.triangle(n_alleles),
                hl.missing('call')).or_error('invalid GT ' + hl.str(e.GT) +
                                             ' at site ' + hl.str(row.locus)))

        def make_entry_struct(e, alleles_len, has_non_ref, row):
            handled_fields = dict()
            handled_names = {
                'LA', 'gvcf_info', 'END', 'LAD', 'AD', 'LGT', 'GT', 'LPL',
                'PL', 'LPGT', 'PGT'
            }

            if 'END' not in row.info:
                raise hl.utils.FatalError(
                    "the Hail GVCF combiner expects GVCFs to have an 'END' field in INFO."
                )
            if 'GT' not in e:
                raise hl.utils.FatalError(
                    "the Hail GVCF combiner expects GVCFs to have a 'GT' field in FORMAT."
                )

            handled_fields['LA'] = hl.range(
                0, alleles_len - hl.if_else(has_non_ref, 1, 0))
            handled_fields['LGT'] = get_lgt(e, alleles_len, has_non_ref, row)
            if 'AD' in e:
                handled_fields['LAD'] = hl.if_else(has_non_ref, e.AD[:-1],
                                                   e.AD)
            if 'PGT' in e:
                handled_fields['LPGT'] = e.PGT
            if 'PL' in e:
                handled_fields['LPL'] = hl.if_else(
                    has_non_ref,
                    hl.if_else(alleles_len > 2, e.PL[:-alleles_len],
                               hl.missing(e.PL.dtype)),
                    hl.if_else(alleles_len > 1, e.PL, hl.missing(e.PL.dtype)))
                handled_fields['RGQ'] = hl.if_else(
                    has_non_ref,
                    e.PL[hl.call(0,
                                 alleles_len - 1).unphased_diploid_gt_index()],
                    hl.missing(e.PL.dtype.element_type))

            handled_fields['END'] = row.info.END
            handled_fields['gvcf_info'] = (hl.case().when(
                hl.is_missing(row.info.END),
                hl.struct(**(parse_as_fields(row.info.select(
                    *info_to_keep), has_non_ref)))).or_missing())

            pass_through_fields = {
                k: v
                for k, v in e.items() if k not in handled_names
            }
            return hl.struct(**handled_fields, **pass_through_fields)

        f = hl.experimental.define_function(
            lambda row: hl.rbind(
                hl.len(row.alleles), '<NON_REF>' == row.alleles[-1], lambda
                alleles_len, has_non_ref: hl.struct(
                    locus=row.locus,
                    alleles=hl.if_else(has_non_ref, row.alleles[:-1], row.
                                       alleles),
                    rsid=row.rsid,
                    __entries=row.__entries.map(lambda e: make_entry_struct(
                        e, alleles_len, has_non_ref, row)))), mt.row.dtype)
        _transform_rows_function_map[mt.row.dtype] = f
    transform_row = _transform_rows_function_map[mt.row.dtype]
    return Table(
        TableMapRows(
            mt._tir,
            Apply(transform_row._name, transform_row._ret_type,
                  TopLevelReference('row'))))
示例#21
0
def ld_score_regression(weight_expr,
                        ld_score_expr,
                        chi_sq_exprs,
                        n_samples_exprs,
                        n_blocks=200,
                        two_step_threshold=30,
                        n_reference_panel_variants=None) -> Table:
    r"""Estimate SNP-heritability and level of confounding biases from genome-wide association study
    (GWAS) summary statistics.

    Given a set or multiple sets of GWAS summary statistics, :func:`.ld_score_regression` estimates the heritability
    of a trait or set of traits and the level of confounding biases present in
    the underlying studies by regressing chi-squared statistics on LD scores,
    leveraging the model:

    .. math::

        \mathrm{E}[\chi_j^2] = 1 + Na + \frac{Nh_g^2}{M}l_j

    *  :math:`\mathrm{E}[\chi_j^2]` is the expected chi-squared statistic
       for variant :math:`j` resulting from a test of association between
       variant :math:`j` and a trait.
    *  :math:`l_j = \sum_{k} r_{jk}^2` is the LD score of variant
       :math:`j`, calculated as the sum of squared correlation coefficients
       between variant :math:`j` and nearby variants. See :func:`ld_score`
       for further details.
    *  :math:`a` captures the contribution of confounding biases, such as
       cryptic relatedness and uncontrolled population structure, to the
       association test statistic.
    *  :math:`h_g^2` is the SNP-heritability, or the proportion of variation
       in the trait explained by the effects of variants included in the
       regression model above.
    *  :math:`M` is the number of variants used to estimate :math:`h_g^2`.
    *  :math:`N` is the number of samples in the underlying association study.

    For more details on the method implemented in this function, see:

    * `LD Score regression distinguishes confounding from polygenicity in genome-wide association studies (Bulik-Sullivan et al, 2015) <https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4495769/>`__

    Examples
    --------

    Run the method on a matrix table of summary statistics, where the rows
    are variants and the columns are different phenotypes:

    >>> mt_gwas = ld_score_all_phenos_sumstats
    >>> ht_results = hl.experimental.ld_score_regression(
    ...     weight_expr=mt_gwas['ld_score'],
    ...     ld_score_expr=mt_gwas['ld_score'],
    ...     chi_sq_exprs=mt_gwas['chi_squared'],
    ...     n_samples_exprs=mt_gwas['n'])


    Run the method on a table with summary statistics for a single
    phenotype:

    >>> ht_gwas = ld_score_one_pheno_sumstats
    >>> ht_results = hl.experimental.ld_score_regression(
    ...     weight_expr=ht_gwas['ld_score'],
    ...     ld_score_expr=ht_gwas['ld_score'],
    ...     chi_sq_exprs=ht_gwas['chi_squared_50_irnt'],
    ...     n_samples_exprs=ht_gwas['n_50_irnt'])

    Run the method on a table with summary statistics for multiple
    phenotypes:

    >>> ht_gwas = ld_score_one_pheno_sumstats
    >>> ht_results = hl.experimental.ld_score_regression(
    ...     weight_expr=ht_gwas['ld_score'],
    ...     ld_score_expr=ht_gwas['ld_score'],
    ...     chi_sq_exprs=[ht_gwas['chi_squared_50_irnt'],
    ...                        ht_gwas['chi_squared_20160']],
    ...     n_samples_exprs=[ht_gwas['n_50_irnt'],
    ...                      ht_gwas['n_20160']])

    Notes
    -----
    The ``exprs`` provided as arguments to :func:`.ld_score_regression`
    must all be from the same object, either a :class:`Table` or a
    :class:`MatrixTable`.

    **If the arguments originate from a table:**

    *  The table must be keyed by fields ``locus`` of type
       :class:`.tlocus` and ``alleles``, a :py:data:`.tarray` of
       :py:data:`.tstr` elements.
    *  ``weight_expr``, ``ld_score_expr``, ``chi_sq_exprs``, and
       ``n_samples_exprs`` are must be row-indexed fields.
    *  The number of expressions passed to ``n_samples_exprs`` must be
       equal to one or the number of expressions passed to
       ``chi_sq_exprs``. If just one expression is passed to
       ``n_samples_exprs``, that sample size expression is assumed to
       apply to all sets of statistics passed to ``chi_sq_exprs``.
       Otherwise, the expressions passed to ``chi_sq_exprs`` and
       ``n_samples_exprs`` are matched by index.
    *  The ``phenotype`` field that keys the table returned by
       :func:`.ld_score_regression` will have generic :obj:`int` values
       ``0``, ``1``, etc. corresponding to the ``0th``, ``1st``, etc.
       expressions passed to the ``chi_sq_exprs`` argument.

    **If the arguments originate from a matrix table:**

    *  The dimensions of the matrix table must be variants
       (rows) by phenotypes (columns).
    *  The rows of the matrix table must be keyed by fields
       ``locus`` of type :class:`.tlocus` and ``alleles``,
       a :py:data:`.tarray` of :py:data:`.tstr` elements.
    *  The columns of the matrix table must be keyed by a field
       of type :py:data:`.tstr` that uniquely identifies phenotypes
       represented in the matrix table. The column key must be a single
       expression; compound keys are not accepted.
    *  ``weight_expr`` and ``ld_score_expr`` must be row-indexed
       fields.
    *  ``chi_sq_exprs`` must be a single entry-indexed field
       (not a list of fields).
    *  ``n_samples_exprs`` must be a single entry-indexed field
       (not a list of fields).
    *  The ``phenotype`` field that keys the table returned by
       :func:`.ld_score_regression` will have values corresponding to the
       column keys of the input matrix table.

    This function returns a :class:`Table` with one row per set of summary
    statistics passed to the ``chi_sq_exprs`` argument. The following
    row-indexed fields are included in the table:

    *  **phenotype** (:py:data:`.tstr`) -- The name of the phenotype. The
       returned table is keyed by this field. See the notes below for
       details on the possible values of this field.
    *  **mean_chi_sq** (:py:data:`.tfloat64`) -- The mean chi-squared
       test statistic for the given phenotype.
    *  **intercept** (`Struct`) -- Contains fields:

       -  **estimate** (:py:data:`.tfloat64`) -- A point estimate of the
          intercept :math:`1 + Na`.
       -  **standard_error**  (:py:data:`.tfloat64`) -- An estimate of
          the standard error of this point estimate.

    *  **snp_heritability** (`Struct`) -- Contains fields:

       -  **estimate** (:py:data:`.tfloat64`) -- A point estimate of the
          SNP-heritability :math:`h_g^2`.
       -  **standard_error** (:py:data:`.tfloat64`) -- An estimate of
          the standard error of this point estimate.

    Warning
    -------
    :func:`.ld_score_regression` considers only the rows for which both row
    fields ``weight_expr`` and ``ld_score_expr`` are defined. Rows with missing
    values in either field are removed prior to fitting the LD score
    regression model.

    Parameters
    ----------
    weight_expr : :class:`.Float64Expression`
                  Row-indexed expression for the LD scores used to derive
                  variant weights in the model.
    ld_score_expr : :class:`.Float64Expression`
                    Row-indexed expression for the LD scores used as covariates
                    in the model.
    chi_sq_exprs : :class:`.Float64Expression` or :obj:`list` of
                        :class:`.Float64Expression`
                        One or more row-indexed (if table) or entry-indexed
                        (if matrix table) expressions for chi-squared
                        statistics resulting from genome-wide association
                        studies (GWAS).
    n_samples_exprs: :class:`.NumericExpression` or :obj:`list` of
                     :class:`.NumericExpression`
                     One or more row-indexed (if table) or entry-indexed
                     (if matrix table) expressions indicating the number of
                     samples used in the studies that generated the test
                     statistics supplied to ``chi_sq_exprs``.
    n_blocks : :obj:`int`
               The number of blocks used in the jackknife approach to
               estimating standard errors.
    two_step_threshold : :obj:`int`
                         Variants with chi-squared statistics greater than this
                         value are excluded in the first step of the two-step
                         procedure used to fit the model.
    n_reference_panel_variants : :obj:`int`, optional
                                 Number of variants used to estimate the
                                 SNP-heritability :math:`h_g^2`.

    Returns
    -------
    :class:`.Table`
        Table keyed by ``phenotype`` with intercept and heritability estimates
        for each phenotype passed to the function."""

    chi_sq_exprs = wrap_to_list(chi_sq_exprs)
    n_samples_exprs = wrap_to_list(n_samples_exprs)

    assert ((len(chi_sq_exprs) == len(n_samples_exprs))
            or (len(n_samples_exprs) == 1))
    __k = 2  # number of covariates, including intercept

    ds = chi_sq_exprs[0]._indices.source

    analyze('ld_score_regression/weight_expr',
            weight_expr,
            ds._row_indices)
    analyze('ld_score_regression/ld_score_expr',
            ld_score_expr,
            ds._row_indices)

    # format input dataset
    if isinstance(ds, MatrixTable):
        if len(chi_sq_exprs) != 1:
            raise ValueError("""Only one chi_sq_expr allowed if originating
                from a matrix table.""")
        if len(n_samples_exprs) != 1:
            raise ValueError("""Only one n_samples_expr allowed if
                originating from a matrix table.""")

        col_key = list(ds.col_key)
        if len(col_key) != 1:
            raise ValueError("""Matrix table must be keyed by a single
                phenotype field.""")

        analyze('ld_score_regression/chi_squared_expr',
                chi_sq_exprs[0],
                ds._entry_indices)
        analyze('ld_score_regression/n_samples_expr',
                n_samples_exprs[0],
                ds._entry_indices)

        ds = ds._select_all(row_exprs={'__locus': ds.locus,
                                       '__alleles': ds.alleles,
                                       '__w_initial': weight_expr,
                                       '__w_initial_floor': hl.max(weight_expr,
                                                                   1.0),
                                       '__x': ld_score_expr,
                                       '__x_floor': hl.max(ld_score_expr,
                                                           1.0)},
                            row_key=['__locus', '__alleles'],
                            col_exprs={'__y_name': ds[col_key[0]]},
                            col_key=['__y_name'],
                            entry_exprs={'__y': chi_sq_exprs[0],
                                         '__n': n_samples_exprs[0]})
        ds = ds.annotate_entries(**{'__w': ds.__w_initial})

        ds = ds.filter_rows(hl.is_defined(ds.__locus)
                            & hl.is_defined(ds.__alleles)
                            & hl.is_defined(ds.__w_initial)
                            & hl.is_defined(ds.__x))

    else:
        assert isinstance(ds, Table)
        for y in chi_sq_exprs:
            analyze('ld_score_regression/chi_squared_expr', y, ds._row_indices)
        for n in n_samples_exprs:
            analyze('ld_score_regression/n_samples_expr', n, ds._row_indices)

        ys = ['__y{:}'.format(i) for i, _ in enumerate(chi_sq_exprs)]
        ws = ['__w{:}'.format(i) for i, _ in enumerate(chi_sq_exprs)]
        ns = ['__n{:}'.format(i) for i, _ in enumerate(n_samples_exprs)]

        ds = ds.select(**dict(**{'__locus': ds.locus,
                                 '__alleles': ds.alleles,
                                 '__w_initial': weight_expr,
                                 '__x': ld_score_expr},
                              **{y: chi_sq_exprs[i]
                                 for i, y in enumerate(ys)},
                              **{w: weight_expr for w in ws},
                              **{n: n_samples_exprs[i]
                                 for i, n in enumerate(ns)}))
        ds = ds.key_by(ds.__locus, ds.__alleles)

        table_tmp_file = new_temp_file()
        ds.write(table_tmp_file)
        ds = hl.read_table(table_tmp_file)

        hts = [ds.select(**{'__w_initial': ds.__w_initial,
                            '__w_initial_floor': hl.max(ds.__w_initial,
                                                        1.0),
                            '__x': ds.__x,
                            '__x_floor': hl.max(ds.__x, 1.0),
                            '__y_name': i,
                            '__y': ds[ys[i]],
                            '__w': ds[ws[i]],
                            '__n': hl.int(ds[ns[i]])})
               for i, y in enumerate(ys)]

        mts = [ht.to_matrix_table(row_key=['__locus',
                                           '__alleles'],
                                  col_key=['__y_name'],
                                  row_fields=['__w_initial',
                                              '__w_initial_floor',
                                              '__x',
                                              '__x_floor'])
               for ht in hts]

        ds = mts[0]
        for i in range(1, len(ys)):
            ds = ds.union_cols(mts[i])

        ds = ds.filter_rows(hl.is_defined(ds.__locus)
                            & hl.is_defined(ds.__alleles)
                            & hl.is_defined(ds.__w_initial)
                            & hl.is_defined(ds.__x))

    mt_tmp_file1 = new_temp_file()
    ds.write(mt_tmp_file1)
    mt = hl.read_matrix_table(mt_tmp_file1)

    if not n_reference_panel_variants:
        M = mt.count_rows()
    else:
        M = n_reference_panel_variants

    mt = mt.annotate_entries(__in_step1=(hl.is_defined(mt.__y)
                                         & (mt.__y < two_step_threshold)),
                             __in_step2=hl.is_defined(mt.__y))

    mt = mt.annotate_cols(__col_idx=hl.int(hl.scan.count()),
                          __m_step1=hl.agg.count_where(mt.__in_step1),
                          __m_step2=hl.agg.count_where(mt.__in_step2))

    col_keys = list(mt.col_key)

    ht = mt.localize_entries(entries_array_field_name='__entries',
                             columns_array_field_name='__cols')

    ht = ht.annotate(__entries=hl.rbind(
        hl.scan.array_agg(
            lambda entry: hl.scan.count_where(entry.__in_step1),
            ht.__entries),
        lambda step1_indices: hl.map(
            lambda i: hl.rbind(
                hl.int(hl.or_else(step1_indices[i], 0)),
                ht.__cols[i].__m_step1,
                ht.__entries[i],
                lambda step1_idx, m_step1, entry: hl.rbind(
                    hl.map(
                        lambda j: hl.int(hl.floor(j * (m_step1 / n_blocks))),
                        hl.range(0, n_blocks + 1)),
                    lambda step1_separators: hl.rbind(
                        hl.set(step1_separators).contains(step1_idx),
                        hl.sum(
                            hl.map(
                                lambda s1: step1_idx >= s1,
                                step1_separators)) - 1,
                        lambda is_separator, step1_block: entry.annotate(
                            __step1_block=step1_block,
                            __step2_block=hl.cond(~entry.__in_step1 & is_separator,
                                                  step1_block - 1,
                                                  step1_block))))),
            hl.range(0, hl.len(ht.__entries)))))

    mt = ht._unlocalize_entries('__entries', '__cols', col_keys)

    mt_tmp_file2 = new_temp_file()
    mt.write(mt_tmp_file2)
    mt = hl.read_matrix_table(mt_tmp_file2)

    # initial coefficient estimates
    mt = mt.annotate_cols(__initial_betas=[
        1.0, (hl.agg.mean(mt.__y) - 1.0) / hl.agg.mean(mt.__x)])
    mt = mt.annotate_cols(__step1_betas=mt.__initial_betas,
                          __step2_betas=mt.__initial_betas)

    # step 1 iteratively reweighted least squares
    for i in range(3):
        mt = mt.annotate_entries(__w=hl.cond(
            mt.__in_step1,
            1.0 / (mt.__w_initial_floor * 2.0 * (mt.__step1_betas[0]
                                                 + mt.__step1_betas[1]
                                                 * mt.__x_floor) ** 2),
            0.0))
        mt = mt.annotate_cols(__step1_betas=hl.agg.filter(
            mt.__in_step1,
            hl.agg.linreg(y=mt.__y,
                          x=[1.0, mt.__x],
                          weight=mt.__w).beta))
        mt = mt.annotate_cols(__step1_h2=hl.max(hl.min(
            mt.__step1_betas[1] * M / hl.agg.mean(mt.__n), 1.0), 0.0))
        mt = mt.annotate_cols(__step1_betas=[
            mt.__step1_betas[0],
            mt.__step1_h2 * hl.agg.mean(mt.__n) / M])

    # step 1 block jackknife
    mt = mt.annotate_cols(__step1_block_betas=hl.agg.array_agg(
        lambda i: hl.agg.filter((mt.__step1_block != i) & mt.__in_step1,
                                hl.agg.linreg(y=mt.__y,
                                              x=[1.0, mt.__x],
                                              weight=mt.__w).beta),
        hl.range(n_blocks)))

    mt = mt.annotate_cols(__step1_block_betas_bias_corrected=hl.map(
        lambda x: n_blocks * mt.__step1_betas - (n_blocks - 1) * x,
        mt.__step1_block_betas))

    mt = mt.annotate_cols(
        __step1_jackknife_mean=hl.map(
            lambda i: hl.mean(
                hl.map(lambda x: x[i],
                       mt.__step1_block_betas_bias_corrected)),
            hl.range(0, __k)),
        __step1_jackknife_variance=hl.map(
            lambda i: (hl.sum(
                hl.map(lambda x: x[i]**2,
                       mt.__step1_block_betas_bias_corrected))
                       - hl.sum(
                hl.map(lambda x: x[i],
                       mt.__step1_block_betas_bias_corrected)) ** 2
                       / n_blocks)
            / (n_blocks - 1) / n_blocks,
            hl.range(0, __k)))

    # step 2 iteratively reweighted least squares
    for i in range(3):
        mt = mt.annotate_entries(__w=hl.cond(
            mt.__in_step2,
            1.0 / (mt.__w_initial_floor
                   * 2.0 * (mt.__step2_betas[0] +
                            + mt.__step2_betas[1]
                            * mt.__x_floor) ** 2),
            0.0))
        mt = mt.annotate_cols(__step2_betas=[
            mt.__step1_betas[0],
            hl.agg.filter(mt.__in_step2,
                          hl.agg.linreg(y=mt.__y - mt.__step1_betas[0],
                                        x=[mt.__x],
                                        weight=mt.__w).beta[0])])
        mt = mt.annotate_cols(__step2_h2=hl.max(hl.min(
            mt.__step2_betas[1] * M / hl.agg.mean(mt.__n), 1.0), 0.0))
        mt = mt.annotate_cols(__step2_betas=[
            mt.__step1_betas[0],
            mt.__step2_h2 * hl.agg.mean(mt.__n) / M])

    # step 2 block jackknife
    mt = mt.annotate_cols(__step2_block_betas=hl.agg.array_agg(
        lambda i: hl.agg.filter((mt.__step2_block != i) & mt.__in_step2,
                                hl.agg.linreg(y=mt.__y - mt.__step1_betas[0],
                                              x=[mt.__x],
                                              weight=mt.__w).beta[0]),
        hl.range(n_blocks)))

    mt = mt.annotate_cols(__step2_block_betas_bias_corrected=hl.map(
        lambda x: n_blocks * mt.__step2_betas[1] - (n_blocks - 1) * x,
        mt.__step2_block_betas))

    mt = mt.annotate_cols(
        __step2_jackknife_mean=hl.mean(
            mt.__step2_block_betas_bias_corrected),
        __step2_jackknife_variance=(
            hl.sum(mt.__step2_block_betas_bias_corrected ** 2)
            - hl.sum(mt.__step2_block_betas_bias_corrected) ** 2
            / n_blocks) / (n_blocks - 1) / n_blocks)

    # combine step 1 and step 2 block jackknifes
    mt = mt.annotate_entries(
        __step2_initial_w=1.0 / (mt.__w_initial_floor
                                 * 2.0 * (mt.__initial_betas[0] +
                                          + mt.__initial_betas[1]
                                          * mt.__x_floor) ** 2))

    mt = mt.annotate_cols(
        __final_betas=[
            mt.__step1_betas[0],
            mt.__step2_betas[1]],
        __c=(hl.agg.sum(mt.__step2_initial_w * mt.__x)
             / hl.agg.sum(mt.__step2_initial_w * mt.__x**2)))

    mt = mt.annotate_cols(__final_block_betas=hl.map(
        lambda i: (mt.__step2_block_betas[i] - mt.__c
                   * (mt.__step1_block_betas[i][0] - mt.__final_betas[0])),
        hl.range(0, n_blocks)))

    mt = mt.annotate_cols(
        __final_block_betas_bias_corrected=(n_blocks * mt.__final_betas[1]
                                            - (n_blocks - 1)
                                            * mt.__final_block_betas))

    mt = mt.annotate_cols(
        __final_jackknife_mean=[
            mt.__step1_jackknife_mean[0],
            hl.mean(mt.__final_block_betas_bias_corrected)],
        __final_jackknife_variance=[
            mt.__step1_jackknife_variance[0],
            (hl.sum(mt.__final_block_betas_bias_corrected ** 2)
             - hl.sum(mt.__final_block_betas_bias_corrected) ** 2
             / n_blocks) / (n_blocks - 1) / n_blocks])

    # convert coefficient to heritability estimate
    mt = mt.annotate_cols(
        phenotype=mt.__y_name,
        mean_chi_sq=hl.agg.mean(mt.__y),
        intercept=hl.struct(
            estimate=mt.__final_betas[0],
            standard_error=hl.sqrt(mt.__final_jackknife_variance[0])),
        snp_heritability=hl.struct(
            estimate=(M / hl.agg.mean(mt.__n)) * mt.__final_betas[1],
            standard_error=hl.sqrt((M / hl.agg.mean(mt.__n)) ** 2
                                   * mt.__final_jackknife_variance[1])))

    # format and return results
    ht = mt.cols()
    ht = ht.key_by(ht.phenotype)
    ht = ht.select(ht.mean_chi_sq,
                   ht.intercept,
                   ht.snp_heritability)

    ht_tmp_file = new_temp_file()
    ht.write(ht_tmp_file)
    ht = hl.read_table(ht_tmp_file)

    return ht
示例#22
0
文件: qc.py 项目: jigold/hail
def variant_qc(mt, name='variant_qc') -> MatrixTable:
    """Compute common variant statistics (quality control metrics).

    .. include:: ../_templates/req_tvariant.rst

    Examples
    --------

    >>> dataset_result = hl.variant_qc(dataset)

    Notes
    -----
    This method computes variant statistics from the genotype data, returning
    a new struct field `name` with the following metrics based on the fields
    present in the entry schema.

    If `mt` contains an entry field `DP` of type :py:data:`.tint32`, then the
    field `dp_stats` is computed. If `mt` contains an entry field `GQ` of type
    :py:data:`.tint32`, then the field `gq_stats` is computed. Both `dp_stats`
    and `gq_stats` are structs with with four fields:

    - `mean` (``float64``) -- Mean value.
    - `stdev` (``float64``) -- Standard deviation (zero degrees of freedom).
    - `min` (``int32``) -- Minimum value.
    - `max` (``int32``) -- Maximum value.

    If the dataset does not contain an entry field `GT` of type
    :py:data:`.tcall`, then an error is raised. The following fields are always
    computed from `GT`:

    - `AF` (``array<float64>``) -- Calculated allele frequency, one element
      per allele, including the reference. Sums to one. Equivalent to
      `AC` / `AN`.
    - `AC` (``array<int32>``) -- Calculated allele count, one element per
      allele, including the reference. Sums to `AN`.
    - `AN` (``int32``) -- Total number of called alleles.
    - `homozygote_count` (``array<int32>``) -- Number of homozygotes per
      allele. One element per allele, including the reference.
    - `call_rate` (``float64``) -- Fraction of calls neither missing nor filtered.
       Equivalent to `n_called` / :meth:`.count_cols`.
    - `n_called` (``int64``) -- Number of samples with a defined `GT`.
    - `n_not_called` (``int64``) -- Number of samples with a missing `GT`.
    - `n_filtered` (``int64``) -- Number of filtered entries.
    - `n_het` (``int64``) -- Number of heterozygous samples.
    - `n_non_ref` (``int64``) -- Number of samples with at least one called
      non-reference allele.
    - `het_freq_hwe` (``float64``) -- Expected frequency of heterozygous
      samples under Hardy-Weinberg equilibrium. See
      :func:`.functions.hardy_weinberg_test` for details.
    - `p_value_hwe` (``float64``) -- p-value from test of Hardy-Weinberg equilibrium.
      See :func:`.functions.hardy_weinberg_test` for details.

    Warning
    -------
    `het_freq_hwe` and `p_value_hwe` are calculated as in
    :func:`.functions.hardy_weinberg_test`, with non-diploid calls
    (``ploidy != 2``) ignored in the counts. As this test is only
    statistically rigorous in the biallelic setting, :func:`.variant_qc`
    sets both fields to missing for multiallelic variants. Consider using
    :func:`~hail.methods.split_multi` to split multi-allelic variants beforehand.

    Parameters
    ----------
    mt : :class:`.MatrixTable`
        Dataset.
    name : :obj:`str`
        Name for resulting field.

    Returns
    -------
    :class:`.MatrixTable`
    """
    require_row_key_variant(mt, 'variant_qc')

    bound_exprs = {}
    gq_dp_exprs = {}

    def has_field_of_type(name, dtype):
        return name in mt.entry and mt[name].dtype == dtype

    if has_field_of_type('DP', hl.tint32):
        gq_dp_exprs['dp_stats'] = hl.agg.stats(mt.DP).select('mean', 'stdev', 'min', 'max')

    if has_field_of_type('GQ', hl.tint32):
        gq_dp_exprs['gq_stats'] = hl.agg.stats(mt.GQ).select('mean', 'stdev', 'min', 'max')

    if not has_field_of_type('GT',  hl.tcall):
        raise ValueError(f"'variant_qc': expect an entry field 'GT' of type 'call'")

    bound_exprs['n_called'] = hl.agg.count_where(hl.is_defined(mt['GT']))
    bound_exprs['n_not_called'] = hl.agg.count_where(hl.is_missing(mt['GT']))
    bound_exprs['n_filtered'] = mt.count_cols(_localize=False) - hl.agg.count()
    bound_exprs['call_stats'] = hl.agg.call_stats(mt.GT, mt.alleles)

    result = hl.rbind(hl.struct(**bound_exprs),
                      lambda e1: hl.rbind(
                          hl.case().when(hl.len(mt.alleles) == 2,
                                         hl.hardy_weinberg_test(e1.call_stats.homozygote_count[0],
                                                                e1.call_stats.AC[1] - 2 *
                                                                e1.call_stats.homozygote_count[1],
                                                                e1.call_stats.homozygote_count[1])
                                         ).or_missing(),
                          lambda hwe: hl.struct(**{
                              **gq_dp_exprs,
                              **e1.call_stats,
                              'call_rate': hl.float(e1.n_called) / (e1.n_called + e1.n_not_called + e1.n_filtered),
                              'n_called': e1.n_called,
                              'n_not_called': e1.n_not_called,
                              'n_filtered': e1.n_filtered,
                              'n_het': e1.n_called - hl.sum(e1.call_stats.homozygote_count),
                              'n_non_ref': e1.n_called - e1.call_stats.homozygote_count[0],
                              'het_freq_hwe': hwe.het_freq_hwe,
                              'p_value_hwe': hwe.p_value})))

    return mt.annotate_rows(**{name: result})
示例#23
0
def sample_qc(vds: 'VariantDataset',
              *,
              gq_bins: 'Sequence[int]' = (0, 20, 60),
              dp_bins: 'Sequence[int]' = (0, 1, 10, 20, 30),
              dp_field=None) -> 'Table':
    """Compute sample quality metrics about a :class:`.VariantDataset`.

    If the `dp_field` parameter is not specified, the ``DP`` is used for depth
    if present. If no ``DP`` field is present, the ``MIN_DP`` field is used. If no ``DP``
    or ``MIN_DP`` field is present, no depth statistics will be calculated.

    Parameters
    ----------
    vds : :class:`.VariantDataset`
        Dataset in VariantDataset representation.
    name : :obj:`str`
        Name for resulting field.
    gq_bins : :class:`tuple` of :obj:`int`
        Tuple containing cutoffs for genotype quality (GQ) scores.
    dp_bins : :class:`tuple` of :obj:`int`
        Tuple containing cutoffs for depth (DP) scores.
    dp_field : :obj:`str`
        Name of depth field. If not supplied, DP or MIN_DP will be used, in that order.

    Returns
    -------
    :class:`.Table`
        Hail Table of results, keyed by sample.
    """

    require_first_key_field_locus(vds.reference_data, 'sample_qc')
    require_first_key_field_locus(vds.variant_data, 'sample_qc')

    ref = vds.reference_data

    if 'DP' in ref.entry:
        ref_dp_field_to_use = 'DP'
    elif 'MIN_DP' in ref.entry:
        ref_dp_field_to_use = 'MIN_DP'
    else:
        ref_dp_field_to_use = dp_field

    from hail.expr.functions import _num_allele_type, _allele_types

    allele_types = _allele_types[:]
    allele_types.extend(['Transition', 'Transversion'])
    allele_enum = {i: v for i, v in enumerate(allele_types)}
    allele_ints = {v: k for k, v in allele_enum.items()}

    def allele_type(ref, alt):
        return hl.bind(
            lambda at: hl.if_else(
                at == allele_ints['SNP'],
                hl.if_else(hl.is_transition(ref, alt), allele_ints[
                    'Transition'], allele_ints['Transversion']), at),
            _num_allele_type(ref, alt))

    variant_ac = Env.get_uid()
    variant_atypes = Env.get_uid()

    vmt = vds.variant_data
    if 'GT' not in vmt.entry:
        vmt = vmt.annotate_entries(
            GT=hl.experimental.lgt_to_gt(vmt.LGT, vmt.LA))

    vmt = vmt.annotate_rows(
        **{
            variant_ac:
            hl.agg.call_stats(vmt.GT, vmt.alleles).AC,
            variant_atypes:
            vmt.alleles[1:].map(lambda alt: allele_type(vmt.alleles[0], alt))
        })

    bound_exprs = {}

    bound_exprs['n_het'] = hl.agg.count_where(vmt['GT'].is_het())
    bound_exprs['n_hom_var'] = hl.agg.count_where(vmt['GT'].is_hom_var())
    bound_exprs['n_singleton'] = hl.agg.sum(
        hl.rbind(
            vmt['GT'], lambda gt: hl.sum(
                hl.range(0, gt.ploidy).map(lambda i: hl.rbind(
                    gt[i], lambda gti:
                    (gti != 0) & (vmt[variant_ac][gti] == 1))))))

    bound_exprs['allele_type_counts'] = hl.agg.explode(
        lambda allele_type: hl.tuple(
            hl.agg.count_where(allele_type == i)
            for i in range(len(allele_ints))),
        (hl.range(0, vmt['GT'].ploidy).map(lambda i: vmt['GT'][i]).filter(
            lambda allele_idx: allele_idx > 0).map(
                lambda allele_idx: vmt[variant_atypes][allele_idx - 1])))

    dp_exprs = {}
    if ref_dp_field_to_use is not None and 'DP' in vmt.entry:
        dp_exprs['dp'] = hl.tuple(
            hl.agg.count_where(vmt.DP >= x) for x in dp_bins)

    gq_dp_exprs = hl.struct(
        **{'gq': hl.tuple(hl.agg.count_where(vmt.GQ >= x) for x in gq_bins)},
        **dp_exprs)

    result_struct = hl.rbind(
        hl.struct(**bound_exprs), lambda x: hl.rbind(
            hl.struct(
                **{
                    'gq_dp_exprs':
                    gq_dp_exprs,
                    'n_het':
                    x.n_het,
                    'n_hom_var':
                    x.n_hom_var,
                    'n_non_ref':
                    x.n_het + x.n_hom_var,
                    'n_singleton':
                    x.n_singleton,
                    'n_snp':
                    (x.allele_type_counts[allele_ints['Transition']] + x.
                     allele_type_counts[allele_ints['Transversion']]),
                    'n_insertion':
                    x.allele_type_counts[allele_ints['Insertion']],
                    'n_deletion':
                    x.allele_type_counts[allele_ints['Deletion']],
                    'n_transition':
                    x.allele_type_counts[allele_ints['Transition']],
                    'n_transversion':
                    x.allele_type_counts[allele_ints['Transversion']],
                    'n_star':
                    x.allele_type_counts[allele_ints['Star']]
                }), lambda s: s.annotate(r_ti_tv=divide_null(
                    hl.float64(s.n_transition), s.n_transversion),
                                         r_het_hom_var=divide_null(
                                             hl.float64(s.n_het), s.n_hom_var),
                                         r_insertion_deletion=divide_null(
                                             hl.float64(s.n_insertion), s.
                                             n_deletion))))
    variant_results = vmt.select_cols(**result_struct).cols()

    rmt = vds.reference_data

    ref_dp_expr = {}
    if ref_dp_field_to_use is not None:
        ref_dp_expr['ref_bases_over_dp_threshold'] = hl.tuple(
            hl.agg.filter(rmt[ref_dp_field_to_use] >= x,
                          hl.agg.sum(1 + rmt.END - rmt.locus.position))
            for x in dp_bins)
    ref_results = rmt.select_cols(ref_bases_over_gq_threshold=hl.tuple(
        hl.agg.filter(rmt.GQ >= x, hl.agg.sum(1 + rmt.END -
                                              rmt.locus.position))
        for x in gq_bins),
                                  **ref_dp_expr).cols()

    joined = ref_results[variant_results.key]

    joined_dp_expr = {}
    dp_bins_field = {}
    if ref_dp_field_to_use is not None:
        joined_dp_expr['bases_over_dp_threshold'] = hl.tuple(
            x + y for x, y in zip(variant_results.gq_dp_exprs.dp,
                                  joined.ref_bases_over_dp_threshold))
        dp_bins_field['dp_bins'] = hl.tuple(dp_bins)

    joined_results = variant_results.transmute(
        bases_over_gq_threshold=hl.tuple(
            x + y for x, y in zip(variant_results.gq_dp_exprs.gq,
                                  joined.ref_bases_over_gq_threshold)),
        **joined_dp_expr)

    joined_results = joined_results.annotate_globals(gq_bins=hl.tuple(gq_bins),
                                                     **dp_bins_field)
    return joined_results