예제 #1
0
    def _get_projection(self):
        if self.hdu is None:
            try:
                hdu = SkyView.get_images(position=self.coord,
                                         coordinates='icrs',
                                         survey=self.survey,
                                         radius=self.radius * u.arcsec,
                                         grid=self.grid)[0][0]
                wcs = WCS(hdu.header)
                self.vlim = PercentileInterval(99.).get_limits(hdu.data)

            except (IndexError, HTTPError):
                hdu = SkyView.get_images(position=self.coord,
                                         coordinates='icrs',
                                         survey=self.survey,
                                         radius=self.radius * u.arcsec -
                                         1 * u.arcsec,
                                         grid=self.grid)[0][0]
                wcs = WCS(hdu.header)
                self.vlim = PercentileInterval(99.).get_limits(hdu.data)

        else:
            wcs = WCS(self.hdu.header)
            self.vlim = PercentileInterval(99.).get_limits(self.hdu.data)
        return wcs
예제 #2
0
    def figure_final_before_s(self, data):
        self._figure.clf()
        ax = self._figure.add_subplot(111)
        cmap = mpl.cm.get_cmap('gray')
        # norm = mpl.colors.LogNorm()
        ax.set_title('Result image')
        ax.set_xlabel('X')
        ax.set_ylabel('Y')

        interval = PercentileInterval(50.)
        z1, z2 = interval.get_limits(data)
        norm = ImageNormalize(vmin=z1, vmax=z2, stretch=SqrtStretch())
        ax.imshow(data, cmap=cmap, clim=(z1, z2), norm=norm)
        self._figure.canvas.draw()
예제 #3
0
    def figure_image(self, thedata, image):
        ax = self._figure.gca()
        image_axes, = ax.get_images()
        image_axes.set_data(thedata)

        # Create normalizer object
        interval = PercentileInterval(50.)
        z1, z2 = interval.get_limits(thedata)
        norm = ImageNormalize(vmin=z1, vmax=z2, stretch=SqrtStretch())
        image_axes.set_clim(z1, z2)
        image_axes.set_norm(norm)
        clim = image_axes.get_clim()
        ax.set_title('%s, bg=%g fg=%g, linscale' %
                     (image.lastname, clim[0], clim[1]))
        self._figure.canvas.draw()
예제 #4
0
 def showFittedStars(self, perc_interval=95, aperture_radius=7):
     from astropy.visualization import PercentileInterval
     positions = (self._starsTab['x_fit'], self._starsTab['y_fit'])
     apertures = CircularAperture(positions, r=aperture_radius)
     sandbox.showNorm(self._image,
                      interval=PercentileInterval(perc_interval))
     apertures.plot()
예제 #5
0
def plot_tramlines(image_data, tramlines, tramlines_bg=None):
    """
    Displays image data with the tramline extraction regions using the viridis colour map, and
    the remainder in grey.
    """
    norm = ImageNormalize(image_data,
                          interval=PercentileInterval(99.5),
                          stretch=LinearStretch(),
                          clip=False)
    spectrum_data = np.ma.array(image_data, copy=True)
    tramline_mask = np.ones(spectrum_data.shape, dtype=np.bool)
    for tramline in tramlines:
        tramline_mask[tramline] = False
    spectrum_data[tramline_mask] = np.ma.masked

    fig = plt.figure(figsize=(15, 6), tight_layout=True)
    ax1 = fig.add_subplot(1, 1, 1)
    ax1.set_aspect('equal')

    if tramlines_bg:
        background_data = np.ma.array(image_data, copy=True)
        background_mask = np.ones(background_data.shape, dtype=np.bool)
        for tramline_bg in tramlines_bg:
            background_mask[tramline_bg] = False
        background_data[background_mask] = np.ma.masked
        ax1.imshow(background_data, cmap='gray_r', norm=norm, origin='lower')
    else:
        ax1.imshow(image_data, cmap='gray_r', norm=norm, origin='lower')

    spectrum_image = ax1.imshow(spectrum_data,
                                cmap='viridis_r',
                                norm=norm,
                                origin='lower')
    fig.colorbar(spectrum_image)
    plt.show()
