Пример #1
0
def sor_from_sb(
    sb: Union[hl.expr.ArrayNumericExpression, hl.expr.ArrayExpression]
) -> hl.expr.Float64Expression:
    """
    Computes `SOR` (Symmetric Odds Ratio test) annotation from  the `SB` (strand balance table) field.

    .. note::

        This function can either take
        - an array of length four containing the forward and reverse strands' counts of ref and alt alleles: [ref fwd, ref rev, alt fwd, alt rev]
        - a two dimensional array with arrays of length two, containing the counts: [[ref fwd, ref rev], [alt fwd, alt rev]]

    GATK code here: https://github.com/broadinstitute/gatk/blob/master/src/main/java/org/broadinstitute/hellbender/tools/walkers/annotator/StrandOddsRatio.java

    :param sb: Count of ref/alt reads on each strand
    :return: SOR value
    """

    if not isinstance(sb, hl.expr.ArrayNumericExpression):
        sb = hl.bind(lambda x: hl.flatten(x), sb)

    sb = sb.map(lambda x: hl.float64(x) + 1)

    ref_fw = sb[0]
    ref_rv = sb[1]
    alt_fw = sb[2]
    alt_rv = sb[3]
    symmetrical_ratio = ((ref_fw * alt_rv) / (alt_fw * ref_rv)) + (
        (alt_fw * ref_rv) / (ref_fw * alt_rv)
    )
    ref_ratio = hl.min(ref_rv, ref_fw) / hl.max(ref_rv, ref_fw)
    alt_ratio = hl.min(alt_fw, alt_rv) / hl.max(alt_fw, alt_rv)
    sor = hl.log(symmetrical_ratio) + hl.log(ref_ratio) - hl.log(alt_ratio)

    return sor
Пример #2
0
    def compute_same_hap_log_like(n, p, q, x):
        res = (
            hl.cond(
                q > 0,
                hl.fold(
                    lambda i, j: i + j[0] * j[1], 0.0,
                    hl.zip(gt_counts, [
                        hl.log10(x) * 2,
                        hl.log10(2 * x * e),
                        hl.log10(e) * 2,
                        hl.log10(2 * x * p),
                        hl.log10(2 * (p * e + x * q)),
                        hl.log10(2 * q * e),
                        hl.log10(p) * 2,
                        hl.log10(2 * p * q),
                        hl.log10(q) * 2
                    ])),
                -1e31  # Very large negative value if no q is present
            ))

        # If desired, add distance posterior based on value derived from regression
        if distance is not None:
            res = res + hl.max(-6,
                               hl.log10(0.97 - 0.03 * hl.log(distance + 1)))

        return res
def annotate_mt(mt):
    # Annotate POPMAX_AF, which is max of respective fields using a_index for multi-allelics.
    return mt.annotate_rows(
        POPMAX_AF=hl.max(mt.info.AFR_AF[mt.a_index - 1], mt.info.AMR_AF[
            mt.a_index -
            1], mt.info.EAS_AF[mt.a_index -
                               1], mt.info.EUR_AF[mt.a_index -
                                                  1], mt.info.SAS_AF[mt.a_index
                                                                     - 1]))
def vcf_to_mt(splice_ai_snvs_path, splice_ai_indels_path, genome_version):
    """
    Loads the snv path and indels source path to a matrix table and returns the table.

    :param splice_ai_snvs_path: source location
    :param splice_ai_indels_path: source location
    :return: matrix table
    """

    logger.info("==> reading in splice_ai vcfs: %s, %s" %
                (splice_ai_snvs_path, splice_ai_indels_path))

    # for 37, extract to MT, for 38, MT not included.
    interval = "1-MT" if genome_version == "37" else "chr1-chrY"
    contig_dict = None
    if genome_version == "38":
        contig_dict = NO_CHR_TO_CHR_CONTIG_RECODING

    mt = hl.import_vcf(
        [splice_ai_snvs_path, splice_ai_indels_path],
        reference_genome=f"GRCh{genome_version}",
        contig_recoding=contig_dict,
        force_bgz=True,
        min_partitions=10000,
        skip_invalid_loci=True,
    )
    interval = [
        hl.parse_locus_interval(interval,
                                reference_genome=f"GRCh{genome_version}")
    ]
    mt = hl.filter_intervals(mt, interval)

    # Split SpliceAI field by | delimiter. Capture delta score entries and map to floats
    delta_scores = mt.info.SpliceAI[0].split(delim="\\|")[2:6]
    splice_split = mt.info.annotate(
        SpliceAI=hl.map(lambda x: hl.float32(x), delta_scores))
    mt = mt.annotate_rows(info=splice_split)

    # Annotate info.max_DS with the max of DS_AG, DS_AL, DS_DG, DS_DL in info.
    # delta_score array is |DS_AG|DS_AL|DS_DG|DS_DL
    consequences = hl.literal(
        ["Acceptor gain", "Acceptor loss", "Donor gain", "Donor loss"])
    mt = mt.annotate_rows(info=mt.info.annotate(
        max_DS=hl.max(mt.info.SpliceAI)))
    mt = mt.annotate_rows(info=mt.info.annotate(splice_consequence=hl.if_else(
        mt.info.max_DS > 0,
        consequences[mt.info.SpliceAI.index(mt.info.max_DS)],
        "No consequence",
    )))
    return mt
def format_regional_missense_constraint(ds):
    ds = ds.annotate(obs_mis=hl.int(ds.obs_mis))

    ds = ds.annotate(start=hl.min(ds.genomic_start, ds.genomic_end), stop=hl.max(ds.genomic_start, ds.genomic_end))

    ds = ds.drop("amino_acids", "chr", "gene", "genomic_start", "genomic_end", "region_name")

    ds = ds.transmute(transcript_id=ds.transcript.split("\\.")[0])

    ds = ds.group_by("transcript_id").aggregate(regions=hl.agg.collect(ds.row_value))

    ds = ds.annotate(regions=hl.sorted(ds.regions, lambda region: region.start))

    return ds
Пример #6
0
def merge_overlapping_regions(regions):
    return hl.cond(
        hl.len(regions) > 1,
        hl.rbind(
            hl.sorted(regions, lambda region: region.start),
            lambda sorted_regions: sorted_regions[1:].fold(
                lambda acc, region: hl.cond(
                    region.start <= acc[-1].stop + 1,
                    acc[:-1].append(acc[-1].annotate(stop=hl.max(
                        region.stop, acc[-1].stop))),
                    acc.append(region),
                ),
                [sorted_regions[0]],
            ),
        ),
        regions,
    )
Пример #7
0
def vcf_to_mt(splice_ai_snvs_path, splice_ai_indels_path, genome_version):
    '''
    Loads the snv path and indels source path to a matrix table and returns the table.

    :param splice_ai_snvs_path: source location
    :param splice_ai_indels_path: source location
    :return: matrix table
    '''

    logger.info('==> reading in splice_ai vcfs: %s, %s' %
                (splice_ai_snvs_path, splice_ai_indels_path))

    # for 37, extract to MT, for 38, MT not included.
    interval = '1-MT' if genome_version == '37' else 'chr1-chrY'
    contig_dict = None
    if genome_version == '38':
        rg = hl.get_reference('GRCh37')
        grch37_contigs = [
            x for x in rg.contigs
            if not x.startswith('GL') and not x.startswith('M')
        ]
        contig_dict = dict(
            zip(grch37_contigs, ['chr' + x for x in grch37_contigs]))

    mt = hl.import_vcf([splice_ai_snvs_path, splice_ai_indels_path],
                       reference_genome=f"GRCh{genome_version}",
                       contig_recoding=contig_dict,
                       force_bgz=True,
                       min_partitions=10000,
                       skip_invalid_loci=True)
    interval = [
        hl.parse_locus_interval(interval,
                                reference_genome=f"GRCh{genome_version}")
    ]
    mt = hl.filter_intervals(mt, interval)

    # Annotate info.max_DS with the max of DS_AG, DS_AL, DS_DG, DS_DL in info.
    info = mt.info.annotate(max_DS=hl.max(
        [mt.info.DS_AG, mt.info.DS_AL, mt.info.DS_DG, mt.info.DS_DL]))
    mt = mt.annotate_rows(info=info)

    return mt
Пример #8
0
    def add_stats(
        i: hl.expr.StructExpression, j: hl.expr.StructExpression
    ) -> hl.expr.StructExpression:
        """
        This merges two stast counters together. It assumes that all stats counter fields are present in the struct.

        :param i: accumulator: struct with mean, n and variance
        :param j: new element: stats_struct -- needs to contain mean, n and variance
        :return: Accumulation over all elements: struct with mean, n and variance
        """
        delta = j.mean - i.mean
        n_tot = i.n + j.n
        return hl.struct(
            min=hl.min(i.min, j.min),
            max=hl.max(i.max, j.max),
            mean=(i.mean * i.n + j.mean * j.n) / n_tot,
            variance=i.variance + j.variance + (delta * delta * i.n * j.n) / n_tot,
            n=n_tot,
            sum=i.sum + j.sum,
        )
Пример #9
0
    def compute_chet_log_like(n, p, q, x):
        res = (hl.cond((p > 0) & (q > 0),
                       hl.fold(
                           lambda i, j: i + j[0] * j[1], 0,
                           hl.zip(gt_counts, [
                               hl.log10(x) * 2,
                               hl.log10(2 * x * q),
                               hl.log10(q) * 2,
                               hl.log10(2 * x * p),
                               hl.log10(2 * (p * q + x * e)),
                               hl.log10(2 * q * e),
                               hl.log10(p) * 2,
                               hl.log10(2 * p * e),
                               hl.log10(e) * 2
                           ])), -1e-31))
        # If desired, add distance posterior based on value derived from regression
        if distance is not None:
            res = res + hl.max(-6,
                               hl.log10(0.03 + 0.03 * hl.log(distance - 1)))

        return res
def prepare_exac_regional_missense_constraint(path):
    ds = hl.import_table(
        path,
        missing="",
        types={
            "transcript": hl.tstr,
            "gene": hl.tstr,
            "chr": hl.tstr,
            "amino_acids": hl.tstr,
            "genomic_start": hl.tint,
            "genomic_end": hl.tint,
            "obs_mis": hl.tfloat,
            "exp_mis": hl.tfloat,
            "obs_exp": hl.tfloat,
            "chisq_diff_null": hl.tfloat,
            "region_name": hl.tstr,
        },
    )

    ds = ds.annotate(obs_mis=hl.int(ds.obs_mis))

    ds = ds.annotate(start=hl.min(ds.genomic_start, ds.genomic_end),
                     stop=hl.max(ds.genomic_start, ds.genomic_end))

    ds = ds.drop("amino_acids", "chr", "gene", "genomic_start", "genomic_end",
                 "region_name")

    ds = ds.transmute(transcript_id=ds.transcript.split("\\.")[0])

    ds = ds.group_by("transcript_id").aggregate(
        regions=hl.agg.collect(ds.row_value))

    ds = ds.annotate(
        regions=hl.sorted(ds.regions, lambda region: region.start))

    ds = ds.select(exac_regional_missense_constraint_regions=ds.regions)

    return ds
