Exemplo n.º 1
0
def _update_model(call: APICall, model_id=None):
    identity = call.identity
    model_id = model_id or call.data["model"]

    with translate_errors_context():
        # get model by id
        query = dict(id=model_id, company=identity.company)
        model = Model.objects(**query).first()
        if not model:
            raise errors.bad_request.InvalidModelId(**query)

        data = prepare_update_fields(call, call.data)

        task_id = data.get("task")
        iteration = data.get("iteration")
        if task_id and iteration is not None:
            TaskBLL.update_statistics(
                task_id=task_id,
                company_id=identity.company,
                last_iteration_max=iteration,
            )

        updated_count, updated_fields = Model.safe_update(
            call.identity.company, model.id, data
        )
        conform_output_tags(call, updated_fields)
        return UpdateResponse(updated=updated_count, fields=updated_fields)
Exemplo n.º 2
0
def _update_model(call: APICall, company_id, model_id=None):
    model_id = model_id or call.data["model"]

    with translate_errors_context():
        # get model by id
        query = dict(id=model_id, company=company_id)
        model = Model.objects(**query).first()
        if not model:
            raise errors.bad_request.InvalidModelId(**query)

        data = prepare_update_fields(call, company_id, call.data)

        task_id = data.get("task")
        iteration = data.get("iteration")
        if task_id and iteration is not None:
            TaskBLL.update_statistics(
                task_id=task_id,
                company_id=company_id,
                last_iteration_max=iteration,
            )

        updated_count, updated_fields = Model.safe_update(
            company_id, model.id, data)
        if updated_count:
            new_project = updated_fields.get("project", model.project)
            if new_project != model.project:
                _reset_cached_tags(company_id,
                                   projects=[new_project, model.project])
            else:
                _update_cached_tags(company_id,
                                    project=model.project,
                                    fields=updated_fields)
        conform_output_tags(call, updated_fields)
        return UpdateResponse(updated=updated_count, fields=updated_fields)
Exemplo n.º 3
0
def get_by_id(call: APICall, company_id, req_model: TaskRequest):
    task = TaskBLL.get_task_with_access(req_model.task,
                                        company_id=company_id,
                                        allow_public=True)
    task_dict = task.to_proper_dict()
    conform_output_tags(call, task_dict)
    call.result.data = {"task": task_dict}
Exemplo n.º 4
0
def delete(call: APICall, company_id, req_model: DeleteRequest):
    task = TaskBLL.get_task_with_access(req_model.task,
                                        company_id=company_id,
                                        requires_write_access=True)

    move_to_trash = req_model.move_to_trash
    force = req_model.force

    if task.status != TaskStatus.created and not force:
        raise errors.bad_request.TaskCannotBeDeleted(
            "due to status, use force=True",
            task=task.id,
            expected=TaskStatus.created,
            current=task.status,
        )

    with translate_errors_context():
        result = cleanup_task(task, force)

        if move_to_trash:
            collection_name = task._get_collection_name()
            archived_collection = "{}__trash".format(collection_name)
            task.switch_collection(archived_collection)
            try:
                # A simple save() won't do due to mongoengine caching (nothing will be saved), so we have to force
                # an insert. However, if for some reason such an ID exists, let's make sure we'll keep going.
                with TimingContext("mongo", "save_task"):
                    task.save(force_insert=True)
            except Exception:
                pass
            task.switch_collection(collection_name)

        task.delete()
        org_bll.update_org_tags(company_id, reset=True)
        call.result.data = dict(deleted=True, **attr.asdict(result))
Exemplo n.º 5
0
def add_or_update_artifacts(
    call: APICall, company_id, request: AddOrUpdateArtifactsRequest
):
    added, updated = TaskBLL.add_or_update_artifacts(
        task_id=request.task, company_id=company_id, artifacts=request.artifacts
    )
    call.result.data_model = AddOrUpdateArtifactsResponse(added=added, updated=updated)
Exemplo n.º 6
0
def get_by_id(call: APICall, company_id, req_model: TaskRequest):
    task = TaskBLL.get_task_with_access(req_model.task,
                                        company_id=company_id,
                                        allow_public=True)
    task_dict = task.to_proper_dict()
    unprepare_from_saved(call, task_dict)
    call.result.data = {"task": task_dict}
