Esempio n. 1
0
    def _get_train_counts(ht: hl.Table) -> Tuple[int, int]:
        """
        Determine the number of TP and FP variants in the input Table and report some stats on Ti, Tv, indels.

        :param ht: Input Table
        :return: Counts of TP and FP variants in the table
        """
        train_stats = hl.struct(n=hl.agg.count())

        if "alleles" in ht.row and ht.row.alleles.dtype == hl.tarray(hl.tstr):
            train_stats = train_stats.annotate(
                ti=hl.agg.count_where(
                    hl.expr.is_transition(ht.alleles[0], ht.alleles[1])),
                tv=hl.agg.count_where(
                    hl.expr.is_transversion(ht.alleles[0], ht.alleles[1])),
                indel=hl.agg.count_where(
                    hl.expr.is_indel(ht.alleles[0], ht.alleles[1])),
            )

        # Sample training examples
        pd_stats = (ht.group_by(**{
            "contig": ht.locus.contig,
            "tp": ht._tp,
            "fp": ht._fp
        }).aggregate(**train_stats).to_pandas())

        logger.info(pformat(pd_stats))
        pd_stats = pd_stats.fillna(False)

        # Number of true positive and false positive variants to be sampled for the training set
        n_tp = pd_stats[pd_stats["tp"] & ~pd_stats["fp"]]["n"].sum()
        n_fp = pd_stats[~pd_stats["tp"] & pd_stats["fp"]]["n"].sum()

        return n_tp, n_fp
Esempio n. 2
0
    def aggregate_contig(ht: hl.Table, contigs: Set[str] = None):
        """
        Aggregates all contigs together and computes number for bins accross the contigs.
        """
        if contigs:
            ht = ht.filter(hl.literal(contigs).contains(ht.contig))

        return ht.group_by(*[k for k in ht.key if k != 'contig']).aggregate(
            min_score=hl.agg.min(ht.min_score),
            max_score=hl.agg.max(ht.max_score),
            **{
                x: hl.agg.sum(ht[x])
                for x in ht.row_value if x not in ['min_score', 'max_score']
            })
Esempio n. 3
0
def generate_clusters_map(ht: hl.Table) -> hl.Table:
    """
    Generate a table mapping gene/features -> cluster_ids.
    Expected as input a two-field HailTable: f0: cluster id, f1: gene/feature symbol.

    :param ht: hl.HailTable
    :return:  hl.HailTable
    """

    # rename fields
    ht = (ht.rename({'f0': 'cluster_id', 'f1': 'gene'}))
    clusters = (ht.group_by('gene').aggregate(cluster_id=hl.agg.collect_as_set(
        ht.cluster_id)).repartition(50).key_by('gene'))

    return clusters
def compute_grouped_binned_ht(
    bin_ht: hl.Table,
    checkpoint_path: Optional[str] = None,
) -> hl.GroupedTable:
    """
    Group a Table that has been annotated with bins (`compute_ranked_bin` or `create_binned_ht`).

    The table will be grouped by bin_id (bin, biallelic, etc.), contig, snv, bi_allelic and singleton.

    .. note::

        If performing an aggregation following this grouping (such as `score_bin_agg`) then the aggregation
        function will need to use `ht._parent` to get the origin Table from the GroupedTable for the aggregation

    :param bin_ht: Input Table with a `bin_id` annotation
    :param checkpoint_path: If provided an intermediate checkpoint table is created with all required annotations before shuffling.
    :return: Table grouped by bins(s)
    """
    # Explode the rank table by bin_id
    bin_ht = bin_ht.annotate(
        bin_groups=hl.array(
            [
                hl.Struct(bin_id=bin_name, bin=bin_ht[bin_name])
                for bin_name in bin_ht.bin_group_variant_counts
            ]
        )
    )
    bin_ht = bin_ht.explode(bin_ht.bin_groups)
    bin_ht = bin_ht.transmute(
        bin_id=bin_ht.bin_groups.bin_id, bin=bin_ht.bin_groups.bin
    )
    bin_ht = bin_ht.filter(hl.is_defined(bin_ht.bin))

    if checkpoint_path is not None:
        bin_ht.checkpoint(checkpoint_path, overwrite=True)
    else:
        bin_ht = bin_ht.persist()

    # Group by bin_id, bin and additional stratification desired and compute QC metrics per bin
    return bin_ht.group_by(
        bin_id=bin_ht.bin_id,
        contig=bin_ht.locus.contig,
        snv=hl.is_snp(bin_ht.alleles[0], bin_ht.alleles[1]),
        bi_allelic=~bin_ht.was_split,
        singleton=bin_ht.singleton,
        release_adj=bin_ht.ac > 0,
        bin=bin_ht.bin,
    )._set_buffer_size(20000)
def test_model(
    ht: hl.Table,
    rf_model: pyspark.ml.PipelineModel,
    features: List[str],
    label: str,
    prediction_col_name: str = "rf_prediction",
) -> List[hl.tstruct]:
    """
    A wrapper to test a model on a set of examples with known labels.

    1) Runs the model on the data
    2) Prints confusion matrix and accuracy
    3) Returns confusion matrix as a list of struct

    :param ht: Input table
    :param rf_model: RF Model
    :param features: Columns containing features that were used in the model
    :param label: Column containing label to be predicted
    :param prediction_col_name: Where to store the prediction
    :return: A list containing structs with {label, prediction, n}
    """

    ht = apply_rf_model(
        ht.filter(hl.is_defined(ht[label])),
        rf_model,
        features,
        label,
        prediction_col_name=prediction_col_name,
    )

    test_results = (
        ht.group_by(ht[prediction_col_name], ht[label])
        .aggregate(n=hl.agg.count())
        .collect()
    )

    # Print results
    df = pd.DataFrame(test_results)
    df = df.pivot(index=label, columns=prediction_col_name, values="n")
    logger.info("Testing results:\n{}".format(pprint.pformat(df)))
    logger.info(
        "Accuracy: {}".format(
            sum([x.n for x in test_results if x[label] == x[prediction_col_name]])
            / sum([x.n for x in test_results])
        )
    )

    return test_results
