Ejemplo n.º 1
0
def compute_stratified_metrics_filter(ht: hl.Table,
                                      qc_metrics: List[str],
                                      strata: List[str] = None) -> hl.Table:
    """
    Compute median, MAD, and upper and lower thresholds for each metric used in pop- and platform-specific outlier filtering

    :param MatrixTable ht: HT containing relevant sample QC metric annotations
    :param list qc_metrics: list of metrics for which to compute the critical values for filtering outliers
    :param list of str strata: List of annotations used for stratification. These metrics should be discrete types!
    :return: Table grouped by pop and platform, with upper and lower threshold values computed for each sample QC metric
    :rtype: Table
    """
    def make_pop_filters_expr(ht: hl.Table,
                              qc_metrics: List[str]) -> hl.expr.SetExpression:
        return hl.set(
            hl.filter(lambda x: hl.is_defined(x), [
                hl.or_missing(ht[f'fail_{metric}'], metric)
                for metric in qc_metrics
            ]))

    ht = ht.select(*strata,
                   **ht.sample_qc.select(*qc_metrics)).key_by('s').persist()

    def get_metric_expr(ht, metric):
        metric_values = hl.agg.collect(ht[metric])
        metric_median = hl.median(metric_values)
        metric_mad = 1.4826 * hl.median(hl.abs(metric_values - metric_median))
        return hl.struct(median=metric_median,
                         mad=metric_mad,
                         upper=metric_median +
                         4 * metric_mad if metric != 'callrate' else 1,
                         lower=metric_median -
                         4 * metric_mad if metric != 'callrate' else 0.99)

    agg_expr = hl.struct(
        **{metric: get_metric_expr(ht, metric)
           for metric in qc_metrics})
    if strata:
        ht = ht.annotate_globals(metrics_stats=ht.aggregate(
            hl.agg.group_by(hl.tuple([ht[x] for x in strata]), agg_expr)))
    else:
        ht = ht.annotate_globals(metrics_stats={(): ht.aggregate(agg_expr)})

    strata_exp = hl.tuple([ht[x] for x in strata]) if strata else hl.tuple([])

    fail_exprs = {
        f'fail_{metric}':
        (ht[metric] >= ht.metrics_stats[strata_exp][metric].upper) |
        (ht[metric] <= ht.metrics_stats[strata_exp][metric].lower)
        for metric in qc_metrics
    }
    ht = ht.transmute(**fail_exprs)
    pop_platform_filters = make_pop_filters_expr(ht, qc_metrics)
    return ht.annotate(pop_platform_filters=pop_platform_filters)
def generic_field_check_loop(
    ht: hl.Table,
    field_check_expr: Dict[str, Dict[str, Union[hl.expr.Int64Expression,
                                                hl.expr.StructExpression]]],
    verbose: bool,
    show_percent_sites: bool = False,
    ht_count: int = None,
) -> None:
    """
    Loop through all conditional checks for a given hail Table.

    This loop allows aggregation across the hail Table once, as opposed to aggregating during every conditional check.

    :param ht: Table containing annotations to be checked.
    :param field_check_expr: Dictionary whose keys are conditions being checked and values are the expressions for filtering to condition.
    :param verbose: If True, show top values of annotations being checked, including checks that pass; if False, show only top values of annotations that fail checks.
    :param show_percent_sites: Show percentage of sites that fail checks. Default is False.
    :param ht_count: Previously computed sum of sites within hail Table. Default is None.
    :return: None
    """
    ht_field_check_counts = ht.aggregate(
        hl.struct(**{k: v["expr"]
                     for k, v in field_check_expr.items()}))
    for check_description, n_fail in ht_field_check_counts.items():
        generic_field_check(
            ht,
            check_description=check_description,
            n_fail=n_fail,
            display_fields=field_check_expr[check_description]
            ["display_fields"],
            verbose=verbose,
            show_percent_sites=show_percent_sites,
            ht_count=ht_count,
        )
Ejemplo n.º 3
0
def get_duplicated_samples(
    relationship_ht: hl.Table,
    i_col: str = "i",
    j_col: str = "j",
    rel_col: str = "relationship",
) -> List[Set[str]]:
    """
    Extract the list of duplicate samples using a Table ouput from pc_relate.

    :param relationship_ht: Table with relationships between pairs of samples
    :param i_col: Column containing the 1st sample
    :param j_col: Column containing the 2nd sample
    :param rel_col: Column containing the sample pair relationship annotated with get_relationship_expr
    :return: List of sets of samples that are duplicates
    """
    def get_all_dups(
        s: str, dups: Set[str], samples_duplicates: Dict[str, Set[str]]
    ) -> Tuple[Set[str], Dict[str, Set[str]]]:
        """
        Create the set of all duplicated samples corresponding to `s` that are found in `sample_duplicates`.

        Also return the remaining sample duplicates after removing all duplicated samples corresponding to `s`.

        Works by recursively adding duplicated samples to the set.

        :param s: sample to identify duplicates for
        :param dups: set of corresponding samples already identified
        :param samples_duplicates: dict of sample -> duplicate-pair left to assign
        :return: (set of duplicates corresponding to s found in samples_duplicates, remaining samples_duplicates)
        """
        if s in samples_duplicates:
            dups.add(s)
            s_dups = samples_duplicates.pop(s)
            for s_dup in s_dups:
                if s_dup not in dups:
                    dups, samples_duplicates = get_all_dups(
                        s_dup, dups, samples_duplicates)
        return dups, samples_duplicates

    logger.info("Computing duplicate sets")
    dup_pairs = relationship_ht.aggregate(
        hl.agg.filter(
            relationship_ht[rel_col] == DUPLICATE_OR_TWINS,
            hl.agg.collect(
                hl.tuple([relationship_ht[i_col], relationship_ht[j_col]])),
        ))

    samples_duplicates = defaultdict(set)
    for i, j in dup_pairs:
        samples_duplicates[i].add(j)
        samples_duplicates[j].add(i)

    duplicated_samples = []
    while len(samples_duplicates) > 0:
        dup_set, samples_duplicates = get_all_dups(
            list(samples_duplicates)[0], set(), samples_duplicates)
        duplicated_samples.append(dup_set)

    return duplicated_samples
Ejemplo n.º 4
0
def collapse_small_pops(ht: hl.Table, min_pop_size: int) -> hl.Table:
    """

    Collapses (sub)populations that are too small for release into others.
    When collapsing subpops, the name for the other category is composed of "o" +  2 first letters of the superpop
    The original RF population assignments are kept in the `rf_pop` and `rf_subpop` columns.

    :param ht: Input Table
    :return: Table with small populations collapsed
    :rtype: Table
    """
    def get_subpop_oth(pop: str):
        for superpop, subpops in SUBPOPS.items():
            if pop.upper() in subpops:
                return "o" + superpop[:2].lower()

        raise ValueError(
            f"Subpopulation {pop} not found in possible subpopulations.")

    ht = ht.persist()
    pop_counts = ht.aggregate(hl.agg.filter(ht.release,
                                            hl.agg.counter(ht.pop)))
    pop_collapse = {
        pop: "oth"
        for pop, n in pop_counts.items() if n < min_pop_size
    }
    pop_collapse = hl.literal(pop_collapse) if pop_collapse else hl.empty_dict(
        hl.tstr, hl.tstr)

    subpop_counts = ht.aggregate(
        hl.agg.filter(ht.release, hl.agg.counter(ht.subpop)))
    subpop_collapse = {
        subpop: get_subpop_oth(subpop)
        for subpop, n in subpop_counts.items() if n < min_pop_size
    }
    subpop_collapse = hl.literal(
        subpop_collapse) if subpop_collapse else hl.empty_dict(
            hl.tstr, hl.tstr)

    return ht.annotate(pop=pop_collapse.get(ht.pop, ht.pop),
                       subpop=subpop_collapse.get(ht.subpop, ht.subpop),
                       rf_pop=ht.pop,
                       rf_subpop=ht.subpop)
Ejemplo n.º 5
0
def check_mismatch(ht: hl.Table) -> hl.expr.expressions.StructExpression:
    """
    Checks for mismatches between reference allele and allele in reference fasta

    :param ht: Table to be checked
    :return: StructExpression containing counts for mismatches and count for all variants on negative strand
    """

    mismatch = ht.aggregate(
        hl.struct(
            total_variants=hl.agg.count(),
            total_mismatch=hl.agg.count_where(ht.reference_mismatch),
            negative_strand=hl.agg.count_where(
                ht.new_locus.is_negative_strand),
            negative_strand_mismatch=hl.agg.count_where(
                ht.new_locus.is_negative_strand & ht.reference_mismatch)))
    return mismatch
