예제 #1
0
    def get_parent_tasks(
        cls,
        company_id: str,
        projects: Sequence[str],
        state: Optional[EntityVisibility] = None,
    ) -> Sequence[dict]:
        """
        Get list of unique parent tasks sorted by task name for the passed company projects
        If projects is None or empty then get parents for all the company tasks
        """
        query = Q(company=company_id)
        if projects:
            query &= Q(project__in=projects)
        if state == EntityVisibility.archived:
            query &= Q(system_tags__in=[EntityVisibility.archived.value])
        elif state == EntityVisibility.active:
            query &= Q(system_tags__nin=[EntityVisibility.archived.value])

        parent_ids = set(Task.objects(query).distinct("parent"))
        if not parent_ids:
            return []

        parents = Task.get_many_with_join(
            company_id,
            query=Q(id__in=parent_ids),
            allow_public=True,
            override_projection=("id", "name", "project.name"),
        )
        return sorted(parents, key=itemgetter("name"))
예제 #2
0
def _delete_tasks(company: str, projects: Sequence[str]) -> Tuple[int, Set, Set]:
    """
    Delete only the task themselves and their non published version.
    Child models under the same project are deleted separately.
    Children tasks should be deleted in the same api call.
    If any child entities are left in another projects then updated their parent task to None
    """
    tasks = Task.objects(project__in=projects).only("id", "execution__artifacts")
    if not tasks:
        return 0, set(), set()

    task_ids = {t.id for t in tasks}
    with TimingContext("mongo", "delete_tasks_update_children"):
        Task.objects(parent__in=task_ids, project__nin=projects).update(parent=None)
        Model.objects(task__in=task_ids, project__nin=projects).update(task=None)

    event_urls, artifact_urls = set(), set()
    for task in tasks:
        event_urls.update(collect_debug_image_urls(company, task.id))
        event_urls.update(collect_plot_image_urls(company, task.id))
        if task.execution and task.execution.artifacts:
            artifact_urls.update(
                {
                    a.uri
                    for a in task.execution.artifacts.values()
                    if a.mode == ArtifactModes.output and a.uri
                }
            )

    event_bll.delete_multi_task_events(company, list(task_ids))
    deleted = tasks.delete()
    return deleted, event_urls, artifact_urls
예제 #3
0
def update(call: APICall, company_id, req_model: UpdateRequest):
    task_id = req_model.task

    with translate_errors_context():
        task = Task.get_for_writing(
            id=task_id, company=company_id, _only=["id", "project"]
        )
        if not task:
            raise errors.bad_request.InvalidTaskId(id=task_id)

        partial_update_dict, valid_fields = prepare_update_fields(call, task, call.data)

        if not partial_update_dict:
            return UpdateResponse(updated=0)

        updated_count, updated_fields = Task.safe_update(
            company_id=company_id,
            id=task_id,
            partial_update_dict=partial_update_dict,
            injected_update=dict(last_change=datetime.utcnow()),
        )
        if updated_count:
            new_project = updated_fields.get("project", task.project)
            if new_project != task.project:
                _reset_cached_tags(company_id, projects=[new_project, task.project])
            else:
                _update_cached_tags(
                    company_id, project=task.project, fields=updated_fields
                )
            update_project_time(updated_fields.get("project"))
        unprepare_from_saved(call, updated_fields)
        return UpdateResponse(updated=updated_count, fields=updated_fields)
