Ejemplo n.º 1
0
    def _draw_catalog(self, catalog, catalog_kwargs, projection, ax):
        """
        Overlay catalog on top of map.

        Parameters
        ----------
        catalog : :py:class:`~numpy.ndarray`
            `catalog` parameter given to :py:meth:`draw`.
        catalog_kwargs : dict
            `catalog_kwargs` parameter given to :py:meth:`draw`.
        projection : :py:class:`pyproj.Proj`
            PyProj projection object.
        ax : :py:class:`~matplotlib.axes.Axes`
            Axes to plot on.
        """
        if catalog is not None:
            N_src = catalog.size // 3
            if not (catalog.shape == (3, N_src)):
                raise ValueError(
                    "Parameter[catalog]: expected (3, N_src) array.")

            _, c_lat, c_lon = transform.cart2eq(*catalog)
            c_x, c_y = self._eq2xy(c_lat, c_lon, projection)

            if catalog_kwargs is None:
                catalog_kwargs = dict()

            plot_style = dict(s=400, facecolors="none", edgecolors="w")
            plot_style.update(catalog_kwargs)

            ax.scatter(c_x, c_y, **plot_style)
Ejemplo n.º 2
0
    def _draw_projection(self, projection):
        """
        Setup :py:class:`pyproj.Proj` object to do (lon,lat) <-> (x,y) transforms.

        Parameters
        ----------
        projection : str
            `projection` parameter given to :py:meth:`draw`.

        Returns
        -------
        proj : :py:class:`pyproj.Proj`
        """
        # Most projections can be provided a point in space around which distortions are minimized.
        # We choose this point to approximately map to the center of the grid when appropriate.
        # (approximate since it is not always a spherical cap.)
        if self._is_gridded:  # (3, N_height, N_width) grid
            grid_dir = np.mean(self._grid, axis=(1, 2))
        else:  # (3, N_points) grid
            grid_dir = np.mean(self._grid, axis=1)
        _, grid_lat, grid_lon = transform.cart2eq(*grid_dir)
        grid_lat, grid_lon = self._wrapped_rad2deg(grid_lat, grid_lon)

        p_name = projection.lower()
        if p_name == "lcc":
            # Lambert Conformal Conic
            proj = pyproj.Proj(proj="lcc", lon_0=grid_lon, lat_0=grid_lat, R=1)
        elif p_name == "aeqd":
            # Azimuthal Equi-Distant
            proj = pyproj.Proj(proj="aeqd",
                               lon_0=grid_lon,
                               lat_0=grid_lat,
                               R=1)
        elif p_name == "laea":
            # Lambert Equal-Area
            proj = pyproj.Proj(proj="laea",
                               lon_0=grid_lon,
                               lat_0=grid_lat,
                               R=1)
        elif p_name == "robin":
            # Robinson
            proj = pyproj.Proj(proj="robin", lon_0=grid_lon, R=1)
        elif p_name == "gnom":
            # Gnomonic
            proj = pyproj.Proj(proj="gnom",
                               lon_0=grid_lon,
                               lat_0=grid_lat,
                               R=1)
        elif p_name == "healpix":
            # Hierarchical Equal-Area Pixelisation
            proj = pyproj.Proj(proj="healpix",
                               lon_0=grid_lon,
                               lat_0=grid_lat,
                               R=1)
        else:
            raise ValueError(
                "Parameter[projection] is not a valid projection specifier.")

        return proj
Ejemplo n.º 3
0
if __name__ == '__main__':
    args = parse_args()

    D = [nn.DataSet.from_file(str(_)) for _ in args['datasets']]
    P = [np.load(_) for _ in args['parameters']]

    R = D[0].R
    N_px = R.shape[1]
    N_sample = len(args['img_idx'])

    I_all = np.zeros((N_sample, 3, N_px))
    for i, idx_img in tqdm.tqdm(list(enumerate(args['img_idx']))):
        I = get_field(D, P, idx_img, img_type=args['img_type'])

        I_rgb = to_RGB(I)
        # I_rgb /= I_rgb.max()

        # Filter field to lie in specified interval
        _, R_lat, R_lon = transform.cart2eq(*R)
        _, R_lon_d = wrapped_rad2deg(R_lat, R_lon)
        min_lon, max_lon = args['lon_ticks'].min(), args['lon_ticks'].max()
        mask_lon = (min_lon <= R_lon_d) & (R_lon_d <= max_lon)

        R_field = transform.eq2cart(1, R_lat[mask_lon], R_lon[mask_lon])
        I_rgb = I_rgb[:, mask_lon]

        I_all[i] = I_rgb

    np.savez(args['out'], I_rgb=I_all)
