Пример #1
0
def plotcontrast(img):
    """
    http://docs.astropy.org/en/stable/visualization/index.html
    """
    vistypes = (None, LogNorm(), vis.AsinhStretch(),
                vis.ContrastBiasStretch(1, 0.5), vis.HistEqStretch(img),
                vis.LinearStretch(), vis.LogStretch(),
                vis.PowerDistStretch(a=10.), vis.PowerStretch(a=10.),
                vis.SinhStretch(), vis.SqrtStretch(), vis.SquaredStretch())

    fg, ax = subplots(4, 3, figsize=(10, 10))
    ax = ax.ravel()

    for i, v in enumerate(vistypes):
        #a = figure().gca()
        a = ax[i]
        if v and not isinstance(v, LogNorm):
            norm = ImageNormalize(stretch=v)
            a.set_title(str(v.__class__).split('.')[-1].split("'")[0])
        else:
            norm = v
            a.set_title(str(v).split('.')[-1].split(" ")[0])

        a.imshow(img, origin='lower', cmap='gray', norm=norm)
        a.axis('off')

    fg.suptitle('Matplotlib/AstroPy normalizations')
def download_tgss_image(url):
    """Download an image for TGSS.

    Parameters
    ----------
    url:
        url of the image.

    Returns
    -------
    file_url:
        Url of the newly downloaded file.
    """

    # parsing ra and dec from the url
    query = urlparse(url).query
    h_pos = parse_qs(query)['hPOS'][0]
    ra, dec = h_pos.split(',')

    # name of the temporary tar file to be saved locally
    local_file_name = '{}_{}.tar'.format(ra, dec)

    # getting the tar file downloaded
    request = requests.get(url, stream=True)
    with open(local_file_name, 'wb') as f:
        for chunk in request.iter_content(chunk_size=1024):
            if chunk:
                f.write(chunk)
        f.flush()

    tar = tarfile.open(local_file_name)
    members = tar.getmembers()

    # temp folder for the fits file
    temp_folder = local_file_name.replace('.tar', '')
    tar.extract(member=members[0], path=temp_folder)

    # removing the temporary file
    os.remove(local_file_name)

    fits_image = temp_folder + '/' + members[0].name
    hdu_list = fits.open(fits_image)
    stretch = vis.AsinhStretch(0.01) + vis.MinMaxInterval()
    file_url = settings.MEDIA_ROOT + 'database_images/' + temp_folder + '_tgss.png'

    image_data = hdu_list[0].data[0, 0]

    try:
        imsave(file_url, stretch(image_data), cmap='copper')
    except OSError:
        logger.info("Something is wrong with the fits file for url = {}".format(url))
        logger.error(OSError)
        file_url = settings.MEDIA_ROOT + 'database_images/no_image.png'

    hdu_list.close()
    shutil.rmtree(temp_folder)

    return file_url
Пример #3
0
    def __init__(self, data, header, **kwargs):

        GenericMap.__init__(self, data, header, **kwargs)

        # Fill in some missing info
        self.meta['detector'] = "AIA"
        self._nickname = self.detector
        self.plot_settings['cmap'] = cm.get_cmap(self._get_cmap_name())
        self.plot_settings['norm'] = ImageNormalize(
            stretch=visualization.AsinhStretch(0.01))
Пример #4
0
    def plot_norm(self,
                  stretch='linear',
                  power=1.0,
                  asinh_a=0.1,
                  min_cut=None,
                  max_cut=None,
                  min_percent=None,
                  max_percent=None,
                  percent=None,
                  clip=True):
        """Create a matplotlib norm object for plotting.

        This is a copy of this function that will be available in Astropy 1.3:
        `astropy.visualization.mpl_normalize.simple_norm`

        See the parameter description there!

        Examples
        --------
        >>> image = SkyImage()
        >>> norm = image.plot_norm(stretch='sqrt', max_percent=99)
        >>> image.plot(norm=norm)
        """
        import astropy.visualization as v
        from astropy.visualization.mpl_normalize import ImageNormalize

        if percent is not None:
            interval = v.PercentileInterval(percent)
        elif min_percent is not None or max_percent is not None:
            interval = v.AsymmetricPercentileInterval(min_percent or 0.,
                                                      max_percent or 100.)
        elif min_cut is not None or max_cut is not None:
            interval = v.ManualInterval(min_cut, max_cut)
        else:
            interval = v.MinMaxInterval()

        if stretch == 'linear':
            stretch = v.LinearStretch()
        elif stretch == 'sqrt':
            stretch = v.SqrtStretch()
        elif stretch == 'power':
            stretch = v.PowerStretch(power)
        elif stretch == 'log':
            stretch = v.LogStretch()
        elif stretch == 'asinh':
            stretch = v.AsinhStretch(asinh_a)
        else:
            raise ValueError('Unknown stretch: {0}.'.format(stretch))

        vmin, vmax = interval.get_limits(self.data)

        return ImageNormalize(vmin=vmin, vmax=vmax, stretch=stretch, clip=clip)
Пример #5
0
def simple(arr, linear=0.00001):
    # Rescale arr to 0,1
    scaled_arr = skimage.exposure.exposure.rescale_intensity(arr,
                                                             out_range=(0, 1))

    # Remove hot pixels
    hot, filt_arr = find_outlier_pixels(scaled_arr, tolerance=10)

    # Asinh stretch arr
    stretch = av.AsinhStretch(linear)
    stretched_arr = stretch(filt_arr)

    # Remove pixels again
    hot, stretched_arr = find_outlier_pixels(stretched_arr, tolerance=3)

    return stretched_arr
def create_stamp_plot(alert: dict, ax, type: str):
    """Helper function to create cutout subplot"""

    with gzip.open(io.BytesIO(b64decode(alert[f"cutout{type}"]["stampData"])),
                   "rb") as f:
        data = fits.open(io.BytesIO(f.read()),
                         ignore_missing_simple=True)[0].data
    vmin, vmax = np.percentile(data[data == data], [0, 100])
    data_ = visualization.AsinhStretch()((data - vmin) / (vmax - vmin))
    ax.imshow(
        data_,
        norm=Normalize(*np.percentile(data_[data_ == data_], [0.5, 99.5])),
        aspect="auto",
    )
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(type, fontdict={"fontsize": "small"})
Пример #7
0
    def __init__(self, data, header, **kwargs):
        GenericMap.__init__(self, data, header, **kwargs)
        if header.get('lvl_num') == 2:
            self.meta['wavelnth'] = header.get('twave1')
            self.meta['detector'] = header.get('instrume')
            self.meta['waveunit'] = "Angstrom"
        if header.get('lvl_num') == 1:
            self.meta['wavelnth'] = int(header.get('img_path').split('_')[1])
            self.meta['waveunit'] = "Angstrom"

        self.meta['detector'] = "SJI"
        self.meta['waveunit'] = "Angstrom"
        palette = cm.get_cmap('irissji' + str(int(self.meta['wavelnth'])))
        palette.set_bad('black')
        self.plot_settings['cmap'] = palette
        self.plot_settings['norm'] = ImageNormalize(
            stretch=visualization.AsinhStretch(0.1))
Пример #8
0
def get_im_stretch(stretch):
    ''' Returns a stretch to feed the ImageNormalize routine from Astropy.

   :param stretch: short name for the stretch I want. Possibilities are 'arcsinh' or 'linear'.
   :type stretch: string

   :return: A :class:`astropy.visualization.stretch` thingy ...
   :rtype: :class:`astropy.visualization.stretch`

   '''

    if stretch == 'arcsinh':
        return astrovis.AsinhStretch()
    if stretch == 'linear':
        return astrovis.LinearStretch()

    raise Exception('Ouch! Stretch %s unknown.' % (stretch))
Пример #9
0
def shrink(arr, bounds=None, ax=None):
    if bounds is None:
        low = np.min(arr)
        high = np.max(arr)
    else:
        low, high = bounds

    # Rescale arr to 0,1
    arr_ = skimage.exposure.exposure.rescale_intensity(arr,
                                                       in_range=(low, high),
                                                       out_range=(0, 1))

    # Remove hot pixels
    #hot,arr_ = find_outlier_pixels(arr_,tolerance=10)

    # Asinh stretch arr_
    stretch = av.AsinhStretch(0.005)
    arr_ = stretch(arr_)

    # Remove pixels again
    #hot,arr_ = find_outlier_pixels(arr_,tolerance=3)

    # Rescale arr_ to -1,1
    arr_ = skimage.exposure.exposure.rescale_intensity(arr_, out_range=(-1, 1))

    # Find edges
    edge = skimage.filters.scharr(arr_)

    # Rescale edge to 0,1
    edge = skimage.exposure.exposure.rescale_intensity(edge, out_range=(0, 1))

    # Asinh stretch edges
    #stretch = av.AsinhStretch(0.1)
    #edge = stretch(edge)

    # Rescale
    avg = np.average(arr, weights=edge)
    dist = avg - arr
    dist = (1 - edge) * dist
    res = arr + dist

    # Map low and high back to old scale
    return arr[np.unravel_index(res.argmin(),
                                res.shape)], arr[np.unravel_index(
                                    res.argmax(), res.shape)]
