예제 #1
0
# =========
# Author : Sepand KASHANI [[email protected]]
# ##############################################################################
"""
Linear algebra routines.
"""

import numpy as np
import scipy.linalg as linalg

import imot_tools.util.argcheck as chk


@chk.check(
    dict(
        A=chk.accept_any(chk.has_reals, chk.has_complex),
        B=chk.allow_None(chk.accept_any(chk.has_reals, chk.has_complex)),
        tau=chk.is_real,
        N=chk.allow_None(chk.is_integer),
    ))
def eigh(A, B=None, tau=1, N=None):
    """
    Solve a generalized eigenvalue problem.

    Finds :math:`(D, V)`, solution of the generalized eigenvalue problem

    .. math::

       A V = B V D.

    This function is a wrapper around :py:func:`scipy.linalg.eigh` that adds energy truncation and
예제 #2
0
class Image:
    r"""
    Container for storing real-valued images defined on :math:`\mathbb{S}^{2}`.

    Main features:

    * import/export spherical maps to FITS format;
    * advanced 2D plotting based on `Matplotlib <https://matplotlib.org/>`_;

    Examples
    --------
    .. doctest::

       import numpy as np

       import imot_tools.io.s2image as s2image
       import imot_tools.math.func as func
       import imot_tools.math.sphere.grid as grid
       import imot_tools.math.sphere.transform as transform

       # grid settings =======================
       direction = transform.eq2cart(1, lat=np.deg2rad(30), lon=np.deg2rad(20)).reshape(-1)
       FoV = np.deg2rad(60)
       N_height, N_width = 256, 384
       px_grid = grid.uniform(direction, FoV, size=[N_height, N_width])

       # data settings =======================
       beta0, a0 = 0.7, [1, 1, 1]
       beta1, a1 = 0.9, [0, 0, 1]
       kent0 = func.Kent(k=func.Kent.min_scale(FoV, beta0) * 2,
                         beta=beta0,
                         g1=direction,
                         a=a0)
       kent1 = func.Kent(k=func.Kent.min_scale(FoV, beta1) * 2,
                         beta=beta1,
                         g1=direction,
                         a=a1)

       data0 = (kent0(px_grid.reshape(3, N_height * N_width).T)
                .reshape(N_height, N_width))
       data1 = (kent1(px_grid.reshape(3, N_height * N_width).T)
                .reshape(N_height, N_width))
       data = np.stack([data0, data1], axis=0)

       # Image creation ======================
       I = s2image.Image(data, px_grid)

    Data IO:

    .. doctest::

       I.to_fits('test.fits')  # save to FITS
       I2 = s2image.from_fits('test.fits')  # load from FITS

    Interactive plotting:

    .. doctest::

       kwargs = dict(cmap='jet')
       I.draw(data_kwargs=kwargs)  # AEQD projection by default, all layers.

    .. image:: _img/sphericalimage_aeqd_example.png

    .. doctest::

       kwargs = dict(cmap='jet')
       I.draw(index=0, projection='GNOM', data_kwargs=kwargs)  # Only show first data slice.

    .. image:: _img/sphericalimage_gnom_example.png

    .. doctest::

       kwargs = dict(cmap='jet')
       I.draw(index=1, projection='LCC', data_kwargs=kwargs)

    .. image:: _img/sphericalimage_lcc_example.png
    """
    @chk.check(dict(data=chk.has_reals, grid=chk.has_reals))
    def __init__(self, data, grid):
        """
        Parameters
        ----------
        data : :py:class:`~numpy.ndarray`
            multi-level (float) data-cube.

            Possible shapes are:

            * (N_height, N_width);
            * (N_image, N_height, N_width);
            * (N_points,);
            * (N_image, N_points).
        grid : :py:class:`~numpy.ndarray`
            (3, ...) Cartesian coordinates of the sky on which the data points are defined.

            Possible shapes are:

            * (3, N_height, N_width);
            * (3, N_points).

        Notes
        -----
        For efficiency reasons, `data` and `grid` are not copied internally.
        """
        grid = np.array(grid, copy=False)
        grid_shape_error_msg = (
            "Parameter[grid] must have shape (3, N_height, N_width) or (3, N_points)."
        )
        if len(grid) != 3:
            raise ValueError(grid_shape_error_msg)
        if grid.ndim == 2:
            self._is_gridded = False
        elif grid.ndim == 3:
            self._is_gridded = True
        else:
            raise ValueError(grid_shape_error_msg)
        self._grid = grid / linalg.norm(grid, axis=0)

        data = np.array(data, copy=False)
        if self._is_gridded:
            N_height, N_width = self._grid.shape[1:]
            if (data.ndim == 2) and chk.has_shape([N_height, N_width])(data):
                self._data = data[np.newaxis]
            elif (data.ndim == 3) and chk.has_shape([N_height, N_width])(
                    data[0]):
                self._data = data
            else:
                raise ValueError("Parameters[grid, data] are inconsistent.")
        else:
            N_points = self._grid.shape[1]
            if (data.ndim == 1) and chk.has_shape([N_points])(data):
                self._data = data[np.newaxis]
            elif (data.ndim == 2) and chk.has_shape([N_points])(data[0]):
                self._data = data
            else:
                raise ValueError("Parameters[grid, data] are inconsistent.")

    @property
    def data(self):
        """
        Returns
        -------
        I : :py:class:`~numpy.ndarray`
            (N_image, ...) data cube.
        """
        return self._data

    @property
    def grid(self):
        """
        Returns
        -------
        XYZ : :py:class:`~numpy.ndarray`
            (3, ...) Cartesian coordinates of the grid on which the data points are defined.
        """
        return self._grid

    def to_fits(self, file_name):
        """
        Save image to FITS file.

        Parameters
        ----------
        file_name : path-like
            Name of file.

        Notes
        -----
        * :py:class:`~imot_tools.io.s2image.Image` subclasses may write WCS information to the FITS
          file.  The user-provided `grid` is assumed in ICRS.  If this is not the case, rotate the
          grid accordingly before calling :py:meth:`~imot_tools.io.s2image.Image.to_fits`.

        * Data cubes are stored in a secondary IMAGE frame and can be viewed with DS9 using::

              $ ds9 <FITS_file>.fits[IMAGE]

          Only gridded maps are successfully visualized with DS9.  Moreover WCS information only
          available in select subclasses.
        """
        primary_hdu = self._PrimaryHDU()
        image_hdu = self._ImageHDU()

        hdulist = fits.HDUList([primary_hdu, image_hdu])
        hdulist.writeto(file_name, overwrite=True)

    def _PrimaryHDU(self):
        """
        Generate primary Header Descriptor Unit (HDU) for FITS export.

        Returns
        -------
        hdu : :py:class:`~astropy.io.fits.PrimaryHDU`
        """
        metadata = dict(IMG_TYPE=(self.__class__.__name__, "Image subclass"))

        # grid: stored as angles to reduce file size.
        _, colat, lon = transform.cart2pol(*self._grid)
        coordinates = np.stack([colat, lon], axis=0)

        hdu = fits.PrimaryHDU(data=coordinates)
        for k, v in metadata.items():
            hdu.header[k] = v
        return hdu

    def _ImageHDU(self):
        """
        Generate image Header Descriptor Unit (HDU) for FITS export.

        Returns
        -------
        hdu : :py:class:`~astropy.io.fits.ImageHDU`
        """
        hdu = fits.ImageHDU(data=self._data, name="IMAGE")
        return hdu

    @classmethod
    @chk.check(
        dict(primary_hdu=chk.is_instance(fits.PrimaryHDU),
             image_hdu=chk.is_instance(fits.ImageHDU)))
    def _from_fits(cls, primary_hdu, image_hdu):
        """
        Load image from Header Descriptor Units.

        Parameters
        ----------
        primary_hdu : :py:class:`~astropy.io.fits.PrimaryHDU`
        image_hdu : :py:class:`~astropy.io.fits.ImageHDU`

        Returns
        -------
        I : :py:class:`~imot_tools.io.s2image.Image`
        """
        # PrimaryHDU: grid specification.
        colat, lon = primary_hdu.data
        grid = transform.pol2cart(1, colat, lon)

        # ImageHDU: extract data cube.
        data = image_hdu.data

        I = cls(data=data, grid=grid)
        return I

    @property
    def shape(self):
        """
        Returns
        -------
        sh : tuple
            Shape of data cube.
        """
        return self._data.shape

    @chk.check(
        dict(
            index=chk.accept_any(chk.is_integer, chk.has_integers,
                                 chk.is_instance(slice)),
            projection=chk.is_instance(str),
            catalog=chk.allow_None(chk.has_reals),
            show_gridlines=chk.is_boolean,
            show_colorbar=chk.is_boolean,
            ax=chk.allow_None(chk.is_instance(axes.Axes)),
            use_contours=chk.is_boolean,
            data_kwargs=chk.allow_None(chk.is_instance(dict)),
            grid_kwargs=chk.allow_None(chk.is_instance(dict)),
            catalog_kwargs=chk.allow_None(chk.is_instance(dict)),
        ))
    def draw(
        self,
        index=slice(None),
        projection="AEQD",
        catalog=None,
        show_gridlines=True,
        show_colorbar=True,
        ax=None,
        use_contours=False,
        data_kwargs=None,
        grid_kwargs=None,
        catalog_kwargs=None,
    ):
        """
        Plot spherical image using a 2D projection.

        Parameters
        ----------
        index : int, array-like(int), slice
            Slices of the data-cube to show.

            If multiple layers are provided, they are summed together.
        projection : str
            Plot projection.

            Must be one of (case-insensitive):

            * AEQD: `Azimuthal Equi-Distant <https://en.wikipedia.org/wiki/Azimuthal_equidistant_projection>`_; (default)
            * LAEA: `Lambert Equal-Area <https://en.wikipedia.org/wiki/Lambert_azimuthal_equal-area_projection>`_;
            * LCC: `Lambert Conformal Conic <https://en.wikipedia.org/wiki/Lambert_conformal_conic_projection>`_;
            * ROBIN: `Robinson <https://en.wikipedia.org/wiki/Robinson_projection>`_;
            * GNOM: `Gnomonic <https://en.wikipedia.org/wiki/Gnomonic_projection>`_;
            * HEALPIX: `Hierarchical Equal-Area Pixelisation <https://en.wikipedia.org/wiki/HEALPix>`_.

            Notes
            -----
            * (AEQD, LAEA, LCC, GNOM) are recommended for mapping portions of the sphere.

                * LCC breaks down when mapping polar regions.

            * (ROBIN, HEALPIX) are recommended for mapping the entire sphere.
        catalog : :py:class:`~numpy.ndarray`
            (3, N_src) source directions to overlay on top of images. (Default: no overlay)
            The catalog is assumed to lie in the same reference frame as `grid`.
        show_gridlines : bool
            Show RA/DEC gridlines. (Default: True)
        show_colorbar : bool
            Show colorbar. (Default: True)
        ax : :py:class:`~matplotlib.axes.Axes`
            Axes to draw on.

            If :py:obj:`None`, a new axes is used.
        use_contours: bool

            * If :py:obj:`True`, use [tri]contourf() to produce the plots.
            * If :py:obj:`False` (default), use [tri]pcolor[mesh]() to produce the plots.
        data_kwargs : dict
            Keyword arguments related to data-cube visualization.

            Depending on `use_contours`, accepted keys are:

            * :py:meth:`~matplotlib.axes.Axes.contourf` / :py:meth:`~matplotlib.axes.Axes.pcolormesh` options;
            * :py:meth:`~matplotlib.axes.Axes.tricontourf` / :py:meth:`~matplotlib.axes.Axes.tripcolor` options.
        grid_kwargs : dict
            Keyword arguments related to grid visualization.

            Accepted keys are:

            * N_parallel : int
                Number declination lines to show in viewable region. (Default: 3)
            * N_meridian : int
                Number of right-ascension lines to show in viewable region. (Default: 3)
            * polar_plot : bool
                Correct RA/DEC gridlines when mapping polar regions. (Default: False)

                When mapping polar regions, meridian lines may be doubled at 180W/E, making it seem like a meridian line is missing.
                Setting `polar_plot` to :py:obj:`True` redistributes the meridians differently to correct the issue.

                This option only makes sense when mapping polar regions, and will produce incorrect gridlines otherwise.
            * ticks : bool
                Add RA/DEC labels next to gridlines. (Default: False)
                TODO: change to True once implemented
        catalog_kwargs : dict
            Keyword arguments related to catalog visualization.

            Accepted keys are:

            * :py:meth:`~matplotlib.axes.Axes.scatter` options.

        Returns
        -------
        ax : :py:class:`~matplotlib.axes.Axes`
        proj : :py:class:`pyproj.Proj`
        scm : :py:class:`~matplotlib.cm.ScalarMappable`
        """
        if ax is None:
            fig, ax = plt.subplots()

        proj = self._draw_projection(projection)
        scm = self._draw_data(index, data_kwargs, use_contours, proj, ax)
        cbar = self._draw_colorbar(show_colorbar, scm, ax)
        self._draw_gridlines(show_gridlines, grid_kwargs, proj, ax)
        self._draw_catalog(catalog, catalog_kwargs, proj, ax)
        self._draw_beautify(proj, ax)

        return ax, proj, scm

    def _draw_projection(self, projection):
        """
        Setup :py:class:`pyproj.Proj` object to do (lon,lat) <-> (x,y) transforms.

        Parameters
        ----------
        projection : str
            `projection` parameter given to :py:meth:`draw`.

        Returns
        -------
        proj : :py:class:`pyproj.Proj`
        """
        # Most projections can be provided a point in space around which distortions are minimized.
        # We choose this point to approximately map to the center of the grid when appropriate.
        # (approximate since it is not always a spherical cap.)
        if self._is_gridded:  # (3, N_height, N_width) grid
            grid_dir = np.mean(self._grid, axis=(1, 2))
        else:  # (3, N_points) grid
            grid_dir = np.mean(self._grid, axis=1)
        _, grid_lat, grid_lon = transform.cart2eq(*grid_dir)
        grid_lat, grid_lon = self._wrapped_rad2deg(grid_lat, grid_lon)

        p_name = projection.lower()
        if p_name == "lcc":
            # Lambert Conformal Conic
            proj = pyproj.Proj(proj="lcc", lon_0=grid_lon, lat_0=grid_lat, R=1)
        elif p_name == "aeqd":
            # Azimuthal Equi-Distant
            proj = pyproj.Proj(proj="aeqd",
                               lon_0=grid_lon,
                               lat_0=grid_lat,
                               R=1)
        elif p_name == "laea":
            # Lambert Equal-Area
            proj = pyproj.Proj(proj="laea",
                               lon_0=grid_lon,
                               lat_0=grid_lat,
                               R=1)
        elif p_name == "robin":
            # Robinson
            proj = pyproj.Proj(proj="robin", lon_0=grid_lon, R=1)
        elif p_name == "gnom":
            # Gnomonic
            proj = pyproj.Proj(proj="gnom",
                               lon_0=grid_lon,
                               lat_0=grid_lat,
                               R=1)
        elif p_name == "healpix":
            # Hierarchical Equal-Area Pixelisation
            proj = pyproj.Proj(proj="healpix",
                               lon_0=grid_lon,
                               lat_0=grid_lat,
                               R=1)
        else:
            raise ValueError(
                "Parameter[projection] is not a valid projection specifier.")

        return proj

    @chk.check(
        dict(
            index=chk.accept_any(chk.is_integer, chk.has_integers,
                                 chk.is_instance(slice)),
            data_kwargs=chk.allow_None(chk.is_instance(dict)),
            use_contours=chk.is_boolean,
            projection=chk.is_instance(pyproj.Proj),
            ax=chk.is_instance(axes.Axes),
        ))
    def _draw_data(self, index, data_kwargs, use_contours, projection, ax):
        """
        Contour plot of data.

        Parameters
        ----------
        index : int, array-like(int), slice
            `index` parameter given to :py:meth:`draw`.
        data_kwargs : dict
            `data_kwargs` parameter given to :py:meth:`draw`.
        use_contours: bool
            If :py:obj:`True`, use [tri]contourf() to produce the plots.
            If :py:obj:`False`, use [tri]pcolor[mesh]() to produce the plots.
        projection : :py:class:`~pyproj.Proj`
            PyProj projection object.
        ax : :py:class:`~matplotlib.axes.Axes`
            Axes to plot on.

        Returns
        -------
        scm : :py:class:`~matplotlib.cm.ScalarMappable`
        """
        if data_kwargs is None:
            data_kwargs = dict()

        N_image = self.shape[0]
        if chk.is_integer(index):
            index = np.array([index], dtype=int)
        elif chk.has_integers(index):
            index = np.array(index, dtype=int)
        else:  # slice()
            index = np.arange(N_image, dtype=int)[index]
            if index.size == 0:
                raise ValueError("No data-cube slice chosen.")
        if not np.all((0 <= index) & (index < N_image)):
            raise ValueError("Parameter[index] is out of bounds.")
        data = np.sum(self._data[index], axis=0)

        _, grid_lat, grid_lon = transform.cart2eq(*self._grid)
        grid_x, grid_y = self._eq2xy(grid_lat, grid_lon, projection)

        # Colormap choice
        if "cmap" in data_kwargs:
            obj = data_kwargs.pop("cmap")
            if chk.is_instance(str)(obj):
                cmap = cm.get_cmap(obj)
            else:
                cmap = obj
        else:
            cmap = cm.get_cmap("RdPu")

        if self._is_gridded:
            if use_contours:
                scm = ax.contourf(grid_x,
                                  grid_y,
                                  data,
                                  cmap.N,
                                  cmap=cmap,
                                  **data_kwargs)
            else:
                scm = ax.pcolormesh(grid_x,
                                    grid_y,
                                    data,
                                    cmap=cmap,
                                    **data_kwargs)
        else:
            triangulation = tri.Triangulation(grid_x, grid_y)
            if use_contours:
                scm = ax.tricontourf(triangulation,
                                     data,
                                     cmap.N,
                                     cmap=cmap,
                                     **data_kwargs)
            else:
                scm = ax.tripcolor(triangulation,
                                   data,
                                   cmap=cmap,
                                   **data_kwargs)

        # Show coordinates in status bar
        def sexagesimal_coords(x, y):
            lon, lat = projection(x, y, errcheck=False, inverse=True)
            lon = (coord.Angle(lon * u.deg).wrap_at(180 * u.deg).to_string(
                unit=u.hourangle, sep="hms"))
            lat = coord.Angle(lat * u.deg).to_string(unit=u.degree, sep="dms")

            msg = f"RA: {lon}, DEC: {lat}"
            return msg

        ax.format_coord = sexagesimal_coords

        return scm

    @chk.check(
        dict(
            show_colorbar=chk.is_boolean,
            scm=chk.is_instance(cm.ScalarMappable),
            ax=chk.is_instance(axes.Axes),
        ))
    def _draw_colorbar(self, show_colorbar, scm, ax):
        """
        Attach colorbar.

        Parameters
        ----------
        show_colorbar : bool
            `show_colorbar` parameter given to :py:meth:`draw`.
        scm : :py:class:`~matplotlib.cm.ScalarMappable`
            Intensity scale.
        ax : :py:class:`~matplotlib.axes.Axes`
            Axes to plot on.

        Returns
        -------
        cbar : :py:class:`~matplotlib.colorbar.Colorbar`
        """
        if show_colorbar:
            cbar = plot.colorbar(scm, ax)
        else:
            cbar = None

        return cbar

    @chk.check(
        dict(
            show_gridlines=chk.is_boolean,
            grid_kwargs=chk.allow_None(chk.is_instance(dict)),
            projection=chk.is_instance(pyproj.Proj),
            ax=chk.is_instance(axes.Axes),
        ))
    def _draw_gridlines(self, show_gridlines, grid_kwargs, projection, ax):
        """
        Plot Right-Ascension / Declination lines.

        Parameters
        ----------
        show_gridlines : bool
            `show_gridlines` parameter given to :py:meth:`draw`.
        grid_kwargs : dict
            `grid_kwargs` parameter given to :py:meth:`draw`.
        projection : :py:class:`pyproj.Proj`
            PyProj projection object.
        ax : :py:class:`~matplotlib.axes.Axes`
            Axes to plot on.
        """
        if grid_kwargs is None:
            grid_kwargs = dict()

        if "N_parallel" in grid_kwargs:
            N_parallel = grid_kwargs.pop("N_parallel")
            if not (chk.is_integer(N_parallel) and (N_parallel >= 3)):
                raise ValueError("Value[N_parallel] must be at least 3.")
        else:
            N_parallel = 3

        if "N_meridian" in grid_kwargs:
            N_meridian = grid_kwargs.pop("N_meridian")
            if not (chk.is_integer(N_meridian) and (N_meridian >= 3)):
                raise ValueError("Value[N_meridian] must be at least 3.")
        else:
            N_meridian = 3

        if "polar_plot" in grid_kwargs:
            polar_plot = grid_kwargs.pop("polar_plot")
            if not chk.is_boolean(polar_plot):
                raise ValueError("Value[polar_plot] must be boolean.")
        else:
            polar_plot = False

        if "ticks" in grid_kwargs:
            show_ticks = grid_kwargs.pop("ticks")
            if not chk.is_boolean(show_ticks):
                raise ValueError("Value[ticks] must be boolean.")
        else:
            # TODO: change to True once implemented.
            show_ticks = False

        plot_style = dict(alpha=0.5, color="k", linewidth=1, linestyle="solid")
        plot_style.update(grid_kwargs)

        _, grid_lat, grid_lon = transform.cart2eq(*self._grid)
        grid_lat, grid_lon = self._wrapped_rad2deg(grid_lat, grid_lon)

        # RA curves
        meridian = dict()
        dec_span = np.linspace(grid_lat.min(), grid_lat.max(), 200)
        if polar_plot:
            ra = np.linspace(-180, 180, N_meridian, endpoint=False)
        else:
            ra = np.linspace(grid_lon.min(), grid_lon.max(), N_meridian)
        for _ in ra:
            ra_span = _ * np.ones_like(dec_span)
            grid_x, grid_y = self._eq2xy(np.deg2rad(dec_span),
                                         np.deg2rad(ra_span), projection)

            if show_gridlines:
                mer = ax.plot(grid_x, grid_y, **plot_style)[0]
                meridian[_] = mer

        # DEC curves
        parallel = dict()
        ra_span = np.linspace(grid_lon.min(), grid_lon.max(), 200)
        if polar_plot:
            dec = np.linspace(grid_lat.min(), grid_lat.max(), N_parallel + 1)
        else:
            dec = np.linspace(grid_lat.min(), grid_lat.max(), N_parallel)
        for _ in dec:
            dec_span = _ * np.ones_like(ra_span)
            grid_x, grid_y = self._eq2xy(np.deg2rad(dec_span),
                                         np.deg2rad(ra_span), projection)

            if show_gridlines:
                par = ax.plot(grid_x, grid_y, **plot_style)[0]
                parallel[_] = par

        # LAT/LON ticks
        if show_gridlines and show_ticks:
            raise NotImplementedError("Not yet implemented.")

    @chk.check(
        dict(
            catalog=chk.allow_None(chk.has_reals),
            projection=chk.is_instance(pyproj.Proj),
            ax=chk.is_instance(axes.Axes),
        ))
    def _draw_catalog(self, catalog, catalog_kwargs, projection, ax):
        """
        Overlay catalog on top of map.

        Parameters
        ----------
        catalog : :py:class:`~numpy.ndarray`
            `catalog` parameter given to :py:meth:`draw`.
        catalog_kwargs : dict
            `catalog_kwargs` parameter given to :py:meth:`draw`.
        projection : :py:class:`pyproj.Proj`
            PyProj projection object.
        ax : :py:class:`~matplotlib.axes.Axes`
            Axes to plot on.
        """
        if catalog is not None:
            N_src = catalog.size // 3
            if not (catalog.shape == (3, N_src)):
                raise ValueError(
                    "Parameter[catalog]: expected (3, N_src) array.")

            _, c_lat, c_lon = transform.cart2eq(*catalog)
            c_x, c_y = self._eq2xy(c_lat, c_lon, projection)

            if catalog_kwargs is None:
                catalog_kwargs = dict()

            plot_style = dict(s=400, facecolors="none", edgecolors="w")
            plot_style.update(catalog_kwargs)

            ax.scatter(c_x, c_y, **plot_style)

    @chk.check(
        dict(projection=chk.is_instance(pyproj.Proj),
             ax=chk.is_instance(axes.Axes)))
    def _draw_beautify(self, projection, ax):
        """
        Format plot.

        Parameters
        ----------
        projection : :py:class:`pyproj.Proj`
            PyProj projection object.
        ax : :py:class:`~matplotlib.axes.Axes`
            Axes to draw on.
        """
        ax.axis("off")
        ax.axis("equal")

    @classmethod
    def _wrapped_rad2deg(cls, lat_r, lon_r):
        """
        Equatorial coordinate [rad] -> [deg] unit conversion.
        Output longitude guaranteed to lie in [-180, 180) [deg].

        Parameters
        ----------
        lat_r : :py:class:`~numpy.ndarray`
        lon_r : :py:class:`~numpy.ndarray`

        Returns
        -------
        lat_d : :py:class:`~numpy.ndarray`
        lon_d : :py:class:`~numpy.ndarray`
        """
        lat_d = coord.Angle(lat_r * u.rad).to_value(u.deg)
        lon_d = coord.Angle(lon_r * u.rad).wrap_at(180 * u.deg).to_value(u.deg)
        return lat_d, lon_d

    @classmethod
    def _eq2xy(cls, lat_r, lon_r, projection):
        """
        Transform (lon,lat) [rad] to (x,y).
        Some projections have unmappable regions or exhibit singularities at certain points.
        These regions are colored white in contour plots by replacing their incorrect value (1e30)
        with NaN.

        Parameters
        ----------
        lat_r : :py:class:`~numpy.ndarray`
        lon_r : :py:class:`~numpy.ndarray`
        projection : :py:class:`~pyproj.Proj`

        Returns
        -------
        x : :py:class:`~numpy.ndarray`
        y : :py:class:`~numpy.ndarray`
        """
        lat_d, lon_d = cls._wrapped_rad2deg(lat_r, lon_r)
        x, y = projection(lon_d, lat_d, errcheck=False)
        x[np.isclose(x, 1e30)] = np.nan
        y[np.isclose(y, 1e30)] = np.nan
        return x, y
예제 #3
0
# Author : Sepand KASHANI [[email protected]]
# ##############################################################################
"""
Coordinate transforms.
"""

import astropy.coordinates as coord
import astropy.units as u
import numpy as np

import imot_tools.util.argcheck as chk


@chk.check(
    dict(
        r=chk.accept_any(chk.is_real, chk.has_reals),
        colat=chk.accept_any(chk.is_real, chk.has_reals),
        lon=chk.accept_any(chk.is_real, chk.has_reals),
    ))
def pol2eq(r, colat, lon):
    """
    Polar coordinates to Equatorial coordinates.

    Parameters
    ----------
    r : float or :py:class:`~numpy.ndarray`
        Radius.
    colat : :py:class:`~numpy.ndarray`
        Polar/Zenith angle [rad].
    lon : :py:class:`~numpy.ndarray`
        Longitude angle [rad].
