def test_direct_init(self): s = DistanceToLonLat(R=6378.273) coord_meta = {} coord_meta['type'] = ('longitude', 'latitude') coord_meta['wrap'] = (360., None) coord_meta['unit'] = (u.deg, u.deg) coord_meta['name'] = 'lon', 'lat' fig = plt.figure(figsize=(4, 4)) ax = WCSAxes(fig, [0.15, 0.15, 0.7, 0.7], transform=s, coord_meta=coord_meta) fig.add_axes(ax) ax.coords['lon'].grid(color='red', linestyle='solid', alpha=0.3) ax.coords['lat'].grid(color='blue', linestyle='solid', alpha=0.3) ax.coords['lon'].set_ticklabel(size=7, exclude_overlapping=True) ax.coords['lat'].set_ticklabel(size=7, exclude_overlapping=True) ax.coords['lon'].set_ticklabel_position('brtl') ax.coords['lat'].set_ticklabel_position('brtl') ax.coords['lon'].set_ticks(spacing=10. * u.deg) ax.coords['lat'].set_ticks(spacing=10. * u.deg) ax.set_xlim(-400., 500.) ax.set_ylim(-300., 400.) return fig
def test_ticks_labels(self): fig = plt.figure(figsize=(6, 6)) ax = WCSAxes(fig, [0.1, 0.1, 0.7, 0.7], wcs=None) fig.add_axes(ax) ax.set_xlim(-0.5, 2) ax.set_ylim(-0.5, 2) ax.coords[0].set_ticks(size=10, color='blue', alpha=0.2, width=1) ax.coords[1].set_ticks(size=20, color='red', alpha=0.9, width=1) ax.coords[0].set_ticks_position('all') ax.coords[1].set_ticks_position('all') ax.coords[0].set_axislabel('X-axis', size=20) ax.coords[1].set_axislabel('Y-axis', color='green', size=25, weight='regular', style='normal', family='cmtt10') ax.coords[0].set_axislabel_position('t') ax.coords[1].set_axislabel_position('r') ax.coords[0].set_ticklabel(color='purple', size=15, alpha=1, weight='light', style='normal', family='cmss10') ax.coords[1].set_ticklabel(color='black', size=18, alpha=0.9, weight='bold', family='cmr10') ax.coords[0].set_ticklabel_position('all') ax.coords[1].set_ticklabel_position('r') return fig
def time_basic_plot(): fig = Figure() canvas = FigureCanvas(fig) ax = WCSAxes(fig, [0.15, 0.15, 0.7, 0.7], wcs=MSX_WCS) fig.add_axes(ax) ax.set_xlim(-0.5, 148.5) ax.set_ylim(-0.5, 148.5) canvas.draw()
def test_copy_frame_properties_change_wcs(self): # When WCS is changed, a new frame is created, so we need to make sure # that the color and linewidth are transferred over fig = plt.figure() ax = WCSAxes(fig, [0.1, 0.1, 0.8, 0.8]) fig.add_axes(ax) ax.coords.frame.set_linewidth(5) ax.coords.frame.set_color('purple') ax.reset_wcs() assert ax.coords.frame.get_linewidth() == 5 assert ax.coords.frame.get_color() == 'purple'
def time_contourf_with_transform(): fig = Figure() canvas = FigureCanvas(fig) ax = WCSAxes(fig, [0.15, 0.15, 0.7, 0.7], wcs=MSX_WCS) fig.add_axes(ax) ax.contourf(DATA, transform=ax.get_transform(TWOMASS_WCS)) # The limits are to make sure the contours are in the middle of the result ax.set_xlim(32.5, 150.5) ax.set_ylim(-64.5, 64.5) canvas.draw()
def time_basic_plot_with_grid_and_overlay(): fig = Figure() canvas = FigureCanvas(fig) ax = WCSAxes(fig, [0.15, 0.15, 0.7, 0.7], wcs=MSX_WCS) fig.add_axes(ax) ax.grid(color='red', alpha=0.5, linestyle='solid') ax.set_xlim(-0.5, 148.5) ax.set_ylim(-0.5, 148.5) overlay = ax.get_coords_overlay('fk5') overlay.grid(color='purple', ls='dotted') canvas.draw()
def time_basic_plot_with_grid(): fig = Figure() canvas = FigureCanvas(fig) ax = WCSAxes(fig, [0.15, 0.15, 0.7, 0.7], wcs=MSX_WCS) fig.add_axes(ax) ax.grid(color='red', alpha=0.5, linestyle='solid') ax.set_xlim(-0.5, 148.5) ax.set_ylim(-0.5, 148.5) canvas.draw()
def test_coords_overlay_auto_coord_meta(self): fig = plt.figure(figsize=(4, 4)) ax = WCSAxes(fig, [0.15, 0.15, 0.7, 0.7], wcs=WCS(self.msx_header)) fig.add_axes(ax) ax.grid(color='red', alpha=0.5, linestyle='solid') overlay = ax.get_coords_overlay('fk5') # automatically sets coord_meta overlay.grid(color='black', alpha=0.5, linestyle='solid') overlay['ra'].set_ticks(color='black') overlay['dec'].set_ticks(color='black') ax.set_xlim(-0.5, 148.5) ax.set_ylim(-0.5, 148.5) return fig
def test_coords_overlay(self): # Set up a simple WCS that maps pixels to non-projected distances wcs = WCS(naxis=2) wcs.wcs.ctype = ['x', 'y'] wcs.wcs.cunit = ['km', 'km'] wcs.wcs.crpix = [614.5, 856.5] wcs.wcs.cdelt = [6.25, 6.25] wcs.wcs.crval = [0., 0.] fig = plt.figure(figsize=(4, 4)) ax = WCSAxes(fig, [0.15, 0.15, 0.7, 0.7], wcs=wcs) fig.add_axes(ax) s = DistanceToLonLat(R=6378.273) ax.coords['x'].set_ticklabel_position('') ax.coords['y'].set_ticklabel_position('') coord_meta = {} coord_meta['type'] = ('longitude', 'latitude') coord_meta['wrap'] = (360., None) coord_meta['unit'] = (u.deg, u.deg) coord_meta['name'] = 'lon', 'lat' overlay = ax.get_coords_overlay(s, coord_meta=coord_meta) overlay.grid(color='red') overlay['lon'].grid(color='red', linestyle='solid', alpha=0.3) overlay['lat'].grid(color='blue', linestyle='solid', alpha=0.3) overlay['lon'].set_ticklabel(size=7, exclude_overlapping=True) overlay['lat'].set_ticklabel(size=7, exclude_overlapping=True) overlay['lon'].set_ticklabel_position('brtl') overlay['lat'].set_ticklabel_position('brtl') overlay['lon'].set_ticks(spacing=10. * u.deg) overlay['lat'].set_ticks(spacing=10. * u.deg) ax.set_xlim(-0.5, 1215.5) ax.set_ylim(-0.5, 1791.5) return fig
def test_custom_frame(self): wcs = WCS(self.msx_header) fig = plt.figure(figsize=(4, 4)) ax = WCSAxes(fig, [0.15, 0.15, 0.7, 0.7], wcs=wcs, frame_class=HexagonalFrame) fig.add_axes(ax) ax.coords.grid(color='white') im = ax.imshow(np.ones((149, 149)), vmin=0., vmax=2., origin='lower', cmap=plt.cm.gist_heat) minpad = {} minpad['a'] = minpad['d'] = 1 minpad['b'] = minpad['c'] = minpad['e'] = minpad['f'] = 2.75 ax.coords['glon'].set_axislabel("Longitude", minpad=minpad) ax.coords['glon'].set_axislabel_position('ad') ax.coords['glat'].set_axislabel("Latitude", minpad=minpad) ax.coords['glat'].set_axislabel_position('bcef') ax.coords['glon'].set_ticklabel_position('ad') ax.coords['glat'].set_ticklabel_position('bcef') # Set limits so that no labels overlap ax.set_xlim(5.5, 100.5) ax.set_ylim(5.5, 110.5) # Clip the image to the frame im.set_clip_path(ax.coords.frame.patch) return fig
def test_rcparams(self): # Test custom rcParams with rc_context({ 'axes.labelcolor': 'purple', 'axes.labelsize': 14, 'axes.labelweight': 'bold', 'axes.linewidth': 3, 'axes.facecolor': '0.5', 'axes.edgecolor': 'green', 'xtick.color': 'red', 'xtick.labelsize': 8, 'xtick.direction': 'in', 'xtick.minor.visible': True, 'xtick.minor.size': 5, 'xtick.major.size': 20, 'xtick.major.width': 3, 'xtick.major.pad': 10, 'grid.color': 'blue', 'grid.linestyle': ':', 'grid.linewidth': 1, 'grid.alpha': 0.5}): fig = plt.figure(figsize=(6, 6)) ax = WCSAxes(fig, [0.15, 0.1, 0.7, 0.7], wcs=None) fig.add_axes(ax) ax.set_xlim(-0.5, 2) ax.set_ylim(-0.5, 2) ax.grid() ax.set_xlabel('X label') ax.set_ylabel('Y label') ax.coords[0].set_ticklabel(exclude_overlapping=True) ax.coords[1].set_ticklabel(exclude_overlapping=True) return fig
import matplotlib.cm as cm from photutils import find_peaks from astropy import wcs from astropy.nddata import Cutout2D from astropy import units as u filename = './rcutout.fits' image = fits.open(filename) try: from astropy.wcs import WCS from astropy.visualization.wcsaxes import WCSAxes wcs = WCS(image[0].header) fig = plt.figure() ax = WCSAxes(fig, [.1,.1,.8,.8], wcs=wcs) fig.add_axes(ax) except ImportError: ax = plt.subplot(111) ax.imshow(image[0].data, cmap=cm.gray, vmin=0, vmax=0.00038, origin='lower') region_name = './satpeaks4.reg' r = pyregion.open(region_name).as_imagecoord(header=image[0].header) from pyregion.mpl_helper import properties_func_default def fixed_color(shape, saved_attrs): attr_list, attr_dict = saved_attrs attr_dict["color"] = "red" kwargs = properties_func_default(shape, (attr_list,attr_dict))
def plot_image_rs_full(map_in, xrange=[-1.4, 1.4], yrange=[-1.4, 1.4], log_min=0.5, log_max=3.5, cmap_name=None, outfile=None, dpi=100, save_interactive=False): """ Quick method to plot a map by specifying the x and y range in SOLAR coordinates. - Unlike plot_image_rs, here the image fills the entire frame, with no outside annotations like axes labels or colorbars - xrange and yrange are two elements lists or tuples that specify the solar coords in Rs - e.g. xrange=[-1.3, -1.3], yrange=[-1.3, 1.3] - if a output file is specified, it will switch to a non-interactive backend and save the file without showing the plot (unless save_interactive=True). - cmap_name (optional) is a string that specifies a sunpy or matplotlib colormap ToDo: - put white annotations in the corners that describe the image (time, inst, clon, b0) - overplot some solar grid lines (e.g. the lat=0 line and/or various clons) """ # I don't want to modify the input map at all --> copy the map object just in case map = copy.deepcopy(map_in) # info from the map rs_obs = map.rsun_obs # get the coordinate positions of the x and y ranges x0 = xrange[0] * rs_obs.value * u.arcsec x1 = xrange[1] * rs_obs.value * u.arcsec y0 = yrange[0] * rs_obs.value * u.arcsec y1 = yrange[1] * rs_obs.value * u.arcsec bot_left = SkyCoord(x0, y0, frame=map.coordinate_frame) top_right = SkyCoord(x1, y1, frame=map.coordinate_frame) # experiment with different styles of plotting the x and y window # using "limits" lets you plot outside of the image window, which can be important # for aligning images. plot_method = 'limits' if plot_method == 'submap': map = map.submap(bot_left, top_right) # setup the optional colormap if cmap_name is not None: cmap = plt.get_cmap(cmap_name) map.plot_settings['cmap'] = cmap # Set the map plot min/max pmin = 10.0**(log_min) pmax = 10.0**(log_max) map.plot_settings['norm'] = colors.LogNorm(pmin, pmax) # Change the colormap so undefined values don't show up white map.plot_settings['cmap'].set_bad(color='black') # if saving a file, don't use the interactive backend if outfile is not None and not save_interactive: matplotlib.use(mpl_backend_non_interactive) # setup the figure fig = plt.figure(figsize=(9, 9)) # Manually specify the axis (vs. getting through map.plot) this way you have more control axis = WCSAxes(fig, [0.0, 0.0, 1.0, 1.0], wcs=map.wcs) fig.add_axes( axis) # note that the axes have to be explicitly added to the figure # plot the image map.plot(axes=axis) # example for adjusting the tick spacing (see astropy examples for WCSAxes) custom_ticks = True if custom_ticks: spacing = map.rsun_obs axis.coords[0].set_ticks(spacing=spacing) axis.coords[1].set_ticks(spacing=spacing) # if plot is NOT a submap, compute the pixel positions and change the matplotlib limits if plot_method == 'limits': pp_bot_left = map.world_to_pixel(bot_left) pp_top_right = map.world_to_pixel(top_right) axis.set_xlim(left=pp_bot_left.x.value, right=pp_top_right.x.value) axis.set_ylim(bottom=pp_bot_left.y.value, top=pp_top_right.y.value) # save the plot (optional) if outfile is not None: print("Saving image plot to: " + outfile) fig.savefig(outfile, dpi=dpi) # revert to the default MPL backend if not save_interactive: plt.close() matplotlib.use(mpl_backend_default) else: plt.show() else: plt.show()
def plot(LSM, fileName=None, labelBy=None): """ Shows a simple plot of the sky model. The circles in the plot are scaled with flux. If the sky model is grouped into patches, sources are colored by patch and the patch positions are indicated with stars. Parameters ---------- LSM : SkyModel object Input sky model fileName : str, optional If given, the plot is saved to a file instead of displayed labelBy : str, optional One of 'source' or 'patch': label points using source names ('source') or patch names ('patch') Examples: --------- Plot and display to the screen:: >>> LSM = lsmtool.load('sky.model') >>> plot(LSM) Plot and save to a PDF file:: >>> plot(LSM, 'sky_plot.pdf') """ try: import os if 'DISPLAY' not in os.environ: import matplotlib if matplotlib.get_backend() is not 'Agg': matplotlib.use("Agg") import matplotlib.pyplot as plt from matplotlib.ticker import FuncFormatter except Exception as e: raise ImportError('PyPlot could not be imported. Plotting is not ' 'available: {0}'.format(e.message)) try: try: from astropy.visualization.wcsaxes import WCSAxes hasWCSaxes = True except: from wcsaxes import WCSAxes hasWCSaxes = True except: hasWCSaxes = False import numpy as np from ..operations_lib import radec2xy, makeWCS global midRA, midDec, ymin, xmin if len(LSM) == 0: log.error('Sky model is empty.') return fig = plt.figure(1, figsize=(7.66, 7)) plt.clf() x, y, midRA, midDec = LSM._getXY() if hasWCSaxes: wcs = makeWCS(midRA, midDec) ax = WCSAxes(fig, [0.16, 0.1, 0.8, 0.8], wcs=wcs) fig.add_axes(ax) else: ax = plt.gca() if LSM.hasPatches: nsrc = len(LSM.getPatchNames()) else: nsrc = len(LSM) sm = plt.cm.ScalarMappable(cmap=plt.cm.Set3, norm=plt.Normalize(vmin=0, vmax=nsrc)) sm._A = [] # Set symbol sizes by flux, making sure no symbol is smaller than 50 or # larger than 1000 s = [] fluxes = LSM.getColValues('I') if len(fluxes[fluxes > 0.0]) == 0: minflux = 0.0 else: minflux = np.min(fluxes[fluxes > 0.0]) for flux in LSM.getColValues('I'): if flux > 0.0: s.append(min(1000.0, (1.0 + 2.0 * np.log10(flux / minflux)) * 50.0)) else: s.append(50.0) # Color sources by patch if grouped c = [0] * len(LSM) cp = [] if LSM.hasPatches: for p, patchName in enumerate(LSM.getPatchNames()): indices = LSM.getRowIndex(patchName) cp.append(sm.to_rgba(p)) for ind in indices: c[ind] = sm.to_rgba(p) else: c = [sm.to_rgba(0)] * nsrc # Plot sources if hasWCSaxes: RA = LSM.getColValues('Ra') Dec = LSM.getColValues('Dec') ax.set_xlim(np.min(x) - 20, np.max(x) + 20) ax.set_ylim(np.min(y) - 20, np.max(y) + 20) plt.scatter(x, y, s=s, c=c) if LSM.hasPatches: RAp, Decp = LSM.getPatchPositions(asArray=True) goodInd = np.where((RAp != 0.0) & (Decp != 0.0)) if len(goodInd[0]) < len(RAp): log.info('Some patch positions are unset. Run setPatchPositions() ' 'before plotting to see patch positions and patch names.') xp, yp = radec2xy(RAp[goodInd], Decp[goodInd], midRA, midDec) plt.scatter(xp, yp, s=100, c=cp, marker='*') # Set axis labels, etc. if hasWCSaxes: RAAxis = ax.coords['ra'] RAAxis.set_axislabel('RA', minpad=0.75) RAAxis.set_major_formatter('hh:mm:ss') DecAxis = ax.coords['dec'] DecAxis.set_axislabel('Dec', minpad=0.75) DecAxis.set_major_formatter('dd:mm:ss') ax.coords.grid(color='black', alpha=0.5, linestyle='solid') else: plt.xlabel("RA (arb. units)") plt.ylabel("Dec (arb. units)") if labelBy is not None: if labelBy.lower() == 'source': labels = LSM.getColValues('name') xls = x yls = y elif labelBy.lower() == 'patch': if LSM.hasPatches: labels = LSM.getPatchNames() xls = xp yls = yp else: labels = LSM.getColValues('name') xls = x yls = y else: raise ValueError( "The lableBy parameter must be one of 'source' or " "'patch'.") for label, xl, yl in zip(labels, xls, yls): plt.annotate(label, xy=(xl, yl), xytext=(-2, 2), textcoords='offset points', ha='right', va='bottom') # Define coodinate formater to show RA and Dec under mouse pointer RAformatter = FuncFormatter(RAtickformatter) ax.format_coord = formatCoord if fileName is not None: plt.savefig(fileName) else: plt.show() plt.close(fig)
def plot_state(directions_list, trim_names=True): """ Plots the facets of a run """ global midRA, midDec, fig, at, selected_direction, choose_from_list selected_direction = None choose_from_list = False # Set up coordinate system and figure points, midRA, midDec = factor.directions.getxy(directions_list) fig = plt.figure(1, figsize=(10, 9)) if hasWCSaxes: wcs = factor.directions.makeWCS(midRA, midDec) ax = WCSAxes(fig, [0.16, 0.1, 0.8, 0.8], wcs=wcs) fig.add_axes(ax) else: ax = plt.gca() field_x = min(points[0]) field_y = max(points[1]) adjust_xy = True while adjust_xy: adjust_xy = False for xy in points: dist = np.sqrt((xy[0] - field_x)**2 + (xy[1] - field_y)**2) if dist < 10.0: field_x -= 1 field_y += 1 adjust_xy = True break field_ra, field_dec = factor.directions.xy2radec([field_x], [field_y], refRA=midRA, refDec=midDec) field = Direction('field', field_ra[0], field_dec[0], factor_working_dir=directions_list[0].working_dir) directions_list.append(field) ax.set_title('Overview of Factor run in\n{}'.format( directions_list[0].working_dir)) # Plot facets markers = [] for direction in directions_list: if direction.name != 'field': vertices = read_vertices(direction.vertices_file) RAverts = vertices[0] Decverts = vertices[1] xverts, yverts = factor.directions.radec2xy(RAverts, Decverts, refRA=midRA, refDec=midDec) xyverts = [np.array([xp, yp]) for xp, yp in zip(xverts, yverts)] mpl_poly = Polygon(np.array(xyverts), edgecolor='#a9a9a9', facecolor='#F2F2F2', clip_box=ax.bbox, picker=3.0, linewidth=2) else: xverts = [field_x] yverts = [field_y] mpl_poly = Circle((field_x, field_y), radius=5.0, edgecolor='#a9a9a9', facecolor='#F2F2F2', clip_box=ax.bbox, picker=3.0, linewidth=2) mpl_poly.facet_name = direction.name mpl_poly.completed_ops = get_completed_ops(direction) mpl_poly.started_ops = get_started_ops(direction) mpl_poly.current_op = get_current_op(direction) set_patch_color(mpl_poly, direction) ax.add_patch(mpl_poly) # Add facet names if direction.name != 'field': poly_tuple = tuple([(xp, yp) for xp, yp in zip(xverts, yverts)]) xmid = SPolygon(poly_tuple).centroid.x ymid = SPolygon(poly_tuple).centroid.y else: xmid = field_x ymid = field_y if trim_names: name = direction.name.split('_')[-1] else: name = direction.name marker = ax.text(xmid, ymid, name, color='k', clip_on=True, clip_box=ax.bbox, ha='center', va='bottom') marker.set_zorder(1001) markers.append(marker) # Add info box at = AnchoredText("Selected direction: None", prop=dict(size=12), frameon=True, loc=3) at.patch.set_boxstyle("round,pad=0.,rounding_size=0.2") at.set_zorder(1002) ax.add_artist(at) ax.relim() ax.autoscale() ax.set_aspect('equal') if hasWCSaxes: RAAxis = ax.coords['ra'] RAAxis.set_axislabel('RA', minpad=0.75) RAAxis.set_major_formatter('hh:mm:ss') DecAxis = ax.coords['dec'] DecAxis.set_axislabel('Dec', minpad=0.75) DecAxis.set_major_formatter('dd:mm:ss') ax.coords.grid(color='black', alpha=0.5, linestyle='solid') else: plt.xlabel("RA (arb. units)") plt.ylabel("Dec (arb. units)") # Define coodinate formater to show RA and Dec under mouse pointer ax.format_coord = formatCoord # Show legend not_processed_patch = plt.Rectangle((0, 0), 1, 1, edgecolor='#a9a9a9', facecolor='#F2F2F2', linewidth=2) processing_patch = plt.Rectangle((0, 0), 1, 1, edgecolor='#a9a9a9', facecolor='#F2F5A9', linewidth=2) selfcal_ok_patch = plt.Rectangle((0, 0), 1, 1, edgecolor='#a9a9a9', facecolor='#A9F5A9', linewidth=2) selfcal_not_ok_patch = plt.Rectangle((0, 0), 1, 1, edgecolor='#a9a9a9', facecolor='#A4A4A4', linewidth=2) processing_error = plt.Rectangle((0, 0), 1, 1, edgecolor='#a9a9a9', facecolor='#F5A9A9', linewidth=2) patch_list = [ not_processed_patch, processing_patch, processing_error, selfcal_not_ok_patch, selfcal_ok_patch ] label_list = [ 'Unprocessed', 'Processing', 'Pipeline Error', 'Selfcal Failed', 'Selfcal OK' ] for i in range(options['reimages']): label_list.append('Image ' + str(i + 1)) color = (0.66 / (i + 2)**0.5, 0.96 / (i + 2)**0.5, 0.66 / (i + 2)**0.5, 1.0) reimage_patch = plt.Rectangle((0, 0), 1, 1, edgecolor='#a9a9a9', facecolor=color, linewidth=2) patch_list.append(reimage_patch) l = ax.legend(patch_list, label_list, loc="upper right") l.set_zorder(1002) # Add check for mouse clicks and key presses fig.canvas.mpl_connect('pick_event', on_pick) fig.canvas.mpl_connect('key_press_event', on_press) # Add timer to update the plot every 60 seconds timer = fig.canvas.new_timer(interval=60000) timer.add_callback(update_plot) timer.start() # Show plot plt.show() plt.close(fig) # Clean up any temp casacore images if not hasaplpy: if os.path.exists('/tmp/tempimage'): try: shutil.rmtree('/tmp/tempimage') except OSError: pass
def test_update_clip_path_change_wcs(self, tmpdir): # When WCS is changed, a new frame is created, so we need to make sure # that the path is carried over to the new frame. fig = plt.figure() ax = WCSAxes(fig, [0.1, 0.1, 0.8, 0.8], aspect='equal') fig.add_axes(ax) ax.set_xlim(0., 2.) ax.set_ylim(0., 2.) # Force drawing, which freezes the clip path returned by WCSAxes fig.savefig(tmpdir.join('nothing').strpath) ax.reset_wcs() ax.imshow(np.zeros((12, 4))) ax.set_xlim(-0.5, 3.5) ax.set_ylim(-0.5, 11.5) return fig
def plot_image(img_file, hdu="IMAGE", stretch='linear', vmin=None, vmax=None, facecolor='black', center=None, width=None, figsize=(10, 10), cmap=None): """ Plot a FITS image created by SOXS using Matplotlib. Parameters ---------- img_file : str The on-disk FITS image to plot. hdu : str or int, optional The image extension to plot. Default is "IMAGE" stretch : str, optional The stretch to apply to the colorbar scale. Options are "linear", "log", and "sqrt". Default: "linear" vmin : float, optional The minimum value of the colorbar. If not set, it will be the minimum value in the image. vmax : float, optional The maximum value of the colorbar. If not set, it will be the maximum value in the image. facecolor : str, optional The color of zero-valued pixels. Default: "black" center : array-like A 2-element object giving an (RA, Dec) coordinate for the center in degrees. If not set, the reference pixel of the image (usually the center) is used. width : float, optional The width of the image in degrees. If not set, the width of the entire image will be used. figsize : tuple, optional A 2-tuple giving the size of the image in inches, e.g. (12, 15). Default: (10,10) cmap : str, optional The colormap to be used. If not set, the default Matplotlib colormap will be used. Returns ------- A tuple of the :class:`~matplotlib.figure.Figure` and the :class:`~matplotlib.axes.Axes` objects. """ import matplotlib.pyplot as plt from matplotlib.colors import PowerNorm, LogNorm, Normalize from astropy.wcs.utils import proj_plane_pixel_scales from astropy.visualization.wcsaxes import WCSAxes if stretch == "linear": norm = Normalize(vmin=vmin, vmax=vmax) elif stretch == "log": norm = LogNorm(vmin=vmin, vmax=vmax) elif stretch == "sqrt": norm = PowerNorm(0.5, vmin=vmin, vmax=vmax) else: raise RuntimeError(f"'{stretch}' is not a valid stretch!") with fits.open(img_file) as f: hdu = f[hdu] w = wcs.WCS(hdu.header) pix_scale = proj_plane_pixel_scales(w) if center is None: center = w.wcs.crpix else: center = w.wcs_world2pix(center[0], center[1], 0) if width is None: dx_pix = 0.5 * hdu.shape[0] dy_pix = 0.5 * hdu.shape[1] else: dx_pix = width / pix_scale[0] dy_pix = width / pix_scale[1] fig = plt.figure(figsize=figsize) ax = WCSAxes(fig, [0.15, 0.1, 0.8, 0.8], wcs=w) fig.add_axes(ax) im = ax.imshow(hdu.data, norm=norm, cmap=cmap) ax.set_xlim(center[0] - 0.5 * dx_pix, center[0] + 0.5 * dx_pix) ax.set_ylim(center[1] - 0.5 * dy_pix, center[1] + 0.5 * dy_pix) ax.set_facecolor(facecolor) cbar = plt.colorbar(im) return fig, ax
def plot(LSM, fileName=None, labelBy=None): """ Shows a simple plot of the sky model. The circles in the plot are scaled with flux. If the sky model is grouped into patches, sources are colored by patch and the patch positions are indicated with stars. Parameters ---------- LSM : SkyModel object Input sky model fileName : str, optional If given, the plot is saved to a file instead of displayed labelBy : str, optional One of 'source' or 'patch': label points using source names ('source') or patch names ('patch') Examples: --------- Plot and display to the screen:: >>> LSM = lsmtool.load('sky.model') >>> plot(LSM) Plot and save to a PDF file:: >>> plot(LSM, 'sky_plot.pdf') """ try: import os if 'DISPLAY' not in os.environ: import matplotlib if matplotlib.get_backend() is not 'Agg': matplotlib.use("Agg") import matplotlib.pyplot as plt from matplotlib.ticker import FuncFormatter except Exception as e: raise ImportError('PyPlot could not be imported. Plotting is not ' 'available: {0}'.format(e.message)) try: try: from astropy.visualization.wcsaxes import WCSAxes hasWCSaxes = True except: from wcsaxes import WCSAxes hasWCSaxes = True except: hasWCSaxes = False import numpy as np from ..operations_lib import radec2xy, makeWCS global midRA, midDec, ymin, xmin if len(LSM) == 0: log.error('Sky model is empty.') return fig = plt.figure(1,figsize=(7.66,7)) plt.clf() x, y, midRA, midDec = LSM._getXY() if hasWCSaxes: wcs = makeWCS(midRA, midDec) ax = WCSAxes(fig, [0.16, 0.1, 0.8, 0.8], wcs=wcs) fig.add_axes(ax) else: ax = plt.gca() if LSM.hasPatches: nsrc = len(LSM.getPatchNames()) else: nsrc = len(LSM) sm = plt.cm.ScalarMappable(cmap=plt.cm.Set3, norm=plt.Normalize(vmin=0, vmax=nsrc)) sm._A = [] # Set symbol sizes by flux, making sure no symbol is smaller than 50 or # larger than 1000 s = [] fluxes = LSM.getColValues('I') if len(fluxes[fluxes > 0.0]) == 0: minflux = 0.0 else: minflux = np.min(fluxes[fluxes > 0.0]) for flux in LSM.getColValues('I'): if flux > 0.0: s.append(min(1000.0, (1.0+2.0*np.log10(flux/minflux))*50.0)) else: s.append(50.0) # Color sources by patch if grouped c = [0]*len(LSM) cp = [] if LSM.hasPatches: for p, patchName in enumerate(LSM.getPatchNames()): indices = LSM.getRowIndex(patchName) cp.append(sm.to_rgba(p)) for ind in indices: c[ind] = sm.to_rgba(p) else: c = [sm.to_rgba(0)] * nsrc # Plot sources if hasWCSaxes: RA = LSM.getColValues('Ra') Dec = LSM.getColValues('Dec') ax.set_xlim(np.min(x)-20, np.max(x)+20) ax.set_ylim(np.min(y)-20, np.max(y)+20) plt.scatter(x, y, s=s, c=c) if LSM.hasPatches: RAp, Decp = LSM.getPatchPositions(asArray=True) goodInd = np.where( (RAp != 0.0) & (Decp != 0.0) ) if len(goodInd[0]) < len(RAp): log.info('Some patch positions are unset. Run setPatchPositions() ' 'before plotting to see patch positions and patch names.') xp, yp = radec2xy(RAp[goodInd], Decp[goodInd], midRA, midDec) plt.scatter(xp, yp, s=100, c=cp, marker='*') # Set axis labels, etc. if hasWCSaxes: RAAxis = ax.coords['ra'] RAAxis.set_axislabel('RA', minpad=0.75) RAAxis.set_major_formatter('hh:mm:ss') DecAxis = ax.coords['dec'] DecAxis.set_axislabel('Dec', minpad=0.75) DecAxis.set_major_formatter('dd:mm:ss') ax.coords.grid(color='black', alpha=0.5, linestyle='solid') else: plt.xlabel("RA (arb. units)") plt.ylabel("Dec (arb. units)") if labelBy is not None: if labelBy.lower() == 'source': labels = LSM.getColValues('name') xls = x yls = y elif labelBy.lower() == 'patch': if LSM.hasPatches: labels = LSM.getPatchNames() xls = xp yls = yp else: labels = LSM.getColValues('name') xls = x yls = y else: raise ValueError("The lableBy parameter must be one of 'source' or " "'patch'.") for label, xl, yl in zip(labels, xls, yls): plt.annotate(label, xy = (xl, yl), xytext = (-2, 2), textcoords= 'offset points', ha='right', va='bottom') # Define coodinate formater to show RA and Dec under mouse pointer RAformatter = FuncFormatter(RAtickformatter) ax.format_coord = formatCoord if fileName is not None: plt.savefig(fileName) else: plt.show() plt.close(fig)
def test_update_clip_path_rectangular(self, tmpdir): fig = plt.figure() ax = WCSAxes(fig, [0.1, 0.1, 0.8, 0.8], aspect='equal') fig.add_axes(ax) ax.set_xlim(0., 2.) ax.set_ylim(0., 2.) # Force drawing, which freezes the clip path returned by WCSAxes fig.savefig(tmpdir.join('nothing').strpath) ax.imshow(np.zeros((12, 4))) ax.set_xlim(-0.5, 3.5) ax.set_ylim(-0.5, 11.5) return fig
def plot(self, center, width, s=None, c=None, marker=None, stride=1, emin=None, emax=None, label=None, fontsize=18, fig=None, ax=None, **kwargs): """ Plot event coordinates from this photon list in a scatter plot, optionally restricting the photon energies which are plotted and using only a subset of the photons. Parameters ---------- center : array-like The RA, Dec of the center of the plot in degrees. width : float, (value, unit) tuple, or :class:`~astropy.units.Quantity` The width of the plot in arcminutes. s : integer, optional Size of the scatter marker in points^2. c : string, optional The color of the points. marker : string, optional The marker to use for the points in the scatter plot. Default: 'o' stride : integer, optional Plot every *stride* events. Default: 1 emin : float, (value, unit) tuple, or :class:`~astropy.units.Quantity` The minimum energy of the photons to plot. Default is the minimum energy in the list. emax : float, (value, unit) tuple, or :class:`~astropy.units.Quantity` The maximum energy of the photons to plot. Default is the maximum energy in the list. label : string, optional The label of the spectrum. Default: None fontsize : int Font size for labels and axes. Default: 18 fig : :class:`~matplotlib.figure.Figure`, optional A Figure instance to plot in. Default: None, one will be created if not provided. ax : :class:`~matplotlib.axes.Axes`, optional An Axes instance to plot in. Default: None, one will be created if not provided. """ import matplotlib.pyplot as plt from astropy.visualization.wcsaxes import WCSAxes if fig is None: fig = plt.figure(figsize=(10, 10)) if ax is None: wcs = construct_wcs(center[0], center[1]) ax = WCSAxes(fig, [0.15, 0.1, 0.8, 0.8], wcs=wcs) fig.add_axes(ax) else: wcs = ax.wcs if emin is None: emin = self.energy.value.min() else: emin = parse_value(emin, "keV") if emax is None: emax = self.energy.value.max() else: emax = parse_value(emax, "keV") idxs = np.logical_and(self.energy.value >= emin, self.energy.value <= emax) ra = self.ra[idxs][::stride].value dec = self.dec[idxs][::stride].value x, y = wcs.wcs_world2pix(ra, dec, 1) ax.scatter(x, y, s=s, c=c, marker=marker, label=label, **kwargs) x0, y0 = wcs.wcs_world2pix(center[0], center[1], 1) width = parse_value(width, "arcmin") * 60.0 ax.set_xlim(x0 - 0.5 * width, x0 + 0.5 * width) ax.set_ylim(y0 - 0.5 * width, y0 + 0.5 * width) ax.set_xlabel("RA") ax.set_ylabel("Dec") ax.tick_params(axis='both', labelsize=fontsize) return fig, ax
def plot_state(directions_list, trim_names=True): """ Plots the facets of a run """ global midRA, midDec, fig, at, selected_direction, choose_from_list selected_direction = None choose_from_list = False # Set up coordinate system and figure points, midRA, midDec = factor.directions.getxy(directions_list) fig = plt.figure(1, figsize=(10,9)) if hasWCSaxes: wcs = factor.directions.makeWCS(midRA, midDec) ax = WCSAxes(fig, [0.16, 0.1, 0.8, 0.8], wcs=wcs) fig.add_axes(ax) else: ax = plt.gca() field_x = min(points[0]) field_y = max(points[1]) adjust_xy = True while adjust_xy: adjust_xy = False for xy in points: dist = np.sqrt( (xy[0] - field_x)**2 + (xy[1] - field_y)**2 ) if dist < 10.0: field_x -= 1 field_y += 1 adjust_xy = True break field_ra, field_dec = factor.directions.xy2radec([field_x], [field_y], refRA=midRA, refDec=midDec) field = Direction('field', field_ra[0], field_dec[0], factor_working_dir=directions_list[0].working_dir) directions_list.append(field) ax.set_title('Overview of Factor run in\n{}'.format(directions_list[0].working_dir)) # Plot facets markers = [] for direction in directions_list: if direction.name != 'field': vertices = read_vertices(direction.vertices_file) RAverts = vertices[0] Decverts = vertices[1] xverts, yverts = factor.directions.radec2xy(RAverts, Decverts, refRA=midRA, refDec=midDec) xyverts = [np.array([xp, yp]) for xp, yp in zip(xverts, yverts)] mpl_poly = Polygon(np.array(xyverts), edgecolor='#a9a9a9', facecolor='#F2F2F2', clip_box=ax.bbox, picker=3.0, linewidth=2) else: xverts = [field_x] yverts = [field_y] mpl_poly = Circle((field_x, field_y), radius=5.0, edgecolor='#a9a9a9', facecolor='#F2F2F2', clip_box=ax.bbox, picker=3.0, linewidth=2) mpl_poly.facet_name = direction.name mpl_poly.completed_ops = get_completed_ops(direction) mpl_poly.started_ops = get_started_ops(direction) mpl_poly.current_op = get_current_op(direction) set_patch_color(mpl_poly, direction) ax.add_patch(mpl_poly) # Add facet names if direction.name != 'field': poly_tuple = tuple([(xp, yp) for xp, yp in zip(xverts, yverts)]) xmid = SPolygon(poly_tuple).centroid.x ymid = SPolygon(poly_tuple).centroid.y else: xmid = field_x ymid = field_y if trim_names: name = direction.name.split('_')[-1] else: name = direction.name marker = ax.text(xmid, ymid, name, color='k', clip_on=True, clip_box=ax.bbox, ha='center', va='bottom') marker.set_zorder(1001) markers.append(marker) # Add info box at = AnchoredText("Selected direction: None", prop=dict(size=12), frameon=True, loc=3) at.patch.set_boxstyle("round,pad=0.,rounding_size=0.2") at.set_zorder(1002) ax.add_artist(at) ax.relim() ax.autoscale() ax.set_aspect('equal') if hasWCSaxes: RAAxis = ax.coords['ra'] RAAxis.set_axislabel('RA', minpad=0.75) RAAxis.set_major_formatter('hh:mm:ss') DecAxis = ax.coords['dec'] DecAxis.set_axislabel('Dec', minpad=0.75) DecAxis.set_major_formatter('dd:mm:ss') ax.coords.grid(color='black', alpha=0.5, linestyle='solid') else: plt.xlabel("RA (arb. units)") plt.ylabel("Dec (arb. units)") # Define coodinate formater to show RA and Dec under mouse pointer ax.format_coord = formatCoord # Show legend not_processed_patch = plt.Rectangle((0, 0), 1, 1, edgecolor='#a9a9a9', facecolor='#F2F2F2', linewidth=2) processing_patch = plt.Rectangle((0, 0), 1, 1, edgecolor='#a9a9a9', facecolor='#F2F5A9', linewidth=2) selfcal_ok_patch = plt.Rectangle((0, 0), 1, 1, edgecolor='#a9a9a9', facecolor='#A9F5A9', linewidth=2) selfcal_not_ok_patch = plt.Rectangle((0, 0), 1, 1, edgecolor='#a9a9a9', facecolor='#A4A4A4', linewidth=2) processing_error = plt.Rectangle((0, 0), 1, 1, edgecolor='#a9a9a9', facecolor='#F5A9A9', linewidth=2) patch_list=[not_processed_patch, processing_patch, processing_error, selfcal_not_ok_patch, selfcal_ok_patch] label_list=['Unprocessed', 'Processing', 'Pipeline Error', 'Selfcal Failed', 'Selfcal OK'] for i in range(options['reimages']): label_list.append('Image '+str(i+1)) color=(0.66/(i+2)**0.5, 0.96/(i+2)**0.5, 0.66/(i+2)**0.5, 1.0) reimage_patch=plt.Rectangle((0, 0), 1, 1, edgecolor='#a9a9a9', facecolor=color, linewidth=2) patch_list.append(reimage_patch) l = ax.legend(patch_list, label_list, loc="upper right") l.set_zorder(1002) # Add check for mouse clicks and key presses fig.canvas.mpl_connect('pick_event', on_pick) fig.canvas.mpl_connect('key_press_event', on_press) # Add timer to update the plot every 60 seconds timer = fig.canvas.new_timer(interval=60000) timer.add_callback(update_plot) timer.start() # Show plot plt.show() plt.close(fig) # Clean up any temp casacore images if not hasaplpy: if os.path.exists('/tmp/tempimage'): try: shutil.rmtree('/tmp/tempimage') except OSError: pass
def plot(self, center, width, s=None, c=None, marker=None, stride=1, emin=None, emax=None, label=None, fontsize=18, fig=None, ax=None, **kwargs): """ Plot event coordinates from this photon list in a scatter plot, optionally restricting the photon energies which are plotted and using only a subset of the photons. Parameters ---------- center : array-like The RA, Dec of the center of the plot in degrees. width : float, (value, unit) tuple, or :class:`~astropy.units.Quantity` The width of the plot in arcminutes. s : integer, optional Size of the scatter marker in points^2. c : string, optional The color of the points. marker : string, optional The marker to use for the points in the scatter plot. Default: 'o' stride : integer, optional Plot every *stride* events. Default: 1 emin : float, (value, unit) tuple, or :class:`~astropy.units.Quantity` The minimum energy of the photons to plot. Default is the minimum energy in the list. emax : float, (value, unit) tuple, or :class:`~astropy.units.Quantity` The maximum energy of the photons to plot. Default is the maximum energy in the list. label : string, optional The label of the spectrum. Default: None fontsize : int Font size for labels and axes. Default: 18 fig : :class:`~matplotlib.figure.Figure`, optional A Figure instance to plot in. Default: None, one will be created if not provided. ax : :class:`~matplotlib.axes.Axes`, optional An Axes instance to plot in. Default: None, one will be created if not provided. """ import matplotlib.pyplot as plt from astropy.visualization.wcsaxes import WCSAxes if fig is None: fig = plt.figure(figsize=(10, 10)) if ax is None: wcs = construct_wcs(center[0], center[1]) ax = WCSAxes(fig, [0.15, 0.1, 0.8, 0.8], wcs=wcs) fig.add_axes(ax) else: wcs = ax.wcs if emin is None: emin = self.energy.value.min() else: emin = parse_value(emin, "keV") if emax is None: emax = self.energy.value.max() else: emax = parse_value(emax, "keV") idxs = np.logical_and(self.energy.value >= emin, self.energy.value <= emax) ra = self.ra[idxs][::stride].value dec = self.dec[idxs][::stride].value x, y = wcs.wcs_world2pix(ra, dec, 1) ax.scatter(x, y, s=s, c=c, marker=marker, label=label, **kwargs) x0, y0 = wcs.wcs_world2pix(center[0], center[1], 1) width = parse_value(width, "arcmin")*60.0 ax.set_xlim(x0-0.5*width, x0+0.5*width) ax.set_ylim(y0-0.5*width, y0+0.5*width) ax.set_xlabel("RA") ax.set_ylabel("Dec") ax.tick_params(axis='both', labelsize=fontsize) return fig, ax
import matplotlib.pyplot as plt import matplotlib.cm as cm from astropy.io import fits import pyregion # read in the image xray_name = "pspc_skyview.fits" f_xray = fits.open(xray_name) try: from astropy.wcs import WCS from astropy.visualization.wcsaxes import WCSAxes wcs = WCS(f_xray[0].header) fig = plt.figure() ax = WCSAxes(fig, [0.1, 0.1, 0.8, 0.8], wcs=wcs) fig.add_axes(ax) except ImportError: ax = plt.subplot(111) ax.imshow(f_xray[0].data, cmap=cm.gray, vmin=0., vmax=0.00038, origin="lower") reg_name = "test.reg" r = pyregion.open(reg_name).as_imagecoord(header=f_xray[0].header) from pyregion.mpl_helper import properties_func_default # Use custom function for patch attribute def fixed_color(shape, saved_attrs): attr_list, attr_dict = saved_attrs
def test_rcparams(self): # Test custom rcParams with rc_context({ 'axes.labelcolor': 'purple', 'axes.labelsize': 14, 'axes.labelweight': 'bold', 'axes.linewidth': 3, 'axes.facecolor': '0.5', 'axes.edgecolor': 'green', 'xtick.color': 'red', 'xtick.labelsize': 8, 'xtick.direction': 'in', 'xtick.minor.visible': True, 'xtick.minor.size': 5, 'xtick.major.size': 20, 'xtick.major.width': 3, 'xtick.major.pad': 10, 'grid.color': 'blue', 'grid.linestyle': ':', 'grid.linewidth': 1, 'grid.alpha': 0.5 }): fig = plt.figure(figsize=(6, 6)) ax = WCSAxes(fig, [0.15, 0.1, 0.7, 0.7], wcs=None) fig.add_axes(ax) ax.set_xlim(-0.5, 2) ax.set_ylim(-0.5, 2) ax.grid() ax.set_xlabel('X label') ax.set_ylabel('Y label') ax.coords[0].set_ticklabel(exclude_overlapping=True) ax.coords[1].set_ticklabel(exclude_overlapping=True) return fig
import matplotlib.pyplot as plt import matplotlib.cm as cm from astropy.io import fits import pyregion # read in the image xray_name = "pspc_skyview.fits" f_xray = fits.open(xray_name) try: from astropy.wcs import WCS from astropy.visualization.wcsaxes import WCSAxes wcs = WCS(f_xray[0].header) fig = plt.figure() ax = WCSAxes(fig, [0.1, 0.1, 0.8, 0.8], wcs=wcs) fig.add_axes(ax) except ImportError: ax = plt.subplot(111) ax.imshow(f_xray[0].data, cmap=cm.gray, vmin=0., vmax=0.00038, origin="lower") reg_name = "test.reg" r = pyregion.open(reg_name).as_imagecoord(f_xray[0].header) patch_list, text_list = r.get_mpl_patches_texts() for p in patch_list: ax.add_patch(p) for t in text_list: ax.add_artist(t)
import pyregion region_list = [ "test_text.reg", "test_context.reg", ] # Create figure fig = plt.figure(figsize=(8, 4)) # Parse WCS information header = Header.fromtextfile('sample_fits01.header') wcs = WCS(header) # Create axes ax1 = WCSAxes(fig, [0.1, 0.1, 0.4, 0.8], wcs=wcs) fig.add_axes(ax1) ax2 = WCSAxes(fig, [0.5, 0.1, 0.4, 0.8], wcs=wcs) fig.add_axes(ax2) # Hide labels on y axis ax2.coords[1].set_ticklabel_position('') for ax, reg_name in zip([ax1, ax2], region_list): ax.set_xlim(300, 1300) ax.set_ylim(300, 1300) ax.set_aspect(1) r = pyregion.open(reg_name).as_imagecoord(header)
@pytest.mark.parametrize("test_input, test_kwargs, expected_error", [ (cube[0, 0], {"axes_coordinates": np.arange(10, 10 + cube_unit[0, 0].data.shape[0]), "axes_units": u.C}, TypeError), (cube[0, 0], {"data_unit": u.C}, TypeError) ]) def test_cube_plot_1D_errors(test_input, test_kwargs, expected_error): with pytest.raises(expected_error): output = test_input.plot(**test_kwargs) @pytest.mark.parametrize("test_input, test_kwargs, expected_values", [ (cube[0], {}, (np.ma.masked_array(cube[0].data, cube[0].mask), "time [min]", "em.wl [m]", (-0.5, 3.5, -0.5, 2.5))), (cube_spatial, {'axes': WCSAxes(plt.figure(), (0, 0, 1, 1), wcs=cube_spatial.wcs)}, (cube_spatial.data, "custom:pos.helioprojective.lat [deg]", "custom:pos.helioprojective.lon [deg]", (-0.5, 3.5, -0.5, 2.5))), (cube[0], {"axes_coordinates": ["bye", None], "axes_units": [None, u.cm]}, (np.ma.masked_array(cube[0].data, cube[0].mask), "bye [m]", "em.wl [cm]", (0.0, 3.0, 2e-9, 6e-9))), (cube[0], {"axes_coordinates": [np.arange(10, 10 + cube[0].data.shape[1]), u.Quantity(np.arange(10, 10 + cube[0].data.shape[0]), unit=u.m)], "axes_units": [None, u.cm]}, (np.ma.masked_array(cube[0].data, cube[0].mask), " [None]", " [cm]", (10, 13, 1000, 1200))), (cube[0], {"axes_coordinates": [np.arange(10, 10 + cube[0].data.shape[1]),