예제 #1
0
    def get_norm(self, stretch_text, scale_text):

        image = self.ax.get_images()[0]
        scale = None

        if scale_text == "MinMax":
            scale = MinMaxInterval()

        else:
            scale = ZScaleInterval()

        if stretch_text == "Linear":
            stretch = LinearStretch()

        elif stretch_text == "Log":
            stretch = LogStretch()

        else:
            stretch = SqrtStretch()

        minV, maxV = scale.get_limits(
            self.cubeObj.data_cube[self.cubeObj.currSlice])

        norm = ImageNormalize(vmin=minV, vmax=maxV, stretch=stretch)

        return norm
def build_image(id_,
                set_,
                bands=['EUC_VIS', 'EUC_H', 'EUC_J', 'EUC_Y'],
                img_size=200,
                scale=100,
                clip=True):
    tables = []
    data = np.empty((img_size, img_size, len(bands)))
    for i, band in enumerate(bands):
        fname = get_image_filename_from_id(id_, band, set_)
        try:
            tables.append(fits.open(fname))
        except FileNotFoundError as fe:
            raise
        if band != 'EUC_VIS':
            band_data, data_footprint = reproject_interp(
                tables[i][0], tables[0][0].header)
        else:
            band_data = tables[0][0].data
        band_data[np.isnan(band_data)] = 0.
        if clip:
            interval = AsymmetricPercentileInterval(0.25,
                                                    99.75,
                                                    n_samples=100000)
            vmin, vmax = interval.get_limits(band_data)
            stretch = MinMaxInterval() + LogStretch()
            data[:, :, i] = stretch(
                ((np.clip(band_data, -vmin * 0.7, vmax)) / (vmax)))
        else:
            stretch = LogStretch() + MinMaxInterval()
            data[:, :, i] = stretch(band_data)
    for t in tables:
        t.close()
    return data.astype(np.float32)
예제 #3
0
def convert_to_valid_color(
    image_color: np.ndarray,
    clip: bool = False,
    lower_clip: float = 0.0,
    upper_clip: float = 1.0,
    normalize: bool = False,
    scaling: Optional[str] = None,
    simple_norm: bool = False,
) -> np.ndarray:
    """
    Convert the channel to a valid 0-1 range for RGB images
    """

    if simple_norm:
        interval = MinMaxInterval()
        norm = ImageNormalize()
        return norm(interval(image_color))

    if clip:
        image_color = np.clip(image_color, lower_clip, upper_clip)
    if normalize:
        interval = ManualInterval(lower_clip, upper_clip)
    else:
        interval = MinMaxInterval()
    if scaling == "sqrt":
        stretch = SqrtStretch()
        image_color = stretch(interval(image_color))
    else:
        norm = ImageNormalize()
        image_color = norm(interval(image_color))

    return image_color
def asinh_plot_VLASS_mJy(wcs_celestial, image_data, showgrid=False):
    cmap = plt.get_cmap('viridis')

    plt.rcParams.update({'font.size': 18})
    fig = plt.figure(figsize=(11, 8), dpi=75)
    ax = plt.subplot(projection=wcs_celestial)
    im, norm = imshow_norm(image_data,
                           ax,
                           origin='lower',
                           interval=MinMaxInterval(),
                           stretch=AsinhStretch(),
                           vmin=-1e-5,
                           cmap=cmap)
    cbar = fig.colorbar(im, cmap=cmap)
    ax.set_xlabel('RA J2000')
    ax.set_ylabel('Dec J2000')
    ax.set_title(pulsarname)
    cbar.set_label('mJy')
    if showgrid:
        ax.grid(color='white', ls='solid')
    cutout = Cutout2D(image_data, position, boxsize, wcs=wcs_celestial)
    cutout.plot_on_original(color='white')
    plt.savefig(pulsarname)

    plt.show()
예제 #5
0
 def interval_select(self):
     radioBtn = self.sender()
     if radioBtn.isChecked():
         if radioBtn.text() == 'MinMax':
             self.interval = MinMaxInterval()
             self.rbtn5.setChecked(False)
             self.rbtn6.setChecked(True)
             self.rbtn7.setChecked(False)
             self.rbtn8.setChecked(False)
             self.rbtn9.setChecked(False)
             main.refresh_norm()
         elif radioBtn.text() == 'Manual':
             main.manual_int()
             main.refresh_norm()
         elif radioBtn.text() == 'Percentile':
             main.percent_int()
             main.refresh_norm()
         elif radioBtn.text() == 'AsymmetricPercentile':
             main.asym_int()
             main.refresh_norm()
         elif radioBtn.text() == 'ZScale':
             self.interval = ZScaleInterval()
             self.rbtn5.setChecked(False)
             self.rbtn6.setChecked(False)
             self.rbtn7.setChecked(False)
             self.rbtn8.setChecked(False)
             self.rbtn9.setChecked(True)
             main.refresh_norm()
예제 #6
0
def contour_plot(ax, contour_file, contour_type, image_beam):
    """Plots contours over the main image, can be either from the same image or from a different one, fits file."""

    cblock = import_contours(contour_file)
    cdata = cblock[0]
    cheader = cblock[1]
    cwcs = WCS(cheader)

    contour_type = str(contour_type)
    ax.set_autoscale_on(False)

    norm = ImageNormalize(cdata,
                          interval=MinMaxInterval(),
                          stretch=SqrtStretch())

    if contour_type == "automatic":
        spacing = contour_bias
        n = 6  #number of contours
        ub = np.max(cdata)
        lb = np.min(cdata)

        def level_func(lb, ub, n, spacing=1.1):
            span = (ub - lb)
            dx = 1.0 / (n - 1)
            return [lb + (i * dx)**spacing * span for i in range(n)]

        levels = level_func(lb, ub, n, spacing=float(spacing))
        levels = np.array(levels)[np.array(levels) > 0.]
        print('Generated Levels for contour are: ', levels, 'MJy/sr')
        CS = ax.contour(cdata,
                        levels=levels,
                        colors='black',
                        transform=ax.get_transform(cwcs.celestial),
                        alpha=1.0,
                        linewidths=1)
    if contour_type == "sigma":
        contour_levels = list(map(float, args.contour_levels))
        sigma_levels = np.array(contour_levels)
        rms = (0.0004 * u.Jy / u.beam).to(
            u.MJy / u.sr, equivalencies=u.beam_angular_area(image_beam))
        levels = rms * sigma_levels
        print('Sigma Levels for contour are: ', sigma_levels, 'sigma')
        CS = ax.contour(cdata,
                        levels=levels,
                        norm=norm,
                        transform=ax.get_transform(cwcs.celestial),
                        colors='white',
                        alpha=0.5)
        #plt.title(label='Contours at %s sigma' % contour_levels)
    if contour_type == "manual":
        levels = list(map(float, args.contour_levels))
        print('Contour Levels are: ', levels, 'MJy/sr')
        CS = ax.contour(cdata,
                        levels=levels,
                        norm=norm,
                        transform=ax.get_transform(cwcs.celestial),
                        colors='white',
                        alpha=0.5,
                        linewidths=1.0)
