def get_by_task_id(call: APICall, company_id, _): task_id = call.data["task"] with translate_errors_context(): query = dict(id=task_id, company=company_id) task = Task.get(_only=["output"], **query) if not task: raise errors.bad_request.InvalidTaskId(**query) if not task.output: raise errors.bad_request.MissingTaskFields(field="output") if not task.output.model: raise errors.bad_request.MissingTaskFields(field="output.model") model_id = task.output.model model = Model.objects( Q(id=model_id) & get_company_or_none_constraint(company_id)).first() if not model: raise errors.bad_request.InvalidModelId( "no such public or company model", id=model_id, company=company_id, ) model_dict = model.to_proper_dict() conform_output_tags(call, model_dict) call.result.data = {"model": model_dict}
def model_set_ready( cls, model_id: str, company_id: str, publish_task: bool, force_publish_task: bool = False, ) -> tuple: with translate_errors_context(): query = dict(id=model_id, company=company_id) model = Model.objects(**query).first() if not model: raise errors.bad_request.InvalidModelId(**query) elif model.ready: raise errors.bad_request.ModelIsReady(**query) published_task_data = {} if model.task and publish_task: task = (Task.objects(id=model.task, company=company_id).only( "id", "status").first()) if task and task.status != TaskStatus.published: published_task_data["data"] = cls.publish_task( task_id=model.task, company_id=company_id, publish_model=False, force=force_publish_task, ) published_task_data["id"] = model.task updated = model.update(upsert=False, ready=True) return updated, published_task_data
def _update_model(call: APICall, company_id, model_id=None): model_id = model_id or call.data["model"] with translate_errors_context(): # get model by id query = dict(id=model_id, company=company_id) model = Model.objects(**query).first() if not model: raise errors.bad_request.InvalidModelId(**query) data = prepare_update_fields(call, company_id, call.data) task_id = data.get("task") iteration = data.get("iteration") if task_id and iteration is not None: TaskBLL.update_statistics( task_id=task_id, company_id=company_id, last_iteration_max=iteration, ) updated_count, updated_fields = Model.safe_update( company_id, model.id, data) if updated_count: new_project = updated_fields.get("project", model.project) if new_project != model.project: _reset_cached_tags(company_id, projects=[new_project, model.project]) else: _update_cached_tags(company_id, project=model.project, fields=updated_fields) conform_output_tags(call, updated_fields) return UpdateResponse(updated=updated_count, fields=updated_fields)
def get_outputs_for_deletion(task, force=False): with TimingContext("mongo", "get_task_models"): models = TaskOutputs( attrgetter("ready"), Model, Model.objects(task=task.id).only("id", "task", "ready"), ) if not force and models.published: raise errors.bad_request.TaskCannotBeDeleted( "has output models, use force=True", task=task.id, models=len(models.published), ) if task.output.model: output_model = get_output_model(task, force) if output_model: if output_model.ready: models.published.append(output_model) else: models.draft.append(output_model) with TimingContext("mongo", "get_task_children"): tasks = Task.objects(parent=task.id).only("id", "parent", "status") published_tasks = [ task for task in tasks if task.status == TaskStatus.published ] if not force and published_tasks: raise errors.bad_request.TaskCannotBeDeleted( "has children, use force=True", task=task.id, children=published_tasks) return models, tasks
def _update_model(call: APICall, model_id=None): identity = call.identity model_id = model_id or call.data["model"] with translate_errors_context(): # get model by id query = dict(id=model_id, company=identity.company) model = Model.objects(**query).first() if not model: raise errors.bad_request.InvalidModelId(**query) data = prepare_update_fields(call, call.data) task_id = data.get("task") iteration = data.get("iteration") if task_id and iteration is not None: TaskBLL.update_statistics( task_id=task_id, company_id=identity.company, last_iteration_max=iteration, ) updated_count, updated_fields = Model.safe_update( call.identity.company, model.id, data ) conform_output_tags(call, updated_fields) return UpdateResponse(updated=updated_count, fields=updated_fields)
def get_by_task_id(call): assert isinstance(call, APICall) task_id = call.data["task"] with translate_errors_context(): query = dict(id=task_id, company=call.identity.company) res = Task.get(_only=["output"], **query) if not res: raise errors.bad_request.InvalidTaskId(**query) if not res.output: raise errors.bad_request.MissingTaskFields(field="output") if not res.output.model: raise errors.bad_request.MissingTaskFields(field="output.model") model_id = res.output.model res = Model.objects( Q(id=model_id) & get_company_or_none_constraint(call.identity.company)).first() if not res: raise errors.bad_request.InvalidModelId( "no such public or company model", id=model_id, company=call.identity.company, ) call.result.data = {"model": res.to_proper_dict()}
def get_output_model(task, force=False): with TimingContext("mongo", "get_task_output_model"): output_model = Model.objects(id=task.output.model).first() if output_model and output_model.ready and not force: raise errors.bad_request.TaskCannotBeDeleted( "has output model, use force=True", task=task.id, model=task.output.model ) return output_model
def get_frameworks(self, company, project_ids: Optional[Sequence]) -> Sequence: """ Return the list of unique frameworks used by company and public models If project ids passed then only models from these projects are considered """ query = get_company_or_none_constraint(company) if project_ids: query &= Q(project__in=project_ids) return Model.objects(query).distinct(field="framework")
def update(call: APICall, company_id, _): model_id = call.data["model"] force = call.data.get("force", False) with translate_errors_context(): query = dict(id=model_id, company=company_id) model = Model.objects(**query).only("id", "task", "project").first() if not model: raise errors.bad_request.InvalidModelId(**query) deleted_model_id = f"__DELETED__{model_id}" using_tasks = Task.objects(execution__model=model_id).only("id") if using_tasks: if not force: raise errors.bad_request.ModelInUse( "as execution model, use force=True to delete", num_tasks=len(using_tasks), ) # update deleted model id in using tasks using_tasks.update(execution__model=deleted_model_id, upsert=False, multi=True) if model.task: task = Task.objects(id=model.task).first() if task and task.status == TaskStatus.published: if not force: raise errors.bad_request.ModelCreatingTaskExists( "and published, use force=True to delete", task=model.task) task.update( output__model=deleted_model_id, output__error= f"model deleted on {datetime.utcnow().isoformat()}", upsert=False, ) del_count = Model.objects(**query).delete() if del_count: _reset_cached_tags(company_id, projects=[model.project]) call.result.data = dict(deleted=del_count > 0)
def publish_task( cls, task_id: str, company_id: str, publish_model: bool, force: bool, status_reason: str = "", status_message: str = "", ) -> dict: task = cls.get_task_with_access( task_id, company_id=company_id, requires_write_access=True ) if not force: validate_status_change(task.status, TaskStatus.published) previous_task_status = task.status output = task.output or Output() publish_failed = False try: # set state to publishing task.status = TaskStatus.publishing task.save() # publish task models if task.output.model and publish_model: output_model = ( Model.objects(id=task.output.model) .only("id", "task", "ready") .first() ) if output_model and not output_model.ready: cls.model_set_ready( model_id=task.output.model, company_id=company_id, publish_task=False, ) # set task status to published, and update (or set) it's new output (view and models) return ChangeStatusRequest( task=task, new_status=TaskStatus.published, force=force, status_reason=status_reason, status_message=status_message, ).execute(published=datetime.utcnow(), output=output) except Exception as ex: publish_failed = True raise ex finally: if publish_failed: task.status = previous_task_status task.save()
def validate_execution_model(task, allow_only_public=False): if not task.execution or not task.execution.model: return company = None if allow_only_public else task.company model_id = task.execution.model model = Model.objects( Q(id=model_id) & get_company_or_none_constraint(company)).first() if not model: raise errors.bad_request.InvalidModelId(model=model_id) return model
def edit(call: APICall, company_id, _): model_id = call.data["model"] with translate_errors_context(): query = dict(id=model_id, company=company_id) model = Model.objects(**query).first() if not model: raise errors.bad_request.InvalidModelId(**query) fields = parse_model_fields(call, create_fields) fields = prepare_update_fields(call, company_id, fields) for key in fields: field = getattr(model, key, None) value = fields[key] if (field and isinstance(value, dict) and isinstance(field, EmbeddedDocument)): d = field.to_mongo(use_db_field=False).to_dict() d.update(value) fields[key] = d iteration = call.data.get("iteration") task_id = model.task or fields.get("task") if task_id and iteration is not None: TaskBLL.update_statistics( task_id=task_id, company_id=company_id, last_iteration_max=iteration, ) if fields: updated = model.update(upsert=False, **fields) if updated: new_project = fields.get("project", model.project) if new_project != model.project: _reset_cached_tags(company_id, projects=[new_project, model.project]) else: _update_cached_tags(company_id, project=model.project, fields=fields) conform_output_tags(call, fields) call.result.data_model = UpdateResponse(updated=updated, fields=fields) else: call.result.data_model = UpdateResponse(updated=0)
def _resolve_entities( cls, experiments: Sequence[str] = None, projects: Sequence[str] = None, task_statuses: Sequence[str] = None, ) -> Dict[Type[mongoengine.Document], Set[mongoengine.Document]]: entities = defaultdict(set) if projects: print("Reading projects...") entities[Project].update(cls._resolve_type(Project, projects)) print("--> Reading project experiments...") query = Q( project__in=list( set(filter(None, (p.id for p in entities[Project])))), system_tags__nin=[EntityVisibility.archived.value], ) if task_statuses: query &= Q(status__in=list(set(task_statuses))) objs = Task.objects(query) entities[Task].update(o for o in objs if o.id not in (experiments or [])) if experiments: print("Reading experiments...") entities[Task].update(cls._resolve_type(Task, experiments)) print("--> Reading experiments projects...") objs = Project.objects(id__in=list( set(filter(None, (p.project for p in entities[Task]))))) project_ids = {p.id for p in entities[Project]} entities[Project].update(o for o in objs if o.id not in project_ids) model_ids = { model_id for task in entities[Task] for model_id in (task.output.model, task.execution.model) if model_id } if model_ids: print("Reading models...") entities[Model] = set(Model.objects(id__in=list(model_ids))) return entities
def edit(call): assert isinstance(call, APICall) identity = call.identity model_id = call.data["model"] with translate_errors_context(): query = dict(id=model_id, company=identity.company) model = Model.objects(**query).first() if not model: raise errors.bad_request.InvalidModelId(**query) fields = parse_model_fields(call, create_fields) fields = prepare_update_fields(call, fields) for key in fields: field = getattr(model, key, None) value = fields[key] if (field and isinstance(value, dict) and isinstance(field, EmbeddedDocument)): d = field.to_mongo(use_db_field=False).to_dict() d.update(value) fields[key] = d iteration = call.data.get("iteration") task_id = model.task or fields.get('task') if task_id and iteration is not None: TaskBLL.update_statistics( task_id=task_id, company_id=identity.company, last_iteration_max=iteration, ) if fields: updated = model.update(upsert=False, **fields) call.result.data_model = UpdateResponse(updated=updated, fields=fields) else: call.result.data_model = UpdateResponse(updated=0)
def update_for_task(call: APICall, company_id, _): task_id = call.data["task"] uri = call.data.get("uri") iteration = call.data.get("iteration") override_model_id = call.data.get("override_model_id") if not (uri or override_model_id) or (uri and override_model_id): raise errors.bad_request.MissingRequiredFields( "exactly one field is required", fields=("uri", "override_model_id")) with translate_errors_context(): query = dict(id=task_id, company=company_id) task = Task.get_for_writing( id=task_id, company=company_id, _only=["output", "execution", "name", "status", "project"], ) if not task: raise errors.bad_request.InvalidTaskId(**query) allowed_states = [TaskStatus.created, TaskStatus.in_progress] if task.status not in allowed_states: raise errors.bad_request.InvalidTaskStatus( f"model can only be updated for tasks in the {allowed_states} states", **query, ) if override_model_id: query = dict(company=company_id, id=override_model_id) model = Model.objects(**query).first() if not model: raise errors.bad_request.InvalidModelId(**query) else: if "name" not in call.data: # use task name if name not provided call.data["name"] = task.name if "comment" not in call.data: call.data[ "comment"] = f"Created by task `{task.name}` ({task.id})" if task.output and task.output.model: # model exists, update res = _update_model(call, company_id, model_id=task.output.model).to_struct() res.update({"id": task.output.model, "created": False}) call.result.data = res return # new model, create fields = parse_model_fields(call, create_fields) # create and save model model = Model( id=database.utils.id(), created=datetime.utcnow(), user=call.identity.user, company=company_id, project=task.project, framework=task.execution.framework, parent=task.execution.model, design=task.execution.model_desc, labels=task.execution.model_labels, ready=(task.status == TaskStatus.published), **fields, ) model.save() _update_cached_tags(company_id, project=model.project, fields=fields) TaskBLL.update_statistics( task_id=task_id, company_id=company_id, last_iteration_max=iteration, output__model=model.id, ) call.result.data = {"id": model.id, "created": True}