示例#1
0
    def plot_norm(self,
                  stretch='linear',
                  power=1.0,
                  asinh_a=0.1,
                  min_cut=None,
                  max_cut=None,
                  min_percent=None,
                  max_percent=None,
                  percent=None,
                  clip=True):
        """Create a matplotlib norm object for plotting.

        This is a copy of this function that will be available in Astropy 1.3:
        `astropy.visualization.mpl_normalize.simple_norm`

        See the parameter description there!

        Examples
        --------
        >>> image = SkyImage()
        >>> norm = image.plot_norm(stretch='sqrt', max_percent=99)
        >>> image.plot(norm=norm)
        """
        import astropy.visualization as v
        from astropy.visualization.mpl_normalize import ImageNormalize

        if percent is not None:
            interval = v.PercentileInterval(percent)
        elif min_percent is not None or max_percent is not None:
            interval = v.AsymmetricPercentileInterval(min_percent or 0.,
                                                      max_percent or 100.)
        elif min_cut is not None or max_cut is not None:
            interval = v.ManualInterval(min_cut, max_cut)
        else:
            interval = v.MinMaxInterval()

        if stretch == 'linear':
            stretch = v.LinearStretch()
        elif stretch == 'sqrt':
            stretch = v.SqrtStretch()
        elif stretch == 'power':
            stretch = v.PowerStretch(power)
        elif stretch == 'log':
            stretch = v.LogStretch()
        elif stretch == 'asinh':
            stretch = v.AsinhStretch(asinh_a)
        else:
            raise ValueError('Unknown stretch: {0}.'.format(stretch))

        vmin, vmax = interval.get_limits(self.data)

        return ImageNormalize(vmin=vmin, vmax=vmax, stretch=stretch, clip=clip)
示例#2
0
def scale_rgb(red, green, blue, stretch_coeff, pmin=0, pmax=98.):
    """ Perform interval and stretch for RGB
    
    Parameters
    ----------
    red : CCDData object
        Red filter image data, parsed with load_image
    green : CCDData object
        Green filter image data, parsed with load_image    
    blue : CCDData object
        Blue filter image data, parsed with load_image
    stretch_coeff : tuple, (1,3)
        Power-law stretch coefficients for (red, green, blue)
        filters respectively.
    pmin : float, default = 0
        The lower percentile below which to ignore pixels. Passed to 
        AsymmetricPercentileInterval function.
    pmax : float, default = 98
        The upper percentile above which to ignore pixels. Passed to 
        AsymmetricPercentileInterval function.    
        
    Returns
    -------
    rgb : np.array, N x N x 3
    
    """
    r_stretch = vis.stretch.PowerStretch(stretch_coeff[0])
    g_stretch = vis.stretch.PowerStretch(stretch_coeff[1])
    b_stretch = vis.stretch.PowerStretch(stretch_coeff[2])

    interval = vis.AsymmetricPercentileInterval(pmin, pmax)

    rgb = np.array([
        r_stretch(interval(red)),
        g_stretch(interval(green)),
        b_stretch(interval(blue))
    ]).T
    return rgb
示例#3
0
def get_im_interval(pmin=10, pmax=99.9, vmin=None, vmax=None):
    ''' Returns an interval, to feed the ImageNormalize routine from Astropy.

   :param pmin: lower-limit percentile
   :type pmin: float
   :param pmax: upper-limit percentile
   :type pmax: float
   :param vmin: absolute lower limit
   :type vmin: float
   :param vmax: absolute upper limit
   :type vmax: float

   :return: an :class:`astropy.visualization.interval` thingy ...
   :rtype: :class:`astropy.visualization.interval`

   .. note:: Specifying *both* vmin and vmax will override pmin and pmax.

   '''

    if vmin is not None and vmax is not None:
        return astrovis.ManualInterval(vmin, vmax)

    return astrovis.AsymmetricPercentileInterval(pmin, pmax)
