Example #1
0
def update(call: APICall, company_id, req_model: UpdateRequest):
    task_id = req_model.task

    with translate_errors_context():
        task = Task.get_for_writing(
            id=task_id, company=company_id, _only=["id", "project"]
        )
        if not task:
            raise errors.bad_request.InvalidTaskId(id=task_id)

        partial_update_dict, valid_fields = prepare_update_fields(call, task, call.data)

        if not partial_update_dict:
            return UpdateResponse(updated=0)

        updated_count, updated_fields = Task.safe_update(
            company_id=company_id,
            id=task_id,
            partial_update_dict=partial_update_dict,
            injected_update=dict(last_change=datetime.utcnow()),
        )
        if updated_count:
            new_project = updated_fields.get("project", task.project)
            if new_project != task.project:
                _reset_cached_tags(company_id, projects=[new_project, task.project])
            else:
                _update_cached_tags(
                    company_id, project=task.project, fields=updated_fields
                )
            update_project_time(updated_fields.get("project"))
        unprepare_from_saved(call, updated_fields)
        return UpdateResponse(updated=updated_count, fields=updated_fields)
Example #2
0
def enqueue(call: APICall, company_id, req_model: EnqueueRequest):
    task_id = req_model.task
    queue_id = req_model.queue
    status_message = req_model.status_message
    status_reason = req_model.status_reason

    if not queue_id:
        # try to get default queue
        queue_id = queue_bll.get_default(company_id).id

    with translate_errors_context():
        query = dict(id=task_id, company=company_id)
        task = Task.get_for_writing(
            _only=("type", "script", "execution", "status", "project", "id"), **query
        )
        if not task:
            raise errors.bad_request.InvalidTaskId(**query)

        res = EnqueueResponse(
            **ChangeStatusRequest(
                task=task,
                new_status=TaskStatus.queued,
                status_reason=status_reason,
                status_message=status_message,
                allow_same_state_transition=False,
            ).execute()
        )

        try:
            queue_bll.add_task(
                company_id=company_id, queue_id=queue_id, task_id=task.id
            )
        except Exception:
            # failed enqueueing, revert to previous state
            ChangeStatusRequest(
                task=task,
                current_status_override=TaskStatus.queued,
                new_status=task.status,
                force=True,
                status_reason="failed enqueueing",
            ).execute()
            raise

        # set the current queue ID in the task
        if task.execution:
            Task.objects(**query).update(execution__queue=queue_id, multi=False)
        else:
            Task.objects(**query).update(
                execution=Execution(queue=queue_id), multi=False
            )

        res.queued = 1
        res.fields.update(**{"execution.queue": queue_id})

        call.result.data_model = res
Example #3
0
def enqueue_task(
    task_id: str,
    company_id: str,
    queue_id: str,
    status_message: str,
    status_reason: str,
    validate: bool = False,
) -> Tuple[int, dict]:
    if not queue_id:
        # try to get default queue
        queue_id = queue_bll.get_default(company_id).id

    query = dict(id=task_id, company=company_id)
    task = Task.get_for_writing(**query)
    if not task:
        raise errors.bad_request.InvalidTaskId(**query)

    if validate:
        TaskBLL.validate(task)

    res = ChangeStatusRequest(
        task=task,
        new_status=TaskStatus.queued,
        status_reason=status_reason,
        status_message=status_message,
        allow_same_state_transition=False,
    ).execute(enqueue_status=task.status)

    try:
        queue_bll.add_task(company_id=company_id,
                           queue_id=queue_id,
                           task_id=task.id)
    except Exception:
        # failed enqueueing, revert to previous state
        ChangeStatusRequest(
            task=task,
            current_status_override=TaskStatus.queued,
            new_status=task.status,
            force=True,
            status_reason="failed enqueueing",
        ).execute(enqueue_status=None)
        raise

    # set the current queue ID in the task
    if task.execution:
        Task.objects(**query).update(execution__queue=queue_id, multi=False)
    else:
        Task.objects(**query).update(execution=Execution(queue=queue_id),
                                     multi=False)

    nested_set(res, ("fields", "execution.queue"), queue_id)
    return 1, res
Example #4
0
def dequeue_task(
    task_id: str,
    company_id: str,
    status_message: str,
    status_reason: str,
) -> Tuple[int, dict]:
    query = dict(id=task_id, company=company_id)
    task = Task.get_for_writing(**query)
    if not task:
        raise errors.bad_request.InvalidTaskId(**query)

    res = TaskBLL.dequeue_and_change_status(
        task,
        company_id,
        status_message=status_message,
        status_reason=status_reason,
    )
    return 1, res
Example #5
0
def get_task_for_update(
    company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
) -> Task:
    """
    Loads only task id and return the task only if it is updatable (status == 'created')
    """
    task = Task.get_for_writing(company=company_id, id=task_id, _only=("id", "status"))
    if not task:
        raise errors.bad_request.InvalidTaskId(id=task_id)

    if allow_all_statuses:
        return task

    allowed_statuses = (
        [TaskStatus.created, TaskStatus.in_progress] if force else [TaskStatus.created]
    )
    if task.status not in allowed_statuses:
        raise errors.bad_request.InvalidTaskStatus(
            expected=TaskStatus.created, status=task.status
        )
    return task