예제 #4
0
class Wishart(RandomSampler):
    """
    `Wishart <https://en.wikipedia.org/wiki/Wishart_distribution>`_ distribution.
    """
    @chk.check(
        dict(V=chk.accept_any(chk.has_reals, chk.has_complex),
             n=chk.is_integer))
    def __init__(self, V, n):
        """
        Parameters
        ----------
        V : :py:class:`~numpy.ndarray`
            (p, p) positive-semidefinite scale matrix.
        n : int
            degrees of freedom.
        """
        super().__init__()

        V = np.array(V)
        p = len(V)

        if not (chk.has_shape([p, p])(V) and np.allclose(V, V.conj().T)):
            raise ValueError("Parameter[V] must be hermitian symmetric.")
        if not (n > p):
            raise ValueError(f"Parameter[n] must be greater than {p}.")

        self._V = V
        self._p = p
        self._n = n

        Vq = linalg.sqrtm(V)
        _, R = linalg.qr(Vq)
        self._L = R.conj().T

    @chk.check("N_sample", chk.is_integer)
    def __call__(self, N_sample=1):
        """
        Generate random samples.

        Parameters
        ----------
        N_sample : int
            Number of samples to generate.

        Returns
        -------
        x : :py:class:`~numpy.ndarray`
            (N_sample, p, p) samples.

        Notes
        -----
        The Wishart estimate is obtained using the `Bartlett Decomposition`_.

        .. _Bartlett Decomposition: https://en.wikipedia.org/wiki/Wishart_distribution#Bartlett_decomposition
        """
        if N_sample < 1:
            raise ValueError("Parameter[N_sample] must be positive.")

        A = np.zeros((N_sample, self._p, self._p))

        diag_idx = np.diag_indices(self._p)
        df = self._n * np.ones((N_sample, 1)) - np.arange(self._p)
        A[:, diag_idx[0], diag_idx[1]] = np.sqrt(stats.chi2.rvs(df=df))

        tril_idx = np.tril_indices(self._p, k=-1)
        size = (N_sample, self._p * (self._p - 1) // 2)
        A[:, tril_idx[0], tril_idx[1]] = stats.norm.rvs(size=size)

        W = self._L @ A
        X = W @ W.conj().transpose(0, 2, 1)
        return X
예제 #5
0
class Interpolator:
    r"""
    Interpolate order-limited zonal function from spatial samples.

    Computes :math:`f(r) = \sum_{q} \alpha_{q} f(r_{q}) K_{N}(\langle r, r_{q} \rangle)`, where
    :math:`r_{q} \in \mathbb{S}^{2}` are points from a spatial sampling scheme, :math:`K_{N}(\cdot)`
    is the spherical Dirichlet kernel of order :math:`N`, and the :math:`\alpha_{q}` are scaling
    factors tailored to the sampling scheme.
    """
    @chk.check(dict(N=chk.is_integer, approximate_kernel=chk.is_boolean))
    def __init__(self, N, approximate_kernel=False):
        r"""
        Parameters
        ----------
        N : int
            Order of the reconstructed zonal function.
        approximate_kernel : bool
            If :py:obj:`True`, pass the `approx` option to :py:class:`~imot_tools.math.func.SphericalDirichlet`.
        """
        if not (N > 0):
            raise ValueError("Parameter[N] must be positive.")
        self._N = N
        self._kernel_func = func.SphericalDirichlet(N, approximate_kernel)

    @chk.check(
        dict(
            weight=chk.has_reals,
            support=chk.has_reals,
            f=chk.accept_any(chk.has_reals, chk.has_complex),
            r=chk.has_reals,
            sparsity_mask=chk.allow_None(
                chk.require_all(chk.is_instance(sp.spmatrix),
                                lambda _: np.issubdtype(_.dtype, np.bool_))),
        ))
    def __call__(self, weight, support, f, r, sparsity_mask=None):
        """
        Interpolate function samples at order `N`.

        Parameters
        ----------
        weight : :py:class:`~numpy.ndarray`
            (N_s,) weights to apply per support point.
        support : :py:class:`~numpy.ndarray`
            (3, N_s) critical support points.
        f : :py:class:`~numpy.ndarray`
            (L, N_s) zonal function values at support points. (float or complex)
        r : :py:class:`~numpy.ndarray`
            (3, N_px) evaluation points.
        sparsity_mask : :py:class:`~scipy.sparse.spmatrix`
            (N_s, N_px) sparsity mask (bool) to perform localized kernel evaluation.

        Returns
        -------
        f_interp : :py:class:`~numpy.ndarray`
            (L, N_px) function values at specified coordinates.
        """
        if not (weight.shape == (weight.size, )):
            raise ValueError("Parameter[weight] must have shape (N_s,).")
        N_s = weight.size

        if not (support.shape == (3, N_s)):
            raise ValueError("Parameter[support] must have shape (3, N_s).")

        L = len(f)
        if not (f.shape == (L, N_s)):
            raise ValueError("Parameter[f] must have shape (L, N_s).")

        if not ((r.ndim == 2) and (r.shape[0] == 3)):
            raise ValueError("Parameter[r] must have shape (3, N_px).")
        N_px = r.shape[1]

        if sparsity_mask is not None:
            if not (sparsity_mask.shape == (N_s, N_px)):
                raise ValueError(
                    "Parameter[sparsity_mask] must have shape (N_s, N_px).")

        if sparsity_mask is None:  # Dense evaluation
            kernel = self._kernel_func(support.T @ r)
            beta = f * weight
            f_interp = beta @ kernel
        else:  # Sparse evaluation
            raise NotImplementedError

        return f_interp
예제 #6
0
class EqualAngleInterpolator(Interpolator):
    r"""
    Interpolate order-limited zonal function from Equal-Angle samples.

    Computes :math:`f(r) = \sum_{q, l} \alpha_{q} f(r_{q, l}) K_{N}(\langle r, r_{q, l} \rangle)`,
    where :math:`r_{q, l} \in \mathbb{S}^{2}` are points from an Equal-Angle sampling scheme,
    :math:`K_{N}(\cdot)` is the spherical Dirichlet kernel of order :math:`N`, and the
    :math:`\alpha_{q}` are scaling factors tailored to an Equal-Angle sampling scheme.

    Examples
    --------
    Let :math:`\gamma_{N}(r): \mathbb{S}^{2} \to \mathbb{R}` be the order-:math:`N` approximation of
    :math:`\gamma(r) = \delta(r - r_{0})`:

    .. math::

       \gamma_{N}(r) = \frac{N + 1}{4 \pi} \frac{P_{N + 1}(\langle r, r_{0} \rangle) - P_{N}(\langle r, r_{0} \rangle)}{\langle r, r_{0} \rangle -1}.

    As :math:`\gamma_{N}` is order-limited, it can be exactly reconstructed from it's samples on an
    order-:math:`N` Equal-Angle grid:

    .. testsetup::

       import numpy as np

       from imot_tools.math.func import SphericalDirichlet
       from imot_tools.math.sphere.grid import equal_angle
       from imot_tools.math.sphere.interpolate import EqualAngleInterpolator
       from imot_tools.math.sphere.transform import pol2cart

       def gammaN(r, r0, N):
           similarity = np.tensordot(r0, r, axes=1)
           d_func = SphericalDirichlet(N)
           return d_func(similarity)

    .. doctest::

       # \gammaN Parameters
       >>> N = 3
       >>> r0 = np.array([1, 0, 0])

       # Solution at Nyquist resolution
       >>> colat_idx, lon_idx, colat_nyquist, lon_nyquist = equal_angle(N)
       >>> N_colat, N_lon = colat_nyquist.size, lon_nyquist.size
       >>> R_nyquist = pol2cart(1, colat_nyquist, lon_nyquist)
       >>> g_nyquist = gammaN(R_nyquist, r0, N)

       # Solution at high resolution
       >>> _, _, colat_dense, lon_dense = equal_angle(2 * N)
       >>> R_dense = pol2cart(1, colat_dense, lon_dense).reshape(3, -1)
       >>> g_exact = gammaN(R_dense, r0, N)

       >>> ea_interp = EqualAngleInterpolator(N)
       >>> g_interp = ea_interp(colat_idx,
       ...                      lon_idx,
       ...                      f=g_nyquist.reshape(1, N_colat, N_lon),
       ...                      r=R_dense)

       >>> np.allclose(g_exact, g_interp)
       True
    """
    @chk.check(dict(N=chk.is_integer, approximate_kernel=chk.is_boolean))
    def __init__(self, N, approximate_kernel=False):
        r"""
        Parameters
        ----------
        N : int
            Order of the reconstructed zonal function.
        approximate_kernel : bool
            If :py:obj:`True`, pass the `approx` option to :py:class:`~imot_tools.math.func.SphericalDirichlet`.
        """
        super().__init__(N, approximate_kernel)

    # TODO: Allow sparse evaluation.
    @chk.check(
        dict(
            colat_idx=chk.has_integers,
            lon_idx=chk.has_integers,
            f=chk.accept_any(chk.has_reals, chk.has_complex),
            r=chk.has_reals,
        ))
    def __call__(self, colat_idx, lon_idx, f, r):
        """
        Interpolate function samples at order `N`.

        Parameters
        ----------
        colat_idx : :py:class:`~numpy.ndarray`
            (N_colat,) polar support indices from :py:func:`~imot_tools.math.sphere.grid.equal_angle`.
        lon_idx : :py:class:`~numpy.ndarray`
            (N_lon,) azimuthal support indices from :py:func:`~imot_tools.math.sphere.grid.equal_angle`.
        f : :py:class:`~numpy.ndarray`
            (L, N_colat, N_lon) zonal function values at support points. (float or complex)
        r : :py:class:`~numpy.ndarray`
            (3, N_px) evaluation points.

        Returns
        -------
        f_interp : :py:class:`~numpy.ndarray`
            (L, N_px) function values at specified coordinates.
        """
        N_colat = colat_idx.size
        if not (colat_idx.shape == (N_colat, )):
            raise ValueError(
                "Parameter[colat_idx] must have shape (N_colat,).")

        N_lon = lon_idx.size
        if not (lon_idx.shape == (N_lon, )):
            raise ValueError("Parameter[lon_idx] must have shape (N_lon,).")

        L = len(f)
        if not (f.shape == (L, N_colat, N_lon)):
            raise ValueError(
                "Parameter[f] must have shape (L, N_colat, N_lon).")

        if not ((r.ndim == 2) and (r.shape[0] == 3)):
            raise ValueError("Parameter[r] must have shape (3, N_px).")

        # Apply weights directly onto `f` to avoid memory blow-up.
        _, _, colat, lon = grid.equal_angle(self._N)
        a = np.arange(self._N + 1)
        weight = (np.sum(np.sin((2 * a + 1) * colat[colat_idx]) / (2 * a + 1),
                         axis=1) * np.sin(colat[colat_idx, 0]) *
                  ((2 * np.pi) / ((self._N + 1)**2)))  # (N_colat,)
        fw = f * weight.reshape((1, N_colat, 1))  # (L, N_colat, N_lon)

        f_interp = super().__call__(
            weight=np.broadcast_to([1], (N_colat * N_lon, )),
            support=transform.pol2cart(1, colat[colat_idx, :],
                                       lon[:, lon_idx]).reshape((3, -1)),
            f=fw.reshape((L, -1)),
            r=r,
            sparsity_mask=None,
        )
        return f_interp