Esempio n. 1
0
def find_nearest_tile(ra, dec):
    from astropy.coordinates import SkyCoord
    from astropy.wcs import WCS
    from astropy.io import fits
    from astropy import units as u
    import numpy as np

    sc = SkyCoord(ra=ra * u.degree, dec=dec * u.degree, frame='icrs')
    line = None

    with fits.open('/home/rt2122/Data/fulldepth_neo4_index.fits') as hdul:
        sc1 = SkyCoord(ra=hdul[1].data['RA'] * u.degree,
                       dec=hdul[1].data['DEC'] * u.degree)
        idx = np.argmin(sc.separation(sc1).degree)
        line = hdul[1].data[idx]

    w = WCS(naxis=2)
    w.wcs.cd = line['CD']
    w.wcs.cdelt = line['CDELT']
    w.wcs.crpix = line['CRPIX']
    w.wcs.crval = line['CRVAL']
    w.wcs.ctype = ['RA---TAN', 'DEC--TAN']
    w.wcs.lonpole = line['LONGPOLE']
    w.wcs.latpole = line['LATPOLE']
    w.wcs.set_pv([(0, 0, 0)])
    w.array_shape = (2048, 2048)
    return w
Esempio n. 2
0
    def from_header(cls, header, hdu_bands=None, format=None):
        """Create a WCS geometry object from a FITS header.

        Parameters
        ----------
        header : `~astropy.io.fits.Header`
            The FITS header
        hdu_bands : `~astropy.io.fits.BinTableHDU`
            The BANDS table HDU.
        format : {'gadf', 'fgst-ccube','fgst-template'}
            FITS format convention.

        Returns
        -------
        wcs : `~WcsGeom`
            WCS geometry object.
        """
        wcs = WCS(header, naxis=2)
        # TODO: see https://github.com/astropy/astropy/issues/9259
        wcs._naxis = wcs._naxis[:2]

        axes = MapAxes.from_table_hdu(hdu_bands, format=format)
        shape = axes.shape

        if hdu_bands is not None and "NPIX" in hdu_bands.columns.names:
            npix = hdu_bands.data.field("NPIX").reshape(shape + (2, ))
            npix = (npix[..., 0], npix[..., 1])
            cdelt = hdu_bands.data.field("CDELT").reshape(shape + (2, ))
            cdelt = (cdelt[..., 0], cdelt[..., 1])
        elif "WCSSHAPE" in header:
            wcs_shape = eval(header["WCSSHAPE"])
            npix = (wcs_shape[0], wcs_shape[1])
            cdelt = None
            wcs.array_shape = npix
        else:
            npix = (header["NAXIS1"], header["NAXIS2"])
            cdelt = None

        if "PSLICE1" in header:
            cutout_info = {}
            cutout_info["parent-slices"] = (
                str_to_slice(header["PSLICE2"]),
                str_to_slice(header["PSLICE1"]),
            )
            cutout_info["cutout-slices"] = (
                str_to_slice(header["CSLICE2"]),
                str_to_slice(header["CSLICE1"]),
            )
        else:
            cutout_info = None

        return cls(wcs, npix, cdelt=cdelt, axes=axes, cutout_info=cutout_info)
Esempio n. 3
0
def custom_wcs(ra, dec):
    from astropy.io import fits
    from astropy.wcs import WCS

    cutout_url = 'https://www.legacysurvey.org/viewer/cutout.fits?ra={:.4f}&dec={:.4f}&layer=sdss2&pixscale=4.00'
    w = None
    with fits.open(cutout_url.format(ra, dec)) as hdul:
        w = WCS(hdul[0].header)
    w1 = WCS(naxis=2)
    w1.wcs.cd = w.wcs.cd[:2, :2]
    w1.wcs.cdelt = w.wcs.cdelt[:2]
    w1.wcs.crpix = [1024.5, 1024.5]
    w1.wcs.crval = w.wcs.crval[:2]
    w1.wcs.ctype = ['RA---TAN', 'DEC--TAN']
    w1.wcs.lonpole = w.wcs.lonpole
    w1.wcs.latpole = w.wcs.latpole
    w1.wcs.set_pv([(0, 0, 0)])
    w1.array_shape = [2048, 2048]
    return w1
Esempio n. 4
0
def brickwcs(ra, dec, npix=3600, step=0.262):
    """ Create the WCS and header for a brick."""

    # This creates a brick WCS given the brick structure

    # Make the tiling file
    #---------------------
    # Lines with the tiling scheme first
    nx = npix
    ny = npix
    step = step / npix
    xref = nx // 2
    yref = ny // 2

    w = WCS()
    w.wcs.crpix = [xref + 1, yref + 1]
    w.wcs.cdelt = np.array([step, step])
    w.wcs.crval = [ra, dec]
    w.wcs.ctype = ["RA---TAN", "DEC--TAN"]
    w.array_shape = (npix, npix)

    #  Make the header as well
    hdu = fits.PrimaryHDU()
    head = hdu.header

    head['NAXIS'] = 2
    head['NAXIS1'] = nx
    head['CDELT1'] = step
    head['CRPIX1'] = xref + 1
    head['CRVAL1'] = ra
    head['CTYPE1'] = 'RA---TAN'
    head['NAXIS2'] = ny
    head['CDELT2'] = step
    head['CRPIX2'] = yref + 1
    head['CRVAL2'] = dec
    head['CTYPE2'] = 'DEC--TAN'
    #head['BRAMIN'] = brickstr.ra1,'RA min of unique brick area'
    #head['BRAMAX'] = brickstr.ra2,'RA max of unique brick area'
    #head['BDECMIN'] = brickstr.dec1,'DEC min of unique brick area'
    #head['BDECMAX'] = brickstr.dec2,'DEC max of unique brick area'

    return w, head
