def _get_scalar_average_per_task( self, metrics_interval: MetricInterval, task_ids: Sequence[str], es_index: str, key: ScalarKey, ) -> Sequence[MetricData]: """ Retrieve scalar histograms per several metric variants that share the same interval """ interval, task_metrics = metrics_interval aggregation = self._add_aggregation_average( key.get_aggregation(interval)) aggs = { "metrics": { "terms": { "field": "metric", "size": self.MAX_METRICS_COUNT }, "aggs": { "variants": { "terms": { "field": "variant", "size": self.MAX_VARIANTS_COUNT }, "aggs": { "tasks": { "terms": { "field": "task", "size": self.MAX_TASKS_COUNT, }, "aggs": aggregation, } }, } }, } } aggs_result = self._query_aggregation_for_metrics_and_tasks( es_index, aggs=aggs, task_ids=task_ids, task_metrics=task_metrics) if not aggs_result: return {} metrics = [( metric["key"], { variant["key"]: { task["key"]: key.get_iterations_data(task) for task in variant["tasks"]["buckets"] } for variant in metric["variants"]["buckets"] }, ) for metric in aggs_result["metrics"]["buckets"]] return metrics
def _get_scalar_average( self, metrics_interval: MetricInterval, task_ids: Sequence[str], es_index: str, key: ScalarKey, ) -> Sequence[MetricData]: """ Retrieve scalar histograms per several metric variants that share the same interval Note: the function works with a single task only """ assert len(task_ids) == 1 interval, task_metrics = metrics_interval aggregation = self._add_aggregation_average( key.get_aggregation(interval)) aggs = { "metrics": { "terms": { "field": "metric", "size": self.MAX_METRICS_COUNT, "order": { "_term": "desc" }, }, "aggs": { "variants": { "terms": { "field": "variant", "size": self.MAX_VARIANTS_COUNT, "order": { "_term": "desc" }, }, "aggs": aggregation, } }, } } aggs_result = self._query_aggregation_for_metrics_and_tasks( es_index, aggs=aggs, task_ids=task_ids, task_metrics=task_metrics) if not aggs_result: return {} metrics = [( metric["key"], { variant["key"]: { "name": variant["key"], **key.get_iterations_data(variant), } for variant in metric["variants"]["buckets"] }, ) for metric in aggs_result["metrics"]["buckets"]] return metrics
def compare_scalar_metrics_average_per_iter( self, company_id, task_ids: Sequence[str], samples, key: ScalarKeyEnum, allow_public=True, ): """ Compare scalar metrics for different tasks per metric and variant The amount of points in each histogram should not exceed the requested samples """ task_name_by_id = {} with translate_errors_context(): task_objs = Task.get_many( company=company_id, query=Q(id__in=task_ids), allow_public=allow_public, override_projection=("id", "name", "company"), return_dicts=False, ) if len(task_objs) < len(task_ids): invalid = tuple(set(task_ids) - set(r.id for r in task_objs)) raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid) task_name_by_id = {t.id: t.name for t in task_objs} companies = {t.company for t in task_objs} if len(companies) > 1: raise errors.bad_request.InvalidTaskId( "only tasks from the same company are supported" ) es_index = self.get_index_name(next(iter(companies)), "training_stats_scalar") if not self.es.indices.exists(es_index): return {} get_scalar_average_per_iter = partial( self._get_scalar_average_per_iter_core, es_index=es_index, samples=samples, key=ScalarKey.resolve(key), run_parallel=False, ) with ThreadPoolExecutor(max_workers=self._max_concurrency) as pool: task_metrics = zip( task_ids, pool.map(get_scalar_average_per_iter, task_ids) ) res = defaultdict(lambda: defaultdict(dict)) for task_id, task_data in task_metrics: task_name = task_name_by_id[task_id] for metric_key, metric_data in task_data.items(): for variant_key, variant_data in metric_data.items(): variant_data["name"] = task_name res[metric_key][variant_key][task_id] = variant_data return res
def compare_scalar_metrics_average_per_iter( self, company_id, task_ids: Sequence[str], samples, key: ScalarKeyEnum, allow_public=True, ): """ Compare scalar metrics for different tasks per metric and variant The amount of points in each histogram should not exceed the requested samples """ if len(task_ids) > self.MAX_TASKS_COUNT: raise errors.BadRequest( f"Up to {self.MAX_TASKS_COUNT} tasks supported for comparison", len(task_ids), ) task_name_by_id = {} with translate_errors_context(): task_objs = Task.get_many( company=company_id, query=Q(id__in=task_ids), allow_public=allow_public, override_projection=("id", "name", "company"), return_dicts=False, ) if len(task_objs) < len(task_ids): invalid = tuple(set(task_ids) - set(r.id for r in task_objs)) raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid) task_name_by_id = {t.id: t.name for t in task_objs} companies = {t.company for t in task_objs} if len(companies) > 1: raise errors.bad_request.InvalidTaskId( "only tasks from the same company are supported" ) ret = self._run_get_scalar_metrics_as_parallel( next(iter(companies)), task_ids=task_ids, samples=samples, key=ScalarKey.resolve(key), get_func=self._get_scalar_average_per_task, ) for metric_data in ret.values(): for variant_data in metric_data.values(): for task_id, task_data in variant_data.items(): task_data["name"] = task_name_by_id[task_id] return ret
def get_scalar_metrics_average_per_iter( self, company_id: str, task_id: str, samples: int, key: ScalarKeyEnum ) -> dict: """ Get scalar metric histogram per metric and variant The amount of points in each histogram should not exceed the requested samples """ es_index = self.get_index_name(company_id, "training_stats_scalar") if not self.es.indices.exists(es_index): return {} return self._get_scalar_average_per_iter_core( task_id, es_index, samples, ScalarKey.resolve(key) )
def get_scalar_metrics_average_per_iter( self, company_id: str, task_id: str, samples: int, key: ScalarKeyEnum ) -> dict: """ Get scalar metric histogram per metric and variant The amount of points in each histogram should not exceed the requested samples """ return self._run_get_scalar_metrics_as_parallel( company_id, task_ids=[task_id], samples=samples, key=ScalarKey.resolve(key), get_func=self._get_scalar_average, )