Exemple #1
0
def params_unprepare_from_saved(fields, copy_to_legacy=False):
    """
    Unescape all section and param names for hyper params and configuration
    If copy_to_legacy is set then copy hyperparams and configuration data to the legacy location for the old clients
    """
    for param_field in ("hyperparams", "configuration"):
        params = fields.get(param_field)
        if params:
            unescaped_params = {
                ParameterKeyEscaper.unescape(key):
                {ParameterKeyEscaper.unescape(k): v
                 for k, v in value.items()}
                if isinstance(value, dict) else value
                for key, value in params.items()
            }
            fields[param_field] = unescaped_params

    if copy_to_legacy:
        for new_params_field, old_params_field, use_sections in (
            ("hyperparams", ("execution", "parameters"), True),
            ("configuration", ("execution", "model_desc"), False),
        ):
            legacy_params = _get_legacy_params(fields.get(new_params_field),
                                               with_sections=use_sections)
            if legacy_params:
                nested_set(
                    fields,
                    old_params_field,
                    {
                        _get_full_param_name(p): p["value"]
                        for p in legacy_params
                    },
                )
Exemple #2
0
 def _db_dicts_from_list(
         cls, items: Sequence[HyperParamItem]) -> Dict[str, dict]:
     sections = iterutils.bucketize(items, key=attrgetter("section"))
     return {
         ParameterKeyEscaper.escape(section): {
             ParameterKeyEscaper.escape(param.name):
             ParamsItem(**param.to_struct())
             for param in params
         }
         for section, params in sections.items()
     }
Exemple #3
0
def _process_path(path: str):
    """
    Frontend does a partial escaping on the path so the all '.' in section and key names are escaped
    Need to unescape and apply a full mongo escaping
    """
    parts = path.split(".")
    if len(parts) < 2 or len(parts) > 4:
        raise errors.bad_request.ValidationError("invalid task field",
                                                 path=path)
    return ".".join(
        ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p))
        for p in parts)
Exemple #4
0
def params_prepare_for_save(fields: dict, previous_task: Task = None):
    """
    If legacy hyper params or configuration is passed then replace the corresponding section in the new structure
    Escape all the section and param names for hyper params and configuration to make it mongo sage
    """
    for old_params_field, new_params_field, default_section in (
        (("execution", "parameters"), "hyperparams",
         hyperparams_default_section),
        (("execution", "model_desc"), "configuration", None),
    ):
        legacy_params = nested_get(fields, old_params_field)
        if legacy_params is None:
            continue

        if (not fields.get(new_params_field) and previous_task
                and previous_task[new_params_field]):
            previous_data = previous_task.to_proper_dict().get(
                new_params_field)
            removed = _remove_legacy_params(previous_data,
                                            with_sections=default_section
                                            is not None)
            if not legacy_params and not removed:
                # if we only need to delete legacy fields from the db
                # but they are not there then there is no point to proceed
                continue

            fields_update = {new_params_field: previous_data}
            params_unprepare_from_saved(fields_update)
            fields.update(fields_update)

        for full_name, value in legacy_params.items():
            section, name = split_param_name(full_name, default_section)
            new_path = list(filter(None, (new_params_field, section, name)))
            new_param = dict(name=name,
                             type=hyperparams_legacy_type,
                             value=str(value))
            if section is not None:
                new_param["section"] = section
            nested_set(fields, new_path, new_param)
        nested_delete(fields, old_params_field)

    for param_field in ("hyperparams", "configuration"):
        params = fields.get(param_field)
        if params:
            escaped_params = {
                ParameterKeyEscaper.escape(key):
                {ParameterKeyEscaper.escape(k): v
                 for k, v in value.items()}
                if isinstance(value, dict) else value
                for key, value in params.items()
            }
            fields[param_field] = escaped_params
Exemple #5
0
    def edit_configuration(
        cls,
        company_id: str,
        task_id: str,
        configuration: Sequence[Configuration],
        replace_configuration: bool,
        force: bool,
    ) -> int:
        with TimingContext("mongo", "edit_configuration"):
            task = get_task_for_update(company_id=company_id,
                                       task_id=task_id,
                                       force=force)

            update_cmds = dict()
            configuration = {
                ParameterKeyEscaper.escape(c.name):
                ConfigurationItem(**c.to_struct())
                for c in configuration
            }
            if replace_configuration:
                update_cmds["set__configuration"] = configuration
            else:
                for name, value in configuration.items():
                    update_cmds[
                        f"set__configuration__{mongoengine_safe(name)}"] = value

            return update_task(task, update_cmds=update_cmds)
