Ejemplo n.º 1
0
def test_select_eq(vos):
    assert list(select_eq(vos, True, labels={
        "d0": 1,
        "d1": "hello"
    })) == [{
        "d0": 1,
        "d1": "hello",
        "value": 1
    }]

    assert list(select_eq(vos, True, labels={
        "d0": [1, 2],
        "d1": "hello"
    })) == [
        {
            "d0": 1,
            "d1": "hello",
            "value": 1
        },
        {
            "d0": 2,
            "d1": "hello",
            "value": 1
        },
    ]
Ejemplo n.º 2
0
 def select_eq(self, param, exact_match, **labels):
     return select_eq(
         self._data[param]["value"],
         exact_match,
         labels,
         tree=self._search_trees.get(param),
     )
Ejemplo n.º 3
0
 def select_eq(self, param, strict=True, **labels):
     return select_eq(
         self._data[param]["value"],
         strict,
         labels,
         tree=self._search_trees.get(param),
     )
Ejemplo n.º 4
0
def test_select_eq_strict(vos):
    assert list(select_eq(vos, True, labels={"d0": 1, "d1": "hello"})) == [
        {"d0": 1, "d1": "hello", "value": 1}
    ]

    assert list(
        select_eq(vos, True, labels={"d0": [1, 2], "d1": "hello"})
    ) == [
        {"d0": 1, "d1": "hello", "value": 1},
        {"d0": 2, "d1": "hello", "value": 1},
    ]

    vos[2]["_auto"] = True
    vos[3]["_auto"] = True
    assert list(select_eq(vos, False, labels={"_auto": False})) == [
        {"d0": 1, "d1": "hello", "value": 1},
        {"d0": 1, "d1": "world", "value": 1},
    ]
Ejemplo n.º 5
0
    def extend(
        self,
        label_to_extend=None,
        label_to_extend_values=None,
        params=None,
        raise_errors=True,
    ):
        """
        Extend parameters along label_to_extend.

        Raises:
            InconsistentLabelsException: Value objects do not have consistent
                labels.
        """
        if label_to_extend is None:
            label_to_extend = self.label_to_extend

        spec = self.specification(meta_data=True)
        if params is not None:
            spec = {
                param: self._data[param]
                for param, data in spec.items() if param in params
            }
        extend_grid = (label_to_extend_values
                       or self._stateless_label_grid[label_to_extend])
        adjustment = defaultdict(list)
        for param, data in spec.items():
            if not any(label_to_extend in vo for vo in data["value"]):
                continue
            extended_vos = set()
            for vo in sorted(
                    data["value"],
                    key=lambda val: extend_grid.index(val[label_to_extend]),
            ):
                hashable_vo = utils.hashable_value_object(vo)
                if hashable_vo in extended_vos:
                    continue
                else:
                    extended_vos.add(hashable_vo)
                gt = select_gt_ix(
                    self._data[param]["value"],
                    False,
                    {label_to_extend: vo[label_to_extend]},
                    extend_grid,
                    tree=self._search_trees.get(param),
                )
                eq = select_eq(
                    gt,
                    False,
                    utils.filter_labels(vo, drop=["value", label_to_extend]),
                )
                extended_vos.update(map(utils.hashable_value_object, eq))
                eq += [vo]

                defined_vals = {eq_vo[label_to_extend] for eq_vo in eq}

                missing_vals = sorted(
                    set(extend_grid) - defined_vals,
                    key=lambda val: extend_grid.index(val),
                )

                if not missing_vals:
                    continue

                extended = defaultdict(list)

                for val in missing_vals:
                    eg_ix = extend_grid.index(val)
                    if eg_ix == 0:
                        first_defined_value = min(
                            defined_vals,
                            key=lambda val: extend_grid.index(val),
                        )
                        value_objects = select_eq(
                            eq, True, {label_to_extend: first_defined_value})
                    elif extend_grid[eg_ix - 1] in extended:
                        value_objects = extended.pop(extend_grid[eg_ix - 1])
                    else:
                        prev_defined_value = extend_grid[eg_ix - 1]
                        value_objects = select_eq(
                            eq, True, {label_to_extend: prev_defined_value})
                    # In practice, value_objects has length one.
                    # Theoretically, there could be multiple if the inital value
                    # object had less labels than later value objects and thus
                    # matched multiple value objects.
                    for value_object in value_objects:
                        ext = dict(value_object, **{label_to_extend: val})
                        ext = self.extend_func(
                            param,
                            ext,
                            value_object,
                            extend_grid,
                            label_to_extend,
                        )
                        extended_vos.add(
                            utils.hashable_value_object(value_object))
                        extended[val].append(ext)
                        adjustment[param].append(ext)
        # Ensure that the adjust method of paramtools.Parameter is used
        # in case the child class also implements adjust.
        self._adjust(adjustment, extend_adj=False, raise_errors=raise_errors)