Esempio n. 6
0
def filter_ped(raw_ped: hl.Pedigree, mendel: hl.Table, max_dnm: int,
               max_mendel: int) -> hl.Pedigree:
    mendel = mendel.filter(mendel.fam_id.startswith("fake"))
    mendel_by_s = (
        mendel.group_by(mendel.s).aggregate(
            fam_id=hl.agg.take(mendel.fam_id, 1)[0],
            n_mendel=hl.agg.count(),
            n_de_novo=hl.agg.count_where(
                mendel.mendel_code ==
                2),  # Code 2 is parents are hom ref, child is het
        ).persist())

    good_trios = mendel_by_s.aggregate(
        hl.agg.filter(
            (mendel_by_s.n_mendel < max_mendel) &
            (mendel_by_s.n_de_novo < max_dnm),
            hl.agg.collect(mendel_by_s.s, ),
        ))
    logger.info(f"Found {len(good_trios)} trios passing filters")
    return hl.Pedigree(
        [trio for trio in raw_ped.trios if trio.s in good_trios])
Esempio n. 7
0
def get_summary_counts(
    ht: hl.Table,
    freq_field: str = "freq",
    filter_field: str = "filters",
    filter_decoy: bool = False,
    index: int = 0,
) -> hl.Table:
    """
    Generate a struct with summary counts across variant categories.

    Summary counts:
        - Number of variants
        - Number of indels
        - Number of SNVs
        - Number of LoF variants
        - Number of LoF variants that pass LOFTEE (including with LoF flags)
        - Number of LoF variants that pass LOFTEE without LoF flags
        - Number of OS (other splice) variants annotated by LOFTEE
        - Number of LoF variants that fail LOFTEE filters

    Also annotates Table's globals with total variant counts.

    Before calculating summary counts, function:
        - Filters out low confidence regions
        - Filters to canonical transcripts
        - Uses the most severe consequence

    Assumes that:
        - Input HT is annotated with VEP.
        - Multiallelic variants have been split and/or input HT contains bi-allelic variants only.
        - freq_expr was calculated with `annotate_freq`.
        - (Frequency index 0 from `annotate_freq` is frequency for all pops calculated on adj genotypes only.)

    :param ht: Input Table.
    :param freq_field: Name of field in HT containing frequency annotation (array of structs). Default is "freq".
    :param filter_field: Name of field in HT containing variant filter information. Default is "filters".
    :param filter_decoy: Whether to filter decoy regions. Default is False.
    :param index: Which index of freq_expr to use for annotation. Default is 0.
    :return: Table grouped by frequency bin and aggregated across summary count categories.
    """
    logger.info("Checking if multi-allelic variants have been split...")
    max_alleles = ht.aggregate(hl.agg.max(hl.len(ht.alleles)))
    if max_alleles > 2:
        logger.info(
            "Splitting multi-allelics and VEP transcript consequences...")
        ht = hl.split_multi_hts(ht)

    logger.info("Filtering to PASS variants in high confidence regions...")
    ht = ht.filter((hl.len(ht[filter_field]) == 0))
    ht = filter_low_conf_regions(ht, filter_decoy=filter_decoy)

    logger.info(
        "Filtering to canonical transcripts and getting VEP summary annotations..."
    )
    ht = filter_vep_to_canonical_transcripts(ht)
    ht = get_most_severe_consequence_for_summary(ht)

    logger.info("Annotating with frequency bin information...")
    ht = ht.annotate(freq_bin=freq_bin_expr(ht[freq_field], index))

    logger.info(
        "Annotating HT globals with total counts/total allele counts per variant category..."
    )
    summary_counts = ht.aggregate(
        hl.struct(**get_summary_counts_dict(
            ht.locus,
            ht.alleles,
            ht.lof,
            ht.no_lof_flags,
            ht.most_severe_csq,
            prefix_str="total_",
        )))
    summary_ac_counts = ht.aggregate(
        hl.struct(**get_summary_ac_dict(
            ht[freq_field][index].AC,
            ht.lof,
            ht.no_lof_flags,
            ht.most_severe_csq,
        )))
    ht = ht.annotate_globals(summary_counts=summary_counts.annotate(
        **summary_ac_counts))
    return ht.group_by("freq_bin").aggregate(**get_summary_counts_dict(
        ht.locus,
        ht.alleles,
        ht.lof,
        ht.no_lof_flags,
        ht.most_severe_csq,
    ))
