Exemplo n.º 1
0
def update(call: APICall, company_id, req_model: UpdateRequest):
    task_id = req_model.task

    with translate_errors_context():
        task = Task.get_for_writing(id=task_id,
                                    company=company_id,
                                    _only=["id"])
        if not task:
            raise errors.bad_request.InvalidTaskId(id=task_id)

        partial_update_dict, valid_fields = prepare_update_fields(
            call, task, call.data)

        if not partial_update_dict:
            return UpdateResponse(updated=0)

        updated_count, updated_fields = Task.safe_update(
            company_id=company_id,
            id=task_id,
            partial_update_dict=partial_update_dict,
            injected_update=dict(last_update=datetime.utcnow()),
        )
        if updated_count:
            _update_org_tags(company_id, updated_fields)
            update_project_time(updated_fields.get("project"))
        unprepare_from_saved(call, updated_fields)
        return UpdateResponse(updated=updated_count, fields=updated_fields)
Exemplo n.º 2
0
def enqueue(call: APICall, company_id, req_model: EnqueueRequest):
    task_id = req_model.task
    queue_id = req_model.queue
    status_message = req_model.status_message
    status_reason = req_model.status_reason

    if not queue_id:
        # try to get default queue
        queue_id = queue_bll.get_default(company_id).id

    with translate_errors_context():
        query = dict(id=task_id, company=company_id)
        task = Task.get_for_writing(
            _only=("type", "script", "execution", "status", "project", "id"), **query
        )
        if not task:
            raise errors.bad_request.InvalidTaskId(**query)

        res = EnqueueResponse(
            **ChangeStatusRequest(
                task=task,
                new_status=TaskStatus.queued,
                status_reason=status_reason,
                status_message=status_message,
                allow_same_state_transition=False,
            ).execute()
        )

        try:
            queue_bll.add_task(
                company_id=company_id, queue_id=queue_id, task_id=task.id
            )
        except Exception:
            # failed enqueueing, revert to previous state
            ChangeStatusRequest(
                task=task,
                current_status_override=TaskStatus.queued,
                new_status=task.status,
                force=True,
                status_reason="failed enqueueing",
            ).execute()
            raise

        # set the current queue ID in the task
        if task.execution:
            Task.objects(**query).update(execution__queue=queue_id, multi=False)
        else:
            Task.objects(**query).update(
                execution=Execution(queue=queue_id), multi=False
            )

        res.queued = 1
        res.fields.update(**{"execution.queue": queue_id})

        call.result.data_model = res
Exemplo n.º 3
0
    def clone_task(
        cls,
        company_id,
        user_id,
        task_id,
        name: Optional[str] = None,
        comment: Optional[str] = None,
        parent: Optional[str] = None,
        project: Optional[str] = None,
        tags: Optional[Sequence[str]] = None,
        system_tags: Optional[Sequence[str]] = None,
        execution_overrides: Optional[dict] = None,
    ) -> Task:
        task = cls.get_by_id(company_id=company_id, task_id=task_id)
        execution_dict = task.execution.to_proper_dict() if task.execution else {}
        if execution_overrides:
            parameters = execution_overrides.get("parameters")
            if parameters is not None:
                execution_overrides["parameters"] = {
                    ParameterKeyEscaper.escape(k): v for k, v in parameters.items()
                }
            execution_dict = deep_merge(execution_dict, execution_overrides)
        artifacts = execution_dict.get("artifacts")
        if artifacts:
            execution_dict["artifacts"] = [
                a for a in artifacts if a.get("mode") != ArtifactModes.output
            ]
        now = datetime.utcnow()

        with translate_errors_context():
            new_task = Task(
                id=create_id(),
                user=user_id,
                company=company_id,
                created=now,
                last_update=now,
                name=name or task.name,
                comment=comment or task.comment,
                parent=parent or task.parent,
                project=project or task.project,
                tags=tags or task.tags,
                system_tags=system_tags or [],
                type=task.type,
                script=task.script,
                output=Output(destination=task.output.destination)
                if task.output
                else None,
                execution=execution_dict,
            )
            cls.validate(new_task)
            new_task.save()

        return new_task
