def main(): filename = get_stack_filename() ad = astrodata.open(filename) data = ad[0].data mask = ad[0].mask header = ad[0].hdr masked_data = np.ma.masked_where(mask, data, copy=True) palette = copy(plt.cm.viridis) palette.set_bad('gray') norm_factor = visualization.ImageNormalize( masked_data, stretch=visualization.LinearStretch(), interval=visualization.ZScaleInterval(), ) fig, ax = plt.subplots(subplot_kw={'projection': wcs.WCS(header)}) ax.imshow(masked_data, cmap=palette, vmin=norm_factor.vmin, vmax=norm_factor.vmax) ax.set_title(os.path.basename(filename)) ax.set_xticklabels([]) ax.set_yticklabels([]) fig.savefig(filename.replace('.fits', '.png')) plt.show()
def plotcontrast(img): """ http://docs.astropy.org/en/stable/visualization/index.html """ vistypes = (None, LogNorm(), vis.AsinhStretch(), vis.ContrastBiasStretch(1, 0.5), vis.HistEqStretch(img), vis.LinearStretch(), vis.LogStretch(), vis.PowerDistStretch(a=10.), vis.PowerStretch(a=10.), vis.SinhStretch(), vis.SqrtStretch(), vis.SquaredStretch()) fg, ax = subplots(4, 3, figsize=(10, 10)) ax = ax.ravel() for i, v in enumerate(vistypes): #a = figure().gca() a = ax[i] if v and not isinstance(v, LogNorm): norm = ImageNormalize(stretch=v) a.set_title(str(v.__class__).split('.')[-1].split("'")[0]) else: norm = v a.set_title(str(v).split('.')[-1].split(" ")[0]) a.imshow(img, origin='lower', cmap='gray', norm=norm) a.axis('off') fg.suptitle('Matplotlib/AstroPy normalizations')
def plot_norm(self, stretch='linear', power=1.0, asinh_a=0.1, min_cut=None, max_cut=None, min_percent=None, max_percent=None, percent=None, clip=True): """Create a matplotlib norm object for plotting. This is a copy of this function that will be available in Astropy 1.3: `astropy.visualization.mpl_normalize.simple_norm` See the parameter description there! Examples -------- >>> image = SkyImage() >>> norm = image.plot_norm(stretch='sqrt', max_percent=99) >>> image.plot(norm=norm) """ import astropy.visualization as v from astropy.visualization.mpl_normalize import ImageNormalize if percent is not None: interval = v.PercentileInterval(percent) elif min_percent is not None or max_percent is not None: interval = v.AsymmetricPercentileInterval(min_percent or 0., max_percent or 100.) elif min_cut is not None or max_cut is not None: interval = v.ManualInterval(min_cut, max_cut) else: interval = v.MinMaxInterval() if stretch == 'linear': stretch = v.LinearStretch() elif stretch == 'sqrt': stretch = v.SqrtStretch() elif stretch == 'power': stretch = v.PowerStretch(power) elif stretch == 'log': stretch = v.LogStretch() elif stretch == 'asinh': stretch = v.AsinhStretch(asinh_a) else: raise ValueError('Unknown stretch: {0}.'.format(stretch)) vmin, vmax = interval.get_limits(self.data) return ImageNormalize(vmin=vmin, vmax=vmax, stretch=stretch, clip=clip)
def main(): args = _parse_args() filename = args.filename ad = astrodata.open(filename) data = ad[0].data mask = ad[0].mask header = ad[0].hdr if args.mask: masked_data = np.ma.masked_where(mask, data, copy=True) else: masked_data = data palette = copy(plt.cm.viridis) palette.set_bad('Gainsboro') norm_factor = visualization.ImageNormalize( masked_data, stretch=visualization.LinearStretch(), interval=visualization.ZScaleInterval(), ) fig = plt.figure(num=filename) ax = fig.subplots(subplot_kw={"projection": wcs.WCS(header)}) print(norm_factor.vmin) print(norm_factor.vmax) ax.imshow( masked_data, cmap=palette, #vmin=norm_factor.vmin, #vmax=norm_factor.vmax, vmin=750., vmax=900., origin='lower') ax.set_title(os.path.basename(filename)) ax.coords[0].set_axislabel('Right Ascension') ax.coords[0].set_ticklabel(fontsize='small') ax.coords[1].set_axislabel('Declination') ax.coords[1].set_ticklabel(rotation='vertical', fontsize='small') fig.tight_layout(rect=[0.05, 0, 1, 1]) fig.savefig(os.path.basename(filename.replace('.fits', '.png'))) plt.show()
def get_im_stretch(stretch): ''' Returns a stretch to feed the ImageNormalize routine from Astropy. :param stretch: short name for the stretch I want. Possibilities are 'arcsinh' or 'linear'. :type stretch: string :return: A :class:`astropy.visualization.stretch` thingy ... :rtype: :class:`astropy.visualization.stretch` ''' if stretch == 'arcsinh': return astrovis.AsinhStretch() if stretch == 'linear': return astrovis.LinearStretch() raise Exception('Ouch! Stretch %s unknown.' % (stretch))
def main(): filename = get_stack_filename() ad = astrodata.open(filename) fig = plt.figure(num=filename, figsize=(7, 4.5)) fig.suptitle(os.path.basename(filename), y=0.97) axs = fig.subplots(1, len(ad), sharey=True) palette = copy(plt.cm.viridis) palette.set_bad("Gainsboro", 1.0) norm = visualization.ImageNormalize( np.dstack([ext.data for ext in ad]), stretch=visualization.LinearStretch(), interval=visualization.ZScaleInterval() ) print(norm.vmin) print(norm.vmax) for i in range(len(ad)): axs[i].imshow( # np.ma.masked_where(ad[i].mask > 0, ad[i].data), ad[i].data, #norm=colors.Normalize(vmin=norm.vmin, vmax=norm.vmax), norm=colors.Normalize(vmin=750, vmax=900), origin="lower", cmap=palette, ) axs[i].set_xlabel('d{:02d}'.format(i+1)) axs[i].set_xticks([]) axs[i].set_yticks([]) fig.tight_layout(rect=[0, 0, 1, 1], w_pad=0.05) fig.savefig(os.path.basename(filename.replace('.fits', '.png'))) plt.show()
def create_figure(self, frameno=0, binning=1, dpi=None, stretch='log', vmin=1, vmax=5000, cmap='gray', data_col='FLUX', annotate=True, time_format='ut', show_flags=False, label=None): """Returns a matplotlib Figure object that visualizes a frame. Parameters ---------- frameno : int Image number in the target pixel file. binning : int Number of frames around `frameno` to co-add. (default: 1). dpi : float, optional [dots per inch] Resolution of the output in dots per Kepler CCD pixel. By default the dpi is chosen such that the image is 440px wide. vmin : float, optional Minimum cut level (default: 1). vmax : float, optional Maximum cut level (default: 5000). cmap : str, optional The matplotlib color map name. The default is 'gray', can also be e.g. 'gist_heat'. raw : boolean, optional If `True`, show the raw pixel counts rather than the calibrated flux. Default: `False`. annotate : boolean, optional Annotate the Figure with a timestamp and target name? (Default: `True`.) show_flags : boolean, optional Show the quality flags? (Default: `False`.) label : str Label text to show in the bottom left corner of the movie. Returns ------- image : array An array of unisgned integers of shape (x, y, 3), representing an RBG colour image x px wide and y px high. """ # Get the flux data to visualize flx = self.flux_binned(frameno=frameno, binning=binning, data_col=data_col) # Determine the figsize and dpi shape = list(flx.shape) shape = [shape[1], shape[0]] if dpi is None: # Twitter timeline requires dimensions between 440x220 and 1024x512 # so we make 440 the default dpi = 440 / float(shape[0]) # libx264 require the height to be divisible by 2, we ensure this here: shape[0] -= ((shape[0] * dpi) % 2) / dpi # Create the figureand display the flux image using matshow fig = pl.figure(figsize=shape, dpi=dpi) # Display the image using matshow ax = fig.add_subplot(1, 1, 1) if self.verbose: print('{} vmin/vmax = {}/{} (median={})'.format( data_col, vmin, vmax, np.nanmedian(flx))) if stretch == 'linear': stretch_fn = visualization.LinearStretch() elif stretch == 'sqrt': stretch_fn = visualization.SqrtStretch() elif stretch == 'power': stretch_fn = visualization.PowerStretch(1.0) elif stretch == 'log': stretch_fn = visualization.LogStretch() elif stretch == 'asinh': stretch_fn = visualization.AsinhStretch(0.1) else: raise ValueError('Unknown stretch: {0}.'.format(stretch)) transform = (stretch_fn + visualization.ManualInterval(vmin=vmin, vmax=vmax)) flx_transform = 255 * transform(flx) # Make sure to remove all NaNs! flx_transform[~np.isfinite(flx_transform)] = 0 ax.imshow(flx_transform.astype(int), aspect='auto', origin='lower', interpolation='nearest', cmap=cmap, norm=NoNorm()) if annotate: # Annotate the frame with a timestamp and target name? fontsize = 3. * shape[0] margin = 0.03 # Print target name in lower left corner if label is None: label = self.objectname txt = ax.text(margin, margin, label, family="monospace", fontsize=fontsize, color='white', transform=ax.transAxes) txt.set_path_effects([ path_effects.Stroke(linewidth=fontsize / 6., foreground='black'), path_effects.Normal() ]) # Print a timestring in the lower right corner txt2 = ax.text(1 - margin, margin, self.timestamp(frameno, time_format=time_format), family="monospace", fontsize=fontsize, color='white', ha='right', transform=ax.transAxes) txt2.set_path_effects([ path_effects.Stroke(linewidth=fontsize / 6., foreground='black'), path_effects.Normal() ]) # Print quality flags in upper right corner if show_flags: flags = self.quality_flags(frameno) if len(flags) > 0: txt3 = ax.text(margin, 1 - margin, '\n'.join(flags), family="monospace", fontsize=fontsize * 1.3, color='white', ha='left', va='top', transform=ax.transAxes, linespacing=1.5, backgroundcolor='red') txt3.set_path_effects([ path_effects.Stroke(linewidth=fontsize / 6., foreground='black'), path_effects.Normal() ]) ax.set_xticks([]) ax.set_yticks([]) ax.axis('off') fig.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0) fig.canvas.draw() return fig
interval=v.ManualInterval(vmin=image.min() - 5, vmax=image.max() + 10), stretch=v.LogStretch(10)) im = plt.imshow(image, origin='lower', norm=norm, cmap='Greys') plt.colorbar(im) plt.subplot(1, 2, 2) plt.title('Nebular Emission Mask') mimage = np.ma.MaskedArray(image) mimage.mask = ~nmask mimagef = np.ma.filled(mimage, fill_value=0) norm = v.ImageNormalize(mimagef, interval=v.ManualInterval( vmin=image.min() - 5, vmax=np.percentile(image, mask_pcnt) + 5), stretch=v.LinearStretch()) im = plt.imshow(mimagef, origin='lower', norm=norm, cmap='Greys') plt.colorbar(im) plt.savefig('H-beta Image.png', bbox_inches='tight', pad_inches=0.10) ##------------------------------------------------------------------------- ## Model and Plot Nebular Background log.info('Model and plot nebular background') background_0 = models.Polynomial1D(degree=2) H_beta_0 = models.Gaussian1D(amplitude=500, mean=4861, stddev=1., bounds={ 'mean': (4855, 4865), 'stddev': (0.1, 5)
def show_image(image, percl=99, percu=None, is_mask=False, figsize=(6, 10), cmap='viridis', log=False, show_colorbar=True, show_ticks=True, fig=None, ax=None, input_ratio=None): """ Show an image in matplotlib with some basic astronomically-appropriat stretching. Parameters ---------- image The image to show percl : number The percentile for the lower edge of the stretch (or both edges if ``percu`` is None) percu : number or None The percentile for the upper edge of the stretch (or None to use ``percl`` for both) figsize : 2-tuple The size of the matplotlib figure in inches """ if percu is None: percu = percl percl = 100 - percl if (fig is None and ax is not None) or (fig is not None and ax is None): raise ValueError('Must provide both "fig" and "ax" ' 'if you provide one of them') elif fig is None and ax is None: fig, ax = plt.subplots(1, 1, figsize=figsize) if figsize is not None: # Rescale the fig size to match the image dimensions, roughly image_aspect_ratio = image.shape[0] / image.shape[1] figsize = (max(figsize) * image_aspect_ratio, max(figsize)) print(figsize) # To preserve details we should *really* downsample correctly and # not rely on matplotlib to do it correctly for us (it won't). # So, calculate the size of the figure in pixels, block_reduce to # roughly that,and display the block reduced image. # Thanks, https://stackoverflow.com/questions/29702424/how-to-get-matplotlib-figure-size fig_size_pix = fig.get_size_inches() * fig.dpi ratio = (image.shape // fig_size_pix).max() if ratio < 1: ratio = 1 ratio = input_ratio or ratio # Divide by the square of the ratio to keep the flux the same in the # reduced image reduced_data = block_reduce(image, ratio) / ratio**2 # Of course, now that we have downsampled, the axis limits are changed to # match the smaller image size. Setting the extent will do the trick to # change the axis display back to showing the actual extent of the image. extent = [0, image.shape[1], 0, image.shape[0]] if log: stretch = aviz.LogStretch() else: stretch = aviz.LinearStretch() norm = aviz.ImageNormalize(reduced_data, interval=aviz.AsymmetricPercentileInterval( percl, percu), stretch=stretch) if is_mask: # The image is a mask in which pixels are zero or one. Set the image scale # limits appropriately. scale_args = dict(vmin=0, vmax=1) else: scale_args = dict(norm=norm) im = ax.imshow(reduced_data, origin='lower', cmap=cmap, extent=extent, aspect='equal', **scale_args) if show_colorbar: # I haven't a clue why the fraction and pad arguments below work to make # the colorbar the same height as the image, but they do....unless the image # is wider than it is tall. Sticking with this for now anyway... # Thanks: https://stackoverflow.com/a/26720422/3486425 fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04, format='%2.0f') # In case someone in the future wants to improve this: # https://joseph-long.com/writing/colorbars/ # https://stackoverflow.com/a/33505522/3486425 # https://matplotlib.org/mpl_toolkits/axes_grid/users/overview.html#colorbar-whose-height-or-width-in-sync-with-the-master-axes if not show_ticks: ax.tick_params(labelbottom=False, labelleft=False, labelright=False, labeltop=False)
def find_bar_positions_from_image(imagefile, filtersize=5, plot=False, pixel_shim=5): '''Loop over all slits in the image and using the affine transformation determined by `fit_transforms`, select the Y pixel range over which this slit should be found. Take a median filtered version of that image and determine the X direction gradient (derivative). Then collapse it in the Y direction to form a 1D profile. Using the `find_bar_edges` method, determine the X pixel positions of each bar forming the slit. Convert those X pixel position to physical coordinates using the `pixel_to_physical` method and then call the `compare_to_csu_bar_state` method to determine the bar state. ''' ## Get image from file imagefile = Path(imagefile).absolute() try: hdul = fits.open(imagefile) data = hdul[0].data except Error as e: log.error(e) raise # median X pixels only (preserve Y structure) medimage = ndimage.median_filter(data, size=(1, filtersize)) bars = {} ypos = {} for slit in range(1,47): b1, b2 = slit_to_bars(slit) ## Determine y pixel range y1 = int(np.ceil((physical_to_pixel(np.array([(4.0, slit+0.5)])))[0][0][1])) + pixel_shim y2 = int(np.floor((physical_to_pixel(np.array([(270.4, slit-0.5)])))[0][0][1])) - pixel_shim ypos[b1] = [y1, y2] ypos[b2] = [y1, y2] gradx = np.gradient(medimage[y1:y2,:], axis=1) horizontal_profile = np.sum(gradx, axis=0) try: bars[b1], bars[b2] = find_bar_edges(horizontal_profile) except: print(f'Unable to fit bars: {b1}, {b2}') # Generate plot if called for if plot is True: plotfile = imagefile.with_name(f"{imagefile.stem}.png") log.info(f'Creating PNG image {plotfile}') if plotfile.exists(): plotfile.unlink() plt.figure(figsize=(16,16), dpi=300) norm = viz.ImageNormalize(data, interval=viz.PercentileInterval(99.9), stretch=viz.LinearStretch()) plt.imshow(data, norm=norm, origin='lower', cmap='Greys') for bar in bars.keys(): # plt.plot([0,2048], [ypos[bar][0], ypos[bar][0]], 'r-', alpha=0.1) # plt.plot([0,2048], [ypos[bar][1], ypos[bar][1]], 'r-', alpha=0.1) mms = np.linspace(4,270.4,2) slit = bar_to_slit(bar) pix = np.array([(physical_to_pixel(np.array([(mm, slit+0.5)])))[0][0] for mm in mms]) plt.plot(pix.transpose()[0], pix.transpose()[1], 'g-', alpha=0.5) plt.plot([bars[bar],bars[bar]], ypos[bar], 'r-', alpha=0.75) offset = {0: -20, 1:+20}[bar % 2] plt.text(bars[bar]+offset, np.mean(ypos[bar]), bar, fontsize=8, color='r', alpha=0.75, horizontalalignment='center', verticalalignment='center') plt.savefig(str(plotfile), bbox_inches='tight') return bars
def plot_image(image, ax=None, scale='log', cmap=None, origin='lower', xlabel=None, ylabel=None, cbar=None, clabel='Flux ($e^{-}s^{-1}$)', cbar_ticks=None, cbar_ticklabels=None, cbar_pad=None, cbar_size='5%', title=None, percentile=95.0, vmin=None, vmax=None, offset_axes=None, color_bad='k', **kwargs): """ Utility function to plot a 2D image. Parameters: image (2d array): Image data. ax (matplotlib.pyplot.axes, optional): Axes in which to plot. Default (None) is to use current active axes. scale (str or :py:class:`astropy.visualization.ImageNormalize` object, optional): Normalization used to stretch the colormap. Options: ``'linear'``, ``'sqrt'``, ``'log'``, ``'asinh'``, ``'histeq'``, ``'sinh'`` and ``'squared'``. Can also be a :py:class:`astropy.visualization.ImageNormalize` object. Default is ``'log'``. origin (str, optional): The origin of the coordinate system. xlabel (str, optional): Label for the x-axis. ylabel (str, optional): Label for the y-axis. cbar (string, optional): Location of color bar. Choises are ``'right'``, ``'left'``, ``'top'``, ``'bottom'``. Default is not to create colorbar. clabel (str, optional): Label for the color bar. cbar_size (float, optional): Fractional size of colorbar compared to axes. Default=0.03. cbar_pad (float, optional): Padding between axes and colorbar. title (str or None, optional): Title for the plot. percentile (float, optional): The fraction of pixels to keep in color-trim. If single float given, the same fraction of pixels is eliminated from both ends. If tuple of two floats is given, the two are used as the percentiles. Default=95. cmap (matplotlib colormap, optional): Colormap to use. Default is the ``Blues`` colormap. vmin (float, optional): Lower limit to use for colormap. vmax (float, optional): Upper limit to use for colormap. color_bad (str, optional): Color to apply to bad pixels (NaN). Default is black. kwargs (dict, optional): Keyword arguments to be passed to :py:func:`matplotlib.pyplot.imshow`. Returns: :py:class:`matplotlib.image.AxesImage`: Image from returned by :py:func:`matplotlib.pyplot.imshow`. .. codeauthor:: Rasmus Handberg <*****@*****.**> """ logger = logging.getLogger(__name__) # Backward compatible settings: make_cbar = kwargs.pop('make_cbar', None) if make_cbar: raise FutureWarning("'make_cbar' is deprecated. Use 'cbar' instead.") if not cbar: cbar = make_cbar # Special treatment for boolean arrays: if isinstance(image, np.ndarray) and image.dtype == 'bool': if vmin is None: vmin = 0 if vmax is None: vmax = 1 if cbar_ticks is None: cbar_ticks = [0, 1] if cbar_ticklabels is None: cbar_ticklabels = ['False', 'True'] # Calculate limits of color scaling: interval = None if vmin is None or vmax is None: if allnan(image): logger.warning("Image is all NaN") vmin = 0 vmax = 1 if cbar_ticks is None: cbar_ticks = [] if cbar_ticklabels is None: cbar_ticklabels = [] elif isinstance(percentile, (list, tuple, np.ndarray)): interval = viz.AsymmetricPercentileInterval( percentile[0], percentile[1]) else: interval = viz.PercentileInterval(percentile) # Create ImageNormalize object with extracted limits: if scale in ('log', 'linear', 'sqrt', 'asinh', 'histeq', 'sinh', 'squared'): if scale == 'log': stretch = viz.LogStretch() elif scale == 'linear': stretch = viz.LinearStretch() elif scale == 'sqrt': stretch = viz.SqrtStretch() elif scale == 'asinh': stretch = viz.AsinhStretch() elif scale == 'histeq': stretch = viz.HistEqStretch(image[np.isfinite(image)]) elif scale == 'sinh': stretch = viz.SinhStretch() elif scale == 'squared': stretch = viz.SquaredStretch() # Create ImageNormalize object. Very important to use clip=False if the image contains # NaNs, otherwise NaN points will not be plotted correctly. norm = viz.ImageNormalize(data=image[np.isfinite(image)], interval=interval, vmin=vmin, vmax=vmax, stretch=stretch, clip=not anynan(image)) elif isinstance(scale, (viz.ImageNormalize, matplotlib.colors.Normalize)): norm = scale else: raise ValueError("scale {} is not available.".format(scale)) if offset_axes: extent = (offset_axes[0] - 0.5, offset_axes[0] + image.shape[1] - 0.5, offset_axes[1] - 0.5, offset_axes[1] + image.shape[0] - 0.5) else: extent = (-0.5, image.shape[1] - 0.5, -0.5, image.shape[0] - 0.5) if ax is None: ax = plt.gca() # Set up the colormap to use. If a bad color is defined, # add it to the colormap: if cmap is None: cmap = copy.copy(plt.get_cmap('Blues')) elif isinstance(cmap, str): cmap = copy.copy(plt.get_cmap(cmap)) if color_bad: cmap.set_bad(color_bad, 1.0) # Plotting the image using all the settings set above: im = ax.imshow(image, cmap=cmap, norm=norm, origin=origin, extent=extent, interpolation='nearest', **kwargs) if xlabel is not None: ax.set_xlabel(xlabel) if ylabel is not None: ax.set_ylabel(ylabel) if title is not None: ax.set_title(title) ax.set_xlim([extent[0], extent[1]]) ax.set_ylim([extent[2], extent[3]]) if cbar: colorbar(im, ax=ax, loc=cbar, size=cbar_size, pad=cbar_pad, label=clabel, ticks=cbar_ticks, ticklabels=cbar_ticklabels) # Settings for ticks: integer_locator = MaxNLocator(nbins=10, integer=True) ax.xaxis.set_major_locator(integer_locator) ax.xaxis.set_minor_locator(integer_locator) ax.yaxis.set_major_locator(integer_locator) ax.yaxis.set_minor_locator(integer_locator) ax.tick_params(which='both', direction='out', pad=5) ax.xaxis.tick_bottom() ax.yaxis.tick_left() return im
def plotDirectCutouts(self, savePath=None, colourMap='viridis', gridSpecs=None): if self.directCutouts is not None: if gridSpecs is None: mplplot.figure(figsize=(10, 5)) for stampIndex, (grism, cutoutData) in enumerate( self.directCutouts.items()): if gridSpecs is None: subplotAxes = mplplot.subplot(2, 1, stampIndex + 1) else: subplotAxes = mplplot.subplot(gridSpecs[stampIndex]) if cutoutData is None: subplotAxes.text( 0.5, 0.5, 'Field {}, Object {}:\nNO DATA AVAILABLE.'.format( self.targetPar, self.targetObject), horizontalalignment='center', fontsize='large', transform=subplotAxes.transAxes) continue if np.all(cutoutData < 0): subplotAxes.text( 0.5, 0.5, 'Field {}, Object {}:\nNO NONZERO DATA AVAILABLE.'. format(self.targetPar, self.targetObject), horizontalalignment='center', fontsize='large', transform=subplotAxes.transAxes) continue norm = astromplnorm.ImageNormalize( cutoutData, interval=astrovis.AsymmetricPercentileInterval(0, 99.5), stretch=astrovis.LinearStretch(), clip=True) mplplot.imshow(cutoutData, origin='lower', interpolation='nearest', cmap=colourMap, norm=norm) mplplot.xlabel('X (pixels)') mplplot.ylabel('Y (pixels)') mplplot.title( 'Field {}, Object {}:\nDirect cutout for F{} (G{})'.format( self.targetPar, self.targetObject, self.getDirectFilterForGrism(grism), grism)) arcsecYAxis = subplotAxes.twinx() arcsecYAxis.set_ylim(*(np.array(subplotAxes.get_ylim()) - 0.5 * np.sum(subplotAxes.get_ylim())) * self.directHdus[grism][1]['IDCSCALE']) print( subplotAxes.get_ylim(), np.array(subplotAxes.get_ylim()), np.array(subplotAxes.get_ylim()) * self.directHdus[grism][1]['IDCSCALE'], *np.array(subplotAxes.get_ylim()) * self.directHdus[grism][1]['IDCSCALE']) arcsecYAxis.set_ylabel('$\Delta Y$ (arcsec)') mplplot.grid(color='white', ls='solid') try: mplplot.tight_layout() except ValueError as e: print( 'Error attempting tight_layout for: Field {}, Object {} ({})' .format(self.targetPar, self.targetObject, e)) return if savePath is not None: mplplot.savefig(savePath, dpi=300, bbox_inches='tight') mplplot.close() else: print( 'The loadDirectCutouts(...) method must be called before direct cutouts can be plotted.' )
def makeSlitIllum(self, adinputs=None, **params): """ Makes the processed Slit Illumination Function by binning a 2D spectrum along the dispersion direction, fitting a smooth function for each bin, fitting a smooth 2D model, and reconstructing the 2D array using this last model. Its implementation based on the IRAF's `noao.twodspec.longslit.illumination` task following the algorithm described in [Valdes, 1968]. It expects an input calibration image to be an a dispersed image of the slit without illumination problems (e.g, twilight flat). The spectra is not required to be smooth in wavelength and may contain strong emission and absorption lines. The image should contain a `.mask` attribute in each extension, and it is expected to be overscan and bias corrected. Parameters ---------- adinputs : list List of AstroData objects containing the dispersed image of the slit of a source free of illumination problems. The data needs to have been overscan and bias corrected and is expected to have a Data Quality mask. bins : {None, int}, optional Total number of bins across the dispersion axis. If None, the number of bins will match the number of extensions on each input AstroData object. It it is an int, it will create N bins with the same size. border : int, optional Border size that is added on every edge of the slit illumination image before cutting it down to the input AstroData frame. smooth_order : int, optional Order of the spline that is used in each bin fitting to smooth the data (Default: 3) x_order : int, optional Order of the x-component in the Chebyshev2D model used to reconstruct the 2D data from the binned data. y_order : int, optional Order of the y-component in the Chebyshev2D model used to reconstruct the 2D data from the binned data. Return ------ List of AstroData : containing an AstroData with the Slit Illumination Response Function for each of the input object. References ---------- .. [Valdes, 1968] Francisco Valdes "Reduction Of Long Slit Spectra With IRAF", Proc. SPIE 0627, Instrumentation in Astronomy VI, (13 October 1986); https://doi.org/10.1117/12.968155 """ log = self.log log.debug(gt.log_message("primitive", self.myself(), "starting")) timestamp_key = self.timestamp_keys[self.myself()] suffix = params["suffix"] bins = params["bins"] border = params["border"] debug_plot = params["debug_plot"] smooth_order = params["smooth_order"] cheb2d_x_order = params["x_order"] cheb2d_y_order = params["y_order"] ad_outputs = [] for ad in adinputs: if len(ad) > 1 and "mosaic" not in ad[0].wcs.available_frames: log.info('Add "mosaic" gWCS frame to input data') geotable = import_module('.geometry_conf', self.inst_lookups) # deepcopy prevents modifying input `ad` inplace ad = transform.add_mosaic_wcs(deepcopy(ad), geotable) log.info("Temporarily mosaicking multi-extension file") mosaicked_ad = transform.resample_from_wcs( ad, "mosaic", attributes=None, order=1, process_objcat=False) else: log.info('Input data already has one extension and has a ' '"mosaic" frame.') # deepcopy prevents modifying input `ad` inplace mosaicked_ad = deepcopy(ad) log.info("Transposing data if needed") dispaxis = 2 - mosaicked_ad[0].dispersion_axis() # python sense should_transpose = dispaxis == 1 data, mask, variance = _transpose_if_needed( mosaicked_ad[0].data, mosaicked_ad[0].mask, mosaicked_ad[0].variance, transpose=should_transpose) log.info("Masking data") data = np.ma.masked_array(data, mask=mask) variance = np.ma.masked_array(variance, mask=mask) std = np.sqrt(variance) # Easier to work with log.info("Creating bins for data and variance") height = data.shape[0] width = data.shape[1] if bins is None: nbins = max(len(ad), 12) bin_limits = np.linspace(0, height, nbins + 1, dtype=int) elif isinstance(bins, int): nbins = bins bin_limits = np.linspace(0, height, nbins + 1, dtype=int) else: # ToDo: Handle input bins as array raise TypeError("Expected None or Int for `bins`. " "Found: {}".format(type(bins))) bin_top = bin_limits[1:] bin_bot = bin_limits[:-1] binned_data = np.zeros_like(data) binned_std = np.zeros_like(std) log.info("Smooth binned data and variance, and normalize them by " "smoothed central value") for bin_idx, (b0, b1) in enumerate(zip(bin_bot, bin_top)): rows = np.arange(width) avg_data = np.ma.mean(data[b0:b1], axis=0) model_1d_data = astromodels.UnivariateSplineWithOutlierRemoval( rows, avg_data, order=smooth_order) avg_std = np.ma.mean(std[b0:b1], axis=0) model_1d_std = astromodels.UnivariateSplineWithOutlierRemoval( rows, avg_std, order=smooth_order) slit_central_value = model_1d_data(rows)[width // 2] binned_data[b0:b1] = model_1d_data(rows) / slit_central_value binned_std[b0:b1] = model_1d_std(rows) / slit_central_value log.info("Reconstruct 2D mosaicked data") bin_center = np.array(0.5 * (bin_bot + bin_top), dtype=int) cols_fit, rows_fit = np.meshgrid(np.arange(width), bin_center) fitter = fitting.SLSQPLSQFitter() model_2d_init = models.Chebyshev2D(x_degree=cheb2d_x_order, x_domain=(0, width), y_degree=cheb2d_y_order, y_domain=(0, height)) model_2d_data = fitter(model_2d_init, cols_fit, rows_fit, binned_data[rows_fit, cols_fit]) model_2d_std = fitter(model_2d_init, cols_fit, rows_fit, binned_std[rows_fit, cols_fit]) rows_val, cols_val = \ np.mgrid[-border:height+border, -border:width+border] slit_response_data = model_2d_data(cols_val, rows_val) slit_response_mask = np.pad( mask, border, mode='edge') # ToDo: any update to the mask? slit_response_std = model_2d_std(cols_val, rows_val) slit_response_var = slit_response_std**2 del cols_fit, cols_val, rows_fit, rows_val _data, _mask, _variance = _transpose_if_needed( slit_response_data, slit_response_mask, slit_response_var, transpose=dispaxis == 1) log.info("Update slit response data and data_section") slit_response_ad = deepcopy(mosaicked_ad) slit_response_ad[0].data = _data slit_response_ad[0].mask = _mask slit_response_ad[0].variance = _variance if "mosaic" in ad[0].wcs.available_frames: log.info( "Map coordinates between slit function and mosaicked data" ) # ToDo: Improve message? slit_response_ad = _split_mosaic_into_extensions( ad, slit_response_ad, border_size=border) elif len(ad) == 1: log.info("Trim out borders") slit_response_ad[0].data = \ slit_response_ad[0].data[border:-border, border:-border] slit_response_ad[0].mask = \ slit_response_ad[0].mask[border:-border, border:-border] slit_response_ad[0].variance = \ slit_response_ad[0].variance[border:-border, border:-border] log.info("Update metadata and filename") gt.mark_history(slit_response_ad, primname=self.myself(), keyword=timestamp_key) slit_response_ad.update_filename(suffix=suffix, strip=True) ad_outputs.append(slit_response_ad) # Plotting ------ if debug_plot: log.info("Creating plots") palette = copy(plt.cm.cividis) palette.set_bad('r', 0.75) norm = vis.ImageNormalize(data[~data.mask], stretch=vis.LinearStretch(), interval=vis.PercentileInterval(97)) fig = plt.figure(num="Slit Response from MEF - {}".format( ad.filename), figsize=(12, 9), dpi=110) gs = gridspec.GridSpec(nrows=2, ncols=3, figure=fig) # Display raw mosaicked data and its bins --- ax1 = fig.add_subplot(gs[0, 0]) im1 = ax1.imshow(data, cmap=palette, origin='lower', vmin=norm.vmin, vmax=norm.vmax) ax1.set_title("Mosaicked Data\n and Spectral Bins", fontsize=10) ax1.set_xlim(-1, data.shape[1]) ax1.set_xticks([]) ax1.set_ylim(-1, data.shape[0]) ax1.set_yticks(bin_center) ax1.tick_params(axis=u'both', which=u'both', length=0) ax1.set_yticklabels( ["Bin {}".format(i) for i in range(len(bin_center))], fontsize=6) _ = [ax1.spines[s].set_visible(False) for s in ax1.spines] _ = [ax1.axhline(b, c='w', lw=0.5) for b in bin_limits] divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad=0.05) plt.colorbar(im1, cax=cax1) # Display non-smoothed bins --- ax2 = fig.add_subplot(gs[0, 1]) im2 = ax2.imshow(binned_data, cmap=palette, origin='lower') ax2.set_title("Binned, smoothed\n and normalized data ", fontsize=10) ax2.set_xlim(0, data.shape[1]) ax2.set_xticks([]) ax2.set_ylim(0, data.shape[0]) ax2.set_yticks(bin_center) ax2.tick_params(axis=u'both', which=u'both', length=0) ax2.set_yticklabels( ["Bin {}".format(i) for i in range(len(bin_center))], fontsize=6) _ = [ax2.spines[s].set_visible(False) for s in ax2.spines] _ = [ax2.axhline(b, c='w', lw=0.5) for b in bin_limits] divider = make_axes_locatable(ax2) cax2 = divider.append_axes("right", size="5%", pad=0.05) plt.colorbar(im2, cax=cax2) # Display reconstructed slit response --- vmin = slit_response_data.min() vmax = slit_response_data.max() ax3 = fig.add_subplot(gs[1, 0]) im3 = ax3.imshow(slit_response_data, cmap=palette, origin='lower', vmin=vmin, vmax=vmax) ax3.set_title("Reconstructed\n Slit response", fontsize=10) ax3.set_xlim(0, data.shape[1]) ax3.set_xticks([]) ax3.set_ylim(0, data.shape[0]) ax3.set_yticks([]) ax3.tick_params(axis=u'both', which=u'both', length=0) _ = [ax3.spines[s].set_visible(False) for s in ax3.spines] divider = make_axes_locatable(ax3) cax3 = divider.append_axes("right", size="5%", pad=0.05) plt.colorbar(im3, cax=cax3) # Display extensions --- ax4 = fig.add_subplot(gs[1, 1]) ax4.set_xticks([]) ax4.set_yticks([]) _ = [ax4.spines[s].set_visible(False) for s in ax4.spines] sub_gs4 = gridspec.GridSpecFromSubplotSpec(nrows=len(ad), ncols=1, subplot_spec=gs[1, 1], hspace=0.03) # The [::-1] is needed to put the fist extension in the bottom for i, ext in enumerate(slit_response_ad[::-1]): ext_data, ext_mask, ext_variance = _transpose_if_needed( ext.data, ext.mask, ext.variance, transpose=dispaxis == 1) ext_data = np.ma.masked_array(ext_data, mask=ext_mask) sub_ax = fig.add_subplot(sub_gs4[i]) im4 = sub_ax.imshow(ext_data, origin="lower", vmin=vmin, vmax=vmax, cmap=palette) sub_ax.set_xlim(0, ext_data.shape[1]) sub_ax.set_xticks([]) sub_ax.set_ylim(0, ext_data.shape[0]) sub_ax.set_yticks([ext_data.shape[0] // 2]) sub_ax.set_yticklabels( ["Ext {}".format(len(slit_response_ad) - i - 1)], fontsize=6) _ = [ sub_ax.spines[s].set_visible(False) for s in sub_ax.spines ] if i == 0: sub_ax.set_title( "Multi-extension\n Slit Response Function") divider = make_axes_locatable(ax4) cax4 = divider.append_axes("right", size="5%", pad=0.05) plt.colorbar(im4, cax=cax4) # Display Signal-To-Noise Ratio --- snr = data / np.sqrt(variance) norm = vis.ImageNormalize(snr[~snr.mask], stretch=vis.LinearStretch(), interval=vis.PercentileInterval(97)) ax5 = fig.add_subplot(gs[0, 2]) im5 = ax5.imshow(snr, cmap=palette, origin='lower', vmin=norm.vmin, vmax=norm.vmax) ax5.set_title("Mosaicked Data SNR", fontsize=10) ax5.set_xlim(-1, data.shape[1]) ax5.set_xticks([]) ax5.set_ylim(-1, data.shape[0]) ax5.set_yticks(bin_center) ax5.tick_params(axis=u'both', which=u'both', length=0) ax5.set_yticklabels( ["Bin {}".format(i) for i in range(len(bin_center))], fontsize=6) _ = [ax5.spines[s].set_visible(False) for s in ax5.spines] _ = [ax5.axhline(b, c='w', lw=0.5) for b in bin_limits] divider = make_axes_locatable(ax5) cax5 = divider.append_axes("right", size="5%", pad=0.05) plt.colorbar(im5, cax=cax5) # Display Signal-To-Noise Ratio of Slit Illumination --- slit_response_snr = np.ma.masked_array( slit_response_data / np.sqrt(slit_response_var), mask=slit_response_mask) ax6 = fig.add_subplot(gs[1, 2]) im6 = ax6.imshow(slit_response_snr, origin="lower", vmin=norm.vmin, vmax=norm.vmax, cmap=palette) ax6.set_xlim(0, slit_response_snr.shape[1]) ax6.set_xticks([]) ax6.set_ylim(0, slit_response_snr.shape[0]) ax6.set_yticks([]) ax6.set_title("Reconstructed\n Slit Response SNR") _ = [ax6.spines[s].set_visible(False) for s in ax6.spines] divider = make_axes_locatable(ax6) cax6 = divider.append_axes("right", size="5%", pad=0.05) plt.colorbar(im6, cax=cax6) # Save plots --- fig.tight_layout(rect=[0, 0, 0.95, 1], pad=0.5) fname = slit_response_ad.filename.replace(".fits", ".png") log.info("Saving plots to {}".format(fname)) plt.savefig(fname) return ad_outputs
def create_figure(self, output_filename, survey, stretch='log', vmin=1, vmax=None, min_percent=1, max_percent=95, cmap='gray', contour_color='red', data_col='FLUX'): """Returns a matplotlib Figure object that visualizes a frame. Parameters ---------- vmin : float, optional Minimum cut level (default: 0). vmax : float, optional Maximum cut level (default: 5000). cmap : str, optional The matplotlib color map name. The default is 'gray', can also be e.g. 'gist_heat'. raw : boolean, optional If `True`, show the raw pixel counts rather than the calibrated flux. Default: `False`. Returns ------- image : array An array of unisgned integers of shape (x, y, 3), representing an RBG colour image x px wide and y px high. """ # Get the flux data to visualize # Update to use TPF flx = self.TPF.flux_binned() # print(np.shape(flx)) # calculate cut_levels if vmax is None: vmin, vmax = self.cut_levels(min_percent, max_percent, data_col) # Determine the figsize shape = list(flx.shape) # print(shape) # Create the figure and display the flux image using matshow fig = plt.figure(figsize=shape) # Display the image using matshow # Update to generate axes using WCS axes instead of plain axes ax = plt.subplot(projection=self.TPF.wcs) ax.set_xlabel('RA') ax.set_ylabel('Dec') if self.verbose: print('{} vmin/vmax = {}/{} (median={})'.format( data_col, vmin, vmax, np.nanmedian(flx))) if stretch == 'linear': stretch_fn = visualization.LinearStretch() elif stretch == 'sqrt': stretch_fn = visualization.SqrtStretch() elif stretch == 'power': stretch_fn = visualization.PowerStretch(1.0) elif stretch == 'log': stretch_fn = visualization.LogStretch() elif stretch == 'asinh': stretch_fn = visualization.AsinhStretch(0.1) else: raise ValueError('Unknown stretch: {0}.'.format(stretch)) transform = (stretch_fn + visualization.ManualInterval(vmin=vmin, vmax=vmax)) ax.imshow((255 * transform(flx)).astype(int), aspect='auto', origin='lower', interpolation='nearest', cmap=cmap, norm=NoNorm()) ax.set_xticks([]) ax.set_yticks([]) current_ylims = ax.get_ylim() current_xlims = ax.get_xlim() pixels, header = surveyquery.getSVImg(self.TPF.position, survey) levels = np.linspace(np.min(pixels), np.percentile(pixels, 95), 10) ax.contour(pixels, transform=ax.get_transform(WCS(header)), levels=levels, colors=contour_color) ax.set_xlim(current_xlims) ax.set_ylim(current_ylims) fig.canvas.draw() plt.savefig(output_filename, bbox_inches='tight', dpi=300) return fig
def make_fov_image(fov, pngfn=None, **kwargs): stretch = kwargs.get('stretch', 'linear') interval = kwargs.get('interval', 'zscale') imrange = kwargs.get('imrange') contrast = kwargs.get('contrast', 0.25) ccdplotorder = ['CCD2', 'CCD4', 'CCD1', 'CCD3'] if interval == 'rms': try: losig, hisig = imrange except: losig, hisig = (2.5, 5.0) # cmap = kwargs.get('cmap', 'viridis') cmap = plt.get_cmap(cmap) cmap.set_bad('w', 1.0) w = 0.4575 h = 0.455 rc('text', usetex=False) fig = plt.figure(figsize=(6, 6.5)) cax = fig.add_axes([0.1, 0.04, 0.8, 0.01]) ims = [fov[ccd]['im'] for ccd in ccdplotorder] allpix = np.ma.array(ims).flatten() stretch = { 'linear': vis.LinearStretch(), 'histeq': vis.HistEqStretch(allpix), 'asinh': vis.AsinhStretch(), }[stretch] if interval == 'zscale': iv = vis.ZScaleInterval(contrast=contrast) vmin, vmax = iv.get_limits(allpix) elif interval == 'rms': nsample = 1000 // nbin background = sigma_clip(allpix[::nsample], iters=3, sigma=2.2) m, s = background.mean(), background.std() vmin, vmax = m - losig * s, m + hisig * s elif interval == 'fixed': vmin, vmax = imrange else: raise ValueError norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=stretch) for n, (im, ccd) in enumerate(zip(ims, ccdplotorder)): if im.ndim == 3: im = im.mean(axis=-1) x = fov[ccd]['x'] y = fov[ccd]['y'] i = n % 2 j = n // 2 pos = [0.0225 + i * w + i * 0.04, 0.05 + j * h + j * 0.005, w, h] ax = fig.add_axes(pos) _im = ax.imshow(im, origin='lower', extent=[x[0, 0], x[0, -1], y[0, 0], y[-1, 0]], norm=norm, cmap=cmap, interpolation=kwargs.get('interpolation', 'nearest')) if fov['coordsys'] == 'sky': ax.set_xlim(x.max(), x.min()) else: ax.set_xlim(x.min(), x.max()) ax.set_ylim(y.min(), y.max()) ax.xaxis.set_visible(False) ax.yaxis.set_visible(False) if n == 0: cb = fig.colorbar(_im, cax, orientation='horizontal') cb.ax.tick_params(labelsize=9) tstr = fov.get('file', '') + ' ' + fov.get('objname', '') title = kwargs.get('title', tstr) title = title[-60:] fig.text(0.5, 0.99, title, ha='center', va='top', size=12) if pngfn is not None: plt.savefig(pngfn) plt.close(fig)
def main(): # filename = 'S20170505S0102_flatCorrected.fits' filename = get_filename() ad = astrodata.open(filename) print(ad.info()) fig = plt.figure(num=filename, figsize=(8, 8)) fig.suptitle('{}'.format(filename)) palette = copy(plt.cm.viridis) palette.set_bad('w', 1.0) norm = visualization.ImageNormalize( np.dstack([ad[i].data for i in range(4)]), stretch=visualization.LinearStretch(), interval=visualization.ZScaleInterval()) ax1 = fig.add_subplot(224) ax1.imshow(np.ma.masked_where(ad[0].mask > 0, ad[0].data), norm=colors.Normalize(vmin=norm.vmin, vmax=norm.vmax), origin='lower', cmap=palette) ax1.annotate('d1', (20, 20), color='white') ax1.set_xlabel('x [pixels]') ax1.set_ylabel('y [pixels]') ax2 = fig.add_subplot(223) ax2.imshow(np.ma.masked_where(ad[1].mask > 0, ad[1].data), norm=colors.Normalize(vmin=norm.vmin, vmax=norm.vmax), origin='lower', cmap=palette) ax2.annotate('d2', (20, 20), color='white') ax2.set_xlabel('x [pixels]') ax2.set_ylabel('y [pixels]') ax3 = fig.add_subplot(221) ax3.imshow(np.ma.masked_where(ad[2].mask > 0, ad[2].data), norm=colors.Normalize(vmin=norm.vmin, vmax=norm.vmax), origin='lower', cmap=palette) ax3.annotate('d3', (20, 20), color='white') ax3.set_xlabel('x [pixels]') ax3.set_ylabel('y [pixels]') ax4 = fig.add_subplot(222) ax4.imshow(np.ma.masked_where(ad[3].mask > 0, ad[3].data), norm=colors.Normalize(vmin=norm.vmin, vmax=norm.vmax), origin='lower', cmap=palette) ax4.annotate('d4', (20, 20), color='white') ax4.set_xlabel('x [pixels]') ax4.set_ylabel('y [pixels]') fig.tight_layout(rect=[0, 0.03, 1, 0.95]) plt.savefig(filename.replace('.fits', '.png')) plt.show()
pl.imshow(rgb, origin='lower', interpolation='none', norm=norm) (x1, x2), (y1, y2) = celwcs.wcs_world2pix(lims[0], lims[1], 0) ax.axis((x1, x2, y1, y2)) visualization_tools.make_scalebar(ax, left_side=scalebarx, length=1.213 * u.arcsec, label='0.05 pc') pl.savefig(paths.fpath(figfilename), bbox_inches='tight') if __name__ == "__main__": rgbfig(stretch=visualization.LinearStretch(), ) rgbfig( figfilename='SgrB2M_RGB.pdf', lims=[(266.8359, 266.8325), (-28.38600555, -28.3832)], redfn=paths.Fpath('SGRB2M-2012-Q-MEAN.DePree.recentered.fits'), greenfn=paths.Fpath( 'sgr_b2m.M.B3.allspw.continuum.r0.5.clean1000.image.tt0.pbcor.fits' ), bluefn=paths.Fpath( 'sgr_b2m.M.B6.allspw.continuum.r0.5.clean1000.image.tt0.pbcor.fits' ), scalebarx=coordinates.SkyCoord(266.8336007 * u.deg, -28.38553839 * u.deg), redpercentile=99.99, greenpercentile=99.98,
def set_normalization(self, stretch=None, interval=None, stretchkwargs={}, intervalkwargs={}, perm_linear=None): if stretch is None: if self.stretch is None: stretch = 'linear' else: stretch = self.stretch if isinstance(stretch, str): print(stretch, ' '.join([f'{k}={v}' for k, v in stretchkwargs.items()])) if self.data is None: #can not calculate objects yet self.stretch_kwargs = stretchkwargs else: kwargs = self.prepare_kwargs( self.stretch_kws_defaults[stretch], self.stretch_kwargs, stretchkwargs) if perm_linear is not None: perm_linear_kwargs = self.prepare_kwargs( self.stretch_kws_defaults['linear'], perm_linear) print( 'linear', ' '.join([ f'{k}={v}' for k, v in perm_linear_kwargs.items() ])) if stretch == 'asinh': # arg: a=0.1 stretch = vis.CompositeStretch( vis.LinearStretch(**perm_linear_kwargs), vis.AsinhStretch(**kwargs)) elif stretch == 'contrastbias': # args: contrast, bias stretch = vis.CompositeStretch( vis.LinearStretch(**perm_linear_kwargs), vis.ContrastBiasStretch(**kwargs)) elif stretch == 'histogram': stretch = vis.CompositeStretch( vis.HistEqStretch(self.data, **kwargs), vis.LinearStretch(**perm_linear_kwargs)) elif stretch == 'log': # args: a=1000.0 stretch = vis.CompositeStretch( vis.LogStretch(**kwargs), vis.LinearStretch(**perm_linear_kwargs)) elif stretch == 'powerdist': # args: a=1000.0 stretch = vis.CompositeStretch( vis.LinearStretch(**perm_linear_kwargs), vis.PowerDistStretch(**kwargs)) elif stretch == 'power': # args: a stretch = vis.CompositeStretch( vis.PowerStretch(**kwargs), vis.LinearStretch(**perm_linear_kwargs)) elif stretch == 'sinh': # args: a=0.33 stretch = vis.CompositeStretch( vis.LinearStretch(**perm_linear_kwargs), vis.SinhStretch(**kwargs)) elif stretch == 'sqrt': stretch = vis.CompositeStretch( vis.SqrtStretch(), vis.LinearStretch(**perm_linear_kwargs)) elif stretch == 'square': stretch = vis.CompositeStretch( vis.LinearStretch(**perm_linear_kwargs), vis.SquaredStretch()) else: raise ValueError('Unknown stretch:' + stretch) else: if stretch == 'linear': # args: slope=1, intercept=0 stretch = vis.LinearStretch(**kwargs) else: raise ValueError('Unknown stretch:' + stretch) self.stretch = stretch if interval is None: if self.interval is None: interval = 'zscale' else: interval = self.interval if isinstance(interval, str): print(interval, ' '.join([f'{k}={v}' for k, v in intervalkwargs.items()])) kwargs = self.prepare_kwargs(self.interval_kws_defaults[interval], self.interval_kwargs, intervalkwargs) if self.data is None: self.interval_kwargs = intervalkwargs else: if interval == 'minmax': interval = vis.MinMaxInterval() elif interval == 'manual': # args: vmin, vmax interval = vis.ManualInterval(**kwargs) elif interval == 'percentile': # args: percentile, n_samples interval = vis.PercentileInterval(**kwargs) elif interval == 'asymetric': # args: lower_percentile, upper_percentile, n_samples interval = vis.AsymmetricPercentileInterval(**kwargs) elif interval == 'zscale': # args: nsamples=1000, contrast=0.25, max_reject=0.5, min_npixels=5, krej=2.5, max_iterations=5 interval = vis.ZScaleInterval(**kwargs) else: raise ValueError('Unknown interval:' + interval) self.interval = interval if self.img is not None: self.img.set_norm( vis.ImageNormalize(self.data, interval=self.interval, stretch=self.stretch, clip=True))
def show_image(image, percl=99, percu=None, figsize=(6, 10), cmap='viridis', log=False): """ Show an image in matplotlib with some basic astronomically-appropriat stretching. Parameters ---------- image The image to show percl : number The percentile for the lower edge of the stretch (or both edges if ``percu`` is None) percu : number or None The percentile for the upper edge of the stretch (or None to use ``percl`` for both) figsize : 2-tuple The size of the matplotlib figure in inches """ if percu is None: percu = percl percl = 100 - percl if figsize is not None: # Rescale the fig size to match the image dimensions, roughly image_aspect_ratio = image.shape[0] / image.shape[1] figsize = (max(figsize) * image_aspect_ratio, max(figsize)) fig, ax = plt.subplots(1, 1, figsize=figsize) # To preserve details we should *really* downsample correctly and not rely on # matplotlib to do it correctly for us (it won't). # So, calculate the size of the figure in pixels, block_reduce to roughly that, # and display the block reduced image. # Thanks, https://stackoverflow.com/questions/29702424/how-to-get-matplotlib-figure-size fig_size_pix = fig.get_size_inches() * fig.dpi ratio = (image.shape // fig_size_pix).max() if ratio < 1: ratio = 1 # Divide by the square of the ratio to keep the flux the same in the reduced image reduced_data = block_reduce(image, ratio) / ratio**2 # Of course, now that we have downsampled, the axis limits are changed to match # the smaller image size. Setting the extent will do the trick to change the axis display # back to showing the actual extent of the image. extent = [0, image.shape[1], 0, image.shape[0]] if log: stretch = aviz.LogStretch() else: stretch = aviz.LinearStretch() norm = aviz.ImageNormalize(reduced_data, interval=aviz.AsymmetricPercentileInterval( percl, percu), stretch=stretch) plt.colorbar( ax.imshow(reduced_data, norm=norm, origin='lower', cmap=cmap, extent=extent))