def _update_model(call: APICall, model_id=None): identity = call.identity model_id = model_id or call.data["model"] with translate_errors_context(): # get model by id query = dict(id=model_id, company=identity.company) model = Model.objects(**query).first() if not model: raise errors.bad_request.InvalidModelId(**query) data = prepare_update_fields(call, call.data) task_id = data.get("task") iteration = data.get("iteration") if task_id and iteration is not None: TaskBLL.update_statistics( task_id=task_id, company_id=identity.company, last_iteration_max=iteration, ) updated_count, updated_fields = Model.safe_update( call.identity.company, model.id, data ) conform_output_tags(call, updated_fields) return UpdateResponse(updated=updated_count, fields=updated_fields)
def _update_model(call: APICall, company_id, model_id=None): model_id = model_id or call.data["model"] with translate_errors_context(): # get model by id query = dict(id=model_id, company=company_id) model = Model.objects(**query).first() if not model: raise errors.bad_request.InvalidModelId(**query) data = prepare_update_fields(call, company_id, call.data) task_id = data.get("task") iteration = data.get("iteration") if task_id and iteration is not None: TaskBLL.update_statistics( task_id=task_id, company_id=company_id, last_iteration_max=iteration, ) updated_count, updated_fields = Model.safe_update( company_id, model.id, data) if updated_count: new_project = updated_fields.get("project", model.project) if new_project != model.project: _reset_cached_tags(company_id, projects=[new_project, model.project]) else: _update_cached_tags(company_id, project=model.project, fields=updated_fields) conform_output_tags(call, updated_fields) return UpdateResponse(updated=updated_count, fields=updated_fields)
def get_by_id(call: APICall, company_id, req_model: TaskRequest): task = TaskBLL.get_task_with_access(req_model.task, company_id=company_id, allow_public=True) task_dict = task.to_proper_dict() conform_output_tags(call, task_dict) call.result.data = {"task": task_dict}
def delete(call: APICall, company_id, req_model: DeleteRequest): task = TaskBLL.get_task_with_access(req_model.task, company_id=company_id, requires_write_access=True) move_to_trash = req_model.move_to_trash force = req_model.force if task.status != TaskStatus.created and not force: raise errors.bad_request.TaskCannotBeDeleted( "due to status, use force=True", task=task.id, expected=TaskStatus.created, current=task.status, ) with translate_errors_context(): result = cleanup_task(task, force) if move_to_trash: collection_name = task._get_collection_name() archived_collection = "{}__trash".format(collection_name) task.switch_collection(archived_collection) try: # A simple save() won't do due to mongoengine caching (nothing will be saved), so we have to force # an insert. However, if for some reason such an ID exists, let's make sure we'll keep going. with TimingContext("mongo", "save_task"): task.save(force_insert=True) except Exception: pass task.switch_collection(collection_name) task.delete() org_bll.update_org_tags(company_id, reset=True) call.result.data = dict(deleted=True, **attr.asdict(result))
def add_or_update_artifacts( call: APICall, company_id, request: AddOrUpdateArtifactsRequest ): added, updated = TaskBLL.add_or_update_artifacts( task_id=request.task, company_id=company_id, artifacts=request.artifacts ) call.result.data_model = AddOrUpdateArtifactsResponse(added=added, updated=updated)
def get_by_id(call: APICall, company_id, req_model: TaskRequest): task = TaskBLL.get_task_with_access(req_model.task, company_id=company_id, allow_public=True) task_dict = task.to_proper_dict() unprepare_from_saved(call, task_dict) call.result.data = {"task": task_dict}
def dequeue(call: APICall, company_id, req_model: UpdateRequest): task = TaskBLL.get_task_with_access( req_model.task, company_id=company_id, only=("id", "execution", "status", "project"), requires_write_access=True, ) if task.status not in (TaskStatus.queued,): raise errors.bad_request.InvalidTaskId( status=task.status, expected=TaskStatus.queued ) _dequeue(task, company_id) status_message = req_model.status_message status_reason = req_model.status_reason res = DequeueResponse( **ChangeStatusRequest( task=task, new_status=TaskStatus.created, status_reason=status_reason, status_message=status_message, ).execute(unset__execution__queue=1) ) res.dequeued = 1 call.result.data_model = res
def _update_task(self, company_id, task_id, now, iter_max=None, last_events=None): """ Update task information in DB with aggregated results after handling event(s) related to this task. This updates the task with the highest iteration value encountered during the last events update, as well as the latest metric/variant scalar values reported (according to the report timestamp) and the task's last update time. """ fields = {} if iter_max is not None: fields["last_iteration_max"] = iter_max if last_events: fields["last_values"] = list( flatten_nested_items( last_events, nesting=2, include_leaves=["value", "metric", "variant"], )) if not fields: return False return TaskBLL.update_statistics(task_id, company_id, last_update=now, **fields)
def edit(call: APICall, company_id, _): model_id = call.data["model"] with translate_errors_context(): query = dict(id=model_id, company=company_id) model = Model.objects(**query).first() if not model: raise errors.bad_request.InvalidModelId(**query) fields = parse_model_fields(call, create_fields) fields = prepare_update_fields(call, company_id, fields) for key in fields: field = getattr(model, key, None) value = fields[key] if (field and isinstance(value, dict) and isinstance(field, EmbeddedDocument)): d = field.to_mongo(use_db_field=False).to_dict() d.update(value) fields[key] = d iteration = call.data.get("iteration") task_id = model.task or fields.get("task") if task_id and iteration is not None: TaskBLL.update_statistics( task_id=task_id, company_id=company_id, last_iteration_max=iteration, ) if fields: updated = model.update(upsert=False, **fields) if updated: new_project = fields.get("project", model.project) if new_project != model.project: _reset_cached_tags(company_id, projects=[new_project, model.project]) else: _update_cached_tags(company_id, project=model.project, fields=fields) conform_output_tags(call, fields) call.result.data_model = UpdateResponse(updated=updated, fields=fields) else: call.result.data_model = UpdateResponse(updated=0)
def publish(call: APICall, company_id, req_model: PublishRequest): call.result.data_model = PublishResponse(**TaskBLL.publish_task( task_id=req_model.task, company_id=company_id, publish_model=req_model.publish_model, force=req_model.force, status_reason=req_model.status_reason, status_message=req_model.status_message, ))
def set_ready(call: APICall, company, req_model: PublishModelRequest): updated, published_task_data = TaskBLL.model_set_ready( model_id=req_model.model, company_id=company, publish_task=req_model.publish_task, force_publish_task=req_model.force_publish_task) call.result.data_model = PublishModelResponse( updated=updated, published_task=ModelTaskPublishResponse( **published_task_data) if published_task_data else None)
def reset(call: APICall, company_id, req_model: UpdateRequest): task = TaskBLL.get_task_with_access(req_model.task, company_id=company_id, requires_write_access=True) force = req_model.force if not force and task.status == TaskStatus.published: raise errors.bad_request.InvalidTaskStatus(task_id=task.id, status=task.status) api_results = {} updates = {} try: dequeued = _dequeue(task, company_id, silent_fail=True) except APIError: # dequeue may fail if the task was not enqueued pass else: if dequeued: api_results.update(dequeued=dequeued) updates.update(unset__execution__queue=1) cleaned_up = cleanup_task(task, force) api_results.update(attr.asdict(cleaned_up)) updates.update( set__last_iteration=DEFAULT_LAST_ITERATION, set__last_metrics={}, unset__output__result=1, unset__output__model=1, __raw__={"$pull": { "execution.artifacts": { "mode": { "$ne": "input" } } }}, ) res = ResetResponse(**ChangeStatusRequest( task=task, new_status=TaskStatus.created, force=force, status_reason="reset", status_message="reset", ).execute(started=None, completed=None, published=None, **updates)) for key, value in api_results.items(): setattr(res, key, value) call.result.data_model = res
def edit(call): assert isinstance(call, APICall) identity = call.identity model_id = call.data["model"] with translate_errors_context(): query = dict(id=model_id, company=identity.company) model = Model.objects(**query).first() if not model: raise errors.bad_request.InvalidModelId(**query) fields = parse_model_fields(call, create_fields) fields = prepare_update_fields(call, fields) for key in fields: field = getattr(model, key, None) value = fields[key] if (field and isinstance(value, dict) and isinstance(field, EmbeddedDocument)): d = field.to_mongo(use_db_field=False).to_dict() d.update(value) fields[key] = d iteration = call.data.get("iteration") task_id = model.task or fields.get('task') if task_id and iteration is not None: TaskBLL.update_statistics( task_id=task_id, company_id=identity.company, last_iteration_max=iteration, ) if fields: updated = model.update(upsert=False, **fields) call.result.data_model = UpdateResponse(updated=updated, fields=fields) else: call.result.data_model = UpdateResponse(updated=0)
def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamReq): total, remaining, parameters = TaskBLL.get_aggregated_project_execution_parameters( company_id, project_ids=[request.project] if request.project else None, page=request.page, page_size=request.page_size, ) call.result.data = { "total": total, "remaining": remaining, "parameters": parameters, }
def stop(call: APICall, company_id, req_model: UpdateRequest): """ stop :summary: Stop a running task. Requires task status 'in_progress' and execution_progress 'running', or force=True. Development task is stopped immediately. For a non-development task only its status message is set to 'stopping' """ call.result.data_model = UpdateResponse(**TaskBLL.stop_task( task_id=req_model.task, company_id=company_id, user_name=call.identity.user_name, status_reason=req_model.status_reason, force=req_model.force, ))
def _update_task(self, company_id, task_id, now, iter=None, last_events=None, last_metrics=None): """ Update task information in DB with aggregated results after handling event(s) related to this task. This updates the task with the highest iteration value encountered during the last events update, as well as the latest metric/variant scalar values reported (according to the report timestamp) and the task's last update time. """ fields = {} if iter is not None: fields["last_iteration"] = iter if last_events: def get_metric_event(ev): me = MetricEvent.from_dict(**ev) if "timestamp" in ev: me.timestamp = datetime.utcfromtimestamp(ev["timestamp"] / 1000) return me new_last_metrics = nested_dict(2, MetricEvent) new_last_metrics.update(last_metrics) for metric_hash, variants in last_events.items(): for variant_hash, event in variants.items(): new_last_metrics[metric_hash][ variant_hash] = get_metric_event(event) fields["last_metrics"] = new_last_metrics.to_dict() if not fields: return False return TaskBLL.update_statistics(task_id, company_id, last_update=now, **fields)
def set_requirements(call: APICall, company_id, req_model: SetRequirementsRequest): requirements = req_model.requirements with translate_errors_context(): task = TaskBLL.get_task_with_access( req_model.task, company_id=company_id, only=("status", "script"), requires_write_access=True, ) if not task.script: raise errors.bad_request.MissingTaskFields( "Task has no script field", task=task.id ) res = task.update( script__requirements=requirements, last_update=datetime.utcnow() ) call.result.data_model = UpdateResponse(updated=res) if res: call.result.data_model.fields = {"script.requirements": requirements}
def set_task_status_from_call(request: UpdateRequest, company_id, new_status=None, **kwargs) -> dict: task = TaskBLL.get_task_with_access( request.task, company_id=company_id, only=("status", "project"), requires_write_access=True, ) status_reason = request.status_reason status_message = request.status_message force = request.force return ChangeStatusRequest( task=task, new_status=new_status or task.status, status_reason=status_reason, status_message=status_message, force=force, ).execute(**kwargs)
def set_task_status_from_call( request: UpdateRequest, company_id, new_status=None, **set_fields ) -> dict: fields_resolver = SetFieldsResolver(set_fields) task = TaskBLL.get_task_with_access( request.task, company_id=company_id, only=tuple({"status", "project"} | fields_resolver.get_names()), requires_write_access=True, ) status_reason = request.status_reason status_message = request.status_message force = request.force return ChangeStatusRequest( task=task, new_status=new_status or task.status, status_reason=status_reason, status_message=status_message, force=force, ).execute(**fields_resolver.get_fields(task))
def reset(call: APICall, company_id, req_model: UpdateRequest): task = TaskBLL.get_task_with_access(req_model.task, company_id=company_id, requires_write_access=True) force = req_model.force if not force and task.status == TaskStatus.published: raise errors.bad_request.InvalidTaskStatus(task_id=task.id, status=task.status) api_results = {} updates = {} cleaned_up = cleanup_task(task, force) api_results.update(attr.asdict(cleaned_up)) updates.update( unset__script__requirements=1, set__last_iteration=DEFAULT_LAST_ITERATION, set__last_metrics={}, unset__output__result=1, unset__output__model=1, ) res = ResetResponse(**ChangeStatusRequest( task=task, new_status=TaskStatus.created, force=force, status_reason="reset", status_message="reset", ).execute(started=None, completed=None, published=None, **updates)) for key, value in api_results.items(): setattr(res, key, value) call.result.data_model = res
def set_task_status_from_call(request: UpdateRequest, company_id, new_status=None, **set_fields) -> dict: fields_resolver = SetFieldsResolver(set_fields) task = TaskBLL.get_task_with_access( request.task, company_id=company_id, only=tuple({"status", "project", "started", "duration"} | fields_resolver.get_names()), requires_write_access=True, ) if "duration" not in fields_resolver.get_names(): if new_status == Task.started: fields_resolver.add_fields( min__duration=max(0, task.duration or 0)) elif new_status in ( TaskStatus.completed, TaskStatus.failed, TaskStatus.stopped, ): fields_resolver.add_fields( duration=int((task.started - datetime.utcnow()).total_seconds())) status_reason = request.status_reason status_message = request.status_message force = request.force return ChangeStatusRequest( task=task, new_status=new_status or task.status, status_reason=status_reason, status_message=status_message, force=force, ).execute(**fields_resolver.get_fields(task))
from boltons import iterutils from apierrors import errors from apimodels.tasks import ( HyperParamKey, HyperParamItem, ReplaceHyperparams, Configuration, ) from bll.task import TaskBLL from config import config from database.model.task.task import ParamsItem, Task, ConfigurationItem, TaskStatus from utilities.parameter_key_escaper import ParameterKeyEscaper log = config.logger(__file__) task_bll = TaskBLL() class HyperParams: _properties_section = "properties" @classmethod def get_params(cls, company_id: str, task_ids: Sequence[str]) -> Dict[str, dict]: only = ("id", "hyperparams") tasks = task_bll.assert_exists( company_id=company_id, task_ids=task_ids, only=only, allow_public=True, )
def add_events(self, company_id, events, worker, allow_locked_tasks=False): actions = [] task_ids = set() task_iteration = defaultdict(lambda: 0) task_last_events = nested_dict( 3, dict) # task_id -> metric_hash -> variant_hash -> MetricEvent for event in events: # remove spaces from event type if "type" not in event: raise errors.BadRequest("Event must have a 'type' field", event=event) event_type = event["type"].replace(" ", "_") if event_type not in EVENT_TYPES: raise errors.BadRequest( "Invalid event type {}".format(event_type), event=event, types=EVENT_TYPES, ) event["type"] = event_type # @timestamp indicates the time the event is written, not when it happened event["@timestamp"] = es_factory.get_es_timestamp_str() # for backward bomba-tavili-tea if "ts" in event: event["timestamp"] = event.pop("ts") # set timestamp and worker if not sent if "timestamp" not in event: event["timestamp"] = es_factory.get_timestamp_millis() if "worker" not in event: event["worker"] = worker # force iter to be a long int iter = event.get("iter") if iter is not None: iter = int(iter) event["iter"] = iter # used to have "values" to indicate array. no need anymore if "values" in event: event["value"] = event["values"] del event["values"] index_name = EventMetrics.get_index_name(company_id, event_type) es_action = { "_op_type": "index", # overwrite if exists with same ID "_index": index_name, "_type": "event", "_source": event, } # for "log" events, don't assing custom _id - whatever is sent, is written (not overwritten) if event_type != "log": es_action["_id"] = self._get_event_id(event) else: es_action["_id"] = dbutils.id() task_id = event.get("task") if task_id is not None: es_action["_routing"] = task_id task_ids.add(task_id) if (iter is not None and event.get("metric") not in self._skip_iteration_for_metric): task_iteration[task_id] = max(iter, task_iteration[task_id]) if event_type == EventType.metrics_scalar.value: self._update_last_metric_event_for_task( task_last_events=task_last_events, task_id=task_id, event=event) else: es_action["_routing"] = task_id actions.append(es_action) if task_ids: # verify task_ids with translate_errors_context(), TimingContext( "mongo", "task_by_ids"): extra_msg = None query = Q(id__in=task_ids, company=company_id) if not allow_locked_tasks: query &= Q(status__nin=LOCKED_TASK_STATUSES) extra_msg = "or task published" res = Task.objects(query).only("id") if len(res) < len(task_ids): invalid_task_ids = tuple( set(task_ids) - set(r.id for r in res)) raise errors.bad_request.InvalidTaskId( extra_msg, company=company_id, ids=invalid_task_ids) errors_in_bulk = [] added = 0 chunk_size = 500 with translate_errors_context(), TimingContext("es", "events_add_batch"): # TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed with closing( helpers.streaming_bulk( self.es, actions, chunk_size=chunk_size, # thread_count=8, refresh=True, )) as it: for success, info in it: if success: added += chunk_size else: errors_in_bulk.append(info) remaining_tasks = set() now = datetime.utcnow() for task_id in task_ids: # Update related tasks. For reasons of performance, we prefer to update all of them and not only those # who's events were successful updated = self._update_task( company_id=company_id, task_id=task_id, now=now, iter_max=task_iteration.get(task_id), last_events=task_last_events.get(task_id), ) if not updated: remaining_tasks.add(task_id) continue if remaining_tasks: TaskBLL.set_last_update(remaining_tasks, company_id, last_update=now) # Compensate for always adding chunk_size on success (last chunk is probably smaller) added = min(added, len(actions)) return added, errors_in_bulk
def ping(_, company_id, request: PingRequest): TaskBLL.set_last_update(task_ids=[request.task], company_id=company_id, last_update=datetime.utcnow())
def add_events(self, company_id, events, worker, allow_locked_tasks=False) -> Tuple[int, int, dict]: actions = [] task_ids = set() task_iteration = defaultdict(lambda: 0) task_last_scalar_events = nested_dict( 3, dict) # task_id -> metric_hash -> variant_hash -> MetricEvent task_last_events = nested_dict( 3, dict) # task_id -> metric_hash -> event_type -> MetricEvent errors_per_type = defaultdict(int) valid_tasks = self._get_valid_tasks( company_id, task_ids={ event["task"] for event in events if event.get("task") is not None }, allow_locked_tasks=allow_locked_tasks, ) for event in events: # remove spaces from event type event_type = event.get("type") if event_type is None: errors_per_type["Event must have a 'type' field"] += 1 continue event_type = event_type.replace(" ", "_") if event_type not in EVENT_TYPES: errors_per_type[f"Invalid event type {event_type}"] += 1 continue task_id = event.get("task") if task_id is None: errors_per_type["Event must have a 'task' field"] += 1 continue if task_id not in valid_tasks: errors_per_type["Invalid task id"] += 1 continue event["type"] = event_type # @timestamp indicates the time the event is written, not when it happened event["@timestamp"] = es_factory.get_es_timestamp_str() # for backward bomba-tavili-tea if "ts" in event: event["timestamp"] = event.pop("ts") # set timestamp and worker if not sent if "timestamp" not in event: event["timestamp"] = es_factory.get_timestamp_millis() if "worker" not in event: event["worker"] = worker # force iter to be a long int iter = event.get("iter") if iter is not None: iter = int(iter) event["iter"] = iter # used to have "values" to indicate array. no need anymore if "values" in event: event["value"] = event["values"] del event["values"] event["metric"] = event.get("metric") or "" event["variant"] = event.get("variant") or "" index_name = EventMetrics.get_index_name(company_id, event_type) es_action = { "_op_type": "index", # overwrite if exists with same ID "_index": index_name, "_type": "event", "_source": event, } # for "log" events, don't assing custom _id - whatever is sent, is written (not overwritten) if event_type != "log": es_action["_id"] = self._get_event_id(event) else: es_action["_id"] = dbutils.id() es_action["_routing"] = task_id task_ids.add(task_id) if (iter is not None and event.get("metric") not in self._skip_iteration_for_metric): task_iteration[task_id] = max(iter, task_iteration[task_id]) self._update_last_metric_events_for_task( last_events=task_last_events[task_id], event=event, ) if event_type == EventType.metrics_scalar.value: self._update_last_scalar_events_for_task( last_events=task_last_scalar_events[task_id], event=event) actions.append(es_action) added = 0 if actions: chunk_size = 500 with translate_errors_context(), TimingContext( "es", "events_add_batch"): # TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed with closing( helpers.streaming_bulk( self.es, actions, chunk_size=chunk_size, # thread_count=8, refresh=True, )) as it: for success, info in it: if success: added += chunk_size else: errors_per_type[ "Error when indexing events batch"] += 1 remaining_tasks = set() now = datetime.utcnow() for task_id in task_ids: # Update related tasks. For reasons of performance, we prefer to update # all of them and not only those who's events were successful updated = self._update_task( company_id=company_id, task_id=task_id, now=now, iter_max=task_iteration.get(task_id), last_scalar_events=task_last_scalar_events.get( task_id), last_events=task_last_events.get(task_id), ) if not updated: remaining_tasks.add(task_id) continue if remaining_tasks: TaskBLL.set_last_update(remaining_tasks, company_id, last_update=now) # Compensate for always adding chunk_size on success (last chunk is probably smaller) added = min(added, len(actions)) if not added: raise errors.bad_request.EventsNotAdded(**errors_per_type) errors_count = sum(errors_per_type.values()) return added, errors_count, errors_per_type
from service_repo import APICall, endpoint from services.utils import conform_tag_fields, conform_output_tags from timing_context import TimingContext from utilities import safe_get task_fields = set(Task.get_fields()) task_script_fields = set(get_fields(Script)) get_all_query_options = Task.QueryParameterOptions( list_fields=("id", "user", "tags", "system_tags", "type", "status", "project"), datetime_fields=("status_changed", ), pattern_fields=("name", "comment"), fields=("parent", ), ) task_bll = TaskBLL() event_bll = EventBLL() queue_bll = QueueBLL() TaskBLL.start_non_responsive_tasks_watchdog() def set_task_status_from_call(request: UpdateRequest, company_id, new_status=None, **set_fields) -> dict: fields_resolver = SetFieldsResolver(set_fields) task = TaskBLL.get_task_with_access( request.task, company_id=company_id, only=tuple({"status", "project"} | fields_resolver.get_names()),
def update_for_task(call: APICall, company_id, _): task_id = call.data["task"] uri = call.data.get("uri") iteration = call.data.get("iteration") override_model_id = call.data.get("override_model_id") if not (uri or override_model_id) or (uri and override_model_id): raise errors.bad_request.MissingRequiredFields( "exactly one field is required", fields=("uri", "override_model_id")) with translate_errors_context(): query = dict(id=task_id, company=company_id) task = Task.get_for_writing( id=task_id, company=company_id, _only=["output", "execution", "name", "status", "project"], ) if not task: raise errors.bad_request.InvalidTaskId(**query) allowed_states = [TaskStatus.created, TaskStatus.in_progress] if task.status not in allowed_states: raise errors.bad_request.InvalidTaskStatus( f"model can only be updated for tasks in the {allowed_states} states", **query, ) if override_model_id: query = dict(company=company_id, id=override_model_id) model = Model.objects(**query).first() if not model: raise errors.bad_request.InvalidModelId(**query) else: if "name" not in call.data: # use task name if name not provided call.data["name"] = task.name if "comment" not in call.data: call.data[ "comment"] = f"Created by task `{task.name}` ({task.id})" if task.output and task.output.model: # model exists, update res = _update_model(call, company_id, model_id=task.output.model).to_struct() res.update({"id": task.output.model, "created": False}) call.result.data = res return # new model, create fields = parse_model_fields(call, create_fields) # create and save model model = Model( id=database.utils.id(), created=datetime.utcnow(), user=call.identity.user, company=company_id, project=task.project, framework=task.execution.framework, parent=task.execution.model, design=task.execution.model_desc, labels=task.execution.model_labels, ready=(task.status == TaskStatus.published), **fields, ) model.save() _update_cached_tags(company_id, project=model.project, fields=fields) TaskBLL.update_statistics( task_id=task_id, company_id=company_id, last_iteration_max=iteration, output__model=model.id, ) call.result.data = {"id": model.id, "created": True}