Exemple #1
0
    def plotSideBySide(self):
        self.coords1 = []
        self.coords2 = []
        fig1 = plt.figure()
        interval = ZScaleInterval()

        ax1 = fig1.add_subplot()
        limits = interval.get_limits(self.HSTImage1Data)
        ax1.imshow(self.HSTImage1Data,
                   cmap='Greys',
                   origin='lower',
                   interpolation='nearest',
                   vmin=limits[0],
                   vmax=limits[1])

        def onclick1(event):
            if event.dblclick:
                print(f'x = {event.xdata}, y = {event.ydata}')
                star, columnPix, rowPix = cutStar(event.xdata, event.ydata,
                                                  self.HSTImage1Data)
                if np.abs(rowPix - event.xdata) < 15 and np.abs(
                        columnPix - event.ydata) < 15:
                    plt.scatter(rowPix,
                                columnPix,
                                facecolors='none',
                                edgecolors='r',
                                s=50)
                    fig1.canvas.draw()
                    self.coords1.append((rowPix, columnPix))

        cid = fig1.canvas.mpl_connect('button_press_event', onclick1)

        fig2 = plt.figure()
        ax2 = fig2.add_subplot()
        limits = interval.get_limits(self.HSTImage2Data)
        ax2.imshow(self.HSTImage2Data,
                   cmap='Greys',
                   origin='lower',
                   interpolation='nearest',
                   vmin=limits[0],
                   vmax=limits[1])

        def onclick2(event):
            if event.dblclick:
                print(f'x = {event.xdata}, y = {event.ydata}')
                star, columnPix, rowPix = cutStar(event.xdata, event.ydata,
                                                  self.HSTImage2Data)
                if np.abs(rowPix - event.xdata) < 15 and np.abs(
                        columnPix - event.ydata) < 15:
                    plt.scatter(rowPix,
                                columnPix,
                                facecolors='none',
                                edgecolors='b',
                                s=50)
                    fig2.canvas.draw()
                    self.coords2.append((rowPix, columnPix))

        cid = fig2.canvas.mpl_connect('button_press_event', onclick2)
        plt.show()
Exemple #2
0
def quick_rgb(image_red, image_green, image_blue, contrast=0.25):
    # Determine limits for each channel
    interval = ZScaleInterval(contrast=contrast)
    red_min, red_max = interval.get_limits(image_red)
    green_min, green_max = interval.get_limits(image_green)
    blue_min, blue_max = interval.get_limits(image_blue)
    # Determine overall limits
    vmin, vmax = min(red_min, green_min, blue_min), max(red_max, green_max, blue_max)
    norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=LogStretch(), clip=True)
    # Make destination array
    rgbim = np.zeros(image_red.shape + (3,), dtype=np.uint8)
    for idx, im in enumerate((image_red, image_green, image_blue)):
        rescaled = (norm(im) * 255).astype(np.uint8)
        rgbim[:,:,idx] = rescaled
    return rgbim
Exemple #3
0
    def create_jpg_from_fits(self, fits_filepath, outdir):
        '''
        Basic convert fits primary data to jpg.  Instrument subclasses can override this function.
        '''

        #get image data
        hdu = fits.open(fits_filepath, ignore_missing_end=True)
        data = hdu[0].data
        hdr = hdu[0].header

        #form filepaths
        basename = os.path.basename(fits_filepath).replace('.fits', '')
        jpg_filepath = f'{outdir}/{basename}.jpg'

        #create jpg
        interval = ZScaleInterval()
        vmin, vmax = interval.get_limits(data)
        norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=AsinhStretch())
        dpi = 100
        width_inches = hdr['NAXIS1'] / dpi
        height_inches = hdr['NAXIS2'] / dpi
        fig = plt.figure(figsize=(width_inches, height_inches),
                         frameon=False,
                         dpi=dpi)
        ax = fig.add_axes([0, 0, 1, 1])  #this forces no border padding
        plt.axis('off')
        plt.imshow(data, cmap='gray', origin='lower', norm=norm)
        plt.savefig(jpg_filepath, quality=92)
        plt.close()
Exemple #4
0
def fits2png(filename: Union[str, Path], sh: SizeHint, hdu_index=1):
    with timeit(f'fits2png({filename})'):
        with timeit('astropy.fits.open'):
            with afits.open(filename) as hdul:
                data = hdul[hdu_index].data[::-1]  # type: ignore
        assert len(data.shape) == 2
        factor = max(
            (data.shape[0] - 1) // sh.max_height + 1 if sh.max_height else 0,
            (data.shape[1] - 1) // sh.max_width + 1 if sh.max_width else 0,
            sh.factor or 0,
        )
        with timeit('resize'):
            data = skimage.transform.downscale_local_mean(
                data, (factor, factor))
        zscale = ZScaleInterval()
        with timeit('zscale'):
            vmin, vmax = zscale.get_limits(data)
        with timeit('convert2uint'):
            data8 = numpy.array(255 * numpy.clip(
                (data - vmin) / (vmax - vmin), 0., 1.),
                                dtype=numpy.uint8)
        img = Image.fromarray(data8)
        buffer = io.BytesIO()
        img.save(buffer, format='png')
        return buffer.getvalue()
Exemple #5
0
def plot_aper_mask(fluxes,rad,aper_shape,contrast=0.1,epic=None):
    stacked = np.nansum(fluxes,axis=0)
    mask = make_mask(fluxes, rad = rad, shape=aper_shape)
    y,x = centroid_com(stacked)

    fig, ax = pl.subplots(1,1)
    interval = ZScaleInterval(contrast=contrast)
    zmin,zmax = interval.get_limits(stacked)

    cb = ax.imshow(stacked, origin='bottom', interpolation=None)
    ax.imshow(mask, alpha=0.3)
    if aper_shape == 'round':
        circ = pl.Circle((x,y),rad,color='w',alpha=0.2,lw=5,label='r={}'.format(rad))
        ax.add_artist(circ)

    ax.plot(x,y,'r+',ms=20, lw=10,label='centroid')
    pl.colorbar(cb, ax=ax)
    pl.xlabel('X')
    pl.ylabel('Y')
    pl.legend()
    if epic is not None:
        pl.title('EPIC {}'.format(epic))
    pl.show()

    return fig
Exemple #6
0
    def __init__(self, parent, controller):
        tk.Frame.__init__(self, parent)

