示例#1
0
def _update_model(call: APICall, company_id, model_id=None):
    model_id = model_id or call.data["model"]

    model = ModelBLL.get_company_model_by_id(company_id=company_id,
                                             model_id=model_id)

    data = prepare_update_fields(call, company_id, call.data)

    task_id = data.get("task")
    iteration = data.get("iteration")
    if task_id and iteration is not None:
        TaskBLL.update_statistics(
            task_id=task_id,
            company_id=company_id,
            last_iteration_max=iteration,
        )

    updated_count, updated_fields = Model.safe_update(company_id, model.id,
                                                      data)
    if updated_count:
        if any(uf in updated_fields for uf in last_update_fields):
            model.update(upsert=False, last_update=datetime.utcnow())

        new_project = updated_fields.get("project", model.project)
        if new_project != model.project:
            _reset_cached_tags(company_id,
                               projects=[new_project, model.project])
        else:
            _update_cached_tags(company_id,
                                project=model.project,
                                fields=updated_fields)
    conform_output_tags(call, updated_fields)
    unescape_metadata(call, updated_fields)
    return UpdateResponse(updated=updated_count, fields=updated_fields)
示例#2
0
def archive(call: APICall, company_id, request: ArchiveRequest):
    archived = 0
    tasks = TaskBLL.assert_exists(
        company_id,
        task_ids=request.tasks,
        only=("id", "execution", "status", "project", "system_tags"),
    )
    for task in tasks:
        try:
            TaskBLL.dequeue_and_change_status(
                task, company_id, request.status_message, request.status_reason,
            )
        except APIError:
            # dequeue may fail if the task was not enqueued
            pass
        task.update(
            status_message=request.status_message,
            status_reason=request.status_reason,
            system_tags=sorted(
                set(task.system_tags) | {EntityVisibility.archived.value}
            ),
            last_change=datetime.utcnow(),
        )

        archived += 1

    call.result.data_model = ArchiveResponse(archived=archived)
示例#3
0
def archive_task(
    task: Union[str, Task], company_id: str, status_message: str, status_reason: str,
) -> int:
    """
    Deque and archive task
    Return 1 if successful
    """
    if isinstance(task, str):
        task = TaskBLL.get_task_with_access(
            task,
            company_id=company_id,
            only=(
                "id",
                "execution",
                "status",
                "project",
                "system_tags",
                "enqueue_status",
            ),
            requires_write_access=True,
        )
    try:
        TaskBLL.dequeue_and_change_status(
            task, company_id, status_message, status_reason,
        )
    except APIError:
        # dequeue may fail if the task was not enqueued
        pass

    return task.update(
        status_message=status_message,
        status_reason=status_reason,
        add_to_set__system_tags=EntityVisibility.archived.value,
        last_change=datetime.utcnow(),
    )
示例#4
0
def delete_task(
    task_id: str,
    company_id: str,
    move_to_trash: bool,
    force: bool,
    return_file_urls: bool,
    delete_output_models: bool,
    status_message: str,
    status_reason: str,
) -> Tuple[int, Task, CleanupResult]:
    task = TaskBLL.get_task_with_access(
        task_id, company_id=company_id, requires_write_access=True
    )

    if (
        task.status != TaskStatus.created
        and EntityVisibility.archived.value not in task.system_tags
        and not force
    ):
        raise errors.bad_request.TaskCannotBeDeleted(
            "due to status, use force=True",
            task=task.id,
            expected=TaskStatus.created,
            current=task.status,
        )

    try:
        TaskBLL.dequeue_and_change_status(
            task,
            company_id=company_id,
            status_message=status_message,
            status_reason=status_reason,
        )
    except APIError:
        # dequeue may fail if the task was not enqueued
        pass

    cleanup_res = cleanup_task(
        task,
        force=force,
        return_file_urls=return_file_urls,
        delete_output_models=delete_output_models,
    )

    if move_to_trash:
        collection_name = task._get_collection_name()
        archived_collection = "{}__trash".format(collection_name)
        task.switch_collection(archived_collection)
        try:
            # A simple save() won't do due to mongoengine caching (nothing will be saved), so we have to force
            # an insert. However, if for some reason such an ID exists, let's make sure we'll keep going.
            task.save(force_insert=True)
        except Exception:
            pass
        task.switch_collection(collection_name)

    task.delete()
    update_project_time(task.project)
    return 1, task, cleanup_res
