Esempio n. 1
0
def generate_allele_data(ht: hl.Table) -> hl.Table:
    """
    Returns bi-allelic sites HT with the following annotations:
     - allele_data (nonsplit_alleles, has_star, variant_type, and n_alt_alleles)

    :param Table ht: Full unsplit HT
    :return: Table with allele data annotations
    :rtype: Table
    """
    ht = ht.select()
    allele_data = hl.struct(nonsplit_alleles=ht.alleles,
                            has_star=hl.any(lambda a: a == "*", ht.alleles))
    ht = ht.annotate(allele_data=allele_data.annotate(
        **add_variant_type(ht.alleles)))

    ht = hl.split_multi_hts(ht)
    ht = ht.filter(hl.len(ht.alleles) > 1)
    allele_type = (hl.case().when(
        hl.is_snp(ht.alleles[0], ht.alleles[1]),
        "snv").when(hl.is_insertion(ht.alleles[0], ht.alleles[1]),
                    "ins").when(hl.is_deletion(ht.alleles[0], ht.alleles[1]),
                                "del").default("complex"))
    ht = ht.annotate(allele_data=ht.allele_data.annotate(
        allele_type=allele_type,
        was_mixed=ht.allele_data.variant_type == "mixed"))
    return ht
def prepare_exomes(exome_ht: hl.Table, groupings: List, impose_high_af_cutoff_upfront: bool = True) -> hl.Table:

    # Manipulate VEP annotations and explode by them
    exome_ht = add_most_severe_csq_to_tc_within_ht(exome_ht)
    exome_ht = exome_ht.transmute(transcript_consequences=exome_ht.vep.transcript_consequences)
    exome_ht = exome_ht.explode(exome_ht.transcript_consequences)
    
    # Annotate variants with grouping variables. 
    exome_ht, grouping = annotate_constraint_groupings(exome_ht,groupings) # This function needs to be adapted
    exome_ht = exome_ht.select(
        'context', 'ref', 'alt', 'methylation_level', 'freq', 'pass_filters', *groupings)

    # Filter by allele count
    # Likely to need to adapt this function as well
    af_cutoff = 0.001
    freq_index = exome_ht.freq_index_dict.collect()[0][dataset]

    def keep_criteria(ht):
        crit = (ht.freq[freq_index].AC > 0) & ht.pass_filters & (ht.coverage > 0)
        if impose_high_af_cutoff_upfront:
            crit &= (ht.freq[freq_index].AF <= af_cutoff)
        return crit

    exome_ht = exome_ht.filter(keep_criteria(exome_ht))
    return exome_ht
Esempio n. 3
0
def filter_kin_ht(
    ht: hl.Table,
    out_summary: io.TextIOWrapper,
    first_degree_pi_hat: float = 0.40,
    grandparent_pi_hat: float = 0.20,
    grandparent_ibd1: float = 0.25,
    grandparent_ibd2: float = 0.15,
) -> hl.Table:
    """
    Filter the kinship table to relationships of grandparents and above.

    :param ht: hl.Table
    :param out_summary: Summary file with a summary statistics and notes
    :param first_degree_pi_hat: Minimum pi_hat threshold to use to filter the kinship table to first degree relatives
    :param grandparent_pi_hat: Minimum pi_hat threshold to use to filter the kinship table to grandparents
    :param grandparent_ibd1: Minimum IBD1 threshold to use to filter the kinship table to grandparents
    :param grandparent_ibd2: Maximum IBD2 threshold to use to filter the kinship table to grandparents
    :return: Table containing only relationships of grandparents and above
    """
    # Filter to anything above the relationship of a grandparent
    ht = ht.filter((ht.pi_hat > first_degree_pi_hat)
                   | ((ht.pi_hat > grandparent_pi_hat)
                      & (ht.ibd1 > grandparent_ibd1)
                      & (ht.ibd2 < grandparent_ibd2)))
    ht = ht.annotate(pair=hl.sorted([ht.i, ht.j]))

    out_summary.write(
        f"NOTE: kinship table was filtered to:\n(kin > {first_degree_pi_hat}) or kin > {grandparent_pi_hat} and IBD1 > {grandparent_ibd1} and IBD2 > {grandparent_ibd2})\n"
    )
    out_summary.write(
        f"relationships not meeting this critera were not evaluated\n\n")

    return ht
Esempio n. 4
0
def generate_sib_stats(
    mt: hl.MatrixTable,
    relatedness_ht: hl.Table,
    i_col: str = "i",
    j_col: str = "j",
    relationship_col: str = "relationship",
    autosomes_only: bool = True,
    bi_allelic_only: bool = True,
) -> hl.Table:
    """
    This is meant as a default wrapper for `generate_sib_stats_expr`. It returns a hail table with counts of variants
    shared by pairs of siblings in `relatedness_ht`.

    This function takes a hail Table with a row for each pair of individuals i,j in the data that are related (it's OK to have unrelated samples too).
    The `relationship_col` should be a column specifying the relationship between each two samples as defined by
    the constants in `gnomad.utils.relatedness`. This relationship_col will be used to filter to only pairs of
    samples that are annotated as `SIBLINGS`.

    .. note::

        By default this pipeline function will filter `mt` to only autosomes and bi-allelic sites.

    :param mt: Input Matrix table
    :param relatedness_ht: Input relationship table
    :param i_col: Column containing the 1st sample of the pair in the relationship table
    :param j_col: Column containing the 2nd sample of the pair in the relationship table
    :param relationship_col: Column containing the relationship for the sample pair as defined in this module constants.
    :param autosomes_only: If set, only autosomal intervals are used.
    :param bi_allelic_only: If set, only bi-allelic sites are used for the computation
    :return: A Table with the sibling shared variant counts
    """
    if autosomes_only:
        mt = filter_to_autosomes(mt)
    if bi_allelic_only:
        mt = mt.filter_rows(bi_allelic_expr(mt))

    sib_ht = relatedness_ht.filter(
        relatedness_ht[relationship_col] == SIBLINGS)
    s_to_keep = sib_ht.aggregate(
        hl.agg.explode(lambda s: hl.agg.collect_as_set(s),
                       [sib_ht[i_col].s, sib_ht[j_col].s]),
        _localize=False,
    )
    mt = mt.filter_cols(s_to_keep.contains(mt.s))
    if "adj" not in mt.entry:
        mt = annotate_adj(mt)

    sib_stats_ht = mt.select_rows(**generate_sib_stats_expr(
        mt,
        sib_ht,
        i_col=i_col,
        j_col=j_col,
        strata={
            "raw": True,
            "adj": mt.adj
        },
    )).rows()

    return sib_stats_ht
Esempio n. 5
0
def default_generate_sib_stats(
    mt: hl.MatrixTable,
    relatedness_ht: hl.Table,
    sex_ht: hl.Table,
    i_col: str = "i",
    j_col: str = "j",
    relationship_col: str = "relationship",
) -> hl.Table:
    """
    This is meant as a default wrapper for `generate_sib_stats_expr`. It returns a hail table with counts of variants
    shared by pairs of siblings in `relatedness_ht`.

    This function takes a hail Table with a row for each pair of individuals i,j in the data that are related (it's OK to have unrelated samples too).
    The `relationship_col` should be a column specifying the relationship between each two samples as defined by
    the constants in `gnomad.utils.relatedness`. This relationship_col will be used to filter to only pairs of
    samples that are annotated as `SIBLINGS`.

    :param mt: Input Matrix table
    :param relatedness_ht: Input relationship table
    :param sex_ht: A Table containing sex information for the samples
    :param i_col: Column containing the 1st sample of the pair in the relationship table
    :param j_col: Column containing the 2nd sample of the pair in the relationship table
    :param relationship_col: Column containing the relationship for the sample pair as defined in this module constants.
    :return: A Table with the sibling shared variant counts
    """
    sex_ht = sex_ht.annotate(
        is_female=hl.case()
        .when(sex_ht.sex_karyotype == "XX", True)
        .when(sex_ht.sex_karyotype == "XY", False)
        .or_missing()
    )

    # TODO: Change to use SIBLINGS constant when relatedness PR goes in
    sib_ht = relatedness_ht.filter(relatedness_ht[relationship_col] == "Siblings")
    s_to_keep = sib_ht.aggregate(
        hl.agg.explode(
            lambda s: hl.agg.collect_as_set(s), [sib_ht[i_col].s, sib_ht[j_col].s]
        ),
        _localize=False,
    )
    mt = mt.filter_cols(s_to_keep.contains(mt.s))
    mt = annotate_adj(mt)

    mt = mt.annotate_cols(is_female=sex_ht[mt.s].is_female)

    sib_stats_ht = mt.select_rows(
        **generate_sib_stats_expr(
            mt,
            sib_ht,
            i_col=i_col,
            j_col=j_col,
            strata={"raw": True, "adj": mt.adj},
            is_female=mt.is_female,
        )
    ).rows()

    return sib_stats_ht
Esempio n. 6
0
def explode_phase_info(ht: hl.Table, remove_all_ref: bool = True) -> hl.Table:
    ht = ht.transmute(phase_info=hl.array(ht.phase_info))
    ht = ht.explode('phase_info')
    ht = ht.transmute(pop=ht.phase_info[0], phase_info=ht.phase_info[1])

    if remove_all_ref:
        ht = ht.filter(hl.sum(ht.phase_info.gt_counts.raw[1:]) > 0)

    return ht
Esempio n. 7
0
def filter_ht_for_plink(ht: hl.Table,
                        n_samples: int,
                        min_call_rate: float = 0.95,
                        variants_per_mac_category: int = 2000,
                        variants_per_maf_category: int = 10000):
    from gnomad.utils.filtering import filter_to_autosomes
    ht = filter_to_autosomes(ht)
    ht = ht.filter((ht.call_stats.AN >= n_samples * 2 * min_call_rate)
                   & (ht.call_stats.AC > 0))
    ht = ht.annotate(mac_category=mac_category_case_builder(ht.call_stats))
    category_counter = ht.aggregate(hl.agg.counter(ht.mac_category))
    print(category_counter)
    ht = ht.annotate_globals(category_counter=category_counter)
    return ht.filter(
        hl.rand_unif(
            0, 1) < hl.cond(ht.mac_category >= 1, variants_per_mac_category,
                            variants_per_maf_category) /
        ht.category_counter[ht.mac_category])