Ejemplo n.º 4
0
def draw_map(I,
             R,
             lon_ticks,
             catalog=None,
             show_labels=False,
             show_axis=False):
    """
    Parameters
    ==========
    I : :py:class:`~numpy.ndarray`
        (3, N_px)
    R : :py:class:`~numpy.ndarray`
        (3, N_px)
    """
    import mpl_toolkits.basemap as basemap
    import matplotlib.tri as tri

    _, R_el, R_az = transform.cart2eq(*R)
    R_el, R_az = wrapped_rad2deg(R_el, R_az)
    R_el_min, R_el_max = np.around([np.min(R_el), np.max(R_el)])
    R_az_min, R_az_max = np.around([np.min(R_az), np.max(R_az)])

    fig = plt.figure()
    ax = fig.add_subplot(111)
    bm = basemap.Basemap(projection='mill',
                         llcrnrlat=R_el_min,
                         urcrnrlat=R_el_max,
                         llcrnrlon=R_az_min,
                         urcrnrlon=R_az_max,
                         resolution='c',
                         ax=ax)

    if show_axis:
        bm_labels = [1, 0, 0, 1]
    else:
        bm_labels = [0, 0, 0, 0]
    bm.drawparallels(np.linspace(R_el_min, R_el_max, 5),
                     color='w',
                     dashes=[1, 0],
                     labels=bm_labels,
                     labelstyle='+/-',
                     textcolor='#565656',
                     zorder=0,
                     linewidth=2)
    bm.drawmeridians(lon_ticks,
                     color='w',
                     dashes=[1, 0],
                     labels=bm_labels,
                     labelstyle='+/-',
                     textcolor='#565656',
                     zorder=0,
                     linewidth=2)

    if show_labels:
        ax.set_xlabel('Azimuth (degrees)', labelpad=20)
        ax.set_ylabel('Elevation (degrees)', labelpad=40)

    R_x, R_y = bm(R_az, R_el)
    triangulation = tri.Triangulation(R_x, R_y)

    N_px = I.shape[1]
    mycmap = cmap_from_list('mycmap', I_rgb.T, N=N_px)
    colors_cmap = np.arange(N_px)
    ax.tripcolor(triangulation,
                 colors_cmap,
                 cmap=mycmap,
                 shading='gouraud',
                 alpha=0.9,
                 edgecolors='w',
                 linewidth=0.1)

    if catalog is not None:
        _, sky_el, sky_az = transform.cart2eq(*catalog.xyz)
        sky_el, sky_az = wrapped_rad2deg(sky_el, sky_az)
        sky_x, sky_y = bm(sky_az, sky_el)
        ax.scatter(sky_x, sky_y, c='w', s=5)

    return fig, ax
Ejemplo n.º 5
0
    def _draw_gridlines(self, show_gridlines, grid_kwargs, projection, ax):
        """
        Plot Right-Ascension / Declination lines.

        Parameters
        ----------
        show_gridlines : bool
            `show_gridlines` parameter given to :py:meth:`draw`.
        grid_kwargs : dict
            `grid_kwargs` parameter given to :py:meth:`draw`.
        projection : :py:class:`pyproj.Proj`
            PyProj projection object.
        ax : :py:class:`~matplotlib.axes.Axes`
            Axes to plot on.
        """
        if grid_kwargs is None:
            grid_kwargs = dict()

        if "N_parallel" in grid_kwargs:
            N_parallel = grid_kwargs.pop("N_parallel")
            if not (chk.is_integer(N_parallel) and (N_parallel >= 3)):
                raise ValueError("Value[N_parallel] must be at least 3.")
        else:
            N_parallel = 3

        if "N_meridian" in grid_kwargs:
            N_meridian = grid_kwargs.pop("N_meridian")
            if not (chk.is_integer(N_meridian) and (N_meridian >= 3)):
                raise ValueError("Value[N_meridian] must be at least 3.")
        else:
            N_meridian = 3

        if "polar_plot" in grid_kwargs:
            polar_plot = grid_kwargs.pop("polar_plot")
            if not chk.is_boolean(polar_plot):
                raise ValueError("Value[polar_plot] must be boolean.")
        else:
            polar_plot = False

        if "ticks" in grid_kwargs:
            show_ticks = grid_kwargs.pop("ticks")
            if not chk.is_boolean(show_ticks):
                raise ValueError("Value[ticks] must be boolean.")
        else:
            # TODO: change to True once implemented.
            show_ticks = False

        plot_style = dict(alpha=0.5, color="k", linewidth=1, linestyle="solid")
        plot_style.update(grid_kwargs)

        _, grid_lat, grid_lon = transform.cart2eq(*self._grid)
        grid_lat, grid_lon = self._wrapped_rad2deg(grid_lat, grid_lon)

        # RA curves
        meridian = dict()
        dec_span = np.linspace(grid_lat.min(), grid_lat.max(), 200)
        if polar_plot:
            ra = np.linspace(-180, 180, N_meridian, endpoint=False)
        else:
            ra = np.linspace(grid_lon.min(), grid_lon.max(), N_meridian)
        for _ in ra:
            ra_span = _ * np.ones_like(dec_span)
            grid_x, grid_y = self._eq2xy(np.deg2rad(dec_span),
                                         np.deg2rad(ra_span), projection)

            if show_gridlines:
                mer = ax.plot(grid_x, grid_y, **plot_style)[0]
                meridian[_] = mer

        # DEC curves
        parallel = dict()
        ra_span = np.linspace(grid_lon.min(), grid_lon.max(), 200)
        if polar_plot:
            dec = np.linspace(grid_lat.min(), grid_lat.max(), N_parallel + 1)
        else:
            dec = np.linspace(grid_lat.min(), grid_lat.max(), N_parallel)
        for _ in dec:
            dec_span = _ * np.ones_like(ra_span)
            grid_x, grid_y = self._eq2xy(np.deg2rad(dec_span),
                                         np.deg2rad(ra_span), projection)

            if show_gridlines:
                par = ax.plot(grid_x, grid_y, **plot_style)[0]
                parallel[_] = par

        # LAT/LON ticks
        if show_gridlines and show_ticks:
            raise NotImplementedError("Not yet implemented.")
