def apply_rf_model(ht: hl.Table, rf_model: pyspark.ml.PipelineModel, features: List[str], label: str, probability_col_name: str = 'rf_probability', prediction_col_name: str = 'rf_prediction') -> hl.Table: """ Applies a Random Forest (RF) pipeline model to a Table and annotate the RF probabilities and predictions. :param MatrixTable ht: Input HT :param PipelineModel rf_model: Random Forest pipeline model :param list of str features: List of feature columns in the pipeline. !Should match the model list of features! :param str label: Column containing the labels. !Should match the model labels! :param str probability_col_name: Name of the column that will store the RF probabilities :param str prediction_col_name: Name of the column that will store the RF predictions :return: Table with RF columns :rtype: Table """ logger.info("Applying RF model.") check_ht_fields_for_spark(ht, features + [label]) index_name = 'rf_idx' while index_name in ht.row: index_name += '_tmp' ht = ht.add_index(name=index_name) ht_keys = ht.key ht = ht.key_by(index_name) df = ht_to_rf_df(ht, features, label, index_name) rf_df = rf_model.transform(df) def to_array(col): def to_array_(v): return v.toArray().tolist() return udf(to_array_, ArrayType(DoubleType()))(col) rf_ht = hl.Table.from_spark( rf_df.withColumn("probability", to_array(col("probability"))).select( [index_name, 'probability', 'predictedLabel'])).persist() rf_ht = rf_ht.key_by(index_name) ht = ht.annotate( **{ probability_col_name: { label: rf_ht[ht[index_name]]["probability"][i] for i, label in enumerate(get_labels(rf_model)) }, prediction_col_name: rf_ht[ht[index_name]]["predictedLabel"] }) ht = ht.key_by(*ht_keys) ht = ht.drop(index_name) return ht
def generate_sib_stats_expr( mt: hl.MatrixTable, sib_ht: hl.Table, i_col: str = "i", j_col: str = "j", strata: Dict[str, hl.expr.BooleanExpression] = {"raw": True}, is_female: Optional[hl.expr.BooleanExpression] = None, ) -> hl.expr.StructExpression: """ Generates a row-wise expression containing the number of alternate alleles in common between sibling pairs. The sibling sharing counts can be stratified using additional filters using `stata`. .. note:: This function expects that the `mt` has either been split or filtered to only bi-allelics If a sample has multiple sibling pairs, only one pair will be counted :param mt: Input matrix table :param sib_ht: Table defining sibling pairs with one sample in a col (`i_col`) and the second in another col (`j_col`) :param i_col: Column containing the 1st sample of the pair in the relationship table :param j_col: Column containing the 2nd sample of the pair in the relationship table :param strata: Dict with additional strata to use when computing shared sibling variant counts :param is_female: An optional column in mt giving the sample sex. If not given, counts are only computed for autosomes. :return: A Table with the sibling shared variant counts """ def _get_alt_count(locus, gt, is_female): """ Helper method to calculate alt allele count with sex info if present """ if is_female is None: return hl.or_missing(locus.in_autosome(), gt.n_alt_alleles()) return (hl.case().when( locus.in_autosome_or_par(), gt.n_alt_alleles()).when( ~is_female & (locus.in_x_nonpar() | locus.in_y_nonpar()), hl.min(1, gt.n_alt_alleles()), ).when(is_female & locus.in_y_nonpar(), 0).default(0)) if is_female is None: logger.warning( "Since no sex expression was given to generate_sib_stats_expr, only variants in autosomes will be counted." ) # If a sample is in sib_ht more than one time, keep only one of the sibling pairs # First filter to only samples found in mt to keep as many pairs as possible s_to_keep = mt.aggregate_cols(hl.agg.collect_as_set(mt.s), _localize=False) sib_ht = sib_ht.filter( s_to_keep.contains(sib_ht[i_col].s) & s_to_keep.contains(sib_ht[j_col].s)) sib_ht = sib_ht.add_index("sib_idx") sib_ht = sib_ht.annotate(sibs=[sib_ht[i_col].s, sib_ht[j_col].s]) sib_ht = sib_ht.explode("sibs") sib_ht = sib_ht.group_by("sibs").aggregate( sib_idx=(hl.agg.take(sib_ht.sib_idx, 1, ordering=sib_ht.sib_idx)[0])) sib_ht = sib_ht.group_by( sib_ht.sib_idx).aggregate(sibs=hl.agg.collect(sib_ht.sibs)) sib_ht = sib_ht.filter(hl.len(sib_ht.sibs) == 2).persist() logger.info( f"Generating sibling variant sharing counts using {sib_ht.count()} pairs." ) sib_ht = sib_ht.explode("sibs").key_by("sibs")[mt.s] # Create sibling sharing counters sib_stats = hl.struct( **{ f"n_sib_shared_variants_{name}": hl.sum( hl.agg.filter( expr, hl.agg.group_by( sib_ht.sib_idx, hl.or_missing( hl.agg.sum(hl.is_defined(mt.GT)) == 2, hl.agg.min( _get_alt_count(mt.locus, mt.GT, is_female)), ), ), ).values()) for name, expr in strata.items() }) sib_stats = sib_stats.annotate( **{ f"ac_sibs_{name}": hl.agg.filter( expr & hl.is_defined(sib_ht.sib_idx), hl.agg.sum(mt.GT.n_alt_alleles())) for name, expr in strata.items() }) return sib_stats
def apply_rf_model( ht: hl.Table, rf_model: pyspark.ml.PipelineModel, features: List[str], label: str, probability_col_name: str = "rf_probability", prediction_col_name: str = "rf_prediction", ) -> hl.Table: """ Applies a Random Forest (RF) pipeline model to a Table and annotate the RF probabilities and predictions. :param ht: Input HT :param rf_model: Random Forest pipeline model :param features: List of feature columns in the pipeline. !Should match the model list of features! :param label: Column containing the labels. !Should match the model labels! :param probability_col_name: Name of the column that will store the RF probabilities :param prediction_col_name: Name of the column that will store the RF predictions :return: Table with RF columns """ logger.info("Applying RF model.") check_ht_fields_for_spark(ht, features + [label]) index_name = "rf_idx" while index_name in ht.row: index_name += "_tmp" ht = ht.add_index(name=index_name) ht_keys = ht.key ht = ht.key_by(index_name) df = ht_to_rf_df(ht, features, label, index_name) rf_df = rf_model.transform(df) def to_array(col): def to_array_(v): return v.toArray().tolist() return udf(to_array_, ArrayType(DoubleType()))(col) # Note: SparkSession is needed to write DF to disk before converting to HT; hail currently fails without intermediate write spark = SparkSession.builder.getOrCreate() rf_df.withColumn("probability", to_array(col("probability"))).select( [index_name, "probability", "predictedLabel"]).write.mode("overwrite").save("rf_probs.parquet") rf_df = spark.read.format("parquet").load("rf_probs.parquet") rf_ht = hl.Table.from_spark(rf_df) rf_ht = rf_ht.checkpoint("/tmp/rf_raw_pred.ht", overwrite=True) rf_ht = rf_ht.key_by(index_name) ht = ht.annotate( **{ probability_col_name: { label: rf_ht[ht[index_name]]["probability"][i] for i, label in enumerate(get_labels(rf_model)) }, prediction_col_name: rf_ht[ht[index_name]]["predictedLabel"], }) ht = ht.key_by(*ht_keys) ht = ht.drop(index_name) return ht