Ejemplo n.º 6
0
    def _adjust(self, params_or_path, raise_errors=True, extend_adj=True):
        """
        Internal method for performing adjustments.
        """
        params = self.read_params(params_or_path)

        # Validate user adjustments.
        parsed_params = {}
        try:
            parsed_params = self._validator_schema.load(params)
        except MarshmallowValidationError as ve:
            self._parse_errors(ve, params)

        if not self._errors:
            if self.label_to_extend is not None and extend_adj:
                extend_grid = self._stateless_label_grid[self.label_to_extend]
                to_delete = defaultdict(list)
                backup = {}
                for param, vos in parsed_params.items():
                    for vo in utils.grid_sort(vos, self.label_to_extend,
                                              extend_grid):
                        if self.label_to_extend in vo:
                            if (vo[self.label_to_extend] not in
                                    self.label_grid[self.label_to_extend]):
                                msg = (
                                    f"{param}[{self.label_to_extend}={vo[self.label_to_extend]}] "
                                    f"is not active in the current state: "
                                    f"{self.label_to_extend}= "
                                    f"{self.label_grid[self.label_to_extend]}."
                                )
                                warnings.warn(msg)
                            gt = select_gt_ix(
                                self._data[param]["value"],
                                True,
                                {
                                    self.label_to_extend:
                                    vo[self.label_to_extend]
                                },
                                extend_grid,
                                tree=self._search_trees.get(param),
                            )
                            eq = select_eq(
                                gt,
                                True,
                                utils.filter_labels(
                                    vo, drop=[self.label_to_extend, "value"]),
                            )
                            to_delete[param] += [
                                dict(td, **{"value": None}) for td in eq
                            ]
                    # make copy of value objects since they
                    # are about to be modified
                    backup[param] = copy.deepcopy(self._data[param]["value"])
                try:
                    array_first = self.array_first
                    self.array_first = False

                    # delete params that will be overwritten out by extend.
                    self._adjust(to_delete,
                                 extend_adj=False,
                                 raise_errors=True)

                    # set user adjustments.
                    self._adjust(parsed_params,
                                 extend_adj=False,
                                 raise_errors=True)
                    self.extend(params=parsed_params.keys(), raise_errors=True)
                except ValidationError:
                    for param in backup:
                        self._data[param]["value"] = backup[param]
                finally:
                    self.array_first = array_first
            else:
                for param, value in parsed_params.items():
                    self._update_param(param, value)

        self._validator_schema.context["spec"] = self

        if raise_errors and self._errors:
            raise self.validation_error

        # Update attrs for params that were adjusted.
        self._set_state(params=parsed_params.keys())

        return parsed_params