예제 #7
0
def plot_box(box, title=None, path=None, format=None, scale="log", interval="pts", cmap="viridis"):

    """
    This function ...
    :param box:
    :param title:
    :param path:
    :param format:
    :param scale:
    :param interval:
    :param cmap:
    :return:
    """

    # Other new colormaps: plasma, magma, inferno

    # Normalization
    if scale == "log": norm = ImageNormalize(stretch=LogStretch())
    elif scale == "sqrt": norm = ImageNormalize(stretch=SqrtStretch())
    #elif scale == "skimage": norm = exposure.equalize_hist
    else: raise ValueError("Invalid option for 'scale'")

    if interval == "zscale":

        vmin, vmax = ZScaleInterval().get_limits(box)

    elif interval == "pts":

        # Determine the maximum value in the box and the mimimum value for plotting
        vmin = max(np.nanmin(box), 0.)
        vmax = 0.5 * (np.nanmax(box) + vmin)

    elif interval == "minmax":

        vmin, vmax = MinMaxInterval().get_limits(box)

    elif isinstance(interval, tuple):

        vmin = interval[0]
        vmax = interval[1]

    else: raise ValueError("Invalid option for 'interval'")

    # Make the plot
    plt.figure(figsize=(7,7))
    plt.imshow(box, origin="lower", interpolation="nearest", vmin=vmin, vmax=vmax, norm=norm, cmap=cmap)
    plt.xlim(0, box.shape[1]-1)
    plt.ylim(0, box.shape[0]-1)

    if title is not None: plt.title(title)

    if path is None: plt.show()
    else: plt.savefig(path, format=format)

    plt.close()

    # Return vmin and vmax
    return vmin, vmax
def preprocess_band(image, clip=True):
    """Do clip preprocessing of a single band.

    param: image (ndarray): 2D array containing a single band's data.
    param: clip (bool): Whether or not to do clip preprocessing. if False log-stretches the image.
           Defaults to True.
    returns: newimage (ndarray): Preprocessed 2D array."""

    image[np.isnan(image)] = 0.
    if clip:
        interval = AsymmetricPercentileInterval(0.25, 99.75, n_samples=100000)
        vmin, vmax = interval.get_limits(image)
        stretch = MinMaxInterval() + LogStretch()
        newimage = stretch(((np.clip(image, -vmin * 0.7, vmax)) / (vmax)))
    else:
        stretch = LogStretch() + MinMaxInterval()
        newimage = stretch(image)
    return newimage
예제 #9
0
def logarithmic_scale (Images, Images_load_length):

    transform = LogStretch() + MinMaxInterval()  # def of transformation

    print('Logarithmic strech :')

    for k in range(Images_load_length):
        if k%1000==0:
            print(k)
        Images[k,:,:] = transform(Images[k,:,:])

    print('Logarithmic data strech done')

    return Images
예제 #10
0
def load_images(path, extensions, load_length, IR=False, test=False, start=0):

    print('Data loading :')

    if test == False:
        start = 0

    if IR:

        transform = MinMaxInterval()

        filelist_y = sorted(glob.glob(path + 'EUC_Y/' + extensions))
        filelist_j = sorted(glob.glob(path + 'EUC_J/' + extensions))
        filelist_h = sorted(glob.glob(path + 'EUC_H/' + extensions))

        Images = np.empty((load_length, 200, 200), dtype=np.float32)

        for k in range(start, load_length):
            if k % 300 == 0:
                print(k)

            Im_y = np.array(fits.getdata(filelist_y[k + start]))
            Im_y = transform(Im_y)
            Im_y = np.array(cv2.resize(Im_y, dsize=(200, 200)))

            Im_j = np.array(fits.getdata(filelist_j[k + start]))
            Im_j = transform(Im_j)
            Im_j = np.array(cv2.resize(Im_j, dsize=(200, 200)))

            Im_h = np.array(fits.getdata(filelist_h[k + start]))
            Im_h = transform(Im_h)
            Im_h = np.array(cv2.resize(Im_h, dsize=(200, 200)))

            Images[k, :, :] = 0.15 * Im_y + 0.35 * Im_j + 0.5 * Im_h

    else:

        filelist = sorted(glob.glob(path + 'EUC_VIS/' + extensions))

        Images = np.empty((load_length, 200, 200), dtype=np.float32)

        for k in range(start, load_length):
            if k % 1000 == 0:
                print(k)
            Images[k, :, :] = np.array(fits.getdata(filelist[k + start]))

    print('Done loading data')

    return Images
예제 #11
0
def plot_image(image,
               cmap='gray',
               title='Image',
               input_ratio=None,
               interval_type='zscale',
               stretch='linear',
               percentile=99,
               xlim_minus=0,
               xlim_plus=1000,
               ylim_minus=0,
               ylim_plus=1000):

    from astropy.visualization import (ZScaleInterval, PercentileInterval,
                                       MinMaxInterval, ImageNormalize,
                                       simple_norm)

    fig = plt.figure(figsize=(10, 10))

    ax = fig.add_subplot(1, 1, 1)

    if interval_type == 'zscale':
        norm = ImageNormalize(image, interval=ZScaleInterval())

    elif interval_type == 'percentile':
        norm = ImageNormalize(image, interval=PercentileInterval(percentile))
    elif interval_type == 'minmax':
        norm = ImageNormalize(image, interval=MinMaxInterval())
    elif interval_type == 'simple_norm':
        norm = simple_norm(image, stretch)

    im = ax.imshow(image[xlim_minus:xlim_plus, ylim_minus:ylim_plus],
                   cmap=cmap,
                   interpolation='none',
                   origin='lower',
                   norm=norm)

    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="3%", pad=0.1)

    plt.colorbar(im, cax=cax, label='ADU')
    ax.set_xlabel("x [pixels]")
    ax.set_ylabel("y [pixels]")
    ax.set_title(title, fontsize=14)

    plt.show()
예제 #12
0
def optical_depth(inputfile):

    SB = import_fits(inputfile)[0]
    SB_erg = SB * 1e-17

    T = 8000  #K
    nu = 8e9  # central frequency of our broadband observations
    c = 3e10  #speed of light in cgs
    k_b = 1.38e-16  #boltzmann constant in cgs

    tau = -np.log(1 - (c**2 * SB_erg) / (2 * k_b * T * (nu**2)))

    hrd = import_fits(inputfile)[1]
    wcs = WCS(hrd)

    norm = ImageNormalize(emission_measure,
                          interval=MinMaxInterval(),
                          stretch=SqrtStretch())

    fig = plt.figure(1)
    ax = fig.add_subplot(111, projection=wcs, slices=('x', 'y', 0, 0))
    em_map = ax.imshow(tau, origin='lower', cmap='viridis', norm=norm)
    ax.set_xlabel('Right Ascension\nJ2000')
    ax.set_ylabel('Declination')
    cbar = fig.colorbar(em_map)
    cbar.set_label('Optical Depth')

    dims = np.shape(tau)
    centre = (dims[0] / 2., dims[1] / 2.)
    ax.set_xlim(centre[0] - 300, centre[0] + 300)
    ax.set_ylim(centre[1] - 300, centre[1] + 300)

    ra = ax.coords[0]
    ra.set_format_unit('degree', decimal=True)
    #ra.set_ticks(number=4)

    dec = ax.coords[1]
    dec.set_format_unit('degree', decimal=True)
    #dec.set_ticks(number=4)
    #ax.set_title('8-12 GHz, 5\u03C3 mask ')

    #ax.grid(True)

    plt.show()
