Beispiel #1
0
def showNorm(imaOrCcd, **kwargs):
    from astropy.visualization import imshow_norm, SqrtStretch
    from astropy.visualization.mpl_normalize import PercentileInterval
    from astropy.nddata import CCDData

    plt.clf()
    fig = plt.gcf()
    if isinstance(imaOrCcd, CCDData):
        arr = imaOrCcd.data
        wcs = imaOrCcd.wcs
        if wcs is None:
            ax = plt.subplot()
        else:
            ax = plt.subplot(projection=wcs)
            ax.coords.grid(True, color='white', ls='solid')
    else:
        arr = imaOrCcd
        ax = plt.subplot()
    if 'interval' not in kwargs:
        kwargs['interval'] = PercentileInterval(99.7)
    if 'stretch' not in kwargs:
        kwargs['stretch'] = SqrtStretch()
    if 'origin' not in kwargs:
        kwargs['origin'] = 'lower'

    im, _ = imshow_norm(arr, ax=ax, **kwargs)

    cb = fig.colorbar(im)
    cb.ax.tick_params(labelsize=11)
def asinh_plot_VLASS_mJy(wcs_celestial, image_data, showgrid=False):
    cmap = plt.get_cmap('viridis')

    plt.rcParams.update({'font.size': 18})
    fig = plt.figure(figsize=(11, 8), dpi=75)
    ax = plt.subplot(projection=wcs_celestial)
    im, norm = imshow_norm(image_data,
                           ax,
                           origin='lower',
                           interval=MinMaxInterval(),
                           stretch=AsinhStretch(),
                           vmin=-1e-5,
                           cmap=cmap)
    cbar = fig.colorbar(im, cmap=cmap)
    ax.set_xlabel('RA J2000')
    ax.set_ylabel('Dec J2000')
    ax.set_title(pulsarname)
    cbar.set_label('mJy')
    if showgrid:
        ax.grid(color='white', ls='solid')
    cutout = Cutout2D(image_data, position, boxsize, wcs=wcs_celestial)
    cutout.plot_on_original(color='white')
    plt.savefig(pulsarname)

    plt.show()
Beispiel #3
0
def read_and_refine_wcs(filepath,
                        catalog_coords,
                        use_astrometry_net=False,
                        show=False):
    if use_astrometry_net:
        if filepath.endswith('.fz'):
            os.system(f'funpack {filepath}')
            filepath = filepath[:-3]
        os.system(
            f'solve-field -p --temp-axy -S none -M none -R none -W none -B none -O -U indx.xyls {filepath}'
            ' && rm indx.xyls')
        os.system(f'mv {filepath.replace(".fits", ".new")} {filepath}')

    ccddata = CCDData.read(filepath, unit='adu', hdu='SCI')
    sources = fits.getdata(filepath, extname='CAT')
    ccddata.mask = fits.getdata(filepath, extname='BPM')
    ra, dec = ccddata.wcs.all_pix2world(sources['x'], sources['y'], 0)
    source_coords = SkyCoord(ra, dec, unit=u.deg)
    i, sep, _ = source_coords.match_to_catalog_sky(catalog_coords)
    n_hist, bins = np.histogram(sep.arcsec)
    i_peak = np.argmax(n_hist)
    match = (sep.arcsec > bins[i_peak]) & (sep.arcsec < bins[i_peak + 1])
    xy = np.array([sources['x'][match], sources['y'][match]]).T
    radec = np.array(
        [catalog_coords.ra.deg[i[match]], catalog_coords.dec.deg[i[match]]]).T
    refine_wcs(ccddata.wcs, xy, radec)

    if show:
        plt.figure(figsize=(6., 6.))
        imshow_norm(
            ccddata.data,
            interval=ZScaleInterval(),
            origin='lower' if ccddata.wcs.wcs.cd[1, 1] > 0. else 'upper')
        plt.axis('off')
        plt.axis('tight')
        plt.tight_layout(pad=0.)
        x, y = ccddata.wcs.all_world2pix(radec, 0).T
        plt.plot(x, y, ls='none', marker='o', mec='r', mfc='none')
        image_filename = os.path.basename(filepath).replace('.fz', '').replace(
            '.fits', '.png')
        plt.savefig(os.path.join(image_dir, image_filename), overwrite=True)
        plt.savefig('latest_image.png', overwrite=True)
        plt.close()

    return ccddata