Пример #11
0
def de_novo(mt: MatrixTable,
            pedigree: Pedigree,
            pop_frequency_prior,
            *,
            min_gq: int = 20,
            min_p: float = 0.05,
            max_parent_ab: float = 0.05,
            min_child_ab: float = 0.20,
            min_dp_ratio: float = 0.10) -> Table:
    r"""Call putative *de novo* events from trio data.

    .. include:: ../_templates/req_tstring.rst

    .. include:: ../_templates/req_tvariant.rst

    .. include:: ../_templates/req_biallelic.rst

    Examples
    --------

    Call de novo events:

    >>> pedigree = hl.Pedigree.read('data/trios.fam')
    >>> priors = hl.import_table('data/gnomadFreq.tsv', impute=True)
    >>> priors = priors.transmute(**hl.parse_variant(priors.Variant)).key_by('locus', 'alleles')
    >>> de_novo_results = hl.de_novo(dataset, pedigree, pop_frequency_prior=priors[dataset.row_key].AF)

    Notes
    -----
    This method assumes the GATK high-throughput sequencing fields exist:
    `GT`, `AD`, `DP`, `GQ`, `PL`.

    This method replicates the functionality of `Kaitlin Samocha's de novo
    caller <https://github.com/ksamocha/de_novo_scripts>`__. The version
    corresponding to git commit ``bde3e40`` is implemented in Hail with her
    permission and assistance.

    This method produces a :class:`.Table` with the following fields:

     - `locus` (``locus``) -- Variant locus.
     - `alleles` (``array<str>``) -- Variant alleles.
     - `id` (``str``) -- Proband sample ID.
     - `prior` (``float64``) -- Site frequency prior. It is the maximum of:
       the computed dataset alternate allele frequency, the
       `pop_frequency_prior` parameter, and the global prior
       ``1 / 3e7``.
     - `proband` (``struct``) -- Proband column fields from `mt`.
     - `father` (``struct``) -- Father column fields from `mt`.
     - `mother` (``struct``) -- Mother column fields from `mt`.
     - `proband_entry` (``struct``) -- Proband entry fields from `mt`.
     - `father_entry` (``struct``) -- Father entry fields from `mt`.
     - `proband_entry` (``struct``) -- Mother entry fields from `mt`.
     - `is_female` (``bool``) -- ``True`` if proband is female.
     - `p_de_novo` (``float64``) -- Unfiltered posterior probability
       that the event is *de novo* rather than a missed heterozygous
       event in a parent.
     - `confidence` (``str``) Validation confidence. One of: ``'HIGH'``,
       ``'MEDIUM'``, ``'LOW'``.

    The key of the table is ``['locus', 'alleles', 'id']``.

    The model looks for de novo events in which both parents are homozygous
    reference and the proband is a heterozygous. The model makes the simplifying
    assumption that when this configuration ``x = (AA, AA, AB)`` of calls
    occurs, exactly one of the following is true:

     - ``d``: a de novo mutation occurred in the proband and all calls are
       accurate.
     - ``m``: at least one parental allele is actually heterozygous and
       the proband call is accurate.

    We can then estimate the posterior probability of a de novo mutation as:

    .. math::

        \mathrm{P_{\text{de novo}}} = \frac{\mathrm{P}(d\,|\,x)}{\mathrm{P}(d\,|\,x) + \mathrm{P}(m\,|\,x)}

    Applying Bayes rule to the numerator and denominator yields

    .. math::

        \frac{\mathrm{P}(x\,|\,d)\,\mathrm{P}(d)}{\mathrm{P}(x\,|\,d)\,\mathrm{P}(d) +
        \mathrm{P}(x\,|\,m)\,\mathrm{P}(m)}

    The prior on de novo mutation is estimated from the rate in the literature:

    .. math::

        \mathrm{P}(d) = \frac{1 \text{mutation}}{30,000,000\, \text{bases}}

    The prior used for at least one alternate allele between the parents
    depends on the alternate allele frequency:

    .. math::

        \mathrm{P}(m) = 1 - (1 - AF)^4

    The likelihoods :math:`\mathrm{P}(x\,|\,d)` and :math:`\mathrm{P}(x\,|\,m)`
    are computed from the PL (genotype likelihood) fields using these
    factorizations:

    .. math::

        \mathrm{P}(x = (AA, AA, AB) \,|\,d) = \Big(
        &\mathrm{P}(x_{\mathrm{father}} = AA \,|\, \mathrm{father} = AA) \\
        \cdot &\mathrm{P}(x_{\mathrm{mother}} = AA \,|\, \mathrm{mother} =
        AA) \\ \cdot &\mathrm{P}(x_{\mathrm{proband}} = AB \,|\,
        \mathrm{proband} = AB) \Big)

    .. math::

        \mathrm{P}(x = (AA, AA, AB) \,|\,m) = \Big( &
        \mathrm{P}(x_{\mathrm{father}} = AA \,|\, \mathrm{father} = AB)
        \cdot \mathrm{P}(x_{\mathrm{mother}} = AA \,|\, \mathrm{mother} =
        AA) \\ + \, &\mathrm{P}(x_{\mathrm{father}} = AA \,|\,
        \mathrm{father} = AA) \cdot \mathrm{P}(x_{\mathrm{mother}} = AA
        \,|\, \mathrm{mother} = AB) \Big) \\ \cdot \,
        &\mathrm{P}(x_{\mathrm{proband}} = AB \,|\, \mathrm{proband} = AB)

    (Technically, the second factorization assumes there is exactly (rather
    than at least) one alternate allele among the parents, which may be
    justified on the grounds that it is typically the most likely case by far.)

    While this posterior probability is a good metric for grouping putative de
    novo mutations by validation likelihood, there exist error modes in
    high-throughput sequencing data that are not appropriately accounted for by
    the phred-scaled genotype likelihoods. To this end, a number of hard filters
    are applied in order to assign validation likelihood.

    These filters are different for SNPs and insertions/deletions. In the below
    rules, the following variables are used:

     - ``DR`` refers to the ratio of the read depth in the proband to the
       combined read depth in the parents.
     - ``AB`` refers to the read allele balance of the proband (number of
       alternate reads divided by total reads).
     - ``AC`` refers to the count of alternate alleles across all individuals
       in the dataset at the site.
     - ``p`` refers to :math:`\mathrm{P_{\text{de novo}}}`.
     - ``min_p`` refers to the ``min_p`` function parameter.

    HIGH-quality SNV:

    .. code-block:: text

        p > 0.99 && AB > 0.3 && DR > 0.2
            or
        p > 0.99 && AB > 0.3 && AC == 1

    MEDIUM-quality SNV:

    .. code-block:: text

        p > 0.5 && AB > 0.3
            or
        p > 0.5 && AB > 0.2 && AC == 1

    LOW-quality SNV:

    .. code-block:: text

        p > min_p && AB > 0.2

    HIGH-quality indel:

    .. code-block:: text

        p > 0.99 && AB > 0.3 && DR > 0.2
            or
        p > 0.99 && AB > 0.3 && AC == 1

    MEDIUM-quality indel:

    .. code-block:: text

        p > 0.5 && AB > 0.3
            or
        p > 0.5 && AB > 0.2 and AC == 1

    LOW-quality indel:

    .. code-block:: text

        p > min_p && AB > 0.2

    Additionally, de novo candidates are not considered if the proband GQ is
    smaller than the ``min_gq`` parameter, if the proband allele balance is
    lower than the ``min_child_ab`` parameter, if the depth ratio between the
    proband and parents is smaller than the ``min_depth_ratio`` parameter, or if
    the allele balance in a parent is above the ``max_parent_ab`` parameter.

    Parameters
    ----------
    mt : :class:`.MatrixTable`
        High-throughput sequencing dataset.
    pedigree : :class:`.Pedigree`
        Sample pedigree.
    pop_frequency_prior : :class:`.Float64Expression`
        Expression for population alternate allele frequency prior.
    min_gq
        Minimum proband GQ to be considered for *de novo* calling.
    min_p
        Minimum posterior probability to be considered for *de novo* calling.
    max_parent_ab
        Maximum parent allele balance.
    min_child_ab
        Minimum proband allele balance/
    min_dp_ratio
        Minimum ratio between proband read depth and parental read depth.

    Returns
    -------
    :class:`.Table`
    """
    DE_NOVO_PRIOR = 1 / 30000000
    MIN_POP_PRIOR = 100 / 30000000

    required_entry_fields = {'GT', 'AD', 'DP', 'GQ', 'PL'}
    missing_fields = required_entry_fields - set(mt.entry)
    if missing_fields:
        raise ValueError(
            f"'de_novo': expected 'MatrixTable' to have at least {required_entry_fields}, "
            f"missing {missing_fields}")

    mt = mt.annotate_rows(__prior=pop_frequency_prior,
                          __alt_alleles=hl.agg.sum(mt.GT.n_alt_alleles()),
                          __total_alleles=2 * hl.agg.sum(hl.is_defined(mt.GT)))
    # subtract 1 from __alt_alleles to correct for the observed genotype
    mt = mt.annotate_rows(
        __site_freq=hl.max((mt.__alt_alleles - 1) /
                           mt.__total_alleles, mt.__prior, MIN_POP_PRIOR))
    mt = require_biallelic(mt, 'de_novo')

    # FIXME check that __site_freq is between 0 and 1 when possible in expr
    tm = trio_matrix(mt, pedigree, complete_trios=True)

    autosomal = tm.locus.in_autosome_or_par() | (tm.locus.in_x_nonpar()
                                                 & tm.is_female)
    hemi_x = tm.locus.in_x_nonpar() & ~tm.is_female
    hemi_y = tm.locus.in_y_nonpar() & ~tm.is_female
    hemi_mt = tm.locus.in_mito() & tm.is_female

    is_snp = hl.is_snp(tm.alleles[0], tm.alleles[1])
    n_alt_alleles = tm.__alt_alleles
    prior = tm.__site_freq
    het_hom_hom = tm.proband_entry.GT.is_het() & tm.father_entry.GT.is_hom_ref(
    ) & tm.mother_entry.GT.is_hom_ref()
    kid_ad_fail = tm.proband_entry.AD[1] / hl.sum(
        tm.proband_entry.AD) < min_child_ab

    failure = hl.null(hl.tstruct(p_de_novo=hl.tfloat64, confidence=hl.tstr))

    kid = tm.proband_entry
    dad = tm.father_entry
    mom = tm.mother_entry

    kid_linear_pl = 10**(-kid.PL / 10)
    kid_pp = hl.bind(lambda x: x / hl.sum(x), kid_linear_pl)

    dad_linear_pl = 10**(-dad.PL / 10)
    dad_pp = hl.bind(lambda x: x / hl.sum(x), dad_linear_pl)

    mom_linear_pl = 10**(-mom.PL / 10)
    mom_pp = hl.bind(lambda x: x / hl.sum(x), mom_linear_pl)

    kid_ad_ratio = kid.AD[1] / hl.sum(kid.AD)
    dp_ratio = kid.DP / (dad.DP + mom.DP)

    def call_auto(kid_pp, dad_pp, mom_pp, kid_ad_ratio):
        p_data_given_dn = dad_pp[0] * mom_pp[0] * kid_pp[1] * DE_NOVO_PRIOR
        p_het_in_parent = 1 - (1 - prior)**4
        p_data_given_missed_het = (dad_pp[1] * mom_pp[0] + dad_pp[0] *
                                   mom_pp[1]) * kid_pp[1] * p_het_in_parent
        p_de_novo = p_data_given_dn / (p_data_given_dn +
                                       p_data_given_missed_het)

        def solve(p_de_novo):
            return (hl.case().when(kid.GQ < min_gq, failure).when(
                (kid.DP / (dad.DP + mom.DP) < min_dp_ratio)
                | ~(kid_ad_ratio >= min_child_ab), failure).when(
                    (hl.sum(mom.AD) == 0) | (hl.sum(dad.AD) == 0),
                    failure).when(
                        (mom.AD[1] / hl.sum(mom.AD) > max_parent_ab) |
                        (dad.AD[1] / hl.sum(dad.AD) > max_parent_ab),
                        failure).when(p_de_novo < min_p, failure).when(
                            ~is_snp,
                            hl.case().when(
                                (p_de_novo > 0.99) & (kid_ad_ratio > 0.3) &
                                (n_alt_alleles == 1),
                                hl.struct(p_de_novo=p_de_novo,
                                          confidence='HIGH')).when(
                                              (p_de_novo > 0.5) &
                                              (kid_ad_ratio > 0.3) &
                                              (n_alt_alleles <= 5),
                                              hl.struct(
                                                  p_de_novo=p_de_novo,
                                                  confidence='MEDIUM')).when(
                                                      (p_de_novo > 0.05) &
                                                      (kid_ad_ratio > 0.2),
                                                      hl.struct(
                                                          p_de_novo=p_de_novo,
                                                          confidence='LOW')).
                            or_missing()).default(hl.case().when(
                                ((p_de_novo > 0.99) & (kid_ad_ratio > 0.3) &
                                 (dp_ratio > 0.2)) | ((p_de_novo > 0.99) &
                                                      (kid_ad_ratio > 0.3) &
                                                      (n_alt_alleles == 1)) |
                                ((p_de_novo > 0.5) & (kid_ad_ratio > 0.3) &
                                 (n_alt_alleles < 10) & (kid.DP > 10)),
                                hl.struct(p_de_novo=p_de_novo,
                                          confidence='HIGH')).when(
                                              (p_de_novo > 0.5) &
                                              ((kid_ad_ratio > 0.3) |
                                               (n_alt_alleles == 1)),
                                              hl.struct(
                                                  p_de_novo=p_de_novo,
                                                  confidence='MEDIUM')).when(
                                                      (p_de_novo > 0.05) &
                                                      (kid_ad_ratio > 0.2),
                                                      hl.struct(
                                                          p_de_novo=p_de_novo,
                                                          confidence='LOW')).
                                                  or_missing()))

        return hl.bind(solve, p_de_novo)

    def call_hemi(kid_pp, parent, parent_pp, kid_ad_ratio):
        p_data_given_dn = parent_pp[0] * kid_pp[1] * DE_NOVO_PRIOR
        p_het_in_parent = 1 - (1 - prior)**4
        p_data_given_missed_het = (parent_pp[1] +
                                   parent_pp[2]) * kid_pp[2] * p_het_in_parent
        p_de_novo = p_data_given_dn / (p_data_given_dn +
                                       p_data_given_missed_het)

        def solve(p_de_novo):
            return (hl.case().when(kid.GQ < min_gq, failure).when(
                (kid.DP /
                 (parent.DP) < min_dp_ratio) | (kid_ad_ratio < min_child_ab),
                failure).when((hl.sum(parent.AD) == 0), failure).when(
                    parent.AD[1] / hl.sum(parent.AD) > max_parent_ab,
                    failure).when(p_de_novo < min_p, failure).when(
                        ~is_snp,
                        hl.case().when(
                            (p_de_novo > 0.99) & (kid_ad_ratio > 0.3) &
                            (n_alt_alleles == 1),
                            hl.struct(
                                p_de_novo=p_de_novo, confidence='HIGH')).when(
                                    (p_de_novo > 0.5) & (kid_ad_ratio > 0.3) &
                                    (n_alt_alleles <= 5),
                                    hl.struct(p_de_novo=p_de_novo,
                                              confidence='MEDIUM')).when(
                                                  (p_de_novo > 0.05) &
                                                  (kid_ad_ratio > 0.3),
                                                  hl.struct(
                                                      p_de_novo=p_de_novo,
                                                      confidence='LOW')).
                        or_missing()).default(
                            hl.case().when(
                                ((p_de_novo > 0.99) & (kid_ad_ratio > 0.3) &
                                 (dp_ratio > 0.2)) |
                                ((p_de_novo > 0.99) & (kid_ad_ratio > 0.3) &
                                 (n_alt_alleles == 1)) |
                                ((p_de_novo > 0.5) & (kid_ad_ratio > 0.3) &
                                 (n_alt_alleles < 10) & (kid.DP > 10)),
                                hl.struct(p_de_novo=p_de_novo,
                                          confidence='HIGH')).when(
                                              (p_de_novo > 0.5) &
                                              ((kid_ad_ratio > 0.3) |
                                               (n_alt_alleles == 1)),
                                              hl.struct(p_de_novo=p_de_novo,
                                                        confidence='MEDIUM')).
                            when((p_de_novo > 0.05) & (kid_ad_ratio > 0.2),
                                 hl.struct(p_de_novo=p_de_novo,
                                           confidence='LOW')).or_missing()))

        return hl.bind(solve, p_de_novo)

    de_novo_call = (hl.case().when(~het_hom_hom | kid_ad_fail, failure).when(
        autosomal,
        hl.bind(call_auto, kid_pp, dad_pp, mom_pp, kid_ad_ratio)).when(
            hemi_x | hemi_mt,
            hl.bind(call_hemi, kid_pp, mom, mom_pp, kid_ad_ratio)).when(
                hemi_y, hl.bind(call_hemi, kid_pp, dad, dad_pp,
                                kid_ad_ratio)).or_missing())

    tm = tm.annotate_entries(__call=de_novo_call)
    tm = tm.filter_entries(hl.is_defined(tm.__call))
    entries = tm.entries()
    return (entries.select('__site_freq', 'proband', 'father', 'mother',
                           'proband_entry', 'father_entry', 'mother_entry',
                           'is_female',
                           **entries.__call).rename({'__site_freq': 'prior'}))
Пример #12
0
import os
import sys
import hail as hl
from typing import *

arrs = [
    'info_Axiom_KP_UCSF_EUR',
    'info_Affymetrix_6.0',
    'info_Human550',
    'info_Human610660',
    'info_HumanOmni',
]

mt = hl.read_matrix_table(
    'gs://unicorn-qc/post-imputation-qc/merged/mt/pre_qc.mt')
mt = mt.annotate_entries(MGP=hl.max(mt.GP))

# Filter to allele frequencies > 0.005
for arr in arrs:
    mt = mt.filter_rows(mt[arr].MAF > 0.005, keep=True)

print('There are {n_cols} samples '\
      'and {n_rows} variants '\
      'in post-imputation data'.format(
            n_cols = mt.count_cols(),
            n_rows = mt.count_rows()))
# There are 54125 samples and 8604906 variants in post-imputation data
mt.write('gs://unicorn-qc/post-imputation-qc/merged/mt/qc-af_0.005.mt',
         overwrite=True)