Пример #10
0
def download_first_image(url, galaxy):
    """
    Download first image from the url
    :param url: link of the image
    :param galaxy: galaxy object
    :return: url of the saved image
    """
    stretch = vis.AsinhStretch(0.01) + vis.MinMaxInterval()

    file_url = settings.MEDIA_ROOT + 'database_images/' + galaxy.first + '.png'
    try:
        imsave(file_url, stretch(fits.open(download_file(url, cache=True))[0].data), cmap='inferno')
    except OSError:
        logger.info("Something is wrong with the fits file for url = {}".format(url))
        logger.error(OSError)
        file_url = settings.MEDIA_ROOT + 'temp_images/no_image.png'

    return file_url
Пример #11
0
def complex(arr, mode, linear=0.00001, trimbright=100.0):
    # shrink
    if mode == "adapt":
        bounds = shrink(arr)
    elif mode == "percent":
        bounds = max(np.min(arr), -250.), np.percentile(arr, trimbright)
    elif mode == "fixed":
        bounds = max(np.min(arr), -250.), trimbright
    else:
        raise Exception("Mode %s not supported" % mode)

    # Rescale arr to 0,1
    arr = skimage.exposure.exposure.rescale_intensity(arr,
                                                      in_range=bounds,
                                                      out_range=(0, 1))

    # Asinh stretch arr
    stretch = av.AsinhStretch(linear)
    arr = stretch(arr)

    # floating point stuff can bump values outside 0,1. re-force:
    arr = skimage.exposure.exposure.rescale_intensity(arr, out_range=(0, 1))

    return arr
    visualization_tools.make_scalebar(ax,
                                      left_side=scalebarx,
                                      length=1.213 * u.arcsec,
                                      label='0.05 pc')

    pl.savefig(paths.fpath(figfilename), bbox_inches='tight')


if __name__ == "__main__":

    rgbfig(stretch=visualization.LinearStretch(), )

    rgbfig(
        figfilename='SgrB2M_RGB.pdf',
        lims=[(266.8359, 266.8325), (-28.38600555, -28.3832)],
        redfn=paths.Fpath('SGRB2M-2012-Q-MEAN.DePree.recentered.fits'),
        greenfn=paths.Fpath(
            'sgr_b2m.M.B3.allspw.continuum.r0.5.clean1000.image.tt0.pbcor.fits'
        ),
        bluefn=paths.Fpath(
            'sgr_b2m.M.B6.allspw.continuum.r0.5.clean1000.image.tt0.pbcor.fits'
        ),
        scalebarx=coordinates.SkyCoord(266.8336007 * u.deg,
                                       -28.38553839 * u.deg),
        redpercentile=99.99,
        greenpercentile=99.98,
        bluepercentile=99.98,
        stretch=visualization.AsinhStretch(),
    )
Пример #13
0
            cubes['mask'].with_mask(include_mask).minimal_subcube()[0]
            if crop else cubes['mask'].with_mask(include_mask)[0])
    except AssertionError:
        # this implies there is no mask
        pass

    imgs['includemask'] = include_mask  # the mask applied to the cube

    # give up on the 'Slice' nature so we can change units
    imgs['model'] = imgs['model'].quantity * cubes[
        'image'].pixels_per_beam * u.pix / u.beam

    return imgs, cubes


asinhn = visualization.ImageNormalize(stretch=visualization.AsinhStretch())


def show(imgs,
         zoom=None,
         clear=True,
         norm=asinhn,
         imnames_toplot=('image', 'model', 'residual', 'mask'),
         **kwargs):

    if clear:
        pl.clf()

    if 'mask' not in imgs:
        imnames_toplot = list(imnames_toplot)
        imnames_toplot.remove('mask')
Пример #14
0
def plot_image(image,
               ax=None,
               scale='log',
               cmap=None,
               origin='lower',
               xlabel=None,
               ylabel=None,
               cbar=None,
               clabel='Flux ($e^{-}s^{-1}$)',
               cbar_ticks=None,
               cbar_ticklabels=None,
               cbar_pad=None,
               cbar_size='5%',
               title=None,
               percentile=95.0,
               vmin=None,
               vmax=None,
               offset_axes=None,
               color_bad='k',
               **kwargs):
    """
	Utility function to plot a 2D image.

	Parameters:
		image (2d array): Image data.
		ax (matplotlib.pyplot.axes, optional): Axes in which to plot.
			Default (None) is to use current active axes.
		scale (str or :py:class:`astropy.visualization.ImageNormalize` object, optional):
			Normalization used to stretch the colormap.
			Options: ``'linear'``, ``'sqrt'``, ``'log'``, ``'asinh'``, ``'histeq'``, ``'sinh'``
			and ``'squared'``.
			Can also be a :py:class:`astropy.visualization.ImageNormalize` object.
			Default is ``'log'``.
		origin (str, optional): The origin of the coordinate system.
		xlabel (str, optional): Label for the x-axis.
		ylabel (str, optional): Label for the y-axis.
		cbar (string, optional): Location of color bar.
			Choises are ``'right'``, ``'left'``, ``'top'``, ``'bottom'``.
			Default is not to create colorbar.
		clabel (str, optional): Label for the color bar.
		cbar_size (float, optional): Fractional size of colorbar compared to axes. Default=0.03.
		cbar_pad (float, optional): Padding between axes and colorbar.
		title (str or None, optional): Title for the plot.
		percentile (float, optional): The fraction of pixels to keep in color-trim.
			If single float given, the same fraction of pixels is eliminated from both ends.
			If tuple of two floats is given, the two are used as the percentiles.
			Default=95.
		cmap (matplotlib colormap, optional): Colormap to use. Default is the ``Blues`` colormap.
		vmin (float, optional): Lower limit to use for colormap.
		vmax (float, optional): Upper limit to use for colormap.
		color_bad (str, optional): Color to apply to bad pixels (NaN). Default is black.
		kwargs (dict, optional): Keyword arguments to be passed to :py:func:`matplotlib.pyplot.imshow`.

	Returns:
		:py:class:`matplotlib.image.AxesImage`: Image from returned
			by :py:func:`matplotlib.pyplot.imshow`.

	.. codeauthor:: Rasmus Handberg <*****@*****.**>
	"""

    logger = logging.getLogger(__name__)

    # Backward compatible settings:
    make_cbar = kwargs.pop('make_cbar', None)
    if make_cbar:
        raise FutureWarning("'make_cbar' is deprecated. Use 'cbar' instead.")
        if not cbar:
            cbar = make_cbar

    # Special treatment for boolean arrays:
    if isinstance(image, np.ndarray) and image.dtype == 'bool':
        if vmin is None: vmin = 0
        if vmax is None: vmax = 1
        if cbar_ticks is None: cbar_ticks = [0, 1]
        if cbar_ticklabels is None: cbar_ticklabels = ['False', 'True']

    # Calculate limits of color scaling:
    interval = None
    if vmin is None or vmax is None:
        if allnan(image):
            logger.warning("Image is all NaN")
            vmin = 0
            vmax = 1
            if cbar_ticks is None:
                cbar_ticks = []
            if cbar_ticklabels is None:
                cbar_ticklabels = []
        elif isinstance(percentile, (list, tuple, np.ndarray)):
            interval = viz.AsymmetricPercentileInterval(
                percentile[0], percentile[1])
        else:
            interval = viz.PercentileInterval(percentile)

    # Create ImageNormalize object with extracted limits:
    if scale in ('log', 'linear', 'sqrt', 'asinh', 'histeq', 'sinh',
                 'squared'):
        if scale == 'log':
            stretch = viz.LogStretch()
        elif scale == 'linear':
            stretch = viz.LinearStretch()
        elif scale == 'sqrt':
            stretch = viz.SqrtStretch()
        elif scale == 'asinh':
            stretch = viz.AsinhStretch()
        elif scale == 'histeq':
            stretch = viz.HistEqStretch(image[np.isfinite(image)])
        elif scale == 'sinh':
            stretch = viz.SinhStretch()
        elif scale == 'squared':
            stretch = viz.SquaredStretch()

        # Create ImageNormalize object. Very important to use clip=False if the image contains
        # NaNs, otherwise NaN points will not be plotted correctly.
        norm = viz.ImageNormalize(data=image[np.isfinite(image)],
                                  interval=interval,
                                  vmin=vmin,
                                  vmax=vmax,
                                  stretch=stretch,
                                  clip=not anynan(image))

    elif isinstance(scale, (viz.ImageNormalize, matplotlib.colors.Normalize)):
        norm = scale
    else:
        raise ValueError("scale {} is not available.".format(scale))

    if offset_axes:
        extent = (offset_axes[0] - 0.5, offset_axes[0] + image.shape[1] - 0.5,
                  offset_axes[1] - 0.5, offset_axes[1] + image.shape[0] - 0.5)
    else:
        extent = (-0.5, image.shape[1] - 0.5, -0.5, image.shape[0] - 0.5)

    if ax is None:
        ax = plt.gca()

    # Set up the colormap to use. If a bad color is defined,
    # add it to the colormap:
    if cmap is None:
        cmap = copy.copy(plt.get_cmap('Blues'))
    elif isinstance(cmap, str):
        cmap = copy.copy(plt.get_cmap(cmap))

    if color_bad:
        cmap.set_bad(color_bad, 1.0)

    # Plotting the image using all the settings set above:
    im = ax.imshow(image,
                   cmap=cmap,
                   norm=norm,
                   origin=origin,
                   extent=extent,
                   interpolation='nearest',
                   **kwargs)

    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)
    if title is not None:
        ax.set_title(title)
    ax.set_xlim([extent[0], extent[1]])
    ax.set_ylim([extent[2], extent[3]])

    if cbar:
        colorbar(im,
                 ax=ax,
                 loc=cbar,
                 size=cbar_size,
                 pad=cbar_pad,
                 label=clabel,
                 ticks=cbar_ticks,
                 ticklabels=cbar_ticklabels)

    # Settings for ticks:
    integer_locator = MaxNLocator(nbins=10, integer=True)
    ax.xaxis.set_major_locator(integer_locator)
    ax.xaxis.set_minor_locator(integer_locator)
    ax.yaxis.set_major_locator(integer_locator)
    ax.yaxis.set_minor_locator(integer_locator)
    ax.tick_params(which='both', direction='out', pad=5)
    ax.xaxis.tick_bottom()
    ax.yaxis.tick_left()

    return im
