def compare_row_counts(ht1: hl.Table, ht2: hl.Table) -> bool: """ Check if the row counts in two Tables are the same. :param ht1: First Table to be checked :param ht2: Second Table to be checked :return: Whether the row counts are the same """ r_count1 = ht1.count() r_count2 = ht2.count() logger.info(f"{r_count1} rows in left table; {r_count2} rows in right table") return r_count1 == r_count2
def pc_project( mt: hl.MatrixTable, loadings_ht: hl.Table, loading_location: str = "loadings", af_location: str = "pca_af", ) -> hl.Table: """ Project samples in `mt` on pre-computed PCs. :param mt: MT containing the samples to project :param loadings_ht: HT containing the PCA loadings and allele frequencies used for the PCA :param loading_location: Location of expression for loadings in `loadings_ht` :param af_location: Location of expression for allele frequency in `loadings_ht` :return: Table with scores calculated from loadings in column `scores` """ n_variants = loadings_ht.count() mt = mt.annotate_rows( pca_loadings=loadings_ht[mt.row_key][loading_location], pca_af=loadings_ht[mt.row_key][af_location], ) mt = mt.filter_rows( hl.is_defined(mt.pca_loadings) & hl.is_defined(mt.pca_af) & (mt.pca_af > 0) & (mt.pca_af < 1)) gt_norm = (mt.GT.n_alt_alleles() - 2 * mt.pca_af) / hl.sqrt( n_variants * 2 * mt.pca_af * (1 - mt.pca_af)) mt = mt.annotate_cols(scores=hl.agg.array_sum(mt.pca_loadings * gt_norm)) return mt.cols().select("scores")
def pc_hwe_gt( mt: hl.MatrixTable, loadings_ht: hl.Table, loading_location: str = "loadings", af_location: str = "pca_af", ) -> hl.MatrixTable: n_variants = loadings_ht.count() mt = mt.annotate_rows( pca_loadings=loadings_ht[mt.row_key][loading_location], pca_af=loadings_ht[mt.row_key][af_location], ) mt = mt.filter_rows( hl.is_defined(mt.pca_loadings) & hl.is_defined(mt.pca_af) & (mt.pca_af > 0) & (mt.pca_af < 1) ) # Attach normalized entries to be used in projection mt = mt.annotate_entries( GTN=(mt.GT.n_alt_alleles() - 2 * mt.pca_af) / hl.sqrt(n_variants * 2 * mt.pca_af * (1 - mt.pca_af)) ) return mt
def compute_phase(variants_ht: hl.Table, least_consequence: str = LEAST_CONSEQUENCE, max_freq: float = MAX_FREQ) -> hl.Table: n_variant_pairs = variants_ht.count() logger.info(f"Looking up phase for {n_variant_pairs} variant pair(s).") # Join with gnomad phased variants vp_ht = hl.read_table(phased_vp_count_ht_path('exomes')) phased_ht = vp_ht.semi_join(variants_ht) n_phased = phased_ht.count() phased_ht = explode_phase_info(phased_ht) # explodes phase_info by pop phased_ht = phased_ht.transmute( phase_info=phased_ht.phase_info.select('gt_counts', 'em')).repartition( ceil(n_variant_pairs / 10000), shuffle=True) phased_ht = phased_ht.persist() # .checkpoint("gs://gnomad-tmp/vp_ht.ht") # If not all pairs had at least one carrier of both, then compute phase estimate from single variants logger.info( f"{n_phased}/{n_variant_pairs} variant pair(s) found with carriers of both in gnomAD." ) if n_phased < n_variant_pairs: unphased_ht = variants_ht.anti_join(vp_ht) unphased_ht = annotate_unphased_pairs(unphased_ht, n_variant_pairs, least_consequence, max_freq) phased_ht = phased_ht.union(unphased_ht, unify=True) return phased_ht
def generic_field_check( ht: hl.Table, check_description: str, display_fields: hl.expr.StructExpression, cond_expr: hl.expr.BooleanExpression = None, verbose: bool = False, show_percent_sites: bool = False, n_fail: Optional[int] = None, ht_count: Optional[int] = None, ) -> None: """ Check generic logical condition `cond_expr` involving annotations in a Hail Table when `n_fail` is absent and print the results to stdout. Displays the number of rows (and percent of rows, if `show_percent_sites` is True) in the Table that fail, either previously computed as `n_fail` or that match the `cond_expr`, and fail to be the desired condition (`check_description`). If the number of rows that match the `cond_expr` or `n_fail` is 0, then the Table passes that check; otherwise, it fails. .. note:: `cond_expr` and `check_description` are opposites and should never be the same. E.g., If `cond_expr` filters for instances where the raw AC is less than adj AC, then it is checking sites that fail to be the desired condition (`check_description`) of having a raw AC greater than or equal to the adj AC. :param ht: Table containing annotations to be checked. :param check_description: String describing the condition being checked; is displayed in stdout summary message. :param display_fields: StructExpression containing annotations to be displayed in case of failure (for troubleshooting purposes); these fields are also displayed if verbose is True. :param cond_expr: Optional logical expression referring to annotations in ht to be checked. :param verbose: If True, show top values of annotations being checked, including checks that pass; if False, show only top values of annotations that fail checks. :param show_percent_sites: Show percentage of sites that fail checks. Default is False. :param n_fail: Optional number of sites that fail the conditional checks (previously computed). If not supplied, `cond_expr` is used to filter the Table and obtain the count of sites that fail the checks. :param ht_count: Optional number of sites within hail Table (previously computed). If not supplied, a count of sites in the Table is performed. :return: None """ if (n_fail is None and cond_expr is None) or (n_fail and cond_expr): raise ValueError( "One and only one of n_fail or cond_expr must be defined!") if cond_expr: n_fail = ht.filter(cond_expr).count() if show_percent_sites and (ht_count is None): ht_count = ht.count() if n_fail > 0: logger.info("Found %d sites that fail %s check:", n_fail, check_description) if show_percent_sites: logger.info("Percentage of sites that fail: %.2f %%", 100 * (n_fail / ht_count)) ht.select(**display_fields).show() else: logger.info("PASSED %s check", check_description) if verbose: ht.select(**display_fields).show()
def generic_field_check( ht: hl.Table, cond_expr: hl.expr.BooleanExpression, check_description: str, display_fields: List[str], verbose: bool, show_percent_sites: bool = False, ) -> None: """ Check a generic logical condition involving annotations in a Hail Table and print the results to terminal. Displays the number of rows (and percent of rows, if `show_percent_sites` is True) in the Table that match the `cond_expr` and fail to be the desired condition (`check_description`). If the number of rows that match the `cond_expr` is 0, then the Table passes that check; otherwise, it fails. .. note:: `cond_expr` and `check_description` are opposites and should never be the same. E.g., If `cond_expr` filters for instances where the raw AC is less than adj AC, then it is checking sites that fail to be the desired condition (`check_description`) of having a raw AC greater than or equal to the adj AC. :param ht: Table containing annotations to be checked. :param cond_expr: Logical expression referring to annotations in ht to be checked. :param check_description: String describing the condition being checked; is displayed in terminal summary message. :param display_fields: List of names of ht annotations to be displayed in case of failure (for troubleshooting purposes); these fields are also displayed if verbose is True. :param verbose: If True, show top values of annotations being checked, including checks that pass; if False, show only top values of annotations that fail checks. :param show_percent_sites: Show percentage of sites that fail checks. Default is False. :return: None """ ht_orig = ht ht = ht.filter(cond_expr) n_fail = ht.count() if n_fail > 0: logger.info("Found %d sites that fail %s check:", n_fail, check_description) if show_percent_sites: logger.info("Percentage of sites that fail: %f", n_fail / ht_orig.count()) ht = ht.flatten() ht.select("locus", "alleles", *display_fields).show() else: logger.info("PASSED %s check", check_description) if verbose: ht_orig = ht_orig.flatten() ht_orig.select(*display_fields).show()
def pc_project( # reference: https://github.com/macarthur-lab/gnomad_hail/blob/master/utils/generic.py#L131 mt: hl.MatrixTable, loadings_ht: hl.Table, loading_location: str = "loadings", af_location: str = "pca_af") -> hl.Table: n_variants = loadings_ht.count() mt = mt.annotate_rows( pca_loadings=loadings_ht[mt.row_key][loading_location], pca_af=loadings_ht[mt.row_key][af_location]) mt = mt.filter_rows( hl.is_defined(mt.pca_loadings) & hl.is_defined(mt.pca_af) & (mt.pca_af > 0) & (mt.pca_af < 1)) gt_norm = (mt.GT.n_alt_alleles() - 2 * mt.pca_af) / hl.sqrt( n_variants * 2 * mt.pca_af * (1 - mt.pca_af)) mt = mt.annotate_cols(scores=hl.agg.array_sum(mt.pca_loadings * gt_norm)) return mt.cols().select('scores')
def get_related_samples_to_drop(rank_table: hl.Table, relatedness_ht: hl.Table) -> hl.Table: """ Use the maximal independence function in Hail to intelligently prune clusters of related individuals, removing less desirable samples while maximizing the number of unrelated individuals kept in the sample set :param Table rank_table: Table with ranking annotations across exomes and genomes, computed via make_rank_file() :param Table relatedness_ht: Table with kinship coefficient annotations computed via pc_relate() :return: Table containing sample IDs ('s') to be pruned from the combined exome and genome sample set :rtype: Table """ # Define maximal independent set, using rank list related_pairs = relatedness_ht.filter( relatedness_ht.kin > 0.08838835).select('i', 'j') n_related_samples = hl.eval( hl.len( related_pairs.aggregate(hl.agg.explode( lambda x: hl.agg.collect_as_set(x), [related_pairs.i, related_pairs.j]), _localize=False))) logger.info( '{} samples with at least 2nd-degree relatedness found in callset'. format(n_related_samples)) max_rank = rank_table.count() related_pairs = related_pairs.annotate( id1_rank=hl.struct(id=related_pairs.i, rank=rank_table[related_pairs.i].rank), id2_rank=hl.struct(id=related_pairs.j, rank=rank_table[related_pairs.j].rank)).select( 'id1_rank', 'id2_rank') def tie_breaker(l, r): return hl.or_else(l.rank, max_rank + 1) - hl.or_else( r.rank, max_rank + 1) related_samples_to_drop_ranked = hl.maximal_independent_set( related_pairs.id1_rank, related_pairs.id2_rank, keep=False, tie_breaker=tie_breaker) return related_samples_to_drop_ranked.select( **related_samples_to_drop_ranked.node.id).key_by('data_type', 's')
def generate_sib_stats_expr( mt: hl.MatrixTable, sib_ht: hl.Table, i_col: str = "i", j_col: str = "j", strata: Dict[str, hl.expr.BooleanExpression] = {"raw": True}, is_female: Optional[hl.expr.BooleanExpression] = None, ) -> hl.expr.StructExpression: """ Generate a row-wise expression containing the number of alternate alleles in common between sibling pairs. The sibling sharing counts can be stratified using additional filters using `stata`. .. note:: This function expects that the `mt` has either been split or filtered to only bi-allelics If a sample has multiple sibling pairs, only one pair will be counted :param mt: Input matrix table :param sib_ht: Table defining sibling pairs with one sample in a col (`i_col`) and the second in another col (`j_col`) :param i_col: Column containing the 1st sample of the pair in the relationship table :param j_col: Column containing the 2nd sample of the pair in the relationship table :param strata: Dict with additional strata to use when computing shared sibling variant counts :param is_female: An optional column in mt giving the sample sex. If not given, counts are only computed for autosomes. :return: A Table with the sibling shared variant counts """ def _get_alt_count(locus, gt, is_female): """Calculate alt allele count with sex info if present.""" if is_female is None: return hl.or_missing(locus.in_autosome(), gt.n_alt_alleles()) return (hl.case().when( locus.in_autosome_or_par(), gt.n_alt_alleles()).when( ~is_female & (locus.in_x_nonpar() | locus.in_y_nonpar()), hl.min(1, gt.n_alt_alleles()), ).when(is_female & locus.in_y_nonpar(), 0).default(0)) if is_female is None: logger.warning( "Since no sex expression was given to generate_sib_stats_expr, only variants in autosomes will be counted." ) # If a sample is in sib_ht more than one time, keep only one of the sibling pairs # First filter to only samples found in mt to keep as many pairs as possible s_to_keep = mt.aggregate_cols(hl.agg.collect_as_set(mt.s), _localize=False) sib_ht = sib_ht.filter( s_to_keep.contains(sib_ht[i_col].s) & s_to_keep.contains(sib_ht[j_col].s)) sib_ht = sib_ht.add_index("sib_idx") sib_ht = sib_ht.annotate(sibs=[sib_ht[i_col].s, sib_ht[j_col].s]) sib_ht = sib_ht.explode("sibs") sib_ht = sib_ht.group_by("sibs").aggregate( sib_idx=(hl.agg.take(sib_ht.sib_idx, 1, ordering=sib_ht.sib_idx)[0])) sib_ht = sib_ht.group_by( sib_ht.sib_idx).aggregate(sibs=hl.agg.collect(sib_ht.sibs)) sib_ht = sib_ht.filter(hl.len(sib_ht.sibs) == 2).persist() logger.info("Generating sibling variant sharing counts using %d pairs.", sib_ht.count()) sib_ht = sib_ht.explode("sibs").key_by("sibs")[mt.s] # Create sibling sharing counters sib_stats = hl.struct( **{ f"n_sib_shared_variants_{name}": hl.sum( hl.agg.filter( expr, hl.agg.group_by( sib_ht.sib_idx, hl.or_missing( hl.agg.sum(hl.is_defined(mt.GT)) == 2, hl.agg.min( _get_alt_count(mt.locus, mt.GT, is_female)), ), ), ).values()) for name, expr in strata.items() }) sib_stats = sib_stats.annotate( **{ f"ac_sibs_{name}": hl.agg.filter( expr & hl.is_defined(sib_ht.sib_idx), hl.agg.sum(mt.GT.n_alt_alleles())) for name, expr in strata.items() }) return sib_stats