Exemplo n.º 7
0
def dequeue(call: APICall, company_id, req_model: UpdateRequest):
    task = TaskBLL.get_task_with_access(
        req_model.task,
        company_id=company_id,
        only=("id", "execution", "status", "project"),
        requires_write_access=True,
    )
    if task.status not in (TaskStatus.queued,):
        raise errors.bad_request.InvalidTaskId(
            status=task.status, expected=TaskStatus.queued
        )

    _dequeue(task, company_id)

    status_message = req_model.status_message
    status_reason = req_model.status_reason
    res = DequeueResponse(
        **ChangeStatusRequest(
            task=task,
            new_status=TaskStatus.created,
            status_reason=status_reason,
            status_message=status_message,
        ).execute(unset__execution__queue=1)
    )
    res.dequeued = 1

    call.result.data_model = res
Exemplo n.º 8
0
    def _update_task(self,
                     company_id,
                     task_id,
                     now,
                     iter_max=None,
                     last_events=None):
        """
        Update task information in DB with aggregated results after handling event(s) related to this task.

        This updates the task with the highest iteration value encountered during the last events update, as well
        as the latest metric/variant scalar values reported (according to the report timestamp) and the task's last
        update time.
        """
        fields = {}

        if iter_max is not None:
            fields["last_iteration_max"] = iter_max

        if last_events:
            fields["last_values"] = list(
                flatten_nested_items(
                    last_events,
                    nesting=2,
                    include_leaves=["value", "metric", "variant"],
                ))

        if not fields:
            return False

        return TaskBLL.update_statistics(task_id,
                                         company_id,
                                         last_update=now,
                                         **fields)
Exemplo n.º 9
0
def edit(call: APICall, company_id, _):
    model_id = call.data["model"]

    with translate_errors_context():
        query = dict(id=model_id, company=company_id)
        model = Model.objects(**query).first()
        if not model:
            raise errors.bad_request.InvalidModelId(**query)

        fields = parse_model_fields(call, create_fields)
        fields = prepare_update_fields(call, company_id, fields)

        for key in fields:
            field = getattr(model, key, None)
            value = fields[key]
            if (field and isinstance(value, dict)
                    and isinstance(field, EmbeddedDocument)):
                d = field.to_mongo(use_db_field=False).to_dict()
                d.update(value)
                fields[key] = d

        iteration = call.data.get("iteration")
        task_id = model.task or fields.get("task")
        if task_id and iteration is not None:
            TaskBLL.update_statistics(
                task_id=task_id,
                company_id=company_id,
                last_iteration_max=iteration,
            )

        if fields:
            updated = model.update(upsert=False, **fields)
            if updated:
                new_project = fields.get("project", model.project)
                if new_project != model.project:
                    _reset_cached_tags(company_id,
                                       projects=[new_project, model.project])
                else:
                    _update_cached_tags(company_id,
                                        project=model.project,
                                        fields=fields)
            conform_output_tags(call, fields)
            call.result.data_model = UpdateResponse(updated=updated,
                                                    fields=fields)
        else:
            call.result.data_model = UpdateResponse(updated=0)
Exemplo n.º 10
0
def publish(call: APICall, company_id, req_model: PublishRequest):
    call.result.data_model = PublishResponse(**TaskBLL.publish_task(
        task_id=req_model.task,
        company_id=company_id,
        publish_model=req_model.publish_model,
        force=req_model.force,
        status_reason=req_model.status_reason,
        status_message=req_model.status_message,
    ))
Exemplo n.º 11
0
def set_ready(call: APICall, company, req_model: PublishModelRequest):
    updated, published_task_data = TaskBLL.model_set_ready(
        model_id=req_model.model,
        company_id=company,
        publish_task=req_model.publish_task,
        force_publish_task=req_model.force_publish_task)

    call.result.data_model = PublishModelResponse(
        updated=updated,
        published_task=ModelTaskPublishResponse(
            **published_task_data) if published_task_data else None)