Пример #13
0
def fs_from_sb(
    sb: Union[hl.expr.ArrayNumericExpression, hl.expr.ArrayExpression],
    normalize: bool = True,
    min_cell_count: int = 200,
    min_count: int = 4,
    min_p_value: float = 1e-320,
) -> hl.expr.Int64Expression:
    """
    Computes `FS` (Fisher strand balance) annotation from  the `SB` (strand balance table) field.
    `FS` is the phred-scaled value of the double-sided Fisher exact test on strand balance.

    Using default values will have the same behavior as the GATK implementation, that is:
    - If sum(counts) > 2*`min_cell_count` (default to GATK value of 200), they are normalized
    - If sum(counts) < `min_count` (default to GATK value of 4), returns missing
    - Any p-value < `min_p_value` (default to GATK value of 1e-320) is truncated to that value

    In addition to the default GATK behavior, setting `normalize` to `False` will perform a chi-squared test
    for large counts (> `min_cell_count`) instead of normalizing the cell values.

    .. note::

        This function can either take
        - an array of length four containing the forward and reverse strands' counts of ref and alt alleles: [ref fwd, ref rev, alt fwd, alt rev]
        - a two dimensional array with arrays of length two, containing the counts: [[ref fwd, ref rev], [alt fwd, alt rev]]

    GATK code here: https://github.com/broadinstitute/gatk/blob/master/src/main/java/org/broadinstitute/hellbender/tools/walkers/annotator/FisherStrand.java

    :param sb: Count of ref/alt reads on each strand
    :param normalize: Whether to normalize counts is sum(counts) > min_cell_count (normalize=True), or use a chi sq instead of FET (normalize=False)
    :param min_cell_count: Maximum count for performing a FET
    :param min_count: Minimum total count to output FS (otherwise null it output)
    :return: FS value
    """
    if not isinstance(sb, hl.expr.ArrayNumericExpression):
        sb = hl.bind(lambda x: hl.flatten(x), sb)

    sb_sum = hl.bind(lambda x: hl.sum(x), sb)

    # Normalize table if counts get too large
    if normalize:
        fs_expr = hl.bind(
            lambda sb, sb_sum: hl.cond(
                sb_sum <= 2 * min_cell_count,
                sb,
                sb.map(lambda x: hl.int(x / (sb_sum / min_cell_count))),
            ),
            sb,
            sb_sum,
        )

        # FET
        fs_expr = to_phred(
            hl.max(
                hl.fisher_exact_test(
                    fs_expr[0], fs_expr[1], fs_expr[2], fs_expr[3]
                ).p_value,
                min_p_value,
            )
        )
    else:
        fs_expr = to_phred(
            hl.max(
                hl.contingency_table_test(
                    sb[0], sb[1], sb[2], sb[3], min_cell_count=min_cell_count
                ).p_value,
                min_p_value,
            )
        )

    # Return null if counts <= `min_count`
    return hl.or_missing(
        sb_sum > min_count, hl.max(0, fs_expr)  # Needed to avoid -0.0 values
    )
Пример #14
0
def compute_from_vp_mt(chr20: bool, overwrite: bool):
    meta = get_gnomad_meta('exomes')
    vp_mt = hl.read_matrix_table(full_mt_path('exomes'))
    vp_mt = vp_mt.filter_cols(meta[vp_mt.col_key].release)
    ann_ht = hl.read_table(vp_ann_ht_path('exomes'))
    phase_ht = hl.read_table(phased_vp_count_ht_path('exomes'))

    if chr20:
        vp_mt, ann_ht, phase_ht = filter_to_chr20([vp_mt, ann_ht, phase_ht])

    vep1_expr = get_worst_gene_csq_code_expr(ann_ht.vep1)
    vep2_expr = get_worst_gene_csq_code_expr(ann_ht.vep2)
    ann_ht = ann_ht.select(
        'snv1',
        'snv2',
        is_singleton_vp=(ann_ht.freq1['all'].AC < 2) & (ann_ht.freq2['all'].AC < 2),
        pop_af=hl.dict(
            ann_ht.freq1.key_set().intersection(ann_ht.freq2.key_set())
                .map(
                lambda pop: hl.tuple([pop, hl.max(ann_ht.freq1[pop].AF, ann_ht.freq2[pop].AF)])
            )
        ),
        popmax_af=hl.max(ann_ht.popmax1.AF, ann_ht.popmax2.AF, filter_missing=False),
        filtered=(hl.len(ann_ht.filters1) > 0) | (hl.len(ann_ht.filters2) > 0),
        vep=vep1_expr.keys().filter(
            lambda k: vep2_expr.contains(k)
        ).map(
            lambda k: vep1_expr[k].annotate(
                csq=hl.max(vep1_expr[k].csq, vep2_expr[k].csq)
            )
        )
    )

    vp_mt = vp_mt.annotate_cols(
        pop=meta[vp_mt.col_key].pop
    )
    vp_mt = vp_mt.annotate_rows(
        **ann_ht[vp_mt.row_key],
        phase_info=phase_ht[vp_mt.row_key].phase_info
    )

    vp_mt = vp_mt.filter_rows(
        ~vp_mt.filtered
    )

    vp_mt = vp_mt.filter_entries(
        vp_mt.GT1.is_het() & vp_mt.GT2.is_het() & vp_mt.adj1 & vp_mt.adj2
    )

    vp_mt = vp_mt.select_entries(
        x=True
    )

    vp_mt = vp_mt.annotate_cols(
        pop=['all', vp_mt.pop]
    )
    vp_mt = vp_mt.explode_cols('pop')

    vp_mt = vp_mt.explode_rows('vep')
    vp_mt = vp_mt.transmute_rows(
        **vp_mt.vep
    )

    def get_grouped_phase_agg():
        return hl.agg.group_by(
            hl.case()
                .when(~vp_mt.is_singleton_vp & (vp_mt.phase_info[vp_mt.pop].em.adj.p_chet > CHET_THRESHOLD), 1)
                .when(~vp_mt.is_singleton_vp & (vp_mt.phase_info[vp_mt.pop].em.adj.p_chet < SAME_HAP_THRESHOLD), 2)
                .default(3)
            ,
            hl.agg.min(vp_mt.csq)
        )

    vp_mt = vp_mt.group_rows_by(
        'gene_id',
        'gene_symbol'
    ).aggregate(
        all=hl.agg.filter(
            vp_mt.x &
            hl.if_else(
                vp_mt.pop == 'all',
                hl.is_defined(vp_mt.popmax_af) &
                (vp_mt.popmax_af <= MAX_FREQ),
                vp_mt.pop_af[vp_mt.pop] <= MAX_FREQ
            ),
            get_grouped_phase_agg()
        ),
        af_le_0_001=hl.agg.filter(
            hl.if_else(
                vp_mt.pop == 'all',
                hl.is_defined(vp_mt.popmax_af) &
                (vp_mt.popmax_af <= 0.001),
                vp_mt.pop_af[vp_mt.pop] <= 0.001
            )
            & vp_mt.x,
            get_grouped_phase_agg()
        )
    )

    vp_mt = vp_mt.checkpoint('gs://gnomad-tmp/compound_hets/chet_per_gene{}.2.mt'.format(
        '.chr20' if chr20 else ''
    ), overwrite=True)

    gene_ht = vp_mt.annotate_rows(
        row_counts=hl.flatten([
            hl.array(
                hl.agg.group_by(
                    vp_mt.pop,
                    hl.struct(
                        csq=csq,
                        af=af,
                        # TODO: Review this
                        # These will only kept the worst csq -- now maybe it'd be better to keep either
                        # - the worst csq for chet or
                        # - the worst csq for both chet and same_hap
                        n_worst_chet=hl.agg.count_where(vp_mt[af].get(1) == csq_i),
                        n_chet=hl.agg.count_where((vp_mt[af].get(1) == csq_i) & (vp_mt[af].get(2, 9) >= csq_i) & (vp_mt[af].get(3, 9) >= csq_i)),
                        n_same_hap=hl.agg.count_where((vp_mt[af].get(2) == csq_i) & (vp_mt[af].get(1, 9) > csq_i) & (vp_mt[af].get(3, 9) >= csq_i)),
                        n_unphased=hl.agg.count_where((vp_mt[af].get(3) == csq_i) & (vp_mt[af].get(1, 9) > csq_i) & (vp_mt[af].get(2, 9) > csq_i))
                    )
                )
            ).filter(
                lambda x: (x[1].n_chet > 0) | (x[1].n_same_hap > 0) | (x[1].n_unphased > 0)
            ).map(
                lambda x: x[1].annotate(
                    pop=x[0]
                )
            )
            for csq_i, csq in enumerate(CSQ_CODES)
            for af in ['all', 'af_le_0_001']
        ])
    ).rows()

    gene_ht = gene_ht.explode('row_counts')
    gene_ht = gene_ht.select(
        **gene_ht.row_counts
    )

    gene_ht.describe()
    gene_ht = gene_ht.checkpoint(
        'gs://gnomad-lfran/compound_hets/chet_per_gene{}.ht'.format(
            '.chr20' if chr20 else ''
        ),
        overwrite=overwrite
    )

    gene_ht.flatten().export(
        'gs://gnomad-lfran/compound_hets/chet_per_gene{}.tsv.gz'.format(
            '.chr20' if chr20 else ''
        )
    )
Пример #15
0
    joined_gnomad_exomes_ht = gnomad_exomes[(denovos.locus, denovos.alleles)]

    denovos = denovos.annotate(
        dataset=data_label,
        variant_type=get_expr_for_variant_type(denovos),
        in_LCR=joined_mt_rows.info.in_LCR,
        in_segdup=joined_mt_rows.info.in_segdup,
        filters=hl.cond(hl.len(joined_mt_rows.filters) > 0, hl.delimit(joined_mt_rows.filters, ','), 'PASS'),
        AC=joined_mt_rows.info.AC,
        AF=joined_mt_rows.info.AF, 
        QD=joined_mt_rows.info.QD,
        gnomAD_genomes_AF = joined_gnomad_genomes_ht.freq[0].AF,
        gnomAD_genomes_AC = joined_gnomad_genomes_ht.freq[0].AC,
        gnomAD_exomes_AF = joined_gnomad_exomes_ht.freq[0].AF,
        gnomAD_exomes_AC = joined_gnomad_exomes_ht.freq[0].AC,
        gnomAD_popmax_AF = hl.max(joined_gnomad_exomes_ht.popmax[0].AF, joined_gnomad_genomes_ht.popmax[0].AF),
    )

    clinvar_ht = hl.read_table("gs://seqr-bw2/ref/GRCh38/clinvar_20190715.pathogenic.ht")
    clinvar_ht = hl.split_multi_hts(clinvar_ht)
    joined_clinvar_ht = clinvar_ht[(denovos.locus, denovos.alleles)]
    denovos = denovos.annotate(
        clinvar_alleleid = joined_clinvar_ht.info.ALLELEID,
        clinvar_clnsig = hl.delimit(joined_clinvar_ht.info.CLNSIG, delimiter=','),
        clinvar_revstat = hl.delimit(joined_clinvar_ht.info.CLNREVSTAT, delimiter=','),
    )

    #if args.mendelian:
    #    denovos = denovos.annotate(
    #        s = hl.delimit(denovos.s.split("").map(lambda l: OBFUSCATE[l]), ""),
    #    )
    "obs_mis": hl.tfloat,
    "exp_mis": hl.tfloat,
    "obs_exp": hl.tfloat,
    "chisq_diff_null": hl.tfloat,
    "region_name": hl.tstr,
}

ds = hl.import_table(args.input_url, missing="", types=column_types)

###########
# Prepare #
###########

ds = ds.annotate(
    start=hl.min(ds.genomic_start, ds.genomic_end),
    stop=hl.max(ds.genomic_start, ds.genomic_end),
)

ds = ds.annotate(
    xstart=get_expr_for_xpos(hl.locus(ds.chr, ds.start)),
    xstop=get_expr_for_xpos(hl.locus(ds.chr, ds.stop)),
)

ds = ds.drop("genomic_start", "genomic_end")

ds = ds.transmute(
    chrom=ds.chr, gene_name=ds.gene, transcript_id=ds.transcript.split("\.")[0]
)

ds = ds.drop("region_name")
Пример #17
0
    def test_annotate(self):
        schema = hl.tstruct(a=hl.tint32, b=hl.tint32, c=hl.tint32, d=hl.tint32, e=hl.tstr, f=hl.tarray(hl.tint32))

        rows = [{'a': 4, 'b': 1, 'c': 3, 'd': 5, 'e': "hello", 'f': [1, 2, 3]},
                {'a': 0, 'b': 5, 'c': 13, 'd': -1, 'e': "cat", 'f': []},
                {'a': 4, 'b': 2, 'c': 20, 'd': 3, 'e': "dog", 'f': [5, 6, 7]}]

        kt = hl.Table.parallelize(rows, schema)

        self.assertTrue(kt.annotate()._same(kt))

        result1 = convert_struct_to_dict(kt.annotate(foo=kt.a + 1,
                                                     foo2=kt.a).take(1)[0])

        self.assertDictEqual(result1, {'a': 4,
                                       'b': 1,
                                       'c': 3,
                                       'd': 5,
                                       'e': "hello",
                                       'f': [1, 2, 3],
                                       'foo': 5,
                                       'foo2': 4})

        result3 = convert_struct_to_dict(kt.annotate(
            x1=kt.f.map(lambda x: x * 2),
            x2=kt.f.map(lambda x: [x, x + 1]).flatmap(lambda x: x),
            x3=hl.min(kt.f),
            x4=hl.max(kt.f),
            x5=hl.sum(kt.f),
            x6=hl.product(kt.f),
            x7=kt.f.length(),
            x8=kt.f.filter(lambda x: x == 3),
            x9=kt.f[1:],
            x10=kt.f[:],
            x11=kt.f[1:2],
            x12=kt.f.map(lambda x: [x, x + 1]),
            x13=kt.f.map(lambda x: [[x, x + 1], [x + 2]]).flatmap(lambda x: x),
            x14=hl.cond(kt.a < kt.b, kt.c, hl.null(hl.tint32)),
            x15={1, 2, 3}
        ).take(1)[0])

        self.assertDictEqual(result3, {'a': 4,
                                       'b': 1,
                                       'c': 3,
                                       'd': 5,
                                       'e': "hello",
                                       'f': [1, 2, 3],
                                       'x1': [2, 4, 6], 'x2': [1, 2, 2, 3, 3, 4],
                                       'x3': 1, 'x4': 3, 'x5': 6, 'x6': 6, 'x7': 3, 'x8': [3],
                                       'x9': [2, 3], 'x10': [1, 2, 3], 'x11': [2],
                                       'x12': [[1, 2], [2, 3], [3, 4]],
                                       'x13': [[1, 2], [3], [2, 3], [4], [3, 4], [5]],
                                       'x14': None, 'x15': set([1, 2, 3])})
        kt.annotate(
            x1=kt.a + 5,
            x2=5 + kt.a,
            x3=kt.a + kt.b,
            x4=kt.a - 5,
            x5=5 - kt.a,
            x6=kt.a - kt.b,
            x7=kt.a * 5,
            x8=5 * kt.a,
            x9=kt.a * kt.b,
            x10=kt.a / 5,
            x11=5 / kt.a,
            x12=kt.a / kt.b,
            x13=-kt.a,
            x14=+kt.a,
            x15=kt.a == kt.b,
            x16=kt.a == 5,
            x17=5 == kt.a,
            x18=kt.a != kt.b,
            x19=kt.a != 5,
            x20=5 != kt.a,
            x21=kt.a > kt.b,
            x22=kt.a > 5,
            x23=5 > kt.a,
            x24=kt.a >= kt.b,
            x25=kt.a >= 5,
            x26=5 >= kt.a,
            x27=kt.a < kt.b,
            x28=kt.a < 5,
            x29=5 < kt.a,
            x30=kt.a <= kt.b,
            x31=kt.a <= 5,
            x32=5 <= kt.a,
            x33=(kt.a == 0) & (kt.b == 5),
            x34=(kt.a == 0) | (kt.b == 5),
            x35=False,
            x36=True
        )
