def _update_model(call: APICall, company_id, model_id=None): model_id = model_id or call.data["model"] with translate_errors_context(): # get model by id query = dict(id=model_id, company=company_id) model = Model.objects(**query).first() if not model: raise errors.bad_request.InvalidModelId(**query) data = prepare_update_fields(call, company_id, call.data) task_id = data.get("task") iteration = data.get("iteration") if task_id and iteration is not None: TaskBLL.update_statistics( task_id=task_id, company_id=company_id, last_iteration_max=iteration, ) updated_count, updated_fields = Model.safe_update( company_id, model.id, data) if updated_count: new_project = updated_fields.get("project", model.project) if new_project != model.project: _reset_cached_tags(company_id, projects=[new_project, model.project]) else: _update_cached_tags(company_id, project=model.project, fields=updated_fields) conform_output_tags(call, updated_fields) return UpdateResponse(updated=updated_count, fields=updated_fields)
def _update_model(call: APICall, model_id=None): identity = call.identity model_id = model_id or call.data["model"] with translate_errors_context(): # get model by id query = dict(id=model_id, company=identity.company) model = Model.objects(**query).first() if not model: raise errors.bad_request.InvalidModelId(**query) data = prepare_update_fields(call, call.data) task_id = data.get("task") iteration = data.get("iteration") if task_id and iteration is not None: TaskBLL.update_statistics( task_id=task_id, company_id=identity.company, last_iteration_max=iteration, ) updated_count, updated_fields = Model.safe_update( call.identity.company, model.id, data ) conform_output_tags(call, updated_fields) return UpdateResponse(updated=updated_count, fields=updated_fields)
def _update_task(self, company_id, task_id, now, iter_max=None, last_events=None): """ Update task information in DB with aggregated results after handling event(s) related to this task. This updates the task with the highest iteration value encountered during the last events update, as well as the latest metric/variant scalar values reported (according to the report timestamp) and the task's last update time. """ fields = {} if iter_max is not None: fields["last_iteration_max"] = iter_max if last_events: fields["last_values"] = list( flatten_nested_items( last_events, nesting=2, include_leaves=["value", "metric", "variant"], )) if not fields: return False return TaskBLL.update_statistics(task_id, company_id, last_update=now, **fields)
def edit(call: APICall, company_id, _): model_id = call.data["model"] with translate_errors_context(): query = dict(id=model_id, company=company_id) model = Model.objects(**query).first() if not model: raise errors.bad_request.InvalidModelId(**query) fields = parse_model_fields(call, create_fields) fields = prepare_update_fields(call, company_id, fields) for key in fields: field = getattr(model, key, None) value = fields[key] if (field and isinstance(value, dict) and isinstance(field, EmbeddedDocument)): d = field.to_mongo(use_db_field=False).to_dict() d.update(value) fields[key] = d iteration = call.data.get("iteration") task_id = model.task or fields.get("task") if task_id and iteration is not None: TaskBLL.update_statistics( task_id=task_id, company_id=company_id, last_iteration_max=iteration, ) if fields: updated = model.update(upsert=False, **fields) if updated: new_project = fields.get("project", model.project) if new_project != model.project: _reset_cached_tags(company_id, projects=[new_project, model.project]) else: _update_cached_tags(company_id, project=model.project, fields=fields) conform_output_tags(call, fields) call.result.data_model = UpdateResponse(updated=updated, fields=fields) else: call.result.data_model = UpdateResponse(updated=0)
def edit(call): assert isinstance(call, APICall) identity = call.identity model_id = call.data["model"] with translate_errors_context(): query = dict(id=model_id, company=identity.company) model = Model.objects(**query).first() if not model: raise errors.bad_request.InvalidModelId(**query) fields = parse_model_fields(call, create_fields) fields = prepare_update_fields(call, fields) for key in fields: field = getattr(model, key, None) value = fields[key] if (field and isinstance(value, dict) and isinstance(field, EmbeddedDocument)): d = field.to_mongo(use_db_field=False).to_dict() d.update(value) fields[key] = d iteration = call.data.get("iteration") task_id = model.task or fields.get('task') if task_id and iteration is not None: TaskBLL.update_statistics( task_id=task_id, company_id=identity.company, last_iteration_max=iteration, ) if fields: updated = model.update(upsert=False, **fields) call.result.data_model = UpdateResponse(updated=updated, fields=fields) else: call.result.data_model = UpdateResponse(updated=0)
def _update_task(self, company_id, task_id, now, iter=None, last_events=None, last_metrics=None): """ Update task information in DB with aggregated results after handling event(s) related to this task. This updates the task with the highest iteration value encountered during the last events update, as well as the latest metric/variant scalar values reported (according to the report timestamp) and the task's last update time. """ fields = {} if iter is not None: fields["last_iteration"] = iter if last_events: def get_metric_event(ev): me = MetricEvent.from_dict(**ev) if "timestamp" in ev: me.timestamp = datetime.utcfromtimestamp(ev["timestamp"] / 1000) return me new_last_metrics = nested_dict(2, MetricEvent) new_last_metrics.update(last_metrics) for metric_hash, variants in last_events.items(): for variant_hash, event in variants.items(): new_last_metrics[metric_hash][ variant_hash] = get_metric_event(event) fields["last_metrics"] = new_last_metrics.to_dict() if not fields: return False return TaskBLL.update_statistics(task_id, company_id, last_update=now, **fields)
def update_for_task(call: APICall, company_id, _): task_id = call.data["task"] uri = call.data.get("uri") iteration = call.data.get("iteration") override_model_id = call.data.get("override_model_id") if not (uri or override_model_id) or (uri and override_model_id): raise errors.bad_request.MissingRequiredFields( "exactly one field is required", fields=("uri", "override_model_id")) with translate_errors_context(): query = dict(id=task_id, company=company_id) task = Task.get_for_writing( id=task_id, company=company_id, _only=["output", "execution", "name", "status", "project"], ) if not task: raise errors.bad_request.InvalidTaskId(**query) allowed_states = [TaskStatus.created, TaskStatus.in_progress] if task.status not in allowed_states: raise errors.bad_request.InvalidTaskStatus( f"model can only be updated for tasks in the {allowed_states} states", **query, ) if override_model_id: query = dict(company=company_id, id=override_model_id) model = Model.objects(**query).first() if not model: raise errors.bad_request.InvalidModelId(**query) else: if "name" not in call.data: # use task name if name not provided call.data["name"] = task.name if "comment" not in call.data: call.data[ "comment"] = f"Created by task `{task.name}` ({task.id})" if task.output and task.output.model: # model exists, update res = _update_model(call, company_id, model_id=task.output.model).to_struct() res.update({"id": task.output.model, "created": False}) call.result.data = res return # new model, create fields = parse_model_fields(call, create_fields) # create and save model model = Model( id=database.utils.id(), created=datetime.utcnow(), user=call.identity.user, company=company_id, project=task.project, framework=task.execution.framework, parent=task.execution.model, design=task.execution.model_desc, labels=task.execution.model_labels, ready=(task.status == TaskStatus.published), **fields, ) model.save() _update_cached_tags(company_id, project=model.project, fields=fields) TaskBLL.update_statistics( task_id=task_id, company_id=company_id, last_iteration_max=iteration, output__model=model.id, ) call.result.data = {"id": model.id, "created": True}