def get_by_id( company_id, task_id, required_status=None, only_fields=None, allow_public=False, ): if only_fields: if isinstance(only_fields, string_types): only_fields = [only_fields] else: only_fields = list(only_fields) only_fields = only_fields + ["status"] with TimingContext("mongo", "task_by_id_all"): tasks = Task.get_many( company=company_id, query=Q(id=task_id), allow_public=allow_public, override_projection=only_fields, return_dicts=False, ) task = None if not tasks else tasks[0] if not task: raise errors.bad_request.InvalidTaskId(id=task_id) if required_status and not task.status == required_status: raise errors.bad_request.InvalidTaskStatus( expected=required_status) return task
def assert_exists(company_id, task_ids, only=None, allow_public=False, return_tasks=True) -> Optional[Sequence[Task]]: task_ids = [task_ids] if isinstance(task_ids, six.string_types) else task_ids with translate_errors_context(), TimingContext("mongo", "task_exists"): ids = set(task_ids) q = Task.get_many( company=company_id, query=Q(id__in=ids), allow_public=allow_public, return_dicts=False, ) if only: # Make sure to reset fields filters (some fields are excluded by default) since this # is an internal call and specific fields were requested. q = q.all_fields().only(*only) if q.count() != len(ids): raise errors.bad_request.InvalidTaskId(ids=task_ids) if return_tasks: return list(q)
def get_all(call: APICall, company_id, _): conform_tag_fields(call, call.data) escape_execution_parameters(call) with translate_errors_context(): with TimingContext("mongo", "task_get_all"): tasks = Task.get_many( company=company_id, parameters=call.data, query_dict=call.data, allow_public=True, # required in case projection is requested for public dataset/versions ) unprepare_from_saved(call, tasks) call.result.data = {"tasks": tasks}
def get_all(call: APICall, company_id, _): conform_tag_fields(call, call.data) call_data = escape_execution_parameters(call) with TimingContext("mongo", "task_get_all"): ret_params = {} tasks = Task.get_many( company=company_id, parameters=call_data, query_dict=call_data, allow_public=True, ret_params=ret_params, ) unprepare_from_saved(call, tasks) call.result.data = {"tasks": tasks, **ret_params}
def compare_scalar_metrics_average_per_iter( self, company_id, task_ids: Sequence[str], samples, key: ScalarKeyEnum, allow_public=True, ): """ Compare scalar metrics for different tasks per metric and variant The amount of points in each histogram should not exceed the requested samples """ task_name_by_id = {} with translate_errors_context(): task_objs = Task.get_many( company=company_id, query=Q(id__in=task_ids), allow_public=allow_public, override_projection=("id", "name", "company", "company_origin"), return_dicts=False, ) if len(task_objs) < len(task_ids): invalid = tuple(set(task_ids) - set(r.id for r in task_objs)) raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid) task_name_by_id = {t.id: t.name for t in task_objs} companies = {t.get_index_company() for t in task_objs} if len(companies) > 1: raise errors.bad_request.InvalidTaskId( "only tasks from the same company are supported") event_type = EventType.metrics_scalar company_id = next(iter(companies)) if check_empty_data(self.es, company_id=company_id, event_type=event_type): return {} get_scalar_average_per_iter = partial( self._get_scalar_average_per_iter_core, company_id=company_id, event_type=event_type, samples=samples, key=ScalarKey.resolve(key), run_parallel=False, ) with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool: task_metrics = zip(task_ids, pool.map(get_scalar_average_per_iter, task_ids)) res = defaultdict(lambda: defaultdict(dict)) for task_id, task_data in task_metrics: task_name = task_name_by_id[task_id] for metric_key, metric_data in task_data.items(): for variant_key, variant_data in metric_data.items(): variant_data["name"] = task_name res[metric_key][variant_key][task_id] = variant_data return res