Exemplo n.º 4
0
def update_batch(call: APICall, company_id, _):
    items = call.batched_data
    if items is None:
        raise errors.bad_request.BatchContainsNoItems()

    with translate_errors_context():
        items = {i["task"]: i for i in items}
        tasks = {
            t.id: t
            for t in Task.get_many_for_writing(company=company_id,
                                               query=Q(id__in=list(items)))
        }

        if len(tasks) < len(items):
            missing = tuple(set(items).difference(tasks))
            raise errors.bad_request.InvalidTaskId(ids=missing)

        now = datetime.utcnow()

        bulk_ops = []
        updated_projects = set()
        for id, data in items.items():
            task = tasks[id]
            fields, valid_fields = prepare_update_fields(call, task, data)
            partial_update_dict = Task.get_safe_update_dict(fields)
            if not partial_update_dict:
                continue
            partial_update_dict.update(last_update=now)
            update_op = UpdateOne({
                "_id": id,
                "company": company_id
            }, {"$set": partial_update_dict})
            bulk_ops.append(update_op)

            new_project = partial_update_dict.get("project", task.project)
            if new_project != task.project:
                updated_projects.update({new_project, task.project})
            elif any(f in partial_update_dict
                     for f in ("tags", "system_tags")):
                updated_projects.add(task.project)

        updated = 0
        if bulk_ops:
            res = Task._get_collection().bulk_write(bulk_ops)
            updated = res.modified_count

        if updated and updated_projects:
            _reset_cached_tags(company_id, projects=list(updated_projects))

        call.result.data = {"updated": updated}
Exemplo n.º 5
0
def get_by_task_id(call):
    assert isinstance(call, APICall)
    task_id = call.data["task"]

    with translate_errors_context():
        query = dict(id=task_id, company=call.identity.company)
        res = Task.get(_only=["output"], **query)
        if not res:
            raise errors.bad_request.InvalidTaskId(**query)
        if not res.output:
            raise errors.bad_request.MissingTaskFields(field="output")
        if not res.output.model:
            raise errors.bad_request.MissingTaskFields(field="output.model")

        model_id = res.output.model
        res = Model.objects(
            Q(id=model_id)
            & get_company_or_none_constraint(call.identity.company)).first()
        if not res:
            raise errors.bad_request.InvalidModelId(
                "no such public or company model",
                id=model_id,
                company=call.identity.company,
            )
        call.result.data = {"model": res.to_proper_dict()}
Exemplo n.º 6
0
 def _upgrade_tasks(cls, f: BinaryIO) -> bytes:
     """
     Build content array that contains fixed tasks from the passed file
     For each task the old execution.parameters and model.design are
     converted to the new structure.
     The fix is done on Task objects (not the dictionary) so that
     the fields are serialized back in the same order as they were in the original file
     """
     with BytesIO() as temp:
         with cls.JsonLinesWriter(temp) as w:
             for line in cls.json_lines(f):
                 task_data = Task.from_json(line).to_proper_dict()
                 cls._upgrade_task_data(task_data)
                 new_task = Task(**task_data)
                 w.write(new_task.to_json())
         return temp.getvalue()
Exemplo n.º 7
0
    def get_by_id(
        company_id,
        task_id,
        required_status=None,
        required_dataset=None,
        only_fields=None,
    ):

        with TimingContext("mongo", "task_by_id_all"):
            qs = Task.objects(id=task_id, company=company_id)
            if only_fields:
                qs = (qs.only(only_fields) if isinstance(
                    only_fields, string_types) else qs.only(*only_fields))
                qs = qs.only(
                    "status", "input"
                )  # make sure all fields we rely on here are also returned
            task = qs.first()

        if not task:
            raise errors.bad_request.InvalidTaskId(id=task_id)

        if required_status and not task.status == required_status:
            raise errors.bad_request.InvalidTaskStatus(
                expected=required_status)

        if required_dataset and required_dataset not in (
                entry.dataset for entry in task.input.view.entries):
            raise errors.bad_request.InvalidId("not in input view",
                                               dataset=required_dataset)

        return task