Ejemplo n.º 6
0
    def _draw_data(self, index, data_kwargs, use_contours, projection, ax):
        """
        Contour plot of data.

        Parameters
        ----------
        index : int, array-like(int), slice
            `index` parameter given to :py:meth:`draw`.
        data_kwargs : dict
            `data_kwargs` parameter given to :py:meth:`draw`.
        use_contours: bool
            If :py:obj:`True`, use [tri]contourf() to produce the plots.
            If :py:obj:`False`, use [tri]pcolor[mesh]() to produce the plots.
        projection : :py:class:`~pyproj.Proj`
            PyProj projection object.
        ax : :py:class:`~matplotlib.axes.Axes`
            Axes to plot on.

        Returns
        -------
        scm : :py:class:`~matplotlib.cm.ScalarMappable`
        """
        if data_kwargs is None:
            data_kwargs = dict()

        N_image = self.shape[0]
        if chk.is_integer(index):
            index = np.array([index], dtype=int)
        elif chk.has_integers(index):
            index = np.array(index, dtype=int)
        else:  # slice()
            index = np.arange(N_image, dtype=int)[index]
            if index.size == 0:
                raise ValueError("No data-cube slice chosen.")
        if not np.all((0 <= index) & (index < N_image)):
            raise ValueError("Parameter[index] is out of bounds.")
        data = np.sum(self._data[index], axis=0)

        _, grid_lat, grid_lon = transform.cart2eq(*self._grid)
        grid_x, grid_y = self._eq2xy(grid_lat, grid_lon, projection)

        # Colormap choice
        if "cmap" in data_kwargs:
            obj = data_kwargs.pop("cmap")
            if chk.is_instance(str)(obj):
                cmap = cm.get_cmap(obj)
            else:
                cmap = obj
        else:
            cmap = cm.get_cmap("RdPu")

        if self._is_gridded:
            if use_contours:
                scm = ax.contourf(grid_x,
                                  grid_y,
                                  data,
                                  cmap.N,
                                  cmap=cmap,
                                  **data_kwargs)
            else:
                scm = ax.pcolormesh(grid_x,
                                    grid_y,
                                    data,
                                    cmap=cmap,
                                    **data_kwargs)
        else:
            triangulation = tri.Triangulation(grid_x, grid_y)
            if use_contours:
                scm = ax.tricontourf(triangulation,
                                     data,
                                     cmap.N,
                                     cmap=cmap,
                                     **data_kwargs)
            else:
                scm = ax.tripcolor(triangulation,
                                   data,
                                   cmap=cmap,
                                   **data_kwargs)

        # Show coordinates in status bar
        def sexagesimal_coords(x, y):
            lon, lat = projection(x, y, errcheck=False, inverse=True)
            lon = (coord.Angle(lon * u.deg).wrap_at(180 * u.deg).to_string(
                unit=u.hourangle, sep="hms"))
            lat = coord.Angle(lat * u.deg).to_string(unit=u.degree, sep="dms")

            msg = f"RA: {lon}, DEC: {lat}"
            return msg

        ax.format_coord = sexagesimal_coords

        return scm