Пример #18
0
def ld_score_regression(weight_expr,
                        ld_score_expr,
                        chi_sq_exprs,
                        n_samples_exprs,
                        n_blocks=200,
                        two_step_threshold=30,
                        n_reference_panel_variants=None) -> Table:
    r"""Estimate SNP-heritability and level of confounding biases from
    GWAS summary statistics.

    Given a set or multiple sets of genome-wide association study (GWAS)
    summary statistics, :func:`.ld_score_regression` estimates the heritability
    of a trait or set of traits and the level of confounding biases present in
    the underlying studies by regressing chi-squared statistics on LD scores,
    leveraging the model:

    .. math::

        \mathrm{E}[\chi_j^2] = 1 + Na + \frac{Nh_g^2}{M}l_j

    *  :math:`\mathrm{E}[\chi_j^2]` is the expected chi-squared statistic
       for variant :math:`j` resulting from a test of association between
       variant :math:`j` and a trait.
    *  :math:`l_j = \sum_{k} r_{jk}^2` is the LD score of variant
       :math:`j`, calculated as the sum of squared correlation coefficients
       between variant :math:`j` and nearby variants. See :func:`ld_score`
       for further details.
    *  :math:`a` captures the contribution of confounding biases, such as
       cryptic relatedness and uncontrolled population structure, to the
       association test statistic.
    *  :math:`h_g^2` is the SNP-heritability, or the proportion of variation
       in the trait explained by the effects of variants included in the
       regression model above.
    *  :math:`M` is the number of variants used to estimate :math:`h_g^2`.
    *  :math:`N` is the number of samples in the underlying association study.

    For more details on the method implemented in this function, see:

    * `LD Score regression distinguishes confounding from polygenicity in genome-wide association studies (Bulik-Sullivan et al, 2015) <https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4495769/>`__

    Examples
    --------

    Run the method on a matrix table of summary statistics, where the rows
    are variants and the columns are different phenotypes:

    >>> mt_gwas = hl.read_matrix_table('data/ld_score_regression.sumstats.mt')
    >>> ht_results = hl.experimental.ld_score_regression(
    ...     weight_expr=mt_gwas['ld_score'],
    ...     ld_score_expr=mt_gwas['ld_score'],
    ...     chi_sq_exprs=mt_gwas['chi_squared'],
    ...     n_samples_exprs=mt_gwas['n'])


    Run the method on a table with summary statistics for a single
    phenotype:

    >>> ht_gwas = hl.read_table('data/ld_score_regression.sumstats.ht')
    >>> ht_results = hl.experimental.ld_score_regression(
    ...     weight_expr=ht_gwas['ld_score'],
    ...     ld_score_expr=ht_gwas['ld_score'],
    ...     chi_sq_exprs=ht_gwas['chi_squared_50_irnt'],
    ...     n_samples_exprs=ht_gwas['n_50_irnt'])

    Run the method on a table with summary statistics for multiple
    phenotypes:

    >>> ht_gwas = hl.read_table('data/ld_score_regression.sumstats.ht')
    >>> ht_results = hl.experimental.ld_score_regression(
    ...     weight_expr=ht_gwas['ld_score'],
    ...     ld_score_expr=ht_gwas['ld_score'],
    ...     chi_sq_exprs=[ht_gwas['chi_squared_50_irnt'],
    ...                        ht_gwas['chi_squared_20160']],
    ...     n_samples_exprs=[ht_gwas['n_50_irnt'],
    ...                      ht_gwas['n_20160']])

    Notes
    -----
    The ``exprs`` provided as arguments to :func:`.ld_score_regression`
    must all be from the same object, either a :class:`Table` or a
    :class:`MatrixTable`.

    **If the arguments originate from a table:**

    *  The table must be keyed by fields ``locus`` of type
       :class:`.tlocus` and ``alleles``, a :py:data:`.tarray` of
       :py:data:`.tstr` elements.
    *  ``weight_expr``, ``ld_score_expr``, ``chi_sq_exprs``, and
       ``n_samples_exprs`` are must be row-indexed fields.
    *  The number of expressions passed to ``n_samples_exprs`` must be
       equal to one or the number of expressions passed to
       ``chi_sq_exprs``. If just one expression is passed to
       ``n_samples_exprs``, that sample size expression is assumed to
       apply to all sets of statistics passed to ``chi_sq_exprs``.
       Otherwise, the expressions passed to ``chi_sq_exprs`` and
       ``n_samples_exprs`` are matched by index.
    *  The ``phenotype`` field that keys the table returned by
       :func:`.ld_score_regression` will have generic :obj:`int` values
       ``0``, ``1``, etc. corresponding to the ``0th``, ``1st``, etc.
       expressions passed to the ``chi_sq_exprs`` argument.

    **If the arguments originate from a matrix table:**

    *  The dimensions of the matrix table must be variants
       (rows) by phenotypes (columns).
    *  The rows of the matrix table must be keyed by fields
       ``locus`` of type :class:`.tlocus` and ``alleles``,
       a :py:data:`.tarray` of :py:data:`.tstr` elements.
    *  The columns of the matrix table must be keyed by a field
       of type :py:data:`.tstr` that uniquely identifies phenotypes
       represented in the matrix table. The column key must be a single
       expression; compound keys are not accepted.
    *  ``weight_expr`` and ``ld_score_expr`` must be row-indexed
       fields.
    *  ``chi_sq_exprs`` must be a single entry-indexed field
       (not a list of fields).
    *  ``n_samples_exprs`` must be a single entry-indexed field
       (not a list of fields).
    *  The ``phenotype`` field that keys the table returned by
       :func:`.ld_score_regression` will have values corresponding to the
       column keys of the input matrix table.

    This function returns a :class:`Table` with one row per set of summary
    statistics passed to the ``chi_sq_exprs`` argument. The following
    row-indexed fields are included in the table:

    *  **phenotype** (:py:data:`.tstr`) -- The name of the phenotype. The
       returned table is keyed by this field. See the notes below for
       details on the possible values of this field.
    *  **mean_chi_sq** (:py:data:`.tfloat64`) -- The mean chi-squared
       test statistic for the given phenotype.
    *  **intercept** (`Struct`) -- Contains fields:

       -  **estimate** (:py:data:`.tfloat64`) -- A point estimate of the
          intercept :math:`1 + Na`.
       -  **standard_error**  (:py:data:`.tfloat64`) -- An estimate of
          the standard error of this point estimate.

    *  **snp_heritability** (`Struct`) -- Contains fields:

       -  **estimate** (:py:data:`.tfloat64`) -- A point estimate of the
          SNP-heritability :math:`h_g^2`.
       -  **standard_error** (:py:data:`.tfloat64`) -- An estimate of
          the standard error of this point estimate.

    Warning
    -------
    :func:`.ld_score_regression` considers only the rows for which both row
    fields ``weight_expr`` and ``ld_score_expr`` are defined. Rows with missing
    values in either field are removed prior to fitting the LD score
    regression model.

    Parameters
    ----------
    weight_expr : :class:`.Float64Expression`
                  Row-indexed expression for the LD scores used to derive
                  variant weights in the model.
    ld_score_expr : :class:`.Float64Expression`
                    Row-indexed expression for the LD scores used as covariates
                    in the model.
    chi_sq_exprs : :class:`.Float64Expression` or :obj:`list` of
                        :class:`.Float64Expression`
                        One or more row-indexed (if table) or entry-indexed
                        (if matrix table) expressions for chi-squared
                        statistics resulting from genome-wide association
                        studies.
    n_samples_exprs: :class:`.NumericExpression` or :obj:`list` of
                     :class:`.NumericExpression`
                     One or more row-indexed (if table) or entry-indexed
                     (if matrix table) expressions indicating the number of
                     samples used in the studies that generated the test
                     statistics supplied to ``chi_sq_exprs``.
    n_blocks : :obj:`int`
               The number of blocks used in the jackknife approach to
               estimating standard errors.
    two_step_threshold : :obj:`int`
                         Variants with chi-squared statistics greater than this
                         value are excluded in the first step of the two-step
                         procedure used to fit the model.
    n_reference_panel_variants : :obj:`int`, optional
                                 Number of variants used to estimate the
                                 SNP-heritability :math:`h_g^2`.

    Returns
    -------
    :class:`.Table`
        Table keyed by ``phenotype`` with intercept and heritability estimates
        for each phenotype passed to the function."""

    chi_sq_exprs = wrap_to_list(chi_sq_exprs)
    n_samples_exprs = wrap_to_list(n_samples_exprs)

    assert ((len(chi_sq_exprs) == len(n_samples_exprs)) or
            (len(n_samples_exprs) == 1))
    __k = 2  # number of covariates, including intercept

    ds = chi_sq_exprs[0]._indices.source

    analyze('ld_score_regression/weight_expr',
            weight_expr,
            ds._row_indices)
    analyze('ld_score_regression/ld_score_expr',
            ld_score_expr,
            ds._row_indices)

    # format input dataset
    if isinstance(ds, MatrixTable):
        if len(chi_sq_exprs) != 1:
            raise ValueError("""Only one chi_sq_expr allowed if originating
                from a matrix table.""")
        if len(n_samples_exprs) != 1:
            raise ValueError("""Only one n_samples_expr allowed if
                originating from a matrix table.""")

        col_key = list(ds.col_key)
        if len(col_key) != 1:
            raise ValueError("""Matrix table must be keyed by a single
                phenotype field.""")

        analyze('ld_score_regression/chi_squared_expr',
                chi_sq_exprs[0],
                ds._entry_indices)
        analyze('ld_score_regression/n_samples_expr',
                n_samples_exprs[0],
                ds._entry_indices)

        ds = ds._select_all(row_exprs={'__locus': ds.locus,
                                       '__alleles': ds.alleles,
                                       '__w_initial': weight_expr,
                                       '__w_initial_floor': hl.max(weight_expr,
                                                                   1.0),
                                       '__x': ld_score_expr,
                                       '__x_floor': hl.max(ld_score_expr,
                                                           1.0)},
                            row_key=['__locus', '__alleles'],
                            col_exprs={'__y_name': ds[col_key[0]]},
                            col_key=['__y_name'],
                            entry_exprs={'__y': chi_sq_exprs[0],
                                         '__n': n_samples_exprs[0]})
        ds = ds.annotate_entries(**{'__w': ds.__w_initial})

        ds = ds.filter_rows(hl.is_defined(ds.__locus) &
                            hl.is_defined(ds.__alleles) &
                            hl.is_defined(ds.__w_initial) &
                            hl.is_defined(ds.__x))

    else:
        assert isinstance(ds, Table)
        for y in chi_sq_exprs:
            analyze('ld_score_regression/chi_squared_expr', y, ds._row_indices)
        for n in n_samples_exprs:
            analyze('ld_score_regression/n_samples_expr', n, ds._row_indices)

        ys = ['__y{:}'.format(i) for i, _ in enumerate(chi_sq_exprs)]
        ws = ['__w{:}'.format(i) for i, _ in enumerate(chi_sq_exprs)]
        ns = ['__n{:}'.format(i) for i, _ in enumerate(n_samples_exprs)]

        ds = ds.select(**dict(**{'__locus': ds.locus,
                                 '__alleles': ds.alleles,
                                 '__w_initial': weight_expr,
                                 '__x': ld_score_expr},
                              **{y: chi_sq_exprs[i]
                                 for i, y in enumerate(ys)},
                              **{w: weight_expr for w in ws},
                              **{n: n_samples_exprs[i]
                                 for i, n in enumerate(ns)}))
        ds = ds.key_by(ds.__locus, ds.__alleles)

        table_tmp_file = new_temp_file()
        ds.write(table_tmp_file)
        ds = hl.read_table(table_tmp_file)

        hts = [ds.select(**{'__w_initial': ds.__w_initial,
                            '__w_initial_floor': hl.max(ds.__w_initial,
                                                        1.0),
                            '__x': ds.__x,
                            '__x_floor': hl.max(ds.__x, 1.0),
                            '__y_name': i,
                            '__y': ds[ys[i]],
                            '__w': ds[ws[i]],
                            '__n': hl.int(ds[ns[i]])})
               for i, y in enumerate(ys)]

        mts = [ht.to_matrix_table(row_key=['__locus',
                                           '__alleles'],
                                  col_key=['__y_name'],
                                  row_fields=['__w_initial',
                                              '__w_initial_floor',
                                              '__x',
                                              '__x_floor'])
               for ht in hts]

        ds = mts[0]
        for i in range(1, len(ys)):
            ds = ds.union_cols(mts[i])

        ds = ds.filter_rows(hl.is_defined(ds.__locus) &
                            hl.is_defined(ds.__alleles) &
                            hl.is_defined(ds.__w_initial) &
                            hl.is_defined(ds.__x))

    mt_tmp_file1 = new_temp_file()
    ds.write(mt_tmp_file1)
    mt = hl.read_matrix_table(mt_tmp_file1)

    if not n_reference_panel_variants:
        M = mt.count_rows()
    else:
        M = n_reference_panel_variants

    # block variants for each phenotype
    n_phenotypes = mt.count_cols()

    mt = mt.annotate_entries(__in_step1=(hl.is_defined(mt.__y) &
                                         (mt.__y < two_step_threshold)),
                             __in_step2=hl.is_defined(mt.__y))

    mt = mt.annotate_cols(__col_idx=hl.int(hl.scan.count()),
                          __m_step1=hl.agg.count_where(mt.__in_step1),
                          __m_step2=hl.agg.count_where(mt.__in_step2))

    col_keys = list(mt.col_key)

    ht = mt.localize_entries(entries_array_field_name='__entries',
                             columns_array_field_name='__cols')

    ht = ht.annotate(__entries=hl.rbind(
        hl.scan.array_agg(
            lambda entry: hl.scan.count_where(entry.__in_step1),
            ht.__entries),
        lambda step1_indices: hl.map(
            lambda i: hl.rbind(
                hl.int(hl.or_else(step1_indices[i], 0)),
                ht.__cols[i].__m_step1,
                ht.__entries[i],
                lambda step1_idx, m_step1, entry: hl.rbind(
                    hl.map(
                        lambda j: hl.int(hl.floor(j * (m_step1 / n_blocks))),
                        hl.range(0, n_blocks + 1)),
                    lambda step1_separators: hl.rbind(
                        hl.set(step1_separators).contains(step1_idx),
                        hl.sum(
                            hl.map(
                                lambda s1: step1_idx >= s1,
                                step1_separators)) - 1,
                        lambda is_separator, step1_block: entry.annotate(
                            __step1_block=step1_block,
                            __step2_block=hl.cond(~entry.__in_step1 & is_separator,
                                                  step1_block - 1,
                                                  step1_block))))),
            hl.range(0, hl.len(ht.__entries)))))

    mt = ht._unlocalize_entries('__entries', '__cols', col_keys)

    mt_tmp_file2 = new_temp_file()
    mt.write(mt_tmp_file2)
    mt = hl.read_matrix_table(mt_tmp_file2)
    
    # initial coefficient estimates
    mt = mt.annotate_cols(__initial_betas=[
        1.0, (hl.agg.mean(mt.__y) - 1.0) / hl.agg.mean(mt.__x)])
    mt = mt.annotate_cols(__step1_betas=mt.__initial_betas,
                          __step2_betas=mt.__initial_betas)

    # step 1 iteratively reweighted least squares
    for i in range(3):
        mt = mt.annotate_entries(__w=hl.cond(
            mt.__in_step1,
            1.0/(mt.__w_initial_floor * 2.0 * (mt.__step1_betas[0] +
                                               mt.__step1_betas[1] *
                                               mt.__x_floor)**2),
            0.0))
        mt = mt.annotate_cols(__step1_betas=hl.agg.filter(
            mt.__in_step1,
            hl.agg.linreg(y=mt.__y,
                          x=[1.0, mt.__x],
                          weight=mt.__w).beta))
        mt = mt.annotate_cols(__step1_h2=hl.max(hl.min(
            mt.__step1_betas[1] * M / hl.agg.mean(mt.__n), 1.0), 0.0))
        mt = mt.annotate_cols(__step1_betas=[
            mt.__step1_betas[0],
            mt.__step1_h2 * hl.agg.mean(mt.__n) / M])

    # step 1 block jackknife
    mt = mt.annotate_cols(__step1_block_betas=[
        hl.agg.filter((mt.__step1_block != i) & mt.__in_step1,
                      hl.agg.linreg(y=mt.__y,
                                    x=[1.0, mt.__x],
                                    weight=mt.__w).beta)
        for i in range(n_blocks)])

    mt = mt.annotate_cols(__step1_block_betas_bias_corrected=hl.map(
        lambda x: n_blocks * mt.__step1_betas - (n_blocks - 1) * x,
        mt.__step1_block_betas))

    mt = mt.annotate_cols(
        __step1_jackknife_mean=hl.map(
            lambda i: hl.mean(
                hl.map(lambda x: x[i],
                       mt.__step1_block_betas_bias_corrected)),
            hl.range(0, __k)),
        __step1_jackknife_variance=hl.map(
            lambda i: (hl.sum(
                hl.map(lambda x: x[i]**2,
                       mt.__step1_block_betas_bias_corrected)) -
                       hl.sum(
                hl.map(lambda x: x[i],
                       mt.__step1_block_betas_bias_corrected))**2 /
                       n_blocks) /
            (n_blocks - 1) / n_blocks,
            hl.range(0, __k)))

    # step 2 iteratively reweighted least squares
    for i in range(3):
        mt = mt.annotate_entries(__w=hl.cond(
            mt.__in_step2,
            1.0/(mt.__w_initial_floor *
                 2.0 * (mt.__step2_betas[0] +
                        mt.__step2_betas[1] *
                        mt.__x_floor)**2),
            0.0))
        mt = mt.annotate_cols(__step2_betas=[
            mt.__step1_betas[0],
            hl.agg.filter(mt.__in_step2,
                          hl.agg.linreg(y=mt.__y - mt.__step1_betas[0],
                                        x=[mt.__x],
                                        weight=mt.__w).beta[0])])
        mt = mt.annotate_cols(__step2_h2=hl.max(hl.min(
            mt.__step2_betas[1] * M/hl.agg.mean(mt.__n), 1.0), 0.0))
        mt = mt.annotate_cols(__step2_betas=[
            mt.__step1_betas[0],
            mt.__step2_h2 * hl.agg.mean(mt.__n)/M])

    # step 2 block jackknife
    mt = mt.annotate_cols(__step2_block_betas=[
        hl.agg.filter((mt.__step2_block != i) & mt.__in_step2,
                      hl.agg.linreg(y=mt.__y - mt.__step1_betas[0],
                                    x=[mt.__x],
                                    weight=mt.__w).beta[0])
        for i in range(n_blocks)])

    mt = mt.annotate_cols(__step2_block_betas_bias_corrected=hl.map(
        lambda x: n_blocks * mt.__step2_betas[1] - (n_blocks - 1) * x,
        mt.__step2_block_betas))

    mt = mt.annotate_cols(
        __step2_jackknife_mean=hl.mean(
            mt.__step2_block_betas_bias_corrected),
        __step2_jackknife_variance=(
            hl.sum(mt.__step2_block_betas_bias_corrected**2) -
            hl.sum(mt.__step2_block_betas_bias_corrected)**2 /
            n_blocks) / (n_blocks - 1) / n_blocks)

    # combine step 1 and step 2 block jackknifes
    mt = mt.annotate_entries(
        __step2_initial_w=1.0/(mt.__w_initial_floor *
                               2.0 * (mt.__initial_betas[0] +
                                      mt.__initial_betas[1] *
                                      mt.__x_floor)**2))

    mt = mt.annotate_cols(
        __final_betas=[
            mt.__step1_betas[0],
            mt.__step2_betas[1]],
        __c=(hl.agg.sum(mt.__step2_initial_w * mt.__x) /
             hl.agg.sum(mt.__step2_initial_w * mt.__x**2)))

    mt = mt.annotate_cols(__final_block_betas=hl.map(
        lambda i: (mt.__step2_block_betas[i] - mt.__c *
                   (mt.__step1_block_betas[i][0] - mt.__final_betas[0])),
        hl.range(0, n_blocks)))

    mt = mt.annotate_cols(
        __final_block_betas_bias_corrected=(n_blocks * mt.__final_betas[1] -
                                            (n_blocks - 1) *
                                            mt.__final_block_betas))

    mt = mt.annotate_cols(
        __final_jackknife_mean=[
            mt.__step1_jackknife_mean[0],
            hl.mean(mt.__final_block_betas_bias_corrected)],
        __final_jackknife_variance=[
            mt.__step1_jackknife_variance[0],
            (hl.sum(mt.__final_block_betas_bias_corrected**2) -
             hl.sum(mt.__final_block_betas_bias_corrected)**2 /
             n_blocks) / (n_blocks - 1) / n_blocks])

    # convert coefficient to heritability estimate
    mt = mt.annotate_cols(
        phenotype=mt.__y_name,
        mean_chi_sq=hl.agg.mean(mt.__y),
        intercept=hl.struct(
            estimate=mt.__final_betas[0],
            standard_error=hl.sqrt(mt.__final_jackknife_variance[0])),
        snp_heritability=hl.struct(
            estimate=(M/hl.agg.mean(mt.__n)) * mt.__final_betas[1],
            standard_error=hl.sqrt((M/hl.agg.mean(mt.__n))**2 *
                                   mt.__final_jackknife_variance[1])))

    # format and return results
    ht = mt.cols()
    ht = ht.key_by(ht.phenotype)
    ht = ht.select(ht.mean_chi_sq,
                   ht.intercept,
                   ht.snp_heritability)

    ht_tmp_file = new_temp_file()
    ht.write(ht_tmp_file)
    ht = hl.read_table(ht_tmp_file)
    
    return ht
