예제 #1
0
def log_norm(image: np.ndarray, a=2000):
    return ImageNormalize(image, interval=ZScaleInterval(), stretch=LogStretch(a))
            # if not os.path.isfile('{}/images/image_filtered_low_clean.fits'.format(odf_dir)):
            #     fits_image = '{}/images/image_filtered_low.fits'.format(odf_dir)
            # else:
            #     fits_image = '{}/images/image_filtered_low_clean.fits'.format(odf_dir)

            hdu = fits.open(fits_image)
            wcs = WCS(hdu[0].header)
            g2_kernel = Gaussian2DKernel(2)
            smoothed_data_g2 = convolve(hdu[0].data, g2_kernel, mask=np.logical_not(detmask.data)) * detmask.data

            fig = plt.figure(figsize=(10, 10), dpi=100)
            pp = 99.9  # colour cut percentage

            ax = fig.add_subplot(111, projection=wcs)
            #ax.set_title("Gaussian smoothed image")
            norm_xmm = ImageNormalize(smoothed_data_g2, interval=ManualInterval(vmin=0.01, vmax=100.0),
                                      stretch=LogStretch())
            # norm_xmm = ImageNormalize(smoothed_data_g2,interval=PercentileInterval(pp), stretch=AsinhStretch())
            ax.imshow(smoothed_data_g2, cmap=plt.cm.hot, norm=norm_xmm, origin='lower', interpolation='nearest')
            ax.xlabel = 'RA'
            ax.ylabel = 'Dec'
            ax.set_xlabel('RA')
            ax.set_ylabel('DEC')
            #
            # only show the last aperture
            #
            circle_sky = CircleSkyRegion(center=center, radius=r_end)
            pix_reg = circle_sky.to_pixel(wcs)
            pix_reg.plot(ax=ax, edgecolor='yellow')
            plt.savefig('/home/aaranda/tfm/results_v2/{}/{}/smoothed_g2_image_low.png'.format(target, obsid))
            plt.close(fig)
예제 #3
0
rt = razmap(dfile,
            rbins,
            tbins,
            incl=disk.disk['HD143006']['incl'],
            PA=disk.disk['HD143006']['PA'],
            offx=disk.disk['HD143006']['dx'],
            offy=disk.disk['HD143006']['dy'])

# Image setups
im_bounds = (rt.dRA.max(), rt.dRA.min(), rt.dDEC.min(), rt.dDEC.max())
dRA_lims, dDEC_lims = [rs * rout, -rs * rout], [-rs * rout, rs * rout]
rt_bounds = (rbins.min(), rbins.max(), tbins.min(), tbins.max())
rlims, tlims = [0, rbins.max()], [tbins.min(), tbins.max()]

# intensity limits and stretch
cnorm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=LinearStretch())

# Residual emission map (sky-plane, Cartesian)
plt.style.use('default')
plt.rc('font', size=6)

# set location in figure
ax = fig.add_subplot(gs[0, 0])

# plot the residual image
im = ax.imshow(1e6 * rt.image,
               origin='lower',
               cmap=cmap,
               extent=im_bounds,
               norm=cnorm,
               aspect='equal')
예제 #4
0
def kepmask(infile, frameno=100, maskfile='maskfile.txt', plotfile='kepmask.png',
            imin=None, imax=None, iscale='linear', cmap='bone',
            verbose=False, logfile='kepmask.log'):
    """
    kepmask - plots, creates or edits custom target masks for target pixel
    files.

    The product from this task is a target mask definition file which
    can be used by kepextract to extract a light curve from target pixel data.
    This tool is a GUI interface for defining a pixel mask by moving a mouse
    over image pixels and selecting them by pressing the left-button of your
    mouse/keypad.

    Parameters
    ----------
    infile : str
        The name of a target pixel file from the MAST Kepler archive,
        containing a standard mask definition image in the second data
        extension.
    frameno : int
        Frame number in the target pixel file.
    maskfile : str
        The name of an ASCII mask definition file. This is either the name of
        a file to be plotted, a file to be created, or a file to be edited.
    plotfile : str
        The name of a PNG plot file containing a record of the mask defined or
        uploaded by this task.
    imin : float or None
        Minimum intensity value (in electrons per cadence) for the image
        display. The default minimum intensity level is the median of the
        faintest 10% of pixels in the image.
    imax : float or None
        Maximum intensity value (in electrons per cadence) for the image
        display. The default maximum intensity level is the median of the
        brightest 10% of pixels in the image.
    iscale : str
        Type of intensity scaling for the image display.
        * linear
        * log
        * sqrt
    cmap : str
        Color intensity scheme for the image display.
    verbose : bool
        Print informative messages and warnings to the shell and logfile?
    logfile : str
        Name of the logfile containing error and warning messages.

    Examples
    --------
    .. code-block:: bash

        $ kepmask ktwo202933888-c02_lpd-targ.fits.gz

    .. image:: ../_static/images/api/kepmask.png
        :align: center
    """

    global pimg, zscale, zmin, zmax, xmin, xmax, ymin, ymax, quarter, norm
    global pxdim, pydim, kepmag, skygroup, season, channel
    global module, output, row, column, mfile, pfile
    global pkepid, pkepmag, pra, pdec, colmap
    global pxdim, pydim

    # input arguments
    zmin = imin; zmax = imax; zscale = iscale; colmap = cmap
    mfile = maskfile; pfile = plotfile

    # log the call
    hashline = '--------------------------------------------------------------'
    kepmsg.log(logfile, hashline, verbose)
    call = ('KEPMASK -- '
            + ' infile={}'.format(infile)
            + ' maskfile={}'.format(mfile)
            + ' plotfile={}'.format(pfile)
            + ' frameno={}'.format(frameno)
            + ' imin={}'.format(imin)
            + ' imax={}'.format(imax)
            + ' iscale={}'.format(iscale)
            + ' cmap={}'.format(cmap)
            + ' verbose={}'.format(verbose)
            + ' logfile={}'.format(logfile))

    kepmsg.log(logfile, call + '\n', verbose)
    kepmsg.clock('KEPMASK started at', logfile, verbose)

    # open TPF FITS file and check whether or not frameno exists
    try:
        tpf = pyfits.open(infile, mode='readonly')
    except:
        errmsg = ('ERROR -- KEPIO.OPENFITS: cannot open ' +
                  infile + ' as a FITS file')
        kepmsg.err(logfile, errmsg, verbose)

    try:
        naxis2 = tpf['TARGETTABLES'].header['NAXIS2']
    except:
        errmsg = ('ERROR -- KEPMASK: No NAXIS2 keyword in ' + infile +
                  '[TARGETTABLES]')
        kepmsg.err(logfile, errmsg, verbose)

    if frameno > naxis2:
        errmsg = ('ERROR -- KEPMASK: frameno is too large. There are'
                  ' {} rows in the table.'.format(naxis2))
        kepmsg.err(logfile, errmsg, verbose)

    tpf.close()

    # read TPF data pixel image
    kepid, channel, skygroup, module, output, quarter, season, \
    ra, dec, column, row, kepmag, xdim, ydim, pixels = \
        kepio.readTPF(infile, 'FLUX', logfile, verbose)
    img = pixels[frameno]
    pkepid = copy(kepid)
    pra = copy(ra)
    pdec = copy(dec)
    pkepmag = copy(kepmag)
    pxdim = copy(xdim)
    pydim = copy(ydim)
    pimg = copy(img)

    # print target data
    print('')
    print('      KepID:  {}'.format(kepid))
    print(' RA (J2000):  {}'.format(ra))
    print('Dec (J2000):  {}'.format(dec))
    print('     KepMag:  {}'.format(kepmag))
    print('   SkyGroup:  {}'.format(skygroup))
    print('     Season:  {}'.format(season))
    print('    Channel:  {}'.format(channel))
    print('     Module:  {}'.format(module))
    print('     Output:  {}'.format(output))
    print('')

    # subimage of channel for plot
    ymin = copy(row)
    ymax = ymin + ydim
    xmin = copy(column)
    xmax = xmin + xdim

    # intensity scale
    if imin is None and imax is None:
        imin, imax = PercentileInterval(95.).get_limits(pimg)
    else:
        if imin is None:
            imin, _ = PercentileInterval(95.).get_limits(pimg)
        else:
            _, imax = PercentileInterval(95.).get_limits(pimg)

    if zscale == 'sqrt':
        norm = ImageNormalize(vmin=imin, vmax=imax, stretch=SqrtStretch())
    elif zscale == 'linear':
        norm = ImageNormalize(vmin=imin, vmax=imax, stretch=LinearStretch())
    elif zscale == 'log':
        norm = ImageNormalize(vmin=imin, vmax=imax, stretch=LogStretch())

    zmin = copy(imin)
    zmax = copy(imax)

    # plot limits
    ymin = float(ymin) - 0.5
    ymax = float(ymax) - 0.5
    xmin = float(xmin) - 0.5
    xmax = float(xmax) - 0.5

    # plot style
    plt.rcParams['figure.dpi'] = 80
    plt.figure(figsize=[10, 7])

    global mask, aid, bid, cid, did, fid

    aid = plt.connect('button_press_event', clicker1)
    bid = plt.connect('button_press_event', clicker2)
    cid = plt.connect('button_press_event', clicker3)
    did = plt.connect('button_press_event', clicker4)
    fid = plt.connect('button_press_event', clicker6)

    redraw()
    plt.show()
