예제 #1
0
def main():

    filename = get_stack_filename()

    ad = astrodata.open(filename)

    data = ad[0].data
    mask = ad[0].mask
    header = ad[0].hdr

    masked_data = np.ma.masked_where(mask, data, copy=True)

    palette = copy(plt.cm.viridis)
    palette.set_bad('gray')

    norm_factor = visualization.ImageNormalize(
        masked_data,
        stretch=visualization.LinearStretch(),
        interval=visualization.ZScaleInterval(),
    )

    fig, ax = plt.subplots(subplot_kw={'projection': wcs.WCS(header)})

    ax.imshow(masked_data,
              cmap=palette,
              vmin=norm_factor.vmin,
              vmax=norm_factor.vmax)

    ax.set_title(os.path.basename(filename))
    ax.set_xticklabels([])
    ax.set_yticklabels([])

    fig.savefig(filename.replace('.fits', '.png'))
    plt.show()
예제 #2
0
def plotcontrast(img):
    """
    http://docs.astropy.org/en/stable/visualization/index.html
    """
    vistypes = (None, LogNorm(), vis.AsinhStretch(),
                vis.ContrastBiasStretch(1, 0.5), vis.HistEqStretch(img),
                vis.LinearStretch(), vis.LogStretch(),
                vis.PowerDistStretch(a=10.), vis.PowerStretch(a=10.),
                vis.SinhStretch(), vis.SqrtStretch(), vis.SquaredStretch())

    fg, ax = subplots(4, 3, figsize=(10, 10))
    ax = ax.ravel()

    for i, v in enumerate(vistypes):
        #a = figure().gca()
        a = ax[i]
        if v and not isinstance(v, LogNorm):
            norm = ImageNormalize(stretch=v)
            a.set_title(str(v.__class__).split('.')[-1].split("'")[0])
        else:
            norm = v
            a.set_title(str(v).split('.')[-1].split(" ")[0])

        a.imshow(img, origin='lower', cmap='gray', norm=norm)
        a.axis('off')

    fg.suptitle('Matplotlib/AstroPy normalizations')
예제 #3
0
파일: core.py 프로젝트: vorugantia/gammapy
    def plot_norm(self,
                  stretch='linear',
                  power=1.0,
                  asinh_a=0.1,
                  min_cut=None,
                  max_cut=None,
                  min_percent=None,
                  max_percent=None,
                  percent=None,
                  clip=True):
        """Create a matplotlib norm object for plotting.

        This is a copy of this function that will be available in Astropy 1.3:
        `astropy.visualization.mpl_normalize.simple_norm`

        See the parameter description there!

        Examples
        --------
        >>> image = SkyImage()
        >>> norm = image.plot_norm(stretch='sqrt', max_percent=99)
        >>> image.plot(norm=norm)
        """
        import astropy.visualization as v
        from astropy.visualization.mpl_normalize import ImageNormalize

        if percent is not None:
            interval = v.PercentileInterval(percent)
        elif min_percent is not None or max_percent is not None:
            interval = v.AsymmetricPercentileInterval(min_percent or 0.,
                                                      max_percent or 100.)
        elif min_cut is not None or max_cut is not None:
            interval = v.ManualInterval(min_cut, max_cut)
        else:
            interval = v.MinMaxInterval()

        if stretch == 'linear':
            stretch = v.LinearStretch()
        elif stretch == 'sqrt':
            stretch = v.SqrtStretch()
        elif stretch == 'power':
            stretch = v.PowerStretch(power)
        elif stretch == 'log':
            stretch = v.LogStretch()
        elif stretch == 'asinh':
            stretch = v.AsinhStretch(asinh_a)
        else:
            raise ValueError('Unknown stretch: {0}.'.format(stretch))

        vmin, vmax = interval.get_limits(self.data)

        return ImageNormalize(vmin=vmin, vmax=vmax, stretch=stretch, clip=clip)
예제 #4
0
def main():
    args = _parse_args()
    filename = args.filename

    ad = astrodata.open(filename)

    data = ad[0].data
    mask = ad[0].mask
    header = ad[0].hdr

    if args.mask:
        masked_data = np.ma.masked_where(mask, data, copy=True)
    else:
        masked_data = data

    palette = copy(plt.cm.viridis)
    palette.set_bad('Gainsboro')

    norm_factor = visualization.ImageNormalize(
        masked_data,
        stretch=visualization.LinearStretch(),
        interval=visualization.ZScaleInterval(),
    )

    fig = plt.figure(num=filename)
    ax = fig.subplots(subplot_kw={"projection": wcs.WCS(header)})

    print(norm_factor.vmin)
    print(norm_factor.vmax)
    ax.imshow(
        masked_data,
        cmap=palette,
        #vmin=norm_factor.vmin,
        #vmax=norm_factor.vmax,
        vmin=750.,
        vmax=900.,
        origin='lower')

    ax.set_title(os.path.basename(filename))

    ax.coords[0].set_axislabel('Right Ascension')
    ax.coords[0].set_ticklabel(fontsize='small')

    ax.coords[1].set_axislabel('Declination')
    ax.coords[1].set_ticklabel(rotation='vertical', fontsize='small')

    fig.tight_layout(rect=[0.05, 0, 1, 1])
    fig.savefig(os.path.basename(filename.replace('.fits', '.png')))
    plt.show()
예제 #5
0
def get_im_stretch(stretch):
    ''' Returns a stretch to feed the ImageNormalize routine from Astropy.

   :param stretch: short name for the stretch I want. Possibilities are 'arcsinh' or 'linear'.
   :type stretch: string

   :return: A :class:`astropy.visualization.stretch` thingy ...
   :rtype: :class:`astropy.visualization.stretch`

   '''

    if stretch == 'arcsinh':
        return astrovis.AsinhStretch()
    if stretch == 'linear':
        return astrovis.LinearStretch()

    raise Exception('Ouch! Stretch %s unknown.' % (stretch))
예제 #6
0
def main():

    filename = get_stack_filename()

    ad = astrodata.open(filename)

    fig = plt.figure(num=filename, figsize=(7, 4.5))
    fig.suptitle(os.path.basename(filename), y=0.97)

    axs = fig.subplots(1, len(ad), sharey=True)

    palette = copy(plt.cm.viridis)
    palette.set_bad("Gainsboro", 1.0)

    norm = visualization.ImageNormalize(
        np.dstack([ext.data for ext in ad]),
        stretch=visualization.LinearStretch(),
        interval=visualization.ZScaleInterval()
    )

    print(norm.vmin)
    print(norm.vmax)
    for i in range(len(ad)):

        axs[i].imshow(
            # np.ma.masked_where(ad[i].mask > 0, ad[i].data),
            ad[i].data,
            #norm=colors.Normalize(vmin=norm.vmin, vmax=norm.vmax),
            norm=colors.Normalize(vmin=750, vmax=900),
            origin="lower",
            cmap=palette,
        )

        axs[i].set_xlabel('d{:02d}'.format(i+1))
        axs[i].set_xticks([])

    axs[i].set_yticks([])

    fig.tight_layout(rect=[0, 0, 1, 1], w_pad=0.05)

    fig.savefig(os.path.basename(filename.replace('.fits', '.png')))
    plt.show()