def image_plot(inputfile, d_range, outputfile):
    """Takes a scaled image in the form of a numpy array and plots it along with coordinates and a colorbar"""

    block = import_fits(inputfile, d_range)
    data = block[0]
    hrd = block[1]
    wcs = WCS(hrd)
    if hrd['TELESCOP'] != 'Spitzer':
        beam = Beam.from_fits_header(hrd)

    norm = ImageNormalize(data,
                          interval=MinMaxInterval(),
                          stretch=SqrtStretch())

    figure = plt.figure(num=1)
    if wcs.pixel_n_dim == 4:
        ax = figure.add_subplot(111, projection=wcs, slices=('x', 'y', 0, 0))
    elif wcs.pixel_n_dim == 2:
        ax = figure.add_subplot(111, projection=wcs, slices=('x', 'y'))

    main_image = ax.imshow(X=data,
                           cmap='plasma',
                           origin='lower',
                           norm=norm,
                           vmax=np.max(data) - 5,
                           vmin=np.min(data))
    cbar = figure.colorbar(main_image)

    if hrd['TELESCOP'] == 'Spitzer':
        ax.invert_xaxis()
        ax.invert_yaxis()

    ax.set_xlabel('Right Ascension J2000')
    ax.set_ylabel('Declination J2000')
    cbar.set_label('Surface Brigthness (MJy/Sr)')

    if contour_file != False:
        contours = contour_plot(ax, contour_file, contour_type, beam)

    plt.show()
    if outputfile != False:
        plt.savefig(os.getcwd() + outputfile)
예제 #14
0
    def update_image(self, fits_file):
        """ Update the image displayed in the ImageDisplayWindow."""

        try:
            image_data = fits.getdata(fits_file, ext=0)
            sz = image_data.shape

            # Crop 10% of image all the way around the edge.
            new_image_data = self.crop_center(image_data, int(sz[0] * 0.8),
                                              int(sz[1] * 0.8))

            # Normalize the image to the range [0.0, 1.0].
            norm = ImageNormalize(new_image_data,
                                  interval=MinMaxInterval(),
                                  stretch=SqrtStretch())

            # Plot image with "prism" color map and norm defined above.
            self.gca.imshow(new_image_data, cmap='prism', norm=norm)
            self.draw()
        except:
            print("Error reading and displaying FITS file")
    def generate_plots(self, obs_id):
        self._logger.debug(f'Begin generate_plots for {obs_id}')
        count = 0
        # NC - algorithm
        # NC - review - 07-08-20
        hdus = fits.open(self._science_fqn)
        obs_type = hdus[0].header.get('OBSTYPE').upper()
        interval = ZScaleInterval()
        if (self._observation.target is not None and
                self._observation.target.moving) and obs_type != 'DARK':
            interval = MinMaxInterval()
        if 'OBJECT' in obs_type:
            white_light_data = interval(
                np.flipud(np.median(hdus['SCI'].data, axis=0)))
        elif ('FLAT' in obs_type or 'ARC' in obs_type or
              'RONCHI' in obs_type or 'DARK' in obs_type):
            # Stitch together the 29 'SCI' extensions into one array and save.
            hdul = [x for x in hdus if x.name == 'SCI']
            hdul.sort(key=lambda x: int(re.split(r"[\[\]\:\,']+",
                                                 x.header['NSCUTSEC'])[3]))
            temp = np.concatenate([x.data for x in hdul])
            white_light_data = interval(temp)
        elif 'SHIFT' in obs_type:
            temp = np.flipud(hdus['SCI'].data)
            white_light_data = interval(temp)
        else:
            return count

        plt.figure(figsize=(10.24, 10.24), dpi=100)
        plt.grid(False)
        plt.axis('off')
        plt.imshow(white_light_data, cmap='inferno')
        plt.savefig(self._preview_fqn, format='jpg')
        count += 1
        self.add_preview(self._storage_name.prev_uri, self._storage_name.prev,
                         ProductType.PREVIEW)
        self.add_to_delete(self._preview_fqn)
        count += self._gen_thumbnail()
        self._logger.info(f'End generate_plots for {obs_id}.')
        return count
    def __loadFitsFile(self, imageDataIndex, filePath, imageSize,
                       rejectionThresholdArea):
        if (
                filePath != ""
        ):  #If the file corresponding to currentIndex was found in the folder.
            newHDU = fits.open(filePath)[0]
            newImageData = newHDU.data

            if (newImageData.shape[0] != newImageData.shape[1]):
                return  #The image is not square in shape and is therefore rejected.

            if (not self.__imageFromEdgeOfSurvey(newImageData,
                                                 rejectionThresholdArea)):
                resizedImageData = imresize(newImageData,
                                            (imageSize, imageSize),
                                            mode="F")  #The
                self.nonBlankImageCount += 1

                normalizerInterval = MinMaxInterval()
                normalizedNewImageData = normalizerInterval(
                    resizedImageData
                )  #The pixel values in the curent image are normalized.
                self.imageData[:, :,
                               imageDataIndex] = normalizedNewImageData  #The image has been successfully loaded.
예제 #17
0
def image_plot(inputfile, d_range, imsize, outputfile):
	"""Takes a scaled image in the form of a numpy array and plots it along with coordinates and a colorbar"""

	block = import_fits(inputfile, d_range)
	data = block[0]
	hrd = block[1]
	wcs = WCS(hrd, naxis=2)
	if hrd['TELESCOP'] != 'Spitzer':
		beam = Beam.from_fits_header(hrd)

	print(np.shape(data))
	norm = ImageNormalize(data, interval=MinMaxInterval(), stretch=SqrtStretch())

	figure=plt.figure(num=1)
	figure.subplots_adjust(
		top=0.924,
		bottom=0.208,
		left=0.042,
		right=0.958,
		hspace=0.125,
		wspace=0.2
	)

	if wcs.pixel_n_dim == 4:
		ax=figure.add_subplot(111, projection=wcs, slices=('x','y', 0, 0))
	elif wcs.pixel_n_dim ==2:
		ax=figure.add_subplot(111, projection=wcs, slices=('x','y'))


	main_image=ax.imshow(X=data, cmap='plasma', origin='lower', norm=norm, vmax=np.max(data) , vmin=np.min(data))
	cbar=figure.colorbar(main_image)
	star_index = wcs.world_to_pixel(star_coord)
	star=ax.scatter(star_index[0], star_index[1], marker='*', c='w')

	if hrd['TELESCOP'] == 'Spitzer':
		ax.invert_xaxis()
		ax.invert_yaxis()


	dims=np.shape(data)
	centre=(dims[0]/2., dims[1]/2.)
	ax.set_xlim(centre[0]-imsize, centre[0]+imsize)
	ax.set_ylim(centre[1]-imsize, centre[1]+imsize)



	ax.set_xlabel('Right Ascension J2000')
	ax.set_ylabel('Declination J2000')
	cbar.set_label('Surface Brigthness (MJy/Sr)')
	ra = ax.coords[0]
	ra.set_format_unit('degree', decimal=True)

	dec=ax.coords[1]
	dec.set_format_unit('degree', decimal=True)

	#ax.set_title('Spitzer 24\u03BCm', fontdict=font)
	ax.set_title('4-12 GH 5\u03C3 mask')

	if hrd['TELESCOP'] != 'Spitzer':
		beam = Beam.from_fits_header(hrd)
		c = SphericalCircle((350.32, 61.14)*u.degree, beam.major, edgecolor='black', facecolor='none',
	           				transform=ax.get_transform('fk5'))
		ax.add_patch(c)


	if outputfile != False:
		plt.savefig(os.getcwd()+'/thesis_figs/'+outputfile)
	plt.show()
