Esempio n. 1
0
 def merge_alleles(alleles):
     from hail.expr.functions import _num_allele_type, _allele_ints
     return hl.rbind(
         alleles.map(lambda a: hl.or_else(a[0], ''))
                .fold(lambda s, t: hl.cond(hl.len(s) > hl.len(t), s, t), ''),
         lambda ref:
         hl.rbind(
             alleles.map(
                 lambda al: hl.rbind(
                     al[0],
                     lambda r:
                     hl.array([ref]).extend(
                         al[1:].map(
                             lambda a:
                             hl.rbind(
                                 _num_allele_type(r, a),
                                 lambda at:
                                 hl.cond(
                                     (_allele_ints['SNP'] == at) |
                                     (_allele_ints['Insertion'] == at) |
                                     (_allele_ints['Deletion'] == at) |
                                     (_allele_ints['MNP'] == at) |
                                     (_allele_ints['Complex'] == at),
                                     a + ref[hl.len(r):],
                                     a)))))),
             lambda lal:
             hl.struct(
                 globl=hl.array([ref]).extend(hl.array(hl.set(hl.flatten(lal)).remove(ref))),
                 local=lal)))
Esempio n. 2
0
def combine(ts):
    # pylint: disable=protected-access
    tmp = ts.annotate(
        alleles=merge_alleles(ts.data.map(lambda d: d.alleles)),
        rsid=hl.find(hl.is_defined, ts.data.map(lambda d: d.rsid)),
        filters=hl.set(hl.flatten(ts.data.map(lambda d: hl.array(d.filters)))),
        info=hl.struct(
            DP=hl.sum(ts.data.map(lambda d: d.info.DP)),
            MQ_DP=hl.sum(ts.data.map(lambda d: d.info.MQ_DP)),
            QUALapprox=hl.sum(ts.data.map(lambda d: d.info.QUALapprox)),
            RAW_MQ=hl.sum(ts.data.map(lambda d: d.info.RAW_MQ)),
            VarDP=hl.sum(ts.data.map(lambda d: d.info.VarDP)),
            SB=hl.array([
                hl.sum(ts.data.map(lambda d: d.info.SB[0])),
                hl.sum(ts.data.map(lambda d: d.info.SB[1])),
                hl.sum(ts.data.map(lambda d: d.info.SB[2])),
                hl.sum(ts.data.map(lambda d: d.info.SB[3]))
            ])))
    tmp = tmp.annotate(
        __entries=hl.bind(
            lambda combined_allele_index:
            hl.range(0, hl.len(tmp.data)).flatmap(
                lambda i:
                hl.cond(hl.is_missing(tmp.data[i].__entries),
                        hl.range(0, hl.len(tmp.g[i].__cols))
                          .map(lambda _: hl.null(tmp.data[i].__entries.dtype.element_type)),
                        hl.bind(
                            lambda old_to_new: tmp.data[i].__entries.map(lambda e: renumber_entry(e, old_to_new)),
                            hl.range(0, hl.len(tmp.data[i].alleles)).map(
                                lambda j: combined_allele_index[tmp.data[i].alleles[j]])))),
            hl.dict(hl.range(0, hl.len(tmp.alleles)).map(
                lambda j: hl.tuple([tmp.alleles[j], j])))))
    tmp = tmp.annotate_globals(__cols=hl.flatten(tmp.g.map(lambda g: g.__cols)))

    return tmp.drop('data', 'g')
Esempio n. 3
0
    def recur_expr(expr, path):
        d = {}
        missingness = append_agg(hl.agg.count_where(hl.is_missing(expr)))
        d['type'] = lambda _: str(expr.dtype)
        d['missing'] = lambda \
                results: f'{results[missingness]} values ({pct(results[missingness] / results[count])})'

        t = expr.dtype

        if t in (hl.tint32, hl.tint64, hl.tfloat32, hl.tfloat64):
            stats = append_agg(hl.agg.stats(expr))
            if t in (hl.tint32, hl.tint64):
                d['minimum'] = lambda results: format(map_int(results[stats]['min']))
                d['maximum'] = lambda results: format(map_int(results[stats]['max']))
                d['sum'] = lambda results: format(map_int(results[stats]['sum']))
            else:
                d['minimum'] = lambda results: format(results[stats]['min'])
                d['maximum'] = lambda results: format(results[stats]['max'])
                d['sum'] = lambda results: format(results[stats]['sum'])
            d['mean'] = lambda results: format(results[stats]['mean'])
            d['stdev'] = lambda results: format(results[stats]['stdev'])
        elif t == hl.tbool:
            counter = append_agg(hl.agg.filter(hl.is_defined(expr), hl.agg.counter(expr)))
            d['counts'] = lambda results: format(results[counter])
        elif t == hl.tstr:
            size = append_agg(hl.agg.stats(hl.len(expr)))
            take = append_agg(hl.agg.filter(hl.is_defined(expr), hl.agg.take(expr, 5)))
            d['minimum size'] = lambda results: format(map_int(results[size]['min']))
            d['maximum size'] = lambda results: format(map_int(results[size]['max']))
            d['mean size'] = lambda results: format(results[size]['mean'])
            d['sample values'] = lambda results: format(results[take])
        elif t == hl.tcall:
            ploidy_counts = append_agg(hl.agg.filter(hl.is_defined(expr), hl.agg.counter(expr.ploidy)))
            phased_counts = append_agg(hl.agg.filter(hl.is_defined(expr), hl.agg.counter(expr.phased)))
            n_hom_ref = append_agg(hl.agg.count_where(expr.is_hom_ref()))
            n_hom_var = append_agg(hl.agg.count_where(expr.is_hom_var()))
            n_het = append_agg(hl.agg.count_where(expr.is_het()))
            d['homozygous reference'] = lambda results: format(results[n_hom_ref])
            d['heterozygous'] = lambda results: format(results[n_het])
            d['homozygous variant'] = lambda results: format(results[n_hom_var])
            d['ploidy'] = lambda results: format(results[ploidy_counts])
            d['phased'] = lambda results: format(results[phased_counts])
        elif isinstance(t, hl.tlocus):
            contig_counts = append_agg(hl.agg.filter(hl.is_defined(expr), hl.agg.counter(expr.contig)))
            d['contig counts'] = lambda results: format(results[contig_counts])
        elif isinstance(t, (hl.tset, hl.tdict, hl.tarray)):
            size = append_agg(hl.agg.stats(hl.len(expr)))
            d['minimum size'] = lambda results: format(map_int(results[size]['min']))
            d['maximum size'] = lambda results: format(map_int(results[size]['max']))
            d['mean size'] = lambda results: format(results[size]['mean'])
        to_print.append((path, d))
        if isinstance(t, hl.ttuple):
            for i in range(len(expr)):
                recur_expr(expr[i], f'{path} / {i}')
        if isinstance(t, hl.tstruct):
            for k, v in expr.items():
                recur_expr(v, f'{path} / {repr(k)[1:-1]}')
Esempio n. 4
0
def set_female_y_metrics_to_na_expr(
        t: Union[hl.Table, hl.MatrixTable]) -> hl.expr.ArrayExpression:
    """
    Set Y-variant frequency callstats for female-specific metrics to missing structs.

    .. note:: Requires freq, freq_meta, and freq_index_dict annotations to be present in Table or MatrixTable

    :param t: Table or MatrixTable for which to adjust female metrics
    :return: Hail array expression to set female Y-variant metrics to missing values
    """
    female_idx = hl.map(
        lambda x: t.freq_index_dict[x],
        hl.filter(lambda x: x.contains("XX"), t.freq_index_dict.keys()),
    )
    freq_idx_range = hl.range(hl.len(t.freq_meta))

    new_freq_expr = hl.if_else(
        (t.locus.in_y_nonpar() | t.locus.in_y_par()),
        hl.map(
            lambda x: hl.if_else(female_idx.contains(x),
                                 missing_callstats_expr(), t.freq[x]),
            freq_idx_range,
        ),
        t.freq,
    )

    return new_freq_expr
Esempio n. 5
0
def create_gene_map_ht(ht, check_gene_contigs=False):
    from gnomad.utils.vep import process_consequences

    ht = process_consequences(ht)
    ht = ht.explode(ht.vep.worst_csq_by_gene_canonical)
    ht = ht.annotate(
        variant_id=ht.locus.contig + ':' + hl.str(ht.locus.position) + '_' +
        ht.alleles[0] + '/' + ht.alleles[1],
        annotation=annotation_case_builder(ht.vep.worst_csq_by_gene_canonical))
    if check_gene_contigs:
        gene_contigs = ht.group_by(
            gene_id=ht.vep.worst_csq_by_gene_canonical.gene_id,
            gene_symbol=ht.vep.worst_csq_by_gene_canonical.gene_symbol,
        ).aggregate(contigs=hl.agg.collect_as_set(ht.locus.contig))
        assert gene_contigs.all(hl.len(gene_contigs.contigs) == 1)

    gene_map_ht = ht.group_by(
        gene_id=ht.vep.worst_csq_by_gene_canonical.gene_id,
        gene_symbol=ht.vep.worst_csq_by_gene_canonical.gene_symbol,
    ).partition_hint(100).aggregate(
        interval=hl.interval(start=hl.locus(
            hl.agg.take(ht.locus.contig, 1)[0], hl.agg.min(ht.locus.position)),
                             end=hl.locus(
                                 hl.agg.take(ht.locus.contig, 1)[0],
                                 hl.agg.max(ht.locus.position))),
        variants=hl.agg.group_by(ht.annotation, hl.agg.collect(ht.variant_id)),
    )
    return gene_map_ht
Esempio n. 6
0
def annotate_variants_with_mnvs(variants_path, mnvs_path):
    ds = hl.read_table(mnvs_path)

    ds = ds.select("changes_amino_acids_for_snvs", "constituent_snvs", "constituent_snv_ids", "n_individuals",)

    ds = ds.explode(ds.constituent_snvs, "snv")
    ds = ds.annotate(
        locus=hl.locus(ds.snv.chrom, ds.snv.pos, reference_genome="GRCh37"), alleles=[ds.snv.ref, ds.snv.alt]
    )
    ds = ds.group_by(ds.locus, ds.alleles).aggregate(multi_nucleotide_variants=hl.agg.collect(ds.row.drop("snv")))

    variants = hl.read_table(variants_path)

    variants = variants.annotate(multi_nucleotide_variants=ds[variants.key].multi_nucleotide_variants)
    variants = variants.annotate(
        flags=hl.if_else(
            hl.len(variants.multi_nucleotide_variants) > 0,
            variants.flags.add("mnv"),
            variants.flags,
            missing_false=True,
        ),
        multi_nucleotide_variants=variants.multi_nucleotide_variants.map(
            lambda mnv: mnv.select(
                combined_variant_id=mnv.variant_id,
                changes_amino_acids=mnv.changes_amino_acids_for_snvs.contains(variants.variant_id),
                n_individuals=mnv.n_individuals,
                other_constituent_snvs=mnv.constituent_snv_ids.filter(lambda snv_id: snv_id != variants.variant_id),
            )
        ),
    )

    return variants
Esempio n. 7
0
def load_dob_ht(pre_phesant_tsv_path,
                key_name='userId',
                year_field='x34_0_0',
                month_field='x52_0_0',
                recruitment_center_field='x54_0_0',
                quote=None):
    dob_ht = hl.import_table(pre_phesant_tsv_path,
                             impute=False,
                             min_partitions=100,
                             missing='',
                             key=key_name,
                             quote=quote,
                             types={
                                 key_name: hl.tint32,
                                 recruitment_center_field: hl.tint32
                             })
    year_field, month_field = dob_ht[year_field], dob_ht[month_field]
    month_field = hl.cond(
        hl.len(month_field) == 1, '0' + month_field, month_field)
    dob_ht = dob_ht.select(date_of_birth=hl.experimental.strptime(
        year_field + month_field + '15 00:00:00', '%Y%m%d %H:%M:%S', 'GMT'),
                           month_of_birth=month_field,
                           year_of_birth=year_field,
                           recruitment_center=dob_ht[recruitment_center_field])
    return dob_ht
Esempio n. 8
0
def transform_one(mt: MatrixTable) -> MatrixTable:
    """transforms a gvcf into a form suitable for combining"""
    mt = mt.annotate_entries(
        # local (alt) allele index into global (alt) alleles
        LA=hl.range(0, hl.len(mt.alleles)),
        END=mt.info.END,
        BaseQRankSum=mt.info['BaseQRankSum'],
        ClippingRankSum=mt.info['ClippingRankSum'],
        MQ=mt.info['MQ'],
        MQRankSum=mt.info['MQRankSum'],
        ReadPosRankSum=mt.info['ReadPosRankSum'],
    )
    mt = mt.annotate_rows(
        info=mt.info.annotate(
            SB_TABLE=hl.array([
                hl.agg.sum(mt.entry.SB[0]),
                hl.agg.sum(mt.entry.SB[1]),
                hl.agg.sum(mt.entry.SB[2]),
                hl.agg.sum(mt.entry.SB[3]),
            ])
        ).select(
            "MQ_DP",
            "QUALapprox",
            "RAW_MQ",
            "VarDP",
            "SB_TABLE",
        ))
    mt = mt.transmute_entries(
        LGT=mt.GT,
        LAD=mt.AD[0:],  # requiredness issues :'(
        LPL=mt.PL[0:],
        LPGT=mt.PGT)
    mt = mt.drop('SB', 'qual', 'filters')

    return mt
Esempio n. 9
0
def remove_FT_values(
    mt: hl.MatrixTable,
    filters_to_remove: list = [
        'possible_numt', 'mt_many_low_hets', 'FAIL', 'blacklisted_site'
    ]
) -> hl.MatrixTable:
    """Removes the FT filters specified in filters_to_remove
    
    By default, this function removes the 'possible_numt', 'mt_many_low_hets', and 'FAIL' filters (because these filters were found to have low performance), 
    and the 'blacklisted_site' filter because this filter did not always behave as expected in early GATK versions (can be replaced with apply_mito_artifact_filter function)

    :param hl.MatrixTable mt:  MatrixTable
    :param list filters_to_remove: list of FT filters that should be removed from the entries
    
    :return: MatrixTable with certain FT filters removed
    :rtype: MatrixTable
    """

    filters_to_remove = hl.set(filters_to_remove)
    mt = mt.annotate_entries(
        FT=hl.array((mt.FT).difference(filters_to_remove)))

    # if no filters exists after removing those specified above, set the FT field to PASS
    mt = mt.annotate_entries(
        FT=hl.if_else(hl.len(mt.FT) == 0, ["PASS"], mt.FT))

    return (mt)
Esempio n. 10
0
def _import_clinvar(**kwargs) -> hl.Table:
    clinvar = import_sites_vcf(**kwargs)
    clinvar = clinvar.filter(
        hl.len(clinvar.alleles) > 1
    )  # Get around problematic single entry in alleles array in the clinvar vcf
    clinvar = vep_or_lookup_vep(clinvar, reference="GRCh38")
    return clinvar
def main(args):

    # Read mt
    mt = hl.read_matrix_table(args.matrixtable)
    # pca_scores_pop
    pca_scores_pop = hl.read_table(args.pca_scores_population)

    # annotate mt with pop and superpop
    mt = mt.annotate_cols(assigned_pop=pca_scores_pop[mt.s].pop)

    # do sample_qc
    # calculate and annotate with metric heterozygosity
    mt_with_sampleqc = hl.sample_qc(mt, name='sample_qc')

    mt_with_sampleqc = mt_with_sampleqc.annotate_cols(sample_qc=mt_with_sampleqc.sample_qc.annotate(
        heterozygosity_rate=mt_with_sampleqc.sample_qc.n_het/mt_with_sampleqc.sample_qc.n_called))
    # save sample_qc and heterozygosity table as ht table
    mt_with_sampleqc.write(
        f"{args.output_dir}/ddd-elgh-ukbb/mt_pops_superpops_sampleqc.mt", overwrite=True)
    mt_with_sampleqc.cols().write(
        f"{args.output_dir}/ddd-elgh-ukbb/mt_pops_superpops_sampleqc.ht",  overwrite=True)
    pop_ht = hl.read_table(
        f"{args.output_dir}/ddd-elgh-ukbb/mt_pops_superpops_sampleqc.ht")
    # run function on metrics including heterozygosity first for pops:
    qc_metrics = ['heterozygosity_rate', 'n_snp', 'r_ti_tv',
                  'r_insertion_deletion', 'n_insertion', 'n_deletion', 'r_het_hom_var']
    pop_filter_ht = compute_stratified_metrics_filter(
        pop_ht, qc_metrics, ['assigned_pop'])
    pop_ht = pop_ht.annotate_globals(hl.eval(pop_filter_ht.globals))
    pop_ht = pop_ht.annotate(**pop_filter_ht[pop_ht.key]).persist()

    checkpoint = pop_ht.aggregate(hl.agg.count_where(
        hl.len(pop_ht.qc_metrics_filters) == 0))
    logger.info(f'{checkpoint} exome samples found passing pop filtering')
    pop_ht.write(f"{args.output_dir}/ddd-elgh-ukbb/mt_pops_QC_filters.ht")
