Ejemplo n.º 1
0
    def _parse_errors(self, ve, params):
        """
        Parse the error messages given by marshmallow.

        Marshamllow error structure:

        {
            "list_param": {
                0: {
                    "value": {
                        0: [err message for first item in value list]
                        i: [err message for i-th item in value list]
                    }
                },
                i-th value object: {
                    "value": {
                        0: [...],
                        ...
                    }
                },
            }
            "nonlist_param": {
                0: {
                    "value": [err message]
                },
                ...
            }
        }

        self._errors structure:
        {
            "messages": {
                "param": [
                    ["value": {0: [msg0, msg1, ...], other_bad_ix: ...},
                     "label0": {0: msg, ...} // if errors on label values.
                ],
                ...
            },
            "label": {
                "param": [
                    {label_name: label_value, other_label_name: other_label_value},
                    ...
                    // list indices correspond to the error messages' indices
                    // of the error messages caused by the value of this value
                    // object.
                ]
            }
        }

        """
        error_info = {
            "messages": defaultdict(dict),
            "labels": defaultdict(dict),
        }

        for pname, data in ve.messages.items():
            if pname == "_schema":
                error_info["messages"]["schema"] = [
                    f"Data format error: {data}"
                ]
                continue
            if data == ["Unknown field."]:
                error_info["messages"]["schema"] = [f"Unknown field: {pname}"]
                continue
            param_data = utils.ensure_value_object(params[pname])
            error_labels = []
            formatted_errors = []
            for ix, marshmessages in data.items():
                error_labels.append(
                    utils.filter_labels(param_data[ix], drop=["value"]))
                formatted_errors_ix = []
                for _, messages in marshmessages.items():
                    if messages:
                        if isinstance(messages, list):
                            formatted_errors_ix += messages
                        else:
                            for _, messagelist in messages.items():
                                formatted_errors_ix += messagelist
                formatted_errors.append(formatted_errors_ix)
            error_info["messages"][pname] = formatted_errors
            error_info["labels"][pname] = error_labels

        self._errors.update(dict(error_info))
Ejemplo n.º 2
0
    def extend(
        self,
        label_to_extend=None,
        label_to_extend_values=None,
        params=None,
        raise_errors=True,
    ):
        """
        Extend parameters along label_to_extend.

        Raises:
            InconsistentLabelsException: Value objects do not have consistent
                labels.
        """
        if label_to_extend is None:
            label_to_extend = self.label_to_extend

        spec = self.specification(meta_data=True)
        if params is not None:
            spec = {
                param: self._data[param]
                for param, data in spec.items() if param in params
            }
        extend_grid = (label_to_extend_values
                       or self._stateless_label_grid[label_to_extend])
        adjustment = defaultdict(list)
        for param, data in spec.items():
            if not any(label_to_extend in vo for vo in data["value"]):
                continue
            extended_vos = set()
            for vo in sorted(
                    data["value"],
                    key=lambda val: extend_grid.index(val[label_to_extend]),
            ):
                hashable_vo = utils.hashable_value_object(vo)
                if hashable_vo in extended_vos:
                    continue
                else:
                    extended_vos.add(hashable_vo)
                gt = select_gt_ix(
                    self._data[param]["value"],
                    False,
                    {label_to_extend: vo[label_to_extend]},
                    extend_grid,
                    tree=self._search_trees.get(param),
                )
                eq = select_eq(
                    gt,
                    False,
                    utils.filter_labels(vo, drop=["value", label_to_extend]),
                )
                extended_vos.update(map(utils.hashable_value_object, eq))
                eq += [vo]

                defined_vals = {eq_vo[label_to_extend] for eq_vo in eq}

                missing_vals = sorted(
                    set(extend_grid) - defined_vals,
                    key=lambda val: extend_grid.index(val),
                )

                if not missing_vals:
                    continue

                extended = defaultdict(list)

                for val in missing_vals:
                    eg_ix = extend_grid.index(val)
                    if eg_ix == 0:
                        first_defined_value = min(
                            defined_vals,
                            key=lambda val: extend_grid.index(val),
                        )
                        value_objects = select_eq(
                            eq, True, {label_to_extend: first_defined_value})
                    elif extend_grid[eg_ix - 1] in extended:
                        value_objects = extended.pop(extend_grid[eg_ix - 1])
                    else:
                        prev_defined_value = extend_grid[eg_ix - 1]
                        value_objects = select_eq(
                            eq, True, {label_to_extend: prev_defined_value})
                    # In practice, value_objects has length one.
                    # Theoretically, there could be multiple if the inital value
                    # object had less labels than later value objects and thus
                    # matched multiple value objects.
                    for value_object in value_objects:
                        ext = dict(value_object, **{label_to_extend: val})
                        ext = self.extend_func(
                            param,
                            ext,
                            value_object,
                            extend_grid,
                            label_to_extend,
                        )
                        extended_vos.add(
                            utils.hashable_value_object(value_object))
                        extended[val].append(ext)
                        adjustment[param].append(ext)
        # Ensure that the adjust method of paramtools.Parameter is used
        # in case the child class also implements adjust.
        self._adjust(adjustment, extend_adj=False, raise_errors=raise_errors)
