Esempio n. 1
0
def stats_split_mt(
        mt: hl.MatrixTable
) -> Tuple[List[int], hl.MatrixTable, hl.MatrixTable]:
    """
    Collect basic stat counts and split the MatrixTable by phenotype status (for plots)
    :param mt: Hail MatrixTable
    :return: basic stat counts, cases MatrixTable, controls MatrixTable
    """
    # 1. Sex
    n_females: List[str] = mt.filter_cols(mt.is_female == True).s.collect()
    n_males: List[str] = mt.filter_cols(mt.is_female == False).s.collect()
    n_sex_missing: List[str] = mt.filter_cols(hl.is_missing(
        mt.is_female)).s.collect()

    # 2. Phenotype status
    mt_cases: hl.MatrixTable = mt.filter_cols(mt.is_case == True)
    n_cases: int = mt_cases.count_cols()
    mt_controls: hl.MatrixTable = mt.filter_cols(mt.is_case == False)
    n_controls: int = mt_controls.count_cols()
    n_unknown_pheno: List[str] = mt.filter_cols(hl.is_missing(
        mt.is_case)).s.collect()

    # 3. Number of SNPs
    n_snps = mt.count_rows()

    counts: List[int] = [
        len(n_males),
        len(n_females),
        len(n_sex_missing), n_cases, n_controls,
        len(n_unknown_pheno), n_snps
    ]

    return counts, mt_cases, mt_controls
Esempio n. 2
0
def format_gnomad_mt(chrom,out_dir="/gpfs/ycga/scratch60/kahle/sp2349/datasets/gnomad/"):
	#Import Exomes GNOMAD TABLE and modify for easier manipulation
	gnomad_e = hl.read_table('/gpfs/ycga/scratch60/kahle/sp2349/combined_weilai_mts/combined.filtered.gnomad.r2.1.1.sites.{}.mt')

	#If NA annotations for bravo, make value == 0
	gnomad_e = gnomad_e.annotate(bravo_freeze8 = hl.if_else(hl.is_missing(gnomad_e.bravo_freeze8) == True,0.0,gnomad_e.bravo_freeze8))

	#Check for missing MetaSVM, CADD, or Bravo frequencies. If exome missing use genome freqs and vice versa.
	#Joining of the tables causes NA values for positions missing in either dataset. Genomes have some positions missing in Exomes..etc
	combined = combined.annotate(MetaSVM_pred=hl.if_else(hl.is_missing(combined.info.MetaSVM_pred) == True, combined.info_1.MetaSVM_pred, combined.info.MetaSVM_pred))
	combined = combined.annotate(CADD16snv_PHRED=hl.if_else(hl.is_missing(combined.info.CADD16snv_PHRED) == True, hl.if_else(hl.is_missing(combined.info_1.CADD16snv_PHRED) == True, hl.null('float64'), hl.float64(combined.info_1.CADD16snv_PHRED)), hl.float64(combined.info.CADD16snv_PHRED)))
	combined = combined.annotate(CADD_phred=hl.if_else(hl.is_missing(combined.info.CADD13_PHRED) == True, hl.if_else(hl.is_missing(combined.info_1.CADD13_PHRED) == True, hl.null('float64'), hl.float64(combined.info_1.CADD13_PHRED)), hl.float64(combined.info.CADD13_PHRED)))
	combined = combined.annotate(MPC=hl.if_else(hl.is_missing(combined.info.MPC) == True, hl.if_else(hl.is_missing(combined.info_1.MPC_score) == True, hl.null('float64'), hl.float64(combined.info_1.MPC)),hl.float64(combined.info.MPC)))
	combined = combined.annotate(bravo=hl.if_else(hl.is_missing(combined.info.bravo) == True, combined.info_1.bravo, combined.info.bravo))

	#Add Genes and Function (Splicing/non-synonymous)
	#If exomes annotations is missing (NA), use Genome annotations.
	combined = combined.annotate(Exonic_refGene=hl.if_else(hl.is_missing(combined.info["ExonicFunc.refGene"][0]) == True, hl.if_else(hl.is_missing(combined.info_1["ExonicFunc.refGene"][0]) == True, hl.null('str'),combined.info_1["ExonicFunc.refGene"][0]), combined.info["ExonicFunc.refGene"][0]))
	combined = combined.annotate(Func_refGene=hl.if_else(hl.is_missing(combined.info["Func.refGene"][0]) == True, hl.if_else(hl.is_missing(combined.info_1["Func.refGene"][0]) == True, hl.null('str'),combined.info_1["Func.refGene"][0]), combined.info["Func.refGene"][0]))	
	
	#Make NA values 0 (float) if position missing in exomes
	combined = combined.annotate(info=combined.info.annotate(non_topmed_AC=hl.if_else(hl.is_missing(combined.non_topmed_AC) == True,0,combined.non_topmed_AC)))
	combined = combined.annotate(info=combined.info.annotate(non_topmed_AN=hl.if_else(hl.is_missing(combined.non_topmed_AN) == True,0,combined.non_topmed_AN)))
	combined = combined.annotate(info=combined.info.annotate(non_topmed_nhomalt=hl.if_else(hl.is_missing(non_topmed_nhomalt) == True,0,combined.non_topmed_nhomalt)))

	#Make NA values 0 (float) if position missing in exomes
	combined = combined.annotate(info_1=combined.info_1.annotate(non_topmed_AC=hl.if_else(hl.is_missing(combined.info_1.non_topmed_AC) == True,0,combined.info_1.non_topmed_AC)))
def load_cmg(cmg_csv: str) -> hl.Table:
    cmg_ht = hl.import_table(cmg_csv, impute=True, delimiter=",", quote='"')

    cmg_ht = cmg_ht.transmute(
        locus1_b38=hl.locus("chr" + hl.str(cmg_ht.chrom_1), cmg_ht.pos_1, reference_genome='GRCh38'),
        alleles1_b38=[cmg_ht.ref_1, cmg_ht.alt_1],
        locus2_b38=hl.locus("chr" + hl.str(cmg_ht.chrom_2), cmg_ht.pos_2, reference_genome='GRCh38'),
        alleles2_b38=[cmg_ht.ref_2, cmg_ht.alt_2]
    )

    liftover_references = get_liftover_genome(cmg_ht.rename({'locus1_b38': 'locus'}))
    lifted_over_variants = hl.sorted(
        hl.array([
            liftover_expr(cmg_ht.locus1_b38, cmg_ht.alleles1_b38, liftover_references[1]),
            liftover_expr(cmg_ht.locus2_b38, cmg_ht.alleles2_b38, liftover_references[1])
        ]),
        lambda x: x.locus
    )

    cmg_ht = cmg_ht.key_by(
        locus1=lifted_over_variants[0].locus,
        alleles1=lifted_over_variants[0].alleles,
        locus2=lifted_over_variants[1].locus,
        alleles2=lifted_over_variants[1].alleles
    )

    return cmg_ht.annotate(
        bad_liftover=(
                hl.is_missing(cmg_ht.locus1) |
                hl.is_missing(cmg_ht.locus2) |
                (cmg_ht.locus1.sequence_context() != cmg_ht.alleles1[0][0]) |
                (cmg_ht.locus2.sequence_context() != cmg_ht.alleles2[0][0])
        )
    )
Esempio n. 4
0
def filter_low_conf_regions(
    mt: Union[hl.MatrixTable, hl.Table],
    filter_lcr: bool = True,
    filter_decoy: bool = True,
    filter_segdup: bool = True,
    filter_exome_low_coverage_regions: bool = False,
    high_conf_regions: Optional[List[str]] = None,
) -> Union[hl.MatrixTable, hl.Table]:
    """
    Filters low-confidence regions

    :param mt: MatrixTable or Table to filter
    :param filter_lcr: Whether to filter LCR regions
    :param filter_decoy: Whether to filter decoy regions
    :param filter_segdup: Whether to filter Segdup regions
    :param filter_exome_low_coverage_regions: Whether to filter exome low confidence regions
    :param high_conf_regions: Paths to set of high confidence regions to restrict to (union of regions)
    :return: MatrixTable or Table with low confidence regions removed
    """
    build = get_reference_genome(mt.locus).name
    if build == "GRCh37":
        import gnomad.resources.grch37.reference_data as resources
    elif build == "GRCh38":
        import gnomad.resources.grch38.reference_data as resources

    criteria = []
    if filter_lcr:
        lcr = resources.lcr_intervals.ht()
        criteria.append(hl.is_missing(lcr[mt.locus]))

    if filter_decoy:
        decoy = resources.decoy_intervals.ht()
        criteria.append(hl.is_missing(decoy[mt.locus]))

    if filter_segdup:
        segdup = resources.seg_dup_intervals.ht()
        criteria.append(hl.is_missing(segdup[mt.locus]))

    if filter_exome_low_coverage_regions:
        high_cov = resources.high_coverage_intervals.ht()
        criteria.append(hl.is_missing(high_cov[mt.locus]))

    if high_conf_regions is not None:
        for region in high_conf_regions:
            region = hl.import_locus_intervals(region)
            criteria.append(hl.is_defined(region[mt.locus]))

    if criteria:
        filter_criteria = functools.reduce(operator.iand, criteria)
        if isinstance(mt, hl.MatrixTable):
            mt = mt.filter_rows(filter_criteria)
        else:
            mt = mt.filter(filter_criteria)

    return mt
Esempio n. 5
0
    def test_import_bgen_dosage_and_gp_dosage_function_agree(self):
        recoding = {'0{}'.format(i): str(i) for i in range(1, 10)}

        sample_file = resource('example.sample')
        bgen_file = resource('example.8bits.bgen')
        hl.index_bgen(bgen_file, contig_recoding=recoding)

        bgenmt = hl.import_bgen(bgen_file, ['GP', 'dosage'], sample_file)
        et = bgenmt.entries()
        et = et.transmute(gp_dosage=hl.gp_dosage(et.GP))
        self.assertTrue(
            et.all((hl.is_missing(et.dosage) & hl.is_missing(et.gp_dosage))
                   | (hl.abs(et.dosage - et.gp_dosage) < 1e-6)))