Exemple #6
0
    def metadata_from_api(
        api_data: Union[Mapping[str, MetadataItem], Sequence[MetadataItem]]
    ) -> dict:
        if not api_data:
            return {}

        if isinstance(api_data, dict):
            return {
                ParameterKeyEscaper.escape(k): v.to_struct()
                for k, v in api_data.items()
            }

        return {
            ParameterKeyEscaper.escape(item.key): item.to_struct()
            for item in api_data
        }
Exemple #7
0
 def _upgrade_model_data(model_data: dict) -> dict:
     metadata_key = "metadata"
     metadata = model_data.get(metadata_key)
     if isinstance(metadata, list):
         metadata = {
             ParameterKeyEscaper.escape(item["key"]): item
             for item in metadata
             if isinstance(item, dict) and "key" in item
         }
         model_data[metadata_key] = metadata
     return model_data
Exemple #8
0
def escape_metadata(document: dict):
    """
    Escape special characters in metadata keys
    """
    metadata = document.get("metadata")
    if not metadata:
        return

    document["metadata"] = {
        ParameterKeyEscaper.escape(k): v
        for k, v in metadata.items()
    }
Exemple #9
0
    def delete_params(
        cls,
        company_id: str,
        task_id: str,
        hyperparams: Sequence[HyperParamKey],
        force: bool,
    ) -> int:
        with TimingContext("mongo", "delete_hyperparams"):
            properties_only = cls._normalize_params(hyperparams)
            task = get_task_for_update(
                company_id=company_id,
                task_id=task_id,
                allow_all_statuses=properties_only,
                force=force,
            )

            with_param, without_param = iterutils.partition(
                hyperparams, key=lambda p: bool(p.name))
            sections_to_delete = {p.section for p in without_param}
            delete_cmds = {
                f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1
                for section in sections_to_delete
            }

            for item in with_param:
                section = ParameterKeyEscaper.escape(item.section)
                if item.section in sections_to_delete:
                    raise errors.bad_request.FieldsConflict(
                        "Cannot delete section field if the whole section was scheduled for deletion"
                    )
                name = ParameterKeyEscaper.escape(item.name)
                delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1

            return update_task(task,
                               update_cmds=delete_cmds,
                               set_last_update=not properties_only)
Exemple #10
0
    def get_configuration_names(cls, company_id: str,
                                task_ids: Sequence[str]) -> Dict[str, list]:
        with TimingContext("mongo", "get_configuration_names"):
            pipeline = [
                {
                    "$match": {
                        "company": {
                            "$in": [None, "", company_id]
                        },
                        "_id": {
                            "$in": task_ids
                        },
                    }
                },
                {
                    "$project": {
                        "items": {
                            "$objectToArray": "$configuration"
                        }
                    }
                },
                {
                    "$unwind": "$items"
                },
                {
                    "$group": {
                        "_id": "$_id",
                        "names": {
                            "$addToSet": "$items.k"
                        }
                    }
                },
            ]

            tasks = Task.aggregate(pipeline)

            return {
                task["_id"]: {
                    "names":
                    sorted(
                        ParameterKeyEscaper.unescape(name)
                        for name in task["names"])
                }
                for task in tasks
            }
Exemple #11
0
def unescape_metadata(call: APICall, documents: Union[dict, Sequence[dict]]):
    """
    Unescape special characters in metadata keys
    """
    if isinstance(documents, dict):
        documents = [documents]

    old_client = call.requested_endpoint_version <= PartialVersion("2.16")
    for doc in documents:
        if old_client and "metadata" in doc:
            doc["metadata"] = []
            continue

        metadata = doc.get("metadata")
        if not metadata:
            continue

        doc["metadata"] = {
            ParameterKeyEscaper.unescape(k): v
            for k, v in metadata.items()
        }
Exemple #12
0
    def get_aggregated_project_parameters(
        cls,
        company_id,
        project_ids: Sequence[str],
        include_subprojects: bool,
        page: int = 0,
        page_size: int = 500,
    ) -> Tuple[int, int, Sequence[dict]]:
        page = max(0, page)
        page_size = max(1, page_size)
        pipeline = [
            {
                "$match": {
                    **cls._get_company_constraint(company_id),
                    **cls._get_project_constraint(project_ids, include_subprojects),
                    "hyperparams": {
                        "$exists": True,
                        "$gt": {}
                    },
                }
            },
            {
                "$project": {
                    "sections": {
                        "$objectToArray": "$hyperparams"
                    }
                }
            },
            {
                "$unwind": "$sections"
            },
            {
                "$project": {
                    "section": "$sections.k",
                    "names": {
                        "$objectToArray": "$sections.v"
                    },
                }
            },
            {
                "$unwind": "$names"
            },
            {
                "$group": {
                    "_id": {
                        "section": "$section",
                        "name": "$names.k"
                    }
                }
            },
            {
                "$sort": OrderedDict({
                    "_id.section": 1,
                    "_id.name": 1
                })
            },
            {
                "$skip": page * page_size
            },
            {
                "$limit": page_size
            },
            {
                "$group": {
                    "_id": 1,
                    "total": {
                        "$sum": 1
                    },
                    "results": {
                        "$push": "$$ROOT"
                    },
                }
            },
        ]

        result = next(Task.aggregate(pipeline), None)

        total = 0
        remaining = 0
        results = []

        if result:
            total = int(result.get("total", -1))
            results = [{
                "section":
                ParameterKeyEscaper.unescape(nested_get(r,
                                                        ("_id", "section"))),
                "name":
                ParameterKeyEscaper.unescape(nested_get(r, ("_id", "name"))),
            } for r in result.get("results", [])]
            remaining = max(0, total - (len(results) + page * page_size))

        return total, remaining, results