Ejemplo n.º 3
0
    def _adjust(self, params_or_path, raise_errors=True, extend_adj=True):
        """
        Internal method for performing adjustments.
        """
        params = self.read_params(params_or_path)

        # Validate user adjustments.
        parsed_params = {}
        try:
            parsed_params = self._validator_schema.load(params)
        except MarshmallowValidationError as ve:
            self._parse_errors(ve, params)

        if not self._errors:
            if self.label_to_extend is not None and extend_adj:
                extend_grid = self._stateless_label_grid[self.label_to_extend]
                to_delete = defaultdict(list)
                backup = {}
                for param, vos in parsed_params.items():
                    for vo in utils.grid_sort(vos, self.label_to_extend,
                                              extend_grid):
                        if self.label_to_extend in vo:
                            if (vo[self.label_to_extend] not in
                                    self.label_grid[self.label_to_extend]):
                                msg = (
                                    f"{param}[{self.label_to_extend}={vo[self.label_to_extend]}] "
                                    f"is not active in the current state: "
                                    f"{self.label_to_extend}= "
                                    f"{self.label_grid[self.label_to_extend]}."
                                )
                                warnings.warn(msg)
                            gt = select_gt_ix(
                                self._data[param]["value"],
                                True,
                                {
                                    self.label_to_extend:
                                    vo[self.label_to_extend]
                                },
                                extend_grid,
                                tree=self._search_trees.get(param),
                            )
                            eq = select_eq(
                                gt,
                                True,
                                utils.filter_labels(
                                    vo, drop=[self.label_to_extend, "value"]),
                            )
                            to_delete[param] += [
                                dict(td, **{"value": None}) for td in eq
                            ]
                    # make copy of value objects since they
                    # are about to be modified
                    backup[param] = copy.deepcopy(self._data[param]["value"])
                try:
                    array_first = self.array_first
                    self.array_first = False

                    # delete params that will be overwritten out by extend.
                    self._adjust(to_delete,
                                 extend_adj=False,
                                 raise_errors=True)

                    # set user adjustments.
                    self._adjust(parsed_params,
                                 extend_adj=False,
                                 raise_errors=True)
                    self.extend(params=parsed_params.keys(), raise_errors=True)
                except ValidationError:
                    for param in backup:
                        self._data[param]["value"] = backup[param]
                finally:
                    self.array_first = array_first
            else:
                for param, value in parsed_params.items():
                    self._update_param(param, value)

        self._validator_schema.context["spec"] = self

        if raise_errors and self._errors:
            raise self.validation_error

        # Update attrs for params that were adjusted.
        self._set_state(params=parsed_params.keys())

        return parsed_params