Пример #15
0
pl.savefig('pspec_threshold.png')

rslts_nothresh = uvcombine.feather_plot(alma_image,
                                        lores=mgps_image,
                                        lowresfwhm=11 * u.arcsec)
pl.savefig('pspec_nothreshold.png')

hdu_low, im_lowraw, header_low = file_in(lores)
hdu_hi, im_hi, header_hi = file_in(hires)

pl.figure(2).clf()
combined = uvcombine.feather_simple(alma_image,
                                    mgps_image,
                                    lowresfwhm=10 * u.arcsec)

pl.imshow(visualization.AsinhStretch()(combined.real),
          vmax=0.1,
          origin='lower',
          interpolation='none')
hdr = fits.getheader(alma_image)

combined_default = uvcombine.feather_simple(alma_image, mgps_image)
fits.PrimaryHDU(data=np.abs(combined_default),
                header=hdr).writeto('feathered_MGPS_ALMATCTE7m.fits',
                                    overwrite=True)

combined_deconvsd = uvcombine.feather_simple(alma_image,
                                             mgps_image,
                                             deconvSD=True)
fits.PrimaryHDU(data=np.abs(combined_deconvsd),
                header=hdr).writeto('feathered_MGPS_ALMATCTE7m_deconvsd.fits',
Пример #16
0
def get_image(data, fmt='JPEG', norm='percentile', lo=None, hi=None,
              zcontrast=0.25, nsamples=1000, krej=2.5, max_iterations=5,
              stretch='linear', a=None, bias=0.5, contrast=1, cmap=None,
              dpi=100, **kwargs):
    u"""
    Return a byte array containing image in the given format

    Image scaling is done using `~astropy.visualization`. It includes
    normalization of the input data (mapping to [0, 1]) and stretching -
    optional non-linear mapping [0, 1] -> [0, 1] for contrast enhancement.
    A colormap can be applied to the normalized data. Conversion to the target
    image format is done by matplotlib or Pillow.

    :param array_like data: input 2D image data
    :param str fmt: output image format
    :param str norm: data normalization mode::
        "manual": lower and higher clipping limits are set explicitly
        "minmax": limits are set to the minimum and maximum data values
        "percentile" (default): limits are set based on the specified fraction
            of pixels
        "zscale": use IRAF ZScale algorithm
    :param int | float lo::
        for ``norm`` == "manual", lower data limit
        for ``norm`` == "percentile", lower percentile clipping value,
            defaulting to 10
        for ``norm`` == "zscale", lower limit on the number of rejected pixels,
            defaulting to 5
    :param int | float hi::
        for ``norm`` == "manual", upper data limit
        for ``norm`` == "percentile", upper percentile clipping value,
            defaulting to 98
        for ``norm`` == "zscale", upper limit on the number of rejected pixels,
            defaulting to data.size/2
    :param float zcontrast: for ``norm`` == "zscale", the scaling factor,
        0 < zcontrast < 1, defaulting to 0.25
    :param int nsamples: for ``norm`` == "zscale", the number of points in
        the input array for determining scale factors, defaulting to 1000
    :param float krej: for ``norm`` == "zscale", the sigma clipping factor,
        defaulting to 2.5
    :param int max_iterations: for ``norm`` == "zscale", the maximum number
        of rejection iterations, defaulting to 5
    :param str stretch: [0, 1] → [0, 1] mapping mode::
        "asinh": hyperbolic arcsine stretch y = asinh(x/a)/asinh(1/a)
        "contrast": linear bias/contrast-based stretch
            y = (x - bias)*contrast + 0.5
        "exp": exponential stretch y = (a^x - 1)/(a - 1)
        "histeq": histogram equalization stretch
        "linear" (default): direct mapping
        "log": logarithmic stretch y = log(ax + 1)/log(a + 1)
        "power": power stretch y = x^a
        "sinh": hyperbolic sine stretch y = sinh(x/a)/sinh(1/a)
        "sqrt": square root stretch y = √x
        "square": power stretch y = x^2
    :param float a: non-linear stretch parameter::
        for ``stretch`` == "asinh", the point of transition from linear to
            logarithmic behavior, 0 < a <= 1, defaulting to 0.1
        for ``stretch`` == "exp", base of the exponent, a != 1, defaulting to
            1000
        for ``stretch`` == "log", base of the logarithm minus 1, a > 0,
            defaulting to 1000
        for ``stretch`` == "power", the power index, defaulting to 3
        for ``stretch`` == "sinh", a > 0, defaulting to 1/3
    :param float bias: for ``stretch`` == "contrast", the bias parameter,
        defaulting to 0.5
    :param float contrast: for ``stretch`` == "contrast", the contrast
        parameter, defaulting to 1
    :param str cmap: optional matplotlib colormap name, defaulting
        to grayscale; when a non-grayscale colormap is specified,
        the conversion is always done by matplotlib, regardless of the
        availability of Pillow; see https://matplotlib.org/users/colormaps.html
        for more info on matplotlib colormaps and
            [name for name in matplotlib.cd.cmap_d.keys()
             if not name.endswith('_r')]
        to list the available colormap names
    :param int dpi: target image resolution in dots per inch
    :param kwargs: optional format-specific keyword arguments passed to Pillow,
        e.g. "quality" for JPEG; see
        `https://pillow.readthedocs.io/en/stable/handbook/
        image-file-formats.html`_

    :return: a bytes object containing the image in the given format
    :rtype: bytes
    """
    data = asanyarray(data)

    # Normalize image data
    if norm == 'manual':
        if lo is None:
            raise ValueError(
                'Missing lower clipping boundary for norm="manual"')
        if hi is None:
            raise ValueError(
                'Missing upper clipping boundary for norm="manual"')
    elif norm == 'minmax':
        lo, hi = data.min(), data.max()
    elif norm == 'percentile':
        if lo is None:
            lo = 10
        elif not 0 <= lo <= 100:
            raise ValueError(
                'Lower clipping percentile must be in the [0,100] range')
        if hi is None:
            hi = 98
        elif not 0 <= hi <= 100:
            raise ValueError(
                'Upper clipping percentile must be in the [0,100] range')
        if hi < lo:
            raise ValueError(
                'Upper clipping percentile must be greater or equal to '
                'lower percentile')
        lo, hi = percentile(data, [lo, hi])
    elif norm == 'zscale':
        if lo is None:
            lo = 5
        if hi is None:
            hi = 0.5
        else:
            hi /= data.size
        lo, hi = apy_vis.ZScaleInterval(
            nsamples, zcontrast, hi, lo, krej, max_iterations).get_limits(data)
    else:
        raise ValueError('Unknown normalization mode "{}"'.format(norm))
    data = clip((data - lo)/(hi - lo), 0, 1)

    # Stretch the data
    if stretch == 'asinh':
        if a is None:
            a = 0.1
        apy_vis.AsinhStretch(a)(data, out=data)
    elif stretch == 'contrast':
        if bias != 0.5 or contrast != 1:
            apy_vis.ContrastBiasStretch(contrast, bias)(data, out=data)
    elif stretch == 'exp':
        if a is None:
            a = 1000
        apy_vis.PowerDistStretch(a)(data, out=data)
    elif stretch == 'histeq':
        apy_vis.HistEqStretch(data)(data, out=data)
    elif stretch == 'linear':
        pass
    elif stretch == 'log':
        if a is None:
            a = 1000
        apy_vis.LogStretch(a)(data, out=data)
    elif stretch == 'power':
        if a is None:
            a = 3
        apy_vis.PowerStretch(a)(data, out=data)
    elif stretch == 'sinh':
        if a is None:
            a = 1/3
        apy_vis.SinhStretch(a)(data, out=data)
    elif stretch == 'sqrt':
        apy_vis.SqrtStretch()(data, out=data)
    elif stretch == 'square':
        apy_vis.SquaredStretch()(data, out=data)
    else:
        raise ValueError('Unknown stretch mode "{}"'.format(stretch))

    buf = BytesIO()
    try:
        # Choose the backend for making an image
        if cmap is None:
            cmap = 'gray'
        if cmap == 'gray':
            try:
                # noinspection PyPackageRequirements,PyPep8Naming
                from PIL import Image as pil_image
            except ImportError:
                pil_image = None
        else:
            pil_image = None

        if pil_image is not None:
            # Use Pillow for grayscale output if available; flip the image to
            # match the bottom-to-top FITS convention and convert from [0,1] to
            # unsigned byte
            pil_image.fromarray(
                (data[::-1]*255 + 0.5).astype(uint8),
            ).save(buf, fmt, dpi=(dpi, dpi), **kwargs)
        else:
            # Use matplotlib for non-grayscale colormaps or if PIL is not
            # available
            # noinspection PyPackageRequirements
            from matplotlib import image as mpl_image
            if fmt.lower() == 'png':
                # PNG images are saved upside down by matplotlib, regardless of
                # the origin parameter
                data = data[::-1]
            # noinspection PyTypeChecker
            mpl_image.imsave(
                buf, data, cmap=cmap, format=fmt, origin='lower', dpi=dpi)

        return buf.getvalue()
    finally:
        buf.close()
Пример #17
0
    def create_figure(self,
                      output_filename,
                      survey,
                      stretch='log',
                      vmin=1,
                      vmax=None,
                      min_percent=1,
                      max_percent=95,
                      cmap='gray',
                      contour_color='red',
                      data_col='FLUX'):
        """Returns a matplotlib Figure object that visualizes a frame.

        Parameters
        ----------

        vmin : float, optional
            Minimum cut level (default: 0).

        vmax : float, optional
            Maximum cut level (default: 5000).

        cmap : str, optional
            The matplotlib color map name.  The default is 'gray',
            can also be e.g. 'gist_heat'.

        raw : boolean, optional
            If `True`, show the raw pixel counts rather than
            the calibrated flux. Default: `False`.

        Returns
        -------
        image : array
            An array of unisgned integers of shape (x, y, 3),
            representing an RBG colour image x px wide and y px high.
        """
        # Get the flux data to visualize
        # Update to use TPF
        flx = self.TPF.flux_binned()
        # print(np.shape(flx))

        # calculate cut_levels
        if vmax is None:
            vmin, vmax = self.cut_levels(min_percent, max_percent, data_col)

        # Determine the figsize
        shape = list(flx.shape)
        # print(shape)
        # Create the figure and display the flux image using matshow
        fig = plt.figure(figsize=shape)
        # Display the image using matshow

        # Update to generate axes using WCS axes instead of plain axes
        ax = plt.subplot(projection=self.TPF.wcs)
        ax.set_xlabel('RA')
        ax.set_ylabel('Dec')

        if self.verbose:
            print('{} vmin/vmax = {}/{} (median={})'.format(
                data_col, vmin, vmax, np.nanmedian(flx)))

        if stretch == 'linear':
            stretch_fn = visualization.LinearStretch()
        elif stretch == 'sqrt':
            stretch_fn = visualization.SqrtStretch()
        elif stretch == 'power':
            stretch_fn = visualization.PowerStretch(1.0)
        elif stretch == 'log':
            stretch_fn = visualization.LogStretch()
        elif stretch == 'asinh':
            stretch_fn = visualization.AsinhStretch(0.1)
        else:
            raise ValueError('Unknown stretch: {0}.'.format(stretch))

        transform = (stretch_fn +
                     visualization.ManualInterval(vmin=vmin, vmax=vmax))
        ax.imshow((255 * transform(flx)).astype(int),
                  aspect='auto',
                  origin='lower',
                  interpolation='nearest',
                  cmap=cmap,
                  norm=NoNorm())
        ax.set_xticks([])
        ax.set_yticks([])

        current_ylims = ax.get_ylim()
        current_xlims = ax.get_xlim()

        pixels, header = surveyquery.getSVImg(self.TPF.position, survey)
        levels = np.linspace(np.min(pixels), np.percentile(pixels, 95), 10)
        ax.contour(pixels,
                   transform=ax.get_transform(WCS(header)),
                   levels=levels,
                   colors=contour_color)

        ax.set_xlim(current_xlims)
        ax.set_ylim(current_ylims)

        fig.canvas.draw()
        plt.savefig(output_filename, bbox_inches='tight', dpi=300)
        return fig
Пример #18
0
import keract

plt.style.use("seaborn-darkgrid")

DATA_CSV = '/full/path/to/csv'
IMG_FOLD = '/full/path/to/images'  #for me, ends with /Public/EUC_VIS/
DES_IMG_SIZE = 200
NUM_IMAGES = 50000  # can go up to nearly 100000, but there are a few entries with now images
BATCHSIZE = 200
EPOCHS = 20  # I liked doing 20 for quick ones, 40 for longer. maybe we should try more epochs?

# Creating the two normalization objects which
# can also work as functions later on in the
# data preprocessing
norm = avis.MinMaxInterval()
stretch = avis.AsinhStretch(0.010)


def read_csv(filename: str,
             num_images: int,
             to_skip: list[int] = []) -> pd.DataFrame:
    """
    Reads in the csv and trims it down to just the desired size and
    necessary columns. Also adds a column of booleans corresponding
    to whether a particular observation has been assigned to the
    training or testing pool. Training corresponding to True.
    """
    df = pd.read_csv(filename, skiprows=26)
    df = df[["ID", "n_sources", "n_source_im", "mag_eff", "n_pix_source"]]
    df["should_detect"] = ((df["n_source_im"] > 0) & (df["mag_eff"] > 1.6) &
                           (df["n_pix_source"] > 20))
Пример #19
0
def make_fov_image(fov, pngfn=None, **kwargs):
    stretch = kwargs.get('stretch', 'linear')
    interval = kwargs.get('interval', 'zscale')
    imrange = kwargs.get('imrange')
    contrast = kwargs.get('contrast', 0.25)
    ccdplotorder = ['CCD2', 'CCD4', 'CCD1', 'CCD3']
    if interval == 'rms':
        try:
            losig, hisig = imrange
        except:
            losig, hisig = (2.5, 5.0)
    #
    cmap = kwargs.get('cmap', 'viridis')
    cmap = plt.get_cmap(cmap)
    cmap.set_bad('w', 1.0)
    w = 0.4575
    h = 0.455
    rc('text', usetex=False)
    fig = plt.figure(figsize=(6, 6.5))
    cax = fig.add_axes([0.1, 0.04, 0.8, 0.01])
    ims = [fov[ccd]['im'] for ccd in ccdplotorder]
    allpix = np.ma.array(ims).flatten()
    stretch = {
        'linear': vis.LinearStretch(),
        'histeq': vis.HistEqStretch(allpix),
        'asinh': vis.AsinhStretch(),
    }[stretch]
    if interval == 'zscale':
        iv = vis.ZScaleInterval(contrast=contrast)
        vmin, vmax = iv.get_limits(allpix)
    elif interval == 'rms':
        nsample = 1000 // nbin
        background = sigma_clip(allpix[::nsample], iters=3, sigma=2.2)
        m, s = background.mean(), background.std()
        vmin, vmax = m - losig * s, m + hisig * s
    elif interval == 'fixed':
        vmin, vmax = imrange
    else:
        raise ValueError
    norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=stretch)
    for n, (im, ccd) in enumerate(zip(ims, ccdplotorder)):
        if im.ndim == 3:
            im = im.mean(axis=-1)
        x = fov[ccd]['x']
        y = fov[ccd]['y']
        i = n % 2
        j = n // 2
        pos = [0.0225 + i * w + i * 0.04, 0.05 + j * h + j * 0.005, w, h]
        ax = fig.add_axes(pos)
        _im = ax.imshow(im,
                        origin='lower',
                        extent=[x[0, 0], x[0, -1], y[0, 0], y[-1, 0]],
                        norm=norm,
                        cmap=cmap,
                        interpolation=kwargs.get('interpolation', 'nearest'))
        if fov['coordsys'] == 'sky':
            ax.set_xlim(x.max(), x.min())
        else:
            ax.set_xlim(x.min(), x.max())
        ax.set_ylim(y.min(), y.max())
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        if n == 0:
            cb = fig.colorbar(_im, cax, orientation='horizontal')
            cb.ax.tick_params(labelsize=9)
    tstr = fov.get('file', '') + ' ' + fov.get('objname', '')
    title = kwargs.get('title', tstr)
    title = title[-60:]
    fig.text(0.5, 0.99, title, ha='center', va='top', size=12)
    if pngfn is not None:
        plt.savefig(pngfn)
        plt.close(fig)