示例#4
0
def show_image(image,
               percl=99,
               percu=None,
               is_mask=False,
               figsize=(6, 10),
               cmap='viridis',
               log=False,
               show_colorbar=True,
               show_ticks=True,
               fig=None,
               ax=None,
               input_ratio=None):
    """
    Show an image in matplotlib with some basic astronomically-appropriat stretching.

    Parameters
    ----------
    image
        The image to show
    percl : number
        The percentile for the lower edge of the stretch (or both edges if ``percu`` is None)
    percu : number or None
        The percentile for the upper edge of the stretch (or None to use ``percl`` for both)
    figsize : 2-tuple
        The size of the matplotlib figure in inches
    """
    if percu is None:
        percu = percl
        percl = 100 - percl

    if (fig is None and ax is not None) or (fig is not None and ax is None):
        raise ValueError('Must provide both "fig" and "ax" '
                         'if you provide one of them')
    elif fig is None and ax is None:
        fig, ax = plt.subplots(1, 1, figsize=figsize)
        if figsize is not None:
            # Rescale the fig size to match the image dimensions, roughly
            image_aspect_ratio = image.shape[0] / image.shape[1]
            figsize = (max(figsize) * image_aspect_ratio, max(figsize))
            print(figsize)

    # To preserve details we should *really* downsample correctly and
    # not rely on matplotlib to do it correctly for us (it won't).

    # So, calculate the size of the figure in pixels, block_reduce to
    # roughly that,and display the block reduced image.

    # Thanks, https://stackoverflow.com/questions/29702424/how-to-get-matplotlib-figure-size
    fig_size_pix = fig.get_size_inches() * fig.dpi

    ratio = (image.shape // fig_size_pix).max()

    if ratio < 1:
        ratio = 1

    ratio = input_ratio or ratio

    # Divide by the square of the ratio to keep the flux the same in the
    # reduced image
    reduced_data = block_reduce(image, ratio) / ratio**2

    # Of course, now that we have downsampled, the axis limits are changed to
    # match the smaller image size. Setting the extent will do the trick to
    # change the axis display back to showing the actual extent of the image.
    extent = [0, image.shape[1], 0, image.shape[0]]

    if log:
        stretch = aviz.LogStretch()
    else:
        stretch = aviz.LinearStretch()

    norm = aviz.ImageNormalize(reduced_data,
                               interval=aviz.AsymmetricPercentileInterval(
                                   percl, percu),
                               stretch=stretch)

    if is_mask:
        # The image is a mask in which pixels are zero or one. Set the image scale
        # limits appropriately.
        scale_args = dict(vmin=0, vmax=1)
    else:
        scale_args = dict(norm=norm)

    im = ax.imshow(reduced_data,
                   origin='lower',
                   cmap=cmap,
                   extent=extent,
                   aspect='equal',
                   **scale_args)

    if show_colorbar:
        # I haven't a clue why the fraction and pad arguments below work to make
        # the colorbar the same height as the image, but they do....unless the image
        # is wider than it is tall. Sticking with this for now anyway...
        # Thanks: https://stackoverflow.com/a/26720422/3486425
        fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04, format='%2.0f')
        # In case someone in the future wants to improve this:
        # https://joseph-long.com/writing/colorbars/
        # https://stackoverflow.com/a/33505522/3486425
        # https://matplotlib.org/mpl_toolkits/axes_grid/users/overview.html#colorbar-whose-height-or-width-in-sync-with-the-master-axes

    if not show_ticks:
        ax.tick_params(labelbottom=False,
                       labelleft=False,
                       labelright=False,
                       labeltop=False)