Ejemplo n.º 6
0
def filter_ht_for_plink(ht: hl.Table,
                        n_samples: int,
                        min_call_rate: float = 0.95,
                        variants_per_mac_category: int = 2000,
                        variants_per_maf_category: int = 10000):
    from gnomad.utils.filtering import filter_to_autosomes
    ht = filter_to_autosomes(ht)
    ht = ht.filter((ht.call_stats.AN >= n_samples * 2 * min_call_rate)
                   & (ht.call_stats.AC > 0))
    ht = ht.annotate(mac_category=mac_category_case_builder(ht.call_stats))
    category_counter = ht.aggregate(hl.agg.counter(ht.mac_category))
    print(category_counter)
    ht = ht.annotate_globals(category_counter=category_counter)
    return ht.filter(
        hl.rand_unif(
            0, 1) < hl.cond(ht.mac_category >= 1, variants_per_mac_category,
                            variants_per_maf_category) /
        ht.category_counter[ht.mac_category])
def run_sanity_checks(ht: hl.Table) -> None:
    """
    Runs and prints sanity checks on rank table.

    :param Table ht: input ranks Table
    :return: Nothing
    :rtype: None
    """
    print(
        ht.aggregate(
            hl.struct(was_split=hl.agg.counter(ht.was_split),
                      has_biallelic_rank=hl.agg.counter(
                          hl.is_defined(ht.biallelic_rank)),
                      was_singleton=hl.agg.counter(ht._singleton),
                      has_singleton_rank=hl.agg.counter(
                          hl.is_defined(ht.singleton_rank)),
                      was_split_singleton=hl.agg.counter(ht._singleton
                                                         & ~ht.was_split),
                      has_biallelic_singleton_rank=hl.agg.counter(
                          hl.is_defined(ht.biallelic_singleton_rank)))))
def create_grouped_bin_ht(ht: hl.Table,
                          model_id: str,
                          overwrite: bool = False) -> None:
    """
    Creates binned data from a quantile bin annotated Table grouped by bin_id (rank, bi-allelic, etc.), contig, snv,
    bi_allelic and singleton containing the information needed for evaluation plots.
    :param str model_id: Which data/run hash is being created
    :param bool overwrite: Should output files be overwritten if present
    :return: None
    :rtype: None
    """
    trio_stats_ht = hl.read_table(
        f'{temp_dir}/ddd-elgh-ukbb/variant_qc/Sanger_cohorts_trios_stats.ht')
    # Count variants for ranking
    count_expr = {
        x: hl.agg.filter(
            hl.is_defined(ht[x]),
            hl.agg.counter(
                hl.cond(hl.is_snp(ht.alleles[0], ht.alleles[1]), "snv",
                        "indel")),
        )
        for x in ht.row if x.endswith("bin")
    }
    bin_variant_counts = ht.aggregate(hl.struct(**count_expr))
    logger.info(
        f"Found the following variant counts:\n {pformat(bin_variant_counts)}")
    ht = ht.annotate_globals(bin_variant_counts=bin_variant_counts)

    logger.info(f"Creating grouped bin table...")
    grouped_binned_ht = compute_grouped_binned_ht(
        ht,
        checkpoint_path=(f'{tmp_dir}/ddd-elgh-ukbb/{model_id}_grouped_bin.ht'),
    )

    logger.info(f"Aggregating grouped bin table...")
    agg_ht = grouped_binned_ht.aggregate(
        **score_bin_agg(grouped_binned_ht, fam_stats_ht=trio_stats_ht))

    return agg_ht
Ejemplo n.º 9
0
    def _get_agg_struct(ht: hl.Table) -> hl.expr.StructExpression:
        """
        Aggregate input Table and return StructExpression describing doubleton pairs.

        Return count of pairs present in relatedness HT, kinship distribution stats, and
        dictionary counting relationship types.

        Assumes Table is annotated with:
            - `rel_def`: Boolean for whether pair was present in relatedness Table.
            - `kin`: Kinship value for sample pair.
            - `relationship`: Relationship of sample pair (if found in relatedness Table).

        :param hl.Table ht: Input Table.
        :return: StructExpression describing doubleton pairs.
        """
        return ht.aggregate(
            hl.struct(
                pair_in_relatedness_ht=hl.agg.count_where(ht.rel_def),
                kin_stats=hl.agg.stats(ht.kin),
                rel_counter=hl.agg.counter(ht.relationship),
                total_pairs=hl.agg.count(),
            ))
def create_binned_data_initial(ht: hl.Table, data: str, data_type: str, n_bins: int) -> hl.Table:
    # Count variants for ranking
    count_expr = {x: hl.agg.filter(hl.is_defined(ht[x]), hl.agg.counter(hl.cond(hl.is_snp(
        ht.alleles[0], ht.alleles[1]), 'snv', 'indel'))) for x in ht.row if x.endswith('rank')}
    rank_variant_counts = ht.aggregate(hl.Struct(**count_expr))
    logger.info(
        f"Found the following variant counts:\n {pformat(rank_variant_counts)}")
    ht_truth_data = hl.read_table(
        f"{temp_dir}/ddd-elgh-ukbb/variant_qc/truthset_table.ht")
    ht = ht.annotate_globals(rank_variant_counts=rank_variant_counts)
    ht = ht.annotate(
        **ht_truth_data[ht.key],
        # **fam_ht[ht.key],
        # **gnomad_ht[ht.key],
        # **denovo_ht[ht.key],
        # clinvar=hl.is_defined(clinvar_ht[ht.key]),
        indel_length=hl.abs(ht.alleles[0].length()-ht.alleles[1].length()),
        rank_bins=hl.array(
            [hl.Struct(
                rank_id=rank_name,
                bin=hl.int(hl.ceil(hl.float(ht[rank_name] + 1) / hl.floor(ht.globals.rank_variant_counts[rank_name][hl.cond(
                    hl.is_snp(ht.alleles[0], ht.alleles[1]), 'snv', 'indel')] / n_bins)))
            )
                for rank_name in rank_variant_counts]
        ),
        # lcr=hl.is_defined(lcr_intervals[ht.locus])
    )

    ht = ht.explode(ht.rank_bins)
    ht = ht.transmute(
        rank_id=ht.rank_bins.rank_id,
        bin=ht.rank_bins.bin
    )
    ht = ht.filter(hl.is_defined(ht.bin))

    ht = ht.checkpoint(
        f'{tmp_dir}/gnomad_score_binning_tmp.ht', overwrite=True)

    # Create binned data
    return (
        ht
        .group_by(
            rank_id=ht.rank_id,
            contig=ht.locus.contig,
            snv=hl.is_snp(ht.alleles[0], ht.alleles[1]),
            bi_allelic=hl.is_defined(ht.biallelic_rank),
            singleton=ht.transmitted_singleton,
            trans_singletons=hl.is_defined(ht.singleton_rank),
            de_novo_high_quality=ht.de_novo_high_quality_rank,
            de_novo_medium_quality=hl.is_defined(
                ht.de_novo_medium_quality_rank),
            de_novo_synonymous=hl.is_defined(ht.de_novo_synonymous_rank),
            # release_adj=ht.ac > 0,
            bin=ht.bin
        )._set_buffer_size(20000)
        .aggregate(
            min_score=hl.agg.min(ht.score),
            max_score=hl.agg.max(ht.score),
            n=hl.agg.count(),
            n_ins=hl.agg.count_where(
                hl.is_insertion(ht.alleles[0], ht.alleles[1])),
            n_del=hl.agg.count_where(
                hl.is_deletion(ht.alleles[0], ht.alleles[1])),
            n_ti=hl.agg.count_where(hl.is_transition(
                ht.alleles[0], ht.alleles[1])),
            n_tv=hl.agg.count_where(hl.is_transversion(
                ht.alleles[0], ht.alleles[1])),
            n_1bp_indel=hl.agg.count_where(ht.indel_length == 1),
            n_mod3bp_indel=hl.agg.count_where((ht.indel_length % 3) == 0),
            # n_clinvar=hl.agg.count_where(ht.clinvar),
            n_singleton=hl.agg.count_where(ht.transmitted_singleton),
            n_high_quality_de_novos=hl.agg.count_where(
                ht.de_novo_data.p_de_novo[0] > 0.99),
            n_validated_DDD_denovos=hl.agg.count_where(
                ht.inheritance.contains("De novo")),
            n_medium_quality_de_novos=hl.agg.count_where(
                ht.de_novo_data.p_de_novo[0] > 0.5),
            n_high_confidence_de_novos=hl.agg.count_where(
                ht.de_novo_data.confidence[0] == 'HIGH'),
            n_de_novo=hl.agg.filter(ht.family_stats.unrelated_qc_callstats.AC[0][1] == 0, hl.agg.sum(
                ht.family_stats.mendel[0].errors)),
            n_high_quality_de_novos_synonymous=hl.agg.count_where(
                (ht.de_novo_data.p_de_novo[0] > 0.99) & (ht.consequence == "synonymous_variant")),
            # n_de_novo_no_lcr=hl.agg.filter(~ht.lcr & (
            #    ht.family_stats.unrelated_qc_callstats.AC[1] == 0), hl.agg.sum(ht.family_stats.mendel.errors)),
            n_de_novo_sites=hl.agg.filter(ht.family_stats.unrelated_qc_callstats.AC[0][1] == 0, hl.agg.count_where(
                ht.family_stats.mendel[0].errors > 0)),
            # n_de_novo_sites_no_lcr=hl.agg.filter(~ht.lcr & (
            #    ht.family_stats.unrelated_qc_callstats.AC[1] == 0), hl.agg.count_where(ht.family_stats.mendel.errors > 0)),
            n_trans_singletons=hl.agg.filter((ht.ac_raw < 3) & (
                ht.family_stats.unrelated_qc_callstats.AC[0][1] == 1), hl.agg.sum(ht.family_stats.tdt[0].t)),
            n_trans_singletons_synonymous=hl.agg.filter((ht.ac_raw < 3) & (ht.consequence == "synonymous_variant") & (
                ht.family_stats.unrelated_qc_callstats.AC[0][1] == 1), hl.agg.sum(ht.family_stats.tdt[0].t)),
            n_untrans_singletons=hl.agg.filter((ht.ac_raw < 3) & (
                ht.family_stats.unrelated_qc_callstats.AC[0][1] == 1), hl.agg.sum(ht.family_stats.tdt[0].u)),
            n_untrans_singletons_synonymous=hl.agg.filter((ht.ac_raw < 3) & (ht.consequence == "synonymous_variant") & (
                ht.family_stats.unrelated_qc_callstats.AC[0][1] == 1), hl.agg.sum(ht.family_stats.tdt[0].u)),
            n_train_trans_singletons=hl.agg.count_where(
                (ht.family_stats.unrelated_qc_callstats.AC[0][1] == 1) & (ht.family_stats.tdt[0].t == 1)),
            n_omni=hl.agg.count_where(ht.omni),
            n_mills=hl.agg.count_where(ht.mills),
            n_hapmap=hl.agg.count_where(ht.hapmap),
            n_kgp_high_conf_snvs=hl.agg.count_where(
                ht.kgp_phase1_hc),
            fail_hard_filters=hl.agg.count_where(ht.fail_hard_filters),
            # n_vqsr_pos_train=hl.agg.count_where(ht.vqsr_positive_train_site),
            # n_vqsr_neg_train=hl.agg.count_where(ht.vqsr_negative_train_site)
        )
    )