示例#5
0
def enqueue_task(
    task_id: str,
    company_id: str,
    queue_id: str,
    status_message: str,
    status_reason: str,
    validate: bool = False,
) -> Tuple[int, dict]:
    if not queue_id:
        # try to get default queue
        queue_id = queue_bll.get_default(company_id).id

    query = dict(id=task_id, company=company_id)
    task = Task.get_for_writing(**query)
    if not task:
        raise errors.bad_request.InvalidTaskId(**query)

    if validate:
        TaskBLL.validate(task)

    res = ChangeStatusRequest(
        task=task,
        new_status=TaskStatus.queued,
        status_reason=status_reason,
        status_message=status_message,
        allow_same_state_transition=False,
    ).execute(enqueue_status=task.status)

    try:
        queue_bll.add_task(company_id=company_id,
                           queue_id=queue_id,
                           task_id=task.id)
    except Exception:
        # failed enqueueing, revert to previous state
        ChangeStatusRequest(
            task=task,
            current_status_override=TaskStatus.queued,
            new_status=task.status,
            force=True,
            status_reason="failed enqueueing",
        ).execute(enqueue_status=None)
        raise

    # set the current queue ID in the task
    if task.execution:
        Task.objects(**query).update(execution__queue=queue_id, multi=False)
    else:
        Task.objects(**query).update(execution=Execution(queue=queue_id),
                                     multi=False)

    nested_set(res, ("fields", "execution.queue"), queue_id)
    return 1, res
示例#6
0
def get_by_id(call: APICall, company_id, req_model: TaskRequest):
    task = TaskBLL.get_task_with_access(
        req_model.task, company_id=company_id, allow_public=True
    )
    task_dict = task.to_proper_dict()
    unprepare_from_saved(call, task_dict)
    call.result.data = {"task": task_dict}
示例#7
0
def edit(call: APICall, company_id, _):
    model_id = call.data["model"]

    with translate_errors_context():
        model = ModelBLL.get_company_model_by_id(company_id=company_id,
                                                 model_id=model_id)

        fields = parse_model_fields(call, create_fields)
        fields = prepare_update_fields(call, company_id, fields)

        for key in fields:
            field = getattr(model, key, None)
            value = fields[key]
            if (field and isinstance(value, dict)
                    and isinstance(field, EmbeddedDocument)):
                d = field.to_mongo(use_db_field=False).to_dict()
                d.update(value)
                fields[key] = d

        iteration = call.data.get("iteration")
        task_id = model.task or fields.get("task")
        if task_id and iteration is not None:
            TaskBLL.update_statistics(
                task_id=task_id,
                company_id=company_id,
                last_iteration_max=iteration,
            )

        if fields:
            if any(uf in fields for uf in last_update_fields):
                fields.update(last_update=datetime.utcnow())

            updated = model.update(upsert=False, **fields)
            if updated:
                new_project = fields.get("project", model.project)
                if new_project != model.project:
                    _reset_cached_tags(company_id,
                                       projects=[new_project, model.project])
                else:
                    _update_cached_tags(company_id,
                                        project=model.project,
                                        fields=fields)
            conform_output_tags(call, fields)
            call.result.data_model = UpdateResponse(updated=updated,
                                                    fields=fields)
        else:
            call.result.data_model = UpdateResponse(updated=0)
