예제 #1
0
def _unit_validator(instance: Mapping, expected_dimensionality: str,
                    position: list[str]) -> Iterator[ValidationError]:
    """Validate the 'unit' key of the instance against the given string.

    Parameters
    ----------
    instance:
        Tree serialization with 'unit' key to validate.
    expected_dimensionality:
        String representation of the unit dimensionality to test against.
    position:
        Current position in nested structure for debugging

    Yields
    ------
    asdf.ValidationError

    """
    if not position:
        position = instance

    unit = instance["units"]
    unit = str(unit)  # catch TaggedString
    valid = U_(unit).is_compatible_with(U_(expected_dimensionality))
    if not valid:
        yield ValidationError(
            f"Error validating unit dimension for property '{position}'. "
            f"Expected unit of dimension '{expected_dimensionality}' "
            f"but got unit '{unit}'")
예제 #2
0
파일: xarray.py 프로젝트: CagtayFabry/weldx
def xr_check_dimensionality(da: xr.DataArray, units_ref: Union[str,
                                                               pint.Unit]):
    """Check if the dimensionality of a ``DataArray`` is compatible with reference unit.

    Parameters
    ----------
    da:
        The data array that should be checked.
    units_ref:
        The reference unit

    Raises
    ------
    pint.DimensionalityError
        The error is raised if the check fails

    """
    if units_ref is None:
        return

    units_ref = U_(units_ref)
    units = da.weldx.units

    if units is None or not units.is_compatible_with(units_ref):
        raise DimensionalityError(
            units,
            units_ref,
            extra_msg=
            f"\nDataArray units are '{units}'.  This is incompatible with "
            f"the expected dimensionality '{units_ref.dimensionality}'",
        )
예제 #3
0
    def _determine_output_signal_unit(
        func: MathematicalExpression, input_unit: Union[str, Union]
    ) -> pint.Unit:
        """Determine the unit of a transformations' output signal.

        Parameters
        ----------
        func :
            The function describing the transformation
        input_unit :
            The unit of the input signal

        Returns
        -------
        pint.Unit:
            Unit of the transformations' output signal

        """
        input_unit = U_(input_unit)

        if func is not None:
            variables = func.get_variable_names()
            if len(variables) != 1:
                raise ValueError("The provided function must have exactly 1 parameter")

            try:
                test_output = func.evaluate(**{variables[0]: Q_(1, input_unit)})
            except Exception as e:
                raise ValueError(
                    "The provided function is incompatible with the input signals unit."
                    f" \nThe test raised the following exception:\n{e}"
                )
            return test_output.data.units

        return input_unit
예제 #4
0
def test_pint_default_ureg():
    """Test if the weldx unit registry is set as the default unit registry."""
    da = xr.DataArray(
        Q_([1, 2, 3, 4], "mm"),
        dims=["a"],
        coords={"a": ("a", [1, 2, 3, 4], {
            "units": U_("s")
        })},
    )
    da.pint.dequantify().pint.quantify().pint.dequantify().pint.quantify()
예제 #5
0
파일: xarray.py 프로젝트: CagtayFabry/weldx
    def units(self) -> Union[pint.Unit, None]:
        """Get the unit of the data array values.

        Other than the pint-xarray accessor ``.pint.units`` this will also return units
        Stored in the attributes.
        """
        da = self._obj
        units = da.pint.units
        if units is None:
            units = da.attrs.get(UNITS_KEY, None)
        return U_(units) if units is not None else units
예제 #6
0
    def test_construction_expression(data, shape_exp, unit_exp):
        """Test the construction of the TimeSeries class."""
        ts = TimeSeries(data=data)

        # check
        assert ts.data == data
        assert ts.time is None
        assert ts.interpolation is None
        assert ts.shape == shape_exp
        assert ts.data_array is None
        assert U_(unit_exp).is_compatible_with(ts.units)
예제 #7
0
    def test_interp_time(ts, time, magnitude_exp, unit_exp):
        """Test the interp_time function."""
        result = ts.interp_time(time)

        assert np.all(np.isclose(result.data.magnitude, magnitude_exp))
        assert result.units == U_(unit_exp)

        time = Time(time)
        if len(time) == 1:
            assert result.time is None
        else:
            assert np.all(Time(result.time, result._reference_time) == time)