def create_binned_data(ht: hl.Table, data: str, data_type: str,
                       n_bins: int) -> hl.Table:
    """
    Creates binned data from a rank Table grouped by rank_id (rank, biallelic, etc.), contig, snv, bi_allelic and singleton
    containing the information needed for evaluation plots.

    :param Table ht: Input rank table
    :param str data: Which data/run hash is being created
    :param str data_type: one of 'exomes' or 'genomes'
    :param int n_bins: Number of bins.
    :return: Binned Table
    :rtype: Table
    """

    # Count variants for ranking
    count_expr = {
        x: hl.agg.filter(
            hl.is_defined(ht[x]),
            hl.agg.counter(
                hl.cond(hl.is_snp(ht.alleles[0], ht.alleles[1]), 'snv',
                        'indel')))
        for x in ht.row if x.endswith('rank')
    }
    rank_variant_counts = ht.aggregate(hl.Struct(**count_expr))
    logger.info(
        f"Found the following variant counts:\n {pformat(rank_variant_counts)}"
    )
    ht = ht.annotate_globals(rank_variant_counts=rank_variant_counts)

    # Load external evaluation data
    clinvar_ht = hl.read_table(clinvar_ht_path)
    denovo_ht = get_validated_denovos_ht()
    if data_type == 'exomes':
        denovo_ht = denovo_ht.filter(denovo_ht.gnomad_exomes.high_quality)
    else:
        denovo_ht = denovo_ht.filter(denovo_ht.gnomad_genomes.high_quality)
    denovo_ht = denovo_ht.select(
        validated_denovo=denovo_ht.validated,
        high_confidence_denovo=denovo_ht.Confidence == 'HIGH')
    ht_truth_data = hl.read_table(annotations_ht_path(data_type, 'truth_data'))
    fam_ht = hl.read_table(annotations_ht_path(data_type, 'family_stats'))
    fam_ht = fam_ht.select(family_stats=fam_ht.family_stats[0])
    gnomad_ht = get_gnomad_data(data_type).rows()
    gnomad_ht = gnomad_ht.select(
        vqsr_negative_train_site=gnomad_ht.info.NEGATIVE_TRAIN_SITE,
        vqsr_positive_train_site=gnomad_ht.info.POSITIVE_TRAIN_SITE,
        fail_hard_filters=(gnomad_ht.info.QD < 2) | (gnomad_ht.info.FS > 60) |
        (gnomad_ht.info.MQ < 30))
    lcr_intervals = hl.import_locus_intervals(lcr_intervals_path)

    ht = ht.annotate(
        **ht_truth_data[ht.key],
        **fam_ht[ht.key],
        **gnomad_ht[ht.key],
        **denovo_ht[ht.key],
        clinvar=hl.is_defined(clinvar_ht[ht.key]),
        indel_length=hl.abs(ht.alleles[0].length() - ht.alleles[1].length()),
        rank_bins=hl.array([
            hl.Struct(
                rank_id=rank_name,
                bin=hl.int(
                    hl.ceil(
                        hl.float(ht[rank_name] + 1) / hl.floor(
                            ht.globals.rank_variant_counts[rank_name][hl.cond(
                                hl.is_snp(ht.alleles[0], ht.alleles[1]), 'snv',
                                'indel')] / n_bins))))
            for rank_name in rank_variant_counts
        ]),
        lcr=hl.is_defined(lcr_intervals[ht.locus]))

    ht = ht.explode(ht.rank_bins)
    ht = ht.transmute(rank_id=ht.rank_bins.rank_id, bin=ht.rank_bins.bin)
    ht = ht.filter(hl.is_defined(ht.bin))

    ht = ht.checkpoint(
        f'gs://gnomad-tmp/gnomad_score_binning_{data_type}_tmp_{data}.ht',
        overwrite=True)

    # Create binned data
    return (ht.group_by(
        rank_id=ht.rank_id,
        contig=ht.locus.contig,
        snv=hl.is_snp(ht.alleles[0], ht.alleles[1]),
        bi_allelic=hl.is_defined(ht.biallelic_rank),
        singleton=ht.singleton,
        release_adj=ht.ac > 0,
        bin=ht.bin)._set_buffer_size(20000).aggregate(
            min_score=hl.agg.min(ht.score),
            max_score=hl.agg.max(ht.score),
            n=hl.agg.count(),
            n_ins=hl.agg.count_where(
                hl.is_insertion(ht.alleles[0], ht.alleles[1])),
            n_del=hl.agg.count_where(
                hl.is_deletion(ht.alleles[0], ht.alleles[1])),
            n_ti=hl.agg.count_where(
                hl.is_transition(ht.alleles[0], ht.alleles[1])),
            n_tv=hl.agg.count_where(
                hl.is_transversion(ht.alleles[0], ht.alleles[1])),
            n_1bp_indel=hl.agg.count_where(ht.indel_length == 1),
            n_mod3bp_indel=hl.agg.count_where((ht.indel_length % 3) == 0),
            n_clinvar=hl.agg.count_where(ht.clinvar),
            n_singleton=hl.agg.count_where(ht.singleton),
            n_validated_de_novos=hl.agg.count_where(ht.validated_denovo),
            n_high_confidence_de_novos=hl.agg.count_where(
                ht.high_confidence_denovo),
            n_de_novo=hl.agg.filter(
                ht.family_stats.unrelated_qc_callstats.AC[1] == 0,
                hl.agg.sum(ht.family_stats.mendel.errors)),
            n_de_novo_no_lcr=hl.agg.filter(
                ~ht.lcr & (ht.family_stats.unrelated_qc_callstats.AC[1] == 0),
                hl.agg.sum(ht.family_stats.mendel.errors)),
            n_de_novo_sites=hl.agg.filter(
                ht.family_stats.unrelated_qc_callstats.AC[1] == 0,
                hl.agg.count_where(ht.family_stats.mendel.errors > 0)),
            n_de_novo_sites_no_lcr=hl.agg.filter(
                ~ht.lcr & (ht.family_stats.unrelated_qc_callstats.AC[1] == 0),
                hl.agg.count_where(ht.family_stats.mendel.errors > 0)),
            n_trans_singletons=hl.agg.filter(
                (ht.info_ac < 3) &
                (ht.family_stats.unrelated_qc_callstats.AC[1] == 1),
                hl.agg.sum(ht.family_stats.tdt.t)),
            n_untrans_singletons=hl.agg.filter(
                (ht.info_ac < 3) &
                (ht.family_stats.unrelated_qc_callstats.AC[1] == 1),
                hl.agg.sum(ht.family_stats.tdt.u)),
            n_train_trans_singletons=hl.agg.count_where(
                (ht.family_stats.unrelated_qc_callstats.AC[1] == 1)
                & (ht.family_stats.tdt.t == 1)),
            n_omni=hl.agg.count_where(ht.truth_data.omni),
            n_mills=hl.agg.count_where(ht.truth_data.mills),
            n_hapmap=hl.agg.count_where(ht.truth_data.hapmap),
            n_kgp_high_conf_snvs=hl.agg.count_where(
                ht.truth_data.kgp_high_conf_snvs),
            fail_hard_filters=hl.agg.count_where(ht.fail_hard_filters),
            n_vqsr_pos_train=hl.agg.count_where(ht.vqsr_positive_train_site),
            n_vqsr_neg_train=hl.agg.count_where(ht.vqsr_negative_train_site)))
