Esempio n. 1
0
def plotcontrast(img):
    """
    http://docs.astropy.org/en/stable/visualization/index.html
    """
    vistypes = (None, LogNorm(), vis.AsinhStretch(),
                vis.ContrastBiasStretch(1, 0.5), vis.HistEqStretch(img),
                vis.LinearStretch(), vis.LogStretch(),
                vis.PowerDistStretch(a=10.), vis.PowerStretch(a=10.),
                vis.SinhStretch(), vis.SqrtStretch(), vis.SquaredStretch())

    fg, ax = subplots(4, 3, figsize=(10, 10))
    ax = ax.ravel()

    for i, v in enumerate(vistypes):
        #a = figure().gca()
        a = ax[i]
        if v and not isinstance(v, LogNorm):
            norm = ImageNormalize(stretch=v)
            a.set_title(str(v.__class__).split('.')[-1].split("'")[0])
        else:
            norm = v
            a.set_title(str(v).split('.')[-1].split(" ")[0])

        a.imshow(img, origin='lower', cmap='gray', norm=norm)
        a.axis('off')

    fg.suptitle('Matplotlib/AstroPy normalizations')
Esempio n. 2
0
    def plot_norm(self,
                  stretch='linear',
                  power=1.0,
                  asinh_a=0.1,
                  min_cut=None,
                  max_cut=None,
                  min_percent=None,
                  max_percent=None,
                  percent=None,
                  clip=True):
        """Create a matplotlib norm object for plotting.

        This is a copy of this function that will be available in Astropy 1.3:
        `astropy.visualization.mpl_normalize.simple_norm`

        See the parameter description there!

        Examples
        --------
        >>> image = SkyImage()
        >>> norm = image.plot_norm(stretch='sqrt', max_percent=99)
        >>> image.plot(norm=norm)
        """
        import astropy.visualization as v
        from astropy.visualization.mpl_normalize import ImageNormalize

        if percent is not None:
            interval = v.PercentileInterval(percent)
        elif min_percent is not None or max_percent is not None:
            interval = v.AsymmetricPercentileInterval(min_percent or 0.,
                                                      max_percent or 100.)
        elif min_cut is not None or max_cut is not None:
            interval = v.ManualInterval(min_cut, max_cut)
        else:
            interval = v.MinMaxInterval()

        if stretch == 'linear':
            stretch = v.LinearStretch()
        elif stretch == 'sqrt':
            stretch = v.SqrtStretch()
        elif stretch == 'power':
            stretch = v.PowerStretch(power)
        elif stretch == 'log':
            stretch = v.LogStretch()
        elif stretch == 'asinh':
            stretch = v.AsinhStretch(asinh_a)
        else:
            raise ValueError('Unknown stretch: {0}.'.format(stretch))

        vmin, vmax = interval.get_limits(self.data)

        return ImageNormalize(vmin=vmin, vmax=vmax, stretch=stretch, clip=clip)
Esempio n. 3
0
def plotPrior(xdgmm, ax, c='black', lw=1, stretch=False):
    for gg in range(xdgmm.n_components):
        points = drawEllipse.plotvector(xdgmm.mu[gg], xdgmm.V[gg])
        if stretch:
            Stretch = av.PowerStretch(1. / 5)
            alpha = np.power(xdgmm.weights[gg] / np.max(xdgmm.weights),
                             1. / 10)
        else:
            alpha = xdgmm.weights[gg] / np.max(xdgmm.weights)
        ax.plot(points[0, :],
                testXD.absMagKinda2absMag(points[1, :]),
                c,
                lw=lw,
                alpha=alpha)