def generic_field_check(
    ht: hl.Table,
    check_description: str,
    display_fields: hl.expr.StructExpression,
    cond_expr: hl.expr.BooleanExpression = None,
    verbose: bool = False,
    show_percent_sites: bool = False,
    n_fail: Optional[int] = None,
    ht_count: Optional[int] = None,
) -> None:
    """
    Check generic logical condition `cond_expr` involving annotations in a Hail Table when `n_fail` is absent and print the results to stdout.

    Displays the number of rows (and percent of rows, if `show_percent_sites` is True) in the Table that fail, either previously computed as `n_fail` or that match the `cond_expr`, and fail to be the desired condition (`check_description`).
    If the number of rows that match the `cond_expr` or `n_fail` is 0, then the Table passes that check; otherwise, it fails.

    .. note::

        `cond_expr` and `check_description` are opposites and should never be the same.
        E.g., If `cond_expr` filters for instances where the raw AC is less than adj AC,
        then it is checking sites that fail to be the desired condition (`check_description`)
        of having a raw AC greater than or equal to the adj AC.

    :param ht: Table containing annotations to be checked.
    :param check_description: String describing the condition being checked; is displayed in stdout summary message.
    :param display_fields: StructExpression containing annotations to be displayed in case of failure (for troubleshooting purposes); these fields are also displayed if verbose is True.
    :param cond_expr: Optional logical expression referring to annotations in ht to be checked.
    :param verbose: If True, show top values of annotations being checked, including checks that pass; if False, show only top values of annotations that fail checks.
    :param show_percent_sites: Show percentage of sites that fail checks. Default is False.
    :param n_fail: Optional number of sites that fail the conditional checks (previously computed). If not supplied, `cond_expr` is used to filter the Table and obtain the count of sites that fail the checks.
    :param ht_count: Optional number of sites within hail Table (previously computed). If not supplied, a count of sites in the Table is performed.
    :return: None
    """
    if (n_fail is None and cond_expr is None) or (n_fail and cond_expr):
        raise ValueError(
            "One and only one of n_fail or cond_expr must be defined!")

    if cond_expr:
        n_fail = ht.filter(cond_expr).count()

    if show_percent_sites and (ht_count is None):
        ht_count = ht.count()

    if n_fail > 0:
        logger.info("Found %d sites that fail %s check:", n_fail,
                    check_description)
        if show_percent_sites:
            logger.info("Percentage of sites that fail: %.2f %%",
                        100 * (n_fail / ht_count))
        ht.select(**display_fields).show()
    else:
        logger.info("PASSED %s check", check_description)
        if verbose:
            ht.select(**display_fields).show()
Esempio n. 9
0
def densify_sites(
    mt: hl.MatrixTable,
    sites_ht: hl.Table,
    last_END_positions_ht: hl.Table,
    semi_join_rows: bool = True,
) -> hl.MatrixTable:
    """
    Creates a dense version of the input sparse MT at the sites in `sites_ht` reading the minimal amount of data required.

    Note that only rows that appear both in `mt` and `sites_ht` are returned.

    :param mt: Input sparse MT
    :param sites_ht: Desired sites to densify
    :param last_END_positions_ht: Table storing positions of the furthest ref block (END tag)
    :param semi_join_rows: Whether to filter the MT rows based on semi-join (default, better if sites_ht is large) or based on filter_intervals (better if sites_ht only contains a few sites)
    :return: Dense MT filtered to the sites in `sites_ht`
    """
    logger.info("Computing intervals to densify from sites Table.")
    sites_ht = sites_ht.key_by("locus")
    sites_ht = sites_ht.annotate(
        interval=hl.locus_interval(
            sites_ht.locus.contig,
            last_END_positions_ht[sites_ht.key].last_END_position,
            end=sites_ht.locus.position,
            includes_end=True,
            reference_genome=sites_ht.locus.dtype.reference_genome,
        )
    )
    sites_ht = sites_ht.filter(hl.is_defined(sites_ht.interval))

    if semi_join_rows:
        mt = mt.filter_rows(hl.is_defined(sites_ht.key_by("interval")[mt.locus]))
    else:
        logger.info("Collecting intervals to densify.")
        intervals = sites_ht.interval.collect()

        print(
            "Found {0} intervals, totalling {1} bp in the dense Matrix.".format(
                len(intervals),
                sum(
                    [
                        interval_length(interval)
                        for interval in union_intervals(intervals)
                    ]
                ),
            )
        )

        mt = hl.filter_intervals(mt, intervals)

    mt = hl.experimental.densify(mt)

    return mt.filter_rows(hl.is_defined(sites_ht[mt.locus]))
Esempio n. 10
0
    def aggregate_contig(ht: hl.Table, contigs: Set[str] = None):
        """
        Aggregates all contigs together and computes number for bins accross the contigs.
        """
        if contigs:
            ht = ht.filter(hl.literal(contigs).contains(ht.contig))

        return ht.group_by(*[k for k in ht.key if k != 'contig']).aggregate(
            min_score=hl.agg.min(ht.min_score),
            max_score=hl.agg.max(ht.max_score),
            **{
                x: hl.agg.sum(ht[x])
                for x in ht.row_value if x not in ['min_score', 'max_score']
            })
Esempio n. 11
0
def get_platform_specific_intervals(platform_pc_loadings_ht: hl.Table,
                                    threshold: float) -> List[hl.Interval]:
    """
    This takes the platform PC loadings and returns a list of intervals where the sum of the loadings above the given threshold.
    The experimental / untested idea behind this, is that those intervals may be problematic on some platforms.

    :param Table platform_pc_loadings_ht: Platform PCA loadings indexed by interval
    :param float threshold: Minimal threshold
    :param str intervals_path: Path to the intervals file to use (default: b37 exome calling intervals)
    :return: List of intervals with PC loadings above the given threshold
    :rtype: list of Interval
    """
    platform_specific_intervals = platform_pc_loadings_ht.filter(
        hl.sum(hl.abs(platform_pc_loadings_ht.loadings)) >= threshold)
    return platform_specific_intervals.interval.collect()
Esempio n. 12
0
 def annotate_related_pairs(related_pairs: hl.Table,
                            index_col: str) -> hl.Table:
     related_pairs = related_pairs.key_by(**related_pairs[index_col])
     related_pairs = related_pairs.filter(
         hl.is_missing(case_parents[related_pairs.key]))
     return related_pairs.annotate(
         **{
             index_col:
             related_pairs[index_col].annotate(
                 case_rank=hl.or_else(
                     hl.int(meta_ht[related_pairs.key].is_case), -1),
                 dp_mean=hl.or_else(
                     sample_qc_ht[
                         related_pairs.key].sample_qc.dp_stats.mean, -1.0))
         }).key_by()
Esempio n. 13
0
def compute_grouped_binned_ht(
    bin_ht: hl.Table,
    checkpoint_path: Optional[str] = None,
) -> hl.GroupedTable:
    """
    Group a Table that has been annotated with bins (`compute_ranked_bin` or `create_binned_ht`).

    The table will be grouped by bin_id (bin, biallelic, etc.), contig, snv, bi_allelic and singleton.

    .. note::

        If performing an aggregation following this grouping (such as `score_bin_agg`) then the aggregation
        function will need to use `ht._parent` to get the origin Table from the GroupedTable for the aggregation

    :param bin_ht: Input Table with a `bin_id` annotation
    :param checkpoint_path: If provided an intermediate checkpoint table is created with all required annotations before shuffling.
    :return: Table grouped by bins(s)
    """
    # Explode the rank table by bin_id
    bin_ht = bin_ht.annotate(
        bin_groups=hl.array(
            [
                hl.Struct(bin_id=bin_name, bin=bin_ht[bin_name])
                for bin_name in bin_ht.bin_group_variant_counts
            ]
        )
    )
    bin_ht = bin_ht.explode(bin_ht.bin_groups)
    bin_ht = bin_ht.transmute(
        bin_id=bin_ht.bin_groups.bin_id, bin=bin_ht.bin_groups.bin
    )
    bin_ht = bin_ht.filter(hl.is_defined(bin_ht.bin))

    if checkpoint_path is not None:
        bin_ht.checkpoint(checkpoint_path, overwrite=True)
    else:
        bin_ht = bin_ht.persist()

    # Group by bin_id, bin and additional stratification desired and compute QC metrics per bin
    return bin_ht.group_by(
        bin_id=bin_ht.bin_id,
        contig=bin_ht.locus.contig,
        snv=hl.is_snp(bin_ht.alleles[0], bin_ht.alleles[1]),
        bi_allelic=~bin_ht.was_split,
        singleton=bin_ht.singleton,
        release_adj=bin_ht.ac > 0,
        bin=bin_ht.bin,
    )._set_buffer_size(20000)
Esempio n. 14
0
def test_model(
    ht: hl.Table,
    rf_model: pyspark.ml.PipelineModel,
    features: List[str],
    label: str,
    prediction_col_name: str = "rf_prediction",
) -> List[hl.tstruct]:
    """
    A wrapper to test a model on a set of examples with known labels.

    1) Runs the model on the data
    2) Prints confusion matrix and accuracy
    3) Returns confusion matrix as a list of struct

    :param ht: Input table
    :param rf_model: RF Model
    :param features: Columns containing features that were used in the model
    :param label: Column containing label to be predicted
    :param prediction_col_name: Where to store the prediction
    :return: A list containing structs with {label, prediction, n}
    """

    ht = apply_rf_model(
        ht.filter(hl.is_defined(ht[label])),
        rf_model,
        features,
        label,
        prediction_col_name=prediction_col_name,
    )

    test_results = (
        ht.group_by(ht[prediction_col_name], ht[label])
        .aggregate(n=hl.agg.count())
        .collect()
    )

    # Print results
    df = pd.DataFrame(test_results)
    df = df.pivot(index=label, columns=prediction_col_name, values="n")
    logger.info("Testing results:\n{}".format(pprint.pformat(df)))
    logger.info(
        "Accuracy: {}".format(
            sum([x.n for x in test_results if x[label] == x[prediction_col_name]])
            / sum([x.n for x in test_results])
        )
    )

    return test_results