예제 #18
0
while (i < len(name)):

    obj = name[i]
    sc = SkyCoord.from_name(obj)
    ra = sc.icrs.ra.deg
    dec = sc.icrs.dec.deg
    for hips in hips_list:
        url = 'http://alasky.u-strasbg.fr/hips-image-services/hips2fits?hips={}&width={}&height={}&fov={}&projection=TAN&coordsys=icrs&ra={}&dec={}'.format(
            quote(hips), width, height, fov, ra, dec)
        hdu = fits.open(url)
        wcs = WCS(hdu[0].header)
        plt.figure()
        plt.subplot(projection=wcs)
        im = hdu[0].data
        norm = ImageNormalize(im,
                              interval=MinMaxInterval(),
                              stretch=AsinhStretch())
        file_name = '{}-{}.jpg'.format(obj, hips.replace('/', '_'))
        ap.append(file_name)
        plt.imshow(im, cmap='Greys', norm=norm, origin='lower')
        px, py = wcs.wcs_world2pix(ra, dec, 1)
        plt.scatter(px, py, c='g', s=100)
        plt.title('{} - {}'.format(obj, hips))
        plt.xlabel('RA')
        plt.ylabel('DEC')
        print(file_name)
        plt.savefig(file_name)
    i = i + 1

#%%
예제 #19
0
def asin_stretch_norm(images: Stamp):
    return ImageNormalize(
        images,
        interval=MinMaxInterval(),
        stretch=AsinhStretch(),
    )
예제 #20
0
def log_stretch_norm(images: Stamp):
    return ImageNormalize(
        images,
        interval=MinMaxInterval(),
        stretch=LogStretch(),
    )