Пример #20
0
from matplotlib.collections import PathCollection
from matplotlib.legend_handler import HandlerPathCollection
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import numpy as np
import pandas as pd
import os
from utils import get_sets

matplotlib.rcParams.update({'font.size': 6})
bands = [
    'U', 'F378', 'F395', 'F410', 'F430', 'G', 'F515', 'R', 'F660', 'I', 'F861',
    'Z'
]

stretcher = vis.AsinhStretch()
# scaler = vis.ZScaleInterval()


def plot_single_band(filename, asinh=True, output_file=None):
    data = fits.getdata(filename)
    # data = data[650:750, 450:550]
    data = data[3500:4500, 7000:8000]
    # data = data[8300:8600,9100:9400]
    if asinh:
        data = stretcher(data, clip=False)
    fig = plt.figure(frameon=False)
    ax = fig.add_axes([0, 0, 1, 1])
    ax.axis('off')
    plt.imshow(data, cmap='gray')
    if output_file is None:
Пример #21
0
rslts_thresh = uvcombine.feather_plot(almafn,
                                      lores=loresfn,
                                      lowresfwhm=loresfwhm,
                                      hires_threshold=0.0005,
                                      lores_threshold=0.001)
pl.figure(2).clf()
pl.imshow(combined.real + 0.01,
          origin='lower',
          interpolation='none',
          vmax=0.1,
          vmin=0.001,
          norm=pl.matplotlib.colors.LogNorm())