示例#5
0
def plot_image(image,
               ax=None,
               scale='log',
               cmap=None,
               origin='lower',
               xlabel=None,
               ylabel=None,
               cbar=None,
               clabel='Flux ($e^{-}s^{-1}$)',
               cbar_ticks=None,
               cbar_ticklabels=None,
               cbar_pad=None,
               cbar_size='5%',
               title=None,
               percentile=95.0,
               vmin=None,
               vmax=None,
               offset_axes=None,
               color_bad='k',
               **kwargs):
    """
	Utility function to plot a 2D image.

	Parameters:
		image (2d array): Image data.
		ax (matplotlib.pyplot.axes, optional): Axes in which to plot.
			Default (None) is to use current active axes.
		scale (str or :py:class:`astropy.visualization.ImageNormalize` object, optional):
			Normalization used to stretch the colormap.
			Options: ``'linear'``, ``'sqrt'``, ``'log'``, ``'asinh'``, ``'histeq'``, ``'sinh'``
			and ``'squared'``.
			Can also be a :py:class:`astropy.visualization.ImageNormalize` object.
			Default is ``'log'``.
		origin (str, optional): The origin of the coordinate system.
		xlabel (str, optional): Label for the x-axis.
		ylabel (str, optional): Label for the y-axis.
		cbar (string, optional): Location of color bar.
			Choises are ``'right'``, ``'left'``, ``'top'``, ``'bottom'``.
			Default is not to create colorbar.
		clabel (str, optional): Label for the color bar.
		cbar_size (float, optional): Fractional size of colorbar compared to axes. Default=0.03.
		cbar_pad (float, optional): Padding between axes and colorbar.
		title (str or None, optional): Title for the plot.
		percentile (float, optional): The fraction of pixels to keep in color-trim.
			If single float given, the same fraction of pixels is eliminated from both ends.
			If tuple of two floats is given, the two are used as the percentiles.
			Default=95.
		cmap (matplotlib colormap, optional): Colormap to use. Default is the ``Blues`` colormap.
		vmin (float, optional): Lower limit to use for colormap.
		vmax (float, optional): Upper limit to use for colormap.
		color_bad (str, optional): Color to apply to bad pixels (NaN). Default is black.
		kwargs (dict, optional): Keyword arguments to be passed to :py:func:`matplotlib.pyplot.imshow`.

	Returns:
		:py:class:`matplotlib.image.AxesImage`: Image from returned
			by :py:func:`matplotlib.pyplot.imshow`.

	.. codeauthor:: Rasmus Handberg <*****@*****.**>
	"""

    logger = logging.getLogger(__name__)

    # Backward compatible settings:
    make_cbar = kwargs.pop('make_cbar', None)
    if make_cbar:
        raise FutureWarning("'make_cbar' is deprecated. Use 'cbar' instead.")
        if not cbar:
            cbar = make_cbar

    # Special treatment for boolean arrays:
    if isinstance(image, np.ndarray) and image.dtype == 'bool':
        if vmin is None: vmin = 0
        if vmax is None: vmax = 1
        if cbar_ticks is None: cbar_ticks = [0, 1]
        if cbar_ticklabels is None: cbar_ticklabels = ['False', 'True']

    # Calculate limits of color scaling:
    interval = None
    if vmin is None or vmax is None:
        if allnan(image):
            logger.warning("Image is all NaN")
            vmin = 0
            vmax = 1
            if cbar_ticks is None:
                cbar_ticks = []
            if cbar_ticklabels is None:
                cbar_ticklabels = []
        elif isinstance(percentile, (list, tuple, np.ndarray)):
            interval = viz.AsymmetricPercentileInterval(
                percentile[0], percentile[1])
        else:
            interval = viz.PercentileInterval(percentile)

    # Create ImageNormalize object with extracted limits:
    if scale in ('log', 'linear', 'sqrt', 'asinh', 'histeq', 'sinh',
                 'squared'):
        if scale == 'log':
            stretch = viz.LogStretch()
        elif scale == 'linear':
            stretch = viz.LinearStretch()
        elif scale == 'sqrt':
            stretch = viz.SqrtStretch()
        elif scale == 'asinh':
            stretch = viz.AsinhStretch()
        elif scale == 'histeq':
            stretch = viz.HistEqStretch(image[np.isfinite(image)])
        elif scale == 'sinh':
            stretch = viz.SinhStretch()
        elif scale == 'squared':
            stretch = viz.SquaredStretch()

        # Create ImageNormalize object. Very important to use clip=False if the image contains
        # NaNs, otherwise NaN points will not be plotted correctly.
        norm = viz.ImageNormalize(data=image[np.isfinite(image)],
                                  interval=interval,
                                  vmin=vmin,
                                  vmax=vmax,
                                  stretch=stretch,
                                  clip=not anynan(image))

    elif isinstance(scale, (viz.ImageNormalize, matplotlib.colors.Normalize)):
        norm = scale
    else:
        raise ValueError("scale {} is not available.".format(scale))

    if offset_axes:
        extent = (offset_axes[0] - 0.5, offset_axes[0] + image.shape[1] - 0.5,
                  offset_axes[1] - 0.5, offset_axes[1] + image.shape[0] - 0.5)
    else:
        extent = (-0.5, image.shape[1] - 0.5, -0.5, image.shape[0] - 0.5)

    if ax is None:
        ax = plt.gca()

    # Set up the colormap to use. If a bad color is defined,
    # add it to the colormap:
    if cmap is None:
        cmap = copy.copy(plt.get_cmap('Blues'))
    elif isinstance(cmap, str):
        cmap = copy.copy(plt.get_cmap(cmap))

    if color_bad:
        cmap.set_bad(color_bad, 1.0)

    # Plotting the image using all the settings set above:
    im = ax.imshow(image,
                   cmap=cmap,
                   norm=norm,
                   origin=origin,
                   extent=extent,
                   interpolation='nearest',
                   **kwargs)

    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)
    if title is not None:
        ax.set_title(title)
    ax.set_xlim([extent[0], extent[1]])
    ax.set_ylim([extent[2], extent[3]])

    if cbar:
        colorbar(im,
                 ax=ax,
                 loc=cbar,
                 size=cbar_size,
                 pad=cbar_pad,
                 label=clabel,
                 ticks=cbar_ticks,
                 ticklabels=cbar_ticklabels)

    # Settings for ticks:
    integer_locator = MaxNLocator(nbins=10, integer=True)
    ax.xaxis.set_major_locator(integer_locator)
    ax.xaxis.set_minor_locator(integer_locator)
    ax.yaxis.set_major_locator(integer_locator)
    ax.yaxis.set_minor_locator(integer_locator)
    ax.tick_params(which='both', direction='out', pad=5)
    ax.xaxis.tick_bottom()
    ax.yaxis.tick_left()

    return im