예제 #4
0
def verify_task_children_and_ouptuts(task: Task,
                                     force: bool) -> TaskOutputs[Model]:
    if not force:
        with TimingContext("mongo", "count_published_children"):
            published_children_count = Task.objects(
                parent=task.id, status=TaskStatus.published).count()
            if published_children_count:
                raise errors.bad_request.TaskCannotBeDeleted(
                    "has children, use force=True",
                    task=task.id,
                    children=published_children_count,
                )

    with TimingContext("mongo", "get_task_models"):
        models = TaskOutputs(
            attrgetter("ready"),
            Model,
            Model.objects(task=task.id).only("id", "task", "ready"),
        )
        if not force and models.published:
            raise errors.bad_request.TaskCannotBeDeleted(
                "has output models, use force=True",
                task=task.id,
                models=len(models.published),
            )

    if task.models and task.models.output:
        with TimingContext("mongo", "get_task_output_model"):
            model_ids = [m.model for m in task.models.output]
            for output_model in Model.objects(id__in=model_ids):
                if output_model.ready:
                    if not force:
                        raise errors.bad_request.TaskCannotBeDeleted(
                            "has output model, use force=True",
                            task=task.id,
                            model=output_model.id,
                        )
                    models.published.append(output_model)
                else:
                    models.draft.append(output_model)

    if models.draft:
        with TimingContext("mongo", "get_execution_models"):
            model_ids = models.draft.ids
            dependent_tasks = Task.objects(
                models__input__model__in=model_ids).only("id", "models")
            input_models = {
                m.model
                for m in chain.from_iterable(
                    t.models.input for t in dependent_tasks if t.models)
            }
            if input_models:
                models.draft = DocumentGroup(
                    Model,
                    (m for m in models.draft if m.id not in input_models))

    return models
예제 #5
0
def enqueue(call: APICall, company_id, req_model: EnqueueRequest):
    task_id = req_model.task
    queue_id = req_model.queue
    status_message = req_model.status_message
    status_reason = req_model.status_reason

    if not queue_id:
        # try to get default queue
        queue_id = queue_bll.get_default(company_id).id

    with translate_errors_context():
        query = dict(id=task_id, company=company_id)
        task = Task.get_for_writing(
            _only=("type", "script", "execution", "status", "project", "id"), **query
        )
        if not task:
            raise errors.bad_request.InvalidTaskId(**query)

        res = EnqueueResponse(
            **ChangeStatusRequest(
                task=task,
                new_status=TaskStatus.queued,
                status_reason=status_reason,
                status_message=status_message,
                allow_same_state_transition=False,
            ).execute()
        )

        try:
            queue_bll.add_task(
                company_id=company_id, queue_id=queue_id, task_id=task.id
            )
        except Exception:
            # failed enqueueing, revert to previous state
            ChangeStatusRequest(
                task=task,
                current_status_override=TaskStatus.queued,
                new_status=task.status,
                force=True,
                status_reason="failed enqueueing",
            ).execute()
            raise

        # set the current queue ID in the task
        if task.execution:
            Task.objects(**query).update(execution__queue=queue_id, multi=False)
        else:
            Task.objects(**query).update(
                execution=Execution(queue=queue_id), multi=False
            )

        res.queued = 1
        res.fields.update(**{"execution.queue": queue_id})

        call.result.data_model = res
예제 #6
0
def update_batch(call: APICall, company_id, _):
    items = call.batched_data
    if items is None:
        raise errors.bad_request.BatchContainsNoItems()

    with translate_errors_context():
        items = {i["task"]: i for i in items}
        tasks = {
            t.id: t
            for t in Task.get_many_for_writing(company=company_id,
                                               query=Q(id__in=list(items)))
        }

        if len(tasks) < len(items):
            missing = tuple(set(items).difference(tasks))
            raise errors.bad_request.InvalidTaskId(ids=missing)

        now = datetime.utcnow()

        bulk_ops = []
        updated_projects = set()
        for id, data in items.items():
            task = tasks[id]
            fields, valid_fields = prepare_update_fields(call, data)
            partial_update_dict = Task.get_safe_update_dict(fields)
            if not partial_update_dict:
                continue
            partial_update_dict.update(last_change=now)
            update_op = UpdateOne({
                "_id": id,
                "company": company_id
            }, {"$set": partial_update_dict})
            bulk_ops.append(update_op)

            new_project = partial_update_dict.get("project", task.project)
            if new_project != task.project:
                updated_projects.update({new_project, task.project})
            elif any(f in partial_update_dict
                     for f in ("tags", "system_tags")):
                updated_projects.add(task.project)

        updated = 0
        if bulk_ops:
            res = Task._get_collection().bulk_write(bulk_ops)
            updated = res.modified_count

        if updated and updated_projects:
            projects = list(updated_projects)
            _reset_cached_tags(company_id, projects=projects)
            update_project_time(project_ids=projects)

        call.result.data = {"updated": updated}