예제 #21
0
            data = photom[sbset]

            cut = (data[z_col] > 0.)
            #data = data[cut][::10]
            data = data[data[z_col] >= 0.]

            zdata = data[z_col].data
            mdata = data[mag_col].data

            Fig, Ax = plt.subplots(1, figsize=(5, 3.5))
            obs_peaks = []

            b1 = sampler.flatchain[np.argmax(sampler.flatlnprobability)]
            b1[-1] = 0.

            intvl = MinMaxInterval()
            colors = plt.cm.viridis(intvl(mrange))

            for im, m in enumerate(mrange):
                pzm = pzl(zrange, m, *best, lzc=0.001)
                pzm /= np.trapz(pzm, zrange)
                P = Ax.plot(zrange,
                            pzm,
                            color=colors[im],
                            lw=3,
                            alpha=0.7,
                            label='{0} = {1:.1f}'.format('$m_{I}$', m))

                cut = np.logical_and(mdata > m - 0.5, mdata < m + 0.5)
                Ax.hist(zdata[cut],
                        normed=True,
예제 #22
0
def data_to_pitch(data_array, pitch_range=[100, 10000], center_pitch=440, zero_point="median",
                  stretch='linear', minmax_percent=None, minmax_value=None, invert=False):
    """
    Map data array to audible pitches in the given range, and apply stretch and scaling
    as required.

    Parameters
    ----------
    data_array : array-like
        Data to map to pitch values. Individual data values should be floats.
    pitch_range : array
        Optional, default [100,10000]. Range of acceptable pitches in Hz.
    center_pitch : float
        Optional, default 440. The pitch in Hz where that the the zero point of the
        data will be mapped to.
    zero_point : str or float
        Optional, default "median". The data value that will be mapped to the center
        pitch. Options are mean, median, or a specified data value (float).
    stretch : str
        Optional, default 'linear'. The stretch to apply to the data array.
        Valid values are: asinh, sinh, sqrt, log, linear
    minmax_percent : array
        Optional. Interval based on a keeping a specified fraction of data values
        (can be asymmetric) when scaling the data. The format is [lower percentile,
        upper percentile], where data values below the lower percentile and above the upper
        percentile are clipped. Only one of minmax_percent and minmax_value should be specified.
    minmax_value : array
        Optional. Interval based on user-specified data values when scaling the data array.
        The format is [min value, max value], where data values below the min value and above
        the max value are clipped.
        Only one of minmax_percent and minmax_value should be specified.
    invert : bool
        Optional, default False.  If True the pitch array is inverted (low pitches become high
        and vice versa).

    Returns
    -------
    response : array
        The normalized data array, with values in given pitch range.
    """
    # Parsing the zero point
    if zero_point in ("med", "median"):
        zero_point = np.median(data_array)
    if zero_point in ("ave", "mean", "average"):
        zero_point = np.mean(data_array)

    # The center pitch cannot be >= max() pitch range, or <= min() of pitch range.
    # If it is, fall back to using the mean of the pitch range provided.
    if center_pitch <= pitch_range[0] or center_pitch >= pitch_range[1]:
        warnings.warn("Given center pitch is outside the pitch range, defaulting to the mean.",
                      InputWarning)
        center_pitch = np.mean(pitch_range)

    if (data_array == zero_point).all():  # All values are the same, no more calculation needed
        return np.full(len(data_array), center_pitch)

    # Normalizing the data_array and adding the zero point (so it can go through the same transform)
    data_array = np.append(np.array(data_array), zero_point)

    # Setting up the transform with the stretch
    if stretch == 'asinh':
        transform = AsinhStretch()
    elif stretch == 'sinh':
        transform = SinhStretch()
    elif stretch == 'sqrt':
        transform = SqrtStretch()
    elif stretch == 'log':
        transform = LogStretch()
    elif stretch == 'linear':
        transform = LinearStretch()
    else:
        raise InvalidInputError("Stretch {} is not supported!".format(stretch))

    # Adding the scaling to the transform
    if minmax_percent is not None:
        transform += AsymmetricPercentileInterval(*minmax_percent)

        if minmax_value is not None:
            warnings.warn("Both minmax_percent and minmax_value are set, minmax_value will be ignored.",
                          InputWarning)
    elif minmax_value is not None:
        transform += ManualInterval(*minmax_value)
    else:  # Default, scale the entire image range to [0,1]
        transform += MinMaxInterval()

    # Performing the transform and then putting it into the pich range
    pitch_array = transform(data_array)

    if invert:
        pitch_array = 1 - pitch_array

    zero_point = pitch_array[-1]
    pitch_array = pitch_array[:-1]

    # In rare cases, the zero-point at this stage might be 0.0.
    # One example is an input array of two values where the median() is the same as the
    # lowest of the two values. In this case, the zero-point is 0.0 and will lead to error
    # (divide by zero). Change to small value to avoid dividing by zero (in reality the choice
    # of zero-point calculation by the user was probably poor, but not in purview to mandate or
    # change user's choice here.  May want to consider providing info back to the user about the
    # distribution of pitches actually used based on their sonification options in some way.
    if zero_point == 0.0:
        zero_point = 1E-6

    if ((1/zero_point)*(center_pitch - pitch_range[0]) + pitch_range[0]) <= pitch_range[1]:
        pitch_array = (pitch_array/zero_point)*(center_pitch - pitch_range[0]) + pitch_range[0]
    else:
        pitch_array = (((pitch_array-zero_point)/(1-zero_point))*(pitch_range[1] - center_pitch) +
                       center_pitch)

    return pitch_array
예제 #23
0
    async def get(self, object_id: str = None):
        """
        ---
        summary: Serve alert cutout as fits or png
        tags:
          - alerts
          - kowalski

        parameters:
          - in: query
            name: instrument
            required: false
            schema:
              type: str
          - in: query
            name: candid
            description: "ZTF alert candid"
            required: true
            schema:
              type: integer
          - in: query
            name: cutout
            description: "retrieve science, template, or difference cutout image?"
            required: true
            schema:
              type: string
              enum: [science, template, difference]
          - in: query
            name: file_format
            description: "response file format: original loss-less FITS or rendered png"
            required: true
            default: png
            schema:
              type: string
              enum: [fits, png]
          - in: query
            name: interval
            description: "Interval to use when rendering png"
            required: false
            schema:
              type: string
              enum: [min_max, zscale]
          - in: query
            name: stretch
            description: "Stretch to use when rendering png"
            required: false
            schema:
              type: string
              enum: [linear, log, asinh, sqrt]
          - in: query
            name: cmap
            description: "Color map to use when rendering png"
            required: false
            schema:
              type: string
              enum: [bone, gray, cividis, viridis, magma]

        responses:
          '200':
            description: retrieved cutout
            content:
              image/fits:
                schema:
                  type: string
                  format: binary
              image/png:
                schema:
                  type: string
                  format: binary

          '400':
            description: retrieval failed
            content:
              application/json:
                schema: Error
        """
        instrument = self.get_query_argument("instrument", "ZTF").upper()
        if instrument not in INSTRUMENTS:
            raise ValueError("Instrument name not recognised")

        # allow access to public data only by default
        selector = {1}

        for stream in self.associated_user_object.streams:
            if "ztf" in stream.name.lower():
                selector.update(set(stream.altdata.get("selector", [])))

        selector = list(selector)

        try:
            candid = int(self.get_argument("candid"))
            cutout = self.get_argument("cutout").capitalize()
            file_format = self.get_argument("file_format", "png").lower()
            interval = self.get_argument("interval", default=None)
            stretch = self.get_argument("stretch", default=None)
            cmap = self.get_argument("cmap", default=None)

            known_cutouts = ["Science", "Template", "Difference"]
            if cutout not in known_cutouts:
                return self.error(
                    f"Cutout {cutout} of {object_id}/{candid} not in {str(known_cutouts)}"
                )
            known_file_formats = ["fits", "png"]
            if file_format not in known_file_formats:
                return self.error(
                    f"File format {file_format} of {object_id}/{candid}/{cutout} not in {str(known_file_formats)}"
                )

            normalization_methods = {
                "asymmetric_percentile": AsymmetricPercentileInterval(
                    lower_percentile=1, upper_percentile=100
                ),
                "min_max": MinMaxInterval(),
                "zscale": ZScaleInterval(nsamples=600, contrast=0.045, krej=2.5),
            }
            if interval is None:
                interval = "asymmetric_percentile"
            normalizer = normalization_methods.get(
                interval.lower(),
                AsymmetricPercentileInterval(lower_percentile=1, upper_percentile=100),
            )

            stretching_methods = {
                "linear": LinearStretch,
                "log": LogStretch,
                "asinh": AsinhStretch,
                "sqrt": SqrtStretch,
            }
            if stretch is None:
                stretch = "log" if cutout != "Difference" else "linear"
            stretcher = stretching_methods.get(stretch.lower(), LogStretch)()

            if (cmap is None) or (
                cmap.lower() not in ["bone", "gray", "cividis", "viridis", "magma"]
            ):
                cmap = "bone"
            else:
                cmap = cmap.lower()

            query = {
                "query_type": "find",
                "query": {
                    "catalog": "ZTF_alerts",
                    "filter": {
                        "candid": candid,
                        "candidate.programid": {"$in": selector},
                    },
                    "projection": {"_id": 0, f"cutout{cutout}": 1},
                },
                "kwargs": {"limit": 1, "max_time_ms": 5000},
            }

            response = kowalski.query(query=query)

            if response.get("status", "error") == "success":
                alert = response.get("data", [dict()])[0]
            else:
                return self.error("No cutout found.")

            cutout_data = bj.loads(bj.dumps([alert[f"cutout{cutout}"]["stampData"]]))[0]

            # unzipped fits name
            fits_name = pathlib.Path(alert[f"cutout{cutout}"]["fileName"]).with_suffix(
                ""
            )

            # unzip and flip about y axis on the server side
            with gzip.open(io.BytesIO(cutout_data), "rb") as f:
                with fits.open(io.BytesIO(f.read())) as hdu:
                    header = hdu[0].header
                    data_flipped_y = np.flipud(hdu[0].data)

            if file_format == "fits":
                hdu = fits.PrimaryHDU(data_flipped_y, header=header)
                hdul = fits.HDUList([hdu])

                stamp_fits = io.BytesIO()
                hdul.writeto(fileobj=stamp_fits)

                self.set_header("Content-Type", "image/fits")
                self.set_header(
                    "Content-Disposition", f"Attachment;filename={fits_name}"
                )
                self.write(stamp_fits.getvalue())

            if file_format == "png":
                buff = io.BytesIO()
                plt.close("all")

                fig, ax = plt.subplots(figsize=(4, 4))
                fig.subplots_adjust(0, 0, 1, 1)
                ax.set_axis_off()

                # replace nans with median:
                img = np.array(data_flipped_y)
                # replace dubiously large values
                xl = np.greater(np.abs(img), 1e20, where=~np.isnan(img))
                if img[xl].any():
                    img[xl] = np.nan
                if np.isnan(img).any():
                    median = float(np.nanmean(img.flatten()))
                    img = np.nan_to_num(img, nan=median)
                norm = ImageNormalize(img, stretch=stretcher)
                img_norm = norm(img)
                vmin, vmax = normalizer.get_limits(img_norm)
                ax.imshow(img_norm, cmap=cmap, origin="lower", vmin=vmin, vmax=vmax)
                plt.savefig(buff, dpi=42)
                buff.seek(0)
                plt.close("all")
                self.set_header("Content-Type", "image/png")
                self.write(buff.getvalue())

        except Exception:
            _err = traceback.format_exc()
            return self.error(f"failure: {_err}")
예제 #24
0
def load_data (path, extensions ,label_name, load_length , IR  , test = False, start = 0):

	print('Data loading :')
	file_n=load_csv_data(path+label_name,1)
	filelist_y = str(np.zeros(len(file_n)))
	filelist_j = str(np.zeros(len(file_n)))		#Read the file number and initialize
	filelist_h = str(np.zeros(len(file_n)))		#empty string lists
	filelist = str(np.zeros(len(file_n)))

	if IR == 'IR':		#Load the deeper infrared channel on its own

		transform =  MinMaxInterval()
		#Fill the list with the files path and name
		filelist_h = [path + 'EUC_H/' + 'imageEUC_H-' + str(int(file_n[x])) + extensions for x in np.arange(len(file_n))]


		Images = np.empty( ( load_length,200,200 ) ,dtype=np.float32)

		for k in range(load_length):
			if k%300==0:
				print(str(k) + ', loading file ' + str(k+start))
			#Interpolate the IR images to match the VIS resolution
			#and load the data
			try:
				Im_h = np.array(fits.getdata( filelist_h[k+start]))
				Im_h = transform(Im_h)
				Im_h= np.array( cv2.resize(Im_h, dsize=(200, 200))  )

				Images[k,:,:] = Im_h
			except FileNotFoundError:
				Im_h = np.array(fits.getdata( filelist_h[k+start-1]))
				Im_h = transform(Im_h)
				Im_h= np.array( cv2.resize(Im_h, dsize=(200, 200))  )
				print(filelist_h[k+start]+' not found')

	elif IR == 'mid':	#Load mid IR channels to combine them into one image

		transform =  MinMaxInterval()

		filelist_y = [path + 'EUC_Y/' + 'imageEUC_Y-' + str(int(file_n[x])) + extensions for x in np.arange(len(file_n))]
		filelist_j = [path + 'EUC_J/' + 'imageEUC_J-' + str(int(file_n[x])) + extensions for x in np.arange(len(file_n))]

		Images = np.empty( ( load_length,200,200 ) ,dtype=np.float32)

		for k in range(load_length):
			if k%300==0:
				print(str(k) + ', loading file ' + str(k+start))
			try:
				Im_y = np.array(fits.getdata( filelist_y[k+start]))
				Im_y = transform(Im_y)
				Im_y= np.array( cv2.resize(Im_y, dsize=(200, 200))  )

				Im_j = np.array(fits.getdata( filelist_j[k+start]))
				Im_j = transform(Im_j)
				Im_j= np.array( cv2.resize(Im_j, dsize=(200, 200))  )

				Images[k,:,:] = 0.5 * Im_y + 0.5*Im_j
			except FileNotFoundError:
				Im_y = np.array(fits.getdata( filelist_y[k+start-1]))
				Im_y = transform(Im_y)
				Im_y= np.array( cv2.resize(Im_y, dsize=(200, 200))  )

				Im_j = np.array(fits.getdata( filelist_j[k+start-1]))
				Im_j = transform(Im_j)
				Im_j= np.array( cv2.resize(Im_j, dsize=(200, 200))  )

				Images[k,:,:] = 0.5 * Im_y + 0.5*Im_j
				print(filelist_j[k+start]+' not found')

	else:
		#Load the visible images for the third channel
		filelist = [path + 'EUC_VIS/' + 'imageEUC_VIS-' + str(int(file_n[x])) + extensions for x in np.arange(len(file_n))]

		Images = np.empty( ( load_length,200,200 ) ,dtype=np.float32)

		for k in range(load_length):
			if k%1000==0:
				print(str(k) + ', loading file ' + str(k+start))
			try:

				Images[k,:,:] = np.array( fits.getdata( filelist[k+start    ] ) )
			except FileNotFoundError:
				Images[k,:,:] = np.array( fits.getdata( filelist[k+start -1   ] ) )
				print(filelist[k+start]+' not found')
				

	print('Done loading data')

	return Images
예제 #25
0
                       lw=2,
                       color='steelblue',
                       label='Test - Orig.')
            Ax[0].plot(bins,
                       ci_test_mod,
                       lw=2,
                       color='steelblue',
                       label='Test - Smthd.')
            Ax[0].plot([0, 1], [0, 1], ':', color='k', lw=2)

            leg1 = Ax[0].legend(loc='upper left', prop={'size': 9})

            Ax[1].plot(bins, ci_train_all, 'k', label='All', lw=2)

            magrange = np.arange(17., 23., 1.)
            magcols = plt.cm.viridis(MinMaxInterval()(magrange))
            for ic, magc in enumerate(magrange):
                sample = (photom[mag_col][sbset * mcut][train_index] >= magc -
                          1.) * (photom[mag_col][sbset * mcut][train_index] <
                                 magc + 1.)
                ci_train_mag, bins = pdf.calc_ci_dist(pz_train[sample], zgrid,
                                                      zspec_train[sample])
                Ax[1].plot(bins,
                           ci_train_mag,
                           color=magcols[ic],
                           ls='--',
                           lw=2)

                ci_train_mag, bins = pdf.calc_ci_dist(pz_mod[sample], zgrid,
                                                      zspec_train[sample])
                Ax[1].plot(bins, ci_train_mag, color=magcols[ic], lw=2)
