Ejemplo n.º 1
0
def apply_rf_model(ht: hl.Table,
                   rf_model: pyspark.ml.PipelineModel,
                   features: List[str],
                   label: str,
                   probability_col_name: str = 'rf_probability',
                   prediction_col_name: str = 'rf_prediction') -> hl.Table:
    """
    Applies a Random Forest (RF) pipeline model to a Table and annotate the RF probabilities and predictions.

    :param MatrixTable ht: Input HT
    :param PipelineModel rf_model: Random Forest pipeline model
    :param list of str features: List of feature columns in the pipeline. !Should match the model list of features!
    :param str label: Column containing the labels. !Should match the model labels!
    :param str probability_col_name: Name of the column that will store the RF probabilities
    :param str prediction_col_name: Name of the column that will store the RF predictions
    :return: Table with RF columns
    :rtype: Table
    """

    logger.info("Applying RF model.")

    check_ht_fields_for_spark(ht, features + [label])

    index_name = 'rf_idx'
    while index_name in ht.row:
        index_name += '_tmp'
    ht = ht.add_index(name=index_name)

    ht_keys = ht.key
    ht = ht.key_by(index_name)

    df = ht_to_rf_df(ht, features, label, index_name)

    rf_df = rf_model.transform(df)

    def to_array(col):
        def to_array_(v):
            return v.toArray().tolist()

        return udf(to_array_, ArrayType(DoubleType()))(col)

    rf_ht = hl.Table.from_spark(
        rf_df.withColumn("probability", to_array(col("probability"))).select(
            [index_name, 'probability', 'predictedLabel'])).persist()

    rf_ht = rf_ht.key_by(index_name)

    ht = ht.annotate(
        **{
            probability_col_name: {
                label: rf_ht[ht[index_name]]["probability"][i]
                for i, label in enumerate(get_labels(rf_model))
            },
            prediction_col_name: rf_ht[ht[index_name]]["predictedLabel"]
        })

    ht = ht.key_by(*ht_keys)
    ht = ht.drop(index_name)

    return ht
Ejemplo n.º 2
0
def compute_callrate_mt(
    mt: hl.MatrixTable,
    intervals_ht: hl.Table,
    bi_allelic_only: bool = True,
    autosomes_only: bool = True,
    match: bool = True,
) -> hl.MatrixTable:
    """
    Compute a sample/interval MT with each entry containing the call rate for that sample/interval.

    This can be used as input for imputing exome sequencing platforms.

    .. note::

        The input interval HT should have a key of type Interval.
        The resulting table will have a key of the same type as the `intervals_ht` table and
        contain an `interval_info` field containing all non-key fields of the `intervals_ht`.

    :param mt: Input MT
    :param intervals_ht: Table containing the intervals. This table has to be keyed by locus.
    :param bi_allelic_only: If set, only bi-allelic sites are used for the computation
    :param autosomes_only: If set, only autosomal intervals are used.
    :param matches: If set, returns all intervals in intervals_ht that overlap the locus in the input MT.
    :return: Callrate MT
    """
    logger.info("Computing call rate MatrixTable")

    if len(intervals_ht.key) != 1 or not isinstance(
            intervals_ht.key[0], hl.expr.IntervalExpression):
        logger.warning(
            "Call rate matrix computation expects `intervals_ht` with a key of type Interval. Found: %s",
            intervals_ht.key,
        )

    if autosomes_only:
        callrate_mt = filter_to_autosomes(mt)

    if bi_allelic_only:
        callrate_mt = callrate_mt.filter_rows(bi_allelic_expr(callrate_mt))

    intervals_ht = intervals_ht.annotate(_interval_key=intervals_ht.key)
    callrate_mt = callrate_mt.annotate_rows(_interval_key=intervals_ht.index(
        callrate_mt.locus, all_matches=match)._interval_key)

    if match:
        callrate_mt = callrate_mt.explode_rows("_interval_key")

    callrate_mt = callrate_mt.filter_rows(
        hl.is_defined(callrate_mt._interval_key.interval))
    callrate_mt = callrate_mt.select_entries(
        GT=hl.or_missing(hl.is_defined(callrate_mt.GT), hl.struct()))
    callrate_mt = callrate_mt.group_rows_by(
        **callrate_mt._interval_key).aggregate(
            callrate=hl.agg.fraction(hl.is_defined(callrate_mt.GT)))
    intervals_ht = intervals_ht.drop("_interval_key")
    callrate_mt = callrate_mt.annotate_rows(interval_info=hl.struct(
        **intervals_ht[callrate_mt.row_key]))
    return callrate_mt