예제 #6
0
def plot_superbias(superbias,
                   nbiases,
                   dataname,
                   title=None,
                   outname=None,
                   show_plots=False):
    if title is None:
        title = f'Superbias: mean of {nbiases} exposures post-PCA oscan in {dataname}'
    fig = plt.figure(dpi=150, facecolor='white')
    plt.subplot(211)
    norm = ImageNormalize(superbias, interval=PercentileInterval(99))
    plt.imshow(np.rot90(superbias), norm=norm, cmap='gist_heat')
    plt.colorbar(orientation='horizontal')
    plt.xlabel('ADU')

    plt.subplot(212)
    plt.hist(superbias.flat, bins=100)
    plt.xlabel('ADU')
    plt.ylabel('Bin count')
    plt.yscale('log')
    plt.suptitle(title, wrap=True)

    if outname is not None:
        plt.savefig(outname)
    if show_plots:
        plt.show()
    plt.close()
예제 #7
0
    def set_compass(self, image):
        """Update the Compass plugin with info from the given image Data object."""
        if self.compass is None:  # Maybe another viewer has it
            return

        zoom_limits = (self.state.x_min, self.state.y_min, self.state.x_max,
                       self.state.y_max)
        if data_has_valid_wcs(image):
            wcs = image.coords

            # Convert X,Y from reference data to the one we are actually seeing.
            if self.get_link_type(image.label) == 'wcs':
                x = wcs.world_to_pixel(
                    self.state.reference_data.coords.pixel_to_world(
                        (self.state.x_min, self.state.x_max),
                        (self.state.y_min, self.state.y_max)))
                zoom_limits = (x[0][0], x[1][0], x[0][1], x[1][1])
        else:
            wcs = None

        arr = image[image.main_components[0]]
        vmin, vmax = PercentileInterval(95).get_limits(arr)
        norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=LinearStretch())
        self.compass.draw_compass(
            image.label,
            wcs_utils.draw_compass_mpl(arr,
                                       wcs,
                                       show=False,
                                       zoom_limits=zoom_limits,
                                       norm=norm))
예제 #8
0
def get_image_datatab(ra, dec, width):
    fitsurl = geturl(ra,
                     dec,
                     size=int(width * 240),
                     filters="i",
                     format="fits")
    fh = fits.open(fitsurl[0])
    fhead = fh[0].header
    wcs = WCS(fhead)
    fim = fh[0].data
    # replace NaN values with zero for display
    fim[numpy.isnan(fim)] = 0.0
    # set contrast to something reasonable
    transform = AsinhStretch() + PercentileInterval(90)
    bfim = transform(fim)
    #query PS1 catalog
    datatab = panstarrs_query_pos(ra, dec, width)
    #query the table with positions and aperture pre-defined
    positions = SkyCoord(ra=datatab['raMean'],
                         dec=datatab['decMean'],
                         unit='deg')
    aper = SkyCircularAperture(positions, 0.5 * u.arcsec)
    pix_aperture = aper.to_pixel(wcs)
    #plot
    fig = plt.figure()
    fig.add_subplot(111, projection=wcs)
    #norm = ImageNormalize(stretch=SqrtStretch())
    plt.imshow(bfim, cmap='Greys', origin='lower')
    pix_aperture.plot(color='blue', lw=0.5, alpha=0.5)
    plt.xlabel('RA')
    plt.ylabel('Dec')
    image = plt.savefig('test_run.jpg', dpi=1000)
    fits.writeto('test_run.fits', fim, fhead, overwrite=True)
    ascii.write(datatab, 'test_run.csv', format='csv', fast_writer=False)
    return datatab