Exemplo n.º 8
0
    def get_by_id(
        company_id,
        task_id,
        required_status=None,
        only_fields=None,
        allow_public=False,
    ):
        if only_fields:
            if isinstance(only_fields, string_types):
                only_fields = [only_fields]
            else:
                only_fields = list(only_fields)
            only_fields = only_fields + ["status"]

        with TimingContext("mongo", "task_by_id_all"):
            tasks = Task.get_many(
                company=company_id,
                query=Q(id=task_id),
                allow_public=allow_public,
                override_projection=only_fields,
                return_dicts=False,
            )
            task = None if not tasks else tasks[0]

        if not task:
            raise errors.bad_request.InvalidTaskId(id=task_id)

        if required_status and not task.status == required_status:
            raise errors.bad_request.InvalidTaskStatus(
                expected=required_status)

        return task
Exemplo n.º 9
0
    def assert_exists(company_id,
                      task_ids,
                      only=None,
                      allow_public=False,
                      return_tasks=True) -> Optional[Sequence[Task]]:
        task_ids = [task_ids] if isinstance(task_ids,
                                            six.string_types) else task_ids
        with translate_errors_context(), TimingContext("mongo", "task_exists"):
            ids = set(task_ids)
            q = Task.get_many(
                company=company_id,
                query=Q(id__in=ids),
                allow_public=allow_public,
                return_dicts=False,
            )
            res = None
            if only:
                res = q.only(*only)
            elif return_tasks:
                res = list(q)

            count = len(res) if res is not None else q.count()
            if count != len(ids):
                raise errors.bad_request.InvalidTaskId(ids=task_ids)

            if return_tasks:
                return res
Exemplo n.º 10
0
    def _resolve_entities(
        cls, experiments: List[str] = None, projects: List[str] = None
    ) -> Dict[Type[mongoengine.Document], Set[mongoengine.Document]]:
        from database.model.project import Project
        from database.model.task.task import Task

        entities = defaultdict(set)

        if projects:
            print("Reading projects...")
            entities[Project].update(cls._resolve_type(Project, projects))
            print("--> Reading project experiments...")
            objs = Task.objects(
                project__in=list(set(filter(None, (p.id for p in entities[Project]))))
            )
            entities[Task].update(o for o in objs if o.id not in (experiments or []))

        if experiments:
            print("Reading experiments...")
            entities[Task].update(cls._resolve_type(Task, experiments))
            print("--> Reading experiments projects...")
            objs = Project.objects(
                id__in=list(set(filter(None, (p.project for p in entities[Task]))))
            )
            project_ids = {p.id for p in entities[Project]}
            entities[Project].update(o for o in objs if o.id not in project_ids)

        return entities
Exemplo n.º 11
0
def get_by_task_id(call: APICall, company_id, _):
    task_id = call.data["task"]

    with translate_errors_context():
        query = dict(id=task_id, company=company_id)
        task = Task.get(_only=["output"], **query)
        if not task:
            raise errors.bad_request.InvalidTaskId(**query)
        if not task.output:
            raise errors.bad_request.MissingTaskFields(field="output")
        if not task.output.model:
            raise errors.bad_request.MissingTaskFields(field="output.model")

        model_id = task.output.model
        model = Model.objects(
            Q(id=model_id)
            & get_company_or_none_constraint(company_id)).first()
        if not model:
            raise errors.bad_request.InvalidModelId(
                "no such public or company model",
                id=model_id,
                company=company_id,
            )
        model_dict = model.to_proper_dict()
        conform_output_tags(call, model_dict)
        call.result.data = {"model": model_dict}
Exemplo n.º 12
0
def get_outputs_for_deletion(task, force=False):
    with TimingContext("mongo", "get_task_models"):
        models = TaskOutputs(
            attrgetter("ready"),
            Model,
            Model.objects(task=task.id).only("id", "task", "ready"),
        )
        if not force and models.published:
            raise errors.bad_request.TaskCannotBeDeleted(
                "has output models, use force=True",
                task=task.id,
                models=len(models.published),
            )

    if task.output.model:
        output_model = get_output_model(task, force)
        if output_model:
            if output_model.ready:
                models.published.append(output_model)
            else:
                models.draft.append(output_model)

    with TimingContext("mongo", "get_task_children"):
        tasks = Task.objects(parent=task.id).only("id", "parent", "status")
        published_tasks = [
            task for task in tasks if task.status == TaskStatus.published
        ]
        if not force and published_tasks:
            raise errors.bad_request.TaskCannotBeDeleted(
                "has children, use force=True",
                task=task.id,
                children=published_tasks)
    return models, tasks