예제 #26
0
def plot_subimage(fig: plt.figure, hdu: Union[str, fits.HDUList], ra: float, dec: float,
                  frame: Union[int, float], world_frame: bool = False, title: str = None,
                  n: int = 1, n_x: int = 1, n_y: int = 1,
                  cmap: str = 'viridis', show_cbar: bool = False, stretch: str = 'sqrt', vmin: float = None,
                  vmax: float = None,
                  show_grid: bool = False,
                  ticks: int = None, interval: str = 'minmax',
                  show_coords: bool = True, ylabel: str = None,
                  font_size: int = 12,
                  reverse_y=False):
    """

    :param fig:
    :param hdu:
    :param ra:
    :param dec:
    :param frame: in pixels, or in arcsecs (?) if world_frame is True.
    :param world_frame:
    :param title:
    :param n:
    :param n_x:
    :param n_y:
    :param cmap:
    :param show_cbar:
    :param stretch:
    :param vmin:
    :param vmax:
    :param show_grid:
    :param ticks:
    :param interval:
    :param show_coords:
    :param ylabel:
    :param font_size:
    :param reverse_y:
    :return:
    """
    print(hdu)
    hdu, path = ff.path_or_hdu(hdu=hdu)
    print(hdu[0].data.shape)

    hdu_cut = ff.trim_frame_point(hdu=hdu, ra=ra, dec=dec, frame=frame, world_frame=world_frame)
    wcs_cut = wcs.WCS(header=hdu_cut[0].header)

    print(n_y, n_x, n)
    if show_coords:
        plot = fig.add_subplot(n_y, n_x, n, projection=wcs_cut)
        if ticks is not None:
            lat = plot.coords[0]
            lat.set_ticks(number=ticks)
    else:
        plot = fig.add_subplot(n_y, n_x, n)
        frame1 = plt.gca()
        frame1.axes.get_xaxis().set_visible(False)
        frame1.axes.set_yticks([])
        # frame1.axes.get_yaxis().set_visible(False)

    if show_grid:
        plt.grid(color='black', ls='dashed')

    if type(vmin) is str:
        if vmin == 'median_full':
            vmin = np.nanmedian(hdu[0].data)
        elif vmin == 'median_cut':
            vmin = np.nanmedian(hdu_cut[0].data)
        else:
            raise ValueError('Unrecognised vmin string argument.')

    if interval == 'minmax':
        interval = MinMaxInterval()
    elif interval == 'zscale':
        interval = ZScaleInterval()
    else:
        raise ValueError('Interval not recognised.')

    print(hdu_cut[0].data.shape)
    if stretch == 'log':
        norm = ImageNormalize(hdu_cut[0].data, interval=interval, stretch=LogStretch(), vmin=vmin, vmax=vmax)
    elif stretch == 'sqrt':
        norm = ImageNormalize(hdu_cut[0].data, interval=interval, stretch=SqrtStretch(), vmin=vmin, vmax=vmax)
    else:
        raise ValueError('Stretch not recognised.')

    plot.title.set_text(title)
    plot.title.set_size(font_size)
    print(ylabel)
    if ylabel is not None:
        plot.set_ylabel(ylabel, size=12)

    im = plt.imshow(hdu_cut[0].data, norm=norm, cmap=cmap)
    if reverse_y:
        plot.invert_yaxis()
    c_ticks = np.linspace(norm.vmin, norm.vmax, 5, endpoint=True)
    if show_cbar:
        cbar = plt.colorbar(im)  # ticks=c_ticks)

    return plot, hdu_cut