예제 #7
0
def enqueue_task(
    task_id: str,
    company_id: str,
    queue_id: str,
    status_message: str,
    status_reason: str,
    validate: bool = False,
) -> Tuple[int, dict]:
    if not queue_id:
        # try to get default queue
        queue_id = queue_bll.get_default(company_id).id

    query = dict(id=task_id, company=company_id)
    task = Task.get_for_writing(**query)
    if not task:
        raise errors.bad_request.InvalidTaskId(**query)

    if validate:
        TaskBLL.validate(task)

    res = ChangeStatusRequest(
        task=task,
        new_status=TaskStatus.queued,
        status_reason=status_reason,
        status_message=status_message,
        allow_same_state_transition=False,
    ).execute(enqueue_status=task.status)

    try:
        queue_bll.add_task(company_id=company_id,
                           queue_id=queue_id,
                           task_id=task.id)
    except Exception:
        # failed enqueueing, revert to previous state
        ChangeStatusRequest(
            task=task,
            current_status_override=TaskStatus.queued,
            new_status=task.status,
            force=True,
            status_reason="failed enqueueing",
        ).execute(enqueue_status=None)
        raise

    # set the current queue ID in the task
    if task.execution:
        Task.objects(**query).update(execution__queue=queue_id, multi=False)
    else:
        Task.objects(**query).update(execution=Execution(queue=queue_id),
                                     multi=False)

    nested_set(res, ("fields", "execution.queue"), queue_id)
    return 1, res
예제 #8
0
def _update_task_name(task: Task):
    if not task or not task.project:
        return

    project = Project.objects(id=task.project).only("name").first()
    if not project:
        return

    _, _, name_prefix = project.name.rpartition("/")
    name_mask = re.compile(rf"{re.escape(name_prefix)}( #\d+)?$")
    count = Task.objects(project=task.project,
                         system_tags__in=["pipeline"],
                         name=name_mask).count()
    new_name = f"{name_prefix} #{count}" if count > 0 else name_prefix
    task.update(name=new_name)
예제 #9
0
    def get_by_id(
        company_id,
        task_id,
        required_status=None,
        only_fields=None,
        allow_public=False,
    ):
        if only_fields:
            if isinstance(only_fields, string_types):
                only_fields = [only_fields]
            else:
                only_fields = list(only_fields)
            only_fields = only_fields + ["status"]

        with TimingContext("mongo", "task_by_id_all"):
            tasks = Task.get_many(
                company=company_id,
                query=Q(id=task_id),
                allow_public=allow_public,
                override_projection=only_fields,
                return_dicts=False,
            )
            task = None if not tasks else tasks[0]

        if not task:
            raise errors.bad_request.InvalidTaskId(id=task_id)

        if required_status and not task.status == required_status:
            raise errors.bad_request.InvalidTaskStatus(
                expected=required_status)

        return task
예제 #10
0
    def model_set_ready(
        cls,
        model_id: str,
        company_id: str,
        publish_task: bool,
        force_publish_task: bool = False,
    ) -> tuple:
        with translate_errors_context():
            query = dict(id=model_id, company=company_id)
            model = Model.objects(**query).first()
            if not model:
                raise errors.bad_request.InvalidModelId(**query)
            elif model.ready:
                raise errors.bad_request.ModelIsReady(**query)

            published_task_data = {}
            if model.task and publish_task:
                task = (Task.objects(id=model.task, company=company_id).only(
                    "id", "status").first())
                if task and task.status != TaskStatus.published:
                    published_task_data["data"] = cls.publish_task(
                        task_id=model.task,
                        company_id=company_id,
                        publish_model=False,
                        force=force_publish_task,
                    )
                    published_task_data["id"] = model.task

            updated = model.update(upsert=False, ready=True)
            return updated, published_task_data