예제 #8
0
    def from_parameters(
        cls,
        name: str,
        source_name: str,
        source_error: Error,
        output_signal_type: str,
        output_signal_unit: Union[str, Unit],
        signal_data: TimeSeries = None,
    ) -> "MeasurementChain":
        """Create a new measurement chain without providing a `SignalSource` instance.

        Parameters
        ----------
        name :
            Name of the measurement chain
        source_name :
            Name of the source
        source_error :
            Error of the source
        output_signal_type :
            Type of the output signal ('analog' or 'digital')
        output_signal_unit :
            Unit of the output signal
        signal_data :
            Measured data of the sources' signal

        Returns
        -------
        MeasurementChain :
            New measurement chain

        Examples
        --------
        >>> from weldx import Q_
        >>> from weldx.measurement import Error, MeasurementChain

        >>> mc = MeasurementChain.from_parameters(
        ...          name="Current measurement chain",
        ...          source_error=Error(deviation=Q_(0.5, "percent")),
        ...          source_name="Current sensor",
        ...          output_signal_type="analog",
        ...          output_signal_unit="V"
        ...      )

        """
        source = SignalSource(
            source_name,
            Signal(output_signal_type, U_(output_signal_unit), signal_data),
            source_error,
        )
        return cls(name, source)
예제 #9
0
    def test_add_transformation(self, tf_kwargs, exp_signal_type, exp_signal_unit):
        """Test the `add_transformation` method of the `MeasurementChain`.

        Parameters
        ----------
        tf_kwargs:
            A dictionary with keyword arguments that are used to construct the
            `SignalTransformation` that is passed to the `add_transformation` method.
            Missing arguments are added.
        exp_signal_type :
            The expected signal type after the transformation
        exp_signal_unit :
            The expected unit after the transformation

        """
        mc = MeasurementChain(**self._default_init_kwargs())

        mc.add_transformation(self._default_transformation(tf_kwargs))

        signal = mc.output_signal
        assert signal.signal_type == exp_signal_type
        assert U_(signal.units) == exp_signal_unit
예제 #10
0
 def from_yaml_tree(self, node: str, tag: str, ctx) -> pint.Unit:
     """Reconstruct from tree."""
     return U_(node)