Esempio n. 9
0
def compute_binned_truth_sample_concordance(
    ht: hl.Table,
    binned_score_ht: hl.Table,
    n_bins: int = 100,
    add_bins: Dict[str, hl.expr.BooleanExpression] = {},
) -> hl.Table:
    """
    Determine the concordance (TP, FP, FN) between a truth sample within the callset and the samples truth data grouped by bins computed using `compute_ranked_bin`.

    .. note::
        The input 'ht` should contain three row fields:
            - score: value to use for binning
            - GT: a CallExpression containing the genotype of the evaluation data for the sample
            - truth_GT: a CallExpression containing the genotype of the truth sample
        The input `binned_score_ht` should contain:
             - score: value used to bin the full callset
             - bin: the full callset bin

    'add_bins` can be used to add additional global and truth sample binning to the final binned truth sample
    concordance HT. The keys in `add_bins` must be present in `binned_score_ht` and the values in `add_bins`
    should be expressions on `ht` that define a subset of variants to bin in the truth sample. An example is if we want
    to look at the global and truth sample binning on only bi-allelic variants. `add_bins` could be set to
    {'biallelic_bin': ht.biallelic}.

    The table is grouped by global/truth sample bin and variant type and contains TP, FP and FN.

    :param ht: Input HT
    :param binned_score_ht: Table with the bin annotation for each variant
    :param n_bins: Number of bins to bin the data into
    :param add_bins: Dictionary of additional global bin columns (key) and the expr to use for binning the truth sample (value)
    :return: Binned truth sample concordance HT
    """
    # Annotate score and global bin
    indexed_binned_score_ht = binned_score_ht[ht.key]
    ht = ht.annotate(
        **{
            f"global_{bin_id}": indexed_binned_score_ht[bin_id]
            for bin_id in add_bins
        },
        **{f"_{bin_id}": bin_expr
           for bin_id, bin_expr in add_bins.items()},
        score=indexed_binned_score_ht.score,
        global_bin=indexed_binned_score_ht.bin,
    )

    # Annotate the truth sample bin
    bin_ht = compute_ranked_bin(
        ht,
        score_expr=ht.score,
        bin_expr={
            "truth_sample_bin": hl.expr.bool(True),
            **{
                f"truth_sample_{bin_id}": ht[f"_{bin_id}"]
                for bin_id in add_bins
            },
        },
        n_bins=n_bins,
    )
    ht = ht.join(bin_ht, how="left")

    bin_list = [
        hl.tuple(["global_bin", ht.global_bin]),
        hl.tuple(["truth_sample_bin", ht.truth_sample_bin]),
    ]
    bin_list.extend([
        hl.tuple([f"global_{bin_id}", ht[f"global_{bin_id}"]])
        for bin_id in add_bins
    ])
    bin_list.extend([
        hl.tuple([f"truth_sample_{bin_id}", ht[f"truth_sample_{bin_id}"]])
        for bin_id in add_bins
    ])

    # Explode the global and truth sample bins
    ht = ht.annotate(bin=bin_list)

    ht = ht.explode(ht.bin)
    ht = ht.annotate(bin_id=ht.bin[0], bin=hl.int(ht.bin[1]))

    # Compute TP, FP and FN by bin_id, variant type and bin
    return (ht.group_by("bin_id", "snv", "bin").aggregate(
        # TP => allele is found in both data sets
        tp=hl.agg.count_where(ht.GT.is_non_ref() & ht.truth_GT.is_non_ref()),
        # FP => allele is found only in test data set
        fp=hl.agg.count_where(ht.GT.is_non_ref()
                              & hl.or_else(ht.truth_GT.is_hom_ref(), True)),
        # FN => allele is found in truth data only
        fn=hl.agg.count_where(
            hl.or_else(ht.GT.is_hom_ref(), True) & ht.truth_GT.is_non_ref()),
        min_score=hl.agg.min(ht.score),
        max_score=hl.agg.max(ht.score),
        n_alleles=hl.agg.count(),
    ).repartition(5))