Esempio n. 15
0
def generic_field_check(
    ht: hl.Table,
    cond_expr: hl.expr.BooleanExpression,
    check_description: str,
    display_fields: List[str],
    verbose: bool,
    show_percent_sites: bool = False,
) -> None:
    """
    Check a generic logical condition involving annotations in a Hail Table and print the results to terminal.

    Displays the number of rows (and percent of rows, if `show_percent_sites` is True) in the Table that match the `cond_expr` and fail to be the desired condition (`check_description`).
    If the number of rows that match the `cond_expr` is 0, then the Table passes that check; otherwise, it fails.

    .. note::

        `cond_expr` and `check_description` are opposites and should never be the same.
        E.g., If `cond_expr` filters for instances where the raw AC is less than adj AC,
        then it is checking sites that fail to be the desired condition (`check_description`)
        of having a raw AC greater than or equal to the adj AC.

    :param ht: Table containing annotations to be checked.
    :param cond_expr: Logical expression referring to annotations in ht to be checked.
    :param check_description: String describing the condition being checked; is displayed in terminal summary message.
    :param display_fields: List of names of ht annotations to be displayed in case of failure (for troubleshooting purposes);
        these fields are also displayed if verbose is True.
    :param verbose: If True, show top values of annotations being checked, including checks that pass; if False,
        show only top values of annotations that fail checks.
    :param show_percent_sites: Show percentage of sites that fail checks. Default is False.
    :return: None
    """
    ht_orig = ht
    ht = ht.filter(cond_expr)
    n_fail = ht.count()
    if n_fail > 0:
        logger.info("Found %d sites that fail %s check:", n_fail,
                    check_description)
        if show_percent_sites:
            logger.info("Percentage of sites that fail: %f",
                        n_fail / ht_orig.count())
        ht = ht.flatten()
        ht.select("locus", "alleles", *display_fields).show()
    else:
        logger.info("PASSED %s check", check_description)
        if verbose:
            ht_orig = ht_orig.flatten()
            ht_orig.select(*display_fields).show()
Esempio n. 16
0
def get_duplicated_samples_ibd(
        kin_ht: hl.Table,
        i_col: str = 'i',
        j_col: str = 'j',
        pi_hat_col: str = 'pi_hat',
        duplicate_threshold: float = 0.90) -> List[Set[str]]:
    """
    Given a ibd output Table, extract the list of duplicate samples. Returns a list of set of samples that are duplicates.
    :param Table kin_ht: ibd output table
    :param str i_col: Column containing the 1st sample
    :param str j_col: Column containing the 2nd sample
    :param str pi_hat_col: Column containing the pi_hat value
    :param float duplicate_threshold: pi_hat threshold to consider two samples duplicated
    :return: List of samples that are duplicates
    :rtype: list of set of str
    """
    def get_all_dups(s, dups, samples_duplicates
                     ):  # should add docstring (sample, empty set, duplicates)
        if s in samples_duplicates:
            dups.add(s)
            s_dups = samples_duplicates.pop(
                s)  # gives u the value with that key
            for s_dup in s_dups:
                if s_dup not in dups:
                    dups = get_all_dups(s_dup, dups, samples_duplicates)
        return dups

    dup_rows = kin_ht.filter(
        kin_ht[pi_hat_col] > duplicate_threshold).collect()

    samples_duplicates = defaultdict(
        set)  #key is sample, value is set of that sample's duplicates
    for row in dup_rows:
        samples_duplicates[row[i_col]].add(row[j_col])
        samples_duplicates[row[j_col]].add(row[i_col])

    duplicated_samples = []
    while len(samples_duplicates) > 0:
        duplicated_samples.append(
            get_all_dups(
                list(samples_duplicates)[0], set(), samples_duplicates))

    return duplicated_samples
Esempio n. 17
0
def rank_related_samples(
    relatedness_ht: hl.Table, meta_ht: hl.Table, sample_qc_ht: hl.Table,
    fam_ht: hl.Table
) -> Tuple[hl.Table, Callable[[hl.expr.Expression, hl.expr.Expression],
                              hl.expr.NumericExpression]]:
    # Load families and identify parents from cases as they will be thrown away anyways
    fam_ht = fam_ht.transmute(trio=[
        hl.struct(s=fam_ht.id, is_parent=False),
        hl.struct(s=fam_ht.pat_id, is_parent=True),
        hl.struct(s=fam_ht.mat_id, is_parent=True)
    ])
    fam_ht = fam_ht.explode(fam_ht.trio)
    fam_ht = fam_ht.key_by(s=fam_ht.trio.s)
    case_parents = fam_ht.filter(meta_ht[fam_ht.key].is_case
                                 & fam_ht.trio.is_parent)

    def annotate_related_pairs(related_pairs: hl.Table,
                               index_col: str) -> hl.Table:
        related_pairs = related_pairs.key_by(**related_pairs[index_col])
        related_pairs = related_pairs.filter(
            hl.is_missing(case_parents[related_pairs.key]))
        return related_pairs.annotate(
            **{
                index_col:
                related_pairs[index_col].annotate(
                    case_rank=hl.or_else(
                        hl.int(meta_ht[related_pairs.key].is_case), -1),
                    dp_mean=hl.or_else(
                        sample_qc_ht[
                            related_pairs.key].sample_qc.dp_stats.mean, -1.0))
            }).key_by()

    relatedness_ht = annotate_related_pairs(relatedness_ht, "i")
    relatedness_ht = annotate_related_pairs(relatedness_ht, "j")

    def tie_breaker(l, r):
        return (hl.case().when(l.case_rank != r.case_rank,
                               r.case_rank - l.case_rank)  # smaller  is better
                .default(l.dp_mean - r.dp_mean)  # larger is better
                )

    return relatedness_ht, tie_breaker
def get_related_samples_to_drop(rank_table: hl.Table,
                                relatedness_ht: hl.Table) -> hl.Table:
    """
    Use the maximal independence function in Hail to intelligently prune clusters of related individuals, removing
    less desirable samples while maximizing the number of unrelated individuals kept in the sample set

    :param Table rank_table: Table with ranking annotations across exomes and genomes, computed via make_rank_file()
    :param Table relatedness_ht: Table with kinship coefficient annotations computed via pc_relate()
    :return: Table containing sample IDs ('s') to be pruned from the combined exome and genome sample set
    :rtype: Table
    """
    # Define maximal independent set, using rank list
    related_pairs = relatedness_ht.filter(
        relatedness_ht.kin > 0.08838835).select('i', 'j')
    n_related_samples = hl.eval(
        hl.len(
            related_pairs.aggregate(hl.agg.explode(
                lambda x: hl.agg.collect_as_set(x),
                [related_pairs.i, related_pairs.j]),
                                    _localize=False)))
    logger.info(
        '{} samples with at least 2nd-degree relatedness found in callset'.
        format(n_related_samples))
    max_rank = rank_table.count()
    related_pairs = related_pairs.annotate(
        id1_rank=hl.struct(id=related_pairs.i,
                           rank=rank_table[related_pairs.i].rank),
        id2_rank=hl.struct(id=related_pairs.j,
                           rank=rank_table[related_pairs.j].rank)).select(
                               'id1_rank', 'id2_rank')

    def tie_breaker(l, r):
        return hl.or_else(l.rank, max_rank + 1) - hl.or_else(
            r.rank, max_rank + 1)

    related_samples_to_drop_ranked = hl.maximal_independent_set(
        related_pairs.id1_rank,
        related_pairs.id2_rank,
        keep=False,
        tie_breaker=tie_breaker)
    return related_samples_to_drop_ranked.select(
        **related_samples_to_drop_ranked.node.id).key_by('data_type', 's')
Esempio n. 19
0
def filter_ped(raw_ped: hl.Pedigree, mendel: hl.Table, max_dnm: int,
               max_mendel: int) -> hl.Pedigree:
    mendel = mendel.filter(mendel.fam_id.startswith("fake"))
    mendel_by_s = (
        mendel.group_by(mendel.s).aggregate(
            fam_id=hl.agg.take(mendel.fam_id, 1)[0],
            n_mendel=hl.agg.count(),
            n_de_novo=hl.agg.count_where(
                mendel.mendel_code ==
                2),  # Code 2 is parents are hom ref, child is het
        ).persist())

    good_trios = mendel_by_s.aggregate(
        hl.agg.filter(
            (mendel_by_s.n_mendel < max_mendel) &
            (mendel_by_s.n_de_novo < max_dnm),
            hl.agg.collect(mendel_by_s.s, ),
        ))
    logger.info(f"Found {len(good_trios)} trios passing filters")
    return hl.Pedigree(
        [trio for trio in raw_ped.trios if trio.s in good_trios])