예제 #9
0
    def exposeSF(self):
        watchpath = path_to_watch+"/Reference"
        before = dict ([(f, None) for f in os.listdir (watchpath)])
        print("Acquiring Single Frame")
        
        self.s.send("ACQUIRESINGLEFRAME")
        response = self.s.recv(buffersize)
        print(response)
        self.l1["text"] = response

        after = dict ([(f, None) for f in os.listdir (watchpath)])
        added = [f for f in after if not f in before]
        self.status["text"] = "READY"
        self.status["bg"] = "green"

        print("Added File: "+added[0])

        hdu = fits.open(watchpath+"/"+added[0])
        image = hdu[0].data*1.0
        if(self.correctData):
            channelrefcorrect(image)

        norm = ImageNormalize(image, interval=PercentileInterval(99.5),
                      stretch=LinearStretch())

        fig = plt.figure()
        ax = fig.add_subplot(1, 1, 1)
        im = ax.imshow(image, origin='lower', norm=norm, interpolation='none')
        ax.format_coord = Formatter(im)
        ax.set_title(added[0])
        fig.colorbar(im)
예제 #10
0
    def takeImage(self):
        if self.cam and self.foc:
            if self.imgtypeVariable.get() == 'Dark':
                self.cam.end_exposure()
                self.cam.set_exposure(int(self.entryExpVariable.get()),
                                      frametype='dark')
                img = self.cam.take_photo()
                self.cam.set_exposure(int(self.entryExpVariable.get()),
                                      frametype='normal')
            else:
                self.cam.end_exposure()
                self.cam.set_exposure(int(self.entryExpVariable.get()),
                                      frametype='normal')
                img = self.cam.take_photo()

            mpl.close()
            fig = mpl.figure()
            ax = fig.add_subplot(1, 1, 1)

            norm = ImageNormalize(img,
                                  interval=PercentileInterval(99.9),
                                  stretch=LinearStretch())

            im = ax.imshow(img,
                           interpolation='none',
                           norm=norm,
                           cmap='gray',
                           origin='lower')
            ax.format_coord = Formatter(im)
            fig.colorbar(im)
            mpl.show()
        return img
예제 #11
0
def bandsMerge2tif(dataPath, dataName, savePath, saveName, stretchFlag):
    run = GRID()

    bandNameList = []
    interval = PercentileInterval(95.)

    fileNameList = glob.glob(dataPath + "{}*.tif".format(dataName))
    num = len(fileNameList)

    proj, geotrans, data = run.read_data(fileNameList[0])  # read data
    bandNameList.append(fileNameList[0].split(".")[1])
    if stretchFlag:
        data = interval(data)

    if 1 == num:
        DATA = data
    else:

        DATA = np.zeros([num, data.shape[0], data.shape[1]])
        DATA[0, ...] = data

        for i in range(1, num):
            _, _, data = run.read_data(fileNameList[i])  # read data
            bandNameList.append(fileNameList[i].split(".")[1])
            if stretchFlag:
                DATA[i, ...] = interval(data)
            else:
                DATA[i, ...] = data

    if not os.path.exists(savePath):
        os.makedirs(savePath)
    print("{}: {}".format(saveName, bandNameList))
    run.write_data(savePath + saveName + ".tif", proj, geotrans, DATA,
                   bandNameList)
예제 #12
0
def calculate_movie_normalization(mc, percentile_interval=99.0, stretch=None):
    """
    A convenience function that calculates an image normalization
    that means a movie of the input mapcube will not flicker.
    Assumes that all layers are similar and the stretch function
    for all layers is the same

    Parameters
    ----------
    mc : `sunpy.map.MapCube`
        a sunpy mapcube

    percentile_interval : float
        the central percentile interval used to

    stretch :
        image stretch function

    Returns
    -------
    An image normalization setting that can be used with all the images
    in the mapcube ensuring no flickering in a movie of the images.
    """
    data = np.concatenate([m.data.flatten() for m in mc])
    vmin, vmax = PercentileInterval(percentile_interval).get_limits(data)
    if stretch is None:
        try:
            stretcher = mc[0].plot_settings['norm'].stretch
        except AttributeError:
            stretcher = None
    else:
        stretcher = stretch
    return ImageNormalize(vmin=vmin, vmax=vmax, stretch=stretcher)