Exemplo n.º 13
0
    def model_set_ready(
        cls,
        model_id: str,
        company_id: str,
        publish_task: bool,
        force_publish_task: bool = False,
    ) -> tuple:
        with translate_errors_context():
            query = dict(id=model_id, company=company_id)
            model = Model.objects(**query).first()
            if not model:
                raise errors.bad_request.InvalidModelId(**query)
            elif model.ready:
                raise errors.bad_request.ModelIsReady(**query)

            published_task_data = {}
            if model.task and publish_task:
                task = (Task.objects(id=model.task, company=company_id).only(
                    "id", "status").first())
                if task and task.status != TaskStatus.published:
                    published_task_data["data"] = cls.publish_task(
                        task_id=model.task,
                        company_id=company_id,
                        publish_model=False,
                        force=force_publish_task,
                    )
                    published_task_data["id"] = model.task

            updated = model.update(upsert=False, ready=True)
            return updated, published_task_data
Exemplo n.º 14
0
    def validate(
        cls,
        task: Task,
        validate_model=True,
        validate_parent=True,
        validate_project=True,
    ):
        if (
            validate_parent
            and task.parent
            and not Task.get(
                company=task.company, id=task.parent, _only=("id",), include_public=True
            )
        ):
            raise errors.bad_request.InvalidTaskId("invalid parent", parent=task.parent)

        if (
            validate_project
            and task.project
            and not Project.get_for_writing(company=task.company, id=task.project)
        ):
            raise errors.bad_request.InvalidProjectId(id=task.project)

        if validate_model:
            cls.validate_execution_model(task)
Exemplo n.º 15
0
    def validate(cls, task: Task, force=False):
        assert isinstance(task, Task)

        if task.parent and not Task.get(company=task.company,
                                        id=task.parent,
                                        _only=("id", ),
                                        include_public=True):
            raise errors.bad_request.InvalidTaskId("invalid parent",
                                                   parent=task.parent)

        if task.project:
            Project.get_for_writing(company=task.company, id=task.project)

        model = cls.validate_execution_model(task)
        if model and not force and not model.ready:
            raise errors.bad_request.ModelNotReady("can't be used in a task",
                                                   model=model.id)

        if task.execution:
            if task.execution.parameters:
                cls._validate_execution_parameters(task.execution.parameters)

        if task.output and task.output.destination:
            parsed_url = urlparse(task.output.destination)
            if parsed_url.scheme not in OutputDestinationField.schemes:
                raise errors.bad_request.FieldsValueError(
                    "unsupported scheme for output destination",
                    dest=task.output.destination,
                )
    def cleanup_tasks(cls, threshold_sec):
        relevant_status = (TaskStatus.in_progress,)
        threshold = timedelta(seconds=threshold_sec)
        ref_time = datetime.utcnow() - threshold
        log.info(
            f"Starting cleanup cycle for running tasks last updated before {ref_time}"
        )

        tasks = list(
            Task.objects(status__in=relevant_status, last_update__lt=ref_time).only(
                "id", "name", "status", "project", "last_update"
            )
        )
        log.info(f"{len(tasks)} non-responsive tasks found")
        if not tasks:
            return 0

        err_count = 0
        for task in tasks:
            log.info(
                f"Stopping {task.id} ({task.name}), last updated at {task.last_update}"
            )
            try:
                ChangeStatusRequest(
                    task=task,
                    new_status=TaskStatus.stopped,
                    status_reason="Forced stop (non-responsive)",
                    status_message="Forced stop (non-responsive)",
                    force=True,
                ).execute()
            except errors.bad_request.FailedChangingTaskStatus:
                err_count += 1

        return len(tasks) - err_count