Пример #19
0
def ld_score_regression(weight_expr,
                        ld_score_expr,
                        chi_sq_exprs,
                        n_samples_exprs,
                        n_blocks=200,
                        two_step_threshold=30,
                        n_reference_panel_variants=None) -> Table:
    r"""Estimate SNP-heritability and level of confounding biases from
    GWAS summary statistics.

    Given a set or multiple sets of genome-wide association study (GWAS)
    summary statistics, :func:`.ld_score_regression` estimates the heritability
    of a trait or set of traits and the level of confounding biases present in
    the underlying studies by regressing chi-squared statistics on LD scores,
    leveraging the model:

    .. math::

        \mathrm{E}[\chi_j^2] = 1 + Na + \frac{Nh_g^2}{M}l_j

    *  :math:`\mathrm{E}[\chi_j^2]` is the expected chi-squared statistic
       for variant :math:`j` resulting from a test of association between
       variant :math:`j` and a trait.
    *  :math:`l_j = \sum_{k} r_{jk}^2` is the LD score of variant
       :math:`j`, calculated as the sum of squared correlation coefficients
       between variant :math:`j` and nearby variants. See :func:`ld_score`
       for further details.
    *  :math:`a` captures the contribution of confounding biases, such as
       cryptic relatedness and uncontrolled population structure, to the
       association test statistic.
    *  :math:`h_g^2` is the SNP-heritability, or the proportion of variation
       in the trait explained by the effects of variants included in the
       regression model above.
    *  :math:`M` is the number of variants used to estimate :math:`h_g^2`.
    *  :math:`N` is the number of samples in the underlying association study.

    For more details on the method implemented in this function, see:

    * `LD Score regression distinguishes confounding from polygenicity in genome-wide association studies (Bulik-Sullivan et al, 2015) <https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4495769/>`__

    Examples
    --------

    Run the method on a matrix table of summary statistics, where the rows
    are variants and the columns are different phenotypes:

    >>> mt_gwas = ld_score_all_phenos_sumstats
    >>> ht_results = hl.experimental.ld_score_regression(
    ...     weight_expr=mt_gwas['ld_score'],
    ...     ld_score_expr=mt_gwas['ld_score'],
    ...     chi_sq_exprs=mt_gwas['chi_squared'],
    ...     n_samples_exprs=mt_gwas['n'])


    Run the method on a table with summary statistics for a single
    phenotype:

    >>> ht_gwas = ld_score_one_pheno_sumstats
    >>> ht_results = hl.experimental.ld_score_regression(
    ...     weight_expr=ht_gwas['ld_score'],
    ...     ld_score_expr=ht_gwas['ld_score'],
    ...     chi_sq_exprs=ht_gwas['chi_squared_50_irnt'],
    ...     n_samples_exprs=ht_gwas['n_50_irnt'])

    Run the method on a table with summary statistics for multiple
    phenotypes:

    >>> ht_gwas = ld_score_one_pheno_sumstats
    >>> ht_results = hl.experimental.ld_score_regression(
    ...     weight_expr=ht_gwas['ld_score'],
    ...     ld_score_expr=ht_gwas['ld_score'],
    ...     chi_sq_exprs=[ht_gwas['chi_squared_50_irnt'],
    ...                        ht_gwas['chi_squared_20160']],
    ...     n_samples_exprs=[ht_gwas['n_50_irnt'],
    ...                      ht_gwas['n_20160']])

    Notes
    -----
    The ``exprs`` provided as arguments to :func:`.ld_score_regression`
    must all be from the same object, either a :class:`Table` or a
    :class:`MatrixTable`.

    **If the arguments originate from a table:**

    *  The table must be keyed by fields ``locus`` of type
       :class:`.tlocus` and ``alleles``, a :py:data:`.tarray` of
       :py:data:`.tstr` elements.
    *  ``weight_expr``, ``ld_score_expr``, ``chi_sq_exprs``, and
       ``n_samples_exprs`` are must be row-indexed fields.
    *  The number of expressions passed to ``n_samples_exprs`` must be
       equal to one or the number of expressions passed to
       ``chi_sq_exprs``. If just one expression is passed to
       ``n_samples_exprs``, that sample size expression is assumed to
       apply to all sets of statistics passed to ``chi_sq_exprs``.
       Otherwise, the expressions passed to ``chi_sq_exprs`` and
       ``n_samples_exprs`` are matched by index.
    *  The ``phenotype`` field that keys the table returned by
       :func:`.ld_score_regression` will have generic :obj:`int` values
       ``0``, ``1``, etc. corresponding to the ``0th``, ``1st``, etc.
       expressions passed to the ``chi_sq_exprs`` argument.

    **If the arguments originate from a matrix table:**

    *  The dimensions of the matrix table must be variants
       (rows) by phenotypes (columns).
    *  The rows of the matrix table must be keyed by fields
       ``locus`` of type :class:`.tlocus` and ``alleles``,
       a :py:data:`.tarray` of :py:data:`.tstr` elements.
    *  The columns of the matrix table must be keyed by a field
       of type :py:data:`.tstr` that uniquely identifies phenotypes
       represented in the matrix table. The column key must be a single
       expression; compound keys are not accepted.
    *  ``weight_expr`` and ``ld_score_expr`` must be row-indexed
       fields.
    *  ``chi_sq_exprs`` must be a single entry-indexed field
       (not a list of fields).
    *  ``n_samples_exprs`` must be a single entry-indexed field
       (not a list of fields).
    *  The ``phenotype`` field that keys the table returned by
       :func:`.ld_score_regression` will have values corresponding to the
       column keys of the input matrix table.

    This function returns a :class:`Table` with one row per set of summary
    statistics passed to the ``chi_sq_exprs`` argument. The following
    row-indexed fields are included in the table:

    *  **phenotype** (:py:data:`.tstr`) -- The name of the phenotype. The
       returned table is keyed by this field. See the notes below for
       details on the possible values of this field.
    *  **mean_chi_sq** (:py:data:`.tfloat64`) -- The mean chi-squared
       test statistic for the given phenotype.
    *  **intercept** (`Struct`) -- Contains fields:

       -  **estimate** (:py:data:`.tfloat64`) -- A point estimate of the
          intercept :math:`1 + Na`.
       -  **standard_error**  (:py:data:`.tfloat64`) -- An estimate of
          the standard error of this point estimate.

    *  **snp_heritability** (`Struct`) -- Contains fields:

       -  **estimate** (:py:data:`.tfloat64`) -- A point estimate of the
          SNP-heritability :math:`h_g^2`.
       -  **standard_error** (:py:data:`.tfloat64`) -- An estimate of
          the standard error of this point estimate.

    Warning
    -------
    :func:`.ld_score_regression` considers only the rows for which both row
    fields ``weight_expr`` and ``ld_score_expr`` are defined. Rows with missing
    values in either field are removed prior to fitting the LD score
    regression model.

    Parameters
    ----------
    weight_expr : :class:`.Float64Expression`
                  Row-indexed expression for the LD scores used to derive
                  variant weights in the model.
    ld_score_expr : :class:`.Float64Expression`
                    Row-indexed expression for the LD scores used as covariates
                    in the model.
    chi_sq_exprs : :class:`.Float64Expression` or :obj:`list` of
                        :class:`.Float64Expression`
                        One or more row-indexed (if table) or entry-indexed
                        (if matrix table) expressions for chi-squared
                        statistics resulting from genome-wide association
                        studies.
    n_samples_exprs: :class:`.NumericExpression` or :obj:`list` of
                     :class:`.NumericExpression`
                     One or more row-indexed (if table) or entry-indexed
                     (if matrix table) expressions indicating the number of
                     samples used in the studies that generated the test
                     statistics supplied to ``chi_sq_exprs``.
    n_blocks : :obj:`int`
               The number of blocks used in the jackknife approach to
               estimating standard errors.
    two_step_threshold : :obj:`int`
                         Variants with chi-squared statistics greater than this
                         value are excluded in the first step of the two-step
                         procedure used to fit the model.
    n_reference_panel_variants : :obj:`int`, optional
                                 Number of variants used to estimate the
                                 SNP-heritability :math:`h_g^2`.

    Returns
    -------
    :class:`.Table`
        Table keyed by ``phenotype`` with intercept and heritability estimates
        for each phenotype passed to the function."""

    chi_sq_exprs = wrap_to_list(chi_sq_exprs)
    n_samples_exprs = wrap_to_list(n_samples_exprs)

    assert ((len(chi_sq_exprs) == len(n_samples_exprs))
            or (len(n_samples_exprs) == 1))
    __k = 2  # number of covariates, including intercept

    ds = chi_sq_exprs[0]._indices.source

    analyze('ld_score_regression/weight_expr', weight_expr, ds._row_indices)
    analyze('ld_score_regression/ld_score_expr', ld_score_expr,
            ds._row_indices)

    # format input dataset
    if isinstance(ds, MatrixTable):
        if len(chi_sq_exprs) != 1:
            raise ValueError("""Only one chi_sq_expr allowed if originating
                from a matrix table.""")
        if len(n_samples_exprs) != 1:
            raise ValueError("""Only one n_samples_expr allowed if
                originating from a matrix table.""")

        col_key = list(ds.col_key)
        if len(col_key) != 1:
            raise ValueError("""Matrix table must be keyed by a single
                phenotype field.""")

        analyze('ld_score_regression/chi_squared_expr', chi_sq_exprs[0],
                ds._entry_indices)
        analyze('ld_score_regression/n_samples_expr', n_samples_exprs[0],
                ds._entry_indices)

        ds = ds._select_all(row_exprs={
            '__locus': ds.locus,
            '__alleles': ds.alleles,
            '__w_initial': weight_expr,
            '__w_initial_floor': hl.max(weight_expr, 1.0),
            '__x': ld_score_expr,
            '__x_floor': hl.max(ld_score_expr, 1.0)
        },
                            row_key=['__locus', '__alleles'],
                            col_exprs={'__y_name': ds[col_key[0]]},
                            col_key=['__y_name'],
                            entry_exprs={
                                '__y': chi_sq_exprs[0],
                                '__n': n_samples_exprs[0]
                            })
        ds = ds.annotate_entries(**{'__w': ds.__w_initial})

        ds = ds.filter_rows(
            hl.is_defined(ds.__locus)
            & hl.is_defined(ds.__alleles)
            & hl.is_defined(ds.__w_initial)
            & hl.is_defined(ds.__x))

    else:
        assert isinstance(ds, Table)
        for y in chi_sq_exprs:
            analyze('ld_score_regression/chi_squared_expr', y, ds._row_indices)
        for n in n_samples_exprs:
            analyze('ld_score_regression/n_samples_expr', n, ds._row_indices)

        ys = ['__y{:}'.format(i) for i, _ in enumerate(chi_sq_exprs)]
        ws = ['__w{:}'.format(i) for i, _ in enumerate(chi_sq_exprs)]
        ns = ['__n{:}'.format(i) for i, _ in enumerate(n_samples_exprs)]

        ds = ds.select(**dict(
            **{
                '__locus': ds.locus,
                '__alleles': ds.alleles,
                '__w_initial': weight_expr,
                '__x': ld_score_expr
            }, **{y: chi_sq_exprs[i]
                  for i, y in enumerate(ys)}, **{w: weight_expr
                                                 for w in ws}, **
            {n: n_samples_exprs[i]
             for i, n in enumerate(ns)}))
        ds = ds.key_by(ds.__locus, ds.__alleles)

        table_tmp_file = new_temp_file()
        ds.write(table_tmp_file)
        ds = hl.read_table(table_tmp_file)

        hts = [
            ds.select(
                **{
                    '__w_initial': ds.__w_initial,
                    '__w_initial_floor': hl.max(ds.__w_initial, 1.0),
                    '__x': ds.__x,
                    '__x_floor': hl.max(ds.__x, 1.0),
                    '__y_name': i,
                    '__y': ds[ys[i]],
                    '__w': ds[ws[i]],
                    '__n': hl.int(ds[ns[i]])
                }) for i, y in enumerate(ys)
        ]

        mts = [
            ht.to_matrix_table(row_key=['__locus', '__alleles'],
                               col_key=['__y_name'],
                               row_fields=[
                                   '__w_initial', '__w_initial_floor', '__x',
                                   '__x_floor'
                               ]) for ht in hts
        ]

        ds = mts[0]
        for i in range(1, len(ys)):
            ds = ds.union_cols(mts[i])

        ds = ds.filter_rows(
            hl.is_defined(ds.__locus)
            & hl.is_defined(ds.__alleles)
            & hl.is_defined(ds.__w_initial)
            & hl.is_defined(ds.__x))

    mt_tmp_file1 = new_temp_file()
    ds.write(mt_tmp_file1)
    mt = hl.read_matrix_table(mt_tmp_file1)

    if not n_reference_panel_variants:
        M = mt.count_rows()
    else:
        M = n_reference_panel_variants

    mt = mt.annotate_entries(__in_step1=(hl.is_defined(mt.__y)
                                         & (mt.__y < two_step_threshold)),
                             __in_step2=hl.is_defined(mt.__y))

    mt = mt.annotate_cols(__col_idx=hl.int(hl.scan.count()),
                          __m_step1=hl.agg.count_where(mt.__in_step1),
                          __m_step2=hl.agg.count_where(mt.__in_step2))

    col_keys = list(mt.col_key)

    ht = mt.localize_entries(entries_array_field_name='__entries',
                             columns_array_field_name='__cols')

    ht = ht.annotate(__entries=hl.rbind(
        hl.scan.array_agg(lambda entry: hl.scan.count_where(entry.__in_step1),
                          ht.__entries),
        lambda step1_indices: hl.map(
            lambda i: hl.rbind(
                hl.int(hl.or_else(step1_indices[i], 0)), ht.__cols[
                    i].__m_step1, ht.__entries[i], lambda step1_idx, m_step1,
                entry: hl.rbind(
                    hl.map(
                        lambda j: hl.int(hl.floor(j * (m_step1 / n_blocks))),
                        hl.range(0, n_blocks + 1)), lambda step1_separators: hl
                    .rbind(
                        hl.set(step1_separators).contains(step1_idx),
                        hl.sum(
                            hl.map(lambda s1: step1_idx >= s1, step1_separators
                                   )) - 1, lambda is_separator, step1_block:
                        entry.annotate(__step1_block=step1_block,
                                       __step2_block=hl.cond(
                                           ~entry.__in_step1 & is_separator,
                                           step1_block - 1, step1_block))))),
            hl.range(0, hl.len(ht.__entries)))))

    mt = ht._unlocalize_entries('__entries', '__cols', col_keys)

    mt_tmp_file2 = new_temp_file()
    mt.write(mt_tmp_file2)
    mt = hl.read_matrix_table(mt_tmp_file2)

    # initial coefficient estimates
    mt = mt.annotate_cols(__initial_betas=[
        1.0, (hl.agg.mean(mt.__y) - 1.0) / hl.agg.mean(mt.__x)
    ])
    mt = mt.annotate_cols(__step1_betas=mt.__initial_betas,
                          __step2_betas=mt.__initial_betas)

    # step 1 iteratively reweighted least squares
    for i in range(3):
        mt = mt.annotate_entries(__w=hl.cond(
            mt.__in_step1, 1.0 /
            (mt.__w_initial_floor * 2.0 *
             (mt.__step1_betas[0] + mt.__step1_betas[1] * mt.__x_floor)**2),
            0.0))
        mt = mt.annotate_cols(__step1_betas=hl.agg.filter(
            mt.__in_step1,
            hl.agg.linreg(y=mt.__y, x=[1.0, mt.__x], weight=mt.__w).beta))
        mt = mt.annotate_cols(__step1_h2=hl.max(
            hl.min(mt.__step1_betas[1] * M / hl.agg.mean(mt.__n), 1.0), 0.0))
        mt = mt.annotate_cols(__step1_betas=[
            mt.__step1_betas[0], mt.__step1_h2 * hl.agg.mean(mt.__n) / M
        ])

    # step 1 block jackknife
    mt = mt.annotate_cols(__step1_block_betas=hl.agg.array_agg(
        lambda i: hl.agg.filter(
            (mt.__step1_block != i) & mt.__in_step1,
            hl.agg.linreg(y=mt.__y, x=[1.0, mt.__x], weight=mt.__w).beta),
        hl.range(n_blocks)))

    mt = mt.annotate_cols(__step1_block_betas_bias_corrected=hl.map(
        lambda x: n_blocks * mt.__step1_betas - (n_blocks - 1) * x,
        mt.__step1_block_betas))

    mt = mt.annotate_cols(
        __step1_jackknife_mean=hl.map(
            lambda i: hl.mean(
                hl.map(lambda x: x[i], mt.__step1_block_betas_bias_corrected)),
            hl.range(0, __k)),
        __step1_jackknife_variance=hl.map(
            lambda i: (hl.sum(
                hl.map(lambda x: x[i]**2, mt.__step1_block_betas_bias_corrected
                       )) - hl.sum(
                           hl.map(lambda x: x[i], mt.
                                  __step1_block_betas_bias_corrected))**
                       2 / n_blocks) / (n_blocks - 1) / n_blocks,
            hl.range(0, __k)))

    # step 2 iteratively reweighted least squares
    for i in range(3):
        mt = mt.annotate_entries(__w=hl.cond(
            mt.__in_step2, 1.0 /
            (mt.__w_initial_floor * 2.0 *
             (mt.__step2_betas[0] + +mt.__step2_betas[1] * mt.__x_floor)**2),
            0.0))
        mt = mt.annotate_cols(__step2_betas=[
            mt.__step1_betas[0],
            hl.agg.filter(
                mt.__in_step2,
                hl.agg.linreg(
                    y=mt.__y -
                    mt.__step1_betas[0], x=[mt.__x], weight=mt.__w).beta[0])
        ])
        mt = mt.annotate_cols(__step2_h2=hl.max(
            hl.min(mt.__step2_betas[1] * M / hl.agg.mean(mt.__n), 1.0), 0.0))
        mt = mt.annotate_cols(__step2_betas=[
            mt.__step1_betas[0], mt.__step2_h2 * hl.agg.mean(mt.__n) / M
        ])

    # step 2 block jackknife
    mt = mt.annotate_cols(__step2_block_betas=hl.agg.array_agg(
        lambda i: hl.agg.filter((mt.__step2_block != i) & mt.__in_step2,
                                hl.agg.linreg(y=mt.__y - mt.__step1_betas[0],
                                              x=[mt.__x],
                                              weight=mt.__w).beta[0]),
        hl.range(n_blocks)))

    mt = mt.annotate_cols(__step2_block_betas_bias_corrected=hl.map(
        lambda x: n_blocks * mt.__step2_betas[1] - (n_blocks - 1) * x,
        mt.__step2_block_betas))

    mt = mt.annotate_cols(
        __step2_jackknife_mean=hl.mean(mt.__step2_block_betas_bias_corrected),
        __step2_jackknife_variance=(
            hl.sum(mt.__step2_block_betas_bias_corrected**2) -
            hl.sum(mt.__step2_block_betas_bias_corrected)**2 / n_blocks) /
        (n_blocks - 1) / n_blocks)

    # combine step 1 and step 2 block jackknifes
    mt = mt.annotate_entries(
        __step2_initial_w=1.0 /
        (mt.__w_initial_floor * 2.0 *
         (mt.__initial_betas[0] + +mt.__initial_betas[1] * mt.__x_floor)**2))

    mt = mt.annotate_cols(
        __final_betas=[mt.__step1_betas[0], mt.__step2_betas[1]],
        __c=(hl.agg.sum(mt.__step2_initial_w * mt.__x) /
             hl.agg.sum(mt.__step2_initial_w * mt.__x**2)))

    mt = mt.annotate_cols(__final_block_betas=hl.map(
        lambda i: (mt.__step2_block_betas[i] - mt.__c *
                   (mt.__step1_block_betas[i][0] - mt.__final_betas[0])),
        hl.range(0, n_blocks)))

    mt = mt.annotate_cols(__final_block_betas_bias_corrected=(
        n_blocks * mt.__final_betas[1] -
        (n_blocks - 1) * mt.__final_block_betas))

    mt = mt.annotate_cols(
        __final_jackknife_mean=[
            mt.__step1_jackknife_mean[0],
            hl.mean(mt.__final_block_betas_bias_corrected)
        ],
        __final_jackknife_variance=[
            mt.__step1_jackknife_variance[0],
            (hl.sum(mt.__final_block_betas_bias_corrected**2) -
             hl.sum(mt.__final_block_betas_bias_corrected)**2 / n_blocks) /
            (n_blocks - 1) / n_blocks
        ])

    # convert coefficient to heritability estimate
    mt = mt.annotate_cols(
        phenotype=mt.__y_name,
        mean_chi_sq=hl.agg.mean(mt.__y),
        intercept=hl.struct(estimate=mt.__final_betas[0],
                            standard_error=hl.sqrt(
                                mt.__final_jackknife_variance[0])),
        snp_heritability=hl.struct(
            estimate=(M / hl.agg.mean(mt.__n)) * mt.__final_betas[1],
            standard_error=hl.sqrt((M / hl.agg.mean(mt.__n))**2 *
                                   mt.__final_jackknife_variance[1])))

    # format and return results
    ht = mt.cols()
    ht = ht.key_by(ht.phenotype)
    ht = ht.select(ht.mean_chi_sq, ht.intercept, ht.snp_heritability)

    ht_tmp_file = new_temp_file()
    ht.write(ht_tmp_file)
    ht = hl.read_table(ht_tmp_file)

    return ht