Beispiel #4
0
def extract_photometry(ccddata,
                       catalog,
                       catalog_coords,
                       target,
                       image_path=None,
                       aperture_radius=2. * u.arcsec,
                       bg_radius_in=None,
                       bg_radius_out=None):
    apertures = SkyCircularAperture(catalog_coords, aperture_radius)
    if bg_radius_in is not None and bg_radius_out is not None:
        apertures = [
            apertures,
            SkyCircularAnnulus(catalog_coords, bg_radius_in, bg_radius_out)
        ]
    photometry = aperture_photometry(ccddata, apertures)
    target_row = photometry[target][0]
    if target_row['xcenter'].value < 0. or target_row['xcenter'].value > ccddata.shape[1] or \
            target_row['ycenter'].value < 0. or target_row['ycenter'].value > ccddata.shape[0]:
        logging.error(
            'target not contained in the image (or coordinate solution is bad)'
        )
        return
    if 'aperture_sum_1' in photometry.colnames:
        flux = photometry['aperture_sum_0'] - photometry['aperture_sum_1']
        dflux = (photometry['aperture_sum_err_0']**2. +
                 photometry['aperture_sum_err_1']**2.)**0.5
    else:
        flux = photometry['aperture_sum']
        dflux = photometry['aperture_sum_err']
    photometry['aperture_mag'] = u.Magnitude(flux / ccddata.meta['exptime'])
    photometry['aperture_mag_err'] = 2.5 / np.log(10.) * dflux / flux
    photometry = hstack([catalog, photometry])
    photometry['zeropoint'] = photometry['catalog_mag'] - photometry[
        'aperture_mag'].value
    zeropoints = photometry['zeropoint'][~target].filled(np.nan)
    zp = np.nanmedian(zeropoints)
    zperr = mad_std(
        zeropoints,
        ignore_nan=True) / np.isfinite(zeropoints).sum()**0.5  # std error
    target_row = photometry[target][0]
    mag = target_row['aperture_mag'].value + zp
    dmag = (target_row['aperture_mag_err'].value**2. + zperr**2.)**0.5
    with open(lc_file, 'a') as f:
        f.write(
            f'{ccddata.meta["MJD-OBS"]:11.5f} {mag:6.3f} {dmag:5.3f} {zp:6.3f} {zperr:5.3f} {ccddata.meta["FILTER"]:>6s} '
            f'{ccddata.meta["TELESCOP"]:>16s} {ccddata.meta["filename"]:>22s}\n'
        )

    if image_path is not None:
        ax = plt.axes()
        mark = ',' if np.isfinite(
            photometry['aperture_mag']).sum() > 1000 else '.'
        ax.plot(photometry['aperture_mag'],
                photometry['catalog_mag'],
                ls='none',
                marker=mark,
                zorder=1,
                label='calibration stars')
        ax.plot(mag - zp, mag, ls='none', marker='*', zorder=3, label='target')
        yfit = np.array([21., 13.])
        xfit = yfit - zp
        ax.plot(xfit, yfit, label=f'$Z = {zp:.2f}$ mag', zorder=2)
        ax.set_xlabel('Instrumental Magnitude')
        ax.set_ylabel('AB Magnitude')
        ax.legend()
        plt.savefig(image_path, overwrite=True)
        plt.savefig('latest_cal.pdf', overwrite=True)
        plt.close()

        plt.figure(figsize=(6., 6.))
        imshow_norm(ccddata.data, interval=ZScaleInterval())
        plt.axis('off')
        plt.axis('tight')
        plt.tight_layout(pad=0.)
        if isinstance(apertures, list):
            for aperture in apertures:
                aperture.to_pixel(ccddata.wcs).plot(color='r', lw=1)
        else:
            apertures.to_pixel(ccddata.wcs).plot(color='r', lw=1)
        image_filename = ccddata.meta['filename'].replace('.fz', '').replace(
            '.fits', '.png')
        plt.savefig(os.path.join(image_dir, image_filename), overwrite=True)
        plt.savefig('latest_image.png', overwrite=True)
        plt.close()
