def update(call: APICall, company_id, req_model: UpdateRequest): task_id = req_model.task with translate_errors_context(): task = Task.get_for_writing( id=task_id, company=company_id, _only=["id", "project"] ) if not task: raise errors.bad_request.InvalidTaskId(id=task_id) partial_update_dict, valid_fields = prepare_update_fields(call, task, call.data) if not partial_update_dict: return UpdateResponse(updated=0) updated_count, updated_fields = Task.safe_update( company_id=company_id, id=task_id, partial_update_dict=partial_update_dict, injected_update=dict(last_change=datetime.utcnow()), ) if updated_count: new_project = updated_fields.get("project", task.project) if new_project != task.project: _reset_cached_tags(company_id, projects=[new_project, task.project]) else: _update_cached_tags( company_id, project=task.project, fields=updated_fields ) update_project_time(updated_fields.get("project")) unprepare_from_saved(call, updated_fields) return UpdateResponse(updated=updated_count, fields=updated_fields)
def enqueue(call: APICall, company_id, req_model: EnqueueRequest): task_id = req_model.task queue_id = req_model.queue status_message = req_model.status_message status_reason = req_model.status_reason if not queue_id: # try to get default queue queue_id = queue_bll.get_default(company_id).id with translate_errors_context(): query = dict(id=task_id, company=company_id) task = Task.get_for_writing( _only=("type", "script", "execution", "status", "project", "id"), **query ) if not task: raise errors.bad_request.InvalidTaskId(**query) res = EnqueueResponse( **ChangeStatusRequest( task=task, new_status=TaskStatus.queued, status_reason=status_reason, status_message=status_message, allow_same_state_transition=False, ).execute() ) try: queue_bll.add_task( company_id=company_id, queue_id=queue_id, task_id=task.id ) except Exception: # failed enqueueing, revert to previous state ChangeStatusRequest( task=task, current_status_override=TaskStatus.queued, new_status=task.status, force=True, status_reason="failed enqueueing", ).execute() raise # set the current queue ID in the task if task.execution: Task.objects(**query).update(execution__queue=queue_id, multi=False) else: Task.objects(**query).update( execution=Execution(queue=queue_id), multi=False ) res.queued = 1 res.fields.update(**{"execution.queue": queue_id}) call.result.data_model = res
def enqueue_task( task_id: str, company_id: str, queue_id: str, status_message: str, status_reason: str, validate: bool = False, ) -> Tuple[int, dict]: if not queue_id: # try to get default queue queue_id = queue_bll.get_default(company_id).id query = dict(id=task_id, company=company_id) task = Task.get_for_writing(**query) if not task: raise errors.bad_request.InvalidTaskId(**query) if validate: TaskBLL.validate(task) res = ChangeStatusRequest( task=task, new_status=TaskStatus.queued, status_reason=status_reason, status_message=status_message, allow_same_state_transition=False, ).execute(enqueue_status=task.status) try: queue_bll.add_task(company_id=company_id, queue_id=queue_id, task_id=task.id) except Exception: # failed enqueueing, revert to previous state ChangeStatusRequest( task=task, current_status_override=TaskStatus.queued, new_status=task.status, force=True, status_reason="failed enqueueing", ).execute(enqueue_status=None) raise # set the current queue ID in the task if task.execution: Task.objects(**query).update(execution__queue=queue_id, multi=False) else: Task.objects(**query).update(execution=Execution(queue=queue_id), multi=False) nested_set(res, ("fields", "execution.queue"), queue_id) return 1, res
def dequeue_task( task_id: str, company_id: str, status_message: str, status_reason: str, ) -> Tuple[int, dict]: query = dict(id=task_id, company=company_id) task = Task.get_for_writing(**query) if not task: raise errors.bad_request.InvalidTaskId(**query) res = TaskBLL.dequeue_and_change_status( task, company_id, status_message=status_message, status_reason=status_reason, ) return 1, res
def get_task_for_update( company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False ) -> Task: """ Loads only task id and return the task only if it is updatable (status == 'created') """ task = Task.get_for_writing(company=company_id, id=task_id, _only=("id", "status")) if not task: raise errors.bad_request.InvalidTaskId(id=task_id) if allow_all_statuses: return task allowed_statuses = ( [TaskStatus.created, TaskStatus.in_progress] if force else [TaskStatus.created] ) if task.status not in allowed_statuses: raise errors.bad_request.InvalidTaskStatus( expected=TaskStatus.created, status=task.status ) return task
def get_task_with_access(task_id, company_id, only=None, allow_public=False, requires_write_access=False) -> Task: """ Gets a task that has a required write access :except errors.bad_request.InvalidTaskId: if the task is not found :except errors.forbidden.NoWritePermission: if write_access was required and the task cannot be modified """ with translate_errors_context(): query = dict(id=task_id, company=company_id) with TimingContext("mongo", "task_with_access"): if requires_write_access: task = Task.get_for_writing(_only=only, **query) else: task = Task.get(_only=only, **query, include_public=allow_public) if not task: raise errors.bad_request.InvalidTaskId(**query) return task
def edit(call: APICall, company_id, req_model: UpdateRequest): task_id = req_model.task force = req_model.force with translate_errors_context(): task = Task.get_for_writing(id=task_id, company=company_id) if not task: raise errors.bad_request.InvalidTaskId(id=task_id) if not force and task.status != TaskStatus.created: raise errors.bad_request.InvalidTaskStatus( expected=TaskStatus.created, status=task.status ) edit_fields = create_fields.copy() edit_fields.update(dict(status=None)) with translate_errors_context( field_does_not_exist_cls=errors.bad_request.ValidationError ), TimingContext("code", "parse_and_validate"): fields = prepare_create_fields( call, valid_fields=edit_fields, output=task.output, previous_task=task ) for key in fields: field = getattr(task, key, None) value = fields[key] if ( field and isinstance(value, dict) and isinstance(field, EmbeddedDocument) ): d = field.to_mongo(use_db_field=False).to_dict() d.update(value) fields[key] = d task_bll.validate(task_bll.create(call, fields)) # make sure field names do not end in mongoengine comparison operators fixed_fields = { (k if k not in COMPARISON_OPERATORS else "%s__" % k): v for k, v in fields.items() } if fixed_fields: now = datetime.utcnow() last_change = dict(last_change=now) if not set(fields).issubset(Task.user_set_allowed()): last_change.update(last_update=now) fields.update(**last_change) fixed_fields.update(**last_change) updated = task.update(upsert=False, **fixed_fields) if updated: new_project = fixed_fields.get("project", task.project) if new_project != task.project: _reset_cached_tags(company_id, projects=[new_project, task.project]) else: _update_cached_tags( company_id, project=task.project, fields=fixed_fields ) update_project_time(fields.get("project")) unprepare_from_saved(call, fields) call.result.data_model = UpdateResponse(updated=updated, fields=fields) else: call.result.data_model = UpdateResponse(updated=0)
def validate_task(company_id, fields: dict): Task.get_for_writing(company=company_id, id=fields["task"], _only=["id"])
def update_for_task(call: APICall, company_id, _): if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version: raise errors.moved_permanently.NotSupported( "use tasks.add_or_update_model") task_id = call.data["task"] uri = call.data.get("uri") iteration = call.data.get("iteration") override_model_id = call.data.get("override_model_id") if not (uri or override_model_id) or (uri and override_model_id): raise errors.bad_request.MissingRequiredFields( "exactly one field is required", fields=("uri", "override_model_id")) with translate_errors_context(): query = dict(id=task_id, company=company_id) task = Task.get_for_writing( id=task_id, company=company_id, _only=["models", "execution", "name", "status", "project"], ) if not task: raise errors.bad_request.InvalidTaskId(**query) allowed_states = [TaskStatus.created, TaskStatus.in_progress] if task.status not in allowed_states: raise errors.bad_request.InvalidTaskStatus( f"model can only be updated for tasks in the {allowed_states} states", **query, ) if override_model_id: model = ModelBLL.get_company_model_by_id( company_id=company_id, model_id=override_model_id) else: if "name" not in call.data: # use task name if name not provided call.data["name"] = task.name if "comment" not in call.data: call.data[ "comment"] = f"Created by task `{task.name}` ({task.id})" if task.models and task.models.output: # model exists, update model_id = task.models.output[-1].model res = _update_model(call, company_id, model_id=model_id).to_struct() res.update({"id": model_id, "created": False}) call.result.data = res return # new model, create fields = parse_model_fields(call, create_fields) # create and save model now = datetime.utcnow() model = Model( id=database.utils.id(), created=now, last_update=now, user=call.identity.user, company=company_id, project=task.project, framework=task.execution.framework, parent=task.models.input[0].model if task.models and task.models.input else None, design=task.execution.model_desc, labels=task.execution.model_labels, ready=(task.status == TaskStatus.published), **fields, ) model.save() _update_cached_tags(company_id, project=model.project, fields=fields) TaskBLL.update_statistics( task_id=task_id, company_id=company_id, last_iteration_max=iteration, models__output=[ ModelItem( model=model.id, name=TaskModelNames[TaskModelTypes.output], updated=datetime.utcnow(), ) ], ) call.result.data = {"id": model.id, "created": True}