Esempio n. 10
0
def generate_sib_stats_expr(
    mt: hl.MatrixTable,
    sib_ht: hl.Table,
    i_col: str = "i",
    j_col: str = "j",
    strata: Dict[str, hl.expr.BooleanExpression] = {"raw": True},
    is_female: Optional[hl.expr.BooleanExpression] = None,
) -> hl.expr.StructExpression:
    """
    Generates a row-wise expression containing the number of alternate alleles in common between sibling pairs.

    The sibling sharing counts can be stratified using additional filters using `stata`.

    .. note::

        This function expects that the `mt` has either been split or filtered to only bi-allelics
        If a sample has multiple sibling pairs, only one pair will be counted

    :param mt: Input matrix table
    :param sib_ht: Table defining sibling pairs with one sample in a col (`i_col`) and the second in another col (`j_col`)
    :param i_col: Column containing the 1st sample of the pair in the relationship table
    :param j_col: Column containing the 2nd sample of the pair in the relationship table
    :param strata: Dict with additional strata to use when computing shared sibling variant counts
    :param is_female: An optional column in mt giving the sample sex. If not given, counts are only computed for autosomes.
    :return: A Table with the sibling shared variant counts
    """
    def _get_alt_count(locus, gt, is_female):
        """
        Helper method to calculate alt allele count with sex info if present
        """
        if is_female is None:
            return hl.or_missing(locus.in_autosome(), gt.n_alt_alleles())
        return (hl.case().when(
            locus.in_autosome_or_par(), gt.n_alt_alleles()).when(
                ~is_female & (locus.in_x_nonpar() | locus.in_y_nonpar()),
                hl.min(1, gt.n_alt_alleles()),
            ).when(is_female & locus.in_y_nonpar(), 0).default(0))

    if is_female is None:
        logger.warning(
            "Since no sex expression was given to generate_sib_stats_expr, only variants in autosomes will be counted."
        )

    # If a sample is in sib_ht more than one time, keep only one of the sibling pairs
    # First filter to only samples found in mt to keep as many pairs as possible
    s_to_keep = mt.aggregate_cols(hl.agg.collect_as_set(mt.s), _localize=False)
    sib_ht = sib_ht.filter(
        s_to_keep.contains(sib_ht[i_col].s)
        & s_to_keep.contains(sib_ht[j_col].s))
    sib_ht = sib_ht.add_index("sib_idx")
    sib_ht = sib_ht.annotate(sibs=[sib_ht[i_col].s, sib_ht[j_col].s])
    sib_ht = sib_ht.explode("sibs")
    sib_ht = sib_ht.group_by("sibs").aggregate(
        sib_idx=(hl.agg.take(sib_ht.sib_idx, 1, ordering=sib_ht.sib_idx)[0]))
    sib_ht = sib_ht.group_by(
        sib_ht.sib_idx).aggregate(sibs=hl.agg.collect(sib_ht.sibs))
    sib_ht = sib_ht.filter(hl.len(sib_ht.sibs) == 2).persist()

    logger.info(
        f"Generating sibling variant sharing counts using {sib_ht.count()} pairs."
    )
    sib_ht = sib_ht.explode("sibs").key_by("sibs")[mt.s]

    # Create sibling sharing counters
    sib_stats = hl.struct(
        **{
            f"n_sib_shared_variants_{name}": hl.sum(
                hl.agg.filter(
                    expr,
                    hl.agg.group_by(
                        sib_ht.sib_idx,
                        hl.or_missing(
                            hl.agg.sum(hl.is_defined(mt.GT)) == 2,
                            hl.agg.min(
                                _get_alt_count(mt.locus, mt.GT, is_female)),
                        ),
                    ),
                ).values())
            for name, expr in strata.items()
        })

    sib_stats = sib_stats.annotate(
        **{
            f"ac_sibs_{name}": hl.agg.filter(
                expr & hl.is_defined(sib_ht.sib_idx),
                hl.agg.sum(mt.GT.n_alt_alleles()))
            for name, expr in strata.items()
        })

    return sib_stats
Esempio n. 11
0
def train_rf_model(
    ht: hl.Table,
    rf_features: List[str],
    tp_expr: hl.expr.BooleanExpression,
    fp_expr: hl.expr.BooleanExpression,
    fp_to_tp: float = 1.0,
    num_trees: int = 500,
    max_depth: int = 5,
    test_expr: hl.expr.BooleanExpression = False,
) -> Tuple[hl.Table, pyspark.ml.PipelineModel]:
    """
    Perform random forest (RF) training using a Table annotated with features and training data.

    .. note::

        This function uses `train_rf` and extends it by:
            - Adding an option to apply the resulting model to test variants which are withheld from training.
            - Uses a false positive (FP) to true positive (TP) ratio to determine what variants to use for RF training.

    The returned Table includes the following annotations:
        - rf_train: indicates if the variant was used for training of the RF model.
        - rf_label: indicates if the variant is a TP or FP.
        - rf_test: indicates if the variant was used in testing of the RF model.
        - features: global annotation of the features used for the RF model.
        - features_importance: global annotation of the importance of each feature in the model.
        - test_results: results from testing the model on variants defined by `test_expr`.

    :param ht: Table annotated with features for the RF model and the positive and negative training data.
    :param rf_features: List of column names to use as features in the RF training.
    :param tp_expr: TP training expression.
    :param fp_expr: FP training expression.
    :param fp_to_tp: Ratio of FPs to TPs for creating the RF model. If set to 0, all training examples are used.
    :param num_trees: Number of trees in the RF model.
    :param max_depth: Maxmimum tree depth in the RF model.
    :param test_expr: An expression specifying variants to hold out for testing and use for evaluation only.
    :return: Table with TP and FP training sets used in the RF training and the resulting RF model.
    """

    ht = ht.annotate(_tp=tp_expr, _fp=fp_expr, rf_test=test_expr)

    rf_ht = sample_training_examples(
        ht, tp_expr=ht._tp, fp_expr=ht._fp, fp_to_tp=fp_to_tp, test_expr=ht.rf_test
    )
    ht = ht.annotate(rf_train=rf_ht[ht.key].train, rf_label=rf_ht[ht.key].label)

    summary = ht.group_by("_tp", "_fp", "rf_train", "rf_label", "rf_test").aggregate(
        n=hl.agg.count()
    )
    logger.info("Summary of TP/FP and RF training data:")
    summary.show(n=20)

    logger.info(
        "Training RF model:\nfeatures: {}\nnum_tree: {}\nmax_depth:{}".format(
            ",".join(rf_features), num_trees, max_depth
        )
    )

    rf_model = train_rf(
        ht.filter(ht.rf_train),
        features=rf_features,
        label="rf_label",
        num_trees=num_trees,
        max_depth=max_depth,
    )

    test_results = None
    if test_expr is not None:
        logger.info(f"Testing model on specified variants or intervals...")
        test_ht = ht.filter(hl.is_defined(ht.rf_label) & ht.rf_test)
        test_results = test_model(
            test_ht, rf_model, features=rf_features, label="rf_label"
        )

    features_importance = get_features_importance(rf_model)
    ht = ht.select_globals(
        features_importance=features_importance,
        features=rf_features,
        test_results=test_results,
    )

    return ht.select("rf_train", "rf_label", "rf_test"), rf_model