Ejemplo n.º 11
0
def infer_families(
    relationship_ht: hl.Table,
    sex: Union[hl.Table, Dict[str, bool]],
    duplicate_samples_ht: hl.Table,
    i_col: str = "i",
    j_col: str = "j",
    relationship_col: str = "relationship",
) -> hl.Pedigree:
    """
    This function takes a hail Table with a row for each pair of individuals i,j in the data that are related (it's OK to have unrelated samples too).
    The `relationship_col` should be a column specifying the relationship between each two samples as defined in this module's constants.

    This function returns a pedigree containing trios inferred from the data. Family ID can be the same for multiple
    trios if one or more members of the trios are related (e.g. sibs, multi-generational family). Trios are ordered by family ID.

    .. note::

        This function only returns complete trios defined as: one child, one father and one mother (sex is required for both parents).

    :param relationship_ht: Input relationship table
    :param sex: A Table or dict giving the sex for each sample (`TRUE`=female, `FALSE`=male). If a Table is given, it should have a field `is_female`.
    :param duplicated_samples: All duplicated samples TO REMOVE (If not provided, this function won't work as it assumes that each child has exactly two parents)
    :param i_col: Column containing the 1st sample of the pair in the relationship table
    :param j_col: Column containing the 2nd sample of the pair in the relationship table
    :param relationship_col: Column contatining the relationship for the sample pair as defined in this module constants.
    :return: Pedigree of complete trios
    """
    def group_parent_child_pairs_by_fam(
        parent_child_pairs: Iterable[Tuple[str, str]]
    ) -> List[List[Tuple[str, str]]]:
        """
        Takes all parent-children pairs and groups them by family.
        A family here is defined as a list of sample-pairs which all share at least one sample with at least one other sample-pair in the list.

        :param parent_child_pairs: All the parent-children pairs
        :return: A list of families, where each element of the list is a list of the parent-children pairs
        """
        fam_id = 1  # stores the current family id
        s_fam = dict()  # stores the family id for each sample
        fams = defaultdict(list)  # stores fam_id -> sample-pairs
        for pair in parent_child_pairs:
            if pair[0] in s_fam:
                if pair[1] in s_fam:
                    if (
                            s_fam[pair[0]] != s_fam[pair[1]]
                    ):  # If both samples are in different families, merge the families
                        new_fam_id = s_fam[pair[0]]
                        fam_id_to_merge = s_fam[pair[1]]
                        for s in s_fam:
                            if s_fam[s] == fam_id_to_merge:
                                s_fam[s] = new_fam_id
                        fams[new_fam_id].extend(fams.pop(fam_id_to_merge))
                else:  # If only the 1st sample in the pair is already in a family, assign the 2nd sample in the pair to the same family
                    s_fam[pair[1]] = s_fam[pair[0]]
                fams[s_fam[pair[0]]].append(pair)
            elif (
                    pair[1] in s_fam
            ):  # If only the 2nd sample in the pair is already in a family, assign the 1st sample in the pair to the same family
                s_fam[pair[0]] = s_fam[pair[1]]
                fams[s_fam[pair[1]]].append(pair)
            else:  # If none of the samples in the pair is already in a family, create a new family
                s_fam[pair[0]] = fam_id
                s_fam[pair[1]] = fam_id
                fams[fam_id].append(pair)
                fam_id += 1

        return list(fams.values())

    def get_trios(
        fam_id: str,
        parent_child_pairs: List[Tuple[str, str]],
        related_pairs: Dict[Tuple[str, str], str],
    ) -> List[hl.Trio]:
        """
        Generates trios based from the list of parent-child pairs in the family and
        all related pairs in the data. Only complete parent/offspring trios are included in the results.

        The trios are assembled as follows:
        1. All pairs of unrelated samples with different sexes within the family are extracted as possible parent pairs
        2. For each possible parent pair, a list of all children is constructed (each child in the list has a parent-offspring pair with each parent)
        3. If there are multiple children for a given parent pair, all children should be siblings with each other
        4. Check that each child was only assigned a single pair of parents. If a child is found to have multiple parent pairs, they are ALL discarded.

        :param fam_id: The family ID
        :param parent_child_pairs: The parent-child pairs for this family
        :param related_pairs: All related sample pairs in the data
        :return: List of trios in the family
        """
        def get_possible_parents(samples: List[str]) -> List[Tuple[str, str]]:
            """
            1. All pairs of unrelated samples with different sexes within the family are extracted as possible parent pairs

            :param samples: All samples in the family
            :return: Possible parent pairs
            """
            possible_parents = []
            for i in range(len(samples)):
                for j in range(i + 1, len(samples)):
                    if (related_pairs.get(
                            tuple(sorted([samples[i], samples[j]]))) is None):
                        if sex.get(samples[i]) is False and sex.get(
                                samples[j]) is True:
                            possible_parents.append((samples[i], samples[j]))
                        elif (sex.get(samples[i]) is True
                              and sex.get(samples[j]) is False):
                            possible_parents.append((samples[j], samples[i]))
            return possible_parents

        def get_children(possible_parents: Tuple[str, str]) -> List[str]:
            """
            2. For a given possible parent pair, a list of all children is constructed (each child in the list has a parent-offspring pair with each parent)

            :param possible_parents: A pair of possible parents
            :return: The list of all children (if any) corresponding to the possible parents
            """
            possible_offsprings = defaultdict(
                set
            )  # stores sample -> set of parents in the possible_parents where (sample, parent) is found in possible_child_pairs
            for pair in parent_child_pairs:
                if possible_parents[0] == pair[0]:
                    possible_offsprings[pair[1]].add(possible_parents[0])
                elif possible_parents[0] == pair[1]:
                    possible_offsprings[pair[0]].add(possible_parents[0])
                elif possible_parents[1] == pair[0]:
                    possible_offsprings[pair[1]].add(possible_parents[1])
                elif possible_parents[1] == pair[1]:
                    possible_offsprings[pair[0]].add(possible_parents[1])

            return [
                s for s, parents in possible_offsprings.items()
                if len(parents) == 2
            ]

        def check_sibs(children: List[str]) -> bool:
            """
            3. If there are multiple children for a given parent pair, all children should be siblings with each other

            :param children: List of all children for a given parent pair
            :return: Whether all children in the list are siblings
            """
            for i in range(len(children)):
                for j in range(i + 1, len(children)):
                    if (related_pairs[tuple(sorted([children[i], children[j]
                                                    ]))] != SIBLINGS):
                        return False
            return True

        def discard_multi_parents_children(trios: List[hl.Trio]):
            """
            4. Check that each child was only assigned a single pair of parents. If a child is found to have multiple parent pairs, they are ALL discarded.

            :param trios: All trios formed for this family
            :return: The list of trios for which each child has a single parents pair.
            """
            children_trios = defaultdict(list)
            for trio in trios:
                children_trios[trio.s].append(trio)

            for s, s_trios in children_trios.items():
                if len(s_trios) > 1:
                    logger.warning(
                        "Discarded duplicated child {0} found multiple in trios: {1}"
                        .format(s, ", ".join([str(trio) for trio in s_trios])))

            return [
                trios[0] for trios in children_trios.values()
                if len(trios) == 1
            ]

        # Get all possible pairs of parents in (father, mother) order
        all_possible_parents = get_possible_parents(
            list({s
                  for pair in parent_child_pairs for s in pair}))

        trios = []
        for possible_parents in all_possible_parents:
            children = get_children(possible_parents)
            if check_sibs(children):
                trios.extend([
                    hl.Trio(
                        s=s,
                        fam_id=fam_id,
                        pat_id=possible_parents[0],
                        mat_id=possible_parents[1],
                        is_female=sex.get(s),
                    ) for s in children
                ])
            else:
                logger.warning(
                    "Discarded family with same parents, and multiple offspring that weren't siblings:"
                    "\nMother: {}\nFather:{}\nChildren:{}".format(
                        possible_parents[0], possible_parents[1],
                        ", ".join(children)))

        return discard_multi_parents_children(trios)

    # Get all the relations we care about:
    # => Remove unrelateds and duplicates
    dups = duplicate_samples_ht.aggregate(
        hl.agg.explode(lambda dup: hl.agg.collect_as_set(dup),
                       duplicate_samples_ht.filtered),
        _localize=False,
    )
    relationship_ht = relationship_ht.filter(
        ~dups.contains(relationship_ht[i_col])
        & ~dups.contains(relationship_ht[j_col])
        & (relationship_ht[relationship_col] != UNRELATED))

    # Check relatedness table format
    if not relationship_ht[i_col].dtype == relationship_ht[j_col].dtype:
        logger.error(
            "i_col and j_col of the relatedness table need to be of the same type."
        )

    # If i_col and j_col aren't str, then convert them
    if not isinstance(relationship_ht[i_col], hl.expr.StringExpression):
        logger.warning(
            f"Pedigrees can only be constructed from string IDs, but your relatedness_ht ID column is of type: {relationship_ht[i_col].dtype}. Expression will be converted to string in Pedigrees."
        )
        if isinstance(relationship_ht[i_col], hl.expr.StructExpression):
            logger.warning(
                f"Struct fields {list(relationship_ht[i_col])} will be joined by underscores to use as sample names in Pedigree."
            )
            relationship_ht = relationship_ht.key_by(
                **{
                    i_col:
                    hl.delimit(
                        hl.array([
                            hl.str(relationship_ht[i_col][x])
                            for x in relationship_ht[i_col]
                        ]),
                        "_",
                    ),
                    j_col:
                    hl.delimit(
                        hl.array([
                            hl.str(relationship_ht[j_col][x])
                            for x in relationship_ht[j_col]
                        ]),
                        "_",
                    ),
                })
        else:
            raise NotImplementedError(
                "The `i_col` and `j_col` columns of the `relationship_ht` argument passed to infer_families are not of type StringExpression or Struct."
            )

    # If sex is a Table, extract sex information as a Dict
    if isinstance(sex, hl.Table):
        sex = dict(hl.tuple([sex.s, sex.is_female]).collect())

    # Collect all related sample pairs and
    # create a dictionnary with pairs as keys and relationships as values
    # Sample-pairs are tuples ordered by sample name
    related_pairs = {
        tuple(sorted([i, j])): rel
        for i, j, rel in hl.tuple([
            relationship_ht.i, relationship_ht.j, relationship_ht.relationship
        ]).collect()
    }

    parent_child_pairs_by_fam = group_parent_child_pairs_by_fam(
        [pair for pair, rel in related_pairs.items() if rel == PARENT_CHILD])
    return hl.Pedigree([
        trio for fam_index, parent_child_pairs in enumerate(
            parent_child_pairs_by_fam) for trio in get_trios(
                str(fam_index), parent_child_pairs, related_pairs)
    ])