Esempio n. 20
0
def get_duplicated_samples(kin_ht: hl.Table,
                           i_col: str = 'i',
                           j_col: str = 'j',
                           kin_col: str = 'kin',
                           duplicate_threshold: float = 0.4) -> List[Set[str]]:
    """
    Given a pc_relate output Table, extract the list of duplicate samples. Returns a list of set of samples that are duplicates.


    :param Table kin_ht: pc_relate output table
    :param str i_col: Column containing the 1st sample
    :param str j_col: Column containing the 2nd sample
    :param str kin_col: Column containing the kinship value
    :param float duplicate_threshold: Kinship threshold to consider two samples duplicated
    :return: List of samples that are duplicates
    :rtype: list of set of str
    """
    def get_all_dups(s, dups, samples_duplicates):
        if s in samples_duplicates:
            dups.add(s)
            s_dups = samples_duplicates.pop(s)
            for s_dup in s_dups:
                if s_dup not in dups:
                    dups = get_all_dups(s_dup, dups, samples_duplicates)
        return dups

    dup_rows = kin_ht.filter(kin_ht[kin_col] > duplicate_threshold).collect()

    samples_duplicates = defaultdict(set)
    for row in dup_rows:
        samples_duplicates[row[i_col]].add(row[j_col])
        samples_duplicates[row[j_col]].add(row[i_col])

    duplicated_samples = []
    while len(samples_duplicates) > 0:
        duplicated_samples.append(
            get_all_dups(
                list(samples_duplicates)[0], set(), samples_duplicates))

    return duplicated_samples
Esempio n. 21
0
def subset_samples(
    input_mt: hl.MatrixTable,
    pedigree: hl.Table,
    sex_ht: hl.Table,
    output_dir: str,
    output_name: str,
) -> Tuple[hl.MatrixTable, hl.Table, list, list]:
    """
    Filter the MatrixTable and sex Table to only samples in the pedigree.

    :param input_mt: MatrixTable
    :param pedigree: Pedigree file from seqr loaded as a Hail Table
    :param sex_ht: Table of inferred sexes for each sample
    :param output_dir: Path to directory to output results
    :param output_name: Output prefix to use for results
    :return: MatrixTable and sex ht subsetted to the samples given in the pedigree, list of samples in the pedigree, list of samples in the VCF
    """
    # Get sample names to subset from the pedigree
    samples_to_subset = hl.set(pedigree.Individual_ID.collect())
    # Subset mt and ht to samples in the pedigree
    mt_subset = input_mt.filter_cols(samples_to_subset.contains(input_mt["s"]))
    sex_ht = sex_ht.filter(samples_to_subset.contains(sex_ht["s"]))

    # Filter to variants that have at least one alt call after the subsetting
    mt_subset = mt_subset.filter_rows(hl.agg.any(mt_subset.GT.is_non_ref()))

    # Check that the samples in the pedigree are present in the VCF subset and output samples that are missing
    out_missing_samples = hl.hadoop_open(
        f"{output_dir}/{output_name}_missing_samples_in_subset.txt", "w")
    expected_samples = pedigree.Individual_ID.collect()
    vcf_samples = mt_subset.s.collect()

    missings = set(expected_samples) - set(vcf_samples)
    for i in missings:
        out_missing_samples.write(i + "\n")
    out_missing_samples.close()

    return (mt_subset, sex_ht, expected_samples, vcf_samples)
Esempio n. 22
0
def infer_families(
    relationship_ht: hl.Table,
    sex: Union[hl.Table, Dict[str, bool]],
    duplicate_samples_ht: hl.Table,
    i_col: str = "i",
    j_col: str = "j",
    relationship_col: str = "relationship",
) -> hl.Pedigree:
    """
    This function takes a hail Table with a row for each pair of individuals i,j in the data that are related (it's OK to have unrelated samples too).
    The `relationship_col` should be a column specifying the relationship between each two samples as defined in this module's constants.

    This function returns a pedigree containing trios inferred from the data. Family ID can be the same for multiple
    trios if one or more members of the trios are related (e.g. sibs, multi-generational family). Trios are ordered by family ID.

    .. note::

        This function only returns complete trios defined as: one child, one father and one mother (sex is required for both parents).

    :param relationship_ht: Input relationship table
    :param sex: A Table or dict giving the sex for each sample (`TRUE`=female, `FALSE`=male). If a Table is given, it should have a field `is_female`.
    :param duplicated_samples: All duplicated samples TO REMOVE (If not provided, this function won't work as it assumes that each child has exactly two parents)
    :param i_col: Column containing the 1st sample of the pair in the relationship table
    :param j_col: Column containing the 2nd sample of the pair in the relationship table
    :param relationship_col: Column contatining the relationship for the sample pair as defined in this module constants.
    :return: Pedigree of complete trios
    """
    def group_parent_child_pairs_by_fam(
        parent_child_pairs: Iterable[Tuple[str, str]]
    ) -> List[List[Tuple[str, str]]]:
        """
        Takes all parent-children pairs and groups them by family.
        A family here is defined as a list of sample-pairs which all share at least one sample with at least one other sample-pair in the list.

        :param parent_child_pairs: All the parent-children pairs
        :return: A list of families, where each element of the list is a list of the parent-children pairs
        """
        fam_id = 1  # stores the current family id
        s_fam = dict()  # stores the family id for each sample
        fams = defaultdict(list)  # stores fam_id -> sample-pairs
        for pair in parent_child_pairs:
            if pair[0] in s_fam:
                if pair[1] in s_fam:
                    if (
                            s_fam[pair[0]] != s_fam[pair[1]]
                    ):  # If both samples are in different families, merge the families
                        new_fam_id = s_fam[pair[0]]
                        fam_id_to_merge = s_fam[pair[1]]
                        for s in s_fam:
                            if s_fam[s] == fam_id_to_merge:
                                s_fam[s] = new_fam_id
                        fams[new_fam_id].extend(fams.pop(fam_id_to_merge))
                else:  # If only the 1st sample in the pair is already in a family, assign the 2nd sample in the pair to the same family
                    s_fam[pair[1]] = s_fam[pair[0]]
                fams[s_fam[pair[0]]].append(pair)
            elif (
                    pair[1] in s_fam
            ):  # If only the 2nd sample in the pair is already in a family, assign the 1st sample in the pair to the same family
                s_fam[pair[0]] = s_fam[pair[1]]
                fams[s_fam[pair[1]]].append(pair)
            else:  # If none of the samples in the pair is already in a family, create a new family
                s_fam[pair[0]] = fam_id
                s_fam[pair[1]] = fam_id
                fams[fam_id].append(pair)
                fam_id += 1

        return list(fams.values())

    def get_trios(
        fam_id: str,
        parent_child_pairs: List[Tuple[str, str]],
        related_pairs: Dict[Tuple[str, str], str],
    ) -> List[hl.Trio]:
        """
        Generates trios based from the list of parent-child pairs in the family and
        all related pairs in the data. Only complete parent/offspring trios are included in the results.

        The trios are assembled as follows:
        1. All pairs of unrelated samples with different sexes within the family are extracted as possible parent pairs
        2. For each possible parent pair, a list of all children is constructed (each child in the list has a parent-offspring pair with each parent)
        3. If there are multiple children for a given parent pair, all children should be siblings with each other
        4. Check that each child was only assigned a single pair of parents. If a child is found to have multiple parent pairs, they are ALL discarded.

        :param fam_id: The family ID
        :param parent_child_pairs: The parent-child pairs for this family
        :param related_pairs: All related sample pairs in the data
        :return: List of trios in the family
        """
        def get_possible_parents(samples: List[str]) -> List[Tuple[str, str]]:
            """
            1. All pairs of unrelated samples with different sexes within the family are extracted as possible parent pairs

            :param samples: All samples in the family
            :return: Possible parent pairs
            """
            possible_parents = []
            for i in range(len(samples)):
                for j in range(i + 1, len(samples)):
                    if (related_pairs.get(
                            tuple(sorted([samples[i], samples[j]]))) is None):
                        if sex.get(samples[i]) is False and sex.get(
                                samples[j]) is True:
                            possible_parents.append((samples[i], samples[j]))
                        elif (sex.get(samples[i]) is True
                              and sex.get(samples[j]) is False):
                            possible_parents.append((samples[j], samples[i]))
            return possible_parents

        def get_children(possible_parents: Tuple[str, str]) -> List[str]:
            """
            2. For a given possible parent pair, a list of all children is constructed (each child in the list has a parent-offspring pair with each parent)

            :param possible_parents: A pair of possible parents
            :return: The list of all children (if any) corresponding to the possible parents
            """
            possible_offsprings = defaultdict(
                set
            )  # stores sample -> set of parents in the possible_parents where (sample, parent) is found in possible_child_pairs
            for pair in parent_child_pairs:
                if possible_parents[0] == pair[0]:
                    possible_offsprings[pair[1]].add(possible_parents[0])
                elif possible_parents[0] == pair[1]:
                    possible_offsprings[pair[0]].add(possible_parents[0])
                elif possible_parents[1] == pair[0]:
                    possible_offsprings[pair[1]].add(possible_parents[1])
                elif possible_parents[1] == pair[1]:
                    possible_offsprings[pair[0]].add(possible_parents[1])

            return [
                s for s, parents in possible_offsprings.items()
                if len(parents) == 2
            ]

        def check_sibs(children: List[str]) -> bool:
            """
            3. If there are multiple children for a given parent pair, all children should be siblings with each other

            :param children: List of all children for a given parent pair
            :return: Whether all children in the list are siblings
            """
            for i in range(len(children)):
                for j in range(i + 1, len(children)):
                    if (related_pairs[tuple(sorted([children[i], children[j]
                                                    ]))] != SIBLINGS):
                        return False
            return True

        def discard_multi_parents_children(trios: List[hl.Trio]):
            """
            4. Check that each child was only assigned a single pair of parents. If a child is found to have multiple parent pairs, they are ALL discarded.

            :param trios: All trios formed for this family
            :return: The list of trios for which each child has a single parents pair.
            """
            children_trios = defaultdict(list)
            for trio in trios:
                children_trios[trio.s].append(trio)

            for s, s_trios in children_trios.items():
                if len(s_trios) > 1:
                    logger.warning(
                        "Discarded duplicated child {0} found multiple in trios: {1}"
                        .format(s, ", ".join([str(trio) for trio in s_trios])))

            return [
                trios[0] for trios in children_trios.values()
                if len(trios) == 1
            ]

        # Get all possible pairs of parents in (father, mother) order
        all_possible_parents = get_possible_parents(
            list({s
                  for pair in parent_child_pairs for s in pair}))

        trios = []
        for possible_parents in all_possible_parents:
            children = get_children(possible_parents)
            if check_sibs(children):
                trios.extend([
                    hl.Trio(
                        s=s,
                        fam_id=fam_id,
                        pat_id=possible_parents[0],
                        mat_id=possible_parents[1],
                        is_female=sex.get(s),
                    ) for s in children
                ])
            else:
                logger.warning(
                    "Discarded family with same parents, and multiple offspring that weren't siblings:"
                    "\nMother: {}\nFather:{}\nChildren:{}".format(
                        possible_parents[0], possible_parents[1],
                        ", ".join(children)))

        return discard_multi_parents_children(trios)

    # Get all the relations we care about:
    # => Remove unrelateds and duplicates
    dups = duplicate_samples_ht.aggregate(
        hl.agg.explode(lambda dup: hl.agg.collect_as_set(dup),
                       duplicate_samples_ht.filtered),
        _localize=False,
    )
    relationship_ht = relationship_ht.filter(
        ~dups.contains(relationship_ht[i_col])
        & ~dups.contains(relationship_ht[j_col])
        & (relationship_ht[relationship_col] != UNRELATED))

    # Check relatedness table format
    if not relationship_ht[i_col].dtype == relationship_ht[j_col].dtype:
        logger.error(
            "i_col and j_col of the relatedness table need to be of the same type."
        )

    # If i_col and j_col aren't str, then convert them
    if not isinstance(relationship_ht[i_col], hl.expr.StringExpression):
        logger.warning(
            f"Pedigrees can only be constructed from string IDs, but your relatedness_ht ID column is of type: {relationship_ht[i_col].dtype}. Expression will be converted to string in Pedigrees."
        )
        if isinstance(relationship_ht[i_col], hl.expr.StructExpression):
            logger.warning(
                f"Struct fields {list(relationship_ht[i_col])} will be joined by underscores to use as sample names in Pedigree."
            )
            relationship_ht = relationship_ht.key_by(
                **{
                    i_col:
                    hl.delimit(
                        hl.array([
                            hl.str(relationship_ht[i_col][x])
                            for x in relationship_ht[i_col]
                        ]),
                        "_",
                    ),
                    j_col:
                    hl.delimit(
                        hl.array([
                            hl.str(relationship_ht[j_col][x])
                            for x in relationship_ht[j_col]
                        ]),
                        "_",
                    ),
                })
        else:
            raise NotImplementedError(
                "The `i_col` and `j_col` columns of the `relationship_ht` argument passed to infer_families are not of type StringExpression or Struct."
            )

    # If sex is a Table, extract sex information as a Dict
    if isinstance(sex, hl.Table):
        sex = dict(hl.tuple([sex.s, sex.is_female]).collect())

    # Collect all related sample pairs and
    # create a dictionnary with pairs as keys and relationships as values
    # Sample-pairs are tuples ordered by sample name
    related_pairs = {
        tuple(sorted([i, j])): rel
        for i, j, rel in hl.tuple([
            relationship_ht.i, relationship_ht.j, relationship_ht.relationship
        ]).collect()
    }

    parent_child_pairs_by_fam = group_parent_child_pairs_by_fam(
        [pair for pair, rel in related_pairs.items() if rel == PARENT_CHILD])
    return hl.Pedigree([
        trio for fam_index, parent_child_pairs in enumerate(
            parent_child_pairs_by_fam) for trio in get_trios(
                str(fam_index), parent_child_pairs, related_pairs)
    ])
