def test_update_clip_path_change_wcs(self, tmpdir): # When WCS is changed, a new frame is created, so we need to make sure # that the path is carried over to the new frame. fig = plt.figure() ax = WCSAxes(fig, [0.1, 0.1, 0.8, 0.8], aspect='equal') fig.add_axes(ax) ax.set_xlim(0., 2.) ax.set_ylim(0., 2.) # Force drawing, which freezes the clip path returned by WCSAxes fig.savefig(tmpdir.join('nothing').strpath) ax.reset_wcs() ax.imshow(np.zeros((12, 4))) ax.set_xlim(-0.5, 3.5) ax.set_ylim(-0.5, 11.5) ax.coords[0].set_auto_axislabel(False) ax.coords[1].set_auto_axislabel(False) return fig
def test_ticks_labels(self): fig = plt.figure(figsize=(6, 6)) ax = WCSAxes(fig, [0.1, 0.1, 0.7, 0.7], wcs=None) fig.add_axes(ax) ax.set_xlim(-0.5, 2) ax.set_ylim(-0.5, 2) ax.coords[0].set_ticks(size=10, color='blue', alpha=0.2, width=1) ax.coords[1].set_ticks(size=20, color='red', alpha=0.9, width=1) ax.coords[0].set_ticks_position('all') ax.coords[1].set_ticks_position('all') ax.coords[0].set_axislabel('X-axis', size=20) ax.coords[1].set_axislabel('Y-axis', color='green', size=25, weight='regular', style='normal', family='cmtt10') ax.coords[0].set_axislabel_position('t') ax.coords[1].set_axislabel_position('r') ax.coords[0].set_ticklabel(color='purple', size=15, alpha=1, weight='light', style='normal', family='cmss10') ax.coords[1].set_ticklabel(color='black', size=18, alpha=0.9, weight='bold', family='cmr10') ax.coords[0].set_ticklabel_position('all') ax.coords[1].set_ticklabel_position('r') return fig
def test_direct_init(self): s = DistanceToLonLat(R=6378.273) coord_meta = {} coord_meta['type'] = ('longitude', 'latitude') coord_meta['wrap'] = (360., None) coord_meta['unit'] = (u.deg, u.deg) coord_meta['name'] = 'lon', 'lat' fig = plt.figure(figsize=(4, 4)) ax = WCSAxes(fig, [0.15, 0.15, 0.7, 0.7], transform=s, coord_meta=coord_meta) fig.add_axes(ax) ax.coords['lon'].grid(color='red', linestyle='solid', alpha=0.3) ax.coords['lat'].grid(color='blue', linestyle='solid', alpha=0.3) ax.coords['lon'].set_ticklabel(size=7, exclude_overlapping=True) ax.coords['lat'].set_ticklabel(size=7, exclude_overlapping=True) ax.coords['lon'].set_ticklabel_position('brtl') ax.coords['lat'].set_ticklabel_position('brtl') ax.coords['lon'].set_ticks(spacing=10. * u.deg) ax.coords['lat'].set_ticks(spacing=10. * u.deg) ax.set_xlim(-400., 500.) ax.set_ylim(-300., 400.) return fig
def time_basic_plot(): fig = Figure() canvas = FigureCanvas(fig) ax = WCSAxes(fig, [0.15, 0.15, 0.7, 0.7], wcs=MSX_WCS) fig.add_axes(ax) ax.set_xlim(-0.5, 148.5) ax.set_ylim(-0.5, 148.5) canvas.draw()
def time_basic_plot_with_grid(): fig = Figure() canvas = FigureCanvas(fig) ax = WCSAxes(fig, [0.15, 0.15, 0.7, 0.7], wcs=MSX_WCS) fig.add_axes(ax) ax.grid(color='red', alpha=0.5, linestyle='solid') ax.set_xlim(-0.5, 148.5) ax.set_ylim(-0.5, 148.5) canvas.draw()
def time_contourf_with_transform(): fig = Figure() canvas = FigureCanvas(fig) ax = WCSAxes(fig, [0.15, 0.15, 0.7, 0.7], wcs=MSX_WCS) fig.add_axes(ax) ax.contourf(DATA, transform=ax.get_transform(TWOMASS_WCS)) # The limits are to make sure the contours are in the middle of the result ax.set_xlim(32.5, 150.5) ax.set_ylim(-64.5, 64.5) canvas.draw()
def time_basic_plot_with_grid_and_overlay(): fig = Figure() canvas = FigureCanvas(fig) ax = WCSAxes(fig, [0.15, 0.15, 0.7, 0.7], wcs=MSX_WCS) fig.add_axes(ax) ax.grid(color='red', alpha=0.5, linestyle='solid') ax.set_xlim(-0.5, 148.5) ax.set_ylim(-0.5, 148.5) overlay = ax.get_coords_overlay('fk5') overlay.grid(color='purple', ls='dotted') canvas.draw()
def test_coords_overlay(self): # Set up a simple WCS that maps pixels to non-projected distances wcs = WCS(naxis=2) wcs.wcs.ctype = ['x', 'y'] wcs.wcs.cunit = ['km', 'km'] wcs.wcs.crpix = [614.5, 856.5] wcs.wcs.cdelt = [6.25, 6.25] wcs.wcs.crval = [0., 0.] fig = plt.figure(figsize=(4, 4)) ax = WCSAxes(fig, [0.15, 0.15, 0.7, 0.7], wcs=wcs) fig.add_axes(ax) s = DistanceToLonLat(R=6378.273) ax.coords['x'].set_ticklabel_position('') ax.coords['y'].set_ticklabel_position('') coord_meta = {} coord_meta['type'] = ('longitude', 'latitude') coord_meta['wrap'] = (360., None) coord_meta['unit'] = (u.deg, u.deg) coord_meta['name'] = 'lon', 'lat' overlay = ax.get_coords_overlay(s, coord_meta=coord_meta) overlay.grid(color='red') overlay['lon'].grid(color='red', linestyle='solid', alpha=0.3) overlay['lat'].grid(color='blue', linestyle='solid', alpha=0.3) overlay['lon'].set_ticklabel(size=7, exclude_overlapping=True) overlay['lat'].set_ticklabel(size=7, exclude_overlapping=True) overlay['lon'].set_ticklabel_position('brtl') overlay['lat'].set_ticklabel_position('brtl') overlay['lon'].set_ticks(spacing=10. * u.deg) overlay['lat'].set_ticks(spacing=10. * u.deg) ax.set_xlim(-0.5, 1215.5) ax.set_ylim(-0.5, 1791.5) return fig
def test_update_clip_path_rectangular(self, tmpdir): fig = plt.figure() ax = WCSAxes(fig, [0.1, 0.1, 0.8, 0.8], aspect='equal') fig.add_axes(ax) ax.set_xlim(0., 2.) ax.set_ylim(0., 2.) # Force drawing, which freezes the clip path returned by WCSAxes fig.savefig(tmpdir.join('nothing').strpath) ax.imshow(np.zeros((12, 4))) ax.set_xlim(-0.5, 3.5) ax.set_ylim(-0.5, 11.5) return fig
def test_coords_overlay_auto_coord_meta(self): fig = plt.figure(figsize=(4, 4)) ax = WCSAxes(fig, [0.15, 0.15, 0.7, 0.7], wcs=WCS(self.msx_header)) fig.add_axes(ax) ax.grid(color='red', alpha=0.5, linestyle='solid') overlay = ax.get_coords_overlay('fk5') # automatically sets coord_meta overlay.grid(color='black', alpha=0.5, linestyle='solid') overlay['ra'].set_ticks(color='black') overlay['dec'].set_ticks(color='black') ax.set_xlim(-0.5, 148.5) ax.set_ylim(-0.5, 148.5) return fig
def test_rcparams(self): # Test custom rcParams with rc_context({ 'axes.labelcolor': 'purple', 'axes.labelsize': 14, 'axes.labelweight': 'bold', 'axes.linewidth': 3, 'axes.facecolor': '0.5', 'axes.edgecolor': 'green', 'xtick.color': 'red', 'xtick.labelsize': 8, 'xtick.direction': 'in', 'xtick.minor.visible': True, 'xtick.minor.size': 5, 'xtick.major.size': 20, 'xtick.major.width': 3, 'xtick.major.pad': 10, 'grid.color': 'blue', 'grid.linestyle': ':', 'grid.linewidth': 1, 'grid.alpha': 0.5}): fig = plt.figure(figsize=(6, 6)) ax = WCSAxes(fig, [0.15, 0.1, 0.7, 0.7], wcs=None) fig.add_axes(ax) ax.set_xlim(-0.5, 2) ax.set_ylim(-0.5, 2) ax.grid() ax.set_xlabel('X label') ax.set_ylabel('Y label') ax.coords[0].set_ticklabel(exclude_overlapping=True) ax.coords[1].set_ticklabel(exclude_overlapping=True) return fig
def test_custom_frame(self): wcs = WCS(self.msx_header) fig = plt.figure(figsize=(4, 4)) ax = WCSAxes(fig, [0.15, 0.15, 0.7, 0.7], wcs=wcs, frame_class=HexagonalFrame) fig.add_axes(ax) ax.coords.grid(color='white') im = ax.imshow(np.ones((149, 149)), vmin=0., vmax=2., origin='lower', cmap=plt.cm.gist_heat) minpad = {} minpad['a'] = minpad['d'] = 1 minpad['b'] = minpad['c'] = minpad['e'] = minpad['f'] = 2.75 ax.coords['glon'].set_axislabel("Longitude", minpad=minpad) ax.coords['glon'].set_axislabel_position('ad') ax.coords['glat'].set_axislabel("Latitude", minpad=minpad) ax.coords['glat'].set_axislabel_position('bcef') ax.coords['glon'].set_ticklabel_position('ad') ax.coords['glat'].set_ticklabel_position('bcef') # Set limits so that no labels overlap ax.set_xlim(5.5, 100.5) ax.set_ylim(5.5, 110.5) # Clip the image to the frame im.set_clip_path(ax.coords.frame.patch) return fig
def test_update_clip_path_change_wcs(self, tmpdir): # When WCS is changed, a new frame is created, so we need to make sure # that the path is carried over to the new frame. fig = plt.figure() ax = WCSAxes(fig, [0.1, 0.1, 0.8, 0.8], aspect='equal') fig.add_axes(ax) ax.set_xlim(0., 2.) ax.set_ylim(0., 2.) # Force drawing, which freezes the clip path returned by WCSAxes fig.savefig(tmpdir.join('nothing').strpath) ax.reset_wcs() ax.imshow(np.zeros((12, 4))) ax.set_xlim(-0.5, 3.5) ax.set_ylim(-0.5, 11.5) return fig
def plot(LSM, fileName=None, labelBy=None): """ Shows a simple plot of the sky model. The circles in the plot are scaled with flux. If the sky model is grouped into patches, sources are colored by patch and the patch positions are indicated with stars. Parameters ---------- LSM : SkyModel object Input sky model fileName : str, optional If given, the plot is saved to a file instead of displayed labelBy : str, optional One of 'source' or 'patch': label points using source names ('source') or patch names ('patch') Examples: --------- Plot and display to the screen:: >>> LSM = lsmtool.load('sky.model') >>> plot(LSM) Plot and save to a PDF file:: >>> plot(LSM, 'sky_plot.pdf') """ try: import os if 'DISPLAY' not in os.environ: import matplotlib if matplotlib.get_backend() is not 'Agg': matplotlib.use("Agg") import matplotlib.pyplot as plt from matplotlib.ticker import FuncFormatter except Exception as e: raise ImportError('PyPlot could not be imported. Plotting is not ' 'available: {0}'.format(e.message)) try: try: from astropy.visualization.wcsaxes import WCSAxes hasWCSaxes = True except: from wcsaxes import WCSAxes hasWCSaxes = True except: hasWCSaxes = False import numpy as np from ..operations_lib import radec2xy, makeWCS global midRA, midDec, ymin, xmin if len(LSM) == 0: log.error('Sky model is empty.') return fig = plt.figure(1, figsize=(7.66, 7)) plt.clf() x, y, midRA, midDec = LSM._getXY() if hasWCSaxes: wcs = makeWCS(midRA, midDec) ax = WCSAxes(fig, [0.16, 0.1, 0.8, 0.8], wcs=wcs) fig.add_axes(ax) else: ax = plt.gca() if LSM.hasPatches: nsrc = len(LSM.getPatchNames()) else: nsrc = len(LSM) sm = plt.cm.ScalarMappable(cmap=plt.cm.Set3, norm=plt.Normalize(vmin=0, vmax=nsrc)) sm._A = [] # Set symbol sizes by flux, making sure no symbol is smaller than 50 or # larger than 1000 s = [] fluxes = LSM.getColValues('I') if len(fluxes[fluxes > 0.0]) == 0: minflux = 0.0 else: minflux = np.min(fluxes[fluxes > 0.0]) for flux in LSM.getColValues('I'): if flux > 0.0: s.append(min(1000.0, (1.0 + 2.0 * np.log10(flux / minflux)) * 50.0)) else: s.append(50.0) # Color sources by patch if grouped c = [0] * len(LSM) cp = [] if LSM.hasPatches: for p, patchName in enumerate(LSM.getPatchNames()): indices = LSM.getRowIndex(patchName) cp.append(sm.to_rgba(p)) for ind in indices: c[ind] = sm.to_rgba(p) else: c = [sm.to_rgba(0)] * nsrc # Plot sources if hasWCSaxes: RA = LSM.getColValues('Ra') Dec = LSM.getColValues('Dec') ax.set_xlim(np.min(x) - 20, np.max(x) + 20) ax.set_ylim(np.min(y) - 20, np.max(y) + 20) plt.scatter(x, y, s=s, c=c) if LSM.hasPatches: RAp, Decp = LSM.getPatchPositions(asArray=True) goodInd = np.where((RAp != 0.0) & (Decp != 0.0)) if len(goodInd[0]) < len(RAp): log.info('Some patch positions are unset. Run setPatchPositions() ' 'before plotting to see patch positions and patch names.') xp, yp = radec2xy(RAp[goodInd], Decp[goodInd], midRA, midDec) plt.scatter(xp, yp, s=100, c=cp, marker='*') # Set axis labels, etc. if hasWCSaxes: RAAxis = ax.coords['ra'] RAAxis.set_axislabel('RA', minpad=0.75) RAAxis.set_major_formatter('hh:mm:ss') DecAxis = ax.coords['dec'] DecAxis.set_axislabel('Dec', minpad=0.75) DecAxis.set_major_formatter('dd:mm:ss') ax.coords.grid(color='black', alpha=0.5, linestyle='solid') else: plt.xlabel("RA (arb. units)") plt.ylabel("Dec (arb. units)") if labelBy is not None: if labelBy.lower() == 'source': labels = LSM.getColValues('name') xls = x yls = y elif labelBy.lower() == 'patch': if LSM.hasPatches: labels = LSM.getPatchNames() xls = xp yls = yp else: labels = LSM.getColValues('name') xls = x yls = y else: raise ValueError( "The lableBy parameter must be one of 'source' or " "'patch'.") for label, xl, yl in zip(labels, xls, yls): plt.annotate(label, xy=(xl, yl), xytext=(-2, 2), textcoords='offset points', ha='right', va='bottom') # Define coodinate formater to show RA and Dec under mouse pointer RAformatter = FuncFormatter(RAtickformatter) ax.format_coord = formatCoord if fileName is not None: plt.savefig(fileName) else: plt.show() plt.close(fig)
def plot(self, center, width, s=None, c=None, marker=None, stride=1, emin=None, emax=None, label=None, fontsize=18, fig=None, ax=None, **kwargs): """ Plot event coordinates from this photon list in a scatter plot, optionally restricting the photon energies which are plotted and using only a subset of the photons. Parameters ---------- center : array-like The RA, Dec of the center of the plot in degrees. width : float, (value, unit) tuple, or :class:`~astropy.units.Quantity` The width of the plot in arcminutes. s : integer, optional Size of the scatter marker in points^2. c : string, optional The color of the points. marker : string, optional The marker to use for the points in the scatter plot. Default: 'o' stride : integer, optional Plot every *stride* events. Default: 1 emin : float, (value, unit) tuple, or :class:`~astropy.units.Quantity` The minimum energy of the photons to plot. Default is the minimum energy in the list. emax : float, (value, unit) tuple, or :class:`~astropy.units.Quantity` The maximum energy of the photons to plot. Default is the maximum energy in the list. label : string, optional The label of the spectrum. Default: None fontsize : int Font size for labels and axes. Default: 18 fig : :class:`~matplotlib.figure.Figure`, optional A Figure instance to plot in. Default: None, one will be created if not provided. ax : :class:`~matplotlib.axes.Axes`, optional An Axes instance to plot in. Default: None, one will be created if not provided. """ import matplotlib.pyplot as plt from astropy.visualization.wcsaxes import WCSAxes if fig is None: fig = plt.figure(figsize=(10, 10)) if ax is None: wcs = construct_wcs(center[0], center[1]) ax = WCSAxes(fig, [0.15, 0.1, 0.8, 0.8], wcs=wcs) fig.add_axes(ax) else: wcs = ax.wcs if emin is None: emin = self.energy.value.min() else: emin = parse_value(emin, "keV") if emax is None: emax = self.energy.value.max() else: emax = parse_value(emax, "keV") idxs = np.logical_and(self.energy.value >= emin, self.energy.value <= emax) ra = self.ra[idxs][::stride].value dec = self.dec[idxs][::stride].value x, y = wcs.wcs_world2pix(ra, dec, 1) ax.scatter(x, y, s=s, c=c, marker=marker, label=label, **kwargs) x0, y0 = wcs.wcs_world2pix(center[0], center[1], 1) width = parse_value(width, "arcmin")*60.0 ax.set_xlim(x0-0.5*width, x0+0.5*width) ax.set_ylim(y0-0.5*width, y0+0.5*width) ax.set_xlabel("RA") ax.set_ylabel("Dec") ax.tick_params(axis='both', labelsize=fontsize) return fig, ax
def plot_image(img_file, hdu="IMAGE", stretch='linear', vmin=None, vmax=None, facecolor='black', center=None, width=None, figsize=(10, 10), cmap=None): """ Plot a FITS image created by SOXS using Matplotlib. Parameters ---------- img_file : str The on-disk FITS image to plot. hdu : str or int, optional The image extension to plot. Default is "IMAGE" stretch : str, optional The stretch to apply to the colorbar scale. Options are "linear", "log", and "sqrt". Default: "linear" vmin : float, optional The minimum value of the colorbar. If not set, it will be the minimum value in the image. vmax : float, optional The maximum value of the colorbar. If not set, it will be the maximum value in the image. facecolor : str, optional The color of zero-valued pixels. Default: "black" center : array-like A 2-element object giving an (RA, Dec) coordinate for the center in degrees. If not set, the reference pixel of the image (usually the center) is used. width : float, optional The width of the image in degrees. If not set, the width of the entire image will be used. figsize : tuple, optional A 2-tuple giving the size of the image in inches, e.g. (12, 15). Default: (10,10) cmap : str, optional The colormap to be used. If not set, the default Matplotlib colormap will be used. Returns ------- A tuple of the :class:`~matplotlib.figure.Figure` and the :class:`~matplotlib.axes.Axes` objects. """ import matplotlib.pyplot as plt from matplotlib.colors import PowerNorm, LogNorm, Normalize from astropy.wcs.utils import proj_plane_pixel_scales from astropy.visualization.wcsaxes import WCSAxes if stretch == "linear": norm = Normalize(vmin=vmin, vmax=vmax) elif stretch == "log": norm = LogNorm(vmin=vmin, vmax=vmax) elif stretch == "sqrt": norm = PowerNorm(0.5, vmin=vmin, vmax=vmax) else: raise RuntimeError(f"'{stretch}' is not a valid stretch!") with fits.open(img_file) as f: hdu = f[hdu] w = wcs.WCS(hdu.header) pix_scale = proj_plane_pixel_scales(w) if center is None: center = w.wcs.crpix else: center = w.wcs_world2pix(center[0], center[1], 0) if width is None: dx_pix = 0.5 * hdu.shape[0] dy_pix = 0.5 * hdu.shape[1] else: dx_pix = width / pix_scale[0] dy_pix = width / pix_scale[1] fig = plt.figure(figsize=figsize) ax = WCSAxes(fig, [0.15, 0.1, 0.8, 0.8], wcs=w) fig.add_axes(ax) im = ax.imshow(hdu.data, norm=norm, cmap=cmap) ax.set_xlim(center[0] - 0.5 * dx_pix, center[0] + 0.5 * dx_pix) ax.set_ylim(center[1] - 0.5 * dy_pix, center[1] + 0.5 * dy_pix) ax.set_facecolor(facecolor) cbar = plt.colorbar(im) return fig, ax
def plot(self, center, width, s=None, c=None, marker=None, stride=1, emin=None, emax=None, label=None, fontsize=18, fig=None, ax=None, **kwargs): """ Plot event coordinates from this photon list in a scatter plot, optionally restricting the photon energies which are plotted and using only a subset of the photons. Parameters ---------- center : array-like The RA, Dec of the center of the plot in degrees. width : float, (value, unit) tuple, or :class:`~astropy.units.Quantity` The width of the plot in arcminutes. s : integer, optional Size of the scatter marker in points^2. c : string, optional The color of the points. marker : string, optional The marker to use for the points in the scatter plot. Default: 'o' stride : integer, optional Plot every *stride* events. Default: 1 emin : float, (value, unit) tuple, or :class:`~astropy.units.Quantity` The minimum energy of the photons to plot. Default is the minimum energy in the list. emax : float, (value, unit) tuple, or :class:`~astropy.units.Quantity` The maximum energy of the photons to plot. Default is the maximum energy in the list. label : string, optional The label of the spectrum. Default: None fontsize : int Font size for labels and axes. Default: 18 fig : :class:`~matplotlib.figure.Figure`, optional A Figure instance to plot in. Default: None, one will be created if not provided. ax : :class:`~matplotlib.axes.Axes`, optional An Axes instance to plot in. Default: None, one will be created if not provided. """ import matplotlib.pyplot as plt from astropy.visualization.wcsaxes import WCSAxes if fig is None: fig = plt.figure(figsize=(10, 10)) if ax is None: wcs = construct_wcs(center[0], center[1]) ax = WCSAxes(fig, [0.15, 0.1, 0.8, 0.8], wcs=wcs) fig.add_axes(ax) else: wcs = ax.wcs if emin is None: emin = self.energy.value.min() else: emin = parse_value(emin, "keV") if emax is None: emax = self.energy.value.max() else: emax = parse_value(emax, "keV") idxs = np.logical_and(self.energy.value >= emin, self.energy.value <= emax) ra = self.ra[idxs][::stride].value dec = self.dec[idxs][::stride].value x, y = wcs.wcs_world2pix(ra, dec, 1) ax.scatter(x, y, s=s, c=c, marker=marker, label=label, **kwargs) x0, y0 = wcs.wcs_world2pix(center[0], center[1], 1) width = parse_value(width, "arcmin") * 60.0 ax.set_xlim(x0 - 0.5 * width, x0 + 0.5 * width) ax.set_ylim(y0 - 0.5 * width, y0 + 0.5 * width) ax.set_xlabel("RA") ax.set_ylabel("Dec") ax.tick_params(axis='both', labelsize=fontsize) return fig, ax
def plot_image_rs(map_in, xrange=[-1.4, 1.4], yrange=[-1.4, 1.4], log_min=0.5, log_max=3.5, cmap_name=None, outfile=None, dpi=100, save_interactive=False): """ Quick method to plot a map by specifying the x and y range in SOLAR coordinates. - xrange and yrange are two elements lists or tuples that specify the solar coords in Rs - e.g. xrange=[-1.3, -1.3], yrange=[-1.3, 1.3] - if a output file is specified, it will switch to a non-interactive backend and save the file without showing the plot (unless save_interactive=True). - cmap_name (optional) is a string that specifies a sunpy or matplotlib colormap """ # I don't want to modify the input map at all --> copy the map object just in case map = copy.deepcopy(map_in) # info from the map rs_obs = map.rsun_obs # get the coordinate positions of the x and y ranges x0 = xrange[0] * rs_obs.value * u.arcsec x1 = xrange[1] * rs_obs.value * u.arcsec y0 = yrange[0] * rs_obs.value * u.arcsec y1 = yrange[1] * rs_obs.value * u.arcsec bot_left = SkyCoord(x0, y0, frame=map.coordinate_frame) top_right = SkyCoord(x1, y1, frame=map.coordinate_frame) # experiment with different styles of plotting the x and y window # using "limits" lets you plot outside of the image window, which can be important # for aligning images. plot_method = 'limits' if plot_method == 'submap': map = map.submap(bot_left, top_right) # setup the optional colormap if cmap_name is not None: cmap = plt.get_cmap(cmap_name) map.plot_settings['cmap'] = cmap # Set the map plot min/max pmin = 10.0**(log_min) pmax = 10.0**(log_max) map.plot_settings['norm'] = colors.LogNorm(pmin, pmax) # Change the colormap so undefined values don't show up white map.plot_settings['cmap'].set_bad(color='black') # if saving a file, don't use the interactive backend if outfile is not None and not save_interactive: matplotlib.use(mpl_backend_non_interactive) # setup the figure fig = plt.figure(figsize=(10, 9)) # Manually specify the axis (vs. getting through map.plot) this way you have more control axis = WCSAxes(fig, [0.1, 0.025, 0.95, 0.95], wcs=map.wcs) fig.add_axes( axis) # note that the axes have to be explicitly added to the figure # plot the image map.plot(axes=axis) # example for adjusting the tick spacing (see astropy examples for WCSAxes) custom_ticks = False if custom_ticks: spacing = 500. * u.arcsec axis.coords[0].set_ticks(spacing=spacing) axis.coords[1].set_ticks(spacing=spacing) # if plot is NOT a submap, compute the pixel positions and change the matplotlib limits if plot_method == 'limits': pp_bot_left = map.world_to_pixel(bot_left) pp_top_right = map.world_to_pixel(top_right) axis.set_xlim(left=pp_bot_left.x.value, right=pp_top_right.x.value) axis.set_ylim(bottom=pp_bot_left.y.value, top=pp_top_right.y.value) # plot the colorbar plt.colorbar() # save the plot (optional) if outfile is not None: print("Saving image plot to: " + outfile) fig.savefig(outfile, dpi=dpi) # revert to the default MPL backend if not save_interactive: plt.close() matplotlib.use(mpl_backend_default) else: plt.show() else: plt.show()
def plot(LSM, fileName=None, labelBy=None): """ Shows a simple plot of the sky model. The circles in the plot are scaled with flux. If the sky model is grouped into patches, sources are colored by patch and the patch positions are indicated with stars. Parameters ---------- LSM : SkyModel object Input sky model fileName : str, optional If given, the plot is saved to a file instead of displayed labelBy : str, optional One of 'source' or 'patch': label points using source names ('source') or patch names ('patch') Examples: --------- Plot and display to the screen:: >>> LSM = lsmtool.load('sky.model') >>> plot(LSM) Plot and save to a PDF file:: >>> plot(LSM, 'sky_plot.pdf') """ try: import os if 'DISPLAY' not in os.environ: import matplotlib if matplotlib.get_backend() is not 'Agg': matplotlib.use("Agg") import matplotlib.pyplot as plt from matplotlib.ticker import FuncFormatter except Exception as e: raise ImportError('PyPlot could not be imported. Plotting is not ' 'available: {0}'.format(e.message)) try: try: from astropy.visualization.wcsaxes import WCSAxes hasWCSaxes = True except: from wcsaxes import WCSAxes hasWCSaxes = True except: hasWCSaxes = False import numpy as np from ..operations_lib import radec2xy, makeWCS global midRA, midDec, ymin, xmin if len(LSM) == 0: log.error('Sky model is empty.') return fig = plt.figure(1,figsize=(7.66,7)) plt.clf() x, y, midRA, midDec = LSM._getXY() if hasWCSaxes: wcs = makeWCS(midRA, midDec) ax = WCSAxes(fig, [0.16, 0.1, 0.8, 0.8], wcs=wcs) fig.add_axes(ax) else: ax = plt.gca() if LSM.hasPatches: nsrc = len(LSM.getPatchNames()) else: nsrc = len(LSM) sm = plt.cm.ScalarMappable(cmap=plt.cm.Set3, norm=plt.Normalize(vmin=0, vmax=nsrc)) sm._A = [] # Set symbol sizes by flux, making sure no symbol is smaller than 50 or # larger than 1000 s = [] fluxes = LSM.getColValues('I') if len(fluxes[fluxes > 0.0]) == 0: minflux = 0.0 else: minflux = np.min(fluxes[fluxes > 0.0]) for flux in LSM.getColValues('I'): if flux > 0.0: s.append(min(1000.0, (1.0+2.0*np.log10(flux/minflux))*50.0)) else: s.append(50.0) # Color sources by patch if grouped c = [0]*len(LSM) cp = [] if LSM.hasPatches: for p, patchName in enumerate(LSM.getPatchNames()): indices = LSM.getRowIndex(patchName) cp.append(sm.to_rgba(p)) for ind in indices: c[ind] = sm.to_rgba(p) else: c = [sm.to_rgba(0)] * nsrc # Plot sources if hasWCSaxes: RA = LSM.getColValues('Ra') Dec = LSM.getColValues('Dec') ax.set_xlim(np.min(x)-20, np.max(x)+20) ax.set_ylim(np.min(y)-20, np.max(y)+20) plt.scatter(x, y, s=s, c=c) if LSM.hasPatches: RAp, Decp = LSM.getPatchPositions(asArray=True) goodInd = np.where( (RAp != 0.0) & (Decp != 0.0) ) if len(goodInd[0]) < len(RAp): log.info('Some patch positions are unset. Run setPatchPositions() ' 'before plotting to see patch positions and patch names.') xp, yp = radec2xy(RAp[goodInd], Decp[goodInd], midRA, midDec) plt.scatter(xp, yp, s=100, c=cp, marker='*') # Set axis labels, etc. if hasWCSaxes: RAAxis = ax.coords['ra'] RAAxis.set_axislabel('RA', minpad=0.75) RAAxis.set_major_formatter('hh:mm:ss') DecAxis = ax.coords['dec'] DecAxis.set_axislabel('Dec', minpad=0.75) DecAxis.set_major_formatter('dd:mm:ss') ax.coords.grid(color='black', alpha=0.5, linestyle='solid') else: plt.xlabel("RA (arb. units)") plt.ylabel("Dec (arb. units)") if labelBy is not None: if labelBy.lower() == 'source': labels = LSM.getColValues('name') xls = x yls = y elif labelBy.lower() == 'patch': if LSM.hasPatches: labels = LSM.getPatchNames() xls = xp yls = yp else: labels = LSM.getColValues('name') xls = x yls = y else: raise ValueError("The lableBy parameter must be one of 'source' or " "'patch'.") for label, xl, yl in zip(labels, xls, yls): plt.annotate(label, xy = (xl, yl), xytext = (-2, 2), textcoords= 'offset points', ha='right', va='bottom') # Define coodinate formater to show RA and Dec under mouse pointer RAformatter = FuncFormatter(RAtickformatter) ax.format_coord = formatCoord if fileName is not None: plt.savefig(fileName) else: plt.show() plt.close(fig)