# change 'figure_width': 1920, 'figure_height': 1080 =(19.2,10.8)
        f = Figure(figsize=(8, 6))
        a = f.add_subplot(111)
        f.subplots_adjust(left=0, bottom=0.005, right=1, top=1)
        a.get_xaxis().set_visible(False)
        a.get_yaxis().set_visible(False)              
        
        # add axes for sliders 
        ax_norma = f.add_axes([0.81, 0.1, 0.15, 0.025])
        ax_contr = f.add_axes([0.81, 0.05, 0.15, 0.025])

        hdu_list = fits.open(sky_image)
        hdu_list.info()
        img = hdu_list[0].data

        # interval = MinMaxInterval()
        contrast=0.15
        interval = ZScaleInterval(contrast=contrast)
        vmin, vmax = interval.get_limits(img)

        norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=PowerStretch(power_normalise))

        a.imshow(img, origin='lower', norm=norm, cmap='cividis')#, vmax=max_saturation)

        # Embedding In Tk
        canvas = FigureCanvasTkAgg(f, self)
        # import code; code.interact(local=dict(globals(), **locals()))
        canvas.draw()
        canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True)

        # Show ToolBar
        toolbar = NavigationToolbar(canvas, self)
        toolbar.update()
        # Activate Zoom
        toolbar.zoom(self)
        canvas._tkcanvas.pack(side=tk.TOP, fill=tk.BOTH, expand=True)
        
        delta_f = 0.1
        
        # add sliders 
        s_norma = Slider(ax_norma, 'Normalise', 0.1, 5.0, valinit=power_normalise, valstep=delta_f)
        s_contr = Slider(ax_contr, 'Contrast', 0.1, 1.0, valinit=contrast)
        

        def update(val):
            n_norma = s_norma.val
            n_contr = s_contr.val
            # assign new values to contrast and normalise
            interval = ZScaleInterval(contrast=n_contr)
            vmin, vmax = interval.get_limits(img)
            norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=PowerStretch(n_norma))
            a.imshow(img, origin='lower', norm=norm, cmap='cividis')
            canvas.draw_idle()
        s_norma.on_changed(update)    
        s_contr.on_changed(update)
        
        hdu_list.close()
    def make_jpg(self):
        '''
        Converts HIRES FITS file to JPG image
        Output filename = KOAID_CCD#_HDU##.jpg
            # = 1, 2, 3...
            ## = 01, 02, 03...
        '''

        # TODO: Can we utilize instrument.make_jpg() to reduce duplicate code?
        # Perhaps add an 'ext' param to make_jpg().

        # file to convert is lev0Dir/KOAID

        koaid = self.get_keyword('KOAID')
        filePath = ''
        for root, dirs, files in os.walk(self.dirs['lev0']):
            if koaid in files:
                filePath = ''.join((root, '/', koaid))
        if not filePath or not os.path.isfile(filePath):
            self.log_warn('MAKE_JPG_ERROR')
            return False

        koaid = filePath.replace('.fits', '')
        for ext in range(1, len(self.fits_hdu)):
            try:
                ext2 = str(ext)
                pngFile = koaid + '_CCD' + ext2 + '_HDU' + ext2.zfill(
                    2) + '.png'
                jpgFile = pngFile.replace('.png', '.jpg')
                # image data to convert
                image = self.fits_hdu[ext].data
                interval = ZScaleInterval()
                vmin, vmax = interval.get_limits(image)
                norm = ImageNormalize(vmin=vmin,
                                      vmax=vmax,
                                      stretch=AsinhStretch())
                fig = plt.figure()
                ax = plt.axes([0, 0, 1, 1])
                ax.get_xaxis().set_visible(False)
                ax.get_yaxis().set_visible(False)
                plt.imshow(np.rot90(image),
                           cmap='gray',
                           origin='lower',
                           norm=norm)
                plt.axis('off')
                # save as png, then convert to jpg
                plt.savefig(pngFile, bbox_inches='tight', pad_inches=0)
                img = Image.open(pngFile).convert('RGB')
                basewidth = int(len(image) / 2)
                wpercent = basewidth / float(img.size[0])
                hsize = int((float(img.size[1]) * float(wpercent)))
                img = img.resize((basewidth, hsize), Image.ANTIALIAS)
                img.save(jpgFile)
                os.remove(pngFile)
                plt.close()
            except:
                self.log_warn("MAKE_JPG_ERROR", jpgFile)
                return False

        return True
Exemple #8
0
def plot_image(image, width=800, downsample=2, title=None):
    """
    plots image downsampled, returning bokeh figure of requested width
    """
    #- Downsample image 2x2 (or whatever downsample specifies)
    ny, nx = image.shape
    image2 = downsample_image(image, downsample)

    #- Default image scaling
    zscale = ZScaleInterval()
    zmin, zmax = zscale.get_limits(image2)

    #- Experimental: rescale to uint8 to save space
    u8img = (255 * (image2.clip(zmin, zmax) - zmin) / (zmax - zmin)).astype(
        np.uint8)
    colormap = LinearColorMapper(palette=gray(256), low=0, high=255)

    #- Create figure
    fig = bk.figure(width=width,
                    height=width - 50,
                    active_drag='box_zoom',
                    active_scroll='wheel_zoom')
    fig.image([
        u8img,
    ], 0, 0, nx, ny, color_mapper=colormap)
    fig.x_range.start = 0
    fig.x_range.end = nx
    fig.y_range.start = 0
    fig.y_range.end = ny

    if title is not None:
        fig.title.text = title

    return fig
Exemple #9
0
def plot_dss_image(hdu,
                   cmap="gray",
                   contrast=0.5,
                   coord_format="dd:mm:ss",
                   ax=None):
    """
    Plot output of get_dss_data:
    hdu = get_dss_data(ra, dec)
    """
    data, header = hdu.data, hdu.header
    interval = ZScaleInterval(contrast=contrast)
    zmin, zmax = interval.get_limits(data)

    if ax is None:
        fig = pl.figure(constrained_layout=True)
        ax = fig.add_subplot(projection=WCS(header))
    ax.imshow(data, vmin=zmin, vmax=zmax, cmap=cmap)
    ax.set_xlabel("RA")
    ax.set_ylabel("DEC", y=0.9)
    title = f"{header['SURVEY']} ({header['FILTER']})\n"
    title += f"{header['DATE-OBS'][:10]}"
    ax.set_title(title)
    # set RA from hourangle to degree
    if hasattr(ax, "coords"):
        ax.coords[1].set_major_formatter(coord_format)
        ax.coords[0].set_major_formatter(coord_format)
    return ax
Exemple #10
0
    def __init__(self,
                 data,
                 XWIN_IMAGE,
                 YWIN_IMAGE,
                 FLUX_AUTO,
                 FLUXERR_AUTO,
                 zscaleNsamp=200,
                 zscaleContrast=1.,
                 minGoodVal=None):
        self.XWIN_IMAGE = XWIN_IMAGE
        self.YWIN_IMAGE = YWIN_IMAGE
        self.FLUX_AUTO = FLUX_AUTO
        self.FLUXERR_AUTO = FLUXERR_AUTO
        self.data = data
        self.minGoodVal = minGoodVal

        mask = np.ones(self.data.shape, dtype=bool)
        if self.minGoodVal is not None:
            w = np.where(self.data < self.minGoodVal)
            mask[w] = 0
        zscale = ZScaleInterval(nsamples=zscaleNsamp, contrast=zscaleContrast)
        (self.z1, self.z2) = zscale.get_limits(self.data[mask])
        self.normer = interval.ManualInterval(self.z1, self.z2)

        self._increment = 0
Exemple #11
0
def scale_array_for_jpg(array):
    zscale = ZScaleInterval()
    try:
        low, upp = zscale.get_limits(array)
    except IndexError:
        low, upp = np.nanpercentile(array, (1, 99))
    scaled_array = np.clip(array, low, upp)
    mi, ma = np.nanmin(scaled_array), np.nanmax(scaled_array)
    return ((scaled_array - mi) / (ma - mi) * ((2 << 7) - 1)).astype(np.uint8)