示例#8
0
def dequeue(call: APICall, company_id, request: UpdateRequest):
    task = TaskBLL.get_task_with_access(
        request.task,
        company_id=company_id,
        only=("id", "execution", "status", "project"),
        requires_write_access=True,
    )
    res = DequeueResponse(
        **TaskBLL.dequeue_and_change_status(
            task,
            company_id,
            status_message=request.status_message,
            status_reason=request.status_reason,
        )
    )

    res.dequeued = 1
    call.result.data_model = res
示例#9
0
def publish(call: APICall, company_id, req_model: PublishRequest):
    call.result.data_model = PublishResponse(
        **TaskBLL.publish_task(
            task_id=req_model.task,
            company_id=company_id,
            publish_model=req_model.publish_model,
            force=req_model.force,
            status_reason=req_model.status_reason,
            status_message=req_model.status_message,
        )
    )
示例#10
0
def stop_task(
    task_id: str,
    company_id: str,
    user_name: str,
    status_reason: str,
    force: bool,
) -> dict:
    """
    Stop a running task. Requires task status 'in_progress' and
    execution_progress 'running', or force=True. Development task or
    task that has no associated worker is stopped immediately.
    For a non-development task with worker only the status message
    is set to 'stopping' to allow the worker to stop the task and report by itself
    :return: updated task fields
    """

    task = TaskBLL.get_task_with_access(
        task_id,
        company_id=company_id,
        only=(
            "status",
            "project",
            "tags",
            "system_tags",
            "last_worker",
            "last_update",
        ),
        requires_write_access=True,
    )

    def is_run_by_worker(t: Task) -> bool:
        """Checks if there is an active worker running the task"""
        update_timeout = config.get("apiserver.workers.task_update_timeout",
                                    600)
        return (t.last_worker and t.last_update
                and (datetime.utcnow() - t.last_update).total_seconds() <
                update_timeout)

    if TaskSystemTags.development in task.system_tags or not is_run_by_worker(
            task):
        new_status = TaskStatus.stopped
        status_message = f"Stopped by {user_name}"
    else:
        new_status = task.status
        status_message = TaskStatusMessage.stopping

    return ChangeStatusRequest(
        task=task,
        new_status=new_status,
        status_reason=status_reason,
        status_message=status_message,
        force=force,
    ).execute()
示例#11
0
def publish_task(
    task_id: str,
    company_id: str,
    force: bool,
    publish_model_func: Callable[[str, str], Any] = None,
    status_message: str = "",
    status_reason: str = "",
) -> dict:
    task = TaskBLL.get_task_with_access(
        task_id, company_id=company_id, requires_write_access=True
    )
    if not force:
        validate_status_change(task.status, TaskStatus.published)

    previous_task_status = task.status
    output = task.output or Output()
    publish_failed = False

    try:
        # set state to publishing
        task.status = TaskStatus.publishing
        task.save()

        # publish task models
        if task.models and task.models.output and publish_model_func:
            model_id = task.models.output[-1].model
            model = (
                Model.objects(id=model_id, company=company_id)
                .only("id", "ready")
                .first()
            )
            if model and not model.ready:
                publish_model_func(model.id, company_id)

        # set task status to published, and update (or set) it's new output (view and models)
        return ChangeStatusRequest(
            task=task,
            new_status=TaskStatus.published,
            force=force,
            status_reason=status_reason,
            status_message=status_message,
        ).execute(published=datetime.utcnow(), output=output)

    except Exception as ex:
        publish_failed = True
        raise ex
    finally:
        if publish_failed:
            task.status = previous_task_status
            task.save()
示例#12
0
def unarchive_task(
    task: str, company_id: str, status_message: str, status_reason: str,
) -> int:
    """
    Unarchive task. Return 1 if successful
    """
    task = TaskBLL.get_task_with_access(
        task, company_id=company_id, only=("id",), requires_write_access=True,
    )
    return task.update(
        status_message=status_message,
        status_reason=status_reason,
        pull__system_tags=EntityVisibility.archived.value,
        last_change=datetime.utcnow(),
    )
