def build_image(id_, set_, bands=['EUC_VIS', 'EUC_H', 'EUC_J', 'EUC_Y'], img_size=200, scale=100, clip=True): tables = [] data = np.empty((img_size, img_size, len(bands))) for i, band in enumerate(bands): fname = get_image_filename_from_id(id_, band, set_) try: tables.append(fits.open(fname)) except FileNotFoundError as fe: raise if band != 'EUC_VIS': band_data, data_footprint = reproject_interp( tables[i][0], tables[0][0].header) else: band_data = tables[0][0].data band_data[np.isnan(band_data)] = 0. if clip: interval = AsymmetricPercentileInterval(0.25, 99.75, n_samples=100000) vmin, vmax = interval.get_limits(band_data) stretch = MinMaxInterval() + LogStretch() data[:, :, i] = stretch( ((np.clip(band_data, -vmin * 0.7, vmax)) / (vmax))) else: stretch = LogStretch() + MinMaxInterval() data[:, :, i] = stretch(band_data) for t in tables: t.close() return data.astype(np.float32)
def _initialize_plot(self): # self.colors = ["#7F7F7F", "#D62728", "#2CA02C"] self.d_colors = { "S": "#7F7F7F", "I": "#D62728", "I_UK": "#135DD8", "R": "#2CA02C", } # orangy red: #D66727, normal red: #D62728 factor = self.cfg["N_tot"] / 580_000 self.norm_1000 = ImageNormalize(vmin=0.0, vmax=1000 * factor, stretch=LogStretch()) self.norm_100 = ImageNormalize(vmin=0.0, vmax=100 * factor, stretch=LogStretch()) self.norm_10 = ImageNormalize(vmin=0.0, vmax=10 * factor, stretch=LogStretch()) self.f_norm = lambda x: ImageNormalize( vmin=0.0, vmax=x * factor, stretch=LogStretch()) # self.states = ['S', 'E', 'I', 'R'] if self.split_corona_types: self.states = ["S", "I", "I_UK", "R"] else: self.states = ["S", "I", "R"] self.state_names = { "S": "Susceptable", "I": r"Infected $\&$ Exposed", "I_UK": r"I $\&$ E UK", "R": "Recovered", } # create the new map cmap = mpl.colors.ListedColormap( [self.d_colors["R"], self.d_colors["I"]]) bounds = [0, 0.5, 1] norm = mpl.colors.BoundaryNorm(bounds, cmap.N) self._scatter_kwargs = dict(cmap=cmap, norm=norm, edgecolor="none") self._geo_plot_kwargs = {} self._geo_plot_kwargs["S"] = dict(alpha=0.2, norm=self.norm_1000) self._geo_plot_kwargs["R"] = dict(alpha=0.3, norm=self.norm_100) self._geo_plot_kwargs["I"] = dict(norm=self.norm_10) self._geo_plot_kwargs["I_UK"] = dict(norm=self.norm_10)
def norm_set(self, text): stretch_dict = { 'No Stretch': None, 'Sqrt Stretch': SqrtStretch(), 'Linear Stretch': LinearStretch(), 'Squared Stretch': SquaredStretch(), 'Power Stretch': PowerStretch(self.a), 'Log Stretch': LogStretch(self.a) } self.stretch_val = stretch_dict[text] if text == 'Log Stretch': self.var_max = 10000 self.var_int = 10 self.cube = self.stretch_val(self.cube_base) #self.norm = ImageNormalize(interval=self.interval, stretch=self.stretch_val, clip=True) elif text == 'Sqrt Stretch': self.cube = np.sqrt(self.cube) elif text == 'No Stretch': self.norm = None self.cube = self.cube_base else: self.var_max = 10 self.var_int = 1 self.cube = self.stretch_val(self.cube_base, clip=False) #self.norm = ImageNormalize(interval=self.interval, stretch=self.stretch_val, clip=True) main.create_image(self.cube)
def muti_draw_picture(NDVI, predict): # Density Map # set the Chinese font type plt.rcParams['font.sans-serif'] = ["DFKai-SB"] plt.rcParams['axes.unicode_minus'] = False fig = plt.figure(figsize=(12, 10)) # set the figure and the size density_map = fig.add_subplot( 111, projection='scatter_density') # Add a subplot in the figure # density_map.set_title("點密度圖",size=20) #set the title normalize = ImageNormalize(vmin=0, vmax=100, stretch=LogStretch()) density = density_map.scatter_density(predict, NDVI, cmap=plt.cm.Blues, norm=normalize) density_map.plot([-1, 1], [-1, 1], 'black') density_map.set_xlim(-1, 1) density_map.set_ylim(-1, 1) density_map.set_xticklabels([]) density_map.set_yticklabels([]) # divider = make_axes_locatable(density_map) # cax = divider.append_axes('right', size='5%', pad=0.1) colorbar = plt.colorbar(density) # , cax=cax, orientation='vertical') colorbar.set_ticks(np.linspace(0, 100, 5)) colorbar.set_ticklabels([0, '', 50, '', 100]) colorbar.ax.tick_params(labelsize=50) fig.savefig("./result2.png")
def make_image(parms, fitsfile, pngfile): ''' Function to read in the FITS file from Oculus. - performe a LogStretch - Write image array to a PNG ''' img_data = fits.getdata(fitsfile) img_data = img_data - numpy.min(img_data) img_data = img_data / numpy.max(img_data) stretch = LogStretch() img_data = stretch(img_data) img_data = img_data * 255 result = Image.fromarray(img_data.astype(numpy.uint8)) fontB = ImageFont.truetype( "/usr/share/fonts/truetype/freefont/FreeSans.ttf", 32) fontS = ImageFont.truetype( "/usr/share/fonts/truetype/freefont/FreeSans.ttf", 24) titletxt = '%s All Sky - %s' % (parms['observatory'], parms['night']) timetxt = 'timestamp = %s' % (parms['utc'].strftime("%Y-%m-%d %H:%M:%S")) if parms['exp'] >= 1: expotxt = 'exposure = %.0f s' % (parms['exp']) else: expotxt = 'exposure = %.4f s' % (parms['exp']) draw = ImageDraw.Draw(result) draw.text((10, 10), titletxt, font=fontB, fill=255) draw.text((10, 60), timetxt, font=fontS, fill=255) draw.text((10, 100), expotxt, font=fontS, fill=255) result.save(pngfile) del fontB, fontS, titletxt, timetxt, expotxt del img_data, result, draw return
def plot_sources(self, data, centroids, vmin=0, vmax=1000, radius=3, color='red', ax=None): """Draw apertures at each position (x_pos[i], y_pos[i]) Parameters ---------- data centroid radius color ax Returns ------- """ m_interval = ManualInterval(vmin=vmin, vmax=vmax) norm = ImageNormalize(data, stretch=LogStretch(), interval=m_interval) if ax is None: fig, ax = plt.subplots(nrows=1, ncols=1) ax.grid(False) ax.imshow(data, norm=norm, cmap='gray', origin='lower') for xy in centroids: aperture = self.mk_aperture(xy, radius, color) ax.add_patch(aperture) return ax
def get_norm(self, stretch_text, scale_text): image = self.ax.get_images()[0] scale = None if scale_text == "MinMax": scale = MinMaxInterval() else: scale = ZScaleInterval() if stretch_text == "Linear": stretch = LinearStretch() elif stretch_text == "Log": stretch = LogStretch() else: stretch = SqrtStretch() minV, maxV = scale.get_limits( self.cubeObj.data_cube[self.cubeObj.currSlice]) norm = ImageNormalize(vmin=minV, vmax=maxV, stretch=stretch) return norm
def make_oneone(ax, img, result): '''Function plots the cleaned image Parameters ---------- ax : matplotlip axis object img : np.ndarray image data to be plotted results : Result dataclass dataclass of calculated results for object Returns ------- ''' log_stretch = LogStretch(10000.) ax.imshow(log_stretch(_normalise(img)), origin="lower", aspect="auto") ax.scatter(result.apix[0], result.apix[1], label="Asym. centre") ax.set_xlim([-0.5, img.shape[0] + 0.5]) ax.set_title("Cleaned Image") text = f"Sky={result.sky:.2f}\n" fr"Sky $\sigma$={result.sky_err:.2f}" textbox = AnchoredText(text, frameon=True, loc=3, pad=0.5) ax.add_artist(textbox)
def plot_sdss_image(sdss_hdu, wcs_sdss): fig = plt.figure(figsize=(8, 8)) ax = fig.add_subplot(111, projection=wcs_sdss) norm = ImageNormalize(sdss_hdu[0].data, interval=ZScaleInterval(), stretch=LogStretch()) orig_cmap = mpl.cm.Greys shifted_cmap = shiftedColorMap(orig_cmap, midpoint=0.6, name='shifted') im = ax.imshow(sdss_hdu[0].data, origin='lower', cmap=shifted_cmap, vmin=0, vmax=3, norm=norm) ax.set_autoscale_on(False) lon = ax.coords[0] lat = ax.coords[1] lon.set_ticks_visible(False) lon.set_ticklabel_visible(False) lat.set_ticks_visible(False) lat.set_ticklabel_visible(False) lon.set_axislabel('') lat.set_axislabel('') ax.coords.frame.set_color('None') return fig, ax
def __init__(self): pass self.image_norms = { 'log': LogStretch(), 'linear': LinearStretch(), 'sqrt': SqrtStretch(), } self.map = None
def __init__(self, data, header, **kwargs): super().__init__(data, header, **kwargs) self._nickname = self.detector # Colour maps self.plot_settings['cmap'] = 'trace' + str(self.meta['WAVE_LEN']) self.plot_settings['norm'] = ImageNormalize( stretch=source_stretch(self.meta, LogStretch()), clip=False)
def plot_box(box, title=None, path=None, format=None, scale="log", interval="pts", cmap="viridis"): """ This function ... :param box: :param title: :param path: :param format: :param scale: :param interval: :param cmap: :return: """ # Other new colormaps: plasma, magma, inferno # Normalization if scale == "log": norm = ImageNormalize(stretch=LogStretch()) elif scale == "sqrt": norm = ImageNormalize(stretch=SqrtStretch()) #elif scale == "skimage": norm = exposure.equalize_hist else: raise ValueError("Invalid option for 'scale'") if interval == "zscale": vmin, vmax = ZScaleInterval().get_limits(box) elif interval == "pts": # Determine the maximum value in the box and the mimimum value for plotting vmin = max(np.nanmin(box), 0.) vmax = 0.5 * (np.nanmax(box) + vmin) elif interval == "minmax": vmin, vmax = MinMaxInterval().get_limits(box) elif isinstance(interval, tuple): vmin = interval[0] vmax = interval[1] else: raise ValueError("Invalid option for 'interval'") # Make the plot plt.figure(figsize=(7,7)) plt.imshow(box, origin="lower", interpolation="nearest", vmin=vmin, vmax=vmax, norm=norm, cmap=cmap) plt.xlim(0, box.shape[1]-1) plt.ylim(0, box.shape[0]-1) if title is not None: plt.title(title) if path is None: plt.show() else: plt.savefig(path, format=format) plt.close() # Return vmin and vmax return vmin, vmax
def plot_grid(self, filename=None, show=False, plot_radii=[], xy_lim=None): if self.uv_grid is None: self.grid_uvw_coords() grid_size = self.uv_grid.shape[0] wavelength = const.c.value / self.freq_hz fov_rad = Imager.uv_cellsize_to_fov(self.grid_cell_size_m / wavelength, grid_size) extent = Imager.grid_extent_wavelengths(degrees(fov_rad), grid_size) extent = np.array(extent) * wavelength fig, ax = plt.subplots(figsize=(8, 8), ncols=1, nrows=1) fig.subplots_adjust(left=0.125, bottom=0.1, right=0.9, top=0.9, wspace=0.2, hspace=0.2) image = self.uv_grid.real options = dict(interpolation='nearest', cmap='gray_r', extent=extent, origin='lower') im = ax.imshow(image, norm=ImageNormalize(stretch=LogStretch()), **options) divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="2%", pad=0.03) cbar = ax.figure.colorbar(im, cax=cax) cbar.set_label('baselines per pixel') cbar.ax.tick_params(labelsize='small') ticks = np.arange(5) * round(image.max() / 5) ticks = np.append(ticks, image.max()) cbar.set_ticks(ticks, update_ticks=True) for r in plot_radii: ax.add_artist(plt.Circle((0, 0), r, fill=False, color='r')) ax.set_xlabel('uu (m)') ax.set_ylabel('vv (m)') ax.grid(True) if xy_lim is not None: ax.set_xlim(-xy_lim, xy_lim) ax.set_ylim(-xy_lim, xy_lim) if filename is not None: label = '' if 'ska1_v5' in filename: label = 'SKA1 v5' if 'model' in filename: label = 'Model ' + str( re.search(r'\d+', os.path.basename(filename)).group()) ax.text(0.02, 0.95, label, weight='bold', transform=ax.transAxes) fig.savefig(filename) if show: plt.show() if filename is not None or show: plt.close(fig) else: return fig
def plot_unwise(tile, ax=None): unwise_hdu = fits.open(tile) data = unwise_hdu[0].data # unorm = ImageNormalize(stretch=AsinhStretch(), data=data) unorm = ImageNormalize(stretch=LogStretch(400), data=data) if ax is None: plt.imshow(data, cmap="plasma", norm=unorm) else: ax.imshow(data, cmap="plasma", norm=unorm)
def LogNorm(): """Custom LogNorm. Returns ------- ImageNormalize """ return ImageNormalize(stretch=LogStretch())
def FITSprocess(img): # img is the FITS image's filepath from astropy.io import fits from astropy.visualization import ZScaleInterval from astropy.visualization import LogStretch interval = ZScaleInterval() stretch = LogStretch() image_data = interval(stretch(fits.getdata(img))) return image_data
def preprocess_band(image, clip=True): """Do clip preprocessing of a single band. param: image (ndarray): 2D array containing a single band's data. param: clip (bool): Whether or not to do clip preprocessing. if False log-stretches the image. Defaults to True. returns: newimage (ndarray): Preprocessed 2D array.""" image[np.isnan(image)] = 0. if clip: interval = AsymmetricPercentileInterval(0.25, 99.75, n_samples=100000) vmin, vmax = interval.get_limits(image) stretch = MinMaxInterval() + LogStretch() newimage = stretch(((np.clip(image, -vmin * 0.7, vmax)) / (vmax))) else: stretch = LogStretch() + MinMaxInterval() newimage = stretch(image) return newimage
def plot_difference(box_a, box_b, share_colorscale=False, title=None): """ This function ... :param box_a: :param box_b: :param share_colorscale: :return: """ #norm = ImageNormalize(stretch=SqrtStretch()) norm = ImageNormalize(stretch=LogStretch()) # Determine the maximum value in the box and the minimum value for plotting #vmax = np.nanmax(box_a) #vmin = np.nanmin(box_a) if vmax <= 0 else 0.0 vmin = np.nanmin(box_a) vmax = 0.5 * (np.nanmax(box_b) + vmin) # Plot the data with the best-fit model plt.figure(figsize=(8,2.5)) plt.subplot(1,3,1) #plt.imshow(box_a, origin='lower', interpolation='nearest', vmin=vmin, vmax=vmax) plt.imshow(box_a, origin='lower', interpolation="nearest", norm=norm, vmin=vmin, vmax=vmax, cmap="viridis") plt.xlim(0, box_a.shape[1]-1) plt.ylim(0, box_a.shape[0]-1) plt.title("Data a") plt.subplot(1,3,2) #plt.imshow(box_b, origin='lower', interpolation='nearest', vmin=0.0, vmax=vmax) plt.imshow(box_b, origin='lower', interpolation="nearest", norm=norm, vmin=0.0, vmax=vmax, cmap="viridis") plt.xlim(0, box_a.shape[1]-1) plt.ylim(0, box_a.shape[0]-1) plt.title("Data b") plt.subplot(1,3,3) if share_colorscale: plt.imshow(box_a - box_b, origin='lower', interpolation="nearest", norm=norm, vmin=0.0, vmax=vmax, cmap="viridis") plt.xlim(0, box_a.shape[1]-1) plt.ylim(0, box_a.shape[0]-1) plt.title("Residual") #plt.imshow(box_a - box_b, origin='lower', interpolation='nearest', vmin=0.0, vmax=vmax) else: residualimage = plt.imshow(box_a - box_b, origin='lower', interpolation="nearest", cmap="viridis") plt.xlim(0, box_a.shape[1]-1) plt.ylim(0, box_a.shape[0]-1) plt.title("Residual") plt.colorbar(residualimage, format="%.2f") # Set the main title if title is not None: plt.suptitle(title, size=16) plt.show()
def plot_radio(tile, ax=None): vlass_hdu = fits.open(tile) data = vlass_hdu[0].data rms = pu.rms_estimate(data) vmin = 0.25 * rms vmax = np.nanmax(data) norm = ImageNormalize(stretch=LogStretch(100), vmin=vmin, vmax=vmax) # norm = ImageNormalize(stretch=AsinhStretch(0.01), vmin=vmin, vmax=vmax) if ax is None: plt.imshow(data, norm=norm) else: ax.imshow(data, norm=norm)
def LSBImage(dat, noise): plt.figure(figsize = (6,6)) plt.imshow(dat, origin = 'lower', cmap = 'Greys', norm = ImageNormalize(stretch=HistEqStretch(dat))) my_cmap = cm.Greys_r my_cmap.set_under('k', alpha=0) plt.imshow(np.clip(dat,a_min = noise, a_max = None), origin = 'lower', cmap = my_cmap, norm = ImageNormalize(stretch=LogStretch(), clip = False), clim = [3*noise, None], vmin = 3*noise) plt.xticks([]) plt.yticks([]) plt.subplots_adjust(left=0.03, right=0.97, top=0.97, bottom=0.05)
def test_norm(self): from astropy.visualization import LogStretch from astropy.visualization.mpl_normalize import ImageNormalize norm = ImageNormalize(vmin=0., vmax=1000, stretch=LogStretch()) a = ScatterDensityArtist(self.ax, self.x1, self.y1, norm=norm) self.ax.add_artist(a) self.ax.set_xlim(-3, 5) self.ax.set_ylim(-2, 4) return self.fig
def plot_background_center(cutout, mask, peaks=None, title=None, show=True, scale="sqrt"): """ This function ... :param cutout: :param mask: :param peaks: :param title: :param show: :param scale: :return: """ if scale == "sqrt": norm = ImageNormalize(stretch=SqrtStretch()) elif scale == "log": norm = ImageNormalize(stretch=LogStretch()) else: raise ValueError("Invalid scale option") # Determine the maximum value in the box and the minimum value for plotting vmax = np.nanmax(cutout) vmin = np.nanmin(cutout) if vmax <= 0 else 0.0 # Plot the data with the best-fit model plt.figure(figsize=(10,4)) plt.subplot(1,3,1) plt.imshow(cutout, origin='lower', interpolation="nearest", norm=norm, vmin=vmin, vmax=vmax, cmap="viridis") plt.xlim(0.5, cutout.xsize-0.5) plt.ylim(0.5, cutout.ysize-0.5) plt.title("Cutout") # Get raw data of mask as a numpy array if hasattr(mask, "data"): maskdata = mask.data else: maskdata = mask plt.subplot(1,3,2) plt.imshow(np.ma.masked_array(cutout, mask=maskdata), origin='lower', interpolation="nearest", norm=norm, vmin=vmin, vmax=vmax, cmap="viridis") plt.xlim(0.5, cutout.xsize-0.5) plt.ylim(0.5, cutout.ysize-0.5) plt.title("Masked source") plt.subplot(1,3,3) plt.imshow(np.ma.masked_array(cutout, mask=np.logical_not(maskdata)), origin='lower', interpolation="nearest", norm=norm, vmin=vmin, vmax=vmax, cmap="viridis") if peaks is not None: plt.plot(peaks[0], peaks[1], ls='none', color='white', marker='+', ms=40, lw=10, mew=4) plt.xlim(0.5, cutout.xsize-0.5) plt.ylim(0.5, cutout.ysize-0.5) plt.title("Masked background") # Set the main title if title is not None: plt.suptitle(title, size=16) # Show the plot if show: plt.show()
def make_photerror_plot(tile): """ This function produces a photometry error plot for bands J, H and Ks using given tile. :param tile: :return: """ file = tile.get_file(data_dir) table = read_fits_table(file) fig = plt.figure() ax1 = fig.add_subplot(311, projection='scatter_density') ax2 = fig.add_subplot(312, projection='scatter_density') ax3 = fig.add_subplot(313, projection='scatter_density') fig.subplots_adjust(hspace=0) fig.suptitle(f'Photometric error for tile {tile.name}') norm = ImageNormalize(vmin=0., vmax=1000, stretch=LogStretch()) ax1.scatter_density(table['mag_J'], table['er_J'], color='blue', norm=norm, label='J') ax2.scatter_density(table['mag_H'], table['er_H'], color='green', norm=norm, label='H') ax3.scatter_density(table['mag_Ks'], table['er_Ks'], color='red', norm=norm, label='Ks') ax3.set_xlabel('Magnitudes') # Hide x labels and tick labels for all but bottom plot. Tweak default plot configuration for ax, lbl in zip([ax1, ax2, ax3], ['J', 'H', 'Ks']): ax.label_outer() ax.set_ylim(-0.05, 0.4) ax.set_xlim(9.9, 22.1) red_patch = mpatches.Patch(label=lbl, alpha=0.00) ax.legend(handles=[red_patch], loc='upper left', markerscale=0, markerfirst=False, framealpha=0.00) ax.set_ylabel(r'$\sigma$') fig.savefig(f'figphoterr_{tile.name}_v2.png', overwrite=True) fig.clf()
def __init__(self, data, header, **kwargs): GenericMap.__init__(self, data, header, **kwargs) # It needs to be verified that these must actually be set and are not # already in the header. self.meta['detector'] = "TRACE" self.meta['obsrvtry'] = "TRACE" self._nickname = self.detector # Colour maps self.plot_settings['cmap'] = plt.get_cmap('trace' + str(self.meta['WAVE_LEN'])) self.plot_settings['norm'] = ImageNormalize( stretch=source_stretch(self.meta, LogStretch()))
def preview_image(HDU): """For an image, preview""" from astropy.visualization import quantity_support, PercentileInterval, LogStretch from astropy.visualization.mpl_normalize import ImageNormalize from astropy.wcs import WCS with quantity_support(): fig = plt.figure() ax = fig.add_subplot(1,1,1, projection=WCS(HDU.header)) image = PercentileInterval(90)(HDU.data) norm = ImageNormalize(stretch=LogStretch()) im = ax.imshow(image, norm=norm, cmap='Blues_r') fig.colorbar(im, ax=ax) return fig
def logarithmic_scale (Images, Images_load_length): transform = LogStretch() + MinMaxInterval() # def of transformation print('Logarithmic strech :') for k in range(Images_load_length): if k%1000==0: print(k) Images[k,:,:] = transform(Images[k,:,:]) print('Logarithmic data strech done') return Images
def diagnostic_source_finding_plots(ifile, coo_tab=None): hdu = fits.open(ifile) data = hdu[0].data norm = ImageNormalize(stretch=LogStretch()) plt.imshow(data, cmap='Greys', origin='lower', norm=norm) if coo_tab: positions = (np.array(coo_tab['xcentroid'].tolist()), \ np.array(coo_tab['ycentroid'].tolist())) apertures = CircularAperture(positions, r=10.) apertures.plot(color='blue', lw=2, alpha=1) plt.show()
def quick_rgb(image_red, image_green, image_blue, contrast=0.25): # Determine limits for each channel interval = ZScaleInterval(contrast=contrast) red_min, red_max = interval.get_limits(image_red) green_min, green_max = interval.get_limits(image_green) blue_min, blue_max = interval.get_limits(image_blue) # Determine overall limits vmin, vmax = min(red_min, green_min, blue_min), max(red_max, green_max, blue_max) norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=LogStretch(), clip=True) # Make destination array rgbim = np.zeros(image_red.shape + (3,), dtype=np.uint8) for idx, im in enumerate((image_red, image_green, image_blue)): rescaled = (norm(im) * 255).astype(np.uint8) rgbim[:,:,idx] = rescaled return rgbim
def __init__(self, data, header, **kwargs): # Assume pixel units are arcesc if not given header['cunit1'] = header.get('cunit1', 'arcsec') header['cunit2'] = header.get('cunit2', 'arcsec') super().__init__(data, header, **kwargs) # It needs to be verified that these must actually be set and are not # already in the header. self.meta['detector'] = "TRACE" self.meta['obsrvtry'] = "TRACE" self._nickname = self.detector # Colour maps self.plot_settings['cmap'] = 'trace' + str(self.meta['WAVE_LEN']) self.plot_settings['norm'] = ImageNormalize(stretch=source_stretch( self.meta, LogStretch()), clip=False)
def plot_image(image, scale='linear', origin='lower', xlabel='Pixel Column Number', ylabel='Pixel Row Number', clabel='Flux ($e^{-}s^{-1}$)', title=None, **kwargs): """Utility function to plot a 2D image Parameters ---------- image : 2d array Image data. scale : str Scale used to stretch the colormap. Options: 'linear', 'sqrt', or 'log'. origin : str The origin of the coordinate system. xlabel : str Label for the x-axis. ylabel : str Label for the y-axis. clabel : str Label for the color bar. title : str or None Title for the plot. kwargs : dict Keyword arguments to be passed to `matplotlib.pyplot.imshow`. """ fig, ax = plt.subplots() vmin, vmax = PercentileInterval(95.).get_limits(image) if scale == 'linear': norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=LinearStretch()) elif scale == 'sqrt': norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=SqrtStretch()) elif scale == 'log': norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=LogStretch()) else: raise ValueError("scale {} is not available.".format(scale)) cax = ax.imshow(image, origin=origin, norm=norm, **kwargs) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_title(title) cbar = fig.colorbar(cax, norm=norm, label=clabel) return fig, ax