示例#6
0
    def plotDirectCutouts(self,
                          savePath=None,
                          colourMap='viridis',
                          gridSpecs=None):
        if self.directCutouts is not None:
            if gridSpecs is None:
                mplplot.figure(figsize=(10, 5))
            for stampIndex, (grism, cutoutData) in enumerate(
                    self.directCutouts.items()):
                if gridSpecs is None:
                    subplotAxes = mplplot.subplot(2, 1, stampIndex + 1)
                else:
                    subplotAxes = mplplot.subplot(gridSpecs[stampIndex])

                if cutoutData is None:
                    subplotAxes.text(
                        0.5,
                        0.5,
                        'Field {}, Object {}:\nNO DATA AVAILABLE.'.format(
                            self.targetPar, self.targetObject),
                        horizontalalignment='center',
                        fontsize='large',
                        transform=subplotAxes.transAxes)
                    continue
                if np.all(cutoutData < 0):
                    subplotAxes.text(
                        0.5,
                        0.5,
                        'Field {}, Object {}:\nNO NONZERO DATA AVAILABLE.'.
                        format(self.targetPar, self.targetObject),
                        horizontalalignment='center',
                        fontsize='large',
                        transform=subplotAxes.transAxes)
                    continue

                norm = astromplnorm.ImageNormalize(
                    cutoutData,
                    interval=astrovis.AsymmetricPercentileInterval(0, 99.5),
                    stretch=astrovis.LinearStretch(),
                    clip=True)

                mplplot.imshow(cutoutData,
                               origin='lower',
                               interpolation='nearest',
                               cmap=colourMap,
                               norm=norm)

                mplplot.xlabel('X (pixels)')
                mplplot.ylabel('Y (pixels)')
                mplplot.title(
                    'Field {}, Object {}:\nDirect cutout for F{} (G{})'.format(
                        self.targetPar, self.targetObject,
                        self.getDirectFilterForGrism(grism), grism))
                arcsecYAxis = subplotAxes.twinx()
                arcsecYAxis.set_ylim(*(np.array(subplotAxes.get_ylim()) -
                                       0.5 * np.sum(subplotAxes.get_ylim())) *
                                     self.directHdus[grism][1]['IDCSCALE'])
                print(
                    subplotAxes.get_ylim(), np.array(subplotAxes.get_ylim()),
                    np.array(subplotAxes.get_ylim()) *
                    self.directHdus[grism][1]['IDCSCALE'],
                    *np.array(subplotAxes.get_ylim()) *
                    self.directHdus[grism][1]['IDCSCALE'])
                arcsecYAxis.set_ylabel('$\Delta Y$ (arcsec)')
                mplplot.grid(color='white', ls='solid')

            try:
                mplplot.tight_layout()
            except ValueError as e:
                print(
                    'Error attempting tight_layout for: Field {}, Object {} ({})'
                    .format(self.targetPar, self.targetObject, e))
                return
            if savePath is not None:
                mplplot.savefig(savePath, dpi=300, bbox_inches='tight')
                mplplot.close()

        else:
            print(
                'The loadDirectCutouts(...) method must be called before direct cutouts can be plotted.'
            )