Esempio n. 6
0
def export_updated_phenos(num_pops=None):
    old_manifest = hl.import_table(get_pheno_manifest_path(),
                                   key=['trait_type','phenocode','pheno_sex',
                                        'coding','modifier'],
                                   impute=True)
    new_manifest = make_pheno_manifest(export=False)
    joined_manifest = old_manifest.join(new_manifest, how='outer')
    joined_manifest = joined_manifest.annotate(pheno_id = get_pheno_id(tb=joined_manifest))
    
    pheno_ids_new_phenos = joined_manifest.filter(hl.is_missing(joined_manifest.pops)).pheno_id.collect()
    pheno_ids_to_update = joined_manifest.filter(joined_manifest.pops!=joined_manifest.pops_1).pheno_id.collect()
    pheno_ids_new_phenos_str = '\n'.join(pheno_ids_new_phenos)
    pheno_ids_to_update_str = '\n'.join(pheno_ids_to_update)
    print(f'\n\nNew phenotypes to be exported:\n{pheno_ids_new_phenos_str}')
    print(f'\nUpdated phenotypes to be exported:\n{pheno_ids_to_update_str}')
    print(f'\n> Number of new phenotypes to be exported: {len(pheno_ids_new_phenos)}')
    print(f'\n> Number of phenotypes to be updated: {len(pheno_ids_to_update)}')
    print(f'\n> Total number of phenotypes to be exported: {len(pheno_ids_to_update+pheno_ids_new_phenos)}\n')
    
    # identify phenotypes that should be removed from previous results
    to_remove = joined_manifest.filter((hl.is_missing(joined_manifest.pops_1)))
    pheno_ids_to_remove = get_pheno_id(tb=to_remove).collect()
    pheno_ids_to_remove_str = '\n'.join(pheno_ids_to_remove)
    print(f'\nPhenotypes to remove:\n{pheno_ids_to_remove_str}')
    print(f'\n> Number of phenotypes to be removed: {len(pheno_ids_to_remove)}\n')
    
    # filtered to phenotypes that need to be updated (either completely new or a different set of populations)
    to_export = joined_manifest.filter((joined_manifest.pops!=joined_manifest.pops_1)|
                                       (hl.is_missing(joined_manifest.pops)))
    print(to_export.select('pops','pops_1').show(int(1e6)))
        
    mt0 = get_final_sumstats_mt_for_export()
    mt0 = mt0.filter_cols(hl.is_defined(to_export[mt0.col_key]))
    
    if num_pops == None:
        num_pops_set = set(to_export.num_pops_1.collect()) # get set of num_pops to run
        print(f'num_pops set: {num_pops_set}')
        for num_pops in num_pops_set:
            export_results(num_pops=num_pops, 
                           trait_types='all', 
                           batch_size=256, 
                           mt=mt0, 
                           export_path_str='update',
                           skip_binary_eur=False)
    else: # useful if parallelizing num_pops over multiple clusters
        export_results(num_pops=num_pops, 
                       trait_types='all', 
                       batch_size=256, 
                       mt=mt0, 
                       export_path_str='update',
                       skip_binary_eur=False)
Esempio n. 7
0
    def test_import_bgen_dosage_and_gp_dosage_function_agree(self):
        recoding = {'0{}'.format(i): str(i) for i in range(1, 10)}

        sample_file = resource('example.sample')
        bgen_file = resource('example.8bits.bgen')
        hl.index_bgen(bgen_file,
                      contig_recoding=recoding)

        bgenmt = hl.import_bgen(bgen_file, ['GP', 'dosage'], sample_file)
        et = bgenmt.entries()
        et = et.transmute(gp_dosage = hl.gp_dosage(et.GP))
        self.assertTrue(et.all(
            (hl.is_missing(et.dosage) & hl.is_missing(et.gp_dosage)) |
            (hl.abs(et.dosage - et.gp_dosage) < 1e-6)))
Esempio n. 8
0
def remap_samples(
    original_mt_path: str,
    input_mt: hl.MatrixTable,
    pedigree: hl.Table,
    inferred_sex: str,
) -> Tuple[hl.MatrixTable, hl.Table]:
    """
    Rename `s` col in the MatrixTable and inferred sex ht.

    :param original_mt_path: Path to original MatrixTable location
    :param input_mt: MatrixTable 
    :param pedigree: Pedigree file from seqr loaded as a Hail Table
    :param inferred_sex: Path to text file of inferred sexes
    :return: mt and sex ht with sample names remapped
    """
    base_path = "/".join(
        dirname(original_mt_path).split("/")[:-1]) + ("/base/projects")
    project_list = list(set(pedigree.Project_GUID.collect()))

    # Get the list of hts containing sample remapping information for each project
    remap_hts = []

    logger.info("Found %d projects that need to be remapped.", len(remap_hts))
    sex_ht = hl.import_table(inferred_sex)

    for i in project_list:
        remap = f"{base_path}/{i}/{i}_remap.tsv"
        if hl.hadoop_is_file(remap):
            remap_ht = hl.import_table(remap)
            remap_ht = remap_ht.key_by("s", "seqr_id")
            remap_hts.append(remap_ht)

    if len(remap_hts) > 0:
        ht = remap_hts[0]
        for next_ht in remap_hts[1:]:
            ht = ht.join(next_ht, how="outer")

        # If a sample has a non-missing value for seqr_id, rename it to the sample name for the mt and sex ht
        ht = ht.key_by("s")
        input_mt = input_mt.annotate_cols(seqr_id=ht[input_mt.s].seqr_id)
        input_mt = input_mt.key_cols_by(s=hl.if_else(
            hl.is_missing(input_mt.seqr_id), input_mt.s, input_mt.seqr_id))

        sex_ht = sex_ht.annotate(seqr_id=ht[sex_ht.s].seqr_id).key_by("s")
        sex_ht = sex_ht.key_by(s=hl.if_else(hl.is_missing(sex_ht.seqr_id),
                                            sex_ht.s, sex_ht.seqr_id))
    else:
        sex_ht = sex_ht.key_by("s")

    return input_mt, sex_ht
def compute_missingness(
    t: Union[hl.MatrixTable, hl.Table],
    info_metrics: List[str],
    non_info_metrics: List[str],
    n_sites: int,
    missingness_threshold: float,
) -> None:
    """
    Check amount of missingness in all row annotations.

    Print metric to sdout if the percentage of metric annotations missingness exceeds the missingness_threshold.

    :param t: Input MatrixTable or Table.
    :param info_metrics: List of metrics in info struct of input Table.
    :param non_info_metrics: List of row annotations minus info struct from input Table.
    :param n_sites: Number of sites in input Table.
    :param missingness_threshold: Upper cutoff for allowed amount of missingness.
    :return: None
    """
    t = t.rows() if isinstance(t, hl.MatrixTable) else t

    logger.info(
        "Missingness threshold (upper cutoff for what is allowed for missingness checks): %.2f",
        missingness_threshold,
    )
    metrics_missing = {}
    for x in info_metrics:
        metrics_missing[x] = hl.agg.sum(hl.is_missing(t.info[x]))
    for x in non_info_metrics:
        metrics_missing[x] = hl.agg.sum(hl.is_missing(t[x]))
    output = dict(t.aggregate(hl.struct(**metrics_missing)))

    n_fail = 0
    for metric, n_missing in output.items():
        if n_missing / n_sites > missingness_threshold:
            logger.info(
                "FAILED missingness check for %s: %d sites or %.2f%% missing",
                metric,
                n_missing,
                (100 * n_missing / n_sites),
            )
            n_fail += 1
        else:
            logger.info(
                "Passed missingness check for %s: %d sites or %.2f%% missing",
                metric,
                n_missing,
                (100 * n_missing / n_sites),
            )
    logger.info("%d missing metrics checks failed", n_fail)
Esempio n. 10
0
def remove_telomeres_centromes(
        t: Union[hl.Table, hl.MatrixTable]) -> Union[hl.Table, hl.MatrixTable]:
    """
    Remove sites overlapping telomeres and centromeres regions

    :param t:  MatrixTable or Table to filter
    :return:  MatrixTable or Table
    """
    tc_intervals = get_telomeres_and_centromeres_ht(overwrite=False)

    if isinstance(t, hl.MatrixTable):
        t = t.filter_rows(hl.is_missing(tc_intervals[t.locus]))
    else:
        t = t.filter(hl.is_missing(tc_intervals[t.locus]))
    return t
Esempio n. 11
0
    def _get_most_severe_csq(csq_list: hl.expr.ArrayExpression,
                             protein_coding: bool) -> hl.expr.StructExpression:
        """
        Processes VEP consequences to generate summary annotations.

        :param csq_list: VEP consequences list to be processed.
        :param protein_coding: Whether variant is in a protein-coding transcript.
        :return: Struct containing summary annotations.
        """
        lof = hl.null(hl.tstr)
        no_lof_flags = hl.null(hl.tbool)
        if protein_coding:
            all_lofs = csq_list.map(lambda x: x.lof)
            lof = hl.literal(loftee_labels).find(
                lambda x: all_lofs.contains(x))
            csq_list = hl.if_else(hl.is_defined(lof),
                                  csq_list.filter(lambda x: x.lof == lof),
                                  csq_list)
            no_lof_flags = hl.or_missing(
                hl.is_defined(lof),
                csq_list.any(lambda x:
                             (x.lof == lof) & hl.is_missing(x.lof_flags)),
            )
        all_csq_terms = csq_list.flatmap(lambda x: x.consequence_terms)
        most_severe_csq = hl.literal(csq_order).find(
            lambda x: all_csq_terms.contains(x))
        return hl.struct(
            most_severe_csq=most_severe_csq,
            protein_coding=protein_coding,
            lof=lof,
            no_lof_flags=no_lof_flags,
        )
Esempio n. 12
0
def hwe_normalized_pca(
        qc_mt: hl.MatrixTable,
        related_samples_to_drop: Optional[hl.Table] = None,
        n_pcs: int = 10
) -> Tuple[List[float], hl.Table, hl.Table]:
    """
    First runs PCA excluding the given related samples,
    then projects these samples in the PC space to return scores for all samples.
    The `related_samples_to_drop` Table has to be keyed by the sample ID and all samples present in this
    table will be excluded from the PCA.
    The loadings Table returned also contains a `pca_af` annotation which is the allele frequency
    used for PCA. This is useful to project other samples in the PC space.
    :param qc_mt: Input QC MT
    :param related_samples_to_drop: Optional table of related samples to drop
    :param n_pcs: Number of PCs to compute
    :param autosomes_only: Whether to run the analysis on autosomes only
    :return: eigenvalues, scores and loadings
    """
    unrelated_mt = qc_mt

    if related_samples_to_drop:
        unrelated_mt = qc_mt.filter_cols(hl.is_missing(related_samples_to_drop[qc_mt.col_key]))

    pca_evals, pca_scores, pca_loadings = hl.hwe_normalized_pca(unrelated_mt.GT, k=n_pcs, compute_loadings=True)
    pca_af_ht = unrelated_mt.annotate_rows(pca_af=hl.agg.mean(unrelated_mt.GT.n_alt_alleles()) / 2).rows()
    pca_loadings = pca_loadings.annotate(pca_af=pca_af_ht[pca_loadings.key].pca_af)

    if not related_samples_to_drop:
        return pca_evals, pca_scores, pca_loadings
    else:
        related_mt = qc_mt.filter_cols(hl.is_defined(related_samples_to_drop[qc_mt.col_key]))
        related_scores = pc_project(related_mt, pca_loadings)
        pca_scores = pca_scores.union(related_scores)
        return pca_evals, pca_scores, pca_loadings