Esempio n. 5
0
def meascutout(meas, obj, size=10, outdir='./', domask=True):
    """ Input the measurements and create cutouts. """

    expstr = fits.getdata(
        '/net/dl2/dnidever/nsc/instcal/v3/lists/nsc_v3_exposures.fits.gz', 1)
    #expstr = fits.getdata('/net/dl2/dnidever/nsc/instcal/v3/lists/nsc_v3_exposure.fits.gz',1)
    decam = Table.read('/home/dnidever/projects/delvered/data/decam.txt',
                       format='ascii')

    objid = obj['id'][0]

    # Sort by MJD
    si = np.argsort(meas['mjd'])
    meas = meas[si]

    # Make cut on FWHM
    # maybe only use values for 0.5*fwhm_chip to 1.5*fwhm_chip
    sql = "select chip.* from nsc_dr2.chip as chip join nsc_dr2.meas as meas on chip.exposure=meas.exposure and chip.ccdnum=meas.ccdnum"
    sql += " where meas.objectid='" + objid + "'"
    chip = qc.query(sql=sql, fmt='table')
    ind3, ind4 = dln.match(chip['exposure'], meas['exposure'])
    si = np.argsort(ind4)  # sort by input meas catalog
    ind3 = ind3[si]
    ind4 = ind4[si]
    chip = chip[ind3]
    meas = meas[ind4]
    gdfwhm, = np.where((meas['fwhm'] > 0.2 * chip['fwhm'])
                       & (meas['fwhm'] < 2.0 * chip['fwhm']))
    if len(gdfwhm) == 0:
        print('All measurements have bad FWHM values')
        return
    if len(gdfwhm) < len(meas):
        print('Removing ' + str(len(meas) - len(gdfwhm)) +
              ' measurements with bad FWHM values')
        meas = meas[gdfwhm]

    ind1, ind2 = dln.match(expstr['base'], meas['exposure'])
    nind = len(ind1)
    if nind == 0:
        print('No matches')
        return
    # Sort by input meas catalog
    si = np.argsort(ind2)
    ind1 = ind1[si]
    ind2 = ind2[si]

    # Create the reference WCS
    wref = WCS(naxis=2)
    pixscale = 0.26  # DECam, "/pix
    npix = round(size / pixscale)
    if npix % 2 == 0:  # must be odd
        npix += 1
    hpix = npix // 2  # center of image
    wref.wcs.ctype = ['RA---TAN', 'DEC--TAN']
    wref.wcs.crval = [obj['ra'][0], obj['dec'][0]]
    wref.wcs.crpix = [npix // 2, npix // 2]
    wref.wcs.cd = np.array([[pixscale / 3600.0, 0.0], [0.0, pixscale / 3600]])
    wref.array_shape = (npix, npix)
    refheader = wref.to_header()
    refheader['NAXIS'] = 2
    refheader['NAXIS1'] = npix
    refheader['NAXIS2'] = npix

    # Load the data
    instrument = expstr['instrument'][ind1]
    plver = expstr['plver'][ind1]
    fluxfile = expstr['file'][ind1]
    fluxfile = fluxfile.replace('/net/mss1/', '/mss1/')  # for thing/hulk
    maskfile = expstr['maskfile'][ind1]
    maskfile = maskfile.replace('/net/mss1/', '/mss1/')  # for thing/hulk
    ccdnum = meas['ccdnum'][ind2]
    figfiles = []
    xmeas = []
    ymeas = []
    cutimarr = np.zeros((npix, npix, nind), float)
    for i in range(nind):
        #for i in range(3):
        instcode = instrument[i]
        plver1 = plver[i]
        try:
            if instrument[i] == 'c4d':
                dind, = np.where(decam['CCDNUM'] == ccdnum[i])
                extname = decam['NAME'][dind[0]]
                im, head = getfitsext(fluxfile[i], extname, header=True)
                mim, mhead = getfitsext(maskfile[i], extname, header=True)
                #im,head = fits.getdata(fluxfile[i],header=True,extname=extname)
                #mim,mhead = fits.getdata(maskfile[i],header=True,extname=extname)
            else:
                im, head = fits.getdata(fluxfile[i], ccdnum[i], header=True)
                mim, mhead = fits.getdata(maskfile[i], ccdnum[i], header=True)
        except:
            print('error')
            import pdb
            pdb.set_trace()

        # Turn the mask from integer to bitmask
        if ((instcode == 'c4d') &
            (plver1 >= 'V3.5.0')) | (instcode == 'k4m') | (instcode == 'ksb'):
            omim = mim.copy()
            mim *= 0
            nonzero = (omim > 0)
            mim[nonzero] = 2**((omim - 1)[nonzero])  # This takes about 1 sec
        # Fix the DECam Pre-V3.5.0 masks
        if (instcode == 'c4d') & (plver1 < 'V3.5.0'):
            omim = mim.copy()
            mim *= 0  # re-initialize
            mim += (np.bitwise_and(omim, 1) == 1) * 1  # bad pixels
            mim += (np.bitwise_and(omim, 2) == 2) * 4  # saturated
            mim += (np.bitwise_and(omim, 4) == 4) * 32  # interpolated
            mim += (np.bitwise_and(omim, 16) == 16) * 16  # cosmic ray
            mim += (np.bitwise_and(omim, 64) == 64) * 8  # bleed trail

        # Get chip-level information
        exposure = os.path.basename(fluxfile[i])[0:-8]  # remove fits.fz
        chres = qc.query(sql="select * from nsc_dr2.chip where exposure='" +
                         exposure + "' and ccdnum=" + str(ccdnum[i]),
                         fmt='table')

        w = WCS(head)
        # RA/DEC correction for the object
        lon = obj['ra'][0] - chres['ra'][0]
        lat = obj['dec'][0] - chres['dec'][0]
        racorr = chres['ra_coef1'][0] + chres['ra_coef2'][0] * lon + chres[
            'ra_coef3'][0] * lon * lat + chres['ra_coef4'][0] * lat
        deccorr = chres['dec_coef1'][0] + chres['dec_coef2'][0] * lon + chres[
            'dec_coef3'][0] * lon * lat + chres['dec_coef4'][0] * lat
        # apply these offsets to the header WCS CRVAL
        #w.wcs.crval += [racorr,deccorr]
        #head['CRVAL1'] += racorr
        #head['CRVAL2'] += deccorr
        print(racorr, deccorr)

        # Object X/Y position
        xobj, yobj = w.all_world2pix(obj['ra'], obj['dec'], 0)
        # Get the cutout
        xcen = meas['x'][ind2[i]] - 1  # convert to 0-indexes
        ycen = meas['y'][ind2[i]] - 1
        smim = dln.gsmooth(im, 2)
        # use the object coords for centering
        #cutim,xr,yr = cutout(smim,xobj,yobj,size)

        # Mask the bad pixels
        if domask == True:
            badmask = (mim > 0)
            im[badmask] = np.nanmedian(im[~badmask])
        else:
            badmask = (im < 0)

        # Create a common TAN WCS that each image gets interpoled onto!!!
        #hdu1 = fits.open(fluxfile[i],extname=extname)
        smim1 = dln.gsmooth(im, 1.5)
        hdu = fits.PrimaryHDU(smim1, head)
        cutim, footprint = reproject_interp(hdu, refheader,
                                            order='bicubic')  # biquadratic
        cutim[footprint == 0] = np.nanmedian(
            im[~badmask])  # set out-of-bounds to background
        #xr = [0,npix-1]
        #yr = [0,npix-1]
        xr = [-hpix * pixscale, hpix * pixscale]
        yr = [-hpix * pixscale, hpix * pixscale]

        # exposure_ccdnum, filter, MJD, delta_MJD, mag
        print(
            str(i + 1) + ' ' + meas['exposure'][ind2[i]] + ' ' +
            str(ccdnum[i]) + ' ' + str(meas['x'][ind2[i]]) + ' ' +
            str(meas['y'][ind2[i]]) + ' ' + str(meas['mag_auto'][ind2[i]]))

        #figdir = '/net/dl2/dnidever/nsc/instcal/v3/hpm2/cutouts/'
        figfile = outdir
        figfile += '%s_%04d_%s_%02d.jpg' % (str(
            obj['id'][0]), i + 1, meas['exposure'][ind2[i]], ccdnum[i])
        figfiles.append(figfile)
        matplotlib.use('Agg')
        plt.rc('font', size=15)
        plt.rc('axes', titlesize=20)
        plt.rc('axes', labelsize=20)
        plt.rc('xtick', labelsize=20)
        plt.rc('ytick', labelsize=20)
        #plt.rcParams.update({'font.size': 15})
        #plt.rcParams.update({'axes.size': 20})
        #plt.rcParams.update({'xtick.size': 20})
        #plt.rcParams.update({'ytick.size': 20})
        if os.path.exists(figfile): os.remove(figfile)
        fig = plt.gcf()  # get current graphics window
        fig.clf()  # clear

        gskw = dict(width_ratios=[30, 1])
        fig, ax = plt.subplots(ncols=2, nrows=1, gridspec_kw=gskw)

        figsize = 8.0  #6.0
        figheight = 8.0
        figwidth = 9.0
        #ax = fig.subplots()  # projection=wcs
        #fig.set_figheight(figsize*0.8)
        fig.set_figheight(figheight)
        fig.set_figwidth(figwidth)
        med = np.nanmedian(smim)
        sig = dln.mad(smim)
        bigim, xr2, yr2 = cutout(smim, xcen, ycen, 151, missing=med)
        lmed = np.nanmedian(bigim)

        # Get the flux of the object and scale each image to the same height
        #meas.mag_aper1 = cat1.mag_aper[0] + 2.5*alog10(exptime) + chstr[i].zpterm
        #cmag = mag_auto + 2.5*alog10(exptime) + zpterm
        instmag = meas['mag_auto'][ind2[i]] - 2.5 * np.log10(
            chres['exptime'][0]) - chres['zpterm'][0]
        #mag = -2.5*log(flux)+25.0
        instflux = 10**((25.0 - instmag) / 2.5)
        print('flux = ' + str(instflux))
        # Get height of object
        #  flux of 2D Gaussian is ~2*pi*height*sigma^2
        pixscale1 = np.max(np.abs(w.wcs.cd)) * 3600
        fwhm = chres['fwhm'][0] / pixscale1
        instheight = instflux / (2 * 3.14 * (fwhm / 2.35)**2)
        print('height = ' + str(instheight))
        # Scale the images to the flux level of the first image
        cutim -= lmed
        if i == 0:
            instflux0 = instflux.copy()
            instheight0 = instheight.copy()
        else:
            scale = instflux0 / instflux
            #scale = instheight0/instheight
            cutim *= scale
            print('scaling image by ' + str(scale))

        #vmin = lmed-8*sig  # 3*sig
        #vmax = lmed+12*sig  # 5*sig
        if i == 0:
            vmin = -8 * sig  # 3*sig
            #vmax = 12*sig  # 5*sig
            vmax = 0.5 * instheight  # 0.5
            vmin0 = vmin
            vmax0 = vmax
        else:
            vmin = vmin0
            vmax = vmax0

        print('vmin = ' + str(vmin))
        print('vmax = ' + str(vmax))

        cutimarr[:, :, i] = cutim.copy()

        ax[0].imshow(cutim,
                     origin='lower',
                     aspect='auto',
                     interpolation='none',
                     extent=(xr[0], xr[1], yr[0], yr[1]),
                     vmin=vmin,
                     vmax=vmax,
                     cmap='viridis')  # viridis, Greys, jet
        #plt.imshow(cutim,origin='lower',aspect='auto',interpolation='none',
        #           vmin=vmin,vmax=vmax,cmap='viridis')   # viridis, Greys, jet
        #plt.colorbar()

        # show one vertical, one horizontal line pointing to the center but offset
        # then a small dot on the meas position
        # 13, 8
        ax[0].plot(np.array([0, 0]),
                   np.array([-0.066 * npix, 0.066 * npix]) * pixscale,
                   c='white',
                   alpha=0.7)
        ax[0].plot(np.array([-0.066 * npix, 0.066 * npix]) * pixscale,
                   np.array([0, 0]),
                   c='white',
                   alpha=0.7)

        # Meas X/Y position
        xmeas1, ymeas1 = wref.all_world2pix(meas['ra'][ind2[i]],
                                            meas['dec'][ind2[i]], 0)
        xmeas.append(xmeas1)
        ymeas.append(ymeas1)
        ax[0].scatter([(xmeas1 - hpix) * pixscale],
                      [(ymeas1 - hpix) * pixscale],
                      c='r',
                      marker='+',
                      s=20)
        #plt.scatter([xmeas],[ymeas],c='r',marker='+',s=100)
        #plt.scatter([xcen],[ycen],c='r',marker='+',s=100)
        # Object X/Y position
        #xobj,yobj = w.all_world2pix(obj['ra'],obj['dec'],0)
        xobj, yobj = wref.all_world2pix(obj['ra'], obj['dec'], 0)
        #plt.scatter(xobj,yobj,marker='o',s=200,facecolors='none',edgecolors='y',linewidth=3)
        #plt.scatter(xobj,yobj,c='y',marker='+',s=100)
        #leg = ax.legend(loc='upper left', frameon=False)
        ax[0].set_xlabel(r'$\Delta$ RA (arcsec)')
        ax[0].set_ylabel(r'$\Delta$ DEC (arcsec)')
        ax[0].set_xlim((xr[1], xr[0]))  # sky right
        ax[0].set_ylim(yr)
        #plt.xlabel('X')
        #plt.ylabel('Y')
        #plt.xlim(xr)
        #plt.ylim(yr)
        #ax.annotate(r'S/N=%5.1f',xy=(np.mean(xr), yr[0]+dln.valrange(yr)*0.05),ha='center')
        co = 'white'  #'lightgray' # blue
        ax[0].annotate('%s  %02d  %s  %6.1f  ' %
                       (meas['exposure'][ind2[i]], ccdnum[i],
                        meas['filter'][ind2[i]], expstr['exptime'][ind1[i]]),
                       xy=(np.mean(xr), yr[0] + dln.valrange(yr) * 0.05),
                       ha='center',
                       color=co)
        ax[0].annotate(
            '%10.2f    $\Delta$t=%7.2f  ' %
            (meas['mjd'][ind2[i]], meas['mjd'][ind2[i]] - np.min(meas['mjd'])),
            xy=(xr[1] - dln.valrange(xr) * 0.05,
                yr[1] - dln.valrange(yr) * 0.05),
            ha='left',
            color=co)
        #               xy=(xr[0]+dln.valrange(xr)*0.05, yr[1]-dln.valrange(yr)*0.05),ha='left',color=co)
        ax[0].annotate('%s = %5.2f +/- %4.2f' %
                       (meas['filter'][ind2[i]], meas['mag_auto'][ind2[i]],
                        meas['magerr_auto'][ind2[i]]),
                       xy=(xr[0] + dln.valrange(xr) * 0.05,
                           yr[1] - dln.valrange(yr) * 0.05),
                       ha='right',
                       color=co)
        #               xy=(xr[1]-dln.valrange(xr)*0.05, yr[1]-dln.valrange(yr)*0.05),ha='right',color=co)

        # Progress bar
        frameratio = (i + 1) / float(nind)
        timeratio = (meas['mjd'][ind2[i]] -
                     np.min(meas['mjd'])) / dln.valrange(meas['mjd'])
        #ratio = frameratio
        ratio = timeratio
        print('ratio = ' + str(100 * ratio))
        barim = np.zeros((100, 100), int)
        ind = dln.limit(int(round(ratio * 100)), 1, 99)
        barim[:, 0:ind] = 1
        ax[1].imshow(barim.T, origin='lower', aspect='auto', cmap='Greys')
        ax[1].set_xlabel('%7.1f \n days' %
                         (meas['mjd'][ind2[i]] - np.min(meas['mjd'])))
        #ax[1].set_xlabel('%d/%d' % (i+1,nind))
        ax[1].set_title('%d/%d' % (i + 1, nind))
        ax[1].axes.xaxis.set_ticks([])
        #ax[1].axes.xaxis.set_visible(False)
        ax[1].axes.yaxis.set_visible(False)
        #ax[1].axis('off')
        right_side = ax[1].spines['right']
        right_side.set_visible(False)
        left_side = ax[1].spines['left']
        left_side.set_visible(False)
        top_side = ax[1].spines['top']
        top_side.set_visible(False)

        plt.savefig(figfile)
        print('Cutout written to ' + figfile)

        #import pdb; pdb.set_trace()

    avgim = np.sum(cutimarr, axis=2) / nind
    avgim *= instheight0 / np.max(avgim)
    medim = np.median(cutimarr, axis=2)

    # Make a single blank file at the end so you know it looped
    figfile = outdir
    figfile += '%s_%04d_%s.jpg' % (str(obj['id'][0]), i + 2, 'path')
    figfiles.append(figfile)
    matplotlib.use('Agg')
    if os.path.exists(figfile): os.remove(figfile)
    fig = plt.gcf()  # get current graphics window
    fig.clf()  # clear
    gskw = dict(width_ratios=[30, 1])
    fig, ax = plt.subplots(ncols=2, nrows=1, gridspec_kw=gskw)
    fig.set_figheight(figheight)
    fig.set_figwidth(figwidth)
    ax[0].imshow(avgim,
                 origin='lower',
                 aspect='auto',
                 interpolation='none',
                 extent=(xr[0], xr[1], yr[0], yr[1]),
                 vmin=vmin,
                 vmax=vmax,
                 cmap='viridis')  # viridis, Greys, jet
    ax[0].plot(np.array([0, 0]),
               np.array([-0.066 * npix, 0.066 * npix]) * pixscale,
               c='white',
               alpha=0.7,
               zorder=1)
    ax[0].plot(np.array([-0.066 * npix, 0.066 * npix]) * pixscale,
               np.array([0, 0]),
               c='white',
               alpha=0.7,
               zorder=1)
    xmeas = np.array(xmeas)
    ymeas = np.array(ymeas)
    ax[0].plot((xmeas - hpix) * pixscale, (ymeas - hpix) * pixscale, c='r')
    #plt.scatter((xmeas-hpix)*pixscale,(ymeas-hpix)*pixscale,c='r',marker='+',s=30)
    ax[0].set_xlabel(r'$\Delta$ RA (arcsec)')
    ax[0].set_ylabel(r'$\Delta$ DEC (arcsec)')
    ax[0].set_xlim((xr[1], xr[0]))  # sky -right
    ax[0].set_ylim(yr)
    ax[1].axis('off')
    plt.savefig(figfile)
    # Make four copies
    for j in np.arange(2, 11):
        #pathfile = figfile.replace('path1','path'+str(j))
        pathfile = figfile.replace('%04d' % (i + 2), '%04d' % (i + 1 + j))
        if os.path.exists(pathfile): os.remove(pathfile)
        shutil.copyfile(figfile, pathfile)
        figfiles.append(pathfile)

    # Make the animated gif
    animfile = outdir + str(objid) + '_cutouts.gif'
    if os.path.exists(animfile): os.remove(animfile)
    # put list of files in a separate file
    listfile = outdir + str(objid) + '_cutouts.lst'
    if os.path.exists(listfile): os.remove(listfile)
    dln.writelines(listfile, figfiles)
    delay = dln.scale(nind, [20, 1000], [20, 1])
    delay = int(np.round(dln.limit(delay, 1, 20)))
    print('delay = ' + str(delay))
    print('Creating animated gif ' + animfile)
    #ret = subprocess.run('convert -delay 100 '+figdir+str(objid)+'_*.jpg '+animfile,shell=True)
    #ret = subprocess.run('convert -delay 20 '+' '.join(figfiles)+' '+animfile,shell=True)
    ret = subprocess.run('convert @' + listfile + ' -delay ' + str(delay) +
                         ' ' + animfile,
                         shell=True)
    #import pdb; pdb.set_trace()
    dln.remove(figfiles)
Esempio n. 6
0
def meascutout(meas, obj, size=10, outdir='./'):
    """ Input the measurements and create cutouts. """

    expstr = fits.getdata(
        '/net/dl2/dnidever/nsc/instcal/v3/lists/nsc_v3_exposures.fits.gz', 1)
    #expstr = fits.getdata('/net/dl2/dnidever/nsc/instcal/v3/lists/nsc_v3_exposure.fits.gz',1)
    decam = Table.read('/home/dnidever/projects/delvered/data/decam.txt',
                       format='ascii')

    objid = obj['id'][0]

    # Sort by MJD
    si = np.argsort(meas['mjd'])
    meas = meas[si]

    ind1, ind2 = dln.match(expstr['base'], meas['exposure'])
    nind = len(ind1)
    if nind == 0:
        print('No matches')
        return
    # Sort by input meas catalog
    si = np.argsort(ind2)
    ind1 = ind1[si]
    ind2 = ind2[si]

    # Create the reference WCS
    wref = WCS(naxis=2)
    pixscale = 0.26  # DECam, "/pix
    npix = round(size / pixscale)
    if npix % 2 == 0:  # must be odd
        npix += 1
    hpix = npix // 2  # center of image
    wref.wcs.ctype = ['RA---TAN', 'DEC--TAN']
    wref.wcs.crval = [obj['ra'][0], obj['dec'][0]]
    wref.wcs.crpix = [npix // 2, npix // 2]
    wref.wcs.cd = np.array([[pixscale / 3600.0, 0.0], [0.0, pixscale / 3600]])
    wref.array_shape = (npix, npix)
    refheader = wref.to_header()
    refheader['NAXIS'] = 2
    refheader['NAXIS1'] = npix
    refheader['NAXIS2'] = npix

    # Load the data
    instrument = expstr['instrument'][ind1]
    fluxfile = expstr['file'][ind1]
    fluxfile = fluxfile.replace('/net/mss1/', '/mss1/')  # for thing/hulk
    maskfile = expstr['maskfile'][ind1]
    maskfile = maskfile.replace('/net/mss1/', '/mss1/')  # for thing/hulk
    ccdnum = meas['ccdnum'][ind2]
    figfiles = []
    for i in range(nind):
        try:
            if instrument[i] == 'c4d':
                dind, = np.where(decam['CCDNUM'] == ccdnum[i])
                extname = decam['NAME'][dind[0]]
                im, head = getfitsext(fluxfile[i], extname, header=True)
                mim, mhead = getfitsext(maskfile[i], extname, header=True)
                #im,head = fits.getdata(fluxfile[i],header=True,extname=extname)
                #mim,mhead = fits.getdata(maskfile[i],header=True,extname=extname)
            else:
                im, head = fits.getdata(fluxfile[i], ccdnum[i], header=True)
                mim, mhead = fits.getdata(maskfile[i], ccdnum[i], header=True)
        except:
            print('error')
            import pdb
            pdb.set_trace()

        # Get chip-level information
        exposure = os.path.basename(fluxfile[i])[0:-8]  # remove fits.fz
        chres = qc.query(sql="select * from nsc_dr2.chip where exposure='" +
                         exposure + "' and ccdnum=" + str(ccdnum[i]),
                         fmt='table')

        w = WCS(head)
        # RA/DEC correction for the object
        lon = obj['ra'][0] - chres['ra'][0]
        lat = obj['dec'][0] - chres['dec'][0]
        racorr = chres['ra_coef1'][0] + chres['ra_coef2'][0] * lon + chres[
            'ra_coef3'][0] * lon * lat + chres['ra_coef4'][0] * lat
        deccorr = chres['dec_coef1'][0] + chres['dec_coef2'][0] * lon + chres[
            'dec_coef3'][0] * lon * lat + chres['dec_coef4'][0] * lat
        # apply these offsets to the header WCS CRVAL
        w.wcs.crval += [racorr, deccorr]
        head['CRVAL1'] += racorr
        head['CRVAL2'] += deccorr

        # Object X/Y position
        xobj, yobj = w.all_world2pix(obj['ra'], obj['dec'], 0)
        # Get the cutout
        xcen = meas['x'][ind2[i]] - 1  # convert to 0-indexes
        ycen = meas['y'][ind2[i]] - 1
        smim = dln.gsmooth(im, 2)
        # use the object coords for centering
        #cutim,xr,yr = cutout(smim,xobj,yobj,size)

        # Mask the bad pixels
        badmask = (mim > 0)
        im[badmask] = np.nanmedian(im[~badmask])

        # Create a common TAN WCS that each image gets interpoled onto!!!
        #hdu1 = fits.open(fluxfile[i],extname=extname)
        smim1 = dln.gsmooth(im, 1.5)
        hdu = fits.PrimaryHDU(smim1, head)
        cutim, footprint = reproject_interp(hdu, refheader,
                                            order='bicubic')  # biquadratic
        cutim[footprint == 0] = np.nanmedian(
            im[~badmask])  # set out-of-bounds to background
        #xr = [0,npix-1]
        #yr = [0,npix-1]
        xr = [-hpix * pixscale, hpix * pixscale]
        yr = [-hpix * pixscale, hpix * pixscale]

        # exposure_ccdnum, filter, MJD, delta_MJD, mag
        print(
            str(i + 1) + ' ' + meas['exposure'][ind2[i]] + ' ' +
            str(ccdnum[i]) + ' ' + str(meas['x'][ind2[i]]) + ' ' +
            str(meas['y'][ind2[i]]) + ' ' + str(meas['mag_auto'][ind2[i]]))

        #figdir = '/net/dl2/dnidever/nsc/instcal/v3/hpm2/cutouts/'
        figfile = outdir
        figfile += '%s_%04d_%s_%02d.jpg' % (str(
            obj['id'][0]), i + 1, meas['exposure'][ind2[i]], ccdnum[i])
        figfiles.append(figfile)
        matplotlib.use('Agg')
        plt.rcParams.update({'font.size': 11})
        if os.path.exists(figfile): os.remove(figfile)
        fig = plt.gcf()  # get current graphics window
        fig.clf()  # clear

        figsize = 8.0  #6.0
        ax = fig.subplots()  # projection=wcs
        #fig.set_figheight(figsize*0.8)
        fig.set_figheight(figsize)
        fig.set_figwidth(figsize)
        med = np.nanmedian(smim)
        sig = dln.mad(smim)
        bigim, xr2, yr2 = cutout(smim, xcen, ycen, 151, missing=med)
        lmed = np.nanmedian(bigim)

        # Get the flux of the object and scale each image to the same height
        #meas.mag_aper1 = cat1.mag_aper[0] + 2.5*alog10(exptime) + chstr[i].zpterm
        #cmag = mag_auto + 2.5*alog10(exptime) + zpterm
        instmag = meas['mag_auto'][ind2[i]] - 2.5 * np.log10(
            chres['exptime'][0]) - chres['zpterm'][0]
        #mag = -2.5*log(flux)+25.0
        instflux = 10**((25.0 - instmag) / 2.5)
        print('flux = ' + str(instflux))
        # Get height of object
        #  flux of 2D Gaussian is ~2*pi*height*sigma^2
        pixscale1 = np.max(np.abs(w.wcs.cd)) * 3600
        fwhm = chres['fwhm'][0] / pixscale1
        instheight = instflux / (2 * 3.14 * (fwhm / 2.35)**2)
        print('height = ' + str(instheight))
        # Scale the images to the flux level of the first image
        cutim -= lmed
        if i == 0:
            instflux0 = instflux.copy()
            instheight0 = instheight.copy()
        else:
            scale = instflux0 / instflux
            #scale = instheight0/instheight
            cutim *= scale
            print('scaling image by ' + str(scale))

        #vmin = lmed-8*sig  # 3*sig
        #vmax = lmed+12*sig  # 5*sig
        if i == 0:
            vmin = -8 * sig  # 3*sig
            #vmax = 12*sig  # 5*sig
            vmax = 0.5 * instheight  # 0.5
            vmin0 = vmin
            vmax0 = vmax
        else:
            vmin = vmin0
            vmax = vmax0

        print('vmin = ' + str(vmin))
        print('vmax = ' + str(vmax))

        plt.imshow(cutim,
                   origin='lower',
                   aspect='auto',
                   interpolation='none',
                   extent=(xr[0], xr[1], yr[0], yr[1]),
                   vmin=vmin,
                   vmax=vmax,
                   cmap='viridis')  # viridis, Greys, jet
        #plt.imshow(cutim,origin='lower',aspect='auto',interpolation='none',
        #           vmin=vmin,vmax=vmax,cmap='viridis')   # viridis, Greys, jet
        #plt.colorbar()

        # show one vertical, one horizontal line pointing to the center but offset
        # then a small dot on the meas position
        # 13, 8
        plt.plot(np.array([0, 0]),
                 np.array([-0.066 * npix, 0.066 * npix]) * pixscale,
                 c='white',
                 alpha=0.7)
        plt.plot(np.array([-0.066 * npix, 0.066 * npix]) * pixscale,
                 np.array([0, 0]),
                 c='white',
                 alpha=0.7)

        # Meas X/Y position
        xmeas, ymeas = wref.all_world2pix(meas['ra'][ind2[i]],
                                          meas['dec'][ind2[i]], 0)
        plt.scatter([(xmeas - hpix) * pixscale], [(ymeas - hpix) * pixscale],
                    c='r',
                    marker='+',
                    s=20)
        #plt.scatter([xmeas],[ymeas],c='r',marker='+',s=100)
        #plt.scatter([xcen],[ycen],c='r',marker='+',s=100)
        # Object X/Y position
        #xobj,yobj = w.all_world2pix(obj['ra'],obj['dec'],0)
        xobj, yobj = wref.all_world2pix(obj['ra'], obj['dec'], 0)
        #plt.scatter(xobj,yobj,marker='o',s=200,facecolors='none',edgecolors='y',linewidth=3)
        #plt.scatter(xobj,yobj,c='y',marker='+',s=100)
        #leg = ax.legend(loc='upper left', frameon=False)
        plt.xlabel(r'$\Delta$ RA (arcsec)')
        plt.ylabel(r'$\Delta$ DEC (arcsec)')
        #plt.xlabel('X')
        #plt.ylabel('Y')
        #plt.xlim(xr)
        #plt.ylim(yr)
        #ax.annotate(r'S/N=%5.1f',xy=(np.mean(xr), yr[0]+dln.valrange(yr)*0.05),ha='center')
        co = 'white'  #'lightgray' # blue
        ax.annotate('%s  %02d  %s  %6.1f  ' %
                    (meas['exposure'][ind2[i]], ccdnum[i],
                     meas['filter'][ind2[i]], expstr['exptime'][ind1[i]]),
                    xy=(np.mean(xr), yr[0] + dln.valrange(yr) * 0.05),
                    ha='center',
                    color=co)
        ax.annotate(
            '%10.2f  %10.2f  ' %
            (meas['mjd'][ind2[i]], meas['mjd'][ind2[i]] - np.min(meas['mjd'])),
            xy=(xr[0] + dln.valrange(xr) * 0.05,
                yr[1] - dln.valrange(yr) * 0.05),
            ha='left',
            color=co)
        ax.annotate('%s = %5.2f +/- %4.2f' %
                    (meas['filter'][ind2[i]], meas['mag_auto'][ind2[i]],
                     meas['magerr_auto'][ind2[i]]),
                    xy=(xr[1] - dln.valrange(xr) * 0.05,
                        yr[1] - dln.valrange(yr) * 0.05),
                    ha='right',
                    color=co)
        plt.savefig(figfile)
        print('Cutout written to ' + figfile)

        #import pdb; pdb.set_trace()

    # Make a single blank file at the end so you know it looped
    figfile = outdir
    figfile += '%s_%04d_%s.jpg' % (str(obj['id'][0]), i + 2, 'blank')
    figfiles.append(figfile)
    matplotlib.use('Agg')
    if os.path.exists(figfile): os.remove(figfile)
    fig = plt.gcf()  # get current graphics window
    fig.clf()  # clear
    figsize = 8.0  #6.0
    fig.set_figheight(figsize)
    fig.set_figwidth(figsize)
    plt.savefig(figfile)
    print(figfile)

    # Make the animated gif
    animfile = outdir + str(objid) + '_cutouts.gif'
    print('Creating animated gif ' + animfile)
    if os.path.exists(animfile): os.remove(animfile)
    #ret = subprocess.run('convert -delay 100 '+figdir+str(objid)+'_*.jpg '+animfile,shell=True)
    ret = subprocess.run('convert -delay 20 ' + ' '.join(figfiles) + ' ' +
                         animfile,
                         shell=True)
    #import pdb; pdb.set_trace()
    dln.remove(figfiles)
def main():

    parser = argparse.ArgumentParser(
        description="RGB predictions for Gaia EDR3 stars")
    parser.add_argument("ra_center",
                        help="right Ascension (decimal degrees)",
                        type=float)
    parser.add_argument("dec_center",
                        help="declination (decimal degrees)",
                        type=float)
    parser.add_argument("search_radius",
                        help="search radius (decimal degrees)",
                        type=float)
    parser.add_argument("g_limit",
                        help="limiting Gaia G magnitude",
                        type=float)
    parser.add_argument("--basename",
                        help="file basename for output files",
                        type=str,
                        default="rgbsearch")
    parser.add_argument(
        "--brightlimit",
        help=
        "stars brighter than this Gaia G limit are displayed with star symbols (default=8.0)",
        type=float,
        default=8.0)
    parser.add_argument(
        "--symbsize",
        help="multiplying factor for symbol size (default=1.0)",
        type=float,
        default=1.0)
    parser.add_argument("--nonumbers",
                        help="do not display star numbers in PDF chart",
                        action="store_true")
    parser.add_argument("--noplot",
                        help="skip PDF chart generation",
                        action="store_true")
    parser.add_argument("--nocolor",
                        help="do not use colors in PDF chart",
                        action="store_true")
    parser.add_argument("--starhorse_block",
                        help="number of stars/query (default=0, no query)",
                        default=0,
                        type=int)
    parser.add_argument("--verbose",
                        help="increase program verbosity",
                        action="store_true")
    parser.add_argument("--debug", help="debug flag", action="store_true")

    args = parser.parse_args()

    if len(sys.argv) == 1:
        parser.print_usage()
        raise SystemExit()

    if args.ra_center < 0 or args.ra_center > 360:
        raise SystemExit('ERROR: right ascension out of valid range')
    if args.dec_center < -90 or args.dec_center > 90:
        raise SystemExit('ERROR: declination out of valid range')
    if args.search_radius < 0:
        raise SystemExit('ERROR: search radius must be > 0 degrees')
    if args.search_radius > MAX_SEARCH_RADIUS:
        raise SystemExit(
            f'ERROR: search radius must be <= {MAX_SEARCH_RADIUS} degrees')

    # check whether the auxiliary FITS binary table exists
    if args.debug:
        auxbintable = RGB_FROM_GAIA_ALLSKY
    else:
        auxbintable = EDR3_SOURCE_ID_15M_ALLSKY
    if os.path.isfile(auxbintable):
        pass
    else:
        urldir = f'http://nartex.fis.ucm.es/~ncl/rgbphot/gaia/{auxbintable}'
        sys.stdout.write(f'Downloading {urldir}... (please wait)')
        sys.stdout.flush()
        urllib.request.urlretrieve(urldir, auxbintable)
        print(' ...OK!')

    # read the previous file
    try:
        with fits.open(auxbintable) as hdul_table:
            edr3_source_id_15M_allsky = hdul_table[1].data.source_id
            if args.debug:
                edr3_b_rgb_15M_allsky = hdul_table[1].data.B_rgb
                edr3_g_rgb_15M_allsky = hdul_table[1].data.G_rgb
                edr3_r_rgb_15M_allsky = hdul_table[1].data.R_rgb
                edr3_g_br_rgb_15M_allsky = hdul_table[1].data.G_BR_rgb
                edr3_g_gaia_15M_allsky = hdul_table[1].data.G_gaia
                edr3_bp_gaia_15M_allsky = hdul_table[1].data.BP_gaia
                edr3_rp_gaia_15M_allsky = hdul_table[1].data.RP_gaia
                edr3_av50_15M_allsky = hdul_table[1].data.av50
                edr3_met50_15M_allsky = hdul_table[1].data.met50
                edr3_dist50_15M_allsky = hdul_table[1].data.dist50
    except FileNotFoundError:
        raise SystemExit(
            f'ERROR: unexpected problem while reading {EDR3_SOURCE_ID_15M_ALLSKY}'
        )

    # define WCS
    naxis1 = 1024
    naxis2 = naxis1
    pixscale = 2 * args.search_radius / naxis1

    wcs_image = WCS(naxis=2)
    wcs_image.wcs.crpix = [naxis1 / 2, naxis2 / 2]
    wcs_image.wcs.crval = [args.ra_center, args.dec_center]
    wcs_image.wcs.cunit = ["deg", "deg"]
    wcs_image.wcs.ctype = ["RA---TAN", "DEC--TAN"]
    wcs_image.wcs.cdelt = [-pixscale, pixscale]
    wcs_image.array_shape = [naxis1, naxis2]
    if args.verbose:
        print(wcs_image)

    # ---

    # EDR3 query
    query = f"""
    SELECT source_id, ra, dec,
    phot_g_mean_mag, phot_bp_mean_mag, phot_rp_mean_mag

    FROM gaiaedr3.gaia_source
    WHERE 1=CONTAINS(
      POINT('ICRS', {args.ra_center}, {args.dec_center}), 
      CIRCLE('ICRS',ra, dec, {args.search_radius}))
    AND phot_g_mean_mag IS NOT NULL 
    AND phot_bp_mean_mag IS NOT NULL 
    AND phot_rp_mean_mag IS NOT NULL
    AND phot_g_mean_mag < {args.g_limit}
    
    ORDER BY ra
    """
    sys.stdout.write(
        '<STEP1> Starting cone search in Gaia EDR3... (please wait)\n  ')
    sys.stdout.flush()
    job = Gaia.launch_job_async(query)
    r_edr3 = job.get_results()
    # compute G_BP - G_RP colour
    r_edr3.add_column(
        Column(r_edr3['phot_bp_mean_mag'] - r_edr3['phot_rp_mean_mag'],
               name='bp_rp',
               unit=u.mag))
    # colour cut in BP-RP
    mask_colour = np.logical_or((r_edr3['bp_rp'] <= -0.5),
                                (r_edr3['bp_rp'] >= 2.0))
    r_edr3_colorcut = r_edr3[mask_colour]
    nstars = len(r_edr3)
    print(f'        --> {nstars} stars found')
    nstars_colorcut = len(r_edr3_colorcut)
    print(
        f'        --> {nstars_colorcut} stars outside -0.5 < G_BP-G_RP < 2.0')
    if nstars == 0:
        raise SystemExit('ERROR: no stars found. Change search parameters!')
    if args.verbose:
        r_edr3.pprint(max_width=1000)

    # ---

    # intersection with StarHorse star sample
    if args.starhorse_block > 0:
        param_starhorse = [
            'dr3_source_id', 'sh_gaiaflag', 'sh_outflag', 'dist05', 'dist16',
            'dist50', 'dist84', 'dist95', 'av05', 'av16', 'av50', 'av84',
            'av95', 'teff16', 'teff50', 'teff84', 'logg16', 'logg50', 'logg84',
            'met16', 'met50', 'met84', 'mass16', 'mass50', 'mass84', 'xgal',
            'ygal', 'zgal', 'rgal', 'ruwe', 'angular_distance',
            'magnitude_difference', 'proper_motion_propagation',
            'dup_max_number'
        ]
        print(
            '<STEP2> Retrieving StarHorse data from Gaia@AIP... (please wait)')
        print(f'        pyvo version {pyvo.__version__}')
        print(f'        TAP service GAIA@AIP')
        nstars_per_block = args.starhorse_block
        nblocks = int(nstars / nstars_per_block)
        r_starhorse = None
        if nstars - nblocks * nstars_per_block > 0:
            nblocks += 1
        for iblock in range(nblocks):
            irow1 = iblock * nstars_per_block
            irow2 = min(irow1 + nstars_per_block, nstars)
            print(f'        Starting query #{iblock+1} of {nblocks}...')
            dumstr = ','.join(
                [str(item) for item in r_edr3[irow1:irow2]['source_id']])
            query = f"""
            SELECT {','.join(param_starhorse)}
            FROM gaiadr2_contrib.starhorse
            WHERE dr3_source_id IN ({dumstr})
            """
            tap_session = requests.Session()
            tap_session.headers['Authorization'] = "kkk"
            tap_service = pyvo.dal.TAPService('https://gaia.aip.de/tap',
                                              session=tap_session)
            tap_result = tap_service.run_sync(query)
            if args.debug:
                print(tap_result.to_table())
            if iblock == 0:
                r_starhorse = tap_result.to_table()
            else:
                r_starhorse = vstack(
                    [r_starhorse, tap_result.to_table()],
                    join_type='exact',
                    metadata_conflicts='silent')

        nstars_starhorse = len(r_starhorse)
        if args.verbose:
            if nstars_starhorse > 0:
                r_starhorse.pprint(max_width=1000)

        # join tables
        print(f'        --> {nstars_starhorse} stars found in StarHorse')
        print('        Joining EDR3 and StarHorse queries...')
        r_starhorse.rename_column('dr3_source_id', 'source_id')
        r_edr3 = join(r_edr3, r_starhorse, keys='source_id', join_type='outer')
        r_edr3.sort('ra')
        if args.verbose:
            r_edr3.pprint(max_width=1000)

    else:
        print('<STEP2> Retrieving StarHorse data from Gaia@AIP... (skipped!)')

    # ---

    # intersection with 15M star sample
    sys.stdout.write(
        '<STEP3> Cross-matching EDR3 with 15M subsample... (please wait)')
    sys.stdout.flush()
    set1 = set(np.array(r_edr3['source_id']))
    set2 = set(edr3_source_id_15M_allsky)
    intersection = set2.intersection(set1)
    print(f'\n        --> {len(intersection)} stars in common with 15M sample')
    if args.verbose:
        print(len(set1), len(set2), len(intersection))

    # ---

    # DR2 query to identify variable stars
    query = f"""
    SELECT source_id, ra, dec, phot_g_mean_mag, phot_variable_flag

    FROM gaiadr2.gaia_source
    WHERE  1=CONTAINS(
      POINT('ICRS', {args.ra_center}, {args.dec_center}), 
      CIRCLE('ICRS',ra, dec, {args.search_radius}))
    AND phot_g_mean_mag < {args.g_limit}
    """
    sys.stdout.write(
        '<STEP4> Looking for variable stars in Gaia DR2... (please wait)\n  ')
    sys.stdout.flush()
    job = Gaia.launch_job_async(query)
    r_dr2 = job.get_results()
    nstars_dr2 = len(r_dr2)
    if nstars_dr2 == 0:
        nvariables = 0
        mask_var = None
    else:
        if isinstance(r_dr2['phot_variable_flag'][0], bytes):
            mask_var = r_dr2['phot_variable_flag'] == b'VARIABLE'
        elif isinstance(r_dr2['phot_variable_flag'][0], str):
            mask_var = r_dr2['phot_variable_flag'] == 'VARIABLE'
        else:
            raise SystemExit(
                'Unexpected type of data in column phot_variable_flag')
        nvariables = sum(mask_var)
        print(
            f'        --> {nstars_dr2} stars in DR2, ({nvariables} initial variables)'
        )
    if nvariables > 0:
        if args.verbose:
            r_dr2[mask_var].pprint(max_width=1000)

    # ---

    # cross-match between DR2 and EDR3 to identify the variable stars
    dumstr = '('
    if nvariables > 0:
        # generate sequence of source_id of variable stars
        dumstr = ','.join([str(item) for item in r_dr2[mask_var]['source_id']])
        # cross-match
        query = f"""
        SELECT *
        FROM gaiaedr3.dr2_neighbourhood
        WHERE dr2_source_id IN ({dumstr})
        ORDER BY angular_distance
        """
        sys.stdout.write(
            '<STEP5> Cross-matching variables in DR2 with stars in EDR3... (please wait)\n  '
        )
        sys.stdout.flush()
        job = Gaia.launch_job_async(query)
        r_cross_var = job.get_results()
        if args.verbose:
            r_cross_var.pprint(max_width=1000)
        nvariables = len(r_cross_var)
        if nvariables > 0:
            # check that the variables pass the same selection as the EDR3 stars
            # (this includes de colour cut)
            mask_var = []
            for item in r_cross_var['dr3_source_id']:
                if item in r_edr3['source_id']:
                    mask_var.append(True)
                else:
                    mask_var.append(False)
            r_cross_var = r_cross_var[mask_var]
            nvariables = len(r_cross_var)
            if args.verbose:
                r_cross_var.pprint(max_width=1000)
        else:
            r_cross_var = None
    else:
        r_cross_var = None  # Avoid PyCharm warning
    print(f'        --> {nvariables} variable(s) in selected EDR3 star sample')

    # ---

    sys.stdout.write('<STEP6> Computing RGB magnitudes...')
    sys.stdout.flush()
    # predict RGB magnitudes
    coef_B = np.array([
        -0.13748689, 0.44265552, 0.37878846, -0.14923841, 0.09172474,
        -0.02594726
    ])
    coef_G = np.array([
        -0.02330159, 0.12884074, 0.22149167, -0.1455048, 0.10635149, -0.0236399
    ])
    coef_R = np.array([
        0.10979647, -0.14579334, 0.10747392, -0.1063592, 0.08494556,
        -0.01368962
    ])
    coef_X = np.array([
        -0.01252185, 0.13983574, 0.23688188, -0.10175532, 0.07401939,
        -0.0182115
    ])

    poly_B = Polynomial(coef_B)
    poly_G = Polynomial(coef_G)
    poly_R = Polynomial(coef_R)
    poly_X = Polynomial(coef_X)

    r_edr3.add_column(Column(np.round(
        r_edr3['phot_g_mean_mag'] + poly_B(r_edr3['bp_rp']), 2),
                             name='b_rgb',
                             unit=u.mag,
                             format='.2f'),
                      index=3)
    r_edr3.add_column(Column(np.round(
        r_edr3['phot_g_mean_mag'] + poly_G(r_edr3['bp_rp']), 2),
                             name='g_rgb',
                             unit=u.mag,
                             format='.2f'),
                      index=4)
    r_edr3.add_column(Column(np.round(
        r_edr3['phot_g_mean_mag'] + poly_R(r_edr3['bp_rp']), 2),
                             name='r_rgb',
                             unit=u.mag,
                             format='.2f'),
                      index=5)
    r_edr3.add_column(Column(np.round(
        r_edr3['phot_g_mean_mag'] + poly_X(r_edr3['bp_rp']), 2),
                             name='g_br_rgb',
                             unit=u.mag,
                             format='.2f'),
                      index=6)
    print('OK')
    if args.verbose:
        r_edr3.pprint(max_width=1000)

    # ---

    sys.stdout.write('<STEP7> Saving output CSV files...')
    sys.stdout.flush()
    outtypes = ['edr3', '15m', 'var']
    outtypes_color = {'edr3': 'black', '15m': 'red', 'var': 'blue'}
    r_edr3.add_column(
        Column(np.zeros(len(r_edr3)), name='number_csv', dtype=int))
    for item in outtypes:
        r_edr3.add_column(
            Column(np.zeros(len(r_edr3)), name=f'number_{item}', dtype=int))
    outlist = [f'./{args.basename}_{ftype}.csv' for ftype in outtypes]
    filelist = glob.glob('./*.csv')
    # remove previous versions of the output files (if present)
    for file in outlist:
        if file in filelist:
            try:
                os.remove(file)
            except:
                print(f'ERROR: while deleting existing file {file}')
    # columns to be saved (use a list to guarantee the same order)
    outcolumns_list = [
        'source_id', 'ra', 'dec', 'b_rgb', 'g_rgb', 'r_rgb', 'g_br_rgb',
        'phot_g_mean_mag', 'phot_bp_mean_mag', 'phot_rp_mean_mag'
    ]
    # define column format with a dictionary
    outcolumns = {
        'source_id': '19d',
        'ra': '14.9f',
        'dec': '14.9f',
        'b_rgb': '6.2f',
        'g_rgb': '6.2f',
        'r_rgb': '6.2f',
        'g_br_rgb': '6.2f',
        'phot_g_mean_mag': '8.4f',
        'phot_bp_mean_mag': '8.4f',
        'phot_rp_mean_mag': '8.4f'
    }
    if set(outcolumns_list) != set(outcolumns.keys()):
        raise SystemExit('ERROR: check outcolumns_list and outcolumns')
    csv_header_ini = 'number,' + ','.join(outcolumns_list)
    flist = []
    for ftype in outtypes:
        f = open(f'{args.basename}_{ftype}.csv', 'wt')
        flist.append(f)
        if (args.starhorse_block > 0) and (ftype in ['edr3', '15m']):
            if args.debug and (ftype == '15m'):
                csv_header = csv_header_ini + \
                             ',av50,met50,dist50,b_rgb_bis,g_rgb_bis,r_rgb_bis,g_br_rgb_bis,' \
                             'phot_g_mean_mag_bis,phot_bp_mean_mag_bis,phot_rp_mean_mag_bis,' \
                             'av50_bis,met50_bis,dist50_bis'
            else:
                csv_header = csv_header_ini + ',av50,met50,dist50'
        else:
            csv_header = csv_header_ini
        f.write(csv_header + '\n')
    # save each star in its corresponding output file
    krow = np.ones(len(outtypes), dtype=int)
    for irow, row in enumerate(r_edr3):
        cout = []
        for item in outcolumns_list:
            cout.append(eval("f'{row[item]:" + f'{outcolumns[item]}' + "}'"))
        iout = 0
        if nvariables > 0:
            if row['source_id'] in r_cross_var['dr3_source_id']:
                iout = 2
        if iout == 0:
            if args.starhorse_block > 0:
                for item in ['av50', 'met50', 'dist50']:
                    value = row[item]
                    if isinstance(value, float):
                        pass
                    else:
                        value = 99.999
                    cout.append(f'{value:7.3f}')
            if row['source_id'] in intersection:
                iout = 1
                if args.debug:
                    iloc = np.argwhere(
                        edr3_source_id_15M_allsky == row['source_id'])[0][0]
                    cout.append(f"{edr3_b_rgb_15M_allsky[iloc]:6.2f}")
                    cout.append(f"{edr3_g_rgb_15M_allsky[iloc]:6.2f}")
                    cout.append(f"{edr3_r_rgb_15M_allsky[iloc]:6.2f}")
                    cout.append(f"{edr3_g_br_rgb_15M_allsky[iloc]:6.2f}")
                    cout.append(f"{edr3_g_gaia_15M_allsky[iloc]:8.4f}")
                    cout.append(f"{edr3_bp_gaia_15M_allsky[iloc]:8.4f}")
                    cout.append(f"{edr3_rp_gaia_15M_allsky[iloc]:8.4f}")
                    cout.append(f"{edr3_av50_15M_allsky[iloc]:7.3f}")
                    cout.append(f"{edr3_met50_15M_allsky[iloc]:7.3f}")
                    cout.append(f"{edr3_dist50_15M_allsky[iloc]:7.3f}")
        flist[iout].write(f'{krow[iout]:6d}, ' + ','.join(cout) + '\n')
        r_edr3[irow]['number_csv'] = iout
        r_edr3[irow][f'number_{outtypes[iout]}'] = krow[iout]
        krow[iout] += 1
    for f in flist:
        f.close()
    print('OK')

    if args.verbose:
        print(r_edr3)

    if args.noplot:
        raise SystemExit()

    # ---

    sys.stdout.write('<STEP8> Generating PDF plot...')
    sys.stdout.flush()
    # generate plot
    r_edr3.sort('phot_g_mean_mag')
    if args.verbose:
        print('')
        r_edr3.pprint(max_width=1000)

    symbol_size = args.symbsize * (50 /
                                   np.array(r_edr3['phot_g_mean_mag']))**2.5
    ra_array = np.array(r_edr3['ra'])
    dec_array = np.array(r_edr3['dec'])

    c = SkyCoord(ra=ra_array * u.degree,
                 dec=dec_array * u.degree,
                 frame='icrs')
    x_pix, y_pix = wcs_image.world_to_pixel(c)

    fig = plt.figure(figsize=(13, 10))
    ax = plt.subplot(projection=wcs_image)
    iok = r_edr3['phot_g_mean_mag'] < args.brightlimit
    if args.nocolor:
        sc = ax.scatter(x_pix[iok],
                        y_pix[iok],
                        marker='*',
                        color='grey',
                        edgecolors='black',
                        linewidth=0.2,
                        s=symbol_size[iok])
        ax.scatter(x_pix[~iok],
                   y_pix[~iok],
                   marker='.',
                   color='grey',
                   edgecolors='black',
                   linewidth=0.2,
                   s=symbol_size[~iok])
    else:
        cmap = plt.cm.get_cmap('jet')
        sc = ax.scatter(x_pix[iok],
                        y_pix[iok],
                        marker='*',
                        edgecolors='black',
                        linewidth=0.2,
                        s=symbol_size[iok],
                        cmap=cmap,
                        c=r_edr3[iok]['bp_rp'],
                        vmin=-0.5,
                        vmax=2.0)
        ax.scatter(x_pix[~iok],
                   y_pix[~iok],
                   marker='.',
                   edgecolors='black',
                   linewidth=0.2,
                   s=symbol_size[~iok],
                   cmap=cmap,
                   c=r_edr3[~iok]['bp_rp'],
                   vmin=-0.5,
                   vmax=2.0)

    # display numbers if requested
    if not args.nonumbers:
        for irow in range(len(r_edr3)):
            number_csv = r_edr3[irow]['number_csv']
            text = r_edr3[irow][f'number_{outtypes[number_csv]}']
            ax.text(x_pix[irow],
                    y_pix[irow],
                    text,
                    color=outtypes_color[outtypes[number_csv]],
                    fontsize='5',
                    horizontalalignment='left',
                    verticalalignment='bottom')

    # stars outside the -0.5 < G_BP - G_RP < 2.0 colour cut
    if nstars_colorcut > 0:
        mask_colour = np.logical_or((r_edr3['bp_rp'] <= -0.5),
                                    (r_edr3['bp_rp'] >= 2.0))
        iok = np.argwhere(mask_colour)
        ax.scatter(x_pix[iok],
                   y_pix[iok],
                   s=240,
                   marker='D',
                   facecolors='none',
                   edgecolors='grey',
                   linewidth=0.5)

    # variable stars
    if nvariables > 0:
        sorter = np.argsort(r_edr3['source_id'])
        iok = np.array(sorter[np.searchsorted(r_edr3['source_id'],
                                              r_cross_var['dr3_source_id'],
                                              sorter=sorter)])
        ax.scatter(x_pix[iok],
                   y_pix[iok],
                   s=240,
                   marker='s',
                   facecolors='none',
                   edgecolors='blue',
                   linewidth=0.5)

    # stars in 15M sample
    if len(intersection) > 0:
        sorter = np.argsort(r_edr3['source_id'])
        iok = np.array(sorter[np.searchsorted(r_edr3['source_id'],
                                              np.array(list(intersection)),
                                              sorter=sorter)])
        ax.scatter(x_pix[iok],
                   y_pix[iok],
                   s=240,
                   marker='o',
                   facecolors='none',
                   edgecolors=outtypes_color['15m'],
                   linewidth=0.5)

    ax.scatter(0.03,
               0.96,
               s=240,
               marker='o',
               facecolors='white',
               edgecolors=outtypes_color['15m'],
               linewidth=0.5,
               transform=ax.transAxes)
    ax.text(0.06,
            0.96,
            'star in 15M sample',
            fontsize=12,
            backgroundcolor='white',
            horizontalalignment='left',
            verticalalignment='center',
            transform=ax.transAxes)

    ax.scatter(0.03,
               0.92,
               s=240,
               marker='s',
               facecolors='white',
               edgecolors=outtypes_color['var'],
               linewidth=0.5,
               transform=ax.transAxes)
    ax.text(0.06,
            0.92,
            'variable in Gaia DR2',
            fontsize=12,
            backgroundcolor='white',
            horizontalalignment='left',
            verticalalignment='center',
            transform=ax.transAxes)

    ax.scatter(0.03,
               0.88,
               s=240,
               marker='D',
               facecolors='white',
               edgecolors='grey',
               linewidth=0.5,
               transform=ax.transAxes)
    ax.text(0.06,
            0.88,
            'outside colour range',
            fontsize=12,
            backgroundcolor='white',
            horizontalalignment='left',
            verticalalignment='center',
            transform=ax.transAxes)

    ax.set_xlabel('ra')
    ax.set_ylabel('dec')

    ax.set_aspect('equal')

    if not args.nocolor:
        cbaxes = fig.add_axes([0.683, 0.81, 0.15, 0.02])
        cbar = plt.colorbar(sc,
                            cax=cbaxes,
                            orientation='horizontal',
                            format='%1.0f')
        cbar.ax.tick_params(labelsize=12)
        cbar.set_label(label=r'$G_{\rm BP}-G_{\rm RP}$',
                       size=12,
                       backgroundcolor='white')

    ax.text(0.98,
            0.96,
            f'Field radius: {args.search_radius:.4f} degree',
            fontsize=12,
            backgroundcolor='white',
            horizontalalignment='right',
            verticalalignment='center',
            transform=ax.transAxes)
    ax.text(0.02,
            0.06,
            r'$\alpha_{\rm center}$:',
            fontsize=12,
            backgroundcolor='white',
            horizontalalignment='left',
            verticalalignment='bottom',
            transform=ax.transAxes)
    ax.text(0.25,
            0.06,
            f'{args.ra_center:.4f} degree',
            fontsize=12,
            backgroundcolor='white',
            horizontalalignment='right',
            verticalalignment='bottom',
            transform=ax.transAxes)
    ax.text(0.02,
            0.02,
            r'$\delta_{\rm center}$:',
            fontsize=12,
            backgroundcolor='white',
            horizontalalignment='left',
            verticalalignment='bottom',
            transform=ax.transAxes)
    ax.text(0.25,
            0.02,
            f'{args.dec_center:+.4f} degree',
            fontsize=12,
            backgroundcolor='white',
            horizontalalignment='right',
            verticalalignment='bottom',
            transform=ax.transAxes)
    ax.text(0.98,
            0.02,
            f'RGBfromGaiaEDR3, version {VERSION}',
            fontsize=12,
            backgroundcolor='white',
            horizontalalignment='right',
            verticalalignment='bottom',
            transform=ax.transAxes)

    f = np.pi / 180
    xp = naxis1 / 2 + args.search_radius / pixscale * np.cos(
        np.arange(361) * f)
    yp = naxis2 / 2 + args.search_radius / pixscale * np.sin(
        np.arange(361) * f)
    ax.plot(xp, yp, '-', color='orange', linewidth=0.5, alpha=0.5)

    ax.set_xlim([-naxis1 * 0.12, naxis1 * 1.12])
    ax.set_ylim([-naxis2 * 0.05, naxis2 * 1.05])

    ax.set_axisbelow(True)
    overlay = ax.get_coords_overlay('icrs')
    overlay.grid(color='black', ls='dotted')

    plt.savefig(f'{args.basename}.pdf')
    plt.close(fig)
    if args.verbose:
        pass
    else:
        print('OK')