Exemplo n.º 17
0
    def update_statistics(
        task_id: str,
        company_id: str,
        last_update: datetime = None,
        last_iteration: int = None,
        last_iteration_max: int = None,
        last_values: Sequence[Tuple[Tuple[str, ...], Any]] = None,
        **extra_updates,
    ):
        """
        Update task statistics
        :param task_id: Task's ID.
        :param company_id: Task's company ID.
        :param last_update: Last update time. If not provided, defaults to datetime.utcnow().
        :param last_iteration: Last reported iteration. Use this to set a value regardless of current
            task's last iteration value.
        :param last_iteration_max: Last reported iteration. Use this to conditionally set a value only
            if the current task's last iteration value is smaller than the provided value.
        :param last_values: Last reported metrics summary (value, metric, variant).
        :param extra_updates: Extra task updates to include in this update call.
        :return:
        """
        last_update = last_update or datetime.utcnow()

        if last_iteration is not None:
            extra_updates.update(last_iteration=last_iteration)
        elif last_iteration_max is not None:
            extra_updates.update(max__last_iteration=last_iteration_max)

        if last_values is not None:

            def op_path(op, *path):
                return "__".join((op, "last_metrics") + path)

            for path, value in last_values:
                extra_updates[op_path("set", *path)] = value
                if path[-1] == "value":
                    extra_updates[op_path("min", *path[:-1],
                                          "min_value")] = value
                    extra_updates[op_path("max", *path[:-1],
                                          "max_value")] = value

        Task.objects(id=task_id,
                     company=company_id).update(upsert=False,
                                                last_update=last_update,
                                                **extra_updates)
Exemplo n.º 18
0
def edit(call: APICall, company_id, req_model: UpdateRequest):
    task_id = req_model.task
    force = req_model.force

    with translate_errors_context():
        task = Task.get_for_writing(id=task_id, company=company_id)
        if not task:
            raise errors.bad_request.InvalidTaskId(id=task_id)

        if not force and task.status != TaskStatus.created:
            raise errors.bad_request.InvalidTaskStatus(
                expected=TaskStatus.created, status=task.status)

        edit_fields = create_fields.copy()
        edit_fields.update(dict(status=None))

        with translate_errors_context(
                field_does_not_exist_cls=errors.bad_request.ValidationError
        ), TimingContext("code", "parse_and_validate"):
            fields = prepare_create_fields(call,
                                           valid_fields=edit_fields,
                                           output=task.output,
                                           previous_task=task)

        for key in fields:
            field = getattr(task, key, None)
            value = fields[key]
            if (field and isinstance(value, dict)
                    and isinstance(field, EmbeddedDocument)):
                d = field.to_mongo(use_db_field=False).to_dict()
                d.update(value)
                fields[key] = d

        task_bll.validate(task_bll.create(call, fields))

        # make sure field names do not end in mongoengine comparison operators
        fixed_fields = {(k if k not in COMPARISON_OPERATORS else "%s__" % k): v
                        for k, v in fields.items()}
        if fixed_fields:
            now = datetime.utcnow()
            fields.update(last_update=now)
            fixed_fields.update(last_update=now)
            updated = task.update(upsert=False, **fixed_fields)
            if updated:
                new_project = fixed_fields.get("project", task.project)
                if new_project != task.project:
                    _reset_cached_tags(company_id,
                                       projects=[new_project, task.project])
                else:
                    _update_cached_tags(company_id,
                                        project=task.project,
                                        fields=fixed_fields)
                update_project_time(fields.get("project"))
            unprepare_from_saved(call, fields)
            call.result.data_model = UpdateResponse(updated=updated,
                                                    fields=fields)
        else:
            call.result.data_model = UpdateResponse(updated=0)