Exemple #12
0
    def fits_to_png(self, file_path, dst_path, contrast=0.15):

        img = fits.getdata(file_path, ignore_missing_end=True)
        interval = ZScaleInterval(contrast=contrast)
        min, max = interval.get_limits(img)

        img = (img - min) / (max - min)

        save_image(torch.from_numpy(img), dst_path)
    def make_jpg(self):
        '''
        Converts a FITS file to JPG image
        '''

        # file to convert is lev0Dir/KOAID

        path = self.dirs['lev0']
        koaid = self.fitsHeader.get('KOAID')
        filePath = ''
        for root, dirs, files in os.walk(path):
            if koaid in files:
                filePath = ''.join((root, '/', koaid))
        if not filePath:
            self.log.error('make_jpg: Could not find KOAID: ' + koaid)
            return False
        self.log.info(
            'make_jpg: converting {} to jpeg format'.format(filePath))

        #check if already exists? (JPG conversion is time consuming)
        #todo: Only allow skip if not fullRun? (ie Will we ever want to regenerate the jpg?)

        jpgFile = filePath.replace('.fits', '.jpg')
        if os.path.isfile(jpgFile):
            self.log.warning('make_jpg: file already exists. SKIPPING')
            return True

        # verify file exists

        try:
            if os.path.isfile(filePath):
                # image data to convert
                image = self.fitsHdu[0].data
                interval = ZScaleInterval()
                vmin, vmax = interval.get_limits(image)
                norm = ImageNormalize(vmin=vmin,
                                      vmax=vmax,
                                      stretch=AsinhStretch())
                plt.imshow(image, cmap='gray', origin='lower', norm=norm)
                plt.axis('off')
                # save as png, then convert to jpg
                pngFile = filePath.replace('.fits', '.png')
                jpgFile = pngFile.replace('.png', '.jpg')
                plt.savefig(pngFile)
                Image.open(pngFile).convert('RGB').save(jpgFile)
                os.remove(pngFile)
                plt.close()
            else:
                #TODO: if this errors, should we remove .fits file added previously?
                self.log.error(
                    'make_jpg: file does not exist {}'.format(filePath))
                return False
        except:
            self.log.error('make_jpg: Could not create JPG: ' + jpgFile)
            return False

        return True
Exemple #14
0
def zscale_image(imgarr):
    '''
    This zscales an image.

    '''

    zscaler = ZScaleInterval()
    scaled_vals = zscaler.get_limits(imgarr)
    return direct_linscale_img(imgarr, scaled_vals[0], scaled_vals[1])
Exemple #15
0
 def update(val):
     n_norma = s_norma.val
     n_contr = s_contr.val
     # assign new values to contrast and normalise
     interval = ZScaleInterval(contrast=n_contr)
     vmin, vmax = interval.get_limits(img)
     norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=PowerStretch(n_norma))
     a.imshow(img, origin='lower', norm=norm, cmap='cividis')
     canvas.draw_idle()
Exemple #16
0
    def make_thumbnail(oid: str, detection: Mapping, thumbnail_type: str):
        """Convert lossless FITS cutouts from ZTF images into PNGs

        :param oid: Fritz obj id
        :param detection: Tails detection dict
        :param thumbnail_type: <new|ref|sub>
        :return:
        """
        stack = deepcopy(detection["cutouts"])

        if thumbnail_type == "ref":
            index = 1
        elif thumbnail_type == "sub":
            index = 2
        else:
            index = 0
        cutout_data = stack[..., index]
        # flip up/down
        # cutout_data = np.flipud(cutout_data)
        buff = io.BytesIO()
        plt.close("all")
        fig = plt.figure()
        fig.set_size_inches(4, 4, forward=False)
        ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
        ax.set_axis_off()
        fig.add_axes(ax)

        # replace nans with median:
        img = np.array(cutout_data)
        # replace dubiously large values
        xl = np.greater(np.abs(img), 1e20, where=~np.isnan(img))
        if img[xl].any():
            img[xl] = np.nan
        if np.isnan(img).any():
            median = float(np.nanmean(img.flatten()))
            img = np.nan_to_num(img, nan=median)

        interval = ZScaleInterval(nsamples=img.shape[0] * img.shape[1])
        limits = interval.get_limits(img)
        ax.imshow(img,
                  origin="upper",
                  cmap="bone",
                  vmin=limits[0],
                  vmax=limits[1])
        plt.savefig(buff, dpi=42)

        buff.seek(0)
        plt.close("all")

        thumb = {
            "obj_id": oid,
            "data": base64.b64encode(buff.read()).decode("utf-8"),
            "ttype": thumbnail_type,
        }

        return thumb
    def make_jpg(self):
        '''
        Converts HIRES FITS file to JPG image
        Output filename = KOAID_CCD#_HDU##.jpg
            # = 1, 2, 3...
            ## = 01, 02, 03...
        '''

        # TODO: Can we merge this with instrument.make_jpg()?

        # file to convert is lev0Dir/KOAID

        koaid = self.get_keyword('KOAID')
        filePath = ''
        for root, dirs, files in os.walk(self.dirs['lev0']):
            if koaid in files:
                filePath = ''.join((root, '/', koaid))
        if not filePath:
            self.log.error('make_jpg: Could not find KOAID: ' + koaid)
            return False
        self.log.info(
            'make_jpg: converting {} to jpeg format'.format(filePath))

        koaid = filePath.replace('.fits', '')

        if os.path.isfile(filePath):
            for ext in range(1, len(self.fitsHdu)):
                try:
                    ext2 = str(ext)
                    pngFile = koaid + '_CCD' + ext2 + '_HDU' + ext2.zfill(
                        2) + '.png'
                    jpgFile = pngFile.replace('.png', '.jpg')
                    # image data to convert
                    image = self.fitsHdu[ext].data
                    interval = ZScaleInterval()
                    vmin, vmax = interval.get_limits(image)
                    norm = ImageNormalize(vmin=vmin,
                                          vmax=vmax,
                                          stretch=AsinhStretch())
                    plt.imshow(image, cmap='gray', origin='lower', norm=norm)
                    plt.axis('off')
                    # save as png, then convert to jpg
                    plt.savefig(pngFile)
                    Image.open(pngFile).convert('RGB').rotate(-90).save(
                        jpgFile)
                    os.remove(pngFile)
                    plt.close()
                except:
                    self.log.error('make_jpg: Could not create JPG: ' +
                                   jpgFile)
        else:
            self.log.error('make_jpg: file does not exist {}'.format(filePath))
            return False

        return True
