def create(call, company, req_model): assert isinstance(call, APICall) assert isinstance(req_model, CreateModelRequest) identity = call.identity if req_model.public: company = "" with translate_errors_context(): project = req_model.project if project: validate_id(Project, company=company, project=project) task = req_model.task req_data = req_model.to_struct() if task: validate_task(call, req_data) fields = filter_fields(Model, req_data) conform_tag_fields(call, fields) # create and save model model = Model( id=database.utils.id(), user=identity.user, company=company, created=datetime.utcnow(), **fields, ) model.save() call.result.data_model = CreateModelResponse(id=model.id, created=True)
def get_all(call: APICall): conform_tag_fields(call, call.data) queues = queue_bll.get_all(company_id=call.identity.company, query_dict=call.data) conform_output_tags(call, queues) call.result.data = {"queues": queues}
def prepare_update_fields(call, company_id, fields: dict): fields = fields.copy() if "uri" in fields: # clear UI cache if URI is provided (model updated) fields["ui_cache"] = fields.pop("ui_cache", {}) if "task" in fields: validate_task(company_id, fields) if "labels" in fields: labels = fields["labels"] def find_other_types(iterable, type_): res = [x for x in iterable if not isinstance(x, type_)] try: return set(res) except TypeError: # Un-hashable, probably return res invalid_keys = find_other_types(labels.keys(), str) if invalid_keys: raise errors.bad_request.ValidationError( "labels keys must be strings", keys=invalid_keys) invalid_values = find_other_types(labels.values(), int) if invalid_values: raise errors.bad_request.ValidationError( "labels values must be integers", values=invalid_values) conform_tag_fields(call, fields, validate=True) return fields
def update(call: APICall): """ update :summary: Update project information. See `project.create` for parameters. :return: updated - `int` - number of projects updated fields - `[string]` - updated fields """ project_id = call.data["project"] with translate_errors_context(): project = Project.get_for_writing(company=call.identity.company, id=project_id) if not project: raise errors.bad_request.InvalidProjectId(id=project_id) fields = parse_from_call(call.data, create_fields, Project.get_fields(), discard_none_values=False) conform_tag_fields(call, fields, validate=True) fields["last_update"] = datetime.utcnow() with TimingContext("mongo", "projects_update"): updated = project.update(upsert=False, **fields) conform_output_tags(call, fields) call.result.data_model = UpdateResponse(updated=updated, fields=fields)
def create(call: APICall, company_id, req_model: CreateModelRequest): if req_model.public: company_id = "" with translate_errors_context(): project = req_model.project if project: validate_id(Project, company=company_id, project=project) task = req_model.task req_data = req_model.to_struct() if task: validate_task(company_id, req_data) fields = filter_fields(Model, req_data) conform_tag_fields(call, fields, validate=True) # create and save model model = Model( id=database.utils.id(), user=call.identity.user, company=company_id, created=datetime.utcnow(), **fields, ) model.save() _update_cached_tags(company_id, project=model.project, fields=fields) call.result.data_model = CreateModelResponse(id=model.id, created=True)
def update(call: APICall, company_id, req_model: UpdateRequest): data = call.data_model_for_partial_update conform_tag_fields(call, data) updated, fields = queue_bll.update(company_id=company_id, queue_id=req_model.queue, **data) conform_output_tags(call, fields) call.result.data_model = UpdateResponse(updated=updated, fields=fields)
def get_all_ex(call: APICall, company_id, _): conform_tag_fields(call, call.data) with translate_errors_context(): with TimingContext("mongo", "models_get_all_ex"): models = Model.get_many_with_join(company=company_id, query_dict=call.data, allow_public=True) conform_output_tags(call, models) call.result.data = {"models": models}
def prepare_update_fields(call, fields): fields = fields.copy() if "uri" in fields: # clear UI cache if URI is provided (model updated) fields["ui_cache"] = fields.pop("ui_cache", {}) if "task" in fields: validate_task(call, fields) conform_tag_fields(call, fields) return fields
def prepare_update_fields(call: APICall, task, call_data): valid_fields = deepcopy(task.__class__.user_set_allowed()) update_fields = { k: v for k, v in create_fields.items() if k in valid_fields } update_fields["output__error"] = None t_fields = task_fields t_fields.add("output__error") fields = parse_from_call(call_data, update_fields, t_fields) conform_tag_fields(call, fields) return fields, valid_fields
def get_all(call: APICall): conform_tag_fields(call, call.data) with translate_errors_context(): with TimingContext("mongo", "models_get_all"): models = Model.get_many( company=call.identity.company, parameters=call.data, query_dict=call.data, allow_public=True, query_options=get_all_query_options, ) conform_output_tags(call, models) call.result.data = {"models": models}
def get_all_ex(call: APICall): conform_tag_fields(call, call.data) with translate_errors_context(): with TimingContext("mongo", "task_get_all_ex"): tasks = Task.get_many_with_join( company=call.identity.company, query_dict=call.data, query_options=get_all_query_options, allow_public= True, # required in case projection is requested for public dataset/versions ) conform_output_tags(call, tasks) call.result.data = {"tasks": tasks}
def get_all(call: APICall): conform_tag_fields(call, call.data) with translate_errors_context(), TimingContext("mongo", "projects_get_all"): projects = Project.get_many( company=call.identity.company, query_dict=call.data, query_options=get_all_query_options, parameters=call.data, allow_public=True, ) conform_output_tags(call, projects) call.result.data = {"projects": projects}
def get_all_ex(call: APICall, company_id, _): conform_tag_fields(call, call.data) escape_execution_parameters(call) with translate_errors_context(): with TimingContext("mongo", "task_get_all_ex"): tasks = Task.get_many_with_join( company=company_id, query_dict=call.data, allow_public= True, # required in case projection is requested for public dataset/versions ) unprepare_from_saved(call, tasks) call.result.data = {"tasks": tasks}
def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None): conform_tag_fields(call, fields, validate=True) params_prepare_for_save(fields, previous_task=previous_task) # Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths for field in task_script_fields: try: path = f"script/{field}" value = dpath.get(fields, path) if isinstance(value, str): value = value.strip() dpath.set(fields, path, value) except KeyError: pass return fields
def create(call): assert isinstance(call, APICall) identity = call.identity with translate_errors_context(): fields = parse_from_call(call.data, create_fields, Project.get_fields()) conform_tag_fields(call, fields, validate=True) now = datetime.utcnow() project = Project(id=database.utils.id(), user=identity.user, company=identity.company, created=now, last_update=now, **fields) with TimingContext("mongo", "projects_save"): project.save() call.result.data = {"id": project.id}
def prepare_for_save(call: APICall, fields: dict): conform_tag_fields(call, fields) # Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths for field in task_script_fields: try: path = f"script/{field}" value = dpath.get(fields, path) if isinstance(value, str): value = value.strip() dpath.set(fields, path, value) except KeyError: pass parameters = safe_get(fields, "execution/parameters") if parameters is not None: # Escape keys to make them mongo-safe parameters = {ParameterKeyEscaper.escape(k): v for k, v in parameters.items()} dpath.set(fields, "execution/parameters", parameters) return fields
def prepare_create_fields(call: APICall, valid_fields=None, output=None, previous_task: Task = None): valid_fields = valid_fields if valid_fields is not None else create_fields t_fields = task_fields t_fields.add("output_dest") fields = parse_from_call(call.data, valid_fields, t_fields) # Move output_dest to output.destination output_dest = fields.get("output_dest") if output_dest is not None: fields.pop("output_dest") if output: output.destination = output_dest else: output = Output(destination=output_dest) fields["output"] = output conform_tag_fields(call, fields) # Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths for field in task_script_fields: try: path = "script/%s" % field value = dpath.get(fields, path) if isinstance(value, six.string_types): value = value.strip() dpath.set(fields, path, value) except KeyError: pass parameters = safe_get(fields, "execution/parameters") if parameters is not None: parameters = {k.strip(): v for k, v in parameters.items()} dpath.set(fields, "execution/parameters", parameters) return fields
def get_all_ex(call: APICall): include_stats = call.data.get("include_stats") stats_for_state = call.data.get("stats_for_state", EntityVisibility.active.value) allow_public = not call.data.get("non_public", False) if stats_for_state: try: specific_state = EntityVisibility(stats_for_state) except ValueError: raise errors.bad_request.FieldsValueError( stats_for_state=stats_for_state) else: specific_state = None conform_tag_fields(call, call.data) with translate_errors_context(), TimingContext("mongo", "projects_get_all"): projects = Project.get_many_with_join( company=call.identity.company, query_dict=call.data, query_options=get_all_query_options, allow_public=allow_public, ) conform_output_tags(call, projects) if not include_stats: call.result.data = {"projects": projects} return ids = [project["id"] for project in projects] status_count_pipeline, runtime_pipeline = make_projects_get_all_pipelines( call.identity.company, ids, specific_state=specific_state) default_counts = dict.fromkeys(get_options(TaskStatus), 0) def set_default_count(entry): return dict(default_counts, **entry) status_count = defaultdict(lambda: {}) key = itemgetter(EntityVisibility.archived.value) for result in Task.aggregate(status_count_pipeline): for k, group in groupby(sorted(result["counts"], key=key), key): section = (EntityVisibility.archived if k else EntityVisibility.active).value status_count[result["_id"]][section] = set_default_count({ count_entry["status"]: count_entry["count"] for count_entry in group }) runtime = { result["_id"]: {k: v for k, v in result.items() if k != "_id"} for result in Task.aggregate(runtime_pipeline) } def safe_get(obj, path, default=None): try: return dpath.get(obj, path) except KeyError: return default def get_status_counts(project_id, section): path = "/".join((project_id, section)) return { "total_runtime": safe_get(runtime, path, 0), "status_count": safe_get(status_count, path, default_counts), } report_for_states = [ s for s in EntityVisibility if not specific_state or specific_state == s ] for project in projects: project["stats"] = { task_state.value: get_status_counts(project["id"], task_state.value) for task_state in report_for_states } call.result.data = {"projects": projects}
def parse_model_fields(call, valid_fields): fields = parse_from_call(call.data, valid_fields, Model.get_fields()) conform_tag_fields(call, fields, validate=True) return fields