Пример #20
0
    def test_annotate(self):
        schema = hl.tstruct(a=hl.tint32, b=hl.tint32, c=hl.tint32, d=hl.tint32, e=hl.tstr, f=hl.tarray(hl.tint32))

        rows = [{'a': 4, 'b': 1, 'c': 3, 'd': 5, 'e': "hello", 'f': [1, 2, 3]},
                {'a': 0, 'b': 5, 'c': 13, 'd': -1, 'e': "cat", 'f': []},
                {'a': 4, 'b': 2, 'c': 20, 'd': 3, 'e': "dog", 'f': [5, 6, 7]}]

        kt = hl.Table.parallelize(rows, schema)

        self.assertTrue(kt.annotate()._same(kt))

        result1 = convert_struct_to_dict(kt.annotate(foo=kt.a + 1,
                                                     foo2=kt.a).take(1)[0])

        self.assertDictEqual(result1, {'a': 4,
                                       'b': 1,
                                       'c': 3,
                                       'd': 5,
                                       'e': "hello",
                                       'f': [1, 2, 3],
                                       'foo': 5,
                                       'foo2': 4})

        result3 = convert_struct_to_dict(kt.annotate(
            x1=kt.f.map(lambda x: x * 2),
            x2=kt.f.map(lambda x: [x, x + 1]).flatmap(lambda x: x),
            x3=hl.min(kt.f),
            x4=hl.max(kt.f),
            x5=hl.sum(kt.f),
            x6=hl.product(kt.f),
            x7=kt.f.length(),
            x8=kt.f.filter(lambda x: x == 3),
            x9=kt.f[1:],
            x10=kt.f[:],
            x11=kt.f[1:2],
            x12=kt.f.map(lambda x: [x, x + 1]),
            x13=kt.f.map(lambda x: [[x, x + 1], [x + 2]]).flatmap(lambda x: x),
            x14=hl.cond(kt.a < kt.b, kt.c, hl.null(hl.tint32)),
            x15={1, 2, 3}
        ).take(1)[0])

        self.assertDictEqual(result3, {'a': 4,
                                       'b': 1,
                                       'c': 3,
                                       'd': 5,
                                       'e': "hello",
                                       'f': [1, 2, 3],
                                       'x1': [2, 4, 6], 'x2': [1, 2, 2, 3, 3, 4],
                                       'x3': 1, 'x4': 3, 'x5': 6, 'x6': 6, 'x7': 3, 'x8': [3],
                                       'x9': [2, 3], 'x10': [1, 2, 3], 'x11': [2],
                                       'x12': [[1, 2], [2, 3], [3, 4]],
                                       'x13': [[1, 2], [3], [2, 3], [4], [3, 4], [5]],
                                       'x14': None, 'x15': set([1, 2, 3])})
        kt.annotate(
            x1=kt.a + 5,
            x2=5 + kt.a,
            x3=kt.a + kt.b,
            x4=kt.a - 5,
            x5=5 - kt.a,
            x6=kt.a - kt.b,
            x7=kt.a * 5,
            x8=5 * kt.a,
            x9=kt.a * kt.b,
            x10=kt.a / 5,
            x11=5 / kt.a,
            x12=kt.a / kt.b,
            x13=-kt.a,
            x14=+kt.a,
            x15=kt.a == kt.b,
            x16=kt.a == 5,
            x17=5 == kt.a,
            x18=kt.a != kt.b,
            x19=kt.a != 5,
            x20=5 != kt.a,
            x21=kt.a > kt.b,
            x22=kt.a > 5,
            x23=5 > kt.a,
            x24=kt.a >= kt.b,
            x25=kt.a >= 5,
            x26=5 >= kt.a,
            x27=kt.a < kt.b,
            x28=kt.a < 5,
            x29=5 < kt.a,
            x30=kt.a <= kt.b,
            x31=kt.a <= 5,
            x32=5 <= kt.a,
            x33=(kt.a == 0) & (kt.b == 5),
            x34=(kt.a == 0) | (kt.b == 5),
            x35=False,
            x36=True
        )