예제 #11
0
def get_by_task_id(call: APICall, company_id, _):
    if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
        raise errors.moved_permanently.NotSupported(
            "use models.get_by_id/get_all apis")

    task_id = call.data["task"]

    with translate_errors_context():
        query = dict(id=task_id, company=company_id)
        task = Task.get(_only=["models"], **query)
        if not task:
            raise errors.bad_request.InvalidTaskId(**query)
        if not task.models or not task.models.output:
            raise errors.bad_request.MissingTaskFields(field="models.output")

        model_id = task.models.output[-1].model
        model = Model.objects(
            Q(id=model_id)
            & get_company_or_none_constraint(company_id)).first()
        if not model:
            raise errors.bad_request.InvalidModelId(
                "no such public or company model",
                id=model_id,
                company=company_id,
            )
        model_dict = model.to_proper_dict()
        conform_output_tags(call, model_dict)
        call.result.data = {"model": model_dict}
예제 #12
0
    def delete_task_events(self, company_id, task_id, allow_locked=False):
        with translate_errors_context():
            extra_msg = None
            query = Q(id=task_id, company=company_id)
            if not allow_locked:
                query &= Q(status__nin=LOCKED_TASK_STATUSES)
                extra_msg = "or task published"
            res = Task.objects(query).only("id").first()
            if not res:
                raise errors.bad_request.InvalidTaskId(extra_msg,
                                                       company=company_id,
                                                       id=task_id)

        es_req = {"query": {"term": {"task": task_id}}}
        with translate_errors_context(), TimingContext("es",
                                                       "delete_task_events"):
            es_res = delete_company_events(
                es=self.es,
                company_id=company_id,
                event_type=EventType.all,
                body=es_req,
                refresh=True,
            )

        return es_res.get("deleted", 0)
예제 #13
0
    def publish_model(
        cls,
        model_id: str,
        company_id: str,
        force_publish_task: bool = False,
        publish_task_func: Callable[[str, str, bool], dict] = None,
    ) -> Tuple[int, ModelTaskPublishResponse]:
        model = cls.get_company_model_by_id(company_id=company_id,
                                            model_id=model_id)
        if model.ready:
            raise errors.bad_request.ModelIsReady(company=company_id,
                                                  model=model_id)

        published_task = None
        if model.task and publish_task_func:
            task = (Task.objects(id=model.task,
                                 company=company_id).only("id",
                                                          "status").first())
            if task and task.status != TaskStatus.published:
                task_publish_res = publish_task_func(model.task, company_id,
                                                     force_publish_task)
                published_task = ModelTaskPublishResponse(
                    id=model.task, data=task_publish_res)

        updated = model.update(upsert=False,
                               ready=True,
                               last_update=datetime.utcnow())
        return updated, published_task
예제 #14
0
    def assert_exists(company_id,
                      task_ids,
                      only=None,
                      allow_public=False,
                      return_tasks=True) -> Optional[Sequence[Task]]:
        task_ids = [task_ids] if isinstance(task_ids,
                                            six.string_types) else task_ids
        with translate_errors_context(), TimingContext("mongo", "task_exists"):
            ids = set(task_ids)
            q = Task.get_many(
                company=company_id,
                query=Q(id__in=ids),
                allow_public=allow_public,
                return_dicts=False,
            )
            if only:
                # Make sure to reset fields filters (some fields are excluded by default) since this
                # is an internal call and specific fields were requested.
                q = q.all_fields().only(*only)

            if q.count() != len(ids):
                raise errors.bad_request.InvalidTaskId(ids=task_ids)

            if return_tasks:
                return list(q)
예제 #15
0
    def validate(
        cls,
        task: Task,
        validate_models=True,
        validate_parent=True,
        validate_project=True,
    ):
        """
        Validate task properties according to the flag
        Task project is always checked for being writable
        in order to disable the modification of public projects
        """
        if (validate_parent and task.parent
                and not task.parent.startswith(deleted_prefix)
                and not Task.get(company=task.company,
                                 id=task.parent,
                                 _only=("id", ),
                                 include_public=True)):
            raise errors.bad_request.InvalidTaskId("invalid parent",
                                                   parent=task.parent)

        if task.project:
            project = Project.get_for_writing(company=task.company,
                                              id=task.project)
            if validate_project and not project:
                raise errors.bad_request.InvalidProjectId(id=task.project)

        if validate_models:
            cls.validate_input_models(task)
