def delete_params(cls, company_id: str, task_id: str, hyperparams=Sequence[HyperParamKey]) -> int: properties_only = cls._normalize_params(hyperparams) task = cls._get_task_for_update(company=company_id, id=task_id, allow_all_statuses=properties_only) with_param, without_param = iterutils.partition( hyperparams, key=lambda p: bool(p.name)) sections_to_delete = {p.section for p in without_param} delete_cmds = { f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1 for section in sections_to_delete } for item in with_param: section = ParameterKeyEscaper.escape(item.section) if item.section in sections_to_delete: raise errors.bad_request.FieldsConflict( "Cannot delete section field if the whole section was scheduled for deletion" ) name = ParameterKeyEscaper.escape(item.name) delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1 return task.update(**delete_cmds, last_update=datetime.utcnow())
def params_unprepare_from_saved(fields, copy_to_legacy=False): """ Unescape all section and param names for hyper params and configuration If copy_to_legacy is set then copy hyperparams and configuration data to the legacy location for the old clients """ for param_field in ("hyperparams", "configuration"): params = safe_get(fields, param_field) if params: unescaped_params = { ParameterKeyEscaper.unescape(key): { ParameterKeyEscaper.unescape(k): v for k, v in value.items() } if isinstance(value, dict) else value for key, value in params.items() } dpath.set(fields, param_field, unescaped_params) if copy_to_legacy: for new_params_field, old_params_field, use_sections in ( (f"hyperparams", "execution/parameters", True), (f"configuration", "execution/model_desc", False), ): legacy_params = _get_legacy_params( safe_get(fields, new_params_field), with_sections=use_sections ) if legacy_params: dpath.new( fields, old_params_field, {_get_full_param_name(p): p["value"] for p in legacy_params}, )
def _db_dicts_from_list( cls, items: Sequence[HyperParamItem]) -> Dict[str, dict]: sections = iterutils.bucketize(items, key=attrgetter("section")) return { ParameterKeyEscaper.escape(section): { ParameterKeyEscaper.escape(param.name): ParamsItem(**param.to_struct()) for param in params } for section, params in sections.items() }
def _process_path(path: str): """ Frontend does a partial escaping on the path so the all '.' in section and key names are escaped Need to unescape and apply a full mongo escaping """ parts = path.split(".") if len(parts) < 2 or len(parts) > 3: raise errors.bad_request.ValidationError("invalid task field", path=path) return ".".join( ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p)) for p in parts )
def params_prepare_for_save(fields: dict, previous_task: Task = None): """ If legacy hyper params or configuration is passed then replace the corresponding section in the new structure Escape all the section and param names for hyper params and configuration to make it mongo sage """ for old_params_field, new_params_field, default_section in ( ("execution/parameters", "hyperparams", hyperparams_default_section), ("execution/model_desc", "configuration", None), ): legacy_params = safe_get(fields, old_params_field) if legacy_params is None: continue if ( not safe_get(fields, new_params_field) and previous_task and previous_task[new_params_field] ): previous_data = previous_task.to_proper_dict().get(new_params_field) removed = _remove_legacy_params( previous_data, with_sections=default_section is not None ) if not legacy_params and not removed: # if we only need to delete legacy fields from the db # but they are not there then there is no point to proceed continue fields_update = {new_params_field: previous_data} params_unprepare_from_saved(fields_update) fields.update(fields_update) for full_name, value in legacy_params.items(): section, name = split_param_name(full_name, default_section) new_path = list(filter(None, (new_params_field, section, name))) new_param = dict(name=name, type=hyperparams_legacy_type, value=str(value)) if section is not None: new_param["section"] = section dpath.new(fields, new_path, new_param) dpath.delete(fields, old_params_field) for param_field in ("hyperparams", "configuration"): params = safe_get(fields, param_field) if params: escaped_params = { ParameterKeyEscaper.escape(key): { ParameterKeyEscaper.escape(k): v for k, v in value.items() } if isinstance(value, dict) else value for key, value in params.items() } dpath.set(fields, param_field, escaped_params)
def get_configuration_names(cls, company_id: str, task_ids: Sequence[str]) -> Dict[str, list]: pipeline = [ { "$match": { "company": { "$in": [None, "", company_id] }, "_id": { "$in": task_ids }, } }, { "$project": { "items": { "$objectToArray": "$configuration" } } }, { "$unwind": "$items" }, { "$group": { "_id": "$_id", "names": { "$addToSet": "$items.k" } } }, ] tasks = Task.aggregate(pipeline) return { task["_id"]: { "names": sorted( ParameterKeyEscaper.unescape(name) for name in task["names"]) } for task in tasks }
def edit_configuration( cls, company_id: str, task_id: str, configuration: Sequence[Configuration], replace_configuration: bool, ) -> int: task = cls._get_task_for_update(company=company_id, id=task_id) update_cmds = dict() configuration = { ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct()) for c in configuration } if replace_configuration: update_cmds["set__configuration"] = configuration else: for name, value in configuration.items(): update_cmds[f"set__configuration__{name}"] = value return task.update(**update_cmds, last_update=datetime.utcnow())
def get_aggregated_project_parameters( company_id, project_ids: Sequence[str] = None, page: int = 0, page_size: int = 500, ) -> Tuple[int, int, Sequence[dict]]: page = max(0, page) page_size = max(1, page_size) pipeline = [ { "$match": { "company": company_id, "hyperparams": {"$exists": True, "$gt": {}}, **({"project": {"$in": project_ids}} if project_ids else {}), } }, {"$project": {"sections": {"$objectToArray": "$hyperparams"}}}, {"$unwind": "$sections"}, { "$project": { "section": "$sections.k", "names": {"$objectToArray": "$sections.v"}, } }, {"$unwind": "$names"}, {"$group": {"_id": {"section": "$section", "name": "$names.k"}}}, {"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})}, { "$group": { "_id": 1, "total": {"$sum": 1}, "results": {"$push": "$$ROOT"}, } }, { "$project": { "total": 1, "results": {"$slice": ["$results", page * page_size, page_size]}, } }, ] with translate_errors_context(): result = next(Task.aggregate(pipeline), None) total = 0 remaining = 0 results = [] if result: total = int(result.get("total", -1)) results = [ { "section": ParameterKeyEscaper.unescape( dpath.get(r, "_id/section") ), "name": ParameterKeyEscaper.unescape(dpath.get(r, "_id/name")), } for r in result.get("results", []) ] remaining = max(0, total - (len(results) + page * page_size)) return total, remaining, results