Exemplo n.º 12
0
def reset(call: APICall, company_id, req_model: UpdateRequest):
    task = TaskBLL.get_task_with_access(req_model.task,
                                        company_id=company_id,
                                        requires_write_access=True)

    force = req_model.force

    if not force and task.status == TaskStatus.published:
        raise errors.bad_request.InvalidTaskStatus(task_id=task.id,
                                                   status=task.status)

    api_results = {}
    updates = {}

    try:
        dequeued = _dequeue(task, company_id, silent_fail=True)
    except APIError:
        # dequeue may fail if the task was not enqueued
        pass
    else:
        if dequeued:
            api_results.update(dequeued=dequeued)
        updates.update(unset__execution__queue=1)

    cleaned_up = cleanup_task(task, force)
    api_results.update(attr.asdict(cleaned_up))

    updates.update(
        set__last_iteration=DEFAULT_LAST_ITERATION,
        set__last_metrics={},
        unset__output__result=1,
        unset__output__model=1,
        __raw__={"$pull": {
            "execution.artifacts": {
                "mode": {
                    "$ne": "input"
                }
            }
        }},
    )

    res = ResetResponse(**ChangeStatusRequest(
        task=task,
        new_status=TaskStatus.created,
        force=force,
        status_reason="reset",
        status_message="reset",
    ).execute(started=None, completed=None, published=None, **updates))

    for key, value in api_results.items():
        setattr(res, key, value)

    call.result.data_model = res
Exemplo n.º 13
0
def edit(call):
    assert isinstance(call, APICall)
    identity = call.identity
    model_id = call.data["model"]

    with translate_errors_context():
        query = dict(id=model_id, company=identity.company)
        model = Model.objects(**query).first()
        if not model:
            raise errors.bad_request.InvalidModelId(**query)

        fields = parse_model_fields(call, create_fields)
        fields = prepare_update_fields(call, fields)

        for key in fields:
            field = getattr(model, key, None)
            value = fields[key]
            if (field and isinstance(value, dict)
                    and isinstance(field, EmbeddedDocument)):
                d = field.to_mongo(use_db_field=False).to_dict()
                d.update(value)
                fields[key] = d

        iteration = call.data.get("iteration")
        task_id = model.task or fields.get('task')
        if task_id and iteration is not None:
            TaskBLL.update_statistics(
                task_id=task_id,
                company_id=identity.company,
                last_iteration_max=iteration,
            )

        if fields:
            updated = model.update(upsert=False, **fields)
            call.result.data_model = UpdateResponse(updated=updated,
                                                    fields=fields)
        else:
            call.result.data_model = UpdateResponse(updated=0)
Exemplo n.º 14
0
def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamReq):

    total, remaining, parameters = TaskBLL.get_aggregated_project_execution_parameters(
        company_id,
        project_ids=[request.project] if request.project else None,
        page=request.page,
        page_size=request.page_size,
    )

    call.result.data = {
        "total": total,
        "remaining": remaining,
        "parameters": parameters,
    }
Exemplo n.º 15
0
def stop(call: APICall, company_id, req_model: UpdateRequest):
    """
    stop
    :summary: Stop a running task. Requires task status 'in_progress' and
              execution_progress 'running', or force=True.
              Development task is stopped immediately. For a non-development task
              only its status message is set to 'stopping'

    """
    call.result.data_model = UpdateResponse(**TaskBLL.stop_task(
        task_id=req_model.task,
        company_id=company_id,
        user_name=call.identity.user_name,
        status_reason=req_model.status_reason,
        force=req_model.force,
    ))
Exemplo n.º 16
0
    def _update_task(self,
                     company_id,
                     task_id,
                     now,
                     iter=None,
                     last_events=None,
                     last_metrics=None):
        """
        Update task information in DB with aggregated results after handling event(s) related to this task.

        This updates the task with the highest iteration value encountered during the last events update, as well
        as the latest metric/variant scalar values reported (according to the report timestamp) and the task's last
        update time.
        """
        fields = {}

        if iter is not None:
            fields["last_iteration"] = iter

        if last_events:

            def get_metric_event(ev):
                me = MetricEvent.from_dict(**ev)
                if "timestamp" in ev:
                    me.timestamp = datetime.utcfromtimestamp(ev["timestamp"] /
                                                             1000)
                return me

            new_last_metrics = nested_dict(2, MetricEvent)
            new_last_metrics.update(last_metrics)

            for metric_hash, variants in last_events.items():
                for variant_hash, event in variants.items():
                    new_last_metrics[metric_hash][
                        variant_hash] = get_metric_event(event)

            fields["last_metrics"] = new_last_metrics.to_dict()

        if not fields:
            return False

        return TaskBLL.update_statistics(task_id,
                                         company_id,
                                         last_update=now,
                                         **fields)