예제 #7
0
파일: core.py 프로젝트: keatonb/k2flix
    def create_figure(self,
                      frameno=0,
                      binning=1,
                      dpi=None,
                      stretch='log',
                      vmin=1,
                      vmax=5000,
                      cmap='gray',
                      data_col='FLUX',
                      annotate=True,
                      time_format='ut',
                      show_flags=False,
                      label=None):
        """Returns a matplotlib Figure object that visualizes a frame.

        Parameters
        ----------
        frameno : int
            Image number in the target pixel file.

        binning : int
            Number of frames around `frameno` to co-add. (default: 1).

        dpi : float, optional [dots per inch]
            Resolution of the output in dots per Kepler CCD pixel.
            By default the dpi is chosen such that the image is 440px wide.

        vmin : float, optional
            Minimum cut level (default: 1).

        vmax : float, optional
            Maximum cut level (default: 5000).

        cmap : str, optional
            The matplotlib color map name.  The default is 'gray',
            can also be e.g. 'gist_heat'.

        raw : boolean, optional
            If `True`, show the raw pixel counts rather than
            the calibrated flux. Default: `False`.

        annotate : boolean, optional
            Annotate the Figure with a timestamp and target name?
            (Default: `True`.)

        show_flags : boolean, optional
            Show the quality flags?
            (Default: `False`.)

        label : str
            Label text to show in the bottom left corner of the movie.

        Returns
        -------
        image : array
            An array of unisgned integers of shape (x, y, 3),
            representing an RBG colour image x px wide and y px high.
        """
        # Get the flux data to visualize
        flx = self.flux_binned(frameno=frameno,
                               binning=binning,
                               data_col=data_col)

        # Determine the figsize and dpi
        shape = list(flx.shape)
        shape = [shape[1], shape[0]]
        if dpi is None:
            # Twitter timeline requires dimensions between 440x220 and 1024x512
            # so we make 440 the default
            dpi = 440 / float(shape[0])

        # libx264 require the height to be divisible by 2, we ensure this here:
        shape[0] -= ((shape[0] * dpi) % 2) / dpi

        # Create the figureand display the flux image using matshow
        fig = pl.figure(figsize=shape, dpi=dpi)
        # Display the image using matshow
        ax = fig.add_subplot(1, 1, 1)
        if self.verbose:
            print('{} vmin/vmax = {}/{} (median={})'.format(
                data_col, vmin, vmax, np.nanmedian(flx)))

        if stretch == 'linear':
            stretch_fn = visualization.LinearStretch()
        elif stretch == 'sqrt':
            stretch_fn = visualization.SqrtStretch()
        elif stretch == 'power':
            stretch_fn = visualization.PowerStretch(1.0)
        elif stretch == 'log':
            stretch_fn = visualization.LogStretch()
        elif stretch == 'asinh':
            stretch_fn = visualization.AsinhStretch(0.1)
        else:
            raise ValueError('Unknown stretch: {0}.'.format(stretch))

        transform = (stretch_fn +
                     visualization.ManualInterval(vmin=vmin, vmax=vmax))
        flx_transform = 255 * transform(flx)
        # Make sure to remove all NaNs!
        flx_transform[~np.isfinite(flx_transform)] = 0
        ax.imshow(flx_transform.astype(int),
                  aspect='auto',
                  origin='lower',
                  interpolation='nearest',
                  cmap=cmap,
                  norm=NoNorm())
        if annotate:  # Annotate the frame with a timestamp and target name?
            fontsize = 3. * shape[0]
            margin = 0.03
            # Print target name in lower left corner
            if label is None:
                label = self.objectname
            txt = ax.text(margin,
                          margin,
                          label,
                          family="monospace",
                          fontsize=fontsize,
                          color='white',
                          transform=ax.transAxes)
            txt.set_path_effects([
                path_effects.Stroke(linewidth=fontsize / 6.,
                                    foreground='black'),
                path_effects.Normal()
            ])
            # Print a timestring in the lower right corner
            txt2 = ax.text(1 - margin,
                           margin,
                           self.timestamp(frameno, time_format=time_format),
                           family="monospace",
                           fontsize=fontsize,
                           color='white',
                           ha='right',
                           transform=ax.transAxes)
            txt2.set_path_effects([
                path_effects.Stroke(linewidth=fontsize / 6.,
                                    foreground='black'),
                path_effects.Normal()
            ])
            # Print quality flags in upper right corner
            if show_flags:
                flags = self.quality_flags(frameno)
                if len(flags) > 0:
                    txt3 = ax.text(margin,
                                   1 - margin,
                                   '\n'.join(flags),
                                   family="monospace",
                                   fontsize=fontsize * 1.3,
                                   color='white',
                                   ha='left',
                                   va='top',
                                   transform=ax.transAxes,
                                   linespacing=1.5,
                                   backgroundcolor='red')
                    txt3.set_path_effects([
                        path_effects.Stroke(linewidth=fontsize / 6.,
                                            foreground='black'),
                        path_effects.Normal()
                    ])
        ax.set_xticks([])
        ax.set_yticks([])
        ax.axis('off')
        fig.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0)
        fig.canvas.draw()
        return fig
예제 #8
0
                        interval=v.ManualInterval(vmin=image.min() - 5,
                                                  vmax=image.max() + 10),
                        stretch=v.LogStretch(10))
im = plt.imshow(image, origin='lower', norm=norm, cmap='Greys')
plt.colorbar(im)

plt.subplot(1, 2, 2)
plt.title('Nebular Emission Mask')
mimage = np.ma.MaskedArray(image)
mimage.mask = ~nmask
mimagef = np.ma.filled(mimage, fill_value=0)
norm = v.ImageNormalize(mimagef,
                        interval=v.ManualInterval(
                            vmin=image.min() - 5,
                            vmax=np.percentile(image, mask_pcnt) + 5),
                        stretch=v.LinearStretch())
im = plt.imshow(mimagef, origin='lower', norm=norm, cmap='Greys')
plt.colorbar(im)

plt.savefig('H-beta Image.png', bbox_inches='tight', pad_inches=0.10)