pl.axis((1620, 2300, 1842, 2750))

pl.figure(3).clf()
asinhnorm = lambda: visualization.ImageNormalize(visualization.AsinhStretch())

ax1 = pl.subplot(2, 3, 1)
im1 = ax1.imshow(combined.real + 0.01,
                 origin='lower',
                 interpolation='none',
                 vmax=0.1,
                 vmin=0.001,
                 norm=asinhnorm())
ax1.axis((1620, 2300, 1842, 2750))
pl.colorbar(mappable=im1)

ax2 = pl.subplot(2, 3, 2)
im2 = ax2.imshow(almafh.data,
                 origin='lower',
                 interpolation='none',
Пример #22
0
    def create_figure(self,
                      frameno=0,
                      binning=1,
                      dpi=None,
                      stretch='log',
                      vmin=1,
                      vmax=5000,
                      cmap='gray',
                      data_col='FLUX',
                      annotate=True,
                      time_format='ut',
                      show_flags=False,
                      label=None):
        """Returns a matplotlib Figure object that visualizes a frame.

        Parameters
        ----------
        frameno : int
            Image number in the target pixel file.

        binning : int
            Number of frames around `frameno` to co-add. (default: 1).

        dpi : float, optional [dots per inch]
            Resolution of the output in dots per Kepler CCD pixel.
            By default the dpi is chosen such that the image is 440px wide.

        vmin : float, optional
            Minimum cut level (default: 1).

        vmax : float, optional
            Maximum cut level (default: 5000).

        cmap : str, optional
            The matplotlib color map name.  The default is 'gray',
            can also be e.g. 'gist_heat'.

        raw : boolean, optional
            If `True`, show the raw pixel counts rather than
            the calibrated flux. Default: `False`.

        annotate : boolean, optional
            Annotate the Figure with a timestamp and target name?
            (Default: `True`.)

        show_flags : boolean, optional
            Show the quality flags?
            (Default: `False`.)

        label : str
            Label text to show in the bottom left corner of the movie.

        Returns
        -------
        image : array
            An array of unisgned integers of shape (x, y, 3),
            representing an RBG colour image x px wide and y px high.
        """
        # Get the flux data to visualize
        flx = self.flux_binned(frameno=frameno,
                               binning=binning,
                               data_col=data_col)

        # Determine the figsize and dpi
        shape = list(flx.shape)
        shape = [shape[1], shape[0]]
        if dpi is None:
            # Twitter timeline requires dimensions between 440x220 and 1024x512
            # so we make 440 the default
            dpi = 440 / float(shape[0])

        # libx264 require the height to be divisible by 2, we ensure this here:
        shape[0] -= ((shape[0] * dpi) % 2) / dpi

        # Create the figureand display the flux image using matshow
        fig = pl.figure(figsize=shape, dpi=dpi)
        # Display the image using matshow
        ax = fig.add_subplot(1, 1, 1)
        if self.verbose:
            print('{} vmin/vmax = {}/{} (median={})'.format(
                data_col, vmin, vmax, np.nanmedian(flx)))

        if stretch == 'linear':
            stretch_fn = visualization.LinearStretch()
        elif stretch == 'sqrt':
            stretch_fn = visualization.SqrtStretch()
        elif stretch == 'power':
            stretch_fn = visualization.PowerStretch(1.0)
        elif stretch == 'log':
            stretch_fn = visualization.LogStretch()
        elif stretch == 'asinh':
            stretch_fn = visualization.AsinhStretch(0.1)
        else:
            raise ValueError('Unknown stretch: {0}.'.format(stretch))

        transform = (stretch_fn +
                     visualization.ManualInterval(vmin=vmin, vmax=vmax))
        flx_transform = 255 * transform(flx)
        # Make sure to remove all NaNs!
        flx_transform[~np.isfinite(flx_transform)] = 0
        ax.imshow(flx_transform.astype(int),
                  aspect='auto',
                  origin='lower',
                  interpolation='nearest',
                  cmap=cmap,
                  norm=NoNorm())
        if annotate:  # Annotate the frame with a timestamp and target name?
            fontsize = 3. * shape[0]
            margin = 0.03
            # Print target name in lower left corner
            if label is None:
                label = self.objectname
            txt = ax.text(margin,
                          margin,
                          label,
                          family="monospace",
                          fontsize=fontsize,
                          color='white',
                          transform=ax.transAxes)
            txt.set_path_effects([
                path_effects.Stroke(linewidth=fontsize / 6.,
                                    foreground='black'),
                path_effects.Normal()
            ])
            # Print a timestring in the lower right corner
            txt2 = ax.text(1 - margin,
                           margin,
                           self.timestamp(frameno, time_format=time_format),
                           family="monospace",
                           fontsize=fontsize,
                           color='white',
                           ha='right',
                           transform=ax.transAxes)
            txt2.set_path_effects([
                path_effects.Stroke(linewidth=fontsize / 6.,
                                    foreground='black'),
                path_effects.Normal()
            ])
            # Print quality flags in upper right corner
            if show_flags:
                flags = self.quality_flags(frameno)
                if len(flags) > 0:
                    txt3 = ax.text(margin,
                                   1 - margin,
                                   '\n'.join(flags),
                                   family="monospace",
                                   fontsize=fontsize * 1.3,
                                   color='white',
                                   ha='left',
                                   va='top',
                                   transform=ax.transAxes,
                                   linespacing=1.5,
                                   backgroundcolor='red')
                    txt3.set_path_effects([
                        path_effects.Stroke(linewidth=fontsize / 6.,
                                            foreground='black'),
                        path_effects.Normal()
                    ])
        ax.set_xticks([])
        ax.set_yticks([])
        ax.axis('off')
        fig.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0)
        fig.canvas.draw()
        return fig