예제 #16
0
def prepare_update_fields(call: APICall, task, call_data):
    valid_fields = deepcopy(Task.user_set_allowed())
    update_fields = {k: v for k, v in create_fields.items() if k in valid_fields}
    update_fields["output__error"] = None
    t_fields = task_fields
    t_fields.add("output__error")
    fields = parse_from_call(call_data, update_fields, t_fields)
    return prepare_for_save(call, fields), valid_fields
예제 #17
0
    def get_unique_metric_variants(cls, company_id, project_ids: Sequence[str],
                                   include_subprojects: bool):
        pipeline = [
            {
                "$match": {
                    **cls._get_company_constraint(company_id),
                    **cls._get_project_constraint(project_ids, include_subprojects),
                }
            },
            {
                "$project": {
                    "metrics": {
                        "$objectToArray": "$last_metrics"
                    }
                }
            },
            {
                "$unwind": "$metrics"
            },
            {
                "$project": {
                    "metric": "$metrics.k",
                    "variants": {
                        "$objectToArray": "$metrics.v"
                    },
                }
            },
            {
                "$unwind": "$variants"
            },
            {
                "$group": {
                    "_id": {
                        "metric": "$variants.v.metric",
                        "variant": "$variants.v.variant",
                    },
                    "metrics": {
                        "$addToSet": {
                            "metric": "$variants.v.metric",
                            "metric_hash": "$metric",
                            "variant": "$variants.v.variant",
                            "variant_hash": "$variants.k",
                        }
                    },
                }
            },
            {
                "$sort": OrderedDict({
                    "_id.metric": 1,
                    "_id.variant": 1
                })
            },
        ]

        result = Task.aggregate(pipeline)
        return [r["metrics"][0] for r in result]
예제 #18
0
 def get_types(cls, company, project_ids: Optional[Sequence]) -> set:
     """
     Return the list of unique task types used by company and public tasks
     If project ids passed then only tasks from these projects are considered
     """
     query = get_company_or_none_constraint(company)
     if project_ids:
         query &= Q(project__in=project_ids)
     res = Task.objects(query).distinct(field="type")
     return set(res).intersection(external_task_types)
예제 #19
0
    def cleanup_tasks(cls, threshold_sec):
        relevant_status = (TaskStatus.in_progress, )
        threshold = timedelta(seconds=threshold_sec)
        ref_time = datetime.utcnow() - threshold
        log.info(
            f"Starting cleanup cycle for running tasks last updated before {ref_time}"
        )

        tasks = list(
            Task.objects(status__in=relevant_status,
                         last_update__lt=ref_time).only(
                             "id", "name", "status", "project", "last_update"))
        log.info(f"{len(tasks)} non-responsive tasks found")
        if not tasks:
            return 0

        err_count = 0
        project_ids = set()
        now = datetime.utcnow()
        for task in tasks:
            log.info(
                f"Stopping {task.id} ({task.name}), last updated at {task.last_update}"
            )
            # noinspection PyBroadException
            try:
                updated = Task.objects(id=task.id, status=task.status).update(
                    status=TaskStatus.stopped,
                    status_reason="Forced stop (non-responsive)",
                    status_message="Forced stop (non-responsive)",
                    status_changed=now,
                    last_update=now,
                    last_change=now,
                )
                if updated:
                    project_ids.add(task.project)
                else:
                    err_count += 1
            except Exception as ex:
                log.error("Failed setting status: %s", str(ex))

        update_project_time(list(project_ids))

        return len(tasks) - err_count
예제 #20
0
    def _get_valid_tasks(company_id, task_ids: Set, allow_locked_tasks=False) -> Set:
        """Verify that task exists and can be updated"""
        if not task_ids:
            return set()

        with translate_errors_context(), TimingContext("mongo", "task_by_ids"):
            query = Q(id__in=task_ids, company=company_id)
            if not allow_locked_tasks:
                query &= Q(status__nin=LOCKED_TASK_STATUSES)
            res = Task.objects(query).only("id")
            return {r.id for r in res}
