Beispiel #1
0
def plotcontrast(img):
    """
    http://docs.astropy.org/en/stable/visualization/index.html
    """
    vistypes = (None, LogNorm(), vis.AsinhStretch(),
                vis.ContrastBiasStretch(1, 0.5), vis.HistEqStretch(img),
                vis.LinearStretch(), vis.LogStretch(),
                vis.PowerDistStretch(a=10.), vis.PowerStretch(a=10.),
                vis.SinhStretch(), vis.SqrtStretch(), vis.SquaredStretch())

    fg, ax = subplots(4, 3, figsize=(10, 10))
    ax = ax.ravel()

    for i, v in enumerate(vistypes):
        #a = figure().gca()
        a = ax[i]
        if v and not isinstance(v, LogNorm):
            norm = ImageNormalize(stretch=v)
            a.set_title(str(v.__class__).split('.')[-1].split("'")[0])
        else:
            norm = v
            a.set_title(str(v).split('.')[-1].split(" ")[0])

        a.imshow(img, origin='lower', cmap='gray', norm=norm)
        a.axis('off')

    fg.suptitle('Matplotlib/AstroPy normalizations')
Beispiel #2
0
def uvtracks_airydisk2D(tel_tracks, veritas_tels, baselines, airy_func,
                        guess_r, wavelength, save_dir, star_name):
    x_0 = int(np.max(np.abs(tel_tracks)) * 1.2)
    y_0 = int(np.max(np.abs(tel_tracks)) * 1.2)
    airy_disk, airy_funcd = IImodels.airy_disk2D(shape=(x_0, y_0),
                                                 xpos=x_0,
                                                 ypos=y_0,
                                                 angdiam=1.22 * wavelength /
                                                 airy_func.radius.value,
                                                 wavelength=wavelength)
    y, x = np.mgrid[:x_0 * 2, :y_0 * 2]
    y, x = np.mgrid[:x_0 * 2, :y_0 * 2]
    airy_disk = airy_funcd(x, y)
    fig = plt.figure(figsize=(18, 12))

    plt.imshow(airy_disk,
               norm=viz.ImageNormalize(airy_disk, stretch=viz.LogStretch()),
               extent=[-x_0, x_0, -y_0, y_0],
               cmap='gray')
    for i, track in enumerate(tel_tracks):
        plt.plot(track[0][:, 0], track[0][:, 1], linewidth=6, color='b')
        # plt.text(track[0][:, 0][5], track[0][:, 1][5], "Baseline %s" % (baselines[i]), fontsize=14, color='w')
        plt.plot(track[1][:, 0], track[1][:, 1], linewidth=6, color='b')
        # plt.text(track[1][:, 0][5], track[1][:, 1][5], "Baseline %s" % (-baselines[i]), fontsize=14, color='w')
    starttime = veritas_tels.time_info.T + veritas_tels.observable_times[0]
    endtime = veritas_tels.time_info.T + veritas_tels.observable_times[-1]
    title = "Coverage of %s at VERITAS \non %s UTC" % (
        star_name, veritas_tels.time_info.T)
    # plt.title(star_name, fontsize=28)

    plt.xlabel("U (m)", fontsize=36)
    plt.ylabel("V (m)", fontsize=36)
    plt.tick_params(axis='both',
                    which='major',
                    labelsize=28,
                    length=10,
                    width=4)
    plt.tick_params(axis='both',
                    which='minor',
                    labelsize=28,
                    length=10,
                    width=4)
    plt.tick_params(which="major", labelsize=24, length=8, width=3)
    plt.tick_params(which="minor", length=6, width=2)
    cbar = plt.colorbar()
    cbar.ax.tick_params(labelsize=24, length=6, width=3)

    if save_dir:
        graph_saver(
            save_dir,
            "CoverageOf%sOn%sUTC" % (star_name, veritas_tels.time_info.T))
    else:
        plt.show()
Beispiel #3
0
    def plot_norm(self,
                  stretch='linear',
                  power=1.0,
                  asinh_a=0.1,
                  min_cut=None,
                  max_cut=None,
                  min_percent=None,
                  max_percent=None,
                  percent=None,
                  clip=True):
        """Create a matplotlib norm object for plotting.

        This is a copy of this function that will be available in Astropy 1.3:
        `astropy.visualization.mpl_normalize.simple_norm`

        See the parameter description there!

        Examples
        --------
        >>> image = SkyImage()
        >>> norm = image.plot_norm(stretch='sqrt', max_percent=99)
        >>> image.plot(norm=norm)
        """
        import astropy.visualization as v
        from astropy.visualization.mpl_normalize import ImageNormalize

        if percent is not None:
            interval = v.PercentileInterval(percent)
        elif min_percent is not None or max_percent is not None:
            interval = v.AsymmetricPercentileInterval(min_percent or 0.,
                                                      max_percent or 100.)
        elif min_cut is not None or max_cut is not None:
            interval = v.ManualInterval(min_cut, max_cut)
        else:
            interval = v.MinMaxInterval()

        if stretch == 'linear':
            stretch = v.LinearStretch()
        elif stretch == 'sqrt':
            stretch = v.SqrtStretch()
        elif stretch == 'power':
            stretch = v.PowerStretch(power)
        elif stretch == 'log':
            stretch = v.LogStretch()
        elif stretch == 'asinh':
            stretch = v.AsinhStretch(asinh_a)
        else:
            raise ValueError('Unknown stretch: {0}.'.format(stretch))

        vmin, vmax = interval.get_limits(self.data)

        return ImageNormalize(vmin=vmin, vmax=vmax, stretch=stretch, clip=clip)
Beispiel #4
0
def select_cutout(image, wcs):

    #I'm looking for how many pointings are in the mosaic
    #I don't know if it is always accurate
    nrow = image.shape[0] // 300.
    ncol = image.shape[1] // 300.
    #measuring the exact width of a row and of a column
    drow = image.shape[0] / nrow
    dcol = image.shape[1] / ncol
    #I'm showing the image to select the correct section
    #I'm picking the center with a mouse click (maybe)

    fig, ax = plt.subplots(1, 1)
    interval = vis.PercentileInterval(99.9)
    vmin, vmax = interval.get_limits(image)
    norm = vis.ImageNormalize(vmin=vmin,
                              vmax=vmax,
                              stretch=vis.LogStretch(1000))
    ax.imshow(image, cmap=plt.cm.Greys, norm=norm, origin='lower')
    for x in np.arange(0, image.shape[1], dcol):
        ax.axvline(x)
    for y in np.arange(0, image.shape[0], drow):
        ax.axhline(y)

    def onclick(event):

        ix, iy = event.xdata, event.ydata
        col = ix // 300.
        row = iy // 300.
        print(col, row)
        global x_cen, y_cen
        x_cen = 150 + 300 * (col)  #x of the center of the quadrans
        y_cen = 150 + 300 * (row)  #y of the center of thw quadrans
        print('x: {:3.0f}, y: {:3.0f}'.format(x_cen, y_cen))
        if event.key == 'q':
            fig.canvas.mpl_disconnect(cid)

    cid = fig.canvas.mpl_connect('button_press_event', onclick)
    plt.show()

    nrow = image.shape[0] // 300.
    ncol = image.shape[1] // 300.
    print(image.shape[0] / nrow)
    x = int(x_cen)
    y = int(y_cen)
    print(x, y)
    cutout = Cutout2D(image, (x, y),
                      size=(image.shape[0] / nrow - 20) * u.pixel,
                      wcs=wcs)
    return cutout
Beispiel #5
0
 def to_fig(self,
            rowrange,
            colrange,
            extension=1,
            cmap='Greys_r',
            cut=None,
            dpi=50):
     """Turns a fits file into a cropped and contrast-stretched matplotlib figure."""
     fts = fitsio.FITS(self.fits_filename)
     if (np.isfinite(fts[extension].read())).sum() == 0:
         raise InvalidFrameException()
     image = fts[extension].read()[rowrange[0]:rowrange[1],
                                   colrange[0]:colrange[1]]
     fts.close()
     if cut is None:
         cut = np.percentile(image[np.isfinite(image)], [10, 99.5])
     transform = visualization.LogStretch() + visualization.ManualInterval(
         vmin=cut[0], vmax=cut[1])
     image_scaled = transform(image)
     px_per_kepler_px = 20
     dimensions = [
         image.shape[0] * px_per_kepler_px,
         image.shape[1] * px_per_kepler_px
     ]
     figsize = [dimensions[1] / dpi, dimensions[0] / dpi]
     dpi = 440 / float(figsize[0])
     fig = pl.figure(figsize=figsize, dpi=dpi)
     ax = fig.add_subplot(1, 1, 1)
     ax.matshow(image_scaled,
                aspect='auto',
                cmap=cmap,
                origin='lower',
                interpolation='nearest')
     ax.set_xticks([])
     ax.set_yticks([])
     ax.axis('off')
     #ax.set_axis_bgcolor('red')
     fig.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0)
     fig.canvas.draw()
     return fig
Beispiel #6
0
def plot_image_fit_residuals(fig, image, fit, residuals=None, percentile=95.0):
    """
	Make a figure with three subplots showing the image, the fit and the
	residuals. The image and the fit are shown with logarithmic scaling and a
	common colorbar. The residuals are shown with linear scaling and a separate
	colorbar.

	Parameters:
		fig (fig object): Figure object in which to make the subplots.
		image (2D array): Image numpy array.
		fit (2D array): Fitted image numpy array.
		residuals (2D array, optional): Fitted image subtracted from image numpy array.

	Returns:
		list: List with Matplotlib subplot axes objects for each subplot.
	"""

    if residuals is None:
        residuals = image - fit

    # Calculate common normalization for the first two subplots:
    vmin_image, vmax_image = viz.PercentileInterval(percentile).get_limits(
        image)
    vmin_fit, vmax_fit = viz.PercentileInterval(percentile).get_limits(fit)
    vmin = np.nanmin([vmin_image, vmin_fit])
    vmax = np.nanmax([vmax_image, vmax_fit])
    norm = viz.ImageNormalize(vmin=vmin, vmax=vmax, stretch=viz.LogStretch())

    # Add subplot with the image:
    ax1 = fig.add_subplot(131)
    im1 = plot_image(image, ax=ax1, scale=norm, cbar=None, title='Image')

    # Add subplot with the fit:
    ax2 = fig.add_subplot(132)
    plot_image(fit, ax=ax2, scale=norm, cbar=None, title='PSF fit')

    # Calculate the normalization for the third subplot:
    vmin, vmax = viz.PercentileInterval(percentile).get_limits(residuals)
    v = np.max(np.abs([vmin, vmax]))

    # Add subplot with the residuals:
    ax3 = fig.add_subplot(133)
    im3 = plot_image(residuals,
                     ax=ax3,
                     scale='linear',
                     cmap='seismic',
                     vmin=-v,
                     vmax=v,
                     cbar=None,
                     title='Residuals')

    # Make the common colorbar for image and fit subplots:
    cbar_ax12 = fig.add_axes([0.125, 0.2, 0.494, 0.03])
    fig.colorbar(im1, cax=cbar_ax12, orientation='horizontal')

    # Make the colorbar for the residuals subplot:
    cbar_ax3 = fig.add_axes([0.7, 0.2, 0.205, 0.03])
    fig.colorbar(im3, cax=cbar_ax3, orientation='horizontal')

    # Add more space between subplots:
    plt.subplots_adjust(wspace=0.4, hspace=0.4)

    return [ax1, ax2, ax3]