Пример #23
0
def make_sed_plot(coordinate,
                  mgpsfile,
                  width=1 * u.arcmin,
                  surveys=Magpis.list_surveys(),
                  figure=None,
                  regname='GAL_031'):

    mgps_fh = fits.open(mgpsfile)[0]
    frame = wcs.utils.wcs_to_celestial_frame(wcs.WCS(mgps_fh.header))

    coordname = "{0:06.3f}_{1:06.3f}".format(coordinate.galactic.l.deg,
                                             coordinate.galactic.b.deg)

    mgps_cutout = Cutout2D(mgps_fh.data,
                           coordinate.transform_to(frame.name),
                           size=width * 2,
                           wcs=wcs.WCS(mgps_fh.header))
    print(
        f"Retrieving MAGPIS data for {coordname} ({coordinate.to_string()} {coordinate.frame.name})"
    )
    # we're treating 'width' as a radius elsewhere, here it's a full width
    images = {
        survey: getimg(coordinate, image_size=width * 2.75, survey=survey)
        for survey in surveys
    }
    images = {x: y for x, y in images.items() if y is not None}
    images['mgps'] = [mgps_cutout]

    regdir = os.path.join(paths.basepath, regname)
    if not os.path.exists(regdir):
        os.mkdir(regdir)
    higaldir = os.path.join(paths.basepath, regname, 'HiGalCutouts')
    if not os.path.exists(higaldir):
        os.mkdir(higaldir)
    if not any([
            os.path.exists(f"{higaldir}/{coordname}_{wavelength}.fits")
            for wavelength in map(int, HiGal.HIGAL_WAVELENGTHS.values())
    ]):
        print(
            f"Retrieving HiGal data for {coordname} ({coordinate.to_string()} {coordinate.frame.name})"
        )
        higal_ims = HiGal.get_images(coordinate, radius=width * 1.5)
        for hgim in higal_ims:
            images['HiGal{0}'.format(hgim[0].header['WAVELEN'])] = hgim
            hgim.writeto(
                f"{higaldir}/{coordname}_{hgim[0].header['WAVELEN']}.fits")
    else:
        print(
            f"Loading HiGal data from disk for {coordname} ({coordinate.to_string()} {coordinate.frame.name})"
        )
        for wavelength in map(int, HiGal.HIGAL_WAVELENGTHS.values()):
            hgfn = f"{higaldir}/{coordname}_{wavelength}.fits"
            if os.path.exists(hgfn):
                hgim = fits.open(hgfn)
                images['HiGal{0}'.format(hgim[0].header['WAVELEN'])] = hgim

    if 'gpsmsx2' in images:
        # redundant, save some space for a SED plot
        del images['gpsmsx2']
    if 'gps90' in images:
        # too low-res to be useful
        del images['gps90']

    if figure is None:
        figure = pl.figure(figsize=(15, 12))

    # coordinate stuff so images can be reprojected to same frame
    ww = mgps_cutout.wcs.celestial
    target_header = ww.to_header()
    del target_header['LONPOLE']
    del target_header['LATPOLE']
    mgps_pixscale = (wcs.utils.proj_plane_pixel_area(ww) * u.deg**2)**0.5
    target_header['NAXES'] = 2
    target_header['NAXIS1'] = target_header['NAXIS2'] = (
        width / mgps_pixscale).decompose().value
    #shape = [int((width / mgps_pixscale).decompose().value)]*2
    outframe = wcs.utils.wcs_to_celestial_frame(ww)
    crd_outframe = coordinate.transform_to(outframe)

    figure.clf()

    imagelist = sorted(images.items(), key=lambda x: wlmap[x[0]])

    #for ii, (survey,img) in enumerate(images.items()):
    for ii, (survey, img) in enumerate(imagelist):

        if hasattr(img[0], 'header'):
            inwcs = wcs.WCS(img[0].header).celestial
            pixscale_in = (wcs.utils.proj_plane_pixel_area(inwcs) *
                           u.deg**2)**0.5

            target_header['CDELT1'] = -pixscale_in.value
            target_header['CDELT2'] = pixscale_in.value
            target_header['CRVAL1'] = crd_outframe.spherical.lon.deg
            target_header['CRVAL2'] = crd_outframe.spherical.lat.deg
            axsize = int((width * 2.5 / pixscale_in).decompose().value)
            target_header['NAXIS1'] = target_header['NAXIS2'] = axsize
            target_header['CRPIX1'] = target_header['NAXIS1'] / 2
            target_header['CRPIX2'] = target_header['NAXIS2'] / 2
            shape_out = [axsize, axsize]

            print(
                f"Reprojecting {survey} to scale {pixscale_in} with shape {shape_out} and center {crd_outframe.to_string()}"
            )

            outwcs = wcs.WCS(target_header)

            new_img, _ = reproject.reproject_interp((img[0].data, inwcs),
                                                    target_header,
                                                    shape_out=shape_out)
        else:
            new_img = img[0].data
            outwcs = img[0].wcs
            pixscale_in = (wcs.utils.proj_plane_pixel_area(outwcs) *
                           u.deg**2)**0.5

        ax = figure.add_subplot(4, 5, ii + 1, projection=outwcs)
        ax.set_title("{0}: {1}".format(survey_titles[survey], wlmap[survey]))

        if not np.any(np.isfinite(new_img)):
            print(f"SKIPPING {survey}")
            continue

        norm = visualization.ImageNormalize(
            new_img,
            interval=visualization.PercentileInterval(99.95),
            stretch=visualization.AsinhStretch(),
        )

        ax.imshow(new_img, origin='lower', interpolation='none', norm=norm)
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.xaxis.set_ticklabels('')
        ax.yaxis.set_ticklabels('')
        ax.coords[0].set_ticklabel_visible(False)
        ax.coords[1].set_ticklabel_visible(False)

        if 'GLON' in outwcs.wcs.ctype[0]:
            xpix, ypix = outwcs.wcs_world2pix(coordinate.galactic.l,
                                              coordinate.galactic.b, 0)
        else:
            xpix, ypix = outwcs.wcs_world2pix(coordinate.fk5.ra,
                                              coordinate.fk5.dec, 0)
        ax.set_xlim(xpix - (width / pixscale_in), xpix + (width / pixscale_in))
        ax.set_ylim(ypix - (width / pixscale_in), ypix + (width / pixscale_in))

        # scalebar = 1 arcmin

        ax.plot([
            xpix - width / pixscale_in + 5 * u.arcsec / pixscale_in,
            xpix - width / pixscale_in + 65 * u.arcsec / pixscale_in
        ], [
            ypix - width / pixscale_in + 5 * u.arcsec / pixscale_in,
            ypix - width / pixscale_in + 5 * u.arcsec / pixscale_in
        ],
                linestyle='-',
                linewidth=1,
                color='w')
        ax.plot(crd_outframe.spherical.lon.deg,
                crd_outframe.spherical.lat.deg,
                marker=((0, -10), (0, -4)),
                color='w',
                linestyle='none',
                markersize=20,
                markeredgewidth=0.5,
                transform=ax.get_transform('world'))
        ax.plot(crd_outframe.spherical.lon.deg,
                crd_outframe.spherical.lat.deg,
                marker=((4, 0), (10, 0)),
                color='w',
                linestyle='none',
                markersize=20,
                markeredgewidth=0.5,
                transform=ax.get_transform('world'))

    pl.tight_layout()
Пример #24
0
    def show(self, savefile=None, show_ps_stamp=False, pcvmin=0, pcvmax=100):
        """ 
        Parameters
        ----------
        pcvmin, pcvmax: [0<float<100] -optional-
            percentile of the saturating (min and max) color for the stamps
        """
        from astropy import visualization
        from astropy.time import Time
        import matplotlib.dates as mdates
        from matplotlib.colors import Normalize
        import matplotlib.pyplot as mpl

        # ----------- #
        # Global      #
        # ----------- #
        prop = dict(marker="o", mec="0.7", ms=8, ecolor="0.7", ls="None")

        # ----------- #
        #   Methods   #
        # ----------- #
        def show_fid_lc(ax):
            """ """
            if len(self.get_history_photopoints()) == 0:
                return
            mag, magerr, jd, fid = np.asarray(
                [[d[k] for k in ["magpsf", "sigmapsf", "jd", "fid"]]
                 for d in self.get_history_photopoints()]).T
            for j, i in enumerate([1, 2, 3]):
                if i in fid:
                    flag_fid = fid == i
                    ax.errorbar([
                        Time(jd_, format="jd").datetime for jd_ in jd[flag_fid]
                    ],
                                mag[flag_fid],
                                yerr=magerr[flag_fid],
                                label="magpsf %s" % FILTER_CODE[j],
                                mfc=FILTER_COLORS[j],
                                **prop)

        def show_fid_uplim(ax):
            """ """
            if len(self.get_history_upperlimits()) == 0:
                return
            upmag, jdup, fidup = np.asarray(
                [[d[k] for k in ["diffmaglim", "jd", "fid"]]
                 for d in self.get_history_upperlimits()]).T
            for j, i in enumerate([1, 2, 3]):
                if i in fidup:
                    flag_fid = fidup == i
                    ax.errorbar([
                        Time(jd_, format="jd").datetime
                        for jd_ in jdup[flag_fid]
                    ],
                                upmag[flag_fid],
                                yerr=0.2,
                                lolims=True,
                                color=FILTER_COLORS[j],
                                ls="None",
                                alpha=0.7,
                                label="_no_legend_")

        # ----------- #
        #   Axes      #
        # ----------- #
        fig = mpl.figure(figsize=[9, 5])

        ref, width, heigh = 0.1, 0.15, 0.25
        ypos, span = 0.65, 0.05
        # Stamps
        aximg = fig.add_axes([ref, ypos, width, heigh])
        axref = fig.add_axes([ref + (width + span), ypos, width, heigh])
        axdif = fig.add_axes([ref + (width + span) * 2, ypos, width, heigh])
        if show_ps_stamp:
            axps = fig.add_axes(
                [ref + (width + span * 1.5) * 3, ypos, width, heigh])
            axps.set_yticks([])
            axps.set_xticks([])
        # - LC
        axlc = fig.add_axes([ref, 0.1, (width + span) * 2 + width, 0.5])
        axlc.set_xlabel("Date", fontsize="large")
        axlc.set_ylabel("mag (magpsf)", fontsize="large")

        # ----------------- #
        #  Plotting Stamps  #
        # ----------------- #
        # Loop Over the stamps
        for ax_, source in zip([aximg, axref, axdif],
                               ['Science', 'Template', 'Difference']):
            origdata = self.get_stamp(source).data
            vmin, vmax = np.percentile(origdata[origdata == origdata],
                                       [pcvmin, pcvmax])
            data_ = visualization.AsinhStretch()(
                (origdata - vmin) / (vmax - vmin))

            ax_.imshow(data_,
                       norm=Normalize(
                           *np.percentile(data_[data_ == data_], [0.5, 99.5])),
                       aspect="auto")
            ax_.set_yticks([])
            ax_.set_xticks([])
            ax_.set_title(source)

        if show_ps_stamp:
            try:
                img = self.download_ps_stamp(color=["y", "g", "i"])
                axps.imshow(np.asarray(img))
            except:
                print("Pan-STARRS stamp failed.")
            axps.set_title("Pan-STARRS (y/g/i)")
        # ----------- #
        #  History    #
        # ----------- #
        show_fid_lc(axlc)
        axlc.invert_yaxis()
        show_fid_uplim(axlc)
        axlc.legend(loc="best")

        # ----------- #
        #  Alert      #
        # ----------- #
        prop['marker'] = "D"
        prop['mec'] = FILTER_COLORS[self.alert["candidate"]["fid"] - 1]
        axlc.errorbar(Time(self.alert["candidate"]["jd"],
                           format="jd").datetime,
                      self.alert["candidate"]["magpsf"],
                      yerr=self.alert["candidate"]["sigmapsf"],
                      mfc=FILTER_COLORS[self.alert["candidate"]["fid"] - 1],
                      label="_no_legend_",
                      **prop)

        # add text
        info = []
        for k in ["rb", "fwhm", "nbad", "elong", "isdiffpos"]:
            try:
                info.append("%s : %.3f" % (k, self.alert["candidate"].get(k)))
            except:
                info.append("%s : %s" % (k, self.alert["candidate"].get(k)))

        for kk in ["objectidps", "sgscore", "distpsnr", "srmag"]:
            for k in [k for k in self.alert["candidate"].keys() if kk in k]:
                info.append("%s : %s" % (k, self.alert["candidate"].get(k)))

        fig.text(0.68,
                 0.6,
                 " \n".join(info),
                 va="top",
                 fontsize="medium",
                 color="0.4")
        fig.text(0.005,
                 0.995,
                 "alert: ID: %s (RA: %.5f | Dec: %.5f | Filter: %s)" %
                 (self.alert.get("candid"), self.alert['candidate']['ra'],
                  self.alert['candidate']['dec'],
                  FILTER_CODE[self.alert['candidate']['fid'] - 1]),
                 fontsize="medium",
                 color="k",
                 va="top",
                 ha="left")

        locator = mdates.AutoDateLocator()
        formatter = mdates.ConciseDateFormatter(locator)
        axlc.xaxis.set_major_locator(locator)
        axlc.xaxis.set_major_formatter(formatter)

        if savefile is not None:
            fig.savefig(savefile, dpi=250)

        return fig