Пример #21
0
def main():
    args = parse_args()

    tables = []
    for i, path in enumerate(args.paths):

        ht = import_SJ_out_tab(path)
        ht = ht.key_by("chrom", "start_1based", "end_1based")

        if args.normalize_read_counts:
            ht = ht.annotate_globals(
                unique_reads_in_sample=ht.aggregate(hl.agg.sum(
                    ht.unique_reads)),
                multi_mapped_reads_in_sample=ht.aggregate(
                    hl.agg.sum(ht.multi_mapped_reads)),
            )

        # add 'interval' column
        #ht = ht.annotate(interval=hl.interval(
        #    hl.locus(ht.chrom, ht.start_1based, reference_genome=reference_genome),
        #    hl.locus(ht.chrom, ht.end_1based, reference_genome=reference_genome),))

        tables.append(ht)

    # compute mean
    if args.normalize_read_counts:
        mean_unique_reads_in_sample = sum(
            [hl.eval(ht.unique_reads_in_sample)
             for ht in tables]) / float(len(tables))
        mean_multi_mapped_reads_in_sample = sum(
            [hl.eval(ht.multi_mapped_reads_in_sample)
             for ht in tables]) / float(len(tables))
        print(
            f"mean_unique_reads_in_sample: {mean_unique_reads_in_sample:01f}, mean_multi_mapped_reads_in_sample: {mean_multi_mapped_reads_in_sample:01f}"
        )

    combined_ht = None
    for i, ht in enumerate(tables):
        print(f"Processing table #{i} out of {len(tables)}")

        if args.normalize_read_counts:
            unique_reads_multiplier = mean_unique_reads_in_sample / float(
                hl.eval(ht.unique_reads_in_sample))
            multi_mapped_reads_multiplier = mean_multi_mapped_reads_in_sample / float(
                hl.eval(ht.multi_mapped_reads_in_sample))
            print(
                f"unique_reads_multiplier: {unique_reads_multiplier:01f}, multi_mapped_reads_multiplier: {multi_mapped_reads_multiplier:01f}"
            )

        ht = ht.annotate(
            strand_counter=hl.or_else(
                hl.switch(ht.strand).when(1, 1).when(2, -1).or_missing(), 0),
            num_samples_with_this_junction=1,
        )

        if args.normalize_read_counts:
            ht = ht.annotate(
                unique_reads=hl.int32(ht.unique_reads *
                                      unique_reads_multiplier),
                multi_mapped_reads=hl.int32(ht.multi_mapped_reads *
                                            multi_mapped_reads_multiplier),
            )

        if combined_ht is None:
            combined_ht = ht
            continue

        print("----")
        print_stats(path, ht)

        combined_ht = combined_ht.join(ht, how="outer")
        combined_ht = combined_ht.transmute(
            strand=hl.or_else(
                combined_ht.strand, combined_ht.strand_1
            ),  ## in rare cases, the strand for the same junction may differ across samples, so use a 2-step process that assigns strand based on majority of samples
            strand_counter=hl.sum([
                combined_ht.strand_counter, combined_ht.strand_counter_1
            ]),  # samples vote on whether strand = 1 (eg. '+') or 2 (eg. '-')
            intron_motif=hl.or_else(combined_ht.intron_motif,
                                    combined_ht.intron_motif_1
                                    ),  ## double-check that left == right?
            known_splice_junction=hl.or_else(
                hl.cond((combined_ht.known_splice_junction == 1) |
                        (combined_ht.known_splice_junction_1 == 1), 1, 0),
                0),  ## double-check that left == right?
            unique_reads=hl.sum(
                [combined_ht.unique_reads, combined_ht.unique_reads_1]),
            multi_mapped_reads=hl.sum([
                combined_ht.multi_mapped_reads,
                combined_ht.multi_mapped_reads_1
            ]),
            maximum_overhang=hl.max(
                [combined_ht.maximum_overhang,
                 combined_ht.maximum_overhang_1]),
            num_samples_with_this_junction=hl.sum([
                combined_ht.num_samples_with_this_junction,
                combined_ht.num_samples_with_this_junction_1
            ]),
        )

        combined_ht = combined_ht.checkpoint(
            f"checkpoint{i % 2}.ht", overwrite=True)  #, _read_if_exists=True)

    total_junctions_count = combined_ht.count()
    strand_conflicts_count = combined_ht.filter(
        hl.abs(combined_ht.strand_counter) /
        hl.float(combined_ht.num_samples_with_this_junction) < 0.1,
        keep=True).count()

    # set final strand value to 1 (eg. '+') or 2 (eg. '-') or 0 (eg. uknown) based on the setting in the majority of samples
    combined_ht = combined_ht.annotate(
        strand=hl.case().when(combined_ht.strand_counter > 0, 1).when(
            combined_ht.strand_counter < 0, 2).default(0))

    combined_ht = combined_ht.annotate_globals(combined_tables=args.paths,
                                               n_combined_tables=len(
                                                   args.paths))

    if strand_conflicts_count:
        print(
            f"WARNING: Found {strand_conflicts_count} strand_conflicts out of {total_junctions_count} total_junctions"
        )

    # write as HT
    combined_ht = combined_ht.checkpoint(
        f"combined.SJ.out.ht", overwrite=True)  #, _read_if_exists=True)

    ## write as tsv
    output_prefix = f"combined.{len(tables)}_samples{'.normalized_counts' if args.normalize_read_counts else ''}"
    combined_ht = combined_ht.key_by()
    combined_ht.export(f"{output_prefix}.with_header.combined.SJ.out.tab",
                       header=True)
    combined_ht = combined_ht.select(
        "chrom",
        "start_1based",
        "end_1based",
        "strand",
        "intron_motif",
        "known_splice_junction",
        "unique_reads",
        "multi_mapped_reads",
        "maximum_overhang",
    )
    combined_ht.export(f"{output_prefix}.SJ.out.tab", header=False)

    print(
        f"unique_reads_in combined table: {combined_ht.aggregate(hl.agg.sum(combined_ht.unique_reads))}"
    )
Пример #22
0
def compute_from_full_mt(chr20: bool, overwrite: bool):
    mt = get_gnomad_data('exomes', adj=True, release_samples=True)
    freq_ht = hl.read_table(annotations_ht_path('exomes', 'frequencies'))
    vep_ht = hl.read_table(annotations_ht_path('exomes', 'vep'))
    rf_ht = hl.read_table(annotations_ht_path('exomes', 'rf'))

    if chr20:
        mt, freq_ht, vep_ht, rf_ht = filter_to_chr20([mt, freq_ht, vep_ht, rf_ht])

    vep_ht = vep_ht.annotate(
        vep=get_worst_gene_csq_code_expr(vep_ht.vep).values()
    )

    freq_ht = freq_ht.select(
        freq=freq_ht.freq[:10],
        popmax=freq_ht.popmax
    )
    freq_meta = hl.eval(freq_ht.globals.freq_meta)
    freq_dict = {f['pop']: i for i, f in enumerate(freq_meta[:10]) if 'pop' in f}
    freq_dict['all'] = 0
    freq_dict = hl.literal(freq_dict)
    mt = mt.annotate_rows(
        **freq_ht[mt.row_key],
        vep=vep_ht[mt.row_key].vep,
        filters=rf_ht[mt.row_key].filters
    )
    mt = mt.filter_rows(
        (mt.freq[0].AF <= MAX_FREQ) &
        (hl.len(mt.vep) > 0) &
        (hl.len(mt.filters) == 0)
    )

    mt = mt.filter_entries(mt.GT.is_non_ref())
    mt = mt.select_entries(
        is_het=mt.GT.is_het()
    )

    mt = mt.explode_rows(mt.vep)
    mt = mt.transmute_rows(**mt.vep)

    mt = mt.annotate_cols(
        pop=['all', mt.meta.pop]
    )
    mt = mt.explode_cols(mt.pop)

    mt = mt.group_rows_by(
        'gene_id'
    ).aggregate_rows(
        gene_symbol=hl.agg.take(mt.gene_symbol, 1)[0]
    ).aggregate(
        counts=hl.agg.filter(
            hl.if_else(
                mt.pop == 'all',
                hl.is_defined(mt.popmax) & (mt.popmax.AF <= MAX_FREQ),
                mt.freq[freq_dict[mt.pop]].AF <= MAX_FREQ
            ),
            hl.agg.group_by(
                hl.if_else(
                    mt.pop == 'all',
                    mt.popmax.AF > 0.001,
                    mt.freq[freq_dict[mt.pop]].AF > 0.001
                ),
                hl.struct(
                    hom_csq=hl.agg.filter(~mt.is_het, hl.agg.min(mt.csq)),
                    het_csq=hl.agg.filter(mt.is_het, hl.agg.min(mt.csq)),
                    het_het_csq=hl.sorted(
                        hl.array(
                            hl.agg.filter(mt.is_het, hl.agg.counter(mt.csq))
                        ),
                        key=lambda x: x[0]
                    ).scan(
                        lambda i, j: (j[0], i[1] + j[1]),
                        (0, 0)
                    ).find(
                        lambda x: x[1] > 1
                    )[0]
                )
            )
        )
    )

    mt = mt.annotate_entries(
        counts=hl.struct(
            all=hl.struct(
                hom_csq=hl.min(mt.counts.get(True).hom_csq, mt.counts.get(False).hom_csq),
                het_csq=hl.min(mt.counts.get(True).het_csq, mt.counts.get(False).het_csq),
                het_het_csq=hl.min(
                    mt.counts.get(True).het_het_csq,
                    mt.counts.get(False).het_het_csq,
                    hl.or_missing(
                        hl.is_defined(mt.counts.get(True).het_csq) & hl.is_defined(mt.counts.get(False).het_csq),
                        hl.max(mt.counts.get(True).het_csq, mt.counts.get(False).het_csq)
                    )
                ),
            ),
            af_le_0_001=mt.counts.get(False)
        )
    )

    mt = mt.checkpoint('gs://gnomad-tmp/compound_hets/het_and_hom_per_gene{}.1.mt'.format(
        '.chr20' if chr20 else ''
    ), overwrite=True)

    gene_ht = mt.annotate_rows(
        row_counts=hl.flatten([
            hl.array(
                hl.agg.group_by(
                    mt.pop,
                    hl.struct(
                        csq=csq,
                        af=af,
                        n_hom=hl.agg.count_where(mt.counts[af].hom_csq == csq_i),
                        n_het=hl.agg.count_where(mt.counts[af].het_csq == csq_i),
                        n_het_het=hl.agg.count_where(mt.counts[af].het_het_csq == csq_i)
                    )
                )
            ).filter(
                lambda x: (x[1].n_het > 0) | (x[1].n_hom > 0) | (x[1].n_het_het > 0)
            ).map(
                lambda x: x[1].annotate(
                    pop=x[0]
                )
            )
            for csq_i, csq in enumerate(CSQ_CODES)
            for af in ['all', 'af_le_0_001']
        ])
    ).rows()

    gene_ht = gene_ht.explode('row_counts')
    gene_ht = gene_ht.select(
        'gene_symbol',
        **gene_ht.row_counts
    )

    gene_ht.describe()

    gene_ht = gene_ht.checkpoint(
        'gs://gnomad-lfran/compound_hets/het_and_hom_per_gene{}.ht'.format(
            '.chr20' if chr20 else ''
        ),
        overwrite=overwrite
    )

    gene_ht.flatten().export('gs://gnomad-lfran/compound_hets/het_and_hom_per_gene{}.tsv.gz'.format(
        '.chr20' if chr20 else ''
    ))