示例#13
0
def get_hyper_parameters(call: APICall, company_id: str,
                         request: GetHyperParamRequest):

    total, remaining, parameters = TaskBLL.get_aggregated_project_parameters(
        company_id,
        project_ids=[request.project] if request.project else None,
        include_subprojects=request.include_subprojects,
        page=request.page,
        page_size=request.page_size,
    )

    call.result.data = {
        "total": total,
        "remaining": remaining,
        "parameters": parameters,
    }
示例#14
0
def archive(call: APICall, company_id, request: ArchiveRequest):
    tasks = TaskBLL.assert_exists(
        company_id,
        task_ids=request.tasks,
        only=("id", "execution", "status", "project", "system_tags",
              "enqueue_status"),
    )
    archived = 0
    for task in tasks:
        archived += archive_task(
            company_id=company_id,
            task=task,
            status_message=request.status_message,
            status_reason=request.status_reason,
        )

    call.result.data_model = ArchiveResponse(archived=archived)
示例#15
0
def set_requirements(call: APICall, company_id, req_model: SetRequirementsRequest):
    requirements = req_model.requirements
    with translate_errors_context():
        task = TaskBLL.get_task_with_access(
            req_model.task,
            company_id=company_id,
            only=("status", "script"),
            requires_write_access=True,
        )
        if not task.script:
            raise errors.bad_request.MissingTaskFields(
                "Task has no script field", task=task.id
            )
        res = update_task(task, update_cmds=dict(script__requirements=requirements))
        call.result.data_model = UpdateResponse(updated=res)
        if res:
            call.result.data_model.fields = {"script.requirements": requirements}
示例#16
0
    def _update_task(
        self,
        company_id,
        task_id,
        now,
        iter_max=None,
        last_scalar_events=None,
        last_events=None,
    ):
        """
        Update task information in DB with aggregated results after handling event(s) related to this task.

        This updates the task with the highest iteration value encountered during the last events update, as well
        as the latest metric/variant scalar values reported (according to the report timestamp) and the task's last
        update time.
        """
        fields = {}

        if iter_max is not None:
            fields["last_iteration_max"] = iter_max

        if last_scalar_events:
            fields["last_scalar_values"] = list(
                flatten_nested_items(
                    last_scalar_events,
                    nesting=2,
                    include_leaves=[
                        "value",
                        "min_value",
                        "max_value",
                        "metric",
                        "variant",
                    ],
                ))

        if last_events:
            fields["last_events"] = last_events

        if not fields:
            return False

        return TaskBLL.update_statistics(task_id,
                                         company_id,
                                         last_update=now,
                                         **fields)
示例#17
0
def dequeue_task(
    task_id: str,
    company_id: str,
    status_message: str,
    status_reason: str,
) -> Tuple[int, dict]:
    query = dict(id=task_id, company=company_id)
    task = Task.get_for_writing(**query)
    if not task:
        raise errors.bad_request.InvalidTaskId(**query)

    res = TaskBLL.dequeue_and_change_status(
        task,
        company_id,
        status_message=status_message,
        status_reason=status_reason,
    )
    return 1, res
示例#18
0
def stop(call: APICall, company_id, req_model: UpdateRequest):
    """
    stop
    :summary: Stop a running task. Requires task status 'in_progress' and
              execution_progress 'running', or force=True.
              Development task is stopped immediately. For a non-development task
              only its status message is set to 'stopping'

    """
    call.result.data_model = UpdateResponse(
        **TaskBLL.stop_task(
            task_id=req_model.task,
            company_id=company_id,
            user_name=call.identity.user_name,
            status_reason=req_model.status_reason,
            force=req_model.force,
        )
    )