Exemplo n.º 17
0
def set_requirements(call: APICall, company_id, req_model: SetRequirementsRequest):
    requirements = req_model.requirements
    with translate_errors_context():
        task = TaskBLL.get_task_with_access(
            req_model.task,
            company_id=company_id,
            only=("status", "script"),
            requires_write_access=True,
        )
        if not task.script:
            raise errors.bad_request.MissingTaskFields(
                "Task has no script field", task=task.id
            )
        res = task.update(
            script__requirements=requirements, last_update=datetime.utcnow()
        )
        call.result.data_model = UpdateResponse(updated=res)
        if res:
            call.result.data_model.fields = {"script.requirements": requirements}
Exemplo n.º 18
0
def set_task_status_from_call(request: UpdateRequest,
                              company_id,
                              new_status=None,
                              **kwargs) -> dict:
    task = TaskBLL.get_task_with_access(
        request.task,
        company_id=company_id,
        only=("status", "project"),
        requires_write_access=True,
    )
    status_reason = request.status_reason
    status_message = request.status_message
    force = request.force
    return ChangeStatusRequest(
        task=task,
        new_status=new_status or task.status,
        status_reason=status_reason,
        status_message=status_message,
        force=force,
    ).execute(**kwargs)
Exemplo n.º 19
0
def set_task_status_from_call(
    request: UpdateRequest, company_id, new_status=None, **set_fields
) -> dict:
    fields_resolver = SetFieldsResolver(set_fields)
    task = TaskBLL.get_task_with_access(
        request.task,
        company_id=company_id,
        only=tuple({"status", "project"} | fields_resolver.get_names()),
        requires_write_access=True,
    )

    status_reason = request.status_reason
    status_message = request.status_message
    force = request.force
    return ChangeStatusRequest(
        task=task,
        new_status=new_status or task.status,
        status_reason=status_reason,
        status_message=status_message,
        force=force,
    ).execute(**fields_resolver.get_fields(task))
Exemplo n.º 20
0
def reset(call: APICall, company_id, req_model: UpdateRequest):
    task = TaskBLL.get_task_with_access(req_model.task,
                                        company_id=company_id,
                                        requires_write_access=True)

    force = req_model.force

    if not force and task.status == TaskStatus.published:
        raise errors.bad_request.InvalidTaskStatus(task_id=task.id,
                                                   status=task.status)

    api_results = {}
    updates = {}

    cleaned_up = cleanup_task(task, force)
    api_results.update(attr.asdict(cleaned_up))

    updates.update(
        unset__script__requirements=1,
        set__last_iteration=DEFAULT_LAST_ITERATION,
        set__last_metrics={},
        unset__output__result=1,
        unset__output__model=1,
    )

    res = ResetResponse(**ChangeStatusRequest(
        task=task,
        new_status=TaskStatus.created,
        force=force,
        status_reason="reset",
        status_message="reset",
    ).execute(started=None, completed=None, published=None, **updates))

    for key, value in api_results.items():
        setattr(res, key, value)

    call.result.data_model = res
Exemplo n.º 21
0
def set_task_status_from_call(request: UpdateRequest,
                              company_id,
                              new_status=None,
                              **set_fields) -> dict:
    fields_resolver = SetFieldsResolver(set_fields)
    task = TaskBLL.get_task_with_access(
        request.task,
        company_id=company_id,
        only=tuple({"status", "project", "started", "duration"}
                   | fields_resolver.get_names()),
        requires_write_access=True,
    )

    if "duration" not in fields_resolver.get_names():
        if new_status == Task.started:
            fields_resolver.add_fields(
                min__duration=max(0, task.duration or 0))
        elif new_status in (
                TaskStatus.completed,
                TaskStatus.failed,
                TaskStatus.stopped,
        ):
            fields_resolver.add_fields(
                duration=int((task.started -
                              datetime.utcnow()).total_seconds()))

    status_reason = request.status_reason
    status_message = request.status_message
    force = request.force
    return ChangeStatusRequest(
        task=task,
        new_status=new_status or task.status,
        status_reason=status_reason,
        status_message=status_message,
        force=force,
    ).execute(**fields_resolver.get_fields(task))
Exemplo n.º 22
0
from boltons import iterutils

from apierrors import errors
from apimodels.tasks import (
    HyperParamKey,
    HyperParamItem,
    ReplaceHyperparams,
    Configuration,
)
from bll.task import TaskBLL
from config import config
from database.model.task.task import ParamsItem, Task, ConfigurationItem, TaskStatus
from utilities.parameter_key_escaper import ParameterKeyEscaper