def create_binned_data(ht: hl.Table, data: str, data_type: str,
                       n_bins: int) -> hl.Table:
    """
    Creates binned data from a rank Table grouped by rank_id (rank, biallelic, etc.), contig, snv, bi_allelic and singleton
    containing the information needed for evaluation plots.

    :param Table ht: Input rank table
    :param str data: Which data/run hash is being created
    :param str data_type: one of 'exomes' or 'genomes'
    :param int n_bins: Number of bins.
    :return: Binned Table
    :rtype: Table
    """

    # Count variants for ranking
    count_expr = {
        x: hl.agg.filter(
            hl.is_defined(ht[x]),
            hl.agg.counter(
                hl.cond(hl.is_snp(ht.alleles[0], ht.alleles[1]), 'snv',
                        'indel')))
        for x in ht.row if x.endswith('rank')
    }
    rank_variant_counts = ht.aggregate(hl.Struct(**count_expr))
    logger.info(
        f"Found the following variant counts:\n {pformat(rank_variant_counts)}"
    )
    ht = ht.annotate_globals(rank_variant_counts=rank_variant_counts)

    # Load external evaluation data
    clinvar_ht = hl.read_table(clinvar_ht_path)
    denovo_ht = get_validated_denovos_ht()
    if data_type == 'exomes':
        denovo_ht = denovo_ht.filter(denovo_ht.gnomad_exomes.high_quality)
    else:
        denovo_ht = denovo_ht.filter(denovo_ht.gnomad_genomes.high_quality)
    denovo_ht = denovo_ht.select(
        validated_denovo=denovo_ht.validated,
        high_confidence_denovo=denovo_ht.Confidence == 'HIGH')
    ht_truth_data = hl.read_table(annotations_ht_path(data_type, 'truth_data'))
    fam_ht = hl.read_table(annotations_ht_path(data_type, 'family_stats'))
    fam_ht = fam_ht.select(family_stats=fam_ht.family_stats[0])
    gnomad_ht = get_gnomad_data(data_type).rows()
    gnomad_ht = gnomad_ht.select(
        vqsr_negative_train_site=gnomad_ht.info.NEGATIVE_TRAIN_SITE,
        vqsr_positive_train_site=gnomad_ht.info.POSITIVE_TRAIN_SITE,
        fail_hard_filters=(gnomad_ht.info.QD < 2) | (gnomad_ht.info.FS > 60) |
        (gnomad_ht.info.MQ < 30))
    lcr_intervals = hl.import_locus_intervals(lcr_intervals_path)

    ht = ht.annotate(
        **ht_truth_data[ht.key],
        **fam_ht[ht.key],
        **gnomad_ht[ht.key],
        **denovo_ht[ht.key],
        clinvar=hl.is_defined(clinvar_ht[ht.key]),
        indel_length=hl.abs(ht.alleles[0].length() - ht.alleles[1].length()),
        rank_bins=hl.array([
            hl.Struct(
                rank_id=rank_name,
                bin=hl.int(
                    hl.ceil(
                        hl.float(ht[rank_name] + 1) / hl.floor(
                            ht.globals.rank_variant_counts[rank_name][hl.cond(
                                hl.is_snp(ht.alleles[0], ht.alleles[1]), 'snv',
                                'indel')] / n_bins))))
            for rank_name in rank_variant_counts
        ]),
        lcr=hl.is_defined(lcr_intervals[ht.locus]))

    ht = ht.explode(ht.rank_bins)
    ht = ht.transmute(rank_id=ht.rank_bins.rank_id, bin=ht.rank_bins.bin)
    ht = ht.filter(hl.is_defined(ht.bin))

    ht = ht.checkpoint(
        f'gs://gnomad-tmp/gnomad_score_binning_{data_type}_tmp_{data}.ht',
        overwrite=True)

    # Create binned data
    return (ht.group_by(
        rank_id=ht.rank_id,
        contig=ht.locus.contig,
        snv=hl.is_snp(ht.alleles[0], ht.alleles[1]),
        bi_allelic=hl.is_defined(ht.biallelic_rank),
        singleton=ht.singleton,
        release_adj=ht.ac > 0,
        bin=ht.bin)._set_buffer_size(20000).aggregate(
            min_score=hl.agg.min(ht.score),
            max_score=hl.agg.max(ht.score),
            n=hl.agg.count(),
            n_ins=hl.agg.count_where(
                hl.is_insertion(ht.alleles[0], ht.alleles[1])),
            n_del=hl.agg.count_where(
                hl.is_deletion(ht.alleles[0], ht.alleles[1])),
            n_ti=hl.agg.count_where(
                hl.is_transition(ht.alleles[0], ht.alleles[1])),
            n_tv=hl.agg.count_where(
                hl.is_transversion(ht.alleles[0], ht.alleles[1])),
            n_1bp_indel=hl.agg.count_where(ht.indel_length == 1),
            n_mod3bp_indel=hl.agg.count_where((ht.indel_length % 3) == 0),
            n_clinvar=hl.agg.count_where(ht.clinvar),
            n_singleton=hl.agg.count_where(ht.singleton),
            n_validated_de_novos=hl.agg.count_where(ht.validated_denovo),
            n_high_confidence_de_novos=hl.agg.count_where(
                ht.high_confidence_denovo),
            n_de_novo=hl.agg.filter(
                ht.family_stats.unrelated_qc_callstats.AC[1] == 0,
                hl.agg.sum(ht.family_stats.mendel.errors)),
            n_de_novo_no_lcr=hl.agg.filter(
                ~ht.lcr & (ht.family_stats.unrelated_qc_callstats.AC[1] == 0),
                hl.agg.sum(ht.family_stats.mendel.errors)),
            n_de_novo_sites=hl.agg.filter(
                ht.family_stats.unrelated_qc_callstats.AC[1] == 0,
                hl.agg.count_where(ht.family_stats.mendel.errors > 0)),
            n_de_novo_sites_no_lcr=hl.agg.filter(
                ~ht.lcr & (ht.family_stats.unrelated_qc_callstats.AC[1] == 0),
                hl.agg.count_where(ht.family_stats.mendel.errors > 0)),
            n_trans_singletons=hl.agg.filter(
                (ht.info_ac < 3) &
                (ht.family_stats.unrelated_qc_callstats.AC[1] == 1),
                hl.agg.sum(ht.family_stats.tdt.t)),
            n_untrans_singletons=hl.agg.filter(
                (ht.info_ac < 3) &
                (ht.family_stats.unrelated_qc_callstats.AC[1] == 1),
                hl.agg.sum(ht.family_stats.tdt.u)),
            n_train_trans_singletons=hl.agg.count_where(
                (ht.family_stats.unrelated_qc_callstats.AC[1] == 1)
                & (ht.family_stats.tdt.t == 1)),
            n_omni=hl.agg.count_where(ht.truth_data.omni),
            n_mills=hl.agg.count_where(ht.truth_data.mills),
            n_hapmap=hl.agg.count_where(ht.truth_data.hapmap),
            n_kgp_high_conf_snvs=hl.agg.count_where(
                ht.truth_data.kgp_high_conf_snvs),
            fail_hard_filters=hl.agg.count_where(ht.fail_hard_filters),
            n_vqsr_pos_train=hl.agg.count_where(ht.vqsr_positive_train_site),
            n_vqsr_neg_train=hl.agg.count_where(ht.vqsr_negative_train_site)))