示例#19
0
def delete(call: APICall, company_id, req_model: DeleteRequest):
    task = TaskBLL.get_task_with_access(
        req_model.task, company_id=company_id, requires_write_access=True
    )

    move_to_trash = req_model.move_to_trash
    force = req_model.force

    if task.status != TaskStatus.created and not force:
        raise errors.bad_request.TaskCannotBeDeleted(
            "due to status, use force=True",
            task=task.id,
            expected=TaskStatus.created,
            current=task.status,
        )

    with translate_errors_context():
        result = cleanup_task(task, force)

        if move_to_trash:
            collection_name = task._get_collection_name()
            archived_collection = "{}__trash".format(collection_name)
            task.switch_collection(archived_collection)
            try:
                # A simple save() won't do due to mongoengine caching (nothing will be saved), so we have to force
                # an insert. However, if for some reason such an ID exists, let's make sure we'll keep going.
                with TimingContext("mongo", "save_task"):
                    task.save(force_insert=True)
            except Exception:
                pass
            task.switch_collection(collection_name)

        task.delete()
        _reset_cached_tags(company_id, projects=[task.project])
        update_project_time(task.project)

        call.result.data = dict(deleted=True, **attr.asdict(result))
示例#20
0
def set_task_status_from_call(
    request: UpdateRequest, company_id, new_status=None, **set_fields
) -> dict:
    fields_resolver = SetFieldsResolver(set_fields)
    task = TaskBLL.get_task_with_access(
        request.task,
        company_id=company_id,
        only=tuple(
            {"status", "project", "started", "duration"} | fields_resolver.get_names()
        ),
        requires_write_access=True,
    )

    if "duration" not in fields_resolver.get_names():
        if new_status == Task.started:
            fields_resolver.add_fields(min__duration=max(0, task.duration or 0))
        elif new_status in (
            TaskStatus.completed,
            TaskStatus.failed,
            TaskStatus.stopped,
        ):
            fields_resolver.add_fields(
                duration=int((task.started - datetime.utcnow()).total_seconds())
                if task.started
                else 0
            )

    status_reason = request.status_reason
    status_message = request.status_message
    force = request.force
    return ChangeStatusRequest(
        task=task,
        new_status=new_status or task.status,
        status_reason=status_reason,
        status_message=status_message,
        force=force,
    ).execute(**fields_resolver.get_fields(task))
示例#21
0
def add_or_update_model(_: APICall, company_id: str,
                        request: AddUpdateModelRequest):
    get_task_for_update(company_id=company_id,
                        task_id=request.task,
                        force=True)

    models_field = f"models__{request.type}"
    model = ModelItem(name=request.name,
                      model=request.model,
                      updated=datetime.utcnow())
    query = {"id": request.task, f"{models_field}__name": request.name}
    updated = Task.objects(**query).update_one(
        **{f"set__{models_field}__S": model})

    updated = TaskBLL.update_statistics(
        task_id=request.task,
        company_id=company_id,
        last_iteration_max=request.iteration,
        **({
            f"push__{models_field}": model
        } if not updated else {}),
    )

    return {"updated": updated}
示例#22
0
import re

from apiserver.apimodels.pipelines import StartPipelineResponse, StartPipelineRequest
from apiserver.bll.organization import OrgBLL
from apiserver.bll.project import ProjectBLL
from apiserver.bll.task import TaskBLL
from apiserver.bll.task.task_operations import enqueue_task
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task
from apiserver.service_repo import APICall, endpoint

org_bll = OrgBLL()
project_bll = ProjectBLL()
task_bll = TaskBLL()


def _update_task_name(task: Task):
    if not task or not task.project:
        return

    project = Project.objects(id=task.project).only("name").first()
    if not project:
        return

    _, _, name_prefix = project.name.rpartition("/")
    name_mask = re.compile(rf"{re.escape(name_prefix)}( #\d+)?$")
    count = Task.objects(project=task.project,
                         system_tags__in=["pipeline"],
                         name=name_mask).count()
    new_name = f"{name_prefix} #{count}" if count > 0 else name_prefix
    task.update(name=new_name)