예제 #11
0
class TestMeasurementChain:
    """Test the `MeasurementChain` class."""

    # helper functions -----------------------------------------------------------------

    @staticmethod
    def _default_source_kwargs(kwargs: dict = None) -> dict:
        """Update a dict with default keyword arguments to create a `SignalSource`."""
        default_kwargs = dict(
            name="source", output_signal=Signal("analog", "V"), error=Error(0.01)
        )

        if kwargs is not None:
            default_kwargs.update(kwargs)

        return default_kwargs

    @classmethod
    def _default_init_kwargs(
        cls, kwargs: dict = None, source_kwargs: dict = None
    ) -> dict:
        """Return a dictionary of keyword arguments required by the `__init__` method.

        Parameters
        ----------
        kwargs :
            A dictionary containing some key word arguments that should replace the
            default ones.
        source_kwargs :
            A dictionary of key word arguments that should replace the arguments used
            for the default source.

        Returns
        -------
        Dict :
            Dictionary with keyword arguments for the `__init__` method

        """
        source_kwargs = cls._default_source_kwargs(source_kwargs)

        default_kwargs = dict(
            name="name",
            source=SignalSource(**source_kwargs),
            signal_data=[1, 3, 5],
        )
        if kwargs is not None:
            default_kwargs.update(kwargs)

        return default_kwargs

    @staticmethod
    def _default_transformation(kwargs: dict = None) -> SignalTransformation:
        """Return a default `SignalTransformation`.

        Use the kwargs parameter to modify the default values.

        """
        default_kwargs = dict(
            name="transformation",
            error=Error(0.1),
            func=MathematicalExpression("a*x", parameters={"a": Q_(1, "1/V")}),
            type_transformation="AD",
        )
        if kwargs is not None:
            default_kwargs.update(kwargs)

        return SignalTransformation(**default_kwargs)

    @classmethod
    def _default_add_transformation_kwargs(cls, kwargs: dict = None) -> dict:
        """Update a dict with default keyword arguments to call `add_transformation`."""
        default_kwargs = dict(
            transformation=cls._default_transformation(),
            error=Error(0.02),
            output_signal_type="digital",
            output_signal_unit="",
        )
        if kwargs is not None:
            default_kwargs.update(kwargs)

        return default_kwargs

    # test_init ------------------------------------------------------------------------

    @staticmethod
    @pytest.mark.parametrize(
        "kwargs, source_kwargs",
        [
            ({}, {}),
            (dict(signal_data=None), dict(output_signal=Signal("analog", "V", [1]))),
        ],
    )
    def test_init(kwargs: dict, source_kwargs: dict):
        """Test the `__init__` method of the `MeasurementChain`.

        Parameters
        ----------
        kwargs:
            A dictionary with keyword arguments that are passed to the `__init__`
            method. Missing arguments are added.
        source_kwargs :
            A dictionary with keyword arguments that are used to construct the
            `SignalSource` that is passed to the `__init__` method. Missing arguments
            are added.

        """
        kwargs = TestMeasurementChain._default_init_kwargs(kwargs, source_kwargs)
        MeasurementChain(**kwargs)

    # test_init_exceptions -------------------------------------------------------------

    @staticmethod
    @pytest.mark.parametrize(
        "kwargs, source_kwargs,  exception_type, test_name",
        [({}, {"output_signal": Signal("analog", "V", [1])}, KeyError, "# 2x data")],
        ids=get_test_name,
    )
    def test_init_exceptions(
        kwargs: dict, source_kwargs: dict, exception_type, test_name: str
    ):
        """Test the exceptions of the `__init__` method.

        Parameters
        ----------
        kwargs :
            A dictionary with keyword arguments that are passed to the `__init__`
            method. Missing arguments are added.
        source_kwargs :
            A dictionary with keyword arguments that are used to construct the
            `SignalSource` that is passed to the `__init__` method. Missing arguments
            are added.
        exception_type :
            The expected exception type
        test_name :
            Name of the test

        """
        kwargs = TestMeasurementChain._default_init_kwargs(kwargs, source_kwargs)
        with pytest.raises(exception_type):
            MeasurementChain(**kwargs)

    # test_from_equipment --------------------------------------------------------------

    @pytest.mark.parametrize(
        "num_sources, source_name, exception",
        [
            (0, None, ValueError),
            (1, None, None),
            (2, "source_1", None),
            (2, "wrong name", KeyError),
            (2, None, ValueError),
        ],
    )
    def test_from_equipment(
        self, num_sources: int, source_name: str, exception: Exception
    ):
        """Test the `from_equipment` factory and its exceptions.

        Parameters
        ----------
        num_sources :
            Number of sources of the generated equipment
        source_name :
            Corresponding parameter of `from_equipment`
        exception :
            Expected exception

        """
        sources = [
            SignalSource(**self._default_source_kwargs({"name": f"source_{i}"}))
            for i in range(num_sources)
        ]
        equipment = MeasurementEquipment("Equipment", sources=sources)

        if exception is not None:
            with pytest.raises(exception):
                MeasurementChain.from_equipment(
                    name="name", equipment=equipment, source_name=source_name
                )
        else:
            MeasurementChain.from_equipment(
                name="name", equipment=equipment, source_name=source_name
            )

    # test_add_transformations ---------------------------------------------------------

    @pytest.mark.parametrize(
        "tf_kwargs, exp_signal_type, exp_signal_unit",
        [
            ({}, "digital", U_("")),
            (dict(type_transformation="AA"), "analog", U_("")),
            (dict(type_transformation=None), "analog", U_("")),
            (dict(func=None), "digital", U_("V")),
        ],
    )
    def test_add_transformation(self, tf_kwargs, exp_signal_type, exp_signal_unit):
        """Test the `add_transformation` method of the `MeasurementChain`.

        Parameters
        ----------
        tf_kwargs:
            A dictionary with keyword arguments that are used to construct the
            `SignalTransformation` that is passed to the `add_transformation` method.
            Missing arguments are added.
        exp_signal_type :
            The expected signal type after the transformation
        exp_signal_unit :
            The expected unit after the transformation

        """
        mc = MeasurementChain(**self._default_init_kwargs())

        mc.add_transformation(self._default_transformation(tf_kwargs))

        signal = mc.output_signal
        assert signal.signal_type == exp_signal_type
        assert U_(signal.units) == exp_signal_unit

    # test_add_transformation_exceptions -----------------------------------------------

    @pytest.mark.parametrize(
        "tf_kwargs, input_signal_source, exception_type, test_name",
        [
            (dict(type_transformation="DA"), None, ValueError, "# inv. signal type #1"),
            (dict(type_transformation="DD"), None, ValueError, "# inv. signal type #2"),
            ({}, "not found", KeyError, "# invalid input signal source"),
            (dict(name="source"), None, KeyError, "# Name does already exist"),
            (
                dict(func=MathematicalExpression("x+a", parameters={"a": Q_(1, "A")})),
                None,
                ValueError,
                "# incompatible function",
            ),
        ],
        ids=get_test_name,
    )
    def test_add_transformation_exceptions(
        self, tf_kwargs: dict, input_signal_source: str, exception_type, test_name: str
    ):
        """Test the exceptions of the `add_transformation` method.

        Parameters
        ----------
        tf_kwargs:
            A dictionary with keyword arguments that are used to construct the
            `SignalTransformation` that is passed to the `add_transformation` method.
            Missing arguments are added.
        input_signal_source :
            The value of the corresponding parameter of 'add_transformation'
        exception_type :
            The expected exception type
        test_name :
            Name of the test

        """
        mc = MeasurementChain(**self._default_init_kwargs())

        tf = self._default_transformation(tf_kwargs)

        with pytest.raises(exception_type):
            mc.add_transformation(tf, input_signal_source=input_signal_source)

    # test_add_transformation_from_equipment -------------------------------------------

    @pytest.mark.parametrize(
        "num_transformations, transformation_name, exception",
        [
            (0, None, ValueError),
            (1, None, None),
            (2, "transformation_1", None),
            (2, "wrong name", KeyError),
            (2, None, ValueError),
        ],
    )
    def test_add_transformation_from_equipment(
        self, num_transformations: int, transformation_name: str, exception
    ):
        """Test `add_transformation_from_equipment` and its exceptions.

        Parameters
        ----------
        num_transformations :
            Number of transformations of the generated equipment
        transformation_name :
            Corresponding parameter of `add_transformation_from_equipment`
        exception :
            Expected exception

        """
        mc = MeasurementChain(**self._default_init_kwargs())
        transformations = [
            self._default_transformation({"name": f"transformation_{i}"})
            for i in range(num_transformations)
        ]
        equipment = MeasurementEquipment(name="name", transformations=transformations)

        if exception is not None:
            with pytest.raises(exception):
                mc.add_transformation_from_equipment(
                    equipment=equipment, transformation_name=transformation_name
                )
        else:
            mc.add_transformation_from_equipment(
                equipment=equipment, transformation_name=transformation_name
            )

    # test_add_signal_data -------------------------------------------------------------

    @pytest.mark.parametrize(
        "kwargs",
        [
            dict(data=xr.DataArray([2, 3])),
            dict(signal_source="source"),
        ],
    )
    def test_add_signal_data(self, kwargs):
        """Test the `add_signal_data` method of the `MeasurementChain`.

        Parameters
        ----------
        kwargs:
            A dictionary with keyword arguments that are passed to the
            `add_signal_data` method. If no name is in the kwargs, a default one is
            added.

        """
        mc = MeasurementChain(**self._default_init_kwargs({"signal_data": None}))
        mc.add_transformation(self._default_transformation())

        full_kwargs = dict(data=xr.DataArray([1, 2]))
        full_kwargs.update(kwargs)

        mc.add_signal_data(**full_kwargs)

    # test_add_signal_data_exceptions --------------------------------------------------

    @pytest.mark.parametrize(
        "kwargs,  exception_type, test_name",
        [
            (dict(signal_source="what"), KeyError, "# invalid signal source"),
            (dict(signal_source="source"), KeyError, "# already has data #1"),
            (dict(signal_source="transformation"), KeyError, "# already has data #2"),
        ],
        ids=get_test_name,
    )
    def test_add_signal_data_exceptions(
        self, kwargs: dict, exception_type, test_name: str
    ):
        """Test the exceptions of the `add_signal_data` method.

        Parameters
        ----------
        kwargs :
            A dictionary with keyword arguments that are passed to the `add_signal_data`
            method. Missing arguments are added.
        exception_type :
            The expected exception type
        test_name :
            Name of the test

        """
        mc = MeasurementChain(**self._default_init_kwargs())
        mc.add_transformation(self._default_transformation(), data=[1, 2, 3])
        mc.add_transformation(
            self._default_transformation(
                dict(name="transformation 2", type_transformation="DA")
            )
        )

        full_kwargs = dict(data=xr.DataArray([1, 2]))
        full_kwargs.update(kwargs)

        with pytest.raises(exception_type):
            mc.add_signal_data(**full_kwargs)

    # test_get_equipment ---------------------------------------------------------------

    @pytest.mark.parametrize(
        "signal_source, exception",
        [
            ("source", None),
            ("transformation_1", None),
            ("transformation_2", None),
            ("transformation_3", KeyError),
        ],
    )
    def test_get_equipment(self, signal_source, exception):
        """Test the `get_equipment` function and their exceptions.

        Parameters
        ----------
        signal_source :
            Corresponding function parameter
        exception :
            Expected exception

        """
        src_eq = MeasurementEquipment(
            "Source Eq", sources=[SignalSource(**self._default_source_kwargs())]
        )
        tf_eq = MeasurementEquipment(
            "Transformation_eq",
            transformations=[
                self._default_transformation({"name": "transformation_1"})
            ],
        )

        mc = MeasurementChain.from_equipment("Chain", src_eq)
        mc.add_transformation_from_equipment(tf_eq)
        mc.create_transformation("transformation_2", None, output_signal_unit="A")

        if exception is not None:
            with pytest.raises(exception):
                mc.get_equipment(signal_source=signal_source)
        else:
            mc.get_equipment(signal_source=signal_source)

    # test_get_signal_data -------------------------------------------------------------

    def test_get_signal_data(self):
        """Test the `get_signal_data` method.

        This test assures that the returned data is identical to the one passed
        to the
        measurement chain and that a key error is raised if the requested data is
        not
        present.

        """
        data = xr.DataArray([1, 2, 3])

        mc = MeasurementChain(**self._default_init_kwargs())
        mc.add_transformation(self._default_transformation(), data=data)
        mc.create_transformation("transformation_2", None, output_signal_unit="A")

        assert np.all(mc.get_signal_data("transformation") == data)

        # no data
        with pytest.raises(KeyError):
            mc.get_signal_data("transformation_2")

        # source not present
        with pytest.raises(KeyError):
            mc.get_signal_data("not found")

    # test_get_transformation ----------------------------------------------------------

    def test_get_transformation(self):
        """Test the `get_transformation` method."""
        mc = MeasurementChain(**self._default_init_kwargs())
        mc.add_transformation(self._default_transformation())

        transformation = mc.get_transformation("transformation")

        assert transformation == self._default_transformation()

    # test_get_transformation_exception ------------------------------------------------

    def test_get_transformation_exception(self):
        """Test that a `KeyError` is raised if the transformation does not exist."""
        mc = MeasurementChain(**self._default_init_kwargs())
        mc.add_transformation(self._default_transformation())

        with pytest.raises(KeyError):
            mc.get_transformation("not found")