def ref_filtering(ref_mt,
                  pass_mt,
                  unrel,
                  outliers,
                  pass_unrel_mt,
                  overwrite: bool = False):
    mt = hl.read_matrix_table(ref_mt)
    all_sample_filters = set(mt['sample_filters'])
    bad_sample_filters = {
        re.sub('fail_', '', x)
        for x in all_sample_filters if x.startswith('fail_')
    }
    mt_filt = mt.filter_cols(mt['sample_filters']['qc_metrics_filters'].
                             difference(bad_sample_filters).length() == 0)
    mt_filt = mt_filt.checkpoint(pass_mt,
                                 overwrite=False,
                                 _read_if_exists=True)

    mt_unrel = hl.read_matrix_table(unrel)

    mt_filt = mt_filt.filter_rows(
        mt_filt.filters.length() == 0)  # gnomAD QC pass variants
    mt_filt = mt_filt.filter_cols(hl.is_defined(
        mt_unrel.cols()[mt_filt.s]))  # only unrelated

    # remove outliers
    pca_outliers = hl.import_table(outliers).key_by('s')
    mt_filt = mt_filt.filter_cols(hl.is_missing(pca_outliers[mt_filt.s]))

    mt_filt.write(pass_unrel_mt, overwrite)
Esempio n. 14
0
def pull_out_worst_from_tx_annotate(mt):
    csq_order = []
    for loftee_filter in ["HC", "LC"]:
        for no_flag in [True, False]:
            for consequence in CSQ_CODING_HIGH_IMPACT:
                csq_order.append((loftee_filter, no_flag, consequence))

    # prioritization of mis and syn variant on protein coding transcripts
    csq_order.extend([(hl.null(hl.tstr), True, x)
                      for x in CSQ_CODING_MEDIUM_IMPACT + CSQ_CODING_LOW_IMPACT
                      ])

    # Any variant on a non protein coding transcript (ie. where LOF = None)
    csq_order.extend([(hl.null(hl.tstr), True, x)
                      for x in CSQ_CODING_HIGH_IMPACT +
                      CSQ_CODING_MEDIUM_IMPACT + CSQ_CODING_LOW_IMPACT])

    csq_order = hl.literal({(x): i for i, x in enumerate(csq_order)})

    mt = mt.annotate_rows(**hl.sorted(
        mt.tx_annotation,
        key=lambda x: csq_order[
            (x.lof, hl.or_else(hl.is_missing(x.lof_flag), False), x.csq)])[0])

    return mt
Esempio n. 15
0
def combine(ts):
    # pylint: disable=protected-access
    tmp = ts.annotate(
        alleles=merge_alleles(ts.data.map(lambda d: d.alleles)),
        rsid=hl.find(hl.is_defined, ts.data.map(lambda d: d.rsid)),
        info=hl.struct(
            MQ_DP=hl.sum(ts.data.map(lambda d: d.info.MQ_DP)),
            QUALapprox=hl.sum(ts.data.map(lambda d: d.info.QUALapprox)),
            RAW_MQ=hl.sum(ts.data.map(lambda d: d.info.RAW_MQ)),
            VarDP=hl.sum(ts.data.map(lambda d: d.info.VarDP)),
            SB_TABLE=hl.array([
                hl.sum(ts.data.map(lambda d: d.info.SB_TABLE[0])),
                hl.sum(ts.data.map(lambda d: d.info.SB_TABLE[1])),
                hl.sum(ts.data.map(lambda d: d.info.SB_TABLE[2])),
                hl.sum(ts.data.map(lambda d: d.info.SB_TABLE[3]))
            ])))
    tmp = tmp.annotate(
        __entries=hl.bind(
            lambda combined_allele_index:
            hl.range(0, hl.len(tmp.data)).flatmap(
                lambda i:
                hl.cond(hl.is_missing(tmp.data[i].__entries),
                        hl.range(0, hl.len(tmp.g[i].__cols))
                          .map(lambda _: hl.null(tmp.data[i].__entries.dtype.element_type)),
                        hl.bind(
                            lambda old_to_new: tmp.data[i].__entries.map(lambda e: renumber_entry(e, old_to_new)),
                            hl.array([0]).extend(
                                hl.range(0, hl.len(tmp.data[i].alleles)).map(
                                    lambda j: combined_allele_index[tmp.data[i].alleles[j]]))))),
            hl.dict(hl.range(1, hl.len(tmp.alleles) + 1).map(
                lambda j: hl.tuple([tmp.alleles[j - 1], j])))))
    tmp = tmp.annotate_globals(__cols=hl.flatten(tmp.g.map(lambda g: g.__cols)))

    return tmp.drop('data', 'g')
Esempio n. 16
0
def remap_sample_ids(mt: hl.MatrixTable, remap_path: str):
    """
    Maps the MatrixTable's sample ID field ('s') to the sample ID used within seqr ('seqr_id').

    If the sample does not have a mapping in the remap file, their ID becomes their seqr ID.

    :param hl.MatrixTable mt: Input MatrixTable.
    :param str remap_path: Path to a file with two columnsL 's' and 'seqr_id'.
    :return: MatrixTable with VCF sample IDs mapped to seqr IDs and keyed with seqr ID.
    :rtype: hl.MatrixTable
    """
    remap_ht = hl.import_table(remap_path, key="s")
    missing_samples = remap_ht.anti_join(mt.cols()).collect()
    remap_count = remap_ht.count()

    if len(missing_samples) != 0:
        logger.error(
            f"Only {remap_ht.semi_join(mt.cols()).count()} out of {remap_count} "
            "remap IDs matched IDs in the variant callset.\n"
            f"IDs that aren't in the callset: {missing_samples}\n"
            f"All callset sample IDs:{mt.s.collect()}",
            missing_samples,
        )

    mt = mt.annotate_cols(**remap_ht[mt.s])
    remap_expr = hl.cond(hl.is_missing(mt.seqr_id), mt.s, mt.seqr_id)
    mt = mt.annotate_cols(seqr_id=remap_expr, vcf_id=mt.s)
    mt = mt.key_cols_by(s=mt.seqr_id)
    logger.info(f"Remapped {remap_count} sample ids...")
    return mt
Esempio n. 17
0
    def test_annotate_intervals(self):
        ds = get_dataset()

        bed1 = hl.import_bed(resource('example1.bed'), reference_genome='GRCh37')
        bed2 = hl.import_bed(resource('example2.bed'), reference_genome='GRCh37')
        bed3 = hl.import_bed(resource('example3.bed'), reference_genome='GRCh37')
        self.assertTrue(list(bed2.key.dtype) == ['interval'])
        self.assertTrue(list(bed2.row.dtype) == ['interval', 'target'])

        interval_list1 = hl.import_locus_intervals(resource('exampleAnnotation1.interval_list'))
        interval_list2 = hl.import_locus_intervals(resource('exampleAnnotation2.interval_list'))
        self.assertTrue(list(interval_list2.key.dtype) == ['interval'])
        self.assertTrue(list(interval_list2.row.dtype) == ['interval', 'target'])

        ann = ds.annotate_rows(in_interval=bed1[ds.locus]).rows()
        self.assertTrue(ann.all((ann.locus.position <= 14000000) |
                                (ann.locus.position >= 17000000) |
                                (hl.is_missing(ann.in_interval))))

        for bed in [bed2, bed3]:
            ann = ds.annotate_rows(target=bed[ds.locus].target).rows()
            expr = (hl.case()
                    .when(ann.locus.position <= 14000000, ann.target == 'gene1')
                    .when(ann.locus.position >= 17000000, ann.target == 'gene2')
                    .default(ann.target == hl.null(hl.tstr)))
            self.assertTrue(ann.all(expr))

        self.assertTrue(ds.annotate_rows(in_interval=interval_list1[ds.locus]).rows()
                        ._same(ds.annotate_rows(in_interval=bed1[ds.locus]).rows()))

        self.assertTrue(ds.annotate_rows(target=interval_list2[ds.locus].target).rows()
                        ._same(ds.annotate_rows(target=bed2[ds.locus].target).rows()))
Esempio n. 18
0
def add_coding_information(
        mt: hl.MatrixTable,
        coding_ht: hl.Table,
        phesant_phenotype_info_path: str,
        download_missing_codings: bool = False) -> hl.MatrixTable:
    """
    Add coding information from coding_ht as column annotations into mt

    :param MatrixTable mt: Input MT
    :param Table coding_ht: HT with coding information
    :param str phesant_phenotype_info_path: PHESANT phenotype metadata path
    :param bool download_missing_codings: Whether to download missing coding data
    :return: MT with coding information in column data
    :rtype: MatrixTable
    """
    mt = mt.annotate_cols(**coding_ht[(mt.coding_id, hl.str(mt.coding))])
    if download_missing_codings: get_missing_codings(mt.cols())
    phesant_summary = hl.import_table(phesant_phenotype_info_path,
                                      impute=True,
                                      missing='',
                                      key='FieldID')
    phesant_reassign = get_phesant_reassignments(phesant_summary)
    mt = mt.annotate_cols(recoding=hl.or_missing(
        hl.is_missing(mt.meaning), phesant_reassign[mt.col_key.select(
            'phenocode', 'coding')].reassign_from))
    return mt.annotate_cols(
        **hl.cond(hl.is_defined(mt.meaning),
                  hl.struct(**{x: mt[x]
                               for x in list(coding_ht.row_value)}),
                  coding_ht[(mt.coding_id, hl.str(mt.recoding))]), )
Esempio n. 19
0
def impute_sex_aggregator(call,
                          aaf,
                          aaf_threshold=0.0,
                          include_par=False,
                          female_threshold=0.4,
                          male_threshold=0.8) -> hl.Table:
    """:func:`.impute_sex` as an aggregator."""
    mt = call._indices.source
    rg = mt.locus.dtype.reference_genome
    x_contigs = hl.literal(
        hl.eval(
            hl.map(lambda x_contig: hl.parse_locus_interval(x_contig, rg),
                   rg.x_contigs)))
    inbreeding = hl.agg.inbreeding(call, aaf)
    is_female = hl.if_else(
        inbreeding.f_stat < female_threshold, True,
        hl.if_else(inbreeding.f_stat > male_threshold, False,
                   hl.is_missing('tbool')))
    expression = hl.struct(is_female=is_female, **inbreeding)
    if not include_par:
        interval_type = hl.tarray(hl.tinterval(hl.tlocus(rg)))
        par_intervals = hl.literal(rg.par, interval_type)
        expression = hl.agg.filter(
            ~par_intervals.any(
                lambda par_interval: par_interval.contains(mt.locus)),
            expression)
    expression = hl.agg.filter(
        (aaf > aaf_threshold) & (aaf < (1 - aaf_threshold)), expression)
    expression = hl.agg.filter(
        x_contigs.any(lambda contig: contig.contains(mt.locus)), expression)

    return expression