示例#23
0
def reset_task(
    task_id: str,
    company_id: str,
    force: bool,
    return_file_urls: bool,
    delete_output_models: bool,
    clear_all: bool,
) -> Tuple[dict, CleanupResult, dict]:
    task = TaskBLL.get_task_with_access(
        task_id, company_id=company_id, requires_write_access=True
    )

    if not force and task.status == TaskStatus.published:
        raise errors.bad_request.InvalidTaskStatus(task_id=task.id, status=task.status)

    dequeued = {}
    updates = {}

    try:
        dequeued = TaskBLL.dequeue(task, company_id, silent_fail=True)
    except APIError:
        # dequeue may fail if the task was not enqueued
        pass

    cleaned_up = cleanup_task(
        task,
        force=force,
        update_children=False,
        return_file_urls=return_file_urls,
        delete_output_models=delete_output_models,
    )

    updates.update(
        set__last_iteration=DEFAULT_LAST_ITERATION,
        set__last_metrics={},
        set__metric_stats={},
        set__models__output=[],
        set__runtime={},
        unset__output__result=1,
        unset__output__error=1,
        unset__last_worker=1,
        unset__last_worker_report=1,
    )

    if clear_all:
        updates.update(
            set__execution=Execution(), unset__script=1,
        )
    else:
        updates.update(unset__execution__queue=1)
        if task.execution and task.execution.artifacts:
            updates.update(
                set__execution__artifacts={
                    key: artifact
                    for key, artifact in task.execution.artifacts.items()
                    if artifact.mode == ArtifactModes.input
                }
            )

    res = ChangeStatusRequest(
        task=task,
        new_status=TaskStatus.created,
        force=force,
        status_reason="reset",
        status_message="reset",
    ).execute(
        started=None,
        completed=None,
        published=None,
        active_duration=None,
        enqueue_status=None,
        **updates,
    )

    return dequeued, cleaned_up, res
示例#24
0
def reset(call: APICall, company_id, request: ResetRequest):
    task = TaskBLL.get_task_with_access(
        request.task, company_id=company_id, requires_write_access=True
    )

    force = request.force

    if not force and task.status == TaskStatus.published:
        raise errors.bad_request.InvalidTaskStatus(task_id=task.id, status=task.status)

    api_results = {}
    updates = {}

    try:
        dequeued = TaskBLL.dequeue(task, company_id, silent_fail=True)
    except APIError:
        # dequeue may fail if the task was not enqueued
        pass
    else:
        if dequeued:
            api_results.update(dequeued=dequeued)

    cleaned_up = cleanup_task(task, force)
    api_results.update(attr.asdict(cleaned_up))

    updates.update(
        set__last_iteration=DEFAULT_LAST_ITERATION,
        set__last_metrics={},
        set__metric_stats={},
        unset__output__result=1,
        unset__output__model=1,
        unset__output__error=1,
        unset__last_worker=1,
        unset__last_worker_report=1,
    )

    if request.clear_all:
        updates.update(
            set__execution=Execution(), unset__script=1,
        )
    else:
        updates.update(unset__execution__queue=1)
        if task.execution and task.execution.artifacts:
            updates.update(
                set__execution__artifacts={
                    key: artifact
                    for key, artifact in task.execution.artifacts.items()
                    if artifact.mode == ArtifactModes.input
                }
            )

    res = ResetResponse(
        **ChangeStatusRequest(
            task=task,
            new_status=TaskStatus.created,
            force=force,
            status_reason="reset",
            status_message="reset",
        ).execute(
            started=None,
            completed=None,
            published=None,
            active_duration=None,
            **updates,
        )
    )

    # do not return artifacts since they are not serializable
    res.fields.pop("execution.artifacts", None)

    for key, value in api_results.items():
        setattr(res, key, value)

    call.result.data_model = res