Exemple #13
0
    def get_model_metadata_keys(
        cls,
        company_id,
        project_ids: Sequence[str],
        include_subprojects: bool,
        page: int = 0,
        page_size: int = 500,
    ) -> Tuple[int, int, Sequence[dict]]:
        page = max(0, page)
        page_size = max(1, page_size)
        pipeline = [
            {
                "$match": {
                    **cls._get_company_constraint(company_id),
                    **cls._get_project_constraint(project_ids, include_subprojects),
                    "metadata": {
                        "$exists": True,
                        "$gt": {}
                    },
                }
            },
            {
                "$project": {
                    "metadata": {
                        "$objectToArray": "$metadata"
                    }
                }
            },
            {
                "$unwind": "$metadata"
            },
            {
                "$group": {
                    "_id": "$metadata.k"
                }
            },
            {
                "$sort": {
                    "_id": 1
                }
            },
            {
                "$skip": page * page_size
            },
            {
                "$limit": page_size
            },
            {
                "$group": {
                    "_id": 1,
                    "total": {
                        "$sum": 1
                    },
                    "results": {
                        "$push": "$$ROOT"
                    },
                }
            },
        ]

        result = next(Model.aggregate(pipeline), None)

        total = 0
        remaining = 0
        results = []

        if result:
            total = int(result.get("total", -1))
            results = [
                ParameterKeyEscaper.unescape(r.get("_id"))
                for r in result.get("results", [])
            ]
            remaining = max(0, total - (len(results) + page * page_size))

        return total, remaining, results
Exemple #14
0
def unescape_dict(data: dict) -> dict:
    if not data:
        return data

    return {ParameterKeyEscaper.unescape(k): v for k, v in data.items()}
Exemple #15
0
    def get_aggregated_project_parameters(
        company_id,
        project_ids: Sequence[str] = None,
        page: int = 0,
        page_size: int = 500,
    ) -> Tuple[int, int, Sequence[dict]]:

        page = max(0, page)
        page_size = max(1, page_size)
        pipeline = [
            {
                "$match": {
                    "company": {
                        "$in": [None, "", company_id]
                    },
                    "hyperparams": {
                        "$exists": True,
                        "$gt": {}
                    },
                    **({
                        "project": {
                            "$in": project_ids
                        }
                    } if project_ids else {}),
                }
            },
            {
                "$project": {
                    "sections": {
                        "$objectToArray": "$hyperparams"
                    }
                }
            },
            {
                "$unwind": "$sections"
            },
            {
                "$project": {
                    "section": "$sections.k",
                    "names": {
                        "$objectToArray": "$sections.v"
                    },
                }
            },
            {
                "$unwind": "$names"
            },
            {
                "$group": {
                    "_id": {
                        "section": "$section",
                        "name": "$names.k"
                    }
                }
            },
            {
                "$sort": OrderedDict({
                    "_id.section": 1,
                    "_id.name": 1
                })
            },
            {
                "$group": {
                    "_id": 1,
                    "total": {
                        "$sum": 1
                    },
                    "results": {
                        "$push": "$$ROOT"
                    },
                }
            },
            {
                "$project": {
                    "total": 1,
                    "results": {
                        "$slice": ["$results", page * page_size, page_size]
                    },
                }
            },
        ]

        with translate_errors_context():
            result = next(Task.aggregate(pipeline), None)

        total = 0
        remaining = 0
        results = []

        if result:
            total = int(result.get("total", -1))
            results = [{
                "section":
                ParameterKeyEscaper.unescape(dpath.get(r, "_id/section")),
                "name":
                ParameterKeyEscaper.unescape(dpath.get(r, "_id/name")),
            } for r in result.get("results", [])]
            remaining = max(0, total - (len(results) + page * page_size))

        return total, remaining, results