Example #6
0
    def get_task_with_access(task_id,
                             company_id,
                             only=None,
                             allow_public=False,
                             requires_write_access=False) -> Task:
        """
        Gets a task that has a required write access
        :except errors.bad_request.InvalidTaskId: if the task is not found
        :except errors.forbidden.NoWritePermission: if write_access was required and the task cannot be modified
        """
        with translate_errors_context():
            query = dict(id=task_id, company=company_id)
            with TimingContext("mongo", "task_with_access"):
                if requires_write_access:
                    task = Task.get_for_writing(_only=only, **query)
                else:
                    task = Task.get(_only=only,
                                    **query,
                                    include_public=allow_public)

            if not task:
                raise errors.bad_request.InvalidTaskId(**query)

            return task
Example #7
0
def edit(call: APICall, company_id, req_model: UpdateRequest):
    task_id = req_model.task
    force = req_model.force

    with translate_errors_context():
        task = Task.get_for_writing(id=task_id, company=company_id)
        if not task:
            raise errors.bad_request.InvalidTaskId(id=task_id)

        if not force and task.status != TaskStatus.created:
            raise errors.bad_request.InvalidTaskStatus(
                expected=TaskStatus.created, status=task.status
            )

        edit_fields = create_fields.copy()
        edit_fields.update(dict(status=None))

        with translate_errors_context(
            field_does_not_exist_cls=errors.bad_request.ValidationError
        ), TimingContext("code", "parse_and_validate"):
            fields = prepare_create_fields(
                call, valid_fields=edit_fields, output=task.output, previous_task=task
            )

        for key in fields:
            field = getattr(task, key, None)
            value = fields[key]
            if (
                field
                and isinstance(value, dict)
                and isinstance(field, EmbeddedDocument)
            ):
                d = field.to_mongo(use_db_field=False).to_dict()
                d.update(value)
                fields[key] = d

        task_bll.validate(task_bll.create(call, fields))

        # make sure field names do not end in mongoengine comparison operators
        fixed_fields = {
            (k if k not in COMPARISON_OPERATORS else "%s__" % k): v
            for k, v in fields.items()
        }
        if fixed_fields:
            now = datetime.utcnow()
            last_change = dict(last_change=now)
            if not set(fields).issubset(Task.user_set_allowed()):
                last_change.update(last_update=now)
            fields.update(**last_change)
            fixed_fields.update(**last_change)
            updated = task.update(upsert=False, **fixed_fields)
            if updated:
                new_project = fixed_fields.get("project", task.project)
                if new_project != task.project:
                    _reset_cached_tags(company_id, projects=[new_project, task.project])
                else:
                    _update_cached_tags(
                        company_id, project=task.project, fields=fixed_fields
                    )
                update_project_time(fields.get("project"))
            unprepare_from_saved(call, fields)
            call.result.data_model = UpdateResponse(updated=updated, fields=fields)
        else:
            call.result.data_model = UpdateResponse(updated=0)
Example #8
0
def validate_task(company_id, fields: dict):
    Task.get_for_writing(company=company_id, id=fields["task"], _only=["id"])
Example #9
0
def update_for_task(call: APICall, company_id, _):
    if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
        raise errors.moved_permanently.NotSupported(
            "use tasks.add_or_update_model")

    task_id = call.data["task"]
    uri = call.data.get("uri")
    iteration = call.data.get("iteration")
    override_model_id = call.data.get("override_model_id")
    if not (uri or override_model_id) or (uri and override_model_id):
        raise errors.bad_request.MissingRequiredFields(
            "exactly one field is required",
            fields=("uri", "override_model_id"))

    with translate_errors_context():

        query = dict(id=task_id, company=company_id)
        task = Task.get_for_writing(
            id=task_id,
            company=company_id,
            _only=["models", "execution", "name", "status", "project"],
        )
        if not task:
            raise errors.bad_request.InvalidTaskId(**query)

        allowed_states = [TaskStatus.created, TaskStatus.in_progress]
        if task.status not in allowed_states:
            raise errors.bad_request.InvalidTaskStatus(
                f"model can only be updated for tasks in the {allowed_states} states",
                **query,
            )

        if override_model_id:
            model = ModelBLL.get_company_model_by_id(
                company_id=company_id, model_id=override_model_id)
        else:
            if "name" not in call.data:
                # use task name if name not provided
                call.data["name"] = task.name

            if "comment" not in call.data:
                call.data[
                    "comment"] = f"Created by task `{task.name}` ({task.id})"

            if task.models and task.models.output:
                # model exists, update
                model_id = task.models.output[-1].model
                res = _update_model(call, company_id,
                                    model_id=model_id).to_struct()
                res.update({"id": model_id, "created": False})
                call.result.data = res
                return

            # new model, create
            fields = parse_model_fields(call, create_fields)

            # create and save model
            now = datetime.utcnow()
            model = Model(
                id=database.utils.id(),
                created=now,
                last_update=now,
                user=call.identity.user,
                company=company_id,
                project=task.project,
                framework=task.execution.framework,
                parent=task.models.input[0].model
                if task.models and task.models.input else None,
                design=task.execution.model_desc,
                labels=task.execution.model_labels,
                ready=(task.status == TaskStatus.published),
                **fields,
            )
            model.save()
            _update_cached_tags(company_id,
                                project=model.project,
                                fields=fields)

        TaskBLL.update_statistics(
            task_id=task_id,
            company_id=company_id,
            last_iteration_max=iteration,
            models__output=[
                ModelItem(
                    model=model.id,
                    name=TaskModelNames[TaskModelTypes.output],
                    updated=datetime.utcnow(),
                )
            ],
        )

        call.result.data = {"id": model.id, "created": True}