예제 #21
0
def get_outputs_for_deletion(task, force=False):
    with TimingContext("mongo", "get_task_models"):
        models = TaskOutputs(
            attrgetter("ready"),
            Model,
            Model.objects(task=task.id).only("id", "task", "ready"),
        )
        if not force and models.published:
            raise errors.bad_request.TaskCannotBeDeleted(
                "has output models, use force=True",
                task=task.id,
                models=len(models.published),
            )

    if task.output.model:
        output_model = get_output_model(task, force)
        if output_model:
            if output_model.ready:
                models.published.append(output_model)
            else:
                models.draft.append(output_model)

    if models.draft:
        with TimingContext("mongo", "get_execution_models"):
            model_ids = [m.id for m in models.draft]
            dependent_tasks = Task.objects(execution__model__in=model_ids).only(
                "id", "execution.model"
            )
            busy_models = [t.execution.model for t in dependent_tasks]
            models.draft[:] = [m for m in models.draft if m.id not in busy_models]

    with TimingContext("mongo", "get_task_children"):
        tasks = Task.objects(parent=task.id).only("id", "parent", "status")
        published_tasks = [
            task for task in tasks if task.status == TaskStatus.published
        ]
        if not force and published_tasks:
            raise errors.bad_request.TaskCannotBeDeleted(
                "has children, use force=True", task=task.id, children=published_tasks
            )
    return models, tasks
예제 #22
0
 def create(call: APICall, fields: dict):
     identity = call.identity
     now = datetime.utcnow()
     return Task(
         id=create_id(),
         user=identity.user,
         company=identity.company,
         created=now,
         last_update=now,
         last_change=now,
         **fields,
     )
예제 #23
0
def params_prepare_for_save(fields: dict, previous_task: Task = None):
    """
    If legacy hyper params or configuration is passed then replace the corresponding section in the new structure
    Escape all the section and param names for hyper params and configuration to make it mongo sage
    """
    for old_params_field, new_params_field, default_section in (
        (("execution", "parameters"), "hyperparams",
         hyperparams_default_section),
        (("execution", "model_desc"), "configuration", None),
    ):
        legacy_params = nested_get(fields, old_params_field)
        if legacy_params is None:
            continue

        if (not fields.get(new_params_field) and previous_task
                and previous_task[new_params_field]):
            previous_data = previous_task.to_proper_dict().get(
                new_params_field)
            removed = _remove_legacy_params(previous_data,
                                            with_sections=default_section
                                            is not None)
            if not legacy_params and not removed:
                # if we only need to delete legacy fields from the db
                # but they are not there then there is no point to proceed
                continue

            fields_update = {new_params_field: previous_data}
            params_unprepare_from_saved(fields_update)
            fields.update(fields_update)

        for full_name, value in legacy_params.items():
            section, name = split_param_name(full_name, default_section)
            new_path = list(filter(None, (new_params_field, section, name)))
            new_param = dict(name=name,
                             type=hyperparams_legacy_type,
                             value=str(value))
            if section is not None:
                new_param["section"] = section
            nested_set(fields, new_path, new_param)
        nested_delete(fields, old_params_field)

    for param_field in ("hyperparams", "configuration"):
        params = fields.get(param_field)
        if params:
            escaped_params = {
                ParameterKeyEscaper.escape(key):
                {ParameterKeyEscaper.escape(k): v
                 for k, v in value.items()}
                if isinstance(value, dict) else value
                for key, value in params.items()
            }
            fields[param_field] = escaped_params
예제 #24
0
 def set_last_update(
     task_ids: Collection[str],
     company_id: str,
     last_update: datetime,
     **extra_updates,
 ):
     tasks = Task.objects(id__in=task_ids,
                          company=company_id).only("status", "started")
     for task in tasks:
         updates = extra_updates
         if task.status == TaskStatus.in_progress and task.started:
             updates = {
                 "active_duration":
                 (datetime.utcnow() - task.started).total_seconds(),
                 **extra_updates,
             }
         Task.objects(id=task.id, company=company_id).update(
             upsert=False,
             last_update=last_update,
             last_change=last_update,
             **updates,
         )
예제 #25
0
def get_by_id_ex(call: APICall, company_id, _):
    conform_tag_fields(call, call.data)

    escape_execution_parameters(call)

    with translate_errors_context():
        with TimingContext("mongo", "task_get_by_id_ex"):
            tasks = Task.get_many_with_join(
                company=company_id, query_dict=call.data, allow_public=True,
            )

        unprepare_from_saved(call, tasks)
        call.result.data = {"tasks": tasks}