Ejemplo n.º 3
0
def compute_related_samples_to_drop(
    relatedness_ht: hl.Table,
    rank_ht: hl.Table,
    kin_threshold: float,
    filtered_samples: Optional[hl.expr.SetExpression] = None,
    min_related_hard_filter: Optional[int] = None,
) -> hl.Table:
    """
    Computes a Table with the list of samples to drop (and their global rank) to get the maximal independent set of unrelated samples.

    .. note::

        - `relatedness_ht` should be keyed by exactly two fields of the same type, identifying the pair of samples for each row.
        - `rank_ht` should be keyed by a single key of the same type as a single sample identifier in `relatedness_ht`.

    :param relatedness_ht: relatedness HT, as produced by e.g. pc-relate
    :param kin_threshold: Kinship threshold to consider two samples as related
    :param rank_ht: Table with a global rank for each sample (smaller is preferred)
    :param filtered_samples: An optional set of samples to exclude (e.g. these samples were hard-filtered)  These samples will then appear in the resulting samples to drop.
    :param min_related_hard_filter: If provided, any sample that is related to more samples than this parameter will be filtered prior to computing the maximal independent set and appear in the results.
    :return: A Table with the list of the samples to drop along with their rank.
    """

    # Make sure that the key types are valid
    assert len(list(relatedness_ht.key)) == 2
    assert relatedness_ht.key[0].dtype == relatedness_ht.key[1].dtype
    assert len(list(rank_ht.key)) == 1
    assert relatedness_ht.key[0].dtype == rank_ht.key[0].dtype

    logger.info(
        f"Filtering related samples using a kin threshold of {kin_threshold}")
    relatedness_ht = relatedness_ht.filter(relatedness_ht.kin > kin_threshold)

    filtered_samples_rel = set()
    if min_related_hard_filter is not None:
        logger.info(
            f"Computing samples related to too many individuals (>{min_related_hard_filter}) for exclusion"
        )
        gbi = relatedness_ht.annotate(s=list(relatedness_ht.key))
        gbi = gbi.explode(gbi.s)
        gbi = gbi.group_by(gbi.s).aggregate(n=hl.agg.count())
        filtered_samples_rel = gbi.aggregate(
            hl.agg.filter(gbi.n > min_related_hard_filter,
                          hl.agg.collect_as_set(gbi.s)))
        logger.info(
            f"Found {len(filtered_samples_rel)} samples with too many 1st/2nd degree relatives. These samples will be excluded."
        )

    if filtered_samples is not None:
        filtered_samples_rel = filtered_samples_rel.union(
            relatedness_ht.aggregate(
                hl.agg.explode(
                    lambda s: hl.agg.collect_as_set(s),
                    hl.array(list(relatedness_ht.key)).filter(
                        lambda s: filtered_samples.contains(s)),
                )))

    if len(filtered_samples_rel) > 0:
        filtered_samples_lit = hl.literal(filtered_samples_rel)
        relatedness_ht = relatedness_ht.filter(
            filtered_samples_lit.contains(relatedness_ht.key[0])
            | filtered_samples_lit.contains(relatedness_ht.key[1]),
            keep=False,
        )

    logger.info("Annotating related sample pairs with rank.")
    i, j = list(relatedness_ht.key)
    relatedness_ht = relatedness_ht.key_by(s=relatedness_ht[i])
    relatedness_ht = relatedness_ht.annotate(**{
        i:
        hl.struct(s=relatedness_ht.s, rank=rank_ht[relatedness_ht.key].rank)
    })
    relatedness_ht = relatedness_ht.key_by(s=relatedness_ht[j])
    relatedness_ht = relatedness_ht.annotate(**{
        j:
        hl.struct(s=relatedness_ht.s, rank=rank_ht[relatedness_ht.key].rank)
    })
    relatedness_ht = relatedness_ht.key_by(i, j)
    relatedness_ht = relatedness_ht.drop("s")
    relatedness_ht = relatedness_ht.persist()

    related_samples_to_drop_ht = hl.maximal_independent_set(
        relatedness_ht[i],
        relatedness_ht[j],
        keep=False,
        tie_breaker=lambda l, r: l.rank - r.rank,
    )
    related_samples_to_drop_ht = related_samples_to_drop_ht.key_by()
    related_samples_to_drop_ht = related_samples_to_drop_ht.select(
        **related_samples_to_drop_ht.node)
    related_samples_to_drop_ht = related_samples_to_drop_ht.key_by("s")

    if len(filtered_samples_rel) > 0:
        related_samples_to_drop_ht = related_samples_to_drop_ht.union(
            hl.Table.parallelize(
                [
                    hl.struct(s=s, rank=hl.null(hl.tint64))
                    for s in filtered_samples_rel
                ],
                key="s",
            ))

    return related_samples_to_drop_ht