Ejemplo n.º 13
0
def get_summary_counts(
    ht: hl.Table,
    freq_field: str = "freq",
    filter_field: str = "filters",
    filter_decoy: bool = False,
    index: int = 0,
) -> hl.Table:
    """
    Generate a struct with summary counts across variant categories.

    Summary counts:
        - Number of variants
        - Number of indels
        - Number of SNVs
        - Number of LoF variants
        - Number of LoF variants that pass LOFTEE (including with LoF flags)
        - Number of LoF variants that pass LOFTEE without LoF flags
        - Number of OS (other splice) variants annotated by LOFTEE
        - Number of LoF variants that fail LOFTEE filters

    Also annotates Table's globals with total variant counts.

    Before calculating summary counts, function:
        - Filters out low confidence regions
        - Filters to canonical transcripts
        - Uses the most severe consequence

    Assumes that:
        - Input HT is annotated with VEP.
        - Multiallelic variants have been split and/or input HT contains bi-allelic variants only.
        - freq_expr was calculated with `annotate_freq`.
        - (Frequency index 0 from `annotate_freq` is frequency for all pops calculated on adj genotypes only.)

    :param ht: Input Table.
    :param freq_field: Name of field in HT containing frequency annotation (array of structs). Default is "freq".
    :param filter_field: Name of field in HT containing variant filter information. Default is "filters".
    :param filter_decoy: Whether to filter decoy regions. Default is False.
    :param index: Which index of freq_expr to use for annotation. Default is 0.
    :return: Table grouped by frequency bin and aggregated across summary count categories.
    """
    logger.info("Checking if multi-allelic variants have been split...")
    max_alleles = ht.aggregate(hl.agg.max(hl.len(ht.alleles)))
    if max_alleles > 2:
        logger.info(
            "Splitting multi-allelics and VEP transcript consequences...")
        ht = hl.split_multi_hts(ht)

    logger.info("Filtering to PASS variants in high confidence regions...")
    ht = ht.filter((hl.len(ht[filter_field]) == 0))
    ht = filter_low_conf_regions(ht, filter_decoy=filter_decoy)

    logger.info(
        "Filtering to canonical transcripts and getting VEP summary annotations..."
    )
    ht = filter_vep_to_canonical_transcripts(ht)
    ht = get_most_severe_consequence_for_summary(ht)

    logger.info("Annotating with frequency bin information...")
    ht = ht.annotate(freq_bin=freq_bin_expr(ht[freq_field], index))

    logger.info(
        "Annotating HT globals with total counts/total allele counts per variant category..."
    )
    summary_counts = ht.aggregate(
        hl.struct(**get_summary_counts_dict(
            ht.locus,
            ht.alleles,
            ht.lof,
            ht.no_lof_flags,
            ht.most_severe_csq,
            prefix_str="total_",
        )))
    summary_ac_counts = ht.aggregate(
        hl.struct(**get_summary_ac_dict(
            ht[freq_field][index].AC,
            ht.lof,
            ht.no_lof_flags,
            ht.most_severe_csq,
        )))
    ht = ht.annotate_globals(summary_counts=summary_counts.annotate(
        **summary_ac_counts))
    return ht.group_by("freq_bin").aggregate(**get_summary_counts_dict(
        ht.locus,
        ht.alleles,
        ht.lof,
        ht.no_lof_flags,
        ht.most_severe_csq,
    ))
