Example #1
0
def get_domain_metrics_dict_by_name(metrics: Dict[Tuple, Any],
                                    metric_domain_kwargs: IDDict):
    return {
        metric_edge_key_id_tuple[0]: metric_value
        for metric_edge_key_id_tuple, metric_value in metrics.items()
        if metric_edge_key_id_tuple[1] == metric_domain_kwargs.to_id()
    }
    def resolve_metric_bundle(
        self,
        metric_fn_bundle: Iterable[Tuple[MetricConfiguration, Callable, dict]],
    ) -> dict:
        """For each metric name in the given metric_fn_bundle, finds the domain of the metric and calculates it using a
        metric function from the given provider class.

                Args:
                    metric_fn_bundle - A batch containing MetricEdgeKeys and their corresponding functions
                    metrics (dict) - A dictionary containing metrics and corresponding parameters

                Returns:
                    A dictionary of the collected metrics over their respective domains
        """

        resolved_metrics = dict()
        aggregates: Dict[Tuple, dict] = dict()
        for (
                metric_to_resolve,
                engine_fn,
                compute_domain_kwargs,
                accessor_domain_kwargs,
                metric_provider_kwargs,
        ) in metric_fn_bundle:
            if not isinstance(compute_domain_kwargs, IDDict):
                compute_domain_kwargs = IDDict(compute_domain_kwargs)
            domain_id = compute_domain_kwargs.to_id()
            if domain_id not in aggregates:
                aggregates[domain_id] = {
                    "column_aggregates": [],
                    "ids": [],
                    "domain_kwargs": compute_domain_kwargs,
                }
            aggregates[domain_id]["column_aggregates"].append(engine_fn)
            aggregates[domain_id]["ids"].append(metric_to_resolve.id)
        for aggregate in aggregates.values():
            compute_domain_kwargs = aggregate["domain_kwargs"]
            df, _, _ = self.get_compute_domain(compute_domain_kwargs,
                                               domain_type="identity")
            assert len(aggregate["column_aggregates"]) == len(aggregate["ids"])
            condition_ids = []
            aggregate_cols = []
            for idx in range(len(aggregate["column_aggregates"])):
                column_aggregate = aggregate["column_aggregates"][idx]
                aggregate_id = str(uuid.uuid4())
                condition_ids.append(aggregate_id)
                aggregate_cols.append(column_aggregate)
            res = df.agg(*aggregate_cols).collect()
            assert (
                len(res) == 1
            ), "all bundle-computed metrics must be single-value statistics"
            assert len(aggregate["ids"]) == len(
                res[0]), "unexpected number of metrics returned"
            logger.debug(
                f"SparkDFExecutionEngine computed {len(res[0])} metrics on domain_id {IDDict(compute_domain_kwargs).to_id()}"
            )
            for idx, id in enumerate(aggregate["ids"]):
                resolved_metrics[id] = res[0][idx]

        return resolved_metrics