예제 #26
0
    def execute(self, **kwargs):
        current_status = self.current_status_override or self.task.status
        project_id = self.task.project

        # Verify new status is allowed from current status (will throw exception if not valid)
        self.validate_transition(current_status)

        control = dict(upsert=False,
                       multi=False,
                       write_concern=None,
                       full_result=False)

        now = datetime.utcnow()
        fields = dict(
            status=self.new_status,
            status_reason=self.status_reason,
            status_message=self.status_message,
            status_changed=now,
            last_update=now,
            last_change=now,
        )

        if self.new_status == TaskStatus.queued:
            fields["pull__system_tags"] = TaskSystemTags.development

        def safe_mongoengine_key(key):
            return f"__{key}" if key in control else key

        fields.update({safe_mongoengine_key(k): v for k, v in kwargs.items()})

        with translate_errors_context(), TimingContext("mongo", "task_status"):
            # atomic change of task status by querying the task with the EXPECTED status before modifying it
            params = fields.copy()
            params.update(control)
            updated = Task.objects(id=self.task.id,
                                   status=current_status).update(**params)

        if not updated:
            # failed to change status (someone else beat us to it?)
            raise errors.bad_request.FailedChangingTaskStatus(
                task_id=self.task.id,
                current_status=current_status,
                new_status=self.new_status,
            )

        update_project_time(project_id)

        # make sure that _raw_ queries are not returned back to the client
        fields.pop("__raw__", None)

        return dict(updated=updated, fields=fields)
예제 #27
0
def escape_execution_parameters(call: APICall):
    projection = Task.get_projection(call.data)
    if projection:
        Task.set_projection(call.data, escape_paths(projection))

    ordering = Task.get_ordering(call.data)
    if ordering:
        Task.set_ordering(call.data, escape_paths(ordering))
예제 #28
0
def _delete_models(projects: Sequence[str]) -> Tuple[int, Set[str]]:
    """
    Delete project models and update the tasks from other projects
    that reference them to reference None.
    """
    with TimingContext("mongo", "delete_models"):
        models = Model.objects(project__in=projects).only("task", "id", "uri")
        if not models:
            return 0, set()

        model_ids = list({m.id for m in models})

        Task._get_collection().update_many(
            filter={
                "project": {"$nin": projects},
                "models.input.model": {"$in": model_ids},
            },
            update={"$set": {"models.input.$[elem].model": None}},
            array_filters=[{"elem.model": {"$in": model_ids}}],
            upsert=False,
        )

        model_tasks = list({m.task for m in models if m.task})
        if model_tasks:
            Task._get_collection().update_many(
                filter={
                    "_id": {"$in": model_tasks},
                    "project": {"$nin": projects},
                    "models.output.model": {"$in": model_ids},
                },
                update={"$set": {"models.output.$[elem].model": None}},
                array_filters=[{"elem.model": {"$in": model_ids}}],
                upsert=False,
            )

        urls = {m.uri for m in models if m.uri}
        deleted = models.delete()
        return deleted, urls
예제 #29
0
def get_next_task(call: APICall, company_id, req_model: GetNextTaskRequest):
    entry = queue_bll.get_next_task(company_id=company_id,
                                    queue_id=req_model.queue)
    if entry:
        data = {"entry": entry.to_proper_dict()}
        if req_model.get_task_info:
            task = Task.objects(id=entry.task).first()
            if task:
                data["task_info"] = {
                    "company": task.company,
                    "user": task.user
                }

        call.result.data = data
예제 #30
0
def prepare_update_fields(call: APICall, call_data):
    valid_fields = deepcopy(Task.user_set_allowed())
    update_fields = {
        k: v
        for k, v in create_fields.items() if k in valid_fields
    }
    update_fields.update(status=None,
                         status_reason=None,
                         status_message=None,
                         output__error=None)
    t_fields = task_fields
    t_fields.add("output__error")
    fields = parse_from_call(call_data, update_fields, t_fields)
    return prepare_for_save(call, fields), valid_fields