Ejemplo n.º 7
0
    def extend(
        self,
        label: Optional[str] = None,
        label_values: Optional[List[Any]] = None,
        params: Optional[List[str]] = None,
        raise_errors: bool = True,
        ignore_warnings: bool = False,
    ):
        """
        Extend parameters along `label`.

        **Parameters**

        - `label`: Label to extend values along. By default, `label_to_extend`
          is used.
        - `label_values`: values of `label` to extend. By default, this is a grid
          created from the valid values of `label_to_extend`.
        - `params`: Parameters to extend. By default, all parameters are extended.
        - `raise_errors`: Whether `adjust` should raise or store errors.
        - `ignore_warnings`: Whether `adjust` should raise or ignore warnings.

        **Raises**

          - `InconsistentLabelsException`: Value objects do not have consistent
            labels.
        """
        if label is None:
            label = self.label_to_extend
        else:
            label = label

        spec = self.specification(meta_data=True)
        if params is not None:
            spec = {
                param: self._data[param]
                for param, data in spec.items() if param in params
            }
        full_extend_grid = self._stateless_label_grid[label]
        if label_values is not None:
            labels = self.parse_labels(**{label: label_values})
            extend_grid = labels[label]
        else:
            extend_grid = self._stateless_label_grid[label]

        cmp_funcs = self.label_validators[label].cmp_funcs(choices=extend_grid)
        gt_cmp_func = make_cmp_func(cmp_funcs["gt"], all_or_any=all)

        adjustment = defaultdict(list)
        for param, data in spec.items():
            if not any(label in vo for vo in data["value"]):
                continue
            extended_vos = set()
            for vo in sorted(data["value"],
                             key=lambda val: cmp_funcs["key"](val[label])):
                hashable_vo = utils.hashable_value_object(vo)
                if hashable_vo in extended_vos:
                    continue
                else:
                    extended_vos.add(hashable_vo)
                gt = select(
                    self._data[param]["value"],
                    False,
                    gt_cmp_func,
                    {label: vo[label]},
                    tree=self._search_trees.get(param),
                )
                eq = select_eq(
                    gt,
                    False,
                    utils.filter_labels(vo, drop=["value", label, "_auto"]),
                )
                extended_vos.update(map(utils.hashable_value_object, eq))
                eq += [vo]

                defined_vals = {eq_vo[label] for eq_vo in eq}

                missing_vals = sorted(set(extend_grid) - defined_vals,
                                      key=cmp_funcs["key"])

                if not missing_vals:
                    continue

                extended = defaultdict(list)
                for vo in eq:
                    extended[vo[label]].append(vo)

                skl = utils.SortedKeyList(extended.keys(), cmp_funcs["key"])

                for val in missing_vals:
                    lte_val = skl.lte(val)
                    if lte_val is not None:
                        closest_val = lte_val
                    else:
                        closest_val = skl.gte(val)

                    if closest_val in extended:
                        value_objects = extended.pop(closest_val)
                    else:
                        value_objects = select_eq(eq, False,
                                                  {label: closest_val})
                    # In practice, value_objects has length one.
                    # Theoretically, there could be multiple if the inital value
                    # object had less labels than later value objects and thus
                    # matched multiple value objects.
                    for value_object in value_objects:
                        ext = dict(value_object, **{label: val})
                        ext = self.extend_func(param, ext, value_object,
                                               full_extend_grid, label)
                        extended_vos.add(
                            utils.hashable_value_object(value_object))
                        extended[val].append(ext)
                        skl.insert(val)
                        adjustment[param].append(dict(ext, _auto=True))
        # Ensure that the adjust method of paramtools.Parameter is used
        # in case the child class also implements adjust.
        self._adjust(
            adjustment,
            extend_adj=False,
            ignore_warnings=ignore_warnings,
            raise_errors=raise_errors,
            is_deserialized=True,
        )