Esempio n. 4
0
 def get_offset(self, fitsobj, calcFFTs=False):
     """INPUT:   instance of FITS
         FNUC:   gets pixel offset between fits objects
     """
     logging.info('Calculating alignment offset %s <-- %s' %
                  (self, fitsobj))
     if calcFFTs:
         self.get_fft()
         fitsobj.get_fft()
     convolution = np.multiply(self.fft, np.conj(fitsobj.fft))
     inverse = np.fft.ifft2(convolution)
     stretch = viz.PowerStretch(0.4)
     stretch = lambda x: x
     plt.imshow(stretch(np.real(inverse)), cmap='pink')
     plt.xlabel("Offset [pix]")
     plt.ylabel("Offset [pix]")
     plt.show()
     a = np.argmax(inverse)
     r = len(self.data[0])
     dx = a / r
     dy = a % r
     if dx >= 0.5 * self.size[0]: dx = dx - self.size[0]
     if dy >= 0.5 * self.size[1]: dy = dy - self.size[1]
     return (dx, dy)
Esempio n. 5
0
    def create_figure(self,
                      frameno=0,
                      binning=1,
                      dpi=None,
                      stretch='log',
                      vmin=1,
                      vmax=5000,
                      cmap='gray',
                      data_col='FLUX',
                      annotate=True,
                      time_format='ut',
                      show_flags=False,
                      label=None):
        """Returns a matplotlib Figure object that visualizes a frame.

        Parameters
        ----------
        frameno : int
            Image number in the target pixel file.

        binning : int
            Number of frames around `frameno` to co-add. (default: 1).

        dpi : float, optional [dots per inch]
            Resolution of the output in dots per Kepler CCD pixel.
            By default the dpi is chosen such that the image is 440px wide.

        vmin : float, optional
            Minimum cut level (default: 1).

        vmax : float, optional
            Maximum cut level (default: 5000).

        cmap : str, optional
            The matplotlib color map name.  The default is 'gray',
            can also be e.g. 'gist_heat'.

        raw : boolean, optional
            If `True`, show the raw pixel counts rather than
            the calibrated flux. Default: `False`.

        annotate : boolean, optional
            Annotate the Figure with a timestamp and target name?
            (Default: `True`.)

        show_flags : boolean, optional
            Show the quality flags?
            (Default: `False`.)

        label : str
            Label text to show in the bottom left corner of the movie.

        Returns
        -------
        image : array
            An array of unisgned integers of shape (x, y, 3),
            representing an RBG colour image x px wide and y px high.
        """
        # Get the flux data to visualize
        flx = self.flux_binned(frameno=frameno,
                               binning=binning,
                               data_col=data_col)

        # Determine the figsize and dpi
        shape = list(flx.shape)
        shape = [shape[1], shape[0]]
        if dpi is None:
            # Twitter timeline requires dimensions between 440x220 and 1024x512
            # so we make 440 the default
            dpi = 440 / float(shape[0])

        # libx264 require the height to be divisible by 2, we ensure this here:
        shape[0] -= ((shape[0] * dpi) % 2) / dpi

        # Create the figureand display the flux image using matshow
        fig = pl.figure(figsize=shape, dpi=dpi)
        # Display the image using matshow
        ax = fig.add_subplot(1, 1, 1)
        if self.verbose:
            print('{} vmin/vmax = {}/{} (median={})'.format(
                data_col, vmin, vmax, np.nanmedian(flx)))

        if stretch == 'linear':
            stretch_fn = visualization.LinearStretch()
        elif stretch == 'sqrt':
            stretch_fn = visualization.SqrtStretch()
        elif stretch == 'power':
            stretch_fn = visualization.PowerStretch(1.0)
        elif stretch == 'log':
            stretch_fn = visualization.LogStretch()
        elif stretch == 'asinh':
            stretch_fn = visualization.AsinhStretch(0.1)
        else:
            raise ValueError('Unknown stretch: {0}.'.format(stretch))

        transform = (stretch_fn +
                     visualization.ManualInterval(vmin=vmin, vmax=vmax))
        flx_transform = 255 * transform(flx)
        # Make sure to remove all NaNs!
        flx_transform[~np.isfinite(flx_transform)] = 0
        ax.imshow(flx_transform.astype(int),
                  aspect='auto',
                  origin='lower',
                  interpolation='nearest',
                  cmap=cmap,
                  norm=NoNorm())
        if annotate:  # Annotate the frame with a timestamp and target name?
            fontsize = 3. * shape[0]
            margin = 0.03
            # Print target name in lower left corner
            if label is None:
                label = self.objectname
            txt = ax.text(margin,
                          margin,
                          label,
                          family="monospace",
                          fontsize=fontsize,
                          color='white',
                          transform=ax.transAxes)
            txt.set_path_effects([
                path_effects.Stroke(linewidth=fontsize / 6.,
                                    foreground='black'),
                path_effects.Normal()
            ])
            # Print a timestring in the lower right corner
            txt2 = ax.text(1 - margin,
                           margin,
                           self.timestamp(frameno, time_format=time_format),
                           family="monospace",
                           fontsize=fontsize,
                           color='white',
                           ha='right',
                           transform=ax.transAxes)
            txt2.set_path_effects([
                path_effects.Stroke(linewidth=fontsize / 6.,
                                    foreground='black'),
                path_effects.Normal()
            ])
            # Print quality flags in upper right corner
            if show_flags:
                flags = self.quality_flags(frameno)
                if len(flags) > 0:
                    txt3 = ax.text(margin,
                                   1 - margin,
                                   '\n'.join(flags),
                                   family="monospace",
                                   fontsize=fontsize * 1.3,
                                   color='white',
                                   ha='left',
                                   va='top',
                                   transform=ax.transAxes,
                                   linespacing=1.5,
                                   backgroundcolor='red')
                    txt3.set_path_effects([
                        path_effects.Stroke(linewidth=fontsize / 6.,
                                            foreground='black'),
                        path_effects.Normal()
                    ])
        ax.set_xticks([])
        ax.set_yticks([])
        ax.axis('off')
        fig.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0)
        fig.canvas.draw()
        return fig