Esempio n. 23
0
def liftover_intervals(t: hl.Table,
                       keep_missing_interval: bool = False) -> hl.Table:
    """
    Liftover locus in intervals from one coordinate system (hg37) to another (hg38)

    # Example input table description
    #
    # ----------------------------------------
    # Global fields:
    #     None
    # ----------------------------------------
    # Row fields:
    #     'interval': interval<locus<GRCh37>>
    # ----------------------------------------
    # Key: ['interval']
    # ----------------------------------------


    :param t: Table of intervals on GRCh37
    :param keep_missing_interval: If True, keep missing (non-lifted) intervals in the output Table.
    :return: Table with intervals lifted over GRCh38 added.
    """

    rg37 = hl.get_reference("GRCh37")
    rg38 = hl.get_reference("GRCh38")

    if not rg37.has_liftover("GRCh38"):
        rg37.add_liftover(
            f'{nfs_dir}/resources/liftover/grch37_to_grch38.over.chain.gz',
            rg38)

    t = t.annotate(
        start=hl.liftover(t.interval.start, "GRCh38"),
        end=hl.liftover(t.interval.end, "GRCh38"),
    )

    t = t.filter((t.start.contig == "chr" + t.interval.start.contig)
                 & (t.end.contig == "chr" + t.interval.end.contig))

    t = t.key_by()

    t = (t.select(interval=hl.locus_interval(t.start.contig,
                                             t.start.position,
                                             t.end.position,
                                             reference_genome=rg38,
                                             invalid_missing=True),
                  interval_hg37=t.interval))

    # bad intervals
    missing = t.aggregate(hl.agg.counter(~hl.is_defined(t.interval)))
    logger.info(
        f"Number of missing intervals: {missing[True]} out of {t.count()}...")

    # update globals annotations
    global_ann_expr = {
        'date': current_date(),
        'reference_genome': 'GRCh38',
        'was_lifted': True
    }
    t = t.annotate_globals(**global_ann_expr)

    if not keep_missing_interval:
        logger.info(f"Filtering out {missing[True]} missing intervals...")
        t = t.filter(hl.is_defined(t.interval), keep=True)

    return t.key_by("interval")
Esempio n. 24
0
def train_rf(
    ht: hl.Table,
    fp_to_tp: float = 1.0,
    num_trees: int = 500,
    max_depth: int = 5,
    no_transmitted_singletons: bool = False,
    no_inbreeding_coeff: bool = False,
    vqsr_training: bool = False,
    vqsr_model_id: str = False,
    filter_centromere_telomere: bool = False,
    test_intervals: Union[str, List[str]] = "chr20",
):
    """
    Train random forest model using `train_rf_model`

    :param ht: Table containing annotations needed for RF training, built with `create_rf_ht`
    :param fp_to_tp: Ratio of FPs to TPs for creating the RF model. If set to 0, all training examples are used.
    :param num_trees: Number of trees in the RF model.
    :param max_depth: Maxmimum tree depth in the RF model.
    :param no_transmitted_singletons: Do not use transmitted singletons for training.
    :param no_inbreeding_coeff: Do not use inbreeding coefficient as a feature for training.
    :param vqsr_training: Use VQSR training sites to train the RF.
    :param vqsr_model_id: VQSR model to use for vqsr_training. `vqsr_training` must be True for this parameter to be used.
    :param filter_centromere_telomere: Filter centromeres and telomeres before training.
    :param test_intervals: Specified interval(s) will be held out for testing and evaluation only. (default to "chr20")
    :return: `ht` annotated with training information and the RF model
    """
    features = FEATURES
    test_intervals = test_intervals
    if no_inbreeding_coeff:
        logger.info("Removing InbreedingCoeff from list of features...")
        features.remove("InbreedingCoeff")

    if vqsr_training:
        logger.info("Using VQSR training sites for RF training...")
        vqsr_ht = get_vqsr_filters(vqsr_model_id, split=True).ht()
        ht = ht.annotate(
            vqsr_POSITIVE_TRAIN_SITE=vqsr_ht[ht.key].info.POSITIVE_TRAIN_SITE,
            vqsr_NEGATIVE_TRAIN_SITE=vqsr_ht[ht.key].info.NEGATIVE_TRAIN_SITE,
        )
        tp_expr = ht.vqsr_POSITIVE_TRAIN_SITE
        fp_expr = ht.vqsr_NEGATIVE_TRAIN_SITE
    else:
        fp_expr = ht.fail_hard_filters
        tp_expr = ht.omni | ht.mills
        if not no_transmitted_singletons:
            tp_expr = tp_expr | ht.transmitted_singleton

    if test_intervals:
        if isinstance(test_intervals, str):
            test_intervals = [test_intervals]
        test_intervals = [
            hl.parse_locus_interval(x, reference_genome="GRCh38")
            for x in test_intervals
        ]

    ht = ht.annotate(tp=tp_expr, fp=fp_expr)

    if filter_centromere_telomere:
        logger.info("Filtering centromeres and telomeres from HT...")
        rf_ht = ht.filter(~hl.is_defined(telomeres_and_centromeres.ht()[ht.locus]))
    else:
        rf_ht = ht

    rf_ht, rf_model = train_rf_model(
        rf_ht,
        rf_features=features,
        tp_expr=rf_ht.tp,
        fp_expr=rf_ht.fp,
        fp_to_tp=fp_to_tp,
        num_trees=num_trees,
        max_depth=max_depth,
        test_expr=hl.literal(test_intervals).any(
            lambda interval: interval.contains(rf_ht.locus)
        ),
    )

    logger.info("Joining original RF Table with training information")
    ht = ht.join(rf_ht, how="left")

    return ht, rf_model