Ejemplo n.º 14
0
def generate_final_filter_ht(
    ht: hl.Table,
    model_name: str,
    score_name: str,
    ac0_filter_expr: hl.expr.BooleanExpression,
    ts_ac_filter_expr: hl.expr.BooleanExpression,
    mono_allelic_flag_expr: hl.expr.BooleanExpression,
    inbreeding_coeff_cutoff: float = INBREEDING_COEFF_HARD_CUTOFF,
    snp_bin_cutoff: int = None,
    indel_bin_cutoff: int = None,
    snp_score_cutoff: float = None,
    indel_score_cutoff: float = None,
    aggregated_bin_ht: Optional[hl.Table] = None,
    bin_id: Optional[str] = None,
    vqsr_ht: hl.Table = None,
) -> hl.Table:
    """
    Prepares finalized filtering model given a filtering HT from `rf.apply_rf_model` or VQSR and cutoffs for filtering.

    .. note::

        - `snp_bin_cutoff` and `snp_score_cutoff` are mutually exclusive, and one must be supplied.
        - `indel_bin_cutoff` and `indel_score_cutoff` are mutually exclusive, and one must be supplied.
        - If a `snp_bin_cutoff` or `indel_bin_cutoff` cutoff is supplied then an `aggregated_bin_ht` and `bin_id` must
          also be supplied to determine the SNP and indel scores to use as cutoffs from an aggregated bin Table like
          one created by `compute_grouped_binned_ht` in combination with `score_bin_agg`.

    :param ht: Filtering Table from `rf.apply_rf_model` or VQSR to prepare as the final filter Table
    :param model_name: Filtering model name to use in the 'filters' field (VQSR or RF)
    :param score_name: Name to use for the filtering score annotation. This will be used in place of 'score' in the
        release HT info struct and the INFO field of the VCF (e.g. RF or AS_VQSLOD)
    :param ac0_filter_expr: Expression that indicates if a variant should be filtered as allele count 0 (AC0)
    :param ts_ac_filter_expr: Allele count expression in `ht` to use as a filter for determining a transmitted singleton
    :param mono_allelic_flag_expr: Expression indicating if a variant is mono-allelic
    :param inbreeding_coeff_cutoff: InbreedingCoeff hard filter to use for variants
    :param snp_bin_cutoff: Bin cutoff to use for SNP variant QC filter. Can't be used with `snp_score_cutoff`
    :param indel_bin_cutoff: Bin cutoff to use for indel variant QC filter. Can't be used with `indel_score_cutoff`
    :param snp_score_cutoff: Score cutoff (e.g. RF probability or AS_VQSLOD) to use for SNP variant QC filter. Can't be used with `snp_bin_cutoff`
    :param indel_score_cutoff: Score cutoff (e.g. RF probability or AS_VQSLOD) to use for indel variant QC filter. Can't be used with `indel_bin_cutoff`
    :param aggregated_bin_ht: Table with aggregate counts of variants based on bins
    :param bin_id: Name of bin to use in 'bin_id' column of `aggregated_bin_ht` to use to determine probability cutoff
    :param vqsr_ht: If a VQSR HT is supplied a 'vqsr' annotation containing AS_VQSLOD, AS_culprit, NEGATIVE_TRAIN_SITE,
        and POSITIVE_TRAIN_SITE will be included in the returned Table
    :return: Finalized random forest Table annotated with variant filters
    """
    if snp_bin_cutoff is not None and snp_score_cutoff is not None:
        raise ValueError(
            "snp_bin_cutoff and snp_score_cutoff are mutually exclusive, please only supply one SNP filtering cutoff."
        )

    if indel_bin_cutoff is not None and indel_score_cutoff is not None:
        raise ValueError(
            "indel_bin_cutoff and indel_score_cutoff are mutually exclusive, please only supply one indel filtering cutoff."
        )

    if snp_bin_cutoff is None and snp_score_cutoff is None:
        raise ValueError(
            "One (and only one) of the parameters snp_bin_cutoff and snp_score_cutoff must be supplied."
        )

    if indel_bin_cutoff is None and indel_score_cutoff is None:
        raise ValueError(
            "One (and only one) of the parameters indel_bin_cutoff and indel_score_cutoff must be supplied."
        )

    if (snp_bin_cutoff is not None or indel_bin_cutoff
            is not None) and (aggregated_bin_ht is None or bin_id is None):
        raise ValueError(
            "If using snp_bin_cutoff or indel_bin_cutoff, both aggregated_bin_ht and bin_id must be supplied"
        )

    # Determine SNP and indel score cutoffs if given bin instead of score
    if snp_bin_cutoff:
        snp_score_cutoff = aggregated_bin_ht.aggregate(
            hl.agg.filter(
                aggregated_bin_ht.snv
                & (aggregated_bin_ht.bin_id == bin_id)
                & (aggregated_bin_ht.bin == snp_bin_cutoff),
                hl.agg.min(aggregated_bin_ht.min_score),
            ))
        snp_cutoff_global = hl.struct(bin=snp_bin_cutoff,
                                      min_score=snp_score_cutoff)

    if indel_bin_cutoff:
        indel_score_cutoff = aggregated_bin_ht.aggregate(
            hl.agg.filter(
                ~aggregated_bin_ht.snv
                & (aggregated_bin_ht.bin_id == bin_id)
                & (aggregated_bin_ht.bin == indel_bin_cutoff),
                hl.agg.min(aggregated_bin_ht.min_score),
            ))
        indel_cutoff_global = hl.struct(bin=indel_bin_cutoff,
                                        min_score=indel_score_cutoff)

    min_score = ht.aggregate(hl.agg.min(ht.score))
    max_score = ht.aggregate(hl.agg.max(ht.score))

    if snp_score_cutoff:
        if snp_score_cutoff < min_score or snp_score_cutoff > max_score:
            raise ValueError(
                "snp_score_cutoff is not within the range of score.")
        snp_cutoff_global = hl.struct(min_score=snp_score_cutoff)

    if indel_score_cutoff:
        if indel_score_cutoff < min_score or indel_score_cutoff > max_score:
            raise ValueError(
                "indel_score_cutoff is not within the range of score.")
        indel_cutoff_global = hl.struct(min_score=indel_score_cutoff)

    logger.info(
        f"Using a SNP score cutoff of {snp_score_cutoff} and an indel score cutoff of {indel_score_cutoff}."
    )

    # Add filters to HT
    filters = dict()

    if ht.any(hl.is_missing(ht.score)):
        ht.filter(hl.is_missing(ht.score)).show()
        raise ValueError("Missing Score!")

    filters[model_name] = (hl.is_missing(ht.score)
                           | (hl.is_snp(ht.alleles[0], ht.alleles[1])
                              & (ht.score < snp_cutoff_global.min_score))
                           | (~hl.is_snp(ht.alleles[0], ht.alleles[1])
                              & (ht.score < indel_cutoff_global.min_score)))

    filters["InbreedingCoeff"] = hl.or_else(
        ht.InbreedingCoeff < inbreeding_coeff_cutoff, False)
    filters["AC0"] = ac0_filter_expr

    annotations_expr = dict()
    if model_name == "RF":
        # Fix annotations for release
        annotations_expr = annotations_expr.update({
            "positive_train_site":
            hl.or_else(ht.positive_train_site, False),
            "rf_tp_probability":
            ht.rf_probability["TP"],
        })
    annotations_expr.update({
        "transmitted_singleton":
        hl.or_missing(ts_ac_filter_expr, ht.transmitted_singleton)
    })
    if "feature_imputed" in ht.row:
        annotations_expr.update({
            x: hl.or_missing(~ht.feature_imputed[x], ht[x])
            for x in [f for f in ht.row.feature_imputed]
        })

    ht = ht.transmute(
        filters=add_filters_expr(filters=filters),
        monoallelic=mono_allelic_flag_expr,
        **{score_name: ht.score},
        **annotations_expr,
    )

    bin_names = [x for x in ht.row if x.endswith("bin")]
    bin_names = [(
        x,
        x.split("adj_")[0] +
        x.split("adj_")[1] if len(x.split("adj_")) == 2 else "raw_" + x,
    ) for x in bin_names]
    ht = ht.transmute(**{j: ht[i] for i, j in bin_names})

    ht = ht.annotate_globals(
        bin_stats=hl.struct(**{j: ht.bin_stats[i]
                               for i, j in bin_names}),
        filtering_model=hl.struct(
            model_name=model_name,
            score_name=score_name,
            snv_cutoff=snp_cutoff_global,
            indel_cutoff=indel_cutoff_global,
        ),
        inbreeding_coeff_cutoff=inbreeding_coeff_cutoff,
    )
    if vqsr_ht:
        vqsr = vqsr_ht[ht.key]
        ht = ht.annotate(
            vqsr=hl.struct(
                AS_VQSLOD=vqsr.info.AS_VQSLOD,
                AS_culprit=vqsr.info.AS_culprit,
                NEGATIVE_TRAIN_SITE=vqsr.info.NEGATIVE_TRAIN_SITE,
                POSITIVE_TRAIN_SITE=vqsr.info.POSITIVE_TRAIN_SITE,
            ),
            SOR=vqsr.info.
            SOR,  # NOTE: This was required for v3.1, we now compute this in `get_site_info_expr`
        )

    ht = ht.drop("AS_culprit")

    return ht
Ejemplo n.º 15
0
def compute_stratified_metrics_filter(
    ht: hl.Table,
    qc_metrics: Dict[str, hl.expr.NumericExpression],
    strata: Optional[Dict[str, hl.expr.Expression]] = None,
    lower_threshold: float = 4.0,
    upper_threshold: float = 4.0,
    metric_threshold: Optional[Dict[str, Tuple[float, float]]] = None,
    filter_name: str = "qc_metrics_filters",
) -> hl.Table:
    """
    Compute median, MAD, and upper and lower thresholds for each metric used in outlier filtering.

    :param ht: HT containing relevant sample QC metric annotations
    :param qc_metrics: list of metrics (name and expr) for which to compute the critical values for filtering outliers
    :param strata: List of annotations used for stratification. These metrics should be discrete types!
    :param lower_threshold: Lower MAD threshold
    :param upper_threshold: Upper MAD threshold
    :param metric_threshold: Can be used to specify different (lower, upper) thresholds for one or more metrics
    :param filter_name: Name of resulting filters annotation
    :return: Table grouped by strata, with upper and lower threshold values computed for each sample QC metric
    """
    _metric_threshold = {
        metric: (lower_threshold, upper_threshold)
        for metric in qc_metrics
    }
    if metric_threshold is not None:
        _metric_threshold.update(metric_threshold)

    def make_filters_expr(ht: hl.Table,
                          qc_metrics: Iterable[str]) -> hl.expr.SetExpression:
        return hl.set(
            hl.filter(
                lambda x: hl.is_defined(x),
                [
                    hl.or_missing(ht[f"fail_{metric}"], metric)
                    for metric in qc_metrics
                ],
            ))

    if strata is None:
        strata = {}

    ht = ht.select(**qc_metrics, **strata).key_by("s").persist()

    agg_expr = hl.struct(
        **{
            metric: hl.bind(
                lambda x: x.annotate(
                    lower=x.median - _metric_threshold[metric][0] * x.mad,
                    upper=x.median + _metric_threshold[metric][1] * x.mad,
                ),
                get_median_and_mad_expr(ht[metric]),
            )
            for metric in qc_metrics
        })

    if strata:
        ht = ht.annotate_globals(qc_metrics_stats=ht.aggregate(
            hl.agg.group_by(hl.tuple([ht[x] for x in strata]), agg_expr),
            _localize=False,
        ))
        metrics_stats_expr = ht.qc_metrics_stats[hl.tuple(
            [ht[x] for x in strata])]
    else:
        ht = ht.annotate_globals(
            qc_metrics_stats=ht.aggregate(agg_expr, _localize=False))
        metrics_stats_expr = ht.qc_metrics_stats

    fail_exprs = {
        f"fail_{metric}": (ht[metric] <= metrics_stats_expr[metric].lower)
        | (ht[metric] >= metrics_stats_expr[metric].upper)
        for metric in qc_metrics
    }
    ht = ht.transmute(**fail_exprs)
    stratified_filters = make_filters_expr(ht, qc_metrics)
    return ht.annotate(**{filter_name: stratified_filters})