Ejemplo n.º 4
0
    def extend(
        self,
        label: Optional[str] = None,
        label_values: Optional[List[Any]] = None,
        params: Optional[List[str]] = None,
        raise_errors: bool = True,
        ignore_warnings: bool = False,
    ):
        """
        Extend parameters along `label`.

        **Parameters**

        - `label`: Label to extend values along. By default, `label_to_extend`
          is used.
        - `label_values`: values of `label` to extend. By default, this is a grid
          created from the valid values of `label_to_extend`.
        - `params`: Parameters to extend. By default, all parameters are extended.
        - `raise_errors`: Whether `adjust` should raise or store errors.
        - `ignore_warnings`: Whether `adjust` should raise or ignore warnings.

        **Raises**

          - `InconsistentLabelsException`: Value objects do not have consistent
            labels.
        """
        if label is None:
            label = self.label_to_extend
        else:
            label = label

        spec = self.specification(meta_data=True)
        if params is not None:
            spec = {
                param: self._data[param]
                for param, data in spec.items() if param in params
            }
        full_extend_grid = self._stateless_label_grid[label]
        if label_values is not None:
            labels = self.parse_labels(**{label: label_values})
            extend_grid = labels[label]
        else:
            extend_grid = self._stateless_label_grid[label]

        cmp_funcs = self.label_validators[label].cmp_funcs(choices=extend_grid)
        gt_cmp_func = make_cmp_func(cmp_funcs["gt"], all_or_any=all)

        adjustment = defaultdict(list)
        for param, data in spec.items():
            if not any(label in vo for vo in data["value"]):
                continue
            extended_vos = set()
            for vo in sorted(data["value"],
                             key=lambda val: cmp_funcs["key"](val[label])):
                hashable_vo = utils.hashable_value_object(vo)
                if hashable_vo in extended_vos:
                    continue
                else:
                    extended_vos.add(hashable_vo)
                gt = select(
                    self._data[param]["value"],
                    False,
                    gt_cmp_func,
                    {label: vo[label]},
                    tree=self._search_trees.get(param),
                )
                eq = select_eq(
                    gt,
                    False,
                    utils.filter_labels(vo, drop=["value", label, "_auto"]),
                )
                extended_vos.update(map(utils.hashable_value_object, eq))
                eq += [vo]

                defined_vals = {eq_vo[label] for eq_vo in eq}

                missing_vals = sorted(set(extend_grid) - defined_vals,
                                      key=cmp_funcs["key"])

                if not missing_vals:
                    continue

                extended = defaultdict(list)
                for vo in eq:
                    extended[vo[label]].append(vo)

                skl = utils.SortedKeyList(extended.keys(), cmp_funcs["key"])

                for val in missing_vals:
                    lte_val = skl.lte(val)
                    if lte_val is not None:
                        closest_val = lte_val
                    else:
                        closest_val = skl.gte(val)

                    if closest_val in extended:
                        value_objects = extended.pop(closest_val)
                    else:
                        value_objects = select_eq(eq, False,
                                                  {label: closest_val})
                    # In practice, value_objects has length one.
                    # Theoretically, there could be multiple if the inital value
                    # object had less labels than later value objects and thus
                    # matched multiple value objects.
                    for value_object in value_objects:
                        ext = dict(value_object, **{label: val})
                        ext = self.extend_func(param, ext, value_object,
                                               full_extend_grid, label)
                        extended_vos.add(
                            utils.hashable_value_object(value_object))
                        extended[val].append(ext)
                        skl.insert(val)
                        adjustment[param].append(dict(ext, _auto=True))
        # Ensure that the adjust method of paramtools.Parameter is used
        # in case the child class also implements adjust.
        self._adjust(
            adjustment,
            extend_adj=False,
            ignore_warnings=ignore_warnings,
            raise_errors=raise_errors,
            is_deserialized=True,
        )
Ejemplo n.º 5
0
    def _adjust(
        self,
        params_or_path,
        ignore_warnings=False,
        raise_errors=True,
        extend_adj=True,
        is_deserialized=False,
        clobber=True,
    ):
        """
        Internal method for performing adjustments.
        """
        # Validate user adjustments.
        if is_deserialized:
            parsed_params = {}
            try:
                parsed_params = self._validator_schema.load(
                    params_or_path, ignore_warnings, is_deserialized=True)
            except MarshmallowValidationError as ve:
                self._parse_validation_messages(ve.messages, params_or_path)
        else:
            params = self.read_params(params_or_path)
            parsed_params = {}
            try:
                parsed_params = self._validator_schema.load(
                    params, ignore_warnings)
            except MarshmallowValidationError as ve:
                self._parse_validation_messages(ve.messages, params)

        if not self._errors:
            if self.label_to_extend is not None and extend_adj:
                extend_grid = self._stateless_label_grid[self.label_to_extend]
                to_delete = defaultdict(list)
                backup = {}
                cmp_funcs = self.label_validators[
                    self.label_to_extend].cmp_funcs()
                gt_cmp_func = make_cmp_func(cmp_funcs["gt"], all_or_any=all)
                for param, vos in parsed_params.items():
                    for vo in utils.grid_sort(vos, self.label_to_extend,
                                              extend_grid):

                        if self.label_to_extend in vo:
                            query_args = {
                                self.label_to_extend: vo[self.label_to_extend]
                            }
                            if clobber:
                                queryset = self._data[param]["value"]
                                tree = self._search_trees.get(param)
                            else:
                                queryset = self.select_eq(param,
                                                          strict=True,
                                                          _auto=True)
                                tree = None
                            gt = select(
                                queryset,
                                strict=False,
                                cmp_func=gt_cmp_func,
                                labels=query_args,
                                tree=tree,
                            )
                            to_delete[param] += select_eq(
                                gt,
                                strict=False,
                                labels=utils.filter_labels(
                                    vo,
                                    drop=[
                                        self.label_to_extend,
                                        "value",
                                        "_auto",
                                    ],
                                ),
                            )
                    # make copy of value objects since they
                    # are about to be modified
                    backup[param] = copy.deepcopy(self._data[param]["value"])
                try:
                    array_first = self.array_first
                    self.array_first = False

                    # delete params that will be overwritten out by extend.
                    self.delete(
                        to_delete,
                        extend_adj=False,
                        raise_errors=True,
                        ignore_warnings=ignore_warnings,
                    )

                    # set user adjustments.
                    self._adjust(
                        parsed_params,
                        extend_adj=False,
                        raise_errors=True,
                        ignore_warnings=ignore_warnings,
                    )
                    self.extend(
                        params=parsed_params.keys(),
                        ignore_warnings=ignore_warnings,
                        raise_errors=True,
                    )
                except ValidationError:
                    for param in backup:
                        self._data[param]["value"] = backup[param]
                finally:
                    self.array_first = array_first
            else:
                for param, value in parsed_params.items():
                    self._update_param(param, value)

        self._validator_schema.context["spec"] = self

        has_errors = bool(self._errors.get("messages"))
        has_warnings = bool(self._warnings.get("messages"))
        # throw error if raise_errors is True or ignore_warnings is False
        if (raise_errors and has_errors) or (not ignore_warnings
                                             and has_warnings):
            raise self.validation_error

        # Update attrs for params that were adjusted.
        self._set_state(params=parsed_params.keys())

        return parsed_params