Esempio n. 6
0
 def set_normalization(self,
                       stretch=None,
                       interval=None,
                       stretchkwargs={},
                       intervalkwargs={},
                       perm_linear=None):
     if stretch is None:
         if self.stretch is None:
             stretch = 'linear'
         else:
             stretch = self.stretch
     if isinstance(stretch, str):
         print(stretch,
               ' '.join([f'{k}={v}' for k, v in stretchkwargs.items()]))
         if self.data is None:  #can not calculate objects yet
             self.stretch_kwargs = stretchkwargs
         else:
             kwargs = self.prepare_kwargs(
                 self.stretch_kws_defaults[stretch], self.stretch_kwargs,
                 stretchkwargs)
             if perm_linear is not None:
                 perm_linear_kwargs = self.prepare_kwargs(
                     self.stretch_kws_defaults['linear'], perm_linear)
                 print(
                     'linear', ' '.join([
                         f'{k}={v}' for k, v in perm_linear_kwargs.items()
                     ]))
                 if stretch == 'asinh':  # arg: a=0.1
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.AsinhStretch(**kwargs))
                 elif stretch == 'contrastbias':  # args: contrast, bias
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.ContrastBiasStretch(**kwargs))
                 elif stretch == 'histogram':
                     stretch = vis.CompositeStretch(
                         vis.HistEqStretch(self.data, **kwargs),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'log':  # args: a=1000.0
                     stretch = vis.CompositeStretch(
                         vis.LogStretch(**kwargs),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'powerdist':  # args: a=1000.0
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.PowerDistStretch(**kwargs))
                 elif stretch == 'power':  # args: a
                     stretch = vis.CompositeStretch(
                         vis.PowerStretch(**kwargs),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'sinh':  # args: a=0.33
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.SinhStretch(**kwargs))
                 elif stretch == 'sqrt':
                     stretch = vis.CompositeStretch(
                         vis.SqrtStretch(),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'square':
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.SquaredStretch())
                 else:
                     raise ValueError('Unknown stretch:' + stretch)
             else:
                 if stretch == 'linear':  # args: slope=1, intercept=0
                     stretch = vis.LinearStretch(**kwargs)
                 else:
                     raise ValueError('Unknown stretch:' + stretch)
     self.stretch = stretch
     if interval is None:
         if self.interval is None:
             interval = 'zscale'
         else:
             interval = self.interval
     if isinstance(interval, str):
         print(interval,
               ' '.join([f'{k}={v}' for k, v in intervalkwargs.items()]))
         kwargs = self.prepare_kwargs(self.interval_kws_defaults[interval],
                                      self.interval_kwargs, intervalkwargs)
         if self.data is None:
             self.interval_kwargs = intervalkwargs
         else:
             if interval == 'minmax':
                 interval = vis.MinMaxInterval()
             elif interval == 'manual':  # args: vmin, vmax
                 interval = vis.ManualInterval(**kwargs)
             elif interval == 'percentile':  # args: percentile, n_samples
                 interval = vis.PercentileInterval(**kwargs)
             elif interval == 'asymetric':  # args: lower_percentile, upper_percentile, n_samples
                 interval = vis.AsymmetricPercentileInterval(**kwargs)
             elif interval == 'zscale':  # args: nsamples=1000, contrast=0.25, max_reject=0.5, min_npixels=5, krej=2.5, max_iterations=5
                 interval = vis.ZScaleInterval(**kwargs)
             else:
                 raise ValueError('Unknown interval:' + interval)
     self.interval = interval
     if self.img is not None:
         self.img.set_norm(
             vis.ImageNormalize(self.data,
                                interval=self.interval,
                                stretch=self.stretch,
                                clip=True))
Esempio n. 7
0
def get_image(data, fmt='JPEG', norm='percentile', lo=None, hi=None,
              zcontrast=0.25, nsamples=1000, krej=2.5, max_iterations=5,
              stretch='linear', a=None, bias=0.5, contrast=1, cmap=None,
              dpi=100, **kwargs):
    u"""
    Return a byte array containing image in the given format

    Image scaling is done using `~astropy.visualization`. It includes
    normalization of the input data (mapping to [0, 1]) and stretching -
    optional non-linear mapping [0, 1] -> [0, 1] for contrast enhancement.
    A colormap can be applied to the normalized data. Conversion to the target
    image format is done by matplotlib or Pillow.

    :param array_like data: input 2D image data
    :param str fmt: output image format
    :param str norm: data normalization mode::
        "manual": lower and higher clipping limits are set explicitly
        "minmax": limits are set to the minimum and maximum data values
        "percentile" (default): limits are set based on the specified fraction
            of pixels
        "zscale": use IRAF ZScale algorithm
    :param int | float lo::
        for ``norm`` == "manual", lower data limit
        for ``norm`` == "percentile", lower percentile clipping value,
            defaulting to 10
        for ``norm`` == "zscale", lower limit on the number of rejected pixels,
            defaulting to 5
    :param int | float hi::
        for ``norm`` == "manual", upper data limit
        for ``norm`` == "percentile", upper percentile clipping value,
            defaulting to 98
        for ``norm`` == "zscale", upper limit on the number of rejected pixels,
            defaulting to data.size/2
    :param float zcontrast: for ``norm`` == "zscale", the scaling factor,
        0 < zcontrast < 1, defaulting to 0.25
    :param int nsamples: for ``norm`` == "zscale", the number of points in
        the input array for determining scale factors, defaulting to 1000
    :param float krej: for ``norm`` == "zscale", the sigma clipping factor,
        defaulting to 2.5
    :param int max_iterations: for ``norm`` == "zscale", the maximum number
        of rejection iterations, defaulting to 5
    :param str stretch: [0, 1] → [0, 1] mapping mode::
        "asinh": hyperbolic arcsine stretch y = asinh(x/a)/asinh(1/a)
        "contrast": linear bias/contrast-based stretch
            y = (x - bias)*contrast + 0.5
        "exp": exponential stretch y = (a^x - 1)/(a - 1)
        "histeq": histogram equalization stretch
        "linear" (default): direct mapping
        "log": logarithmic stretch y = log(ax + 1)/log(a + 1)
        "power": power stretch y = x^a
        "sinh": hyperbolic sine stretch y = sinh(x/a)/sinh(1/a)
        "sqrt": square root stretch y = √x
        "square": power stretch y = x^2
    :param float a: non-linear stretch parameter::
        for ``stretch`` == "asinh", the point of transition from linear to
            logarithmic behavior, 0 < a <= 1, defaulting to 0.1
        for ``stretch`` == "exp", base of the exponent, a != 1, defaulting to
            1000
        for ``stretch`` == "log", base of the logarithm minus 1, a > 0,
            defaulting to 1000
        for ``stretch`` == "power", the power index, defaulting to 3
        for ``stretch`` == "sinh", a > 0, defaulting to 1/3
    :param float bias: for ``stretch`` == "contrast", the bias parameter,
        defaulting to 0.5
    :param float contrast: for ``stretch`` == "contrast", the contrast
        parameter, defaulting to 1
    :param str cmap: optional matplotlib colormap name, defaulting
        to grayscale; when a non-grayscale colormap is specified,
        the conversion is always done by matplotlib, regardless of the
        availability of Pillow; see https://matplotlib.org/users/colormaps.html
        for more info on matplotlib colormaps and
            [name for name in matplotlib.cd.cmap_d.keys()
             if not name.endswith('_r')]
        to list the available colormap names
    :param int dpi: target image resolution in dots per inch
    :param kwargs: optional format-specific keyword arguments passed to Pillow,
        e.g. "quality" for JPEG; see
        `https://pillow.readthedocs.io/en/stable/handbook/
        image-file-formats.html`_

    :return: a bytes object containing the image in the given format
    :rtype: bytes
    """
    data = asanyarray(data)

    # Normalize image data
    if norm == 'manual':
        if lo is None:
            raise ValueError(
                'Missing lower clipping boundary for norm="manual"')
        if hi is None:
            raise ValueError(
                'Missing upper clipping boundary for norm="manual"')
    elif norm == 'minmax':
        lo, hi = data.min(), data.max()
    elif norm == 'percentile':
        if lo is None:
            lo = 10
        elif not 0 <= lo <= 100:
            raise ValueError(
                'Lower clipping percentile must be in the [0,100] range')
        if hi is None:
            hi = 98
        elif not 0 <= hi <= 100:
            raise ValueError(
                'Upper clipping percentile must be in the [0,100] range')
        if hi < lo:
            raise ValueError(
                'Upper clipping percentile must be greater or equal to '
                'lower percentile')
        lo, hi = percentile(data, [lo, hi])
    elif norm == 'zscale':
        if lo is None:
            lo = 5
        if hi is None:
            hi = 0.5
        else:
            hi /= data.size
        lo, hi = apy_vis.ZScaleInterval(
            nsamples, zcontrast, hi, lo, krej, max_iterations).get_limits(data)
    else:
        raise ValueError('Unknown normalization mode "{}"'.format(norm))
    data = clip((data - lo)/(hi - lo), 0, 1)

    # Stretch the data
    if stretch == 'asinh':
        if a is None:
            a = 0.1
        apy_vis.AsinhStretch(a)(data, out=data)
    elif stretch == 'contrast':
        if bias != 0.5 or contrast != 1:
            apy_vis.ContrastBiasStretch(contrast, bias)(data, out=data)
    elif stretch == 'exp':
        if a is None:
            a = 1000
        apy_vis.PowerDistStretch(a)(data, out=data)
    elif stretch == 'histeq':
        apy_vis.HistEqStretch(data)(data, out=data)
    elif stretch == 'linear':
        pass
    elif stretch == 'log':
        if a is None:
            a = 1000
        apy_vis.LogStretch(a)(data, out=data)
    elif stretch == 'power':
        if a is None:
            a = 3
        apy_vis.PowerStretch(a)(data, out=data)
    elif stretch == 'sinh':
        if a is None:
            a = 1/3
        apy_vis.SinhStretch(a)(data, out=data)
    elif stretch == 'sqrt':
        apy_vis.SqrtStretch()(data, out=data)
    elif stretch == 'square':
        apy_vis.SquaredStretch()(data, out=data)
    else:
        raise ValueError('Unknown stretch mode "{}"'.format(stretch))

    buf = BytesIO()
    try:
        # Choose the backend for making an image
        if cmap is None:
            cmap = 'gray'
        if cmap == 'gray':
            try:
                # noinspection PyPackageRequirements,PyPep8Naming
                from PIL import Image as pil_image
            except ImportError:
                pil_image = None
        else:
            pil_image = None

        if pil_image is not None:
            # Use Pillow for grayscale output if available; flip the image to
            # match the bottom-to-top FITS convention and convert from [0,1] to
            # unsigned byte
            pil_image.fromarray(
                (data[::-1]*255 + 0.5).astype(uint8),
            ).save(buf, fmt, dpi=(dpi, dpi), **kwargs)
        else:
            # Use matplotlib for non-grayscale colormaps or if PIL is not
            # available
            # noinspection PyPackageRequirements
            from matplotlib import image as mpl_image
            if fmt.lower() == 'png':
                # PNG images are saved upside down by matplotlib, regardless of
                # the origin parameter
                data = data[::-1]
            # noinspection PyTypeChecker
            mpl_image.imsave(
                buf, data, cmap=cmap, format=fmt, origin='lower', dpi=dpi)

        return buf.getvalue()
    finally:
        buf.close()
Esempio n. 8
0
    def create_figure(self,
                      output_filename,
                      survey,
                      stretch='log',
                      vmin=1,
                      vmax=None,
                      min_percent=1,
                      max_percent=95,
                      cmap='gray',
                      contour_color='red',
                      data_col='FLUX'):
        """Returns a matplotlib Figure object that visualizes a frame.

        Parameters
        ----------

        vmin : float, optional
            Minimum cut level (default: 0).

        vmax : float, optional
            Maximum cut level (default: 5000).

        cmap : str, optional
            The matplotlib color map name.  The default is 'gray',
            can also be e.g. 'gist_heat'.

        raw : boolean, optional
            If `True`, show the raw pixel counts rather than
            the calibrated flux. Default: `False`.

        Returns
        -------
        image : array
            An array of unisgned integers of shape (x, y, 3),
            representing an RBG colour image x px wide and y px high.
        """
        # Get the flux data to visualize
        # Update to use TPF
        flx = self.TPF.flux_binned()
        # print(np.shape(flx))

        # calculate cut_levels
        if vmax is None:
            vmin, vmax = self.cut_levels(min_percent, max_percent, data_col)

        # Determine the figsize
        shape = list(flx.shape)
        # print(shape)
        # Create the figure and display the flux image using matshow
        fig = plt.figure(figsize=shape)
        # Display the image using matshow

        # Update to generate axes using WCS axes instead of plain axes
        ax = plt.subplot(projection=self.TPF.wcs)
        ax.set_xlabel('RA')
        ax.set_ylabel('Dec')

        if self.verbose:
            print('{} vmin/vmax = {}/{} (median={})'.format(
                data_col, vmin, vmax, np.nanmedian(flx)))

        if stretch == 'linear':
            stretch_fn = visualization.LinearStretch()
        elif stretch == 'sqrt':
            stretch_fn = visualization.SqrtStretch()
        elif stretch == 'power':
            stretch_fn = visualization.PowerStretch(1.0)
        elif stretch == 'log':
            stretch_fn = visualization.LogStretch()
        elif stretch == 'asinh':
            stretch_fn = visualization.AsinhStretch(0.1)
        else:
            raise ValueError('Unknown stretch: {0}.'.format(stretch))

        transform = (stretch_fn +
                     visualization.ManualInterval(vmin=vmin, vmax=vmax))
        ax.imshow((255 * transform(flx)).astype(int),
                  aspect='auto',
                  origin='lower',
                  interpolation='nearest',
                  cmap=cmap,
                  norm=NoNorm())
        ax.set_xticks([])
        ax.set_yticks([])

        current_ylims = ax.get_ylim()
        current_xlims = ax.get_xlim()

        pixels, header = surveyquery.getSVImg(self.TPF.position, survey)
        levels = np.linspace(np.min(pixels), np.percentile(pixels, 95), 10)
        ax.contour(pixels,
                   transform=ax.get_transform(WCS(header)),
                   levels=levels,
                   colors=contour_color)

        ax.set_xlim(current_xlims)
        ax.set_ylim(current_ylims)

        fig.canvas.draw()
        plt.savefig(output_filename, bbox_inches='tight', dpi=300)
        return fig