Beispiel #1
0
def plotcontrast(img):
    """
    http://docs.astropy.org/en/stable/visualization/index.html
    """
    vistypes = (None, LogNorm(), vis.AsinhStretch(),
                vis.ContrastBiasStretch(1, 0.5), vis.HistEqStretch(img),
                vis.LinearStretch(), vis.LogStretch(),
                vis.PowerDistStretch(a=10.), vis.PowerStretch(a=10.),
                vis.SinhStretch(), vis.SqrtStretch(), vis.SquaredStretch())

    fg, ax = subplots(4, 3, figsize=(10, 10))
    ax = ax.ravel()

    for i, v in enumerate(vistypes):
        #a = figure().gca()
        a = ax[i]
        if v and not isinstance(v, LogNorm):
            norm = ImageNormalize(stretch=v)
            a.set_title(str(v.__class__).split('.')[-1].split("'")[0])
        else:
            norm = v
            a.set_title(str(v).split('.')[-1].split(" ")[0])

        a.imshow(img, origin='lower', cmap='gray', norm=norm)
        a.axis('off')

    fg.suptitle('Matplotlib/AstroPy normalizations')
    def normalizeImage(cls, image):
        """Normalizes the image data to the [0,1] domain, using histogram
        equalization.

        Parameters
        ----------
        image : `np.array`
            Image.

        Returns
        -------
        norm : `np.array`
            Normalized image.
        """
        # TODO: make things like these configurable (also see resize in
        # store_thumbnail)
        stretch = aviz.HistEqStretch(image)
        norm = aviz.ImageNormalize(image, stretch=stretch, clip=True)

        return norm(image)