Esempio n. 12
0
def default_generate_trio_stats(mt: hl.MatrixTable,) -> hl.Table:
    """
    Default function to run `generate_trio_stats_expr` to get trio stats stratified by raw and adj

    .. note::

        Expects that `mt` is it a trio matrix table that was annotated with adj and if dealing with
        a sparse MT `hl.experimental.densify` must be run first.

    :param mt: A Trio Matrix Table returned from `hl.trio_matrix`. Must be dense
    :return: Table with trio stats
    """
    mt = mt.filter_rows(hl.len(mt.alleles) == 2)
    logger.info(f"Generating trio stats using {mt.count_cols()} trios.")
    trio_adj = mt.proband_entry.adj & mt.father_entry.adj & mt.mother_entry.adj

    ht = mt.select_rows(
        **generate_trio_stats_expr(
            mt,
            transmitted_strata={"raw": True, "adj": trio_adj},
            de_novo_strata={"raw": True, "adj": trio_adj},
            ac_strata={"raw": True, "adj": trio_adj},
            proband_is_female_expr=mt.is_female,
        )
    ).rows()

    return ht
Esempio n. 13
0
def import_clinvar_vcf(vcf_path, reference_genome):
    if reference_genome not in ("GRCh37", "GRCh38"):
        raise ValueError("Unsupported reference genome: " + str(reference_genome))

    clinvar_release_date = _parse_clinvar_release_date(vcf_path)

    # contigs in the ClinVar GRCh38 VCF are not prefixed with "chr"
    contig_recoding = None
    if reference_genome == "GRCh38":
        ref = hl.get_reference("GRCh38")
        contig_recoding = {
            ref_contig.replace("chr", ""): ref_contig for ref_contig in ref.contigs if "chr" in ref_contig
        }

    ds = hl.import_vcf(
        vcf_path,
        reference_genome=reference_genome,
        contig_recoding=contig_recoding,
        min_partitions=2000,
        force_bgz=True,
        drop_samples=True,
        skip_invalid_loci=True,
    ).rows()

    ds = ds.annotate_globals(version=clinvar_release_date)

    # Verify assumption that there are no multi-allelic variants and that splitting is not necessary.
    n_multiallelic_variants = ds.aggregate(hl.agg.filter(hl.len(ds.alleles) > 2, hl.agg.count()))
    assert n_multiallelic_variants == 0, "ClinVar VCF contains multi-allelic variants"

    return ds
Esempio n. 14
0
def vep_protein_domain_ann_expr(
        s: hl.expr.StringExpression) -> hl.expr.DictExpression:
    """
    Parse and annotate protein domain(s) from VEP annotation.
    Expected StringExpression as input (e.g. 'Pfam:PF13853&Prints:PR00237&PROSITE_profiles:PS50262')
    It will generate a dict<k,v> where keys (k) represent source/database and values (v) the annotated domain_id.

    :param s: hl.expr.StringExpression
    :return: hl.expr.DictExpression
    """
    a1 = s.split(delim="&")

    # keep only well-annotated domain(s) (i.e. <source:domain_id>)
    a2 = a1.map(lambda x: x.split(delim=":"))
    a2 = a2.filter(lambda x: x.length() == 2)

    d = (
        hl.case().when(
            hl.len(a2) > 0,
            hl.dict(
                hl.zip(
                    a2.map(lambda x: x[0]
                           ),  # TODO: Optimize by scanning array just one.
                    a2.map(lambda x: x[1])))).or_missing())

    return d
Esempio n. 15
0
def add_popmax_expr(freq: hl.expr.ArrayExpression,
                    freq_meta: hl.expr.ArrayExpression,
                    populations: Set[str]) -> hl.expr.ArrayExpression:
    """
    Calculates popmax (add an additional entry into freq with popmax: pop)

    :param ArrayExpression freq: ArrayExpression of Structs with ['ac', 'an', 'hom']
    :param ArrayExpression freq_meta: ArrayExpression of meta dictionaries corresponding to freq
    :param set of str populations: Set of populations over which to calculate popmax
    :return: Frequency data with annotated popmax
    :rtype: ArrayExpression
    """
    pops_to_use = hl.literal(populations)
    freq = hl.map(lambda x: x[0].annotate(meta=x[1]), hl.zip(freq, freq_meta))
    freq_filtered = hl.filter(
        lambda f: (f.meta.size() == 2) & (f.meta.get('group') == 'adj') &
        pops_to_use.contains(f.meta.get('pop')) & (f.AC > 0), freq)
    sorted_freqs = hl.sorted(freq_filtered, key=lambda x: x.AF, reverse=True)
    return hl.or_missing(
        hl.len(sorted_freqs) > 0,
        hl.struct(AC=sorted_freqs[0].AC,
                  AF=sorted_freqs[0].AF,
                  AN=sorted_freqs[0].AN,
                  homozygote_count=sorted_freqs[0].homozygote_count,
                  pop=sorted_freqs[0].meta['pop']))
Esempio n. 16
0
    def find_worst_transcript_consequence(
            tcl: hl.expr.ArrayExpression) -> hl.expr.StructExpression:
        """
        Gets worst transcript_consequence from an array of em
        """
        flag_score = 500
        no_flag_score = flag_score * (1 + penalize_flags)

        def csq_score(tc):
            return csq_dict[csqs.find(
                lambda x: x == tc.most_severe_consequence)]

        tcl = tcl.map(lambda tc: tc.annotate(
            csq_score=hl.case(missing_false=True).
            when((tc.lof == 'HC') & (tc.lof_flags == ''),
                 csq_score(tc) - no_flag_score).when(
                     (tc.lof == 'HC') & (tc.lof_flags != ''),
                     csq_score(tc) - flag_score).when(tc.lof == 'LC',
                                                      csq_score(tc) - 10).
            when(tc.polyphen_prediction == 'probably_damaging',
                 csq_score(tc) - 0.5).when(
                     tc.polyphen_prediction == 'possibly_damaging',
                     csq_score(tc) - 0.25).when(
                         tc.polyphen_prediction == 'benign',
                         csq_score(tc) - 0.1).default(csq_score(tc))))
        return hl.or_missing(
            hl.len(tcl) > 0,
            hl.sorted(tcl, lambda x: x.csq_score)[0])
Esempio n. 17
0
def prepare_clinvar_variants(vcf_path, reference_genome):
    ds = import_clinvar_vcf(vcf_path, reference_genome)

    # There are some variants with only one entry in alleles, ignore them for now.
    # These could be displayed in the ClinVar track even though they will never match a gnomAD variant.
    ds = ds.filter(hl.len(ds.alleles) == 2)

    ds = hl.vep(ds)

    ds = ds.select(
        clinical_significance=hl.sorted(ds.info.CLNSIG, key=lambda s: s.replace("^_", "z")).map(
            lambda s: s.replace("^_", "")
        ),
        clinvar_variation_id=ds.rsid,
        gold_stars=get_gold_stars(ds.info.CLNREVSTAT),
        review_status=hl.sorted(ds.info.CLNREVSTAT, key=lambda s: s.replace("^_", "z")).map(
            lambda s: s.replace("^_", "")
        ),
        vep=ds.vep,
    )

    ds = ds.annotate(
        chrom=normalized_contig(ds.locus.contig), variant_id=variant_id(ds.locus, ds.alleles), xpos=x_position(ds.locus)
    )

    return ds
Esempio n. 18
0
def generate_allele_data(ht: hl.Table) -> hl.Table:
    """
    Returns bi-allelic sites HT with the following annotations:
     - allele_data (nonsplit_alleles, has_star, variant_type, and n_alt_alleles)

    :param Table ht: Full unsplit HT
    :return: Table with allele data annotations
    :rtype: Table
    """
    ht = ht.select()
    allele_data = hl.struct(nonsplit_alleles=ht.alleles,
                            has_star=hl.any(lambda a: a == "*", ht.alleles))
    ht = ht.annotate(allele_data=allele_data.annotate(
        **add_variant_type(ht.alleles)))

    ht = hl.split_multi_hts(ht)
    ht = ht.filter(hl.len(ht.alleles) > 1)
    allele_type = (hl.case().when(
        hl.is_snp(ht.alleles[0], ht.alleles[1]),
        "snv").when(hl.is_insertion(ht.alleles[0], ht.alleles[1]),
                    "ins").when(hl.is_deletion(ht.alleles[0], ht.alleles[1]),
                                "del").default("complex"))
    ht = ht.annotate(allele_data=ht.allele_data.annotate(
        allele_type=allele_type,
        was_mixed=ht.allele_data.variant_type == "mixed"))
    return ht
Esempio n. 19
0
def transform_one(mt: MatrixTable) -> MatrixTable:
    """transforms a gvcf into a form suitable for combining"""
    mt = mt.annotate_entries(
        # local (alt) allele index into global (alt) alleles
        LA=hl.range(0,
                    hl.len(mt.alleles) - 1),
        END=mt.info.END,
        PL=mt['PL'][0:],
        BaseQRankSum=mt.info['BaseQRankSum'],
        ClippingRankSum=mt.info['ClippingRankSum'],
        MQ=mt.info['MQ'],
        MQRankSum=mt.info['MQRankSum'],
        ReadPosRankSum=mt.info['ReadPosRankSum'],
    )
    mt = mt.annotate_rows(info=mt.info.annotate(
        DP=hl.agg.sum(mt.entry.DP),
        SB=hl.agg.array_sum(mt.entry.SB),
    ).select(
        "DP",
        "MQ_DP",
        "QUALapprox",
        "RAW_MQ",
        "VarDP",
        "SB",
    ))
    mt = mt.drop('SB', 'qual')

    return mt
Esempio n. 20
0
def multi_way_union_mts(mts: list, tmp_dir: str,
                        chunk_size: int) -> hl.MatrixTable:
    """Joins MatrixTables in the provided list

    :param list mts: list of MatrixTables to join together
    :param str tmp_dir: path to temporary directory for intermediate results
    :param int chunk_size: number of MatrixTables to join per chunk

    :return: joined MatrixTable
    :rtype: MatrixTable
    """

    staging = [mt.localize_entries("__entries", "__cols") for mt in mts]
    stage = 0
    while len(staging) > 1:
        n_jobs = int(math.ceil(len(staging) / chunk_size))
        info(f"multi_way_union_mts: stage {stage}: {n_jobs} total jobs")
        next_stage = []
        for i in range(n_jobs):
            to_merge = staging[chunk_size * i:chunk_size * (i + 1)]
            info(
                f"multi_way_union_mts: stage {stage} / job {i}: merging {len(to_merge)} inputs"
            )

            merged = hl.Table.multi_way_zip_join(to_merge, "__entries",
                                                 "__cols")
            merged = merged.annotate(__entries=hl.flatten(
                hl.range(hl.len(merged.__entries)).map(lambda i: hl.coalesce(
                    merged.__entries[i].__entries,
                    hl.range(hl.len(merged.__cols[i].__cols)).map(
                        lambda j: hl.null(merged.__entries.__entries.dtype.
                                          element_type.element_type)),
                ))))
            merged = merged.annotate_globals(
                __cols=hl.flatten(merged.__cols.map(lambda x: x.__cols)))

            next_stage.append(
                merged.checkpoint(os.path.join(tmp_dir,
                                               f"stage_{stage}_job_{i}.ht"),
                                  overwrite=True))
        info(f"done stage {stage}")
        stage += 1
        staging.clear()
        staging.extend(next_stage)

    return (staging[0]._unlocalize_entries(
        "__entries", "__cols", list(mts[0].col_key)).unfilter_entries())
Esempio n. 21
0
def get_lowqual_expr(
        alleles: hl.expr.ArrayExpression,
        qual_approx_expr: Union[hl.expr.ArrayNumericExpression,
                                hl.expr.NumericExpression],
        snv_phred_threshold: int = 30,
        snv_phred_het_prior: int = 30,  # 1/1000
        indel_phred_threshold: int = 30,
        indel_phred_het_prior: int = 39,  # 1/8,000
) -> Union[hl.expr.BooleanExpression, hl.expr.ArrayExpression]:
    """
    Computes lowqual threshold expression for either split or unsplit alleles based on QUALapprox or AS_QUALapprox

    .. note::

        When running This lowqual annotation using QUALapprox, it differs from the GATK LowQual filter.
        This is because GATK computes this annotation at the site level, which uses the least stringent prior for mixed sites.
        When run using AS_QUALapprox, this implementation can thus be more stringent for certain alleles at mixed sites.

    :param alleles: Array of alleles
    :param qual_approx_expr: QUALapprox or AS_QUALapprox
    :param snv_phred_threshold: Phred-scaled SNV "emission" threshold (similar to GATK emission threshold)
    :param snv_phred_het_prior: Phred-scaled SNV heterozygosity prior (30 = 1/1000 bases, GATK default)
    :param indel_phred_threshold: Phred-scaled indel "emission" threshold (similar to GATK emission threshold)
    :param indel_phred_het_prior: Phred-scaled indel heterozygosity prior (30 = 1/1000 bases, GATK default)
    :return: lowqual expression (BooleanExpression if `qual_approx_expr`is Numeric, Array[BooleanExpression] if `qual_approx_expr` is ArrayNumeric)
    """

    min_snv_qual = snv_phred_threshold + snv_phred_het_prior
    min_indel_qual = indel_phred_threshold + indel_phred_het_prior
    min_mixed_qual = max(min_snv_qual, min_indel_qual)

    if isinstance(qual_approx_expr, hl.expr.ArrayNumericExpression):
        return hl.range(1, hl.len(alleles)).map(lambda ai: hl.cond(
            hl.is_snp(alleles[0], alleles[ai]),
            qual_approx_expr[ai - 1] < min_snv_qual,
            qual_approx_expr[ai - 1] < min_indel_qual,
        ))
    else:
        return (hl.case().when(
            hl.range(1, hl.len(alleles)).all(
                lambda ai: hl.is_snp(alleles[0], alleles[ai])),
            qual_approx_expr < min_snv_qual,
        ).when(
            hl.range(1, hl.len(alleles)).all(
                lambda ai: hl.is_indel(alleles[0], alleles[ai])),
            qual_approx_expr < min_indel_qual,
        ).default(qual_approx_expr < min_mixed_qual))
Esempio n. 22
0
def mwzj_hts_by_tree(all_hts, temp_dir, globals_for_col_key, 
                     debug=False, inner_mode = 'overwrite', repartition_final: int = None,
                     read_if_exists = False):
    r'''
    Adapted from ukb_common mwzj_hts_by_tree()
    Uses read_clump_ht() instead of read_table()
    '''
    chunk_size = int(len(all_hts) ** 0.5) + 1
    outer_hts = []
    
    if read_if_exists: print('\n\nWARNING: Intermediate tables will not be overwritten if they already exist\n\n')
    
    checkpoint_kwargs = {inner_mode: not read_if_exists,
                         '_read_if_exists': read_if_exists} #
    if repartition_final is not None:
        intervals = ukb_common.get_n_even_intervals(repartition_final)
        checkpoint_kwargs['_intervals'] = intervals
    
    if debug: print(f'Running chunk size {chunk_size}...')
    for i in range(chunk_size):
        if i * chunk_size >= len(all_hts): break
        hts = all_hts[i * chunk_size:(i + 1) * chunk_size]
        if debug: print(f'Going from {i * chunk_size} to {(i + 1) * chunk_size} ({len(hts)} HTs)...')
        try:
            if isinstance(hts[0], str):
                def read_clump_ht(f):
                    ht = hl.read_table(f)
                    ht = ht.drop('idx')
                    return ht
                hts = list(map(read_clump_ht, hts))
            ht = hl.Table.multi_way_zip_join(hts, 'row_field_name', 'global_field_name')
        except:
            if debug:
                print(f'problem in range {i * chunk_size}-{i * chunk_size + chunk_size}')
                _ = [ht.describe() for ht in hts]
            raise
        outer_hts.append(ht.checkpoint(f'{temp_dir}/temp_output_{i}.ht', **checkpoint_kwargs))
    ht = hl.Table.multi_way_zip_join(outer_hts, 'row_field_name_outer', 'global_field_name_outer')
    ht = ht.transmute(inner_row=hl.flatmap(lambda i:
                                           hl.cond(hl.is_missing(ht.row_field_name_outer[i].row_field_name),
                                                   hl.range(0, hl.len(ht.global_field_name_outer[i].global_field_name))
                                                   .map(lambda _: hl.null(ht.row_field_name_outer[i].row_field_name.dtype.element_type)),
                                                   ht.row_field_name_outer[i].row_field_name),
                                           hl.range(hl.len(ht.global_field_name_outer))))
    ht = ht.transmute_globals(inner_global=hl.flatmap(lambda x: x.global_field_name, ht.global_field_name_outer))
    mt = ht._unlocalize_entries('inner_row', 'inner_global', globals_for_col_key)
    return mt