def image_plot(inputfile, d_range, imsize, outputfile):
    """Takes a scaled image in the form of a numpy array and plots it along with coordinates and a colorbar"""

    block = import_fits(inputfile, d_range)
    data = block[0]
    hrd = block[1]
    wcs = WCS(hrd)
    if hrd['TELESCOP'] != 'Spitzer':
        beam = Beam.from_fits_header(hrd)

    print(np.shape(data))
    norm = ImageNormalize(data,
                          interval=MinMaxInterval(),
                          stretch=SqrtStretch())

    figure = plt.figure(num=1)
    if wcs.pixel_n_dim == 4:
        ax = figure.add_subplot(111, projection=wcs, slices=('x', 'y', 0, 0))
    elif wcs.pixel_n_dim == 2:
        ax = figure.add_subplot(111, projection=wcs, slices=('x', 'y'))

    main_image = ax.imshow(X=data,
                           cmap='plasma',
                           origin='lower',
                           norm=norm,
                           vmax=np.max(data),
                           vmin=np.min(data))
    cbar = figure.colorbar(main_image)

    if hrd['TELESCOP'] == 'Spitzer':
        ax.invert_xaxis()
        ax.invert_yaxis()

    #position=SkyCoord('23h20m48s', '+61d12m06s', frame='icrs')
    #centre=wcs.world_to_pixel(position)
    dims = np.shape(data)
    centre = (dims[0] / 2., dims[1] / 2.)
    ax.set_xlim(centre[0] - imsize, centre[0] + imsize)
    ax.set_ylim(centre[1] - imsize, centre[1] + imsize)

    ax.set_xlabel('Right Ascension J2000', fontdict=font)
    ax.set_ylabel('Declination J2000', fontdict=font)
    cbar.set_label('Surface Brigthness (MJy/Sr)', fontdict=font)
    ra = ax.coords[0]
    ra.set_format_unit('degree', decimal=True)

    dec = ax.coords[1]
    dec.set_format_unit('degree', decimal=True)

    #ax.set_title('Spitzer 24\u03BCm', fontdict=font)
    ax.set_title('8-12 GHz, 5\u03C3 mask', fontdict=font)

    if hrd['TELESCOP'] != 'Spitzer':
        beam = Beam.from_fits_header(hrd)
        c = SphericalCircle((350.34, 61.13) * u.degree,
                            beam.major,
                            edgecolor='black',
                            facecolor='none',
                            transform=ax.get_transform('fk5'))
        ax.add_patch(c)

    if outputfile != False:
        plt.savefig(os.getcwd() + '/thesis_figs/' + outputfile)
    plt.show()
예제 #28
0
def data_to_pitch(data_array,
                  pitch_range=[100, 10000],
                  center_pitch=440,
                  zero_point="median",
                  stretch='linear',
                  minmax_percent=None,
                  minmax_value=None,
                  invert=False):
    """
    Map data array to audible pitches in the given range, and apply stretch and scaling
    as required.

    Parameters
    ----------
    data_array : array-like
        Data to map to pitch values. Individual data values should be floats.
    pitch_range : array
        Optional, default [100,10000]. Range of acceptable pitches in Hz. 
    center_pitch : float
        Optional, default 440. The pitch in Hz where that the the zero point of the 
        data will be mapped to.
    zero_point : str or float
        Optional, default "median". The data value that will be mapped to the center
        pitch. Options are mean, median, or a specified data value (float).
    stretch : str
        Optional, default 'linear'. The stretch to apply to the data array.
        Valid values are: asinh, sinh, sqrt, log, linear
    minmax_percent : array
        Optional. Interval based on a keeping a specified fraction of data values (can be asymmetric) 
        when scaling the data. The format is [lower percentile, upper percentile], where data
        values below the lower percentile and above the upper percentile are clipped.
        Only one of minmax_percent and minmax_value should be specified.
    minmax_value : array
        Optional. Interval based on user-specified data values when scaling the data array.
        The format is [min value, max value], where data values below the min value and above
        the max value are clipped.
        Only one of minmax_percent and minmax_value should be specified.
    invert : bool
        Optional, default False.  If True the pitch array is inverted (low pitches become high 
        and vice versa).

    Returns
    -------
    response : array
        The normalized data array, with values in given pitch range.
    """

    # Parsing the zero point
    if zero_point in ("med", "median"):
        zero_point = np.median(data_array)
    if zero_point in ("ave", "mean", "average"):
        zero_point = np.mean(data_array)

    if (data_array == zero_point
        ).all():  # All values are the same, no more calculation needed
        return np.full(len(data_array), zero_point)

    # Normalizing the data_array and adding the zero point (so it can go through the same transform)
    data_array = np.append(np.array(data_array), zero_point)

    # Setting up the transform with the stretch
    if stretch == 'asinh':
        transform = AsinhStretch()
    elif stretch == 'sinh':
        transform = SinhStretch()
    elif stretch == 'sqrt':
        transform = SqrtStretch()
    elif stretch == 'log':
        transform = LogStretch()
    elif stretch == 'linear':
        transform = LinearStretch()
    else:
        raise InvalidInputError("Stretch {} is not supported!".format(stretch))

    # Adding the scaling to the transform
    if minmax_percent is not None:
        transform += AsymmetricPercentileInterval(*minmax_percent)

        if minmax_value is not None:
            warnings.warn(
                "Both minmax_percent and minmax_value are set, minmax_value will be ignored.",
                InputWarning)
    elif minmax_value is not None:
        transform += ManualInterval(*minmax_value)
    else:  # Default, scale the entire image range to [0,1]
        transform += MinMaxInterval()

    # Performing the transform and then putting it into the pich range
    pitch_array = transform(data_array)

    if invert:
        pitch_array = 1 - pitch_array

    zero_point = pitch_array[-1]
    pitch_array = pitch_array[:-1]

    if ((1 / zero_point) *
        (center_pitch - pitch_range[0]) + pitch_range[0]) <= pitch_range[1]:
        pitch_array = (pitch_array / zero_point) * (
            center_pitch - pitch_range[0]) + pitch_range[0]
    else:
        pitch_array = (
            (pitch_array - zero_point) /
            (1 - zero_point)) * (pitch_range[1] - center_pitch) + center_pitch

    return pitch_array
예제 #29
0
파일: plot.py 프로젝트: battyone/PIAA
def show_stamps(pscs,
                frame_idx=None,
                stamp_size=11,
                show_residual=False,
                stretch=None,
                save_name=None,
                show_max=False,
                show_pixel_grid=False,
                **kwargs):

    ncols = len(pscs)

    if show_residual:
        ncols += 1

    nrows = 1

    fig = Figure()
    FigureCanvas(fig)
    fig.set_figheight(4)
    fig.set_figwidth(8)

    if frame_idx is not None:
        s0 = pscs[0][frame_idx]
        s1 = pscs[1][frame_idx]
    else:
        s0 = pscs[0]
        s1 = pscs[1]

    if stretch == 'log':
        stretch = LogStretch()
    else:
        stretch = LinearStretch()

    norm = ImageNormalize(s0, interval=MinMaxInterval(), stretch=stretch)

    ax1 = fig.add_subplot(nrows, ncols, 1)

    im = ax1.imshow(s0, cmap=get_palette(), norm=norm)

    # create an axes on the right side of ax. The width of cax will be 5%
    # of ax and the padding between cax and ax will be fixed at 0.05 inch.
    # https://stackoverflow.com/questions/18195758/set-matplotlib-colorbar-size-to-match-graph
    divider = make_axes_locatable(ax1)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    fig.colorbar(im, cax=cax)
    ax1.set_title('Target')

    # Comparison
    ax2 = fig.add_subplot(nrows, ncols, 2)
    im = ax2.imshow(s1, cmap=get_palette(), norm=norm)

    divider = make_axes_locatable(ax2)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    fig.colorbar(im, cax=cax)
    ax2.set_title('Comparison')

    if show_pixel_grid:
        add_pixel_grid(ax1, stamp_size, stamp_size, show_superpixel=False)
        add_pixel_grid(ax2, stamp_size, stamp_size, show_superpixel=False)

    if show_residual:
        ax3 = fig.add_subplot(nrows, ncols, 3)

        # Residual
        residual = s0 - s1
        im = ax3.imshow(residual,
                        cmap=get_palette(),
                        norm=ImageNormalize(residual,
                                            interval=MinMaxInterval(),
                                            stretch=LinearStretch()))

        divider = make_axes_locatable(ax3)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        fig.colorbar(im, cax=cax)
        ax3.set_title('Noise Residual')
        ax3.set_title('Residual RMS: {:.01%}'.format(residual.std()))
        ax3.set_yticklabels([])
        ax3.set_xticklabels([])

        if show_pixel_grid:
            add_pixel_grid(ax1, stamp_size, stamp_size, show_superpixel=False)

    # Turn off tick labels
    ax1.set_yticklabels([])
    ax1.set_xticklabels([])
    ax2.set_yticklabels([])
    ax2.set_xticklabels([])

    if save_name:
        try:
            fig.savefig(save_name)
        except Exception as e:
            warn("Can't save figure: {}".format(e))

    return fig