Ejemplo n.º 4
0
def apply_rf_model(
    ht: hl.Table,
    rf_model: pyspark.ml.PipelineModel,
    features: List[str],
    label: str,
    probability_col_name: str = "rf_probability",
    prediction_col_name: str = "rf_prediction",
) -> hl.Table:
    """
    Applies a Random Forest (RF) pipeline model to a Table and annotate the RF probabilities and predictions.

    :param ht: Input HT
    :param rf_model: Random Forest pipeline model
    :param features: List of feature columns in the pipeline. !Should match the model list of features!
    :param label: Column containing the labels. !Should match the model labels!
    :param probability_col_name: Name of the column that will store the RF probabilities
    :param prediction_col_name: Name of the column that will store the RF predictions
    :return: Table with RF columns
    """

    logger.info("Applying RF model.")

    check_ht_fields_for_spark(ht, features + [label])

    index_name = "rf_idx"
    while index_name in ht.row:
        index_name += "_tmp"
    ht = ht.add_index(name=index_name)

    ht_keys = ht.key
    ht = ht.key_by(index_name)

    df = ht_to_rf_df(ht, features, label, index_name)

    rf_df = rf_model.transform(df)

    def to_array(col):
        def to_array_(v):
            return v.toArray().tolist()

        return udf(to_array_, ArrayType(DoubleType()))(col)

    # Note: SparkSession is needed to write DF to disk before converting to HT; hail currently fails without intermediate write
    spark = SparkSession.builder.getOrCreate()
    rf_df.withColumn("probability", to_array(col("probability"))).select(
        [index_name, "probability",
         "predictedLabel"]).write.mode("overwrite").save("rf_probs.parquet")
    rf_df = spark.read.format("parquet").load("rf_probs.parquet")
    rf_ht = hl.Table.from_spark(rf_df)
    rf_ht = rf_ht.checkpoint("/tmp/rf_raw_pred.ht", overwrite=True)
    rf_ht = rf_ht.key_by(index_name)

    ht = ht.annotate(
        **{
            probability_col_name: {
                label: rf_ht[ht[index_name]]["probability"][i]
                for i, label in enumerate(get_labels(rf_model))
            },
            prediction_col_name: rf_ht[ht[index_name]]["predictedLabel"],
        })

    ht = ht.key_by(*ht_keys)
    ht = ht.drop(index_name)

    return ht