def create_binned_data_initial(ht: hl.Table, data: str, data_type: str, n_bins: int) -> hl.Table:
    # Count variants for ranking
    count_expr = {x: hl.agg.filter(hl.is_defined(ht[x]), hl.agg.counter(hl.cond(hl.is_snp(
        ht.alleles[0], ht.alleles[1]), 'snv', 'indel'))) for x in ht.row if x.endswith('rank')}
    rank_variant_counts = ht.aggregate(hl.Struct(**count_expr))
    logger.info(
        f"Found the following variant counts:\n {pformat(rank_variant_counts)}")
    ht_truth_data = hl.read_table(
        f"{temp_dir}/ddd-elgh-ukbb/variant_qc/truthset_table.ht")
    ht = ht.annotate_globals(rank_variant_counts=rank_variant_counts)
    ht = ht.annotate(
        **ht_truth_data[ht.key],
        # **fam_ht[ht.key],
        # **gnomad_ht[ht.key],
        # **denovo_ht[ht.key],
        # clinvar=hl.is_defined(clinvar_ht[ht.key]),
        indel_length=hl.abs(ht.alleles[0].length()-ht.alleles[1].length()),
        rank_bins=hl.array(
            [hl.Struct(
                rank_id=rank_name,
                bin=hl.int(hl.ceil(hl.float(ht[rank_name] + 1) / hl.floor(ht.globals.rank_variant_counts[rank_name][hl.cond(
                    hl.is_snp(ht.alleles[0], ht.alleles[1]), 'snv', 'indel')] / n_bins)))
            )
                for rank_name in rank_variant_counts]
        ),
        # lcr=hl.is_defined(lcr_intervals[ht.locus])
    )

    ht = ht.explode(ht.rank_bins)
    ht = ht.transmute(
        rank_id=ht.rank_bins.rank_id,
        bin=ht.rank_bins.bin
    )
    ht = ht.filter(hl.is_defined(ht.bin))

    ht = ht.checkpoint(
        f'{tmp_dir}/gnomad_score_binning_tmp.ht', overwrite=True)

    # Create binned data
    return (
        ht
        .group_by(
            rank_id=ht.rank_id,
            contig=ht.locus.contig,
            snv=hl.is_snp(ht.alleles[0], ht.alleles[1]),
            bi_allelic=hl.is_defined(ht.biallelic_rank),
            singleton=ht.transmitted_singleton,
            trans_singletons=hl.is_defined(ht.singleton_rank),
            de_novo_high_quality=ht.de_novo_high_quality_rank,
            de_novo_medium_quality=hl.is_defined(
                ht.de_novo_medium_quality_rank),
            de_novo_synonymous=hl.is_defined(ht.de_novo_synonymous_rank),
            # release_adj=ht.ac > 0,
            bin=ht.bin
        )._set_buffer_size(20000)
        .aggregate(
            min_score=hl.agg.min(ht.score),
            max_score=hl.agg.max(ht.score),
            n=hl.agg.count(),
            n_ins=hl.agg.count_where(
                hl.is_insertion(ht.alleles[0], ht.alleles[1])),
            n_del=hl.agg.count_where(
                hl.is_deletion(ht.alleles[0], ht.alleles[1])),
            n_ti=hl.agg.count_where(hl.is_transition(
                ht.alleles[0], ht.alleles[1])),
            n_tv=hl.agg.count_where(hl.is_transversion(
                ht.alleles[0], ht.alleles[1])),
            n_1bp_indel=hl.agg.count_where(ht.indel_length == 1),
            n_mod3bp_indel=hl.agg.count_where((ht.indel_length % 3) == 0),
            # n_clinvar=hl.agg.count_where(ht.clinvar),
            n_singleton=hl.agg.count_where(ht.transmitted_singleton),
            n_high_quality_de_novos=hl.agg.count_where(
                ht.de_novo_data.p_de_novo[0] > 0.99),
            n_validated_DDD_denovos=hl.agg.count_where(
                ht.inheritance.contains("De novo")),
            n_medium_quality_de_novos=hl.agg.count_where(
                ht.de_novo_data.p_de_novo[0] > 0.5),
            n_high_confidence_de_novos=hl.agg.count_where(
                ht.de_novo_data.confidence[0] == 'HIGH'),
            n_de_novo=hl.agg.filter(ht.family_stats.unrelated_qc_callstats.AC[0][1] == 0, hl.agg.sum(
                ht.family_stats.mendel[0].errors)),
            n_high_quality_de_novos_synonymous=hl.agg.count_where(
                (ht.de_novo_data.p_de_novo[0] > 0.99) & (ht.consequence == "synonymous_variant")),
            # n_de_novo_no_lcr=hl.agg.filter(~ht.lcr & (
            #    ht.family_stats.unrelated_qc_callstats.AC[1] == 0), hl.agg.sum(ht.family_stats.mendel.errors)),
            n_de_novo_sites=hl.agg.filter(ht.family_stats.unrelated_qc_callstats.AC[0][1] == 0, hl.agg.count_where(
                ht.family_stats.mendel[0].errors > 0)),
            # n_de_novo_sites_no_lcr=hl.agg.filter(~ht.lcr & (
            #    ht.family_stats.unrelated_qc_callstats.AC[1] == 0), hl.agg.count_where(ht.family_stats.mendel.errors > 0)),
            n_trans_singletons=hl.agg.filter((ht.ac_raw < 3) & (
                ht.family_stats.unrelated_qc_callstats.AC[0][1] == 1), hl.agg.sum(ht.family_stats.tdt[0].t)),
            n_trans_singletons_synonymous=hl.agg.filter((ht.ac_raw < 3) & (ht.consequence == "synonymous_variant") & (
                ht.family_stats.unrelated_qc_callstats.AC[0][1] == 1), hl.agg.sum(ht.family_stats.tdt[0].t)),
            n_untrans_singletons=hl.agg.filter((ht.ac_raw < 3) & (
                ht.family_stats.unrelated_qc_callstats.AC[0][1] == 1), hl.agg.sum(ht.family_stats.tdt[0].u)),
            n_untrans_singletons_synonymous=hl.agg.filter((ht.ac_raw < 3) & (ht.consequence == "synonymous_variant") & (
                ht.family_stats.unrelated_qc_callstats.AC[0][1] == 1), hl.agg.sum(ht.family_stats.tdt[0].u)),
            n_train_trans_singletons=hl.agg.count_where(
                (ht.family_stats.unrelated_qc_callstats.AC[0][1] == 1) & (ht.family_stats.tdt[0].t == 1)),
            n_omni=hl.agg.count_where(ht.omni),
            n_mills=hl.agg.count_where(ht.mills),
            n_hapmap=hl.agg.count_where(ht.hapmap),
            n_kgp_high_conf_snvs=hl.agg.count_where(
                ht.kgp_phase1_hc),
            fail_hard_filters=hl.agg.count_where(ht.fail_hard_filters),
            # n_vqsr_pos_train=hl.agg.count_where(ht.vqsr_positive_train_site),
            # n_vqsr_neg_train=hl.agg.count_where(ht.vqsr_negative_train_site)
        )
    )
Esempio n. 26
0
def get_summary_counts(
    ht: hl.Table,
    freq_field: str = "freq",
    filter_field: str = "filters",
    filter_decoy: bool = False,
    index: int = 0,
) -> hl.Table:
    """
    Generate a struct with summary counts across variant categories.

    Summary counts:
        - Number of variants
        - Number of indels
        - Number of SNVs
        - Number of LoF variants
        - Number of LoF variants that pass LOFTEE (including with LoF flags)
        - Number of LoF variants that pass LOFTEE without LoF flags
        - Number of OS (other splice) variants annotated by LOFTEE
        - Number of LoF variants that fail LOFTEE filters

    Also annotates Table's globals with total variant counts.

    Before calculating summary counts, function:
        - Filters out low confidence regions
        - Filters to canonical transcripts
        - Uses the most severe consequence

    Assumes that:
        - Input HT is annotated with VEP.
        - Multiallelic variants have been split and/or input HT contains bi-allelic variants only.
        - freq_expr was calculated with `annotate_freq`.
        - (Frequency index 0 from `annotate_freq` is frequency for all pops calculated on adj genotypes only.)

    :param ht: Input Table.
    :param freq_field: Name of field in HT containing frequency annotation (array of structs). Default is "freq".
    :param filter_field: Name of field in HT containing variant filter information. Default is "filters".
    :param filter_decoy: Whether to filter decoy regions. Default is False.
    :param index: Which index of freq_expr to use for annotation. Default is 0.
    :return: Table grouped by frequency bin and aggregated across summary count categories.
    """
    logger.info("Checking if multi-allelic variants have been split...")
    max_alleles = ht.aggregate(hl.agg.max(hl.len(ht.alleles)))
    if max_alleles > 2:
        logger.info(
            "Splitting multi-allelics and VEP transcript consequences...")
        ht = hl.split_multi_hts(ht)

    logger.info("Filtering to PASS variants in high confidence regions...")
    ht = ht.filter((hl.len(ht[filter_field]) == 0))
    ht = filter_low_conf_regions(ht, filter_decoy=filter_decoy)

    logger.info(
        "Filtering to canonical transcripts and getting VEP summary annotations..."
    )
    ht = filter_vep_to_canonical_transcripts(ht)
    ht = get_most_severe_consequence_for_summary(ht)

    logger.info("Annotating with frequency bin information...")
    ht = ht.annotate(freq_bin=freq_bin_expr(ht[freq_field], index))

    logger.info(
        "Annotating HT globals with total counts/total allele counts per variant category..."
    )
    summary_counts = ht.aggregate(
        hl.struct(**get_summary_counts_dict(
            ht.locus,
            ht.alleles,
            ht.lof,
            ht.no_lof_flags,
            ht.most_severe_csq,
            prefix_str="total_",
        )))
    summary_ac_counts = ht.aggregate(
        hl.struct(**get_summary_ac_dict(
            ht[freq_field][index].AC,
            ht.lof,
            ht.no_lof_flags,
            ht.most_severe_csq,
        )))
    ht = ht.annotate_globals(summary_counts=summary_counts.annotate(
        **summary_ac_counts))
    return ht.group_by("freq_bin").aggregate(**get_summary_counts_dict(
        ht.locus,
        ht.alleles,
        ht.lof,
        ht.no_lof_flags,
        ht.most_severe_csq,
    ))