log = config.logger(__file__)
task_bll = TaskBLL()


class HyperParams:
    _properties_section = "properties"

    @classmethod
    def get_params(cls, company_id: str,
                   task_ids: Sequence[str]) -> Dict[str, dict]:
        only = ("id", "hyperparams")
        tasks = task_bll.assert_exists(
            company_id=company_id,
            task_ids=task_ids,
            only=only,
            allow_public=True,
        )
Exemplo n.º 23
0
    def add_events(self, company_id, events, worker, allow_locked_tasks=False):
        actions = []
        task_ids = set()
        task_iteration = defaultdict(lambda: 0)
        task_last_events = nested_dict(
            3, dict)  # task_id -> metric_hash -> variant_hash -> MetricEvent

        for event in events:
            # remove spaces from event type
            if "type" not in event:
                raise errors.BadRequest("Event must have a 'type' field",
                                        event=event)

            event_type = event["type"].replace(" ", "_")
            if event_type not in EVENT_TYPES:
                raise errors.BadRequest(
                    "Invalid event type {}".format(event_type),
                    event=event,
                    types=EVENT_TYPES,
                )

            event["type"] = event_type

            # @timestamp indicates the time the event is written, not when it happened
            event["@timestamp"] = es_factory.get_es_timestamp_str()

            # for backward bomba-tavili-tea
            if "ts" in event:
                event["timestamp"] = event.pop("ts")

            # set timestamp and worker if not sent
            if "timestamp" not in event:
                event["timestamp"] = es_factory.get_timestamp_millis()

            if "worker" not in event:
                event["worker"] = worker

            # force iter to be a long int
            iter = event.get("iter")
            if iter is not None:
                iter = int(iter)
                event["iter"] = iter

            # used to have "values" to indicate array. no need anymore
            if "values" in event:
                event["value"] = event["values"]
                del event["values"]

            index_name = EventMetrics.get_index_name(company_id, event_type)
            es_action = {
                "_op_type": "index",  # overwrite if exists with same ID
                "_index": index_name,
                "_type": "event",
                "_source": event,
            }

            # for "log" events, don't assing custom _id - whatever is sent, is written (not overwritten)
            if event_type != "log":
                es_action["_id"] = self._get_event_id(event)
            else:
                es_action["_id"] = dbutils.id()

            task_id = event.get("task")
            if task_id is not None:
                es_action["_routing"] = task_id
                task_ids.add(task_id)
                if (iter is not None and event.get("metric")
                        not in self._skip_iteration_for_metric):
                    task_iteration[task_id] = max(iter,
                                                  task_iteration[task_id])

                if event_type == EventType.metrics_scalar.value:
                    self._update_last_metric_event_for_task(
                        task_last_events=task_last_events,
                        task_id=task_id,
                        event=event)
            else:
                es_action["_routing"] = task_id

            actions.append(es_action)

        if task_ids:
            # verify task_ids
            with translate_errors_context(), TimingContext(
                    "mongo", "task_by_ids"):
                extra_msg = None
                query = Q(id__in=task_ids, company=company_id)
                if not allow_locked_tasks:
                    query &= Q(status__nin=LOCKED_TASK_STATUSES)
                    extra_msg = "or task published"
                res = Task.objects(query).only("id")
                if len(res) < len(task_ids):
                    invalid_task_ids = tuple(
                        set(task_ids) - set(r.id for r in res))
                    raise errors.bad_request.InvalidTaskId(
                        extra_msg, company=company_id, ids=invalid_task_ids)

        errors_in_bulk = []
        added = 0
        chunk_size = 500
        with translate_errors_context(), TimingContext("es",
                                                       "events_add_batch"):
            # TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
            with closing(
                    helpers.streaming_bulk(
                        self.es,
                        actions,
                        chunk_size=chunk_size,
                        # thread_count=8,
                        refresh=True,
                    )) as it:
                for success, info in it:
                    if success:
                        added += chunk_size
                    else:
                        errors_in_bulk.append(info)

            remaining_tasks = set()
            now = datetime.utcnow()
            for task_id in task_ids:
                # Update related tasks. For reasons of performance, we prefer to update all of them and not only those
                #  who's events were successful

                updated = self._update_task(
                    company_id=company_id,
                    task_id=task_id,
                    now=now,
                    iter_max=task_iteration.get(task_id),
                    last_events=task_last_events.get(task_id),
                )

                if not updated:
                    remaining_tasks.add(task_id)
                    continue

            if remaining_tasks:
                TaskBLL.set_last_update(remaining_tasks,
                                        company_id,
                                        last_update=now)

        # Compensate for always adding chunk_size on success (last chunk is probably smaller)
        added = min(added, len(actions))

        return added, errors_in_bulk