Ejemplo n.º 5
0
def generate_final_filter_ht(
    ht: hl.Table,
    model_name: str,
    score_name: str,
    ac0_filter_expr: hl.expr.BooleanExpression,
    ts_ac_filter_expr: hl.expr.BooleanExpression,
    mono_allelic_flag_expr: hl.expr.BooleanExpression,
    inbreeding_coeff_cutoff: float = INBREEDING_COEFF_HARD_CUTOFF,
    snp_bin_cutoff: int = None,
    indel_bin_cutoff: int = None,
    snp_score_cutoff: float = None,
    indel_score_cutoff: float = None,
    aggregated_bin_ht: Optional[hl.Table] = None,
    bin_id: Optional[str] = None,
    vqsr_ht: hl.Table = None,
) -> hl.Table:
    """
    Prepares finalized filtering model given a filtering HT from `rf.apply_rf_model` or VQSR and cutoffs for filtering.

    .. note::

        - `snp_bin_cutoff` and `snp_score_cutoff` are mutually exclusive, and one must be supplied.
        - `indel_bin_cutoff` and `indel_score_cutoff` are mutually exclusive, and one must be supplied.
        - If a `snp_bin_cutoff` or `indel_bin_cutoff` cutoff is supplied then an `aggregated_bin_ht` and `bin_id` must
          also be supplied to determine the SNP and indel scores to use as cutoffs from an aggregated bin Table like
          one created by `compute_grouped_binned_ht` in combination with `score_bin_agg`.

    :param ht: Filtering Table from `rf.apply_rf_model` or VQSR to prepare as the final filter Table
    :param model_name: Filtering model name to use in the 'filters' field (VQSR or RF)
    :param score_name: Name to use for the filtering score annotation. This will be used in place of 'score' in the
        release HT info struct and the INFO field of the VCF (e.g. RF or AS_VQSLOD)
    :param ac0_filter_expr: Expression that indicates if a variant should be filtered as allele count 0 (AC0)
    :param ts_ac_filter_expr: Allele count expression in `ht` to use as a filter for determining a transmitted singleton
    :param mono_allelic_flag_expr: Expression indicating if a variant is mono-allelic
    :param inbreeding_coeff_cutoff: InbreedingCoeff hard filter to use for variants
    :param snp_bin_cutoff: Bin cutoff to use for SNP variant QC filter. Can't be used with `snp_score_cutoff`
    :param indel_bin_cutoff: Bin cutoff to use for indel variant QC filter. Can't be used with `indel_score_cutoff`
    :param snp_score_cutoff: Score cutoff (e.g. RF probability or AS_VQSLOD) to use for SNP variant QC filter. Can't be used with `snp_bin_cutoff`
    :param indel_score_cutoff: Score cutoff (e.g. RF probability or AS_VQSLOD) to use for indel variant QC filter. Can't be used with `indel_bin_cutoff`
    :param aggregated_bin_ht: Table with aggregate counts of variants based on bins
    :param bin_id: Name of bin to use in 'bin_id' column of `aggregated_bin_ht` to use to determine probability cutoff
    :param vqsr_ht: If a VQSR HT is supplied a 'vqsr' annotation containing AS_VQSLOD, AS_culprit, NEGATIVE_TRAIN_SITE,
        and POSITIVE_TRAIN_SITE will be included in the returned Table
    :return: Finalized random forest Table annotated with variant filters
    """
    if snp_bin_cutoff is not None and snp_score_cutoff is not None:
        raise ValueError(
            "snp_bin_cutoff and snp_score_cutoff are mutually exclusive, please only supply one SNP filtering cutoff."
        )

    if indel_bin_cutoff is not None and indel_score_cutoff is not None:
        raise ValueError(
            "indel_bin_cutoff and indel_score_cutoff are mutually exclusive, please only supply one indel filtering cutoff."
        )

    if snp_bin_cutoff is None and snp_score_cutoff is None:
        raise ValueError(
            "One (and only one) of the parameters snp_bin_cutoff and snp_score_cutoff must be supplied."
        )

    if indel_bin_cutoff is None and indel_score_cutoff is None:
        raise ValueError(
            "One (and only one) of the parameters indel_bin_cutoff and indel_score_cutoff must be supplied."
        )

    if (snp_bin_cutoff is not None or indel_bin_cutoff
            is not None) and (aggregated_bin_ht is None or bin_id is None):
        raise ValueError(
            "If using snp_bin_cutoff or indel_bin_cutoff, both aggregated_bin_ht and bin_id must be supplied"
        )

    # Determine SNP and indel score cutoffs if given bin instead of score
    if snp_bin_cutoff:
        snp_score_cutoff = aggregated_bin_ht.aggregate(
            hl.agg.filter(
                aggregated_bin_ht.snv
                & (aggregated_bin_ht.bin_id == bin_id)
                & (aggregated_bin_ht.bin == snp_bin_cutoff),
                hl.agg.min(aggregated_bin_ht.min_score),
            ))
        snp_cutoff_global = hl.struct(bin=snp_bin_cutoff,
                                      min_score=snp_score_cutoff)

    if indel_bin_cutoff:
        indel_score_cutoff = aggregated_bin_ht.aggregate(
            hl.agg.filter(
                ~aggregated_bin_ht.snv
                & (aggregated_bin_ht.bin_id == bin_id)
                & (aggregated_bin_ht.bin == indel_bin_cutoff),
                hl.agg.min(aggregated_bin_ht.min_score),
            ))
        indel_cutoff_global = hl.struct(bin=indel_bin_cutoff,
                                        min_score=indel_score_cutoff)

    min_score = ht.aggregate(hl.agg.min(ht.score))
    max_score = ht.aggregate(hl.agg.max(ht.score))

    if snp_score_cutoff:
        if snp_score_cutoff < min_score or snp_score_cutoff > max_score:
            raise ValueError(
                "snp_score_cutoff is not within the range of score.")
        snp_cutoff_global = hl.struct(min_score=snp_score_cutoff)

    if indel_score_cutoff:
        if indel_score_cutoff < min_score or indel_score_cutoff > max_score:
            raise ValueError(
                "indel_score_cutoff is not within the range of score.")
        indel_cutoff_global = hl.struct(min_score=indel_score_cutoff)

    logger.info(
        f"Using a SNP score cutoff of {snp_score_cutoff} and an indel score cutoff of {indel_score_cutoff}."
    )

    # Add filters to HT
    filters = dict()

    if ht.any(hl.is_missing(ht.score)):
        ht.filter(hl.is_missing(ht.score)).show()
        raise ValueError("Missing Score!")

    filters[model_name] = (hl.is_missing(ht.score)
                           | (hl.is_snp(ht.alleles[0], ht.alleles[1])
                              & (ht.score < snp_cutoff_global.min_score))
                           | (~hl.is_snp(ht.alleles[0], ht.alleles[1])
                              & (ht.score < indel_cutoff_global.min_score)))

    filters["InbreedingCoeff"] = hl.or_else(
        ht.InbreedingCoeff < inbreeding_coeff_cutoff, False)
    filters["AC0"] = ac0_filter_expr

    annotations_expr = dict()
    if model_name == "RF":
        # Fix annotations for release
        annotations_expr = annotations_expr.update({
            "positive_train_site":
            hl.or_else(ht.positive_train_site, False),
            "rf_tp_probability":
            ht.rf_probability["TP"],
        })
    annotations_expr.update({
        "transmitted_singleton":
        hl.or_missing(ts_ac_filter_expr, ht.transmitted_singleton)
    })
    if "feature_imputed" in ht.row:
        annotations_expr.update({
            x: hl.or_missing(~ht.feature_imputed[x], ht[x])
            for x in [f for f in ht.row.feature_imputed]
        })

    ht = ht.transmute(
        filters=add_filters_expr(filters=filters),
        monoallelic=mono_allelic_flag_expr,
        **{score_name: ht.score},
        **annotations_expr,
    )

    bin_names = [x for x in ht.row if x.endswith("bin")]
    bin_names = [(
        x,
        x.split("adj_")[0] +
        x.split("adj_")[1] if len(x.split("adj_")) == 2 else "raw_" + x,
    ) for x in bin_names]
    ht = ht.transmute(**{j: ht[i] for i, j in bin_names})

    ht = ht.annotate_globals(
        bin_stats=hl.struct(**{j: ht.bin_stats[i]
                               for i, j in bin_names}),
        filtering_model=hl.struct(
            model_name=model_name,
            score_name=score_name,
            snv_cutoff=snp_cutoff_global,
            indel_cutoff=indel_cutoff_global,
        ),
        inbreeding_coeff_cutoff=inbreeding_coeff_cutoff,
    )
    if vqsr_ht:
        vqsr = vqsr_ht[ht.key]
        ht = ht.annotate(
            vqsr=hl.struct(
                AS_VQSLOD=vqsr.info.AS_VQSLOD,
                AS_culprit=vqsr.info.AS_culprit,
                NEGATIVE_TRAIN_SITE=vqsr.info.NEGATIVE_TRAIN_SITE,
                POSITIVE_TRAIN_SITE=vqsr.info.POSITIVE_TRAIN_SITE,
            ),
            SOR=vqsr.info.
            SOR,  # NOTE: This was required for v3.1, we now compute this in `get_site_info_expr`
        )

    ht = ht.drop("AS_culprit")

    return ht
