def _update_model(call: APICall, company_id, model_id=None): model_id = model_id or call.data["model"] model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id) data = prepare_update_fields(call, company_id, call.data) task_id = data.get("task") iteration = data.get("iteration") if task_id and iteration is not None: TaskBLL.update_statistics( task_id=task_id, company_id=company_id, last_iteration_max=iteration, ) updated_count, updated_fields = Model.safe_update(company_id, model.id, data) if updated_count: if any(uf in updated_fields for uf in last_update_fields): model.update(upsert=False, last_update=datetime.utcnow()) new_project = updated_fields.get("project", model.project) if new_project != model.project: _reset_cached_tags(company_id, projects=[new_project, model.project]) else: _update_cached_tags(company_id, project=model.project, fields=updated_fields) conform_output_tags(call, updated_fields) unescape_metadata(call, updated_fields) return UpdateResponse(updated=updated_count, fields=updated_fields)
def edit(call: APICall, company_id, _): model_id = call.data["model"] with translate_errors_context(): model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id) fields = parse_model_fields(call, create_fields) fields = prepare_update_fields(call, company_id, fields) for key in fields: field = getattr(model, key, None) value = fields[key] if (field and isinstance(value, dict) and isinstance(field, EmbeddedDocument)): d = field.to_mongo(use_db_field=False).to_dict() d.update(value) fields[key] = d iteration = call.data.get("iteration") task_id = model.task or fields.get("task") if task_id and iteration is not None: TaskBLL.update_statistics( task_id=task_id, company_id=company_id, last_iteration_max=iteration, ) if fields: if any(uf in fields for uf in last_update_fields): fields.update(last_update=datetime.utcnow()) updated = model.update(upsert=False, **fields) if updated: new_project = fields.get("project", model.project) if new_project != model.project: _reset_cached_tags(company_id, projects=[new_project, model.project]) else: _update_cached_tags(company_id, project=model.project, fields=fields) conform_output_tags(call, fields) call.result.data_model = UpdateResponse(updated=updated, fields=fields) else: call.result.data_model = UpdateResponse(updated=0)
def _update_task( self, company_id, task_id, now, iter_max=None, last_scalar_events=None, last_events=None, ): """ Update task information in DB with aggregated results after handling event(s) related to this task. This updates the task with the highest iteration value encountered during the last events update, as well as the latest metric/variant scalar values reported (according to the report timestamp) and the task's last update time. """ fields = {} if iter_max is not None: fields["last_iteration_max"] = iter_max if last_scalar_events: fields["last_scalar_values"] = list( flatten_nested_items( last_scalar_events, nesting=2, include_leaves=[ "value", "min_value", "max_value", "metric", "variant", ], )) if last_events: fields["last_events"] = last_events if not fields: return False return TaskBLL.update_statistics(task_id, company_id, last_update=now, **fields)
def add_or_update_model(_: APICall, company_id: str, request: AddUpdateModelRequest): get_task_for_update(company_id=company_id, task_id=request.task, force=True) models_field = f"models__{request.type}" model = ModelItem(name=request.name, model=request.model, updated=datetime.utcnow()) query = {"id": request.task, f"{models_field}__name": request.name} updated = Task.objects(**query).update_one( **{f"set__{models_field}__S": model}) updated = TaskBLL.update_statistics( task_id=request.task, company_id=company_id, last_iteration_max=request.iteration, **({ f"push__{models_field}": model } if not updated else {}), ) return {"updated": updated}
def update_for_task(call: APICall, company_id, _): if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version: raise errors.moved_permanently.NotSupported( "use tasks.add_or_update_model") task_id = call.data["task"] uri = call.data.get("uri") iteration = call.data.get("iteration") override_model_id = call.data.get("override_model_id") if not (uri or override_model_id) or (uri and override_model_id): raise errors.bad_request.MissingRequiredFields( "exactly one field is required", fields=("uri", "override_model_id")) with translate_errors_context(): query = dict(id=task_id, company=company_id) task = Task.get_for_writing( id=task_id, company=company_id, _only=["models", "execution", "name", "status", "project"], ) if not task: raise errors.bad_request.InvalidTaskId(**query) allowed_states = [TaskStatus.created, TaskStatus.in_progress] if task.status not in allowed_states: raise errors.bad_request.InvalidTaskStatus( f"model can only be updated for tasks in the {allowed_states} states", **query, ) if override_model_id: model = ModelBLL.get_company_model_by_id( company_id=company_id, model_id=override_model_id) else: if "name" not in call.data: # use task name if name not provided call.data["name"] = task.name if "comment" not in call.data: call.data[ "comment"] = f"Created by task `{task.name}` ({task.id})" if task.models and task.models.output: # model exists, update model_id = task.models.output[-1].model res = _update_model(call, company_id, model_id=model_id).to_struct() res.update({"id": model_id, "created": False}) call.result.data = res return # new model, create fields = parse_model_fields(call, create_fields) # create and save model now = datetime.utcnow() model = Model( id=database.utils.id(), created=now, last_update=now, user=call.identity.user, company=company_id, project=task.project, framework=task.execution.framework, parent=task.models.input[0].model if task.models and task.models.input else None, design=task.execution.model_desc, labels=task.execution.model_labels, ready=(task.status == TaskStatus.published), **fields, ) model.save() _update_cached_tags(company_id, project=model.project, fields=fields) TaskBLL.update_statistics( task_id=task_id, company_id=company_id, last_iteration_max=iteration, models__output=[ ModelItem( model=model.id, name=TaskModelNames[TaskModelTypes.output], updated=datetime.utcnow(), ) ], ) call.result.data = {"id": model.id, "created": True}