Beispiel #7
0
def plot_image(image,
               ax=None,
               scale='log',
               cmap=None,
               origin='lower',
               xlabel=None,
               ylabel=None,
               cbar=None,
               clabel='Flux ($e^{-}s^{-1}$)',
               cbar_ticks=None,
               cbar_ticklabels=None,
               cbar_pad=None,
               cbar_size='5%',
               title=None,
               percentile=95.0,
               vmin=None,
               vmax=None,
               offset_axes=None,
               color_bad='k',
               **kwargs):
    """
	Utility function to plot a 2D image.

	Parameters:
		image (2d array): Image data.
		ax (matplotlib.pyplot.axes, optional): Axes in which to plot.
			Default (None) is to use current active axes.
		scale (str or :py:class:`astropy.visualization.ImageNormalize` object, optional):
			Normalization used to stretch the colormap.
			Options: ``'linear'``, ``'sqrt'``, ``'log'``, ``'asinh'``, ``'histeq'``, ``'sinh'``
			and ``'squared'``.
			Can also be a :py:class:`astropy.visualization.ImageNormalize` object.
			Default is ``'log'``.
		origin (str, optional): The origin of the coordinate system.
		xlabel (str, optional): Label for the x-axis.
		ylabel (str, optional): Label for the y-axis.
		cbar (string, optional): Location of color bar.
			Choises are ``'right'``, ``'left'``, ``'top'``, ``'bottom'``.
			Default is not to create colorbar.
		clabel (str, optional): Label for the color bar.
		cbar_size (float, optional): Fractional size of colorbar compared to axes. Default=0.03.
		cbar_pad (float, optional): Padding between axes and colorbar.
		title (str or None, optional): Title for the plot.
		percentile (float, optional): The fraction of pixels to keep in color-trim.
			If single float given, the same fraction of pixels is eliminated from both ends.
			If tuple of two floats is given, the two are used as the percentiles.
			Default=95.
		cmap (matplotlib colormap, optional): Colormap to use. Default is the ``Blues`` colormap.
		vmin (float, optional): Lower limit to use for colormap.
		vmax (float, optional): Upper limit to use for colormap.
		color_bad (str, optional): Color to apply to bad pixels (NaN). Default is black.
		kwargs (dict, optional): Keyword arguments to be passed to :py:func:`matplotlib.pyplot.imshow`.

	Returns:
		:py:class:`matplotlib.image.AxesImage`: Image from returned
			by :py:func:`matplotlib.pyplot.imshow`.

	.. codeauthor:: Rasmus Handberg <*****@*****.**>
	"""

    logger = logging.getLogger(__name__)

    # Backward compatible settings:
    make_cbar = kwargs.pop('make_cbar', None)
    if make_cbar:
        raise FutureWarning("'make_cbar' is deprecated. Use 'cbar' instead.")
        if not cbar:
            cbar = make_cbar

    # Special treatment for boolean arrays:
    if isinstance(image, np.ndarray) and image.dtype == 'bool':
        if vmin is None: vmin = 0
        if vmax is None: vmax = 1
        if cbar_ticks is None: cbar_ticks = [0, 1]
        if cbar_ticklabels is None: cbar_ticklabels = ['False', 'True']

    # Calculate limits of color scaling:
    interval = None
    if vmin is None or vmax is None:
        if allnan(image):
            logger.warning("Image is all NaN")
            vmin = 0
            vmax = 1
            if cbar_ticks is None:
                cbar_ticks = []
            if cbar_ticklabels is None:
                cbar_ticklabels = []
        elif isinstance(percentile, (list, tuple, np.ndarray)):
            interval = viz.AsymmetricPercentileInterval(
                percentile[0], percentile[1])
        else:
            interval = viz.PercentileInterval(percentile)

    # Create ImageNormalize object with extracted limits:
    if scale in ('log', 'linear', 'sqrt', 'asinh', 'histeq', 'sinh',
                 'squared'):
        if scale == 'log':
            stretch = viz.LogStretch()
        elif scale == 'linear':
            stretch = viz.LinearStretch()
        elif scale == 'sqrt':
            stretch = viz.SqrtStretch()
        elif scale == 'asinh':
            stretch = viz.AsinhStretch()
        elif scale == 'histeq':
            stretch = viz.HistEqStretch(image[np.isfinite(image)])
        elif scale == 'sinh':
            stretch = viz.SinhStretch()
        elif scale == 'squared':
            stretch = viz.SquaredStretch()

        # Create ImageNormalize object. Very important to use clip=False if the image contains
        # NaNs, otherwise NaN points will not be plotted correctly.
        norm = viz.ImageNormalize(data=image[np.isfinite(image)],
                                  interval=interval,
                                  vmin=vmin,
                                  vmax=vmax,
                                  stretch=stretch,
                                  clip=not anynan(image))

    elif isinstance(scale, (viz.ImageNormalize, matplotlib.colors.Normalize)):
        norm = scale
    else:
        raise ValueError("scale {} is not available.".format(scale))

    if offset_axes:
        extent = (offset_axes[0] - 0.5, offset_axes[0] + image.shape[1] - 0.5,
                  offset_axes[1] - 0.5, offset_axes[1] + image.shape[0] - 0.5)
    else:
        extent = (-0.5, image.shape[1] - 0.5, -0.5, image.shape[0] - 0.5)

    if ax is None:
        ax = plt.gca()

    # Set up the colormap to use. If a bad color is defined,
    # add it to the colormap:
    if cmap is None:
        cmap = copy.copy(plt.get_cmap('Blues'))
    elif isinstance(cmap, str):
        cmap = copy.copy(plt.get_cmap(cmap))

    if color_bad:
        cmap.set_bad(color_bad, 1.0)

    # Plotting the image using all the settings set above:
    im = ax.imshow(image,
                   cmap=cmap,
                   norm=norm,
                   origin=origin,
                   extent=extent,
                   interpolation='nearest',
                   **kwargs)

    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)
    if title is not None:
        ax.set_title(title)
    ax.set_xlim([extent[0], extent[1]])
    ax.set_ylim([extent[2], extent[3]])

    if cbar:
        colorbar(im,
                 ax=ax,
                 loc=cbar,
                 size=cbar_size,
                 pad=cbar_pad,
                 label=clabel,
                 ticks=cbar_ticks,
                 ticklabels=cbar_ticklabels)

    # Settings for ticks:
    integer_locator = MaxNLocator(nbins=10, integer=True)
    ax.xaxis.set_major_locator(integer_locator)
    ax.xaxis.set_minor_locator(integer_locator)
    ax.yaxis.set_major_locator(integer_locator)
    ax.yaxis.set_minor_locator(integer_locator)
    ax.tick_params(which='both', direction='out', pad=5)
    ax.xaxis.tick_bottom()
    ax.yaxis.tick_left()

    return im
Beispiel #8
0
        gs = gridspec.GridSpec(1,
                               3,
                               height_ratios=[1],
                               width_ratios=[1, 0.05, 0.01])
        gs.update(left=0.05,
                  right=0.95,
                  bottom=0.12,
                  top=0.95,
                  wspace=0.01,
                  hspace=0.03)
        ax1 = plt.subplot(gs[0, 0])

        # TPF plot
        mean_tpf = np.mean(tpf.flux, axis=0)
        nx, ny = np.shape(mean_tpf)
        norm = ImageNormalize(stretch=stretching.LogStretch())
        division = np.int(
            np.log10(np.nanmax(np.nanmean(tpf.flux.value, axis=0))))
        image = np.nanmean(tpf.flux, axis=0) / 10**division
        splot = plt.imshow(image,norm=norm, \
            extent=[tpf.column,tpf.column+ny,tpf.row,tpf.row+nx],origin='lower', zorder=0)

        # Pipeline aperture
        if pipeline == "True":  #
            aperture_mask = tpf.pipeline_mask
            aperture = tpf._parse_aperture_mask(aperture_mask)
            maskcolor = 'tomato'
            print("    --> Using pipeline aperture...")
        else:
            aperture_mask = tpf.create_threshold_mask(threshold=10,
                                                      reference_pixel='center')
Beispiel #9
0
def plot_fits(img,
              header,
              figsize=(10, 10),
              fontsize=16,
              levels=(None, None),
              lognorm=False,
              title=None,
              show=True,
              cmap="viridis"):
    """
    Show a fits image. (c) Sunil Sumha's code
    Parameters
    ----------
    img: np.ndarray
        Image data
    header: fits.header.Header
        Fits image header
    figsize: tuple of ints, optional
        Size of figure to be displayed (x,y)
    levels: tuple of floats, optional
        Minimum and maximum pixel values
        for visualisation.
    lognorm: bool, optional
        If true, the visualisation is log
        stretched.
    title: str, optional
        Title of the image
    show: bool, optional
        If true, displays the image.
        Else, returns the fig, ax
    cmap: str or pyplot cmap, optional
        Defaults to viridis

    Returns
    -------
    None if show is False. fig, ax if True
    """
    from astropy.wcs import WCS
    from astropy.stats import sigma_clipped_stats
    from astropy import visualization as vis

    plt.rcParams['font.size'] = fontsize
    wcs = WCS(header)

    _, median, sigma = sigma_clipped_stats(img)

    assert len(levels) == 2, "Invalid levels. Use this format: (vmin,vmax)"
    vmin, vmax = levels

    if vmin is None:
        vmin = median
    if vmax is not None:
        if vmin > vmax:
            vmin = vmax - 10 * sigma
            warnings.warn(
                "levels changed to ({:f},{:f}) because input vmin waz greater than vmax"
                .format(vmin, vmax))
    else:
        vmax = median + 10 * sigma

    fig = plt.figure(figsize=figsize)
    ax = plt.subplot(projection=wcs)

    if lognorm:
        ax.imshow(img,
                  vmax=vmax,
                  vmin=vmin,
                  norm=vis.ImageNormalize(stretch=vis.LogStretch()),
                  cmap=cmap)
    else:
        ax.imshow(img, vmax=vmax, vmin=vmin, cmap=cmap)
    ax.set_xlabel("RA")
    ax.set_ylabel("Dec")
    ax.set_title(title)
    if show:
        plt.show()
    else:
        return fig, ax
Beispiel #10
0
img2.data = img2.data * photflam / exptime / 0.0455**2  # get into units of erg/s/cm^2/A/arcsec^2

#Zoom into region of interest
cut_ctr = SkyCoord('12h18m57.5s 47d18m14s')
cut_dims = np.array([4.0, 4.0]) * u.arcmin
cut = Cutout2D(img2.data, cut_ctr, cut_dims, wcs=img.wcs)

#Plot first subplot: raw data gathered by the telescope
plt.subplot(131, projection=img.wcs)
plt.imshow(cut.data, origin='lower', cmap='plasma')
plt.grid(color='yellow', ls='solid')
plt.title('Raw Telescope Data', weight='bold')
plt.ylabel('Declination (J2000)')

#Features of raw data are hard to see, so time to stretch the values
trans = viz.LogStretch() + viz.ManualInterval(0, 5e-19)
cut.data = trans(cut.data)

#Plot second subplot: Enhanced so all bright regions are more visible
plt.subplot(132, projection=img.wcs)
plt.imshow(cut.data, origin='lower', cmap="plasma")
plt.grid(color='yellow', ls='solid')
plt.title('Enhanced', weight='bold')
plt.xlabel('Right Ascension (J2000)')


#Gaussian filter to supress point sources of light, like other stars
def destar(I, sigma, t):
    D = np.zeros_like(I)
    B = gauss(I, sigma)
    M = I - B