Ejemplo n.º 6
0
def annotate_unphased_pairs(unphased_ht: hl.Table, n_variant_pairs: int,
                            least_consequence: str, max_af: float):
    # unphased_ht = vp_ht.filter(hl.is_missing(vp_ht.all_phase))
    # unphased_ht = unphased_ht.key_by()

    # Explode variant pairs
    unphased_ht = unphased_ht.annotate(las=[
        hl.tuple([unphased_ht.locus1, unphased_ht.alleles1]),
        hl.tuple([unphased_ht.locus2, unphased_ht.alleles2])
    ]).explode('las', name='la')

    unphased_ht = unphased_ht.key_by(
        locus=unphased_ht.la[0], alleles=unphased_ht.la[1]).persist(
        )  # .checkpoint('gs://gnomad-tmp/vp_ht_unphased.ht')

    # Annotate single variants with gnomAD freq
    gnomad_ht = gnomad.public_release('exomes').ht()
    gnomad_ht = gnomad_ht.semi_join(unphased_ht).repartition(
        ceil(n_variant_pairs / 10000), shuffle=True).persist()

    missing_freq = hl.struct(
        AC=0,
        AF=0,
        AN=125748 * 2,  # set to no missing for now
        homozygote_count=0)

    logger.info(
        f"{gnomad_ht.count()}/{unphased_ht.count()} single variants from the unphased pairs found in gnomAD."
    )

    gnomad_indexed = gnomad_ht[unphased_ht.key]
    gnomad_freq = gnomad_indexed.freq
    unphased_ht = unphased_ht.annotate(
        adj_freq=hl.or_else(gnomad_freq[0], missing_freq),
        raw_freq=hl.or_else(gnomad_freq[1], missing_freq),
        vep_genes=vep_genes_expr(gnomad_indexed.vep, least_consequence),
        max_af_filter=gnomad_indexed.freq[0].AF <= max_af
        # pop_max_freq=hl.or_else(
        #     gnomad_exomes.popmax[0],
        #     missing_freq.annotate(
        #         pop=hl.null(hl.tstr)
        #     )
        # )
    )
    unphased_ht = unphased_ht.persist()
    # unphased_ht = unphased_ht.checkpoint('gs://gnomad-tmp/unphased_ann.ht', overwrite=True)

    loci_expr = hl.sorted(
        hl.agg.collect(
            hl.tuple([
                unphased_ht.locus,
                hl.struct(
                    adj_freq=unphased_ht.adj_freq,
                    raw_freq=unphased_ht.raw_freq,
                    # pop_max_freq=unphased_ht.pop_max_freq
                )
            ])),
        lambda x: x[0]  # sort by locus
    ).map(lambda x: x[1]  # get rid of locus
          )

    vp_freq_expr = hl.struct(v1=loci_expr[0], v2=loci_expr[1])

    # [AABB, AABb, AAbb, AaBB, AaBb, Aabb, aaBB, aaBb, aabb]
    def get_gt_counts(freq: str):
        return hl.array([
            hl.min(vp_freq_expr.v1[freq].AN, vp_freq_expr.v2[freq].AN),  # AABB
            vp_freq_expr.v2[freq].AC -
            (2 * vp_freq_expr.v2[freq].homozygote_count),  # AABb
            vp_freq_expr.v2[freq].homozygote_count,  # AAbb
            vp_freq_expr.v1[freq].AC -
            (2 * vp_freq_expr.v1[freq].homozygote_count),  # AaBB
            0,  # AaBb
            0,  # Aabb
            vp_freq_expr.v1[freq].homozygote_count,  # aaBB
            0,  # aaBb
            0  # aabb
        ])

    gt_counts_raw_expr = get_gt_counts('raw_freq')
    gt_counts_adj_expr = get_gt_counts('adj_freq')

    # gt_counts_pop_max_expr = get_gt_counts('pop_max_freq')
    unphased_ht = unphased_ht.group_by(
        unphased_ht.locus1, unphased_ht.alleles1, unphased_ht.locus2,
        unphased_ht.alleles2
    ).aggregate(
        pop='all',  # TODO Add option for multiple pops?
        phase_info=hl.struct(gt_counts=hl.struct(raw=gt_counts_raw_expr,
                                                 adj=gt_counts_adj_expr),
                             em=hl.struct(
                                 raw=get_em_expr(gt_counts_raw_expr),
                                 adj=get_em_expr(gt_counts_raw_expr))),
        vep_genes=hl.agg.collect(
            unphased_ht.vep_genes).filter(lambda x: hl.len(x) > 0),
        max_af_filter=hl.agg.all(unphased_ht.max_af_filter)

        # pop_max_gt_counts_adj=gt_counts_raw_expr,
        # pop_max_em_p_chet_adj=get_em_expr(gt_counts_raw_expr).p_chet,
    )  # .key_by()

    unphased_ht = unphased_ht.transmute(
        vep_filter=(hl.len(unphased_ht.vep_genes) > 1)
        & (hl.len(unphased_ht.vep_genes[0].intersection(
            unphased_ht.vep_genes[1])) > 0))

    max_af_filtered, vep_filtered = unphased_ht.aggregate([
        hl.agg.count_where(~unphased_ht.max_af_filter),
        hl.agg.count_where(~unphased_ht.vep_filter)
    ])
    if max_af_filtered > 0:
        logger.info(
            f"{max_af_filtered} variant-pairs excluded because the AF of at least one variant was > {max_af}"
        )
    if vep_filtered > 0:
        logger.info(
            f"{vep_filtered} variant-pairs excluded because the variants were not found within the same gene with a csq of at least {least_consequence}"
        )

    unphased_ht = unphased_ht.filter(unphased_ht.max_af_filter
                                     & unphased_ht.vep_filter)

    return unphased_ht.drop('max_af_filter', 'vep_filter')