def create_binned_data(ht: hl.Table, data: str, data_type: str,
                       n_bins: int) -> hl.Table:
    """
    Creates binned data from a rank Table grouped by rank_id (rank, biallelic, etc.), contig, snv, bi_allelic and singleton
    containing the information needed for evaluation plots.

    :param Table ht: Input rank table
    :param str data: Which data/run hash is being created
    :param str data_type: one of 'exomes' or 'genomes'
    :param int n_bins: Number of bins.
    :return: Binned Table
    :rtype: Table
    """

    # Count variants for ranking
    count_expr = {
        x: hl.agg.filter(
            hl.is_defined(ht[x]),
            hl.agg.counter(
                hl.cond(hl.is_snp(ht.alleles[0], ht.alleles[1]), 'snv',
                        'indel')))
        for x in ht.row if x.endswith('rank')
    }
    rank_variant_counts = ht.aggregate(hl.Struct(**count_expr))
    logger.info(
        f"Found the following variant counts:\n {pformat(rank_variant_counts)}"
    )
    ht = ht.annotate_globals(rank_variant_counts=rank_variant_counts)

    # Load external evaluation data
    clinvar_ht = hl.read_table(clinvar_ht_path)
    denovo_ht = get_validated_denovos_ht()
    if data_type == 'exomes':
        denovo_ht = denovo_ht.filter(denovo_ht.gnomad_exomes.high_quality)
    else:
        denovo_ht = denovo_ht.filter(denovo_ht.gnomad_genomes.high_quality)
    denovo_ht = denovo_ht.select(
        validated_denovo=denovo_ht.validated,
        high_confidence_denovo=denovo_ht.Confidence == 'HIGH')
    ht_truth_data = hl.read_table(annotations_ht_path(data_type, 'truth_data'))
    fam_ht = hl.read_table(annotations_ht_path(data_type, 'family_stats'))
    fam_ht = fam_ht.select(family_stats=fam_ht.family_stats[0])
    gnomad_ht = get_gnomad_data(data_type).rows()
    gnomad_ht = gnomad_ht.select(
        vqsr_negative_train_site=gnomad_ht.info.NEGATIVE_TRAIN_SITE,
        vqsr_positive_train_site=gnomad_ht.info.POSITIVE_TRAIN_SITE,
        fail_hard_filters=(gnomad_ht.info.QD < 2) | (gnomad_ht.info.FS > 60) |
        (gnomad_ht.info.MQ < 30))
    lcr_intervals = hl.import_locus_intervals(lcr_intervals_path)

    ht = ht.annotate(
        **ht_truth_data[ht.key],
        **fam_ht[ht.key],
        **gnomad_ht[ht.key],
        **denovo_ht[ht.key],
        clinvar=hl.is_defined(clinvar_ht[ht.key]),
        indel_length=hl.abs(ht.alleles[0].length() - ht.alleles[1].length()),
        rank_bins=hl.array([
            hl.Struct(
                rank_id=rank_name,
                bin=hl.int(
                    hl.ceil(
                        hl.float(ht[rank_name] + 1) / hl.floor(
                            ht.globals.rank_variant_counts[rank_name][hl.cond(
                                hl.is_snp(ht.alleles[0], ht.alleles[1]), 'snv',
                                'indel')] / n_bins))))
            for rank_name in rank_variant_counts
        ]),
        lcr=hl.is_defined(lcr_intervals[ht.locus]))

    ht = ht.explode(ht.rank_bins)
    ht = ht.transmute(rank_id=ht.rank_bins.rank_id, bin=ht.rank_bins.bin)
    ht = ht.filter(hl.is_defined(ht.bin))

    ht = ht.checkpoint(
        f'gs://gnomad-tmp/gnomad_score_binning_{data_type}_tmp_{data}.ht',
        overwrite=True)

    # Create binned data
    return (ht.group_by(
        rank_id=ht.rank_id,
        contig=ht.locus.contig,
        snv=hl.is_snp(ht.alleles[0], ht.alleles[1]),
        bi_allelic=hl.is_defined(ht.biallelic_rank),
        singleton=ht.singleton,
        release_adj=ht.ac > 0,
        bin=ht.bin)._set_buffer_size(20000).aggregate(
            min_score=hl.agg.min(ht.score),
            max_score=hl.agg.max(ht.score),
            n=hl.agg.count(),
            n_ins=hl.agg.count_where(
                hl.is_insertion(ht.alleles[0], ht.alleles[1])),
            n_del=hl.agg.count_where(
                hl.is_deletion(ht.alleles[0], ht.alleles[1])),
            n_ti=hl.agg.count_where(
                hl.is_transition(ht.alleles[0], ht.alleles[1])),
            n_tv=hl.agg.count_where(
                hl.is_transversion(ht.alleles[0], ht.alleles[1])),
            n_1bp_indel=hl.agg.count_where(ht.indel_length == 1),
            n_mod3bp_indel=hl.agg.count_where((ht.indel_length % 3) == 0),
            n_clinvar=hl.agg.count_where(ht.clinvar),
            n_singleton=hl.agg.count_where(ht.singleton),
            n_validated_de_novos=hl.agg.count_where(ht.validated_denovo),
            n_high_confidence_de_novos=hl.agg.count_where(
                ht.high_confidence_denovo),
            n_de_novo=hl.agg.filter(
                ht.family_stats.unrelated_qc_callstats.AC[1] == 0,
                hl.agg.sum(ht.family_stats.mendel.errors)),
            n_de_novo_no_lcr=hl.agg.filter(
                ~ht.lcr & (ht.family_stats.unrelated_qc_callstats.AC[1] == 0),
                hl.agg.sum(ht.family_stats.mendel.errors)),
            n_de_novo_sites=hl.agg.filter(
                ht.family_stats.unrelated_qc_callstats.AC[1] == 0,
                hl.agg.count_where(ht.family_stats.mendel.errors > 0)),
            n_de_novo_sites_no_lcr=hl.agg.filter(
                ~ht.lcr & (ht.family_stats.unrelated_qc_callstats.AC[1] == 0),
                hl.agg.count_where(ht.family_stats.mendel.errors > 0)),
            n_trans_singletons=hl.agg.filter(
                (ht.info_ac < 3) &
                (ht.family_stats.unrelated_qc_callstats.AC[1] == 1),
                hl.agg.sum(ht.family_stats.tdt.t)),
            n_untrans_singletons=hl.agg.filter(
                (ht.info_ac < 3) &
                (ht.family_stats.unrelated_qc_callstats.AC[1] == 1),
                hl.agg.sum(ht.family_stats.tdt.u)),
            n_train_trans_singletons=hl.agg.count_where(
                (ht.family_stats.unrelated_qc_callstats.AC[1] == 1)
                & (ht.family_stats.tdt.t == 1)),
            n_omni=hl.agg.count_where(ht.truth_data.omni),
            n_mills=hl.agg.count_where(ht.truth_data.mills),
            n_hapmap=hl.agg.count_where(ht.truth_data.hapmap),
            n_kgp_high_conf_snvs=hl.agg.count_where(
                ht.truth_data.kgp_high_conf_snvs),
            fail_hard_filters=hl.agg.count_where(ht.fail_hard_filters),
            n_vqsr_pos_train=hl.agg.count_where(ht.vqsr_positive_train_site),
            n_vqsr_neg_train=hl.agg.count_where(ht.vqsr_negative_train_site)))
Esempio n. 28
0
def sample_training_examples(
    ht: hl.Table,
    tp_expr: hl.BooleanExpression,
    fp_expr: hl.BooleanExpression,
    fp_to_tp: float = 1.0,
    test_expr: Optional[hl.expr.BooleanExpression] = None,
) -> hl.Table:
    """
    Returns a Table of all positive and negative training examples in `ht` with an annotation indicating those that
    should be used for training given a true positive (TP) to false positive (FP) ratio.

    The returned Table has the following annotations:
        - train: indicates if the variant should be used for training. A row is given False for the annotation if True
          for `test_expr`, True for both `tp_expr and fp_expr`, or it is pruned out to obtain the desired `fp_to_tp` ratio.
        - label: indicates if a variant is a 'TP' or 'FP' and will also be labeled as such for variants defined by `test_expr`.

    .. note::

        - This function does not support multi-allelic variants.
        - The function will give some stats about the TPs/FPs provided (Ti, Tv, indels).

    :param ht: Input Table.
    :param tp_expr: Expression for TP examples.
    :param fp_expr: Expression for FP examples.
    :param fp_to_tp: FP to TP ratio. If set to <= 0, all training examples are used.
    :param test_expr: Optional expression to exclude a set of variants from training set. Still contains TP/FP label annotation.
    :return: Table subset with corresponding TP and FP examples with desired FP to TP ratio.
    """

    ht = ht.select(
        _tp=hl.or_else(tp_expr, False),
        _fp=hl.or_else(fp_expr, False),
        _exclude=False if test_expr is None else test_expr,
    )
    ht = ht.filter(ht._tp | ht._fp).persist()

    # Get stats about TP / FP sets
    def _get_train_counts(ht: hl.Table) -> Tuple[int, int]:
        """
        Determine the number of TP and FP variants in the input Table and report some stats on Ti, Tv, indels.

        :param ht: Input Table
        :return: Counts of TP and FP variants in the table
        """
        train_stats = hl.struct(n=hl.agg.count())

        if "alleles" in ht.row and ht.row.alleles.dtype == hl.tarray(hl.tstr):
            train_stats = train_stats.annotate(
                ti=hl.agg.count_where(
                    hl.expr.is_transition(ht.alleles[0], ht.alleles[1])
                ),
                tv=hl.agg.count_where(
                    hl.expr.is_transversion(ht.alleles[0], ht.alleles[1])
                ),
                indel=hl.agg.count_where(
                    hl.expr.is_indel(ht.alleles[0], ht.alleles[1])
                ),
            )

        # Sample training examples
        pd_stats = (
            ht.group_by(**{"contig": ht.locus.contig, "tp": ht._tp, "fp": ht._fp})
            .aggregate(**train_stats)
            .to_pandas()
        )

        logger.info(pformat(pd_stats))
        pd_stats = pd_stats.fillna(False)

        # Number of true positive and false positive variants to be sampled for the training set
        n_tp = pd_stats[pd_stats["tp"] & ~pd_stats["fp"]]["n"].sum()
        n_fp = pd_stats[~pd_stats["tp"] & pd_stats["fp"]]["n"].sum()

        return n_tp, n_fp

    n_tp, n_fp = _get_train_counts(ht.filter(~ht._exclude))

    prob_tp = prob_fp = 1.0
    if fp_to_tp > 0:
        desired_fp = fp_to_tp * n_tp
        if desired_fp < n_fp:
            prob_fp = desired_fp / n_fp
        else:
            prob_tp = n_fp / desired_fp

        logger.info(
            f"Training examples sampling: tp={prob_tp}*{n_tp}, fp={prob_fp}*{n_fp}"
        )

        train_expr = (
            hl.case(missing_false=True)
            .when(ht._fp & hl.or_else(~ht._tp, True), hl.rand_bool(prob_fp))
            .when(ht._tp & hl.or_else(~ht._fp, True), hl.rand_bool(prob_tp))
            .default(False)
        )
    else:
        train_expr = ~(ht._tp & ht._fp)
        logger.info(f"Using all {n_tp} TP and {n_fp} FP training examples.")

    label_expr = (
        hl.case(missing_false=True)
        .when(ht._tp & hl.or_else(~ht._fp, True), "TP")
        .when(ht._fp & hl.or_else(~ht._tp, True), "FP")
        .default(hl.null(hl.tstr))
    )

    return ht.select(train=train_expr & ~ht._exclude, label=label_expr)
