Exemplo n.º 1
0
def apply_rf_model(ht: hl.Table,
                   rf_model: pyspark.ml.PipelineModel,
                   features: List[str],
                   label: str,
                   probability_col_name: str = 'rf_probability',
                   prediction_col_name: str = 'rf_prediction') -> hl.Table:
    """
    Applies a Random Forest (RF) pipeline model to a Table and annotate the RF probabilities and predictions.

    :param MatrixTable ht: Input HT
    :param PipelineModel rf_model: Random Forest pipeline model
    :param list of str features: List of feature columns in the pipeline. !Should match the model list of features!
    :param str label: Column containing the labels. !Should match the model labels!
    :param str probability_col_name: Name of the column that will store the RF probabilities
    :param str prediction_col_name: Name of the column that will store the RF predictions
    :return: Table with RF columns
    :rtype: Table
    """

    logger.info("Applying RF model.")

    check_ht_fields_for_spark(ht, features + [label])

    index_name = 'rf_idx'
    while index_name in ht.row:
        index_name += '_tmp'
    ht = ht.add_index(name=index_name)

    ht_keys = ht.key
    ht = ht.key_by(index_name)

    df = ht_to_rf_df(ht, features, label, index_name)

    rf_df = rf_model.transform(df)

    def to_array(col):
        def to_array_(v):
            return v.toArray().tolist()

        return udf(to_array_, ArrayType(DoubleType()))(col)

    rf_ht = hl.Table.from_spark(
        rf_df.withColumn("probability", to_array(col("probability"))).select(
            [index_name, 'probability', 'predictedLabel'])).persist()

    rf_ht = rf_ht.key_by(index_name)

    ht = ht.annotate(
        **{
            probability_col_name: {
                label: rf_ht[ht[index_name]]["probability"][i]
                for i, label in enumerate(get_labels(rf_model))
            },
            prediction_col_name: rf_ht[ht[index_name]]["predictedLabel"]
        })

    ht = ht.key_by(*ht_keys)
    ht = ht.drop(index_name)

    return ht
Exemplo n.º 2
0
def generate_sib_stats_expr(
    mt: hl.MatrixTable,
    sib_ht: hl.Table,
    i_col: str = "i",
    j_col: str = "j",
    strata: Dict[str, hl.expr.BooleanExpression] = {"raw": True},
    is_female: Optional[hl.expr.BooleanExpression] = None,
) -> hl.expr.StructExpression:
    """
    Generates a row-wise expression containing the number of alternate alleles in common between sibling pairs.

    The sibling sharing counts can be stratified using additional filters using `stata`.

    .. note::

        This function expects that the `mt` has either been split or filtered to only bi-allelics
        If a sample has multiple sibling pairs, only one pair will be counted

    :param mt: Input matrix table
    :param sib_ht: Table defining sibling pairs with one sample in a col (`i_col`) and the second in another col (`j_col`)
    :param i_col: Column containing the 1st sample of the pair in the relationship table
    :param j_col: Column containing the 2nd sample of the pair in the relationship table
    :param strata: Dict with additional strata to use when computing shared sibling variant counts
    :param is_female: An optional column in mt giving the sample sex. If not given, counts are only computed for autosomes.
    :return: A Table with the sibling shared variant counts
    """
    def _get_alt_count(locus, gt, is_female):
        """
        Helper method to calculate alt allele count with sex info if present
        """
        if is_female is None:
            return hl.or_missing(locus.in_autosome(), gt.n_alt_alleles())
        return (hl.case().when(
            locus.in_autosome_or_par(), gt.n_alt_alleles()).when(
                ~is_female & (locus.in_x_nonpar() | locus.in_y_nonpar()),
                hl.min(1, gt.n_alt_alleles()),
            ).when(is_female & locus.in_y_nonpar(), 0).default(0))

    if is_female is None:
        logger.warning(
            "Since no sex expression was given to generate_sib_stats_expr, only variants in autosomes will be counted."
        )

    # If a sample is in sib_ht more than one time, keep only one of the sibling pairs
    # First filter to only samples found in mt to keep as many pairs as possible
    s_to_keep = mt.aggregate_cols(hl.agg.collect_as_set(mt.s), _localize=False)
    sib_ht = sib_ht.filter(
        s_to_keep.contains(sib_ht[i_col].s)
        & s_to_keep.contains(sib_ht[j_col].s))
    sib_ht = sib_ht.add_index("sib_idx")
    sib_ht = sib_ht.annotate(sibs=[sib_ht[i_col].s, sib_ht[j_col].s])
    sib_ht = sib_ht.explode("sibs")
    sib_ht = sib_ht.group_by("sibs").aggregate(
        sib_idx=(hl.agg.take(sib_ht.sib_idx, 1, ordering=sib_ht.sib_idx)[0]))
    sib_ht = sib_ht.group_by(
        sib_ht.sib_idx).aggregate(sibs=hl.agg.collect(sib_ht.sibs))
    sib_ht = sib_ht.filter(hl.len(sib_ht.sibs) == 2).persist()

    logger.info(
        f"Generating sibling variant sharing counts using {sib_ht.count()} pairs."
    )
    sib_ht = sib_ht.explode("sibs").key_by("sibs")[mt.s]

    # Create sibling sharing counters
    sib_stats = hl.struct(
        **{
            f"n_sib_shared_variants_{name}": hl.sum(
                hl.agg.filter(
                    expr,
                    hl.agg.group_by(
                        sib_ht.sib_idx,
                        hl.or_missing(
                            hl.agg.sum(hl.is_defined(mt.GT)) == 2,
                            hl.agg.min(
                                _get_alt_count(mt.locus, mt.GT, is_female)),
                        ),
                    ),
                ).values())
            for name, expr in strata.items()
        })

    sib_stats = sib_stats.annotate(
        **{
            f"ac_sibs_{name}": hl.agg.filter(
                expr & hl.is_defined(sib_ht.sib_idx),
                hl.agg.sum(mt.GT.n_alt_alleles()))
            for name, expr in strata.items()
        })

    return sib_stats
