def get_domain_metrics_dict_by_name(metrics: Dict[Tuple, Any], metric_domain_kwargs: IDDict): return { metric_edge_key_id_tuple[0]: metric_value for metric_edge_key_id_tuple, metric_value in metrics.items() if metric_edge_key_id_tuple[1] == metric_domain_kwargs.to_id() }
def resolve_metric_bundle( self, metric_fn_bundle: Iterable[Tuple[MetricConfiguration, Callable, dict]], ) -> dict: """For each metric name in the given metric_fn_bundle, finds the domain of the metric and calculates it using a metric function from the given provider class. Args: metric_fn_bundle - A batch containing MetricEdgeKeys and their corresponding functions metrics (dict) - A dictionary containing metrics and corresponding parameters Returns: A dictionary of the collected metrics over their respective domains """ resolved_metrics = dict() aggregates: Dict[Tuple, dict] = dict() for ( metric_to_resolve, engine_fn, compute_domain_kwargs, accessor_domain_kwargs, metric_provider_kwargs, ) in metric_fn_bundle: if not isinstance(compute_domain_kwargs, IDDict): compute_domain_kwargs = IDDict(compute_domain_kwargs) domain_id = compute_domain_kwargs.to_id() if domain_id not in aggregates: aggregates[domain_id] = { "column_aggregates": [], "ids": [], "domain_kwargs": compute_domain_kwargs, } aggregates[domain_id]["column_aggregates"].append(engine_fn) aggregates[domain_id]["ids"].append(metric_to_resolve.id) for aggregate in aggregates.values(): compute_domain_kwargs = aggregate["domain_kwargs"] df, _, _ = self.get_compute_domain(compute_domain_kwargs, domain_type="identity") assert len(aggregate["column_aggregates"]) == len(aggregate["ids"]) condition_ids = [] aggregate_cols = [] for idx in range(len(aggregate["column_aggregates"])): column_aggregate = aggregate["column_aggregates"][idx] aggregate_id = str(uuid.uuid4()) condition_ids.append(aggregate_id) aggregate_cols.append(column_aggregate) res = df.agg(*aggregate_cols).collect() assert ( len(res) == 1 ), "all bundle-computed metrics must be single-value statistics" assert len(aggregate["ids"]) == len( res[0]), "unexpected number of metrics returned" logger.debug( f"SparkDFExecutionEngine computed {len(res[0])} metrics on domain_id {IDDict(compute_domain_kwargs).to_id()}" ) for idx, id in enumerate(aggregate["ids"]): resolved_metrics[id] = res[0][idx] return resolved_metrics