예제 #13
0
    def saveImage(self):
        if self.cam:
            if self.imgtypeVariable.get() == 'Dark':
                self.cam.end_exposure()
                self.cam.set_exposure(int(self.entryExpVariable.get()),
                                      frametype='dark')
                img = self.cam.take_photo()
                self.cam.set_exposure(int(self.entryExpVariable.get()),
                                      frametype='normal')
            else:
                self.cam.end_exposure()
                self.cam.set_exposure(int(self.entryExpVariable.get()),
                                      frametype='normal')
                img = self.cam.take_photo()

            telemDict = WG.get_telemetry(self.telSock)
            hduhdr = self.makeHeader(telemDict)
            #hdu = fits.PrimaryHDU(header=hduhdr)
            #hdulist = fits.HDUList([hdu])
            if self.entryFilepathVariable.get() == "":
                print "Writing to: " + self.direc + self.todaydate + 'T' + time.strftime(
                    '%H%M%S') + '.fits'
                fits.writeto(self.direc + self.todaydate + 'T' +
                             time.strftime('%H%M%S') + '.fits',
                             img,
                             hduhdr,
                             clobber=True)
                #hdulist.writeto(self.direc+self.todaydate+'T'+time.strftime('%H%M%S')+'.fits', clobber=True)
            else:
                print "Writing to: " + self.direc + self.todaydate + 'T' + time.strftime(
                    '%H%M%S') + '_' + self.entryFilepathVariable.get(
                    ) + ".fits"
                fits.writeto(self.direc + self.todaydate + 'T' +
                             time.strftime('%H%M%S') + '_' +
                             self.entryFilepathVariable.get() + ".fits",
                             img,
                             hduhdr,
                             clobber=True)
                #hdulist.writeto(self.entryFilepathVariable.get(),clobber=True)
                #self.entryFilepathVariable.set("")

            mpl.close()
            fig = mpl.figure()
            ax = fig.add_subplot(1, 1, 1)

            norm = ImageNormalize(img,
                                  interval=PercentileInterval(99.9),
                                  stretch=LinearStretch())
            #norm = ImageNormalize(img,  stretch=LinearStretch())

            im = ax.imshow(img,
                           interpolation='none',
                           norm=norm,
                           cmap='gray',
                           origin='lower')
            ax.format_coord = Formatter(im)
            fig.colorbar(im)
            mpl.show()
예제 #14
0
 def percent_int(self):
     self.interval = PercentileInterval(self.percent)
     self.rbtn5.setChecked(False)
     self.rbtn6.setChecked(False)
     self.rbtn7.setChecked(True)
     self.rbtn8.setChecked(False)
     self.rbtn9.setChecked(False)
     self.refresh_norm()
     print('Percent = ' + str(self.percent))
예제 #15
0
def tifBand2png_GDAL(dataPath, dataName, savePath, saveName, pngSretch):
    run = GRID()
    if 'water' not in dataName:
        interval_95 = PercentileInterval(pngSretch)
    else:
        interval_95 = PercentileInterval(100)

    proj, geotrans, data = run.read_data(dataPath + dataName +
                                         ".tif")  # read data

    # print("data shape {}".format(data.shape))

    if len(data.shape) == 2:
        jpg_data = interval_95(data)

    if not os.path.exists(savePath):
        os.makedirs(savePath)
    print("{}".format(saveName))
    plt.imsave(savePath + saveName + ".png", jpg_data, cmap="gray")
예제 #16
0
def percentile(img: np.ndarray, percentile: int) -> Tuple[float, float]:
    """Determine percentile range.

    Calculates the range (vmin, vmax) so that a percentile
    of the pixels is within those values.

    Parameters
    ----------
    img
        image array
    percentile
        Percentile value

    Returns
    -------
    percentile range
    """
    p = PercentileInterval(percentile)
    vmin, vmax = p.get_limits(img.ravel())
    return vmin, vmax