예제 #12
0
    def create_transformation(
        self,
        name: str,
        error: Error,
        output_signal_type: str = None,
        output_signal_unit: Union[str, Unit] = None,
        func: MathematicalExpression = None,
        data: TimeSeries = None,
        input_signal_source: str = None,
    ):
        """Create and add a transformation to the measurement chain.

        Parameters
        ----------
        name :
            Name of the transformation
        error :
            The error of the transformation
        output_signal_type :
            Type of the output signal (analog or digital)
        output_signal_unit :
            Unit of the output signal. If a function is provided, it is not necessary to
            provide this parameter since it can be derived from the function. In case
            both, the function and the unit are provided, an exception is raised if a
            mismatch is dimensionality is detected. This functionality may be used as
            extra safety layer. If no function is provided, a simple unit conversion
            function is created.
        func :
            A function describing the transformation. The provided value interacts
            with the 'output_signal_unit' parameter as described in its documentation
        data :
            A set of measurement data that is associated with the output signal of the
            transformation
        input_signal_source :
            The source of the signal that should be used as input of the transformation.
            If `None` is provided, the name of the last added transformation (or the
            source, if no transformation was added to the chain) is used.

        Examples
        --------
        >>> from weldx import Q_
        >>> from weldx.core import MathematicalExpression
        >>> from weldx.measurement import Error, MeasurementChain, SignalTransformation

        >>> mc = MeasurementChain.from_parameters(
        ...          name="Current measurement chain",
        ...          source_error=Error(deviation=Q_(0.5, "percent")),
        ...          source_name="Current sensor",
        ...          output_signal_type="analog",
        ...          output_signal_unit="V"
        ...      )

        Create a mathematical expression that accepts a quantity with volts as unit and
        that returns a dimentsionless quantity.

        >>> func = MathematicalExpression(expression="a*x + b",
        ...                               parameters=dict(a=Q_(5, "1/V"), b=Q_(1, ""))
        ...                               )

        Use the mathematical expression to create a new transformation which also
        performs a analog-digital conversion.

        >>> mc.create_transformation(name="Current AD conversion",
        ...                          error=Error(deviation=Q_(1,"percent")),
        ...                          func=func,
        ...                          output_signal_type="digital"
        ...                          )

        """
        if output_signal_unit is not None:
            output_signal_unit = U_(output_signal_unit)

        if output_signal_type is None and output_signal_unit is None and func is None:
            warn("The created transformation does not perform any transformations.")

        input_signal_source = self._check_and_get_node_name(input_signal_source)
        input_signal: Signal = self._graph.nodes[input_signal_source]["signal"]
        if output_signal_type is None:
            output_signal_type = input_signal.signal_type
        type_tf = f"{input_signal.signal_type[0]}{output_signal_type[0]}".upper()
        if output_signal_unit is not None:
            if func is not None:
                if not output_signal_unit.is_compatible_with(
                    self._determine_output_signal_unit(func, input_signal.units),
                ):
                    raise ValueError(
                        "The unit of the provided functions output has not the same "
                        f"dimensionality as {output_signal_unit}"
                    )
            else:
                unit_conversion = output_signal_unit / input_signal.units
                func = MathematicalExpression(
                    "a*x",
                    parameters={"a": Q_(1, unit_conversion)},
                )

        transformation = SignalTransformation(name, error, func, type_tf)
        self.add_transformation(transformation, data, input_signal_source)