Esempio n. 29
0
def generate_sib_stats_expr(
    mt: hl.MatrixTable,
    sib_ht: hl.Table,
    i_col: str = "i",
    j_col: str = "j",
    strata: Dict[str, hl.expr.BooleanExpression] = {"raw": True},
    is_female: Optional[hl.expr.BooleanExpression] = None,
) -> hl.expr.StructExpression:
    """
    Generates a row-wise expression containing the number of alternate alleles in common between sibling pairs.

    The sibling sharing counts can be stratified using additional filters using `stata`.

    .. note::

        This function expects that the `mt` has either been split or filtered to only bi-allelics
        If a sample has multiple sibling pairs, only one pair will be counted

    :param mt: Input matrix table
    :param sib_ht: Table defining sibling pairs with one sample in a col (`i_col`) and the second in another col (`j_col`)
    :param i_col: Column containing the 1st sample of the pair in the relationship table
    :param j_col: Column containing the 2nd sample of the pair in the relationship table
    :param strata: Dict with additional strata to use when computing shared sibling variant counts
    :param is_female: An optional column in mt giving the sample sex. If not given, counts are only computed for autosomes.
    :return: A Table with the sibling shared variant counts
    """
    def _get_alt_count(locus, gt, is_female):
        """
        Helper method to calculate alt allele count with sex info if present
        """
        if is_female is None:
            return hl.or_missing(locus.in_autosome(), gt.n_alt_alleles())
        return (hl.case().when(
            locus.in_autosome_or_par(), gt.n_alt_alleles()).when(
                ~is_female & (locus.in_x_nonpar() | locus.in_y_nonpar()),
                hl.min(1, gt.n_alt_alleles()),
            ).when(is_female & locus.in_y_nonpar(), 0).default(0))

    if is_female is None:
        logger.warning(
            "Since no sex expression was given to generate_sib_stats_expr, only variants in autosomes will be counted."
        )

    # If a sample is in sib_ht more than one time, keep only one of the sibling pairs
    # First filter to only samples found in mt to keep as many pairs as possible
    s_to_keep = mt.aggregate_cols(hl.agg.collect_as_set(mt.s), _localize=False)
    sib_ht = sib_ht.filter(
        s_to_keep.contains(sib_ht[i_col].s)
        & s_to_keep.contains(sib_ht[j_col].s))
    sib_ht = sib_ht.add_index("sib_idx")
    sib_ht = sib_ht.annotate(sibs=[sib_ht[i_col].s, sib_ht[j_col].s])
    sib_ht = sib_ht.explode("sibs")
    sib_ht = sib_ht.group_by("sibs").aggregate(
        sib_idx=(hl.agg.take(sib_ht.sib_idx, 1, ordering=sib_ht.sib_idx)[0]))
    sib_ht = sib_ht.group_by(
        sib_ht.sib_idx).aggregate(sibs=hl.agg.collect(sib_ht.sibs))
    sib_ht = sib_ht.filter(hl.len(sib_ht.sibs) == 2).persist()

    logger.info(
        f"Generating sibling variant sharing counts using {sib_ht.count()} pairs."
    )
    sib_ht = sib_ht.explode("sibs").key_by("sibs")[mt.s]

    # Create sibling sharing counters
    sib_stats = hl.struct(
        **{
            f"n_sib_shared_variants_{name}": hl.sum(
                hl.agg.filter(
                    expr,
                    hl.agg.group_by(
                        sib_ht.sib_idx,
                        hl.or_missing(
                            hl.agg.sum(hl.is_defined(mt.GT)) == 2,
                            hl.agg.min(
                                _get_alt_count(mt.locus, mt.GT, is_female)),
                        ),
                    ),
                ).values())
            for name, expr in strata.items()
        })

    sib_stats = sib_stats.annotate(
        **{
            f"ac_sibs_{name}": hl.agg.filter(
                expr & hl.is_defined(sib_ht.sib_idx),
                hl.agg.sum(mt.GT.n_alt_alleles()))
            for name, expr in strata.items()
        })

    return sib_stats
Esempio n. 30
0
def compute_related_samples_to_drop(
    relatedness_ht: hl.Table,
    rank_ht: hl.Table,
    kin_threshold: float,
    filtered_samples: Optional[hl.expr.SetExpression] = None,
    min_related_hard_filter: Optional[int] = None,
) -> hl.Table:
    """
    Computes a Table with the list of samples to drop (and their global rank) to get the maximal independent set of unrelated samples.

    .. note::

        - `relatedness_ht` should be keyed by exactly two fields of the same type, identifying the pair of samples for each row.
        - `rank_ht` should be keyed by a single key of the same type as a single sample identifier in `relatedness_ht`.

    :param relatedness_ht: relatedness HT, as produced by e.g. pc-relate
    :param kin_threshold: Kinship threshold to consider two samples as related
    :param rank_ht: Table with a global rank for each sample (smaller is preferred)
    :param filtered_samples: An optional set of samples to exclude (e.g. these samples were hard-filtered)  These samples will then appear in the resulting samples to drop.
    :param min_related_hard_filter: If provided, any sample that is related to more samples than this parameter will be filtered prior to computing the maximal independent set and appear in the results.
    :return: A Table with the list of the samples to drop along with their rank.
    """

    # Make sure that the key types are valid
    assert len(list(relatedness_ht.key)) == 2
    assert relatedness_ht.key[0].dtype == relatedness_ht.key[1].dtype
    assert len(list(rank_ht.key)) == 1
    assert relatedness_ht.key[0].dtype == rank_ht.key[0].dtype

    logger.info(
        f"Filtering related samples using a kin threshold of {kin_threshold}")
    relatedness_ht = relatedness_ht.filter(relatedness_ht.kin > kin_threshold)

    filtered_samples_rel = set()
    if min_related_hard_filter is not None:
        logger.info(
            f"Computing samples related to too many individuals (>{min_related_hard_filter}) for exclusion"
        )
        gbi = relatedness_ht.annotate(s=list(relatedness_ht.key))
        gbi = gbi.explode(gbi.s)
        gbi = gbi.group_by(gbi.s).aggregate(n=hl.agg.count())
        filtered_samples_rel = gbi.aggregate(
            hl.agg.filter(gbi.n > min_related_hard_filter,
                          hl.agg.collect_as_set(gbi.s)))
        logger.info(
            f"Found {len(filtered_samples_rel)} samples with too many 1st/2nd degree relatives. These samples will be excluded."
        )

    if filtered_samples is not None:
        filtered_samples_rel = filtered_samples_rel.union(
            relatedness_ht.aggregate(
                hl.agg.explode(
                    lambda s: hl.agg.collect_as_set(s),
                    hl.array(list(relatedness_ht.key)).filter(
                        lambda s: filtered_samples.contains(s)),
                )))

    if len(filtered_samples_rel) > 0:
        filtered_samples_lit = hl.literal(filtered_samples_rel)
        relatedness_ht = relatedness_ht.filter(
            filtered_samples_lit.contains(relatedness_ht.key[0])
            | filtered_samples_lit.contains(relatedness_ht.key[1]),
            keep=False,
        )

    logger.info("Annotating related sample pairs with rank.")
    i, j = list(relatedness_ht.key)
    relatedness_ht = relatedness_ht.key_by(s=relatedness_ht[i])
    relatedness_ht = relatedness_ht.annotate(**{
        i:
        hl.struct(s=relatedness_ht.s, rank=rank_ht[relatedness_ht.key].rank)
    })
    relatedness_ht = relatedness_ht.key_by(s=relatedness_ht[j])
    relatedness_ht = relatedness_ht.annotate(**{
        j:
        hl.struct(s=relatedness_ht.s, rank=rank_ht[relatedness_ht.key].rank)
    })
    relatedness_ht = relatedness_ht.key_by(i, j)
    relatedness_ht = relatedness_ht.drop("s")
    relatedness_ht = relatedness_ht.persist()

    related_samples_to_drop_ht = hl.maximal_independent_set(
        relatedness_ht[i],
        relatedness_ht[j],
        keep=False,
        tie_breaker=lambda l, r: l.rank - r.rank,
    )
    related_samples_to_drop_ht = related_samples_to_drop_ht.key_by()
    related_samples_to_drop_ht = related_samples_to_drop_ht.select(
        **related_samples_to_drop_ht.node)
    related_samples_to_drop_ht = related_samples_to_drop_ht.key_by("s")

    if len(filtered_samples_rel) > 0:
        related_samples_to_drop_ht = related_samples_to_drop_ht.union(
            hl.Table.parallelize(
                [
                    hl.struct(s=s, rank=hl.null(hl.tint64))
                    for s in filtered_samples_rel
                ],
                key="s",
            ))

    return related_samples_to_drop_ht