예제 #17
0
def preview_image(HDU):
    """For an image, preview"""
    from astropy.visualization import quantity_support, PercentileInterval, LogStretch
    from astropy.visualization.mpl_normalize import ImageNormalize
    from astropy.wcs import WCS
    
    with quantity_support():
        fig = plt.figure()
        ax = fig.add_subplot(1,1,1, projection=WCS(HDU.header))
        image = PercentileInterval(90)(HDU.data)
        norm = ImageNormalize(stretch=LogStretch())
        im = ax.imshow(image, norm=norm, cmap='Blues_r')
        fig.colorbar(im, ax=ax)
    return fig
def xmkpy3_finder_chart_survey_fits_image_get_v1():
    import lightkurve as lk
    lk.log.setLevel('INFO')
    import matplotlib.pyplot as plt
    import astropy.units as u
    from astropy.visualization import ImageNormalize, PercentileInterval, SqrtStretch
    import os
    import ntpath

    # Exoplanet Kelper-138b is "KIC 7603200":
    tpf = lk.search_targetpixelfile(target='kepler-138b', mission='kepler',
      cadence='long', quarter=10).download(quality_bitmask=0)
    print('TPF filename:', ntpath.basename(tpf.path))
    print('TPF dirname: ', os.path.dirname(tpf.path))

    target = 'Kepler-138b'
    ra_deg = tpf.ra
    dec_deg = tpf.dec

    # get survey image data
    width_height_arcmin = 3.00
    survey = '2MASS-J'
    survey_hdu, survey_hdr, survey_data, survey_wcs, survey_cframe = \
      mkpy3_finder_chart_survey_fits_image_get_v1(ra_deg, dec_deg,
      radius_arcmin=width_height_arcmin, survey=survey, verbose=True)

    # create a matplotlib figure object
    fig = plt.figure(figsize=(12, 12))

    # create a matplotlib axis object with right ascension and declination axes
    ax = plt.subplot(projection=survey_wcs)

    norm = ImageNormalize(survey_data, interval=PercentileInterval(99.0),
      stretch=SqrtStretch())
    ax.imshow(survey_data, origin='lower', norm=norm, cmap='gray_r')

    ax.set_xlabel('Right Ascension (J2000)')
    ax.set_ylabel('Declination (J2000)')
    ax.set_title('')
    plt.suptitle(target)

    # put a yellow circle at the target position
    ax.scatter(ra_deg * u.deg, dec_deg * u.deg,
      transform=ax.get_transform(survey_cframe),
      s=600, edgecolor='yellow', facecolor='None', lw=3, zorder=100)

    pname = 'mkpy3_plot.png'
    if (pname != ''):
        plt.savefig(pname, bbox_inches="tight")
        print(pname, ' <--- plot filename has been written!  :-)\n')
예제 #19
0
    def exposeCDS(self):
        watchpath = path_to_watch + "/CDSReference"
        before = dict([(f, None) for f in os.listdir(watchpath)])
        print("Acquiring CDS Frame")

        self.s.send("ACQUIRECDS")
        response = self.s.recv(buffersize)
        print(response)

        self.l1["text"] = response

        after = dict([(f, None) for f in os.listdir(watchpath)])
        added = [f for f in after if not f in before]
        self.status["text"] = "READY"
        self.status["bg"] = "green"

        print("Added Directory: " + added[0] + ' , ' + self.sourcename.get())

        self.writeObsdata(watchpath + '/' + added[0])

        hdu = fits.open(watchpath + "/" + added[0] + "/Result/CDSResult.fits")
        image = hdu[0].data * 1.0
        if (self.correctData.get()):
            channelrefcorrect(image)

        norm = ImageNormalize(image,
                              interval=PercentileInterval(99.5),
                              stretch=LinearStretch())

        fig = plt.figure()
        ax = fig.add_subplot(1, 1, 1)
        im = ax.imshow(image, origin='lower', norm=norm, interpolation='none')
        ax.format_coord = Formatter(im)
        ax.set_title(added[0])
        fig.colorbar(im)

        if (self.histogram.get()):
            fig = plt.figure()
            subimage = image[900:1100, 900:1100].flatten()
            mean = np.mean(subimage)
            std = np.std(subimage)
            plt.hist(subimage, bins=200)
            plt.xlim([mean - 3 * std, mean + 3 * std])
            plt.title("Mean = %f5, Std = %f5" % (mean, std))
            plt.show()

        if (self.arcfwhm.get()):
            pointing(watchpath + "/" + added[0] + '/Result/CDSResult.fits')

        hdu.close()