Ejemplo n.º 16
0
def compute_related_samples_to_drop(
    relatedness_ht: hl.Table,
    rank_ht: hl.Table,
    kin_threshold: float,
    filtered_samples: Optional[hl.expr.SetExpression] = None,
    min_related_hard_filter: Optional[int] = None,
) -> hl.Table:
    """
    Computes a Table with the list of samples to drop (and their global rank) to get the maximal independent set of unrelated samples.

    .. note::

        - `relatedness_ht` should be keyed by exactly two fields of the same type, identifying the pair of samples for each row.
        - `rank_ht` should be keyed by a single key of the same type as a single sample identifier in `relatedness_ht`.

    :param relatedness_ht: relatedness HT, as produced by e.g. pc-relate
    :param kin_threshold: Kinship threshold to consider two samples as related
    :param rank_ht: Table with a global rank for each sample (smaller is preferred)
    :param filtered_samples: An optional set of samples to exclude (e.g. these samples were hard-filtered)  These samples will then appear in the resulting samples to drop.
    :param min_related_hard_filter: If provided, any sample that is related to more samples than this parameter will be filtered prior to computing the maximal independent set and appear in the results.
    :return: A Table with the list of the samples to drop along with their rank.
    """

    # Make sure that the key types are valid
    assert len(list(relatedness_ht.key)) == 2
    assert relatedness_ht.key[0].dtype == relatedness_ht.key[1].dtype
    assert len(list(rank_ht.key)) == 1
    assert relatedness_ht.key[0].dtype == rank_ht.key[0].dtype

    logger.info(
        f"Filtering related samples using a kin threshold of {kin_threshold}")
    relatedness_ht = relatedness_ht.filter(relatedness_ht.kin > kin_threshold)

    filtered_samples_rel = set()
    if min_related_hard_filter is not None:
        logger.info(
            f"Computing samples related to too many individuals (>{min_related_hard_filter}) for exclusion"
        )
        gbi = relatedness_ht.annotate(s=list(relatedness_ht.key))
        gbi = gbi.explode(gbi.s)
        gbi = gbi.group_by(gbi.s).aggregate(n=hl.agg.count())
        filtered_samples_rel = gbi.aggregate(
            hl.agg.filter(gbi.n > min_related_hard_filter,
                          hl.agg.collect_as_set(gbi.s)))
        logger.info(
            f"Found {len(filtered_samples_rel)} samples with too many 1st/2nd degree relatives. These samples will be excluded."
        )

    if filtered_samples is not None:
        filtered_samples_rel = filtered_samples_rel.union(
            relatedness_ht.aggregate(
                hl.agg.explode(
                    lambda s: hl.agg.collect_as_set(s),
                    hl.array(list(relatedness_ht.key)).filter(
                        lambda s: filtered_samples.contains(s)),
                )))

    if len(filtered_samples_rel) > 0:
        filtered_samples_lit = hl.literal(filtered_samples_rel)
        relatedness_ht = relatedness_ht.filter(
            filtered_samples_lit.contains(relatedness_ht.key[0])
            | filtered_samples_lit.contains(relatedness_ht.key[1]),
            keep=False,
        )

    logger.info("Annotating related sample pairs with rank.")
    i, j = list(relatedness_ht.key)
    relatedness_ht = relatedness_ht.key_by(s=relatedness_ht[i])
    relatedness_ht = relatedness_ht.annotate(**{
        i:
        hl.struct(s=relatedness_ht.s, rank=rank_ht[relatedness_ht.key].rank)
    })
    relatedness_ht = relatedness_ht.key_by(s=relatedness_ht[j])
    relatedness_ht = relatedness_ht.annotate(**{
        j:
        hl.struct(s=relatedness_ht.s, rank=rank_ht[relatedness_ht.key].rank)
    })
    relatedness_ht = relatedness_ht.key_by(i, j)
    relatedness_ht = relatedness_ht.drop("s")
    relatedness_ht = relatedness_ht.persist()

    related_samples_to_drop_ht = hl.maximal_independent_set(
        relatedness_ht[i],
        relatedness_ht[j],
        keep=False,
        tie_breaker=lambda l, r: l.rank - r.rank,
    )
    related_samples_to_drop_ht = related_samples_to_drop_ht.key_by()
    related_samples_to_drop_ht = related_samples_to_drop_ht.select(
        **related_samples_to_drop_ht.node)
    related_samples_to_drop_ht = related_samples_to_drop_ht.key_by("s")

    if len(filtered_samples_rel) > 0:
        related_samples_to_drop_ht = related_samples_to_drop_ht.union(
            hl.Table.parallelize(
                [
                    hl.struct(s=s, rank=hl.null(hl.tint64))
                    for s in filtered_samples_rel
                ],
                key="s",
            ))

    return related_samples_to_drop_ht
Ejemplo n.º 17
0
def median_impute_features(
        ht: hl.Table,
        strata: Optional[Dict[str, hl.expr.Expression]] = None) -> hl.Table:
    """
    Numerical features in the Table are median-imputed by Hail's `approx_median`.

    If a `strata` dict is given, imputation is done based on the median of of each stratification.

    The annotations that are added to the Table are
        - feature_imputed - A row annotation indicating if each numerical feature was imputed or not.
        - features_median - A global annotation containing the median of the numerical features. If `strata` is given,
          this struct will also be broken down by the given strata.
        - variants_by_strata - An additional global annotation with the variant counts by strata that will only be
          added if imputing by a given `strata`.

    :param ht: Table containing all samples and features for median imputation.
    :param strata: Whether to impute features median by specific strata (default False).
    :return: Feature Table imputed using approximate median values.
    """

    logger.info(
        "Computing feature medians for imputation of missing numeric values")
    numerical_features = [
        k for k, v in ht.row.dtype.items() if v == hl.tint or v == hl.tfloat
    ]

    median_agg_expr = hl.struct(
        **{
            feature: hl.agg.approx_median(ht[feature])
            for feature in numerical_features
        })

    if strata:
        ht = ht.annotate_globals(
            feature_medians=ht.aggregate(
                hl.agg.group_by(hl.tuple([ht[x] for x in strata]),
                                median_agg_expr),
                _localize=False,
            ),
            variants_by_strata=ht.aggregate(hl.agg.counter(
                hl.tuple([ht[x] for x in strata])),
                                            _localize=False),
        )
        feature_median_expr = ht.feature_medians[hl.tuple(
            [ht[x] for x in strata])]
        logger.info("Variant count by strata:\n{}".format("\n".join([
            "{}: {}".format(k, v)
            for k, v in hl.eval(ht.variants_by_strata).items()
        ])))

    else:
        ht = ht.annotate_globals(
            feature_medians=ht.aggregate(median_agg_expr, _localize=False))
        feature_median_expr = ht.feature_medians

    ht = ht.annotate(
        **{
            f: hl.or_else(ht[f], feature_median_expr[f])
            for f in numerical_features
        },
        feature_imputed=hl.struct(
            **{f: hl.is_missing(ht[f])
               for f in numerical_features}),
    )

    return ht
Ejemplo n.º 18
0
def liftover_intervals(t: hl.Table,
                       keep_missing_interval: bool = False) -> hl.Table:
    """
    Liftover locus in intervals from one coordinate system (hg37) to another (hg38)

    # Example input table description
    #
    # ----------------------------------------
    # Global fields:
    #     None
    # ----------------------------------------
    # Row fields:
    #     'interval': interval<locus<GRCh37>>
    # ----------------------------------------
    # Key: ['interval']
    # ----------------------------------------


    :param t: Table of intervals on GRCh37
    :param keep_missing_interval: If True, keep missing (non-lifted) intervals in the output Table.
    :return: Table with intervals lifted over GRCh38 added.
    """

    rg37 = hl.get_reference("GRCh37")
    rg38 = hl.get_reference("GRCh38")

    if not rg37.has_liftover("GRCh38"):
        rg37.add_liftover(
            f'{nfs_dir}/resources/liftover/grch37_to_grch38.over.chain.gz',
            rg38)

    t = t.annotate(
        start=hl.liftover(t.interval.start, "GRCh38"),
        end=hl.liftover(t.interval.end, "GRCh38"),
    )

    t = t.filter((t.start.contig == "chr" + t.interval.start.contig)
                 & (t.end.contig == "chr" + t.interval.end.contig))

    t = t.key_by()

    t = (t.select(interval=hl.locus_interval(t.start.contig,
                                             t.start.position,
                                             t.end.position,
                                             reference_genome=rg38,
                                             invalid_missing=True),
                  interval_hg37=t.interval))

    # bad intervals
    missing = t.aggregate(hl.agg.counter(~hl.is_defined(t.interval)))
    logger.info(
        f"Number of missing intervals: {missing[True]} out of {t.count()}...")

    # update globals annotations
    global_ann_expr = {
        'date': current_date(),
        'reference_genome': 'GRCh38',
        'was_lifted': True
    }
    t = t.annotate_globals(**global_ann_expr)

    if not keep_missing_interval:
        logger.info(f"Filtering out {missing[True]} missing intervals...")
        t = t.filter(hl.is_defined(t.interval), keep=True)

    return t.key_by("interval")
