コード例 #1
0
ファイル: task_bll.py プロジェクト: ainoam/trains-server
    def get_by_id(
        company_id,
        task_id,
        required_status=None,
        only_fields=None,
        allow_public=False,
    ):
        if only_fields:
            if isinstance(only_fields, string_types):
                only_fields = [only_fields]
            else:
                only_fields = list(only_fields)
            only_fields = only_fields + ["status"]

        with TimingContext("mongo", "task_by_id_all"):
            tasks = Task.get_many(
                company=company_id,
                query=Q(id=task_id),
                allow_public=allow_public,
                override_projection=only_fields,
                return_dicts=False,
            )
            task = None if not tasks else tasks[0]

        if not task:
            raise errors.bad_request.InvalidTaskId(id=task_id)

        if required_status and not task.status == required_status:
            raise errors.bad_request.InvalidTaskStatus(
                expected=required_status)

        return task
コード例 #2
0
ファイル: task_bll.py プロジェクト: ainoam/trains-server
    def assert_exists(company_id,
                      task_ids,
                      only=None,
                      allow_public=False,
                      return_tasks=True) -> Optional[Sequence[Task]]:
        task_ids = [task_ids] if isinstance(task_ids,
                                            six.string_types) else task_ids
        with translate_errors_context(), TimingContext("mongo", "task_exists"):
            ids = set(task_ids)
            q = Task.get_many(
                company=company_id,
                query=Q(id__in=ids),
                allow_public=allow_public,
                return_dicts=False,
            )
            if only:
                # Make sure to reset fields filters (some fields are excluded by default) since this
                # is an internal call and specific fields were requested.
                q = q.all_fields().only(*only)

            if q.count() != len(ids):
                raise errors.bad_request.InvalidTaskId(ids=task_ids)

            if return_tasks:
                return list(q)
コード例 #3
0
def get_all(call: APICall, company_id, _):
    conform_tag_fields(call, call.data)

    escape_execution_parameters(call)

    with translate_errors_context():
        with TimingContext("mongo", "task_get_all"):
            tasks = Task.get_many(
                company=company_id,
                parameters=call.data,
                query_dict=call.data,
                allow_public=True,  # required in case projection is requested for public dataset/versions
            )
        unprepare_from_saved(call, tasks)
        call.result.data = {"tasks": tasks}
コード例 #4
0
def get_all(call: APICall, company_id, _):
    conform_tag_fields(call, call.data)

    call_data = escape_execution_parameters(call)

    with TimingContext("mongo", "task_get_all"):
        ret_params = {}
        tasks = Task.get_many(
            company=company_id,
            parameters=call_data,
            query_dict=call_data,
            allow_public=True,
            ret_params=ret_params,
        )
    unprepare_from_saved(call, tasks)
    call.result.data = {"tasks": tasks, **ret_params}
コード例 #5
0
    def compare_scalar_metrics_average_per_iter(
        self,
        company_id,
        task_ids: Sequence[str],
        samples,
        key: ScalarKeyEnum,
        allow_public=True,
    ):
        """
        Compare scalar metrics for different tasks per metric and variant
        The amount of points in each histogram should not exceed the requested samples
        """
        task_name_by_id = {}
        with translate_errors_context():
            task_objs = Task.get_many(
                company=company_id,
                query=Q(id__in=task_ids),
                allow_public=allow_public,
                override_projection=("id", "name", "company",
                                     "company_origin"),
                return_dicts=False,
            )
            if len(task_objs) < len(task_ids):
                invalid = tuple(set(task_ids) - set(r.id for r in task_objs))
                raise errors.bad_request.InvalidTaskId(company=company_id,
                                                       ids=invalid)
            task_name_by_id = {t.id: t.name for t in task_objs}

        companies = {t.get_index_company() for t in task_objs}
        if len(companies) > 1:
            raise errors.bad_request.InvalidTaskId(
                "only tasks from the same company are supported")

        event_type = EventType.metrics_scalar
        company_id = next(iter(companies))
        if check_empty_data(self.es,
                            company_id=company_id,
                            event_type=event_type):
            return {}

        get_scalar_average_per_iter = partial(
            self._get_scalar_average_per_iter_core,
            company_id=company_id,
            event_type=event_type,
            samples=samples,
            key=ScalarKey.resolve(key),
            run_parallel=False,
        )
        with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
            task_metrics = zip(task_ids,
                               pool.map(get_scalar_average_per_iter, task_ids))

        res = defaultdict(lambda: defaultdict(dict))
        for task_id, task_data in task_metrics:
            task_name = task_name_by_id[task_id]
            for metric_key, metric_data in task_data.items():
                for variant_key, variant_data in metric_data.items():
                    variant_data["name"] = task_name
                    res[metric_key][variant_key][task_id] = variant_data

        return res