Esempio n. 12
0
def compute_binned_truth_sample_concordance(ht: hl.Table,
                                            binned_score_ht: hl.Table,
                                            n_bins: int = 100) -> hl.Table:
    """
    Determines the concordance (TP, FP, FN) between a truth sample within the callset and the samples truth data
    grouped by bins computed using `compute_quantile_bin`.

    .. note::

        The input 'ht` should contain three row fields:
            - score: value to use for quantile binning
            - GT: a CallExpression containing the genotype of the evaluation data for the sample
            - truth_GT: a CallExpression containing the genotype of the truth sample

        The input `binned_score_ht` should contain:
             - score: value used to bin the full callset
             - bin: the full callset quantile bin


    The table is grouped by global/truth sample bin and variant type and contains TP, FP and FN.

    :param ht: Input HT
    :param binned_score_ht: Table with the an annotation for quantile bin for each variant
    :param n_bins: Number of bins to bin the data into
    :return: Binned truth sample concordance HT
    """
    # Annotate score and global bin
    indexed_binned_score_ht = binned_score_ht[ht.key]
    ht = ht.annotate(score=indexed_binned_score_ht.score,
                     global_bin=indexed_binned_score_ht.bin)

    # Annotate the truth sample quantile bin
    bin_ht = compute_quantile_bin(
        ht,
        score_expr=ht.score,
        bin_expr={"truth_sample_bin": hl.expr.bool(True)},
        n_bins=n_bins,
    )
    ht = ht.join(bin_ht, how="left")

    # Explode the global and truth sample bins
    ht = ht.annotate(bin=[
        hl.tuple(["global_bin", ht.global_bin]),
        hl.tuple(["truth_sample_bin", ht.truth_sample_bin]),
    ])

    ht = ht.explode(ht.bin)
    ht = ht.annotate(bin_id=ht.bin[0], bin=hl.int(ht.bin[1]))

    # Compute TP, FP and FN by bin_id, variant type and bin
    return (ht.group_by("bin_id", "snv", "bin").aggregate(
        # TP => allele is found in both data sets
        tp=hl.agg.count_where(ht.GT.is_non_ref() & ht.truth_GT.is_non_ref()),
        # FP => allele is found only in test data set
        fp=hl.agg.count_where(ht.GT.is_non_ref()
                              & hl.or_else(ht.truth_GT.is_hom_ref(), True)),
        # FN => allele is found in truth data only
        fn=hl.agg.count_where(ht.GT.is_hom_ref()
                              & hl.or_else(ht.truth_GT.is_non_ref(), True)),
        min_score=hl.agg.min(ht.score),
        max_score=hl.agg.max(ht.score),
        n_alleles=hl.agg.count(),
    ).repartition(5))