Exemplo n.º 3
0
def apply_rf_model(
    ht: hl.Table,
    rf_model: pyspark.ml.PipelineModel,
    features: List[str],
    label: str,
    probability_col_name: str = "rf_probability",
    prediction_col_name: str = "rf_prediction",
) -> hl.Table:
    """
    Applies a Random Forest (RF) pipeline model to a Table and annotate the RF probabilities and predictions.

    :param ht: Input HT
    :param rf_model: Random Forest pipeline model
    :param features: List of feature columns in the pipeline. !Should match the model list of features!
    :param label: Column containing the labels. !Should match the model labels!
    :param probability_col_name: Name of the column that will store the RF probabilities
    :param prediction_col_name: Name of the column that will store the RF predictions
    :return: Table with RF columns
    """

    logger.info("Applying RF model.")

    check_ht_fields_for_spark(ht, features + [label])

    index_name = "rf_idx"
    while index_name in ht.row:
        index_name += "_tmp"
    ht = ht.add_index(name=index_name)

    ht_keys = ht.key
    ht = ht.key_by(index_name)

    df = ht_to_rf_df(ht, features, label, index_name)

    rf_df = rf_model.transform(df)

    def to_array(col):
        def to_array_(v):
            return v.toArray().tolist()

        return udf(to_array_, ArrayType(DoubleType()))(col)

    # Note: SparkSession is needed to write DF to disk before converting to HT; hail currently fails without intermediate write
    spark = SparkSession.builder.getOrCreate()
    rf_df.withColumn("probability", to_array(col("probability"))).select(
        [index_name, "probability",
         "predictedLabel"]).write.mode("overwrite").save("rf_probs.parquet")
    rf_df = spark.read.format("parquet").load("rf_probs.parquet")
    rf_ht = hl.Table.from_spark(rf_df)
    rf_ht = rf_ht.checkpoint("/tmp/rf_raw_pred.ht", overwrite=True)
    rf_ht = rf_ht.key_by(index_name)

    ht = ht.annotate(
        **{
            probability_col_name: {
                label: rf_ht[ht[index_name]]["probability"][i]
                for i, label in enumerate(get_labels(rf_model))
            },
            prediction_col_name: rf_ht[ht[index_name]]["predictedLabel"],
        })

    ht = ht.key_by(*ht_keys)
    ht = ht.drop(index_name)

    return ht