Beispiel #11
0
 def set_normalization(self,
                       stretch=None,
                       interval=None,
                       stretchkwargs={},
                       intervalkwargs={},
                       perm_linear=None):
     if stretch is None:
         if self.stretch is None:
             stretch = 'linear'
         else:
             stretch = self.stretch
     if isinstance(stretch, str):
         print(stretch,
               ' '.join([f'{k}={v}' for k, v in stretchkwargs.items()]))
         if self.data is None:  #can not calculate objects yet
             self.stretch_kwargs = stretchkwargs
         else:
             kwargs = self.prepare_kwargs(
                 self.stretch_kws_defaults[stretch], self.stretch_kwargs,
                 stretchkwargs)
             if perm_linear is not None:
                 perm_linear_kwargs = self.prepare_kwargs(
                     self.stretch_kws_defaults['linear'], perm_linear)
                 print(
                     'linear', ' '.join([
                         f'{k}={v}' for k, v in perm_linear_kwargs.items()
                     ]))
                 if stretch == 'asinh':  # arg: a=0.1
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.AsinhStretch(**kwargs))
                 elif stretch == 'contrastbias':  # args: contrast, bias
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.ContrastBiasStretch(**kwargs))
                 elif stretch == 'histogram':
                     stretch = vis.CompositeStretch(
                         vis.HistEqStretch(self.data, **kwargs),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'log':  # args: a=1000.0
                     stretch = vis.CompositeStretch(
                         vis.LogStretch(**kwargs),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'powerdist':  # args: a=1000.0
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.PowerDistStretch(**kwargs))
                 elif stretch == 'power':  # args: a
                     stretch = vis.CompositeStretch(
                         vis.PowerStretch(**kwargs),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'sinh':  # args: a=0.33
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.SinhStretch(**kwargs))
                 elif stretch == 'sqrt':
                     stretch = vis.CompositeStretch(
                         vis.SqrtStretch(),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'square':
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.SquaredStretch())
                 else:
                     raise ValueError('Unknown stretch:' + stretch)
             else:
                 if stretch == 'linear':  # args: slope=1, intercept=0
                     stretch = vis.LinearStretch(**kwargs)
                 else:
                     raise ValueError('Unknown stretch:' + stretch)
     self.stretch = stretch
     if interval is None:
         if self.interval is None:
             interval = 'zscale'
         else:
             interval = self.interval
     if isinstance(interval, str):
         print(interval,
               ' '.join([f'{k}={v}' for k, v in intervalkwargs.items()]))
         kwargs = self.prepare_kwargs(self.interval_kws_defaults[interval],
                                      self.interval_kwargs, intervalkwargs)
         if self.data is None:
             self.interval_kwargs = intervalkwargs
         else:
             if interval == 'minmax':
                 interval = vis.MinMaxInterval()
             elif interval == 'manual':  # args: vmin, vmax
                 interval = vis.ManualInterval(**kwargs)
             elif interval == 'percentile':  # args: percentile, n_samples
                 interval = vis.PercentileInterval(**kwargs)
             elif interval == 'asymetric':  # args: lower_percentile, upper_percentile, n_samples
                 interval = vis.AsymmetricPercentileInterval(**kwargs)
             elif interval == 'zscale':  # args: nsamples=1000, contrast=0.25, max_reject=0.5, min_npixels=5, krej=2.5, max_iterations=5
                 interval = vis.ZScaleInterval(**kwargs)
             else:
                 raise ValueError('Unknown interval:' + interval)
     self.interval = interval
     if self.img is not None:
         self.img.set_norm(
             vis.ImageNormalize(self.data,
                                interval=self.interval,
                                stretch=self.stretch,
                                clip=True))
Beispiel #12
0
def make_plot(f, pow, fap, bjd0, flux0, bjd, flux, phi, flux_phi, fit, phi2,
              flux_phi2, fit2, period, crowd, ns):

    fig = plt.figure(figsize=(24, 15))

    plt.rcParams.update({'font.size': 22})

    gridspec.GridSpec(6, 10)

    plt.subplot2grid((6, 10), (0, 0), colspan=2, rowspan=2)
    plt.title('TIC %d' % (TIC))
    mean_tpf = np.mean(tpf.flux, axis=0)
    nx, ny = np.shape(mean_tpf)
    norm = ImageNormalize(stretch=stretching.LogStretch())
    division = np.int(np.log10(np.nanmax(np.nanmean(tpf.flux, axis=0))))
    plt.imshow(np.nanmean(tpf.flux, axis=0) / 10**division,
               norm=norm,
               extent=[tpf.column, tpf.column + ny, tpf.row, tpf.row + nx],
               origin='lower',
               zorder=0)
    plt.xlim(tpf.column, tpf.column + 10)
    plt.ylim(tpf.row, tpf.row + 10)
    if not warning:
        x = coords[:, 0] + tpf.column + 0.5
        y = coords[:, 1] + tpf.row + 0.5
        plt.scatter(x, y, c='firebrick', alpha=0.5, edgecolors='r', s=sizes)
        plt.scatter(x, y, c='None', edgecolors='r', s=sizes)
        plt.scatter(x[idx], y[idx], marker='x', c='white')
    plt.text(tpf.column,
             tpf.row,
             'crowdsap = %4.2f' % np.mean(crowd),
             color='w')
    plt.ylabel('Pixel count')
    plt.xlabel('Pixel count')

    plt.subplot2grid((6, 10), (0, 2), colspan=2, rowspan=2)
    plt.scatter(s_bprp, s_MG, c='0.75', s=0.5, zorder=0)
    if (len(gaia) > 1):
        plt.scatter(bprp_all, MG_all, marker='s', c='b', s=10, zorder=1)
    plt.gca().invert_yaxis()
    plt.title('$Gaia$ HR-diagram')
    if not warning:
        plt.plot(bprp, MG, 'or', markersize=10, zorder=2)
    plt.ylabel('$M_G$')
    plt.xlabel('$G_{BP}-G_{RP}$')

    plt.subplot2grid((6, 10), (2, 0), colspan=4, rowspan=2)
    plt.title("Period = %5.2f h" % period)
    plt.plot(1.0 / f, pow, color='k')
    plt.xlim(min(1.0 / f), max(1.0 / f))
    plt.axhline(fap, color='b')
    plt.axvline(period, color='r', ls='--', zorder=0)
    #plt.axvspan(100., max(1.0/freq), alpha=0.5, color='red')
    plt.xscale('log')
    plt.xlabel('P [h]')
    plt.ylabel('Power')

    plt.subplot2grid((6, 10), (4, 0), colspan=4, rowspan=2)
    plt.title('%s sector/s' % ns)
    plt.xlabel("BJD - 2457000")
    plt.ylabel('Relative flux')
    plt.xlim(np.min(bjd), np.max(bjd))
    plt.scatter(bjd0, flux0, c='0.25', zorder=1, s=0.5)
    plt.scatter(bjd, flux, c='k', zorder=1, s=0.5)

    phi_avg = tul.avg_array(phi, 50)
    fphi_avg = tul.avg_array(flux_phi, 50)

    plt.subplot2grid((6, 10), (0, 4), colspan=6, rowspan=3)
    plt.title('Phased to dominant peak')
    plt.xlabel('Phase')
    plt.ylabel('Relative flux')
    plt.xlim(0, 2)
    #plt.errorbar(phi, flux_phi, fmt='.', color='0.5', markersize=0.75, elinewidth=0.5, zorder=0)
    plt.scatter(phi_avg, fphi_avg, marker='.', color='0.5', zorder=0)
    plt.plot(tul.running_mean(phi_avg, 15),
             tul.running_mean(fphi_avg, 15),
             '.k',
             zorder=1)
    plt.plot(phi, fit, 'r--', lw=3, zorder=2)
    #plt.errorbar(phi+1.0, flux_phi, fmt='.', color='0.5', markersize=0.75, elinewidth=0.5, zorder=0)
    plt.scatter(phi_avg + 1.0, fphi_avg, marker='.', color='0.5', zorder=0)
    plt.plot(tul.running_mean(phi_avg, 15) + 1.0,
             tul.running_mean(fphi_avg, 15),
             '.k',
             zorder=1)
    plt.plot(phi + 1.0, fit, 'r--', lw=3, zorder=2)

    phi_avg = tul.avg_array(phi2, 50)
    fphi_avg = tul.avg_array(flux_phi2, 50)

    plt.subplot2grid((6, 10), (3, 4), colspan=6, rowspan=3)
    plt.title('Phased to twice the peak')
    plt.xlabel('Phase')
    plt.ylabel('Relative flux')
    plt.xlim(0, 2)
    #plt.errorbar(phi2, flux_phi2, fmt='.', color='0.5', markersize=0.75, elinewidth=0.5, zorder=0)
    plt.scatter(phi_avg, fphi_avg, marker='.', color='0.5', zorder=0)
    plt.plot(tul.running_mean(phi_avg, 15),
             tul.running_mean(fphi_avg, 15),
             '.k',
             zorder=1)
    plt.plot(phi2, fit2, 'r--', lw=3, zorder=2)
    #plt.errorbar(phi2+1.0, flux_phi2, fmt='.', color='0.5', markersize=0.75, elinewidth=0.5, zorder=0)
    plt.scatter(phi_avg + 1.0, fphi_avg, marker='.', color='0.5', zorder=0)
    plt.plot(tul.running_mean(phi_avg, 15) + 1.0,
             tul.running_mean(fphi_avg, 15),
             '.k',
             zorder=1)
    plt.plot(phi2 + 1.0, fit2, 'r--', lw=3, zorder=2)

    plt.tight_layout()

    return fig
Beispiel #13
0
def plot_image_and_lines(cube,
                         wavs,
                         xrange,
                         yrange,
                         Hbeta_ref=None,
                         title='',
                         filename=None,
                         include_OIII=False):

    zpix = np.arange(0, cube.shape[0])
    lambda_delta = 5
    hbeta_z = np.where((np.array(wavs) > h_beta_std.value-lambda_delta)\
                       & (np.array(wavs) < h_beta_std.value+lambda_delta))[0]
    image = np.mean(cube[min(hbeta_z):max(hbeta_z) + 1, :, :], axis=0)

    spect = [
        np.mean(cube[z, yrange[0]:yrange[1] + 1, xrange[0]:xrange[1] + 1])
        for z in zpix
    ]
    i_peak = spect.index(max(spect))

    background_0 = models.Polynomial1D(degree=2)
    H_beta_0 = models.Gaussian1D(amplitude=500,
                                 mean=4861,
                                 stddev=1.,
                                 bounds={
                                     'mean': (4855, 4865),
                                     'stddev': (0.1, 5)
                                 })
    OIII4959_0 = models.Gaussian1D(amplitude=100,
                                   mean=4959,
                                   stddev=1.,
                                   bounds={
                                       'mean': (4955, 4965),
                                       'stddev': (0.1, 5)
                                   })
    OIII5007_0 = models.Gaussian1D(amplitude=200,
                                   mean=5007,
                                   stddev=1.,
                                   bounds={
                                       'mean': (5002, 5012),
                                       'stddev': (0.1, 5)
                                   })
    fitter = fitting.LevMarLSQFitter()
    if include_OIII is True:
        model0 = background_0 + H_beta_0 + OIII4959_0 + OIII5007_0
    else:
        model0 = background_0 + H_beta_0

    model0.mean_1 = wavs[i_peak]
    model = fitter(model0, wavs, spect)
    residuals = np.array(spect - model(wavs))

    plt.figure(figsize=(20, 8))

    plt.subplot(1, 4, 1)
    plt.title(title)
    norm = v.ImageNormalize(image,
                            interval=v.MinMaxInterval(),
                            stretch=v.LogStretch(1))
    plt.imshow(image, origin='lower', norm=norm)
    region_x = [
        xrange[0] - 0.5, xrange[1] + 0.5, xrange[1] + 0.5, xrange[0] - 0.5,
        xrange[0] - 0.5
    ]
    region_y = [
        yrange[0] - 0.5, yrange[0] - 0.5, yrange[1] + 0.5, yrange[1] + 0.5,
        yrange[0] - 0.5
    ]
    plt.plot(region_x, region_y, 'r-', alpha=0.5, lw=2)

    plt.subplot(1, 4, 2)
    if Hbeta_ref is not None:
        Hbeta_velocity = (model.mean_1.value * u.Angstrom).to(
            u.km / u.s, equivalencies=u.doppler_optical(Hbeta_ref))
        title = f'H-beta ({model.mean_1.value:.1f} A, v={Hbeta_velocity.value:.1f} km/s)'
    else:
        title = f'H-beta ({model.mean_1.value:.1f} A, sigma={model.stddev_1.value:.3f} A)'
    plt.title(title)
    w = [l for l in np.arange(4856, 4866, 0.05)]
    if Hbeta_ref is not None:
        vs = [(l * u.Angstrom).to(
            u.km / u.s, equivalencies=u.doppler_optical(Hbeta_ref)).value
              for l in wavs]
        plt.plot(vs, spect, drawstyle='steps-mid', label='data')
        vs = [(l * u.Angstrom).to(
            u.km / u.s, equivalencies=u.doppler_optical(Hbeta_ref)).value
              for l in w]
        plt.plot(vs, model(w), 'r-', alpha=0.7, label='Fit')
        plt.xlabel('Velocity (km/s)')
        plt.xlim(-200, 200)
    else:
        plt.plot(wavs, spect, drawstyle='steps-mid', label='data')
        plt.plot(w, model(w), 'r-', alpha=0.7, label='Fit')
        plt.xlabel('Wavelength (angstroms)')
        plt.xlim(4856, 4866)
    plt.grid()
    plt.ylabel('Flux')
    plt.legend(loc='best')

    plt.subplot(1, 4, 3)
    if include_OIII is True:
        title = f'OIII 4959 ({model.mean_2.value:.1f} A, sigma={model.stddev_2.value:.3f} A)'
    else:
        title = f'OIII 4959'
    plt.title(title)
    plt.plot(wavs, spect, drawstyle='steps-mid', label='data')
    w = [l for l in np.arange(4954, 4964, 0.05)]
    plt.plot(w, model(w), 'r-', alpha=0.7, label='Fit')
    plt.xlabel('Wavelength (angstroms)')
    plt.ylabel('Flux')
    plt.legend(loc='best')
    plt.xlim(4954, 4964)

    plt.subplot(1, 4, 4)
    if include_OIII is True:
        title = f'OIII 5007 ({model.mean_3.value:.1f} A, sigma={model.stddev_3.value:.3f} A)'
    else:
        title = f'OIII 5007'
    plt.title(title)
    plt.plot(wavs, spect, drawstyle='steps-mid', label='data')
    w = [l for l in np.arange(5002, 5012, 0.05)]
    plt.plot(w, model(w), 'r-', alpha=0.7, label='Fit')
    plt.xlabel('Wavelength (angstroms)')
    plt.ylabel('Flux')
    plt.legend(loc='best')
    plt.xlim(5002, 5012)

    if filename is not None:
        plt.savefig(filename, bbox_inches='tight', pad_inches=0.10)
    else:
        plt.show()

    return spect, model
Beispiel #14
0
def timegen(in_path, out_path, m, n, cell, stretch, full_hd):
    """ Generates a timelapse from the input FITS files (directory) and saves it to the given path. \n
        ---------- \n
        parameters \n
        ---------- \n
        in_path  : The path to the directory containing the input FITS files (*.fits or *.fits.fz) \n
        out_path : The path at which the output timelapse will be saved. If unspecified writes to .\\timelapse \n
        m        : Number of rows to split image into \n
        n        : Number of columns to split image into \n
        cell     : The grid cell to choose. Specified by row and column indices. (0,1)
        stretch  : String specifying what stretches to apply on the image
        full_hd : Force the video to be 1920 * 1080 pixels.    
        ---------- \n
        returns \n
        ---------- \n
        True if timelapse generated successfully. \n
    """
    # Step 1: Get FITS files from input path.
    fits_files = get_file_list(in_path, ['fits', 'fz'])

    # Step 1.5: Remove files containing the string 'background' from the FITS filename
    fits_files = [
        fname for fname in fits_files if 'background.fits' not in fname
    ]
    fits_files = [
        fname for fname in fits_files if 'pointing00.fits' not in fname
    ]

    # Step 2: Choose the transform you want to apply.
    # TG_LOG_1_PERCENTILE_99
    transform = v.LogStretch(1) + v.PercentileInterval(99)
    if stretch == 'TG_SQRT_PERCENTILE_99':
        transform = v.SqrtStretch() + v.PercentileInterval(99)
    elif stretch == 'TG_LOG_PERCENTILE_99':
        transform = v.LogStretch() + v.PercentileInterval(99)
    elif stretch == 'TG_ASINH_1_PERCENTILE_99':
        transform = v.AsinhStretch(1) + v.PercentileInterval(99)
    elif stretch == 'TG_ASINH_PERCENTILE_99':
        transform = v.AsinhStretch() + v.PercentileInterval(99)
    elif stretch == 'TG_SQUARE_PERCENTILE_99':
        transform = v.SquaredStretch() + v.PercentileInterval(99)
    elif stretch == 'TG_SINH_1_PERCENTILE_99':
        transform = v.SinhStretch(1) + v.PercentileInterval(99)
    else:
        transform = v.SinhStretch() + v.PercentileInterval(99)

    # Step 3:
    for file in tqdm.tqdm(fits_files):
        # Read FITS
        try:
            fits_data = fits.getdata(file)
        except Exception as e:
            # If the current FITS file can't be opened, log and skip it.
            logging.error(str(e))
            continue
        # Flip up down
        flipped_data = np.flipud(fits_data)
        # Debayer with 'RGGB'
        rgb_data = debayer_image_array(flipped_data, pattern='RGGB')
        interested_data = get_sub_image(rgb_data, m, n, cell[0], cell[1])
        # Additional processing
        interested_data = 255 * transform(interested_data)
        rgb_data = interested_data.astype(np.uint8)
        bgr_data = cv2.cvtColor(rgb_data, cv2.COLOR_RGB2BGR)
        # save processed image to temp_dir
        try:
            save_image(interested_data,
                       os.path.split(file)[-1].split('.')[0],
                       path=temp_dir)
        except Exception as e:
            logging.error(str(e))
    # Step 4: Validate output path and create if it doesn't exist.

    # Step 5: Create timelapse from the temporary files
    generate_timelapse_from_images('temp_timelapse', out_path, hd_flag=full_hd)

    # Delete temporary files
    try:
        clear_dir(temp_dir)
    except Exception as e:
        print('Clearing TEMP Files failed. See log for more details')
        logging.error(str(e))
    return True
Beispiel #15
0
def log(image_name):
    trial = 1
    while (trial != 0):
        try:
            try:
                hdul = fits.open(image_name)
            except (ValueError, FileNotFoundError):
                image_name = input(
                    "\nfile missing or empty file name !!! \nPlease re-enter file name : "
                )
                hdul = fits.open(image_name)
            hdul.info()
            header_number = int(
                input("Enter Header number whose data  you want view : "))
            image = hdul[header_number].data
            hdul.close()
            #printing stats about the data
            print("\n", " Minimum Value = ", np.min(image),
                  "\t Maximum Value = ", np.max(image), "\t Meadian Value = ",
                  np.median(image))
            flag = 1
            total_count = 1
            previous_parameter = 0
            while (flag != 0):
                ##stretching and normalizing using LogStretch() and MinMaxInterval() like in DS9
                log_param = float(
                    input("Enter base value for logrithmic stretch : "))
                norm = viz.ImageNormalize(image,
                                          vmin=((np.median(image))**2 -
                                                abs(np.min(image))),
                                          vmax=50,
                                          stretch=viz.LogStretch(log_param))
                if total_count > 1:
                    plt.subplot(1, 2, 1)
                    norm = viz.ImageNormalize(
                        image,
                        vmin=((np.median(image))**2 - abs(np.min(image))),
                        vmax=50,
                        stretch=viz.LogStretch(log_param))
                    plt.imshow(image, cmap='gray', norm=norm)
                    plt.title('a=' + str(log_param))
                    plt.grid(True)
                    plt.subplot(1, 2, 2)
                    log_param = previous_parameter
                    norm = viz.ImageNormalize(
                        image,
                        vmin=((np.median(image))**2 - abs(np.min(image))),
                        vmax=50,
                        stretch=viz.LogStretch(log_param))
                    plt.imshow(image, cmap='gray', norm=norm)
                    plt.title('a=' + str(previous_parameter))
                    plt.grid(True)
                else:
                    plt.imshow(image, cmap='gray', norm=norm)
                plt.show()
                ch = input(
                    "Are you happy with your choice of log_parameters(Y/N) : ")
                if ch == 'Y' or ch == 'y':
                    flag = 0
                    print("Stretched Image stored temporarily!!! \n")
                    output = norm(image)

                else:
                    flag = 1
                    total_count += 1
                    previous_parameter = log_param
            trial = 0
            return output
        except (TypeError):
            print("INCORRECT header chosen for viewing the data !!!! ")
            print("Please enter correct header number!!!\n")
        except (IndexError):
            print("HEADER UNIT not found!!!\nPlease recheck!!!\n")
Beispiel #16
0
def get_image(data, fmt='JPEG', norm='percentile', lo=None, hi=None,
              zcontrast=0.25, nsamples=1000, krej=2.5, max_iterations=5,
              stretch='linear', a=None, bias=0.5, contrast=1, cmap=None,
              dpi=100, **kwargs):
    u"""
    Return a byte array containing image in the given format

    Image scaling is done using `~astropy.visualization`. It includes
    normalization of the input data (mapping to [0, 1]) and stretching -
    optional non-linear mapping [0, 1] -> [0, 1] for contrast enhancement.
    A colormap can be applied to the normalized data. Conversion to the target
    image format is done by matplotlib or Pillow.

    :param array_like data: input 2D image data
    :param str fmt: output image format
    :param str norm: data normalization mode::
        "manual": lower and higher clipping limits are set explicitly
        "minmax": limits are set to the minimum and maximum data values
        "percentile" (default): limits are set based on the specified fraction
            of pixels
        "zscale": use IRAF ZScale algorithm
    :param int | float lo::
        for ``norm`` == "manual", lower data limit
        for ``norm`` == "percentile", lower percentile clipping value,
            defaulting to 10
        for ``norm`` == "zscale", lower limit on the number of rejected pixels,
            defaulting to 5
    :param int | float hi::
        for ``norm`` == "manual", upper data limit
        for ``norm`` == "percentile", upper percentile clipping value,
            defaulting to 98
        for ``norm`` == "zscale", upper limit on the number of rejected pixels,
            defaulting to data.size/2
    :param float zcontrast: for ``norm`` == "zscale", the scaling factor,
        0 < zcontrast < 1, defaulting to 0.25
    :param int nsamples: for ``norm`` == "zscale", the number of points in
        the input array for determining scale factors, defaulting to 1000
    :param float krej: for ``norm`` == "zscale", the sigma clipping factor,
        defaulting to 2.5
    :param int max_iterations: for ``norm`` == "zscale", the maximum number
        of rejection iterations, defaulting to 5
    :param str stretch: [0, 1] → [0, 1] mapping mode::
        "asinh": hyperbolic arcsine stretch y = asinh(x/a)/asinh(1/a)
        "contrast": linear bias/contrast-based stretch
            y = (x - bias)*contrast + 0.5
        "exp": exponential stretch y = (a^x - 1)/(a - 1)
        "histeq": histogram equalization stretch
        "linear" (default): direct mapping
        "log": logarithmic stretch y = log(ax + 1)/log(a + 1)
        "power": power stretch y = x^a
        "sinh": hyperbolic sine stretch y = sinh(x/a)/sinh(1/a)
        "sqrt": square root stretch y = √x
        "square": power stretch y = x^2
    :param float a: non-linear stretch parameter::
        for ``stretch`` == "asinh", the point of transition from linear to
            logarithmic behavior, 0 < a <= 1, defaulting to 0.1
        for ``stretch`` == "exp", base of the exponent, a != 1, defaulting to
            1000
        for ``stretch`` == "log", base of the logarithm minus 1, a > 0,
            defaulting to 1000
        for ``stretch`` == "power", the power index, defaulting to 3
        for ``stretch`` == "sinh", a > 0, defaulting to 1/3
    :param float bias: for ``stretch`` == "contrast", the bias parameter,
        defaulting to 0.5
    :param float contrast: for ``stretch`` == "contrast", the contrast
        parameter, defaulting to 1
    :param str cmap: optional matplotlib colormap name, defaulting
        to grayscale; when a non-grayscale colormap is specified,
        the conversion is always done by matplotlib, regardless of the
        availability of Pillow; see https://matplotlib.org/users/colormaps.html
        for more info on matplotlib colormaps and
            [name for name in matplotlib.cd.cmap_d.keys()
             if not name.endswith('_r')]
        to list the available colormap names
    :param int dpi: target image resolution in dots per inch
    :param kwargs: optional format-specific keyword arguments passed to Pillow,
        e.g. "quality" for JPEG; see
        `https://pillow.readthedocs.io/en/stable/handbook/
        image-file-formats.html`_

    :return: a bytes object containing the image in the given format
    :rtype: bytes
    """
    data = asanyarray(data)

    # Normalize image data
    if norm == 'manual':
        if lo is None:
            raise ValueError(
                'Missing lower clipping boundary for norm="manual"')
        if hi is None:
            raise ValueError(
                'Missing upper clipping boundary for norm="manual"')
    elif norm == 'minmax':
        lo, hi = data.min(), data.max()
    elif norm == 'percentile':
        if lo is None:
            lo = 10
        elif not 0 <= lo <= 100:
            raise ValueError(
                'Lower clipping percentile must be in the [0,100] range')
        if hi is None:
            hi = 98
        elif not 0 <= hi <= 100:
            raise ValueError(
                'Upper clipping percentile must be in the [0,100] range')
        if hi < lo:
            raise ValueError(
                'Upper clipping percentile must be greater or equal to '
                'lower percentile')
        lo, hi = percentile(data, [lo, hi])
    elif norm == 'zscale':
        if lo is None:
            lo = 5
        if hi is None:
            hi = 0.5
        else:
            hi /= data.size
        lo, hi = apy_vis.ZScaleInterval(
            nsamples, zcontrast, hi, lo, krej, max_iterations).get_limits(data)
    else:
        raise ValueError('Unknown normalization mode "{}"'.format(norm))
    data = clip((data - lo)/(hi - lo), 0, 1)

    # Stretch the data
    if stretch == 'asinh':
        if a is None:
            a = 0.1
        apy_vis.AsinhStretch(a)(data, out=data)
    elif stretch == 'contrast':
        if bias != 0.5 or contrast != 1:
            apy_vis.ContrastBiasStretch(contrast, bias)(data, out=data)
    elif stretch == 'exp':
        if a is None:
            a = 1000
        apy_vis.PowerDistStretch(a)(data, out=data)
    elif stretch == 'histeq':
        apy_vis.HistEqStretch(data)(data, out=data)
    elif stretch == 'linear':
        pass
    elif stretch == 'log':
        if a is None:
            a = 1000
        apy_vis.LogStretch(a)(data, out=data)
    elif stretch == 'power':
        if a is None:
            a = 3
        apy_vis.PowerStretch(a)(data, out=data)
    elif stretch == 'sinh':
        if a is None:
            a = 1/3
        apy_vis.SinhStretch(a)(data, out=data)
    elif stretch == 'sqrt':
        apy_vis.SqrtStretch()(data, out=data)
    elif stretch == 'square':
        apy_vis.SquaredStretch()(data, out=data)
    else:
        raise ValueError('Unknown stretch mode "{}"'.format(stretch))

    buf = BytesIO()
    try:
        # Choose the backend for making an image
        if cmap is None:
            cmap = 'gray'
        if cmap == 'gray':
            try:
                # noinspection PyPackageRequirements,PyPep8Naming
                from PIL import Image as pil_image
            except ImportError:
                pil_image = None
        else:
            pil_image = None

        if pil_image is not None:
            # Use Pillow for grayscale output if available; flip the image to
            # match the bottom-to-top FITS convention and convert from [0,1] to
            # unsigned byte
            pil_image.fromarray(
                (data[::-1]*255 + 0.5).astype(uint8),
            ).save(buf, fmt, dpi=(dpi, dpi), **kwargs)
        else:
            # Use matplotlib for non-grayscale colormaps or if PIL is not
            # available
            # noinspection PyPackageRequirements
            from matplotlib import image as mpl_image
            if fmt.lower() == 'png':
                # PNG images are saved upside down by matplotlib, regardless of
                # the origin parameter
                data = data[::-1]
            # noinspection PyTypeChecker
            mpl_image.imsave(
                buf, data, cmap=cmap, format=fmt, origin='lower', dpi=dpi)

        return buf.getvalue()
    finally:
        buf.close()
Beispiel #17
0
    def vetting_field_of_view(self, indir, tic, ra, dec, sectors):
        maglim = 6
        sectors_search = None if sectors is not None and len(
            sectors) == 0 else sectors
        tpf_source = lightkurve.search_targetpixelfile("TIC " + str(tic),
                                                       sector=sectors,
                                                       mission='TESS')
        if tpf_source is None or len(tpf_source) == 0:
            ra_str = str(ra)
            dec_str = "+" + str(dec) if dec >= 0 else str(dec)
            tpf_source = lightkurve.search_tesscut(ra_str + " " + dec_str,
                                                   sector=sectors_search)
        for i in range(0, len(tpf_source)):
            tpf = tpf_source[i].download(cutout_size=(12, 12))
            pipeline = True
            fig = plt.figure(figsize=(6.93, 5.5))
            gs = gridspec.GridSpec(1,
                                   3,
                                   height_ratios=[1],
                                   width_ratios=[1, 0.05, 0.01])
            gs.update(left=0.05,
                      right=0.95,
                      bottom=0.12,
                      top=0.95,
                      wspace=0.01,
                      hspace=0.03)
            ax1 = plt.subplot(gs[0, 0])
            # TPF plot
            mean_tpf = np.mean(tpf.flux.value, axis=0)
            nx, ny = np.shape(mean_tpf)
            norm = ImageNormalize(stretch=stretching.LogStretch())
            division = np.int(np.log10(np.nanmax(tpf.flux.value)))
            splot = plt.imshow(np.nanmean(tpf.flux, axis=0) / 10 ** division, norm=norm, \
                               extent=[tpf.column, tpf.column + ny, tpf.row, tpf.row + nx], origin='lower', zorder=0)
            # Pipeline aperture
            if pipeline:  #
                aperture_mask = tpf.pipeline_mask
                aperture = tpf._parse_aperture_mask(aperture_mask)
                maskcolor = 'lightgray'
                print("    --> Using pipeline aperture...")
            else:
                aperture_mask = tpf.create_threshold_mask(
                    threshold=10, reference_pixel='center')
                aperture = tpf._parse_aperture_mask(aperture_mask)
                maskcolor = 'lightgray'
                print("    --> Using threshold aperture...")

            for i in range(aperture.shape[0]):
                for j in range(aperture.shape[1]):
                    if aperture_mask[i, j]:
                        ax1.add_patch(
                            patches.Rectangle((j + tpf.column, i + tpf.row),
                                              1,
                                              1,
                                              color=maskcolor,
                                              fill=True,
                                              alpha=0.4))
                        ax1.add_patch(
                            patches.Rectangle((j + tpf.column, i + tpf.row),
                                              1,
                                              1,
                                              color=maskcolor,
                                              fill=False,
                                              alpha=1,
                                              lw=2))
            # Gaia sources
            gaia_id, mag = tpfplotter.get_gaia_data_from_tic(tic)
            r, res = tpfplotter.add_gaia_figure_elements(tpf,
                                                         magnitude_limit=mag +
                                                         np.float(maglim),
                                                         targ_mag=mag)
            x, y, gaiamags = r
            x, y, gaiamags = np.array(x) + 0.5, np.array(y) + 0.5, np.array(
                gaiamags)
            size = 128.0 / 2**((gaiamags - mag))
            plt.scatter(x,
                        y,
                        s=size,
                        c='red',
                        alpha=0.6,
                        edgecolor=None,
                        zorder=10)
            # Gaia source for the target
            this = np.where(np.array(res['Source']) == int(gaia_id))[0]
            plt.scatter(x[this],
                        y[this],
                        marker='x',
                        c='white',
                        s=32,
                        zorder=11)
            # Legend
            add = 0
            if np.int(maglim) % 2 != 0:
                add = 1
            maxmag = np.int(maglim) + add
            legend_mags = np.linspace(-2, maxmag, np.int((maxmag + 2) / 2 + 1))
            fake_sizes = mag + legend_mags  # np.array([mag-2,mag,mag+2,mag+5, mag+8])
            for f in fake_sizes:
                size = 128.0 / 2**((f - mag))
                plt.scatter(0,
                            0,
                            s=size,
                            c='red',
                            alpha=0.6,
                            edgecolor=None,
                            zorder=10,
                            label=r'$\Delta m=$ ' + str(np.int(f - mag)))
            ax1.legend(fancybox=True, framealpha=0.7)
            # Source labels
            dist = np.sqrt((x - x[this])**2 + (y - y[this])**2)
            dsort = np.argsort(dist)
            for d, elem in enumerate(dsort):
                if dist[elem] < 6:
                    plt.text(x[elem] + 0.1,
                             y[elem] + 0.1,
                             str(d + 1),
                             color='white',
                             zorder=100)
            # Orientation arrows
            tpfplotter.plot_orientation(tpf)
            # Labels and titles
            plt.xlim(tpf.column, tpf.column + ny)
            plt.ylim(tpf.row, tpf.row + nx)
            plt.xlabel('Pixel Column Number', fontsize=16)
            plt.ylabel('Pixel Row Number', fontsize=16)
            plt.title('Coordinates ' + tic + ' - Sector ' + str(tpf.sector),
                      fontsize=16)  # + ' - Camera '+str(tpf.camera))  #
            # Colorbar
            cbax = plt.subplot(gs[0, 1])  # Place it where it should be.
            pos1 = cbax.get_position()  # get the original position
            pos2 = [pos1.x0 - 0.05, pos1.y0, pos1.width, pos1.height]
            cbax.set_position(pos2)  # set a new position
            cb = Colorbar(ax=cbax,
                          mappable=splot,
                          orientation='vertical',
                          ticklocation='right')
            plt.xticks(fontsize=14)
            exponent = r'$\times 10^' + str(division) + '$'
            cb.set_label(r'Flux ' + exponent + r' (e$^-$)',
                         labelpad=10,
                         fontsize=16)
            save_dir = indir + "/tpfplot"
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
            plt.savefig(save_dir + '/TPF_Gaia_TIC' + tic + '_S' +
                        str(tpf.sector) + '.pdf')
            # Save Gaia sources info
            dist = np.sqrt((x - x[this])**2 + (y - y[this])**2)
            GaiaID = np.array(res['Source'])
            srt = np.argsort(dist)
            x, y, gaiamags, dist, GaiaID = x[srt], y[srt], gaiamags[srt], dist[
                srt], GaiaID[srt]
            IDs = np.arange(len(x)) + 1
            inside = np.zeros(len(x))
            for i in range(aperture.shape[0]):
                for j in range(aperture.shape[1]):
                    if aperture_mask[i, j]:
                        xtpf, ytpf = j + tpf.column, i + tpf.row
                        _inside = np.where((x > xtpf) & (x < xtpf + 1)
                                           & (y > ytpf) & (y < ytpf + 1))[0]
                        inside[_inside] = 1
            data = Table([
                IDs, GaiaID, x, y, dist, dist * 21., gaiamags,
                inside.astype('int')
            ],
                         names=[
                             '# ID', 'GaiaID', 'x', 'y', 'Dist_pix',
                             'Dist_arcsec', 'Gmag', 'InAper'
                         ])
            ascii.write(data,
                        save_dir + '/Gaia_TIC' + tic + '_S' + str(tpf.sector) +
                        '.dat',
                        overwrite=True)
        return save_dir
Beispiel #18
0
    def create_figure(self,
                      output_filename,
                      survey,
                      stretch='log',
                      vmin=1,
                      vmax=None,
                      min_percent=1,
                      max_percent=95,
                      cmap='gray',
                      contour_color='red',
                      data_col='FLUX'):
        """Returns a matplotlib Figure object that visualizes a frame.

        Parameters
        ----------

        vmin : float, optional
            Minimum cut level (default: 0).

        vmax : float, optional
            Maximum cut level (default: 5000).

        cmap : str, optional
            The matplotlib color map name.  The default is 'gray',
            can also be e.g. 'gist_heat'.

        raw : boolean, optional
            If `True`, show the raw pixel counts rather than
            the calibrated flux. Default: `False`.

        Returns
        -------
        image : array
            An array of unisgned integers of shape (x, y, 3),
            representing an RBG colour image x px wide and y px high.
        """
        # Get the flux data to visualize
        # Update to use TPF
        flx = self.TPF.flux_binned()
        # print(np.shape(flx))

        # calculate cut_levels
        if vmax is None:
            vmin, vmax = self.cut_levels(min_percent, max_percent, data_col)

        # Determine the figsize
        shape = list(flx.shape)
        # print(shape)
        # Create the figure and display the flux image using matshow
        fig = plt.figure(figsize=shape)
        # Display the image using matshow

        # Update to generate axes using WCS axes instead of plain axes
        ax = plt.subplot(projection=self.TPF.wcs)
        ax.set_xlabel('RA')
        ax.set_ylabel('Dec')

        if self.verbose:
            print('{} vmin/vmax = {}/{} (median={})'.format(
                data_col, vmin, vmax, np.nanmedian(flx)))

        if stretch == 'linear':
            stretch_fn = visualization.LinearStretch()
        elif stretch == 'sqrt':
            stretch_fn = visualization.SqrtStretch()
        elif stretch == 'power':
            stretch_fn = visualization.PowerStretch(1.0)
        elif stretch == 'log':
            stretch_fn = visualization.LogStretch()
        elif stretch == 'asinh':
            stretch_fn = visualization.AsinhStretch(0.1)
        else:
            raise ValueError('Unknown stretch: {0}.'.format(stretch))

        transform = (stretch_fn +
                     visualization.ManualInterval(vmin=vmin, vmax=vmax))
        ax.imshow((255 * transform(flx)).astype(int),
                  aspect='auto',
                  origin='lower',
                  interpolation='nearest',
                  cmap=cmap,
                  norm=NoNorm())
        ax.set_xticks([])
        ax.set_yticks([])

        current_ylims = ax.get_ylim()
        current_xlims = ax.get_xlim()

        pixels, header = surveyquery.getSVImg(self.TPF.position, survey)
        levels = np.linspace(np.min(pixels), np.percentile(pixels, 95), 10)
        ax.contour(pixels,
                   transform=ax.get_transform(WCS(header)),
                   levels=levels,
                   colors=contour_color)

        ax.set_xlim(current_xlims)
        ax.set_ylim(current_ylims)

        fig.canvas.draw()
        plt.savefig(output_filename, bbox_inches='tight', dpi=300)
        return fig
Beispiel #19
0
def show_image(image,
               percl=99,
               percu=None,
               is_mask=False,
               figsize=(6, 10),
               cmap='viridis',
               log=False,
               show_colorbar=True,
               show_ticks=True,
               fig=None,
               ax=None,
               input_ratio=None):
    """
    Show an image in matplotlib with some basic astronomically-appropriat stretching.

    Parameters
    ----------
    image
        The image to show
    percl : number
        The percentile for the lower edge of the stretch (or both edges if ``percu`` is None)
    percu : number or None
        The percentile for the upper edge of the stretch (or None to use ``percl`` for both)
    figsize : 2-tuple
        The size of the matplotlib figure in inches
    """
    if percu is None:
        percu = percl
        percl = 100 - percl

    if (fig is None and ax is not None) or (fig is not None and ax is None):
        raise ValueError('Must provide both "fig" and "ax" '
                         'if you provide one of them')
    elif fig is None and ax is None:
        fig, ax = plt.subplots(1, 1, figsize=figsize)
        if figsize is not None:
            # Rescale the fig size to match the image dimensions, roughly
            image_aspect_ratio = image.shape[0] / image.shape[1]
            figsize = (max(figsize) * image_aspect_ratio, max(figsize))
            print(figsize)

    # To preserve details we should *really* downsample correctly and
    # not rely on matplotlib to do it correctly for us (it won't).

    # So, calculate the size of the figure in pixels, block_reduce to
    # roughly that,and display the block reduced image.

    # Thanks, https://stackoverflow.com/questions/29702424/how-to-get-matplotlib-figure-size
    fig_size_pix = fig.get_size_inches() * fig.dpi

    ratio = (image.shape // fig_size_pix).max()

    if ratio < 1:
        ratio = 1

    ratio = input_ratio or ratio

    # Divide by the square of the ratio to keep the flux the same in the
    # reduced image
    reduced_data = block_reduce(image, ratio) / ratio**2

    # Of course, now that we have downsampled, the axis limits are changed to
    # match the smaller image size. Setting the extent will do the trick to
    # change the axis display back to showing the actual extent of the image.
    extent = [0, image.shape[1], 0, image.shape[0]]

    if log:
        stretch = aviz.LogStretch()
    else:
        stretch = aviz.LinearStretch()

    norm = aviz.ImageNormalize(reduced_data,
                               interval=aviz.AsymmetricPercentileInterval(
                                   percl, percu),
                               stretch=stretch)

    if is_mask:
        # The image is a mask in which pixels are zero or one. Set the image scale
        # limits appropriately.
        scale_args = dict(vmin=0, vmax=1)
    else:
        scale_args = dict(norm=norm)

    im = ax.imshow(reduced_data,
                   origin='lower',
                   cmap=cmap,
                   extent=extent,
                   aspect='equal',
                   **scale_args)

    if show_colorbar:
        # I haven't a clue why the fraction and pad arguments below work to make
        # the colorbar the same height as the image, but they do....unless the image
        # is wider than it is tall. Sticking with this for now anyway...
        # Thanks: https://stackoverflow.com/a/26720422/3486425
        fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04, format='%2.0f')
        # In case someone in the future wants to improve this:
        # https://joseph-long.com/writing/colorbars/
        # https://stackoverflow.com/a/33505522/3486425
        # https://matplotlib.org/mpl_toolkits/axes_grid/users/overview.html#colorbar-whose-height-or-width-in-sync-with-the-master-axes

    if not show_ticks:
        ax.tick_params(labelbottom=False,
                       labelleft=False,
                       labelright=False,
                       labeltop=False)
Beispiel #20
0
def generate_verification_page(lcd, ls, freq, power, cutoutpaths, c_obj,
                               outvppath, outd, show_binned=True):
    """
    Make the verification page, which consists of:

    top row: entire light curve (with horiz bar showing rotation period)

    bottom row:
        lomb scargle periodogram  |  phased light curve  |  image w/ aperture

    ----------
    args:

        lcd (dict): has the light curve, aperture positions, some lomb
        scargle results.

        ls: LombScargle instance with everything passed.

        cutoutpaths (list): FFI cutout FITS paths.

        c_obj (SkyCoord): astropy sky coordinate of the target

        outvppath (str): path to save verification page to
    """
    cutout_wcs = lcd['cutout_wcss'][0]

    mpl.rcParams['xtick.direction'] = 'in'
    mpl.rcParams['ytick.direction'] = 'in'

    plt.close('all')

    fig = plt.figure(figsize=(12,12))

    #ax0 = plt.subplot2grid((3, 3), (0, 0), colspan=3)
    #ax1 = plt.subplot2grid((3, 3), (1, 0), colspan=3)
    #ax2 = plt.subplot2grid((3, 3), (2, 0))
    #ax3 = plt.subplot2grid((3, 3), (2, 1))
    #ax4 = plt.subplot2grid((3, 3), (2, 2), projection=cutout_wcs)

    ax0 = plt.subplot2grid((3, 3), (1, 0), colspan=3)
    ax1 = plt.subplot2grid((3, 3), (2, 0), colspan=3)
    ax2 = plt.subplot2grid((3, 3), (0, 0))
    ax3 = plt.subplot2grid((3, 3), (0, 1))
    ax4 = plt.subplot2grid((3, 3), (0, 2), projection=cutout_wcs)

    #
    # row 0: entire light curve, pre-detrending (with horiz bar showing
    # rotation period). plot model LC too.
    #
    try:
        ax0.scatter(lcd['predetrending_time'], lcd['predetrending_rel_flux'],
                    c='k', alpha=1.0, zorder=3, s=10, rasterized=True,
                    linewidths=0)
    except KeyError as e:
        print('ERR! {}\nReturning.'.format(e))
        return


    try:
        model_flux = nparr(lcd['predetrending_rel_flux']/lcd['rel_flux'])
    except ValueError:
        model_flux = 0

    if isinstance(model_flux, np.ndarray):
        ngroups, groups = find_lc_timegroups(lcd['predetrending_time'], mingap=0.5)
        for group in groups:
            ax0.plot(lcd['predetrending_time'][group], model_flux[group], c='C0',
                     alpha=1.0, zorder=2, rasterized=True, lw=2)

    # add the bar showing the derived period
    ymax = np.percentile(lcd['predetrending_rel_flux'], 95)
    ymin = np.percentile(lcd['predetrending_rel_flux'], 5)
    ydiff = 1.15*(ymax-ymin)

    epoch = np.nanmin(lcd['predetrending_time']) + lcd['ls_period']
    ax0.plot([epoch, epoch+lcd['ls_period']], [ymax, ymax], color='red', lw=2,
             zorder=4)

    ax0.set_ylim((ymin-ydiff,ymax+ydiff))

    #ax0.set_xlabel('Time [BJD$_{\mathrm{TDB}}$]')
    ax0.set_xticklabels('')
    ax0.set_ylabel('Raw flux')

    name = outd['name']
    group_id = outd['group_id']
    if name=='nan':
        nstr = 'Group {}'.format(group_id)
    else:
        nstr = '{}'.format(name)


    if not np.isfinite(outd['teff']):
        outd['teff'] = 0

    ax0.text(0.98, 0.97,
        'Teff={:d}K. {}'.format(int(outd['teff']), nstr),
             ha='right', va='top', fontsize='large', zorder=2,
             transform=ax0.transAxes
    )

    #
    # row 1: entire light curve (with horiz bar showing rotation period)
    #
    ax1.scatter(lcd['time'], lcd['rel_flux'], c='k', alpha=1.0, zorder=2, s=10,
                rasterized=True, linewidths=0)

    # add the bar showing the derived period
    ymax = np.percentile(lcd['rel_flux'], 95)
    ymin = np.percentile(lcd['rel_flux'], 5)
    ydiff = 1.15*(ymax-ymin)

    epoch = np.nanmin(lcd['time']) + lcd['ls_period']
    ax1.plot([epoch, epoch+lcd['ls_period']], [ymax, ymax], color='red', lw=2)

    ax1.set_ylim((ymin-ydiff,ymax+ydiff))

    ax1.set_xlabel('Time [BJD$_{\mathrm{TDB}}$]')
    ax1.set_ylabel('Detrended flux')

    #
    # row 2, col 0: lomb scargle periodogram
    #
    ax2.plot(1/freq, power, c='k')
    ax2.set_xscale('log')
    ax2.text(0.03, 0.97, 'FAP={:.1e}\nP={:.1f}d'.format(
        lcd['ls_fap'], lcd['ls_period']), ha='left', va='top',
        fontsize='large', zorder=2, transform=ax2.transAxes
    )
    ax2.set_xlabel('Period [day]', labelpad=-1)
    ax2.set_ylabel('LS power')

    #
    # row 2, col 1: phased light curve 
    #
    phzd = phase_magseries(lcd['time'], lcd['rel_flux'], lcd['ls_period'],
                           lcd['time'][np.argmin(lcd['rel_flux'])], wrap=False,
                           sort=True)

    ax3.scatter(phzd['phase'], phzd['mags'], c='k', rasterized=True, s=10,
                linewidths=0, zorder=1)

    if show_binned:
        try:
            binphasedlc = phase_bin_magseries(phzd['phase'], phzd['mags'],
                                              binsize=1e-2, minbinelems=5)
            binplotphase = binphasedlc['binnedphases']
            binplotmags = binphasedlc['binnedmags']

            ax3.scatter(binplotphase, binplotmags, s=10, c='darkorange',
                        linewidths=0, zorder=3, rasterized=True)
        except TypeError as e:
            print(e)
            pass

    xlim = ax3.get_xlim()
    ax3.hlines(1.0, xlim[0], xlim[1], colors='gray', linestyles='dotted',
               zorder=2)
    ax3.set_xlim(xlim)

    ymax = np.percentile(lcd['rel_flux'], 95)
    ymin = np.percentile(lcd['rel_flux'], 5)
    ydiff = 1.15*(ymax-ymin)
    ax3.set_ylim((ymin-ydiff,ymax+ydiff))

    ax3.set_xlabel('Phase', labelpad=-1)
    ax3.set_ylabel('Flux', labelpad=-0.5)

    #
    # row2, col2: image w/ aperture. put on the nbhr stars as dots too, to
    # ensure the wcs isn't wonky!
    #

    # acquire neighbor stars.
    radius = 2.0*u.arcminute

    nbhr_stars = Catalogs.query_region(
        "{} {}".format(float(c_obj.ra.value), float(c_obj.dec.value)),
        catalog="TIC",
        radius=radius
    )

    try:
        Tmag_cutoff = 15
        px,py = cutout_wcs.all_world2pix(
            nbhr_stars[nbhr_stars['Tmag'] < Tmag_cutoff]['ra'],
            nbhr_stars[nbhr_stars['Tmag'] < Tmag_cutoff]['dec'],
            0
        )
    except Exception as e:
        print('ERR! wcs all_world2pix got {}'.format(repr(e)))
        return

    tmags = nbhr_stars[nbhr_stars['Tmag'] < Tmag_cutoff]['Tmag']

    sel = (px > 0) & (px < 19) & (py > 0) & (py < 19)
    px,py = px[sel], py[sel]
    tmags = tmags[sel]

    ra, dec = float(c_obj.ra.value), float(c_obj.dec.value)
    target_x, target_y = cutout_wcs.all_world2pix(ra,dec,0)

    #
    # finally make it
    #

    img = lcd['median_imgs'][0]

    # some images come out as nans.
    if np.all(np.isnan(img)):
        img = np.ones_like(img)

    interval = vis.PercentileInterval(99.9)
    vmin,vmax = interval.get_limits(img)
    norm = vis.ImageNormalize(
        vmin=vmin, vmax=vmax, stretch=vis.LogStretch(1000))

    cset = ax4.imshow(img, cmap='YlGnBu_r', origin='lower', zorder=1,
                      norm=norm)

    ax4.scatter(px, py, marker='x', c='r', s=5, rasterized=True, zorder=2,
                linewidths=1)
    ax4.plot(target_x, target_y, mew=0.5, zorder=5, markerfacecolor='yellow',
             markersize=7, marker='*', color='k', lw=0)

    #ax4.coords.grid(True, color='white', ls='dotted', lw=1)
    lon = ax4.coords['ra']
    lat = ax4.coords['dec']

    lon.set_ticks(spacing=1*u.arcminute)
    lat.set_ticks(spacing=1*u.arcminute)

    lon.set_ticklabel(exclude_overlapping=True)
    lat.set_ticklabel(exclude_overlapping=True)

    ax4.coords.grid(True, color='white', alpha=0.3, lw=0.3, ls='dotted')

    #cb0 = fig.colorbar(cset, ax=ax4, extend='neither', fraction=0.046, pad=0.04)

    # overplot aperture
    radius_px = 3
    circle = plt.Circle((target_x, target_y), radius_px,
                         color='C1', fill=False, zorder=5)
    ax4.add_artist(circle)

    #
    # cleanup
    # 
    for ax in [ax0,ax1,ax2,ax3,ax4]:
        ax.get_yaxis().set_tick_params(which='both', direction='in',
                                       labelsize='small', top=True, right=True)
        ax.get_xaxis().set_tick_params(which='both', direction='in',
                                       labelsize='small', top=True, right=True)

    fig.tight_layout(w_pad=0.5, h_pad=0)

    #
    # save
    #
    fig.savefig(outvppath, dpi=300, bbox_inches='tight')
    print('made {}'.format(outvppath))
Beispiel #21
0
if not os.path.exists(os.path.join(data_path, 'HH305E_nebsub.fits')):
    hdr = hdul[0].header
    now = dt.utcnow().strftime('%Y/%m/%d %H:%M:%S UT')
    hdr.set('HISTORY', f'Background subtracted {now}')
    hdu = fits.PrimaryHDU(data=neb_subtracted, header=hdr)
    hdu.writeto(os.path.join(data_path, 'HH305E_nebsub.fits'))

##-------------------------------------------------------------------------
## Plot mask of low H-beta emission
plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.title('Sum of H-beta Bins')
norm = v.ImageNormalize(image,
                        interval=v.ManualInterval(vmin=image.min() - 5,
                                                  vmax=image.max() + 10),
                        stretch=v.LogStretch(10))
im = plt.imshow(image, origin='lower', norm=norm, cmap='Greys')
plt.colorbar(im)

plt.subplot(1, 2, 2)
plt.title('Nebular Emission Mask')
mimage = np.ma.MaskedArray(image)
mimage.mask = ~nmask
mimagef = np.ma.filled(mimage, fill_value=0)
norm = v.ImageNormalize(mimagef,
                        interval=v.ManualInterval(
                            vmin=image.min() - 5,
                            vmax=np.percentile(image, mask_pcnt) + 5),
                        stretch=v.LinearStretch())
im = plt.imshow(mimagef, origin='lower', norm=norm, cmap='Greys')
plt.colorbar(im)
Beispiel #22
0
import matplotlib.pyplot as plt
from astropy.io import fits
import astropy.visualization as viz

image_name = input("Please enter the name of the file : ")
hdul = fits.open(image_name)
hdul.info()
header_number = int(input("Enter Header number whose data  you want view : "))
image = hdul[header_number].data
hdul.close()
##stretching and normalizing using LogStretch() and MinMaxInterval() like in DS9
log_param = float(input("Enter base value for logrithmic stretch : "))
norm = viz.ImageNormalize(image,
                          interval=viz.MinMaxInterval(),
                          stretch=viz.LogStretch())
plt.imshow(image, cmap='gray')
plt.show()
Beispiel #23
0
    def create_figure(self,
                      frameno=0,
                      binning=1,
                      dpi=None,
                      stretch='log',
                      vmin=1,
                      vmax=5000,
                      cmap='gray',
                      data_col='FLUX',
                      annotate=True,
                      time_format='ut',
                      show_flags=False,
                      label=None):
        """Returns a matplotlib Figure object that visualizes a frame.

        Parameters
        ----------
        frameno : int
            Image number in the target pixel file.

        binning : int
            Number of frames around `frameno` to co-add. (default: 1).

        dpi : float, optional [dots per inch]
            Resolution of the output in dots per Kepler CCD pixel.
            By default the dpi is chosen such that the image is 440px wide.

        vmin : float, optional
            Minimum cut level (default: 1).

        vmax : float, optional
            Maximum cut level (default: 5000).

        cmap : str, optional
            The matplotlib color map name.  The default is 'gray',
            can also be e.g. 'gist_heat'.

        raw : boolean, optional
            If `True`, show the raw pixel counts rather than
            the calibrated flux. Default: `False`.

        annotate : boolean, optional
            Annotate the Figure with a timestamp and target name?
            (Default: `True`.)

        show_flags : boolean, optional
            Show the quality flags?
            (Default: `False`.)

        label : str
            Label text to show in the bottom left corner of the movie.

        Returns
        -------
        image : array
            An array of unisgned integers of shape (x, y, 3),
            representing an RBG colour image x px wide and y px high.
        """
        # Get the flux data to visualize
        flx = self.flux_binned(frameno=frameno,
                               binning=binning,
                               data_col=data_col)

        # Determine the figsize and dpi
        shape = list(flx.shape)
        shape = [shape[1], shape[0]]
        if dpi is None:
            # Twitter timeline requires dimensions between 440x220 and 1024x512
            # so we make 440 the default
            dpi = 440 / float(shape[0])

        # libx264 require the height to be divisible by 2, we ensure this here:
        shape[0] -= ((shape[0] * dpi) % 2) / dpi

        # Create the figureand display the flux image using matshow
        fig = pl.figure(figsize=shape, dpi=dpi)
        # Display the image using matshow
        ax = fig.add_subplot(1, 1, 1)
        if self.verbose:
            print('{} vmin/vmax = {}/{} (median={})'.format(
                data_col, vmin, vmax, np.nanmedian(flx)))

        if stretch == 'linear':
            stretch_fn = visualization.LinearStretch()
        elif stretch == 'sqrt':
            stretch_fn = visualization.SqrtStretch()
        elif stretch == 'power':
            stretch_fn = visualization.PowerStretch(1.0)
        elif stretch == 'log':
            stretch_fn = visualization.LogStretch()
        elif stretch == 'asinh':
            stretch_fn = visualization.AsinhStretch(0.1)
        else:
            raise ValueError('Unknown stretch: {0}.'.format(stretch))

        transform = (stretch_fn +
                     visualization.ManualInterval(vmin=vmin, vmax=vmax))
        flx_transform = 255 * transform(flx)
        # Make sure to remove all NaNs!
        flx_transform[~np.isfinite(flx_transform)] = 0
        ax.imshow(flx_transform.astype(int),
                  aspect='auto',
                  origin='lower',
                  interpolation='nearest',
                  cmap=cmap,
                  norm=NoNorm())
        if annotate:  # Annotate the frame with a timestamp and target name?
            fontsize = 3. * shape[0]
            margin = 0.03
            # Print target name in lower left corner
            if label is None:
                label = self.objectname
            txt = ax.text(margin,
                          margin,
                          label,
                          family="monospace",
                          fontsize=fontsize,
                          color='white',
                          transform=ax.transAxes)
            txt.set_path_effects([
                path_effects.Stroke(linewidth=fontsize / 6.,
                                    foreground='black'),
                path_effects.Normal()
            ])
            # Print a timestring in the lower right corner
            txt2 = ax.text(1 - margin,
                           margin,
                           self.timestamp(frameno, time_format=time_format),
                           family="monospace",
                           fontsize=fontsize,
                           color='white',
                           ha='right',
                           transform=ax.transAxes)
            txt2.set_path_effects([
                path_effects.Stroke(linewidth=fontsize / 6.,
                                    foreground='black'),
                path_effects.Normal()
            ])
            # Print quality flags in upper right corner
            if show_flags:
                flags = self.quality_flags(frameno)
                if len(flags) > 0:
                    txt3 = ax.text(margin,
                                   1 - margin,
                                   '\n'.join(flags),
                                   family="monospace",
                                   fontsize=fontsize * 1.3,
                                   color='white',
                                   ha='left',
                                   va='top',
                                   transform=ax.transAxes,
                                   linespacing=1.5,
                                   backgroundcolor='red')
                    txt3.set_path_effects([
                        path_effects.Stroke(linewidth=fontsize / 6.,
                                            foreground='black'),
                        path_effects.Normal()
                    ])
        ax.set_xticks([])
        ax.set_yticks([])
        ax.axis('off')
        fig.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0)
        fig.canvas.draw()
        return fig