예제 #13
0
 class _DerivedSeries(GenericSeries):
     _required_unit_dimensionality = U_("m")
예제 #14
0
 def __post_init__(self):
     """Perform some checks after construction."""
     if self.signal_type not in ["analog", "digital"]:
         raise ValueError(f"{self.signal_type} is an invalid signal type.")
     self.units = U_(self.units)
예제 #15
0
    def test_xr_interp_like_units(fmt, broadcast_missing, quantified):
        """Test the unit aware behavior of xr_interp_like.

        Parameters
        ----------
        fmt
            The input format of the indexer.
        broadcast_missing
            Test missing coordinates broadcasting.
        quantified
            If True provide indexer in full quantified form.

        """
        a = Q_([0.0, 1.0], "m")
        t = Q_([-1.0, 0.0, 1.0], "s")
        t_interp = Q_([-100.0, 0.0, 200.0], "ms")
        b_interp = Q_([10.0, 20.0], "V")

        data_units = U_("A")
        data = Q_([[1, 2, 3], [4, 5, 6]], data_units)
        result = Q_([[1.9, 2, 2.2], [4.9, 5, 5.2]], data_units)

        da = xr.DataArray(
            data,
            dims=["a", "t"],
            coords={
                "t": ("t", t.m, {
                    UNITS_KEY: t.u
                }),
                "a": ("a", a.m, {
                    UNITS_KEY: a.u
                }),
            },
            attrs={META_ATTR: "meta"},
        )

        if fmt == "dict":
            da_interp = {"t": t_interp, "b": b_interp}
        else:
            da_interp = xr.DataArray(
                dims=["t", "b"],
                coords={
                    "t": ("t", t_interp.m, {
                        UNITS_KEY: t_interp.u
                    }),
                    "b": ("b", b_interp.m, {
                        UNITS_KEY: b_interp.u
                    }),
                },
            )

            if not quantified:
                da_interp = da_interp.pint.dequantify()

        da2 = ut.xr_interp_like(da,
                                da_interp,
                                broadcast_missing=broadcast_missing)

        if broadcast_missing:
            assert da2.b.attrs.get(UNITS_KEY, None) == b_interp.units
            da2 = da2.isel(b=0)

        for n in range(len(da.a)):
            assert np.all(da2.sel(a=n) == result[n, :])
        assert da2.pint.units == data_units
        assert da2.attrs[META_ATTR] == "meta"

        assert da2.t.attrs.get(UNITS_KEY, None) == t_interp.units
        assert da2.a.attrs.get(UNITS_KEY, None) == a.units
