Beispiel #1
0
 def _db_dicts_from_list(
         cls, items: Sequence[HyperParamItem]) -> Dict[str, dict]:
     sections = iterutils.bucketize(items, key=attrgetter("section"))
     return {
         ParameterKeyEscaper.escape(section): {
             ParameterKeyEscaper.escape(param.name):
             ParamsItem(**param.to_struct())
             for param in params
         }
         for section, params in sections.items()
     }
Beispiel #2
0
def params_prepare_for_save(fields: dict, previous_task: Task = None):
    """
    If legacy hyper params or configuration is passed then replace the corresponding section in the new structure
    Escape all the section and param names for hyper params and configuration to make it mongo sage
    """
    for old_params_field, new_params_field, default_section in (
        (("execution", "parameters"), "hyperparams",
         hyperparams_default_section),
        (("execution", "model_desc"), "configuration", None),
    ):
        legacy_params = nested_get(fields, old_params_field)
        if legacy_params is None:
            continue

        if (not fields.get(new_params_field) and previous_task
                and previous_task[new_params_field]):
            previous_data = previous_task.to_proper_dict().get(
                new_params_field)
            removed = _remove_legacy_params(previous_data,
                                            with_sections=default_section
                                            is not None)
            if not legacy_params and not removed:
                # if we only need to delete legacy fields from the db
                # but they are not there then there is no point to proceed
                continue

            fields_update = {new_params_field: previous_data}
            params_unprepare_from_saved(fields_update)
            fields.update(fields_update)

        for full_name, value in legacy_params.items():
            section, name = split_param_name(full_name, default_section)
            new_path = list(filter(None, (new_params_field, section, name)))
            new_param = dict(name=name,
                             type=hyperparams_legacy_type,
                             value=str(value))
            if section is not None:
                new_param["section"] = section
            nested_set(fields, new_path, new_param)
        nested_delete(fields, old_params_field)

    for param_field in ("hyperparams", "configuration"):
        params = fields.get(param_field)
        if params:
            escaped_params = {
                ParameterKeyEscaper.escape(key):
                {ParameterKeyEscaper.escape(k): v
                 for k, v in value.items()}
                if isinstance(value, dict) else value
                for key, value in params.items()
            }
            fields[param_field] = escaped_params
Beispiel #3
0
    def edit_configuration(
        cls,
        company_id: str,
        task_id: str,
        configuration: Sequence[Configuration],
        replace_configuration: bool,
        force: bool,
    ) -> int:
        with TimingContext("mongo", "edit_configuration"):
            task = get_task_for_update(company_id=company_id,
                                       task_id=task_id,
                                       force=force)

            update_cmds = dict()
            configuration = {
                ParameterKeyEscaper.escape(c.name):
                ConfigurationItem(**c.to_struct())
                for c in configuration
            }
            if replace_configuration:
                update_cmds["set__configuration"] = configuration
            else:
                for name, value in configuration.items():
                    update_cmds[
                        f"set__configuration__{mongoengine_safe(name)}"] = value

            return update_task(task, update_cmds=update_cmds)
Beispiel #4
0
    def metadata_from_api(
        api_data: Union[Mapping[str, MetadataItem], Sequence[MetadataItem]]
    ) -> dict:
        if not api_data:
            return {}

        if isinstance(api_data, dict):
            return {
                ParameterKeyEscaper.escape(k): v.to_struct()
                for k, v in api_data.items()
            }

        return {
            ParameterKeyEscaper.escape(item.key): item.to_struct()
            for item in api_data
        }
Beispiel #5
0
 def _upgrade_model_data(model_data: dict) -> dict:
     metadata_key = "metadata"
     metadata = model_data.get(metadata_key)
     if isinstance(metadata, list):
         metadata = {
             ParameterKeyEscaper.escape(item["key"]): item
             for item in metadata
             if isinstance(item, dict) and "key" in item
         }
         model_data[metadata_key] = metadata
     return model_data
Beispiel #6
0
def escape_metadata(document: dict):
    """
    Escape special characters in metadata keys
    """
    metadata = document.get("metadata")
    if not metadata:
        return

    document["metadata"] = {
        ParameterKeyEscaper.escape(k): v
        for k, v in metadata.items()
    }
Beispiel #7
0
def _process_path(path: str):
    """
    Frontend does a partial escaping on the path so the all '.' in section and key names are escaped
    Need to unescape and apply a full mongo escaping
    """
    parts = path.split(".")
    if len(parts) < 2 or len(parts) > 4:
        raise errors.bad_request.ValidationError("invalid task field",
                                                 path=path)
    return ".".join(
        ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p))
        for p in parts)
Beispiel #8
0
    def delete_params(
        cls,
        company_id: str,
        task_id: str,
        hyperparams: Sequence[HyperParamKey],
        force: bool,
    ) -> int:
        with TimingContext("mongo", "delete_hyperparams"):
            properties_only = cls._normalize_params(hyperparams)
            task = get_task_for_update(
                company_id=company_id,
                task_id=task_id,
                allow_all_statuses=properties_only,
                force=force,
            )

            with_param, without_param = iterutils.partition(
                hyperparams, key=lambda p: bool(p.name))
            sections_to_delete = {p.section for p in without_param}
            delete_cmds = {
                f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1
                for section in sections_to_delete
            }

            for item in with_param:
                section = ParameterKeyEscaper.escape(item.section)
                if item.section in sections_to_delete:
                    raise errors.bad_request.FieldsConflict(
                        "Cannot delete section field if the whole section was scheduled for deletion"
                    )
                name = ParameterKeyEscaper.escape(item.name)
                delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1

            return update_task(task,
                               update_cmds=delete_cmds,
                               set_last_update=not properties_only)
Beispiel #9
0
def escape_dict(data: dict) -> dict:
    if not data:
        return data

    return {ParameterKeyEscaper.escape(k): v for k, v in data.items()}