예제 #5
0
                backgroundcolor="white",
                transform=ax1[i].transAxes)

    ax1[i].plot(null, np.zeros_like(null), "r:", lw=2)
    ax1[i].errorbar(distance,
                    fluxring[i],
                    yerr=erb[i],
                    fmt="ko-",
                    lw=3,
                    elinewidth=1,
                    capsize=2,
                    label=wav)

    # Plots SZ signal
    # normalises data with 'zscale' and stretches it by 'PowerStretch' as in ds9
    norm = ImageNormalize(SZ[i], interval=zscale(), stretch=PowerStretch(1.3))
    P1 = ax2[1, i].imshow(MCsim[i], norm=norm, aspect="auto")
    P2 = ax2[0, i].imshow(SZ[i], norm=norm, aspect="auto")
    cb = fig2.colorbar(P1, cax=cax[i], orientation="horizontal")
    cb.ax.xaxis.set_ticks_position("top")
    cb.set_label(r"$\mathrm{mJy/px}$", fontsize=10)
    cb.ax.xaxis.set_label_position("top")
    cb.ax.tick_params(labelsize=8)

    for j in range(len(ax2)):
        ax2[j, i].axis("off")

        t = ax2[j, i].text(0.69,
                           0.90,
                           wav,
                           color="w",
예제 #6
0
def create_index(filenames, directory, display=False, imagestretch='linear'):
    """
    create index.html
    diagnostic root website for one pipeline process
    """

    if display:
        print('create frame index table and frame images')
    logging.info('create frame index table and frame images')

    # obtain grating from first image file
    telescope, obsparam = CheckInstrument([filenames[0]])
    refheader = fits.open(filenames[0], ignore_missing_end=True)[0].header
    grating = [refheader[obsparam['grating']]]

    del (refheader)

    html = "<H2>data directory: %s</H2>\n" % directory

    html += ("<H1>%s/%s-band - Diagnostic Output</H1>\n" + \
               "%d frames total, see full pipeline " + \
               "<A HREF=\"%s\">log</A> for more information\n") % \
               (obsparam['telescope_instrument'], grating,
                len(filenames),
                '.diagnostics/' +
                _SP_conf.log_filename.split('.diagnostics/')[1])

    ### create frame information table
    html += "<P><TABLE BORDER=\"1\">\n<TR>\n"
    html += "<TH>Idx</TH><TH>Filename</TH><TH>Date</TH>" + \
            "<TH>Objectname</TH><TH>Grating</TH>" + \
            "<TH>Airmass</TH><TH>Exptime (s)</TH>" + \
            "<TH>Type</TH>\n</TR>\n"

    # fill table and create frames
    filename = filenames
    for idx, filename in enumerate(filenames):

        ### fill table
        hdulist = fits.open(filename, ignore_missing_end=True)
        header = hdulist[0].header

        # read out image binning mode
        binning = tb.get_binning(header, obsparam)

        #framefilename = _pp_conf.diagroot + '/' + filename + '.png'
        framefilename = '.diagnostics/' + filename + '.png'

        try:
            objectname = header[obsparam['object']]
        except KeyError:
            objectname = 'Unknown Target'

        html += ("<TR><TD>%d</TD><TD><A HREF=\"%s\">%s</A></TD>" + \
                 "<TD>%s</TD><TD>%s</TD>" + \
                 "<TD>%s</TD><TD>%4.2f</TD><TD>%.1f</TD>" + \
                 "<TD>%s</TD>\n</TR>\n") % \
            (idx+1, framefilename, filename, header[obsparam['date_keyword']],
             objectname,
             header[obsparam['grating']],
             float(header[obsparam['airmass']]),
             float(header[obsparam['exptime']]),
             header[obsparam['obstype']])

        ### create frame image
        imgdat = hdulist[0].data
        # clip extreme values to prevent crash of imresize
        imgdat = np.clip(imgdat, np.percentile(imgdat, 1),
                         np.percentile(imgdat, 99))
        imgdat = imresize(imgdat,
                          min(1., 1000. / np.max(imgdat.shape)),
                          interp='nearest')
        #resize image larger than 1000px on one side

        norm = ImageNormalize(imgdat,
                              interval=ZScaleInterval(),
                              stretch={
                                  'linear': LinearStretch(),
                                  'log': LogStretch()
                              }[imagestretch])

        plt.figure(figsize=(5, 5))

        img = plt.imshow(imgdat, cmap='gray', norm=norm, origin='lower')
        # remove axes
        plt.axis('off')
        img.axes.get_xaxis().set_visible(False)
        img.axes.get_yaxis().set_visible(False)

        plt.savefig(framefilename,
                    format='png',
                    bbox_inches='tight',
                    pad_inches=0,
                    dpi=200)

        plt.close()
        hdulist.close()
        del (imgdat)

    html += '</TABLE>\n'

    create_website(_SP_conf.index_filename, html)

    ### add to summary website, if requested
    #    if _pp_conf.use_diagnostics_summary:
    #        add_to_summary(header[obsparam['object']], filtername,
    #                       len(filenames))

    return None
예제 #7
0
def add_Extract(filenames_im,
                filenames_spec,
                html_file,
                imagestretch='linear'):
    """
    add pre-processing log to website
    """

    # update index.html
    html = '<H2>Extract list</H2>\n'
    html += '%d images have been processed' % \
            (len(filenames_im))

    html += "<P><TABLE BORDER=\"1\">\n<TR>\n"
    html += "<TH>Idx</TH><TH>Filename</TH> <TH>Images</TH></TR>\n"

    for idx, filename in enumerate(filenames_im):
        ### fill table
        spec = np.loadtxt(filenames_spec[idx])
        plt.figure()
        plt.plot(spec)
        plt.xlabel('pixels')
        plt.ylabel('Relative reflectance')
        specfilename = 'diagnostics/' + filenames_spec[idx] + '.png'
        plt.savefig(specfilename)
        plt.close()
        hdulist = fits.open(filename, ignore_missing_end=True)
        if not os.path.isdir('diagnostics'):
            os.mkdir('diagnostics')

        #framefilename = _SP_conf.diagroot + '/' + filename + '.png'
        framefilename = 'diagnostics/' + filename + '.png'

        html += ("<TR><TD>%d</TD><TD><A HREF=\"%s\">%s</A></TD> <TH> <IMG SRC=\"%s\">  </TH> <TH> <IMG SRC=\"%s\">  </TH></TR>\n") % \
            (idx+1, framefilename, filename,framefilename,specfilename)
        ### create frame image
        imgdat = hdulist[0].data
        # clip extreme values to prevent crash of imresize
        imgdat = np.clip(imgdat, np.percentile(imgdat, 1),
                         np.percentile(imgdat, 99))
        imgdat = imresize(imgdat,
                          min(1., 1000. / np.max(imgdat.shape)),
                          interp='nearest')
        #resize image larger than 1000px on one side

        norm = ImageNormalize(imgdat,
                              interval=ZScaleInterval(),
                              stretch={
                                  'linear': LinearStretch(),
                                  'log': LogStretch()
                              }[imagestretch])

        plt.figure(figsize=(5, 5))

        img = plt.imshow(imgdat, cmap='gray', norm=norm, origin='lower')
        # remove axes
        plt.axis('off')
        img.axes.get_xaxis().set_visible(False)
        img.axes.get_yaxis().set_visible(False)

        plt.savefig(framefilename,
                    format='png',
                    bbox_inches='tight',
                    pad_inches=0,
                    dpi=200)

        plt.close()
        hdulist.close()
        del (imgdat)

    html += '</TABLE>\n'

    append_website(html_file, html, replace_below="<H2>Extract list</H2>\n")
예제 #8
0
    def kde2D_plot(self,
                   parameter1,
                   parameter2,
                   normtype='log',
                   interval=None,
                   xlim=None,
                   ylim=None,
                   gridsize=100):
        """Generate a 2D KDE for the given parameters.

        Parameters
        ----------
        parameter1 : `numpy.array`
            X-axis variable

        parameter2 : `numpy.array`
            Y-axis variable

        normtype : {'log', 'linear', 'sqrt'}
            Normalization type to apply to the data

        interval : tuple
            Limits of the interval to use when computing the image scaling

        xlim : tuple
            X-limits to use for the plot and the KDE grid

        ylim : tuple
            Y-limits to use for the plot and the KDE grid

        gridsize : int
            Step-size for the grid

        Returns
        -------
        fig : :py:class:`matplotlib.figure.Figure`

        ax : :py:class:`matplotlib.axes.Axes`

        surface : numpy.array
            The KDE surface plot
        """

        data = np.vstack([parameter1, parameter2])

        if xlim is None:
            xlim = (np.min(parameter1), np.max(parameter1))

        if ylim is None:
            ylim = (np.min(parameter2), np.max(parameter2))

        # Generate a grid to compute the KDE over
        xgrid = np.linspace(xlim[0], xlim[1], gridsize)
        ygrid = np.linspace(ylim[0], ylim[1], gridsize)

        kde = gaussian_kde(data)
        Xgrid, Ygrid = np.meshgrid(xgrid, ygrid)
        surface = kde.evaluate(np.vstack([Xgrid.ravel(), Ygrid.ravel()]))

        if isinstance(interval, tuple):
            Interval = ManualInterval(vmin=interval[0], vmax=interval[1])
        else:
            Interval = ZScaleInterval()

        norm = ImageNormalize(surface,
                              stretch=self.image_norms[normtype],
                              interval=Interval)

        fig, ax = self.mk_fig(nrows=1, ncols=1)
        ax.imshow(surface.reshape(Xgrid.shape),
                  norm=norm,
                  cmap='gray',
                  origin='lower',
                  aspect='auto',
                  extent=[xgrid.min(),
                          xgrid.max(),
                          ygrid.min(),
                          ygrid.max()])

        return fig, ax, surface
예제 #9
0
    def plot_hst_loc(self,
                     i=5,
                     df=None,
                     title='',
                     thresh=5,
                     fout='',
                     min_exptime=800,
                     key='start',
                     save=False,
                     orbital_path1=None,
                     orbital_path2=None):

        self.fig = plt.figure(figsize=(8, 6))
        # Get the model for the SAA
        self.map = Basemap(projection='cyl')
        self._draw_map()
        df = df[df.integration_time.gt(min_exptime)]
        df.sort_values(by='incident_cr_rate', inplace=True)
        cbar_bounds = [0, 20, 40, 60, 80, 100, 120, 140, 160]
        sci_cmap = plt.cm.gray
        custom_norm = colors.BoundaryNorm(boundaries=cbar_bounds,
                                          ncolors=sci_cmap.N)
        # Generate an SAA contour
        saa = [list(t) for t in zip(*costools.saamodel.saaModel(i))]
        # Ensure the polygon representing the SAA is a closed curve by adding
        # the starting points to the end of the list of lat/lon coords
        saa[0].append(saa[0][0])
        saa[1].append(saa[1][0])
        self.map.plot(saa[1],
                      saa[0],
                      c='k',
                      latlon=True,
                      label='SAA contour {}'.format(i))
        # df = self.perform_SAA_cut(df=df, key=key)
        if df is None:
            lat, lon, rate = self.data_df['latitude_{}'.format(key)], \
                             self.data_df['longitude_{}'.format(key)], \
                             self.data_df['incident_cr_rate']
        else:
            #df = df[df['integration_time'] > 800]
            lat, lon, rate = df['latitude_{}'.format(key)], \
                             df['longitude_{}'.format(key)], \
                             df['incident_cr_rate']
            LOG.info('{} {} {}'.format(len(lat), len(lon), len(rate)))

            # lat1, lon1, rate1 = lat[rate >0], lon[rate >0], rate[rate>0]
            # LOG.info('{} {} {}'.format(len(lat), len(lon), len(rate)))
        # median = np.median(rate)
        # std = np.std(rate)
        mean, median, std = sigma_clipped_stats(rate,
                                                sigma_lower=3,
                                                sigma_upper=3)
        LOG.info('{} +\- {}'.format(median, std))
        norm = ImageNormalize(rate,
                              stretch=LinearStretch(),
                              vmin=mean - thresh * std,
                              vmax=mean + thresh * std)
        cbar_below_mean = [mean - (i + 1) * std for i in range(thresh)]
        cbar_above_mean = [mean + (i + 1) * std for i in range(thresh)]

        cbar_bounds = cbar_below_mean + [mean] + cbar_above_mean
        print(cbar_bounds)
        cbar_bounds.sort()
        sci_cmap = plt.cm.viridis
        custom_norm = colors.BoundaryNorm(boundaries=cbar_bounds,
                                          ncolors=sci_cmap.N)

        scat = self.map.scatter(lon.values,
                                lat.values,
                                marker='o',
                                s=5,
                                latlon=True,
                                c=rate,
                                alpha=0.15,
                                norm=custom_norm,
                                cmap='viridis')
        #im = self.map.contourf(lon_grid, lat_grid, rate, norm=norm, cmap='viridis')
        ax = plt.gca()
        ax.set_title(title)

        # Plot the path of HST
        #self.map.plot(
        #    orbital_path1.metadata['longitude'],
        #    orbital_path1.metadata['latitude'],lw=1.25,
        #    label=f'Int. Time: {1000:.1f}s', color='k', ls='-'
        #)
        if orbital_path2 is not None:
            self.map.scatter(orbital_path2.metadata['longitude'][::4][1:],
                             orbital_path2.metadata['latitude'][::4][1:],
                             c='k',
                             s=20,
                             label='285 seccond interval')
        if orbital_path1 is not None:
            self.map.plot(orbital_path2.metadata['longitude'],
                          orbital_path2.metadata['latitude'],
                          label=f'Orbital Path Over {2000:.0f} seconds',
                          color='k',
                          ls='--',
                          lw=1.25)

        ax1_legend = ax.legend(loc='upper right',
                               ncol=1,
                               labelspacing=0.2,
                               columnspacing=0.5,
                               edgecolor='k')
        # for i in range(len(ax1_legend.legendHandles)):
        #     ax1_legend.legendHandles[i]._sizes = [30]
        #cbar_tick_labels = [f'<x>-{thresh}$\sigma$', '<x>', f'<x>+{thresh}$\sigma$']
        #cbar_ticks = [mean - thresh*std,mean, mean + thresh*std]
        cbar_ticks = cbar_bounds
        cax = self.fig.add_axes([0.1, 0.1, 0.8, 0.05])
        cbar = self.fig.colorbar(scat,
                                 cax=cax,
                                 ticks=cbar_ticks,
                                 orientation='horizontal')
        cbar.set_alpha(1)
        cbar.draw_all()
        cbar_tick_labels = [f'<x>-{i}$\sigma$' for i in [5, 4, 3, 2, 1]] + [
            '<x>'
        ] + [f'<x>+{i}$\sigma$' for i in [1, 2, 3, 4, 5]]
        cbar.ax.set_xticklabels(cbar_tick_labels,
                                horizontalalignment='right',
                                rotation=30)

        cbar.set_label('CR Flux [CR/s/$cm^2$]', fontsize=10)
        # cbar.ax.set_xticklabels(cbar.ax.get_xticklabels(),
        #                         fontweight='medium',fontsize=8)
        if save:
            if not fout:
                fout = 'lat_lon_{}.png'.format(key)

            self.fig.savefig(fout,
                             format='png',
                             bbox_inches='tight',
                             dpi=350,
                             transparent=False)
        plt.show()
        return self.fig
예제 #10
0
def get_finding_chart(
    source_ra,
    source_dec,
    source_name,
    image_source='ps1',
    output_format='pdf',
    imsize=3.0,
    tick_offset=0.02,
    tick_length=0.03,
    fallback_image_source='dss',
    zscale_contrast=0.045,
    zscale_krej=2.5,
    extra_display_string="",
    **offset_star_kwargs,
):
    """Create a finder chart suitable for spectroscopic observations of
       the source

    Parameters
    ----------
    source_ra : float
        Right ascension (J2000) of the source
    source_dec : float
        Declination (J2000) of the source
    source_name : str
        Name of the source
    image_source : {'desi', 'dss', 'ztfref', 'ps1'}, optional
        Survey where the image comes from "desi", "dss", "ztfref", "ps1"
        defaults to "ps1"
    output_format : str, optional
        "pdf" of "png" -- determines the format of the returned finder
    imsize : float, optional
        Requested image size (on a size) in arcmin. Should be between 2-15.
    tick_offset : float, optional
        How far off the each source should the tick mark be made? (in arcsec)
    tick_length : float, optional
        How long should the tick mark be made? (in arcsec)
    fallback_image_source : str, optional
        Where what `image_source` should we fall back to if the
        one requested fails
    zscale_contrast : float, optional
        Contrast parameter for the ZScale interval
    zscale_krej : float, optional
        Krej parameter for the Zscale interval
    extra_display_string :  str, optional
        What else to show for the source itself in the chart (e.g. proper motion)
    **offset_star_kwargs : dict, optional
        Other parameters passed to `get_nearby_offset_stars`

    Returns
    -------
    dict
        success : bool
            Whether the request was successful or not, returning
            a sensible error in 'reason'
        name : str
            suggested filename based on `source_name` and `output_format`
        data : str
            binary encoded data for the image (to be streamed)
        reason : str
            If not successful, a reason is returned.
    """
    if (imsize < 2.0) or (imsize > 15):
        return {
            'success': False,
            'reason': 'Requested `imsize` out of range',
            'data': '',
            'name': '',
        }

    if image_source not in source_image_parameters:
        return {
            'success': False,
            'reason': f'image source {image_source} not in list',
            'data': '',
            'name': '',
        }

    matplotlib.use("Agg")
    fig = plt.figure(figsize=(11, 8.5), constrained_layout=False)
    widths = [2.6, 1]
    heights = [2.6, 1]
    spec = fig.add_gridspec(
        ncols=2,
        nrows=2,
        width_ratios=widths,
        height_ratios=heights,
        left=0.05,
        right=0.95,
    )

    # how wide on the side will the image be? 256 as default
    npixels = source_image_parameters[image_source].get("npixels", 256)
    # set the pixelscale in arcsec (typically about 1 arcsec/pixel)
    pixscale = 60 * imsize / npixels

    hdu = fits_image(source_ra,
                     source_dec,
                     imsize=imsize,
                     image_source=image_source)

    # skeleton WCS - this is the field that the user requested
    wcs = WCS(naxis=2)

    # set the headers of the WCS.
    # The center of the image is the reference point (source_ra, source_dec):
    wcs.wcs.crpix = [npixels / 2, npixels / 2]
    wcs.wcs.crval = [source_ra, source_dec]

    # create the pixel scale and orientation North up, East left
    # pixelscale is in degrees, established in the tangent plane
    # to the reference point
    wcs.wcs.cd = np.array([[-pixscale / 3600, 0], [0, pixscale / 3600]])
    wcs.wcs.ctype = ["RA---TAN", "DEC--TAN"]

    fallback = True
    if hdu is not None:
        im = hdu.data

        # replace the nans with medians
        im[np.isnan(im)] = np.nanmedian(im)

        # Fix the header keyword for the input system, if needed
        hdr = hdu.header
        if 'RADECSYS' in hdr:
            hdr.set('RADESYSa', hdr['RADECSYS'], before='RADECSYS')
            del hdr['RADECSYS']

        if source_image_parameters[image_source].get("reproject", False):
            # project image to the skeleton WCS solution
            log("Reprojecting image to requested position and orientation")
            im, _ = reproject_adaptive(hdu, wcs, shape_out=(npixels, npixels))
        else:
            wcs = WCS(hdu.header)

        if source_image_parameters[image_source].get("smooth", False):
            im = gaussian_filter(
                hdu.data,
                source_image_parameters[image_source]["smooth"] / pixscale)

        cent = int(npixels / 2)
        width = int(0.05 * npixels)
        test_slice = slice(cent - width, cent + width)
        all_nans = np.isnan(im[test_slice, test_slice].flatten()).all()
        all_zeros = (im[test_slice, test_slice].flatten() == 0).all()
        if not (all_zeros or all_nans):
            percents = np.nanpercentile(im.flatten(), [10, 99.0])
            vmin = percents[0]
            vmax = percents[1]
            interval = ZScaleInterval(
                nsamples=int(0.1 * (im.shape[0] * im.shape[1])),
                contrast=zscale_contrast,
                krej=zscale_krej,
            )
            norm = ImageNormalize(im, vmin=vmin, vmax=vmax, interval=interval)
            watermark = source_image_parameters[image_source]["str"]
            fallback = False

    if hdu is None or fallback:
        # if we got back a blank image, try to fallback on another survey
        # and return the results from that call
        if fallback_image_source is not None:
            if fallback_image_source != image_source:
                log(f"Falling back on image source {fallback_image_source}")
                return get_finding_chart(
                    source_ra,
                    source_dec,
                    source_name,
                    image_source=fallback_image_source,
                    output_format=output_format,
                    imsize=imsize,
                    tick_offset=tick_offset,
                    tick_length=tick_length,
                    fallback_image_source=None,
                    **offset_star_kwargs,
                )

        # we dont have an image here, so let's create a dummy one
        # so we can still plot
        im = np.zeros((npixels, npixels))
        watermark = None
        vmin = 0
        vmax = 0
        norm = ImageNormalize(im, vmin=vmin, vmax=vmax)

    # add the images in the top left corner
    ax = fig.add_subplot(spec[0, 0], projection=wcs)
    ax_text = fig.add_subplot(spec[0, 1])
    ax_text.axis('off')
    ax_starlist = fig.add_subplot(spec[1, 0:])
    ax_starlist.axis('off')

    ax.imshow(im, origin='lower', norm=norm, cmap='gray_r')
    ax.set_autoscale_on(False)
    ax.grid(color='white', ls='dotted')
    ax.set_xlabel(r'$\alpha$ (J2000)', fontsize='large')
    ax.set_ylabel(r'$\delta$ (J2000)', fontsize='large')
    obstime = offset_star_kwargs.get("obstime",
                                     datetime.datetime.utcnow().isoformat())
    ax.set_title(
        f'{source_name} Finder (for {obstime.split("T")[0]})',
        fontsize='large',
        fontweight='bold',
    )

    star_list, _, _, _, used_ztfref = get_nearby_offset_stars(
        source_ra, source_dec, source_name, **offset_star_kwargs)

    if not isinstance(star_list, list) or len(star_list) == 0:
        return {
            'success': False,
            'reason': 'failure to get star list',
            'data': '',
            'name': '',
        }

    ncolors = len(star_list)
    if star_list[0]['str'].startswith("!Data"):
        ncolors -= 1
    colors = sns.color_palette("colorblind", ncolors)

    start_text = [-0.45, 0.99]
    origin = "GaiaEDR3" if not used_ztfref else "ZTFref"
    starlist_str = (
        f"# Note: {origin} used for offset star positions\n"
        "# Note: spacing in starlist many not copy/paste correctly in PDF\n" +
        "#       you can get starlist directly from" +
        f" /api/sources/{source_name}/offsets?" +
        f"facility={offset_star_kwargs.get('facility', 'Keck')}\n" +
        "\n".join([x["str"] for x in star_list]))

    # add the starlist
    ax_starlist.text(
        0,
        0.50,
        starlist_str,
        fontsize="x-small",
        family='monospace',
        transform=ax_starlist.transAxes,
    )

    # add the watermark for the survey
    props = dict(boxstyle='round', facecolor='gray', alpha=0.7)

    if watermark is not None:
        ax.text(
            0.035,
            0.035,
            watermark,
            horizontalalignment='left',
            verticalalignment='center',
            transform=ax.transAxes,
            fontsize='medium',
            fontweight='bold',
            color="yellow",
            alpha=0.5,
            bbox=props,
        )

    date_obs = hdr.get('DATE-OBS')
    if not date_obs:
        mjd_obs = hdr.get('MJD-OBS')
        if mjd_obs:
            date_obs = Time(f"{mjd_obs}",
                            format='mjd').to_value('fits', subfmt='date_hms')

    if date_obs:
        ax.text(
            0.95,
            0.95,
            f'image date {date_obs.split("T")[0]}',
            horizontalalignment='right',
            verticalalignment='center',
            transform=ax.transAxes,
            fontsize='small',
            color="yellow",
            alpha=0.5,
            bbox=props,
        )

    ax.text(
        0.95,
        0.035,
        f"{imsize}\u2032 \u00D7 {imsize}\u2032",  # size'x size'
        horizontalalignment='right',
        verticalalignment='center',
        transform=ax.transAxes,
        fontsize='medium',
        fontweight='bold',
        color="yellow",
        alpha=0.5,
        bbox=props,
    )

    # compass rose
    # rose_center_pixel = ax.transAxes.transform((0.04, 0.95))
    rose_center = pixel_to_skycoord(int(npixels * 0.1), int(npixels * 0.9),
                                    wcs)
    props = dict(boxstyle='round', facecolor='gray', alpha=0.5)

    for ang, label, off in [(0, "N", 0.01), (90, "E", 0.03)]:
        position_angle = ang * u.deg
        separation = (0.05 * imsize * 60) * u.arcsec  # 5%
        p2 = rose_center.directional_offset_by(position_angle, separation)
        ax.plot(
            [rose_center.ra.value, p2.ra.value],
            [rose_center.dec.value, p2.dec.value],
            transform=ax.get_transform('world'),
            color="gold",
            linewidth=2,
        )

        # label N and E
        position_angle = (ang + 15) * u.deg
        separation = ((0.05 + off) * imsize * 60) * u.arcsec
        p2 = rose_center.directional_offset_by(position_angle, separation)
        ax.text(
            p2.ra.value,
            p2.dec.value,
            label,
            color="gold",
            transform=ax.get_transform('world'),
            fontsize='large',
            fontweight='bold',
        )

    # account for Shane header
    if star_list[0]['str'].startswith("!Data"):
        star_list = star_list[1:]

    for i, star in enumerate(star_list):

        c1 = SkyCoord(star["ra"] * u.deg, star["dec"] * u.deg, frame='icrs')

        # mark up the right side of the page with position and offset info
        name_title = star["name"]
        if star.get("mag") is not None:
            name_title += f" {star.get('mag'):.2f} mag"
        ax_text.text(
            start_text[0],
            start_text[1] - (i * 1.1) / ncolors,
            name_title,
            ha='left',
            va='top',
            fontsize='large',
            fontweight='bold',
            transform=ax_text.transAxes,
            color=colors[i],
        )
        source_text = f"  {star['ra']:.5f} {star['dec']:.5f}\n"
        source_text += f"  {c1.to_string('hmsdms', precision=2)}\n"
        if i == 0 and extra_display_string != "":
            source_text += f"  {extra_display_string}\n"
        if ((star.get("dras") is not None) and (star.get("ddecs") is not None)
                and (star.get("pa") is not None)):
            source_text += f'  {star.get("dras")} {star.get("ddecs")} (PA={star.get("pa"):<0.02f}°)'
        ax_text.text(
            start_text[0],
            start_text[1] - (i * 1.1) / ncolors - 0.06,
            source_text,
            ha='left',
            va='top',
            fontsize='large',
            transform=ax_text.transAxes,
            color=colors[i],
        )

        # work on making marks where the stars are
        for ang in [0, 90]:
            # for the source itself (i=0), change the angle of the lines in
            # case the offset star is the same as the source itself
            position_angle = ang * u.deg if i != 0 else (ang + 225) * u.deg
            separation = (tick_offset * imsize * 60) * u.arcsec
            p1 = c1.directional_offset_by(position_angle, separation)
            separation = (tick_offset + tick_length) * imsize * 60 * u.arcsec
            p2 = c1.directional_offset_by(position_angle, separation)
            ax.plot(
                [p1.ra.value, p2.ra.value],
                [p1.dec.value, p2.dec.value],
                transform=ax.get_transform('world'),
                color=colors[i],
                linewidth=3 if imsize <= 4 else 2,
                alpha=0.8,
            )
        if star["name"].find("_o") != -1:
            # this is an offset star
            text = star["name"].split("_o")[-1]
            position_angle = 14 * u.deg
            separation = (tick_offset +
                          tick_length * 1.6) * imsize * 60 * u.arcsec
            p1 = c1.directional_offset_by(position_angle, separation)
            ax.text(
                p1.ra.value,
                p1.dec.value,
                text,
                color=colors[i],
                transform=ax.get_transform('world'),
                fontsize='large',
                fontweight='bold',
            )

    buf = io.BytesIO()
    fig.savefig(buf, format=output_format)
    plt.close(fig)
    buf.seek(0)

    return {
        "success": True,
        "name": f"finder_{source_name}.{output_format}",
        "data": buf.read(),
        "reason": "",
    }
def pathtest(step_input_filename,
             reffile,
             comparison_filename,
             writefile=True,
             show_figs=False,
             save_figs=True,
             threshold_diff=1.0e-7,
             debug=False):
    """
    This function calculates the difference between the pipeline and
    calculated pathloss values.
    Args:
        step_input_filename: str, full path name of sourcetype output fits file
        reffile: str, path to the pathloss IFU reference fits file
        comparison_filename: str, path to comparison pipeline pathloss file
        writefile: boolean, if True writes calculated flat and
                   difference image fits files
        show_figs: boolean, whether to show plots or not
        save_figs: boolean, save the plots
        threshold_diff: float, threshold diff between pipeline and ESA file
        debug: boolean, if true print statements will show on-screen
    Returns:
        - 1 plot, if told to save and/or show them.
        - median_diff: Boolean, True if smaller or equal to threshold.
        - log_msgs: list, all print statements are captured in this variable
    """

    log_msgs = []

    # start the timer
    pathtest_start_time = time.time()

    # get info from the rate file header
    msg = 'step_input_filename=' + step_input_filename
    print(msg)
    log_msgs.append(msg)
    exptype = fits.getval(step_input_filename, "EXP_TYPE", 0)
    grat = fits.getval(step_input_filename, "GRATING", 0)
    filt = fits.getval(step_input_filename, "FILTER", 0)

    msg = "pathloss file: Grating:" + grat + " Filter:" + filt + " EXP_TYPE:" + exptype
    print(msg)
    log_msgs.append(msg)

    msg = "Using reference file: " + reffile
    print(msg)
    log_msgs.append(msg)

    is_point_source = True

    if writefile:
        # create the fits list to hold the calculated pathloss values for each slit
        hdu0 = fits.PrimaryHDU()
        outfile = fits.HDUList()
        outfile.append(hdu0)

        # create fits list to hold pipeline-calculated difference values
        hdu0 = fits.PrimaryHDU()
        compfile = fits.HDUList()
        compfile.append(hdu0)

    # list to determine if pytest is passed or not
    total_test_result = []

    # get all the science extensions
    if is_point_source:
        ext = 1  # for all PS IFU

    # get files
    print('Checking files exist & obtaining datamodels. Takes a few mins...')
    if os.path.isfile(comparison_filename):
        if debug:
            print('Comparison file does exist.')
    else:
        result_msg = 'Comparison file does NOT exist. Skipping pathloss test.'
        print(result_msg)
        log_msgs.append(result_msg)
        result = 'skip'
        return result, result_msg, log_msgs

    # get the comparison data model
    ifu_pipe_model = datamodels.open(comparison_filename)
    if debug:
        print('got comparison datamodel!')

    if os.path.isfile(step_input_filename):
        if debug:
            print('Input file does exist.')
    else:
        result_msg = 'Input file does NOT exist. Skipping pathloss test.'
        log_msgs.append(result_msg)
        result = 'skip'
        return result, result_msg, log_msgs

    # get the input data model
    ifu_input_model = datamodels.open(step_input_filename)
    if debug:
        print('got input datamodel!')

    # get slices (instead of using .slit)
    pl_ifu_slits = nirspec.nrs_ifu_wcs(ifu_input_model)
    print("got input slices")

    plcor_ref_ext = fits.getdata(reffile, ext)
    hdul = fits.open(reffile)
    plcor_ref = hdul[1].data
    w = wcs.WCS(hdul[1].header)

    w1, y1, x1 = np.mgrid[:plcor_ref.shape[0], :plcor_ref.shape[1], :plcor_ref.
                          shape[2]]
    slitx_ref, slity_ref, wave_ref = w.all_pix2world(x1, y1, w1, 0)

    # these are full 2048 * 2048 files:
    previous_sci = fits.getdata(step_input_filename, "SCI")
    comp_sci = fits.getdata(comparison_filename, "SCI")
    pathloss_divided = comp_sci / previous_sci

    # Can manually test correction at nonzero point
    # slit_x = -0.3
    # slit_y = 0.3
    # if debug:
    #     print("""WARNING: Using manually set slit_x and slit_y!
    #           The pipeline correction will not use manually set values
    #           and thus the residuals will change""")

    # set up generals for all plots
    font = {'weight': 'normal', 'size': 10}
    matplotlib.rc('font', **font)

    # loop through the slices
    msg = " Looping through the slices... "
    print(msg)
    log_msgs.append(msg)

    slit_list = np.ndarray.tolist(np.arange(0, 30))
    for slit, slit_num in zip(pl_ifu_slits, slit_list):
        print("working with slice {}".format(slit_num))

        x, y = wcstools.grid_from_bounding_box(slit.bounding_box, step=(1, 1))
        ra, dec, wave = slit(x, y)

        slit_x = 0  # This assumption is made for IFU sources only
        slit_y = 0

        if debug:
            print("slit_x, slit_y", slit_x, slit_y)

        correction_array = np.array([])
        lambda_array = np.array([])

        wave_sci = wave * 10**(-6)  # microns --> meters
        wave_sci_flat = wave_sci.reshape(wave_sci.size)
        wave_ref_flat = wave_ref.reshape(wave_ref.size)

        ref_xy = np.column_stack((slitx_ref.reshape(slitx_ref.size),
                                  slity_ref.reshape(slitx_ref.size)))

        # loop through slices in lambda from reference file
        shape = 0
        for lambda_val in wave_ref_flat:
            # loop through every lambda value
            # flattened so that looping works smoothly
            shape = shape + 1
            index = np.where(wave_ref[:, 0, 0] == lambda_val)
            # index of closest lambda value in reffile to given sci lambda
            #   took index of only the first slice of wave_ref because
            #   the others were repetitive & we got extra indices
            # take slice where lambda=index:
            plcor_slice = plcor_ref_ext[index[0][0]].reshape(
                plcor_ref_ext[index[0][0]].size)
            # do 2d interpolation to get a single correction factor for each slice
            corr_val = scipy.interpolate.griddata(ref_xy[:plcor_slice.size],
                                                  plcor_slice,
                                                  np.asarray([slit_x, slit_y]),
                                                  method='linear')
            # append values from loop to create a vector of correction factors
            correction_array = np.append(correction_array, corr_val[0])
            # map to array with corresponding lambda
            lambda_array = np.append(lambda_array, lambda_val)

        # get correction value for each pixel
        corr_vals = np.interp(wave_sci_flat, lambda_array, correction_array)
        corr_vals = corr_vals.reshape(wave_sci.shape)

        box = slit.bounding_box
        small_y = box[0][0]
        big_y = box[0][1]
        small_x = box[1][0]
        big_x = box[1][1]

        left = int(math.trunc(small_x))
        right = int(math.ceil(big_x))
        bottom = int(math.trunc(small_y))
        top = int(math.ceil(big_y))

        full_cut2slice = previous_sci[left:right, bottom:top]
        print("shapes:", full_cut2slice.shape, corr_vals.shape)

        if full_cut2slice.shape != corr_vals.shape:
            value = 0
            while (((full_cut2slice.shape[0] - corr_vals.shape[0]) != 0) or
                   ((full_cut2slice.shape[1] - corr_vals.shape[1]) != 0)) and \
                    value < 7:  # can delete second criteria once all pass
                if value == 6:
                    print("WARNING: may be in infinite loop!")
                x_amount_off = full_cut2slice.shape[0] - corr_vals.shape[0]
                if x_amount_off >= 1:
                    if x_amount_off % 2 == 0:  # need to add two values, so do one on each side
                        right = right - 1
                        left = left + 1
                        print("ALTERED SHAPE OF SLICE: V1")
                        value = value + 1
                        print("left {}, right {}, top {}, bottom {}".format(
                            left, right, top, bottom))
                    else:  # just add one value
                        left = left + 1
                        print("ALTERED SHAPE OF SLICE: V2")
                        value = value + 1
                        print("left {}, right {}, top {}, bottom {}".format(
                            left, right, top, bottom))
                elif x_amount_off <= -1:
                    if x_amount_off % 2 == 0:
                        right = right + 1
                        left = left - 1
                        print("ALTERED SHAPE OF SLICE: V3")
                        value = value + 1
                        print("left {}, right {}, top {}, bottom {}".format(
                            left, right, top, bottom))
                    else:
                        left = left - 1
                        print("ALTERED SHAPE OF SLICE: V4")
                        value = value + 1
                        print("left {}, right {}, top {}, bottom {}".format(
                            left, right, top, bottom))
                y_amount_off = full_cut2slice.shape[1] - corr_vals.shape[1]
                if y_amount_off >= 1:
                    if y_amount_off % 2 == 0:
                        bottom = bottom - 1
                        top = top + 1
                        print("ALTERED SHAPE OF SLICE: V5")
                        value = value + 1
                        print("left {}, right {}, top {}, bottom {}".format(
                            left, right, top, bottom))
                    else:
                        bottom = bottom + 1
                        print("ALTERED SHAPE OF SLICE: V6")
                        value = value + 1
                        print("left {}, right {}, top {}, bottom {}".format(
                            left, right, top, bottom))
                elif y_amount_off <= -1:
                    if y_amount_off % 2 == 0:
                        top = top + 1
                        bottom = bottom - 1
                        print("ALTERED SHAPE OF SLICE: V7")
                        value = value + 1
                        print("left {}, right {}, top {}, bottom {}".format(
                            left, right, top, bottom))
                    else:
                        bottom = bottom - 1
                        print("ALTERED SHAPE OF SLICE: V8")
                        value = value + 1
                        print("left {}, right {}, top {}, bottom {}".format(
                            left, right, top, bottom))
                full_cut2slice = previous_sci[left:right, bottom:top]
                print("final left {}, right {}, top {}, bottom {}".format(
                    left, right, top, bottom))
                print("NEW SHAPE OF SLICE: {} and corr_vals.shape: {}".format(
                    full_cut2slice.shape, corr_vals.shape))

        if full_cut2slice.shape != corr_vals.shape:
            print("shapes did not match! full_cut2slice: {}, corr_vals {}".
                  format(full_cut2slice.shape, corr_vals.shape))
            continue

        corrected_array = full_cut2slice / corr_vals

        pipe_correction = pathloss_divided[left:right, bottom:top]
        if pipe_correction.shape != corr_vals.shape:
            print("shapes did not match! pipe_correction: {}, corr_vals {}".
                  format(pipe_correction.shape, corr_vals.shape))
            continue

        prev_sci_slit = previous_sci[left:right, bottom:top]
        if prev_sci_slit.shape != corr_vals.shape:
            print(
                "shapes did not match! prev_sci_slit: {}, corr_vals {}".format(
                    prev_sci_slit.shape, corr_vals.shape))
            continue

        comp_sci_slit = comp_sci[left:right, bottom:top]
        if comp_sci_slit.shape != corr_vals.shape:
            print(
                "shapes did not match! comp_sci_slit: {}, corr_vals {}".format(
                    comp_sci_slit.shape, corr_vals.shape))
            continue

        # Plots:
        step_input_filepath = step_input_filename.replace(".fits", "")
        # my correction values
        fig = plt.figure(figsize=(15, 15))
        plt.subplot(221)
        norm = ImageNormalize(corr_vals)
        plt.imshow(corr_vals,
                   vmin=0.999995,
                   vmax=1.000005,
                   aspect=10.0,
                   origin='lower',
                   cmap='viridis')
        plt.xlabel('dispersion in pixels')
        plt.ylabel('y in pixels')
        plt.title('Calculated Correction')
        plt.colorbar()
        # pipe correction
        plt.subplot(222)
        norm = ImageNormalize(pipe_correction)
        plt.imshow(pipe_correction,
                   vmin=0.999995,
                   vmax=1.000005,
                   aspect=10.0,
                   origin='lower',
                   cmap='viridis')
        plt.xlabel('dispersion in pixels')
        plt.ylabel('y in pixels')
        plt.title('Pipe Correction')
        plt.colorbar()
        # residuals (pipe correction - my correction)
        if pipe_correction.shape == corr_vals.shape:
            corr_residuals = pipe_correction - corr_vals
            plt.subplot(223)
            norm = ImageNormalize(corr_residuals)
            plt.imshow(corr_residuals,
                       vmin=-0.000000005,
                       vmax=0.000000005,
                       aspect=10.0,
                       origin='lower',
                       cmap='viridis')
            plt.xlabel('dispersion in pixels')
            plt.ylabel('y in pixels')
            plt.title('Correction residuals')
            plt.colorbar()
        # my science data after pathloss
        plt.subplot(224)
        norm = ImageNormalize(corrected_array)
        plt.imshow(corrected_array,
                   vmin=0,
                   vmax=300,
                   aspect=10.0,
                   origin='lower',
                   cmap='viridis')
        plt.title('My slit science data after pathloss')
        plt.xlabel('dispersion in pixels')
        plt.ylabel('y in pixels')
        plt.colorbar()
        fig.suptitle("IFU PS Pathloss Correction Testing" +
                     str(corrected_array.shape))

        if show_figs:
            plt.show()
        if save_figs:
            plt_name = step_input_filepath + "_Pathloss_test_slitlet_IFU_PS_" + str(
                slit_num) + ".png"
            plt.savefig(plt_name)
            print('Figure saved as: ', plt_name)

        elif not save_figs and not show_figs:
            msg = "Not making plots because both show_figs and save_figs were set to False."
            if debug:
                print(msg)
            log_msgs.append(msg)
        elif not save_figs:
            msg = "Not saving plots because save_figs was set to False."
            if debug:
                print(msg)
            log_msgs.append(msg)
        plt.clf()

        # create fits file to hold the calculated pathloss for each slit
        if writefile:
            msg = "Saving the fits files with the calculated pathloss for each slit..."
            print(msg)
            log_msgs.append(msg)

            # this is the file to hold the image of pipeline-calculated difference values
            outfile_ext = fits.ImageHDU(corr_vals, name=str(slit_num))
            outfile.append(outfile_ext)

            # this is the file to hold the image of pipeline-calculated difference values
            compfile_ext = fits.ImageHDU(corr_residuals, name=str(slit_num))
            compfile.append(compfile_ext)

        if corr_residuals[~np.isnan(corr_residuals)].size == 0:
            msg1 = " * Unable to calculate statistics because difference array has all values as NaN. " \
                   "Test will be set to FAILED."
            print(msg1)
            log_msgs.append(msg1)
            test_result = "FAILED"
        else:
            msg = "Calculating statistics... "
            print(msg)
            log_msgs.append(msg)
            corr_residuals = corr_residuals[
                np.where((corr_residuals != 999.0) & (corr_residuals < 0.1)
                         & (corr_residuals > -0.1))]  # ignore outliers
            if corr_residuals.size == 0:
                msg1 = """ * Unable to calculate statistics because
                           difference array has all outlier values.
                           Test will be set to FAILED."""
                print(msg1)
                log_msgs.append(msg1)
                test_result = "FAILED"
            else:
                stats_and_strings = auxfunc.print_stats(corr_residuals,
                                                        "Difference",
                                                        float(threshold_diff),
                                                        absolute=True)
                stats, stats_print_strings = stats_and_strings
                corr_residuals_mean, corr_residuals_median, corr_residuals_std = stats
                for msg in stats_print_strings:
                    log_msgs.append(msg)

                # This is the key argument for the assert pytest function
                median_diff = False
                if abs(corr_residuals_median) <= float(threshold_diff):
                    median_diff = True
                if median_diff:
                    test_result = "PASSED"
                else:
                    test_result = "FAILED"

        msg = " *** Result of the test: " + test_result + "\n"
        print(msg)
        log_msgs.append(msg)
        total_test_result.append(test_result)

    if writefile:
        outfile_name = step_input_filename.replace("srctype",
                                                   "_calcuated_pathloss")
        compfile_name = step_input_filename.replace("srctype",
                                                    "_comparison_pathloss")

        # create the fits list to hold the calculated pathloss values for each slit
        outfile.writeto(outfile_name, overwrite=True)

        # this is the file to hold the image of pipeline-calculated difference values
        compfile.writeto(compfile_name, overwrite=True)

        msg = "\nFits file with calculated pathloss values of each slit saved as: "
        print(msg)
        log_msgs.append(msg)
        print(outfile_name)
        log_msgs.append(outfile_name)

        msg = "Fits file with comparison (pipeline pathloss - calculated pathloss) saved as: "
        print(msg)
        log_msgs.append(msg)
        print(compfile_name)
        log_msgs.append(compfile_name)

    # If all tests passed then pytest will be marked as PASSED, else FAILED
    FINAL_TEST_RESULT = False
    for t in total_test_result:
        if t == "FAILED":
            FINAL_TEST_RESULT = False
            break
        else:
            FINAL_TEST_RESULT = True

    if FINAL_TEST_RESULT:
        msg = "\n *** Final pathloss test result reported as PASSED *** \n"
        print(msg)
        log_msgs.append(msg)
        result_msg = "All slits PASSED path_loss test."
    else:
        msg = "\n *** Final pathloss test result reported as FAILED *** \n"
        print(msg)
        log_msgs.append(msg)
        result_msg = "One or more slits FAILED path_loss test."

    # end the timer
    pathloss_end_time = time.time() - pathtest_start_time
    if pathloss_end_time > 60.0:
        pathloss_end_time = pathloss_end_time / 60.0  # in minutes
        pathloss_tot_time = "* Script IFU_PS.py took ", repr(
            pathloss_end_time) + " minutes to finish."
        if pathloss_end_time > 60.0:
            pathloss_end_time = pathloss_end_time / 60.  # in hours
            pathloss_tot_time = "* Script IFU_PS.py took ", repr(
                pathloss_end_time) + " hours to finish."
    else:
        pathloss_tot_time = "* Script IFU_PS.py took ", repr(
            pathloss_end_time) + " seconds to finish."
    print(pathloss_tot_time)
    log_msgs.append(pathloss_tot_time)

    return FINAL_TEST_RESULT, result_msg, log_msgs
예제 #12
0
    hd = fits.open('data/' + disk_name[i] + '_resid.JvMcorr.fits')[0].header
    nx, ny = hd['NAXIS1'], hd['NAXIS2']
    RAo = 3600 * hd['CDELT1'] * (np.arange(nx) - (hd['CRPIX1'] - 1))
    DECo = 3600 * hd['CDELT2'] * (np.arange(ny) - (hd['CRPIX2'] - 1))
    dRA, dDEC = np.meshgrid(RAo - offs[0], DECo - offs[1])

    # beam parameters
    bmaj, bmin, bPA = 3600 * hd['BMAJ'], 3600 * hd['BMIN'], hd['BPA']
    barea = (np.pi * bmaj * bmin / (4 * np.log(2))) / (3600 * 180 / np.pi)**2

    # image setups
    im_bounds = (dRA.max(), dRA.min(), dDEC.min(), dDEC.max())
    dRA_lims, dDEC_lims = [rs * rout, -rs * rout], [-rs * rout, rs * rout]

    # intensity limits, and stretch
    norm = ImageNormalize(vmin=-vspan, vmax=vspan, stretch=LinearStretch())

    # set location in figure
    ax = fig.add_subplot(gs[np.floor_divide(i, 2), i % 2])

    # load image
    hdu = fits.open('data/' + disk_name[i] + '_resid.JvMcorr.fits')
    img = 1e6 * np.squeeze(hdu[0].data) / 6.

    # plot the image (in uJy / beam units)
    im = ax.imshow(img,
                   origin='lower',
                   cmap=cmap,
                   extent=im_bounds,
                   norm=norm,
                   aspect='equal')
예제 #13
0
xd = (dRA * np.cos(PAr) - dDEC * np.sin(PAr)) / np.cos(inclr)
yd = (dRA * np.sin(PAr) + dDEC * np.cos(PAr))
r, theta = np.sqrt(xd**2 + yd**2), np.degrees(np.arctan2(yd, xd))

# beam parameters
bmaj, bmin, bPA = 3600 * hd['BMAJ'], 3600 * hd['BMIN'], hd['BPA']
beam_area = (np.pi * bmaj * bmin / (4 * np.log(2))) / (3600 * 180 / np.pi)**2

# image setups
rout = disk.disk[target]['rout']
im_bounds = (dRA.max(), dRA.min(), dDEC.min(), dDEC.max())
dRA_lims, dDEC_lims = [1.2 * rout, -1.2 * rout], [-1.2 * rout, 1.2 * rout]

# intensity limits, and stretch
norm = ImageNormalize(vmin=0,
                      vmax=disk.disk[target]['maxTb'],
                      stretch=AsinhStretch())
cmap = 'inferno'

### Plot the data image
plt.style.use('default')
fig = plt.figure(figsize=(7.0, 5.9))
gs = gridspec.GridSpec(1, 2, width_ratios=(1, 0.04))

# image (sky-plane)
ax = fig.add_subplot(gs[0, 0])
Tb = (1e-23 * dimg / beam_area) * c_**2 / (2 * k_ * freq**2)
im = ax.imshow(Tb,
               origin='lower',
               cmap=cmap,
               extent=im_bounds,
예제 #14
0
def plot_all_params(image: Union[str, fits.hdu.HDUList], cat: Union[str, Table, np.ndarray], show: bool = True,
                    cutout: bool = False, ra_key: str = "ALPHA_SKY", dec_key: str = "DELTA_SKY", a_key: str = "A_WORLD",
                    b_key: str = "B_WORLD", theta_key: str = "THETA_WORLD", kron: bool = False,
                    kron_key: str = "KRON_RADIUS"):
    """
    Plots
    :param image:
    :param cat:
    :param cutout: Plots a cutout of the image centre. Forces 'show' to True.
    :return:
    """

    if cutout:
        show = True

    image, path = ff.path_or_hdu(image)

    data = image[0].data

    wcs_image = wcs.WCS(header=image[0].header)

    plt.subplot(projection=wcs_image)
    norm = ImageNormalize(data, interval=ZScaleInterval(), stretch=SqrtStretch())
    plt.imshow(data, origin='lower', norm=norm, )
    plot_gal_params(hdu=image, ras=cat[ra_key], decs=cat[dec_key], a=cat[a_key],
                    b=cat[b_key],
                    theta=cat[theta_key], colour='red')
    if kron:
        plot_gal_params(hdu=image, ras=cat[ra_key], decs=cat[dec_key], a=cat[kron_key] * cat[a_key],
                        b=cat[kron_key] * cat[b_key],
                        theta=cat[theta_key], colour='violet')

    if show:
        plt.show()

    if cutout:
        n_x = data.shape[1]
        n_y = data.shape[0]

        mid_x = int(n_x / 2)
        mid_y = int(n_y / 2)

        left = mid_x - 45
        right = mid_x + 45
        bottom = mid_y - 45
        top = mid_y + 45

        gal = ff.trim(hdu=image, left=left, right=right, bottom=bottom, top=top)
        plt.imshow(gal[0].data)
        plot_gal_params(hdu=gal, ras=cat[ra_key], decs=cat[dec_key], a=cat[a_key],
                        b=cat[b_key],
                        theta=cat[theta_key], colour='red')
        if kron:
            plot_gal_params(hdu=image, ras=cat[ra_key], decs=cat[dec_key], a=cat[kron_key] * cat[a_key],
                            b=cat[kron_key] * cat[b_key],
                            theta=cat[theta_key], colour='violet')

        if show:
            plt.show()

    if path:
        image.close()
예제 #15
0
def pathtest(step_input_filename,
             reffile,
             comparison_filename,
             writefile=True,
             show_figs=False,
             save_figs=True,
             threshold_diff=1e-7,
             debug=False):
    """
    This function calculates the difference between the pipeline and the
    calculated pathloss values. The functions use the output of the
    compute_world_coordinates.py script.
    Args:
        step_input_filename: str, name of the output fits file from the
        source type step (with full path) 
        reffile: str, path to the pathloss FS reference fits file
        comparison_filename: str, path to comparison pipeline pathloss file
        writefile: boolean, if True writes the fits files of the calculated
        pathloss and difference image.
        show_figs: boolean, whether to show plots or not
        save_figs: boolean, save the plots (the 3 plots can be saved or not
        independently with the function call)
        function will name the plot by default)
        threshold_diff: float, threshold difference between pipeline output
        and comparison file
        debug: boolean, if true a series of print statements will show
        on-screen
    Returns:
        - 1 plot, if told to save and/or show them.
        - median_diff: Boolean, True if smaller or equal to threshold.
        - log_msgs: list, all print statements are captured in this variable
    """

    log_msgs = []

    # start the timer
    pathtest_start_time = time.time()

    # get info from the rate file header
    det = fits.getval(step_input_filename, "DETECTOR", 0)
    msg = 'step_input_filename=' + step_input_filename
    print(msg)
    log_msgs.append(msg)
    exptype = fits.getval(step_input_filename, "EXP_TYPE", 0)
    grat = fits.getval(step_input_filename, "GRATING", 0)
    filt = fits.getval(step_input_filename, "FILTER", 0)

    msg = "path_loss_file  -->  Grating:" + grat + "   Filter:" + filt + "   EXP_TYPE:" + exptype
    print(msg)
    log_msgs.append(msg)

    is_point_source = True

    # get the datamodel from the assign_wcs output file
    extract2d_wcs_file = step_input_filename.replace("srctype.fits",
                                                     "extract_2d.fits")
    model = datamodels.MultiSlitModel(extract2d_wcs_file)

    if writefile:
        # create the fits list to hold the calculated pathloss values for
        # each slit
        hdu0 = fits.PrimaryHDU()
        outfile = fits.HDUList()
        outfile.append(hdu0)

        # create the fits list to hold the image of pipeline-calculated
        # difference values
        hdu0 = fits.PrimaryHDU()
        compfile = fits.HDUList()
        compfile.append(hdu0)

    # list to determine if pytest is passed or not
    total_test_result = []

    # loop over the slits
    sltname_list = ["S200A1", "S200A2", "S400A1", "S1600A1"]
    msg = "Now looping through the slits. This may take a while... "
    print(msg)
    log_msgs.append(msg)
    if det == "NRS2":
        sltname_list.append("S200B1")

    # but check if data is BOTS
    if fits.getval(step_input_filename, "EXP_TYPE", 0) == "NRS_BRIGHTOBJ":
        sltname_list = ["S1600A1"]

    # get all the science extensions
    ps_uni_ext_list = get_ps_uni_extensions(reffile, is_point_source)

    # get files
    print("""Checking if files exist & obtaining datamodels.
          This takes a few minutes...""")
    if os.path.isfile(comparison_filename):
        if debug:
            print('Comparison file does exist.')
    else:
        result_msg = 'Comparison file does NOT exist. Skipping pathloss test.'
        print(result_msg)
        log_msgs.append(result_msg)
        result = 'skip'
        return result, result_msg, log_msgs

    # get the comparison data model
    pathloss_pipe = datamodels.open(comparison_filename)
    # For the moment, the pipeline is using the wrong reference file for slit 400A1, so read file that
    # re-processed with the right reference file and open corresponding data model
    # BUT we are skipping this since this is not in the released candidate of the pipeline
    """
    pathloss_400a1 = step_input_filename.replace("srctype.fits", "pathloss_400A1.fits")
    pathloss_pipe_400a1 = datamodels.open(pathloss_400a1)
    """
    if debug:
        print('got comparison datamodel!')

    if os.path.isfile(step_input_filename):
        if debug:
            print('Input file does exist.')
    else:
        result_msg = 'Input file does NOT exist. Skipping pathloss test.'
        log_msgs.append(result_msg)
        result = 'skip'
        return result, result_msg, log_msgs
    # get the input data model
    pl = datamodels.open(step_input_filename)
    if debug:
        print('got input datamodel!')

    # loop through the wavelengths
    msg = "Looping through the wavelengths... "
    print(msg)
    log_msgs.append(msg)

    slit_val = 0
    for slit, pipe_slit in zip(pl.slits, pathloss_pipe.slits):
        slit_val = slit_val + 1

        slit_id = pipe_slit.name
        # with the current reference file, skip S400A1
        #if slit_id == "S400A1":
        #    continue
        print('\nWorking with slitlet ', slit_id)

        if slit.name == slit_id:
            msg = """Slitlet name in fits file previous to pathloss
            and in pathloss output file are the same.
            """
            log_msgs.append(msg)
            print(msg)
        else:
            msg = """* Missmatch of slitlet names in fits file previous to
            pathloss and in pathloss output file. Skipping test.
            """
            result = 'skip'
            log_msgs.append(msg)
            return result, msg, log_msgs

        # S-flat
        mode = "FS"

        if debug:
            print("grat = ", grat)

        continue_pl_test = False
        if fits.getval(step_input_filename, "EXP_TYPE", 0) == "NRS_BRIGHTOBJ":
            slit = model
            continue_pl_test = True
        else:
            for slit_in_MultiSlitModel in pl.slits:
                if slit_in_MultiSlitModel.name == slit_id:
                    slit = slit_in_MultiSlitModel
                    continue_pl_test = True
                    break

        if not continue_pl_test:
            continue
        else:
            try:
                if is_point_source is True:
                    ext = ps_uni_ext_list[0][slit_id]
                    print("Retrieved point source extension")
                elif is_point_source is False:
                    ext = ps_uni_ext_list[1][slit_id]
                    print("WARNING: Retrieved extended source extension")
            except KeyError:
                ext = sltname_list.index(slit_id)
                print("Unable to retrieve extension.")

        wcs_obj = slit.meta.wcs

        # get the wavelength
        x, y = wcstools.grid_from_bounding_box(wcs_obj.bounding_box,
                                               step=(1, 1),
                                               center=True)
        ra, dec, wave = wcs_obj(x, y)  # wave is in microns
        wave_sci = wave * 10**(-6)  # microns --> meters

        # get positions of source in file:
        slit_x = slit.source_xpos
        slit_y = slit.source_ypos
        if debug:
            print("slit_x, slit_y", slit_x, slit_y)
        """
        if slit_id == "S400A1":
            if is_point_source:
                ext = 1
            else:
                ext = 3
            if os.path.isfile("jwst-nirspec-a400.plrf.fits"):
                reffile2use = "jwst-nirspec-a400.plrf.fits"
            else:
                msg = "Skipping slit S400A1 because reference file not present"
                print(msg)
                log_msgs.append(msg)
                continue
        else:"""
        reffile2use = reffile

        msg = "Using reference file: " + reffile2use
        print(msg)
        log_msgs.append(msg)

        plcor_ref_ext = fits.getdata(reffile2use, ext)
        if debug:
            print("ext:", ext)
        hdul = fits.open(reffile2use)
        plcor_ref = hdul[1].data
        w = wcs.WCS(hdul[1].header)

        # make cube
        w1, y1, x1 = np.mgrid[:plcor_ref.shape[0], :plcor_ref.
                              shape[1], :plcor_ref.shape[2]]
        slitx_ref, slity_ref, wave_ref = w.all_pix2world(x1, y1, w1, 0)

        previous_sci = slit.data
        pipe_correction = pipe_slit.pathloss
        """
        if slit_id == "S400A1":
            for pipe_slit_400a1 in pathloss_pipe_400a1.slits:
                if pipe_slit_400a1.name == "S400A1":
                    pipe_correction = pipe_slit_400a1.pathloss
                    break
                else:
                    continue
        """
        if len(pipe_correction) == 0:
            print(
                "Pipeline pathloss correction in datamodel is empty. Skipping testing this slit."
            )
            continue

        # Set up manually to test correction at nonzero point
        # slit_x = 0.2
        # slit_y = 0.2
        # if debug:
        #     print("""WARNING: Using manually set slit_x and slit_y!
        # The pipeline correction will not use manually set values and
        # thus the residuals will change
        # """)

        correction_array = np.array([])
        lambda_array = np.array([])

        wave_sci_flat = wave_sci.reshape(wave_sci.size)
        wave_ref_flat = wave_ref.reshape(wave_ref.size)

        ref_xy = np.column_stack((slitx_ref.reshape(slitx_ref.size),
                                  slity_ref.reshape(slitx_ref.size)))

        # loop through slices in lambda from reference file
        shape = 0
        for lambda_val in wave_ref_flat:
            # loop through every lambda value
            # flattened so that looping works smoothly
            shape = shape + 1
            index = np.where(wave_ref[:, 0, 0] == lambda_val)
            # index of closest lambda value in reffile to given sci lambda
            #   took index of only the first slice of wave_ref because
            #   the others were repetitive & we got extra indices
            # take slice where lambda=index:
            plcor_slice = plcor_ref_ext[index[0][0]].reshape(
                plcor_ref_ext[index[0][0]].size)
            # do 2d interpolation to get a single correction factor for each slice
            corr_val = scipy.interpolate.griddata(ref_xy[:plcor_slice.size],
                                                  plcor_slice,
                                                  np.asarray([slit_x, slit_y]),
                                                  method='linear')
            # append values from loop to create a vector of correction factors
            correction_array = np.append(correction_array, corr_val[0])
            # map to array with corresponding lambda
            lambda_array = np.append(lambda_array, lambda_val)

        # get correction value for each pixel
        corr_vals = np.interp(wave_sci_flat, lambda_array, correction_array)
        corr_vals = corr_vals.reshape(wave_sci.shape)
        corrected_array = previous_sci / corr_vals

        # set up generals for all the plots
        font = {'weight': 'normal', 'size': 7}
        matplotlib.rc('font', **font)

        # Plots:
        step_input_filepath = step_input_filename.replace(".fits", "")
        # my correction values
        fig = plt.figure()
        plt.subplot(221)
        norm = ImageNormalize(corr_vals)
        plt.imshow(corr_vals,
                   norm=norm,
                   aspect=10.0,
                   origin='lower',
                   cmap='viridis')
        plt.xlabel('dispersion in pixels')
        plt.ylabel('y in pixels')
        plt.title('Calculated Correction')
        plt.colorbar()
        # pipe correction
        plt.subplot(222)
        norm = ImageNormalize(pipe_correction)
        plt.imshow(pipe_correction,
                   norm=norm,
                   aspect=10.0,
                   origin='lower',
                   cmap='viridis')
        plt.title("Pipeline Correction")
        plt.xlabel('dispersion in pixels')
        plt.ylabel('y in pixels')
        plt.colorbar()
        # residuals (pipe correction - my correction)
        corr_residuals = pipe_correction - corr_vals
        plt.subplot(223)
        norm = ImageNormalize(corr_residuals)
        plt.imshow(corr_residuals,
                   norm=norm,
                   aspect=10.0,
                   origin='lower',
                   cmap='viridis')
        plt.xlabel('dispersion in pixels')
        plt.ylabel('y in pixels')
        plt.title('Correction residuals')
        plt.colorbar()
        # my science data after
        plt.subplot(224)
        norm = ImageNormalize(corrected_array)
        plt.imshow(corrected_array,
                   norm=norm,
                   aspect=10.0,
                   origin='lower',
                   cmap='viridis')
        plt.title('My science data after pathloss')
        plt.xlabel('dispersion in pixels')
        plt.ylabel('y in pixels')
        plt.colorbar()
        fig.suptitle("FS PS Pathloss Correction Test for slit " + str(slit_id))

        if save_figs:
            plt_name = step_input_filepath + "_Pathloss_test_slitlet_"+str(mode) + "_" + str(slit_id) + "_" + \
                       str(slit_val) + ".png"
            plt.savefig(plt_name)
            print('Figure saved as: ', plt_name)
        if show_figs:
            plt.show()
        elif not save_figs and not show_figs:
            msg = "Not making plots because both show_figs and save_figs were set to False."
            if debug:
                print(msg)
            log_msgs.append(msg)
        elif not save_figs:
            msg = "Not saving plots because save_figs was set to False."
            if debug:
                print(msg)
            log_msgs.append(msg)
        plt.clf()

        # create fits file to hold the calculated pathloss for each slit
        if writefile:
            msg = "Saving the fits files with the calculated pathloss for each slit..."
            print(msg)
            log_msgs.append(msg)

            # this is the file to hold the image of pipeline-calculated difference values
            outfile_ext = fits.ImageHDU(corr_vals, name=slit_id)
            outfile.append(outfile_ext)

            # this is the file to hold the image of pipeline-calculated difference values
            compfile_ext = fits.ImageHDU(corr_residuals, name=slit_id)
            compfile.append(compfile_ext)

        if corr_residuals[~np.isnan(corr_residuals)].size == 0:
            msg1 = " * Unable to calculate statistics because difference array has all values as NaN. " \
                   "Test will be set to FAILED."
            print(msg1)
            log_msgs.append(msg1)
            test_result = "FAILED"
        else:
            msg = "Calculating statistics... "
            print(msg)
            log_msgs.append(msg)
            corr_residuals = corr_residuals[
                np.where((corr_residuals != 999.0) & (corr_residuals < 0.1)
                         & (corr_residuals > -0.1))]  # ignore outliers
            if corr_residuals.size == 0:
                msg1 = " * Unable to calculate statistics because difference array has all outlier values. " \
                       "Test will be set to FAILED."
                print(msg1)
                log_msgs.append(msg1)
                test_result = "FAILED"
            else:
                stats_and_strings = auxfunc.print_stats(corr_residuals,
                                                        "Difference",
                                                        float(threshold_diff),
                                                        absolute=True)
                stats, stats_print_strings = stats_and_strings
                corr_residuals_mean, corr_residuals_median, corr_residuals_std = stats
                for msg in stats_print_strings:
                    log_msgs.append(msg)

                # This is the key argument for the assert pytest function
                median_diff = False
                if abs(corr_residuals_median) <= float(threshold_diff):
                    median_diff = True
                if median_diff:
                    test_result = "PASSED"
                else:
                    test_result = "FAILED"

        msg = " *** Result of the test: " + test_result + "\n"
        print(msg)
        log_msgs.append(msg)
        total_test_result.append(test_result)

    if writefile:
        outfile_name = step_input_filename.replace("srctype",
                                                   "calcuated_pathloss")
        compfile_name = step_input_filename.replace("srctype",
                                                    "comparison_pathloss")

        # create the fits list to hold the calculated pathloss values for each slit
        outfile.writeto(outfile_name, overwrite=True)

        # this is the file to hold the image of pipeline-calculated difference values
        compfile.writeto(compfile_name, overwrite=True)

        msg = "\nFits file with calculated pathloss values of each slit saved as: "
        print(msg)
        log_msgs.append(msg)
        print(outfile_name)
        log_msgs.append(outfile_name)

        msg = "Fits file with comparison (pipeline pathloss - calculated pathloss) saved as: "
        print(msg)
        log_msgs.append(msg)
        print(compfile_name)
        log_msgs.append(compfile_name)

    # If all tests passed then pytest is marked as PASSED, else it is FAILED
    FINAL_TEST_RESULT = False
    for t in total_test_result:
        if t == "FAILED":
            FINAL_TEST_RESULT = False
            break
        else:
            FINAL_TEST_RESULT = True

    if FINAL_TEST_RESULT:
        msg = "\n *** Final pathloss test result reported as PASSED *** \n"
        print(msg)
        log_msgs.append(msg)
        result_msg = "All slits PASSED path_loss test."
    else:
        msg = "\n *** Final pathloss test result reported as FAILED *** \n"
        print(msg)
        log_msgs.append(msg)
        result_msg = "One or more slits FAILED path_loss test."

    # end the timer
    pathloss_end_time = time.time() - pathtest_start_time
    if pathloss_end_time > 60.0:
        pathloss_end_time = pathloss_end_time / 60.0  # in minutes
        pathloss_tot_time = "* Script FS_PS.py took ", repr(
            pathloss_end_time) + " minutes to finish."
        if pathloss_end_time > 60.0:
            pathloss_end_time = pathloss_end_time / 60.  # in hours
            pathloss_tot_time = "* Script FS_PS.py took ", repr(
                pathloss_end_time) + " hours to finish."
    else:
        pathloss_tot_time = "* Script FS_PS.py took ", repr(
            pathloss_end_time) + " seconds to finish."
    print(pathloss_tot_time)
    log_msgs.append(pathloss_tot_time)

    return FINAL_TEST_RESULT, result_msg, log_msgs
예제 #16
0
    def plot_hst_loc_cartopy(self,
                             i=5,
                             df=None,
                             title='',
                             thresh=5,
                             fout='',
                             min_exptime=800,
                             key='start',
                             save=False,
                             orbital_path1=None,
                             orbital_path2=None,
                             projection=ccrs.PlateCarree()):
        fig, ax = plt.subplots(nrows=1,
                               ncols=1,
                               figsize=(8, 7),
                               tight_layout=True,
                               subplot_kw={'projection': projection})
        crs = projection
        transform = crs._as_mpl_transform(ax)
        df = df[df.integration_time.gt(min_exptime)]
        df.sort_values(by='incident_cr_rate', inplace=True)

        # Plot configuration
        ax.coastlines()
        gl = ax.gridlines(crs=crs,
                          draw_labels=True,
                          linewidth=1,
                          color='k',
                          alpha=0.4,
                          linestyle='--')
        fname = '/ifs/missions/projects/plcosmic/hst_cosmic_rays/APJ_plots/HYP_50M_SR_W.tif'
        ax.imshow(plt.imread(fname),
                  origin='upper',
                  transform=crs,
                  extent=[-180, 180, -90, 90])
        gl.xlabels_top = False
        gl.ylabels_left = True
        gl.ylabels_right = False
        gl.xlines = True
        # gl.xlocator = mticker.FixedLocator([-180, -45, 0, 45, 180])
        gl.xformatter = LONGITUDE_FORMATTER
        gl.yformatter = LATITUDE_FORMATTER
        gl.xlocator = MultipleLocator(60)
        gl.ylocator = MultipleLocator(15)
        gl.xlabel_style = {'size': 10, 'color': 'black'}
        gl.xlabel_style = {'color': 'black'}

        date = 2005
        altitude = 565

        # Calculate the B field grid
        # Evenly space grid with 1 degree resolution in both Latitude and Longitude
        lat = np.linspace(-90, 90, 1 * 180 + 1)
        lon = np.linspace(0, 360, 1 * 360 + 1)
        lat_grid, lon_grid = np.meshgrid(lat, lon)
        coordinates = list(zip(lat_grid.ravel(), lon_grid.ravel()))
        B_strength = []
        for coords in coordinates:
            b_field = ipmag.igrf([date, altitude, coords[0], coords[1]])
            B_strength.append(b_field[-1])
        B_strength_grid = np.array(B_strength).reshape(lat_grid.shape)

        # Get the CR rate information
        lat, lon, rate = df['latitude_{}'.format(key)], \
                             df['longitude_{}'.format(key)], \
                             df['incident_cr_rate']
        LOG.info('{} {} {}'.format(len(lat), len(lon), len(rate)))

        # Get average statistics to generate contour
        mean, median, std = sigma_clipped_stats(rate,
                                                sigma_lower=3,
                                                sigma_upper=3)
        LOG.info('{} +\- {}'.format(mean, std))
        norm = ImageNormalize(rate,
                              stretch=LinearStretch(),
                              vmin=mean - thresh * std,
                              vmax=mean + thresh * std)
        cbar_below_mean = [mean - (i + 1) * std for i in range(thresh)]
        cbar_above_mean = [mean + (i + 1) * std for i in range(thresh)]

        cbar_bounds = cbar_below_mean + [mean] + cbar_above_mean
        print(cbar_bounds)
        cbar_bounds.sort()
        sci_cmap = plt.cm.viridis
        custom_norm = colors.BoundaryNorm(boundaries=cbar_bounds,
                                          ncolors=sci_cmap.N)

        scat = ax.scatter(lon.values,
                          lat.values,
                          marker='o',
                          s=3.5,
                          c=rate,
                          alpha=0.2,
                          norm=custom_norm,
                          cmap='viridis',
                          transform=ccrs.PlateCarree())

        cbar_ticks = cbar_bounds
        cax = fig.add_axes([0.1, 0.2, 0.8, 0.05])
        cbar = fig.colorbar(scat,
                            cax=cax,
                            ticks=cbar_ticks,
                            orientation='horizontal')
        cbar.set_alpha(1)
        cbar.draw_all()
        cbar_tick_labels = [f'<x>-{i}$\sigma$' for i in [5, 4, 3, 2, 1]] + [
            '<x>'
        ] + [f'<x>+{i}$\sigma$' for i in [1, 2, 3, 4, 5]]
        cbar.ax.set_xticklabels(cbar_tick_labels,
                                horizontalalignment='right',
                                rotation=30)

        cbar.set_label('CR Flux [CR/s/$cm^2$]', fontsize=10)

        cntr = ax.contour(lon_grid,
                          lat_grid,
                          B_strength_grid,
                          cmap='plasma',
                          levels=10,
                          alpha=1,
                          lw=2,
                          transform=ccrs.PlateCarree())

        h1, l1 = cntr.legend_elements("B_strength_grid")
        l1_custom = [
            f"{val.split('=')[-1].strip('$').strip()} nT" for val in l1
        ]

        leg1 = Legend(ax,
                      h1,
                      l1_custom,
                      loc='upper left',
                      edgecolor='k',
                      fontsize=8,
                      framealpha=0.45,
                      facecolor='tab:gray',
                      bbox_to_anchor=(1.05, 1.03),
                      title='Total Magnetic Intensity')
        ax.add_artist(leg1)

        if orbital_path1 is not None:
            ax.scatter(orbital_path1.metadata['longitude'][::4][1:],
                       orbital_path1.metadata['latitude'][::4][1:],
                       c='k',
                       s=20,
                       label='285 seccond interval')

        if orbital_path2 is not None:
            ax.plot(orbital_path2.metadata['longitude'],
                    orbital_path2.metadata['latitude'],
                    label=f'Orbital Path Over {2000:.0f} seconds',
                    color='k',
                    ls='--',
                    lw=1.25)
        plt.show()
        return fig
예제 #17
0
    a.jsoc.Notify(jsoc_email),
    a.jsoc.Segment.image,
    cutout,
)

#####################################################
# Submit the export request and download the data.

files = Fido.fetch(query)
files.sort()

#####################################################
# Now that we've downloaded the files, we can create
# a `~sunpy.map.MapSequence` from them.

sequence = sunpy.map.Map(files, sequence=True)

#####################################################
# Finally, we can construct an animation in time from
# our stack of cutouts and interactively flip through
# each image in our sequence. We first adjust the plot
# settings on each image to ensure the colorbar is the
# same at each time step.

for each_map in sequence:
    each_map.plot_settings['norm'] = ImageNormalize(vmin=0, vmax=5e3, stretch=SqrtStretch())
plt.figure()
ani = sequence.plot()

plt.show()
예제 #18
0
def defringeflat(flat_file,
                 wbin=10,
                 start_col=10,
                 end_col=980,
                 clip1=0,
                 diagnostic=True,
                 save_to_path=None,
                 filename=None):
    """
	This function is to remove the fringe pattern using
	the method described in Rojo and Harrington (2006).

	Use a fifth order polynomial to remove the continuum.

	Parameters
	----------
	flat_file 		: 	fits
						original flat file

	Optional Parameters
	-------------------
	wbin 			:	int
						the bin width to calculate each 
						enhance row
						Default is 32

	start_col 		: 	int
						starting column number for the
						wavelet analysis
						Default is 10

	end_col 		: 	int
						ending column number for the
						wavelet analysis
						Default is 980

	diagnostic 		: 	boolean
						output the diagnostic plots
						Default is True

	Returns
	-------
	defringe file 	: 	fits
						defringed flat file

	"""
    # the path to save the diagnostic plots
    #save_to_path = 'defringeflat/allflat/'

    #print(flat_file)

    data = fits.open(flat_file, ignore_missing_end=True)

    # Use the data to figure out the values to mask through the image (low counts/order edges)
    hist, bins = np.histogram(data[0].data.flatten(),
                              bins=int(np.sqrt(len(data[0].data.flatten()))))
    bins = bins[0:-1]
    index1 = np.where((bins > np.percentile(data[0].data.flatten(), 10))
                      & (bins < np.percentile(data[0].data.flatten(), 30)))
    try:
        lowval = bins[index1][np.where(hist[index1] == np.min(hist[index1]))]
        #print(lowval, len(lowval))
        if len(lowval) >= 2: lowval = np.min(lowval)
    except:
        lowval = 0  #if no values for index1

    flat = data

    # initial flat plot
    if diagnostic is True:

        # Save the images to a separate folder
        save_to_image_path = save_to_path + '/images/'
        if not os.path.exists(save_to_image_path):
            os.makedirs(save_to_image_path)

        fig = plt.figure(figsize=(8, 8))
        fig.suptitle("original flat", fontsize=12)
        gs = gridspec.GridSpec(2, 1, height_ratios=[6, 1])
        ax0 = plt.subplot(gs[0])
        # Create an ImageNormalize object
        norm = ImageNormalize(flat[0].data, interval=ZScaleInterval())
        ax0.imshow(flat[0].data,
                   cmap='gray',
                   norm=norm,
                   origin='lower',
                   aspect='auto')
        ax0.set_ylabel("Row number")
        ax1 = plt.subplot(gs[1], sharex=ax0)
        ax1.plot(flat[0].data[60, :],
                 'k-',
                 alpha=0.5,
                 label='60th row profile')
        ax1.set_ylabel("Amp (DN)")
        ax1.set_xlabel("Column number")
        plt.legend()
        plt.savefig(save_to_image_path + "defringeflat_{}_0_original_flat.png"\
                 .format(filename), bbox_inches='tight')
        plt.close()

    defringeflat_img = data
    defringe_data = np.array(defringeflat_img[0].data, dtype=float)

    for k in np.arange(0, 1024 - wbin, wbin):
        #print(k)
        #if k != 310: continue
        """
		# Use the data to figure out the values to mask through the image (low counts/order edges)
		hist, bins = np.histogram(flat[0].data[k:k+wbin+1, 0:1024-clip1].flatten(), 
			                      bins=int(np.sqrt(len(flat[0].data[k:k+wbin+1, 0:1024-clip1].flatten()))))
		bins       = bins[0:-1]
		index1     = np.where( (bins > np.percentile(flat[0].data[k:k+wbin+1, 0:1024-clip1].flatten(), 10)) & 
			                   (bins < np.percentile(flat[0].data[k:k+wbin+1, 0:1024-clip1].flatten(), 30)) )
		lowval     = bins[index1][np.where(hist[index1] == np.min(hist[index1]))]
		
		#print(lowval, len(lowval))
		if len(lowval) >= 2: lowval = np.min(lowval)
		"""
        # Find the mask
        mask = np.zeros(flat[0].data[k:k + wbin + 1, 0:1024 - clip1].shape)
        baddata = np.where(
            flat[0].data[k:k + wbin + 1, 0:1024 - clip1] <= lowval)
        mask[baddata] = 1

        # extract the patch from the fits file
        #flat_patch = np.ma.array(flat[0].data[k:k+wbin,:], mask=mask)
        flat_patch = np.array(flat[0].data[k:k + wbin + 1, 0:1024 - clip1])

        # median average the selected region in the order
        flat_patch_median = np.ma.median(flat_patch, axis=0)

        # continuum fit
        # smooth the continuum (Chris's method)
        smoothed = sp.ndimage.uniform_filter1d(flat_patch_median, 30)
        splinefit = sp.interpolate.interp1d(np.arange(len(smoothed)),
                                            smoothed,
                                            kind='cubic')
        cont_fit = splinefit(np.arange(0, 1024 - clip1))  #smoothed

        # Now fit a polynomial
        #pcont     = np.ma.polyfit(np.arange(0, 1024-clip1),
        #	                      cont_fit, 10)
        #cont_fit2 = np.polyval(pcont, np.arange(0,1024))

        #plt.plot(flat_patch_median, c='r')
        #plt.plot(smoothed, c='b')
        #plt.savefig(save_to_image_path + "TEST.png", bbox_inches='tight')
        #plt.close()
        #plt.show()
        #sys.exit()

        #pcont    = np.ma.polyfit(np.arange(start_col,end_col),
        #	                     flat_patch_median[start_col:end_col],10)
        #cont_fit = np.polyval(pcont, np.arange(0,1024))

        # use wavelets package: WaveletAnalysis
        enhance_row = flat_patch_median - cont_fit

        dt = 0.1
        wa = WaveletAnalysis(enhance_row[start_col:end_col], dt=dt)
        # wavelet power spectrum
        power = wa.wavelet_power
        # scales
        cales = wa.scales
        # associated time vector
        t = wa.time
        # reconstruction of the original data
        rx = wa.reconstruction()

        # reconstruct the fringe image
        reconstruct_image = np.zeros(defringe_data[k:k + wbin + 1,
                                                   0:1024 - clip1].shape)
        for i in range(wbin + 1):
            for j in np.arange(start_col, end_col):
                reconstruct_image[i, j] = rx[j - start_col]

        defringe_data[k:k + wbin + 1,
                      0:1024 - clip1] -= reconstruct_image[0:1024 - clip1]

        # Add in something for the edges/masked out data in the reconstructed image
        defringe_data[k:k + wbin + 1, 0:1024 -
                      clip1][baddata] = flat[0].data[k:k + wbin + 1,
                                                     0:1024 - clip1][baddata]

        #print("{} row starting {} is done".format(filename,k))

        # diagnostic plots
        if diagnostic is True:
            print("Generating diagnostic plots")
            # middle cut plot
            fig = plt.figure(figsize=(10, 6))
            fig.suptitle("middle cut at row {}".format(k + wbin // 2),
                         fontsize=12)
            ax1 = fig.add_subplot(2, 1, 1)

            norm = ImageNormalize(flat_patch, interval=ZScaleInterval())
            ax1.imshow(flat_patch,
                       cmap='gray',
                       norm=norm,
                       origin='lower',
                       aspect='auto')
            ax1.set_ylabel("Row number")
            ax2 = fig.add_subplot(2, 1, 2, sharex=ax1)
            ax2.plot(flat_patch[wbin // 2, :], 'k-', alpha=0.5)
            ax2.set_ylabel("Amp (DN)")
            ax2.set_xlabel("Column number")

            plt.tight_layout()
            plt.subplots_adjust(top=0.85, hspace=0.5)
            plt.savefig(save_to_image_path + \
             'defringeflat_{}_flat_start_row_{}_middle_profile.png'\
             .format(filename,k), bbox_inches='tight')
            plt.close()

            # continuum fit plot
            fig = plt.figure(figsize=(10, 6))
            fig.suptitle("continuum fit row {}-{}".format(k, k + wbin),
                         fontsize=12)
            gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1])
            ax0 = plt.subplot(gs[0])
            ax0.plot(flat_patch_median,
                     'k-',
                     alpha=0.5,
                     label='mean average patch')
            ax0.plot(cont_fit, 'r-', alpha=0.5, label='continuum fit')
            #ax0.plot(cont_fit2,'m-', alpha=0.5, label='continuum fit poly')
            ax0.set_ylabel("Amp (DN)")
            plt.legend()
            ax1 = plt.subplot(gs[1])
            ax1.plot(flat_patch_median - cont_fit,
                     'k-',
                     alpha=0.5,
                     label='residual')
            ax1.set_ylabel("Amp (DN)")
            ax1.set_xlabel("Column number")
            plt.legend()

            plt.tight_layout()
            plt.subplots_adjust(top=0.85, hspace=0.5)
            plt.savefig(save_to_image_path + \
                       "defringeflat_{}_start_row_{}_continuum_fit.png".\
                     format(filename,k), bbox_inches='tight')
            #plt.show()
            #sys.exit()
            plt.close()

            # enhance row vs. reconstructed wavelet plot
            try:
                fig = plt.figure(figsize=(10, 6))
                fig.suptitle("reconstruct fringe comparison row {}-{}".\
                          format(k,k+wbin), fontsize=10)
                ax1 = fig.add_subplot(3, 1, 1)

                ax1.set_title('enhance_row start row')
                ax1.plot(enhance_row,
                         'k-',
                         alpha=0.5,
                         label="enhance_row start row {}".format(k))
                ax1.set_ylabel("Amp (DN)")
                #plt.legend()

                ax2 = fig.add_subplot(3, 1, 2, sharex=ax1)
                ax2.set_title('reconstructed fringe pattern')
                ax2.plot(rx,
                         'k-',
                         alpha=0.5,
                         label='reconstructed fringe pattern')
                ax2.set_ylabel("Amp (DN)")
                #plt.legend()

                ax3 = fig.add_subplot(3, 1, 3, sharex=ax1)
                ax3.set_title('residual')
                ax3.plot(enhance_row[start_col:end_col] - rx,
                         'k-',
                         alpha=0.5,
                         label='residual')
                ax3.set_ylabel("Amp (DN)")
                ax3.set_xlabel("Column number")
                #plt.legend()
                plt.tight_layout()
                plt.subplots_adjust(top=0.85, hspace=0.5)
                plt.savefig(save_to_image_path + \
                         "defringeflat_{}_start_row_{}_reconstruct_profile.png".\
                         format(filename,k), bbox_inches='tight')
                plt.close()
            except RuntimeError:
                print("CANNOT GENERATE THE PLOT defringeflat\
					_{}_start_row_{}_reconstruct_profile.png"            \
                 .format(filename,k))
                pass

            # reconstruct image comparison plot
            fig = plt.figure(figsize=(10, 6))
            fig.suptitle("reconstructed image row {}-{}".\
                      format(k,k+wbin), fontsize=12)

            ax1 = fig.add_subplot(3, 1, 1)
            ax1.set_title('raw flat image')
            norm = ImageNormalize(flat_patch, interval=ZScaleInterval())
            ax1.imshow(flat_patch,
                       cmap='gray',
                       norm=norm,
                       origin='lower',
                       label='raw flat image',
                       aspect='auto')
            ax1.set_ylabel("Row number")
            #plt.legend()

            ax2 = fig.add_subplot(3, 1, 2, sharex=ax1)
            ax2.set_title('reconstructed fringe image')
            norm = ImageNormalize(reconstruct_image, interval=ZScaleInterval())
            ax2.imshow(reconstruct_image,
                       cmap='gray',
                       norm=norm,
                       origin='lower',
                       label='reconstructed fringe image',
                       aspect='auto')
            ax2.set_ylabel("Row number")
            #plt.legend()

            ax3 = fig.add_subplot(3, 1, 3, sharex=ax1)
            ax3.set_title('residual')
            norm = ImageNormalize(flat_patch - reconstruct_image,
                                  interval=ZScaleInterval())
            ax3.imshow(flat_patch - reconstruct_image,
                       norm=norm,
                       origin='lower',
                       cmap='gray',
                       label='residual',
                       aspect='auto')
            ax3.set_ylabel("Row number")
            ax3.set_xlabel("Column number")
            #plt.legend()
            plt.tight_layout()
            plt.subplots_adjust(top=0.85, hspace=0.5)
            plt.savefig(save_to_image_path + \
                     "defringeflat_{}_start_row_{}_reconstruct_image.png".\
                     format(filename,k), bbox_inches='tight')
            plt.close()

            # middle residual comparison plot
            fig = plt.figure(figsize=(10, 6))
            fig.suptitle("middle row comparison row {}-{}".\
                      format(k,k+wbin), fontsize=12)

            ax1 = fig.add_subplot(3, 1, 1)
            ax1.plot(flat_patch[wbin // 2, :],
                     'k-',
                     alpha=0.5,
                     label='original flat row {}'.format(k + wbin / 2))
            ax1.set_ylabel("Amp (DN)")
            plt.legend()

            ax2 = fig.add_subplot(3, 1, 2, sharex=ax1)
            ax2.plot(flat_patch[wbin//2,:]-\
                  reconstruct_image[wbin//2,:],'k-',
                  alpha=0.5, label='defringed flat row {}'.format(k+wbin/2))
            ax2.set_ylabel("Amp (DN)")
            plt.legend()

            ax3 = fig.add_subplot(3, 1, 3, sharex=ax1)
            ax3.plot(reconstruct_image[wbin // 2, :],
                     'k-',
                     alpha=0.5,
                     label='difference')
            ax3.set_ylabel("Amp (DN)")
            ax3.set_xlabel("Column number")
            plt.legend()

            plt.tight_layout()
            plt.subplots_adjust(top=0.85, hspace=0.5)
            plt.savefig(save_to_image_path + \
                     "defringeflat_{}_start_row_{}_defringe_middle_profile.png".\
                     format(filename,k), bbox_inches='tight')
            plt.close()

        #if k > 30: sys.exit() # for testing purposes

    # final diagnostic plot
    if diagnostic is True:
        fig = plt.figure(figsize=(8, 8))
        fig.suptitle("defringed flat", fontsize=12)
        gs = gridspec.GridSpec(2, 1, height_ratios=[6, 1])
        ax0 = plt.subplot(gs[0])
        norm = ImageNormalize(defringe_data, interval=ZScaleInterval())
        ax0.imshow(defringe_data,
                   cmap='gray',
                   norm=norm,
                   origin='lower',
                   aspect='auto')
        ax0.set_ylabel("Row number")
        ax1 = plt.subplot(gs[1], sharex=ax0)
        ax1.plot(defringe_data[60, :],
                 'k-',
                 alpha=0.5,
                 label='60th row profile')
        ax1.set_ylabel("Amp (DN)")
        ax1.set_xlabel("Column number")
        plt.legend()

        plt.tight_layout()
        plt.subplots_adjust(top=0.85, hspace=0.5)
        plt.savefig(save_to_image_path + "defringeflat_{}_0_defringe_flat.png"\
         .format(filename), bbox_inches='tight')
        plt.close()

    hdu = fits.PrimaryHDU(data=defringe_data)
    hdu.header = flat[0].header
    return hdu
예제 #19
0
def add_FlatSummary(filenames, MasterFlat, html_fileimagestretch='linear'):
    """
    add bias processing summary to website
    """

    if not os.path.isdir('diagnostics'):
        os.mkdir('diagnostics')
    hdulist = fits.open(MasterFlat)
    framefilename = 'diagnostics/' + MasterFlat + '.png'

    imgdat = hdulist[0].data

    html = '<H2>Flat Processing summary</H2>'
    html += '<p> %d flats have been processed </p>' % \
            (len(filenames))
    html += '<p> Master flat file created as <A HREF=\"%s\">%s</A> </p>' % \
            (framefilename, MasterFlat)

    html +=  ('<p><IMG SRC=\"%s\"></p>')% \
            (framefilename)

    html += '\n'

    html += 'Statistics of the master flat:'

    html += "<P><TABLE BORDER=\"1\">\n<TR>\n"
    html += ("<TR><TD>Mean</TD><TD>%f</TD></TR><TR><TD>Median</TD><TD>%f</TD></TR><TR><TD>Std</TD><TD>%f</TD></TR>\n") % \
            (np.nanmean(imgdat),np.nanmedian(imgdat),np.nanstd(imgdat))
    html += '</TABLE>\n'

    ### Create frame image

    imgdat = hdulist[0].data
    # clip extreme values to prevent crash of imresize
    imgdat = np.clip(imgdat, np.percentile(imgdat, 1),
                     np.percentile(imgdat, 99))
    imgdat = imresize(imgdat,
                      min(1., 1000. / np.max(imgdat.shape)),
                      interp='nearest')
    #resize image larger than 1000px on one side
    norm = ImageNormalize(imgdat,
                          interval=ZScaleInterval(),
                          stretch={
                              'linear': LinearStretch(),
                              'log': LogStretch()
                          }[imagestretch])

    plt.figure(figsize=(5, 5))

    img = plt.imshow(imgdat, cmap='gray', norm=norm, origin='lower')
    # remove axes
    plt.axis('off')
    img.axes.get_xaxis().set_visible(False)
    img.axes.get_yaxis().set_visible(False)

    plt.savefig(framefilename,
                format='png',
                bbox_inches='tight',
                pad_inches=0,
                dpi=200)

    plt.close()
    hdulist.close()
    del (imgdat)

    append_website(html_file,
                   html,
                   replace_below="<H2>Flat Processing summary</H2>\n")

    return None
예제 #20
0
# %%
psf = aiapy.psf.psf(m.wavelength)

# %% [markdown]
# We'll plot just a 500-by-500 pixel section centered on the center pixel. The
# diffraction "arms" extending from the center pixel can often be seen in
# flare observations due to the intense, small-scale brightening.
#
#

# %%
fov = 500
lc_x, lc_y = psf.shape[0] // 2 - fov // 2, psf.shape[1] // 2 - fov // 2
plt.imshow(psf[lc_x:lc_x + fov, lc_y:lc_y + fov],
           norm=ImageNormalize(vmin=1e-8, vmax=1e-3, stretch=LogStretch()))
plt.colorbar()
plt.show()

# %% [markdown]
# Now that we've downloaded our image and computed the PSF, we can deconvolve
# the image with the PSF using the
# `Richardson-Lucy deconvolution algorithm <https://en.wikipedia.org/wiki/Richardson%E2%80%93Lucy_deconvolution>`_.
# Note that passing in the PSF is optional. If you exclude it, it will be
# calculated automatically. However, when deconvolving many images of the same
# wavelength, it is most efficient to only calculate the PSF once.
#
# As with `~aiapy.psf.psf`, this will be much faster if you have
# a GPU and `cupy` installed.
#
#
예제 #21
0
 if args.wcs == 'True' or args.wcs == 'yes':
     fig, ax = plt.subplots(figsize=(7, 7))
     ax = plt.subplot(projection=wcs_object)
     gaia = Irsa.query_region(SkyCoord(hdr[0].header['CRVAL1'] * u.deg,
                                       hdr[0].header['CRVAL2'] * u.deg,
                                       frame='fk5'),
                              catalog="gaia_dr2_source",
                              spatial="Cone",
                              radius=4 * u.arcmin)
     if len(gaia) == 0:
         log.info('No GAIA stars found within 4 arcmin for starlist.')
         plt.close()
     else:
         ax.imshow(sci_med,
                   cmap='gray',
                   norm=ImageNormalize(sci_med, interval=ZScaleInterval()))
         _, median, std = sigma_clipped_stats(sci_med, sigma=3.0)
         daofind = DAOStarFinder(fwhm=7.0, threshold=5. * std)
         sources = daofind(np.asarray(sci_med))
         for l, m in enumerate(gaia['source_id']):
             x, y = (wcs.WCS(hdr[0].header)).all_world2pix(
                 gaia['ra'][l], gaia['dec'][l], 1)
             ax.add_patch(
                 patches.Circle((x, y),
                                radius=4,
                                edgecolor='g',
                                alpha=0.5,
                                facecolor='none',
                                linewidth=2,
                                label='Gaia star: RA = %f, Dec = %f' %
                                (gaia['ra'][l], gaia['dec'][l]),
예제 #22
0
def stack(files: list,
          output: str = None,
          directory: str = '',
          stack_type: str = 'median',
          inherit: bool = True,
          show: bool = False,
          normalise: bool = False):
    accepted_stack_types = ['mean', 'median', 'add']
    if stack_type not in accepted_stack_types:
        raise ValueError('stack_type must be in ' + str(accepted_stack_types))
    if directory != '' and directory[-1] != '/':
        directory = directory + '/'

    data = []
    template = None

    print('Stacking:')

    for f in files:
        # Extract image data and append to list.
        if type(f) is str:
            f = u.sanitise_file_ext(f, 'fits')
            print(' ' + f)
            f = fits.open(directory + f)

        if type(f) is fits.hdu.hdulist.HDUList:
            data_append = f[0].data
            if template is None and inherit:
                # Get a template file to use for output.
                # TODO: Refine this - keep a record of which files went in in the header.
                template = f.copy()
        elif type(f) is CCDData:
            data_append = f.data

        else:
            raise TypeError(
                'files must contain only strings, HDUList or CCDData objects.')

        if normalise:
            data_append = data_append / np.nanmedian(
                data_append[np.isfinite(data_append)])
        if show:
            norm = ImageNormalize(data_append,
                                  interval=ZScaleInterval(),
                                  stretch=SqrtStretch())
            plt.imshow(data_append, origin='lower', norm=norm)
            plt.show()

        data.append(data_append)

    data = np.array(data)
    if stack_type == 'mean':
        stacked = np.mean(data, axis=0)
    elif stack_type == 'median':
        stacked = np.median(data, axis=0)
    else:
        stacked = np.sum(data, axis=0)

    if show:
        norm = ImageNormalize(stacked,
                              interval=ZScaleInterval(),
                              stretch=SqrtStretch())
        plt.imshow(stacked, origin='lower', norm=norm)
        plt.show()

    if inherit:
        template[0].data = stacked
    else:
        template = fits.PrimaryHDU(stacked)
        template = fits.HDUList([template])
    add_log(template, f'Stacked.')
    if output is not None:
        print('Writing stacked image to', output)
        template[0].header['BZERO'] = 0
        template.writeto(output, overwrite=True, output_verify='warn')

    return template
예제 #23
0
def fits_finder_chart(
        fitsfile,
        outfile,
        fitsext=0,
        wcsfrom=None,
        scale=ZScaleInterval(),
        stretch=LinearStretch(),
        colormap=plt.cm.gray_r,
        findersize=None,
        finder_coordlimits=None,
        overlay_ra=None,
        overlay_decl=None,
        overlay_pltopts={'marker':'o',
                         'markersize':10.0,
                         'markerfacecolor':'none',
                         'markeredgewidth':2.0,
                         'markeredgecolor':'red'},
        overlay_zoomcontain=False,
        grid=False,
        gridcolor='k'
):
    '''This makes a finder chart for a given FITS with an optional object
    position overlay.

    Parameters
    ----------

    fitsfile : str
        `fitsfile` is the FITS file to use to make the finder chart.

    outfile : str
        `outfile` is the name of the output file. This can be a png or pdf or
        whatever else matplotlib can write given a filename and extension.

    fitsext : int
        Sets the FITS extension in `fitsfile` to use to extract the image array
        from.

    wcsfrom : str or None
        If `wcsfrom` is None, the WCS to transform the RA/Dec to pixel x/y will
        be taken from the FITS header of `fitsfile`. If this is not None, it
        must be a FITS or similar file that contains a WCS header in its first
        extension.

    scale : astropy.visualization.Interval object
        `scale` sets the normalization for the FITS pixel values. This is an
        astropy.visualization Interval object.
        See http://docs.astropy.org/en/stable/visualization/normalization.html
        for details on `scale` and `stretch` objects.

    stretch : astropy.visualization.Stretch object
        `stretch` sets the stretch function for mapping FITS pixel values to
        output pixel values. This is an astropy.visualization Stretch object.
        See http://docs.astropy.org/en/stable/visualization/normalization.html
        for details on `scale` and `stretch` objects.

    colormap : matplotlib Colormap object
        `colormap` is a matplotlib color map object to use for the output image.

    findersize : None or tuple of two ints
        If `findersize` is None, the output image size will be set by the NAXIS1
        and NAXIS2 keywords in the input `fitsfile` FITS header. Otherwise,
        `findersize` must be a tuple with the intended x and y size of the image
        in inches (all output images will use a DPI = 100).

    finder_coordlimits : list of four floats or None
        If not None, `finder_coordlimits` sets x and y limits for the plot,
        effectively zooming it in if these are smaller than the dimensions of
        the FITS image. This should be a list of the form: [minra, maxra,
        mindecl, maxdecl] all in decimal degrees.

    overlay_ra, overlay_decl : np.array or None
        `overlay_ra` and `overlay_decl` are ndarrays containing the RA and Dec
        values to overplot on the image as an overlay. If these are both None,
        then no overlay will be plotted.

    overlay_pltopts : dict
        `overlay_pltopts` controls how the overlay points will be plotted. This
        a dict with standard matplotlib marker, etc. kwargs as key-val pairs,
        e.g. 'markersize', 'markerfacecolor', etc. The default options make red
        outline circles at the location of each object in the overlay.

    overlay_zoomcontain : bool
        `overlay_zoomcontain` controls if the finder chart will be zoomed to
        just contain the overlayed points. Everything outside the footprint of
        these points will be discarded.

    grid : bool
        `grid` sets if a grid will be made on the output image.

    gridcolor : str
        `gridcolor` sets the color of the grid lines. This is a usual matplotib
        color spec string.

    Returns
    -------

    str or None
        The filename of the generated output image if successful. None
        otherwise.

    '''

    # read in the FITS file
    if wcsfrom is None:

        hdulist = pyfits.open(fitsfile)
        img, hdr = hdulist[fitsext].data, hdulist[fitsext].header
        hdulist.close()

        frameshape = (hdr['NAXIS1'], hdr['NAXIS2'])
        w = WCS(hdr)

    elif os.path.exists(wcsfrom):

        hdulist = pyfits.open(fitsfile)
        img, hdr = hdulist[fitsext].data, hdulist[fitsext].header
        hdulist.close()

        frameshape = (hdr['NAXIS1'], hdr['NAXIS2'])
        w = WCS(wcsfrom)

    else:

        LOGERROR('could not determine WCS info for input FITS: %s' %
                 fitsfile)
        return None

    # use the frame shape to set the output PNG's dimensions
    if findersize is None:
        fig = plt.figure(figsize=(frameshape[0]/100.0,
                                  frameshape[1]/100.0))
    else:
        fig = plt.figure(figsize=findersize)

    # set the coord limits if zoomcontain is True
    # we'll leave 30 arcseconds of padding on each side
    if (overlay_zoomcontain and
        overlay_ra is not None and
        overlay_decl is not None):

        finder_coordlimits = [overlay_ra.min()-30.0/3600.0,
                              overlay_ra.max()+30.0/3600.0,
                              overlay_decl.min()-30.0/3600.0,
                              overlay_decl.max()+30.0/3600.0]

    # set the coordinate limits if provided
    if finder_coordlimits and isinstance(finder_coordlimits, (list,tuple)):

        minra, maxra, mindecl, maxdecl = finder_coordlimits
        cntra, cntdecl = (minra + maxra)/2.0, (mindecl + maxdecl)/2.0

        pixelcoords = w.all_world2pix([[minra, mindecl],
                                       [maxra, maxdecl],
                                       [cntra, cntdecl]],1)
        x1, y1, x2, y2 = (int(pixelcoords[0,0]),
                          int(pixelcoords[0,1]),
                          int(pixelcoords[1,0]),
                          int(pixelcoords[1,1]))

        xmin = x1 if x1 < x2 else x2
        xmax = x2 if x2 > x1 else x1

        ymin = y1 if y1 < y2 else y2
        ymax = y2 if y2 > y1 else y1

        # create a new WCS with the same transform but new center coordinates
        whdr = w.to_header()
        whdr['CRPIX1'] = (xmax - xmin)/2
        whdr['CRPIX2'] = (ymax - ymin)/2
        whdr['CRVAL1'] = cntra
        whdr['CRVAL2'] = cntdecl
        whdr['NAXIS1'] = xmax - xmin
        whdr['NAXIS2'] = ymax - ymin
        w = WCS(whdr)

    else:
        xmin, xmax, ymin, ymax = 0, hdr['NAXIS2'], 0, hdr['NAXIS1']

    # add the axes with the WCS projection
    # this should automatically handle subimages because we fix the WCS
    # appropriately above for these
    fig.add_subplot(111,projection=w)

    if scale is not None and stretch is not None:

        norm = ImageNormalize(img,
                              interval=scale,
                              stretch=stretch)

        plt.imshow(img[ymin:ymax,xmin:xmax],
                   origin='lower',
                   cmap=colormap,
                   norm=norm)

    else:

        plt.imshow(img[ymin:ymax,xmin:xmax],
                   origin='lower',
                   cmap=colormap)

    # handle additional options
    if grid:
        plt.grid(color=gridcolor,ls='solid',lw=1.0)

    # handle the object overlay
    if overlay_ra is not None and overlay_decl is not None:

        our_pltopts = dict(
            transform=plt.gca().get_transform('fk5'),
            marker='o',
            markersize=10.0,
            markerfacecolor='none',
            markeredgewidth=2.0,
            markeredgecolor='red',
            rasterized=True,
            linestyle='none'
        )
        if overlay_pltopts is not None and isinstance(overlay_pltopts,
                                                      dict):
            our_pltopts.update(overlay_pltopts)

        plt.gca().set_autoscale_on(False)
        plt.gca().plot(overlay_ra, overlay_decl,
                       **our_pltopts)

    plt.xlabel('Right Ascension [deg]')
    plt.ylabel('Declination [deg]')

    # get the x and y axes objects to fix the ticks
    xax = plt.gca().coords[0]
    yax = plt.gca().coords[1]

    yax.set_major_formatter('d.ddd')
    xax.set_major_formatter('d.ddd')

    # save the figure
    plt.savefig(outfile, dpi=100.0)
    plt.close('all')

    return outfile
예제 #24
0
    async def get(self, object_id: str = None):
        """
        ---
        summary: Serve alert cutout as fits or png
        tags:
          - alerts
          - kowalski

        parameters:
          - in: query
            name: instrument
            required: false
            schema:
              type: str
          - in: query
            name: candid
            description: "ZTF alert candid"
            required: true
            schema:
              type: integer
          - in: query
            name: cutout
            description: "retrieve science, template, or difference cutout image?"
            required: true
            schema:
              type: string
              enum: [science, template, difference]
          - in: query
            name: file_format
            description: "response file format: original loss-less FITS or rendered png"
            required: true
            default: png
            schema:
              type: string
              enum: [fits, png]
          - in: query
            name: interval
            description: "Interval to use when rendering png"
            required: false
            schema:
              type: string
              enum: [min_max, zscale]
          - in: query
            name: stretch
            description: "Stretch to use when rendering png"
            required: false
            schema:
              type: string
              enum: [linear, log, asinh, sqrt]
          - in: query
            name: cmap
            description: "Color map to use when rendering png"
            required: false
            schema:
              type: string
              enum: [bone, gray, cividis, viridis, magma]

        responses:
          '200':
            description: retrieved cutout
            content:
              image/fits:
                schema:
                  type: string
                  format: binary
              image/png:
                schema:
                  type: string
                  format: binary

          '400':
            description: retrieval failed
            content:
              application/json:
                schema: Error
        """
        instrument = self.get_query_argument("instrument", "ZTF").upper()
        if instrument not in INSTRUMENTS:
            raise ValueError("Instrument name not recognised")

        # allow access to public data only by default
        selector = {1}

        for stream in self.associated_user_object.streams:
            if "ztf" in stream.name.lower():
                selector.update(set(stream.altdata.get("selector", [])))

        selector = list(selector)

        try:
            candid = int(self.get_argument("candid"))
            cutout = self.get_argument("cutout").capitalize()
            file_format = self.get_argument("file_format", "png").lower()
            interval = self.get_argument("interval", default=None)
            stretch = self.get_argument("stretch", default=None)
            cmap = self.get_argument("cmap", default=None)

            known_cutouts = ["Science", "Template", "Difference"]
            if cutout not in known_cutouts:
                return self.error(
                    f"Cutout {cutout} of {object_id}/{candid} not in {str(known_cutouts)}"
                )
            known_file_formats = ["fits", "png"]
            if file_format not in known_file_formats:
                return self.error(
                    f"File format {file_format} of {object_id}/{candid}/{cutout} not in {str(known_file_formats)}"
                )

            normalization_methods = {
                "asymmetric_percentile": AsymmetricPercentileInterval(
                    lower_percentile=1, upper_percentile=100
                ),
                "min_max": MinMaxInterval(),
                "zscale": ZScaleInterval(nsamples=600, contrast=0.045, krej=2.5),
            }
            if interval is None:
                interval = "asymmetric_percentile"
            normalizer = normalization_methods.get(
                interval.lower(),
                AsymmetricPercentileInterval(lower_percentile=1, upper_percentile=100),
            )

            stretching_methods = {
                "linear": LinearStretch,
                "log": LogStretch,
                "asinh": AsinhStretch,
                "sqrt": SqrtStretch,
            }
            if stretch is None:
                stretch = "log" if cutout != "Difference" else "linear"
            stretcher = stretching_methods.get(stretch.lower(), LogStretch)()

            if (cmap is None) or (
                cmap.lower() not in ["bone", "gray", "cividis", "viridis", "magma"]
            ):
                cmap = "bone"
            else:
                cmap = cmap.lower()

            query = {
                "query_type": "find",
                "query": {
                    "catalog": "ZTF_alerts",
                    "filter": {
                        "candid": candid,
                        "candidate.programid": {"$in": selector},
                    },
                    "projection": {"_id": 0, f"cutout{cutout}": 1},
                },
                "kwargs": {"limit": 1, "max_time_ms": 5000},
            }

            response = kowalski.query(query=query)

            if response.get("status", "error") == "success":
                alert = response.get("data", [dict()])[0]
            else:
                return self.error("No cutout found.")

            cutout_data = bj.loads(bj.dumps([alert[f"cutout{cutout}"]["stampData"]]))[0]

            # unzipped fits name
            fits_name = pathlib.Path(alert[f"cutout{cutout}"]["fileName"]).with_suffix(
                ""
            )

            # unzip and flip about y axis on the server side
            with gzip.open(io.BytesIO(cutout_data), "rb") as f:
                with fits.open(io.BytesIO(f.read())) as hdu:
                    header = hdu[0].header
                    data_flipped_y = np.flipud(hdu[0].data)

            if file_format == "fits":
                hdu = fits.PrimaryHDU(data_flipped_y, header=header)
                hdul = fits.HDUList([hdu])

                stamp_fits = io.BytesIO()
                hdul.writeto(fileobj=stamp_fits)

                self.set_header("Content-Type", "image/fits")
                self.set_header(
                    "Content-Disposition", f"Attachment;filename={fits_name}"
                )
                self.write(stamp_fits.getvalue())

            if file_format == "png":
                buff = io.BytesIO()
                plt.close("all")

                fig, ax = plt.subplots(figsize=(4, 4))
                fig.subplots_adjust(0, 0, 1, 1)
                ax.set_axis_off()

                # replace nans with median:
                img = np.array(data_flipped_y)
                # replace dubiously large values
                xl = np.greater(np.abs(img), 1e20, where=~np.isnan(img))
                if img[xl].any():
                    img[xl] = np.nan
                if np.isnan(img).any():
                    median = float(np.nanmean(img.flatten()))
                    img = np.nan_to_num(img, nan=median)
                norm = ImageNormalize(img, stretch=stretcher)
                img_norm = norm(img)
                vmin, vmax = normalizer.get_limits(img_norm)
                ax.imshow(img_norm, cmap=cmap, origin="lower", vmin=vmin, vmax=vmax)
                plt.savefig(buff, dpi=42)
                buff.seek(0)
                plt.close("all")
                self.set_header("Content-Type", "image/png")
                self.write(buff.getvalue())

        except Exception:
            _err = traceback.format_exc()
            return self.error(f"failure: {_err}")
예제 #25
0
result = Observations.query_object('M83')
selected_bands = result[(result['obs_collection'] == 'HST')
                        & (result['instrument_name'] == 'WFC3/UVIS') &
                        ((result['filters'] == 'F657N') |
                         (result['filters'] == 'F487N') |
                         (result['filters'] == 'F336W')) &
                        (result['target_name'] == 'MESSIER-083')]
prodlist = Observations.get_product_list(selected_bands)
filtered_prodlist = Observations.filter_products(prodlist)

downloaded = Observations.download_products(filtered_prodlist)

blue = fits.open(downloaded['Local Path'][2])
red = fits.open(downloaded['Local Path'][5])
green = fits.open(downloaded['Local Path'][8])

target_header = red['SCI'].header
green_repr, _ = reproject.reproject_interp(green['SCI'], target_header)
blue_repr, _ = reproject.reproject_interp(blue['SCI'], target_header)

rgb_img = make_lupton_rgb(
    ImageNormalize(vmin=0, vmax=1)(red['SCI'].data),
    ImageNormalize(vmin=0, vmax=0.3)(green_repr),
    ImageNormalize(vmin=0, vmax=1)(blue_repr),
    stretch=0.1,
    minimum=0,
)

plt.imshow(rgb_img, origin='lower', interpolation='none')
예제 #26
0
def _make_pretty_from_fits(fname=None,
                           title=None,
                           figsize=(10, 10 / 1.325),
                           dpi=150,
                           alpha=0.2,
                           number_ticks=7,
                           clip_percent=99.9,
                           **kwargs):

    with open_fits(fname) as hdu:
        header = hdu[0].header
        data = hdu[0].data
        data = focus_utils.mask_saturated(data)
        wcs = WCS(header)

    if not title:
        field = header.get('FIELD', 'Unknown field')
        exptime = header.get('EXPTIME', 'Unknown exptime')
        filter_type = header.get('FILTER', 'Unknown filter')

        try:
            date_time = header['DATE-OBS']
        except KeyError:
            # If we don't have DATE-OBS, check filename for date
            try:
                basename = os.path.splitext(os.path.basename(fname))[0]
                date_time = date_parser.parse(basename).isoformat()
            except Exception:
                # Otherwise use now
                date_time = current_time(pretty=True)

        date_time = date_time.replace('T', ' ', 1)

        title = '{} ({}s {}) {}'.format(field, exptime, filter_type, date_time)

    norm = ImageNormalize(interval=PercentileInterval(clip_percent), stretch=LogStretch())

    fig = Figure()
    FigureCanvas(fig)
    fig.set_size_inches(*figsize)
    fig.dpi = dpi

    if wcs.is_celestial:
        ax = fig.add_subplot(1, 1, 1, projection=wcs)
        ax.coords.grid(True, color='white', ls='-', alpha=alpha)

        ra_axis = ax.coords['ra']
        ra_axis.set_axislabel('Right Ascension')
        ra_axis.set_major_formatter('hh:mm')
        ra_axis.set_ticks(
            number=number_ticks,
            color='white',
            exclude_overlapping=True
        )

        dec_axis = ax.coords['dec']
        dec_axis.set_axislabel('Declination')
        dec_axis.set_major_formatter('dd:mm')
        dec_axis.set_ticks(
            number=number_ticks,
            color='white',
            exclude_overlapping=True
        )
    else:
        ax = fig.add_subplot(111)
        ax.grid(True, color='white', ls='-', alpha=alpha)

        ax.set_xlabel('X / pixels')
        ax.set_ylabel('Y / pixels')

    im = ax.imshow(data, norm=norm, cmap=palette, origin='lower')
    fig.colorbar(im)
    fig.suptitle(title)

    new_filename = fname.replace('.fits', '.jpg')
    fig.savefig(new_filename, bbox_inches='tight')

    # explicitly close and delete figure
    fig.clf()
    del fig

    return new_filename
예제 #27
0
def show_stamps(pscs,
                frame_idx=None,
                stamp_size=11,
                aperture_position=None,
                show_residual=False,
                stretch=None,
                save_name=None,
                show_max=False,
                show_pixel_grid=False,
                **kwargs):

    if aperture_position is None:
        midpoint = (stamp_size - 1) / 2
        aperture_position = (midpoint, midpoint)

    ncols = len(pscs)

    if show_residual:
        ncols += 1

    nrows = 1

    fig = Figure()
    FigureCanvas(fig)
    fig.set_figheight(4)
    fig.set_figwidth(8)

    if frame_idx is not None:
        s0 = pscs[0][frame_idx]
        s1 = pscs[1][frame_idx]
    else:
        s0 = pscs[0]
        s1 = pscs[1]

    if stretch == 'log':
        stretch = LogStretch()
    else:
        stretch = LinearStretch()

    norm = ImageNormalize(s0, interval=MinMaxInterval(), stretch=stretch)

    ax1 = fig.add_subplot(nrows, ncols, 1)

    im = ax1.imshow(s0, cmap=get_palette(), norm=norm)

    # create an axes on the right side of ax. The width of cax will be 5%
    # of ax and the padding between cax and ax will be fixed at 0.05 inch.
    # https://stackoverflow.com/questions/18195758/set-matplotlib-colorbar-size-to-match-graph
    divider = make_axes_locatable(ax1)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    fig.colorbar(im, cax=cax)
    ax1.set_title('Target')

    # Comparison
    ax2 = fig.add_subplot(nrows, ncols, 2)
    im = ax2.imshow(s1, cmap=get_palette(), norm=norm)

    divider = make_axes_locatable(ax2)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    fig.colorbar(im, cax=cax)
    ax2.set_title('Comparison')

    if show_pixel_grid:
        add_pixel_grid(ax1, stamp_size, stamp_size, show_superpixel=False)
        add_pixel_grid(ax2, stamp_size, stamp_size, show_superpixel=False)

    if show_residual:
        ax3 = fig.add_subplot(nrows, ncols, 3)

        # Residual
        residual = s0 - s1
        im = ax3.imshow(residual,
                        cmap=get_palette(),
                        norm=ImageNormalize(residual,
                                            interval=MinMaxInterval(),
                                            stretch=LinearStretch()))

        divider = make_axes_locatable(ax3)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        fig.colorbar(im, cax=cax)
        ax3.set_title('Noise Residual')
        ax3.set_title('Residual RMS: {:.01%}'.format(residual.std()))
        ax3.set_yticklabels([])
        ax3.set_xticklabels([])

        if show_pixel_grid:
            add_pixel_grid(ax1, stamp_size, stamp_size, show_superpixel=False)

    # Turn off tick labels
    ax1.set_yticklabels([])
    ax1.set_xticklabels([])
    ax2.set_yticklabels([])
    ax2.set_xticklabels([])

    if save_name:
        try:
            fig.savefig(save_name)
        except Exception as e:
            warn("Can't save figure: {}".format(e))

    return fig
예제 #28
0
def pathtest(step_input_filename,
             reffile,
             comparison_filename,
             writefile=True,
             show_figs=True,
             save_figs=False,
             threshold_diff=1.0e-7,
             debug=False):
    """
    This function calculates the difference between the pipeline and
    calculated pathloss values. The functions use the output of sourcetype.
    Args:
        step_input_filename: str, name of sourcetype step output fits file
        reffile: str, path to the pathloss MSA UNI reference fits files
        comparison_filename: str, path to comparison pipeline pathloss file
        writefile: boolean, if True writes fits files of calculated pathloss
                   and difference images
        show_figs: boolean, whether to show plots or not
        save_figs: boolean, save the plots
        threshold_diff: float, threshold difference between pipeline output
                        & ESA file
        debug: boolean, if true print statements will show on-screen
    Returns:
        - 1 plot, if told to save and/or show them.
        - median_diff: Boolean, True if smaller or equal to threshold.
        - log_msgs: list, all print statements are captured in this variable
    """

    log_msgs = []

    # start the timer
    pathtest_start_time = time.time()

    # get info from the rate file header
    det = fits.getval(step_input_filename, "DETECTOR", 0)
    msg = 'step_input_filename=' + step_input_filename
    print(msg)
    log_msgs.append(msg)
    exptype = fits.getval(step_input_filename, "EXP_TYPE", 0)
    grat = fits.getval(step_input_filename, "GRATING", 0)
    filt = fits.getval(step_input_filename, "FILTER", 0)
    # aper = fits.getval(step_input_filename[1], "SLTNAME", 0)

    msg = "path_loss_file:  Grating:" + grat + " Filter:" + filt + " EXP_TYPE:" + exptype
    print(msg)
    log_msgs.append(msg)

    # get the reference files
    msg = "Using reference file: " + reffile
    print(msg)
    log_msgs.append(msg)

    if writefile:
        # create the fits list to hold the calculated flat values for each slit
        hdu0 = fits.PrimaryHDU()
        outfile = fits.HDUList()
        outfile.append(hdu0)

        # create fits list to hold image of pipeline-calculated diff values
        hdu0 = fits.PrimaryHDU()
        compfile = fits.HDUList()
        compfile.append(hdu0)

    # list to determine if pytest is passed or not
    total_test_result = []

    print("""Checking if files exist and obtaining datamodels.
             This takes a few minutes...""")
    if os.path.isfile(comparison_filename):
        if debug:
            print('Comparison file does exist.')
    else:
        result_msg = 'Comparison file does NOT exist. Skipping pathloss test.'
        log_msgs.append(result_msg)
        result = 'skip'
        return result, result_msg, log_msgs

    # get the comparison data model
    pathloss_pipe = datamodels.open(comparison_filename)
    if debug:
        print('got comparison datamodel!')

    if os.path.isfile(step_input_filename):
        if debug:
            print('Input file does exist.')
    else:
        result_msg = 'Input file does NOT exist. Skipping pathloss test.'
        log_msgs.append(result_msg)
        result = 'skip'
        return result, result_msg, log_msgs

    # get the input data model
    pl = datamodels.open(step_input_filename)
    if debug:
        print('got input datamodel!')

    # loop through the slits
    msg = " Looping through the slits... "
    print(msg)
    log_msgs.append(msg)

    slit_val = 0
    for slit, pipe_slit in zip(pl.slits, pathloss_pipe.slits):
        slit_val = slit_val + 1

        mode = "MOS"

        is_point_source = False

        print("Retrieving extensions")
        ps_uni_ext_list = get_mos_ps_uni_extensions(reffile, is_point_source)

        slit_id = slit.name
        try:
            nshutters = util.get_num_msa_open_shutters(slit.shutter_state)
            if is_point_source:
                if nshutters == 3:
                    shutter_key = "MOS1x3"
                elif nshutters == 1:
                    shutter_key = "MOS1x1"
                ext = ps_uni_ext_list[0][shutter_key]
                print("Retrieved point source extension")
            if is_point_source is False:
                if nshutters == 1:
                    shutter_key = "MOS1x1"
                elif nshutters == 3:
                    shutter_key = "MOS1x3"
                ext = ps_uni_ext_list[1][shutter_key]
                print("Retrieved extended source extension {}".format(ext))
        except KeyError:
            print("Unable to retrieve extension. Using ext 3, but may be 7")
            ext = 3

        wcs_obj = slit.meta.wcs

        x, y = wcstools.grid_from_bounding_box(wcs_obj.bounding_box,
                                               step=(1, 1),
                                               center=True)
        ra, dec, wave = slit.meta.wcs(x, y)
        wave_sci = wave * 10**(-6)  # microns --> meters

        plcor_ref_ext = fits.getdata(reffile, ext)
        print("plcor_ref_ext.shape", plcor_ref_ext.shape)

        hdul = fits.open(reffile)

        plcor_ref = hdul[1].data
        w = wcs.WCS(hdul[1].header)

        w1, y1, x1 = np.mgrid[:plcor_ref.shape[0], :plcor_ref.
                              shape[1], :plcor_ref.shape[2]]
        slitx_ref, slity_ref, wave_ref = w.all_pix2world(x1, y1, w1, 0)

        comp_sci = pipe_slit.data
        previous_sci = slit.data

        pipe_correction = pipe_slit.pathloss

        # set up generals for all the plots
        font = {'weight': 'normal', 'size': 10}
        matplotlib.rc('font', **font)

        corr_vals = np.interp(wave_sci, wave_ref[:, 0, 0], plcor_ref_ext)
        corrected_array = previous_sci / corr_vals

        # plots:
        step_input_filepath = step_input_filename.replace(".fits", "")
        # my correction values
        fig = plt.figure()
        ax = plt.gca()
        ax.get_xaxis().get_major_formatter().set_useOffset(False)
        ax.get_xaxis().get_major_formatter().set_scientific(False)
        # calculated correction values
        plt.subplot(221)
        norm = ImageNormalize(corr_vals)
        plt.imshow(corr_vals,
                   norm=norm,
                   aspect=10.0,
                   origin='lower',
                   cmap='viridis')
        plt.xlabel('x in pixels')
        plt.ylabel('y in pixels')
        plt.title('Calculated Correction')
        plt.colorbar()
        # pipeline correction values
        plt.subplot(222)
        norm = ImageNormalize(pipe_correction)
        plt.imshow(pipe_correction,
                   norm=norm,
                   aspect=10.0,
                   origin='lower',
                   cmap='viridis')
        plt.xlabel('x in pixels')
        plt.ylabel('y in pixels')
        plt.title('Pipeline Correction')
        plt.colorbar()
        # residuals (pipe correction - my correction)
        corr_residuals = pipe_correction - corr_vals
        plt.subplot(223)
        norm = ImageNormalize(corr_residuals)
        plt.ticklabel_format(useOffset=False)
        plt.imshow(corr_residuals,
                   norm=norm,
                   aspect=10.0,
                   origin='lower',
                   cmap='viridis')
        plt.xlabel('x in pixels')
        plt.ylabel('y in pixels')
        plt.title('Correction residuals')
        plt.colorbar()
        # my science data after pathloss
        plt.subplot(224)
        norm = ImageNormalize(corrected_array)
        plt.imshow(corrected_array,
                   norm=norm,
                   aspect=10.0,
                   origin='lower',
                   cmap='viridis')
        plt.title('Corrected Data After Pathloss')
        plt.xlabel('x in pixels')
        plt.ylabel('y in pixels')
        plt.colorbar()
        fig.suptitle("MOS UNI Pathloss Calibration Testing")

        if show_figs:
            plt.show()
        if save_figs:
            plt_name = step_input_filepath + "Pathloss_test_slitlet_" + str(
                mode) + "_UNI_" + str(slit_id) + ".png"
            plt.savefig(plt_name)
            print('Figure saved as: ', plt_name)
        plt.close()

        ax = plt.subplot(212)
        plt.hist(corr_residuals[~np.isnan(corr_residuals)],
                 bins=100,
                 range=(-0.00000013, 0.00000013))
        plt.title('Residuals Histogram')
        plt.xlabel("Correction Value")
        plt.ylabel("Number of Occurences")
        nanind = np.isnan(corr_residuals)  # get all the nan indexes
        notnan = ~nanind  # get all the not-nan indexes
        arr_mean = np.mean(corr_residuals[notnan])
        arr_median = np.median(corr_residuals[notnan])
        arr_stddev = np.std(corr_residuals[notnan])
        plt.axvline(arr_mean, label="mean = %0.3e" % (arr_mean), color="g")
        plt.axvline(arr_median,
                    label="median = %0.3e" % (arr_median),
                    linestyle="-.",
                    color="b")
        str_arr_stddev = "stddev = {:0.3e}".format(arr_stddev)
        ax.text(0.73,
                0.67,
                str_arr_stddev,
                transform=ax.transAxes,
                fontsize=16)
        plt.legend()
        plt.minorticks_on()

        # Show and/or save figures
        if save_figs:
            plt_name = step_input_filepath + "Pathlosstest_MOS_UNI_slitlet_" + slit_id + ".png"
            plt.savefig(plt_name)
            print('Figure saved as: ', plt_name)
        if show_figs:
            plt.show()
        elif not save_figs and not show_figs:
            msg = "Not making plots because both show_figs and save_figs were set to False."
            if debug:
                print(msg)
            log_msgs.append(msg)
        elif not save_figs:
            msg = "Not saving plots because save_figs was set to False."
            if debug:
                print(msg)
            log_msgs.append(msg)
        plt.close()

        # create fits file to hold the calculated pathloss for each slit
        if writefile:
            msg = "Saving the fits files with the calculated pathloss for each slit..."
            print(msg)
            log_msgs.append(msg)

            # this is the file to hold the image of pipeline-calculated difference values
            outfile_ext = fits.ImageHDU(corr_vals, name=slit_id)
            outfile.append(outfile_ext)

            # this is the file to hold the image of pipeline-calculated difference values
            compfile_ext = fits.ImageHDU(corr_residuals, name=slit_id)
            compfile.append(compfile_ext)

        if corr_residuals[~np.isnan(corr_residuals)].size == 0:
            msg1 = """Unable to calculate statistics because difference
                    array has all values as NaN.
                    Test will be set to FAILED."""
            print(msg1)
            log_msgs.append(msg1)
            test_result = "FAILED"
        else:
            msg = "Calculating statistics... "
            print(msg)
            log_msgs.append(msg)
            # ignore outliers:
            corr_residuals = corr_residuals[np.where((corr_residuals != 999.0)
                                                     & (corr_residuals < 0.1)
                                                     &
                                                     (corr_residuals > -0.1))]
            if corr_residuals.size == 0:
                msg1 = """ * Unable to calculate statistics because
                       difference array has all outlier values.
                       Test will be set to FAILED."""
                print(msg1)
                log_msgs.append(msg1)
                test_result = "FAILED"
            else:
                stats_and_strings = auxfunc.print_stats(corr_residuals,
                                                        "Difference",
                                                        float(threshold_diff),
                                                        abs=True)
                stats, stats_print_strings = stats_and_strings
                corr_residuals_mean, corr_residuals_median, corr_residuals_std = stats
                for msg in stats_print_strings:
                    log_msgs.append(msg)

                # This is the key argument for the assert pytest function
                median_diff = False
                if abs(corr_residuals_median) <= float(threshold_diff):
                    median_diff = True
                if median_diff:
                    test_result = "PASSED"
                else:
                    test_result = "FAILED"

            msg = " *** Result of the test: " + test_result + "\n"
            print(msg)
            log_msgs.append(msg)
            total_test_result.append(test_result)

    if writefile:
        outfile_name = step_input_filename.replace(
            "srctype", det + "_calcuated_FS_UNI_pathloss")
        compfile_name = step_input_filename.replace(
            "srctype", det + "_comparison_FS_UNI_pathloss")

        # create the fits list to hold the calculated flat values for each slit
        outfile.writeto(outfile_name, overwrite=True)

        # this is the file to hold the image of pipeline-calculated difference values
        compfile.writeto(compfile_name, overwrite=True)

        msg = "\nFits file with calculated pathloss values of each slit saved as: "
        print(msg)
        log_msgs.append(msg)
        print(outfile_name)
        log_msgs.append(outfile_name)

        msg = "Fits file with comparison (pipeline pathloss - calculated pathloss) saved as: "
        print(msg)
        log_msgs.append(msg)
        print(compfile_name)
        log_msgs.append(compfile_name)

    # If all tests passed then pytest will be marked as PASSED, else it will be FAILED
    FINAL_TEST_RESULT = False
    for t in total_test_result:
        if t == "FAILED":
            FINAL_TEST_RESULT = False
            break
        else:
            FINAL_TEST_RESULT = True

    if FINAL_TEST_RESULT:
        msg = "\n *** Final result for path_loss test will be reported as PASSED *** \n"
        print(msg)
        log_msgs.append(msg)
        result_msg = "All slits PASSED path_loss test."
    else:
        msg = "\n *** Final result for path_loss test will be reported as FAILED *** \n"
        print(msg)
        log_msgs.append(msg)
        result_msg = "One or more slits FAILED path_loss test."

    # end the timer
    pathloss_end_time = time.time() - pathtest_start_time
    if pathloss_end_time > 60.0:
        pathloss_end_time = pathloss_end_time / 60.0  # in minutes
        pathloss_tot_time = "* Script msa_uni.py took ", repr(
            pathloss_end_time) + " minutes to finish."
        if pathloss_end_time > 60.0:
            pathloss_end_time = pathloss_end_time / 60.  # in hours
            pathloss_tot_time = "* Script msa_uni.py took ", repr(
                pathloss_end_time) + " hours to finish."
    else:
        pathloss_tot_time = "* Script msa_uni.py took ", repr(
            pathloss_end_time) + " seconds to finish."
    print(pathloss_tot_time)
    log_msgs.append(pathloss_tot_time)

    return FINAL_TEST_RESULT, result_msg, log_msgs
def main():

    #Load in file, need to change name to file desired
    flatname = 'customsci237'
    filein = fits.open("rg" + flatname + ".fits")

    #initialize data array
    data = np.zeros((12, 4176, 512))
    #fill data array
    for i in range(12):
        data[i] = filein[i + 1].data
    print(data.shape)

    #import the corresponding fiber-bundle gap mask that defines where the gaps are
    gaps = np.genfromtxt("blkmask_xS20180408S0238", dtype='int',
                         unpack=True)[2:, :]
    print(gaps.shape)

    # corr=fits.open("brgS20180414S0139.fits")
    # # plt.figure
    # # plt.imshow(data[0]-corr[1].data)
    # # plt.show()
    # for i in range(len(gaps[0])):

    # 	for j in range(12):
    # 		# print("ccd: "+str(j))
    # 		amp=corr[j+1].data
    # 		print(amp.shape)
    # 		if j==0:
    # 			fibergap=np.median(amp[gaps[0,i]:gaps[1,i],:],axis=0)
    # 			# if i==0:
    # 			# 	plt.figure
    # 			# 	plt.imshow(amp[gaps[0,i]:gaps[1,i],:20])
    # 			# 	plt.show()
    # 			# print(fibergap[:10])
    # 		else:
    # 			fibergap=np.hstack((fibergap,np.median(amp[gaps[0,i]:gaps[1,i],:],axis=0)))
    # 	print(fibergap.shape)
    # 	plt.figure
    # 	plt.plot(np.arange(6144),fibergap)
    # 	plt.show()

    #initialize scattered light model array
    x = np.arange(512)
    y = np.arange(4176)
    xx, yy = np.meshgrid(x, y)
    scatteredlight = np.zeros((12, 4176, 512)) * np.nan

    #Create copy of data array for later
    testdata = data[0]
    for i in range(11):
        testdata = np.hstack((testdata, data[i + 1]))
    print(testdata.shape)

    #Show image before scattered light subtraction
    norm = ImageNormalize(testdata, interval=ZScaleInterval())
    plt.figure
    plt.imshow(testdata, norm=norm)
    plt.show()

    #Define order of fitting in x direction. Can be higher due to having many points with full coverage of axis
    xorder = 3

    #xposition of any pixel to be used for fitting later
    xv = np.arange(512 * 4)

    #initialize variables to be filled in later
    #there are 15 gaps, 3 ccds, and each ccd has 4 amps, each with a width of 512 pixels
    polyfits = np.zeros((15, 3, xorder + 1))
    ccds = np.zeros((3, 15, 512 * 4))
    yerr = np.zeros((3, 15, 512 * 4))

    #There are 15 gaps in every image used here, initialize y position varibale for these gaps
    yv = np.zeros(15)

    #Save all fits to a pdf to be examined later if needed.
    with PdfPages('guanolight_xfits.pdf') as pdf:
        #Iterate over all gaps (in this case 15)
        for i in range(len(gaps[0])):

            #Measure the y position of the gap we are on
            yv[i] = np.median(np.arange(gaps[1, i] - gaps[0, i]) + gaps[0, i])

            #There are 3 CCDs next to each other in the x axis. Iterate over to fit each separately
            for j in range(3):

                #initialize the cleaned total CCD array
                cleanx = np.zeros((4176, 512 * 4)) * np.nan

                #4 amps per CCD, each is it's own extension in the fits file. Here we manually clean hot columns of pixels that are constant in all images
                #and group each CCD amps together into one cleaned array.
                for k in range(4):

                    #get overall amp number (0-11)
                    l = j * 4 + k
                    print("amp: " + str(l))

                    #temporary data variable
                    test = data[l]

                    #upper and lower x indicies that this amp covers on the CCD
                    ux, lx = 512 * (k + 1), 512 * k

                    #initialize cleaned amp data to be filled in
                    clean = np.zeros((gaps[1, i] - gaps[0, i], 512)) * np.nan

                    #All of these if, elif statements target hot pixels that are constant and are caught by the manual mask in @remove_outliers function.
                    #These pixels are removed first before removing outliers from other sources in the else.
                    if l == 0:
                        clean[:, 53:] = remove_outliers(
                            test[gaps[0, i]:gaps[1, i], 53:], 4)
                    elif l == 3:
                        clean[:, :497] = remove_outliers(
                            test[gaps[0, i]:gaps[1, i], :497], 4)
                    elif l == 4:
                        temp = test[gaps[0, i]:gaps[1, i], :]
                        temp[:, 239:243] = np.nan
                        clean[:, 16:] = remove_outliers(temp[:, 16:], 4)
                        print(clean.shape)
                    elif l == 7:
                        clean[:, :489] = remove_outliers(
                            test[gaps[0, i]:gaps[1, i], :489], 4)
                    elif l == 8:
                        clean[:, 17:] = remove_outliers(
                            test[gaps[0, i]:gaps[1, i], 17:], 4)
                    elif l == 11:
                        clean[:, :489] = remove_outliers(
                            test[gaps[0, i]:gaps[1, i], :489], 4)
                    else:
                        clean = remove_outliers(test[gaps[0, i]:gaps[1, i], :],
                                                4)

                    #input the cleaned amp data into the cleaned CCD array
                    cleanx[gaps[0, i]:gaps[1, i], lx:ux] = clean

                #condense each gap down to a single point along it's y axis extent to create a single line along x
                ynew = np.median(cleanx[gaps[0, i]:gaps[1, i], :], axis=0)

                #only fit to points that aren't nans
                mask = ~np.isnan(ynew)
                xfit = np.polyfit(xv[mask], ynew[mask], xorder)

                #save the polynomial fit parameters for later
                polyfits[i, j, :] = xfit

                #use difference from fit to actual image within the gap for error estimate
                yfit = np.poly1d(xfit)
                yerr[j, i, :] = np.ones(
                    len(xv)) * np.std(ynew[mask] - yfit(xv[mask]))

                #save fit across ccd for the y fits later
                ccds[j, i, :] = yfit(xv)
                fig, ax = plt.subplots()
                ax.plot(xv[mask], ynew[mask])
                ax.plot(xv, ynew)
                ax.plot(xv, yfit(xv))
                ax.set_ylim(-10., 30.)
                if (j == 0) & (i == 0):
                    plt.show()
                pdf.savefig(fig)
                plt.close()

    #Set y axis fit order. Have to be careful with this as we will essentially be fitting each column of pixels using 15 points
    yorder = 3

    #initialize y-axis for fitting and the resulting scattered light model
    ynew = np.arange(4176)
    modelscatlight = np.zeros((4176, 512 * 12))

    #again saving fits in pdf
    with PdfPages('guanolight_yfits.pdf') as pdf:

        #iterating over 3 CCDs
        for i in range(3):

            #iterating over the columns in groups of 4 columns median'd together along the x axis
            for j in range(512):

                #initial x-position of column to be fit
                xpos = j * 4 + i * 512 * 4

                #median 4 columns together to be fit
                z = np.median(ccds[i, :, j * 4:(j + 1) * 4], axis=-1)

                #pull errors from x-axis fits
                err = yerr[i, :, j * 4]
                # z=np.hstack((z,z[-1]))
                # print(np.median(testdata[0:10,i*512*4+j],axis=0))

                #fit along y-axis (sorry for poor variable naming). Weight each point in fit by the inverse of the error.
                zfit = np.polyfit(yv, z, yorder, w=err**(-1.))
                znew = np.poly1d(zfit)

                #save the fit for those 4 columns into the model
                modelscatlight[:, xpos:xpos + 4] = np.array(
                    [znew(ynew),
                     znew(ynew),
                     znew(ynew),
                     znew(ynew)]).T

                #plot for pdf
                fig, ax = plt.subplots()
                ax.plot(ynew, znew(ynew), 'r', linewidth=2.0)
                ax.errorbar(yv, z, yerr=err, fmt='o', color='b', ms=4.0)
                pdf.savefig(fig)

                #show first plot for sanity check
                if (i == 0) & (j == 0):
                    plt.show()
                plt.close()

    #plot total model after all fitting
    plt.figure
    plt.imshow(modelscatlight)
    plt.show()

    #subtract the model from the data over all 12 amps
    for i in range(12):
        filein[i + 1].data = data[i] - modelscatlight[:, 512 * i:512 * (i + 1)]

    #here we examine all the fiber-bundle gaps post subtraction to see how close to zero they become. If perfectly subtracted then they should all show zero flux.
    #iterate over each gap
    for i in range(len(gaps[0])):

        plt.figure

        #iterate over each amp
        for j in range(12):
            # print("ccd: "+str(j))
            amp = filein[j + 1].data
            print(amp.shape)

            #if first time then initialize variable otherwise add new amp to variable
            if j == 0:
                fibergap = np.median(amp[gaps[0, i]:gaps[1, i], :], axis=0)
            else:
                fibergap = np.hstack(
                    (fibergap, np.median(amp[gaps[0, i]:gaps[1, i], :],
                                         axis=0)))
        print(fibergap.shape)

        #plot gap spanning all amps
        plt.plot(np.arange(6144), fibergap)
        plt.show()

    #write to subtracted image to file.
    filein.writeto("brgcustomsci237.fits", overwrite=True)
예제 #30
0
def nice_norm(image: np.ndarray):
    return ImageNormalize(image, interval=ZScaleInterval(), stretch=SqrtStretch())