def generate_allele_data(ht: hl.Table) -> hl.Table: """ Returns bi-allelic sites HT with the following annotations: - allele_data (nonsplit_alleles, has_star, variant_type, and n_alt_alleles) :param Table ht: Full unsplit HT :return: Table with allele data annotations :rtype: Table """ ht = ht.select() allele_data = hl.struct(nonsplit_alleles=ht.alleles, has_star=hl.any(lambda a: a == "*", ht.alleles)) ht = ht.annotate(allele_data=allele_data.annotate( **add_variant_type(ht.alleles))) ht = hl.split_multi_hts(ht) ht = ht.filter(hl.len(ht.alleles) > 1) allele_type = (hl.case().when( hl.is_snp(ht.alleles[0], ht.alleles[1]), "snv").when(hl.is_insertion(ht.alleles[0], ht.alleles[1]), "ins").when(hl.is_deletion(ht.alleles[0], ht.alleles[1]), "del").default("complex")) ht = ht.annotate(allele_data=ht.allele_data.annotate( allele_type=allele_type, was_mixed=ht.allele_data.variant_type == "mixed")) return ht
def prepare_exomes(exome_ht: hl.Table, groupings: List, impose_high_af_cutoff_upfront: bool = True) -> hl.Table: # Manipulate VEP annotations and explode by them exome_ht = add_most_severe_csq_to_tc_within_ht(exome_ht) exome_ht = exome_ht.transmute(transcript_consequences=exome_ht.vep.transcript_consequences) exome_ht = exome_ht.explode(exome_ht.transcript_consequences) # Annotate variants with grouping variables. exome_ht, grouping = annotate_constraint_groupings(exome_ht,groupings) # This function needs to be adapted exome_ht = exome_ht.select( 'context', 'ref', 'alt', 'methylation_level', 'freq', 'pass_filters', *groupings) # Filter by allele count # Likely to need to adapt this function as well af_cutoff = 0.001 freq_index = exome_ht.freq_index_dict.collect()[0][dataset] def keep_criteria(ht): crit = (ht.freq[freq_index].AC > 0) & ht.pass_filters & (ht.coverage > 0) if impose_high_af_cutoff_upfront: crit &= (ht.freq[freq_index].AF <= af_cutoff) return crit exome_ht = exome_ht.filter(keep_criteria(exome_ht)) return exome_ht
def filter_kin_ht( ht: hl.Table, out_summary: io.TextIOWrapper, first_degree_pi_hat: float = 0.40, grandparent_pi_hat: float = 0.20, grandparent_ibd1: float = 0.25, grandparent_ibd2: float = 0.15, ) -> hl.Table: """ Filter the kinship table to relationships of grandparents and above. :param ht: hl.Table :param out_summary: Summary file with a summary statistics and notes :param first_degree_pi_hat: Minimum pi_hat threshold to use to filter the kinship table to first degree relatives :param grandparent_pi_hat: Minimum pi_hat threshold to use to filter the kinship table to grandparents :param grandparent_ibd1: Minimum IBD1 threshold to use to filter the kinship table to grandparents :param grandparent_ibd2: Maximum IBD2 threshold to use to filter the kinship table to grandparents :return: Table containing only relationships of grandparents and above """ # Filter to anything above the relationship of a grandparent ht = ht.filter((ht.pi_hat > first_degree_pi_hat) | ((ht.pi_hat > grandparent_pi_hat) & (ht.ibd1 > grandparent_ibd1) & (ht.ibd2 < grandparent_ibd2))) ht = ht.annotate(pair=hl.sorted([ht.i, ht.j])) out_summary.write( f"NOTE: kinship table was filtered to:\n(kin > {first_degree_pi_hat}) or kin > {grandparent_pi_hat} and IBD1 > {grandparent_ibd1} and IBD2 > {grandparent_ibd2})\n" ) out_summary.write( f"relationships not meeting this critera were not evaluated\n\n") return ht
def generate_sib_stats( mt: hl.MatrixTable, relatedness_ht: hl.Table, i_col: str = "i", j_col: str = "j", relationship_col: str = "relationship", autosomes_only: bool = True, bi_allelic_only: bool = True, ) -> hl.Table: """ This is meant as a default wrapper for `generate_sib_stats_expr`. It returns a hail table with counts of variants shared by pairs of siblings in `relatedness_ht`. This function takes a hail Table with a row for each pair of individuals i,j in the data that are related (it's OK to have unrelated samples too). The `relationship_col` should be a column specifying the relationship between each two samples as defined by the constants in `gnomad.utils.relatedness`. This relationship_col will be used to filter to only pairs of samples that are annotated as `SIBLINGS`. .. note:: By default this pipeline function will filter `mt` to only autosomes and bi-allelic sites. :param mt: Input Matrix table :param relatedness_ht: Input relationship table :param i_col: Column containing the 1st sample of the pair in the relationship table :param j_col: Column containing the 2nd sample of the pair in the relationship table :param relationship_col: Column containing the relationship for the sample pair as defined in this module constants. :param autosomes_only: If set, only autosomal intervals are used. :param bi_allelic_only: If set, only bi-allelic sites are used for the computation :return: A Table with the sibling shared variant counts """ if autosomes_only: mt = filter_to_autosomes(mt) if bi_allelic_only: mt = mt.filter_rows(bi_allelic_expr(mt)) sib_ht = relatedness_ht.filter( relatedness_ht[relationship_col] == SIBLINGS) s_to_keep = sib_ht.aggregate( hl.agg.explode(lambda s: hl.agg.collect_as_set(s), [sib_ht[i_col].s, sib_ht[j_col].s]), _localize=False, ) mt = mt.filter_cols(s_to_keep.contains(mt.s)) if "adj" not in mt.entry: mt = annotate_adj(mt) sib_stats_ht = mt.select_rows(**generate_sib_stats_expr( mt, sib_ht, i_col=i_col, j_col=j_col, strata={ "raw": True, "adj": mt.adj }, )).rows() return sib_stats_ht
def default_generate_sib_stats( mt: hl.MatrixTable, relatedness_ht: hl.Table, sex_ht: hl.Table, i_col: str = "i", j_col: str = "j", relationship_col: str = "relationship", ) -> hl.Table: """ This is meant as a default wrapper for `generate_sib_stats_expr`. It returns a hail table with counts of variants shared by pairs of siblings in `relatedness_ht`. This function takes a hail Table with a row for each pair of individuals i,j in the data that are related (it's OK to have unrelated samples too). The `relationship_col` should be a column specifying the relationship between each two samples as defined by the constants in `gnomad.utils.relatedness`. This relationship_col will be used to filter to only pairs of samples that are annotated as `SIBLINGS`. :param mt: Input Matrix table :param relatedness_ht: Input relationship table :param sex_ht: A Table containing sex information for the samples :param i_col: Column containing the 1st sample of the pair in the relationship table :param j_col: Column containing the 2nd sample of the pair in the relationship table :param relationship_col: Column containing the relationship for the sample pair as defined in this module constants. :return: A Table with the sibling shared variant counts """ sex_ht = sex_ht.annotate( is_female=hl.case() .when(sex_ht.sex_karyotype == "XX", True) .when(sex_ht.sex_karyotype == "XY", False) .or_missing() ) # TODO: Change to use SIBLINGS constant when relatedness PR goes in sib_ht = relatedness_ht.filter(relatedness_ht[relationship_col] == "Siblings") s_to_keep = sib_ht.aggregate( hl.agg.explode( lambda s: hl.agg.collect_as_set(s), [sib_ht[i_col].s, sib_ht[j_col].s] ), _localize=False, ) mt = mt.filter_cols(s_to_keep.contains(mt.s)) mt = annotate_adj(mt) mt = mt.annotate_cols(is_female=sex_ht[mt.s].is_female) sib_stats_ht = mt.select_rows( **generate_sib_stats_expr( mt, sib_ht, i_col=i_col, j_col=j_col, strata={"raw": True, "adj": mt.adj}, is_female=mt.is_female, ) ).rows() return sib_stats_ht
def explode_phase_info(ht: hl.Table, remove_all_ref: bool = True) -> hl.Table: ht = ht.transmute(phase_info=hl.array(ht.phase_info)) ht = ht.explode('phase_info') ht = ht.transmute(pop=ht.phase_info[0], phase_info=ht.phase_info[1]) if remove_all_ref: ht = ht.filter(hl.sum(ht.phase_info.gt_counts.raw[1:]) > 0) return ht
def filter_ht_for_plink(ht: hl.Table, n_samples: int, min_call_rate: float = 0.95, variants_per_mac_category: int = 2000, variants_per_maf_category: int = 10000): from gnomad.utils.filtering import filter_to_autosomes ht = filter_to_autosomes(ht) ht = ht.filter((ht.call_stats.AN >= n_samples * 2 * min_call_rate) & (ht.call_stats.AC > 0)) ht = ht.annotate(mac_category=mac_category_case_builder(ht.call_stats)) category_counter = ht.aggregate(hl.agg.counter(ht.mac_category)) print(category_counter) ht = ht.annotate_globals(category_counter=category_counter) return ht.filter( hl.rand_unif( 0, 1) < hl.cond(ht.mac_category >= 1, variants_per_mac_category, variants_per_maf_category) / ht.category_counter[ht.mac_category])
def generic_field_check( ht: hl.Table, check_description: str, display_fields: hl.expr.StructExpression, cond_expr: hl.expr.BooleanExpression = None, verbose: bool = False, show_percent_sites: bool = False, n_fail: Optional[int] = None, ht_count: Optional[int] = None, ) -> None: """ Check generic logical condition `cond_expr` involving annotations in a Hail Table when `n_fail` is absent and print the results to stdout. Displays the number of rows (and percent of rows, if `show_percent_sites` is True) in the Table that fail, either previously computed as `n_fail` or that match the `cond_expr`, and fail to be the desired condition (`check_description`). If the number of rows that match the `cond_expr` or `n_fail` is 0, then the Table passes that check; otherwise, it fails. .. note:: `cond_expr` and `check_description` are opposites and should never be the same. E.g., If `cond_expr` filters for instances where the raw AC is less than adj AC, then it is checking sites that fail to be the desired condition (`check_description`) of having a raw AC greater than or equal to the adj AC. :param ht: Table containing annotations to be checked. :param check_description: String describing the condition being checked; is displayed in stdout summary message. :param display_fields: StructExpression containing annotations to be displayed in case of failure (for troubleshooting purposes); these fields are also displayed if verbose is True. :param cond_expr: Optional logical expression referring to annotations in ht to be checked. :param verbose: If True, show top values of annotations being checked, including checks that pass; if False, show only top values of annotations that fail checks. :param show_percent_sites: Show percentage of sites that fail checks. Default is False. :param n_fail: Optional number of sites that fail the conditional checks (previously computed). If not supplied, `cond_expr` is used to filter the Table and obtain the count of sites that fail the checks. :param ht_count: Optional number of sites within hail Table (previously computed). If not supplied, a count of sites in the Table is performed. :return: None """ if (n_fail is None and cond_expr is None) or (n_fail and cond_expr): raise ValueError( "One and only one of n_fail or cond_expr must be defined!") if cond_expr: n_fail = ht.filter(cond_expr).count() if show_percent_sites and (ht_count is None): ht_count = ht.count() if n_fail > 0: logger.info("Found %d sites that fail %s check:", n_fail, check_description) if show_percent_sites: logger.info("Percentage of sites that fail: %.2f %%", 100 * (n_fail / ht_count)) ht.select(**display_fields).show() else: logger.info("PASSED %s check", check_description) if verbose: ht.select(**display_fields).show()
def densify_sites( mt: hl.MatrixTable, sites_ht: hl.Table, last_END_positions_ht: hl.Table, semi_join_rows: bool = True, ) -> hl.MatrixTable: """ Creates a dense version of the input sparse MT at the sites in `sites_ht` reading the minimal amount of data required. Note that only rows that appear both in `mt` and `sites_ht` are returned. :param mt: Input sparse MT :param sites_ht: Desired sites to densify :param last_END_positions_ht: Table storing positions of the furthest ref block (END tag) :param semi_join_rows: Whether to filter the MT rows based on semi-join (default, better if sites_ht is large) or based on filter_intervals (better if sites_ht only contains a few sites) :return: Dense MT filtered to the sites in `sites_ht` """ logger.info("Computing intervals to densify from sites Table.") sites_ht = sites_ht.key_by("locus") sites_ht = sites_ht.annotate( interval=hl.locus_interval( sites_ht.locus.contig, last_END_positions_ht[sites_ht.key].last_END_position, end=sites_ht.locus.position, includes_end=True, reference_genome=sites_ht.locus.dtype.reference_genome, ) ) sites_ht = sites_ht.filter(hl.is_defined(sites_ht.interval)) if semi_join_rows: mt = mt.filter_rows(hl.is_defined(sites_ht.key_by("interval")[mt.locus])) else: logger.info("Collecting intervals to densify.") intervals = sites_ht.interval.collect() print( "Found {0} intervals, totalling {1} bp in the dense Matrix.".format( len(intervals), sum( [ interval_length(interval) for interval in union_intervals(intervals) ] ), ) ) mt = hl.filter_intervals(mt, intervals) mt = hl.experimental.densify(mt) return mt.filter_rows(hl.is_defined(sites_ht[mt.locus]))
def aggregate_contig(ht: hl.Table, contigs: Set[str] = None): """ Aggregates all contigs together and computes number for bins accross the contigs. """ if contigs: ht = ht.filter(hl.literal(contigs).contains(ht.contig)) return ht.group_by(*[k for k in ht.key if k != 'contig']).aggregate( min_score=hl.agg.min(ht.min_score), max_score=hl.agg.max(ht.max_score), **{ x: hl.agg.sum(ht[x]) for x in ht.row_value if x not in ['min_score', 'max_score'] })
def get_platform_specific_intervals(platform_pc_loadings_ht: hl.Table, threshold: float) -> List[hl.Interval]: """ This takes the platform PC loadings and returns a list of intervals where the sum of the loadings above the given threshold. The experimental / untested idea behind this, is that those intervals may be problematic on some platforms. :param Table platform_pc_loadings_ht: Platform PCA loadings indexed by interval :param float threshold: Minimal threshold :param str intervals_path: Path to the intervals file to use (default: b37 exome calling intervals) :return: List of intervals with PC loadings above the given threshold :rtype: list of Interval """ platform_specific_intervals = platform_pc_loadings_ht.filter( hl.sum(hl.abs(platform_pc_loadings_ht.loadings)) >= threshold) return platform_specific_intervals.interval.collect()
def annotate_related_pairs(related_pairs: hl.Table, index_col: str) -> hl.Table: related_pairs = related_pairs.key_by(**related_pairs[index_col]) related_pairs = related_pairs.filter( hl.is_missing(case_parents[related_pairs.key])) return related_pairs.annotate( **{ index_col: related_pairs[index_col].annotate( case_rank=hl.or_else( hl.int(meta_ht[related_pairs.key].is_case), -1), dp_mean=hl.or_else( sample_qc_ht[ related_pairs.key].sample_qc.dp_stats.mean, -1.0)) }).key_by()
def compute_grouped_binned_ht( bin_ht: hl.Table, checkpoint_path: Optional[str] = None, ) -> hl.GroupedTable: """ Group a Table that has been annotated with bins (`compute_ranked_bin` or `create_binned_ht`). The table will be grouped by bin_id (bin, biallelic, etc.), contig, snv, bi_allelic and singleton. .. note:: If performing an aggregation following this grouping (such as `score_bin_agg`) then the aggregation function will need to use `ht._parent` to get the origin Table from the GroupedTable for the aggregation :param bin_ht: Input Table with a `bin_id` annotation :param checkpoint_path: If provided an intermediate checkpoint table is created with all required annotations before shuffling. :return: Table grouped by bins(s) """ # Explode the rank table by bin_id bin_ht = bin_ht.annotate( bin_groups=hl.array( [ hl.Struct(bin_id=bin_name, bin=bin_ht[bin_name]) for bin_name in bin_ht.bin_group_variant_counts ] ) ) bin_ht = bin_ht.explode(bin_ht.bin_groups) bin_ht = bin_ht.transmute( bin_id=bin_ht.bin_groups.bin_id, bin=bin_ht.bin_groups.bin ) bin_ht = bin_ht.filter(hl.is_defined(bin_ht.bin)) if checkpoint_path is not None: bin_ht.checkpoint(checkpoint_path, overwrite=True) else: bin_ht = bin_ht.persist() # Group by bin_id, bin and additional stratification desired and compute QC metrics per bin return bin_ht.group_by( bin_id=bin_ht.bin_id, contig=bin_ht.locus.contig, snv=hl.is_snp(bin_ht.alleles[0], bin_ht.alleles[1]), bi_allelic=~bin_ht.was_split, singleton=bin_ht.singleton, release_adj=bin_ht.ac > 0, bin=bin_ht.bin, )._set_buffer_size(20000)
def test_model( ht: hl.Table, rf_model: pyspark.ml.PipelineModel, features: List[str], label: str, prediction_col_name: str = "rf_prediction", ) -> List[hl.tstruct]: """ A wrapper to test a model on a set of examples with known labels. 1) Runs the model on the data 2) Prints confusion matrix and accuracy 3) Returns confusion matrix as a list of struct :param ht: Input table :param rf_model: RF Model :param features: Columns containing features that were used in the model :param label: Column containing label to be predicted :param prediction_col_name: Where to store the prediction :return: A list containing structs with {label, prediction, n} """ ht = apply_rf_model( ht.filter(hl.is_defined(ht[label])), rf_model, features, label, prediction_col_name=prediction_col_name, ) test_results = ( ht.group_by(ht[prediction_col_name], ht[label]) .aggregate(n=hl.agg.count()) .collect() ) # Print results df = pd.DataFrame(test_results) df = df.pivot(index=label, columns=prediction_col_name, values="n") logger.info("Testing results:\n{}".format(pprint.pformat(df))) logger.info( "Accuracy: {}".format( sum([x.n for x in test_results if x[label] == x[prediction_col_name]]) / sum([x.n for x in test_results]) ) ) return test_results
def generic_field_check( ht: hl.Table, cond_expr: hl.expr.BooleanExpression, check_description: str, display_fields: List[str], verbose: bool, show_percent_sites: bool = False, ) -> None: """ Check a generic logical condition involving annotations in a Hail Table and print the results to terminal. Displays the number of rows (and percent of rows, if `show_percent_sites` is True) in the Table that match the `cond_expr` and fail to be the desired condition (`check_description`). If the number of rows that match the `cond_expr` is 0, then the Table passes that check; otherwise, it fails. .. note:: `cond_expr` and `check_description` are opposites and should never be the same. E.g., If `cond_expr` filters for instances where the raw AC is less than adj AC, then it is checking sites that fail to be the desired condition (`check_description`) of having a raw AC greater than or equal to the adj AC. :param ht: Table containing annotations to be checked. :param cond_expr: Logical expression referring to annotations in ht to be checked. :param check_description: String describing the condition being checked; is displayed in terminal summary message. :param display_fields: List of names of ht annotations to be displayed in case of failure (for troubleshooting purposes); these fields are also displayed if verbose is True. :param verbose: If True, show top values of annotations being checked, including checks that pass; if False, show only top values of annotations that fail checks. :param show_percent_sites: Show percentage of sites that fail checks. Default is False. :return: None """ ht_orig = ht ht = ht.filter(cond_expr) n_fail = ht.count() if n_fail > 0: logger.info("Found %d sites that fail %s check:", n_fail, check_description) if show_percent_sites: logger.info("Percentage of sites that fail: %f", n_fail / ht_orig.count()) ht = ht.flatten() ht.select("locus", "alleles", *display_fields).show() else: logger.info("PASSED %s check", check_description) if verbose: ht_orig = ht_orig.flatten() ht_orig.select(*display_fields).show()
def get_duplicated_samples_ibd( kin_ht: hl.Table, i_col: str = 'i', j_col: str = 'j', pi_hat_col: str = 'pi_hat', duplicate_threshold: float = 0.90) -> List[Set[str]]: """ Given a ibd output Table, extract the list of duplicate samples. Returns a list of set of samples that are duplicates. :param Table kin_ht: ibd output table :param str i_col: Column containing the 1st sample :param str j_col: Column containing the 2nd sample :param str pi_hat_col: Column containing the pi_hat value :param float duplicate_threshold: pi_hat threshold to consider two samples duplicated :return: List of samples that are duplicates :rtype: list of set of str """ def get_all_dups(s, dups, samples_duplicates ): # should add docstring (sample, empty set, duplicates) if s in samples_duplicates: dups.add(s) s_dups = samples_duplicates.pop( s) # gives u the value with that key for s_dup in s_dups: if s_dup not in dups: dups = get_all_dups(s_dup, dups, samples_duplicates) return dups dup_rows = kin_ht.filter( kin_ht[pi_hat_col] > duplicate_threshold).collect() samples_duplicates = defaultdict( set) #key is sample, value is set of that sample's duplicates for row in dup_rows: samples_duplicates[row[i_col]].add(row[j_col]) samples_duplicates[row[j_col]].add(row[i_col]) duplicated_samples = [] while len(samples_duplicates) > 0: duplicated_samples.append( get_all_dups( list(samples_duplicates)[0], set(), samples_duplicates)) return duplicated_samples
def rank_related_samples( relatedness_ht: hl.Table, meta_ht: hl.Table, sample_qc_ht: hl.Table, fam_ht: hl.Table ) -> Tuple[hl.Table, Callable[[hl.expr.Expression, hl.expr.Expression], hl.expr.NumericExpression]]: # Load families and identify parents from cases as they will be thrown away anyways fam_ht = fam_ht.transmute(trio=[ hl.struct(s=fam_ht.id, is_parent=False), hl.struct(s=fam_ht.pat_id, is_parent=True), hl.struct(s=fam_ht.mat_id, is_parent=True) ]) fam_ht = fam_ht.explode(fam_ht.trio) fam_ht = fam_ht.key_by(s=fam_ht.trio.s) case_parents = fam_ht.filter(meta_ht[fam_ht.key].is_case & fam_ht.trio.is_parent) def annotate_related_pairs(related_pairs: hl.Table, index_col: str) -> hl.Table: related_pairs = related_pairs.key_by(**related_pairs[index_col]) related_pairs = related_pairs.filter( hl.is_missing(case_parents[related_pairs.key])) return related_pairs.annotate( **{ index_col: related_pairs[index_col].annotate( case_rank=hl.or_else( hl.int(meta_ht[related_pairs.key].is_case), -1), dp_mean=hl.or_else( sample_qc_ht[ related_pairs.key].sample_qc.dp_stats.mean, -1.0)) }).key_by() relatedness_ht = annotate_related_pairs(relatedness_ht, "i") relatedness_ht = annotate_related_pairs(relatedness_ht, "j") def tie_breaker(l, r): return (hl.case().when(l.case_rank != r.case_rank, r.case_rank - l.case_rank) # smaller is better .default(l.dp_mean - r.dp_mean) # larger is better ) return relatedness_ht, tie_breaker
def get_related_samples_to_drop(rank_table: hl.Table, relatedness_ht: hl.Table) -> hl.Table: """ Use the maximal independence function in Hail to intelligently prune clusters of related individuals, removing less desirable samples while maximizing the number of unrelated individuals kept in the sample set :param Table rank_table: Table with ranking annotations across exomes and genomes, computed via make_rank_file() :param Table relatedness_ht: Table with kinship coefficient annotations computed via pc_relate() :return: Table containing sample IDs ('s') to be pruned from the combined exome and genome sample set :rtype: Table """ # Define maximal independent set, using rank list related_pairs = relatedness_ht.filter( relatedness_ht.kin > 0.08838835).select('i', 'j') n_related_samples = hl.eval( hl.len( related_pairs.aggregate(hl.agg.explode( lambda x: hl.agg.collect_as_set(x), [related_pairs.i, related_pairs.j]), _localize=False))) logger.info( '{} samples with at least 2nd-degree relatedness found in callset'. format(n_related_samples)) max_rank = rank_table.count() related_pairs = related_pairs.annotate( id1_rank=hl.struct(id=related_pairs.i, rank=rank_table[related_pairs.i].rank), id2_rank=hl.struct(id=related_pairs.j, rank=rank_table[related_pairs.j].rank)).select( 'id1_rank', 'id2_rank') def tie_breaker(l, r): return hl.or_else(l.rank, max_rank + 1) - hl.or_else( r.rank, max_rank + 1) related_samples_to_drop_ranked = hl.maximal_independent_set( related_pairs.id1_rank, related_pairs.id2_rank, keep=False, tie_breaker=tie_breaker) return related_samples_to_drop_ranked.select( **related_samples_to_drop_ranked.node.id).key_by('data_type', 's')
def filter_ped(raw_ped: hl.Pedigree, mendel: hl.Table, max_dnm: int, max_mendel: int) -> hl.Pedigree: mendel = mendel.filter(mendel.fam_id.startswith("fake")) mendel_by_s = ( mendel.group_by(mendel.s).aggregate( fam_id=hl.agg.take(mendel.fam_id, 1)[0], n_mendel=hl.agg.count(), n_de_novo=hl.agg.count_where( mendel.mendel_code == 2), # Code 2 is parents are hom ref, child is het ).persist()) good_trios = mendel_by_s.aggregate( hl.agg.filter( (mendel_by_s.n_mendel < max_mendel) & (mendel_by_s.n_de_novo < max_dnm), hl.agg.collect(mendel_by_s.s, ), )) logger.info(f"Found {len(good_trios)} trios passing filters") return hl.Pedigree( [trio for trio in raw_ped.trios if trio.s in good_trios])
def get_duplicated_samples(kin_ht: hl.Table, i_col: str = 'i', j_col: str = 'j', kin_col: str = 'kin', duplicate_threshold: float = 0.4) -> List[Set[str]]: """ Given a pc_relate output Table, extract the list of duplicate samples. Returns a list of set of samples that are duplicates. :param Table kin_ht: pc_relate output table :param str i_col: Column containing the 1st sample :param str j_col: Column containing the 2nd sample :param str kin_col: Column containing the kinship value :param float duplicate_threshold: Kinship threshold to consider two samples duplicated :return: List of samples that are duplicates :rtype: list of set of str """ def get_all_dups(s, dups, samples_duplicates): if s in samples_duplicates: dups.add(s) s_dups = samples_duplicates.pop(s) for s_dup in s_dups: if s_dup not in dups: dups = get_all_dups(s_dup, dups, samples_duplicates) return dups dup_rows = kin_ht.filter(kin_ht[kin_col] > duplicate_threshold).collect() samples_duplicates = defaultdict(set) for row in dup_rows: samples_duplicates[row[i_col]].add(row[j_col]) samples_duplicates[row[j_col]].add(row[i_col]) duplicated_samples = [] while len(samples_duplicates) > 0: duplicated_samples.append( get_all_dups( list(samples_duplicates)[0], set(), samples_duplicates)) return duplicated_samples
def subset_samples( input_mt: hl.MatrixTable, pedigree: hl.Table, sex_ht: hl.Table, output_dir: str, output_name: str, ) -> Tuple[hl.MatrixTable, hl.Table, list, list]: """ Filter the MatrixTable and sex Table to only samples in the pedigree. :param input_mt: MatrixTable :param pedigree: Pedigree file from seqr loaded as a Hail Table :param sex_ht: Table of inferred sexes for each sample :param output_dir: Path to directory to output results :param output_name: Output prefix to use for results :return: MatrixTable and sex ht subsetted to the samples given in the pedigree, list of samples in the pedigree, list of samples in the VCF """ # Get sample names to subset from the pedigree samples_to_subset = hl.set(pedigree.Individual_ID.collect()) # Subset mt and ht to samples in the pedigree mt_subset = input_mt.filter_cols(samples_to_subset.contains(input_mt["s"])) sex_ht = sex_ht.filter(samples_to_subset.contains(sex_ht["s"])) # Filter to variants that have at least one alt call after the subsetting mt_subset = mt_subset.filter_rows(hl.agg.any(mt_subset.GT.is_non_ref())) # Check that the samples in the pedigree are present in the VCF subset and output samples that are missing out_missing_samples = hl.hadoop_open( f"{output_dir}/{output_name}_missing_samples_in_subset.txt", "w") expected_samples = pedigree.Individual_ID.collect() vcf_samples = mt_subset.s.collect() missings = set(expected_samples) - set(vcf_samples) for i in missings: out_missing_samples.write(i + "\n") out_missing_samples.close() return (mt_subset, sex_ht, expected_samples, vcf_samples)
def infer_families( relationship_ht: hl.Table, sex: Union[hl.Table, Dict[str, bool]], duplicate_samples_ht: hl.Table, i_col: str = "i", j_col: str = "j", relationship_col: str = "relationship", ) -> hl.Pedigree: """ This function takes a hail Table with a row for each pair of individuals i,j in the data that are related (it's OK to have unrelated samples too). The `relationship_col` should be a column specifying the relationship between each two samples as defined in this module's constants. This function returns a pedigree containing trios inferred from the data. Family ID can be the same for multiple trios if one or more members of the trios are related (e.g. sibs, multi-generational family). Trios are ordered by family ID. .. note:: This function only returns complete trios defined as: one child, one father and one mother (sex is required for both parents). :param relationship_ht: Input relationship table :param sex: A Table or dict giving the sex for each sample (`TRUE`=female, `FALSE`=male). If a Table is given, it should have a field `is_female`. :param duplicated_samples: All duplicated samples TO REMOVE (If not provided, this function won't work as it assumes that each child has exactly two parents) :param i_col: Column containing the 1st sample of the pair in the relationship table :param j_col: Column containing the 2nd sample of the pair in the relationship table :param relationship_col: Column contatining the relationship for the sample pair as defined in this module constants. :return: Pedigree of complete trios """ def group_parent_child_pairs_by_fam( parent_child_pairs: Iterable[Tuple[str, str]] ) -> List[List[Tuple[str, str]]]: """ Takes all parent-children pairs and groups them by family. A family here is defined as a list of sample-pairs which all share at least one sample with at least one other sample-pair in the list. :param parent_child_pairs: All the parent-children pairs :return: A list of families, where each element of the list is a list of the parent-children pairs """ fam_id = 1 # stores the current family id s_fam = dict() # stores the family id for each sample fams = defaultdict(list) # stores fam_id -> sample-pairs for pair in parent_child_pairs: if pair[0] in s_fam: if pair[1] in s_fam: if ( s_fam[pair[0]] != s_fam[pair[1]] ): # If both samples are in different families, merge the families new_fam_id = s_fam[pair[0]] fam_id_to_merge = s_fam[pair[1]] for s in s_fam: if s_fam[s] == fam_id_to_merge: s_fam[s] = new_fam_id fams[new_fam_id].extend(fams.pop(fam_id_to_merge)) else: # If only the 1st sample in the pair is already in a family, assign the 2nd sample in the pair to the same family s_fam[pair[1]] = s_fam[pair[0]] fams[s_fam[pair[0]]].append(pair) elif ( pair[1] in s_fam ): # If only the 2nd sample in the pair is already in a family, assign the 1st sample in the pair to the same family s_fam[pair[0]] = s_fam[pair[1]] fams[s_fam[pair[1]]].append(pair) else: # If none of the samples in the pair is already in a family, create a new family s_fam[pair[0]] = fam_id s_fam[pair[1]] = fam_id fams[fam_id].append(pair) fam_id += 1 return list(fams.values()) def get_trios( fam_id: str, parent_child_pairs: List[Tuple[str, str]], related_pairs: Dict[Tuple[str, str], str], ) -> List[hl.Trio]: """ Generates trios based from the list of parent-child pairs in the family and all related pairs in the data. Only complete parent/offspring trios are included in the results. The trios are assembled as follows: 1. All pairs of unrelated samples with different sexes within the family are extracted as possible parent pairs 2. For each possible parent pair, a list of all children is constructed (each child in the list has a parent-offspring pair with each parent) 3. If there are multiple children for a given parent pair, all children should be siblings with each other 4. Check that each child was only assigned a single pair of parents. If a child is found to have multiple parent pairs, they are ALL discarded. :param fam_id: The family ID :param parent_child_pairs: The parent-child pairs for this family :param related_pairs: All related sample pairs in the data :return: List of trios in the family """ def get_possible_parents(samples: List[str]) -> List[Tuple[str, str]]: """ 1. All pairs of unrelated samples with different sexes within the family are extracted as possible parent pairs :param samples: All samples in the family :return: Possible parent pairs """ possible_parents = [] for i in range(len(samples)): for j in range(i + 1, len(samples)): if (related_pairs.get( tuple(sorted([samples[i], samples[j]]))) is None): if sex.get(samples[i]) is False and sex.get( samples[j]) is True: possible_parents.append((samples[i], samples[j])) elif (sex.get(samples[i]) is True and sex.get(samples[j]) is False): possible_parents.append((samples[j], samples[i])) return possible_parents def get_children(possible_parents: Tuple[str, str]) -> List[str]: """ 2. For a given possible parent pair, a list of all children is constructed (each child in the list has a parent-offspring pair with each parent) :param possible_parents: A pair of possible parents :return: The list of all children (if any) corresponding to the possible parents """ possible_offsprings = defaultdict( set ) # stores sample -> set of parents in the possible_parents where (sample, parent) is found in possible_child_pairs for pair in parent_child_pairs: if possible_parents[0] == pair[0]: possible_offsprings[pair[1]].add(possible_parents[0]) elif possible_parents[0] == pair[1]: possible_offsprings[pair[0]].add(possible_parents[0]) elif possible_parents[1] == pair[0]: possible_offsprings[pair[1]].add(possible_parents[1]) elif possible_parents[1] == pair[1]: possible_offsprings[pair[0]].add(possible_parents[1]) return [ s for s, parents in possible_offsprings.items() if len(parents) == 2 ] def check_sibs(children: List[str]) -> bool: """ 3. If there are multiple children for a given parent pair, all children should be siblings with each other :param children: List of all children for a given parent pair :return: Whether all children in the list are siblings """ for i in range(len(children)): for j in range(i + 1, len(children)): if (related_pairs[tuple(sorted([children[i], children[j] ]))] != SIBLINGS): return False return True def discard_multi_parents_children(trios: List[hl.Trio]): """ 4. Check that each child was only assigned a single pair of parents. If a child is found to have multiple parent pairs, they are ALL discarded. :param trios: All trios formed for this family :return: The list of trios for which each child has a single parents pair. """ children_trios = defaultdict(list) for trio in trios: children_trios[trio.s].append(trio) for s, s_trios in children_trios.items(): if len(s_trios) > 1: logger.warning( "Discarded duplicated child {0} found multiple in trios: {1}" .format(s, ", ".join([str(trio) for trio in s_trios]))) return [ trios[0] for trios in children_trios.values() if len(trios) == 1 ] # Get all possible pairs of parents in (father, mother) order all_possible_parents = get_possible_parents( list({s for pair in parent_child_pairs for s in pair})) trios = [] for possible_parents in all_possible_parents: children = get_children(possible_parents) if check_sibs(children): trios.extend([ hl.Trio( s=s, fam_id=fam_id, pat_id=possible_parents[0], mat_id=possible_parents[1], is_female=sex.get(s), ) for s in children ]) else: logger.warning( "Discarded family with same parents, and multiple offspring that weren't siblings:" "\nMother: {}\nFather:{}\nChildren:{}".format( possible_parents[0], possible_parents[1], ", ".join(children))) return discard_multi_parents_children(trios) # Get all the relations we care about: # => Remove unrelateds and duplicates dups = duplicate_samples_ht.aggregate( hl.agg.explode(lambda dup: hl.agg.collect_as_set(dup), duplicate_samples_ht.filtered), _localize=False, ) relationship_ht = relationship_ht.filter( ~dups.contains(relationship_ht[i_col]) & ~dups.contains(relationship_ht[j_col]) & (relationship_ht[relationship_col] != UNRELATED)) # Check relatedness table format if not relationship_ht[i_col].dtype == relationship_ht[j_col].dtype: logger.error( "i_col and j_col of the relatedness table need to be of the same type." ) # If i_col and j_col aren't str, then convert them if not isinstance(relationship_ht[i_col], hl.expr.StringExpression): logger.warning( f"Pedigrees can only be constructed from string IDs, but your relatedness_ht ID column is of type: {relationship_ht[i_col].dtype}. Expression will be converted to string in Pedigrees." ) if isinstance(relationship_ht[i_col], hl.expr.StructExpression): logger.warning( f"Struct fields {list(relationship_ht[i_col])} will be joined by underscores to use as sample names in Pedigree." ) relationship_ht = relationship_ht.key_by( **{ i_col: hl.delimit( hl.array([ hl.str(relationship_ht[i_col][x]) for x in relationship_ht[i_col] ]), "_", ), j_col: hl.delimit( hl.array([ hl.str(relationship_ht[j_col][x]) for x in relationship_ht[j_col] ]), "_", ), }) else: raise NotImplementedError( "The `i_col` and `j_col` columns of the `relationship_ht` argument passed to infer_families are not of type StringExpression or Struct." ) # If sex is a Table, extract sex information as a Dict if isinstance(sex, hl.Table): sex = dict(hl.tuple([sex.s, sex.is_female]).collect()) # Collect all related sample pairs and # create a dictionnary with pairs as keys and relationships as values # Sample-pairs are tuples ordered by sample name related_pairs = { tuple(sorted([i, j])): rel for i, j, rel in hl.tuple([ relationship_ht.i, relationship_ht.j, relationship_ht.relationship ]).collect() } parent_child_pairs_by_fam = group_parent_child_pairs_by_fam( [pair for pair, rel in related_pairs.items() if rel == PARENT_CHILD]) return hl.Pedigree([ trio for fam_index, parent_child_pairs in enumerate( parent_child_pairs_by_fam) for trio in get_trios( str(fam_index), parent_child_pairs, related_pairs) ])
def liftover_intervals(t: hl.Table, keep_missing_interval: bool = False) -> hl.Table: """ Liftover locus in intervals from one coordinate system (hg37) to another (hg38) # Example input table description # # ---------------------------------------- # Global fields: # None # ---------------------------------------- # Row fields: # 'interval': interval<locus<GRCh37>> # ---------------------------------------- # Key: ['interval'] # ---------------------------------------- :param t: Table of intervals on GRCh37 :param keep_missing_interval: If True, keep missing (non-lifted) intervals in the output Table. :return: Table with intervals lifted over GRCh38 added. """ rg37 = hl.get_reference("GRCh37") rg38 = hl.get_reference("GRCh38") if not rg37.has_liftover("GRCh38"): rg37.add_liftover( f'{nfs_dir}/resources/liftover/grch37_to_grch38.over.chain.gz', rg38) t = t.annotate( start=hl.liftover(t.interval.start, "GRCh38"), end=hl.liftover(t.interval.end, "GRCh38"), ) t = t.filter((t.start.contig == "chr" + t.interval.start.contig) & (t.end.contig == "chr" + t.interval.end.contig)) t = t.key_by() t = (t.select(interval=hl.locus_interval(t.start.contig, t.start.position, t.end.position, reference_genome=rg38, invalid_missing=True), interval_hg37=t.interval)) # bad intervals missing = t.aggregate(hl.agg.counter(~hl.is_defined(t.interval))) logger.info( f"Number of missing intervals: {missing[True]} out of {t.count()}...") # update globals annotations global_ann_expr = { 'date': current_date(), 'reference_genome': 'GRCh38', 'was_lifted': True } t = t.annotate_globals(**global_ann_expr) if not keep_missing_interval: logger.info(f"Filtering out {missing[True]} missing intervals...") t = t.filter(hl.is_defined(t.interval), keep=True) return t.key_by("interval")
def train_rf( ht: hl.Table, fp_to_tp: float = 1.0, num_trees: int = 500, max_depth: int = 5, no_transmitted_singletons: bool = False, no_inbreeding_coeff: bool = False, vqsr_training: bool = False, vqsr_model_id: str = False, filter_centromere_telomere: bool = False, test_intervals: Union[str, List[str]] = "chr20", ): """ Train random forest model using `train_rf_model` :param ht: Table containing annotations needed for RF training, built with `create_rf_ht` :param fp_to_tp: Ratio of FPs to TPs for creating the RF model. If set to 0, all training examples are used. :param num_trees: Number of trees in the RF model. :param max_depth: Maxmimum tree depth in the RF model. :param no_transmitted_singletons: Do not use transmitted singletons for training. :param no_inbreeding_coeff: Do not use inbreeding coefficient as a feature for training. :param vqsr_training: Use VQSR training sites to train the RF. :param vqsr_model_id: VQSR model to use for vqsr_training. `vqsr_training` must be True for this parameter to be used. :param filter_centromere_telomere: Filter centromeres and telomeres before training. :param test_intervals: Specified interval(s) will be held out for testing and evaluation only. (default to "chr20") :return: `ht` annotated with training information and the RF model """ features = FEATURES test_intervals = test_intervals if no_inbreeding_coeff: logger.info("Removing InbreedingCoeff from list of features...") features.remove("InbreedingCoeff") if vqsr_training: logger.info("Using VQSR training sites for RF training...") vqsr_ht = get_vqsr_filters(vqsr_model_id, split=True).ht() ht = ht.annotate( vqsr_POSITIVE_TRAIN_SITE=vqsr_ht[ht.key].info.POSITIVE_TRAIN_SITE, vqsr_NEGATIVE_TRAIN_SITE=vqsr_ht[ht.key].info.NEGATIVE_TRAIN_SITE, ) tp_expr = ht.vqsr_POSITIVE_TRAIN_SITE fp_expr = ht.vqsr_NEGATIVE_TRAIN_SITE else: fp_expr = ht.fail_hard_filters tp_expr = ht.omni | ht.mills if not no_transmitted_singletons: tp_expr = tp_expr | ht.transmitted_singleton if test_intervals: if isinstance(test_intervals, str): test_intervals = [test_intervals] test_intervals = [ hl.parse_locus_interval(x, reference_genome="GRCh38") for x in test_intervals ] ht = ht.annotate(tp=tp_expr, fp=fp_expr) if filter_centromere_telomere: logger.info("Filtering centromeres and telomeres from HT...") rf_ht = ht.filter(~hl.is_defined(telomeres_and_centromeres.ht()[ht.locus])) else: rf_ht = ht rf_ht, rf_model = train_rf_model( rf_ht, rf_features=features, tp_expr=rf_ht.tp, fp_expr=rf_ht.fp, fp_to_tp=fp_to_tp, num_trees=num_trees, max_depth=max_depth, test_expr=hl.literal(test_intervals).any( lambda interval: interval.contains(rf_ht.locus) ), ) logger.info("Joining original RF Table with training information") ht = ht.join(rf_ht, how="left") return ht, rf_model
def create_binned_data_initial(ht: hl.Table, data: str, data_type: str, n_bins: int) -> hl.Table: # Count variants for ranking count_expr = {x: hl.agg.filter(hl.is_defined(ht[x]), hl.agg.counter(hl.cond(hl.is_snp( ht.alleles[0], ht.alleles[1]), 'snv', 'indel'))) for x in ht.row if x.endswith('rank')} rank_variant_counts = ht.aggregate(hl.Struct(**count_expr)) logger.info( f"Found the following variant counts:\n {pformat(rank_variant_counts)}") ht_truth_data = hl.read_table( f"{temp_dir}/ddd-elgh-ukbb/variant_qc/truthset_table.ht") ht = ht.annotate_globals(rank_variant_counts=rank_variant_counts) ht = ht.annotate( **ht_truth_data[ht.key], # **fam_ht[ht.key], # **gnomad_ht[ht.key], # **denovo_ht[ht.key], # clinvar=hl.is_defined(clinvar_ht[ht.key]), indel_length=hl.abs(ht.alleles[0].length()-ht.alleles[1].length()), rank_bins=hl.array( [hl.Struct( rank_id=rank_name, bin=hl.int(hl.ceil(hl.float(ht[rank_name] + 1) / hl.floor(ht.globals.rank_variant_counts[rank_name][hl.cond( hl.is_snp(ht.alleles[0], ht.alleles[1]), 'snv', 'indel')] / n_bins))) ) for rank_name in rank_variant_counts] ), # lcr=hl.is_defined(lcr_intervals[ht.locus]) ) ht = ht.explode(ht.rank_bins) ht = ht.transmute( rank_id=ht.rank_bins.rank_id, bin=ht.rank_bins.bin ) ht = ht.filter(hl.is_defined(ht.bin)) ht = ht.checkpoint( f'{tmp_dir}/gnomad_score_binning_tmp.ht', overwrite=True) # Create binned data return ( ht .group_by( rank_id=ht.rank_id, contig=ht.locus.contig, snv=hl.is_snp(ht.alleles[0], ht.alleles[1]), bi_allelic=hl.is_defined(ht.biallelic_rank), singleton=ht.transmitted_singleton, trans_singletons=hl.is_defined(ht.singleton_rank), de_novo_high_quality=ht.de_novo_high_quality_rank, de_novo_medium_quality=hl.is_defined( ht.de_novo_medium_quality_rank), de_novo_synonymous=hl.is_defined(ht.de_novo_synonymous_rank), # release_adj=ht.ac > 0, bin=ht.bin )._set_buffer_size(20000) .aggregate( min_score=hl.agg.min(ht.score), max_score=hl.agg.max(ht.score), n=hl.agg.count(), n_ins=hl.agg.count_where( hl.is_insertion(ht.alleles[0], ht.alleles[1])), n_del=hl.agg.count_where( hl.is_deletion(ht.alleles[0], ht.alleles[1])), n_ti=hl.agg.count_where(hl.is_transition( ht.alleles[0], ht.alleles[1])), n_tv=hl.agg.count_where(hl.is_transversion( ht.alleles[0], ht.alleles[1])), n_1bp_indel=hl.agg.count_where(ht.indel_length == 1), n_mod3bp_indel=hl.agg.count_where((ht.indel_length % 3) == 0), # n_clinvar=hl.agg.count_where(ht.clinvar), n_singleton=hl.agg.count_where(ht.transmitted_singleton), n_high_quality_de_novos=hl.agg.count_where( ht.de_novo_data.p_de_novo[0] > 0.99), n_validated_DDD_denovos=hl.agg.count_where( ht.inheritance.contains("De novo")), n_medium_quality_de_novos=hl.agg.count_where( ht.de_novo_data.p_de_novo[0] > 0.5), n_high_confidence_de_novos=hl.agg.count_where( ht.de_novo_data.confidence[0] == 'HIGH'), n_de_novo=hl.agg.filter(ht.family_stats.unrelated_qc_callstats.AC[0][1] == 0, hl.agg.sum( ht.family_stats.mendel[0].errors)), n_high_quality_de_novos_synonymous=hl.agg.count_where( (ht.de_novo_data.p_de_novo[0] > 0.99) & (ht.consequence == "synonymous_variant")), # n_de_novo_no_lcr=hl.agg.filter(~ht.lcr & ( # ht.family_stats.unrelated_qc_callstats.AC[1] == 0), hl.agg.sum(ht.family_stats.mendel.errors)), n_de_novo_sites=hl.agg.filter(ht.family_stats.unrelated_qc_callstats.AC[0][1] == 0, hl.agg.count_where( ht.family_stats.mendel[0].errors > 0)), # n_de_novo_sites_no_lcr=hl.agg.filter(~ht.lcr & ( # ht.family_stats.unrelated_qc_callstats.AC[1] == 0), hl.agg.count_where(ht.family_stats.mendel.errors > 0)), n_trans_singletons=hl.agg.filter((ht.ac_raw < 3) & ( ht.family_stats.unrelated_qc_callstats.AC[0][1] == 1), hl.agg.sum(ht.family_stats.tdt[0].t)), n_trans_singletons_synonymous=hl.agg.filter((ht.ac_raw < 3) & (ht.consequence == "synonymous_variant") & ( ht.family_stats.unrelated_qc_callstats.AC[0][1] == 1), hl.agg.sum(ht.family_stats.tdt[0].t)), n_untrans_singletons=hl.agg.filter((ht.ac_raw < 3) & ( ht.family_stats.unrelated_qc_callstats.AC[0][1] == 1), hl.agg.sum(ht.family_stats.tdt[0].u)), n_untrans_singletons_synonymous=hl.agg.filter((ht.ac_raw < 3) & (ht.consequence == "synonymous_variant") & ( ht.family_stats.unrelated_qc_callstats.AC[0][1] == 1), hl.agg.sum(ht.family_stats.tdt[0].u)), n_train_trans_singletons=hl.agg.count_where( (ht.family_stats.unrelated_qc_callstats.AC[0][1] == 1) & (ht.family_stats.tdt[0].t == 1)), n_omni=hl.agg.count_where(ht.omni), n_mills=hl.agg.count_where(ht.mills), n_hapmap=hl.agg.count_where(ht.hapmap), n_kgp_high_conf_snvs=hl.agg.count_where( ht.kgp_phase1_hc), fail_hard_filters=hl.agg.count_where(ht.fail_hard_filters), # n_vqsr_pos_train=hl.agg.count_where(ht.vqsr_positive_train_site), # n_vqsr_neg_train=hl.agg.count_where(ht.vqsr_negative_train_site) ) )
def get_summary_counts( ht: hl.Table, freq_field: str = "freq", filter_field: str = "filters", filter_decoy: bool = False, index: int = 0, ) -> hl.Table: """ Generate a struct with summary counts across variant categories. Summary counts: - Number of variants - Number of indels - Number of SNVs - Number of LoF variants - Number of LoF variants that pass LOFTEE (including with LoF flags) - Number of LoF variants that pass LOFTEE without LoF flags - Number of OS (other splice) variants annotated by LOFTEE - Number of LoF variants that fail LOFTEE filters Also annotates Table's globals with total variant counts. Before calculating summary counts, function: - Filters out low confidence regions - Filters to canonical transcripts - Uses the most severe consequence Assumes that: - Input HT is annotated with VEP. - Multiallelic variants have been split and/or input HT contains bi-allelic variants only. - freq_expr was calculated with `annotate_freq`. - (Frequency index 0 from `annotate_freq` is frequency for all pops calculated on adj genotypes only.) :param ht: Input Table. :param freq_field: Name of field in HT containing frequency annotation (array of structs). Default is "freq". :param filter_field: Name of field in HT containing variant filter information. Default is "filters". :param filter_decoy: Whether to filter decoy regions. Default is False. :param index: Which index of freq_expr to use for annotation. Default is 0. :return: Table grouped by frequency bin and aggregated across summary count categories. """ logger.info("Checking if multi-allelic variants have been split...") max_alleles = ht.aggregate(hl.agg.max(hl.len(ht.alleles))) if max_alleles > 2: logger.info( "Splitting multi-allelics and VEP transcript consequences...") ht = hl.split_multi_hts(ht) logger.info("Filtering to PASS variants in high confidence regions...") ht = ht.filter((hl.len(ht[filter_field]) == 0)) ht = filter_low_conf_regions(ht, filter_decoy=filter_decoy) logger.info( "Filtering to canonical transcripts and getting VEP summary annotations..." ) ht = filter_vep_to_canonical_transcripts(ht) ht = get_most_severe_consequence_for_summary(ht) logger.info("Annotating with frequency bin information...") ht = ht.annotate(freq_bin=freq_bin_expr(ht[freq_field], index)) logger.info( "Annotating HT globals with total counts/total allele counts per variant category..." ) summary_counts = ht.aggregate( hl.struct(**get_summary_counts_dict( ht.locus, ht.alleles, ht.lof, ht.no_lof_flags, ht.most_severe_csq, prefix_str="total_", ))) summary_ac_counts = ht.aggregate( hl.struct(**get_summary_ac_dict( ht[freq_field][index].AC, ht.lof, ht.no_lof_flags, ht.most_severe_csq, ))) ht = ht.annotate_globals(summary_counts=summary_counts.annotate( **summary_ac_counts)) return ht.group_by("freq_bin").aggregate(**get_summary_counts_dict( ht.locus, ht.alleles, ht.lof, ht.no_lof_flags, ht.most_severe_csq, ))
def create_binned_data(ht: hl.Table, data: str, data_type: str, n_bins: int) -> hl.Table: """ Creates binned data from a rank Table grouped by rank_id (rank, biallelic, etc.), contig, snv, bi_allelic and singleton containing the information needed for evaluation plots. :param Table ht: Input rank table :param str data: Which data/run hash is being created :param str data_type: one of 'exomes' or 'genomes' :param int n_bins: Number of bins. :return: Binned Table :rtype: Table """ # Count variants for ranking count_expr = { x: hl.agg.filter( hl.is_defined(ht[x]), hl.agg.counter( hl.cond(hl.is_snp(ht.alleles[0], ht.alleles[1]), 'snv', 'indel'))) for x in ht.row if x.endswith('rank') } rank_variant_counts = ht.aggregate(hl.Struct(**count_expr)) logger.info( f"Found the following variant counts:\n {pformat(rank_variant_counts)}" ) ht = ht.annotate_globals(rank_variant_counts=rank_variant_counts) # Load external evaluation data clinvar_ht = hl.read_table(clinvar_ht_path) denovo_ht = get_validated_denovos_ht() if data_type == 'exomes': denovo_ht = denovo_ht.filter(denovo_ht.gnomad_exomes.high_quality) else: denovo_ht = denovo_ht.filter(denovo_ht.gnomad_genomes.high_quality) denovo_ht = denovo_ht.select( validated_denovo=denovo_ht.validated, high_confidence_denovo=denovo_ht.Confidence == 'HIGH') ht_truth_data = hl.read_table(annotations_ht_path(data_type, 'truth_data')) fam_ht = hl.read_table(annotations_ht_path(data_type, 'family_stats')) fam_ht = fam_ht.select(family_stats=fam_ht.family_stats[0]) gnomad_ht = get_gnomad_data(data_type).rows() gnomad_ht = gnomad_ht.select( vqsr_negative_train_site=gnomad_ht.info.NEGATIVE_TRAIN_SITE, vqsr_positive_train_site=gnomad_ht.info.POSITIVE_TRAIN_SITE, fail_hard_filters=(gnomad_ht.info.QD < 2) | (gnomad_ht.info.FS > 60) | (gnomad_ht.info.MQ < 30)) lcr_intervals = hl.import_locus_intervals(lcr_intervals_path) ht = ht.annotate( **ht_truth_data[ht.key], **fam_ht[ht.key], **gnomad_ht[ht.key], **denovo_ht[ht.key], clinvar=hl.is_defined(clinvar_ht[ht.key]), indel_length=hl.abs(ht.alleles[0].length() - ht.alleles[1].length()), rank_bins=hl.array([ hl.Struct( rank_id=rank_name, bin=hl.int( hl.ceil( hl.float(ht[rank_name] + 1) / hl.floor( ht.globals.rank_variant_counts[rank_name][hl.cond( hl.is_snp(ht.alleles[0], ht.alleles[1]), 'snv', 'indel')] / n_bins)))) for rank_name in rank_variant_counts ]), lcr=hl.is_defined(lcr_intervals[ht.locus])) ht = ht.explode(ht.rank_bins) ht = ht.transmute(rank_id=ht.rank_bins.rank_id, bin=ht.rank_bins.bin) ht = ht.filter(hl.is_defined(ht.bin)) ht = ht.checkpoint( f'gs://gnomad-tmp/gnomad_score_binning_{data_type}_tmp_{data}.ht', overwrite=True) # Create binned data return (ht.group_by( rank_id=ht.rank_id, contig=ht.locus.contig, snv=hl.is_snp(ht.alleles[0], ht.alleles[1]), bi_allelic=hl.is_defined(ht.biallelic_rank), singleton=ht.singleton, release_adj=ht.ac > 0, bin=ht.bin)._set_buffer_size(20000).aggregate( min_score=hl.agg.min(ht.score), max_score=hl.agg.max(ht.score), n=hl.agg.count(), n_ins=hl.agg.count_where( hl.is_insertion(ht.alleles[0], ht.alleles[1])), n_del=hl.agg.count_where( hl.is_deletion(ht.alleles[0], ht.alleles[1])), n_ti=hl.agg.count_where( hl.is_transition(ht.alleles[0], ht.alleles[1])), n_tv=hl.agg.count_where( hl.is_transversion(ht.alleles[0], ht.alleles[1])), n_1bp_indel=hl.agg.count_where(ht.indel_length == 1), n_mod3bp_indel=hl.agg.count_where((ht.indel_length % 3) == 0), n_clinvar=hl.agg.count_where(ht.clinvar), n_singleton=hl.agg.count_where(ht.singleton), n_validated_de_novos=hl.agg.count_where(ht.validated_denovo), n_high_confidence_de_novos=hl.agg.count_where( ht.high_confidence_denovo), n_de_novo=hl.agg.filter( ht.family_stats.unrelated_qc_callstats.AC[1] == 0, hl.agg.sum(ht.family_stats.mendel.errors)), n_de_novo_no_lcr=hl.agg.filter( ~ht.lcr & (ht.family_stats.unrelated_qc_callstats.AC[1] == 0), hl.agg.sum(ht.family_stats.mendel.errors)), n_de_novo_sites=hl.agg.filter( ht.family_stats.unrelated_qc_callstats.AC[1] == 0, hl.agg.count_where(ht.family_stats.mendel.errors > 0)), n_de_novo_sites_no_lcr=hl.agg.filter( ~ht.lcr & (ht.family_stats.unrelated_qc_callstats.AC[1] == 0), hl.agg.count_where(ht.family_stats.mendel.errors > 0)), n_trans_singletons=hl.agg.filter( (ht.info_ac < 3) & (ht.family_stats.unrelated_qc_callstats.AC[1] == 1), hl.agg.sum(ht.family_stats.tdt.t)), n_untrans_singletons=hl.agg.filter( (ht.info_ac < 3) & (ht.family_stats.unrelated_qc_callstats.AC[1] == 1), hl.agg.sum(ht.family_stats.tdt.u)), n_train_trans_singletons=hl.agg.count_where( (ht.family_stats.unrelated_qc_callstats.AC[1] == 1) & (ht.family_stats.tdt.t == 1)), n_omni=hl.agg.count_where(ht.truth_data.omni), n_mills=hl.agg.count_where(ht.truth_data.mills), n_hapmap=hl.agg.count_where(ht.truth_data.hapmap), n_kgp_high_conf_snvs=hl.agg.count_where( ht.truth_data.kgp_high_conf_snvs), fail_hard_filters=hl.agg.count_where(ht.fail_hard_filters), n_vqsr_pos_train=hl.agg.count_where(ht.vqsr_positive_train_site), n_vqsr_neg_train=hl.agg.count_where(ht.vqsr_negative_train_site)))
def sample_training_examples( ht: hl.Table, tp_expr: hl.BooleanExpression, fp_expr: hl.BooleanExpression, fp_to_tp: float = 1.0, test_expr: Optional[hl.expr.BooleanExpression] = None, ) -> hl.Table: """ Returns a Table of all positive and negative training examples in `ht` with an annotation indicating those that should be used for training given a true positive (TP) to false positive (FP) ratio. The returned Table has the following annotations: - train: indicates if the variant should be used for training. A row is given False for the annotation if True for `test_expr`, True for both `tp_expr and fp_expr`, or it is pruned out to obtain the desired `fp_to_tp` ratio. - label: indicates if a variant is a 'TP' or 'FP' and will also be labeled as such for variants defined by `test_expr`. .. note:: - This function does not support multi-allelic variants. - The function will give some stats about the TPs/FPs provided (Ti, Tv, indels). :param ht: Input Table. :param tp_expr: Expression for TP examples. :param fp_expr: Expression for FP examples. :param fp_to_tp: FP to TP ratio. If set to <= 0, all training examples are used. :param test_expr: Optional expression to exclude a set of variants from training set. Still contains TP/FP label annotation. :return: Table subset with corresponding TP and FP examples with desired FP to TP ratio. """ ht = ht.select( _tp=hl.or_else(tp_expr, False), _fp=hl.or_else(fp_expr, False), _exclude=False if test_expr is None else test_expr, ) ht = ht.filter(ht._tp | ht._fp).persist() # Get stats about TP / FP sets def _get_train_counts(ht: hl.Table) -> Tuple[int, int]: """ Determine the number of TP and FP variants in the input Table and report some stats on Ti, Tv, indels. :param ht: Input Table :return: Counts of TP and FP variants in the table """ train_stats = hl.struct(n=hl.agg.count()) if "alleles" in ht.row and ht.row.alleles.dtype == hl.tarray(hl.tstr): train_stats = train_stats.annotate( ti=hl.agg.count_where( hl.expr.is_transition(ht.alleles[0], ht.alleles[1]) ), tv=hl.agg.count_where( hl.expr.is_transversion(ht.alleles[0], ht.alleles[1]) ), indel=hl.agg.count_where( hl.expr.is_indel(ht.alleles[0], ht.alleles[1]) ), ) # Sample training examples pd_stats = ( ht.group_by(**{"contig": ht.locus.contig, "tp": ht._tp, "fp": ht._fp}) .aggregate(**train_stats) .to_pandas() ) logger.info(pformat(pd_stats)) pd_stats = pd_stats.fillna(False) # Number of true positive and false positive variants to be sampled for the training set n_tp = pd_stats[pd_stats["tp"] & ~pd_stats["fp"]]["n"].sum() n_fp = pd_stats[~pd_stats["tp"] & pd_stats["fp"]]["n"].sum() return n_tp, n_fp n_tp, n_fp = _get_train_counts(ht.filter(~ht._exclude)) prob_tp = prob_fp = 1.0 if fp_to_tp > 0: desired_fp = fp_to_tp * n_tp if desired_fp < n_fp: prob_fp = desired_fp / n_fp else: prob_tp = n_fp / desired_fp logger.info( f"Training examples sampling: tp={prob_tp}*{n_tp}, fp={prob_fp}*{n_fp}" ) train_expr = ( hl.case(missing_false=True) .when(ht._fp & hl.or_else(~ht._tp, True), hl.rand_bool(prob_fp)) .when(ht._tp & hl.or_else(~ht._fp, True), hl.rand_bool(prob_tp)) .default(False) ) else: train_expr = ~(ht._tp & ht._fp) logger.info(f"Using all {n_tp} TP and {n_fp} FP training examples.") label_expr = ( hl.case(missing_false=True) .when(ht._tp & hl.or_else(~ht._fp, True), "TP") .when(ht._fp & hl.or_else(~ht._tp, True), "FP") .default(hl.null(hl.tstr)) ) return ht.select(train=train_expr & ~ht._exclude, label=label_expr)
def generate_sib_stats_expr( mt: hl.MatrixTable, sib_ht: hl.Table, i_col: str = "i", j_col: str = "j", strata: Dict[str, hl.expr.BooleanExpression] = {"raw": True}, is_female: Optional[hl.expr.BooleanExpression] = None, ) -> hl.expr.StructExpression: """ Generates a row-wise expression containing the number of alternate alleles in common between sibling pairs. The sibling sharing counts can be stratified using additional filters using `stata`. .. note:: This function expects that the `mt` has either been split or filtered to only bi-allelics If a sample has multiple sibling pairs, only one pair will be counted :param mt: Input matrix table :param sib_ht: Table defining sibling pairs with one sample in a col (`i_col`) and the second in another col (`j_col`) :param i_col: Column containing the 1st sample of the pair in the relationship table :param j_col: Column containing the 2nd sample of the pair in the relationship table :param strata: Dict with additional strata to use when computing shared sibling variant counts :param is_female: An optional column in mt giving the sample sex. If not given, counts are only computed for autosomes. :return: A Table with the sibling shared variant counts """ def _get_alt_count(locus, gt, is_female): """ Helper method to calculate alt allele count with sex info if present """ if is_female is None: return hl.or_missing(locus.in_autosome(), gt.n_alt_alleles()) return (hl.case().when( locus.in_autosome_or_par(), gt.n_alt_alleles()).when( ~is_female & (locus.in_x_nonpar() | locus.in_y_nonpar()), hl.min(1, gt.n_alt_alleles()), ).when(is_female & locus.in_y_nonpar(), 0).default(0)) if is_female is None: logger.warning( "Since no sex expression was given to generate_sib_stats_expr, only variants in autosomes will be counted." ) # If a sample is in sib_ht more than one time, keep only one of the sibling pairs # First filter to only samples found in mt to keep as many pairs as possible s_to_keep = mt.aggregate_cols(hl.agg.collect_as_set(mt.s), _localize=False) sib_ht = sib_ht.filter( s_to_keep.contains(sib_ht[i_col].s) & s_to_keep.contains(sib_ht[j_col].s)) sib_ht = sib_ht.add_index("sib_idx") sib_ht = sib_ht.annotate(sibs=[sib_ht[i_col].s, sib_ht[j_col].s]) sib_ht = sib_ht.explode("sibs") sib_ht = sib_ht.group_by("sibs").aggregate( sib_idx=(hl.agg.take(sib_ht.sib_idx, 1, ordering=sib_ht.sib_idx)[0])) sib_ht = sib_ht.group_by( sib_ht.sib_idx).aggregate(sibs=hl.agg.collect(sib_ht.sibs)) sib_ht = sib_ht.filter(hl.len(sib_ht.sibs) == 2).persist() logger.info( f"Generating sibling variant sharing counts using {sib_ht.count()} pairs." ) sib_ht = sib_ht.explode("sibs").key_by("sibs")[mt.s] # Create sibling sharing counters sib_stats = hl.struct( **{ f"n_sib_shared_variants_{name}": hl.sum( hl.agg.filter( expr, hl.agg.group_by( sib_ht.sib_idx, hl.or_missing( hl.agg.sum(hl.is_defined(mt.GT)) == 2, hl.agg.min( _get_alt_count(mt.locus, mt.GT, is_female)), ), ), ).values()) for name, expr in strata.items() }) sib_stats = sib_stats.annotate( **{ f"ac_sibs_{name}": hl.agg.filter( expr & hl.is_defined(sib_ht.sib_idx), hl.agg.sum(mt.GT.n_alt_alleles())) for name, expr in strata.items() }) return sib_stats
def compute_related_samples_to_drop( relatedness_ht: hl.Table, rank_ht: hl.Table, kin_threshold: float, filtered_samples: Optional[hl.expr.SetExpression] = None, min_related_hard_filter: Optional[int] = None, ) -> hl.Table: """ Computes a Table with the list of samples to drop (and their global rank) to get the maximal independent set of unrelated samples. .. note:: - `relatedness_ht` should be keyed by exactly two fields of the same type, identifying the pair of samples for each row. - `rank_ht` should be keyed by a single key of the same type as a single sample identifier in `relatedness_ht`. :param relatedness_ht: relatedness HT, as produced by e.g. pc-relate :param kin_threshold: Kinship threshold to consider two samples as related :param rank_ht: Table with a global rank for each sample (smaller is preferred) :param filtered_samples: An optional set of samples to exclude (e.g. these samples were hard-filtered) These samples will then appear in the resulting samples to drop. :param min_related_hard_filter: If provided, any sample that is related to more samples than this parameter will be filtered prior to computing the maximal independent set and appear in the results. :return: A Table with the list of the samples to drop along with their rank. """ # Make sure that the key types are valid assert len(list(relatedness_ht.key)) == 2 assert relatedness_ht.key[0].dtype == relatedness_ht.key[1].dtype assert len(list(rank_ht.key)) == 1 assert relatedness_ht.key[0].dtype == rank_ht.key[0].dtype logger.info( f"Filtering related samples using a kin threshold of {kin_threshold}") relatedness_ht = relatedness_ht.filter(relatedness_ht.kin > kin_threshold) filtered_samples_rel = set() if min_related_hard_filter is not None: logger.info( f"Computing samples related to too many individuals (>{min_related_hard_filter}) for exclusion" ) gbi = relatedness_ht.annotate(s=list(relatedness_ht.key)) gbi = gbi.explode(gbi.s) gbi = gbi.group_by(gbi.s).aggregate(n=hl.agg.count()) filtered_samples_rel = gbi.aggregate( hl.agg.filter(gbi.n > min_related_hard_filter, hl.agg.collect_as_set(gbi.s))) logger.info( f"Found {len(filtered_samples_rel)} samples with too many 1st/2nd degree relatives. These samples will be excluded." ) if filtered_samples is not None: filtered_samples_rel = filtered_samples_rel.union( relatedness_ht.aggregate( hl.agg.explode( lambda s: hl.agg.collect_as_set(s), hl.array(list(relatedness_ht.key)).filter( lambda s: filtered_samples.contains(s)), ))) if len(filtered_samples_rel) > 0: filtered_samples_lit = hl.literal(filtered_samples_rel) relatedness_ht = relatedness_ht.filter( filtered_samples_lit.contains(relatedness_ht.key[0]) | filtered_samples_lit.contains(relatedness_ht.key[1]), keep=False, ) logger.info("Annotating related sample pairs with rank.") i, j = list(relatedness_ht.key) relatedness_ht = relatedness_ht.key_by(s=relatedness_ht[i]) relatedness_ht = relatedness_ht.annotate(**{ i: hl.struct(s=relatedness_ht.s, rank=rank_ht[relatedness_ht.key].rank) }) relatedness_ht = relatedness_ht.key_by(s=relatedness_ht[j]) relatedness_ht = relatedness_ht.annotate(**{ j: hl.struct(s=relatedness_ht.s, rank=rank_ht[relatedness_ht.key].rank) }) relatedness_ht = relatedness_ht.key_by(i, j) relatedness_ht = relatedness_ht.drop("s") relatedness_ht = relatedness_ht.persist() related_samples_to_drop_ht = hl.maximal_independent_set( relatedness_ht[i], relatedness_ht[j], keep=False, tie_breaker=lambda l, r: l.rank - r.rank, ) related_samples_to_drop_ht = related_samples_to_drop_ht.key_by() related_samples_to_drop_ht = related_samples_to_drop_ht.select( **related_samples_to_drop_ht.node) related_samples_to_drop_ht = related_samples_to_drop_ht.key_by("s") if len(filtered_samples_rel) > 0: related_samples_to_drop_ht = related_samples_to_drop_ht.union( hl.Table.parallelize( [ hl.struct(s=s, rank=hl.null(hl.tint64)) for s in filtered_samples_rel ], key="s", )) return related_samples_to_drop_ht