##-------------------------------------------------------------------------
## Model and Plot Nebular Background
log.info('Model and plot nebular background')
background_0 = models.Polynomial1D(degree=2)
H_beta_0 = models.Gaussian1D(amplitude=500,
                             mean=4861,
                             stddev=1.,
                             bounds={
                                 'mean': (4855, 4865),
                                 'stddev': (0.1, 5)
예제 #9
0
def show_image(image,
               percl=99,
               percu=None,
               is_mask=False,
               figsize=(6, 10),
               cmap='viridis',
               log=False,
               show_colorbar=True,
               show_ticks=True,
               fig=None,
               ax=None,
               input_ratio=None):
    """
    Show an image in matplotlib with some basic astronomically-appropriat stretching.

    Parameters
    ----------
    image
        The image to show
    percl : number
        The percentile for the lower edge of the stretch (or both edges if ``percu`` is None)
    percu : number or None
        The percentile for the upper edge of the stretch (or None to use ``percl`` for both)
    figsize : 2-tuple
        The size of the matplotlib figure in inches
    """
    if percu is None:
        percu = percl
        percl = 100 - percl

    if (fig is None and ax is not None) or (fig is not None and ax is None):
        raise ValueError('Must provide both "fig" and "ax" '
                         'if you provide one of them')
    elif fig is None and ax is None:
        fig, ax = plt.subplots(1, 1, figsize=figsize)
        if figsize is not None:
            # Rescale the fig size to match the image dimensions, roughly
            image_aspect_ratio = image.shape[0] / image.shape[1]
            figsize = (max(figsize) * image_aspect_ratio, max(figsize))
            print(figsize)

    # To preserve details we should *really* downsample correctly and
    # not rely on matplotlib to do it correctly for us (it won't).

    # So, calculate the size of the figure in pixels, block_reduce to
    # roughly that,and display the block reduced image.

    # Thanks, https://stackoverflow.com/questions/29702424/how-to-get-matplotlib-figure-size
    fig_size_pix = fig.get_size_inches() * fig.dpi

    ratio = (image.shape // fig_size_pix).max()

    if ratio < 1:
        ratio = 1

    ratio = input_ratio or ratio

    # Divide by the square of the ratio to keep the flux the same in the
    # reduced image
    reduced_data = block_reduce(image, ratio) / ratio**2

    # Of course, now that we have downsampled, the axis limits are changed to
    # match the smaller image size. Setting the extent will do the trick to
    # change the axis display back to showing the actual extent of the image.
    extent = [0, image.shape[1], 0, image.shape[0]]

    if log:
        stretch = aviz.LogStretch()
    else:
        stretch = aviz.LinearStretch()

    norm = aviz.ImageNormalize(reduced_data,
                               interval=aviz.AsymmetricPercentileInterval(
                                   percl, percu),
                               stretch=stretch)

    if is_mask:
        # The image is a mask in which pixels are zero or one. Set the image scale
        # limits appropriately.
        scale_args = dict(vmin=0, vmax=1)
    else:
        scale_args = dict(norm=norm)

    im = ax.imshow(reduced_data,
                   origin='lower',
                   cmap=cmap,
                   extent=extent,
                   aspect='equal',
                   **scale_args)

    if show_colorbar:
        # I haven't a clue why the fraction and pad arguments below work to make
        # the colorbar the same height as the image, but they do....unless the image
        # is wider than it is tall. Sticking with this for now anyway...
        # Thanks: https://stackoverflow.com/a/26720422/3486425
        fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04, format='%2.0f')
        # In case someone in the future wants to improve this:
        # https://joseph-long.com/writing/colorbars/
        # https://stackoverflow.com/a/33505522/3486425
        # https://matplotlib.org/mpl_toolkits/axes_grid/users/overview.html#colorbar-whose-height-or-width-in-sync-with-the-master-axes

    if not show_ticks:
        ax.tick_params(labelbottom=False,
                       labelleft=False,
                       labelright=False,
                       labeltop=False)
예제 #10
0
def find_bar_positions_from_image(imagefile, filtersize=5, plot=False,
                                  pixel_shim=5):
    '''Loop over all slits in the image and using the affine transformation
    determined by `fit_transforms`, select the Y pixel range over which this
    slit should be found.  Take a median filtered version of that image and
    determine the X direction gradient (derivative).  Then collapse it in
    the Y direction to form a 1D profile.
    
    Using the `find_bar_edges` method, determine the X pixel positions of
    each bar forming the slit.
    
    Convert those X pixel position to physical coordinates using the
    `pixel_to_physical` method and then call the `compare_to_csu_bar_state`
    method to determine the bar state.
    '''
    ## Get image from file
    imagefile = Path(imagefile).absolute()
    try:
        hdul = fits.open(imagefile)
        data = hdul[0].data
    except Error as e:
        log.error(e)
        raise
    # median X pixels only (preserve Y structure)
    medimage = ndimage.median_filter(data, size=(1, filtersize))
    
    bars = {}
    ypos = {}
    for slit in range(1,47):
        b1, b2 = slit_to_bars(slit)
        ## Determine y pixel range
        y1 = int(np.ceil((physical_to_pixel(np.array([(4.0, slit+0.5)])))[0][0][1])) + pixel_shim
        y2 = int(np.floor((physical_to_pixel(np.array([(270.4, slit-0.5)])))[0][0][1])) - pixel_shim
        ypos[b1] = [y1, y2]
        ypos[b2] = [y1, y2]
        gradx = np.gradient(medimage[y1:y2,:], axis=1)
        horizontal_profile = np.sum(gradx, axis=0)
        try:
            bars[b1], bars[b2] = find_bar_edges(horizontal_profile)
        except:
            print(f'Unable to fit bars: {b1}, {b2}')

    # Generate plot if called for
    if plot is True:
        plotfile = imagefile.with_name(f"{imagefile.stem}.png")
        log.info(f'Creating PNG image {plotfile}')
        if plotfile.exists(): plotfile.unlink()
        plt.figure(figsize=(16,16), dpi=300)
        norm = viz.ImageNormalize(data, interval=viz.PercentileInterval(99.9),
                                  stretch=viz.LinearStretch())
        plt.imshow(data, norm=norm, origin='lower', cmap='Greys')


        for bar in bars.keys():
#             plt.plot([0,2048], [ypos[bar][0], ypos[bar][0]], 'r-', alpha=0.1)
#             plt.plot([0,2048], [ypos[bar][1], ypos[bar][1]], 'r-', alpha=0.1)

            mms = np.linspace(4,270.4,2)
            slit = bar_to_slit(bar)
            pix = np.array([(physical_to_pixel(np.array([(mm, slit+0.5)])))[0][0] for mm in mms])
            plt.plot(pix.transpose()[0], pix.transpose()[1], 'g-', alpha=0.5)

            plt.plot([bars[bar],bars[bar]], ypos[bar], 'r-', alpha=0.75)
            offset = {0: -20, 1:+20}[bar % 2]
            plt.text(bars[bar]+offset, np.mean(ypos[bar]), bar,
                     fontsize=8, color='r', alpha=0.75,
                     horizontalalignment='center', verticalalignment='center')
        plt.savefig(str(plotfile), bbox_inches='tight')

    return bars
예제 #11
0
def plot_image(image,
               ax=None,
               scale='log',
               cmap=None,
               origin='lower',
               xlabel=None,
               ylabel=None,
               cbar=None,
               clabel='Flux ($e^{-}s^{-1}$)',
               cbar_ticks=None,
               cbar_ticklabels=None,
               cbar_pad=None,
               cbar_size='5%',
               title=None,
               percentile=95.0,
               vmin=None,
               vmax=None,
               offset_axes=None,
               color_bad='k',
               **kwargs):
    """
	Utility function to plot a 2D image.

	Parameters:
		image (2d array): Image data.
		ax (matplotlib.pyplot.axes, optional): Axes in which to plot.
			Default (None) is to use current active axes.
		scale (str or :py:class:`astropy.visualization.ImageNormalize` object, optional):
			Normalization used to stretch the colormap.
			Options: ``'linear'``, ``'sqrt'``, ``'log'``, ``'asinh'``, ``'histeq'``, ``'sinh'``
			and ``'squared'``.
			Can also be a :py:class:`astropy.visualization.ImageNormalize` object.
			Default is ``'log'``.
		origin (str, optional): The origin of the coordinate system.
		xlabel (str, optional): Label for the x-axis.
		ylabel (str, optional): Label for the y-axis.
		cbar (string, optional): Location of color bar.
			Choises are ``'right'``, ``'left'``, ``'top'``, ``'bottom'``.
			Default is not to create colorbar.
		clabel (str, optional): Label for the color bar.
		cbar_size (float, optional): Fractional size of colorbar compared to axes. Default=0.03.
		cbar_pad (float, optional): Padding between axes and colorbar.
		title (str or None, optional): Title for the plot.
		percentile (float, optional): The fraction of pixels to keep in color-trim.
			If single float given, the same fraction of pixels is eliminated from both ends.
			If tuple of two floats is given, the two are used as the percentiles.
			Default=95.
		cmap (matplotlib colormap, optional): Colormap to use. Default is the ``Blues`` colormap.
		vmin (float, optional): Lower limit to use for colormap.
		vmax (float, optional): Upper limit to use for colormap.
		color_bad (str, optional): Color to apply to bad pixels (NaN). Default is black.
		kwargs (dict, optional): Keyword arguments to be passed to :py:func:`matplotlib.pyplot.imshow`.

	Returns:
		:py:class:`matplotlib.image.AxesImage`: Image from returned
			by :py:func:`matplotlib.pyplot.imshow`.

	.. codeauthor:: Rasmus Handberg <*****@*****.**>
	"""

    logger = logging.getLogger(__name__)

    # Backward compatible settings:
    make_cbar = kwargs.pop('make_cbar', None)
    if make_cbar:
        raise FutureWarning("'make_cbar' is deprecated. Use 'cbar' instead.")
        if not cbar:
            cbar = make_cbar

    # Special treatment for boolean arrays:
    if isinstance(image, np.ndarray) and image.dtype == 'bool':
        if vmin is None: vmin = 0
        if vmax is None: vmax = 1
        if cbar_ticks is None: cbar_ticks = [0, 1]
        if cbar_ticklabels is None: cbar_ticklabels = ['False', 'True']

    # Calculate limits of color scaling:
    interval = None
    if vmin is None or vmax is None:
        if allnan(image):
            logger.warning("Image is all NaN")
            vmin = 0
            vmax = 1
            if cbar_ticks is None:
                cbar_ticks = []
            if cbar_ticklabels is None:
                cbar_ticklabels = []
        elif isinstance(percentile, (list, tuple, np.ndarray)):
            interval = viz.AsymmetricPercentileInterval(
                percentile[0], percentile[1])
        else:
            interval = viz.PercentileInterval(percentile)

    # Create ImageNormalize object with extracted limits:
    if scale in ('log', 'linear', 'sqrt', 'asinh', 'histeq', 'sinh',
                 'squared'):
        if scale == 'log':
            stretch = viz.LogStretch()
        elif scale == 'linear':
            stretch = viz.LinearStretch()
        elif scale == 'sqrt':
            stretch = viz.SqrtStretch()
        elif scale == 'asinh':
            stretch = viz.AsinhStretch()
        elif scale == 'histeq':
            stretch = viz.HistEqStretch(image[np.isfinite(image)])
        elif scale == 'sinh':
            stretch = viz.SinhStretch()
        elif scale == 'squared':
            stretch = viz.SquaredStretch()

        # Create ImageNormalize object. Very important to use clip=False if the image contains
        # NaNs, otherwise NaN points will not be plotted correctly.
        norm = viz.ImageNormalize(data=image[np.isfinite(image)],
                                  interval=interval,
                                  vmin=vmin,
                                  vmax=vmax,
                                  stretch=stretch,
                                  clip=not anynan(image))

    elif isinstance(scale, (viz.ImageNormalize, matplotlib.colors.Normalize)):
        norm = scale
    else:
        raise ValueError("scale {} is not available.".format(scale))

    if offset_axes:
        extent = (offset_axes[0] - 0.5, offset_axes[0] + image.shape[1] - 0.5,
                  offset_axes[1] - 0.5, offset_axes[1] + image.shape[0] - 0.5)
    else:
        extent = (-0.5, image.shape[1] - 0.5, -0.5, image.shape[0] - 0.5)

    if ax is None:
        ax = plt.gca()

    # Set up the colormap to use. If a bad color is defined,
    # add it to the colormap:
    if cmap is None:
        cmap = copy.copy(plt.get_cmap('Blues'))
    elif isinstance(cmap, str):
        cmap = copy.copy(plt.get_cmap(cmap))

    if color_bad:
        cmap.set_bad(color_bad, 1.0)

    # Plotting the image using all the settings set above:
    im = ax.imshow(image,
                   cmap=cmap,
                   norm=norm,
                   origin=origin,
                   extent=extent,
                   interpolation='nearest',
                   **kwargs)

    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)
    if title is not None:
        ax.set_title(title)
    ax.set_xlim([extent[0], extent[1]])
    ax.set_ylim([extent[2], extent[3]])

    if cbar:
        colorbar(im,
                 ax=ax,
                 loc=cbar,
                 size=cbar_size,
                 pad=cbar_pad,
                 label=clabel,
                 ticks=cbar_ticks,
                 ticklabels=cbar_ticklabels)

    # Settings for ticks:
    integer_locator = MaxNLocator(nbins=10, integer=True)
    ax.xaxis.set_major_locator(integer_locator)
    ax.xaxis.set_minor_locator(integer_locator)
    ax.yaxis.set_major_locator(integer_locator)
    ax.yaxis.set_minor_locator(integer_locator)
    ax.tick_params(which='both', direction='out', pad=5)
    ax.xaxis.tick_bottom()
    ax.yaxis.tick_left()

    return im
예제 #12
0
    def plotDirectCutouts(self,
                          savePath=None,
                          colourMap='viridis',
                          gridSpecs=None):
        if self.directCutouts is not None:
            if gridSpecs is None:
                mplplot.figure(figsize=(10, 5))
            for stampIndex, (grism, cutoutData) in enumerate(
                    self.directCutouts.items()):
                if gridSpecs is None:
                    subplotAxes = mplplot.subplot(2, 1, stampIndex + 1)
                else:
                    subplotAxes = mplplot.subplot(gridSpecs[stampIndex])

                if cutoutData is None:
                    subplotAxes.text(
                        0.5,
                        0.5,
                        'Field {}, Object {}:\nNO DATA AVAILABLE.'.format(
                            self.targetPar, self.targetObject),
                        horizontalalignment='center',
                        fontsize='large',
                        transform=subplotAxes.transAxes)
                    continue
                if np.all(cutoutData < 0):
                    subplotAxes.text(
                        0.5,
                        0.5,
                        'Field {}, Object {}:\nNO NONZERO DATA AVAILABLE.'.
                        format(self.targetPar, self.targetObject),
                        horizontalalignment='center',
                        fontsize='large',
                        transform=subplotAxes.transAxes)
                    continue

                norm = astromplnorm.ImageNormalize(
                    cutoutData,
                    interval=astrovis.AsymmetricPercentileInterval(0, 99.5),
                    stretch=astrovis.LinearStretch(),
                    clip=True)

                mplplot.imshow(cutoutData,
                               origin='lower',
                               interpolation='nearest',
                               cmap=colourMap,
                               norm=norm)

                mplplot.xlabel('X (pixels)')
                mplplot.ylabel('Y (pixels)')
                mplplot.title(
                    'Field {}, Object {}:\nDirect cutout for F{} (G{})'.format(
                        self.targetPar, self.targetObject,
                        self.getDirectFilterForGrism(grism), grism))
                arcsecYAxis = subplotAxes.twinx()
                arcsecYAxis.set_ylim(*(np.array(subplotAxes.get_ylim()) -
                                       0.5 * np.sum(subplotAxes.get_ylim())) *
                                     self.directHdus[grism][1]['IDCSCALE'])
                print(
                    subplotAxes.get_ylim(), np.array(subplotAxes.get_ylim()),
                    np.array(subplotAxes.get_ylim()) *
                    self.directHdus[grism][1]['IDCSCALE'],
                    *np.array(subplotAxes.get_ylim()) *
                    self.directHdus[grism][1]['IDCSCALE'])
                arcsecYAxis.set_ylabel('$\Delta Y$ (arcsec)')
                mplplot.grid(color='white', ls='solid')

            try:
                mplplot.tight_layout()
            except ValueError as e:
                print(
                    'Error attempting tight_layout for: Field {}, Object {} ({})'
                    .format(self.targetPar, self.targetObject, e))
                return
            if savePath is not None:
                mplplot.savefig(savePath, dpi=300, bbox_inches='tight')
                mplplot.close()

        else:
            print(
                'The loadDirectCutouts(...) method must be called before direct cutouts can be plotted.'
            )
    def makeSlitIllum(self, adinputs=None, **params):
        """
        Makes the processed Slit Illumination Function by binning a 2D
        spectrum along the dispersion direction, fitting a smooth function
        for each bin, fitting a smooth 2D model, and reconstructing the 2D
        array using this last model.

        Its implementation based on the IRAF's `noao.twodspec.longslit.illumination`
        task following the algorithm described in [Valdes, 1968].

        It expects an input calibration image to be an a dispersed image of the
        slit without illumination problems (e.g, twilight flat). The spectra is
        not required to be smooth in wavelength and may contain strong emission
        and absorption lines. The image should contain a `.mask` attribute in
        each extension, and it is expected to be overscan and bias corrected.

        Parameters
        ----------
        adinputs : list
            List of AstroData objects containing the dispersed image of the
            slit of a source free of illumination problems. The data needs to
            have been overscan and bias corrected and is expected to have a
            Data Quality mask.
        bins : {None, int}, optional
            Total number of bins across the dispersion axis. If None,
            the number of bins will match the number of extensions on each
            input AstroData object. It it is an int, it will create N bins
            with the same size.
        border : int, optional
            Border size that is added on every edge of the slit illumination
            image before cutting it down to the input AstroData frame.
        smooth_order : int, optional
            Order of the spline that is used in each bin fitting to smooth
            the data (Default: 3)
        x_order : int, optional
            Order of the x-component in the Chebyshev2D model used to
            reconstruct the 2D data from the binned data.
        y_order : int, optional
            Order of the y-component in the Chebyshev2D model used to
            reconstruct the 2D data from the binned data.

        Return
        ------
        List of AstroData : containing an AstroData with the Slit Illumination
            Response Function for each of the input object.

        References
        ----------
        .. [Valdes, 1968] Francisco Valdes "Reduction Of Long Slit Spectra With
           IRAF", Proc. SPIE 0627, Instrumentation in Astronomy VI,
           (13 October 1986); https://doi.org/10.1117/12.968155
        """
        log = self.log
        log.debug(gt.log_message("primitive", self.myself(), "starting"))
        timestamp_key = self.timestamp_keys[self.myself()]

        suffix = params["suffix"]
        bins = params["bins"]
        border = params["border"]
        debug_plot = params["debug_plot"]
        smooth_order = params["smooth_order"]
        cheb2d_x_order = params["x_order"]
        cheb2d_y_order = params["y_order"]

        ad_outputs = []
        for ad in adinputs:

            if len(ad) > 1 and "mosaic" not in ad[0].wcs.available_frames:

                log.info('Add "mosaic" gWCS frame to input data')
                geotable = import_module('.geometry_conf', self.inst_lookups)

                # deepcopy prevents modifying input `ad` inplace
                ad = transform.add_mosaic_wcs(deepcopy(ad), geotable)

                log.info("Temporarily mosaicking multi-extension file")
                mosaicked_ad = transform.resample_from_wcs(
                    ad,
                    "mosaic",
                    attributes=None,
                    order=1,
                    process_objcat=False)

            else:

                log.info('Input data already has one extension and has a '
                         '"mosaic" frame.')

                # deepcopy prevents modifying input `ad` inplace
                mosaicked_ad = deepcopy(ad)

            log.info("Transposing data if needed")
            dispaxis = 2 - mosaicked_ad[0].dispersion_axis()  # python sense
            should_transpose = dispaxis == 1

            data, mask, variance = _transpose_if_needed(
                mosaicked_ad[0].data,
                mosaicked_ad[0].mask,
                mosaicked_ad[0].variance,
                transpose=should_transpose)

            log.info("Masking data")
            data = np.ma.masked_array(data, mask=mask)
            variance = np.ma.masked_array(variance, mask=mask)
            std = np.sqrt(variance)  # Easier to work with

            log.info("Creating bins for data and variance")
            height = data.shape[0]
            width = data.shape[1]

            if bins is None:
                nbins = max(len(ad), 12)
                bin_limits = np.linspace(0, height, nbins + 1, dtype=int)
            elif isinstance(bins, int):
                nbins = bins
                bin_limits = np.linspace(0, height, nbins + 1, dtype=int)
            else:
                # ToDo: Handle input bins as array
                raise TypeError("Expected None or Int for `bins`. "
                                "Found: {}".format(type(bins)))

            bin_top = bin_limits[1:]
            bin_bot = bin_limits[:-1]
            binned_data = np.zeros_like(data)
            binned_std = np.zeros_like(std)

            log.info("Smooth binned data and variance, and normalize them by "
                     "smoothed central value")
            for bin_idx, (b0, b1) in enumerate(zip(bin_bot, bin_top)):

                rows = np.arange(width)

                avg_data = np.ma.mean(data[b0:b1], axis=0)
                model_1d_data = astromodels.UnivariateSplineWithOutlierRemoval(
                    rows, avg_data, order=smooth_order)

                avg_std = np.ma.mean(std[b0:b1], axis=0)
                model_1d_std = astromodels.UnivariateSplineWithOutlierRemoval(
                    rows, avg_std, order=smooth_order)

                slit_central_value = model_1d_data(rows)[width // 2]
                binned_data[b0:b1] = model_1d_data(rows) / slit_central_value
                binned_std[b0:b1] = model_1d_std(rows) / slit_central_value

            log.info("Reconstruct 2D mosaicked data")
            bin_center = np.array(0.5 * (bin_bot + bin_top), dtype=int)
            cols_fit, rows_fit = np.meshgrid(np.arange(width), bin_center)

            fitter = fitting.SLSQPLSQFitter()
            model_2d_init = models.Chebyshev2D(x_degree=cheb2d_x_order,
                                               x_domain=(0, width),
                                               y_degree=cheb2d_y_order,
                                               y_domain=(0, height))

            model_2d_data = fitter(model_2d_init, cols_fit, rows_fit,
                                   binned_data[rows_fit, cols_fit])

            model_2d_std = fitter(model_2d_init, cols_fit, rows_fit,
                                  binned_std[rows_fit, cols_fit])

            rows_val, cols_val = \
                np.mgrid[-border:height+border, -border:width+border]

            slit_response_data = model_2d_data(cols_val, rows_val)
            slit_response_mask = np.pad(
                mask, border, mode='edge')  # ToDo: any update to the mask?
            slit_response_std = model_2d_std(cols_val, rows_val)
            slit_response_var = slit_response_std**2

            del cols_fit, cols_val, rows_fit, rows_val

            _data, _mask, _variance = _transpose_if_needed(
                slit_response_data,
                slit_response_mask,
                slit_response_var,
                transpose=dispaxis == 1)

            log.info("Update slit response data and data_section")
            slit_response_ad = deepcopy(mosaicked_ad)
            slit_response_ad[0].data = _data
            slit_response_ad[0].mask = _mask
            slit_response_ad[0].variance = _variance

            if "mosaic" in ad[0].wcs.available_frames:

                log.info(
                    "Map coordinates between slit function and mosaicked data"
                )  # ToDo: Improve message?
                slit_response_ad = _split_mosaic_into_extensions(
                    ad, slit_response_ad, border_size=border)

            elif len(ad) == 1:

                log.info("Trim out borders")

                slit_response_ad[0].data = \
                    slit_response_ad[0].data[border:-border, border:-border]
                slit_response_ad[0].mask = \
                    slit_response_ad[0].mask[border:-border, border:-border]
                slit_response_ad[0].variance = \
                    slit_response_ad[0].variance[border:-border, border:-border]

            log.info("Update metadata and filename")
            gt.mark_history(slit_response_ad,
                            primname=self.myself(),
                            keyword=timestamp_key)

            slit_response_ad.update_filename(suffix=suffix, strip=True)
            ad_outputs.append(slit_response_ad)

            # Plotting ------
            if debug_plot:

                log.info("Creating plots")
                palette = copy(plt.cm.cividis)
                palette.set_bad('r', 0.75)

                norm = vis.ImageNormalize(data[~data.mask],
                                          stretch=vis.LinearStretch(),
                                          interval=vis.PercentileInterval(97))

                fig = plt.figure(num="Slit Response from MEF - {}".format(
                    ad.filename),
                                 figsize=(12, 9),
                                 dpi=110)

                gs = gridspec.GridSpec(nrows=2, ncols=3, figure=fig)

                # Display raw mosaicked data and its bins ---
                ax1 = fig.add_subplot(gs[0, 0])
                im1 = ax1.imshow(data,
                                 cmap=palette,
                                 origin='lower',
                                 vmin=norm.vmin,
                                 vmax=norm.vmax)

                ax1.set_title("Mosaicked Data\n and Spectral Bins",
                              fontsize=10)
                ax1.set_xlim(-1, data.shape[1])
                ax1.set_xticks([])
                ax1.set_ylim(-1, data.shape[0])
                ax1.set_yticks(bin_center)
                ax1.tick_params(axis=u'both', which=u'both', length=0)

                ax1.set_yticklabels(
                    ["Bin {}".format(i) for i in range(len(bin_center))],
                    fontsize=6)

                _ = [ax1.spines[s].set_visible(False) for s in ax1.spines]
                _ = [ax1.axhline(b, c='w', lw=0.5) for b in bin_limits]

                divider = make_axes_locatable(ax1)
                cax1 = divider.append_axes("right", size="5%", pad=0.05)
                plt.colorbar(im1, cax=cax1)

                # Display non-smoothed bins ---
                ax2 = fig.add_subplot(gs[0, 1])
                im2 = ax2.imshow(binned_data, cmap=palette, origin='lower')

                ax2.set_title("Binned, smoothed\n and normalized data ",
                              fontsize=10)
                ax2.set_xlim(0, data.shape[1])
                ax2.set_xticks([])
                ax2.set_ylim(0, data.shape[0])
                ax2.set_yticks(bin_center)
                ax2.tick_params(axis=u'both', which=u'both', length=0)

                ax2.set_yticklabels(
                    ["Bin {}".format(i) for i in range(len(bin_center))],
                    fontsize=6)

                _ = [ax2.spines[s].set_visible(False) for s in ax2.spines]
                _ = [ax2.axhline(b, c='w', lw=0.5) for b in bin_limits]

                divider = make_axes_locatable(ax2)
                cax2 = divider.append_axes("right", size="5%", pad=0.05)
                plt.colorbar(im2, cax=cax2)

                # Display reconstructed slit response ---
                vmin = slit_response_data.min()
                vmax = slit_response_data.max()

                ax3 = fig.add_subplot(gs[1, 0])
                im3 = ax3.imshow(slit_response_data,
                                 cmap=palette,
                                 origin='lower',
                                 vmin=vmin,
                                 vmax=vmax)

                ax3.set_title("Reconstructed\n Slit response", fontsize=10)
                ax3.set_xlim(0, data.shape[1])
                ax3.set_xticks([])
                ax3.set_ylim(0, data.shape[0])
                ax3.set_yticks([])
                ax3.tick_params(axis=u'both', which=u'both', length=0)
                _ = [ax3.spines[s].set_visible(False) for s in ax3.spines]

                divider = make_axes_locatable(ax3)
                cax3 = divider.append_axes("right", size="5%", pad=0.05)
                plt.colorbar(im3, cax=cax3)

                # Display extensions ---
                ax4 = fig.add_subplot(gs[1, 1])
                ax4.set_xticks([])
                ax4.set_yticks([])
                _ = [ax4.spines[s].set_visible(False) for s in ax4.spines]

                sub_gs4 = gridspec.GridSpecFromSubplotSpec(nrows=len(ad),
                                                           ncols=1,
                                                           subplot_spec=gs[1,
                                                                           1],
                                                           hspace=0.03)

                # The [::-1] is needed to put the fist extension in the bottom
                for i, ext in enumerate(slit_response_ad[::-1]):

                    ext_data, ext_mask, ext_variance = _transpose_if_needed(
                        ext.data,
                        ext.mask,
                        ext.variance,
                        transpose=dispaxis == 1)

                    ext_data = np.ma.masked_array(ext_data, mask=ext_mask)

                    sub_ax = fig.add_subplot(sub_gs4[i])

                    im4 = sub_ax.imshow(ext_data,
                                        origin="lower",
                                        vmin=vmin,
                                        vmax=vmax,
                                        cmap=palette)

                    sub_ax.set_xlim(0, ext_data.shape[1])
                    sub_ax.set_xticks([])
                    sub_ax.set_ylim(0, ext_data.shape[0])
                    sub_ax.set_yticks([ext_data.shape[0] // 2])

                    sub_ax.set_yticklabels(
                        ["Ext {}".format(len(slit_response_ad) - i - 1)],
                        fontsize=6)

                    _ = [
                        sub_ax.spines[s].set_visible(False)
                        for s in sub_ax.spines
                    ]

                    if i == 0:
                        sub_ax.set_title(
                            "Multi-extension\n Slit Response Function")

                divider = make_axes_locatable(ax4)
                cax4 = divider.append_axes("right", size="5%", pad=0.05)
                plt.colorbar(im4, cax=cax4)

                # Display Signal-To-Noise Ratio ---
                snr = data / np.sqrt(variance)

                norm = vis.ImageNormalize(snr[~snr.mask],
                                          stretch=vis.LinearStretch(),
                                          interval=vis.PercentileInterval(97))

                ax5 = fig.add_subplot(gs[0, 2])

                im5 = ax5.imshow(snr,
                                 cmap=palette,
                                 origin='lower',
                                 vmin=norm.vmin,
                                 vmax=norm.vmax)

                ax5.set_title("Mosaicked Data SNR", fontsize=10)
                ax5.set_xlim(-1, data.shape[1])
                ax5.set_xticks([])
                ax5.set_ylim(-1, data.shape[0])
                ax5.set_yticks(bin_center)
                ax5.tick_params(axis=u'both', which=u'both', length=0)

                ax5.set_yticklabels(
                    ["Bin {}".format(i) for i in range(len(bin_center))],
                    fontsize=6)

                _ = [ax5.spines[s].set_visible(False) for s in ax5.spines]
                _ = [ax5.axhline(b, c='w', lw=0.5) for b in bin_limits]

                divider = make_axes_locatable(ax5)
                cax5 = divider.append_axes("right", size="5%", pad=0.05)
                plt.colorbar(im5, cax=cax5)

                # Display Signal-To-Noise Ratio of Slit Illumination ---
                slit_response_snr = np.ma.masked_array(
                    slit_response_data / np.sqrt(slit_response_var),
                    mask=slit_response_mask)

                ax6 = fig.add_subplot(gs[1, 2])

                im6 = ax6.imshow(slit_response_snr,
                                 origin="lower",
                                 vmin=norm.vmin,
                                 vmax=norm.vmax,
                                 cmap=palette)

                ax6.set_xlim(0, slit_response_snr.shape[1])
                ax6.set_xticks([])
                ax6.set_ylim(0, slit_response_snr.shape[0])
                ax6.set_yticks([])
                ax6.set_title("Reconstructed\n Slit Response SNR")

                _ = [ax6.spines[s].set_visible(False) for s in ax6.spines]

                divider = make_axes_locatable(ax6)
                cax6 = divider.append_axes("right", size="5%", pad=0.05)
                plt.colorbar(im6, cax=cax6)

                # Save plots ---
                fig.tight_layout(rect=[0, 0, 0.95, 1], pad=0.5)
                fname = slit_response_ad.filename.replace(".fits", ".png")
                log.info("Saving plots to {}".format(fname))
                plt.savefig(fname)

        return ad_outputs
예제 #14
0
    def create_figure(self,
                      output_filename,
                      survey,
                      stretch='log',
                      vmin=1,
                      vmax=None,
                      min_percent=1,
                      max_percent=95,
                      cmap='gray',
                      contour_color='red',
                      data_col='FLUX'):
        """Returns a matplotlib Figure object that visualizes a frame.

        Parameters
        ----------

        vmin : float, optional
            Minimum cut level (default: 0).

        vmax : float, optional
            Maximum cut level (default: 5000).

        cmap : str, optional
            The matplotlib color map name.  The default is 'gray',
            can also be e.g. 'gist_heat'.

        raw : boolean, optional
            If `True`, show the raw pixel counts rather than
            the calibrated flux. Default: `False`.

        Returns
        -------
        image : array
            An array of unisgned integers of shape (x, y, 3),
            representing an RBG colour image x px wide and y px high.
        """
        # Get the flux data to visualize
        # Update to use TPF
        flx = self.TPF.flux_binned()
        # print(np.shape(flx))

        # calculate cut_levels
        if vmax is None:
            vmin, vmax = self.cut_levels(min_percent, max_percent, data_col)

        # Determine the figsize
        shape = list(flx.shape)
        # print(shape)
        # Create the figure and display the flux image using matshow
        fig = plt.figure(figsize=shape)
        # Display the image using matshow

        # Update to generate axes using WCS axes instead of plain axes
        ax = plt.subplot(projection=self.TPF.wcs)
        ax.set_xlabel('RA')
        ax.set_ylabel('Dec')

        if self.verbose:
            print('{} vmin/vmax = {}/{} (median={})'.format(
                data_col, vmin, vmax, np.nanmedian(flx)))

        if stretch == 'linear':
            stretch_fn = visualization.LinearStretch()
        elif stretch == 'sqrt':
            stretch_fn = visualization.SqrtStretch()
        elif stretch == 'power':
            stretch_fn = visualization.PowerStretch(1.0)
        elif stretch == 'log':
            stretch_fn = visualization.LogStretch()
        elif stretch == 'asinh':
            stretch_fn = visualization.AsinhStretch(0.1)
        else:
            raise ValueError('Unknown stretch: {0}.'.format(stretch))

        transform = (stretch_fn +
                     visualization.ManualInterval(vmin=vmin, vmax=vmax))
        ax.imshow((255 * transform(flx)).astype(int),
                  aspect='auto',
                  origin='lower',
                  interpolation='nearest',
                  cmap=cmap,
                  norm=NoNorm())
        ax.set_xticks([])
        ax.set_yticks([])

        current_ylims = ax.get_ylim()
        current_xlims = ax.get_xlim()

        pixels, header = surveyquery.getSVImg(self.TPF.position, survey)
        levels = np.linspace(np.min(pixels), np.percentile(pixels, 95), 10)
        ax.contour(pixels,
                   transform=ax.get_transform(WCS(header)),
                   levels=levels,
                   colors=contour_color)

        ax.set_xlim(current_xlims)
        ax.set_ylim(current_ylims)

        fig.canvas.draw()
        plt.savefig(output_filename, bbox_inches='tight', dpi=300)
        return fig
예제 #15
0
def make_fov_image(fov, pngfn=None, **kwargs):
    stretch = kwargs.get('stretch', 'linear')
    interval = kwargs.get('interval', 'zscale')
    imrange = kwargs.get('imrange')
    contrast = kwargs.get('contrast', 0.25)
    ccdplotorder = ['CCD2', 'CCD4', 'CCD1', 'CCD3']
    if interval == 'rms':
        try:
            losig, hisig = imrange
        except:
            losig, hisig = (2.5, 5.0)
    #
    cmap = kwargs.get('cmap', 'viridis')
    cmap = plt.get_cmap(cmap)
    cmap.set_bad('w', 1.0)
    w = 0.4575
    h = 0.455
    rc('text', usetex=False)
    fig = plt.figure(figsize=(6, 6.5))
    cax = fig.add_axes([0.1, 0.04, 0.8, 0.01])
    ims = [fov[ccd]['im'] for ccd in ccdplotorder]
    allpix = np.ma.array(ims).flatten()
    stretch = {
        'linear': vis.LinearStretch(),
        'histeq': vis.HistEqStretch(allpix),
        'asinh': vis.AsinhStretch(),
    }[stretch]
    if interval == 'zscale':
        iv = vis.ZScaleInterval(contrast=contrast)
        vmin, vmax = iv.get_limits(allpix)
    elif interval == 'rms':
        nsample = 1000 // nbin
        background = sigma_clip(allpix[::nsample], iters=3, sigma=2.2)
        m, s = background.mean(), background.std()
        vmin, vmax = m - losig * s, m + hisig * s
    elif interval == 'fixed':
        vmin, vmax = imrange
    else:
        raise ValueError
    norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=stretch)
    for n, (im, ccd) in enumerate(zip(ims, ccdplotorder)):
        if im.ndim == 3:
            im = im.mean(axis=-1)
        x = fov[ccd]['x']
        y = fov[ccd]['y']
        i = n % 2
        j = n // 2
        pos = [0.0225 + i * w + i * 0.04, 0.05 + j * h + j * 0.005, w, h]
        ax = fig.add_axes(pos)
        _im = ax.imshow(im,
                        origin='lower',
                        extent=[x[0, 0], x[0, -1], y[0, 0], y[-1, 0]],
                        norm=norm,
                        cmap=cmap,
                        interpolation=kwargs.get('interpolation', 'nearest'))
        if fov['coordsys'] == 'sky':
            ax.set_xlim(x.max(), x.min())
        else:
            ax.set_xlim(x.min(), x.max())
        ax.set_ylim(y.min(), y.max())
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        if n == 0:
            cb = fig.colorbar(_im, cax, orientation='horizontal')
            cb.ax.tick_params(labelsize=9)
    tstr = fov.get('file', '') + ' ' + fov.get('objname', '')
    title = kwargs.get('title', tstr)
    title = title[-60:]
    fig.text(0.5, 0.99, title, ha='center', va='top', size=12)
    if pngfn is not None:
        plt.savefig(pngfn)
        plt.close(fig)
def main():

    # filename = 'S20170505S0102_flatCorrected.fits'
    filename = get_filename()

    ad = astrodata.open(filename)
    print(ad.info())

    fig = plt.figure(num=filename, figsize=(8, 8))
    fig.suptitle('{}'.format(filename))

    palette = copy(plt.cm.viridis)
    palette.set_bad('w', 1.0)

    norm = visualization.ImageNormalize(
        np.dstack([ad[i].data for i in range(4)]),
        stretch=visualization.LinearStretch(),
        interval=visualization.ZScaleInterval())

    ax1 = fig.add_subplot(224)
    ax1.imshow(np.ma.masked_where(ad[0].mask > 0, ad[0].data),
               norm=colors.Normalize(vmin=norm.vmin, vmax=norm.vmax),
               origin='lower',
               cmap=palette)

    ax1.annotate('d1', (20, 20), color='white')
    ax1.set_xlabel('x [pixels]')
    ax1.set_ylabel('y [pixels]')

    ax2 = fig.add_subplot(223)
    ax2.imshow(np.ma.masked_where(ad[1].mask > 0, ad[1].data),
               norm=colors.Normalize(vmin=norm.vmin, vmax=norm.vmax),
               origin='lower',
               cmap=palette)

    ax2.annotate('d2', (20, 20), color='white')
    ax2.set_xlabel('x [pixels]')
    ax2.set_ylabel('y [pixels]')

    ax3 = fig.add_subplot(221)
    ax3.imshow(np.ma.masked_where(ad[2].mask > 0, ad[2].data),
               norm=colors.Normalize(vmin=norm.vmin, vmax=norm.vmax),
               origin='lower',
               cmap=palette)

    ax3.annotate('d3', (20, 20), color='white')
    ax3.set_xlabel('x [pixels]')
    ax3.set_ylabel('y [pixels]')

    ax4 = fig.add_subplot(222)
    ax4.imshow(np.ma.masked_where(ad[3].mask > 0, ad[3].data),
               norm=colors.Normalize(vmin=norm.vmin, vmax=norm.vmax),
               origin='lower',
               cmap=palette)

    ax4.annotate('d4', (20, 20), color='white')
    ax4.set_xlabel('x [pixels]')
    ax4.set_ylabel('y [pixels]')

    fig.tight_layout(rect=[0, 0.03, 1, 0.95])

    plt.savefig(filename.replace('.fits', '.png'))
    plt.show()
    pl.imshow(rgb, origin='lower', interpolation='none', norm=norm)

    (x1, x2), (y1, y2) = celwcs.wcs_world2pix(lims[0], lims[1], 0)
    ax.axis((x1, x2, y1, y2))

    visualization_tools.make_scalebar(ax,
                                      left_side=scalebarx,
                                      length=1.213 * u.arcsec,
                                      label='0.05 pc')

    pl.savefig(paths.fpath(figfilename), bbox_inches='tight')


if __name__ == "__main__":

    rgbfig(stretch=visualization.LinearStretch(), )

    rgbfig(
        figfilename='SgrB2M_RGB.pdf',
        lims=[(266.8359, 266.8325), (-28.38600555, -28.3832)],
        redfn=paths.Fpath('SGRB2M-2012-Q-MEAN.DePree.recentered.fits'),
        greenfn=paths.Fpath(
            'sgr_b2m.M.B3.allspw.continuum.r0.5.clean1000.image.tt0.pbcor.fits'
        ),
        bluefn=paths.Fpath(
            'sgr_b2m.M.B6.allspw.continuum.r0.5.clean1000.image.tt0.pbcor.fits'
        ),
        scalebarx=coordinates.SkyCoord(266.8336007 * u.deg,
                                       -28.38553839 * u.deg),
        redpercentile=99.99,
        greenpercentile=99.98,
예제 #18
0
 def set_normalization(self,
                       stretch=None,
                       interval=None,
                       stretchkwargs={},
                       intervalkwargs={},
                       perm_linear=None):
     if stretch is None:
         if self.stretch is None:
             stretch = 'linear'
         else:
             stretch = self.stretch
     if isinstance(stretch, str):
         print(stretch,
               ' '.join([f'{k}={v}' for k, v in stretchkwargs.items()]))
         if self.data is None:  #can not calculate objects yet
             self.stretch_kwargs = stretchkwargs
         else:
             kwargs = self.prepare_kwargs(
                 self.stretch_kws_defaults[stretch], self.stretch_kwargs,
                 stretchkwargs)
             if perm_linear is not None:
                 perm_linear_kwargs = self.prepare_kwargs(
                     self.stretch_kws_defaults['linear'], perm_linear)
                 print(
                     'linear', ' '.join([
                         f'{k}={v}' for k, v in perm_linear_kwargs.items()
                     ]))
                 if stretch == 'asinh':  # arg: a=0.1
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.AsinhStretch(**kwargs))
                 elif stretch == 'contrastbias':  # args: contrast, bias
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.ContrastBiasStretch(**kwargs))
                 elif stretch == 'histogram':
                     stretch = vis.CompositeStretch(
                         vis.HistEqStretch(self.data, **kwargs),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'log':  # args: a=1000.0
                     stretch = vis.CompositeStretch(
                         vis.LogStretch(**kwargs),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'powerdist':  # args: a=1000.0
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.PowerDistStretch(**kwargs))
                 elif stretch == 'power':  # args: a
                     stretch = vis.CompositeStretch(
                         vis.PowerStretch(**kwargs),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'sinh':  # args: a=0.33
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.SinhStretch(**kwargs))
                 elif stretch == 'sqrt':
                     stretch = vis.CompositeStretch(
                         vis.SqrtStretch(),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'square':
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.SquaredStretch())
                 else:
                     raise ValueError('Unknown stretch:' + stretch)
             else:
                 if stretch == 'linear':  # args: slope=1, intercept=0
                     stretch = vis.LinearStretch(**kwargs)
                 else:
                     raise ValueError('Unknown stretch:' + stretch)
     self.stretch = stretch
     if interval is None:
         if self.interval is None:
             interval = 'zscale'
         else:
             interval = self.interval
     if isinstance(interval, str):
         print(interval,
               ' '.join([f'{k}={v}' for k, v in intervalkwargs.items()]))
         kwargs = self.prepare_kwargs(self.interval_kws_defaults[interval],
                                      self.interval_kwargs, intervalkwargs)
         if self.data is None:
             self.interval_kwargs = intervalkwargs
         else:
             if interval == 'minmax':
                 interval = vis.MinMaxInterval()
             elif interval == 'manual':  # args: vmin, vmax
                 interval = vis.ManualInterval(**kwargs)
             elif interval == 'percentile':  # args: percentile, n_samples
                 interval = vis.PercentileInterval(**kwargs)
             elif interval == 'asymetric':  # args: lower_percentile, upper_percentile, n_samples
                 interval = vis.AsymmetricPercentileInterval(**kwargs)
             elif interval == 'zscale':  # args: nsamples=1000, contrast=0.25, max_reject=0.5, min_npixels=5, krej=2.5, max_iterations=5
                 interval = vis.ZScaleInterval(**kwargs)
             else:
                 raise ValueError('Unknown interval:' + interval)
     self.interval = interval
     if self.img is not None:
         self.img.set_norm(
             vis.ImageNormalize(self.data,
                                interval=self.interval,
                                stretch=self.stretch,
                                clip=True))
예제 #19
0
def show_image(image,
               percl=99,
               percu=None,
               figsize=(6, 10),
               cmap='viridis',
               log=False):
    """
    Show an image in matplotlib with some basic astronomically-appropriat stretching.
    
    Parameters
    ----------
    image
        The image to show 
    percl : number
        The percentile for the lower edge of the stretch (or both edges if ``percu`` is None)
    percu : number or None
        The percentile for the upper edge of the stretch (or None to use ``percl`` for both)
    figsize : 2-tuple
        The size of the matplotlib figure in inches
    """
    if percu is None:
        percu = percl
        percl = 100 - percl

    if figsize is not None:
        # Rescale the fig size to match the image dimensions, roughly
        image_aspect_ratio = image.shape[0] / image.shape[1]
        figsize = (max(figsize) * image_aspect_ratio, max(figsize))

    fig, ax = plt.subplots(1, 1, figsize=figsize)

    # To preserve details we should *really* downsample correctly and not rely on
    # matplotlib to do it correctly for us (it won't).

    # So, calculate the size of the figure in pixels, block_reduce to roughly that,
    # and display the block reduced image.

    # Thanks, https://stackoverflow.com/questions/29702424/how-to-get-matplotlib-figure-size
    fig_size_pix = fig.get_size_inches() * fig.dpi

    ratio = (image.shape // fig_size_pix).max()

    if ratio < 1:
        ratio = 1

    # Divide by the square of the ratio to keep the flux the same in the reduced image
    reduced_data = block_reduce(image, ratio) / ratio**2

    # Of course, now that we have downsampled, the axis limits are changed to match
    # the smaller image size. Setting the extent will do the trick to change the axis display
    # back to showing the actual extent of the image.
    extent = [0, image.shape[1], 0, image.shape[0]]
    if log:
        stretch = aviz.LogStretch()
    else:
        stretch = aviz.LinearStretch()
    norm = aviz.ImageNormalize(reduced_data,
                               interval=aviz.AsymmetricPercentileInterval(
                                   percl, percu),
                               stretch=stretch)

    plt.colorbar(
        ax.imshow(reduced_data,
                  norm=norm,
                  origin='lower',
                  cmap=cmap,
                  extent=extent))