예제 #16
0
파일: xarray.py 프로젝트: CagtayFabry/weldx
 def dequantify_coords(self):
     """Format coordinates 'units' attribute as string."""
     da = self._obj.copy()
     for c, v in da.coords.items():
         if (units := v.attrs.get(UNITS_KEY, None)) is not None:
             da[c].attrs[UNITS_KEY] = str(U_(units))
예제 #17
0
파일: xarray.py 프로젝트: CagtayFabry/weldx
def xr_check_coords(coords: Union[xr.DataArray, Mapping[str, Any]],
                    ref: dict) -> bool:
    """Validate the coordinates of the DataArray against a reference dictionary.

    The reference dictionary should have the dimensions as keys and those contain
    dictionaries with the following keywords (all optional):

    ``values``
        Specify exact coordinate values to match.

    ``dtype`` : str or type
        Ensure coordinate dtype matches at least one of the given dtypes.

    ``optional`` : boolean
        default ``False`` - if ``True``, the dimension has to be in the DataArray dax

    ``dimensionality`` : str or pint.Unit
        Check if ``.attrs["units"]`` is the requested dimensionality

    ``units`` : str or pint.Unit
        Check if ``.attrs["units"]`` matches the requested unit

    Parameters
    ----------
    coords
        xarray object or coordinate mapping that should be validated
    ref
        reference dictionary

    Returns
    -------
    bool
        True, if the test was a success, else an exception is raised

    Examples
    --------
    >>> import numpy as np
    >>> import pandas as pd
    >>> import xarray as xr
    >>> import weldx as wx
    >>> dax = xr.DataArray(
    ...     data=np.ones((3, 2, 3)),
    ...     dims=["d1", "d2", "d3"],
    ...     coords={
    ...         "d1": np.array([-1, 0, 2], dtype=int),
    ...         "d2": pd.DatetimeIndex(["2020-05-01", "2020-05-03"]),
    ...         "d3": ["x", "y", "z"],
    ...     }
    ... )
    >>> ref = dict(
    ...     d1={"optional": True, "values": np.array([-1, 0, 2], dtype=int)},
    ...     d2={
    ...         "values": pd.DatetimeIndex(["2020-05-01", "2020-05-03"]),
    ...         "dtype": ["datetime64[ns]", "timedelta64[ns]"],
    ...     },
    ...     d3={"values": ["x", "y", "z"], "dtype": "<U1"},
    ... )
    >>> wx.util.xr_check_coords(dax, ref)
    True

    """
    # only process the coords of the xarray
    if isinstance(coords, (xr.DataArray, xr.Dataset)):
        coords = coords.coords

    for key, check in ref.items():
        # check if the optional key is set to true
        if "optional" in check and check["optional"] and key not in coords:
            # skip this key - it is not in dax
            continue

        if key not in coords:
            # Attributes not found in coords
            raise KeyError(f"Could not find required coordinate '{key}'.")

        # only if the key "values" is given do the validation
        if "values" in check and not np.all(
                coords[key].values == check["values"]):
            raise ValueError(f"Value mismatch in DataArray and ref['{key}']"
                             f"\n{coords[key].values}"
                             f"\n{check['values']}")

        # only if the key "dtype" is given do the validation
        if "dtype" in check:
            dtype_list = check["dtype"]
            if not isinstance(dtype_list, list):
                dtype_list = [dtype_list]
            if not any(
                    _check_dtype(coords[key].dtype, var_dtype)
                    for var_dtype in dtype_list):
                raise TypeError(
                    f"Mismatch in the dtype of the DataArray and ref['{key}']")

        if UNITS_KEY in check:
            units = coords[key].attrs.get(UNITS_KEY, None)
            if not units or not U_(units) == U_(check[UNITS_KEY]):
                raise ValueError(
                    f"Unit mismatch in coordinate '{key}'\n"
                    f"Coordinate has unit '{units}', expected '{check['units']}'"
                )

        if "dimensionality" in check:
            units = coords[key].attrs.get(UNITS_KEY, None)
            dim = check["dimensionality"]
            if units is None or not U_(units).is_compatible_with(dim):
                raise DimensionalityError(
                    units,
                    check["dimensionality"],
                    extra_msg=
                    f"\nDimensionality mismatch in coordinate '{key}'\n"
                    f"Coordinate has unit '{units}', expected '{dim}'",
                )

    return True