Exemple #18
0
def plot_epics(ids, kwargs=None):
    '''
    ids: epic ids
    '''
    #there were 240 pixels in 60 arcsec
    pix_size = 0.25  #arcsec/pix

    textloc, fontsize, pad, cmap, contrast, figsize = parse_kwargs(kwargs)

    rows = len(ids)
    cols = len(filters)

    if figsize is not None:
        fig, ax = pl.subplots(rows, cols, figsize=figsize)
    else:
        fig, ax = pl.subplots(rows, cols)

    for m, epic in enumerate(ids):
        foldername = str(epic) + '_' + str(filters)
        file_list = glob(foldername + '/*.fits')

        d = sort_filternames(file_list)
        file_list = [d[i] for i in sorted(d)]
        #print(file_list)

        interval = ZScaleInterval(contrast=contrast)
        for n, f in enumerate(file_list):
            #open fits file
            img, hdr = get_data(f)
            band = hdr['HIERARCH FPA.FILTER'].split('.')[0]
            #get limits
            zmin, zmax = interval.get_limits(img)

            ax[m, n].imshow(img,
                            origin='bottom',
                            label='${}$'.format(band),
                            cmap=cmap,
                            vmin=zmin,
                            vmax=zmax)
            ax[m, 2].set_title('{}'.format(epic), fontsize=fontsize, pad=pad)
            ax[m, n].yaxis.set_major_locator(pl.NullLocator())
            ax[m, n].xaxis.set_major_formatter(pl.NullFormatter())
            xpos, ypos = textloc
            ax[m, n].text(xpos, ypos, band, fontsize=fontsize, color='w')
            rad = img.shape[0] / pix_size
            x, y = img.shape[0] / 2, img.shape[0] / 2
            circ = pl.Circle((x, y),
                             rad,
                             linewidth=2,
                             alpha=0.4,
                             edgecolor='w')
            ax[m, n].add_artist(circ)
    #fig.tight_layout()
    pl.show()
    return fig
Exemple #19
0
def plot_multi(file_list, epic, kwargs, aperture=None, show_centroid=True):
    textloc, fontsize, pad, cmap, contrast, figsize = parse_kwargs(kwargs)

    if figsize is not None:
        fig, ax = pl.subplots(1, len(file_list), figsize=figsize)
    else:
        fig, ax = pl.subplots(1, len(file_list))

    #there were 240 pixels in 60 arcsec
    pix_size = 0.25  #arcsec/pix

    d = sort_filternames(file_list)
    file_list = [d[i] for i in sorted(d)]

    for n, fname in enumerate(file_list):
        img, hdr = get_data(fname)
        filter = hdr['HIERARCH FPA.FILTER']
        band = filter.split('.')[0]

        assert img.shape[0] == img.shape[1]  #square?

        interval = ZScaleInterval(contrast=contrast)
        zmin, zmax = interval.get_limits(img)

        ax[n].imshow(img,
                     origin='bottom',
                     label='${}$'.format(band),
                     cmap='gray',
                     vmin=zmin,
                     vmax=zmax)
        ax[n].yaxis.set_major_locator(pl.NullLocator())
        ax[n].xaxis.set_major_formatter(pl.NullFormatter())
        xpos, ypos = textloc
        ax[n].text(xpos, ypos, band, fontsize=fontsize, color='w')
        if aperture is not None:
            assert aperture.size == img.size
            ax[n].matshow(aperture, cmap=cmap, alpha=0.1, label='K2 aperture')
        else:
            rad = img.shape[0] * pix_size
            x, y = img.shape[0] / 2, img.shape[0] / 2
            #aperture
            circ = pl.Circle((x, y),
                             rad,
                             linewidth=2,
                             alpha=0.4,
                             edgecolor='w')
            ax[n].add_artist(circ)
            #centroid
            if show_centroid:
                ax[n].text(x, y, '+', color='r', fontsize=fontsize)
    ax[n // 2].set_title('{}'.format(epic), fontsize=fontsize)  #, pad=pad)
    pl.show()
    fig.tight_layout()
    return fig
Exemple #20
0
def plot_contourOnImage(fitsfile,total_mask_bool,verbose=False):

    # Read in image
    image,h = fits.getdata(fitsfile,header=True)

    # Create header with wcs
    contour_fits        = fits.PrimaryHDU()
    contour_fits.data   = total_mask_bool.astype('int')
    contour_fits.header['CTYPE1']   = h['CTYPE1']
    contour_fits.header['CRPIX1']   = h['CRPIX1']
    contour_fits.header['CRVAL1']   = h['CRVAL1']
    contour_fits.header['CTYPE2']   = h['CTYPE2']
    contour_fits.header['CRPIX2']   = h['CRPIX2']
    contour_fits.header['CRVAL2']   = h['CRVAL2']
    contour_fits.header['CD1_1']    = h['CD1_1']
    contour_fits.header['CD1_2']    = h['CD1_2']
    contour_fits.header['CD2_1']    = h['CD2_1']
    contour_fits.header['CD2_2']    = h['CD2_2']
    try:
        contour_fits.header['EQUINOX']  = h['EQUINOX']
    except:
        print('IMPORTANT NOTE!!!! Equinox of input image assumed to be 2000.0')
        print('                   This is just for plotting checkim purposes')
        contour_fits.header['EQUINOX']  = 2000.0

    # Save contour_image to file, with fitsfile WCS
    total_mask_fitsWithWCS = './contour.fits'
    contour_fits.writeto(total_mask_fitsWithWCS)
    printme = f'SAVED  : {total_mask_fitsWithWCS}'
    print_verbose_string(printme,verbose=verbose)
    
    # Plot total_mask as contour on fits image
    fig = plt.figure(figsize=(48, 36))
    f2 = aplpy.FITSFigure(fitsfile,figure=fig)
    f2.ticks.hide()
    f2.tick_labels.hide_x()
    f2.tick_labels.hide_y()
    f2.axis_labels.hide()
    interval = ZScaleInterval()
    vmin,vmax = interval.get_limits(image)
    f2.show_grayscale(invert=True, stretch='linear', vmin=vmin, vmax=vmax)
    f2.show_contour(data=total_mask_fitsWithWCS,linewidths=3.0,colors='MediumPurple')
    cont_name = fitsfile.replace('.fits','_skymask_contour.png')
    f2.save(cont_name)
    print(f'SAVED  : {cont_name}')

    # Remove contour_image fits file
    clearit(total_mask_fitsWithWCS)
    printme = f'REMOVED: {total_mask_fitsWithWCS}'
    print_verbose_string(printme,verbose=verbose)

    return None
Exemple #21
0
    def display_image(self, ax = None, display = True):
        '''displays image'''

        with fits.open(self.oriname) as f:
            im = f[0].data

        if ax is None:
            fig, ax = plt.subplots(figsize = (8, 8))
        z = ZScaleInterval()
        zlim = z.get_limits(im.data)
        ax.imshow(-1*im, cmap = 'gray', vmin = -1*zlim[1], vmax = -1*zlim[0])
        if display:
            fig.show()
Exemple #22
0
def display_image(img, minclip=5, maxclip=95, label=None, cmap='Greys_r', 
                  srcs=None, projection=None, calibrated=False, png=None):
    """Simple wrapper to display an image.
    
    """
    from astropy.visualization import AsinhStretch as Stretch
    from astropy.visualization import ZScaleInterval as Interval
    from astropy.visualization.mpl_normalize import ImageNormalize

    #from astropy.visualization import simple_norm
    #norm = simple_norm(img, min_percent=minclip, max_percent=maxclip)

    interval = Interval(contrast=0.5)
    vmin, vmax = interval.get_limits(img)
    norm = ImageNormalize(interval=interval, stretch=Stretch(a=0.9))

    fig, ax = plt.subplots(figsize=(8, 8), subplot_kw={'projection': projection})
    im = ax.imshow(img, origin='lower', norm=norm, cmap=cmap,
                   vmin=vmin, vmax=vmax)
    if projection:
        ax.coords.grid(color='red')
        ax.coords['ra'].set_axislabel('Right Ascension')
        ax.coords['dec'].set_axislabel('Declination')      
    else:
        ax.set_xlabel('Column Number (pixels)')
        ax.set_ylabel('Row Number (pixels)')

    # Mark the locations of stars.
    if srcs:
        from photutils import CircularAperture
        pos = np.transpose((srcs['xcentroid'], srcs['ycentroid']))
        aps = CircularAperture(pos, r=12.)
        aps.plot(color='red', lw=1.5, alpha=0.6, axes=ax)
      
    # Make room for the colorbar
    fig.subplots_adjust(right=0.8)
    cax = fig.add_axes([0.85, 0.28, 0.05, 0.45])
    c = plt.colorbar(im, cax=cax)
    if label:
        c.set_label(label)
    else:
        if calibrated:
            c.set_label(r'Intensity ($e^{-}/s$)')
        else:
            c.set_label('Intensity (ADU)')

    if png:
        print('Writing {}'.format(png))
        fig.savefig(png)
Exemple #23
0
def star_size(mag, N=None, zmin=None, zmax=None):
    """
    Convert magnitudes into intensities and define sizes of stars in
    finding chart.
    """
    mag = np.array(mag)
    if N is None:
        N = mag.size
    if zmin is None and zmax is None:
        interval = ZScaleInterval()
        zmin, zmax = interval.get_limits(mag)

    mag = mag.clip(zmin, zmax)
    factor = 500. * (1 - 1 / (1 + 150 / N**0.85))
    sizes = .1 + factor * (10**((mag - zmin) / -2.5))
    return sizes
Exemple #24
0
def plotTestImage(imageArray):
    interval = ZScaleInterval()
    limits = interval.get_limits(imageArray)
    sources = locateStarsInImage(imageArray)
    positions = np.transpose((sources['xcentroid'], sources['ycentroid']))
    apertures = CircularAperture(positions, r=4.)
    norm = ImageNormalize(stretch=SqrtStretch())
    plt.imshow(imageArray,
               cmap='Greys',
               origin='lower',
               norm=norm,
               interpolation='nearest',
               vmin=limits[0],
               vmax=limits[1])
    apertures.plot(color='red', lw=1.5, alpha=0.5)
    plt.show()
Exemple #25
0
    def image_to_png(self, image, outname):
        """Ouputs an image array into a png file.

        Parameters
        ----------
        image : numpy.ndarray
            2D image array

        outname : str
            The name given to the output png file

        Returns
        -------
        output_filename : str
            The full path to the output png file
        """

        output_filename = os.path.join(self.data_dir, '{}.png'.format(outname))

        if not os.path.isfile(output_filename):
            # Get image scale limits
            z = ZScaleInterval()
            vmin, vmax = z.get_limits(image)

            # Plot the image
            plt.figure(figsize=(12, 12))
            ax = plt.gca()
            im = ax.imshow(image,
                           cmap='gray',
                           origin='lower',
                           vmin=vmin,
                           vmax=vmax)
            ax.set_title('{}'.format(outname))

            # Make the colorbar
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="5%", pad=0.4)
            cbar = plt.colorbar(im, cax=cax)
            cbar.set_label('Signal [DN]')

            plt.savefig(output_filename, bbox_inches='tight', dpi=200)
            set_permissions(output_filename)
            logging.info('\t{} created'.format(output_filename))
        else:
            logging.info('\t{} already exists'.format(output_filename))

        return output_filename