def pre_process_subset_freq(subset: str,
                            global_ht: hl.Table,
                            test: bool = False) -> hl.Table:
    """
    Prepare subset frequency Table by filling in missing frequency fields for loci present only in the global cohort.

    .. note::

        The resulting final `freq` array will be as long as the subset `freq_meta` global (i.e., one `freq` entry for each `freq_meta` entry)

    :param subset: subset ID
    :param global_ht: Hail Table containing all variants discovered in the overall release cohort
    :param test: If True, filter to small region on chr20
    :return: Table containing subset frequencies with missing freq structs filled in
    """

    # Read in subset HTs
    subset_ht_path = get_freq(subset=subset).path
    subset_chr20_ht_path = qc_temp_prefix() + f"chr20_test_freq.{subset}.ht"

    if test:
        if file_exists(subset_chr20_ht_path):
            logger.info(
                "Loading chr20 %s subset frequency data for testing: %s",
                subset,
                subset_chr20_ht_path,
            )
            subset_ht = hl.read_table(subset_chr20_ht_path)

        elif file_exists(subset_ht_path):
            logger.info(
                "Loading %s subset frequency data for testing: %s",
                subset,
                subset_ht_path,
            )
            subset_ht = hl.read_table(subset_ht_path)
            subset_ht = hl.filter_intervals(
                subset_ht, [hl.parse_locus_interval("chr20:1-1000000")])

    elif file_exists(subset_ht_path):
        logger.info("Loading %s subset frequency data: %s", subset,
                    subset_ht_path)
        subset_ht = hl.read_table(subset_ht_path)

    else:
        raise DataException(
            f"Hail Table containing {subset} subset frequencies not found. You may need to run the script generate_freq_data.py to generate frequency annotations first."
        )

    # Fill in missing freq structs
    ht = subset_ht.join(global_ht.select().select_globals(), how="right")
    ht = ht.annotate(freq=hl.if_else(
        hl.is_missing(ht.freq),
        hl.map(lambda x: missing_callstats_expr(),
               hl.range(hl.len(ht.freq_meta))),
        ht.freq,
    ))

    return ht
Esempio n. 24
0
def default_compute_info(mt: hl.MatrixTable,
                         site_annotations: bool = False,
                         n_partitions: int = 5000) -> hl.Table:
    """
    Computes a HT with the typical GATK allele-specific (AS) info fields 
    as well as ACs and lowqual fields.
    Note that this table doesn't split multi-allelic sites.

    :param mt: Input MatrixTable. Note that this table should be filtered to nonref sites.
    :param site_annotations: Whether to also generate site level info fields. Default is False.
    :param n_partitions: Number of desired partitions for output Table. Default is 5000.
    :return: Table with info fields
    :rtype: Table
    """
    # Move gvcf info entries out from nested struct
    mt = mt.transmute_entries(**mt.gvcf_info)

    # Compute AS info expr
    info_expr = get_as_info_expr(mt)

    if site_annotations:
        info_expr = info_expr.annotate(**get_site_info_expr(mt))

    # Add AC and AC_raw:
    # First compute ACs for each non-ref allele, grouped by adj
    grp_ac_expr = hl.agg.array_agg(
        lambda ai: hl.agg.filter(
            mt.LA.contains(ai),
            hl.agg.group_by(
                get_adj_expr(mt.LGT, mt.GQ, mt.DP, mt.LAD),
                hl.agg.sum(
                    mt.LGT.one_hot_alleles(mt.LA.map(lambda x: hl.str(x)))[
                        mt.LA.index(ai)]),
            ),
        ),
        hl.range(1, hl.len(mt.alleles)),
    )

    # Then, for each non-ref allele, compute
    # AC as the adj group
    # AC_raw as the sum of adj and non-adj groups
    info_expr = info_expr.annotate(
        AC_raw=grp_ac_expr.map(
            lambda i: hl.int32(i.get(True, 0) + i.get(False, 0))),
        AC=grp_ac_expr.map(lambda i: hl.int32(i.get(True, 0))),
    )

    info_ht = mt.select_rows(info=info_expr).rows()

    # Add AS lowqual flag
    info_ht = info_ht.annotate(AS_lowqual=get_lowqual_expr(
        info_ht.alleles, info_ht.info.AS_QUALapprox))

    if site_annotations:
        # Add lowqual flag
        info_ht = info_ht.annotate(
            lowqual=get_lowqual_expr(info_ht.alleles, info_ht.info.QUALapprox))

    return info_ht.naive_coalesce(n_partitions)
Esempio n. 25
0
def export_loo(batch_size=256):
    r'''
    For exporting p-values of meta-analysis of leave-one-out population sets
    '''
    meta_mt0 = hl.read_matrix_table(get_meta_analysis_results_path())

    meta_mt0 = meta_mt0.filter_cols(hl.len(meta_mt0.pheno_data.pop) == 6)

    meta_mt0 = meta_mt0.annotate_cols(pheno_id=(
        meta_mt0.trait_type + '-' + meta_mt0.phenocode + '-' +
        meta_mt0.pheno_sex +
        hl.if_else(hl.len(meta_mt0.coding) > 0, '-' + meta_mt0.coding, '') +
        hl.if_else(hl.len(meta_mt0.modifier) > 0, '-' +
                   meta_mt0.modifier, '')).replace(' ', '_').replace('/', '_'))

    meta_mt0 = meta_mt0.annotate_rows(
        SNP=(meta_mt0.locus.contig + ':' + hl.str(meta_mt0.locus.position) +
             ':' + meta_mt0.alleles[0] + ':' + meta_mt0.alleles[1]))

    all_pops = sorted(['AFR', 'AMR', 'CSA', 'EAS', 'EUR', 'MID'])

    annotate_dict = {}
    for pop_idx, pop in enumerate(
            all_pops, 1
    ):  # pop idx corresponds to the alphabetic ordering of the pops (entry with idx=0 is 6-pop meta-analysis)
        annotate_dict.update(
            {f'pval_not_{pop}': meta_mt0.meta_analysis.Pvalue[pop_idx]})
    meta_mt1 = meta_mt0.annotate_entries(**annotate_dict)

    meta_mt1 = meta_mt1.key_cols_by('pheno_id')
    meta_mt1 = meta_mt1.key_rows_by().drop('locus', 'alleles', 'gene',
                                           'annotation', 'meta_analysis')

    print(meta_mt1.describe())

    batch_idx = 1
    get_export_path = lambda batch_idx: f'{ldprune_dir}/loo/sumstats/batch{batch_idx}'
    while hl.hadoop_is_dir(get_export_path(batch_idx)):
        batch_idx += 1
    print(f'\nExporting to: {get_export_path(batch_idx)}\n')
    hl.experimental.export_entries_by_col(mt=meta_mt1,
                                          path=get_export_path(batch_idx),
                                          bgzip=True,
                                          batch_size=batch_size,
                                          use_string_key_as_file_name=True,
                                          header_json_in_file=False)
Esempio n. 26
0
def make_pheno_manifest(export=True):
    mt0 = load_final_sumstats_mt(filter_sumstats=False,
                                 filter_variants=False,
                                 separate_columns_by_pop=False,
                                 annotate_with_nearest_gene=False)

    ht = mt0.cols()
    annotate_dict = {}

    annotate_dict.update({
        'pops': hl.delimit(ht.pheno_data.pop),
        'num_pops': hl.len(ht.pheno_data.pop)
    })

    for field in ['n_cases', 'n_controls', 'heritability', 'lambda_gc']:
        for pop in ['AFR', 'AMR', 'CSA', 'EAS', 'EUR', 'MID']:
            new_field = field if field != 'heritability' else 'saige_heritability'  # new field name (only applicable to saige heritability)
            idx = ht.pheno_data.pop.index(pop)
            field_expr = ht.pheno_data[field]
            annotate_dict.update({
                f'{new_field}_{pop}':
                hl.if_else(hl.is_nan(idx), hl.null(field_expr[0].dtype),
                           field_expr[idx])
            })
    annotate_dict.update({'filename': get_pheno_id(tb=ht) + '.tsv.bgz'})
    ht = ht.annotate(**annotate_dict)

    dropbox_manifest = hl.import_table(
        f'{ldprune_dir}/UKBB_Pan_Populations-Manifest_20200615-manifest_info.tsv',
        impute=True,
        key='File')
    dropbox_manifest = dropbox_manifest.filter(
        dropbox_manifest['is_old_file'] != '1')
    bgz = dropbox_manifest.filter(~dropbox_manifest.File.contains('.tbi'))
    bgz = bgz.rename({'File': 'filename'})
    tbi = dropbox_manifest.filter(dropbox_manifest.File.contains('.tbi'))
    tbi = tbi.annotate(
        filename=tbi.File.replace('.tbi', '')).key_by('filename')

    dropbox_annotate_dict = {}

    rename_dict = {
        'dbox link': 'dropbox_link',
        'size (bytes)': 'size_in_bytes'
    }

    dropbox_annotate_dict.update({'filename_tabix': tbi[ht.filename].File})
    for field in ['dbox link', 'wget', 'size (bytes)', 'md5 hex']:
        for tb, suffix in [(bgz, ''), (tbi, '_tabix')]:
            dropbox_annotate_dict.update({
                (rename_dict[field] if field in rename_dict else field.replace(
                     ' ', '_')) + suffix:
                tb[ht.filename][field]
            })
    ht = ht.annotate(**dropbox_annotate_dict)
    ht = ht.drop('pheno_data')
    ht.describe()
    ht.show()