示例#7
0
 def set_normalization(self,
                       stretch=None,
                       interval=None,
                       stretchkwargs={},
                       intervalkwargs={},
                       perm_linear=None):
     if stretch is None:
         if self.stretch is None:
             stretch = 'linear'
         else:
             stretch = self.stretch
     if isinstance(stretch, str):
         print(stretch,
               ' '.join([f'{k}={v}' for k, v in stretchkwargs.items()]))
         if self.data is None:  #can not calculate objects yet
             self.stretch_kwargs = stretchkwargs
         else:
             kwargs = self.prepare_kwargs(
                 self.stretch_kws_defaults[stretch], self.stretch_kwargs,
                 stretchkwargs)
             if perm_linear is not None:
                 perm_linear_kwargs = self.prepare_kwargs(
                     self.stretch_kws_defaults['linear'], perm_linear)
                 print(
                     'linear', ' '.join([
                         f'{k}={v}' for k, v in perm_linear_kwargs.items()
                     ]))
                 if stretch == 'asinh':  # arg: a=0.1
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.AsinhStretch(**kwargs))
                 elif stretch == 'contrastbias':  # args: contrast, bias
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.ContrastBiasStretch(**kwargs))
                 elif stretch == 'histogram':
                     stretch = vis.CompositeStretch(
                         vis.HistEqStretch(self.data, **kwargs),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'log':  # args: a=1000.0
                     stretch = vis.CompositeStretch(
                         vis.LogStretch(**kwargs),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'powerdist':  # args: a=1000.0
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.PowerDistStretch(**kwargs))
                 elif stretch == 'power':  # args: a
                     stretch = vis.CompositeStretch(
                         vis.PowerStretch(**kwargs),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'sinh':  # args: a=0.33
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.SinhStretch(**kwargs))
                 elif stretch == 'sqrt':
                     stretch = vis.CompositeStretch(
                         vis.SqrtStretch(),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'square':
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.SquaredStretch())
                 else:
                     raise ValueError('Unknown stretch:' + stretch)
             else:
                 if stretch == 'linear':  # args: slope=1, intercept=0
                     stretch = vis.LinearStretch(**kwargs)
                 else:
                     raise ValueError('Unknown stretch:' + stretch)
     self.stretch = stretch
     if interval is None:
         if self.interval is None:
             interval = 'zscale'
         else:
             interval = self.interval
     if isinstance(interval, str):
         print(interval,
               ' '.join([f'{k}={v}' for k, v in intervalkwargs.items()]))
         kwargs = self.prepare_kwargs(self.interval_kws_defaults[interval],
                                      self.interval_kwargs, intervalkwargs)
         if self.data is None:
             self.interval_kwargs = intervalkwargs
         else:
             if interval == 'minmax':
                 interval = vis.MinMaxInterval()
             elif interval == 'manual':  # args: vmin, vmax
                 interval = vis.ManualInterval(**kwargs)
             elif interval == 'percentile':  # args: percentile, n_samples
                 interval = vis.PercentileInterval(**kwargs)
             elif interval == 'asymetric':  # args: lower_percentile, upper_percentile, n_samples
                 interval = vis.AsymmetricPercentileInterval(**kwargs)
             elif interval == 'zscale':  # args: nsamples=1000, contrast=0.25, max_reject=0.5, min_npixels=5, krej=2.5, max_iterations=5
                 interval = vis.ZScaleInterval(**kwargs)
             else:
                 raise ValueError('Unknown interval:' + interval)
     self.interval = interval
     if self.img is not None:
         self.img.set_norm(
             vis.ImageNormalize(self.data,
                                interval=self.interval,
                                stretch=self.stretch,
                                clip=True))
示例#8
0
def show_image(image,
               percl=99,
               percu=None,
               figsize=(6, 10),
               cmap='viridis',
               log=False):
    """
    Show an image in matplotlib with some basic astronomically-appropriat stretching.
    
    Parameters
    ----------
    image
        The image to show 
    percl : number
        The percentile for the lower edge of the stretch (or both edges if ``percu`` is None)
    percu : number or None
        The percentile for the upper edge of the stretch (or None to use ``percl`` for both)
    figsize : 2-tuple
        The size of the matplotlib figure in inches
    """
    if percu is None:
        percu = percl
        percl = 100 - percl

    if figsize is not None:
        # Rescale the fig size to match the image dimensions, roughly
        image_aspect_ratio = image.shape[0] / image.shape[1]
        figsize = (max(figsize) * image_aspect_ratio, max(figsize))

    fig, ax = plt.subplots(1, 1, figsize=figsize)

    # To preserve details we should *really* downsample correctly and not rely on
    # matplotlib to do it correctly for us (it won't).

    # So, calculate the size of the figure in pixels, block_reduce to roughly that,
    # and display the block reduced image.

    # Thanks, https://stackoverflow.com/questions/29702424/how-to-get-matplotlib-figure-size
    fig_size_pix = fig.get_size_inches() * fig.dpi

    ratio = (image.shape // fig_size_pix).max()

    if ratio < 1:
        ratio = 1

    # Divide by the square of the ratio to keep the flux the same in the reduced image
    reduced_data = block_reduce(image, ratio) / ratio**2

    # Of course, now that we have downsampled, the axis limits are changed to match
    # the smaller image size. Setting the extent will do the trick to change the axis display
    # back to showing the actual extent of the image.
    extent = [0, image.shape[1], 0, image.shape[0]]
    if log:
        stretch = aviz.LogStretch()
    else:
        stretch = aviz.LinearStretch()
    norm = aviz.ImageNormalize(reduced_data,
                               interval=aviz.AsymmetricPercentileInterval(
                                   percl, percu),
                               stretch=stretch)

    plt.colorbar(
        ax.imshow(reduced_data,
                  norm=norm,
                  origin='lower',
                  cmap=cmap,
                  extent=extent))