Esempio n. 20
0
def filter_to_clinvar_pathogenic(
    t: Union[hl.MatrixTable, hl.Table],
    clnrevstat_field: str = "CLNREVSTAT",
    clnsig_field: str = "CLNSIG",
    clnsigconf_field: str = "CLNSIGCONF",
    remove_no_assertion: bool = True,
    remove_conflicting: bool = True,
) -> Union[hl.MatrixTable, hl.Table]:
    """
    Return a MatrixTable or Table that filters the clinvar data to pathogenic and likely pathogenic variants.

    Example use:

    .. code-block:: python

        from gnomad.resources.grch38.reference_data import clinvar
        clinvar_ht = clinvar.ht()
        clinvar_ht = filter_to_clinvar_pathogenic(clinvar_ht)

    :param: t: Input dataset that contains clinvar data, could either be a MatrixTable or Table.
    :param clnrevstat_field: The field string for the expression that contains the review status of the clinical significance of clinvar variants.
    :param clnsig_field: The field string for the expression that contains the clinical signifcance of the clinvar variant.
    :param clnsigconf_field: The field string for the expression that contains the conflicting clinical significance values for the variant. For variants with no conflicting significance, this field should be undefined.
    :param remove_no_assertion: Flag for removing entries in which the clnrevstat (clinical significance) has no assertions (zero stars).
    :param remove_conflicting: Flag for removing entries with conflicting clinical interpretations.
    :return: Filtered MatrixTable or Table
    """
    logger.info(
        "Found %d variants before filtering",
        t.count_rows() if isinstance(t, hl.MatrixTable) else t.count(),
    )
    path_expr = (t.info[clnsig_field].map(lambda x: x.lower()).map(
        lambda x: x.contains("pathogenic")).any(lambda x: x))

    if remove_no_assertion:
        logger.info("Variants without assertions will be removed.")
        no_star_assertions = hl.literal({
            "no_assertion_provided",
            "no_assertion_criteria_provided",
            "no_interpretation_for_the_individual_variant",
        })
        path_expr = path_expr & (hl.set(t.info[clnrevstat_field]).intersection(
            no_star_assertions).length() == 0)

    if remove_conflicting:
        logger.info(
            "Variants with conflicting clinical interpretations will be removed."
        )
        path_expr = path_expr & hl.is_missing(t.info[clnsigconf_field])

    if isinstance(t, hl.MatrixTable):
        t = t.filter_rows(path_expr)
    else:
        t = t.filter(path_expr)

    logger.info(
        "Found %d variants after filtering to clinvar pathogenic variants.",
        t.count_rows() if isinstance(t, hl.MatrixTable) else t.count(),
    )
    return t
Esempio n. 21
0
    def filter(self, mt):
        mt = mt.annotate_rows(aaf=hl.agg.call_stats(mt.GT, mt.alleles).AF[1])

        row_filter = mt[self._row_filter].filters if self._row_filter else mt.exclude_row
        col_filter = mt[self._col_filter].filters if self._col_filter else mt.exclude_col

        pre_filter = row_filter | col_filter

        # sex warnings are for ambiguous genotypes (F_male < 0.8, F_female > 0.2) and undefined phenotypes
        mt = mt.annotate_cols(**{
            'sex_ambiguous': hl.struct(
                filters=hl.agg.filter(pre_filter == False,
                                      ((hl.agg.filter(mt.is_female == True,
                                                      impute_sex_aggregator(mt.GT, mt.aaf).f_stat)) > self._fstat_y) |
                                      ((hl.agg.filter(mt.is_female == False,
                                                      impute_sex_aggregator(mt.GT, mt.aaf).f_stat)) < self._fstat_x))
            )})

        if 'is_case' in mt.col:
            mt = mt.annotate_cols(**{
                'sex_warnings': hl.struct(
                    filters=((hl.agg.any(mt['sex_ambiguous'].filters) == True) |
                             (hl.agg.any(hl.is_missing(mt.is_case)))
                             ))})
        else:
            mt = mt.annotate_cols(**{
                'sex_warnings': hl.struct(
                    filters=((hl.agg.any(mt['sex_ambiguous'].filters) == True)
                             ))})

        return mt
Esempio n. 22
0
    def test_concordance(self):
        dataset = get_dataset()
        glob_conc, cols_conc, rows_conc = hl.concordance(dataset, dataset)

        self.assertEqual(sum([sum(glob_conc[i]) for i in range(5)]), dataset.count_rows() * dataset.count_cols())

        counts = dataset.aggregate_entries(hl.Struct(n_het=agg.filter(dataset.GT.is_het(), agg.count()),
                                                     n_hom_ref=agg.filter(dataset.GT.is_hom_ref(),
                                                                          agg.count()),
                                                     n_hom_var=agg.filter(dataset.GT.is_hom_var(),
                                                                          agg.count()),
                                                     nNoCall=agg.filter(hl.is_missing(dataset.GT),
                                                                        agg.count())))

        self.assertEqual(glob_conc[0][0], 0)
        self.assertEqual(glob_conc[1][1], counts.nNoCall)
        self.assertEqual(glob_conc[2][2], counts.n_hom_ref)
        self.assertEqual(glob_conc[3][3], counts.n_het)
        self.assertEqual(glob_conc[4][4], counts.n_hom_var)
        [self.assertEqual(glob_conc[i][j], 0) for i in range(5) for j in range(5) if i != j]

        self.assertTrue(cols_conc.all(hl.sum(hl.flatten(cols_conc.concordance)) == dataset.count_rows()))
        self.assertTrue(rows_conc.all(hl.sum(hl.flatten(rows_conc.concordance)) == dataset.count_cols()))

        cols_conc.write('/tmp/foo.kt', overwrite=True)
        rows_conc.write('/tmp/foo.kt', overwrite=True)
Esempio n. 23
0
def combine(ts):
    def merge_alleles(alleles):
        from hail.expr.functions import _num_allele_type, _allele_ints
        return hl.rbind(
            alleles.map(lambda a: hl.or_else(a[0], '')).fold(
                lambda s, t: hl.cond(hl.len(s) > hl.len(t), s, t), ''),
            lambda ref: hl.rbind(
                alleles.map(lambda al: hl.rbind(
                    al[0], lambda r: hl.array([ref]).
                    extend(al[1:].map(lambda a: hl.rbind(
                        _num_allele_type(r, a), lambda at: hl.cond(
                            (_allele_ints['SNP'] == at)
                            | (_allele_ints['Insertion'] == at)
                            | (_allele_ints['Deletion'] == at)
                            | (_allele_ints['MNP'] == at)
                            | (_allele_ints['Complex'] == at), a + ref[hl.len(
                                r):], a)))))), lambda lal: hl.
                struct(globl=hl.array([ref]).extend(
                    hl.array(hl.set(hl.flatten(lal)).remove(ref))),
                       local=lal)))

    def renumber_entry(entry, old_to_new) -> StructExpression:
        # global index of alternate (non-ref) alleles
        return entry.annotate(LA=entry.LA.map(lambda lak: old_to_new[lak]))

    if (ts.row.dtype, ts.globals.dtype) not in _merge_function_map:
        f = hl.experimental.define_function(
            lambda row, gbl: hl.rbind(
                merge_alleles(row.data.map(lambda d: d.alleles)), lambda
                alleles: hl.struct(
                    locus=row.locus,
                    alleles=alleles.globl,
                    rsid=hl.find(hl.is_defined, row.data.map(lambda d: d.rsid)
                                 ),
                    __entries=hl.bind(
                        lambda combined_allele_index: hl.
                        range(0, hl.len(row.data)).flatmap(lambda i: hl.cond(
                            hl.is_missing(row.data[i].__entries),
                            hl.range(0, hl.len(gbl.g[i].__cols)).map(
                                lambda _: hl.null(row.data[i].__entries.dtype.
                                                  element_type)),
                            hl.bind(
                                lambda old_to_new: row.data[i].__entries.map(
                                    lambda e: renumber_entry(e, old_to_new)),
                                hl.range(0, hl.len(alleles.local[i])).map(
                                    lambda j: combined_allele_index[
                                        alleles.local[i][j]])))),
                        hl.dict(
                            hl.range(0, hl.len(alleles.globl)).map(
                                lambda j: hl.tuple([alleles.globl[j], j])))))),
            ts.row.dtype, ts.globals.dtype)
        _merge_function_map[(ts.row.dtype, ts.globals.dtype)] = f
    merge_function = _merge_function_map[(ts.row.dtype, ts.globals.dtype)]
    ts = Table(
        TableMapRows(
            ts._tir,
            Apply(merge_function._name, merge_function._ret_type,
                  TopLevelReference('row'), TopLevelReference('global'))))
    return ts.transmute_globals(
        __cols=hl.flatten(ts.g.map(lambda g: g.__cols)))
Esempio n. 24
0
def combine(ts):
    # pylint: disable=protected-access
    tmp = ts.annotate(
        alleles=merge_alleles(ts.data.map(lambda d: d.alleles)),
        rsid=hl.find(hl.is_defined, ts.data.map(lambda d: d.rsid)),
        filters=hl.set(hl.flatten(ts.data.map(lambda d: hl.array(d.filters)))),
        info=hl.struct(
            DP=hl.sum(ts.data.map(lambda d: d.info.DP)),
            MQ_DP=hl.sum(ts.data.map(lambda d: d.info.MQ_DP)),
            QUALapprox=hl.sum(ts.data.map(lambda d: d.info.QUALapprox)),
            RAW_MQ=hl.sum(ts.data.map(lambda d: d.info.RAW_MQ)),
            VarDP=hl.sum(ts.data.map(lambda d: d.info.VarDP)),
            SB=hl.array([
                hl.sum(ts.data.map(lambda d: d.info.SB[0])),
                hl.sum(ts.data.map(lambda d: d.info.SB[1])),
                hl.sum(ts.data.map(lambda d: d.info.SB[2])),
                hl.sum(ts.data.map(lambda d: d.info.SB[3]))
            ])))
    tmp = tmp.annotate(
        __entries=hl.bind(
            lambda combined_allele_index:
            hl.range(0, hl.len(tmp.data)).flatmap(
                lambda i:
                hl.cond(hl.is_missing(tmp.data[i].__entries),
                        hl.range(0, hl.len(tmp.g[i].__cols))
                          .map(lambda _: hl.null(tmp.data[i].__entries.dtype.element_type)),
                        hl.bind(
                            lambda old_to_new: tmp.data[i].__entries.map(lambda e: renumber_entry(e, old_to_new)),
                            hl.range(0, hl.len(tmp.data[i].alleles)).map(
                                lambda j: combined_allele_index[tmp.data[i].alleles[j]])))),
            hl.dict(hl.range(0, hl.len(tmp.alleles)).map(
                lambda j: hl.tuple([tmp.alleles[j], j])))))
    tmp = tmp.annotate_globals(__cols=hl.flatten(tmp.g.map(lambda g: g.__cols)))

    return tmp.drop('data', 'g')