sources = daofind(data - median)

# And now, let's make some cuts to remove objects that are too faint or too bright
flux_min = 5
flux_max = 50
flux_range = np.where((sources['flux'] > flux_min)
                      & (sources['flux'] < flux_max))[0]

init_tbl = Table()
init_tbl['x_0'] = sources['xcentroid'][flux_range]
init_tbl['y_0'] = sources['ycentroid'][flux_range]
init_tbl['flux_0'] = sources['flux'][flux_range]

# And now, let's make a plot of the original image.
plt.figure(figsize=(9, 9))
imshow_norm(data, interval=PercentileInterval(99.), stretch=SqrtStretch())
plt.colorbar()
plt.savefig(data_file + '.png', dpi=300)
plt.clf()

# And let's make a plot of the image showing the positions of the objects.
plt.figure(figsize=(9, 9))
imshow_norm(data, interval=PercentileInterval(99.), stretch=SqrtStretch())
plt.scatter(init_tbl['x_0'], init_tbl['y_0'], s=10, color='black')
plt.colorbar()
plt.savefig(data_file + '_with_daostarfinder_objects.png', dpi=300)

# Let's go and load up the psf grid, which will be 3x3 for now.
nrc = webbpsf.NIRCam()
nrc.filter = hdu[0].header['FILTER']  #"F150W"
if (hdu[0].header['DETECTOR'] == 'NRCALONG'):
Beispiel #6
0
def showds9(ax,
            hdu,
            stretch=LinearStretch(),
            cmap='gray',
            pixscale=None,
            ticksevery=None):
    '''
    Description
      Uses astropy.visualization's imshow_norm
      to mimic viewing fits images with
      zscale in DS9
      
    Parameters
      ax: Axes object
      hdu: fits HDU object with data
      stretch: see astropy.visualization for options
      cmap: matplotlib color map
      pixscale: pixel scale WITH astropy unit attached
      ticksevery: label tick marks every __ angular units
                  must provide astropy unit
                  requires pixscaleo
                  the axis unit will adopt this unit

    Returns
      ax: Axes object
      im: returned by imshow_norm
    '''
    im, norm = imshow_norm(hdu.data,
                           ax,
                           origin='lower',
                           interval=ZScaleInterval(),
                           stretch=stretch,
                           cmap=cmap)
    if pixscale is not None:
        scaleunit = pixscale.unit

        if ticksevery is not None:
            #convert pixscale to ticksevery unit
            ticksunit = ticksevery.unit
            pixscale = pixscale.to(ticksunit)
            #N pixels per tick label
            tickstride = (ticksevery / pixscale).value
            tickstr = str(ticksevery.value)
            #use as many decimal places for label as in ticksevery
            if int(ticksevery.value) == ticksevery.value:
                ndec = 0
            else:
                ndec = len(tickstr[tickstr.find('.') + 1:])
            fmt = "%." + str(ndec) + "f"

            #get axes limits in units of pixels
            xmin, xmax = ax.get_xlim()
            ymin, ymax = ax.get_ylim()

            #re-label x
            newxvals, newxlabs = label_angticks(np.array(ax.get_xticks()),
                                                pixscale, tickstride, fmt,
                                                xmin, xmax)
            #re-label y
            newyvals, newylabs = label_angticks(np.array(ax.get_yticks()),
                                                pixscale, tickstride, fmt,
                                                ymin, ymax)

            ax.set_xticks(newxvals)
            ax.set_xticklabels(newxlabs)
            ax.set_yticks(newyvals)
            ax.set_yticklabels(newylabs)
            unit_str = ' [' + angunit_to_latex(pixscale) + ']'
        else:
            unit_str = ' [' + str(
                pixscale.value) + angunit_to_latex(pixscale) + '/pix]'
    else:
        unit_str = 'pix'

    #label units
    ax.set_xlabel('X' + unit_str)
    ax.set_ylabel('Y' + unit_str)

    return ax, im