예제 #20
0
 def __init__(self, image):
     #         plt.ion()
     self._ima = image
     self.fig = plt.figure()
     showNorm(image, interval=PercentileInterval(99.5))
     plt.title("Left click to select a star. End with right click.")
     self.ax = self.fig.gca()
     cid1 = self.fig.canvas.mpl_connect("button_press_event",
                                        self._press)
     cid2 = self.fig.canvas.mpl_connect("button_release_event",
                                        self._release)
     self.goon = True
     self.circles = []
     plt.show()
     self._start()
예제 #21
0
def plot(ra, dec, gal_list, s_list, ser_list, catalogue, search_size, filename):

    size = 240*search_size
    
    fitsurl = geturl(ra, dec, size=240*search_size, filters=texas_cfg['filters'], format="fits")
    fh = fits.open(fitsurl[0])
    wcs = WCS(fh[0].header)
    fim = fh[0].data
    fim[np.isnan(fim)] = 0.
    transform = AsinhStretch() + PercentileInterval(99)
    bfim = transform(fim)
    
    fig, ax = plt.subplots(1,1, figsize=(search_size*3,search_size*3))
    plt.subplot(projection=wcs)
    plt.imshow(bfim, cmap='summer')#, norm=LogNorm())
    ax = plt.gca()
    ax.set_xlim(0,240*search_size)
    ax.set_ylim(0,240*search_size)
    k = 1
    if len(gal_list)>0:
        last = gal_list[0]
    if catalogue == 'glade':
        for j in gal_list:
            if k==1 or ((abs(j['ra']-last['ra'])*np.cos(j['dec']/180*np.pi)>0.005 or abs(j['dec']-last['dec'])>0.005)):
                x, y = ((ra-j['ra'])*4*3600*np.cos(j['dec']/180*np.pi)+(size/2)), (j['dec']-dec)*4*3600+(size/2)
                ax.plot(x,y, 'bx', label = 'glade source')
#            print(j[10])
                z = int(float(j['z'])*10000)/10000.
                ax.annotate(str(k) + ': z='+str(z), xy=(x+20, y-5), fontsize=15, ha="center", color='b')
            k = k+1
            last = j
    elif catalogue == 'texas':
        plot_ellipse(gal_list, ra, dec, size, 'k', ax)#, label = 'texas source')

    for s in s_list:
        x, y = ((ra-s['raMean'])*4*3600*np.cos(s['decMean']/180*np.pi)+(size/2)), (s['decMean']-dec)*4*3600+(size/2)
        ax.plot(x,y, 'rx', label = 'PS point source')
	    
#    print(ser_list)
    if len(ser_list)>0:
        plot_ellipse(ser_list, ra, dec, size, 'm', ax)

    ax.plot([size/2+10, size/2+20], [size/2, size/2], 'k')
    ax.plot([size/2, size/2], [size/2+10, size/2+20], 'k')
    ax.plot([size/2-10, size/2-20], [size/2, size/2], 'k')
    ax.plot([size/2, size/2], [size/2-10, size/2-20], 'k')

    plt.savefig(filename)