示例#25
0
def update_for_task(call: APICall, company_id, _):
    if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
        raise errors.moved_permanently.NotSupported(
            "use tasks.add_or_update_model")

    task_id = call.data["task"]
    uri = call.data.get("uri")
    iteration = call.data.get("iteration")
    override_model_id = call.data.get("override_model_id")
    if not (uri or override_model_id) or (uri and override_model_id):
        raise errors.bad_request.MissingRequiredFields(
            "exactly one field is required",
            fields=("uri", "override_model_id"))

    with translate_errors_context():

        query = dict(id=task_id, company=company_id)
        task = Task.get_for_writing(
            id=task_id,
            company=company_id,
            _only=["models", "execution", "name", "status", "project"],
        )
        if not task:
            raise errors.bad_request.InvalidTaskId(**query)

        allowed_states = [TaskStatus.created, TaskStatus.in_progress]
        if task.status not in allowed_states:
            raise errors.bad_request.InvalidTaskStatus(
                f"model can only be updated for tasks in the {allowed_states} states",
                **query,
            )

        if override_model_id:
            model = ModelBLL.get_company_model_by_id(
                company_id=company_id, model_id=override_model_id)
        else:
            if "name" not in call.data:
                # use task name if name not provided
                call.data["name"] = task.name

            if "comment" not in call.data:
                call.data[
                    "comment"] = f"Created by task `{task.name}` ({task.id})"

            if task.models and task.models.output:
                # model exists, update
                model_id = task.models.output[-1].model
                res = _update_model(call, company_id,
                                    model_id=model_id).to_struct()
                res.update({"id": model_id, "created": False})
                call.result.data = res
                return

            # new model, create
            fields = parse_model_fields(call, create_fields)

            # create and save model
            now = datetime.utcnow()
            model = Model(
                id=database.utils.id(),
                created=now,
                last_update=now,
                user=call.identity.user,
                company=company_id,
                project=task.project,
                framework=task.execution.framework,
                parent=task.models.input[0].model
                if task.models and task.models.input else None,
                design=task.execution.model_desc,
                labels=task.execution.model_labels,
                ready=(task.status == TaskStatus.published),
                **fields,
            )
            model.save()
            _update_cached_tags(company_id,
                                project=model.project,
                                fields=fields)

        TaskBLL.update_statistics(
            task_id=task_id,
            company_id=company_id,
            last_iteration_max=iteration,
            models__output=[
                ModelItem(
                    model=model.id,
                    name=TaskModelNames[TaskModelTypes.output],
                    updated=datetime.utcnow(),
                )
            ],
        )

        call.result.data = {"id": model.id, "created": True}
示例#26
0
def ping(_, company_id, request: PingRequest):
    TaskBLL.set_last_update(
        task_ids=[request.task], company_id=company_id, last_update=datetime.utcnow()
    )