Beispiel #3
0
def plot_image(image,
               ax=None,
               scale='log',
               cmap=None,
               origin='lower',
               xlabel=None,
               ylabel=None,
               cbar=None,
               clabel='Flux ($e^{-}s^{-1}$)',
               cbar_ticks=None,
               cbar_ticklabels=None,
               cbar_pad=None,
               cbar_size='5%',
               title=None,
               percentile=95.0,
               vmin=None,
               vmax=None,
               offset_axes=None,
               color_bad='k',
               **kwargs):
    """
	Utility function to plot a 2D image.

	Parameters:
		image (2d array): Image data.
		ax (matplotlib.pyplot.axes, optional): Axes in which to plot.
			Default (None) is to use current active axes.
		scale (str or :py:class:`astropy.visualization.ImageNormalize` object, optional):
			Normalization used to stretch the colormap.
			Options: ``'linear'``, ``'sqrt'``, ``'log'``, ``'asinh'``, ``'histeq'``, ``'sinh'``
			and ``'squared'``.
			Can also be a :py:class:`astropy.visualization.ImageNormalize` object.
			Default is ``'log'``.
		origin (str, optional): The origin of the coordinate system.
		xlabel (str, optional): Label for the x-axis.
		ylabel (str, optional): Label for the y-axis.
		cbar (string, optional): Location of color bar.
			Choises are ``'right'``, ``'left'``, ``'top'``, ``'bottom'``.
			Default is not to create colorbar.
		clabel (str, optional): Label for the color bar.
		cbar_size (float, optional): Fractional size of colorbar compared to axes. Default=0.03.
		cbar_pad (float, optional): Padding between axes and colorbar.
		title (str or None, optional): Title for the plot.
		percentile (float, optional): The fraction of pixels to keep in color-trim.
			If single float given, the same fraction of pixels is eliminated from both ends.
			If tuple of two floats is given, the two are used as the percentiles.
			Default=95.
		cmap (matplotlib colormap, optional): Colormap to use. Default is the ``Blues`` colormap.
		vmin (float, optional): Lower limit to use for colormap.
		vmax (float, optional): Upper limit to use for colormap.
		color_bad (str, optional): Color to apply to bad pixels (NaN). Default is black.
		kwargs (dict, optional): Keyword arguments to be passed to :py:func:`matplotlib.pyplot.imshow`.

	Returns:
		:py:class:`matplotlib.image.AxesImage`: Image from returned
			by :py:func:`matplotlib.pyplot.imshow`.

	.. codeauthor:: Rasmus Handberg <*****@*****.**>
	"""

    logger = logging.getLogger(__name__)

    # Backward compatible settings:
    make_cbar = kwargs.pop('make_cbar', None)
    if make_cbar:
        raise FutureWarning("'make_cbar' is deprecated. Use 'cbar' instead.")
        if not cbar:
            cbar = make_cbar

    # Special treatment for boolean arrays:
    if isinstance(image, np.ndarray) and image.dtype == 'bool':
        if vmin is None: vmin = 0
        if vmax is None: vmax = 1
        if cbar_ticks is None: cbar_ticks = [0, 1]
        if cbar_ticklabels is None: cbar_ticklabels = ['False', 'True']

    # Calculate limits of color scaling:
    interval = None
    if vmin is None or vmax is None:
        if allnan(image):
            logger.warning("Image is all NaN")
            vmin = 0
            vmax = 1
            if cbar_ticks is None:
                cbar_ticks = []
            if cbar_ticklabels is None:
                cbar_ticklabels = []
        elif isinstance(percentile, (list, tuple, np.ndarray)):
            interval = viz.AsymmetricPercentileInterval(
                percentile[0], percentile[1])
        else:
            interval = viz.PercentileInterval(percentile)

    # Create ImageNormalize object with extracted limits:
    if scale in ('log', 'linear', 'sqrt', 'asinh', 'histeq', 'sinh',
                 'squared'):
        if scale == 'log':
            stretch = viz.LogStretch()
        elif scale == 'linear':
            stretch = viz.LinearStretch()
        elif scale == 'sqrt':
            stretch = viz.SqrtStretch()
        elif scale == 'asinh':
            stretch = viz.AsinhStretch()
        elif scale == 'histeq':
            stretch = viz.HistEqStretch(image[np.isfinite(image)])
        elif scale == 'sinh':
            stretch = viz.SinhStretch()
        elif scale == 'squared':
            stretch = viz.SquaredStretch()

        # Create ImageNormalize object. Very important to use clip=False if the image contains
        # NaNs, otherwise NaN points will not be plotted correctly.
        norm = viz.ImageNormalize(data=image[np.isfinite(image)],
                                  interval=interval,
                                  vmin=vmin,
                                  vmax=vmax,
                                  stretch=stretch,
                                  clip=not anynan(image))

    elif isinstance(scale, (viz.ImageNormalize, matplotlib.colors.Normalize)):
        norm = scale
    else:
        raise ValueError("scale {} is not available.".format(scale))

    if offset_axes:
        extent = (offset_axes[0] - 0.5, offset_axes[0] + image.shape[1] - 0.5,
                  offset_axes[1] - 0.5, offset_axes[1] + image.shape[0] - 0.5)
    else:
        extent = (-0.5, image.shape[1] - 0.5, -0.5, image.shape[0] - 0.5)

    if ax is None:
        ax = plt.gca()

    # Set up the colormap to use. If a bad color is defined,
    # add it to the colormap:
    if cmap is None:
        cmap = copy.copy(plt.get_cmap('Blues'))
    elif isinstance(cmap, str):
        cmap = copy.copy(plt.get_cmap(cmap))

    if color_bad:
        cmap.set_bad(color_bad, 1.0)

    # Plotting the image using all the settings set above:
    im = ax.imshow(image,
                   cmap=cmap,
                   norm=norm,
                   origin=origin,
                   extent=extent,
                   interpolation='nearest',
                   **kwargs)

    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)
    if title is not None:
        ax.set_title(title)
    ax.set_xlim([extent[0], extent[1]])
    ax.set_ylim([extent[2], extent[3]])

    if cbar:
        colorbar(im,
                 ax=ax,
                 loc=cbar,
                 size=cbar_size,
                 pad=cbar_pad,
                 label=clabel,
                 ticks=cbar_ticks,
                 ticklabels=cbar_ticklabels)

    # Settings for ticks:
    integer_locator = MaxNLocator(nbins=10, integer=True)
    ax.xaxis.set_major_locator(integer_locator)
    ax.xaxis.set_minor_locator(integer_locator)
    ax.yaxis.set_major_locator(integer_locator)
    ax.yaxis.set_minor_locator(integer_locator)
    ax.tick_params(which='both', direction='out', pad=5)
    ax.xaxis.tick_bottom()
    ax.yaxis.tick_left()

    return im