Esempio n. 25
0
    def remap_sample_ids(mt, remap_path):
        """
        Remap the MatrixTable's sample ID, 's', field to the sample ID used within seqr, 'seqr_id'
        If the sample 's' does not have a 'seqr_id' in the remap file, 's' becomes 'seqr_id'
        :param mt: MatrixTable from VCF
        :param remap_path: Path to a file with two columns 's' and 'seqr_id'
        :return: MatrixTable remapped and keyed to use seqr_id
        """
        remap_ht = hl.import_table(remap_path, key='s')
        missing_samples = remap_ht.anti_join(mt.cols()).collect()
        remap_count = remap_ht.count()

        if len(missing_samples) != 0:
            raise MatrixTableSampleSetError(
                f'Only {remap_ht.semi_join(mt.cols()).count()} out of {remap_count} '
                'remap IDs matched IDs in the variant callset.\n'
                f'IDs that aren\'t in the callset: {missing_samples}\n'
                f'All callset sample IDs:{mt.s.collect()}', missing_samples)

        mt = mt.annotate_cols(**remap_ht[mt.s])
        remap_expr = hl.cond(hl.is_missing(mt.seqr_id), mt.s, mt.seqr_id)
        mt = mt.annotate_cols(seqr_id=remap_expr, vcf_id=mt.s)
        mt = mt.key_cols_by(s=mt.seqr_id)
        logger.info(f'Remapped {remap_count} sample ids...')
        return mt
Esempio n. 26
0
    def get_contig_size(contig: str) -> int:
        logger.info(f"Working on {contig}")
        contig_ht = hl.utils.range_table(
            ref.contig_length(contig),
            n_partitions=int(ref.contig_length(contig) / 500_000),
        )
        contig_ht = contig_ht.annotate(
            locus=hl.locus(contig=contig, pos=contig_ht.idx + 1, reference_genome=ref)
        )
        contig_ht = contig_ht.filter(contig_ht.locus.sequence_context().lower() != "n")

        if contig in ref.x_contigs:
            contig_ht = contig_ht.filter(contig_ht.locus.in_x_nonpar())
        if contig in ref.y_contigs:
            contig_ht = contig_ht.filter(contig_ht.locus.in_y_nonpar())

        contig_ht = contig_ht.key_by("locus")
        if included_calling_intervals is not None:
            contig_ht = contig_ht.filter(
                hl.is_defined(included_calling_intervals[contig_ht.key])
            )
        if excluded_calling_intervals is not None:
            contig_ht = contig_ht.filter(
                hl.is_missing(excluded_calling_intervals[contig_ht.key])
            )
        contig_size = contig_ht.count()
        logger.info(f"Contig {contig} has {contig_size} bases for coverage.")
        return contig_size
Esempio n. 27
0
    def test_concordance(self):
        dataset = get_dataset()
        glob_conc, cols_conc, rows_conc = hl.concordance(dataset, dataset)

        self.assertEqual(sum([sum(glob_conc[i]) for i in range(5)]), dataset.count_rows() * dataset.count_cols())

        counts = dataset.aggregate_entries(hl.Struct(n_het=agg.filter(dataset.GT.is_het(), agg.count()),
                                                     n_hom_ref=agg.filter(dataset.GT.is_hom_ref(),
                                                                          agg.count()),
                                                     n_hom_var=agg.filter(dataset.GT.is_hom_var(),
                                                                          agg.count()),
                                                     nNoCall=agg.filter(hl.is_missing(dataset.GT),
                                                                        agg.count())))

        self.assertEqual(glob_conc[0][0], 0)
        self.assertEqual(glob_conc[1][1], counts.nNoCall)
        self.assertEqual(glob_conc[2][2], counts.n_hom_ref)
        self.assertEqual(glob_conc[3][3], counts.n_het)
        self.assertEqual(glob_conc[4][4], counts.n_hom_var)
        [self.assertEqual(glob_conc[i][j], 0) for i in range(5) for j in range(5) if i != j]

        self.assertTrue(cols_conc.all(hl.sum(hl.flatten(cols_conc.concordance)) == dataset.count_rows()))
        self.assertTrue(rows_conc.all(hl.sum(hl.flatten(rows_conc.concordance)) == dataset.count_cols()))

        cols_conc.write('/tmp/foo.kt', overwrite=True)
        rows_conc.write('/tmp/foo.kt', overwrite=True)
Esempio n. 28
0
def conditional_phenotypes(mt: hl.MatrixTable,
                           column_field,
                           entry_field,
                           lists_of_columns,
                           new_col_name='grouping',
                           new_entry_name='new_entry'):
    """
    Create a conditional phenotype by setting phenotype1 to missing for any individual without phenotype2.

    Pheno1 Pheno2 new_pheno
    T      T      T
    T      F      NA
    F      F      NA
    F      T      F

    `lists_of_columns` should be a list of lists (of length 2 for the inner list).
    The first element corresponds to the phenotype to maintain, except for setting to missing when the
    phenotype coded by the second element is False.

    new_entry = Pheno1 conditioned on having Pheno2

    Example:

    mt = hl.balding_nichols_model(1, 3, 10).drop('GT')
    mt = mt.annotate_entries(pheno=hl.rand_bool(0.5))
    lists_of_columns = [[0, 1], [2, 1]]
    entry_field = mt.pheno
    column_field = mt.sample_idx

    :param MatrixTable mt: Input MatrixTable
    :param Expression column_field: Column-indexed Expression to group by
    :param Expression entry_field: Entry-indexed Expression to which to apply `grouping_function`
    :param list of list lists_of_columns: Entry in this list should be the same type as `column_field`
    :param str new_col_name: Name for new column key (default 'grouping')
    :param str new_entry_name: Name for new entry expression (default 'new_entry')
    :return: Re-grouped MatrixTable
    :rtype: MatrixTable
    """
    assert all([len(x) == 2 for x in lists_of_columns])
    lists_of_columns = hl.literal(lists_of_columns)
    mt = mt._annotate_all(col_exprs={'_col_expr': column_field},
                          entry_exprs={'_entry_expr': entry_field})
    mt = mt.annotate_cols(
        _col_expr=lists_of_columns.filter(lambda x: x.contains(
            mt._col_expr)).map(lambda y: (y, y[0] == mt._col_expr)))
    mt = mt.explode_cols('_col_expr')
    # if second element (~mt._col_expr[1]) is false (~mt._entry_expr), then return missing
    # otherwise, get actual element (either true if second element, or actual first element)
    bool_array = hl.agg.collect(
        hl.if_else(~mt._col_expr[1] & ~mt._entry_expr, hl.null(hl.tbool),
                   mt._entry_expr))
    # if any element is missing, return missing. otherwise return first element
    return mt.group_cols_by(**{
        new_col_name: mt._col_expr[0]
    }).aggregate(
        **{
            new_entry_name:
            hl.if_else(hl.any(lambda x: hl.is_missing(x), bool_array),
                       hl.null(hl.tbool), bool_array[0] & bool_array[1])
        })
def pre_process_subset_freq(subset: str,
                            global_ht: hl.Table,
                            test: bool = False) -> hl.Table:
    """
    Prepare subset frequency Table by filling in missing frequency fields for loci present only in the global cohort.

    .. note::

        The resulting final `freq` array will be as long as the subset `freq_meta` global (i.e., one `freq` entry for each `freq_meta` entry)

    :param subset: subset ID
    :param global_ht: Hail Table containing all variants discovered in the overall release cohort
    :param test: If True, filter to small region on chr20
    :return: Table containing subset frequencies with missing freq structs filled in
    """

    # Read in subset HTs
    subset_ht_path = get_freq(subset=subset).path
    subset_chr20_ht_path = qc_temp_prefix() + f"chr20_test_freq.{subset}.ht"

    if test:
        if file_exists(subset_chr20_ht_path):
            logger.info(
                "Loading chr20 %s subset frequency data for testing: %s",
                subset,
                subset_chr20_ht_path,
            )
            subset_ht = hl.read_table(subset_chr20_ht_path)

        elif file_exists(subset_ht_path):
            logger.info(
                "Loading %s subset frequency data for testing: %s",
                subset,
                subset_ht_path,
            )
            subset_ht = hl.read_table(subset_ht_path)
            subset_ht = hl.filter_intervals(
                subset_ht, [hl.parse_locus_interval("chr20:1-1000000")])

    elif file_exists(subset_ht_path):
        logger.info("Loading %s subset frequency data: %s", subset,
                    subset_ht_path)
        subset_ht = hl.read_table(subset_ht_path)

    else:
        raise DataException(
            f"Hail Table containing {subset} subset frequencies not found. You may need to run the script generate_freq_data.py to generate frequency annotations first."
        )

    # Fill in missing freq structs
    ht = subset_ht.join(global_ht.select().select_globals(), how="right")
    ht = ht.annotate(freq=hl.if_else(
        hl.is_missing(ht.freq),
        hl.map(lambda x: missing_callstats_expr(),
               hl.range(hl.len(ht.freq_meta))),
        ht.freq,
    ))

    return ht
Esempio n. 30
0
def filter_low_conf_regions(
        mt: Union[hl.MatrixTable, hl.Table],
        filter_lcr: bool = True,
        filter_segdup: bool = True) -> Union[hl.MatrixTable, hl.Table]:
    """
    Filters low-confidence regions

    :param mt: MatrixTable or Table to filter
    :param filter_lcr: Whether to filter LCR regions
    # :param filter_decoy: Whether to filter decoy regions
    :param filter_segdup: Whether to filter Segdup regions
    # :param filter_exome_low_coverage_regions: Whether to filter exome low confidence regions
    # :param high_conf_regions: Paths to set of high confidence regions to restrict to (union of regions)
    :return: MatrixTable or Table with low confidence regions removed
    """

    criteria = []
    if filter_lcr:
        lcr = get_lcr_ht()
        criteria.append(hl.is_missing(lcr[mt.locus]))

    if filter_segdup:
        segdup = get_segdups_ht()
        criteria.append(hl.is_missing(segdup[mt.locus]))

    # if filter_decoy:
    #    decoy = resources.decoy_intervals.ht()
    #    criteria.append(hl.is_missing(decoy[mt.locus]))

    # if filter_exome_low_coverage_regions:
    #    high_cov = resources.high_coverage_intervals.ht()
    #    criteria.append(hl.is_missing(high_cov[mt.locus]))

    # if high_conf_regions is not None:
    #    for region in high_conf_regions:
    #        region = hl.import_locus_intervals(region)
    #        criteria.append(hl.is_defined(region[mt.locus]))

    if criteria:
        filter_criteria = functools.reduce(operator.iand, criteria)
        if isinstance(mt, hl.MatrixTable):
            mt = mt.filter_rows(filter_criteria)
        else:
            mt = mt.filter(filter_criteria)

    return mt