Exemple #26
0
def scale_image(raw_img):
    """
    Rescale pixels intensity according to the iraf's ZScale Algorithm

    Parameters
    ----------
    raw_img : numpy.array(np.float32)
        Mosaic with raw intensities
    Return
    ----------
    raw_img : numpy.array(np.float32)
        Rescaled mosaic
    """
    s = ZScaleInterval()
    z1,z2 = s.get_limits(raw_img)
    raw_img[raw_img > z2] = z2
    raw_img[raw_img < z1] = z1
    return raw_img
Exemple #27
0
    def __call__(self, sample):

        scale = ZScaleInterval()
        vmin, vmax = scale.get_limits(sample['image'][0][200:800, 200:800])
        newimage = (np.clip(sample['image'][0], vmin, vmax) - vmin) / (vmax -
                                                                       vmin)

        # deactivate stretching: linear stretch
        stretch = ContrastBiasStretch(contrast=0.5,
                                      bias=0.2)  # SquaredStretch()
        newimage = stretch(newimage)
        newimage -= newimage[0, 0]
        newimage = LinearStretch()(newimage) * 512

        return {
            'image': newimage.reshape(1, *newimage.shape),
            'clouds': sample['clouds']
        }
Exemple #28
0
def plot_aperture_outline2(
    img,
    mask,
    ax=None,
    imgwcs=None,
    cmap="viridis",
    color_aper="C6",
    figsize=None,
):
    """
    see https://github.com/afeinstein20/eleanor/blob/master/eleanor/visualize.py#L78
    """
    interval = ZScaleInterval(contrast=0.5)
    f = lambda x, y: mask[int(y), int(x)]
    g = np.vectorize(f)

    if ax is None:
        fig, ax = pl.subplots(subplot_kw={"projection": imgwcs},
                              figsize=figsize)
        ax.set_xlabel("RA")
        ax.set_ylabel("Dec")
    x = np.linspace(0, mask.shape[1], mask.shape[1] * 100)
    y = np.linspace(0, mask.shape[0], mask.shape[0] * 100)
    extent = [0 - 0.5, x[:-1].max() - 0.5, 0 - 0.5, y[:-1].max() - 0.5]
    X, Y = np.meshgrid(x[:-1], y[:-1])
    Z = g(X[:-1], Y[:-1])
    # plot contour
    _ = ax.contour(
        Z[::-1],
        levels=[0.5],
        colors=color_aper,
        linewidths=[3],
        extent=extent,
        origin="lower",
    )
    zmin, zmax = interval.get_limits(img)
    # plot image
    ax.matshow(img,
               origin="lower",
               cmap=cmap,
               vmin=zmin,
               vmax=zmax,
               extent=extent)
    return ax
Exemple #29
0
    def image_to_png(self, image, outname):
        """Outputs an image array into a png file.

        Parameters
        ----------
        image : numpy.ndarray
            2D image array.

        outname : str
            The name given to the output png file.

        Returns
        -------
        output_filename : str
            The full path to the output png file.
        """

        output_filename = os.path.join(self.data_dir, '{}.png'.format(outname))

        # Get image scale limits
        zscale = ZScaleInterval()
        vmin, vmax = zscale.get_limits(image)

        # Plot the image
        plt.figure(figsize=(12, 12))
        im = plt.imshow(image,
                        cmap='gray',
                        origin='lower',
                        vmin=vmin,
                        vmax=vmax)
        plt.colorbar(
            im, label='Readnoise Difference (most recent dark - reffile) [DN]')
        plt.title('{}'.format(outname))

        # Save the figure
        plt.savefig(output_filename,
                    bbox_inches='tight',
                    dpi=200,
                    overwrite=True)
        set_permissions(output_filename)
        logging.info('\t{} created'.format(output_filename))

        return output_filename