Esempio n. 27
0
            def with_pl(pl):
                new_exprs = {}
                dropped_fields = ['LA']
                if 'LGT' in fields:
                    new_exprs['GT'] = hl.downcode(
                        old_entry.LGT,
                        hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LGT')
                if 'LPGT' in fields:
                    new_exprs['PGT'] = hl.downcode(
                        old_entry.LPGT,
                        hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LPGT')
                if 'LAD' in fields:
                    non_ref_ad = hl.or_else(old_entry.LAD[local_a_index],
                                            0)  # zeroed if not in LAD
                    new_exprs['AD'] = hl.or_missing(
                        hl.is_defined(old_entry.LAD),
                        [hl.sum(old_entry.LAD) - non_ref_ad, non_ref_ad])
                    dropped_fields.append('LAD')
                if 'LPL' in fields:
                    new_exprs['PL'] = pl
                    if 'GQ' in fields:
                        new_exprs['GQ'] = hl.or_else(hl.gq_from_pl(pl),
                                                     old_entry.GQ)

                    dropped_fields.append('LPL')

                return (hl.case().when(
                    hl.len(ds.alleles) == 1,
                    old_entry.annotate(
                        **{
                            f[1:]: old_entry[f]
                            for f in ['LGT', 'LPGT', 'LAD', 'LPL']
                            if f in fields
                        }).drop(*dropped_fields)).when(
                            hl.or_else(old_entry.LGT.is_hom_ref(), False),
                            old_entry.annotate(
                                **{
                                    f: old_entry[f'L{f}'] if f in
                                    ['GT', 'PGT'] else e
                                    for f, e in new_exprs.items()
                                }).drop(*dropped_fields)).default(
                                    old_entry.annotate(**new_exprs).drop(
                                        *dropped_fields)))
Esempio n. 28
0
def load_prescription_data(prescription_data_tsv_path: str, prescription_mapping_tsv_path):
    ht = hl.import_table(prescription_data_tsv_path, types={'eid': hl.tint, 'data_provider': hl.tint}, key='eid')
    mapping_ht = hl.import_table(prescription_mapping_tsv_path, impute=True, key='Original_Prescription')
    ht = ht.annotate(issue_date=hl.cond(hl.len(ht.issue_date) == 0, hl.null(hl.tint64),
                                        hl.experimental.strptime(ht.issue_date + ' 00:00:00', '%d/%m/%Y %H:%M:%S', 'GMT')),
                     **mapping_ht[ht.drug_name])
    ht = ht.filter(ht.Generic_Name != '').key_by('eid', 'Generic_Name', 'Drug_Category_and_Indication').collect_by_key()
    ht = ht.annotate(values=hl.sorted(ht.values, key=lambda x: x.issue_date))
    return ht.to_matrix_table(row_key=['eid'], col_key=['Generic_Name'], col_fields=['Drug_Category_and_Indication'])
Esempio n. 29
0
def add_variant_type(
        alt_alleles: hl.expr.ArrayExpression) -> hl.expr.StructExpression:
    """Get Struct of variant_type and n_alt_alleles from ArrayExpression of Strings (all alleles)."""
    ref = alt_alleles[0]
    alts = alt_alleles[1:]
    non_star_alleles = hl.filter(lambda a: a != "*", alts)
    return hl.struct(
        variant_type=hl.cond(
            hl.all(lambda a: hl.is_snp(ref, a), non_star_alleles),
            hl.cond(hl.len(non_star_alleles) > 1, "multi-snv", "snv"),
            hl.cond(
                hl.all(lambda a: hl.is_indel(ref, a), non_star_alleles),
                hl.cond(hl.len(non_star_alleles) > 1, "multi-indel", "indel"),
                "mixed",
            ),
        ),
        n_alt_alleles=hl.len(non_star_alleles),
    )
Esempio n. 30
0
def test_blanczos_against_numpy():

    def concatToNumpy(field, horizontal=True):
        blocks = field.collect()
        if horizontal:
            return np.concatenate(blocks, axis=0)
        else:
            return np.concatenate(blocks, axis=1)

    mt = hl.import_vcf(resource('tiny_m.vcf'))
    mt = mt.filter_rows(hl.len(mt.alleles) == 2)
    mt = mt.annotate_rows(AC=hl.agg.sum(mt.GT.n_alt_alleles()),
                          n_called=hl.agg.count_where(hl.is_defined(mt.GT)))
    mt = mt.filter_rows((mt.AC > 0) & (mt.AC < 2 * mt.n_called)).persist()
    n_rows = mt.count_rows()

    def make_expr(mean):
        return hl.if_else(hl.is_defined(mt.GT),
                          (mt.GT.n_alt_alleles() - mean) / hl.sqrt(mean * (2 - mean) * n_rows / 2),
                          0)

    k = 3

    float_expr = make_expr(mt.AC / mt.n_called)

    eigens, scores_t, loadings_t = hl._blanczos_pca(float_expr, k=k, q_iterations=7, compute_loadings=True)
    A = np.array(float_expr.collect()).reshape((3, 4)).T
    scores = concatToNumpy(scores_t.scores)
    loadings = concatToNumpy(loadings_t.loadings)
    scores = np.reshape(scores, (len(scores) // k, k))
    loadings = np.reshape(loadings, (len(loadings) // k, k))

    assert len(eigens) == 3
    assert scores_t.count() == mt.count_cols()
    assert loadings_t.count() == n_rows
    np.testing.assert_almost_equal(A @ loadings, scores)

    assert len(scores_t.globals) == 0
    assert len(loadings_t.globals) == 0

    # compute PCA with numpy
    def normalize(a):
        ms = np.mean(a, axis=0, keepdims=True)
        return np.divide(np.subtract(a, ms), np.sqrt(2.0 * np.multiply(ms / 2.0, 1 - ms / 2.0) * a.shape[1]))

    g = np.pad(np.diag([1.0, 1, 2]), ((0, 1), (0, 0)), mode='constant')
    g[1, 0] = 1.0 / 3
    n = normalize(g)
    U, s, V = np.linalg.svd(n, full_matrices=0)
    np_loadings = V.transpose()
    np_eigenvalues = np.multiply(s, s)

    def bound(vs, us):  # equation 12 from https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4827102/pdf/main.pdf
        return 1/k * sum([np.linalg.norm(us.T @ vs[:,i]) for i in range(k)])

    np.testing.assert_allclose(eigens, np_eigenvalues, rtol=0.05)
    assert bound(np_loadings, loadings) > 0.9
Esempio n. 31
0
def bi_allelic_expr(t: Union[hl.Table, hl.MatrixTable]) -> hl.expr.BooleanExpression:
    """
    Returns a boolean expression selecting bi-allelic sites only,
    accounting for whether the input MT/HT was split.

    :param t: Input HT/MT
    :return: Boolean expression selecting only bi-allelic sites
    """
    return ~t.was_split if "was_split" in t.row else (hl.len(t.alleles) == 2)
Esempio n. 32
0
 def merge_alleles(alleles):
     from hail.expr.functions import _num_allele_type, _allele_ints
     return hl.rbind(
         alleles.map(lambda a: hl.or_else(a[0], '')).fold(
             lambda s, t: hl.cond(hl.len(s) > hl.len(t), s, t), ''),
         lambda ref: hl.rbind(
             alleles.map(lambda al: hl.rbind(
                 al[0], lambda r: hl.array([ref]).
                 extend(al[1:].map(lambda a: hl.rbind(
                     _num_allele_type(r, a), lambda at: hl.cond(
                         (_allele_ints['SNP'] == at) |
                         (_allele_ints['Insertion'] == at) |
                         (_allele_ints['Deletion'] == at) |
                         (_allele_ints['MNP'] == at) | (_allele_ints[
                             'Complex'] == at), a + ref[hl.len(r):], a)
                 ))))), lambda lal: hl.struct(globl=hl.array([ref]).extend(
                     hl.array(hl.set(hl.flatten(lal)).remove(ref))),
                                              local=lal)))
Esempio n. 33
0
def spread(ht, field, value, key=None) -> Table:
    """Spread a key-value pair of fields across multiple fields.

    :func:`.spread` mimics the functionality of the `spread()` function in R's
    `tidyr` package. This is a way to turn "long" format data into "wide"
    format data.

    Given a ``field``, :func:`.spread` will create a new table by grouping
    ``ht`` by its row key and, optionally, any additional fields passed to the
    ``key`` argument.

    After collapsing ``ht`` by these keys, :func:`.spread` creates a new row field
    for each unique value of ``field``, where the row field values are given by the
    corresponding ``value`` in the original ``ht``.


    Parameters
    ----------
    ht : :class:`.Table`
        A Hail table.
    field : :obj:`str`
        The name of the factor field in `ht`.
    value : :obj:`str`
        The name of the value field in `ht`.
    key : optional, obj:`str` or list of :obj:`str`
        The name of any fields to group by, in addition to the
        row key fields of ``ht``.

    Returns
    -------
    :class:`.Table`
        Table with original ``key`` and ``value`` fields spread across multiple columns."""

    if key is None:
        key = list(ht.key)
    else:
        key = wrap_to_list(key)
        key = list(ht.key) + key

    field_vals = list(ht.aggregate(hl.agg.collect_as_set(ht[field])))
    ht = (ht.group_by(*key).aggregate(
        **{
            rv: hl.agg.take(ht[rv], 1)[0]
            for rv in ht.row_value if rv not in set(key + [field, value])
        }, **{
            fv: hl.agg.filter(
                ht[field] == fv,
                hl.rbind(
                    hl.agg.take(ht[value], 1),
                    lambda take: hl.cond(hl.len(take) > 0, take[0], 'NA')))
            for fv in field_vals
        }))

    ht_tmp = new_temp_file()
    ht.write(ht_tmp)

    return ht
Esempio n. 34
0
    def transform_entries(old_entry):
        def with_local_a_index(local_a_index):
            new_pl = hl.or_missing(
                hl.is_defined(old_entry.LPL),
                hl.or_missing(
                    hl.is_defined(local_a_index),
                    hl.range(0, 3).map(lambda i: hl.min(
                        hl.range(0, hl.triangle(hl.len(old_entry.LA)))
                            .filter(lambda j: hl.downcode(hl.unphased_diploid_gt_index_call(j), local_a_index) == hl.unphased_diploid_gt_index_call(i))
                            .map(lambda idx: old_entry.LPL[idx])))))
            fields = set(old_entry.keys())

            def with_pl(pl):
                new_exprs = {}
                dropped_fields = ['LA']
                if 'LGT' in fields:
                    new_exprs['GT'] = hl.downcode(old_entry.LGT, hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LGT')
                if 'LPGT' in fields:
                    new_exprs['PGT'] = hl.downcode(old_entry.LPGT, hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LPGT')
                if 'LAD' in fields:
                    new_exprs['AD'] = hl.or_missing(
                        hl.is_defined(old_entry.LAD),
                        [old_entry.LAD[0], hl.or_else(old_entry.LAD[local_a_index], 0)]) # second entry zeroed for lack of non-ref AD
                    dropped_fields.append('LAD')
                if 'LPL' in fields:
                    new_exprs['PL'] = pl
                    if 'GQ' in fields:
                        new_exprs['GQ'] = hl.or_else(hl.gq_from_pl(pl), old_entry.GQ)

                    dropped_fields.append('LPL')

                return hl.cond(hl.len(ds.alleles) == 1,
                                   old_entry.annotate(**{f[1:]: old_entry[f] for f in ['LGT', 'LPGT', 'LAD', 'LPL'] if f in fields}).drop(*dropped_fields),
                                   old_entry.annotate(**new_exprs).drop(*dropped_fields))

            if 'LPL' in fields:
                return hl.bind(with_pl, new_pl)
            else:
                return with_pl(None)

        lai = hl.fold(lambda accum, elt:
                        hl.cond(old_entry.LA[elt] == ds[new_id].a_index,
                                elt, accum),
                        hl.null(hl.tint32),
                        hl.range(0, hl.len(old_entry.LA)))
        return hl.bind(with_local_a_index, lai)
Esempio n. 35
0
    def phase_diploid_proband(
            locus: hl.expr.LocusExpression,
            alleles: hl.expr.ArrayExpression,
            proband_call: hl.expr.CallExpression,
            father_call: hl.expr.CallExpression,
            mother_call: hl.expr.CallExpression
    ) -> hl.expr.ArrayExpression:
        """
        Returns phased genotype calls in the case of a diploid proband
        (autosomes, PAR regions of sex chromosomes or non-PAR regions of a female proband)

        :param LocusExpression locus: Locus in the trio MatrixTable
        :param ArrayExpression alleles: Alleles in the trio MatrixTable
        :param CallExpression proband_call: Input proband genotype call
        :param CallExpression father_call: Input father genotype call
        :param CallExpression mother_call: Input mother genotype call
        :return: Array containing: phased proband call, phased father call, phased mother call
        :rtype: ArrayExpression
        """

        proband_v = proband_call.one_hot_alleles(alleles)
        father_v = hl.cond(
            locus.in_x_nonpar() | locus.in_y_nonpar(),
            hl.or_missing(father_call.is_haploid(), hl.array([father_call.one_hot_alleles(alleles)])),
            call_to_one_hot_alleles_array(father_call, alleles)
        )
        mother_v = call_to_one_hot_alleles_array(mother_call, alleles)

        combinations = hl.flatmap(
            lambda f:
            hl.zip_with_index(mother_v)
                .filter(lambda m: m[1] + f[1] == proband_v)
                .map(lambda m: hl.struct(m=m[0], f=f[0])),
            hl.zip_with_index(father_v)
        )

        return (
            hl.or_missing(
                hl.is_defined(combinations) & (hl.len(combinations) == 1),
                hl.array([
                    hl.call(father_call[combinations[0].f], mother_call[combinations[0].m], phased=True),
                    hl.cond(father_call.is_haploid(), hl.call(father_call[0], phased=True), phase_parent_call(father_call, combinations[0].f)),
                    phase_parent_call(mother_call, combinations[0].m)
                ])
            )
        )
Esempio n. 36
0
def densify(sparse_mt):
    """Convert sparse MatrixTable to a dense one.

    Parameters
    ----------
    sparse_mt : :class:`.MatrixTable`
        Sparse MatrixTable to densify.  The first row key field must
        be named ``locus`` and have type ``locus``.  Must have an
        ``END`` entry field of type ``int32``.

    Returns
    -------
    :class:`.MatrixTable`
        The densified MatrixTable.  The ``END`` entry field is dropped.

    """
    if list(sparse_mt.row_key)[0] != 'locus' or not isinstance(sparse_mt.locus.dtype, hl.tlocus):
        raise ValueError("first row key field must be named 'locus' and have type 'locus'")
    if 'END' not in sparse_mt.entry or sparse_mt.END.dtype != hl.tint32:
        raise ValueError("'densify' requires 'END' entry field of type 'int32'")
    col_key_fields = list(sparse_mt.col_key)

    mt = sparse_mt
    mt = sparse_mt.annotate_entries(__contig = mt.locus.contig)
    t = mt._localize_entries('__entries', '__cols')
    t = t.annotate(
        __entries = hl.rbind(
            hl.scan.array_agg(
                lambda entry: hl.scan._prev_nonnull(hl.or_missing(hl.is_defined(entry.END), entry)),
                t.__entries),
            lambda prev_entries: hl.map(
                lambda i:
                hl.rbind(
                    prev_entries[i], t.__entries[i],
                    lambda prev_entry, entry:
                    hl.cond(
                        (~hl.is_defined(entry) &
                         (prev_entry.END >= t.locus.position) &
                         (prev_entry.__contig == t.locus.contig)),
                        prev_entry,
                        entry)),
                hl.range(0, hl.len(t.__entries)))))
    mt = t._unlocalize_entries('__entries', '__cols', col_key_fields)
    mt = mt.drop('__contig', 'END')
    return mt
Esempio n. 37
0
    def test_distinct(self):
        t1 = hl.Table.parallelize([
            {'a': 'foo', 'b': 1},
            {'a': 'bar', 'b': 2},
            {'a': 'bar', 'b': 2},
            {'a': 'bar', 'b': 3},
            {'a': 'bar', 'b': 3},
            {'a': 'baz', 'b': 2},
            {'a': 'baz', 'b': 0},
            {'a': 'baz', 'b': 0},
            {'a': 'foo', 'b': 0},
            {'a': '1', 'b': 0},
            {'a': '2', 'b': 0},
            {'a': '3', 'b': 0}],
            hl.tstruct(a=hl.tstr, b=hl.tint32),
            key='a',
            n_partitions=4)

        dist = t1.distinct().collect_by_key()
        self.assertTrue(dist.all(hl.len(dist.values) == 1))
        self.assertEqual(dist.count(), len(t1.aggregate(hl.agg.collect_as_set(t1.a))))
Esempio n. 38
0
def transform_one(mt: MatrixTable) -> MatrixTable:
    """transforms a gvcf into a form suitable for combining"""
    mt = mt.annotate_entries(
        # local (alt) allele index into global (alt) alleles
        LA=hl.range(0, hl.len(mt.alleles) - 1),
        END=mt.info.END,
        PL=mt['PL'][0:],
        BaseQRankSum=mt.info['BaseQRankSum'],
        ClippingRankSum=mt.info['ClippingRankSum'],
        MQ=mt.info['MQ'],
        MQRankSum=mt.info['MQRankSum'],
        ReadPosRankSum=mt.info['ReadPosRankSum'],
    )
    # This collects all fields with median combiners into arrays so we can calculate medians
    # when needed
    mt = mt.annotate_rows(
        # now minrep'ed (ref, alt) allele pairs
        alleles=hl.bind(lambda ref: mt.alleles[1:].map(lambda alt:
                                                       # minrep <NON_REF>
                                                       hl.struct(ref=hl.cond(alt == "<NON_REF>",
                                                                             ref[0:1],
                                                                             ref),
                                                                 alt=alt)),
                        mt.alleles[0]),
        info=mt.info.annotate(
            SB=hl.agg.array_sum(mt.entry.SB)
        ).select(
            "DP",
            "MQ_DP",
            "QUALapprox",
            "RAW_MQ",
            "VarDP",
            "SB",
        ))
    mt = mt.drop('SB', 'qual')

    return mt
Esempio n. 39
0
            def with_pl(pl):
                new_exprs = {}
                dropped_fields = ['LA']
                if 'LGT' in fields:
                    new_exprs['GT'] = hl.downcode(old_entry.LGT, hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LGT')
                if 'LPGT' in fields:
                    new_exprs['PGT'] = hl.downcode(old_entry.LPGT, hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LPGT')
                if 'LAD' in fields:
                    new_exprs['AD'] = hl.or_missing(
                        hl.is_defined(old_entry.LAD),
                        [old_entry.LAD[0], hl.or_else(old_entry.LAD[local_a_index], 0)]) # second entry zeroed for lack of non-ref AD
                    dropped_fields.append('LAD')
                if 'LPL' in fields:
                    new_exprs['PL'] = pl
                    if 'GQ' in fields:
                        new_exprs['GQ'] = hl.or_else(hl.gq_from_pl(pl), old_entry.GQ)

                    dropped_fields.append('LPL')

                return hl.cond(hl.len(ds.alleles) == 1,
                                   old_entry.annotate(**{f[1:]: old_entry[f] for f in ['LGT', 'LPGT', 'LAD', 'LPL'] if f in fields}).drop(*dropped_fields),
                                   old_entry.annotate(**new_exprs).drop(*dropped_fields))
Esempio n. 40
0
def call_stats(call, alleles) -> StructExpression:
    """Compute useful call statistics.

    Examples
    --------
    Compute call statistics per row:

    >>> dataset_result = dataset.annotate_rows(gt_stats = agg.call_stats(dataset.GT, dataset.alleles))
    >>> dataset_result.rows().key_by('locus').select('gt_stats').show()
    +---------------+--------------+----------------+-------------+---------------------------+
    | locus         | gt_stats.AC  | gt_stats.AF    | gt_stats.AN | gt_stats.homozygote_count |
    +---------------+--------------+----------------+-------------+---------------------------+
    | locus<GRCh37> | array<int32> | array<float64> |       int32 | array<int32>              |
    +---------------+--------------+----------------+-------------+---------------------------+
    | 20:10579373   | [199,1]      | [0.995,0.005]  |         200 | [99,0]                    |
    | 20:13695607   | [177,23]     | [0.885,0.115]  |         200 | [77,0]                    |
    | 20:13698129   | [198,2]      | [0.99,0.01]    |         200 | [98,0]                    |
    | 20:14306896   | [142,58]     | [0.71,0.29]    |         200 | [51,9]                    |
    | 20:14306953   | [121,79]     | [0.605,0.395]  |         200 | [38,17]                   |
    | 20:15948325   | [172,2]      | [0.989,0.012]  |         174 | [85,0]                    |
    | 20:15948326   | [174,8]      | [0.956,0.043]  |         182 | [83,0]                    |
    | 20:17479423   | [199,1]      | [0.995,0.005]  |         200 | [99,0]                    |
    | 20:17600357   | [79,121]     | [0.395,0.605]  |         200 | [24,45]                   |
    | 20:17640833   | [193,3]      | [0.985,0.015]  |         196 | [95,0]                    |
    +---------------+--------------+----------------+-------------+---------------------------+

    Notes
    -----
    This method is meaningful for computing call metrics per variant, but not
    especially meaningful for computing metrics per sample.

    This method returns a struct expression with three fields:

     - `AC` (:class:`.tarray` of :py:data:`.tint32`) - Allele counts. One element
       for each allele, including the reference.
     - `AF` (:class:`.tarray` of :py:data:`.tfloat64`) - Allele frequencies. One
       element for each allele, including the reference.
     - `AN` (:py:data:`.tint32`) - Allele number. The total number of called
       alleles, or the number of non-missing calls * 2.
     - `homozygote_count` (:class:`.tarray` of :py:data:`.tint32`) - Homozygote
       genotype counts for each allele, including the reference. Only **diploid**
       genotype calls are counted.

    Parameters
    ----------
    call : :class:`.CallExpression`
    alleles : :class:`.ArrayStringExpression`
        Variant alleles.

    Returns
    -------
    :class:`.StructExpression`
        Struct expression with fields `AC`, `AF`, `AN`, and `homozygote_count`.
    """
    n_alleles = hl.len(alleles)
    t = tstruct(AC=tarray(tint32),
                AF=tarray(tfloat64),
                AN=tint32,
                homozygote_count=tarray(tint32))

    return _agg_func('CallStats', [call], t, [], init_op_args=[n_alleles])
Esempio n. 41
0
def transform_one(mt, vardp_outlier=100_000) -> Table:
    """transforms a gvcf into a form suitable for combining

    The input to this should be some result of either :func:`.import_vcf` or
    :func:`.import_vcfs` with `array_elements_required=False`.

    There is a strong assumption that this function will be called on a matrix
    table with one column.
    """
    mt = localize(mt)
    if mt.row.dtype not in _transform_rows_function_map:
        f = hl.experimental.define_function(
            lambda row: hl.rbind(
                hl.len(row.alleles), '<NON_REF>' == row.alleles[-1],
                lambda alleles_len, has_non_ref: hl.struct(
                    locus=row.locus,
                    alleles=hl.cond(has_non_ref, row.alleles[:-1], row.alleles),
                    rsid=row.rsid,
                    __entries=row.__entries.map(
                        lambda e:
                        hl.struct(
                            DP=e.DP,
                            END=row.info.END,
                            GQ=e.GQ,
                            LA=hl.range(0, alleles_len - hl.cond(has_non_ref, 1, 0)),
                            LAD=hl.cond(has_non_ref, e.AD[:-1], e.AD),
                            LGT=e.GT,
                            LPGT=e.PGT,
                            LPL=hl.cond(has_non_ref,
                                        hl.cond(alleles_len > 2,
                                                e.PL[:-alleles_len],
                                                hl.null(e.PL.dtype)),
                                        hl.cond(alleles_len > 1,
                                                e.PL,
                                                hl.null(e.PL.dtype))),
                            MIN_DP=e.MIN_DP,
                            PID=e.PID,
                            RGQ=hl.cond(
                                has_non_ref,
                                e.PL[hl.call(0, alleles_len - 1).unphased_diploid_gt_index()],
                                hl.null(e.PL.dtype.element_type)),
                            SB=e.SB,
                            gvcf_info=hl.case()
                                .when(hl.is_missing(row.info.END),
                                      hl.struct(
                                          ClippingRankSum=row.info.ClippingRankSum,
                                          BaseQRankSum=row.info.BaseQRankSum,
                                          MQ=row.info.MQ,
                                          MQRankSum=row.info.MQRankSum,
                                          MQ_DP=row.info.MQ_DP,
                                          QUALapprox=row.info.QUALapprox,
                                          RAW_MQ=row.info.RAW_MQ,
                                          ReadPosRankSum=row.info.ReadPosRankSum,
                                          VarDP=hl.cond(row.info.VarDP > vardp_outlier,
                                                        row.info.DP, row.info.VarDP)))
                                .or_missing()
                        ))),
            ),
            mt.row.dtype)
        _transform_rows_function_map[mt.row.dtype] = f
    transform_row = _transform_rows_function_map[mt.row.dtype]
    return Table(TableMapRows(mt._tir, Apply(transform_row._name, TopLevelReference('row'))))
Esempio n. 42
0
File: qc.py Progetto: tpoterba/hail
def summarize_variants(mt: MatrixTable, show=True):
    """Summarize the variants present in a dataset and print the results.

    Examples
    --------
    >>> hl.summarize_variants(dataset)  # doctest: +SKIP
    ==============================
    Number of variants: 346
    ==============================
    Alleles per variant
    -------------------
      2 alleles: 346 variants
    ==============================
    Variants per contig
    -------------------
      20: 346 variants
    ==============================
    Allele type distribution
    ------------------------
            SNP: 301 alleles
       Deletion: 27 alleles
      Insertion: 18 alleles
    ==============================

    Parameters
    ----------
    mt : :class:`.MatrixTable`
        Matrix table with a variant (locus / alleles) row key.
    show : :obj:`bool`
        If ``True``, print results instead of returning them.

    Notes
    -----
    The result returned if `show` is ``False`` is a  :class:`.Struct` with
    four fields:

    - `n_variants` (:obj:`int`): Number of variants present in the matrix table.
    - `allele_types` (:obj:`Dict[str, int]`): Number of alternate alleles in
      each allele allele category.
    - `contigs` (:obj:`Dict[str, int]`): Number of variants on each contig.
    - `allele_counts` (:obj:`Dict[int, int]`): Number of variants broken down
      by number of alleles (biallelic is 2, for example).

    Returns
    -------
    :obj:`None` or :class:`.Struct`
        Returns ``None`` if `show` is ``True``, or returns results as a struct.
    """
    require_row_key_variant(mt, 'summarize_variants')
    alleles_per_variant = hl.range(1, hl.len(mt.alleles)).map(lambda i: hl.allele_type(mt.alleles[0], mt.alleles[i]))
    allele_types, contigs, allele_counts, n_variants = mt.aggregate_rows(
        (hl.agg.explode(lambda elt: hl.agg.counter(elt), alleles_per_variant),
         hl.agg.counter(mt.locus.contig),
         hl.agg.counter(hl.len(mt.alleles)),
         hl.agg.count()))
    rg = mt.locus.dtype.reference_genome
    contig_idx = {contig: i for i, contig in enumerate(rg.contigs)}
    if show:
        max_contig_len = max(len(contig) for contig in contigs)
        contig_formatter = f'%{max_contig_len}s'

        max_allele_count_len = max(len(str(x)) for x in allele_counts)
        allele_count_formatter = f'%{max_allele_count_len}s'

        max_allele_type_len = max(len(x) for x in allele_types)
        allele_type_formatter = f'%{max_allele_type_len}s'

        line_break = '=============================='

        print(line_break)
        print(f'Number of variants: {n_variants}')
        print(line_break)
        print('Alleles per variant')
        print('-------------------')
        for n_alleles, count in sorted(allele_counts.items(), key=lambda x: x[0]):
            print(f'  {allele_count_formatter % n_alleles} alleles: {count} variants')
        print(line_break)
        print('Variants per contig')
        print('-------------------')
        for contig, count in sorted(contigs.items(), key=lambda x: contig_idx[x[0]]):
            print(f'  {contig_formatter % contig}: {count} variants')
        print(line_break)
        print('Allele type distribution')
        print('------------------------')
        for allele_type, count in Counter(allele_types).most_common():
            print(f'  {allele_type_formatter % allele_type}: {count} alternate alleles')
        print(line_break)
    else:
        return hl.Struct(allele_types=allele_types,
                         contigs=contigs,
                         allele_counts=allele_counts,
                         n_variants=n_variants)
Esempio n. 43
0
File: qc.py Progetto: tpoterba/hail
def variant_qc(mt, name='variant_qc') -> MatrixTable:
    """Compute common variant statistics (quality control metrics).

    .. include:: ../_templates/req_tvariant.rst

    Examples
    --------

    >>> dataset_result = hl.variant_qc(dataset)

    Notes
    -----
    This method computes variant statistics from the genotype data, returning
    a new struct field `name` with the following metrics based on the fields
    present in the entry schema.

    If `mt` contains an entry field `DP` of type :py:data:`.tint32`, then the
    field `dp_stats` is computed. If `mt` contains an entry field `GQ` of type
    :py:data:`.tint32`, then the field `gq_stats` is computed. Both `dp_stats`
    and `gq_stats` are structs with with four fields:

    - `mean` (``float64``) -- Mean value.
    - `stdev` (``float64``) -- Standard deviation (zero degrees of freedom).
    - `min` (``int32``) -- Minimum value.
    - `max` (``int32``) -- Maximum value.

    If the dataset does not contain an entry field `GT` of type
    :py:data:`.tcall`, then an error is raised. The following fields are always
    computed from `GT`:

    - `AF` (``array<float64>``) -- Calculated allele frequency, one element
      per allele, including the reference. Sums to one. Equivalent to
      `AC` / `AN`.
    - `AC` (``array<int32>``) -- Calculated allele count, one element per
      allele, including the reference. Sums to `AN`.
    - `AN` (``int32``) -- Total number of called alleles.
    - `homozygote_count` (``array<int32>``) -- Number of homozygotes per
      allele. One element per allele, including the reference.
    - `n_called` (``int64``) -- Number of samples with a defined `GT`.
    - `n_not_called` (``int64``) -- Number of samples with a missing `GT`.
    - `call_rate` (``float32``) -- Fraction of samples with a defined `GT`.
      Equivalent to `n_called` / :meth:`.count_cols`.
    - `n_het` (``int64``) -- Number of heterozygous samples.
    - `n_non_ref` (``int64``) -- Number of samples with at least one called
      non-reference allele.
    - `het_freq_hwe` (``float64``) -- Expected frequency of heterozygous
      samples under Hardy-Weinberg equilibrium. See
      :func:`.functions.hardy_weinberg_test` for details.
    - `p_value_hwe` (``float64``) -- p-value from test of Hardy-Weinberg equilibrium.
      See :func:`.functions.hardy_weinberg_test` for details.

    Warning
    -------
    `het_freq_hwe` and `p_value_hwe` are calculated as in
    :func:`.functions.hardy_weinberg_test`, with non-diploid calls
    (``ploidy != 2``) ignored in the counts. As this test is only
    statistically rigorous in the biallelic setting, :func:`.variant_qc`
    sets both fields to missing for multiallelic variants. Consider using
    :func:`~hail.methods.split_multi` to split multi-allelic variants beforehand.

    Parameters
    ----------
    mt : :class:`.MatrixTable`
        Dataset.
    name : :obj:`str`
        Name for resulting field.

    Returns
    -------
    :class:`.MatrixTable`
    """
    require_row_key_variant(mt, 'variant_qc')

    exprs = {}
    struct_exprs = []

    def has_field_of_type(name, dtype):
        return name in mt.entry and mt[name].dtype == dtype

    n_samples = mt.count_cols()

    if has_field_of_type('DP', hl.tint32):
        exprs['dp_stats'] = hl.agg.stats(mt.DP).select('mean', 'stdev', 'min', 'max')

    if has_field_of_type('GQ', hl.tint32):
        exprs['gq_stats'] = hl.agg.stats(mt.GQ).select('mean', 'stdev', 'min', 'max')

    if not has_field_of_type('GT',  hl.tcall):
        raise ValueError(f"'variant_qc': expect an entry field 'GT' of type 'call'")
    exprs['n_called'] = hl.agg.count_where(hl.is_defined(mt['GT']))
    struct_exprs.append(hl.agg.call_stats(mt.GT, mt.alleles))


    # the structure of this function makes it easy to add new nested computations
    def flatten_struct(*struct_exprs):
        flat = {}
        for struct in struct_exprs:
            for k, v in struct.items():
                flat[k] = v
        return hl.struct(
            **flat,
            **exprs,
        )

    mt = mt.annotate_rows(**{name: hl.bind(flatten_struct, *struct_exprs)})

    hwe = hl.hardy_weinberg_test(mt[name].homozygote_count[0],
                                 mt[name].AC[1] - 2 * mt[name].homozygote_count[1],
                                 mt[name].homozygote_count[1])
    hwe = hwe.select(het_freq_hwe=hwe.het_freq_hwe, p_value_hwe=hwe.p_value)
    mt = mt.annotate_rows(**{name: mt[name].annotate(n_not_called=n_samples - mt[name].n_called,
                                                     call_rate=mt[name].n_called / n_samples,
                                                     n_het=mt[name].n_called - hl.sum(mt[name].homozygote_count),
                                                     n_non_ref=mt[name].n_called - mt[name].homozygote_count[0],
                                                     **hl.cond(hl.len(mt.alleles) == 2,
                                                               hwe,
                                                               hl.null(hwe.dtype)))})
    return mt
Esempio n. 44
0
def combine(ts):
    def merge_alleles(alleles):
        from hail.expr.functions import _num_allele_type, _allele_ints
        return hl.rbind(
            alleles.map(lambda a: hl.or_else(a[0], ''))
                   .fold(lambda s, t: hl.cond(hl.len(s) > hl.len(t), s, t), ''),
            lambda ref:
            hl.rbind(
                alleles.map(
                    lambda al: hl.rbind(
                        al[0],
                        lambda r:
                        hl.array([ref]).extend(
                            al[1:].map(
                                lambda a:
                                hl.rbind(
                                    _num_allele_type(r, a),
                                    lambda at:
                                    hl.cond(
                                        (_allele_ints['SNP'] == at) |
                                        (_allele_ints['Insertion'] == at) |
                                        (_allele_ints['Deletion'] == at) |
                                        (_allele_ints['MNP'] == at) |
                                        (_allele_ints['Complex'] == at),
                                        a + ref[hl.len(r):],
                                        a)))))),
                lambda lal:
                hl.struct(
                    globl=hl.array([ref]).extend(hl.array(hl.set(hl.flatten(lal)).remove(ref))),
                    local=lal)))

    def renumber_entry(entry, old_to_new) -> StructExpression:
        # global index of alternate (non-ref) alleles
        return entry.annotate(LA=entry.LA.map(lambda lak: old_to_new[lak]))

    if (ts.row.dtype, ts.globals.dtype) not in _merge_function_map:
        f = hl.experimental.define_function(
            lambda row, gbl:
            hl.rbind(
                merge_alleles(row.data.map(lambda d: d.alleles)),
                lambda alleles:
                hl.struct(
                    locus=row.locus,
                    alleles=alleles.globl,
                    rsid=hl.find(hl.is_defined, row.data.map(lambda d: d.rsid)),
                    __entries=hl.bind(
                        lambda combined_allele_index:
                        hl.range(0, hl.len(row.data)).flatmap(
                            lambda i:
                            hl.cond(hl.is_missing(row.data[i].__entries),
                                    hl.range(0, hl.len(gbl.g[i].__cols))
                                      .map(lambda _: hl.null(row.data[i].__entries.dtype.element_type)),
                                    hl.bind(
                                        lambda old_to_new: row.data[i].__entries.map(
                                            lambda e: renumber_entry(e, old_to_new)),
                                        hl.range(0, hl.len(alleles.local[i])).map(
                                            lambda j: combined_allele_index[alleles.local[i][j]])))),
                        hl.dict(hl.range(0, hl.len(alleles.globl)).map(
                            lambda j: hl.tuple([alleles.globl[j], j])))))),
            ts.row.dtype, ts.globals.dtype)
        _merge_function_map[(ts.row.dtype, ts.globals.dtype)] = f
    merge_function = _merge_function_map[(ts.row.dtype, ts.globals.dtype)]
    ts = Table(TableMapRows(ts._tir, Apply(merge_function._name,
                                           TopLevelReference('row'),
                                           TopLevelReference('global'))))
    return ts.transmute_globals(__cols=hl.flatten(ts.g.map(lambda g: g.__cols)))
Esempio n. 45
0
    def test_trio_matrix(self):
        """
        This test depends on certain properties of the trio matrix VCF and
        pedigree structure. This test is NOT a valid test if the pedigree
        includes quads: the trio_matrix method will duplicate the parents
        appropriately, but the genotypes_table and samples_table orthogonal
        paths would require another duplication/explode that we haven't written.
        """
        ped = hl.Pedigree.read(resource('triomatrix.fam'))
        ht = hl.import_fam(resource('triomatrix.fam'))

        mt = hl.import_vcf(resource('triomatrix.vcf'))
        mt = mt.annotate_cols(fam=ht[mt.s].fam_id)

        dads = ht.filter(hl.is_defined(ht.pat_id))
        dads = dads.select(dads.pat_id, is_dad=True).key_by('pat_id')

        moms = ht.filter(hl.is_defined(ht.mat_id))
        moms = moms.select(moms.mat_id, is_mom=True).key_by('mat_id')

        et = (mt.entries()
              .key_by('s')
              .join(dads, how='left')
              .join(moms, how='left'))
        et = et.annotate(is_dad=hl.is_defined(et.is_dad),
                         is_mom=hl.is_defined(et.is_mom))

        et = (et
            .group_by(et.locus, et.alleles, fam=et.fam)
            .aggregate(data=hl.agg.collect(hl.struct(
            role=hl.case().when(et.is_dad, 1).when(et.is_mom, 2).default(0),
            g=hl.struct(GT=et.GT, AD=et.AD, DP=et.DP, GQ=et.GQ, PL=et.PL)))))

        et = et.filter(hl.len(et.data) == 3)
        et = et.select('data').explode('data')

        tt = hl.trio_matrix(mt, ped, complete_trios=True).entries().key_by('locus', 'alleles')
        tt = tt.annotate(fam=tt.proband.fam,
                         data=[hl.struct(role=0, g=tt.proband_entry.select('GT', 'AD', 'DP', 'GQ', 'PL')),
                               hl.struct(role=1, g=tt.father_entry.select('GT', 'AD', 'DP', 'GQ', 'PL')),
                               hl.struct(role=2, g=tt.mother_entry.select('GT', 'AD', 'DP', 'GQ', 'PL'))])
        tt = tt.select('fam', 'data').explode('data')
        tt = tt.filter(hl.is_defined(tt.data.g)).key_by('locus', 'alleles', 'fam')

        self.assertEqual(et.key.dtype, tt.key.dtype)
        self.assertEqual(et.row.dtype, tt.row.dtype)
        self.assertTrue(et._same(tt))

        # test annotations
        e_cols = (mt.cols()
                  .join(dads, how='left')
                  .join(moms, how='left'))
        e_cols = e_cols.annotate(is_dad=hl.is_defined(e_cols.is_dad),
                                 is_mom=hl.is_defined(e_cols.is_mom))
        e_cols = (e_cols.group_by(fam=e_cols.fam)
                  .aggregate(data=hl.agg.collect(hl.struct(role=hl.case()
                                                           .when(e_cols.is_dad, 1).when(e_cols.is_mom, 2).default(0),
                                                           sa=hl.struct(**e_cols.row.select(*mt.col))))))
        e_cols = e_cols.filter(hl.len(e_cols.data) == 3).select('data').explode('data')

        t_cols = hl.trio_matrix(mt, ped, complete_trios=True).cols()
        t_cols = t_cols.annotate(fam=t_cols.proband.fam,
                                 data=[
                                     hl.struct(role=0, sa=t_cols.proband),
                                     hl.struct(role=1, sa=t_cols.father),
                                     hl.struct(role=2, sa=t_cols.mother)]).key_by('fam').select('data').explode('data')
        t_cols = t_cols.filter(hl.is_defined(t_cols.data.sa))

        self.assertEqual(e_cols.key.dtype, t_cols.key.dtype)
        self.assertEqual(e_cols.row.dtype, t_cols.row.dtype)
        self.assertTrue(e_cols._same(t_cols))
Esempio n. 46
0
 def test_literals_rebuild(self):
     mt = hl.utils.range_matrix_table(1, 1)
     mt = mt.annotate_rows(x = hl.cond(hl.len(hl.literal([1,2,3])) < hl.rand_unif(10, 11), mt.globals, hl.struct()))
     mt._force_count_rows()
Esempio n. 47
0
    def test_export_plink_exprs(self):
        ds = get_dataset()
        fam_mapping = {'f0': 'fam_id', 'f1': 'ind_id', 'f2': 'pat_id', 'f3': 'mat_id',
                       'f4': 'is_female', 'f5': 'pheno'}
        bim_mapping = {'f0': 'contig', 'f1': 'varid', 'f2': 'cm_position',
                       'f3': 'position', 'f4': 'a1', 'f5': 'a2'}

        # Test default arguments
        out1 = new_temp_file()
        hl.export_plink(ds, out1)
        fam1 = (hl.import_table(out1 + '.fam', no_header=True, impute=False, missing="")
                .rename(fam_mapping))
        bim1 = (hl.import_table(out1 + '.bim', no_header=True, impute=False)
                .rename(bim_mapping))

        self.assertTrue(fam1.all((fam1.fam_id == "0") & (fam1.pat_id == "0") &
                                 (fam1.mat_id == "0") & (fam1.is_female == "0") &
                                 (fam1.pheno == "NA")))
        self.assertTrue(bim1.all((bim1.varid == bim1.contig + ":" + bim1.position + ":" + bim1.a2 + ":" + bim1.a1) &
                                 (bim1.cm_position == "0.0")))

        # Test non-default FAM arguments
        out2 = new_temp_file()
        hl.export_plink(ds, out2, ind_id=ds.s, fam_id=ds.s, pat_id="nope",
                        mat_id="nada", is_female=True, pheno=False)
        fam2 = (hl.import_table(out2 + '.fam', no_header=True, impute=False, missing="")
                .rename(fam_mapping))

        self.assertTrue(fam2.all((fam2.fam_id == fam2.ind_id) & (fam2.pat_id == "nope") &
                                 (fam2.mat_id == "nada") & (fam2.is_female == "2") &
                                 (fam2.pheno == "1")))

        # Test quantitative phenotype
        out3 = new_temp_file()
        hl.export_plink(ds, out3, ind_id=ds.s, pheno=hl.float64(hl.len(ds.s)))
        fam3 = (hl.import_table(out3 + '.fam', no_header=True, impute=False, missing="")
                .rename(fam_mapping))

        self.assertTrue(fam3.all((fam3.fam_id == "0") & (fam3.pat_id == "0") &
                                 (fam3.mat_id == "0") & (fam3.is_female == "0") &
                                 (fam3.pheno != "0") & (fam3.pheno != "NA")))

        # Test non-default BIM arguments
        out4 = new_temp_file()
        hl.export_plink(ds, out4, varid="hello", cm_position=100)
        bim4 = (hl.import_table(out4 + '.bim', no_header=True, impute=False)
                .rename(bim_mapping))

        self.assertTrue(bim4.all((bim4.varid == "hello") & (bim4.cm_position == "100.0")))

        # Test call expr
        out5 = new_temp_file()
        ds_call = ds.annotate_entries(gt_fake=hl.call(0, 0))
        hl.export_plink(ds_call, out5, call=ds_call.gt_fake)
        ds_all_hom_ref = hl.import_plink(out5 + '.bed', out5 + '.bim', out5 + '.fam')
        nerrors = ds_all_hom_ref.aggregate_entries(hl.agg.count_where(~ds_all_hom_ref.GT.is_hom_ref()))
        self.assertTrue(nerrors == 0)

        # Test white-space in FAM id expr raises error
        with self.assertRaisesRegex(TypeError, "has spaces in the following values:"):
            hl.export_plink(ds, new_temp_file(), mat_id="hello world")

        # Test white-space in varid expr raises error
        with self.assertRaisesRegex(FatalError, "no white space allowed:"):
            hl.export_plink(ds, new_temp_file(), varid="hello world")
Esempio n. 48
0
def ld_score_regression(weight_expr,
                        ld_score_expr,
                        chi_sq_exprs,
                        n_samples_exprs,
                        n_blocks=200,
                        two_step_threshold=30,
                        n_reference_panel_variants=None) -> Table:
    r"""Estimate SNP-heritability and level of confounding biases from
    GWAS summary statistics.

    Given a set or multiple sets of genome-wide association study (GWAS)
    summary statistics, :func:`.ld_score_regression` estimates the heritability
    of a trait or set of traits and the level of confounding biases present in
    the underlying studies by regressing chi-squared statistics on LD scores,
    leveraging the model:

    .. math::

        \mathrm{E}[\chi_j^2] = 1 + Na + \frac{Nh_g^2}{M}l_j

    *  :math:`\mathrm{E}[\chi_j^2]` is the expected chi-squared statistic
       for variant :math:`j` resulting from a test of association between
       variant :math:`j` and a trait.
    *  :math:`l_j = \sum_{k} r_{jk}^2` is the LD score of variant
       :math:`j`, calculated as the sum of squared correlation coefficients
       between variant :math:`j` and nearby variants. See :func:`ld_score`
       for further details.
    *  :math:`a` captures the contribution of confounding biases, such as
       cryptic relatedness and uncontrolled population structure, to the
       association test statistic.
    *  :math:`h_g^2` is the SNP-heritability, or the proportion of variation
       in the trait explained by the effects of variants included in the
       regression model above.
    *  :math:`M` is the number of variants used to estimate :math:`h_g^2`.
    *  :math:`N` is the number of samples in the underlying association study.

    For more details on the method implemented in this function, see:

    * `LD Score regression distinguishes confounding from polygenicity in genome-wide association studies (Bulik-Sullivan et al, 2015) <https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4495769/>`__

    Examples
    --------

    Run the method on a matrix table of summary statistics, where the rows
    are variants and the columns are different phenotypes:

    >>> mt_gwas = hl.read_matrix_table('data/ld_score_regression.sumstats.mt')
    >>> ht_results = hl.experimental.ld_score_regression(
    ...     weight_expr=mt_gwas['ld_score'],
    ...     ld_score_expr=mt_gwas['ld_score'],
    ...     chi_sq_exprs=mt_gwas['chi_squared'],
    ...     n_samples_exprs=mt_gwas['n'])


    Run the method on a table with summary statistics for a single
    phenotype:

    >>> ht_gwas = hl.read_table('data/ld_score_regression.sumstats.ht')
    >>> ht_results = hl.experimental.ld_score_regression(
    ...     weight_expr=ht_gwas['ld_score'],
    ...     ld_score_expr=ht_gwas['ld_score'],
    ...     chi_sq_exprs=ht_gwas['chi_squared_50_irnt'],
    ...     n_samples_exprs=ht_gwas['n_50_irnt'])

    Run the method on a table with summary statistics for multiple
    phenotypes:

    >>> ht_gwas = hl.read_table('data/ld_score_regression.sumstats.ht')
    >>> ht_results = hl.experimental.ld_score_regression(
    ...     weight_expr=ht_gwas['ld_score'],
    ...     ld_score_expr=ht_gwas['ld_score'],
    ...     chi_sq_exprs=[ht_gwas['chi_squared_50_irnt'],
    ...                        ht_gwas['chi_squared_20160']],
    ...     n_samples_exprs=[ht_gwas['n_50_irnt'],
    ...                      ht_gwas['n_20160']])

    Notes
    -----
    The ``exprs`` provided as arguments to :func:`.ld_score_regression`
    must all be from the same object, either a :class:`Table` or a
    :class:`MatrixTable`.

    **If the arguments originate from a table:**

    *  The table must be keyed by fields ``locus`` of type
       :class:`.tlocus` and ``alleles``, a :py:data:`.tarray` of
       :py:data:`.tstr` elements.
    *  ``weight_expr``, ``ld_score_expr``, ``chi_sq_exprs``, and
       ``n_samples_exprs`` are must be row-indexed fields.
    *  The number of expressions passed to ``n_samples_exprs`` must be
       equal to one or the number of expressions passed to
       ``chi_sq_exprs``. If just one expression is passed to
       ``n_samples_exprs``, that sample size expression is assumed to
       apply to all sets of statistics passed to ``chi_sq_exprs``.
       Otherwise, the expressions passed to ``chi_sq_exprs`` and
       ``n_samples_exprs`` are matched by index.
    *  The ``phenotype`` field that keys the table returned by
       :func:`.ld_score_regression` will have generic :obj:`int` values
       ``0``, ``1``, etc. corresponding to the ``0th``, ``1st``, etc.
       expressions passed to the ``chi_sq_exprs`` argument.

    **If the arguments originate from a matrix table:**

    *  The dimensions of the matrix table must be variants
       (rows) by phenotypes (columns).
    *  The rows of the matrix table must be keyed by fields
       ``locus`` of type :class:`.tlocus` and ``alleles``,
       a :py:data:`.tarray` of :py:data:`.tstr` elements.
    *  The columns of the matrix table must be keyed by a field
       of type :py:data:`.tstr` that uniquely identifies phenotypes
       represented in the matrix table. The column key must be a single
       expression; compound keys are not accepted.
    *  ``weight_expr`` and ``ld_score_expr`` must be row-indexed
       fields.
    *  ``chi_sq_exprs`` must be a single entry-indexed field
       (not a list of fields).
    *  ``n_samples_exprs`` must be a single entry-indexed field
       (not a list of fields).
    *  The ``phenotype`` field that keys the table returned by
       :func:`.ld_score_regression` will have values corresponding to the
       column keys of the input matrix table.

    This function returns a :class:`Table` with one row per set of summary
    statistics passed to the ``chi_sq_exprs`` argument. The following
    row-indexed fields are included in the table:

    *  **phenotype** (:py:data:`.tstr`) -- The name of the phenotype. The
       returned table is keyed by this field. See the notes below for
       details on the possible values of this field.
    *  **mean_chi_sq** (:py:data:`.tfloat64`) -- The mean chi-squared
       test statistic for the given phenotype.
    *  **intercept** (`Struct`) -- Contains fields:

       -  **estimate** (:py:data:`.tfloat64`) -- A point estimate of the
          intercept :math:`1 + Na`.
       -  **standard_error**  (:py:data:`.tfloat64`) -- An estimate of
          the standard error of this point estimate.

    *  **snp_heritability** (`Struct`) -- Contains fields:

       -  **estimate** (:py:data:`.tfloat64`) -- A point estimate of the
          SNP-heritability :math:`h_g^2`.
       -  **standard_error** (:py:data:`.tfloat64`) -- An estimate of
          the standard error of this point estimate.

    Warning
    -------
    :func:`.ld_score_regression` considers only the rows for which both row
    fields ``weight_expr`` and ``ld_score_expr`` are defined. Rows with missing
    values in either field are removed prior to fitting the LD score
    regression model.

    Parameters
    ----------
    weight_expr : :class:`.Float64Expression`
                  Row-indexed expression for the LD scores used to derive
                  variant weights in the model.
    ld_score_expr : :class:`.Float64Expression`
                    Row-indexed expression for the LD scores used as covariates
                    in the model.
    chi_sq_exprs : :class:`.Float64Expression` or :obj:`list` of
                        :class:`.Float64Expression`
                        One or more row-indexed (if table) or entry-indexed
                        (if matrix table) expressions for chi-squared
                        statistics resulting from genome-wide association
                        studies.
    n_samples_exprs: :class:`.NumericExpression` or :obj:`list` of
                     :class:`.NumericExpression`
                     One or more row-indexed (if table) or entry-indexed
                     (if matrix table) expressions indicating the number of
                     samples used in the studies that generated the test
                     statistics supplied to ``chi_sq_exprs``.
    n_blocks : :obj:`int`
               The number of blocks used in the jackknife approach to
               estimating standard errors.
    two_step_threshold : :obj:`int`
                         Variants with chi-squared statistics greater than this
                         value are excluded in the first step of the two-step
                         procedure used to fit the model.
    n_reference_panel_variants : :obj:`int`, optional
                                 Number of variants used to estimate the
                                 SNP-heritability :math:`h_g^2`.

    Returns
    -------
    :class:`.Table`
        Table keyed by ``phenotype`` with intercept and heritability estimates
        for each phenotype passed to the function."""

    chi_sq_exprs = wrap_to_list(chi_sq_exprs)
    n_samples_exprs = wrap_to_list(n_samples_exprs)

    assert ((len(chi_sq_exprs) == len(n_samples_exprs)) or
            (len(n_samples_exprs) == 1))
    __k = 2  # number of covariates, including intercept

    ds = chi_sq_exprs[0]._indices.source

    analyze('ld_score_regression/weight_expr',
            weight_expr,
            ds._row_indices)
    analyze('ld_score_regression/ld_score_expr',
            ld_score_expr,
            ds._row_indices)

    # format input dataset
    if isinstance(ds, MatrixTable):
        if len(chi_sq_exprs) != 1:
            raise ValueError("""Only one chi_sq_expr allowed if originating
                from a matrix table.""")
        if len(n_samples_exprs) != 1:
            raise ValueError("""Only one n_samples_expr allowed if
                originating from a matrix table.""")

        col_key = list(ds.col_key)
        if len(col_key) != 1:
            raise ValueError("""Matrix table must be keyed by a single
                phenotype field.""")

        analyze('ld_score_regression/chi_squared_expr',
                chi_sq_exprs[0],
                ds._entry_indices)
        analyze('ld_score_regression/n_samples_expr',
                n_samples_exprs[0],
                ds._entry_indices)

        ds = ds._select_all(row_exprs={'__locus': ds.locus,
                                       '__alleles': ds.alleles,
                                       '__w_initial': weight_expr,
                                       '__w_initial_floor': hl.max(weight_expr,
                                                                   1.0),
                                       '__x': ld_score_expr,
                                       '__x_floor': hl.max(ld_score_expr,
                                                           1.0)},
                            row_key=['__locus', '__alleles'],
                            col_exprs={'__y_name': ds[col_key[0]]},
                            col_key=['__y_name'],
                            entry_exprs={'__y': chi_sq_exprs[0],
                                         '__n': n_samples_exprs[0]})
        ds = ds.annotate_entries(**{'__w': ds.__w_initial})

        ds = ds.filter_rows(hl.is_defined(ds.__locus) &
                            hl.is_defined(ds.__alleles) &
                            hl.is_defined(ds.__w_initial) &
                            hl.is_defined(ds.__x))

    else:
        assert isinstance(ds, Table)
        for y in chi_sq_exprs:
            analyze('ld_score_regression/chi_squared_expr', y, ds._row_indices)
        for n in n_samples_exprs:
            analyze('ld_score_regression/n_samples_expr', n, ds._row_indices)

        ys = ['__y{:}'.format(i) for i, _ in enumerate(chi_sq_exprs)]
        ws = ['__w{:}'.format(i) for i, _ in enumerate(chi_sq_exprs)]
        ns = ['__n{:}'.format(i) for i, _ in enumerate(n_samples_exprs)]

        ds = ds.select(**dict(**{'__locus': ds.locus,
                                 '__alleles': ds.alleles,
                                 '__w_initial': weight_expr,
                                 '__x': ld_score_expr},
                              **{y: chi_sq_exprs[i]
                                 for i, y in enumerate(ys)},
                              **{w: weight_expr for w in ws},
                              **{n: n_samples_exprs[i]
                                 for i, n in enumerate(ns)}))
        ds = ds.key_by(ds.__locus, ds.__alleles)

        table_tmp_file = new_temp_file()
        ds.write(table_tmp_file)
        ds = hl.read_table(table_tmp_file)

        hts = [ds.select(**{'__w_initial': ds.__w_initial,
                            '__w_initial_floor': hl.max(ds.__w_initial,
                                                        1.0),
                            '__x': ds.__x,
                            '__x_floor': hl.max(ds.__x, 1.0),
                            '__y_name': i,
                            '__y': ds[ys[i]],
                            '__w': ds[ws[i]],
                            '__n': hl.int(ds[ns[i]])})
               for i, y in enumerate(ys)]

        mts = [ht.to_matrix_table(row_key=['__locus',
                                           '__alleles'],
                                  col_key=['__y_name'],
                                  row_fields=['__w_initial',
                                              '__w_initial_floor',
                                              '__x',
                                              '__x_floor'])
               for ht in hts]

        ds = mts[0]
        for i in range(1, len(ys)):
            ds = ds.union_cols(mts[i])

        ds = ds.filter_rows(hl.is_defined(ds.__locus) &
                            hl.is_defined(ds.__alleles) &
                            hl.is_defined(ds.__w_initial) &
                            hl.is_defined(ds.__x))

    mt_tmp_file1 = new_temp_file()
    ds.write(mt_tmp_file1)
    mt = hl.read_matrix_table(mt_tmp_file1)

    if not n_reference_panel_variants:
        M = mt.count_rows()
    else:
        M = n_reference_panel_variants

    # block variants for each phenotype
    n_phenotypes = mt.count_cols()

    mt = mt.annotate_entries(__in_step1=(hl.is_defined(mt.__y) &
                                         (mt.__y < two_step_threshold)),
                             __in_step2=hl.is_defined(mt.__y))

    mt = mt.annotate_cols(__col_idx=hl.int(hl.scan.count()),
                          __m_step1=hl.agg.count_where(mt.__in_step1),
                          __m_step2=hl.agg.count_where(mt.__in_step2))

    col_keys = list(mt.col_key)

    ht = mt.localize_entries(entries_array_field_name='__entries',
                             columns_array_field_name='__cols')

    ht = ht.annotate(__entries=hl.rbind(
        hl.scan.array_agg(
            lambda entry: hl.scan.count_where(entry.__in_step1),
            ht.__entries),
        lambda step1_indices: hl.map(
            lambda i: hl.rbind(
                hl.int(hl.or_else(step1_indices[i], 0)),
                ht.__cols[i].__m_step1,
                ht.__entries[i],
                lambda step1_idx, m_step1, entry: hl.rbind(
                    hl.map(
                        lambda j: hl.int(hl.floor(j * (m_step1 / n_blocks))),
                        hl.range(0, n_blocks + 1)),
                    lambda step1_separators: hl.rbind(
                        hl.set(step1_separators).contains(step1_idx),
                        hl.sum(
                            hl.map(
                                lambda s1: step1_idx >= s1,
                                step1_separators)) - 1,
                        lambda is_separator, step1_block: entry.annotate(
                            __step1_block=step1_block,
                            __step2_block=hl.cond(~entry.__in_step1 & is_separator,
                                                  step1_block - 1,
                                                  step1_block))))),
            hl.range(0, hl.len(ht.__entries)))))

    mt = ht._unlocalize_entries('__entries', '__cols', col_keys)

    mt_tmp_file2 = new_temp_file()
    mt.write(mt_tmp_file2)
    mt = hl.read_matrix_table(mt_tmp_file2)
    
    # initial coefficient estimates
    mt = mt.annotate_cols(__initial_betas=[
        1.0, (hl.agg.mean(mt.__y) - 1.0) / hl.agg.mean(mt.__x)])
    mt = mt.annotate_cols(__step1_betas=mt.__initial_betas,
                          __step2_betas=mt.__initial_betas)

    # step 1 iteratively reweighted least squares
    for i in range(3):
        mt = mt.annotate_entries(__w=hl.cond(
            mt.__in_step1,
            1.0/(mt.__w_initial_floor * 2.0 * (mt.__step1_betas[0] +
                                               mt.__step1_betas[1] *
                                               mt.__x_floor)**2),
            0.0))
        mt = mt.annotate_cols(__step1_betas=hl.agg.filter(
            mt.__in_step1,
            hl.agg.linreg(y=mt.__y,
                          x=[1.0, mt.__x],
                          weight=mt.__w).beta))
        mt = mt.annotate_cols(__step1_h2=hl.max(hl.min(
            mt.__step1_betas[1] * M / hl.agg.mean(mt.__n), 1.0), 0.0))
        mt = mt.annotate_cols(__step1_betas=[
            mt.__step1_betas[0],
            mt.__step1_h2 * hl.agg.mean(mt.__n) / M])

    # step 1 block jackknife
    mt = mt.annotate_cols(__step1_block_betas=[
        hl.agg.filter((mt.__step1_block != i) & mt.__in_step1,
                      hl.agg.linreg(y=mt.__y,
                                    x=[1.0, mt.__x],
                                    weight=mt.__w).beta)
        for i in range(n_blocks)])

    mt = mt.annotate_cols(__step1_block_betas_bias_corrected=hl.map(
        lambda x: n_blocks * mt.__step1_betas - (n_blocks - 1) * x,
        mt.__step1_block_betas))

    mt = mt.annotate_cols(
        __step1_jackknife_mean=hl.map(
            lambda i: hl.mean(
                hl.map(lambda x: x[i],
                       mt.__step1_block_betas_bias_corrected)),
            hl.range(0, __k)),
        __step1_jackknife_variance=hl.map(
            lambda i: (hl.sum(
                hl.map(lambda x: x[i]**2,
                       mt.__step1_block_betas_bias_corrected)) -
                       hl.sum(
                hl.map(lambda x: x[i],
                       mt.__step1_block_betas_bias_corrected))**2 /
                       n_blocks) /
            (n_blocks - 1) / n_blocks,
            hl.range(0, __k)))

    # step 2 iteratively reweighted least squares
    for i in range(3):
        mt = mt.annotate_entries(__w=hl.cond(
            mt.__in_step2,
            1.0/(mt.__w_initial_floor *
                 2.0 * (mt.__step2_betas[0] +
                        mt.__step2_betas[1] *
                        mt.__x_floor)**2),
            0.0))
        mt = mt.annotate_cols(__step2_betas=[
            mt.__step1_betas[0],
            hl.agg.filter(mt.__in_step2,
                          hl.agg.linreg(y=mt.__y - mt.__step1_betas[0],
                                        x=[mt.__x],
                                        weight=mt.__w).beta[0])])
        mt = mt.annotate_cols(__step2_h2=hl.max(hl.min(
            mt.__step2_betas[1] * M/hl.agg.mean(mt.__n), 1.0), 0.0))
        mt = mt.annotate_cols(__step2_betas=[
            mt.__step1_betas[0],
            mt.__step2_h2 * hl.agg.mean(mt.__n)/M])

    # step 2 block jackknife
    mt = mt.annotate_cols(__step2_block_betas=[
        hl.agg.filter((mt.__step2_block != i) & mt.__in_step2,
                      hl.agg.linreg(y=mt.__y - mt.__step1_betas[0],
                                    x=[mt.__x],
                                    weight=mt.__w).beta[0])
        for i in range(n_blocks)])

    mt = mt.annotate_cols(__step2_block_betas_bias_corrected=hl.map(
        lambda x: n_blocks * mt.__step2_betas[1] - (n_blocks - 1) * x,
        mt.__step2_block_betas))

    mt = mt.annotate_cols(
        __step2_jackknife_mean=hl.mean(
            mt.__step2_block_betas_bias_corrected),
        __step2_jackknife_variance=(
            hl.sum(mt.__step2_block_betas_bias_corrected**2) -
            hl.sum(mt.__step2_block_betas_bias_corrected)**2 /
            n_blocks) / (n_blocks - 1) / n_blocks)

    # combine step 1 and step 2 block jackknifes
    mt = mt.annotate_entries(
        __step2_initial_w=1.0/(mt.__w_initial_floor *
                               2.0 * (mt.__initial_betas[0] +
                                      mt.__initial_betas[1] *
                                      mt.__x_floor)**2))

    mt = mt.annotate_cols(
        __final_betas=[
            mt.__step1_betas[0],
            mt.__step2_betas[1]],
        __c=(hl.agg.sum(mt.__step2_initial_w * mt.__x) /
             hl.agg.sum(mt.__step2_initial_w * mt.__x**2)))

    mt = mt.annotate_cols(__final_block_betas=hl.map(
        lambda i: (mt.__step2_block_betas[i] - mt.__c *
                   (mt.__step1_block_betas[i][0] - mt.__final_betas[0])),
        hl.range(0, n_blocks)))

    mt = mt.annotate_cols(
        __final_block_betas_bias_corrected=(n_blocks * mt.__final_betas[1] -
                                            (n_blocks - 1) *
                                            mt.__final_block_betas))

    mt = mt.annotate_cols(
        __final_jackknife_mean=[
            mt.__step1_jackknife_mean[0],
            hl.mean(mt.__final_block_betas_bias_corrected)],
        __final_jackknife_variance=[
            mt.__step1_jackknife_variance[0],
            (hl.sum(mt.__final_block_betas_bias_corrected**2) -
             hl.sum(mt.__final_block_betas_bias_corrected)**2 /
             n_blocks) / (n_blocks - 1) / n_blocks])

    # convert coefficient to heritability estimate
    mt = mt.annotate_cols(
        phenotype=mt.__y_name,
        mean_chi_sq=hl.agg.mean(mt.__y),
        intercept=hl.struct(
            estimate=mt.__final_betas[0],
            standard_error=hl.sqrt(mt.__final_jackknife_variance[0])),
        snp_heritability=hl.struct(
            estimate=(M/hl.agg.mean(mt.__n)) * mt.__final_betas[1],
            standard_error=hl.sqrt((M/hl.agg.mean(mt.__n))**2 *
                                   mt.__final_jackknife_variance[1])))

    # format and return results
    ht = mt.cols()
    ht = ht.key_by(ht.phenotype)
    ht = ht.select(ht.mean_chi_sq,
                   ht.intercept,
                   ht.snp_heritability)

    ht_tmp_file = new_temp_file()
    ht.write(ht_tmp_file)
    ht = hl.read_table(ht_tmp_file)
    
    return ht
Esempio n. 49
0
def mendel_errors(call, pedigree) -> Tuple[Table, Table, Table, Table]:
    r"""Find Mendel errors; count per variant, individual and nuclear family.

    .. include:: ../_templates/req_tstring.rst

    .. include:: ../_templates/req_tvariant.rst

    .. include:: ../_templates/req_biallelic.rst

    Examples
    --------

    Find all violations of Mendelian inheritance in each (dad, mom, kid) trio in
    a pedigree and return four tables (all errors, errors by family, errors by
    individual, errors by variant):

    >>> ped = hl.Pedigree.read('data/trios.fam')
    >>> all_errors, per_fam, per_sample, per_variant = hl.mendel_errors(dataset['GT'], ped)

    Export all mendel errors to a text file:

    >>> all_errors.export('output/all_mendel_errors.tsv')

    Annotate columns with the number of Mendel errors:

    >>> annotated_samples = dataset.annotate_cols(mendel=per_sample[dataset.s])

    Annotate rows with the number of Mendel errors:

    >>> annotated_variants = dataset.annotate_rows(mendel=per_variant[dataset.locus, dataset.alleles])

    Notes
    -----

    The example above returns four tables, which contain Mendelian violations
    grouped in various ways. These tables are modeled after the `PLINK mendel
    formats <https://www.cog-genomics.org/plink2/formats#mendel>`_, resembling
    the ``.mendel``, ``.fmendel``, ``.imendel``, and ``.lmendel`` formats,
    respectively.

    **First table:** all Mendel errors. This table contains one row per Mendel
    error, keyed by the variant and proband id.

        - `locus` (:class:`.tlocus`) -- Variant locus, key field.
        - `alleles` (:class:`.tarray` of :py:data:`.tstr`) -- Variant alleles, key field.
        - (column key of `dataset`) (:py:data:`.tstr`) -- Proband ID, key field.
        - `fam_id` (:py:data:`.tstr`) -- Family ID.
        - `mendel_code` (:py:data:`.tint32`) -- Mendel error code, see below.

    **Second table:** errors per nuclear family. This table contains one row
    per nuclear family, keyed by the parents.

        - `pat_id` (:py:data:`.tstr`) -- Paternal ID. (key field)
        - `mat_id` (:py:data:`.tstr`) -- Maternal ID. (key field)
        - `fam_id` (:py:data:`.tstr`) -- Family ID.
        - `children` (:py:data:`.tint32`) -- Number of children in this nuclear family.
        - `errors` (:py:data:`.tint64`) -- Number of Mendel errors in this nuclear family.
        - `snp_errors` (:py:data:`.tint64`) -- Number of Mendel errors at SNPs in this
          nuclear family.

    **Third table:** errors per individual. This table contains one row per
    individual. Each error is counted toward the proband, father, and mother
    according to the `Implicated` in the table below.

        - (column key of `dataset`) (:py:data:`.tstr`) -- Sample ID (key field).
        - `fam_id` (:py:data:`.tstr`) -- Family ID.
        - `errors` (:py:data:`.tint64`) -- Number of Mendel errors involving this
          individual.
        - `snp_errors` (:py:data:`.tint64`) -- Number of Mendel errors involving this
          individual at SNPs.

    **Fourth table:** errors per variant.

        - `locus` (:class:`.tlocus`) -- Variant locus, key field.
        - `alleles` (:class:`.tarray` of :py:data:`.tstr`) -- Variant alleles, key field.
        - `errors` (:py:data:`.tint64`) -- Number of Mendel errors in this variant.

    This method only considers complete trios (two parents and proband with
    defined sex). The code of each Mendel error is determined by the table
    below, extending the
    `Plink classification <https://www.cog-genomics.org/plink2/basic_stats#mendel>`__.

    In the table, the copy state of a locus with respect to a trio is defined
    as follows, where PAR is the `pseudoautosomal region
    <https://en.wikipedia.org/wiki/Pseudoautosomal_region>`__ (PAR) of X and Y
    defined by the reference genome and the autosome is defined by
    :meth:`~hail.genetics.Locus.in_autosome`.

    - Auto -- in autosome or in PAR or female child
    - HemiX -- in non-PAR of X and male child
    - HemiY -- in non-PAR of Y and male child

    `Any` refers to the set \{ HomRef, Het, HomVar, NoCall \} and `~`
    denotes complement in this set.

    +------+---------+---------+--------+----------------------------+
    | Code | Dad     | Mom     | Kid    | Copy State | Implicated    |
    +======+=========+=========+========+============+===============+
    |    1 | HomVar  | HomVar  | Het    | Auto       | Dad, Mom, Kid |
    +------+---------+---------+--------+------------+---------------+
    |    2 | HomRef  | HomRef  | Het    | Auto       | Dad, Mom, Kid |
    +------+---------+---------+--------+------------+---------------+
    |    3 | HomRef  | ~HomRef | HomVar | Auto       | Dad, Kid      |
    +------+---------+---------+--------+------------+---------------+
    |    4 | ~HomRef | HomRef  | HomVar | Auto       | Mom, Kid      |
    +------+---------+---------+--------+------------+---------------+
    |    5 | HomRef  | HomRef  | HomVar | Auto       | Kid           |
    +------+---------+---------+--------+------------+---------------+
    |    6 | HomVar  | ~HomVar | HomRef | Auto       | Dad, Kid      |
    +------+---------+---------+--------+------------+---------------+
    |    7 | ~HomVar | HomVar  | HomRef | Auto       | Mom, Kid      |
    +------+---------+---------+--------+------------+---------------+
    |    8 | HomVar  | HomVar  | HomRef | Auto       | Kid           |
    +------+---------+---------+--------+------------+---------------+
    |    9 | Any     | HomVar  | HomRef | HemiX      | Mom, Kid      |
    +------+---------+---------+--------+------------+---------------+
    |   10 | Any     | HomRef  | HomVar | HemiX      | Mom, Kid      |
    +------+---------+---------+--------+------------+---------------+
    |   11 | HomVar  | Any     | HomRef | HemiY      | Dad, Kid      |
    +------+---------+---------+--------+------------+---------------+
    |   12 | HomRef  | Any     | HomVar | HemiY      | Dad, Kid      |
    +------+---------+---------+--------+------------+---------------+

    See Also
    --------
    :func:`.mendel_error_code`

    Parameters
    ----------
    dataset : :class:`.MatrixTable`
    pedigree : :class:`.Pedigree`

    Returns
    -------
    (:class:`.Table`, :class:`.Table`, :class:`.Table`, :class:`.Table`)
    """
    source = call._indices.source
    if not isinstance(source, MatrixTable):
        raise ValueError("'mendel_errors': expected 'call' to be an expression of 'MatrixTable', found {}".format(
            "expression of '{}'".format(source.__class__) if source is not None else 'scalar expression'))

    source = source.select_entries(__GT=call)
    dataset = require_biallelic(source, 'mendel_errors')
    tm = trio_matrix(dataset, pedigree, complete_trios=True)
    tm = tm.select_entries(mendel_code=hl.mendel_error_code(
        tm.locus,
        tm.is_female,
        tm.father_entry['__GT'],
        tm.mother_entry['__GT'],
        tm.proband_entry['__GT']
    ))
    ck_name = next(iter(source.col_key))
    tm = tm.filter_entries(hl.is_defined(tm.mendel_code))
    tm = tm.rename({'id' : ck_name})

    entries = tm.entries()

    table1 = entries.select('fam_id', 'mendel_code')

    fam_counts = (
        entries
            .group_by(pat_id=entries.father[ck_name], mat_id=entries.mother[ck_name])
            .partition_hint(min(entries.n_partitions(), 8))
            .aggregate(children=hl.len(hl.agg.collect_as_set(entries[ck_name])),
                       errors=hl.agg.count_where(hl.is_defined(entries.mendel_code)),
                       snp_errors=hl.agg.count_where(hl.is_snp(entries.alleles[0], entries.alleles[1]) &
                                                     hl.is_defined(entries.mendel_code)))
    )
    table2 = tm.key_cols_by().cols()
    table2 = table2.select(pat_id=table2.father[ck_name],
                           mat_id=table2.mother[ck_name],
                           fam_id=table2.fam_id,
                           **fam_counts[table2.father[ck_name], table2.mother[ck_name]])
    table2 = table2.key_by('pat_id', 'mat_id').distinct()
    table2 = table2.annotate(errors=hl.or_else(table2.errors, hl.int64(0)),
                             snp_errors=hl.or_else(table2.snp_errors, hl.int64(0)))

    # in implicated, idx 0 is dad, idx 1 is mom, idx 2 is child
    implicated = hl.literal([
        [0, 0, 0],  # dummy
        [1, 1, 1],
        [1, 1, 1],
        [1, 0, 1],
        [0, 1, 1],
        [0, 0, 1],
        [1, 0, 1],
        [0, 1, 1],
        [0, 0, 1],
        [0, 1, 1],
        [0, 1, 1],
        [1, 0, 1],
        [1, 0, 1],
    ], dtype=hl.tarray(hl.tarray(hl.tint64)))

    table3 = tm.annotate_cols(all_errors=hl.or_else(hl.agg.array_sum(implicated[tm.mendel_code]), [0, 0, 0]),
                              snp_errors=hl.or_else(
                                  hl.agg.filter(hl.is_snp(tm.alleles[0], tm.alleles[1]),
                                                hl.agg.array_sum(implicated[tm.mendel_code])),
                                  [0, 0, 0])).key_cols_by().cols()

    table3 = table3.select(xs=[
        hl.struct(**{ck_name: table3.father[ck_name],
                     'fam_id': table3.fam_id,
                     'errors': table3.all_errors[0],
                     'snp_errors': table3.snp_errors[0]}),
        hl.struct(**{ck_name: table3.mother[ck_name],
                     'fam_id': table3.fam_id,
                     'errors': table3.all_errors[1],
                     'snp_errors': table3.snp_errors[1]}),
        hl.struct(**{ck_name: table3.proband[ck_name],
                     'fam_id': table3.fam_id,
                     'errors': table3.all_errors[2],
                     'snp_errors': table3.snp_errors[2]}),
    ])
    table3 = table3.explode('xs')
    table3 = table3.select(**table3.xs)
    table3 = (table3.group_by(ck_name, 'fam_id')
              .aggregate(errors=hl.agg.sum(table3.errors),
                         snp_errors=hl.agg.sum(table3.snp_errors))
              .key_by(ck_name))

    table4 = tm.select_rows(errors=hl.agg.count_where(hl.is_defined(tm.mendel_code))).rows()

    return table1, table2, table3, table4
Esempio n. 50
0
def call_stats(call, alleles) -> StructExpression:
    """Compute useful call statistics.

    Examples
    --------
    Compute call statistics per row:

    >>> dataset_result = dataset.annotate_rows(gt_stats = agg.call_stats(dataset.GT, dataset.alleles))
    >>> dataset_result.rows().key_by('locus').select('gt_stats').show()
    +---------------+--------------+---------------------+-------------+
    | locus         | gt_stats.AC  | gt_stats.AF         | gt_stats.AN |
    +---------------+--------------+---------------------+-------------+
    | locus<GRCh37> | array<int32> | array<float64>      |       int32 |
    +---------------+--------------+---------------------+-------------+
    | 20:12990057   | [148,52]     | [7.40e-01,2.60e-01] |         200 |
    | 20:13029862   | [0,198]      | [0.00e+00,1.00e+00] |         198 |
    | 20:13074235   | [13,187]     | [6.50e-02,9.35e-01] |         200 |
    | 20:13140720   | [194,6]      | [9.70e-01,3.00e-02] |         200 |
    | 20:13695498   | [175,25]     | [8.75e-01,1.25e-01] |         200 |
    | 20:13714384   | [199,1]      | [9.95e-01,5.00e-03] |         200 |
    | 20:13765944   | [132,2]      | [9.85e-01,1.49e-02] |         134 |
    | 20:13765954   | [180,2]      | [9.89e-01,1.10e-02] |         182 |
    | 20:13845987   | [2,198]      | [1.00e-02,9.90e-01] |         200 |
    | 20:16223957   | [145,45]     | [7.63e-01,2.37e-01] |         190 |
    +---------------+--------------+---------------------+-------------+
    <BLANKLINE>
    +---------------------------+
    | gt_stats.homozygote_count |
    +---------------------------+
    | array<int32>              |
    +---------------------------+
    | [57,9]                    |
    | [0,99]                    |
    | [1,88]                    |
    | [95,1]                    |
    | [75,0]                    |
    | [99,0]                    |
    | [65,0]                    |
    | [89,0]                    |
    | [0,98]                    |
    | [64,14]                   |
    +---------------------------+
    showing top 10 rows
    <BLANKLINE>

    Notes
    -----
    This method is meaningful for computing call metrics per variant, but not
    especially meaningful for computing metrics per sample.

    This method returns a struct expression with three fields:

     - `AC` (:class:`.tarray` of :py:data:`.tint32`) - Allele counts. One element
       for each allele, including the reference.
     - `AF` (:class:`.tarray` of :py:data:`.tfloat64`) - Allele frequencies. One
       element for each allele, including the reference.
     - `AN` (:py:data:`.tint32`) - Allele number. The total number of called
       alleles, or the number of non-missing calls * 2.
     - `homozygote_count` (:class:`.tarray` of :py:data:`.tint32`) - Homozygote
       genotype counts for each allele, including the reference. Only **diploid**
       genotype calls are counted.

    Parameters
    ----------
    call : :class:`.CallExpression`
    alleles : :class:`.ArrayStringExpression`
        Variant alleles.

    Returns
    -------
    :class:`.StructExpression`
        Struct expression with fields `AC`, `AF`, `AN`, and `homozygote_count`.
    """
    n_alleles = hl.len(alleles)
    t = tstruct(AC=tarray(tint32),
                AF=tarray(tfloat64),
                AN=tint32,
                homozygote_count=tarray(tint32))

    return _agg_func('CallStats', [call], t, [], init_op_args=[n_alleles])
Esempio n. 51
0
def sparse_split_multi(sparse_mt):
    """Splits multiallelic variants on a sparse MatrixTable.

    Takes a dataset formatted like the output of :func:`.vcf_combiner`. The
    splitting will add `was_split` and `a_index` fields, as :func:`.split_multi`
    does. This function drops the `LA` (local alleles) field, as it re-computes
    entry fields based on the new, split globals alleles.

    Variants are split thus:

    - A row with only one (reference) or two (reference and alternate) alleles.

    - A row with multiple alternate alleles  will be split, with one row for
      each alternate allele, and each row will contain two alleles: ref and alt.
      The reference and alternate allele will be minrepped using
      :func:`.min_rep`.

    The split multi logic handles the following entry fields:

        .. code-block:: text

          struct {
            LGT: call
            LAD: array<int32>
            DP: int32
            GQ: int32
            LPL: array<int32>
            RGQ: int32
            LPGT: call
            LA: array<int32>
            END: int32
          }

    All fields except for `LA` are optional, and only handled if they exist.

    - `LA` is used to find the corresponding local allele index for the desired
      global `a_index`, and then dropped from the resulting dataset. If `LA`
      does not contain the global `a_index`, the index for the `<NON_REF>`
      allele is used to process the entry fields.

    - `LGT` and `LPGT` are downcoded using the corresponding local `a_index`.
      They are renamed to `GT` and `PGT` respectively, as the resulting call is
      no longer local.

    - `LAD` is used to create an `AD` field consisting of the allele depths
      corresponding to the reference, global `a_index` allele, and `<NON_REF>`
      allele.

    - `DP` is preserved unchanged.

    - `GQ` is recalculated from the updated `PL`, if it exists, but otherwise
      preserved unchanged.

    - `PL` array elements are calculated from the minimum `LPL` value for all
      allele pairs that downcode to the desired one. (This logic is identical to
      the `PL` logic in :func:`.split_mult_hts`; if a row has an alternate
      allele but it is not present in `LA`, the `PL` field is set to missing.
      The `PL` for `ref/<NON_REF>` in that case can be drawn from `RGQ`.

    - `RGQ` (the ref genotype quality) is preserved unchanged.

    - `END` is untouched.

    Notes
    -----
    This version of split-multi doesn't deal with either duplicate loci (in
    which case the explode could possibly result in out-of-order rows, although
    the actual split_multi function also doesn't handle that case).

    It also checks that min-repping will not change the locus and will error if
    it does. (I believe the VCF combiner checks that this holds true,
    currently.)

    Parameters
    ----------
    sparse_mt : :class:`.MatrixTable`
        Sparse MatrixTable to split.

    Returns
    -------
    :class:`.MatrixTable`
        The split MatrixTable in sparse format.

    """

    hl.methods.misc.require_row_key_variant(sparse_mt, "sparse_split_multi")

    entries = hl.utils.java.Env.get_uid()
    cols = hl.utils.java.Env.get_uid()
    ds = sparse_mt.localize_entries(entries, cols)
    new_id = hl.utils.java.Env.get_uid()

    def struct_from_min_rep(i):
        return hl.bind(lambda mr:
                       (hl.case()
                        .when(ds.locus == mr.locus,
                              hl.struct(
                                  locus=ds.locus,
                                  alleles=[mr.alleles[0], mr.alleles[1]],
                                  a_index=i,
                                  was_split=True))
                        .or_error("Found non-left-aligned variant in sparse_split_multi")),
                       hl.min_rep(ds.locus, [ds.alleles[0], ds.alleles[i]]))

    explode_structs = hl.cond(hl.len(ds.alleles) < 3,
                              [hl.struct(
                                  locus=ds.locus,
                                  alleles=ds.alleles,
                                  a_index=1,
                                  was_split=False)],
                              hl._sort_by(
                                  hl.range(1, hl.len(ds.alleles))
                                      .map(struct_from_min_rep),
                                  lambda l, r: hl._compare(l.alleles, r.alleles) < 0
                              ))

    ds = ds.annotate(**{new_id: explode_structs}).explode(new_id)

    def transform_entries(old_entry):
        def with_local_a_index(local_a_index):
            new_pl = hl.or_missing(
                hl.is_defined(old_entry.LPL),
                hl.or_missing(
                    hl.is_defined(local_a_index),
                    hl.range(0, 3).map(lambda i: hl.min(
                        hl.range(0, hl.triangle(hl.len(old_entry.LA)))
                            .filter(lambda j: hl.downcode(hl.unphased_diploid_gt_index_call(j), local_a_index) == hl.unphased_diploid_gt_index_call(i))
                            .map(lambda idx: old_entry.LPL[idx])))))
            fields = set(old_entry.keys())

            def with_pl(pl):
                new_exprs = {}
                dropped_fields = ['LA']
                if 'LGT' in fields:
                    new_exprs['GT'] = hl.downcode(old_entry.LGT, hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LGT')
                if 'LPGT' in fields:
                    new_exprs['PGT'] = hl.downcode(old_entry.LPGT, hl.or_else(local_a_index, hl.len(old_entry.LA)))
                    dropped_fields.append('LPGT')
                if 'LAD' in fields:
                    new_exprs['AD'] = hl.or_missing(
                        hl.is_defined(old_entry.LAD),
                        [old_entry.LAD[0], hl.or_else(old_entry.LAD[local_a_index], 0)]) # second entry zeroed for lack of non-ref AD
                    dropped_fields.append('LAD')
                if 'LPL' in fields:
                    new_exprs['PL'] = pl
                    if 'GQ' in fields:
                        new_exprs['GQ'] = hl.or_else(hl.gq_from_pl(pl), old_entry.GQ)

                    dropped_fields.append('LPL')

                return hl.cond(hl.len(ds.alleles) == 1,
                                   old_entry.annotate(**{f[1:]: old_entry[f] for f in ['LGT', 'LPGT', 'LAD', 'LPL'] if f in fields}).drop(*dropped_fields),
                                   old_entry.annotate(**new_exprs).drop(*dropped_fields))

            if 'LPL' in fields:
                return hl.bind(with_pl, new_pl)
            else:
                return with_pl(None)

        lai = hl.fold(lambda accum, elt:
                        hl.cond(old_entry.LA[elt] == ds[new_id].a_index,
                                elt, accum),
                        hl.null(hl.tint32),
                        hl.range(0, hl.len(old_entry.LA)))
        return hl.bind(with_local_a_index, lai)

    new_row = ds.row.annotate(**{
        'locus': ds[new_id].locus,
        'alleles': ds[new_id].alleles,
        'a_index': ds[new_id].a_index,
        'was_split': ds[new_id].was_split,
        entries: ds[entries].map(transform_entries)
    }).drop(new_id)

    ds = hl.Table(
        hl.ir.TableKeyBy(
            hl.ir.TableMapRows(
                hl.ir.TableKeyBy(ds._tir, ['locus']),
                new_row._ir),
            ['locus', 'alleles'],
            is_sorted=True))
    return ds._unlocalize_entries(entries, cols, list(sparse_mt.col_key.keys()))
Esempio n. 52
0
File: qc.py Progetto: jigold/hail
def variant_qc(mt, name='variant_qc') -> MatrixTable:
    """Compute common variant statistics (quality control metrics).

    .. include:: ../_templates/req_tvariant.rst

    Examples
    --------

    >>> dataset_result = hl.variant_qc(dataset)

    Notes
    -----
    This method computes variant statistics from the genotype data, returning
    a new struct field `name` with the following metrics based on the fields
    present in the entry schema.

    If `mt` contains an entry field `DP` of type :py:data:`.tint32`, then the
    field `dp_stats` is computed. If `mt` contains an entry field `GQ` of type
    :py:data:`.tint32`, then the field `gq_stats` is computed. Both `dp_stats`
    and `gq_stats` are structs with with four fields:

    - `mean` (``float64``) -- Mean value.
    - `stdev` (``float64``) -- Standard deviation (zero degrees of freedom).
    - `min` (``int32``) -- Minimum value.
    - `max` (``int32``) -- Maximum value.

    If the dataset does not contain an entry field `GT` of type
    :py:data:`.tcall`, then an error is raised. The following fields are always
    computed from `GT`:

    - `AF` (``array<float64>``) -- Calculated allele frequency, one element
      per allele, including the reference. Sums to one. Equivalent to
      `AC` / `AN`.
    - `AC` (``array<int32>``) -- Calculated allele count, one element per
      allele, including the reference. Sums to `AN`.
    - `AN` (``int32``) -- Total number of called alleles.
    - `homozygote_count` (``array<int32>``) -- Number of homozygotes per
      allele. One element per allele, including the reference.
    - `call_rate` (``float64``) -- Fraction of calls neither missing nor filtered.
       Equivalent to `n_called` / :meth:`.count_cols`.
    - `n_called` (``int64``) -- Number of samples with a defined `GT`.
    - `n_not_called` (``int64``) -- Number of samples with a missing `GT`.
    - `n_filtered` (``int64``) -- Number of filtered entries.
    - `n_het` (``int64``) -- Number of heterozygous samples.
    - `n_non_ref` (``int64``) -- Number of samples with at least one called
      non-reference allele.
    - `het_freq_hwe` (``float64``) -- Expected frequency of heterozygous
      samples under Hardy-Weinberg equilibrium. See
      :func:`.functions.hardy_weinberg_test` for details.
    - `p_value_hwe` (``float64``) -- p-value from test of Hardy-Weinberg equilibrium.
      See :func:`.functions.hardy_weinberg_test` for details.

    Warning
    -------
    `het_freq_hwe` and `p_value_hwe` are calculated as in
    :func:`.functions.hardy_weinberg_test`, with non-diploid calls
    (``ploidy != 2``) ignored in the counts. As this test is only
    statistically rigorous in the biallelic setting, :func:`.variant_qc`
    sets both fields to missing for multiallelic variants. Consider using
    :func:`~hail.methods.split_multi` to split multi-allelic variants beforehand.

    Parameters
    ----------
    mt : :class:`.MatrixTable`
        Dataset.
    name : :obj:`str`
        Name for resulting field.

    Returns
    -------
    :class:`.MatrixTable`
    """
    require_row_key_variant(mt, 'variant_qc')

    bound_exprs = {}
    gq_dp_exprs = {}

    def has_field_of_type(name, dtype):
        return name in mt.entry and mt[name].dtype == dtype

    if has_field_of_type('DP', hl.tint32):
        gq_dp_exprs['dp_stats'] = hl.agg.stats(mt.DP).select('mean', 'stdev', 'min', 'max')

    if has_field_of_type('GQ', hl.tint32):
        gq_dp_exprs['gq_stats'] = hl.agg.stats(mt.GQ).select('mean', 'stdev', 'min', 'max')

    if not has_field_of_type('GT',  hl.tcall):
        raise ValueError(f"'variant_qc': expect an entry field 'GT' of type 'call'")

    bound_exprs['n_called'] = hl.agg.count_where(hl.is_defined(mt['GT']))
    bound_exprs['n_not_called'] = hl.agg.count_where(hl.is_missing(mt['GT']))
    bound_exprs['n_filtered'] = mt.count_cols(_localize=False) - hl.agg.count()
    bound_exprs['call_stats'] = hl.agg.call_stats(mt.GT, mt.alleles)

    result = hl.rbind(hl.struct(**bound_exprs),
                      lambda e1: hl.rbind(
                          hl.case().when(hl.len(mt.alleles) == 2,
                                         hl.hardy_weinberg_test(e1.call_stats.homozygote_count[0],
                                                                e1.call_stats.AC[1] - 2 *
                                                                e1.call_stats.homozygote_count[1],
                                                                e1.call_stats.homozygote_count[1])
                                         ).or_missing(),
                          lambda hwe: hl.struct(**{
                              **gq_dp_exprs,
                              **e1.call_stats,
                              'call_rate': hl.float(e1.n_called) / (e1.n_called + e1.n_not_called + e1.n_filtered),
                              'n_called': e1.n_called,
                              'n_not_called': e1.n_not_called,
                              'n_filtered': e1.n_filtered,
                              'n_het': e1.n_called - hl.sum(e1.call_stats.homozygote_count),
                              'n_non_ref': e1.n_called - e1.call_stats.homozygote_count[0],
                              'het_freq_hwe': hwe.het_freq_hwe,
                              'p_value_hwe': hwe.p_value})))

    return mt.annotate_rows(**{name: result})