Ejemplo n.º 19
0
def get_ploidy_cutoffs(
    ht: hl.Table,
    f_stat_cutoff: float,
    normal_ploidy_cutoff: int = 5,
    aneuploidy_cutoff: int = 6,
) -> Tuple[Tuple[float, Tuple[float, float], float], Tuple[Tuple[float, float],
                                                           float]]:
    """
    Gets chromosome X and Y ploidy cutoffs for XY and XX samples. Note this assumes the input hail Table has the fields f_stat, chrX_ploidy, and chrY_ploidy.
    Returns a tuple of sex chromosome ploidy cutoffs: ((x_ploidy_cutoffs), (y_ploidy_cutoffs)).
    x_ploidy_cutoffs: (upper cutoff for single X, (lower cutoff for double X, upper cutoff for double X), lower cutoff for triple X)
    y_ploidy_cutoffs: ((lower cutoff for single Y, upper cutoff for single Y), lower cutoff for double Y)

    Uses the normal_ploidy_cutoff parameter to determine the ploidy cutoffs for XX and XY karyotypes.
    Uses the aneuploidy_cutoff parameter to determine the cutoffs for sex aneuploidies.

    Note that f-stat is used only to split the samples into roughly 'XX' and 'XY' categories and is not used in the final karyotype annotation.

    :param ht: Table with f_stat and sex chromosome ploidies
    :param f_stat_cutoff: f-stat to roughly divide 'XX' from 'XY' samples. Assumes XX samples are below cutoff and XY are above cutoff.
    :param normal_ploidy_cutoff: Number of standard deviations to use when determining sex chromosome ploidy cutoffs for XX, XY karyotypes.
    :param aneuploidy_cutoff: Number of standard deviations to use when sex chromosome ploidy cutoffs for aneuploidies.
    :return: Tuple of ploidy cutoff tuples: ((x_ploidy_cutoffs), (y_ploidy_cutoffs))
    """
    # Group sex chromosome ploidy table by f_stat cutoff and get mean/stdev for chrX/Y ploidies
    sex_stats = ht.aggregate(
        hl.agg.group_by(
            hl.cond(ht.f_stat < f_stat_cutoff, "xx", "xy"),
            hl.struct(x=hl.agg.stats(ht.chrX_ploidy),
                      y=hl.agg.stats(ht.chrY_ploidy)),
        ))
    logger.info(f"XX stats: {sex_stats['xx']}")
    logger.info(f"XY stats: {sex_stats['xy']}")

    cutoffs = (
        (
            sex_stats["xy"].x.mean +
            (normal_ploidy_cutoff * sex_stats["xy"].x.stdev),
            (
                sex_stats["xx"].x.mean -
                (normal_ploidy_cutoff * sex_stats["xx"].x.stdev),
                sex_stats["xx"].x.mean +
                (normal_ploidy_cutoff * sex_stats["xx"].x.stdev),
            ),
            sex_stats["xx"].x.mean +
            (aneuploidy_cutoff * sex_stats["xx"].x.stdev),
        ),
        (
            (
                sex_stats["xx"].y.mean +
                (normal_ploidy_cutoff * sex_stats["xx"].y.stdev),
                sex_stats["xy"].y.mean +
                (normal_ploidy_cutoff * sex_stats["xy"].y.stdev),
            ),
            sex_stats["xy"].y.mean +
            (aneuploidy_cutoff * sex_stats["xy"].y.stdev),
        ),
    )

    logger.info(f"X ploidy cutoffs: {cutoffs[0]}")
    logger.info(f"Y ploidy cutoffs: {cutoffs[1]}")
    return cutoffs
Ejemplo n.º 20
0
def annotate_unphased_pairs(unphased_ht: hl.Table, n_variant_pairs: int,
                            least_consequence: str, max_af: float):
    # unphased_ht = vp_ht.filter(hl.is_missing(vp_ht.all_phase))
    # unphased_ht = unphased_ht.key_by()

    # Explode variant pairs
    unphased_ht = unphased_ht.annotate(las=[
        hl.tuple([unphased_ht.locus1, unphased_ht.alleles1]),
        hl.tuple([unphased_ht.locus2, unphased_ht.alleles2])
    ]).explode('las', name='la')

    unphased_ht = unphased_ht.key_by(
        locus=unphased_ht.la[0], alleles=unphased_ht.la[1]).persist(
        )  # .checkpoint('gs://gnomad-tmp/vp_ht_unphased.ht')

    # Annotate single variants with gnomAD freq
    gnomad_ht = gnomad.public_release('exomes').ht()
    gnomad_ht = gnomad_ht.semi_join(unphased_ht).repartition(
        ceil(n_variant_pairs / 10000), shuffle=True).persist()

    missing_freq = hl.struct(
        AC=0,
        AF=0,
        AN=125748 * 2,  # set to no missing for now
        homozygote_count=0)

    logger.info(
        f"{gnomad_ht.count()}/{unphased_ht.count()} single variants from the unphased pairs found in gnomAD."
    )

    gnomad_indexed = gnomad_ht[unphased_ht.key]
    gnomad_freq = gnomad_indexed.freq
    unphased_ht = unphased_ht.annotate(
        adj_freq=hl.or_else(gnomad_freq[0], missing_freq),
        raw_freq=hl.or_else(gnomad_freq[1], missing_freq),
        vep_genes=vep_genes_expr(gnomad_indexed.vep, least_consequence),
        max_af_filter=gnomad_indexed.freq[0].AF <= max_af
        # pop_max_freq=hl.or_else(
        #     gnomad_exomes.popmax[0],
        #     missing_freq.annotate(
        #         pop=hl.null(hl.tstr)
        #     )
        # )
    )
    unphased_ht = unphased_ht.persist()
    # unphased_ht = unphased_ht.checkpoint('gs://gnomad-tmp/unphased_ann.ht', overwrite=True)

    loci_expr = hl.sorted(
        hl.agg.collect(
            hl.tuple([
                unphased_ht.locus,
                hl.struct(
                    adj_freq=unphased_ht.adj_freq,
                    raw_freq=unphased_ht.raw_freq,
                    # pop_max_freq=unphased_ht.pop_max_freq
                )
            ])),
        lambda x: x[0]  # sort by locus
    ).map(lambda x: x[1]  # get rid of locus
          )

    vp_freq_expr = hl.struct(v1=loci_expr[0], v2=loci_expr[1])

    # [AABB, AABb, AAbb, AaBB, AaBb, Aabb, aaBB, aaBb, aabb]
    def get_gt_counts(freq: str):
        return hl.array([
            hl.min(vp_freq_expr.v1[freq].AN, vp_freq_expr.v2[freq].AN),  # AABB
            vp_freq_expr.v2[freq].AC -
            (2 * vp_freq_expr.v2[freq].homozygote_count),  # AABb
            vp_freq_expr.v2[freq].homozygote_count,  # AAbb
            vp_freq_expr.v1[freq].AC -
            (2 * vp_freq_expr.v1[freq].homozygote_count),  # AaBB
            0,  # AaBb
            0,  # Aabb
            vp_freq_expr.v1[freq].homozygote_count,  # aaBB
            0,  # aaBb
            0  # aabb
        ])

    gt_counts_raw_expr = get_gt_counts('raw_freq')
    gt_counts_adj_expr = get_gt_counts('adj_freq')

    # gt_counts_pop_max_expr = get_gt_counts('pop_max_freq')
    unphased_ht = unphased_ht.group_by(
        unphased_ht.locus1, unphased_ht.alleles1, unphased_ht.locus2,
        unphased_ht.alleles2
    ).aggregate(
        pop='all',  # TODO Add option for multiple pops?
        phase_info=hl.struct(gt_counts=hl.struct(raw=gt_counts_raw_expr,
                                                 adj=gt_counts_adj_expr),
                             em=hl.struct(
                                 raw=get_em_expr(gt_counts_raw_expr),
                                 adj=get_em_expr(gt_counts_raw_expr))),
        vep_genes=hl.agg.collect(
            unphased_ht.vep_genes).filter(lambda x: hl.len(x) > 0),
        max_af_filter=hl.agg.all(unphased_ht.max_af_filter)

        # pop_max_gt_counts_adj=gt_counts_raw_expr,
        # pop_max_em_p_chet_adj=get_em_expr(gt_counts_raw_expr).p_chet,
    )  # .key_by()

    unphased_ht = unphased_ht.transmute(
        vep_filter=(hl.len(unphased_ht.vep_genes) > 1)
        & (hl.len(unphased_ht.vep_genes[0].intersection(
            unphased_ht.vep_genes[1])) > 0))

    max_af_filtered, vep_filtered = unphased_ht.aggregate([
        hl.agg.count_where(~unphased_ht.max_af_filter),
        hl.agg.count_where(~unphased_ht.vep_filter)
    ])
    if max_af_filtered > 0:
        logger.info(
            f"{max_af_filtered} variant-pairs excluded because the AF of at least one variant was > {max_af}"
        )
    if vep_filtered > 0:
        logger.info(
            f"{vep_filtered} variant-pairs excluded because the variants were not found within the same gene with a csq of at least {least_consequence}"
        )

    unphased_ht = unphased_ht.filter(unphased_ht.max_af_filter
                                     & unphased_ht.vep_filter)

    return unphased_ht.drop('max_af_filter', 'vep_filter')