Exemple #30
0
def plot_aperture_outline(
    img,
    mask,
    ax=None,
    imgwcs=None,
    cmap="viridis",
    color_aper="C6",
    figsize=None,
):
    """
    see https://github.com/rodluger/everest/blob/56f61a36625c0d9a39cc52e96e38d257ee69dcd5/everest/standalone.py
    """
    interval = ZScaleInterval(contrast=0.5)
    ny, nx = mask.shape
    contour = np.zeros((ny, nx))
    contour[np.where(mask)] = 1
    contour = np.lib.pad(contour, 1, PadWithZeros)
    highres = zoom(contour, 100, order=0, mode="nearest")
    extent = np.array([-1, nx, -1, ny])

    if ax is None:
        fig, ax = pl.subplots(subplot_kw={"projection": imgwcs},
                              figsize=figsize)
        ax.set_xlabel("RA")
        ax.set_ylabel("Dec")
    _ = ax.contour(
        highres,
        levels=[0.5],
        linewidths=[3],
        extent=extent,
        origin="lower",
        colors=color_aper,
    )
    zmin, zmax = interval.get_limits(img)
    ax.matshow(img,
               origin="lower",
               cmap=cmap,
               vmin=zmin,
               vmax=zmax,
               extent=extent)
    # verts = cs.allsegs[0][0]
    return ax
Exemple #31
0
def make_rrl_finder_chart(target_stem, ra, dec):
	target_stem = re.sub('\.', '_', target_stem)
	target_stem = re.sub('-', '', target_stem)
	target_stem = sgr_setup.get_target_stem(target_stem)
	coords = str(ra) +' ' +str(dec)
	ra= SkyCoord(coords, unit=(u.deg, u.deg)).ra
	dec = SkyCoord(coords, unit=(u.deg, u.deg)).dec
	fitsfile = target_stem + '_e1_3p6um.fits'
	print ra, dec
	inputfile = fitsfile
	fitsdata = astropy.io.fits.open(fitsfile)[0].data
	interval = ZScaleInterval()
	zmin, zmax = interval.get_limits(fitsdata)
	fig = mp.figure(figsize=(10,10))
	mosaic = aplpy.FITSFigure(inputfile, figure = fig)
	mosaic.show_grayscale(vmin=zmin,vmax=zmax, invert='true') ### manually implimenting zscale
	mosaic.tick_labels.set_font(size='small')
	mosaic.tick_labels.set_xformat("hh:mm:ss")
	mosaic.set_theme('publication')
	mosaic.show_markers(ra.deg, dec.deg, edgecolor='magenta', facecolor='magenta', marker='o', s=100, alpha=0.3)
	mosaic.show_markers(ra.deg, dec.deg, edgecolor='magenta', facecolor='magenta', marker='o', s=300, alpha=0.1)
	#mosaic.save(target_stem + '_location.pdf')
	mp.show()
Exemple #32
0
def image_thumbnails(dataMap,photCat,band='g',objid=None,
                     nbin=None,old=False,trim=None):
	from astropy.visualization import ZScaleInterval
	from matplotlib.backends.backend_pdf import PdfPages
	# load object database
	objs = photCat.bokPhot
	try:
		objs['frameIndex'] = objs['frameId']
	except:
		pass
	if objid is not None:
		obj_ii = np.where(objs['objId']==objid)[0]
		objs = objs[obj_ii]
	tmpObsDb = dataMap.obsDb.copy()
	tmpObsDb['mjd_mid'] = tmpObsDb['mjd'] + (tmpObsDb['expTime']/2)/(3600*24.)
	objs = bokrmphot.join_by_frameid(objs,tmpObsDb)
	objs = objs.group_by(['objId','filter'])
	# configure figures
	nrows,ncols = 8,6
	figsize = (7.0,10.25)
	subplots = (0.11,0.07,0.89,0.93,0.00,0.03)
	size = 65
	zscl = ZScaleInterval()
	nplot = nrows*ncols
	if old:
		outdir = 'bokcutouts_old/'
	else:
		outdir = 'bokcutouts/'
	ccdcolors = ['darkblue','darkgreen','darkred','darkmagenta']
	if True and photCat.name=='rmqso':
		diffphot = Table.read('bok%s_photflags.fits'%band)
		errlog = open('bokflags_%s_err.log'%band,'a')
		bitstr = [ 'TinyFlux','BigFlux','TinyErr','BigErr','BigOff']
		frameid = np.zeros(diffphot['MJD'].shape,dtype=np.int32)
		for i in range(len(diffphot)):
			jj = np.where(diffphot['MJD'][i]>0)[0]
			for j in jj:
				dt = diffphot['MJD'][i,j]-tmpObsDb['mjd_mid']
				_j = np.abs(dt).argmin()
				if np.abs(dt[_j]) > 5e-4:
					raise ValueError("no match for ",i,diffphot['MJD'][i,j])
				else:
					frameid[i,j] = tmpObsDb['frameIndex'][_j]
		diffphot['frameIndex'] = frameid
		matched = diffphot['MJD'] == 0 # ignoring these
	plt.ioff()
	for k,obj in zip(objs.groups.keys,objs.groups):
		objId = k['objId']
		_band = k['filter']
		if _band != band: continue
		if photCat.name=='rmqso' and objId >= 850:
			break
		cutfile = outdir+'bok%s%03d_%s.fits' % (photCat.name,
		                                        obj['objId'][0],band)
		pdffile = cutfile.replace('.fits','.pdf')
		if os.path.exists(pdffile) or len(obj)==0:
			continue
		pdf = PdfPages(pdffile)
		cutfits = fits.open(cutfile)
		# number cutouts matches number observations
		if len(cutfits)-1 != len(obj):
			errlog.write('[RM%03d]: %d cutouts, %d obs; skipping\n' %
			              (obj['objId'][0],len(cutfits)-1,len(obj)))
		pnum = -1
		for i,(obs,hdu) in enumerate(zip(obj,cutfits[1:])):
			sys.stdout.write('\rRM%03d %4d/%4d' % 
			                 (obs['objId'],(i+1),len(obj)))
			sys.stdout.flush()
			ccdNum = obs['ccdNum']
			cut = hdu.data
			try:
				z1,z2 = zscl.get_limits(cut[cut>0])
			except:
				try:
					z1,z2 = np.percentile(cut[cut>0],[10,90])
				except:
					z1,z2 = cut.min(),cut.max()
			if not old:
				# rotate to N through E
				if ccdNum==1:
					cut = cut[:,::-1]
				elif ccdNum==2:
					cut = cut[::-1,::-1]
				elif ccdNum==3:
					pass
				elif ccdNum==4:
					cut = cut[::-1,:]
				# except now flip x-axis so east is right direction
				cut = cut[:,::-1]
			if trim is not None:
				cut = cut[trim:-trim,trim:-trim]
			if nbin is not None:
				cut = block_reduce(cut,nbin,np.mean)
			#
			if pnum==nplot+1 or pnum==-1:
				if pnum != -1:
					pdf.savefig()
					plt.close()
				plt.figure(figsize=figsize)
				plt.subplots_adjust(*subplots)
				pnum = 1
			ax = plt.subplot(nrows,ncols,pnum)
			plt.imshow(cut,origin='lower',interpolation='nearest',
			           vmin=z1,vmax=z2,cmap=plt.cm.gray_r,aspect='equal')
			framestr1 = '(%d,%d,%d)' % (obs['ccdNum'],obs['x'],obs['y'])
			framestr2 = '%.3f' % (obs['mjd'])
			utstr = obs['utDate'][2:]+' '+obs['utObs'][:5]
			frameclr = ccdcolors[obs['ccdNum']-1]
			ax.set_title(utstr,size=7,color='k',weight='bold')
			t = ax.text(0.01,0.98,framestr1,
			            size=7,va='top',color=frameclr,
			            transform=ax.transAxes)
			t.set_bbox(dict(color='white',alpha=0.45,boxstyle="square,pad=0"))
			t = ax.text(0.01,0.02,framestr2,
			            size=7,color='blue',
			            transform=ax.transAxes)
			t.set_bbox(dict(color='white',alpha=0.45,boxstyle="square,pad=0"))
			if obs['flags'][2] > 0:
				t = ax.text(0.03,0.7,'%d' % obs['flags'][2],
				            size=10,ha='left',va='top',color='red',
				            transform=ax.transAxes)
			if True and photCat.name=='rmqso':
				_j = np.where(diffphot['frameIndex'][objId] ==
				              obs['frameIndex'])[0]
				if len(_j)>0:
					matched[objId,_j] = True
					flg = diffphot['FLAG'][objId,_j]
					if flg > 0:
						flgstr = [ s for bit,s in enumerate(bitstr)
						               if (flg & (1<<bit)) > 0 ]
						t = ax.text(0.97,0.8,'\n'.join(flgstr),
						            size=10,ha='right',va='top',color='red',
						            transform=ax.transAxes)
				else:
					errlog.write('no diff phot for %d %.4f %.4f\n' % 
					             (objId,obs['mjd'],obs['mjd_mid']))
			ax.xaxis.set_visible(False)
			ax.yaxis.set_visible(False)
			pnum += 1
		if True and photCat.name=='rmqso':
			jj = np.where(~matched[objId])[0]
			if len(jj)>0:
				errlog.write('unmatched for %d:\n'%objId)
				for j in jj:
					errlog.write('    %.5f  %d\n'%
					     (diffphot['MJD'][objId,j],diffphot['FLAG'][objId,j]))
		try:
			pdf.savefig()
			plt.close()
		except:
			pass
		pdf.close()
	plt.ion()
	if True and photCat.name=='rmqso':
		errlog.close()