Exemplo n.º 24
0
def ping(_, company_id, request: PingRequest):
    TaskBLL.set_last_update(task_ids=[request.task],
                            company_id=company_id,
                            last_update=datetime.utcnow())
Exemplo n.º 25
0
    def add_events(self,
                   company_id,
                   events,
                   worker,
                   allow_locked_tasks=False) -> Tuple[int, int, dict]:
        actions = []
        task_ids = set()
        task_iteration = defaultdict(lambda: 0)
        task_last_scalar_events = nested_dict(
            3, dict)  # task_id -> metric_hash -> variant_hash -> MetricEvent
        task_last_events = nested_dict(
            3, dict)  # task_id -> metric_hash -> event_type -> MetricEvent
        errors_per_type = defaultdict(int)
        valid_tasks = self._get_valid_tasks(
            company_id,
            task_ids={
                event["task"]
                for event in events if event.get("task") is not None
            },
            allow_locked_tasks=allow_locked_tasks,
        )
        for event in events:
            # remove spaces from event type
            event_type = event.get("type")
            if event_type is None:
                errors_per_type["Event must have a 'type' field"] += 1
                continue

            event_type = event_type.replace(" ", "_")
            if event_type not in EVENT_TYPES:
                errors_per_type[f"Invalid event type {event_type}"] += 1
                continue

            task_id = event.get("task")
            if task_id is None:
                errors_per_type["Event must have a 'task' field"] += 1
                continue

            if task_id not in valid_tasks:
                errors_per_type["Invalid task id"] += 1
                continue

            event["type"] = event_type

            # @timestamp indicates the time the event is written, not when it happened
            event["@timestamp"] = es_factory.get_es_timestamp_str()

            # for backward bomba-tavili-tea
            if "ts" in event:
                event["timestamp"] = event.pop("ts")

            # set timestamp and worker if not sent
            if "timestamp" not in event:
                event["timestamp"] = es_factory.get_timestamp_millis()

            if "worker" not in event:
                event["worker"] = worker

            # force iter to be a long int
            iter = event.get("iter")
            if iter is not None:
                iter = int(iter)
                event["iter"] = iter

            # used to have "values" to indicate array. no need anymore
            if "values" in event:
                event["value"] = event["values"]
                del event["values"]

            event["metric"] = event.get("metric") or ""
            event["variant"] = event.get("variant") or ""

            index_name = EventMetrics.get_index_name(company_id, event_type)
            es_action = {
                "_op_type": "index",  # overwrite if exists with same ID
                "_index": index_name,
                "_type": "event",
                "_source": event,
            }

            # for "log" events, don't assing custom _id - whatever is sent, is written (not overwritten)
            if event_type != "log":
                es_action["_id"] = self._get_event_id(event)
            else:
                es_action["_id"] = dbutils.id()

            es_action["_routing"] = task_id
            task_ids.add(task_id)
            if (iter is not None and event.get("metric")
                    not in self._skip_iteration_for_metric):
                task_iteration[task_id] = max(iter, task_iteration[task_id])

            self._update_last_metric_events_for_task(
                last_events=task_last_events[task_id],
                event=event,
            )
            if event_type == EventType.metrics_scalar.value:
                self._update_last_scalar_events_for_task(
                    last_events=task_last_scalar_events[task_id], event=event)

            actions.append(es_action)

        added = 0
        if actions:
            chunk_size = 500
            with translate_errors_context(), TimingContext(
                    "es", "events_add_batch"):
                # TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
                with closing(
                        helpers.streaming_bulk(
                            self.es,
                            actions,
                            chunk_size=chunk_size,
                            # thread_count=8,
                            refresh=True,
                        )) as it:
                    for success, info in it:
                        if success:
                            added += chunk_size
                        else:
                            errors_per_type[
                                "Error when indexing events batch"] += 1

                remaining_tasks = set()
                now = datetime.utcnow()
                for task_id in task_ids:
                    # Update related tasks. For reasons of performance, we prefer to update
                    # all of them and not only those who's events were successful
                    updated = self._update_task(
                        company_id=company_id,
                        task_id=task_id,
                        now=now,
                        iter_max=task_iteration.get(task_id),
                        last_scalar_events=task_last_scalar_events.get(
                            task_id),
                        last_events=task_last_events.get(task_id),
                    )

                    if not updated:
                        remaining_tasks.add(task_id)
                        continue

                if remaining_tasks:
                    TaskBLL.set_last_update(remaining_tasks,
                                            company_id,
                                            last_update=now)

        # Compensate for always adding chunk_size on success (last chunk is probably smaller)
        added = min(added, len(actions))

        if not added:
            raise errors.bad_request.EventsNotAdded(**errors_per_type)

        errors_count = sum(errors_per_type.values())
        return added, errors_count, errors_per_type