Exemplo n.º 19
0
    def status_report(
        self, company_id: str, user_id: str, ip: str, report: StatusReportRequest
    ) -> None:
        """
        Write worker status report
        :param company_id: worker's company ID
        :param user_id: user_id ID under which this worker is running
        :raise bad_request.InvalidTaskId: the reported task was not found
        :return: worker entry instance
        """
        entry = self._get_worker(company_id, user_id, report.worker)

        try:
            entry.ip = ip
            now = datetime.utcnow()
            entry.last_activity_time = now

            if report.machine_stats:
                self._log_stats_to_es(
                    company_id=company_id,
                    company_name=entry.company.name,
                    worker=report.worker,
                    timestamp=report.timestamp,
                    task=report.task,
                    machine_stats=report.machine_stats,
                )

            entry.queue = report.queue

            if report.queues:
                entry.queues = report.queues

            if not report.task:
                entry.task = None
            else:
                with translate_errors_context():
                    query = dict(id=report.task, company=company_id)
                    update = dict(
                        last_worker=report.worker,
                        last_worker_report=now,
                        last_update=now,
                    )
                    # modify(new=True, ...) returns the modified object
                    task = Task.objects(**query).modify(new=True, **update)
                    if not task:
                        raise bad_request.InvalidTaskId(**query)
                    entry.task = IdNameEntry(id=task.id, name=task.name)

            entry.last_report_time = now
        except APIError:
            raise
        except Exception as e:
            msg = "Failed processing worker status report"
            log.exception(msg)
            raise server_error.DataError(msg, err=e.args[0])
        finally:
            self._save_worker(entry)
Exemplo n.º 20
0
    def compare_scalar_metrics_average_per_iter(
        self,
        company_id,
        task_ids: Sequence[str],
        samples,
        key: ScalarKeyEnum,
        allow_public=True,
    ):
        """
        Compare scalar metrics for different tasks per metric and variant
        The amount of points in each histogram should not exceed the requested samples
        """
        task_name_by_id = {}
        with translate_errors_context():
            task_objs = Task.get_many(
                company=company_id,
                query=Q(id__in=task_ids),
                allow_public=allow_public,
                override_projection=("id", "name", "company"),
                return_dicts=False,
            )
            if len(task_objs) < len(task_ids):
                invalid = tuple(set(task_ids) - set(r.id for r in task_objs))
                raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid)
            task_name_by_id = {t.id: t.name for t in task_objs}

        companies = {t.company for t in task_objs}
        if len(companies) > 1:
            raise errors.bad_request.InvalidTaskId(
                "only tasks from the same company are supported"
            )

        es_index = self.get_index_name(next(iter(companies)), "training_stats_scalar")
        if not self.es.indices.exists(es_index):
            return {}

        get_scalar_average_per_iter = partial(
            self._get_scalar_average_per_iter_core,
            es_index=es_index,
            samples=samples,
            key=ScalarKey.resolve(key),
            run_parallel=False,
        )
        with ThreadPoolExecutor(max_workers=self._max_concurrency) as pool:
            task_metrics = zip(
                task_ids, pool.map(get_scalar_average_per_iter, task_ids)
            )

        res = defaultdict(lambda: defaultdict(dict))
        for task_id, task_data in task_metrics:
            task_name = task_name_by_id[task_id]
            for metric_key, metric_data in task_data.items():
                for variant_key, variant_data in metric_data.items():
                    variant_data["name"] = task_name
                    res[metric_key][variant_key][task_id] = variant_data

        return res
Exemplo n.º 21
0
 def get_types(cls, company, project_ids: Optional[Sequence]) -> set:
     """
     Return the list of unique task types used by company and public tasks
     If project ids passed then only tasks from these projects are considered
     """
     query = get_company_or_none_constraint(company)
     if project_ids:
         query &= Q(project__in=project_ids)
     res = Task.objects(query).distinct(field="type")
     return set(res).intersection(external_task_types)
Exemplo n.º 22
0
def update(call: APICall, company_id, _):
    model_id = call.data["model"]
    force = call.data.get("force", False)

    with translate_errors_context():
        query = dict(id=model_id, company=company_id)
        model = Model.objects(**query).only("id", "task", "project").first()
        if not model:
            raise errors.bad_request.InvalidModelId(**query)

        deleted_model_id = f"__DELETED__{model_id}"

        using_tasks = Task.objects(execution__model=model_id).only("id")
        if using_tasks:
            if not force:
                raise errors.bad_request.ModelInUse(
                    "as execution model, use force=True to delete",
                    num_tasks=len(using_tasks),
                )
            # update deleted model id in using tasks
            using_tasks.update(execution__model=deleted_model_id,
                               upsert=False,
                               multi=True)

        if model.task:
            task = Task.objects(id=model.task).first()
            if task and task.status == TaskStatus.published:
                if not force:
                    raise errors.bad_request.ModelCreatingTaskExists(
                        "and published, use force=True to delete",
                        task=model.task)
                task.update(
                    output__model=deleted_model_id,
                    output__error=
                    f"model deleted on {datetime.utcnow().isoformat()}",
                    upsert=False,
                )

        del_count = Model.objects(**query).delete()
        if del_count:
            _reset_cached_tags(company_id, projects=[model.project])
        call.result.data = dict(deleted=del_count > 0)