def on_key(event):
    global xxList, yyList, imgList, imgNum
    global fig, brushSize, maskImg
    global stokesDir

    # Handle brush sizing
    if event.key == '1':
        brushSize = 1
    elif event.key == '2':
        brushSize = 2
    elif event.key == '3':
        brushSize = 3
    elif event.key == '4':
        brushSize = 4
    elif event.key == '5':
        brushSize = 5
    elif event.key == '6':
        brushSize = 6

    # Increment the image number
    if event.key == 'right' or event.key == 'left':
        if event.key == 'right':
            #Advance to the next image
            imgNum += 1

            # If there are no more images, then loop back to begin of list
            if imgNum > imgList.size - 1:
                imgNum   = 0

        if event.key == 'left':
            #Move back to the previous image
            imgNum -= 1

            # If there are no more images, then loop back to begin of list
            if imgNum < 0:
                imgNum   = imgList.size - 1

        # Build the image scaling intervals
        img              = imgList[imgNum]
        zScaleGetter     = ZScaleInterval()
        thisMin, thisMax = zScaleGetter.get_limits(img.data)
        thisMax         *= 10

        #*******************************
        # Update the displayed mask
        #*******************************

        # Check which mask files might be usable...
        baseFile = os.path.basename(img.filename).split('_I')[0]
        maskFile = os.path.join(stokesDir,
            baseFile + '_mask.fits')
        if os.path.isfile(maskFile):
            # If the mask for this file exists, use it
            print('using this mask: ',os.path.basename(maskFile))
            maskImg = ai.reduced.ReducedScience.read(maskFile)
        else:
            # If none of those files exist, build a blank slate
            # Build a mask template (0 = not masked, 1 = masked)
            maskImg = ai.reduced.ReducedScience(
                (img.data*0).astype(np.int16),
                header =  img.header
            )
            maskImg.filename = maskFile

        # Grab the pixel positons
        yy, xx, = yyList[imgNum], xxList[imgNum]

        # Update contour plot (clear old lines redo contouring)
        ax.collections = []
        ax.contour(xx, yy, maskImg.data, levels=[0.5], colors='white', alpha = 0.2)

        # Reassign image display limits
        axImg.set_clim(vmin = thisMin, vmax = thisMax)

        # Display the new images and update extent
        axImg.set_data(img.data)
        axImg.set_extent((xx.min(), xx.max(), yy.min(), yy.max()))

        # Update the annotation
        ax.set_title(os.path.basename(img.filename))

        # Update the display
        fig.canvas.draw()

    # Save the generated mask
    if event.key == 'enter':
        # Write the mask to disk
        print('Writing mask for file {}'.format(maskImg.filename))
        maskImg.write(clobber=True)

    # Clear out the mask values
    if event.key == 'backspace':
        try:
            # Clear out the mask array
            maskImg.data = (maskImg.data*0).astype(np.int16)

            # Update contour plot (clear old lines redo contouring)
            ax.collections = []
            ax.contour(xx, yy, maskImg.data, levels=[0.5], colors='white', alpha = 0.2)

            # Update the display
            fig.canvas.draw()
        except:
            pass