Ejemplo n.º 6
0
    def adjust(self, params_or_path, raise_errors=True, extend_adj=True):
        """
        Deserialize and validate parameter adjustments. `params_or_path`
        can be a file path or a `dict` that has not been fully deserialized.
        The adjusted values replace the current values stored in the
        corresponding parameter attributes.

        Returns: parsed, validated parameters.

        Raises:
            marshmallow.exceptions.ValidationError if data is not valid.

            ParameterUpdateException if label values do not match at
                least one existing value item's corresponding label values.
        """
        params = self.read_params(params_or_path)

        # Validate user adjustments.
        parsed_params = {}
        try:
            parsed_params = self._validator_schema.load(params)
        except MarshmallowValidationError as ve:
            self._parse_errors(ve, params)

        if not self._errors:
            if self.label_to_extend is not None and extend_adj:
                extend_grid = self._stateless_label_grid[self.label_to_extend]
                for param, vos in parsed_params.items():
                    to_delete = []
                    for vo in utils.grid_sort(vos, self.label_to_extend,
                                              extend_grid):
                        if self.label_to_extend in vo:
                            if (vo[self.label_to_extend] not in
                                    self.label_grid[self.label_to_extend]):
                                msg = (
                                    f"{param}[{self.label_to_extend}={vo[self.label_to_extend]}] "
                                    f"is not active in the current state: "
                                    f"{self.label_to_extend}= "
                                    f"{self.label_grid[self.label_to_extend]}."
                                )
                                warnings.warn(msg)
                            gt = select_gt_ix(
                                self._data[param]["value"],
                                True,
                                {
                                    self.label_to_extend:
                                    vo[self.label_to_extend]
                                },
                                extend_grid,
                            )
                            eq = select_eq(
                                gt,
                                True,
                                utils.filter_labels(
                                    vo, drop=[self.label_to_extend, "value"]),
                            )
                            to_delete += eq
                    to_delete = [
                        dict(td, **{"value": None}) for td in to_delete
                    ]
                    # make copy of value objects since they
                    # are about to be modified
                    backup = copy.deepcopy(self._data[param]["value"])
                    try:
                        array_first = self.array_first
                        self.array_first = False
                        # delete params that will be overwritten out by extend.
                        self.adjust(
                            {param: to_delete},
                            extend_adj=False,
                            raise_errors=True,
                        )
                        # set user adjustments.
                        self.adjust({param: vos},
                                    extend_adj=False,
                                    raise_errors=True)
                        self.array_first = array_first
                        # extend user adjustments.
                        self.extend(params=[param], raise_errors=True)
                    except ValidationError:
                        self._data[param]["value"] = backup
            else:
                for param, value in parsed_params.items():
                    self._update_param(param, value)

        self._validator_schema.context["spec"] = self

        if raise_errors and self._errors:
            raise self.validation_error

        # Update attrs for params that were adjusted.
        self._set_state(params=parsed_params.keys())

        return parsed_params