예제 #30
0
def show_stamps(psc0,
                psc1,
                frame_idx=0,
                aperture_info=None,
                show_rgb_aperture=True,
                stamp_size=10,
                stretch=None,
                save_name=None,
                show_max=False,
                show_pixel_grid=False,
                cmap='viridis',
                bias_level=2048,
                fig=None,
                **kwargs):

    nrows = 1

    # Two stamps and residual
    ncols = 3

    if show_rgb_aperture:
        ncols += 1

    if fig is None:
        fig, axes = plt.subplots(nrows=nrows, ncols=ncols)
        fig.set_figheight(8)
        fig.set_figwidth(12)
    else:
        axes = fig.axes

    # Get our stamp index
    s0 = psc0[frame_idx]
    s1 = psc1[frame_idx]

    # Get aperture info index
    frame_aperture = None
    if aperture_info is not None:
        aperture_idx = aperture_info.index.levels[0][frame_idx]
        frame_aperture = aperture_info.loc[aperture_idx]

    # Control the stretch
    stretch_method = LinearStretch()
    if stretch == 'log':
        stretch_method = LogStretch()

    norm = ImageNormalize(s1,
                          vmin=bias_level,
                          vmax=s1.max(),
                          stretch=stretch_method)

    # Get axes
    ax1 = axes[0]
    ax2 = axes[1]
    ax3 = axes[2]

    # Target stamp
    im = ax1.imshow(s0, cmap=get_palette(cmap=cmap), norm=norm)
    ax1.set_title('Target', fontsize=16)

    # Target Colorbar
    # https://stackoverflow.com/questions/13310594/positioning-the-colorbar
    divider = make_axes_locatable(ax1)
    cax = divider.new_horizontal(size="5%", pad=0.05, pack_start=False)
    fig.add_axes(cax)
    cbar = fig.colorbar(im, cax=cax, orientation="vertical")
    tick_locator = ticker.MaxNLocator(nbins=3)
    cbar.locator = tick_locator
    cbar.update_ticks()

    # Comparison
    im = ax2.imshow(s1, cmap=get_palette(cmap=cmap), norm=norm)
    ax2.set_title('Comparison', fontsize=16)

    # Comparison Colorbar
    divider = make_axes_locatable(ax2)
    cax = divider.new_horizontal(size="5%", pad=0.05, pack_start=False)
    fig.add_axes(cax)
    cbar = fig.colorbar(im, cax=cax, orientation="vertical")

    tick_locator = ticker.MaxNLocator(nbins=3)
    cbar.locator = tick_locator
    cbar.update_ticks()

    # Residual
    residual = (s0 - s1) / s1
    im = ax3.imshow(residual,
                    cmap=get_palette(cmap=cmap),
                    norm=ImageNormalize(residual,
                                        interval=MinMaxInterval(),
                                        stretch=LinearStretch()))
    ax3.set_title(f'Residual',
                  fontsize=16)  # Replaced below with aperture residual

    # Residual Colorbar
    divider = make_axes_locatable(ax3)
    cax = divider.new_horizontal(size="5%", pad=0.05, pack_start=False)
    fig.add_axes(cax)
    cbar = fig.colorbar(
        im,
        cax=cax,
        orientation="vertical",
        #ticks=[residual.min(), 0, residual.max()]
    )
    tick_locator = ticker.MaxNLocator(nbins=5)
    cbar.locator = tick_locator
    cbar.update_ticks()

    # Show apertures
    if frame_aperture is not None:
        # Make the shapely-based aperture
        aperture_pixels = make_shapely_aperture(aperture_info, aperture_idx)
        # TODO: Sometimes holes are appearing.
        full_aperture = cascaded_union(
            [x for x in chain(*aperture_pixels.values())])

        # Get the plotting positions.
        # If there are holes in aperture we get a MultiPolygon
        # and need to handle with a loop. Need to figure out how
        # to handle holes better.
        try:
            xs, ys = full_aperture.exterior.xy

            # Plot aperture mask on target, comparison, and residual.
            ax1.fill(xs, ys, fc='none', ec='orange', lw=3)
            ax2.fill(xs, ys, fc='none', ec='orange', lw=3)
            ax3.fill(xs, ys, fc='none', ec='orange', lw=3)
        except AttributeError:
            for poly in full_aperture:
                xs, ys = poly.exterior.xy

                # Plot aperture mask on target, comparison, and residual.
                ax1.fill(xs, ys, fc='none', ec='orange', lw=3)
                ax2.fill(xs, ys, fc='none', ec='orange', lw=3)
                ax3.fill(xs, ys, fc='none', ec='orange', lw=3)

        # Set the residual title with the std inside the aperture.
        aperture_mask = make_aperture_mask(aperture_info, frame_idx)
        residual_aperture = np.ma.array(data=residual, mask=aperture_mask)

        residual_std = residual_aperture.std()
        ax3.set_title(f'Residual {residual_std:.02f}%', fontsize=16)

        if show_rgb_aperture:
            ax4 = axes[3]

            # Show a checkerboard for bayer (just greyscale)
            bayer = np.ones_like(s0)
            bayer[1::2, 0::2] = 0.1  # Red
            bayer[1::2, 1::2] = 1  # Green
            bayer[0::2, 0::2] = 1  # Green
            bayer[0::2, 1::2] = 0.1  # Blue
            im = ax4.imshow(bayer, alpha=0.17, cmap='Greys')

            # We want the facecolor to be transparent but not the edge
            # so we add transparency directly to facecolor rather than
            # using the normal `alpha` option.
            alpha_value = 0.75
            color_lookup = {
                'r': (1, 0, 0, alpha_value),
                'g': (0, 1, 0, alpha_value),
                'b': (0, 0, 1, alpha_value),
            }

            # Plot individual pixels of the aperture in their appropriate color.
            for color, box_list in aperture_pixels.items():
                for i, b0 in enumerate(box_list):
                    xs, ys = b0.exterior.xy
                    bayer = np.ones((10, 10))
                    ax4.fill(xs, ys, fc=color_lookup[color], ec='k', lw=3)

            add_pixel_grid(ax4, stamp_size, stamp_size, show_superpixel=True)

            ax4.set_title(f'RGB Pattern', fontsize=16)
            ax4.set_yticklabels([])
            ax4.set_xticklabels([])
            ax4.grid(False)

            # Aperture colorbar
            # Add a blank colorbar so formatting is same
            # Todo keep sizes but get rid of colorbar
            divider = make_axes_locatable(ax4)
            cax = divider.new_horizontal(size="5%", pad=0.05, pack_start=False)
            fig.add_axes(cax)
            cbar = fig.colorbar(im, cax=cax, orientation="vertical")
            cbar.ax.set_xticklabels([])
            cbar.ax.set_yticklabels([])

    # Turn off tick labels
    ax1.set_yticklabels([])
    ax1.set_xticklabels([])
    ax2.set_yticklabels([])
    ax2.set_xticklabels([])
    ax3.set_yticklabels([])
    ax3.set_xticklabels([])

    # Turn off grids
    ax1.grid(False)
    ax2.grid(False)
    ax3.grid(False)

    fig.subplots_adjust(wspace=0.3)

    if save_name:
        try:
            fig.savefig(save_name)
        except Exception as e:
            warn("Can't save figure: {}".format(e))

    return fig