Exemple #1
0
    def _load_grid(
        self, pts0: u.Quantity, pts1: u.Quantity, pts2: u.Quantity, **kwargs
    ):
        r"""
        Initialize the grid object from a user-supplied grid

        Parameters
        ----------
        grid{0,1,2} : u.Quantity array, shape (n0, n1, n2)
            Grids of coordinate positions.

        **kwargs: u.Quantity array, shape (n0, n1, n2)
            Quantities defined on the grid

        Returns
        -------
        None.

        """

        # Validate input
        if not (pts0.shape == pts1.shape and pts0.shape == pts2.shape):
            raise ValueError(
                "Provided arrays of grid points are of unequal "
                f"shape: pts0 = {pts0.shape}, "
                f"pts1 = {pts1.shape}, "
                f"pts2 = {pts2.shape}."
            )

        self._is_uniform = _detect_is_uniform_grid(pts0, pts1, pts2)

        # Create dataset
        self.ds = xr.Dataset()

        self.ds.attrs["axis_units"] = [pts0.unit, pts1.unit, pts2.unit]
        if self.is_uniform:
            self.ds.coords["ax0"] = pts0[:, 0, 0]
            self.ds.coords["ax1"] = pts1[0, :, 0]
            self.ds.coords["ax2"] = pts2[0, 0, :]

        else:
            mdx = pd.MultiIndex.from_arrays(
                [pts0.flatten(), pts1.flatten(), pts2.flatten()],
                names=["ax0", "ax1", "ax2"],
            )
            self.ds.coords["ax"] = mdx

        # Add quantities
        for qk in kwargs.keys():
            q = kwargs[qk]

            self.add_quantity(qk, q)

        # Check to make sure that the object created satisfies any
        # requirements: eg. units correspond to the coordinate system
        self._validate()
Exemple #2
0
    def _load_grid(
        self,
        pts0: u.Quantity,
        pts1: u.Quantity,
        pts2: u.Quantity,
    ):
        r"""
        Initialize the grid object from a user-supplied grid.

        Parameters
        ----------
        grid{0,1,2} : `~astropy.units.Quantity` array, shape (n0, n1, n2)
            Grids of coordinate positions.

        **kwargs : `~astropy.units.Quantity` array, shape (n0, n1, n2)
            Quantities defined on the grid.
        """

        # Validate input
        if pts0.shape != pts1.shape or pts0.shape != pts2.shape:
            raise ValueError(
                "Provided arrays of grid points are of unequal "
                f"shape: pts0 = {pts0.shape}, "
                f"pts1 = {pts1.shape}, "
                f"pts2 = {pts2.shape}."
            )

        self._is_uniform = _detect_is_uniform_grid(pts0, pts1, pts2)

        # Create dataset
        self.ds = xr.Dataset()

        self.ds.attrs["axis_units"] = [pts0.unit, pts1.unit, pts2.unit]

        # Store the conversion factors for each axis to SI
        self._si_factors = [
            pts0.unit.si.scale,
            pts1.unit.si.scale,
            pts2.unit.si.scale,
        ]

        if self.is_uniform:
            self.ds.coords["ax0"] = pts0[:, 0, 0]
            self.ds.coords["ax1"] = pts1[0, :, 0]
            self.ds.coords["ax2"] = pts2[0, 0, :]

        else:
            mdx = pd.MultiIndex.from_arrays(
                [pts0.flatten(), pts1.flatten(), pts2.flatten()],
                names=["ax0", "ax1", "ax2"],
            )
            self.ds.coords["ax"] = mdx

        # Check to make sure that the object created satisfies any
        # requirements: eg. units correspond to the coordinate system
        self._validate()
Exemple #3
0
    def add_quantity(self, key: str, quantity: u.Quantity):
        r"""
        Adds a quantity to the dataset as a new DataArray
        """

        if self.is_uniform_grid:
            axes = ["ax0", "ax1", "ax2"]
        # If grid is non-uniform, flatten quantity
        else:
            quantity = quantity.flatten()
            axes = ["ax"]

        if quantity.shape != self.shape:
            raise ValueError(f"Shape of quantity '{key}' {quantity.shape} "
                             f"does not match the grid shape {self.shape}.")

        data = xr.DataArray(quantity, dims=axes, attrs={"unit": quantity.unit})
        self.ds[key] = data