def make_hiidust_plot(
    reg,
    mgpsfile,
    width=1 * u.arcmin,
    surveys=['atlasgal'],
    figure=None,
    regname='GAL_031',
    fifth_panel_synchro=False,
    alpha=-0.12,
    cmap=None,
):

    if cmap is None:
        cmap = pl.cm.viridis
        cmap.set_bad('w')

    mgps_fh = fits.open(mgpsfile)[0]
    frame = wcs.utils.wcs_to_celestial_frame(wcs.WCS(mgps_fh.header))

    coordinate = reg.center
    coordname = "{0:06.3f}_{1:06.3f}".format(coordinate.galactic.l.deg,
                                             coordinate.galactic.b.deg)

    mgps_cutout = Cutout2D(mgps_fh.data,
                           coordinate.transform_to(frame.name),
                           size=width * 2,
                           wcs=wcs.WCS(mgps_fh.header))
    print()
    print(reg.meta['text'])
    print(
        f"Retrieving MAGPIS data for {coordname} ({coordinate.to_string()} {coordinate.frame.name})"
    )
    # we're treating 'width' as a radius elsewhere, here it's a full width
    images = {
        survey: getimg(coordinate, image_size=width * 2, survey=survey)
        for survey in surveys
    }
    images = {x: y for x, y in images.items() if y is not None}
    assert len(images) > 0
    #images['mgps'] = [mgps_cutout]

    # coordinate stuff so images can be reprojected to same frame
    ww = mgps_cutout.wcs.celestial
    mgps_pixscale = (wcs.utils.proj_plane_pixel_area(ww) * u.deg**2)**0.5

    if figure is None:
        figure = pl.gcf()
    figure.clf()

    (survey, img), = images.items()

    new_img = img[0].data
    if hasattr(img[0], 'header'):
        outwcs = wcs.WCS(img[0].header)
    else:
        outwcs = img[0].wcs

    reproj_pixscale = (wcs.utils.proj_plane_pixel_area(outwcs) * u.deg**2)**0.5

    agal_bm = tgt_bm = Beam(beam_map[survey])
    convbm = tgt_bm.deconvolve(mgps_beam)

    mgps_sm = convolution.convolve_fft(mgps_cutout.data,
                                       convbm.as_kernel(mgps_pixscale))
    mgps_reproj, _ = reproject.reproject_interp((mgps_sm, mgps_cutout.wcs),
                                                outwcs,
                                                shape_out=img[0].data.shape)

    mgpsMjysr = mgps_cutout.data / mgps_beam.sr.value / 1e6

    dust_pred = dust_emissivity.blackbody.modified_blackbody(
        u.Quantity(
            [wlmap[survey].to(u.GHz, u.spectral()),
             mustang_central_frequency]),
        assumed_temperature,
        beta=assumed_dustbeta)

    # assumes "surv" is dust
    surv_to_mgps = new_img * dust_pred[1] / dust_pred[0]
    print(f"{regname} {survey}")
    print(f"{survey} to mgps ratio: {dust_pred[1]/dust_pred[0]}")

    dusty = surv_to_mgps.value / tgt_bm.sr.value / 1e6
    freefree = (mgps_reproj / mgps_beam.sr.value / 1e6 - dusty)
    assert not hasattr(freefree, 'unit')
    print("Max values: ", img[0].data.max(), mgps_sm.max())
    print("More max values: ", np.nanmax(dusty), np.nanmax(freefree),
          np.nanmax(mgps_reproj / mgps_beam.sr.value / 1e6))

    norm = visualization.ImageNormalize(
        freefree,
        interval=visualization.ManualInterval(np.nanpercentile(freefree, 0.1),
                                              np.nanpercentile(freefree,
                                                               99.9)),
        stretch=visualization.LogStretch(),
    )
    mgpsnorm = visualization.ImageNormalize(
        mgps_cutout.data,
        interval=visualization.PercentileInterval(99.95),
        stretch=visualization.LogStretch(),
    )
    print(f"interval: {norm.interval.vmin}, {norm.interval.vmax}")
    assert not hasattr(norm.vmin, 'unit')
    assert not hasattr(norm.vmax, 'unit')
    assert not hasattr(mgpsnorm.vmin, 'unit')
    assert not hasattr(mgpsnorm.vmax, 'unit')

    Magpis.cache_location = '/Volumes/external/mgps/cache/'

    ax0 = figure.add_subplot(1, 6, 3, projection=mgps_cutout.wcs)
    ax0.imshow(mgpsMjysr,
               origin='lower',
               interpolation='none',
               norm=norm,
               cmap=cmap)
    ax0.set_title("3 mm")
    ax1 = figure.add_subplot(1, 6, 1, projection=outwcs)
    ax1.imshow(dusty,
               origin='lower',
               interpolation='none',
               norm=norm,
               cmap=cmap)
    ax1.set_title("870 $\\mu$m scaled")
    ax1.set_ylabel("Galactic Latitude")
    ax2 = figure.add_subplot(1, 6, 2, projection=outwcs)
    ax2.imshow(freefree,
               origin='lower',
               interpolation='none',
               norm=norm,
               cmap=cmap)
    ax2.set_title("3 mm Free-Free")

    for ax in (ax0, ax1, ax2):
        #ax.set_xlabel("Galactic Longitude")
        ax.tick_params(direction='in')
        ax.tick_params(color='w')

    ax0.coords[1].set_axislabel("")
    ax0.coords[1].set_ticklabel_visible(False)
    ax2.coords[1].set_axislabel("")
    ax2.coords[1].set_ticklabel_visible(False)

    pl.subplots_adjust(hspace=0, wspace=0)

    if 'G01' in regname:
        gps20im = fits.open('/Users/adam/work/gc/20cm_0.fits', )
    elif 'G49' in regname:
        gps20im = fits.open(
            '/Users/adam/work/w51/vla_old/W51-LBAND-feathered_ABCD.fits')
        #gps20im = fits.open('/Users/adam/work/w51/vla_old/W51-LBAND_Carray.fits')
    else:
        gps20im = getimg(coordinate, image_size=width * 2, survey='gps20new')

    reproj_gps20, _ = reproject.reproject_interp(
        (gps20im[0].data.squeeze(), wcs.WCS(gps20im[0].header).celestial),
        #mgps_fh.header)
        # refactoring to make a smaller cutout would make this faster....
        mgps_cutout.wcs,
        shape_out=mgps_cutout.data.shape)

    gps20cutout = Cutout2D(
        reproj_gps20,  #gps20im[0].data.squeeze(),
        coordinate.transform_to(frame.name),
        size=width * 2,
        wcs=mgps_cutout.wcs)
    #wcs=wcs.WCS(mgps_fh.header))
    #wcs.WCS(gps20im[0].header).celestial)
    ax3 = figure.add_subplot(1, 6, 5, projection=gps20cutout.wcs)

    gps20_bm = Beam.from_fits_header(gps20im[0].header)
    print(f"GPS 20 beam: {gps20_bm.__repr__()}")

    norm20 = visualization.ImageNormalize(
        gps20cutout.data,
        interval=visualization.ManualInterval(
            np.nanpercentile(gps20cutout.data, 0.5),
            np.nanpercentile(gps20cutout.data, 99.9)),
        stretch=visualization.LogStretch(),
    )

    # use 0.12 per Loren's suggestion
    freefree_20cm_to_3mm = (90 * u.GHz / (1.4 * u.GHz))**alpha

    gps20_Mjysr = gps20cutout.data / gps20_bm.sr.value / 1e6

    ax3.imshow((gps20_Mjysr * freefree_20cm_to_3mm).value,
               origin='lower',
               interpolation='none',
               norm=norm,
               cmap=cmap)
    ax3.set_title("20 cm scaled")

    ax3.coords[1].set_axislabel("")
    ax3.coords[1].set_ticklabel_visible(False)
    ax3.tick_params(direction='in')
    ax3.tick_params(color='w')

    # Fifth Panel:

    # use freefree_proj to get the 20cm-estimated free-free contribution even
    # if we're not using it for plotting
    # MAGPIS data are high-resolution (comparable to but better than MGPS)
    # Zadeh data are low-resolution, 30ish arcsec
    # units: Jy/sr
    freefree_proj, _ = reproject.reproject_interp(
        (freefree, outwcs), gps20cutout.wcs, shape_out=gps20cutout.data.shape)

    gps20_pixscale = (wcs.utils.proj_plane_pixel_area(gps20cutout.wcs) *
                      u.deg**2)**0.5

    # depending on which image has higher resolution, convolve one to the other
    try:
        gps20convbm = tgt_bm.deconvolve(gps20_bm)
        gps20_Mjysr_sm = convolution.convolve_fft(
            gps20_Mjysr, gps20convbm.as_kernel(gps20_pixscale))
    except ValueError:
        gps20_Mjysr_sm = gps20_Mjysr
        ff_convbm = gps20_bm.deconvolve(tgt_bm)
        freefree_proj = convolution.convolve_fft(
            freefree_proj, ff_convbm.as_kernel(gps20_pixscale))

    if fifth_panel_synchro:

        ax4 = figure.add_subplot(1, 6, 5, projection=gps20cutout.wcs)

        # use the central frequency corresponding to an approximately flat spectrum (flat -> 89.72)
        freefree_3mm_to_20cm = 1 / (90 * u.GHz / (1.4 * u.GHz))**-0.12
        #empirical_factor = 3 # freefree was coming out way too high, don't understand why yet
        synchro = gps20_Mjysr_sm - freefree_proj * freefree_3mm_to_20cm
        synchro[np.isnan(gps20_Mjysr) | (gps20_Mjysr == 0)] = np.nan

        synchroish_ratio = gps20_Mjysr_sm / (freefree_proj *
                                             freefree_3mm_to_20cm)

        #synchro = synchroish_ratio

        normsynchro = visualization.ImageNormalize(
            gps20_Mjysr_sm,
            interval=visualization.ManualInterval(
                np.nanpercentile(gps20_Mjysr_sm, 0.5),
                np.nanpercentile(gps20_Mjysr_sm, 99.9)),
            stretch=visualization.LogStretch(),
        )

        ax4.imshow(synchro.value,
                   origin='lower',
                   interpolation='none',
                   norm=normsynchro,
                   cmap=cmap)
        ax4.set_title("Synchrotron")
        ax4.tick_params(direction='in')
        ax4.tick_params(color='w')
        ax4.coords[1].set_axislabel("")
        ax4.coords[1].set_ticklabel_visible(False)

        pl.tight_layout()
    else:
        # scale 20cm to match MGPS and subtract it

        gps20_pixscale = (wcs.utils.proj_plane_pixel_area(gps20cutout.wcs) *
                          u.deg**2)**0.5

        if gps20_bm.sr < mgps_beam.sr:
            # smooth GPS20 to MGPS
            gps20convbm = mgps_beam.deconvolve(gps20_bm)
            gps20_Mjysr_sm = convolution.convolve_fft(
                gps20_Mjysr, gps20convbm.as_kernel(gps20_pixscale))
            gps20_Mjysr_sm[~np.isfinite(gps20_Mjysr)] = np.nan
            gps20_proj = gps20_Mjysr_sm
            #gps20_proj,_ = reproject.reproject_interp((gps20_Mjysr_sm, gps20cutout.wcs),
            #                                          ww,
            #                                          shape_out=mgps_cutout.data.shape)
        else:
            gps20_proj = gps20_Mjysr
            gps20_convbm = gps20_bm.deconvolve(mgps_beam)
            mgpsMjysr = convolution.convolve_fft(
                mgpsMjysr, gps20_convbm.as_kernel(mgps_pixscale))

        ax4 = figure.add_subplot(1, 6, 4, projection=mgps_cutout.wcs)

        # use the central frequency corresponding to an approximately flat spectrum (flat -> 89.72)
        freefree20 = gps20_proj * freefree_20cm_to_3mm
        dust20 = (mgpsMjysr - freefree20).value
        dust20[np.isnan(gps20_proj) | (gps20_proj == 0)] = np.nan

        normdust20 = visualization.ImageNormalize(
            mgpsMjysr,
            interval=visualization.ManualInterval(
                np.nanpercentile(mgpsMjysr, 0.5),
                np.nanpercentile(mgpsMjysr, 99.9)),
            stretch=visualization.LogStretch(),
        )

        # show smoothed 20 cm
        ax3.imshow((freefree20).value,
                   origin='lower',
                   interpolation='none',
                   norm=norm,
                   cmap=cmap)
        ax4.imshow(dust20,
                   origin='lower',
                   interpolation='none',
                   norm=norm,
                   cmap=cmap)
        ax4.set_title("3 mm Dust")
        ax4.tick_params(direction='in')
        ax4.tick_params(color='w')
        ax4.coords[1].set_axislabel("")
        ax4.coords[1].set_ticklabel_visible(False)

        pl.tight_layout()

    #elif 'G01' not in regname:
    #    norm.vmin = np.min([np.nanpercentile(dust20, 0.5), np.nanpercentile(freefree, 0.1)])
    if np.abs(np.nanpercentile(dust20, 0.5) -
              np.nanpercentile(freefree, 0.1)) < 1e2:
        norm.vmin = np.min(
            [np.nanpercentile(dust20, 0.5),
             np.nanpercentile(freefree, 0.1)])
    if 'arches' in reg.meta['text']:
        norm.vmin = 0.95  # force 1 to be on-scale
    if 'w49b' in reg.meta['text']:
        norm.vmin = np.min(
            [np.nanpercentile(dust20, 8),
             np.nanpercentile(freefree, 0.1)])
        norm.vmin = -4
        norm.vmax = 11

    ax0.imshow(mgpsMjysr,
               origin='lower',
               interpolation='none',
               norm=norm,
               cmap=cmap)
    ax1.imshow(dusty,
               origin='lower',
               interpolation='none',
               norm=norm,
               cmap=cmap)
    ax2.imshow(freefree,
               origin='lower',
               interpolation='none',
               norm=norm,
               cmap=cmap)
    ax3.imshow((gps20_proj * freefree_20cm_to_3mm).value,
               origin='lower',
               interpolation='none',
               norm=norm,
               cmap=cmap)
    ax4.imshow(dust20,
               origin='lower',
               interpolation='none',
               norm=norm,
               cmap=cmap)

    print(
        f"{reg}: dusty sum: {dusty[dusty>0].sum()}   freefreeish sum: {freefree[freefree>0].sum()}"
    )

    area = mgps_reproj.size * (reproj_pixscale**2).to(u.sr)
    mgps_reproj_Mjysr = mgps_reproj / mgps_beam.sr.value / 1e6

    # only label the middle axis
    for ax in figure.axes:
        ax.set_xlabel("Galactic Longitude")
    for ax in figure.axes:
        ax.set_xlabel(" ")

    ax0.set_xlabel("Galactic Longitude")

    lastax = ax3
    bbox = lastax.get_position()

    # this is a painful hack to force the bbox to update
    while bbox.height > 0.9:
        print(f"bbox_height = {bbox.height}")
        pl.pause(0.1)
        bbox = lastax.get_position()

    cax = figure.add_axes([bbox.x1 + 0.01, bbox.y0, 0.02, bbox.height])
    cb = figure.colorbar(mappable=lastax.images[-1], cax=cax)
    cb.set_ticks([-3, 0, 10, 50, 100])
    if 'w51' in reg.meta['text']:
        cb.set_ticks([-10, 0, 20, 200])
    if 'w49b' in reg.meta['text']:
        cb.set_ticks([-3, 0, 3, 10])
    if 'arches' in reg.meta['text']:
        cb.set_ticks([0, 1, 5, 10])
    cb.set_label('MJy sr$^{-1}$')

    return {
        'dust': dusty[dusty > 0].sum(),
        'dust20': dust20[dust20 > 0].sum(),
        'freefree': freefree[freefree > 0].sum(),
        'freefree20': freefree20[freefree20 > 0].sum(),
        'totalpos': mgps_reproj_Mjysr[mgps_reproj_Mjysr > 0].sum(),
        'total': mgps_reproj_Mjysr.sum(),
        'totalpos20': mgpsMjysr[mgpsMjysr > 0].sum(),
        'total20': mgpsMjysr.sum(),
    }