Exemplo n.º 26
0
from service_repo import APICall, endpoint
from services.utils import conform_tag_fields, conform_output_tags
from timing_context import TimingContext
from utilities import safe_get

task_fields = set(Task.get_fields())
task_script_fields = set(get_fields(Script))
get_all_query_options = Task.QueryParameterOptions(
    list_fields=("id", "user", "tags", "system_tags", "type", "status",
                 "project"),
    datetime_fields=("status_changed", ),
    pattern_fields=("name", "comment"),
    fields=("parent", ),
)

task_bll = TaskBLL()
event_bll = EventBLL()
queue_bll = QueueBLL()

TaskBLL.start_non_responsive_tasks_watchdog()


def set_task_status_from_call(request: UpdateRequest,
                              company_id,
                              new_status=None,
                              **set_fields) -> dict:
    fields_resolver = SetFieldsResolver(set_fields)
    task = TaskBLL.get_task_with_access(
        request.task,
        company_id=company_id,
        only=tuple({"status", "project"} | fields_resolver.get_names()),
Exemplo n.º 27
0
def update_for_task(call: APICall, company_id, _):
    task_id = call.data["task"]
    uri = call.data.get("uri")
    iteration = call.data.get("iteration")
    override_model_id = call.data.get("override_model_id")
    if not (uri or override_model_id) or (uri and override_model_id):
        raise errors.bad_request.MissingRequiredFields(
            "exactly one field is required",
            fields=("uri", "override_model_id"))

    with translate_errors_context():

        query = dict(id=task_id, company=company_id)
        task = Task.get_for_writing(
            id=task_id,
            company=company_id,
            _only=["output", "execution", "name", "status", "project"],
        )
        if not task:
            raise errors.bad_request.InvalidTaskId(**query)

        allowed_states = [TaskStatus.created, TaskStatus.in_progress]
        if task.status not in allowed_states:
            raise errors.bad_request.InvalidTaskStatus(
                f"model can only be updated for tasks in the {allowed_states} states",
                **query,
            )

        if override_model_id:
            query = dict(company=company_id, id=override_model_id)
            model = Model.objects(**query).first()
            if not model:
                raise errors.bad_request.InvalidModelId(**query)
        else:
            if "name" not in call.data:
                # use task name if name not provided
                call.data["name"] = task.name

            if "comment" not in call.data:
                call.data[
                    "comment"] = f"Created by task `{task.name}` ({task.id})"

            if task.output and task.output.model:
                # model exists, update
                res = _update_model(call,
                                    company_id,
                                    model_id=task.output.model).to_struct()
                res.update({"id": task.output.model, "created": False})
                call.result.data = res
                return

            # new model, create
            fields = parse_model_fields(call, create_fields)

            # create and save model
            model = Model(
                id=database.utils.id(),
                created=datetime.utcnow(),
                user=call.identity.user,
                company=company_id,
                project=task.project,
                framework=task.execution.framework,
                parent=task.execution.model,
                design=task.execution.model_desc,
                labels=task.execution.model_labels,
                ready=(task.status == TaskStatus.published),
                **fields,
            )
            model.save()
            _update_cached_tags(company_id,
                                project=model.project,
                                fields=fields)

        TaskBLL.update_statistics(
            task_id=task_id,
            company_id=company_id,
            last_iteration_max=iteration,
            output__model=model.id,
        )

        call.result.data = {"id": model.id, "created": True}