Beispiel #4
0
 def stretchModel(self):
     if self._stretchModel is None:
         print('Using default HistEqStretch stretch model')
         return lambda target, header, data: astrovis.HistEqStretch(data)
     return self._stretchModel
Beispiel #5
0
 def set_normalization(self,
                       stretch=None,
                       interval=None,
                       stretchkwargs={},
                       intervalkwargs={},
                       perm_linear=None):
     if stretch is None:
         if self.stretch is None:
             stretch = 'linear'
         else:
             stretch = self.stretch
     if isinstance(stretch, str):
         print(stretch,
               ' '.join([f'{k}={v}' for k, v in stretchkwargs.items()]))
         if self.data is None:  #can not calculate objects yet
             self.stretch_kwargs = stretchkwargs
         else:
             kwargs = self.prepare_kwargs(
                 self.stretch_kws_defaults[stretch], self.stretch_kwargs,
                 stretchkwargs)
             if perm_linear is not None:
                 perm_linear_kwargs = self.prepare_kwargs(
                     self.stretch_kws_defaults['linear'], perm_linear)
                 print(
                     'linear', ' '.join([
                         f'{k}={v}' for k, v in perm_linear_kwargs.items()
                     ]))
                 if stretch == 'asinh':  # arg: a=0.1
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.AsinhStretch(**kwargs))
                 elif stretch == 'contrastbias':  # args: contrast, bias
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.ContrastBiasStretch(**kwargs))
                 elif stretch == 'histogram':
                     stretch = vis.CompositeStretch(
                         vis.HistEqStretch(self.data, **kwargs),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'log':  # args: a=1000.0
                     stretch = vis.CompositeStretch(
                         vis.LogStretch(**kwargs),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'powerdist':  # args: a=1000.0
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.PowerDistStretch(**kwargs))
                 elif stretch == 'power':  # args: a
                     stretch = vis.CompositeStretch(
                         vis.PowerStretch(**kwargs),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'sinh':  # args: a=0.33
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.SinhStretch(**kwargs))
                 elif stretch == 'sqrt':
                     stretch = vis.CompositeStretch(
                         vis.SqrtStretch(),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'square':
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.SquaredStretch())
                 else:
                     raise ValueError('Unknown stretch:' + stretch)
             else:
                 if stretch == 'linear':  # args: slope=1, intercept=0
                     stretch = vis.LinearStretch(**kwargs)
                 else:
                     raise ValueError('Unknown stretch:' + stretch)
     self.stretch = stretch
     if interval is None:
         if self.interval is None:
             interval = 'zscale'
         else:
             interval = self.interval
     if isinstance(interval, str):
         print(interval,
               ' '.join([f'{k}={v}' for k, v in intervalkwargs.items()]))
         kwargs = self.prepare_kwargs(self.interval_kws_defaults[interval],
                                      self.interval_kwargs, intervalkwargs)
         if self.data is None:
             self.interval_kwargs = intervalkwargs
         else:
             if interval == 'minmax':
                 interval = vis.MinMaxInterval()
             elif interval == 'manual':  # args: vmin, vmax
                 interval = vis.ManualInterval(**kwargs)
             elif interval == 'percentile':  # args: percentile, n_samples
                 interval = vis.PercentileInterval(**kwargs)
             elif interval == 'asymetric':  # args: lower_percentile, upper_percentile, n_samples
                 interval = vis.AsymmetricPercentileInterval(**kwargs)
             elif interval == 'zscale':  # args: nsamples=1000, contrast=0.25, max_reject=0.5, min_npixels=5, krej=2.5, max_iterations=5
                 interval = vis.ZScaleInterval(**kwargs)
             else:
                 raise ValueError('Unknown interval:' + interval)
     self.interval = interval
     if self.img is not None:
         self.img.set_norm(
             vis.ImageNormalize(self.data,
                                interval=self.interval,
                                stretch=self.stretch,
                                clip=True))
Beispiel #6
0
def get_image(data, fmt='JPEG', norm='percentile', lo=None, hi=None,
              zcontrast=0.25, nsamples=1000, krej=2.5, max_iterations=5,
              stretch='linear', a=None, bias=0.5, contrast=1, cmap=None,
              dpi=100, **kwargs):
    u"""
    Return a byte array containing image in the given format

    Image scaling is done using `~astropy.visualization`. It includes
    normalization of the input data (mapping to [0, 1]) and stretching -
    optional non-linear mapping [0, 1] -> [0, 1] for contrast enhancement.
    A colormap can be applied to the normalized data. Conversion to the target
    image format is done by matplotlib or Pillow.

    :param array_like data: input 2D image data
    :param str fmt: output image format
    :param str norm: data normalization mode::
        "manual": lower and higher clipping limits are set explicitly
        "minmax": limits are set to the minimum and maximum data values
        "percentile" (default): limits are set based on the specified fraction
            of pixels
        "zscale": use IRAF ZScale algorithm
    :param int | float lo::
        for ``norm`` == "manual", lower data limit
        for ``norm`` == "percentile", lower percentile clipping value,
            defaulting to 10
        for ``norm`` == "zscale", lower limit on the number of rejected pixels,
            defaulting to 5
    :param int | float hi::
        for ``norm`` == "manual", upper data limit
        for ``norm`` == "percentile", upper percentile clipping value,
            defaulting to 98
        for ``norm`` == "zscale", upper limit on the number of rejected pixels,
            defaulting to data.size/2
    :param float zcontrast: for ``norm`` == "zscale", the scaling factor,
        0 < zcontrast < 1, defaulting to 0.25
    :param int nsamples: for ``norm`` == "zscale", the number of points in
        the input array for determining scale factors, defaulting to 1000
    :param float krej: for ``norm`` == "zscale", the sigma clipping factor,
        defaulting to 2.5
    :param int max_iterations: for ``norm`` == "zscale", the maximum number
        of rejection iterations, defaulting to 5
    :param str stretch: [0, 1] → [0, 1] mapping mode::
        "asinh": hyperbolic arcsine stretch y = asinh(x/a)/asinh(1/a)
        "contrast": linear bias/contrast-based stretch
            y = (x - bias)*contrast + 0.5
        "exp": exponential stretch y = (a^x - 1)/(a - 1)
        "histeq": histogram equalization stretch
        "linear" (default): direct mapping
        "log": logarithmic stretch y = log(ax + 1)/log(a + 1)
        "power": power stretch y = x^a
        "sinh": hyperbolic sine stretch y = sinh(x/a)/sinh(1/a)
        "sqrt": square root stretch y = √x
        "square": power stretch y = x^2
    :param float a: non-linear stretch parameter::
        for ``stretch`` == "asinh", the point of transition from linear to
            logarithmic behavior, 0 < a <= 1, defaulting to 0.1
        for ``stretch`` == "exp", base of the exponent, a != 1, defaulting to
            1000
        for ``stretch`` == "log", base of the logarithm minus 1, a > 0,
            defaulting to 1000
        for ``stretch`` == "power", the power index, defaulting to 3
        for ``stretch`` == "sinh", a > 0, defaulting to 1/3
    :param float bias: for ``stretch`` == "contrast", the bias parameter,
        defaulting to 0.5
    :param float contrast: for ``stretch`` == "contrast", the contrast
        parameter, defaulting to 1
    :param str cmap: optional matplotlib colormap name, defaulting
        to grayscale; when a non-grayscale colormap is specified,
        the conversion is always done by matplotlib, regardless of the
        availability of Pillow; see https://matplotlib.org/users/colormaps.html
        for more info on matplotlib colormaps and
            [name for name in matplotlib.cd.cmap_d.keys()
             if not name.endswith('_r')]
        to list the available colormap names
    :param int dpi: target image resolution in dots per inch
    :param kwargs: optional format-specific keyword arguments passed to Pillow,
        e.g. "quality" for JPEG; see
        `https://pillow.readthedocs.io/en/stable/handbook/
        image-file-formats.html`_

    :return: a bytes object containing the image in the given format
    :rtype: bytes
    """
    data = asanyarray(data)

    # Normalize image data
    if norm == 'manual':
        if lo is None:
            raise ValueError(
                'Missing lower clipping boundary for norm="manual"')
        if hi is None:
            raise ValueError(
                'Missing upper clipping boundary for norm="manual"')
    elif norm == 'minmax':
        lo, hi = data.min(), data.max()
    elif norm == 'percentile':
        if lo is None:
            lo = 10
        elif not 0 <= lo <= 100:
            raise ValueError(
                'Lower clipping percentile must be in the [0,100] range')
        if hi is None:
            hi = 98
        elif not 0 <= hi <= 100:
            raise ValueError(
                'Upper clipping percentile must be in the [0,100] range')
        if hi < lo:
            raise ValueError(
                'Upper clipping percentile must be greater or equal to '
                'lower percentile')
        lo, hi = percentile(data, [lo, hi])
    elif norm == 'zscale':
        if lo is None:
            lo = 5
        if hi is None:
            hi = 0.5
        else:
            hi /= data.size
        lo, hi = apy_vis.ZScaleInterval(
            nsamples, zcontrast, hi, lo, krej, max_iterations).get_limits(data)
    else:
        raise ValueError('Unknown normalization mode "{}"'.format(norm))
    data = clip((data - lo)/(hi - lo), 0, 1)

    # Stretch the data
    if stretch == 'asinh':
        if a is None:
            a = 0.1
        apy_vis.AsinhStretch(a)(data, out=data)
    elif stretch == 'contrast':
        if bias != 0.5 or contrast != 1:
            apy_vis.ContrastBiasStretch(contrast, bias)(data, out=data)
    elif stretch == 'exp':
        if a is None:
            a = 1000
        apy_vis.PowerDistStretch(a)(data, out=data)
    elif stretch == 'histeq':
        apy_vis.HistEqStretch(data)(data, out=data)
    elif stretch == 'linear':
        pass
    elif stretch == 'log':
        if a is None:
            a = 1000
        apy_vis.LogStretch(a)(data, out=data)
    elif stretch == 'power':
        if a is None:
            a = 3
        apy_vis.PowerStretch(a)(data, out=data)
    elif stretch == 'sinh':
        if a is None:
            a = 1/3
        apy_vis.SinhStretch(a)(data, out=data)
    elif stretch == 'sqrt':
        apy_vis.SqrtStretch()(data, out=data)
    elif stretch == 'square':
        apy_vis.SquaredStretch()(data, out=data)
    else:
        raise ValueError('Unknown stretch mode "{}"'.format(stretch))

    buf = BytesIO()
    try:
        # Choose the backend for making an image
        if cmap is None:
            cmap = 'gray'
        if cmap == 'gray':
            try:
                # noinspection PyPackageRequirements,PyPep8Naming
                from PIL import Image as pil_image
            except ImportError:
                pil_image = None
        else:
            pil_image = None

        if pil_image is not None:
            # Use Pillow for grayscale output if available; flip the image to
            # match the bottom-to-top FITS convention and convert from [0,1] to
            # unsigned byte
            pil_image.fromarray(
                (data[::-1]*255 + 0.5).astype(uint8),
            ).save(buf, fmt, dpi=(dpi, dpi), **kwargs)
        else:
            # Use matplotlib for non-grayscale colormaps or if PIL is not
            # available
            # noinspection PyPackageRequirements
            from matplotlib import image as mpl_image
            if fmt.lower() == 'png':
                # PNG images are saved upside down by matplotlib, regardless of
                # the origin parameter
                data = data[::-1]
            # noinspection PyTypeChecker
            mpl_image.imsave(
                buf, data, cmap=cmap, format=fmt, origin='lower', dpi=dpi)

        return buf.getvalue()
    finally:
        buf.close()
Beispiel #7
0
def make_fov_image(fov, pngfn=None, **kwargs):
    stretch = kwargs.get('stretch', 'linear')
    interval = kwargs.get('interval', 'zscale')
    imrange = kwargs.get('imrange')
    contrast = kwargs.get('contrast', 0.25)
    ccdplotorder = ['CCD2', 'CCD4', 'CCD1', 'CCD3']
    if interval == 'rms':
        try:
            losig, hisig = imrange
        except:
            losig, hisig = (2.5, 5.0)
    #
    cmap = kwargs.get('cmap', 'viridis')
    cmap = plt.get_cmap(cmap)
    cmap.set_bad('w', 1.0)
    w = 0.4575
    h = 0.455
    rc('text', usetex=False)
    fig = plt.figure(figsize=(6, 6.5))
    cax = fig.add_axes([0.1, 0.04, 0.8, 0.01])
    ims = [fov[ccd]['im'] for ccd in ccdplotorder]
    allpix = np.ma.array(ims).flatten()
    stretch = {
        'linear': vis.LinearStretch(),
        'histeq': vis.HistEqStretch(allpix),
        'asinh': vis.AsinhStretch(),
    }[stretch]
    if interval == 'zscale':
        iv = vis.ZScaleInterval(contrast=contrast)
        vmin, vmax = iv.get_limits(allpix)
    elif interval == 'rms':
        nsample = 1000 // nbin
        background = sigma_clip(allpix[::nsample], iters=3, sigma=2.2)
        m, s = background.mean(), background.std()
        vmin, vmax = m - losig * s, m + hisig * s
    elif interval == 'fixed':
        vmin, vmax = imrange
    else:
        raise ValueError
    norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=stretch)
    for n, (im, ccd) in enumerate(zip(ims, ccdplotorder)):
        if im.ndim == 3:
            im = im.mean(axis=-1)
        x = fov[ccd]['x']
        y = fov[ccd]['y']
        i = n % 2
        j = n // 2
        pos = [0.0225 + i * w + i * 0.04, 0.05 + j * h + j * 0.005, w, h]
        ax = fig.add_axes(pos)
        _im = ax.imshow(im,
                        origin='lower',
                        extent=[x[0, 0], x[0, -1], y[0, 0], y[-1, 0]],
                        norm=norm,
                        cmap=cmap,
                        interpolation=kwargs.get('interpolation', 'nearest'))
        if fov['coordsys'] == 'sky':
            ax.set_xlim(x.max(), x.min())
        else:
            ax.set_xlim(x.min(), x.max())
        ax.set_ylim(y.min(), y.max())
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        if n == 0:
            cb = fig.colorbar(_im, cax, orientation='horizontal')
            cb.ax.tick_params(labelsize=9)
    tstr = fov.get('file', '') + ' ' + fov.get('objname', '')
    title = kwargs.get('title', tstr)
    title = title[-60:]
    fig.text(0.5, 0.99, title, ha='center', va='top', size=12)
    if pngfn is not None:
        plt.savefig(pngfn)
        plt.close(fig)