Esempio n. 31
0
 def test_interval_join(self):
     left = hl.utils.range_table(50, n_partitions=10)
     intervals = hl.utils.range_table(4)
     intervals = intervals.key_by(interval=hl.interval(intervals.idx * 10, intervals.idx * 10 + 5))
     left = left.annotate(interval_matches=intervals.index(left.key))
     self.assertTrue(left.all(hl.case()
                              .when(left.idx % 10 < 5, left.interval_matches.idx == left.idx // 10)
                              .default(hl.is_missing(left.interval_matches))))
Esempio n. 32
0
def write_ldsc_hm3_snplist(info_threshold=0.9,
                           maf_threshold=0.01,
                           overwrite=False):
    # Filter variants
    ht = hl.read_table(get_variant_results_qc_path())
    # in autosomes
    ht = ht.filter(ht.locus.in_autosome())
    # no MHC
    ht = ht.filter(
        ~hl.parse_locus_interval('6:28477797-33448354').contains(ht.locus))
    # info > 0.9
    ht = ht.filter(ht.info > info_threshold)
    # SNP only
    ht = ht.filter(hl.is_snp(ht.alleles[0], ht.alleles[1]))
    # no multi-allelic sites
    loc_count = ht.group_by(ht.locus).aggregate(nloc=hl.agg.count())
    loc_count = loc_count.filter(loc_count.nloc > 1)
    multi_sites = loc_count.aggregate(hl.agg.collect_as_set(loc_count.locus),
                                      _localize=False)
    ht = ht.filter(~multi_sites.contains(ht.locus))

    # in HM3
    hm3_snps = hl.read_table(
        'gs://ukbb-ldsc-dev/ukb_hm3_snplist/hm3.r3.b37.auto_bi_af.ht')
    hm3_snps = hm3_snps.select()
    ht = ht.join(hm3_snps, 'right')
    # no strand ambiguity
    ht = ht.filter(~hl.is_strand_ambiguous(ht.alleles[0], ht.alleles[1]))

    ht = checkpoint_tmp(ht)

    def get_maf(af):
        return 0.5 - hl.abs(0.5 - af)

    # MAF > 1% in UKB & gnomad genome/exome (if defined) for each population
    for pop in POPS:
        snplist = ht.filter(
            hl.rbind(
                ht.freq[ht.freq.index(lambda x: x.pop == pop)], lambda y:
                (get_maf(y.af) > maf_threshold) &
                (hl.is_missing(y.gnomad_genomes_af) |
                 (get_maf(y.gnomad_genomes_af) > maf_threshold)) &
                (hl.is_missing(y.gnomad_exomes_af) |
                 (get_maf(y.gnomad_exomes_af) > maf_threshold))))
        snplist = snplist.select('rsid')
        snplist.write(get_hm3_snplist_path(pop), overwrite=overwrite)
Esempio n. 33
0
def filter_out_segdups(mt, genome_version="GRCh38"):

    if genome_version == "GRCh38":
        segdup_regions = hl.import_locus_intervals("gs://broad-dsp-spec-ops/scratch/weisburd/ref/GRCh38/GRCh38GenomicSuperDup.without_decoys.bed", reference_genome="GRCh38")
    else:
        raise ValueError(f"Invalid genome version: {genome_version}")

    return mt.filter_rows(hl.is_missing(segdup_regions[mt.locus]))
Esempio n. 34
0
def test_segfault():
    t = hl.utils.range_table(1)
    t2 = hl.utils.range_table(3)
    t = t.annotate(foo=[0])
    t2 = t2.annotate(foo=[0])
    joined = t.key_by('foo').join(t2.key_by('foo'))
    joined = joined.filter(hl.is_missing(joined.idx))
    assert joined.collect() == []
Esempio n. 35
0
 def test_interval_join(self):
     left = hl.utils.range_table(50, n_partitions=10)
     intervals = hl.utils.range_table(4)
     intervals = intervals.key_by(interval=hl.interval(intervals.idx * 10, intervals.idx * 10 + 5))
     left = left.annotate(interval_matches=intervals.index(left.key))
     self.assertTrue(left.all(hl.case()
                              .when(left.idx % 10 < 5, left.interval_matches.idx == left.idx // 10)
                              .default(hl.is_missing(left.interval_matches))))
Esempio n. 36
0
    def recur_expr(expr, path):
        d = {}
        missingness = append_agg(hl.agg.count_where(hl.is_missing(expr)))
        d['type'] = lambda _: str(expr.dtype)
        d['missing'] = lambda \
                results: f'{results[missingness]} values ({pct(results[missingness] / results[count])})'

        t = expr.dtype

        if t in (hl.tint32, hl.tint64, hl.tfloat32, hl.tfloat64):
            stats = append_agg(hl.agg.stats(expr))
            if t in (hl.tint32, hl.tint64):
                d['minimum'] = lambda results: format(map_int(results[stats]['min']))
                d['maximum'] = lambda results: format(map_int(results[stats]['max']))
                d['sum'] = lambda results: format(map_int(results[stats]['sum']))
            else:
                d['minimum'] = lambda results: format(results[stats]['min'])
                d['maximum'] = lambda results: format(results[stats]['max'])
                d['sum'] = lambda results: format(results[stats]['sum'])
            d['mean'] = lambda results: format(results[stats]['mean'])
            d['stdev'] = lambda results: format(results[stats]['stdev'])
        elif t == hl.tbool:
            counter = append_agg(hl.agg.filter(hl.is_defined(expr), hl.agg.counter(expr)))
            d['counts'] = lambda results: format(results[counter])
        elif t == hl.tstr:
            size = append_agg(hl.agg.stats(hl.len(expr)))
            take = append_agg(hl.agg.filter(hl.is_defined(expr), hl.agg.take(expr, 5)))
            d['minimum size'] = lambda results: format(map_int(results[size]['min']))
            d['maximum size'] = lambda results: format(map_int(results[size]['max']))
            d['mean size'] = lambda results: format(results[size]['mean'])
            d['sample values'] = lambda results: format(results[take])
        elif t == hl.tcall:
            ploidy_counts = append_agg(hl.agg.filter(hl.is_defined(expr), hl.agg.counter(expr.ploidy)))
            phased_counts = append_agg(hl.agg.filter(hl.is_defined(expr), hl.agg.counter(expr.phased)))
            n_hom_ref = append_agg(hl.agg.count_where(expr.is_hom_ref()))
            n_hom_var = append_agg(hl.agg.count_where(expr.is_hom_var()))
            n_het = append_agg(hl.agg.count_where(expr.is_het()))
            d['homozygous reference'] = lambda results: format(results[n_hom_ref])
            d['heterozygous'] = lambda results: format(results[n_het])
            d['homozygous variant'] = lambda results: format(results[n_hom_var])
            d['ploidy'] = lambda results: format(results[ploidy_counts])
            d['phased'] = lambda results: format(results[phased_counts])
        elif isinstance(t, hl.tlocus):
            contig_counts = append_agg(hl.agg.filter(hl.is_defined(expr), hl.agg.counter(expr.contig)))
            d['contig counts'] = lambda results: format(results[contig_counts])
        elif isinstance(t, (hl.tset, hl.tdict, hl.tarray)):
            size = append_agg(hl.agg.stats(hl.len(expr)))
            d['minimum size'] = lambda results: format(map_int(results[size]['min']))
            d['maximum size'] = lambda results: format(map_int(results[size]['max']))
            d['mean size'] = lambda results: format(results[size]['mean'])
        to_print.append((path, d))
        if isinstance(t, hl.ttuple):
            for i in range(len(expr)):
                recur_expr(expr[i], f'{path} / {i}')
        if isinstance(t, hl.tstruct):
            for k, v in expr.items():
                recur_expr(v, f'{path} / {repr(k)[1:-1]}')
Esempio n. 37
0
    def test_reference_genome_liftover(self):
        grch37 = hl.get_reference('GRCh37')
        grch38 = hl.get_reference('GRCh38')

        self.assertTrue(not grch37.has_liftover('GRCh38') and not grch38.has_liftover('GRCh37'))
        grch37.add_liftover(resource('grch37_to_grch38_chr20.over.chain.gz'), 'GRCh38')
        grch38.add_liftover(resource('grch38_to_grch37_chr20.over.chain.gz'), 'GRCh37')
        self.assertTrue(grch37.has_liftover('GRCh38') and grch38.has_liftover('GRCh37'))

        ds = hl.import_vcf(resource('sample.vcf'))
        t = ds.annotate_rows(liftover=hl.liftover(hl.liftover(ds.locus, 'GRCh38'), 'GRCh37')).rows()
        self.assertTrue(t.all(t.locus == t.liftover))

        null_locus = hl.null(hl.tlocus('GRCh38'))

        rows = [
            {'l37': hl.locus('20', 1, 'GRCh37'), 'l38': null_locus},
            {'l37': hl.locus('20', 60000, 'GRCh37'), 'l38': null_locus},
            {'l37': hl.locus('20', 60001, 'GRCh37'), 'l38': hl.locus('chr20', 79360, 'GRCh38')},
            {'l37': hl.locus('20', 278686, 'GRCh37'), 'l38': hl.locus('chr20', 298045, 'GRCh38')},
            {'l37': hl.locus('20', 278687, 'GRCh37'), 'l38': hl.locus('chr20', 298046, 'GRCh38')},
            {'l37': hl.locus('20', 278688, 'GRCh37'), 'l38': null_locus},
            {'l37': hl.locus('20', 278689, 'GRCh37'), 'l38': null_locus},
            {'l37': hl.locus('20', 278690, 'GRCh37'), 'l38': null_locus},
            {'l37': hl.locus('20', 278691, 'GRCh37'), 'l38': hl.locus('chr20', 298047, 'GRCh38')},
            {'l37': hl.locus('20', 37007586, 'GRCh37'), 'l38': hl.locus('chr12', 32563117, 'GRCh38')},
            {'l37': hl.locus('20', 62965520, 'GRCh37'), 'l38': hl.locus('chr20', 64334167, 'GRCh38')},
            {'l37': hl.locus('20', 62965521, 'GRCh37'), 'l38': null_locus}
        ]
        schema = hl.tstruct(l37=hl.tlocus(grch37), l38=hl.tlocus(grch38))
        t = hl.Table.parallelize(rows, schema)
        self.assertTrue(t.all(hl.cond(hl.is_defined(t.l38),
                                      hl.liftover(t.l37, 'GRCh38') == t.l38,
                                      hl.is_missing(hl.liftover(t.l37, 'GRCh38')))))

        t = t.filter(hl.is_defined(t.l38))
        self.assertTrue(t.count() == 6)

        t = t.key_by('l38')
        t.count()
        self.assertTrue(list(t.key) == ['l38'])

        null_locus_interval = hl.null(hl.tinterval(hl.tlocus('GRCh38')))
        rows = [
            {'i37': hl.locus_interval('20', 1, 60000, True, False, 'GRCh37'), 'i38': null_locus_interval},
            {'i37': hl.locus_interval('20', 60001, 82456, True, True, 'GRCh37'),
             'i38': hl.locus_interval('chr20', 79360, 101815, True, True, 'GRCh38')}
        ]
        schema = hl.tstruct(i37=hl.tinterval(hl.tlocus(grch37)), i38=hl.tinterval(hl.tlocus(grch38)))
        t = hl.Table.parallelize(rows, schema)
        self.assertTrue(t.all(hl.liftover(t.i37, 'GRCh38') == t.i38))

        grch37.remove_liftover("GRCh38")
        grch38.remove_liftover("GRCh37")
Esempio n. 38
0
    def test_entry_join_missingness(self):
        mt1 = hl.utils.range_matrix_table(10, 10, n_partitions=4)
        mt1 = mt1.annotate_entries(x=mt1.row_idx + mt1.col_idx)

        mt2 = mt1.filter_cols(mt1.col_idx % 2 == 0)
        mt2 = mt2.filter_rows(mt2.row_idx % 2 == 0)
        mt_join = mt1.annotate_entries(x2=mt2[mt1.row_idx, mt1.col_idx].x * 10)
        mt_join_entries = mt_join.entries()

        kept = mt_join_entries.filter((mt_join_entries.row_idx % 2 == 0) & (mt_join_entries.col_idx % 2 == 0))
        removed = mt_join_entries.filter(~((mt_join_entries.row_idx % 2 == 0) & (mt_join_entries.col_idx % 2 == 0)))

        self.assertTrue(kept.all(hl.is_defined(kept.x2) & (kept.x2 == kept.x * 10)))
        self.assertTrue(removed.all(hl.is_missing(removed.x2)))
Esempio n. 39
0
    def test_field_groups(self):
        ds = self.get_vds()

        df = ds.annotate_rows(row_struct=ds.row).rows()
        self.assertTrue(df.all((df.info == df.row_struct.info) & (df.qual == df.row_struct.qual)))

        ds2 = ds.add_col_index()
        df = ds2.annotate_cols(col_struct=ds2.col).cols()
        self.assertTrue(df.all((df.col_idx == df.col_struct.col_idx)))

        df = ds.annotate_entries(entry_struct=ds.entry).entries()
        self.assertTrue(df.all(
            ((hl.is_missing(df.GT) |
              (df.GT == df.entry_struct.GT)) &
             (df.AD == df.entry_struct.AD))))
Esempio n. 40
0
 def test_computed_key_join_3(self):
     # duplicate row keys
     ds = self.get_vds()
     kt = hl.Table.parallelize(
         [{'culprit': 'InbreedingCoeff', 'foo': 'bar', 'value': 'IB'}],
         hl.tstruct(culprit=hl.tstr, foo=hl.tstr, value=hl.tstr),
         key=['culprit', 'foo'])
     ds = ds.annotate_rows(
         dsfoo='bar',
         info=ds.info.annotate(culprit=[ds.info.culprit, "foo"]))
     ds = ds.explode_rows(ds.info.culprit)
     ds = ds.annotate_rows(value=kt[ds.info.culprit, ds.dsfoo]['value'])
     rt = ds.rows()
     self.assertTrue(
         rt.all(hl.cond(
             rt.info.culprit == "InbreedingCoeff",
             rt['value'] == "IB",
             hl.is_missing(rt['value']))))
Esempio n. 41
0
File: qc.py Progetto: tpoterba/hail
def sample_qc(mt, name='sample_qc') -> MatrixTable:
    """Compute per-sample metrics useful for quality control.

    .. include:: ../_templates/req_tvariant.rst

    Examples
    --------

    Compute sample QC metrics and remove low-quality samples:

    >>> dataset = hl.sample_qc(dataset, name='sample_qc')
    >>> filtered_dataset = dataset.filter_cols((dataset.sample_qc.dp_stats.mean > 20) & (dataset.sample_qc.r_ti_tv > 1.5))

    Notes
    -----

    This method computes summary statistics per sample from a genetic matrix and stores
    the results as a new column-indexed struct field in the matrix, named based on the
    `name` parameter.

    If `mt` contains an entry field `DP` of type :py:data:`.tint32`, then the
    field `dp_stats` is computed. If `mt` contains an entry field `GQ` of type
    :py:data:`.tint32`, then the field `gq_stats` is computed. Both `dp_stats`
    and `gq_stats` are structs with with four fields:

    - `mean` (``float64``) -- Mean value.
    - `stdev` (``float64``) -- Standard deviation (zero degrees of freedom).
    - `min` (``int32``) -- Minimum value.
    - `max` (``int32``) -- Maximum value.

    If the dataset does not contain an entry field `GT` of type
    :py:data:`.tcall`, then an error is raised. The following fields are always
    computed from `GT`:

    - `call_rate` (``float64``) -- Fraction of calls non-missing.
    - `n_called` (``int64``) -- Number of non-missing calls.
    - `n_not_called` (``int64``) -- Number of missing calls.
    - `n_hom_ref` (``int64``) -- Number of homozygous reference calls.
    - `n_het` (``int64``) -- Number of heterozygous calls.
    - `n_hom_var` (``int64``) -- Number of homozygous alternate calls.
    - `n_non_ref` (``int64``) -- Sum of ``n_het`` and ``n_hom_var``.
    - `n_snp` (``int64``) -- Number of SNP alternate alleles.
    - `n_insertion` (``int64``) -- Number of insertion alternate alleles.
    - `n_deletion` (``int64``) -- Number of deletion alternate alleles.
    - `n_singleton` (``int64``) -- Number of private alleles.
    - `n_transition` (``int64``) -- Number of transition (A-G, C-T) alternate alleles.
    - `n_transversion` (``int64``) -- Number of transversion alternate alleles.
    - `n_star` (``int64``) -- Number of star (upstream deletion) alleles.
    - `r_ti_tv` (``float64``) -- Transition/Transversion ratio.
    - `r_het_hom_var` (``float64``) -- Het/HomVar call ratio.
    - `r_insertion_deletion` (``float64``) -- Insertion/Deletion allele ratio.

    Missing values ``NA`` may result from division by zero.

    Parameters
    ----------
    mt : :class:`.MatrixTable`
        Dataset.
    name : :obj:`str`
        Name for resulting field.

    Returns
    -------
    :class:`.MatrixTable`
        Dataset with a new column-indexed field `name`.
    """

    require_row_key_variant(mt, 'sample_qc')

    from hail.expr.functions import _num_allele_type , _allele_types

    allele_types = _allele_types[:]
    allele_types.extend(['Transition', 'Transversion'])
    allele_enum = {i: v for i, v in enumerate(allele_types)}
    allele_ints = {v: k for k, v in allele_enum.items()}

    def allele_type(ref, alt):
        return hl.bind(lambda at: hl.cond(at == allele_ints['SNP'],
                                          hl.cond(hl.is_transition(ref, alt),
                                                  allele_ints['Transition'],
                                                  allele_ints['Transversion']),
                                          at),
                       _num_allele_type(ref, alt))

    variant_ac = Env.get_uid()
    variant_atypes = Env.get_uid()
    mt = mt.annotate_rows(**{variant_ac: hl.agg.call_stats(mt.GT, mt.alleles).AC,
                             variant_atypes: mt.alleles[1:].map(lambda alt: allele_type(mt.alleles[0], alt))})

    exprs = {}

    def has_field_of_type(name, dtype):
        return name in mt.entry and mt[name].dtype == dtype

    if has_field_of_type('DP', hl.tint32):
        exprs['dp_stats'] = hl.agg.stats(mt.DP).select('mean', 'stdev', 'min', 'max')

    if has_field_of_type('GQ', hl.tint32):
        exprs['gq_stats'] = hl.agg.stats(mt.GQ).select('mean', 'stdev', 'min', 'max')

    if not has_field_of_type('GT',  hl.tcall):
        raise ValueError(f"'sample_qc': expect an entry field 'GT' of type 'call'")

    exprs['n_called'] = hl.agg.count_where(hl.is_defined(mt['GT']))
    exprs['n_not_called'] = hl.agg.count_where(hl.is_missing(mt['GT']))
    exprs['n_hom_ref'] = hl.agg.count_where(mt['GT'].is_hom_ref())
    exprs['n_het'] = hl.agg.count_where(mt['GT'].is_het())
    exprs['n_singleton'] = hl.agg.sum(hl.sum(hl.range(0, mt['GT'].ploidy).map(lambda i: mt[variant_ac][mt['GT'][i]] == 1)))

    def get_allele_type(allele_idx):
        return hl.cond(allele_idx > 0, mt[variant_atypes][allele_idx - 1], hl.null(hl.tint32))

    exprs['allele_type_counts'] = hl.agg.explode(
        lambda elt: hl.agg.counter(elt),
        hl.range(0, mt['GT'].ploidy).map(lambda i: get_allele_type(mt['GT'][i])))

    mt = mt.annotate_cols(**{name: hl.struct(**exprs)})

    zero = hl.int64(0)

    select_exprs = {}
    if 'dp_stats' in exprs:
        select_exprs['dp_stats'] = mt[name].dp_stats
    if 'gq_stats' in exprs:
        select_exprs['gq_stats'] = mt[name].gq_stats

    select_exprs = {
        **select_exprs,
        'call_rate': hl.float64(mt[name].n_called) / (mt[name].n_called + mt[name].n_not_called),
        'n_called': mt[name].n_called,
        'n_not_called': mt[name].n_not_called,
        'n_hom_ref': mt[name].n_hom_ref,
        'n_het': mt[name].n_het,
        'n_hom_var': mt[name].n_called - mt[name].n_hom_ref - mt[name].n_het,
        'n_non_ref': mt[name].n_called - mt[name].n_hom_ref,
        'n_singleton': mt[name].n_singleton,
        'n_snp': mt[name].allele_type_counts.get(allele_ints["Transition"], zero) + \
                 mt[name].allele_type_counts.get(allele_ints["Transversion"], zero),
        'n_insertion': mt[name].allele_type_counts.get(allele_ints["Insertion"], zero),
        'n_deletion': mt[name].allele_type_counts.get(allele_ints["Deletion"], zero),
        'n_transition': mt[name].allele_type_counts.get(allele_ints["Transition"], zero),
        'n_transversion': mt[name].allele_type_counts.get(allele_ints["Transversion"], zero),
        'n_star': mt[name].allele_type_counts.get(allele_ints["Star"], zero)
    }

    mt = mt.annotate_cols(**{name: mt[name].select(**select_exprs)})

    mt = mt.annotate_cols(**{name: mt[name].annotate(
        r_ti_tv=divide_null(hl.float64(mt[name].n_transition), mt[name].n_transversion),
        r_het_hom_var=divide_null(hl.float64(mt[name].n_het), mt[name].n_hom_var),
        r_insertion_deletion=divide_null(hl.float64(mt[name].n_insertion), mt[name].n_deletion)
    )})        

    mt = mt.drop(variant_ac, variant_atypes)

    return mt
Esempio n. 42
0
File: qc.py Progetto: jigold/hail
def variant_qc(mt, name='variant_qc') -> MatrixTable:
    """Compute common variant statistics (quality control metrics).

    .. include:: ../_templates/req_tvariant.rst

    Examples
    --------

    >>> dataset_result = hl.variant_qc(dataset)

    Notes
    -----
    This method computes variant statistics from the genotype data, returning
    a new struct field `name` with the following metrics based on the fields
    present in the entry schema.

    If `mt` contains an entry field `DP` of type :py:data:`.tint32`, then the
    field `dp_stats` is computed. If `mt` contains an entry field `GQ` of type
    :py:data:`.tint32`, then the field `gq_stats` is computed. Both `dp_stats`
    and `gq_stats` are structs with with four fields:

    - `mean` (``float64``) -- Mean value.
    - `stdev` (``float64``) -- Standard deviation (zero degrees of freedom).
    - `min` (``int32``) -- Minimum value.
    - `max` (``int32``) -- Maximum value.

    If the dataset does not contain an entry field `GT` of type
    :py:data:`.tcall`, then an error is raised. The following fields are always
    computed from `GT`:

    - `AF` (``array<float64>``) -- Calculated allele frequency, one element
      per allele, including the reference. Sums to one. Equivalent to
      `AC` / `AN`.
    - `AC` (``array<int32>``) -- Calculated allele count, one element per
      allele, including the reference. Sums to `AN`.
    - `AN` (``int32``) -- Total number of called alleles.
    - `homozygote_count` (``array<int32>``) -- Number of homozygotes per
      allele. One element per allele, including the reference.
    - `call_rate` (``float64``) -- Fraction of calls neither missing nor filtered.
       Equivalent to `n_called` / :meth:`.count_cols`.
    - `n_called` (``int64``) -- Number of samples with a defined `GT`.
    - `n_not_called` (``int64``) -- Number of samples with a missing `GT`.
    - `n_filtered` (``int64``) -- Number of filtered entries.
    - `n_het` (``int64``) -- Number of heterozygous samples.
    - `n_non_ref` (``int64``) -- Number of samples with at least one called
      non-reference allele.
    - `het_freq_hwe` (``float64``) -- Expected frequency of heterozygous
      samples under Hardy-Weinberg equilibrium. See
      :func:`.functions.hardy_weinberg_test` for details.
    - `p_value_hwe` (``float64``) -- p-value from test of Hardy-Weinberg equilibrium.
      See :func:`.functions.hardy_weinberg_test` for details.

    Warning
    -------
    `het_freq_hwe` and `p_value_hwe` are calculated as in
    :func:`.functions.hardy_weinberg_test`, with non-diploid calls
    (``ploidy != 2``) ignored in the counts. As this test is only
    statistically rigorous in the biallelic setting, :func:`.variant_qc`
    sets both fields to missing for multiallelic variants. Consider using
    :func:`~hail.methods.split_multi` to split multi-allelic variants beforehand.

    Parameters
    ----------
    mt : :class:`.MatrixTable`
        Dataset.
    name : :obj:`str`
        Name for resulting field.

    Returns
    -------
    :class:`.MatrixTable`
    """
    require_row_key_variant(mt, 'variant_qc')

    bound_exprs = {}
    gq_dp_exprs = {}

    def has_field_of_type(name, dtype):
        return name in mt.entry and mt[name].dtype == dtype

    if has_field_of_type('DP', hl.tint32):
        gq_dp_exprs['dp_stats'] = hl.agg.stats(mt.DP).select('mean', 'stdev', 'min', 'max')

    if has_field_of_type('GQ', hl.tint32):
        gq_dp_exprs['gq_stats'] = hl.agg.stats(mt.GQ).select('mean', 'stdev', 'min', 'max')

    if not has_field_of_type('GT',  hl.tcall):
        raise ValueError(f"'variant_qc': expect an entry field 'GT' of type 'call'")

    bound_exprs['n_called'] = hl.agg.count_where(hl.is_defined(mt['GT']))
    bound_exprs['n_not_called'] = hl.agg.count_where(hl.is_missing(mt['GT']))
    bound_exprs['n_filtered'] = mt.count_cols(_localize=False) - hl.agg.count()
    bound_exprs['call_stats'] = hl.agg.call_stats(mt.GT, mt.alleles)

    result = hl.rbind(hl.struct(**bound_exprs),
                      lambda e1: hl.rbind(
                          hl.case().when(hl.len(mt.alleles) == 2,
                                         hl.hardy_weinberg_test(e1.call_stats.homozygote_count[0],
                                                                e1.call_stats.AC[1] - 2 *
                                                                e1.call_stats.homozygote_count[1],
                                                                e1.call_stats.homozygote_count[1])
                                         ).or_missing(),
                          lambda hwe: hl.struct(**{
                              **gq_dp_exprs,
                              **e1.call_stats,
                              'call_rate': hl.float(e1.n_called) / (e1.n_called + e1.n_not_called + e1.n_filtered),
                              'n_called': e1.n_called,
                              'n_not_called': e1.n_not_called,
                              'n_filtered': e1.n_filtered,
                              'n_het': e1.n_called - hl.sum(e1.call_stats.homozygote_count),
                              'n_non_ref': e1.n_called - e1.call_stats.homozygote_count[0],
                              'het_freq_hwe': hwe.het_freq_hwe,
                              'p_value_hwe': hwe.p_value})))

    return mt.annotate_rows(**{name: result})
Esempio n. 43
0
def combine(ts):
    def merge_alleles(alleles):
        from hail.expr.functions import _num_allele_type, _allele_ints
        return hl.rbind(
            alleles.map(lambda a: hl.or_else(a[0], ''))
                   .fold(lambda s, t: hl.cond(hl.len(s) > hl.len(t), s, t), ''),
            lambda ref:
            hl.rbind(
                alleles.map(
                    lambda al: hl.rbind(
                        al[0],
                        lambda r:
                        hl.array([ref]).extend(
                            al[1:].map(
                                lambda a:
                                hl.rbind(
                                    _num_allele_type(r, a),
                                    lambda at:
                                    hl.cond(
                                        (_allele_ints['SNP'] == at) |
                                        (_allele_ints['Insertion'] == at) |
                                        (_allele_ints['Deletion'] == at) |
                                        (_allele_ints['MNP'] == at) |
                                        (_allele_ints['Complex'] == at),
                                        a + ref[hl.len(r):],
                                        a)))))),
                lambda lal:
                hl.struct(
                    globl=hl.array([ref]).extend(hl.array(hl.set(hl.flatten(lal)).remove(ref))),
                    local=lal)))

    def renumber_entry(entry, old_to_new) -> StructExpression:
        # global index of alternate (non-ref) alleles
        return entry.annotate(LA=entry.LA.map(lambda lak: old_to_new[lak]))

    if (ts.row.dtype, ts.globals.dtype) not in _merge_function_map:
        f = hl.experimental.define_function(
            lambda row, gbl:
            hl.rbind(
                merge_alleles(row.data.map(lambda d: d.alleles)),
                lambda alleles:
                hl.struct(
                    locus=row.locus,
                    alleles=alleles.globl,
                    rsid=hl.find(hl.is_defined, row.data.map(lambda d: d.rsid)),
                    __entries=hl.bind(
                        lambda combined_allele_index:
                        hl.range(0, hl.len(row.data)).flatmap(
                            lambda i:
                            hl.cond(hl.is_missing(row.data[i].__entries),
                                    hl.range(0, hl.len(gbl.g[i].__cols))
                                      .map(lambda _: hl.null(row.data[i].__entries.dtype.element_type)),
                                    hl.bind(
                                        lambda old_to_new: row.data[i].__entries.map(
                                            lambda e: renumber_entry(e, old_to_new)),
                                        hl.range(0, hl.len(alleles.local[i])).map(
                                            lambda j: combined_allele_index[alleles.local[i][j]])))),
                        hl.dict(hl.range(0, hl.len(alleles.globl)).map(
                            lambda j: hl.tuple([alleles.globl[j], j])))))),
            ts.row.dtype, ts.globals.dtype)
        _merge_function_map[(ts.row.dtype, ts.globals.dtype)] = f
    merge_function = _merge_function_map[(ts.row.dtype, ts.globals.dtype)]
    ts = Table(TableMapRows(ts._tir, Apply(merge_function._name,
                                           TopLevelReference('row'),
                                           TopLevelReference('global'))))
    return ts.transmute_globals(__cols=hl.flatten(ts.g.map(lambda g: g.__cols)))
Esempio n. 44
0
def transform_one(mt, vardp_outlier=100_000) -> Table:
    """transforms a gvcf into a form suitable for combining

    The input to this should be some result of either :func:`.import_vcf` or
    :func:`.import_vcfs` with `array_elements_required=False`.

    There is a strong assumption that this function will be called on a matrix
    table with one column.
    """
    mt = localize(mt)
    if mt.row.dtype not in _transform_rows_function_map:
        f = hl.experimental.define_function(
            lambda row: hl.rbind(
                hl.len(row.alleles), '<NON_REF>' == row.alleles[-1],
                lambda alleles_len, has_non_ref: hl.struct(
                    locus=row.locus,
                    alleles=hl.cond(has_non_ref, row.alleles[:-1], row.alleles),
                    rsid=row.rsid,
                    __entries=row.__entries.map(
                        lambda e:
                        hl.struct(
                            DP=e.DP,
                            END=row.info.END,
                            GQ=e.GQ,
                            LA=hl.range(0, alleles_len - hl.cond(has_non_ref, 1, 0)),
                            LAD=hl.cond(has_non_ref, e.AD[:-1], e.AD),
                            LGT=e.GT,
                            LPGT=e.PGT,
                            LPL=hl.cond(has_non_ref,
                                        hl.cond(alleles_len > 2,
                                                e.PL[:-alleles_len],
                                                hl.null(e.PL.dtype)),
                                        hl.cond(alleles_len > 1,
                                                e.PL,
                                                hl.null(e.PL.dtype))),
                            MIN_DP=e.MIN_DP,
                            PID=e.PID,
                            RGQ=hl.cond(
                                has_non_ref,
                                e.PL[hl.call(0, alleles_len - 1).unphased_diploid_gt_index()],
                                hl.null(e.PL.dtype.element_type)),
                            SB=e.SB,
                            gvcf_info=hl.case()
                                .when(hl.is_missing(row.info.END),
                                      hl.struct(
                                          ClippingRankSum=row.info.ClippingRankSum,
                                          BaseQRankSum=row.info.BaseQRankSum,
                                          MQ=row.info.MQ,
                                          MQRankSum=row.info.MQRankSum,
                                          MQ_DP=row.info.MQ_DP,
                                          QUALapprox=row.info.QUALapprox,
                                          RAW_MQ=row.info.RAW_MQ,
                                          ReadPosRankSum=row.info.ReadPosRankSum,
                                          VarDP=hl.cond(row.info.VarDP > vardp_outlier,
                                                        row.info.DP, row.info.VarDP)))
                                .or_missing()
                        ))),
            ),
            mt.row.dtype)
        _transform_rows_function_map[mt.row.dtype] = f
    transform_row = _transform_rows_function_map[mt.row.dtype]
    return Table(TableMapRows(mt._tir, Apply(transform_row._name, TopLevelReference('row'))))