Exemplo n.º 23
0
    def _get_valid_tasks(company_id, task_ids: Set, allow_locked_tasks=False) -> Set:
        """Verify that task exists and can be updated"""
        if not task_ids:
            return set()

        with translate_errors_context(), TimingContext("mongo", "task_by_ids"):
            query = Q(id__in=task_ids, company=company_id)
            if not allow_locked_tasks:
                query &= Q(status__nin=LOCKED_TASK_STATUSES)
            res = Task.objects(query).only("id")
            return {r.id for r in res}
Exemplo n.º 24
0
 def create(call: APICall, fields: dict):
     identity = call.identity
     now = datetime.utcnow()
     return Task(
         id=create_id(),
         user=identity.user,
         company=identity.company,
         created=now,
         last_update=now,
         **fields,
     )
Exemplo n.º 25
0
    def get_task_with_access(
        task_id, company_id, only=None, allow_public=False, requires_write_access=False
    ) -> Task:
        """
        Gets a task that has a required write access
        :except errors.bad_request.InvalidTaskId: if the task is not found
        :except errors.forbidden.NoWritePermission: if write_access was required and the task cannot be modified
        """
        with translate_errors_context():
            query = dict(id=task_id, company=company_id)
            with TimingContext("mongo", "task_with_access"):
                if requires_write_access:
                    task = Task.get_for_writing(_only=only, **query)
                else:
                    task = Task.get(_only=only, **query, include_public=allow_public)

            if not task:
                raise errors.bad_request.InvalidTaskId(**query)

            return task
Exemplo n.º 26
0
    def compare_scalar_metrics_average_per_iter(
        self,
        company_id,
        task_ids: Sequence[str],
        samples,
        key: ScalarKeyEnum,
        allow_public=True,
    ):
        """
        Compare scalar metrics for different tasks per metric and variant
        The amount of points in each histogram should not exceed the requested samples
        """
        if len(task_ids) > self.MAX_TASKS_COUNT:
            raise errors.BadRequest(
                f"Up to {self.MAX_TASKS_COUNT} tasks supported for comparison",
                len(task_ids),
            )

        task_name_by_id = {}
        with translate_errors_context():
            task_objs = Task.get_many(
                company=company_id,
                query=Q(id__in=task_ids),
                allow_public=allow_public,
                override_projection=("id", "name", "company"),
                return_dicts=False,
            )
            if len(task_objs) < len(task_ids):
                invalid = tuple(set(task_ids) - set(r.id for r in task_objs))
                raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid)

            task_name_by_id = {t.id: t.name for t in task_objs}

        companies = {t.company for t in task_objs}
        if len(companies) > 1:
            raise errors.bad_request.InvalidTaskId(
                "only tasks from the same company are supported"
            )

        ret = self._run_get_scalar_metrics_as_parallel(
            next(iter(companies)),
            task_ids=task_ids,
            samples=samples,
            key=ScalarKey.resolve(key),
            get_func=self._get_scalar_average_per_task,
        )

        for metric_data in ret.values():
            for variant_data in metric_data.values():
                for task_id, task_data in variant_data.items():
                    task_data["name"] = task_name_by_id[task_id]

        return ret
Exemplo n.º 27
0
def get_all_ex(call: APICall):
    with translate_errors_context():
        with TimingContext("mongo", "task_get_all_ex"):
            tasks = Task.get_many_with_join(
                company=call.identity.company,
                query_dict=call.data,
                query_options=get_all_query_options,
                allow_public=
                True,  # required in case projection is requested for public dataset/versions
            )

        call.result.data = {"tasks": tasks}