Esempio n. 13
0
def create_binned_data_initial(ht: hl.Table, data: str, data_type: str,
                               n_bins: int) -> hl.Table:
    # Count variants for ranking
    count_expr = {
        x: hl.agg.filter(
            hl.is_defined(ht[x]),
            hl.agg.counter(
                hl.cond(hl.is_snp(ht.alleles[0], ht.alleles[1]), 'snv',
                        'indel')))
        for x in ht.row if x.endswith('rank')
    }
    rank_variant_counts = ht.aggregate(hl.Struct(**count_expr))
    logger.info(
        f"Found the following variant counts:\n {pformat(rank_variant_counts)}"
    )
    ht_truth_data = hl.read_table(
        f"{temp_dir}/ddd-elgh-ukbb/variant_qc/truthset_table.ht")
    ht = ht.annotate_globals(rank_variant_counts=rank_variant_counts)
    ht = ht.annotate(
        **ht_truth_data[ht.key],
        # **fam_ht[ht.key],
        # **gnomad_ht[ht.key],
        # **denovo_ht[ht.key],
        # clinvar=hl.is_defined(clinvar_ht[ht.key]),
        indel_length=hl.abs(ht.alleles[0].length() - ht.alleles[1].length()),
        rank_bins=hl.array([
            hl.Struct(
                rank_id=rank_name,
                bin=hl.int(
                    hl.ceil(
                        hl.float(ht[rank_name] + 1) / hl.floor(
                            ht.globals.rank_variant_counts[rank_name][hl.cond(
                                hl.is_snp(ht.alleles[0], ht.alleles[1]), 'snv',
                                'indel')] / n_bins))))
            for rank_name in rank_variant_counts
        ]),
        # lcr=hl.is_defined(lcr_intervals[ht.locus])
    )

    ht = ht.explode(ht.rank_bins)
    ht = ht.transmute(rank_id=ht.rank_bins.rank_id, bin=ht.rank_bins.bin)
    ht = ht.filter(hl.is_defined(ht.bin))

    ht = ht.checkpoint(f'{tmp_dir}/gnomad_score_binning_tmp.ht',
                       overwrite=True)

    # Create binned data
    return (ht.group_by(
        rank_id=ht.rank_id,
        contig=ht.locus.contig,
        snv=hl.is_snp(ht.alleles[0], ht.alleles[1]),
        bi_allelic=hl.is_defined(ht.biallelic_rank),
        singleton=ht.transmitted_singleton,
        trans_singletons=hl.is_defined(ht.singleton_rank),
        de_novo_high_quality=ht.de_novo_high_quality_rank,
        de_novo_medium_quality=hl.is_defined(ht.de_novo_medium_quality_rank),
        de_novo_synonymous=hl.is_defined(ht.de_novo_synonymous_rank),
        # release_adj=ht.ac > 0,
        bin=ht.bin
    )._set_buffer_size(20000).aggregate(
        min_score=hl.agg.min(ht.score),
        max_score=hl.agg.max(ht.score),
        n=hl.agg.count(),
        n_ins=hl.agg.count_where(hl.is_insertion(ht.alleles[0],
                                                 ht.alleles[1])),
        n_del=hl.agg.count_where(hl.is_deletion(ht.alleles[0], ht.alleles[1])),
        n_ti=hl.agg.count_where(hl.is_transition(ht.alleles[0],
                                                 ht.alleles[1])),
        n_tv=hl.agg.count_where(
            hl.is_transversion(ht.alleles[0], ht.alleles[1])),
        n_1bp_indel=hl.agg.count_where(ht.indel_length == 1),
        n_mod3bp_indel=hl.agg.count_where((ht.indel_length % 3) == 0),
        # n_clinvar=hl.agg.count_where(ht.clinvar),
        n_singleton=hl.agg.count_where(ht.transmitted_singleton),
        n_high_quality_de_novos=hl.agg.count_where(
            ht.de_novo_data.p_de_novo[0] > 0.99),
        n_medium_quality_de_novos=hl.agg.count_where(
            ht.de_novo_data.p_de_novo[0] > 0.5),
        n_high_confidence_de_novos=hl.agg.count_where(
            ht.de_novo_data.confidence[0] == 'HIGH'),
        n_de_novo=hl.agg.filter(
            ht.family_stats.unrelated_qc_callstats.AC[0][1] == 0,
            hl.agg.sum(ht.family_stats.mendel[0].errors)),
        n_high_quality_de_novos_synonymous=hl.agg.count_where(
            (ht.de_novo_data.p_de_novo[0] > 0.99)
            & (ht.consequence == "synonymous_variant")),
        # n_de_novo_no_lcr=hl.agg.filter(~ht.lcr & (
        #    ht.family_stats.unrelated_qc_callstats.AC[1] == 0), hl.agg.sum(ht.family_stats.mendel.errors)),
        n_de_novo_sites=hl.agg.filter(
            ht.family_stats.unrelated_qc_callstats.AC[0][1] == 0,
            hl.agg.count_where(ht.family_stats.mendel[0].errors > 0)),
        # n_de_novo_sites_no_lcr=hl.agg.filter(~ht.lcr & (
        #    ht.family_stats.unrelated_qc_callstats.AC[1] == 0), hl.agg.count_where(ht.family_stats.mendel.errors > 0)),
        n_trans_singletons=hl.agg.filter(
            (ht.ac_raw < 3) &
            (ht.family_stats.unrelated_qc_callstats.AC[0][1] == 1),
            hl.agg.sum(ht.family_stats.tdt[0].t)),
        n_trans_singletons_synonymous=hl.agg.filter(
            (ht.ac_raw < 3) & (ht.consequence == "synonymous_variant") &
            (ht.AC[0] == 2), hl.agg.sum(ht.family_stats.tdt[0].t)),
        n_untrans_singletons=hl.agg.filter(
            (ht.ac_raw < 3) &
            (ht.family_stats.unrelated_qc_callstats.AC[0][1] == 1),
            hl.agg.sum(ht.family_stats.tdt[0].u)),
        n_untrans_singletons_synonymous=hl.agg.filter(
            (ht.ac_raw < 3) & (ht.consequence == "synonymous_variant") &
            (ht.AC[0] == 1), hl.agg.sum(ht.family_stats.tdt[0].u)),
        n_train_trans_singletons=hl.agg.count_where(
            (ht.family_stats.unrelated_qc_callstats.AC[0][1] == 1)
            & (ht.family_stats.tdt[0].t == 1)),
        n_omni=hl.agg.count_where(ht.omni),
        n_mills=hl.agg.count_where(ht.mills),
        n_hapmap=hl.agg.count_where(ht.hapmap),
        n_kgp_high_conf_snvs=hl.agg.count_where(ht.kgp_phase1_hc),
        fail_hard_filters=hl.agg.count_where(ht.fail_hard_filters),
        # n_vqsr_pos_train=hl.agg.count_where(ht.vqsr_positive_train_site),
        # n_vqsr_neg_train=hl.agg.count_where(ht.vqsr_negative_train_site)
    ))