Ejemplo n.º 8
0
    def _adjust(
        self,
        params_or_path,
        ignore_warnings=False,
        raise_errors=True,
        extend_adj=True,
        is_deserialized=False,
        clobber=True,
    ):
        """
        Internal method for performing adjustments.
        """
        # Validate user adjustments.
        if is_deserialized:
            parsed_params = {}
            try:
                parsed_params = self._validator_schema.load(
                    params_or_path, ignore_warnings, is_deserialized=True)
            except MarshmallowValidationError as ve:
                self._parse_validation_messages(ve.messages, params_or_path)
        else:
            params = self.read_params(params_or_path)
            parsed_params = {}
            try:
                parsed_params = self._validator_schema.load(
                    params, ignore_warnings)
            except MarshmallowValidationError as ve:
                self._parse_validation_messages(ve.messages, params)

        if not self._errors:
            if self.label_to_extend is not None and extend_adj:
                extend_grid = self._stateless_label_grid[self.label_to_extend]
                to_delete = defaultdict(list)
                backup = {}
                cmp_funcs = self.label_validators[
                    self.label_to_extend].cmp_funcs()
                gt_cmp_func = make_cmp_func(cmp_funcs["gt"], all_or_any=all)
                for param, vos in parsed_params.items():
                    for vo in utils.grid_sort(vos, self.label_to_extend,
                                              extend_grid):

                        if self.label_to_extend in vo:
                            query_args = {
                                self.label_to_extend: vo[self.label_to_extend]
                            }
                            if clobber:
                                queryset = self._data[param]["value"]
                                tree = self._search_trees.get(param)
                            else:
                                queryset = self.select_eq(param,
                                                          strict=True,
                                                          _auto=True)
                                tree = None
                            gt = select(
                                queryset,
                                strict=False,
                                cmp_func=gt_cmp_func,
                                labels=query_args,
                                tree=tree,
                            )
                            to_delete[param] += select_eq(
                                gt,
                                strict=False,
                                labels=utils.filter_labels(
                                    vo,
                                    drop=[
                                        self.label_to_extend,
                                        "value",
                                        "_auto",
                                    ],
                                ),
                            )
                    # make copy of value objects since they
                    # are about to be modified
                    backup[param] = copy.deepcopy(self._data[param]["value"])
                try:
                    array_first = self.array_first
                    self.array_first = False

                    # delete params that will be overwritten out by extend.
                    self.delete(
                        to_delete,
                        extend_adj=False,
                        raise_errors=True,
                        ignore_warnings=ignore_warnings,
                    )

                    # set user adjustments.
                    self._adjust(
                        parsed_params,
                        extend_adj=False,
                        raise_errors=True,
                        ignore_warnings=ignore_warnings,
                    )
                    self.extend(
                        params=parsed_params.keys(),
                        ignore_warnings=ignore_warnings,
                        raise_errors=True,
                    )
                except ValidationError:
                    for param in backup:
                        self._data[param]["value"] = backup[param]
                finally:
                    self.array_first = array_first
            else:
                for param, value in parsed_params.items():
                    self._update_param(param, value)

        self._validator_schema.context["spec"] = self

        has_errors = bool(self._errors.get("messages"))
        has_warnings = bool(self._warnings.get("messages"))
        # throw error if raise_errors is True or ignore_warnings is False
        if (raise_errors and has_errors) or (not ignore_warnings
                                             and has_warnings):
            raise self.validation_error

        # Update attrs for params that were adjusted.
        self._set_state(params=parsed_params.keys())

        return parsed_params
Ejemplo n.º 9
0
    def sort_values(self, data=None, has_meta_data=True):
        """
        Sort value objects for all parameters in `data` according
        to the order specified in `schema`.


        **Parameters**

          - `data`: Parameter data to be sorted. This should be a
            `dict` of parameter names and values. If `data` is `None`,
            the current values will be sorted.
          - `has_meta_data`: Whether parameter values should be accessed
            directly or through the "value" attribute.

        **Returns**

          - Sorted data.
        """
        def keyfunc(vo, label, label_values):
            if label in vo:
                return label_values.index(vo[label])
            else:
                return -1

        if data is None:
            data = self._data
            update_attrs = True
            if not has_meta_data:
                raise ParamToolsError(
                    "has_meta_data must be True if data is not specified.")
        else:
            update_attrs = False

        # nothing to do if labels aren't specified
        if not self._stateless_label_grid:
            return data

        # iterate over labels so that the first label's order
        # takes precedence.
        label_grid = self._stateless_label_grid
        order = list(reversed(label_grid))

        for param in data:
            for label in order:
                label_values = label_grid[label]
                pfunc = partial(keyfunc,
                                label=label,
                                label_values=label_values)
                if has_meta_data:
                    data[param]["value"] = sorted(data[param]["value"],
                                                  key=pfunc)

                else:
                    data[param] = sorted(data[param], key=pfunc)

            # Only update attributes when array first is off, since
            # value order will not affect how arrays are constructed.
            if update_attrs and has_meta_data and not self.array_first:
                attr_vals = select_eq(data[param]["value"],
                                      strict=True,
                                      labels=self._state)
                sorted_values = self.sort_values({param: attr_vals},
                                                 has_meta_data=False)[param]
                setattr(self, param, sorted_values)
        return data