Exemplo n.º 28
0
    def get_aggregated_project_execution_parameters(
        company_id,
        project_ids: Sequence[str] = None,
        page: int = 0,
        page_size: int = 500,
    ) -> Tuple[int, int, Sequence[str]]:

        page = max(0, page)
        page_size = max(1, page_size)

        pipeline = [
            {
                "$match": {
                    "company": company_id,
                    "execution.parameters": {"$exists": True, "$gt": {}},
                    **({"project": {"$in": project_ids}} if project_ids else {}),
                }
            },
            {"$project": {"parameters": {"$objectToArray": "$execution.parameters"}}},
            {"$unwind": "$parameters"},
            {"$group": {"_id": "$parameters.k"}},
            {"$sort": {"_id": 1}},
            {
                "$group": {
                    "_id": 1,
                    "total": {"$sum": 1},
                    "results": {"$push": "$$ROOT"},
                }
            },
            {
                "$project": {
                    "total": 1,
                    "results": {"$slice": ["$results", page * page_size, page_size]},
                }
            },
        ]

        with translate_errors_context():
            result = next(Task.aggregate(*pipeline), None)

        total = 0
        remaining = 0
        results = []

        if result:
            total = int(result.get("total", -1))
            results = [
                ParameterKeyEscaper.unescape(r["_id"])
                for r in result.get("results", [])
            ]
            remaining = max(0, total - (len(results) + page * page_size))

        return total, remaining, results
Exemplo n.º 29
0
def update_batch(call: APICall):
    identity = call.identity

    items = call.batched_data
    if items is None:
        raise errors.bad_request.BatchContainsNoItems()

    with translate_errors_context():
        items = {i["task"]: i for i in items}
        tasks = {
            t.id: t
            for t in Task.get_many_for_writing(company=identity.company,
                                               query=Q(id__in=list(items)))
        }

        if len(tasks) < len(items):
            missing = tuple(set(items).difference(tasks))
            raise errors.bad_request.InvalidTaskId(ids=missing)

        now = datetime.utcnow()

        bulk_ops = []
        for id, data in items.items():
            fields, valid_fields = prepare_update_fields(call, tasks[id], data)
            partial_update_dict = Task.get_safe_update_dict(fields)
            if not partial_update_dict:
                continue
            partial_update_dict.update(last_update=now)
            update_op = UpdateOne({
                "_id": id,
                "company": identity.company
            }, {"$set": partial_update_dict})
            bulk_ops.append(update_op)

        updated = 0
        if bulk_ops:
            res = Task._get_collection().bulk_write(bulk_ops)
            updated = res.modified_count

        call.result.data = {"updated": updated}
Exemplo n.º 30
0
def params_prepare_for_save(fields: dict, previous_task: Task = None):
    """
    If legacy hyper params or configuration is passed then replace the corresponding section in the new structure
    Escape all the section and param names for hyper params and configuration to make it mongo sage
    """
    for old_params_field, new_params_field, default_section in (
        ("execution/parameters", "hyperparams", hyperparams_default_section),
        ("execution/model_desc", "configuration", None),
    ):
        legacy_params = safe_get(fields, old_params_field)
        if legacy_params is None:
            continue

        if (
            not safe_get(fields, new_params_field)
            and previous_task
            and previous_task[new_params_field]
        ):
            previous_data = previous_task.to_proper_dict().get(new_params_field)
            removed = _remove_legacy_params(
                previous_data, with_sections=default_section is not None
            )
            if not legacy_params and not removed:
                # if we only need to delete legacy fields from the db
                # but they are not there then there is no point to proceed
                continue

            fields_update = {new_params_field: previous_data}
            params_unprepare_from_saved(fields_update)
            fields.update(fields_update)

        for full_name, value in legacy_params.items():
            section, name = split_param_name(full_name, default_section)
            new_path = list(filter(None, (new_params_field, section, name)))
            new_param = dict(name=name, type=hyperparams_legacy_type, value=str(value))
            if section is not None:
                new_param["section"] = section
            dpath.new(fields, new_path, new_param)
        dpath.delete(fields, old_params_field)

    for param_field in ("hyperparams", "configuration"):
        params = safe_get(fields, param_field)
        if params:
            escaped_params = {
                ParameterKeyEscaper.escape(key): {
                    ParameterKeyEscaper.escape(k): v for k, v in value.items()
                }
                if isinstance(value, dict)
                else value
                for key, value in params.items()
            }
            dpath.set(fields, param_field, escaped_params)