Esempio n. 14
0
def annotate_unphased_pairs(unphased_ht: hl.Table, n_variant_pairs: int,
                            least_consequence: str, max_af: float):
    # unphased_ht = vp_ht.filter(hl.is_missing(vp_ht.all_phase))
    # unphased_ht = unphased_ht.key_by()

    # Explode variant pairs
    unphased_ht = unphased_ht.annotate(las=[
        hl.tuple([unphased_ht.locus1, unphased_ht.alleles1]),
        hl.tuple([unphased_ht.locus2, unphased_ht.alleles2])
    ]).explode('las', name='la')

    unphased_ht = unphased_ht.key_by(
        locus=unphased_ht.la[0], alleles=unphased_ht.la[1]).persist(
        )  # .checkpoint('gs://gnomad-tmp/vp_ht_unphased.ht')

    # Annotate single variants with gnomAD freq
    gnomad_ht = gnomad.public_release('exomes').ht()
    gnomad_ht = gnomad_ht.semi_join(unphased_ht).repartition(
        ceil(n_variant_pairs / 10000), shuffle=True).persist()

    missing_freq = hl.struct(
        AC=0,
        AF=0,
        AN=125748 * 2,  # set to no missing for now
        homozygote_count=0)

    logger.info(
        f"{gnomad_ht.count()}/{unphased_ht.count()} single variants from the unphased pairs found in gnomAD."
    )

    gnomad_indexed = gnomad_ht[unphased_ht.key]
    gnomad_freq = gnomad_indexed.freq
    unphased_ht = unphased_ht.annotate(
        adj_freq=hl.or_else(gnomad_freq[0], missing_freq),
        raw_freq=hl.or_else(gnomad_freq[1], missing_freq),
        vep_genes=vep_genes_expr(gnomad_indexed.vep, least_consequence),
        max_af_filter=gnomad_indexed.freq[0].AF <= max_af
        # pop_max_freq=hl.or_else(
        #     gnomad_exomes.popmax[0],
        #     missing_freq.annotate(
        #         pop=hl.null(hl.tstr)
        #     )
        # )
    )
    unphased_ht = unphased_ht.persist()
    # unphased_ht = unphased_ht.checkpoint('gs://gnomad-tmp/unphased_ann.ht', overwrite=True)

    loci_expr = hl.sorted(
        hl.agg.collect(
            hl.tuple([
                unphased_ht.locus,
                hl.struct(
                    adj_freq=unphased_ht.adj_freq,
                    raw_freq=unphased_ht.raw_freq,
                    # pop_max_freq=unphased_ht.pop_max_freq
                )
            ])),
        lambda x: x[0]  # sort by locus
    ).map(lambda x: x[1]  # get rid of locus
          )

    vp_freq_expr = hl.struct(v1=loci_expr[0], v2=loci_expr[1])

    # [AABB, AABb, AAbb, AaBB, AaBb, Aabb, aaBB, aaBb, aabb]
    def get_gt_counts(freq: str):
        return hl.array([
            hl.min(vp_freq_expr.v1[freq].AN, vp_freq_expr.v2[freq].AN),  # AABB
            vp_freq_expr.v2[freq].AC -
            (2 * vp_freq_expr.v2[freq].homozygote_count),  # AABb
            vp_freq_expr.v2[freq].homozygote_count,  # AAbb
            vp_freq_expr.v1[freq].AC -
            (2 * vp_freq_expr.v1[freq].homozygote_count),  # AaBB
            0,  # AaBb
            0,  # Aabb
            vp_freq_expr.v1[freq].homozygote_count,  # aaBB
            0,  # aaBb
            0  # aabb
        ])

    gt_counts_raw_expr = get_gt_counts('raw_freq')
    gt_counts_adj_expr = get_gt_counts('adj_freq')

    # gt_counts_pop_max_expr = get_gt_counts('pop_max_freq')
    unphased_ht = unphased_ht.group_by(
        unphased_ht.locus1, unphased_ht.alleles1, unphased_ht.locus2,
        unphased_ht.alleles2
    ).aggregate(
        pop='all',  # TODO Add option for multiple pops?
        phase_info=hl.struct(gt_counts=hl.struct(raw=gt_counts_raw_expr,
                                                 adj=gt_counts_adj_expr),
                             em=hl.struct(
                                 raw=get_em_expr(gt_counts_raw_expr),
                                 adj=get_em_expr(gt_counts_raw_expr))),
        vep_genes=hl.agg.collect(
            unphased_ht.vep_genes).filter(lambda x: hl.len(x) > 0),
        max_af_filter=hl.agg.all(unphased_ht.max_af_filter)

        # pop_max_gt_counts_adj=gt_counts_raw_expr,
        # pop_max_em_p_chet_adj=get_em_expr(gt_counts_raw_expr).p_chet,
    )  # .key_by()

    unphased_ht = unphased_ht.transmute(
        vep_filter=(hl.len(unphased_ht.vep_genes) > 1)
        & (hl.len(unphased_ht.vep_genes[0].intersection(
            unphased_ht.vep_genes[1])) > 0))

    max_af_filtered, vep_filtered = unphased_ht.aggregate([
        hl.agg.count_where(~unphased_ht.max_af_filter),
        hl.agg.count_where(~unphased_ht.vep_filter)
    ])
    if max_af_filtered > 0:
        logger.info(
            f"{max_af_filtered} variant-pairs excluded because the AF of at least one variant was > {max_af}"
        )
    if vep_filtered > 0:
        logger.info(
            f"{vep_filtered} variant-pairs excluded because the variants were not found within the same gene with a csq of at least {least_consequence}"
        )

    unphased_ht = unphased_ht.filter(unphased_ht.max_af_filter
                                     & unphased_ht.vep_filter)

    return unphased_ht.drop('max_af_filter', 'vep_filter')