Ejemplo n.º 10
0
    def adjust(self, params_or_path, raise_errors=True, extend_adj=True):
        """
        Deserialize and validate parameter adjustments. `params_or_path`
        can be a file path or a `dict` that has not been fully deserialized.
        The adjusted values replace the current values stored in the
        corresponding parameter attributes.

        Returns: parsed, validated parameters.

        Raises:
            marshmallow.exceptions.ValidationError if data is not valid.

            ParameterUpdateException if label values do not match at
                least one existing value item's corresponding label values.
        """
        params = self.read_params(params_or_path)

        # Validate user adjustments.
        parsed_params = {}
        try:
            parsed_params = self._validator_schema.load(params)
        except MarshmallowValidationError as ve:
            self._parse_errors(ve, params)

        if not self._errors:
            if self.label_to_extend is not None and extend_adj:
                extend_grid = self._stateless_label_grid[self.label_to_extend]
                for param, vos in parsed_params.items():
                    to_delete = []
                    for vo in utils.grid_sort(vos, self.label_to_extend,
                                              extend_grid):
                        if self.label_to_extend in vo:
                            if (vo[self.label_to_extend] not in
                                    self.label_grid[self.label_to_extend]):
                                msg = (
                                    f"{param}[{self.label_to_extend}={vo[self.label_to_extend]}] "
                                    f"is not active in the current state: "
                                    f"{self.label_to_extend}= "
                                    f"{self.label_grid[self.label_to_extend]}."
                                )
                                warnings.warn(msg)
                            gt = select_gt_ix(
                                self._data[param]["value"],
                                True,
                                {
                                    self.label_to_extend:
                                    vo[self.label_to_extend]
                                },
                                extend_grid,
                            )
                            eq = select_eq(
                                gt,
                                True,
                                utils.filter_labels(
                                    vo, drop=[self.label_to_extend, "value"]),
                            )
                            to_delete += eq
                    to_delete = [
                        dict(td, **{"value": None}) for td in to_delete
                    ]
                    # make copy of value objects since they
                    # are about to be modified
                    backup = copy.deepcopy(self._data[param]["value"])
                    try:
                        array_first = self.array_first
                        self.array_first = False
                        # delete params that will be overwritten out by extend.
                        self.adjust(
                            {param: to_delete},
                            extend_adj=False,
                            raise_errors=True,
                        )
                        # set user adjustments.
                        self.adjust({param: vos},
                                    extend_adj=False,
                                    raise_errors=True)
                        self.array_first = array_first
                        # extend user adjustments.
                        self.extend(params=[param], raise_errors=True)
                    except ValidationError:
                        self._data[param]["value"] = backup
            else:
                for param, value in parsed_params.items():
                    self._update_param(param, value)

        self._validator_schema.context["spec"] = self

        if raise_errors and self._errors:
            raise self.validation_error

        # Update attrs for params that were adjusted.
        self._set_state(params=parsed_params.keys())

        return parsed_params
Ejemplo n.º 11
0
 def select_eq(self, param, exact_match, **labels):
     return select_eq(self._data[param]["value"], exact_match, labels)