示例#27
0
    def add_events(
        self, company_id, events, worker, allow_locked_tasks=False
    ) -> Tuple[int, int, dict]:
        actions: List[dict] = []
        task_ids = set()
        task_iteration = defaultdict(lambda: 0)
        task_last_scalar_events = nested_dict(
            3, dict
        )  # task_id -> metric_hash -> variant_hash -> MetricEvent
        task_last_events = nested_dict(
            3, dict
        )  # task_id -> metric_hash -> event_type -> MetricEvent
        errors_per_type = defaultdict(int)
        invalid_iteration_error = f"Iteration number should not exceed {MAX_LONG}"
        valid_tasks = self._get_valid_tasks(
            company_id,
            task_ids={
                event["task"] for event in events if event.get("task") is not None
            },
            allow_locked_tasks=allow_locked_tasks,
        )

        for event in events:
            # remove spaces from event type
            event_type = event.get("type")
            if event_type is None:
                errors_per_type["Event must have a 'type' field"] += 1
                continue

            event_type = event_type.replace(" ", "_")
            if event_type not in EVENT_TYPES:
                errors_per_type[f"Invalid event type {event_type}"] += 1
                continue

            task_id = event.get("task")
            if task_id is None:
                errors_per_type["Event must have a 'task' field"] += 1
                continue

            if task_id not in valid_tasks:
                errors_per_type["Invalid task id"] += 1
                continue

            event["type"] = event_type

            # @timestamp indicates the time the event is written, not when it happened
            event["@timestamp"] = es_factory.get_es_timestamp_str()

            # for backward bomba-tavili-tea
            if "ts" in event:
                event["timestamp"] = event.pop("ts")

            # set timestamp and worker if not sent
            if "timestamp" not in event:
                event["timestamp"] = es_factory.get_timestamp_millis()

            if "worker" not in event:
                event["worker"] = worker

            # force iter to be a long int
            iter = event.get("iter")
            if iter is not None:
                iter = int(iter)
                if iter > MAX_LONG or iter < MIN_LONG:
                    errors_per_type[invalid_iteration_error] += 1
                    continue
                event["iter"] = iter

            # used to have "values" to indicate array. no need anymore
            if "values" in event:
                event["value"] = event["values"]
                del event["values"]

            event["metric"] = event.get("metric") or ""
            event["variant"] = event.get("variant") or ""

            index_name = get_index_name(company_id, event_type)
            es_action = {
                "_op_type": "index",  # overwrite if exists with same ID
                "_index": index_name,
                "_source": event,
            }

            # for "log" events, don't assing custom _id - whatever is sent, is written (not overwritten)
            if event_type != EventType.task_log.value:
                es_action["_id"] = self._get_event_id(event)
            else:
                es_action["_id"] = dbutils.id()

            task_ids.add(task_id)
            if (
                iter is not None
                and event.get("metric") not in self._skip_iteration_for_metric
            ):
                task_iteration[task_id] = max(iter, task_iteration[task_id])

            self._update_last_metric_events_for_task(
                last_events=task_last_events[task_id], event=event,
            )
            if event_type == EventType.metrics_scalar.value:
                self._update_last_scalar_events_for_task(
                    last_events=task_last_scalar_events[task_id], event=event
                )

            actions.append(es_action)

        plot_actions = [
            action["_source"]
            for action in actions
            if action["_source"]["type"] == EventType.metrics_plot.value
        ]
        if plot_actions:
            self.validate_and_compress_plots(
                plot_actions,
                validate_json=config.get("services.events.validate_plot_str", False),
                compression_threshold=config.get(
                    "services.events.plot_compression_threshold", 100_000
                ),
            )

        added = 0
        with translate_errors_context():
            if actions:
                chunk_size = 500
                with TimingContext("es", "events_add_batch"):
                    # TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
                    with closing(
                        helpers.streaming_bulk(
                            self.es,
                            actions,
                            chunk_size=chunk_size,
                            # thread_count=8,
                            refresh=True,
                        )
                    ) as it:
                        for success, info in it:
                            if success:
                                added += 1
                            else:
                                errors_per_type["Error when indexing events batch"] += 1

                    remaining_tasks = set()
                    now = datetime.utcnow()
                    for task_id in task_ids:
                        # Update related tasks. For reasons of performance, we prefer to update
                        # all of them and not only those who's events were successful
                        updated = self._update_task(
                            company_id=company_id,
                            task_id=task_id,
                            now=now,
                            iter_max=task_iteration.get(task_id),
                            last_scalar_events=task_last_scalar_events.get(task_id),
                            last_events=task_last_events.get(task_id),
                        )

                        if not updated:
                            remaining_tasks.add(task_id)
                            continue

                    if remaining_tasks:
                        TaskBLL.set_last_update(
                            remaining_tasks, company_id, last_update=now
                        )

            # this is for backwards compatibility with streaming bulk throwing exception on those
            invalid_iterations_count = errors_per_type.get(invalid_iteration_error)
            if invalid_iterations_count:
                raise BulkIndexError(
                    f"{invalid_iterations_count} document(s) failed to index.",
                    [invalid_iteration_error],
                )

        if not added:
            raise errors.bad_request.EventsNotAdded(**errors_per_type)

        errors_count = sum(errors_per_type.values())
        return added, errors_count, errors_per_type