def psf_fit(data_array, data_file, psf_grid, fheader, imagemodel):

    # Let's run a quick fitting using DAOStarFinder
    mean, median, std = sigma_clipped_stats(data_array, sigma=3.0)
    daofind = DAOStarFinder(fwhm=3.0, threshold=5. * std)
    sources = daofind(data_array - median)

    # And now, let's make some cuts to remove objects that are too faint or too bright
    flux_min = 5  #5 #0
    flux_max = 50  #50 #1000
    flux_range = np.where((sources['flux'] > flux_min)
                          & (sources['flux'] < flux_max))[0]

    init_tbl = Table()
    init_tbl['x_0'] = sources['xcentroid'][flux_range]
    init_tbl['y_0'] = sources['ycentroid'][flux_range]
    init_tbl['flux_0'] = sources['flux'][flux_range]

    # And now, let's make a plot of the original image.
    plt.figure(figsize=(9, 9))
    imshow_norm(data_array,
                interval=PercentileInterval(99.),
                stretch=SqrtStretch())
    plt.colorbar()
    plt.savefig(data_file + '.png', dpi=300)
    plt.clf()

    # And now, let's make a plot of the image showing the positions of the objects.
    plt.figure(figsize=(9, 9))
    imshow_norm(data_array,
                interval=PercentileInterval(99.),
                stretch=SqrtStretch())
    plt.scatter(init_tbl['x_0'], init_tbl['y_0'], s=10, color='black')
    plt.colorbar()
    plt.savefig(data_file + '_with_daostarfinder_objects.png', dpi=300)
    plt.clf()

    eval_xshape = int(np.ceil(psf_grid.data.shape[2] / psf_grid.oversampling))
    eval_yshape = int(np.ceil(psf_grid.data.shape[1] / psf_grid.oversampling))

    # And now, let's run the PSF Photometry
    sigma_psf = 3.
    daogroup = DBSCANGroup(2.0 * sigma_psf * gaussian_sigma_to_fwhm)
    mmm_bkg = MMMBackground()
    fit_shape = (eval_yshape, eval_xshape)
    phot = BasicPSFPhotometry(daogroup,
                              mmm_bkg,
                              psf_grid,
                              fit_shape,
                              finder=None,
                              aperture_radius=3.)

    # This is the part that takes the longest, so I print out the date/time before and after.
    now = datetime.now()
    dt_string = now.strftime("%d/%m/%Y %H:%M:%S")
    print("Starting the fit: date and time = ", dt_string)

    tbl = phot(data_array, init_guesses=init_tbl)

    now = datetime.now()
    dt_string = now.strftime("%d/%m/%Y %H:%M:%S")
    print("Ending the fit: date and time = ", dt_string)

    # Now I format the output
    tbl['x_fit'].format = '%.1f'
    tbl['y_fit'].format = '%.1f'
    tbl['flux_fit'].format = '%.4e'
    tbl['flux_unc'].format = '%.4e'
    tbl['x_0_unc'].format = '%.4e'
    tbl['y_0_unc'].format = '%.4e'

    diff = phot.get_residual_image()
    hdu_out = fits.PrimaryHDU(diff, header=fheader)
    hdul_out = fits.HDUList([hdu_out])
    hdul_out.writeto(data_file + '_residual.fits')

    # And create a residual image from the fit
    plt.figure(figsize=(9, 9))
    imshow_norm(diff, interval=PercentileInterval(99.), stretch=SqrtStretch())
    #plt.scatter(tbl['x_fit'], tbl['y_fit'], s=80, facecolors='none', edgecolors='r')
    plt.colorbar()
    plt.savefig(data_file + '_residual.png', dpi=300)
    plt.clf()

    # Calculate the RA and DEC values from the x_fit and y_fit values.
    RA_fit = np.zeros(len(tbl['x_fit']))
    DEC_fit = np.zeros(len(tbl['x_fit']))
    RA_fit, DEC_fit = imagemodel.meta.wcs(tbl['x_fit'], tbl['y_fit'])

    tbl.add_column(DEC_fit, index=0, name='DEC_fit')
    tbl.add_column(RA_fit, index=0, name='RA_fit')

    # And write out the table to a file.
    tbl.write(data_file + '_psf_fit_output.fits')