Пример #23
0
def de_novo(mt: MatrixTable,
            pedigree: Pedigree,
            pop_frequency_prior,
            *,
            min_gq: int = 20,
            min_p: float = 0.05,
            max_parent_ab: float = 0.05,
            min_child_ab: float = 0.20,
            min_dp_ratio: float = 0.10) -> Table:
    r"""Call putative *de novo* events from trio data.

    .. include:: ../_templates/req_tstring.rst

    .. include:: ../_templates/req_tvariant.rst

    .. include:: ../_templates/req_biallelic.rst

    Examples
    --------

    Call de novo events:

    >>> pedigree = hl.Pedigree.read('data/trios.fam')
    >>> priors = hl.import_table('data/gnomadFreq.tsv', impute=True)
    >>> priors = priors.transmute(**hl.parse_variant(priors.Variant)).key_by('locus', 'alleles')
    >>> de_novo_results = hl.de_novo(dataset, pedigree, pop_frequency_prior=priors[dataset.row_key].AF)

    Notes
    -----
    This method assumes the GATK high-throughput sequencing fields exist:
    `GT`, `AD`, `DP`, `GQ`, `PL`.

    This method replicates the functionality of `Kaitlin Samocha's de novo
    caller <https://github.com/ksamocha/de_novo_scripts>`__. The version
    corresponding to git commit ``bde3e40`` is implemented in Hail with her
    permission and assistance.

    This method produces a :class:`.Table` with the following fields:

     - `locus` (``locus``) -- Variant locus.
     - `alleles` (``array<str>``) -- Variant alleles.
     - `id` (``str``) -- Proband sample ID.
     - `prior` (``float64``) -- Site frequency prior. It is the maximum of:
       the computed dataset alternate allele frequency, the
       `pop_frequency_prior` parameter, and the global prior
       ``1 / 3e7``.
     - `proband` (``struct``) -- Proband column fields from `mt`.
     - `father` (``struct``) -- Father column fields from `mt`.
     - `mother` (``struct``) -- Mother column fields from `mt`.
     - `proband_entry` (``struct``) -- Proband entry fields from `mt`.
     - `father_entry` (``struct``) -- Father entry fields from `mt`.
     - `proband_entry` (``struct``) -- Mother entry fields from `mt`.
     - `is_female` (``bool``) -- ``True`` if proband is female.
     - `p_de_novo` (``float64``) -- Unfiltered posterior probability
       that the event is *de novo* rather than a missed heterozygous
       event in a parent.
     - `confidence` (``str``) Validation confidence. One of: ``'HIGH'``,
       ``'MEDIUM'``, ``'LOW'``.

    The key of the table is ``['locus', 'alleles', 'id']``.

    The model looks for de novo events in which both parents are homozygous
    reference and the proband is a heterozygous. The model makes the simplifying
    assumption that when this configuration ``x = (AA, AA, AB)`` of calls
    occurs, exactly one of the following is true:

     - ``d``: a de novo mutation occurred in the proband and all calls are
       accurate.
     - ``m``: at least one parental allele is actually heterozygous and
       the proband call is accurate.

    We can then estimate the posterior probability of a de novo mutation as:

    .. math::

        \mathrm{P_{\text{de novo}}} = \frac{\mathrm{P}(d\,|\,x)}{\mathrm{P}(d\,|\,x) + \mathrm{P}(m\,|\,x)}

    Applying Bayes rule to the numerator and denominator yields

    .. math::

        \frac{\mathrm{P}(x\,|\,d)\,\mathrm{P}(d)}{\mathrm{P}(x\,|\,d)\,\mathrm{P}(d) +
        \mathrm{P}(x\,|\,m)\,\mathrm{P}(m)}

    The prior on de novo mutation is estimated from the rate in the literature:

    .. math::

        \mathrm{P}(d) = \frac{1 \text{mutation}}{30,000,000\, \text{bases}}

    The prior used for at least one alternate allele between the parents
    depends on the alternate allele frequency:

    .. math::

        \mathrm{P}(m) = 1 - (1 - AF)^4

    The likelihoods :math:`\mathrm{P}(x\,|\,d)` and :math:`\mathrm{P}(x\,|\,m)`
    are computed from the PL (genotype likelihood) fields using these
    factorizations:

    .. math::

        \mathrm{P}(x = (AA, AA, AB) \,|\,d) = \Big(
        &\mathrm{P}(x_{\mathrm{father}} = AA \,|\, \mathrm{father} = AA) \\
        \cdot &\mathrm{P}(x_{\mathrm{mother}} = AA \,|\, \mathrm{mother} =
        AA) \\ \cdot &\mathrm{P}(x_{\mathrm{proband}} = AB \,|\,
        \mathrm{proband} = AB) \Big)

    .. math::

        \mathrm{P}(x = (AA, AA, AB) \,|\,m) = \Big( &
        \mathrm{P}(x_{\mathrm{father}} = AA \,|\, \mathrm{father} = AB)
        \cdot \mathrm{P}(x_{\mathrm{mother}} = AA \,|\, \mathrm{mother} =
        AA) \\ + \, &\mathrm{P}(x_{\mathrm{father}} = AA \,|\,
        \mathrm{father} = AA) \cdot \mathrm{P}(x_{\mathrm{mother}} = AA
        \,|\, \mathrm{mother} = AB) \Big) \\ \cdot \,
        &\mathrm{P}(x_{\mathrm{proband}} = AB \,|\, \mathrm{proband} = AB)

    (Technically, the second factorization assumes there is exactly (rather
    than at least) one alternate allele among the parents, which may be
    justified on the grounds that it is typically the most likely case by far.)

    While this posterior probability is a good metric for grouping putative de
    novo mutations by validation likelihood, there exist error modes in
    high-throughput sequencing data that are not appropriately accounted for by
    the phred-scaled genotype likelihoods. To this end, a number of hard filters
    are applied in order to assign validation likelihood.

    These filters are different for SNPs and insertions/deletions. In the below
    rules, the following variables are used:

     - ``DR`` refers to the ratio of the read depth in the proband to the
       combined read depth in the parents.
     - ``AB`` refers to the read allele balance of the proband (number of
       alternate reads divided by total reads).
     - ``AC`` refers to the count of alternate alleles across all individuals
       in the dataset at the site.
     - ``p`` refers to :math:`\mathrm{P_{\text{de novo}}}`.
     - ``min_p`` refers to the ``min_p`` function parameter.

    HIGH-quality SNV:

    .. code-block:: text

        p > 0.99 && AB > 0.3 && DR > 0.2
            or
        p > 0.99 && AB > 0.3 && AC == 1

    MEDIUM-quality SNV:

    .. code-block:: text

        p > 0.5 && AB > 0.3
            or
        p > 0.5 && AB > 0.2 && AC == 1

    LOW-quality SNV:

    .. code-block:: text

        p > min_p && AB > 0.2

    HIGH-quality indel:

    .. code-block:: text

        p > 0.99 && AB > 0.3 && DR > 0.2
            or
        p > 0.99 && AB > 0.3 && AC == 1

    MEDIUM-quality indel:

    .. code-block:: text

        p > 0.5 && AB > 0.3
            or
        p > 0.5 && AB > 0.2 and AC == 1

    LOW-quality indel:

    .. code-block:: text

        p > min_p && AB > 0.2

    Additionally, de novo candidates are not considered if the proband GQ is
    smaller than the ``min_gq`` parameter, if the proband allele balance is
    lower than the ``min_child_ab`` parameter, if the depth ratio between the
    proband and parents is smaller than the ``min_depth_ratio`` parameter, or if
    the allele balance in a parent is above the ``max_parent_ab`` parameter.

    Parameters
    ----------
    mt : :class:`.MatrixTable`
        High-throughput sequencing dataset.
    pedigree : :class:`.Pedigree`
        Sample pedigree.
    pop_frequency_prior : :class:`.Float64Expression`
        Expression for population alternate allele frequency prior.
    min_gq
        Minimum proband GQ to be considered for *de novo* calling.
    min_p
        Minimum posterior probability to be considered for *de novo* calling.
    max_parent_ab
        Maximum parent allele balance.
    min_child_ab
        Minimum proband allele balance/
    min_dp_ratio
        Minimum ratio between proband read depth and parental read depth.

    Returns
    -------
    :class:`.Table`
    """
    DE_NOVO_PRIOR = 1 / 30000000
    MIN_POP_PRIOR = 100 / 30000000

    required_entry_fields = {'GT', 'AD', 'DP', 'GQ', 'PL'}
    missing_fields = required_entry_fields - set(mt.entry)
    if missing_fields:
        raise ValueError(f"'de_novo': expected 'MatrixTable' to have at least {required_entry_fields}, "
                         f"missing {missing_fields}")

    mt = mt.annotate_rows(__prior=pop_frequency_prior,
                          __alt_alleles=hl.agg.sum(mt.GT.n_alt_alleles()),
                          __total_alleles=2 * hl.agg.sum(hl.is_defined(mt.GT)))
    # subtract 1 from __alt_alleles to correct for the observed genotype
    mt = mt.annotate_rows(__site_freq=hl.max((mt.__alt_alleles - 1) / mt.__total_alleles, mt.__prior, MIN_POP_PRIOR))
    mt = require_biallelic(mt, 'de_novo')

    # FIXME check that __site_freq is between 0 and 1 when possible in expr
    tm = trio_matrix(mt, pedigree, complete_trios=True)

    autosomal = tm.locus.in_autosome_or_par() | (tm.locus.in_x_nonpar() & tm.is_female)
    hemi_x = tm.locus.in_x_nonpar() & ~tm.is_female
    hemi_y = tm.locus.in_y_nonpar() & ~tm.is_female
    hemi_mt = tm.locus.in_mito() & tm.is_female

    is_snp = hl.is_snp(tm.alleles[0], tm.alleles[1])
    n_alt_alleles = tm.__alt_alleles
    prior = tm.__site_freq
    het_hom_hom = tm.proband_entry.GT.is_het() & tm.father_entry.GT.is_hom_ref() & tm.mother_entry.GT.is_hom_ref()
    kid_ad_fail = tm.proband_entry.AD[1] / hl.sum(tm.proband_entry.AD) < min_child_ab

    failure = hl.null(hl.tstruct(p_de_novo=hl.tfloat64, confidence=hl.tstr))

    kid = tm.proband_entry
    dad = tm.father_entry
    mom = tm.mother_entry

    kid_linear_pl = 10 ** (-kid.PL / 10)
    kid_pp = hl.bind(lambda x: x / hl.sum(x), kid_linear_pl)

    dad_linear_pl = 10 ** (-dad.PL / 10)
    dad_pp = hl.bind(lambda x: x / hl.sum(x), dad_linear_pl)

    mom_linear_pl = 10 ** (-mom.PL / 10)
    mom_pp = hl.bind(lambda x: x / hl.sum(x), mom_linear_pl)

    kid_ad_ratio = kid.AD[1] / hl.sum(kid.AD)
    dp_ratio = kid.DP / (dad.DP + mom.DP)

    def call_auto(kid_pp, dad_pp, mom_pp, kid_ad_ratio):
        p_data_given_dn = dad_pp[0] * mom_pp[0] * kid_pp[1] * DE_NOVO_PRIOR
        p_het_in_parent = 1 - (1 - prior) ** 4
        p_data_given_missed_het = (dad_pp[1] * mom_pp[0] + dad_pp[0] * mom_pp[1]) * kid_pp[1] * p_het_in_parent
        p_de_novo = p_data_given_dn / (p_data_given_dn + p_data_given_missed_het)

        def solve(p_de_novo):
            return (
                hl.case()
                    .when(kid.GQ < min_gq, failure)
                    .when((kid.DP / (dad.DP + mom.DP) < min_dp_ratio) |
                          ~(kid_ad_ratio >= min_child_ab), failure)
                    .when((hl.sum(mom.AD) == 0) | (hl.sum(dad.AD) == 0), failure)
                    .when((mom.AD[1] / hl.sum(mom.AD) > max_parent_ab) |
                          (dad.AD[1] / hl.sum(dad.AD) > max_parent_ab), failure)
                    .when(p_de_novo < min_p, failure)
                    .when(~is_snp, hl.case()
                          .when((p_de_novo > 0.99) & (kid_ad_ratio > 0.3) & (n_alt_alleles == 1),
                                hl.struct(p_de_novo=p_de_novo, confidence='HIGH'))
                          .when((p_de_novo > 0.5) & (kid_ad_ratio > 0.3) & (n_alt_alleles <= 5),
                                hl.struct(p_de_novo=p_de_novo, confidence='MEDIUM'))
                          .when((p_de_novo > 0.05) & (kid_ad_ratio > 0.2),
                                hl.struct(p_de_novo=p_de_novo, confidence='LOW'))
                          .or_missing())
                    .default(hl.case()
                             .when(((p_de_novo > 0.99) & (kid_ad_ratio > 0.3) & (dp_ratio > 0.2)) |
                                   ((p_de_novo > 0.99) & (kid_ad_ratio > 0.3) & (n_alt_alleles == 1)) |
                                   ((p_de_novo > 0.5) & (kid_ad_ratio > 0.3) & (n_alt_alleles < 10) & (kid.DP > 10)),
                                   hl.struct(p_de_novo=p_de_novo, confidence='HIGH'))
                             .when((p_de_novo > 0.5) & ((kid_ad_ratio > 0.3) | (n_alt_alleles == 1)),
                                   hl.struct(p_de_novo=p_de_novo, confidence='MEDIUM'))
                             .when((p_de_novo > 0.05) & (kid_ad_ratio > 0.2),
                                   hl.struct(p_de_novo=p_de_novo, confidence='LOW'))
                             .or_missing()
                             )
            )

        return hl.bind(solve, p_de_novo)

    def call_hemi(kid_pp, parent, parent_pp, kid_ad_ratio):
        p_data_given_dn = parent_pp[0] * kid_pp[1] * DE_NOVO_PRIOR
        p_het_in_parent = 1 - (1 - prior) ** 4
        p_data_given_missed_het = (parent_pp[1] + parent_pp[2]) * kid_pp[2] * p_het_in_parent
        p_de_novo = p_data_given_dn / (p_data_given_dn + p_data_given_missed_het)

        def solve(p_de_novo):
            return (
                hl.case()
                    .when(kid.GQ < min_gq, failure)
                    .when((kid.DP / (parent.DP) < min_dp_ratio) |
                          (kid_ad_ratio < min_child_ab), failure)
                    .when((hl.sum(parent.AD) == 0), failure)
                    .when(parent.AD[1] / hl.sum(parent.AD) > max_parent_ab, failure)
                    .when(p_de_novo < min_p, failure)
                    .when(~is_snp, hl.case()
                          .when((p_de_novo > 0.99) & (kid_ad_ratio > 0.3) & (n_alt_alleles == 1),
                                hl.struct(p_de_novo=p_de_novo, confidence='HIGH'))
                          .when((p_de_novo > 0.5) & (kid_ad_ratio > 0.3) & (n_alt_alleles <= 5),
                                hl.struct(p_de_novo=p_de_novo, confidence='MEDIUM'))
                          .when((p_de_novo > 0.05) & (kid_ad_ratio > 0.3),
                                hl.struct(p_de_novo=p_de_novo, confidence='LOW'))
                          .or_missing())
                    .default(hl.case()
                             .when(((p_de_novo > 0.99) & (kid_ad_ratio > 0.3) & (dp_ratio > 0.2)) |
                                   ((p_de_novo > 0.99) & (kid_ad_ratio > 0.3) & (n_alt_alleles == 1)) |
                                   ((p_de_novo > 0.5) & (kid_ad_ratio > 0.3) & (n_alt_alleles < 10) & (kid.DP > 10)),
                                   hl.struct(p_de_novo=p_de_novo, confidence='HIGH'))
                             .when((p_de_novo > 0.5) & ((kid_ad_ratio > 0.3) | (n_alt_alleles == 1)),
                                   hl.struct(p_de_novo=p_de_novo, confidence='MEDIUM'))
                             .when((p_de_novo > 0.05) & (kid_ad_ratio > 0.2),
                                   hl.struct(p_de_novo=p_de_novo, confidence='LOW'))
                             .or_missing()
                             )
            )

        return hl.bind(solve, p_de_novo)

    de_novo_call = (
        hl.case()
            .when(~het_hom_hom | kid_ad_fail, failure)
            .when(autosomal, hl.bind(call_auto, kid_pp, dad_pp, mom_pp, kid_ad_ratio))
            .when(hemi_x | hemi_mt, hl.bind(call_hemi, kid_pp, mom, mom_pp, kid_ad_ratio))
            .when(hemi_y, hl.bind(call_hemi, kid_pp, dad, dad_pp, kid_ad_ratio))
            .or_missing()
    )

    tm = tm.annotate_entries(__call=de_novo_call)
    tm = tm.filter_entries(hl.is_defined(tm.__call))
    entries = tm.entries()
    return (entries.select('__site_freq',
                           'proband',
                           'father',
                           'mother',
                           'proband_entry',
                           'father_entry',
                           'mother_entry',
                           'is_female',
                           **entries.__call)
            .rename({'__site_freq': 'prior'}))