def on_key(event):
    global fileList, targetList, fig, imgNum, brushSize
    global maskDir, maskImg
    global prevImg,   thisImg,   nextImg
    global prevAxImg, thisAxImg, nextAxImg
    global prevTarget, thisTarget, nextTarget
    global prevMin,   thisMin,   nextMin
    global prevMax,   thisMax,   nextMax
    global prevLabel, thisLabel, nextLabel

    # Handle brush sizing
    if event.key == '1':
        brushSize = 1
    elif event.key == '2':
        brushSize = 2
    elif event.key == '3':
        brushSize = 3
    elif event.key == '4':
        brushSize = 4
    elif event.key == '5':
        brushSize = 5
    elif event.key == '6':
        brushSize = 6

    # Increment the image number
    if event.key == 'right' or event.key == 'left':
        if event.key == 'right':
            #Advance to the next image
            imgNum += 1

            # Read in the new files
            prevImg = thisImg
            thisImg = nextImg
            nextImg = ai.reduced.ReducedScience.read(fileList[(imgNum + 1) % len(fileList)])

            # Update target info
            prevTarget = thisTarget
            thisTarget = nextTarget
            nextTarget = targetList[(imgNum + 1) % len(fileList)]

            # Build the image scaling intervals
            zScaleGetter = ZScaleInterval()

            # Compute new image display minima
            prevMin = thisMin
            thisMin = nextMin
            nextMin, _ = zScaleGetter.get_limits(nextImg.data)

            # Compute new image display maxima
            prevMax = thisMax
            thisMax = nextMax
            _, nextMax = zScaleGetter.get_limits(nextImg.data)

        if event.key == 'left':
            #Move back to the previous image
            imgNum -= 1

            # Read in the new files
            nextImg = thisImg
            thisImg = prevImg
            prevImg = ai.reduced.ReducedScience.read(fileList[(imgNum - 1) % len(fileList)])

            # Update target info
            nextTarget = thisTarget
            thisTarget = prevTarget
            prevTarget = targetList[(imgNum - 1) % len(fileList)]

            # Build the image scaling intervals
            zScaleGetter = ZScaleInterval()

            # Compute new image display minima
            nextMin = thisMin
            thisMin = prevMin
            prevMin, _ = zScaleGetter.get_limits(prevImg.data)

            # Compute new image display maxima
            nextMax = thisMax
            thisMax = prevMax
            _, prevMax = zScaleGetter.get_limits(prevImg.data)

        #*******************************
        # Update the displayed mask
        #*******************************

        # Check which mask files might be usable...
        prevMaskFile = os.path.join(maskDir,
            os.path.basename(prevImg.filename))
        thisMaskFile = os.path.join(maskDir,
            os.path.basename(thisImg.filename))
        nextMaskFile = os.path.join(maskDir,
            os.path.basename(nextImg.filename))
        if os.path.isfile(thisMaskFile):
            # If the mask for this file exists, use it
            print('using this mask: ',os.path.basename(thisMaskFile))
            maskImg = ai.reduced.ReducedScience.read(thisMaskFile)
        elif os.path.isfile(prevMaskFile) and (prevTarget == thisTarget):
            # Otherwise check for the mask for the previous file
            print('using previous mask: ',os.path.basename(prevMaskFile))
            maskImg = ai.reduced.ReducedScience.read(prevMaskFile)
        elif os.path.isfile(nextMaskFile) and (nextTarget == thisTarget):
            # Then check for the mask of the next file
            print('using next mask: ',os.path.basename(nextMaskFile))
            maskImg = ai.reduced.ReducedScience.read(nextMaskFile)
        else:
            # If none of those files exist, build a blank slate
            # Build a mask template (0 = not masked, 1 = masked)
            maskImg       = thisImg.copy()
            maskImg.filename = thisMaskFile
            maskImg = maskImg.astype(np.int16)
            # Make sure the uncertainty array is removed from the image
            try:
                del maskImg.uncertainty
            except:
                pass

        # Update contour plot (clear old lines redo contouring)
        axarr[1].collections = []
        axarr[1].contour(xx, yy, maskImg.data, levels=[0.5], colors='white', alpha = 0.2)

        # Reassign image display limits
        prevAxImg.set_clim(vmin = prevMin, vmax = prevMax)
        thisAxImg.set_clim(vmin = thisMin, vmax = thisMax)
        nextAxImg.set_clim(vmin = nextMin, vmax = nextMax)

        # Display the new images
        prevAxImg.set_data(prevImg.data)
        thisAxImg.set_data(thisImg.data)
        nextAxImg.set_data(nextImg.data)

        # Update the annotation
        axList = fig.get_axes()
        axList[1].set_title(os.path.basename(thisImg.filename))

        prevStr   = (str(prevImg.header['OBJECT']) + '\n' +
                     str(prevImg.header['FILTNME2'] + '\n' +
                     str(prevImg.header['HWP'])))
        thisStr   = (str(thisImg.header['OBJECT']) + '\n' +
                     str(thisImg.header['FILTNME2'] + '\n' +
                     str(thisImg.header['HWP'])))
        nextStr   = (str(nextImg.header['OBJECT']) + '\n' +
                     str(nextImg.header['FILTNME2'] + '\n' +
                     str(nextImg.header['HWP'])))
        prevLabel.set_text(prevStr)
        thisLabel.set_text(thisStr)
        nextLabel.set_text(nextStr)

        # Update the display
        fig.canvas.draw()

    # Save the generated mask
    if event.key == 'enter':
        # Make sure the header has the right values
        maskImg.header = thisImg.header

        # TODO: make sure the mask ONLY has what it needs
        # i.e., remove uncertainty and convert to np.ubyte type.

        # Write the mask to disk
        maskBasename = os.path.basename(thisImg.filename)
        maskFullname = os.path.join(maskDir, maskBasename)
        print('Writing mask for file {}'.format(maskBasename))
        maskImg.write(maskFullname, clobber=True)

    # Clear out the mask values
    if event.key == 'backspace':
        # Clear out the mask array
        maskImg.data = maskImg.data * np.byte(0)

        # Update contour plot (clear old lines redo contouring)
        axarr[1].collections = []
        axarr[1].contour(xx, yy, maskImg.data, levels=[0.5], colors='white', alpha = 0.2)

        # Update the display
        fig.canvas.draw()
# Generate 2D X and Y position maps
maskShape = maskImg.shape
grids     = np.mgrid[0:maskShape[0], 0:maskShape[1]]
xx        = grids[1]
yy        = grids[0]

# Build the image displays
# Start by preparing a 1x3 plotting area
fig, axarr = plt.subplots(1, 3, sharey=True)

# Build the image scaling intervals
zScaleGetter = ZScaleInterval()

# Compute image count scaling
prevMin, prevMax = zScaleGetter.get_limits(prevImg.data)
thisMin, thisMax = zScaleGetter.get_limits(thisImg.data)
nextMin, nextMax = zScaleGetter.get_limits(nextImg.data)

# prevMin = np.median(prevImg.data) - 0.25*np.std(prevImg.data)
# prevMax = np.median(prevImg.data) + 2*np.std(prevImg.data)
# thisMin = np.median(thisImg.data) - 0.25*np.std(thisImg.data)
# thisMax = np.median(thisImg.data) + 2*np.std(thisImg.data)
# nextMin = np.median(nextImg.data) - 0.25*np.std(nextImg.data)
# nextMax = np.median(nextImg.data) + 2*np.std(nextImg.data)

# Populate each axis with its image
prevAxImg = prevImg.show(axes = axarr[0], cmap='viridis',
                                        vmin = prevMin, vmax = prevMax, noShow = True)
thisAxImg = thisImg.show(axes = axarr[1], cmap='viridis',
                                        vmin = thisMin, vmax = thisMax, noShow = True)
#Plot the polynomial
y_range     = arange(float(xmin), float(xmax))
n           = (2 * y_range - (xmax + xmin)) / (xmax - xmin)
poly_leg    = legval(n, coefs)
trace_curve = poly_leg + aper_center[0]
low_limit   = trace_curve + aper_low[0]
high_limit  = trace_curve + aper_high[0]


#Plot Background region
idx_background  = [i for i in range(len(file_lines)) if '\t\tsample' in file_lines[i]]
background_line = file_lines[idx_background[0]].split()[1:]
for region in background_line:
    limits_region       = map(float, region.split(':'))
    low_limit_region    = trace_curve + limits_region[0]
    high_limit_region   = trace_curve + limits_region[1]    
    Axis.fill_betweenx(y_range, low_limit_region, high_limit_region, alpha=0.1, facecolor='yellow')

 
IntensityLimits = ZScaleInterval()
int_min, int_max = IntensityLimits.get_limits(image_data)[0], IntensityLimits.get_limits(image_data)[1]
Axis.imshow(image_data, cmap='bone', origin='lower', vmin = int_min, vmax = int_max, interpolation='nearest')
Axis.plot(trace_curve, y_range, color='red')
#Axis.plot(low_limit, y_range, color='red')
#Axis.plot(high_limit, y_range, color='red')
Axis.fill_betweenx(y_range, low_limit, high_limit, alpha=0.3, facecolor='purple')


plt.axis('tight')

plt.show()