예제 #22
0
def band2png(data, savePath, saveName, pngSretch):
    run = GRID()
    interval = PercentileInterval(pngSretch)

    # proj, geotrans, data = run.read_data(
    #     dataPath + dataName + ".tif")  # read data

    # print("data shape {}".format(data.shape))

    if len(data.shape) == 2:
        jpg_data = interval(data)

    if not os.path.exists(savePath):
        os.makedirs(savePath)
    print("{}".format(saveName))
    plt.imsave(savePath + saveName + ".png", jpg_data, cmap="gray")
예제 #23
0
def plot_image(image,
               scale='linear',
               origin='lower',
               xlabel='Pixel Column Number',
               ylabel='Pixel Row Number',
               clabel='Flux ($e^{-}s^{-1}$)',
               title=None,
               **kwargs):
    """Utility function to plot a 2D image

        Parameters
        ----------
        image : 2d array
            Image data.
        scale : str
            Scale used to stretch the colormap.
            Options: 'linear', 'sqrt', or 'log'.
        origin : str
            The origin of the coordinate system.
        xlabel : str
            Label for the x-axis.
        ylabel : str
            Label for the y-axis.
        clabel : str
            Label for the color bar.
        title : str or None
            Title for the plot.
        kwargs : dict
            Keyword arguments to be passed to `matplotlib.pyplot.imshow`.
        """
    fig, ax = plt.subplots()
    vmin, vmax = PercentileInterval(95.).get_limits(image)
    if scale == 'linear':
        norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=LinearStretch())
    elif scale == 'sqrt':
        norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=SqrtStretch())
    elif scale == 'log':
        norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=LogStretch())
    else:
        raise ValueError("scale {} is not available.".format(scale))

    cax = ax.imshow(image, origin=origin, norm=norm, **kwargs)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    cbar = fig.colorbar(cax, norm=norm, label=clabel)
    return fig, ax
예제 #24
0
def plot_image(image,
               cmap='gray',
               title='Image',
               input_ratio=None,
               interval_type='zscale',
               stretch='linear',
               percentile=99,
               xlim_minus=0,
               xlim_plus=1000,
               ylim_minus=0,
               ylim_plus=1000):

    from astropy.visualization import (ZScaleInterval, PercentileInterval,
                                       MinMaxInterval, ImageNormalize,
                                       simple_norm)

    fig = plt.figure(figsize=(10, 10))

    ax = fig.add_subplot(1, 1, 1)

    if interval_type == 'zscale':
        norm = ImageNormalize(image, interval=ZScaleInterval())

    elif interval_type == 'percentile':
        norm = ImageNormalize(image, interval=PercentileInterval(percentile))
    elif interval_type == 'minmax':
        norm = ImageNormalize(image, interval=MinMaxInterval())
    elif interval_type == 'simple_norm':
        norm = simple_norm(image, stretch)

    im = ax.imshow(image[xlim_minus:xlim_plus, ylim_minus:ylim_plus],
                   cmap=cmap,
                   interpolation='none',
                   origin='lower',
                   norm=norm)

    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="3%", pad=0.1)

    plt.colorbar(im, cax=cax, label='ADU')
    ax.set_xlabel("x [pixels]")
    ax.set_ylabel("y [pixels]")
    ax.set_title(title, fontsize=14)

    plt.show()
예제 #25
0
def tifBand2biMap_GDAL(dataPath, dataName, savePath, saveName, TH):
    run = GRID()
    interval_100 = PercentileInterval(100.)

    proj, geotrans, data = run.read_data(dataPath + dataName +
                                         ".tif")  # read data

    # print("data shape {}".format(data.shape))

    if len(data.shape) == 2:
        # jpg_data = interval_100(data)
        data[np.where(data > TH)] = 1.0
        data[np.where(data <= TH)] = 0.0

    if not os.path.exists(savePath):
        os.makedirs(savePath)
    print("{}".format(saveName))
    plt.imsave(savePath + saveName + ".png", data, cmap="gray")