def rgbfig(
        figfilename="SgrB2N_RGB.pdf",
        lims=[([266.83404223, 266.83172659]), ([-28.373138, -28.3698755])],
        scalebarx=coordinates.SkyCoord(266.833545 * u.deg,
                                       -28.37283819 * u.deg),
        redfn=paths.Fpath('SGRB2N-2012-Q.DePree_K.recentered.fits'),
        greenfn=paths.
    Fpath('sgr_b2m.N.B3.allspw.continuum.r0.5.clean1000.image.tt0.pbcor.fits'),
        bluefn=paths.
    Fpath('sgr_b2m.N.B6.allspw.continuum.r0.5.clean1000.image.tt0.pbcor.fits'),
        redpercentile=99.99,
        greenpercentile=99.99,
        bluepercentile=99.99,
        stretch=visualization.AsinhStretch(),
):

    header = fits.getheader(redfn)
    celwcs = wcs.WCS(header).celestial

    redhdu = fits.open(redfn)
    greenhdu = fits.open(greenfn)
    bluehdu = fits.open(bluefn)

    greendata, _ = reproject.reproject_interp(
        (greenhdu[0].data, wcs.WCS(greenhdu[0].header).celestial),
        celwcs,
        shape_out=redhdu[0].data.squeeze().shape)
    bluedata, _ = reproject.reproject_interp(
        (bluehdu[0].data, wcs.WCS(bluehdu[0].header).celestial),
        celwcs,
        shape_out=redhdu[0].data.squeeze().shape)

    #def rescale(x):
    #    return (x-np.nanmin(x))/(np.nanmax(x) - np.nanmin(x))
    redrescale = visualization.PercentileInterval(redpercentile)
    greenrescale = visualization.PercentileInterval(greenpercentile)
    bluerescale = visualization.PercentileInterval(bluepercentile)

    rgb = np.array([
        stretch(redrescale(redhdu[0].data.squeeze())),
        stretch(greenrescale(greendata)),
        stretch(bluerescale(bluedata)),
    ]).swapaxes(0, 2).swapaxes(0, 1)

    norm = visualization.ImageNormalize(
        rgb, interval=visualization.MinMaxInterval(), stretch=stretch)

    fig1 = pl.figure(1)
    fig1.clf()
    ax = fig1.add_subplot(1, 1, 1, projection=celwcs)
    pl.imshow(rgb, origin='lower', interpolation='none', norm=norm)

    (x1, x2), (y1, y2) = celwcs.wcs_world2pix(lims[0], lims[1], 0)
    ax.axis((x1, x2, y1, y2))

    visualization_tools.make_scalebar(ax,
                                      left_side=scalebarx,
                                      length=1.213 * u.arcsec,
                                      label='0.05 pc')

    pl.savefig(paths.fpath(figfilename), bbox_inches='tight')
from astropy.wcs import utils as wcsutils
import pylab as pl
import pyspeckit
import paths
from astropy import modeling
from astropy import stats

cube = SpectralCube.read(
    '/Users/adam/work/w51/alma/FITS/longbaseline/velo_cutouts/w51e2e_csv0_j2-1_r0.5_medsub.fits'
)
cs21cube = subcube = cube.spectral_slab(16 * u.km / u.s, 87 * u.km / u.s)[::-1]

norm = vis.ImageNormalize(
    subcube,
    interval=vis.ManualInterval(-0.002, 0.010),
    stretch=vis.AsinhStretch(),
)

pl.rcParams['font.size'] = 12

szinch = 18
fig = pl.figure(1, figsize=(szinch, szinch))
pl.pause(0.1)
for ii in range(5):
    fig.set_size_inches(szinch, szinch)
    pl.pause(0.1)
    try:
        assert np.all(fig.get_size_inches() == np.array([szinch, szinch]))
        break
    except AssertionError:
        continue
