# ========= # Author : Sepand KASHANI [[email protected]] # ############################################################################## """ Linear algebra routines. """ import numpy as np import scipy.linalg as linalg import imot_tools.util.argcheck as chk @chk.check( dict( A=chk.accept_any(chk.has_reals, chk.has_complex), B=chk.allow_None(chk.accept_any(chk.has_reals, chk.has_complex)), tau=chk.is_real, N=chk.allow_None(chk.is_integer), )) def eigh(A, B=None, tau=1, N=None): """ Solve a generalized eigenvalue problem. Finds :math:`(D, V)`, solution of the generalized eigenvalue problem .. math:: A V = B V D. This function is a wrapper around :py:func:`scipy.linalg.eigh` that adds energy truncation and
class Image: r""" Container for storing real-valued images defined on :math:`\mathbb{S}^{2}`. Main features: * import/export spherical maps to FITS format; * advanced 2D plotting based on `Matplotlib <https://matplotlib.org/>`_; Examples -------- .. doctest:: import numpy as np import imot_tools.io.s2image as s2image import imot_tools.math.func as func import imot_tools.math.sphere.grid as grid import imot_tools.math.sphere.transform as transform # grid settings ======================= direction = transform.eq2cart(1, lat=np.deg2rad(30), lon=np.deg2rad(20)).reshape(-1) FoV = np.deg2rad(60) N_height, N_width = 256, 384 px_grid = grid.uniform(direction, FoV, size=[N_height, N_width]) # data settings ======================= beta0, a0 = 0.7, [1, 1, 1] beta1, a1 = 0.9, [0, 0, 1] kent0 = func.Kent(k=func.Kent.min_scale(FoV, beta0) * 2, beta=beta0, g1=direction, a=a0) kent1 = func.Kent(k=func.Kent.min_scale(FoV, beta1) * 2, beta=beta1, g1=direction, a=a1) data0 = (kent0(px_grid.reshape(3, N_height * N_width).T) .reshape(N_height, N_width)) data1 = (kent1(px_grid.reshape(3, N_height * N_width).T) .reshape(N_height, N_width)) data = np.stack([data0, data1], axis=0) # Image creation ====================== I = s2image.Image(data, px_grid) Data IO: .. doctest:: I.to_fits('test.fits') # save to FITS I2 = s2image.from_fits('test.fits') # load from FITS Interactive plotting: .. doctest:: kwargs = dict(cmap='jet') I.draw(data_kwargs=kwargs) # AEQD projection by default, all layers. .. image:: _img/sphericalimage_aeqd_example.png .. doctest:: kwargs = dict(cmap='jet') I.draw(index=0, projection='GNOM', data_kwargs=kwargs) # Only show first data slice. .. image:: _img/sphericalimage_gnom_example.png .. doctest:: kwargs = dict(cmap='jet') I.draw(index=1, projection='LCC', data_kwargs=kwargs) .. image:: _img/sphericalimage_lcc_example.png """ @chk.check(dict(data=chk.has_reals, grid=chk.has_reals)) def __init__(self, data, grid): """ Parameters ---------- data : :py:class:`~numpy.ndarray` multi-level (float) data-cube. Possible shapes are: * (N_height, N_width); * (N_image, N_height, N_width); * (N_points,); * (N_image, N_points). grid : :py:class:`~numpy.ndarray` (3, ...) Cartesian coordinates of the sky on which the data points are defined. Possible shapes are: * (3, N_height, N_width); * (3, N_points). Notes ----- For efficiency reasons, `data` and `grid` are not copied internally. """ grid = np.array(grid, copy=False) grid_shape_error_msg = ( "Parameter[grid] must have shape (3, N_height, N_width) or (3, N_points)." ) if len(grid) != 3: raise ValueError(grid_shape_error_msg) if grid.ndim == 2: self._is_gridded = False elif grid.ndim == 3: self._is_gridded = True else: raise ValueError(grid_shape_error_msg) self._grid = grid / linalg.norm(grid, axis=0) data = np.array(data, copy=False) if self._is_gridded: N_height, N_width = self._grid.shape[1:] if (data.ndim == 2) and chk.has_shape([N_height, N_width])(data): self._data = data[np.newaxis] elif (data.ndim == 3) and chk.has_shape([N_height, N_width])( data[0]): self._data = data else: raise ValueError("Parameters[grid, data] are inconsistent.") else: N_points = self._grid.shape[1] if (data.ndim == 1) and chk.has_shape([N_points])(data): self._data = data[np.newaxis] elif (data.ndim == 2) and chk.has_shape([N_points])(data[0]): self._data = data else: raise ValueError("Parameters[grid, data] are inconsistent.") @property def data(self): """ Returns ------- I : :py:class:`~numpy.ndarray` (N_image, ...) data cube. """ return self._data @property def grid(self): """ Returns ------- XYZ : :py:class:`~numpy.ndarray` (3, ...) Cartesian coordinates of the grid on which the data points are defined. """ return self._grid def to_fits(self, file_name): """ Save image to FITS file. Parameters ---------- file_name : path-like Name of file. Notes ----- * :py:class:`~imot_tools.io.s2image.Image` subclasses may write WCS information to the FITS file. The user-provided `grid` is assumed in ICRS. If this is not the case, rotate the grid accordingly before calling :py:meth:`~imot_tools.io.s2image.Image.to_fits`. * Data cubes are stored in a secondary IMAGE frame and can be viewed with DS9 using:: $ ds9 <FITS_file>.fits[IMAGE] Only gridded maps are successfully visualized with DS9. Moreover WCS information only available in select subclasses. """ primary_hdu = self._PrimaryHDU() image_hdu = self._ImageHDU() hdulist = fits.HDUList([primary_hdu, image_hdu]) hdulist.writeto(file_name, overwrite=True) def _PrimaryHDU(self): """ Generate primary Header Descriptor Unit (HDU) for FITS export. Returns ------- hdu : :py:class:`~astropy.io.fits.PrimaryHDU` """ metadata = dict(IMG_TYPE=(self.__class__.__name__, "Image subclass")) # grid: stored as angles to reduce file size. _, colat, lon = transform.cart2pol(*self._grid) coordinates = np.stack([colat, lon], axis=0) hdu = fits.PrimaryHDU(data=coordinates) for k, v in metadata.items(): hdu.header[k] = v return hdu def _ImageHDU(self): """ Generate image Header Descriptor Unit (HDU) for FITS export. Returns ------- hdu : :py:class:`~astropy.io.fits.ImageHDU` """ hdu = fits.ImageHDU(data=self._data, name="IMAGE") return hdu @classmethod @chk.check( dict(primary_hdu=chk.is_instance(fits.PrimaryHDU), image_hdu=chk.is_instance(fits.ImageHDU))) def _from_fits(cls, primary_hdu, image_hdu): """ Load image from Header Descriptor Units. Parameters ---------- primary_hdu : :py:class:`~astropy.io.fits.PrimaryHDU` image_hdu : :py:class:`~astropy.io.fits.ImageHDU` Returns ------- I : :py:class:`~imot_tools.io.s2image.Image` """ # PrimaryHDU: grid specification. colat, lon = primary_hdu.data grid = transform.pol2cart(1, colat, lon) # ImageHDU: extract data cube. data = image_hdu.data I = cls(data=data, grid=grid) return I @property def shape(self): """ Returns ------- sh : tuple Shape of data cube. """ return self._data.shape @chk.check( dict( index=chk.accept_any(chk.is_integer, chk.has_integers, chk.is_instance(slice)), projection=chk.is_instance(str), catalog=chk.allow_None(chk.has_reals), show_gridlines=chk.is_boolean, show_colorbar=chk.is_boolean, ax=chk.allow_None(chk.is_instance(axes.Axes)), use_contours=chk.is_boolean, data_kwargs=chk.allow_None(chk.is_instance(dict)), grid_kwargs=chk.allow_None(chk.is_instance(dict)), catalog_kwargs=chk.allow_None(chk.is_instance(dict)), )) def draw( self, index=slice(None), projection="AEQD", catalog=None, show_gridlines=True, show_colorbar=True, ax=None, use_contours=False, data_kwargs=None, grid_kwargs=None, catalog_kwargs=None, ): """ Plot spherical image using a 2D projection. Parameters ---------- index : int, array-like(int), slice Slices of the data-cube to show. If multiple layers are provided, they are summed together. projection : str Plot projection. Must be one of (case-insensitive): * AEQD: `Azimuthal Equi-Distant <https://en.wikipedia.org/wiki/Azimuthal_equidistant_projection>`_; (default) * LAEA: `Lambert Equal-Area <https://en.wikipedia.org/wiki/Lambert_azimuthal_equal-area_projection>`_; * LCC: `Lambert Conformal Conic <https://en.wikipedia.org/wiki/Lambert_conformal_conic_projection>`_; * ROBIN: `Robinson <https://en.wikipedia.org/wiki/Robinson_projection>`_; * GNOM: `Gnomonic <https://en.wikipedia.org/wiki/Gnomonic_projection>`_; * HEALPIX: `Hierarchical Equal-Area Pixelisation <https://en.wikipedia.org/wiki/HEALPix>`_. Notes ----- * (AEQD, LAEA, LCC, GNOM) are recommended for mapping portions of the sphere. * LCC breaks down when mapping polar regions. * (ROBIN, HEALPIX) are recommended for mapping the entire sphere. catalog : :py:class:`~numpy.ndarray` (3, N_src) source directions to overlay on top of images. (Default: no overlay) The catalog is assumed to lie in the same reference frame as `grid`. show_gridlines : bool Show RA/DEC gridlines. (Default: True) show_colorbar : bool Show colorbar. (Default: True) ax : :py:class:`~matplotlib.axes.Axes` Axes to draw on. If :py:obj:`None`, a new axes is used. use_contours: bool * If :py:obj:`True`, use [tri]contourf() to produce the plots. * If :py:obj:`False` (default), use [tri]pcolor[mesh]() to produce the plots. data_kwargs : dict Keyword arguments related to data-cube visualization. Depending on `use_contours`, accepted keys are: * :py:meth:`~matplotlib.axes.Axes.contourf` / :py:meth:`~matplotlib.axes.Axes.pcolormesh` options; * :py:meth:`~matplotlib.axes.Axes.tricontourf` / :py:meth:`~matplotlib.axes.Axes.tripcolor` options. grid_kwargs : dict Keyword arguments related to grid visualization. Accepted keys are: * N_parallel : int Number declination lines to show in viewable region. (Default: 3) * N_meridian : int Number of right-ascension lines to show in viewable region. (Default: 3) * polar_plot : bool Correct RA/DEC gridlines when mapping polar regions. (Default: False) When mapping polar regions, meridian lines may be doubled at 180W/E, making it seem like a meridian line is missing. Setting `polar_plot` to :py:obj:`True` redistributes the meridians differently to correct the issue. This option only makes sense when mapping polar regions, and will produce incorrect gridlines otherwise. * ticks : bool Add RA/DEC labels next to gridlines. (Default: False) TODO: change to True once implemented catalog_kwargs : dict Keyword arguments related to catalog visualization. Accepted keys are: * :py:meth:`~matplotlib.axes.Axes.scatter` options. Returns ------- ax : :py:class:`~matplotlib.axes.Axes` proj : :py:class:`pyproj.Proj` scm : :py:class:`~matplotlib.cm.ScalarMappable` """ if ax is None: fig, ax = plt.subplots() proj = self._draw_projection(projection) scm = self._draw_data(index, data_kwargs, use_contours, proj, ax) cbar = self._draw_colorbar(show_colorbar, scm, ax) self._draw_gridlines(show_gridlines, grid_kwargs, proj, ax) self._draw_catalog(catalog, catalog_kwargs, proj, ax) self._draw_beautify(proj, ax) return ax, proj, scm def _draw_projection(self, projection): """ Setup :py:class:`pyproj.Proj` object to do (lon,lat) <-> (x,y) transforms. Parameters ---------- projection : str `projection` parameter given to :py:meth:`draw`. Returns ------- proj : :py:class:`pyproj.Proj` """ # Most projections can be provided a point in space around which distortions are minimized. # We choose this point to approximately map to the center of the grid when appropriate. # (approximate since it is not always a spherical cap.) if self._is_gridded: # (3, N_height, N_width) grid grid_dir = np.mean(self._grid, axis=(1, 2)) else: # (3, N_points) grid grid_dir = np.mean(self._grid, axis=1) _, grid_lat, grid_lon = transform.cart2eq(*grid_dir) grid_lat, grid_lon = self._wrapped_rad2deg(grid_lat, grid_lon) p_name = projection.lower() if p_name == "lcc": # Lambert Conformal Conic proj = pyproj.Proj(proj="lcc", lon_0=grid_lon, lat_0=grid_lat, R=1) elif p_name == "aeqd": # Azimuthal Equi-Distant proj = pyproj.Proj(proj="aeqd", lon_0=grid_lon, lat_0=grid_lat, R=1) elif p_name == "laea": # Lambert Equal-Area proj = pyproj.Proj(proj="laea", lon_0=grid_lon, lat_0=grid_lat, R=1) elif p_name == "robin": # Robinson proj = pyproj.Proj(proj="robin", lon_0=grid_lon, R=1) elif p_name == "gnom": # Gnomonic proj = pyproj.Proj(proj="gnom", lon_0=grid_lon, lat_0=grid_lat, R=1) elif p_name == "healpix": # Hierarchical Equal-Area Pixelisation proj = pyproj.Proj(proj="healpix", lon_0=grid_lon, lat_0=grid_lat, R=1) else: raise ValueError( "Parameter[projection] is not a valid projection specifier.") return proj @chk.check( dict( index=chk.accept_any(chk.is_integer, chk.has_integers, chk.is_instance(slice)), data_kwargs=chk.allow_None(chk.is_instance(dict)), use_contours=chk.is_boolean, projection=chk.is_instance(pyproj.Proj), ax=chk.is_instance(axes.Axes), )) def _draw_data(self, index, data_kwargs, use_contours, projection, ax): """ Contour plot of data. Parameters ---------- index : int, array-like(int), slice `index` parameter given to :py:meth:`draw`. data_kwargs : dict `data_kwargs` parameter given to :py:meth:`draw`. use_contours: bool If :py:obj:`True`, use [tri]contourf() to produce the plots. If :py:obj:`False`, use [tri]pcolor[mesh]() to produce the plots. projection : :py:class:`~pyproj.Proj` PyProj projection object. ax : :py:class:`~matplotlib.axes.Axes` Axes to plot on. Returns ------- scm : :py:class:`~matplotlib.cm.ScalarMappable` """ if data_kwargs is None: data_kwargs = dict() N_image = self.shape[0] if chk.is_integer(index): index = np.array([index], dtype=int) elif chk.has_integers(index): index = np.array(index, dtype=int) else: # slice() index = np.arange(N_image, dtype=int)[index] if index.size == 0: raise ValueError("No data-cube slice chosen.") if not np.all((0 <= index) & (index < N_image)): raise ValueError("Parameter[index] is out of bounds.") data = np.sum(self._data[index], axis=0) _, grid_lat, grid_lon = transform.cart2eq(*self._grid) grid_x, grid_y = self._eq2xy(grid_lat, grid_lon, projection) # Colormap choice if "cmap" in data_kwargs: obj = data_kwargs.pop("cmap") if chk.is_instance(str)(obj): cmap = cm.get_cmap(obj) else: cmap = obj else: cmap = cm.get_cmap("RdPu") if self._is_gridded: if use_contours: scm = ax.contourf(grid_x, grid_y, data, cmap.N, cmap=cmap, **data_kwargs) else: scm = ax.pcolormesh(grid_x, grid_y, data, cmap=cmap, **data_kwargs) else: triangulation = tri.Triangulation(grid_x, grid_y) if use_contours: scm = ax.tricontourf(triangulation, data, cmap.N, cmap=cmap, **data_kwargs) else: scm = ax.tripcolor(triangulation, data, cmap=cmap, **data_kwargs) # Show coordinates in status bar def sexagesimal_coords(x, y): lon, lat = projection(x, y, errcheck=False, inverse=True) lon = (coord.Angle(lon * u.deg).wrap_at(180 * u.deg).to_string( unit=u.hourangle, sep="hms")) lat = coord.Angle(lat * u.deg).to_string(unit=u.degree, sep="dms") msg = f"RA: {lon}, DEC: {lat}" return msg ax.format_coord = sexagesimal_coords return scm @chk.check( dict( show_colorbar=chk.is_boolean, scm=chk.is_instance(cm.ScalarMappable), ax=chk.is_instance(axes.Axes), )) def _draw_colorbar(self, show_colorbar, scm, ax): """ Attach colorbar. Parameters ---------- show_colorbar : bool `show_colorbar` parameter given to :py:meth:`draw`. scm : :py:class:`~matplotlib.cm.ScalarMappable` Intensity scale. ax : :py:class:`~matplotlib.axes.Axes` Axes to plot on. Returns ------- cbar : :py:class:`~matplotlib.colorbar.Colorbar` """ if show_colorbar: cbar = plot.colorbar(scm, ax) else: cbar = None return cbar @chk.check( dict( show_gridlines=chk.is_boolean, grid_kwargs=chk.allow_None(chk.is_instance(dict)), projection=chk.is_instance(pyproj.Proj), ax=chk.is_instance(axes.Axes), )) def _draw_gridlines(self, show_gridlines, grid_kwargs, projection, ax): """ Plot Right-Ascension / Declination lines. Parameters ---------- show_gridlines : bool `show_gridlines` parameter given to :py:meth:`draw`. grid_kwargs : dict `grid_kwargs` parameter given to :py:meth:`draw`. projection : :py:class:`pyproj.Proj` PyProj projection object. ax : :py:class:`~matplotlib.axes.Axes` Axes to plot on. """ if grid_kwargs is None: grid_kwargs = dict() if "N_parallel" in grid_kwargs: N_parallel = grid_kwargs.pop("N_parallel") if not (chk.is_integer(N_parallel) and (N_parallel >= 3)): raise ValueError("Value[N_parallel] must be at least 3.") else: N_parallel = 3 if "N_meridian" in grid_kwargs: N_meridian = grid_kwargs.pop("N_meridian") if not (chk.is_integer(N_meridian) and (N_meridian >= 3)): raise ValueError("Value[N_meridian] must be at least 3.") else: N_meridian = 3 if "polar_plot" in grid_kwargs: polar_plot = grid_kwargs.pop("polar_plot") if not chk.is_boolean(polar_plot): raise ValueError("Value[polar_plot] must be boolean.") else: polar_plot = False if "ticks" in grid_kwargs: show_ticks = grid_kwargs.pop("ticks") if not chk.is_boolean(show_ticks): raise ValueError("Value[ticks] must be boolean.") else: # TODO: change to True once implemented. show_ticks = False plot_style = dict(alpha=0.5, color="k", linewidth=1, linestyle="solid") plot_style.update(grid_kwargs) _, grid_lat, grid_lon = transform.cart2eq(*self._grid) grid_lat, grid_lon = self._wrapped_rad2deg(grid_lat, grid_lon) # RA curves meridian = dict() dec_span = np.linspace(grid_lat.min(), grid_lat.max(), 200) if polar_plot: ra = np.linspace(-180, 180, N_meridian, endpoint=False) else: ra = np.linspace(grid_lon.min(), grid_lon.max(), N_meridian) for _ in ra: ra_span = _ * np.ones_like(dec_span) grid_x, grid_y = self._eq2xy(np.deg2rad(dec_span), np.deg2rad(ra_span), projection) if show_gridlines: mer = ax.plot(grid_x, grid_y, **plot_style)[0] meridian[_] = mer # DEC curves parallel = dict() ra_span = np.linspace(grid_lon.min(), grid_lon.max(), 200) if polar_plot: dec = np.linspace(grid_lat.min(), grid_lat.max(), N_parallel + 1) else: dec = np.linspace(grid_lat.min(), grid_lat.max(), N_parallel) for _ in dec: dec_span = _ * np.ones_like(ra_span) grid_x, grid_y = self._eq2xy(np.deg2rad(dec_span), np.deg2rad(ra_span), projection) if show_gridlines: par = ax.plot(grid_x, grid_y, **plot_style)[0] parallel[_] = par # LAT/LON ticks if show_gridlines and show_ticks: raise NotImplementedError("Not yet implemented.") @chk.check( dict( catalog=chk.allow_None(chk.has_reals), projection=chk.is_instance(pyproj.Proj), ax=chk.is_instance(axes.Axes), )) def _draw_catalog(self, catalog, catalog_kwargs, projection, ax): """ Overlay catalog on top of map. Parameters ---------- catalog : :py:class:`~numpy.ndarray` `catalog` parameter given to :py:meth:`draw`. catalog_kwargs : dict `catalog_kwargs` parameter given to :py:meth:`draw`. projection : :py:class:`pyproj.Proj` PyProj projection object. ax : :py:class:`~matplotlib.axes.Axes` Axes to plot on. """ if catalog is not None: N_src = catalog.size // 3 if not (catalog.shape == (3, N_src)): raise ValueError( "Parameter[catalog]: expected (3, N_src) array.") _, c_lat, c_lon = transform.cart2eq(*catalog) c_x, c_y = self._eq2xy(c_lat, c_lon, projection) if catalog_kwargs is None: catalog_kwargs = dict() plot_style = dict(s=400, facecolors="none", edgecolors="w") plot_style.update(catalog_kwargs) ax.scatter(c_x, c_y, **plot_style) @chk.check( dict(projection=chk.is_instance(pyproj.Proj), ax=chk.is_instance(axes.Axes))) def _draw_beautify(self, projection, ax): """ Format plot. Parameters ---------- projection : :py:class:`pyproj.Proj` PyProj projection object. ax : :py:class:`~matplotlib.axes.Axes` Axes to draw on. """ ax.axis("off") ax.axis("equal") @classmethod def _wrapped_rad2deg(cls, lat_r, lon_r): """ Equatorial coordinate [rad] -> [deg] unit conversion. Output longitude guaranteed to lie in [-180, 180) [deg]. Parameters ---------- lat_r : :py:class:`~numpy.ndarray` lon_r : :py:class:`~numpy.ndarray` Returns ------- lat_d : :py:class:`~numpy.ndarray` lon_d : :py:class:`~numpy.ndarray` """ lat_d = coord.Angle(lat_r * u.rad).to_value(u.deg) lon_d = coord.Angle(lon_r * u.rad).wrap_at(180 * u.deg).to_value(u.deg) return lat_d, lon_d @classmethod def _eq2xy(cls, lat_r, lon_r, projection): """ Transform (lon,lat) [rad] to (x,y). Some projections have unmappable regions or exhibit singularities at certain points. These regions are colored white in contour plots by replacing their incorrect value (1e30) with NaN. Parameters ---------- lat_r : :py:class:`~numpy.ndarray` lon_r : :py:class:`~numpy.ndarray` projection : :py:class:`~pyproj.Proj` Returns ------- x : :py:class:`~numpy.ndarray` y : :py:class:`~numpy.ndarray` """ lat_d, lon_d = cls._wrapped_rad2deg(lat_r, lon_r) x, y = projection(lon_d, lat_d, errcheck=False) x[np.isclose(x, 1e30)] = np.nan y[np.isclose(y, 1e30)] = np.nan return x, y
# Author : Sepand KASHANI [[email protected]] # ############################################################################## """ Coordinate transforms. """ import astropy.coordinates as coord import astropy.units as u import numpy as np import imot_tools.util.argcheck as chk @chk.check( dict( r=chk.accept_any(chk.is_real, chk.has_reals), colat=chk.accept_any(chk.is_real, chk.has_reals), lon=chk.accept_any(chk.is_real, chk.has_reals), )) def pol2eq(r, colat, lon): """ Polar coordinates to Equatorial coordinates. Parameters ---------- r : float or :py:class:`~numpy.ndarray` Radius. colat : :py:class:`~numpy.ndarray` Polar/Zenith angle [rad]. lon : :py:class:`~numpy.ndarray` Longitude angle [rad].
class Wishart(RandomSampler): """ `Wishart <https://en.wikipedia.org/wiki/Wishart_distribution>`_ distribution. """ @chk.check( dict(V=chk.accept_any(chk.has_reals, chk.has_complex), n=chk.is_integer)) def __init__(self, V, n): """ Parameters ---------- V : :py:class:`~numpy.ndarray` (p, p) positive-semidefinite scale matrix. n : int degrees of freedom. """ super().__init__() V = np.array(V) p = len(V) if not (chk.has_shape([p, p])(V) and np.allclose(V, V.conj().T)): raise ValueError("Parameter[V] must be hermitian symmetric.") if not (n > p): raise ValueError(f"Parameter[n] must be greater than {p}.") self._V = V self._p = p self._n = n Vq = linalg.sqrtm(V) _, R = linalg.qr(Vq) self._L = R.conj().T @chk.check("N_sample", chk.is_integer) def __call__(self, N_sample=1): """ Generate random samples. Parameters ---------- N_sample : int Number of samples to generate. Returns ------- x : :py:class:`~numpy.ndarray` (N_sample, p, p) samples. Notes ----- The Wishart estimate is obtained using the `Bartlett Decomposition`_. .. _Bartlett Decomposition: https://en.wikipedia.org/wiki/Wishart_distribution#Bartlett_decomposition """ if N_sample < 1: raise ValueError("Parameter[N_sample] must be positive.") A = np.zeros((N_sample, self._p, self._p)) diag_idx = np.diag_indices(self._p) df = self._n * np.ones((N_sample, 1)) - np.arange(self._p) A[:, diag_idx[0], diag_idx[1]] = np.sqrt(stats.chi2.rvs(df=df)) tril_idx = np.tril_indices(self._p, k=-1) size = (N_sample, self._p * (self._p - 1) // 2) A[:, tril_idx[0], tril_idx[1]] = stats.norm.rvs(size=size) W = self._L @ A X = W @ W.conj().transpose(0, 2, 1) return X
class Interpolator: r""" Interpolate order-limited zonal function from spatial samples. Computes :math:`f(r) = \sum_{q} \alpha_{q} f(r_{q}) K_{N}(\langle r, r_{q} \rangle)`, where :math:`r_{q} \in \mathbb{S}^{2}` are points from a spatial sampling scheme, :math:`K_{N}(\cdot)` is the spherical Dirichlet kernel of order :math:`N`, and the :math:`\alpha_{q}` are scaling factors tailored to the sampling scheme. """ @chk.check(dict(N=chk.is_integer, approximate_kernel=chk.is_boolean)) def __init__(self, N, approximate_kernel=False): r""" Parameters ---------- N : int Order of the reconstructed zonal function. approximate_kernel : bool If :py:obj:`True`, pass the `approx` option to :py:class:`~imot_tools.math.func.SphericalDirichlet`. """ if not (N > 0): raise ValueError("Parameter[N] must be positive.") self._N = N self._kernel_func = func.SphericalDirichlet(N, approximate_kernel) @chk.check( dict( weight=chk.has_reals, support=chk.has_reals, f=chk.accept_any(chk.has_reals, chk.has_complex), r=chk.has_reals, sparsity_mask=chk.allow_None( chk.require_all(chk.is_instance(sp.spmatrix), lambda _: np.issubdtype(_.dtype, np.bool_))), )) def __call__(self, weight, support, f, r, sparsity_mask=None): """ Interpolate function samples at order `N`. Parameters ---------- weight : :py:class:`~numpy.ndarray` (N_s,) weights to apply per support point. support : :py:class:`~numpy.ndarray` (3, N_s) critical support points. f : :py:class:`~numpy.ndarray` (L, N_s) zonal function values at support points. (float or complex) r : :py:class:`~numpy.ndarray` (3, N_px) evaluation points. sparsity_mask : :py:class:`~scipy.sparse.spmatrix` (N_s, N_px) sparsity mask (bool) to perform localized kernel evaluation. Returns ------- f_interp : :py:class:`~numpy.ndarray` (L, N_px) function values at specified coordinates. """ if not (weight.shape == (weight.size, )): raise ValueError("Parameter[weight] must have shape (N_s,).") N_s = weight.size if not (support.shape == (3, N_s)): raise ValueError("Parameter[support] must have shape (3, N_s).") L = len(f) if not (f.shape == (L, N_s)): raise ValueError("Parameter[f] must have shape (L, N_s).") if not ((r.ndim == 2) and (r.shape[0] == 3)): raise ValueError("Parameter[r] must have shape (3, N_px).") N_px = r.shape[1] if sparsity_mask is not None: if not (sparsity_mask.shape == (N_s, N_px)): raise ValueError( "Parameter[sparsity_mask] must have shape (N_s, N_px).") if sparsity_mask is None: # Dense evaluation kernel = self._kernel_func(support.T @ r) beta = f * weight f_interp = beta @ kernel else: # Sparse evaluation raise NotImplementedError return f_interp
class EqualAngleInterpolator(Interpolator): r""" Interpolate order-limited zonal function from Equal-Angle samples. Computes :math:`f(r) = \sum_{q, l} \alpha_{q} f(r_{q, l}) K_{N}(\langle r, r_{q, l} \rangle)`, where :math:`r_{q, l} \in \mathbb{S}^{2}` are points from an Equal-Angle sampling scheme, :math:`K_{N}(\cdot)` is the spherical Dirichlet kernel of order :math:`N`, and the :math:`\alpha_{q}` are scaling factors tailored to an Equal-Angle sampling scheme. Examples -------- Let :math:`\gamma_{N}(r): \mathbb{S}^{2} \to \mathbb{R}` be the order-:math:`N` approximation of :math:`\gamma(r) = \delta(r - r_{0})`: .. math:: \gamma_{N}(r) = \frac{N + 1}{4 \pi} \frac{P_{N + 1}(\langle r, r_{0} \rangle) - P_{N}(\langle r, r_{0} \rangle)}{\langle r, r_{0} \rangle -1}. As :math:`\gamma_{N}` is order-limited, it can be exactly reconstructed from it's samples on an order-:math:`N` Equal-Angle grid: .. testsetup:: import numpy as np from imot_tools.math.func import SphericalDirichlet from imot_tools.math.sphere.grid import equal_angle from imot_tools.math.sphere.interpolate import EqualAngleInterpolator from imot_tools.math.sphere.transform import pol2cart def gammaN(r, r0, N): similarity = np.tensordot(r0, r, axes=1) d_func = SphericalDirichlet(N) return d_func(similarity) .. doctest:: # \gammaN Parameters >>> N = 3 >>> r0 = np.array([1, 0, 0]) # Solution at Nyquist resolution >>> colat_idx, lon_idx, colat_nyquist, lon_nyquist = equal_angle(N) >>> N_colat, N_lon = colat_nyquist.size, lon_nyquist.size >>> R_nyquist = pol2cart(1, colat_nyquist, lon_nyquist) >>> g_nyquist = gammaN(R_nyquist, r0, N) # Solution at high resolution >>> _, _, colat_dense, lon_dense = equal_angle(2 * N) >>> R_dense = pol2cart(1, colat_dense, lon_dense).reshape(3, -1) >>> g_exact = gammaN(R_dense, r0, N) >>> ea_interp = EqualAngleInterpolator(N) >>> g_interp = ea_interp(colat_idx, ... lon_idx, ... f=g_nyquist.reshape(1, N_colat, N_lon), ... r=R_dense) >>> np.allclose(g_exact, g_interp) True """ @chk.check(dict(N=chk.is_integer, approximate_kernel=chk.is_boolean)) def __init__(self, N, approximate_kernel=False): r""" Parameters ---------- N : int Order of the reconstructed zonal function. approximate_kernel : bool If :py:obj:`True`, pass the `approx` option to :py:class:`~imot_tools.math.func.SphericalDirichlet`. """ super().__init__(N, approximate_kernel) # TODO: Allow sparse evaluation. @chk.check( dict( colat_idx=chk.has_integers, lon_idx=chk.has_integers, f=chk.accept_any(chk.has_reals, chk.has_complex), r=chk.has_reals, )) def __call__(self, colat_idx, lon_idx, f, r): """ Interpolate function samples at order `N`. Parameters ---------- colat_idx : :py:class:`~numpy.ndarray` (N_colat,) polar support indices from :py:func:`~imot_tools.math.sphere.grid.equal_angle`. lon_idx : :py:class:`~numpy.ndarray` (N_lon,) azimuthal support indices from :py:func:`~imot_tools.math.sphere.grid.equal_angle`. f : :py:class:`~numpy.ndarray` (L, N_colat, N_lon) zonal function values at support points. (float or complex) r : :py:class:`~numpy.ndarray` (3, N_px) evaluation points. Returns ------- f_interp : :py:class:`~numpy.ndarray` (L, N_px) function values at specified coordinates. """ N_colat = colat_idx.size if not (colat_idx.shape == (N_colat, )): raise ValueError( "Parameter[colat_idx] must have shape (N_colat,).") N_lon = lon_idx.size if not (lon_idx.shape == (N_lon, )): raise ValueError("Parameter[lon_idx] must have shape (N_lon,).") L = len(f) if not (f.shape == (L, N_colat, N_lon)): raise ValueError( "Parameter[f] must have shape (L, N_colat, N_lon).") if not ((r.ndim == 2) and (r.shape[0] == 3)): raise ValueError("Parameter[r] must have shape (3, N_px).") # Apply weights directly onto `f` to avoid memory blow-up. _, _, colat, lon = grid.equal_angle(self._N) a = np.arange(self._N + 1) weight = (np.sum(np.sin((2 * a + 1) * colat[colat_idx]) / (2 * a + 1), axis=1) * np.sin(colat[colat_idx, 0]) * ((2 * np.pi) / ((self._N + 1)**2))) # (N_colat,) fw = f * weight.reshape((1, N_colat, 1)) # (L, N_colat, N_lon) f_interp = super().__call__( weight=np.broadcast_to([1], (N_colat * N_lon, )), support=transform.pol2cart(1, colat[colat_idx, :], lon[:, lon_idx]).reshape((3, -1)), f=fw.reshape((L, -1)), r=r, sparsity_mask=None, ) return f_interp