예제 #26
0
def preview_cube(HDU):
    """Preview a datacube"""
    from astropy.visualization import quantity_support, PercentileInterval, LogStretch
    from astropy.visualization.mpl_normalize import ImageNormalize
    from astropy.wcs import WCS
    import numpy as np
    
    wcs = WCS(HDU.header)
    wave_axis = HDU.data.ndim - wcs.axis_type_names.index('WAVE') - 1
    wcs = wcs.dropaxis(wcs.axis_type_names.index('WAVE'))
    image = np.median(HDU.data, axis=wave_axis)
    
    with quantity_support():
        fig = plt.figure()
        ax = fig.add_subplot(1,1,1, projection=wcs)
        image = PercentileInterval(90)(image)
        norm = ImageNormalize(stretch=LogStretch())
        im = ax.imshow(image, norm=norm, cmap='Blues_r')
        fig.colorbar(im, ax=ax)
    return fig
    
예제 #27
0
def tifBand2png_GDAL(dataPath, dataName, savePath, saveName):

    run = GRID()
    interval = PercentileInterval(98.)

    proj, geotrans, data = run.read_data(dataPath + dataName +
                                         ".tif")  # read data

    # print("data shape {}".format(data.shape))

    if len(data.shape) == 3:
        jpg_data = data[0, :, :]
    elif len(data.shape) == 2:
        jpg_data = interval(data)
    else:
        jpg_data = data

    if not os.path.exists(savePath):
        os.makedirs(savePath)
    print("{}".format(saveName))
    plt.imsave(savePath + saveName + ".png", jpg_data, cmap="gray")
예제 #28
0
def fits_to_png(ff, outfile, log=False):
    plt.clf()
    ax = plt.axes()
    fim = ff[1].data
    # replace NaN values with zero for display
    fim[np.isnan(fim)] = 0.0
    # set contrast to something reasonable
    transform = AsinhStretch() + PercentileInterval(99.5)
    bfim = transform(fim)
    ax.imshow(bfim, cmap="gray", origin="lower")
    circle = plt.Circle((np.shape(fim)[0] / 2 - 1, np.shape(fim)[1] / 2 - 1),
                        15,
                        color='r',
                        fill=False)
    ax.add_artist(circle)
    plt.gca().set_axis_off()
    plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
    plt.margins(0, 0)
    plt.gca().xaxis.set_major_locator(plt.NullLocator())
    plt.gca().yaxis.set_major_locator(plt.NullLocator())
    plt.savefig(outfile, bbox_inches='tight', pad_inches=0)
예제 #29
0
def plot_image(image,
               scale='linear',
               origin='lower',
               xlabel='Pixel Column Number',
               ylabel='Pixel Row Number',
               title=None,
               **kwargs):
    vmin, vmax = PercentileInterval(95.).get_limits(image)
    if scale == 'linear':
        norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=LinearStretch())
    elif scale == 'sqrt':
        norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=SqrtStretch())
    elif scale == 'log':
        norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=LogStretch())
    else:
        raise ValueError("scale {} is not available.".format(scale))

    plt.imshow(image, origin=origin, norm=norm, **kwargs)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    plt.colorbar(norm=norm)
예제 #30
0
    def plot(self, nframe=100, scale='linear', **kwargs):
        pflux = self.flux[nframe]
        vmin, vmax = PercentileInterval(95.).get_limits(pflux)
        if scale == 'linear':
            norm = ImageNormalize(vmin=vmin,
                                  vmax=vmax,
                                  stretch=LinearStretch())
        elif scale == 'sqrt':
            norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=SqrtStretch())
        elif scale == 'log':
            norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=LogStretch())
        else:
            raise ValueError("scale {} is not available.".format(scale))

        plt.imshow(pflux,
                   origin='lower',
                   norm=norm,
                   extent=(self.column, self.column + self.shape[2], self.row,
                           self.row + self.shape[1]),
                   **kwargs)
        plt.xlabel('Pixel Column Number')
        plt.ylabel('Pixel Row Number')
        plt.title('Kepler ID: {}'.format(self.keplerid))
        plt.colorbar(norm=norm)