Пример #27
0
 def set_normalization(self,
                       stretch=None,
                       interval=None,
                       stretchkwargs={},
                       intervalkwargs={},
                       perm_linear=None):
     if stretch is None:
         if self.stretch is None:
             stretch = 'linear'
         else:
             stretch = self.stretch
     if isinstance(stretch, str):
         print(stretch,
               ' '.join([f'{k}={v}' for k, v in stretchkwargs.items()]))
         if self.data is None:  #can not calculate objects yet
             self.stretch_kwargs = stretchkwargs
         else:
             kwargs = self.prepare_kwargs(
                 self.stretch_kws_defaults[stretch], self.stretch_kwargs,
                 stretchkwargs)
             if perm_linear is not None:
                 perm_linear_kwargs = self.prepare_kwargs(
                     self.stretch_kws_defaults['linear'], perm_linear)
                 print(
                     'linear', ' '.join([
                         f'{k}={v}' for k, v in perm_linear_kwargs.items()
                     ]))
                 if stretch == 'asinh':  # arg: a=0.1
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.AsinhStretch(**kwargs))
                 elif stretch == 'contrastbias':  # args: contrast, bias
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.ContrastBiasStretch(**kwargs))
                 elif stretch == 'histogram':
                     stretch = vis.CompositeStretch(
                         vis.HistEqStretch(self.data, **kwargs),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'log':  # args: a=1000.0
                     stretch = vis.CompositeStretch(
                         vis.LogStretch(**kwargs),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'powerdist':  # args: a=1000.0
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.PowerDistStretch(**kwargs))
                 elif stretch == 'power':  # args: a
                     stretch = vis.CompositeStretch(
                         vis.PowerStretch(**kwargs),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'sinh':  # args: a=0.33
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.SinhStretch(**kwargs))
                 elif stretch == 'sqrt':
                     stretch = vis.CompositeStretch(
                         vis.SqrtStretch(),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'square':
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.SquaredStretch())
                 else:
                     raise ValueError('Unknown stretch:' + stretch)
             else:
                 if stretch == 'linear':  # args: slope=1, intercept=0
                     stretch = vis.LinearStretch(**kwargs)
                 else:
                     raise ValueError('Unknown stretch:' + stretch)
     self.stretch = stretch
     if interval is None:
         if self.interval is None:
             interval = 'zscale'
         else:
             interval = self.interval
     if isinstance(interval, str):
         print(interval,
               ' '.join([f'{k}={v}' for k, v in intervalkwargs.items()]))
         kwargs = self.prepare_kwargs(self.interval_kws_defaults[interval],
                                      self.interval_kwargs, intervalkwargs)
         if self.data is None:
             self.interval_kwargs = intervalkwargs
         else:
             if interval == 'minmax':
                 interval = vis.MinMaxInterval()
             elif interval == 'manual':  # args: vmin, vmax
                 interval = vis.ManualInterval(**kwargs)
             elif interval == 'percentile':  # args: percentile, n_samples
                 interval = vis.PercentileInterval(**kwargs)
             elif interval == 'asymetric':  # args: lower_percentile, upper_percentile, n_samples
                 interval = vis.AsymmetricPercentileInterval(**kwargs)
             elif interval == 'zscale':  # args: nsamples=1000, contrast=0.25, max_reject=0.5, min_npixels=5, krej=2.5, max_iterations=5
                 interval = vis.ZScaleInterval(**kwargs)
             else:
                 raise ValueError('Unknown interval:' + interval)
     self.interval = interval
     if self.img is not None:
         self.img.set_norm(
             vis.ImageNormalize(self.data,
                                interval=self.interval,
                                stretch=self.stretch,
                                clip=True))
Пример #28
0
def make_analysis_forms(
        basepath="/orange/adamginsburg/web/secure/ALMA-IMF/October31Release/",
        base_form_url="https://docs.google.com/forms/d/e/1FAIpQLSczsBdB3Am4znOio2Ky5GZqAnRYDrYTD704gspNu7fAMm2-NQ/viewform?embedded=true",
        dontskip_noresid=False):
    import glob
    from diagnostic_images import load_images, show as show_images
    from astropy import visualization
    import pylab as pl

    savepath = f'{basepath}/quicklooks'

    try:
        os.mkdir(savepath)
    except:
        pass

    filedict = {
        (field, band, config, robust, selfcal): glob.glob(
            f"{field}/B{band}/{imtype}{field}*_B{band}_*_{config}_robust{robust}*selfcal{selfcal}*.image.tt0*.fits"
        )
        for field in
        "G008.67 G337.92 W43-MM3 G328.25 G351.77 G012.80 G327.29 W43-MM1 G010.62 W51-IRS2 W43-MM2 G333.60 G338.93 W51-E G353.41"
        .split() for band in (3, 6)
        #for config in ('7M12M', '12M')
        for config in ('12M', )
        #for robust in (-2, 0, 2)
        for robust in (0, ) for selfcal in ("", ) + tuple(range(0, 9))
        for imtype in (('', ) if 'October31' in basepath else ('cleanest/',
                                                               'bsens/'))
    }
    badfiledict = {key: val for key, val in filedict.items() if len(val) == 1}
    print(f"Bad files: {badfiledict}")
    filedict = {key: val for key, val in filedict.items() if len(val) > 1}
    filelist = [key + (fn, ) for key, val in filedict.items() for fn in val]

    prev = 'index.html'

    flist = []

    #for field in "G008.67 G337.92 W43-MM3 G328.25 G351.77 G012.80 G327.29 W43-MM1 G010.62 W51-IRS2 W43-MM2 G333.60 G338.93 W51-E G353.41".split():
    ##for field in ("G333.60",):
    #    for band in (3,6):
    #        for config in ('7M12M', '12M'):
    #            for robust in (-2, 0, 2):

    #                # for not all-in-the-same-place stuff
    #                fns = [x for x in glob.glob(f"{field}/B{band}/{field}*_B{band}_*_{config}_robust{robust}*selfcal[0-9]*.image.tt0*.fits") ]

    #                for fn in fns:
    for ii, (field, band, config, robust, selfcal, fn) in enumerate(filelist):

        image = fn
        basename, suffix = image.split(".image.tt0")
        if 'diff' in suffix or 'bsens-cleanest' in suffix:
            continue
        outname = basename.split("/")[-1]

        if prev == outname + ".html":
            print(
                f"{ii}: {(field, band, config, robust, fn)} yielded the same prev "
                f"{prev} as last time, skipping.")
            continue

        jj = 1
        while jj < len(filelist):
            if ii + jj < len(filelist):
                next_ = filelist[ii + jj][5].split(".image.tt0")[0].split(
                    "/")[-1] + ".html"
            else:
                next_ = "index.html"

            if next_ == outname + ".html":
                jj = jj + 1
            else:
                break

        assert next_ != outname + ".html"

        try:
            with warnings.catch_warnings():
                warnings.filterwarnings('ignore')
                print(f"{ii}: {(field, band, config, robust, fn, selfcal)}"
                      f" basename='{basename}', suffix='{suffix}'")
                imgs, cubes = load_images(basename, suffix=suffix)
        except KeyError as ex:
            print(ex)
            raise
        except Exception as ex:
            print(f"EXCEPTION: {type(ex)}: {str(ex)}")
            raise
            continue
        norm = visualization.ImageNormalize(
            stretch=visualization.AsinhStretch(),
            interval=visualization.PercentileInterval(99.95))
        # set the scaling based on one of these...
        # (this call inplace-modifies logn, according to the docs)
        if 'residual' in imgs:
            norm(imgs['residual'][imgs['residual'] == imgs['residual']])
            imnames_toplot = ('mask', 'model', 'image', 'residual')
        elif 'image' in imgs and dontskip_noresid:
            imnames_toplot = (
                'image',
                'mask',
            )
            norm(imgs['image'][imgs['image'] == imgs['image']])
        else:
            print(
                f"Skipped {fn} because no image OR residual was found.  imgs.keys={imgs.keys()}"
            )
            continue
        pl.close(1)
        pl.figure(1, figsize=(14, 6))
        show_images(imgs, norm=norm, imnames_toplot=imnames_toplot)

        pl.savefig(f"{savepath}/{outname}.png", dpi=150, bbox_inches='tight')

        metadata = {
            'field': field,
            'band': band,
            'selfcal': selfcal,  #get_selfcal_number(basename),
            'array': config,
            'robust': robust,
            'finaliter': 'finaliter' in fn,
        }
        make_quicklook_analysis_form(filename=outname,
                                     metadata=metadata,
                                     savepath=savepath,
                                     prev=prev,
                                     next_=next_,
                                     base_form_url=base_form_url)
        metadata['outname'] = outname
        metadata['suffix'] = suffix
        if robust == 0:
            # only keep robust=0 for simplicity
            flist.append(metadata)
        prev = outname + ".html"

    #make_rand_html(savepath)
    make_index(savepath, flist)

    return flist
Пример #29
0
def timegen(in_path, out_path, m, n, cell, stretch, full_hd):
    """ Generates a timelapse from the input FITS files (directory) and saves it to the given path. \n
        ---------- \n
        parameters \n
        ---------- \n
        in_path  : The path to the directory containing the input FITS files (*.fits or *.fits.fz) \n
        out_path : The path at which the output timelapse will be saved. If unspecified writes to .\\timelapse \n
        m        : Number of rows to split image into \n
        n        : Number of columns to split image into \n
        cell     : The grid cell to choose. Specified by row and column indices. (0,1)
        stretch  : String specifying what stretches to apply on the image
        full_hd : Force the video to be 1920 * 1080 pixels.    
        ---------- \n
        returns \n
        ---------- \n
        True if timelapse generated successfully. \n
    """
    # Step 1: Get FITS files from input path.
    fits_files = get_file_list(in_path, ['fits', 'fz'])

    # Step 1.5: Remove files containing the string 'background' from the FITS filename
    fits_files = [
        fname for fname in fits_files if 'background.fits' not in fname
    ]
    fits_files = [
        fname for fname in fits_files if 'pointing00.fits' not in fname
    ]

    # Step 2: Choose the transform you want to apply.
    # TG_LOG_1_PERCENTILE_99
    transform = v.LogStretch(1) + v.PercentileInterval(99)
    if stretch == 'TG_SQRT_PERCENTILE_99':
        transform = v.SqrtStretch() + v.PercentileInterval(99)
    elif stretch == 'TG_LOG_PERCENTILE_99':
        transform = v.LogStretch() + v.PercentileInterval(99)
    elif stretch == 'TG_ASINH_1_PERCENTILE_99':
        transform = v.AsinhStretch(1) + v.PercentileInterval(99)
    elif stretch == 'TG_ASINH_PERCENTILE_99':
        transform = v.AsinhStretch() + v.PercentileInterval(99)
    elif stretch == 'TG_SQUARE_PERCENTILE_99':
        transform = v.SquaredStretch() + v.PercentileInterval(99)
    elif stretch == 'TG_SINH_1_PERCENTILE_99':
        transform = v.SinhStretch(1) + v.PercentileInterval(99)
    else:
        transform = v.SinhStretch() + v.PercentileInterval(99)

    # Step 3:
    for file in tqdm.tqdm(fits_files):
        # Read FITS
        try:
            fits_data = fits.getdata(file)
        except Exception as e:
            # If the current FITS file can't be opened, log and skip it.
            logging.error(str(e))
            continue
        # Flip up down
        flipped_data = np.flipud(fits_data)
        # Debayer with 'RGGB'
        rgb_data = debayer_image_array(flipped_data, pattern='RGGB')
        interested_data = get_sub_image(rgb_data, m, n, cell[0], cell[1])
        # Additional processing
        interested_data = 255 * transform(interested_data)
        rgb_data = interested_data.astype(np.uint8)
        bgr_data = cv2.cvtColor(rgb_data, cv2.COLOR_RGB2BGR)
        # save processed image to temp_dir
        try:
            save_image(interested_data,
                       os.path.split(file)[-1].split('.')[0],
                       path=temp_dir)
        except Exception as e:
            logging.error(str(e))
    # Step 4: Validate output path and create if it doesn't exist.

    # Step 5: Create timelapse from the temporary files
    generate_timelapse_from_images('temp_timelapse', out_path, hd_flag=full_hd)

    # Delete temporary files
    try:
        clear_dir(temp_dir)
    except Exception as e:
        print('Clearing TEMP Files failed. See log for more details')
        logging.error(str(e))
    return True