Beispiel #25
0
import astropy.visualization as vis

img0 = msc.face()[:, :, 0]  # rgb image, take one channel
hl = fits.open(
    '/Users/luke/Dropbox/proj/timmy/data/phot/2020-04-01/TIC460205581-01-0196_Rc1_out.fit'
)
img1 = hl[0].data
hl.close()

for ix, img in enumerate([img0, img1]):

    if ix == 1:
        vmin, vmax = 10, int(1e4)
        norm = vis.ImageNormalize(vmin=vmin,
                                  vmax=vmax,
                                  stretch=vis.LogStretch(1000))
    else:
        norm = None

    f, axs = plt.subplots(nrows=2, ncols=2)
    # note: this image really should have origin='upper' (otherwise trashpanda is upside-down)
    # but this is to match fits image processing convention
    axs[0, 0].imshow(img, cmap=plt.cm.gray, origin='lower', norm=norm)
    axs[0, 0].set_title('shape: {}'.format(img.shape))

    dx, dy = 200, 50
    axs[1, 0].imshow(ti.integer_shift_img(img, dx, dy),
                     cmap=plt.cm.gray,
                     origin='lower',
                     norm=norm)
    axs[1, 0].set_title('dx={}, dy={}'.format(dx, dy))
Beispiel #26
0
def show_image(image,
               percl=99,
               percu=None,
               figsize=(6, 10),
               cmap='viridis',
               log=False):
    """
    Show an image in matplotlib with some basic astronomically-appropriat stretching.
    
    Parameters
    ----------
    image
        The image to show 
    percl : number
        The percentile for the lower edge of the stretch (or both edges if ``percu`` is None)
    percu : number or None
        The percentile for the upper edge of the stretch (or None to use ``percl`` for both)
    figsize : 2-tuple
        The size of the matplotlib figure in inches
    """
    if percu is None:
        percu = percl
        percl = 100 - percl

    if figsize is not None:
        # Rescale the fig size to match the image dimensions, roughly
        image_aspect_ratio = image.shape[0] / image.shape[1]
        figsize = (max(figsize) * image_aspect_ratio, max(figsize))

    fig, ax = plt.subplots(1, 1, figsize=figsize)

    # To preserve details we should *really* downsample correctly and not rely on
    # matplotlib to do it correctly for us (it won't).

    # So, calculate the size of the figure in pixels, block_reduce to roughly that,
    # and display the block reduced image.

    # Thanks, https://stackoverflow.com/questions/29702424/how-to-get-matplotlib-figure-size
    fig_size_pix = fig.get_size_inches() * fig.dpi

    ratio = (image.shape // fig_size_pix).max()

    if ratio < 1:
        ratio = 1

    # Divide by the square of the ratio to keep the flux the same in the reduced image
    reduced_data = block_reduce(image, ratio) / ratio**2

    # Of course, now that we have downsampled, the axis limits are changed to match
    # the smaller image size. Setting the extent will do the trick to change the axis display
    # back to showing the actual extent of the image.
    extent = [0, image.shape[1], 0, image.shape[0]]
    if log:
        stretch = aviz.LogStretch()
    else:
        stretch = aviz.LinearStretch()
    norm = aviz.ImageNormalize(reduced_data,
                               interval=aviz.AsymmetricPercentileInterval(
                                   percl, percu),
                               stretch=stretch)

    plt.colorbar(
        ax.imshow(reduced_data,
                  norm=norm,
                  origin='lower',
                  cmap=cmap,
                  extent=extent))