Beispiel #1
0
def unwise_forcedphot(cat, tiles, bands=[1, 2, 3, 4], roiradecbox=None,
                      unwise_dir='.',
                      use_ceres=True, ceres_block=8,
                      save_fits=False, ps=None,
                      psf_broadening=None):
    '''
    Given a list of tractor sources *cat*
    and a list of unWISE tiles *tiles* (a fits_table with RA,Dec,coadd_id)
    runs forced photometry, returning a FITS table the same length as *cat*.
    '''

    # Severely limit sizes of models
    for src in cat:
        if isinstance(src, PointSource):
            src.fixedRadius = 10
        else:
            src.halfsize = 10

    wantims = ((ps is not None) or save_fits)
    wanyband = 'w'

    fskeys = ['prochi2', 'pronpix', 'profracflux', 'proflux', 'npix',
              'pronexp']

    Nsrcs = len(cat)
    phot = fits_table()
    phot.tile = np.array(['        '] * Nsrcs)

    ra = np.array([src.getPosition().ra for src in cat])
    dec = np.array([src.getPosition().dec for src in cat])

    for band in bands:
        print('Photometering WISE band', band)
        wband = 'w%i' % band

        # The tiles have some overlap, so for each source, keep the
        # fit in the tile whose center is closest to the source.
        tiledists = np.empty(Nsrcs)
        tiledists[:] = 1e100
        flux_invvars = np.zeros(Nsrcs, np.float32)
        fitstats = dict([(k, np.zeros(Nsrcs, np.float32)) for k in fskeys])
        nexp = np.zeros(Nsrcs, np.int16)
        mjd = np.zeros(Nsrcs, np.float64)

        for tile in tiles:
            print('Reading tile', tile.coadd_id)

            tim = get_unwise_tractor_image(unwise_dir, tile.coadd_id, band,
                                           bandname=wanyband, roiradecbox=roiradecbox)
            if tim is None:
                print('Actually, no overlap with tile', tile.coadd_id)
                continue

            if psf_broadening is not None:
                # psf_broadening is a factor by which the PSF FWHMs
                # should be scaled; the PSF is a little wider
                # post-reactivation.
                psf = tim.getPsf()
                from tractor import GaussianMixturePSF
                if isinstance(psf, GaussianMixturePSF):
                    #
                    print('Broadening PSF: from', psf)
                    p0 = psf.getParams()
                    #print('Params:', p0)
                    pnames = psf.getParamNames()
                    #print('Param names:', pnames)
                    p1 = [p * psf_broadening**2 if 'var' in name else p
                          for (p, name) in zip(p0, pnames)]
                    #print('Broadened:', p1)
                    psf.setParams(p1)
                    print('Broadened PSF:', psf)
                else:
                    print(
                        'WARNING: cannot apply psf_broadening to WISE PSF of type', type(psf))

            print('Read image with shape', tim.shape)

            # Select sources in play.
            wcs = tim.wcs.wcs
            H, W = tim.shape
            ok, x, y = wcs.radec2pixelxy(ra, dec)
            x = (x - 1.).astype(np.float32)
            y = (y - 1.).astype(np.float32)
            margin = 10.
            I = np.flatnonzero((x >= -margin) * (x < W + margin) *
                               (y >= -margin) * (y < H + margin))
            print(len(I), 'within the image + margin')

            inbox = ((x[I] >= -0.5) * (x[I] < (W - 0.5)) *
                     (y[I] >= -0.5) * (y[I] < (H - 0.5)))
            print(sum(inbox), 'strictly within the image')

            # Compute L_inf distance to (full) tile center.
            tilewcs = unwise_tile_wcs(tile.ra, tile.dec)
            cx, cy = tilewcs.crpix
            ok, tx, ty = tilewcs.radec2pixelxy(ra[I], dec[I])
            td = np.maximum(np.abs(tx - cx), np.abs(ty - cy))
            closest = (td < tiledists[I])
            tiledists[I[closest]] = td[closest]

            keep = inbox * closest

            # Source indices (in the full "cat") to keep (the fit values for)
            srci = I[keep]

            if not len(srci):
                print('No sources to be kept; skipping.')
                continue

            phot.tile[srci] = tile.coadd_id
            nexp[srci] = tim.nuims[np.clip(np.round(y[srci]).astype(int), 0, H - 1),
                                   np.clip(np.round(x[srci]).astype(int), 0, W - 1)]

            # Source indices in the margins
            margi = I[np.logical_not(keep)]

            # sources in the box -- at the start of the subcat list.
            subcat = [cat[i] for i in srci]

            # include *copies* of sources in the margins
            # (that way we automatically don't save the results)
            subcat.extend([cat[i].copy() for i in margi])
            assert(len(subcat) == len(I))

            # FIXME -- set source radii, ...?

            minsb = 0.
            fitsky = False

            # Look in image and set radius based on peak height??

            tractor = Tractor([tim], subcat)
            if use_ceres:
                from tractor.ceres_optimizer import CeresOptimizer
                tractor.optimizer = CeresOptimizer(BW=ceres_block,
                                                   BH=ceres_block)
            tractor.freezeParamsRecursive('*')
            tractor.thawPathsTo(wanyband)

            kwa = dict(fitstat_extras=[('pronexp', [tim.nims])])
            t0 = Time()

            R = tractor.optimize_forced_photometry(
                minsb=minsb, mindlnp=1., sky=fitsky, fitstats=True,
                variance=True, shared_params=False,
                wantims=wantims, **kwa)
            print('unWISE forced photometry took', Time() - t0)

            if use_ceres:
                term = R.ceres_status['termination']
                print('Ceres termination status:', term)
                # Running out of memory can cause failure to converge
                # and term status = 2.
                # Fail completely in this case.
                if term != 0:
                    raise RuntimeError(
                        'Ceres terminated with status %i' % term)

            if wantims:
                ims0 = R.ims0
                ims1 = R.ims1
            IV, fs = R.IV, R.fitstats

            if save_fits:
                import fitsio
                (dat, mod, ie, chi, roi) = ims1[0]
                wcshdr = fitsio.FITSHDR()
                tim.wcs.wcs.add_to_header(wcshdr)

                tag = 'fit-%s-w%i' % (tile.coadd_id, band)
                fitsio.write('%s-data.fits' %
                             tag, dat, clobber=True, header=wcshdr)
                fitsio.write('%s-mod.fits' % tag,  mod,
                             clobber=True, header=wcshdr)
                fitsio.write('%s-chi.fits' % tag,  chi,
                             clobber=True, header=wcshdr)

            if ps:
                tag = '%s W%i' % (tile.coadd_id, band)
                (dat, mod, ie, chi, roi) = ims1[0]

                sig1 = tim.sig1
                plt.clf()
                plt.imshow(dat, interpolation='nearest', origin='lower',
                           cmap='gray', vmin=-3 * sig1, vmax=10 * sig1)
                plt.colorbar()
                plt.title('%s: data' % tag)
                ps.savefig()

                plt.clf()
                plt.imshow(mod, interpolation='nearest', origin='lower',
                           cmap='gray', vmin=-3 * sig1, vmax=10 * sig1)
                plt.colorbar()
                plt.title('%s: model' % tag)
                ps.savefig()

                plt.clf()
                plt.imshow(chi, interpolation='nearest', origin='lower',
                           cmap='gray', vmin=-5, vmax=+5)
                plt.colorbar()
                plt.title('%s: chi' % tag)
                ps.savefig()

            # Save results for this tile.
            # the "keep" sources are at the beginning of the "subcat" list
            flux_invvars[srci] = IV[:len(srci)].astype(np.float32)
            if hasattr(tim, 'mjdmin') and hasattr(tim, 'mjdmax'):
                mjd[srci] = (tim.mjdmin + tim.mjdmax) / 2.
            if fs is None:
                continue
            for k in fskeys:
                x = getattr(fs, k)
                # fitstats are returned only for un-frozen sources
                fitstats[k][srci] = np.array(x).astype(np.float32)[:len(srci)]

        # Note, this is *outside* the loop over tiles.
        # The fluxes are saved in the source objects, and will be set based on
        # the 'tiledists' logic above.
        nm = np.array([src.getBrightness().getBand(wanyband) for src in cat])
        nm_ivar = flux_invvars

        # Sources out of bounds, eg, never change from their default
        # (1-sigma or whatever) initial fluxes.  Zero them out instead.
        nm[nm_ivar == 0] = 0.

        phot.set(wband + '_nanomaggies', nm.astype(np.float32))
        phot.set(wband + '_nanomaggies_ivar', nm_ivar)
        dnm = np.zeros(len(nm_ivar), np.float32)
        okiv = (nm_ivar > 0)
        dnm[okiv] = (1. / np.sqrt(nm_ivar[okiv])).astype(np.float32)
        okflux = (nm > 0)
        mag = np.zeros(len(nm), np.float32)
        mag[okflux] = (NanoMaggies.nanomaggiesToMag(nm[okflux])
                       ).astype(np.float32)
        dmag = np.zeros(len(nm), np.float32)
        ok = (okiv * okflux)
        dmag[ok] = (np.abs((-2.5 / np.log(10.)) * dnm[ok] / nm[ok])
                    ).astype(np.float32)
        mag[np.logical_not(okflux)] = np.nan
        dmag[np.logical_not(ok)] = np.nan

        phot.set(wband + '_mag', mag)
        phot.set(wband + '_mag_err', dmag)
        for k in fskeys:
            phot.set(wband + '_' + k, fitstats[k])
        phot.set(wband + '_nexp', nexp)

        if not np.all(mjd == 0):
            phot.set(wband + '_mjd', mjd)

    return phot
Beispiel #2
0
def unwise_forcedphot(cat,
                      tiles,
                      band=1,
                      roiradecbox=None,
                      use_ceres=True,
                      ceres_block=8,
                      save_fits=False,
                      get_models=False,
                      ps=None,
                      psf_broadening=None,
                      pixelized_psf=False,
                      get_masks=None,
                      move_crpix=False,
                      modelsky_dir=None,
                      tag=None):
    '''
    Given a list of tractor sources *cat*
    and a list of unWISE tiles *tiles* (a fits_table with RA,Dec,coadd_id)
    runs forced photometry, returning a FITS table the same length as *cat*.

    *get_masks*: the WCS to resample mask bits into.
    '''
    from tractor import PointSource, Tractor, ExpGalaxy, DevGalaxy
    from tractor.sersic import SersicGalaxy

    if tag is None:
        tag = ''
    else:
        tag = tag + ': '
    if not pixelized_psf and psf_broadening is None:
        # PSF broadening in post-reactivation data, by band.
        # Newer version from Aaron's email to decam-chatter, 2018-06-14.
        broadening = {1: 1.0405, 2: 1.0346, 3: None, 4: None}
        psf_broadening = broadening[band]

    if False:
        from astrometry.util.plotutils import PlotSequence
        ps = PlotSequence('wise-forced-w%i' % band)
    plots = (ps is not None)
    if plots:
        import pylab as plt

    wantims = (plots or save_fits or get_models)
    wanyband = 'w'
    if get_models:
        models = []

    wband = 'w%i' % band

    Nsrcs = len(cat)
    phot = fits_table()
    # Filled in based on unique tile overlap
    phot.wise_coadd_id = np.array(['        '] * Nsrcs, dtype='U8')
    phot.wise_x = np.zeros(Nsrcs, np.float32)
    phot.wise_y = np.zeros(Nsrcs, np.float32)
    phot.set('psfdepth_%s' % wband, np.zeros(Nsrcs, np.float32))
    nexp = np.zeros(Nsrcs, np.int16)
    mjd = np.zeros(Nsrcs, np.float64)
    central_flux = np.zeros(Nsrcs, np.float32)

    ra = np.array([src.getPosition().ra for src in cat])
    dec = np.array([src.getPosition().dec for src in cat])

    fskeys = ['prochi2', 'profracflux']
    fitstats = {}

    if get_masks:
        mh, mw = get_masks.shape
        maskmap = np.zeros((mh, mw), np.uint32)

    tims = []
    for tile in tiles:
        info(tag + 'Reading WISE tile', tile.coadd_id, 'band', band)
        tim = get_unwise_tractor_image(tile.unwise_dir,
                                       tile.coadd_id,
                                       band,
                                       bandname=wanyband,
                                       roiradecbox=roiradecbox)
        if tim is None:
            debug('Actually, no overlap with WISE coadd tile', tile.coadd_id)
            continue

        if plots:
            sig1 = tim.sig1
            plt.clf()
            plt.imshow(tim.getImage(),
                       interpolation='nearest',
                       origin='lower',
                       cmap='gray',
                       vmin=-3 * sig1,
                       vmax=10 * sig1)
            plt.colorbar()
            tag = '%s W%i' % (tile.coadd_id, band)
            plt.title('%s: tim data' % tag)
            ps.savefig()
            plt.clf()
            plt.hist((tim.getImage() * tim.inverr)[tim.inverr > 0].ravel(),
                     range=(-5, 10),
                     bins=100)
            plt.xlabel('Per-pixel intensity (Sigma)')
            plt.title(tag)
            ps.savefig()

        if move_crpix and band in [1, 2]:
            realwcs = tim.wcs.wcs
            x, y = realwcs.crpix
            tile_crpix = tile.get('crpix_w%i' % band)
            dx = tile_crpix[0] - 1024.5
            dy = tile_crpix[1] - 1024.5
            realwcs.set_crpix(x + dx, y + dy)
            debug('unWISE', tile.coadd_id, 'band', band, 'CRPIX', x, y,
                  'shift by', dx, dy, 'to', realwcs.crpix)

        if modelsky_dir and band in [1, 2]:
            fn = os.path.join(modelsky_dir,
                              '%s.%i.mod.fits' % (tile.coadd_id, band))
            if not os.path.exists(fn):
                raise RuntimeError('WARNING: does not exist:', fn)
            x0, x1, y0, y1 = tim.roi
            bg = fitsio.FITS(fn)[2][y0:y1, x0:x1]
            assert (bg.shape == tim.shape)

            if plots:
                plt.clf()
                plt.subplot(1, 2, 1)
                plt.imshow(tim.getImage(),
                           interpolation='nearest',
                           origin='lower',
                           cmap='gray',
                           vmin=-3 * sig1,
                           vmax=5 * sig1)
                plt.subplot(1, 2, 2)
                plt.imshow(bg,
                           interpolation='nearest',
                           origin='lower',
                           cmap='gray',
                           vmin=-3 * sig1,
                           vmax=5 * sig1)
                tag = '%s W%i' % (tile.coadd_id, band)
                plt.suptitle(tag)
                ps.savefig()
                plt.clf()
                ha = dict(range=(-5, 10), bins=100, histtype='step')
                plt.hist((tim.getImage() * tim.inverr)[tim.inverr > 0].ravel(),
                         color='b',
                         label='Original',
                         **ha)
                plt.hist(((tim.getImage() - bg) *
                          tim.inverr)[tim.inverr > 0].ravel(),
                         color='g',
                         label='Minus Background',
                         **ha)
                plt.axvline(0, color='k', alpha=0.5)
                plt.xlabel('Per-pixel intensity (Sigma)')
                plt.legend()
                plt.title(tag + ': background')
                ps.savefig()

            # Actually subtract the background!
            tim.data -= bg

        # Floor the per-pixel variances,
        # and add Poisson contribution from sources
        if band in [1, 2]:
            # in Vega nanomaggies per pixel
            floor_sigma = {1: 0.5, 2: 2.0}
            poissons = {1: 0.15, 2: 0.3}
            with np.errstate(divide='ignore'):
                new_ie = 1. / np.sqrt(
                    (1. / tim.inverr)**2 + floor_sigma[band] +
                    poissons[band]**2 * np.maximum(0., tim.data))
            new_ie[tim.inverr == 0] = 0.

            if plots:
                plt.clf()
                plt.plot((1. / tim.inverr[tim.inverr > 0]).ravel(),
                         (1. / new_ie[tim.inverr > 0]).ravel(), 'b.')
                plt.title('unWISE per-pixel error: %s band %i' %
                          (tile.coadd_id, band))
                plt.xlabel('original')
                plt.ylabel('floored')
                ps.savefig()

            assert (np.all(np.isfinite(new_ie)))
            assert (np.all(new_ie >= 0.))
            tim.inverr = new_ie

            # Expand a 3-pixel radius around weight=0 (saturated) pixels
            # from Eddie via crowdsource
            # https://github.com/schlafly/crowdsource/blob/7069da3e7d9d3124be1cbbe1d21ffeb63fc36dcc/python/wise_proc.py#L74
            ## FIXME -- W3/W4 ??
            satlimit = 85000
            msat = ((tim.data > satlimit) | ((tim.nims == 0) &
                                             (tim.nuims > 1)))
            from scipy.ndimage.morphology import binary_dilation
            xx, yy = np.mgrid[-3:3 + 1, -3:3 + 1]
            dilate = xx**2 + yy**2 <= 3**2
            msat = binary_dilation(msat, dilate)
            nbefore = np.sum(tim.inverr == 0)
            tim.inverr[msat] = 0
            nafter = np.sum(tim.inverr == 0)
            debug('Masking an additional', (nafter - nbefore),
                  'near-saturated pixels in unWISE', tile.coadd_id, 'band',
                  band)

        # Read mask file?
        if get_masks:
            from astrometry.util.resample import resample_with_wcs, OverlapError
            # unwise_dir can be a colon-separated list of paths
            tilemask = None
            for d in tile.unwise_dir.split(':'):
                fn = os.path.join(d, tile.coadd_id[:3], tile.coadd_id,
                                  'unwise-%s-msk.fits.gz' % tile.coadd_id)
                if os.path.exists(fn):
                    debug('Reading unWISE mask file', fn)
                    x0, x1, y0, y1 = tim.roi
                    tilemask = fitsio.FITS(fn)[0][y0:y1, x0:x1]
                    break
            if tilemask is None:
                info('unWISE mask file for tile', tile.coadd_id,
                     'does not exist')
            else:
                try:
                    tanwcs = tim.wcs.wcs
                    assert (tanwcs.shape == tilemask.shape)
                    Yo, Xo, Yi, Xi, _ = resample_with_wcs(get_masks,
                                                          tanwcs,
                                                          intType=np.int16)
                    # Only deal with mask pixels that are set.
                    I, = np.nonzero(tilemask[Yi, Xi] > 0)
                    # Trim to unique area for this tile
                    rr, dd = get_masks.pixelxy2radec(Xo[I] + 1, Yo[I] + 1)
                    good = radec_in_unique_area(rr, dd, tile.ra1, tile.ra2,
                                                tile.dec1, tile.dec2)
                    I = I[good]
                    maskmap[Yo[I], Xo[I]] = tilemask[Yi[I], Xi[I]]
                except OverlapError:
                    # Shouldn't happen by this point
                    print('Warning: no overlap between WISE tile',
                          tile.coadd_id, 'and brick')

            if plots:
                plt.clf()
                plt.imshow(tilemask, interpolation='nearest', origin='lower')
                plt.title('Tile %s: mask' % tile.coadd_id)
                ps.savefig()
                plt.clf()
                plt.imshow(maskmap, interpolation='nearest', origin='lower')
                plt.title('Tile %s: accumulated maskmap' % tile.coadd_id)
                ps.savefig()

        # The tiles have some overlap, so zero out pixels outside the
        # tile's unique area.
        th, tw = tim.shape
        xx, yy = np.meshgrid(np.arange(tw), np.arange(th))
        rr, dd = tim.wcs.wcs.pixelxy2radec(xx + 1, yy + 1)
        unique = radec_in_unique_area(rr, dd, tile.ra1, tile.ra2, tile.dec1,
                                      tile.dec2)
        debug('Tile', tile.coadd_id, '- total of', np.sum(unique),
              'unique pixels out of', len(unique.flat), 'total pixels')
        if get_models:
            # Save the inverr before blanking out non-unique pixels, for making coadds with no gaps!
            # (actually, slightly more subtly, expand unique area by 1 pixel)
            from scipy.ndimage.morphology import binary_dilation
            du = binary_dilation(unique)
            tim.coadd_inverr = tim.inverr * du
        tim.inverr[unique == False] = 0.
        del xx, yy, rr, dd, unique

        if plots:
            sig1 = tim.sig1
            plt.clf()
            plt.imshow(tim.getImage() * (tim.inverr > 0),
                       interpolation='nearest',
                       origin='lower',
                       cmap='gray',
                       vmin=-3 * sig1,
                       vmax=10 * sig1)
            plt.colorbar()
            tag = '%s W%i' % (tile.coadd_id, band)
            plt.title('%s: tim data (unique)' % tag)
            ps.savefig()

        if pixelized_psf:
            from unwise_psf import unwise_psf
            if (band == 1) or (band == 2):
                # we only have updated PSFs for W1 and W2
                psfimg = unwise_psf.get_unwise_psf(band,
                                                   tile.coadd_id,
                                                   modelname='neo6_unwisecat')
            else:
                psfimg = unwise_psf.get_unwise_psf(band, tile.coadd_id)

            if band == 4:
                # oversample (the unwise_psf models are at native W4 5.5"/pix,
                # while the unWISE coadds are made at 2.75"/pix.
                ph, pw = psfimg.shape
                subpsf = np.zeros((ph * 2 - 1, pw * 2 - 1), np.float32)
                from astrometry.util.util import lanczos3_interpolate
                xx, yy = np.meshgrid(
                    np.arange(0., pw - 0.51, 0.5, dtype=np.float32),
                    np.arange(0., ph - 0.51, 0.5, dtype=np.float32))
                xx = xx.ravel()
                yy = yy.ravel()
                ix = xx.astype(np.int32)
                iy = yy.astype(np.int32)
                dx = (xx - ix).astype(np.float32)
                dy = (yy - iy).astype(np.float32)
                psfimg = psfimg.astype(np.float32)
                rtn = lanczos3_interpolate(ix, iy, dx, dy, [subpsf.flat],
                                           [psfimg])

                if plots:
                    plt.clf()
                    plt.imshow(psfimg, interpolation='nearest', origin='lower')
                    plt.title('Original PSF model')
                    ps.savefig()
                    plt.clf()
                    plt.imshow(subpsf, interpolation='nearest', origin='lower')
                    plt.title('Subsampled PSF model')
                    ps.savefig()

                psfimg = subpsf
                del xx, yy, ix, iy, dx, dy

            from tractor.psf import PixelizedPSF
            psfimg /= psfimg.sum()
            fluxrescales = {1: 1.04, 2: 1.005, 3: 1.0, 4: 1.0}
            psfimg *= fluxrescales[band]
            tim.psf = PixelizedPSF(psfimg)

        if psf_broadening is not None and not pixelized_psf:
            # psf_broadening is a factor by which the PSF FWHMs
            # should be scaled; the PSF is a little wider
            # post-reactivation.
            psf = tim.getPsf()
            from tractor import GaussianMixturePSF
            if isinstance(psf, GaussianMixturePSF):
                debug('Broadening PSF: from', psf)
                p0 = psf.getParams()
                pnames = psf.getParamNames()
                p1 = [
                    p * psf_broadening**2 if 'var' in name else p
                    for (p, name) in zip(p0, pnames)
                ]
                psf.setParams(p1)
                debug('Broadened PSF:', psf)
            else:
                print(
                    'WARNING: cannot apply psf_broadening to WISE PSF of type',
                    type(psf))

        wcs = tim.wcs.wcs
        _, fx, fy = wcs.radec2pixelxy(ra, dec)
        x = np.round(fx - 1.).astype(int)
        y = np.round(fy - 1.).astype(int)
        good = (x >= 0) * (x < tw) * (y >= 0) * (y < th)
        # Which sources are in this brick's unique area?
        usrc = radec_in_unique_area(ra, dec, tile.ra1, tile.ra2, tile.dec1,
                                    tile.dec2)
        I, = np.nonzero(good * usrc)

        nexp[I] = tim.nuims[y[I], x[I]]
        if hasattr(tim, 'mjdmin') and hasattr(tim, 'mjdmax'):
            mjd[I] = (tim.mjdmin + tim.mjdmax) / 2.
        phot.wise_coadd_id[I] = tile.coadd_id
        phot.wise_x[I] = fx[I] - 1.
        phot.wise_y[I] = fy[I] - 1.

        central_flux[I] = tim.getImage()[y[I], x[I]]
        del x, y, good, usrc

        # PSF norm for depth
        psf = tim.getPsf()
        h, w = tim.shape
        patch = psf.getPointSourcePatch(h // 2, w // 2).patch
        psfnorm = np.sqrt(np.sum(patch**2))
        # To handle zero-depth, we return 1/nanomaggies^2 units rather than mags.
        # In the small empty patches of the sky (eg W4 in 0922p702), we get sig1 = NaN
        if np.isfinite(tim.sig1):
            phot.get('psfdepth_%s' % wband)[I] = 1. / (tim.sig1 / psfnorm)**2

        tim.tile = tile
        tims.append(tim)

    if plots:
        plt.clf()
        mn, mx = 0.1, 20000
        plt.hist(np.log10(np.clip(central_flux, mn, mx)),
                 bins=100,
                 range=(np.log10(mn), np.log10(mx)))
        logt = np.arange(0, 5)
        plt.xticks(logt, ['%i' % i for i in 10.**logt])
        plt.title('Central fluxes (W%i)' % band)
        plt.axvline(np.log10(20000), color='k')
        plt.axvline(np.log10(1000), color='k')
        ps.savefig()

    # Eddie's non-secret recipe:
    #- central pixel <= 1000: 19x19 pix box size
    #- central pixel in 1000 - 20000: 59x59 box size
    #- central pixel > 20000 or saturated: 149x149 box size
    #- object near "bright star": 299x299 box size
    nbig = nmedium = nsmall = 0
    for src, cflux in zip(cat, central_flux):
        if cflux > 20000:
            R = 100
            nbig += 1
        elif cflux > 1000:
            R = 30
            nmedium += 1
        else:
            R = 15
            nsmall += 1
        if isinstance(src, PointSource):
            src.fixedRadius = R
        else:
            ### FIXME -- sizes for galaxies..... can we set PSF size separately?
            galrad = 0
            # RexGalaxy is a subclass of ExpGalaxy
            if isinstance(src, (ExpGalaxy, DevGalaxy, SersicGalaxy)):
                galrad = src.shape.re
            pixscale = 2.75
            src.halfsize = int(np.hypot(R, galrad * 5 / pixscale))
    debug('Set WISE source sizes:', nbig, 'big', nmedium, 'medium', nsmall,
          'small')

    tractor = Tractor(tims, cat)
    if use_ceres:
        from tractor.ceres_optimizer import CeresOptimizer
        tractor.optimizer = CeresOptimizer(BW=ceres_block, BH=ceres_block)
    tractor.freezeParamsRecursive('*')
    tractor.thawPathsTo(wanyband)

    t0 = Time()
    R = tractor.optimize_forced_photometry(fitstats=True,
                                           variance=True,
                                           shared_params=False,
                                           wantims=wantims)
    info(tag + 'unWISE forced photometry took', Time() - t0)

    if use_ceres:
        term = R.ceres_status['termination']
        # Running out of memory can cause failure to converge and term
        # status = 2.  Fail completely in this case.
        if term != 0:
            info(tag + 'Ceres termination status:', term)
            raise RuntimeError('Ceres terminated with status %i' % term)

    if wantims:
        ims1 = R.ims1
        # can happen if empty source list (we still want to generate coadds)
        if ims1 is None:
            ims1 = R.ims0

    flux_invvars = R.IV
    if R.fitstats is not None:
        for k in fskeys:
            x = getattr(R.fitstats, k)
            fitstats[k] = np.array(x).astype(np.float32)

    if save_fits:
        for i, tim in enumerate(tims):
            tile = tim.tile
            (dat, mod, _, chi, _) = ims1[i]
            wcshdr = fitsio.FITSHDR()
            tim.wcs.wcs.add_to_header(wcshdr)
            tag = 'fit-%s-w%i' % (tile.coadd_id, band)
            fitsio.write('%s-data.fits' % tag,
                         dat,
                         clobber=True,
                         header=wcshdr)
            fitsio.write('%s-mod.fits' % tag, mod, clobber=True, header=wcshdr)
            fitsio.write('%s-chi.fits' % tag, chi, clobber=True, header=wcshdr)

    if plots:
        # Create models for just the brightest sources
        bright_cat = [
            src for src in cat if src.getBrightness().getBand(wanyband) > 1000
        ]
        debug('Bright soures:', len(bright_cat))
        btr = Tractor(tims, bright_cat)
        for tim in tims:
            mod = btr.getModelImage(tim)
            tile = tim.tile
            tag = '%s W%i' % (tile.coadd_id, band)
            sig1 = tim.sig1
            plt.clf()
            plt.imshow(mod,
                       interpolation='nearest',
                       origin='lower',
                       cmap='gray',
                       vmin=-3 * sig1,
                       vmax=25 * sig1)
            plt.colorbar()
            plt.title('%s: bright-star models' % tag)
            ps.savefig()

    if get_models:
        for i, tim in enumerate(tims):
            tile = tim.tile
            (dat, mod, _, _, _) = ims1[i]
            models.append(
                (tile.coadd_id, band, tim.wcs.wcs, dat, mod, tim.coadd_inverr))

    if plots:
        for i, tim in enumerate(tims):
            tile = tim.tile
            tag = '%s W%i' % (tile.coadd_id, band)
            (dat, mod, _, chi, _) = ims1[i]
            sig1 = tim.sig1
            plt.clf()
            plt.imshow(dat,
                       interpolation='nearest',
                       origin='lower',
                       cmap='gray',
                       vmin=-3 * sig1,
                       vmax=25 * sig1)
            plt.colorbar()
            plt.title('%s: data' % tag)
            ps.savefig()
            plt.clf()
            plt.imshow(mod,
                       interpolation='nearest',
                       origin='lower',
                       cmap='gray',
                       vmin=-3 * sig1,
                       vmax=25 * sig1)
            plt.colorbar()
            plt.title('%s: model' % tag)
            ps.savefig()

            plt.clf()
            plt.imshow(chi,
                       interpolation='nearest',
                       origin='lower',
                       cmap='gray',
                       vmin=-5,
                       vmax=+5)
            plt.colorbar()
            plt.title('%s: chi' % tag)
            ps.savefig()

    nm = np.array([src.getBrightness().getBand(wanyband) for src in cat])
    nm_ivar = flux_invvars
    # Sources out of bounds, eg, never change from their initial
    # fluxes.  Zero them out instead.
    nm[nm_ivar == 0] = 0.

    phot.set('flux_%s' % wband, nm.astype(np.float32))
    phot.set('flux_ivar_%s' % wband, nm_ivar.astype(np.float32))
    for k in fskeys:
        phot.set(k + '_' + wband,
                 fitstats.get(k, np.zeros(len(phot), np.float32)))
    phot.set('nobs_%s' % wband, nexp)
    phot.set('mjd_%s' % wband, mjd)

    rtn = wphotduck()
    rtn.phot = phot
    rtn.models = None
    rtn.maskmap = None
    if get_models:
        rtn.models = models
    if get_masks:
        rtn.maskmap = maskmap
    return rtn
Beispiel #3
0
def unwise_forcedphot(cat, tiles, band=1, roiradecbox=None,
                      use_ceres=True, ceres_block=8,
                      save_fits=False, get_models=False, ps=None,
                      psf_broadening=None,
                      pixelized_psf=False,
                      get_masks=None,
                      move_crpix=False,
                      modelsky_dir=None):
    '''
    Given a list of tractor sources *cat*
    and a list of unWISE tiles *tiles* (a fits_table with RA,Dec,coadd_id)
    runs forced photometry, returning a FITS table the same length as *cat*.

    *get_masks*: the WCS to resample mask bits into.
    '''
    from tractor import NanoMaggies, PointSource, Tractor, ExpGalaxy, DevGalaxy, FixedCompositeGalaxy

    if not pixelized_psf and psf_broadening is None:
        # PSF broadening in post-reactivation data, by band.
        # Newer version from Aaron's email to decam-chatter, 2018-06-14.
        broadening = { 1: 1.0405, 2: 1.0346, 3: None, 4: None }
        psf_broadening = broadening[band]

    if False:
        from astrometry.util.plotutils import PlotSequence
        ps = PlotSequence('wise-forced-w%i' % band)
    plots = (ps is not None)
    if plots:
        import pylab as plt
    
    wantims = (plots or save_fits or get_models)
    wanyband = 'w'
    if get_models:
        models = {}

    wband = 'w%i' % band

    fskeys = ['prochi2', 'pronpix', 'profracflux', 'proflux', 'npix',
              'pronexp']

    Nsrcs = len(cat)
    phot = fits_table()
    # Filled in based on unique tile overlap
    phot.wise_coadd_id = np.array(['        '] * Nsrcs)
    phot.set(wband + '_psfdepth', np.zeros(len(phot), np.float32))

    ra  = np.array([src.getPosition().ra  for src in cat])
    dec = np.array([src.getPosition().dec for src in cat])

    nexp = np.zeros(Nsrcs, np.int16)
    mjd  = np.zeros(Nsrcs, np.float64)
    central_flux = np.zeros(Nsrcs, np.float32)

    fitstats = {}
    tims = []

    if get_masks:
        mh,mw = get_masks.shape
        maskmap = np.zeros((mh,mw), np.uint32)
    
    for tile in tiles:
        print('Reading WISE tile', tile.coadd_id, 'band', band)

        tim = get_unwise_tractor_image(tile.unwise_dir, tile.coadd_id, band,
                                       bandname=wanyband, roiradecbox=roiradecbox)
        if tim is None:
            print('Actually, no overlap with tile', tile.coadd_id)
            continue

        if plots:
            sig1 = tim.sig1
            plt.clf()
            plt.imshow(tim.getImage(), interpolation='nearest', origin='lower',
                       cmap='gray', vmin=-3 * sig1, vmax=10 * sig1)
            plt.colorbar()
            tag = '%s W%i' % (tile.coadd_id, band)
            plt.title('%s: tim data' % tag)
            ps.savefig()

            plt.clf()
            plt.hist((tim.getImage() * tim.inverr)[tim.inverr > 0].ravel(),
                     range=(-5,10), bins=100)
            plt.xlabel('Per-pixel intensity (Sigma)')
            plt.title(tag)
            ps.savefig()

        if move_crpix and band in [1, 2]:
            realwcs = tim.wcs.wcs
            x,y = realwcs.crpix
            tile_crpix = tile.get('crpix_w%i' % band)
            dx = tile_crpix[0] - 1024.5
            dy = tile_crpix[1] - 1024.5
            realwcs.set_crpix(x+dx, y+dy)
            #print('CRPIX', x,y, 'shift by', dx,dy, 'to', realwcs.crpix)

        if modelsky_dir and band in [1, 2]:
            fn = os.path.join(modelsky_dir, '%s.%i.mod.fits' % (tile.coadd_id, band))
            if not os.path.exists(fn):
                raise RuntimeError('WARNING: does not exist:', fn)
            x0,x1,y0,y1 = tim.roi
            bg = fitsio.FITS(fn)[2][y0:y1, x0:x1]
            #print('Read background map:', bg.shape, bg.dtype, 'vs image', tim.shape)

            if plots:
                plt.clf()
                plt.subplot(1,2,1)
                plt.imshow(tim.getImage(), interpolation='nearest', origin='lower',
                           cmap='gray', vmin=-3 * sig1, vmax=5 * sig1)
                plt.subplot(1,2,2)
                plt.imshow(bg, interpolation='nearest', origin='lower',
                           cmap='gray', vmin=-3 * sig1, vmax=5 * sig1)
                tag = '%s W%i' % (tile.coadd_id, band)
                plt.suptitle(tag)
                ps.savefig()

                plt.clf()
                ha = dict(range=(-5,10), bins=100, histtype='step')
                plt.hist((tim.getImage() * tim.inverr)[tim.inverr > 0].ravel(),
                         color='b', label='Original', **ha)
                plt.hist(((tim.getImage()-bg) * tim.inverr)[tim.inverr > 0].ravel(),
                         color='g', label='Minus Background', **ha)
                plt.axvline(0, color='k', alpha=0.5)
                plt.xlabel('Per-pixel intensity (Sigma)')
                plt.legend()
                plt.title(tag + ': background')
                ps.savefig()

            # Actually subtract the background!
            tim.data -= bg

        # Floor the per-pixel variances
        if band in [1,2]:
            # in Vega nanomaggies per pixel
            floor_sigma = {1: 0.5, 2: 2.0}
            with np.errstate(divide='ignore'):
                new_ie = 1. / np.hypot(1./tim.inverr, floor_sigma[band])
            new_ie[tim.inverr == 0] = 0.

            if plots:
                plt.clf()
                plt.plot((1. / tim.inverr[tim.inverr>0]).ravel(), (1./new_ie[tim.inverr>0]).ravel(), 'b.')
                plt.title('unWISE per-pixel error: %s band %i' % (tile.coadd_id, band))
                plt.xlabel('original')
                plt.ylabel('floored')
                ps.savefig()

            tim.inverr = new_ie

        # Read mask file?
        if get_masks:
            from astrometry.util.resample import resample_with_wcs, OverlapError
            # unwise_dir can be a colon-separated list of paths
            tilemask = None
            for d in tile.unwise_dir.split(':'):
                fn = os.path.join(d, tile.coadd_id[:3], tile.coadd_id,
                                  'unwise-%s-msk.fits.gz' % tile.coadd_id)
                if os.path.exists(fn):
                    print('Reading unWISE mask file', fn)
                    x0,x1,y0,y1 = tim.roi
                    tilemask = fitsio.FITS(fn)[0][y0:y1,x0:x1]
                    break
            if tilemask is None:
                print('unWISE mask file for tile', tile.coadd_id, 'does not exist')
            else:
                try:
                    tanwcs = tim.wcs.wcs
                    assert(tanwcs.shape == tilemask.shape)
                    Yo,Xo,Yi,Xi,_ = resample_with_wcs(get_masks, tanwcs, intType=np.int16)
                    # Only deal with mask pixels that are set.
                    I, = np.nonzero(tilemask[Yi,Xi] > 0)
                    # Trim to unique area for this tile
                    rr,dd = get_masks.pixelxy2radec(Yo[I]+1, Xo[I]+1)
                    good = radec_in_unique_area(rr, dd, tile.ra1, tile.ra2, tile.dec1, tile.dec2)
                    I = I[good]
                    maskmap[Yo[I],Xo[I]] = tilemask[Yi[I], Xi[I]]
                except OverlapError:
                    # Shouldn't happen by this point
                    print('No overlap between WISE tile', tile.coadd_id, 'and brick')

        # The tiles have some overlap, so zero out pixels outside the
        # tile's unique area.
        th,tw = tim.shape
        xx,yy = np.meshgrid(np.arange(tw), np.arange(th))
        rr,dd = tim.wcs.wcs.pixelxy2radec(xx+1, yy+1)
        unique = radec_in_unique_area(rr, dd, tile.ra1, tile.ra2, tile.dec1, tile.dec2)
        #print(np.sum(unique), 'of', (th*tw), 'pixels in this tile are unique')
        tim.inverr[unique == False] = 0.
        del xx,yy,rr,dd,unique

        if plots:
            sig1 = tim.sig1
            plt.clf()
            plt.imshow(tim.getImage() * (tim.inverr > 0),
                       interpolation='nearest', origin='lower',
                       cmap='gray', vmin=-3 * sig1, vmax=10 * sig1)
            plt.colorbar()
            tag = '%s W%i' % (tile.coadd_id, band)
            plt.title('%s: tim data (unique)' % tag)
            ps.savefig()

        if pixelized_psf:
            import unwise_psf
            if (band == 1) or (band == 2):
                # we only have updated PSFs for W1 and W2
                psfimg = unwise_psf.get_unwise_psf(band, tile.coadd_id, 
                                                   modelname='neo4_unwisecat')
            else:
                psfimg = unwise_psf.get_unwise_psf(band, tile.coadd_id)

            if band == 4:
                # oversample (the unwise_psf models are at native W4 5.5"/pix,
                # while the unWISE coadds are made at 2.75"/pix.
                ph,pw = psfimg.shape
                subpsf = np.zeros((ph*2-1, pw*2-1), np.float32)
                from astrometry.util.util import lanczos3_interpolate
                xx,yy = np.meshgrid(np.arange(0., pw-0.51, 0.5, dtype=np.float32),
                                    np.arange(0., ph-0.51, 0.5, dtype=np.float32))
                xx = xx.ravel()
                yy = yy.ravel()
                ix = xx.astype(np.int32)
                iy = yy.astype(np.int32)
                dx = (xx - ix).astype(np.float32)
                dy = (yy - iy).astype(np.float32)
                psfimg = psfimg.astype(np.float32)
                rtn = lanczos3_interpolate(ix, iy, dx, dy, [subpsf.flat], [psfimg])

                if plots:
                    plt.clf()
                    plt.imshow(psfimg, interpolation='nearest', origin='lower')
                    plt.title('Original PSF model')
                    ps.savefig()
                    plt.clf()
                    plt.imshow(subpsf, interpolation='nearest', origin='lower')
                    plt.title('Subsampled PSF model')
                    ps.savefig()

                psfimg = subpsf
                del xx, yy, ix, iy, dx, dy

            from tractor.psf import PixelizedPSF
            psfimg /= psfimg.sum()
            fluxrescales = {1: 1.04, 2: 1.005, 3: 1.0, 4: 1.0}
            psfimg *= fluxrescales[band]
            tim.psf = PixelizedPSF(psfimg)

        if psf_broadening is not None and not pixelized_psf:
            # psf_broadening is a factor by which the PSF FWHMs
            # should be scaled; the PSF is a little wider
            # post-reactivation.
            psf = tim.getPsf()
            from tractor import GaussianMixturePSF
            if isinstance(psf, GaussianMixturePSF):
                #
                print('Broadening PSF: from', psf)
                p0 = psf.getParams()
                pnames = psf.getParamNames()
                p1 = [p * psf_broadening**2 if 'var' in name else p
                      for (p, name) in zip(p0, pnames)]
                psf.setParams(p1)
                print('Broadened PSF:', psf)
            else:
                print('WARNING: cannot apply psf_broadening to WISE PSF of type', type(psf))

        wcs = tim.wcs.wcs
        ok,x,y = wcs.radec2pixelxy(ra, dec)
        x = np.round(x - 1.).astype(int)
        y = np.round(y - 1.).astype(int)
        good = (x >= 0) * (x < tw) * (y >= 0) * (y < th)
        # Which sources are in this brick's unique area?
        usrc = radec_in_unique_area(ra, dec, tile.ra1, tile.ra2, tile.dec1, tile.dec2)
        I, = np.nonzero(good * usrc)

        nexp[I] = tim.nuims[y[I], x[I]]
        if hasattr(tim, 'mjdmin') and hasattr(tim, 'mjdmax'):
            mjd[I] = (tim.mjdmin + tim.mjdmax) / 2.
        phot.wise_coadd_id[I] = tile.coadd_id

        central_flux[I] = tim.getImage()[y[I], x[I]]
        del x,y,good,usrc

        # PSF norm for depth
        psf = tim.getPsf()
        h,w = tim.shape
        patch = psf.getPointSourcePatch(h//2, w//2).patch
        psfnorm = np.sqrt(np.sum(patch**2))
        # To handle zero-depth, we return 1/nanomaggies^2 units rather than mags.
        psfdepth = 1. / (tim.sig1 / psfnorm)**2
        phot.get(wband + '_psfdepth')[I] = psfdepth

        tim.tile = tile
        tims.append(tim)

    if plots:
        plt.clf()
        mn,mx = 0.1, 20000
        plt.hist(np.log10(np.clip(central_flux, mn, mx)), bins=100,
                 range=(np.log10(mn), np.log10(mx)))
        logt = np.arange(0, 5)
        plt.xticks(logt, ['%i' % i for i in 10.**logt])
        plt.title('Central fluxes (W%i)' % band)
        plt.axvline(np.log10(20000), color='k')
        plt.axvline(np.log10(1000), color='k')
        ps.savefig()

    # Eddie's non-secret recipe:
    #- central pixel <= 1000: 19x19 pix box size
    #- central pixel in 1000 - 20000: 59x59 box size
    #- central pixel > 20000 or saturated: 149x149 box size
    #- object near "bright star": 299x299 box size
    nbig = nmedium = nsmall = 0
    for src,cflux in zip(cat, central_flux):
        if cflux > 20000:
            R = 100
            nbig += 1
        elif cflux > 1000:
            R = 30
            nmedium += 1
        else:
            R = 15
            nsmall += 1
        if isinstance(src, PointSource):
            src.fixedRadius = R
        else:
            ### FIXME -- sizes for galaxies..... can we set PSF size separately?
            galrad = 0
            # RexGalaxy is a subclass of ExpGalaxy
            if isinstance(src, (ExpGalaxy, DevGalaxy)):
                galrad = src.shape.re
            elif isinstance(src, FixedCompositeGalaxy):
                galrad = max(src.shapeExp.re, src.shapeDev.re)
            pixscale = 2.75
            src.halfsize = int(np.hypot(R, galrad * 5 / pixscale))

    #print('Set WISE source sizes:', nbig, 'big', nmedium, 'medium', nsmall, 'small')

    minsb = 0.
    fitsky = False

    tractor = Tractor(tims, cat)
    if use_ceres:
        from tractor.ceres_optimizer import CeresOptimizer
        tractor.optimizer = CeresOptimizer(BW=ceres_block, BH=ceres_block)
    tractor.freezeParamsRecursive('*')
    tractor.thawPathsTo(wanyband)

    kwa = dict(fitstat_extras=[('pronexp', [tim.nims for tim in tims])])
    t0 = Time()

    R = tractor.optimize_forced_photometry(
        minsb=minsb, mindlnp=1., sky=fitsky, fitstats=True,
        variance=True, shared_params=False,
        wantims=wantims, **kwa)
    print('unWISE forced photometry took', Time() - t0)

    if use_ceres:
        term = R.ceres_status['termination']
        # Running out of memory can cause failure to converge
        # and term status = 2.
        # Fail completely in this case.
        if term != 0:
            print('Ceres termination status:', term)
            raise RuntimeError(
                'Ceres terminated with status %i' % term)

    if wantims:
        ims1 = R.ims1
    flux_invvars = R.IV
    if R.fitstats is not None:
        for k in fskeys:
            x = getattr(R.fitstats, k)
            fitstats[k] = np.array(x).astype(np.float32)

    if save_fits:
        for i,tim in enumerate(tims):
            tile = tim.tile
            (dat, mod, ie, chi, roi) = ims1[i]
            wcshdr = fitsio.FITSHDR()
            tim.wcs.wcs.add_to_header(wcshdr)
            tag = 'fit-%s-w%i' % (tile.coadd_id, band)
            fitsio.write('%s-data.fits' %
                         tag, dat, clobber=True, header=wcshdr)
            fitsio.write('%s-mod.fits' % tag,  mod,
                         clobber=True, header=wcshdr)
            fitsio.write('%s-chi.fits' % tag,  chi,
                         clobber=True, header=wcshdr)

    if plots:
        # Create models for just the brightest sources
        bright_cat = [src for src in cat
                      if src.getBrightness().getBand(wanyband) > 1000]
        print('Bright soures:', len(bright_cat))
        btr = Tractor(tims, bright_cat)
        for tim in tims:
            mod = btr.getModelImage(tim)
            tile = tim.tile
            tag = '%s W%i' % (tile.coadd_id, band)
            sig1 = tim.sig1
            plt.clf()
            plt.imshow(mod, interpolation='nearest', origin='lower',
                       cmap='gray', vmin=-3 * sig1, vmax=25 * sig1)
            plt.colorbar()
            plt.title('%s: bright-star models' % tag)
            ps.savefig()

    if get_models:
        for i,tim in enumerate(tims):
            tile = tim.tile
            (dat, mod, ie, chi, roi) = ims1[i]
            models[(tile.coadd_id, band)] = (mod, dat, ie, tim.roi, tim.wcs.wcs)

    if plots:
        for i,tim in enumerate(tims):
            tile = tim.tile
            tag = '%s W%i' % (tile.coadd_id, band)
            (dat, mod, ie, chi, roi) = ims1[i]
            sig1 = tim.sig1
            plt.clf()
            plt.imshow(dat, interpolation='nearest', origin='lower',
                       cmap='gray', vmin=-3 * sig1, vmax=25 * sig1)
            plt.colorbar()
            plt.title('%s: data' % tag)
            ps.savefig()
            plt.clf()
            plt.imshow(mod, interpolation='nearest', origin='lower',
                       cmap='gray', vmin=-3 * sig1, vmax=25 * sig1)
            plt.colorbar()
            plt.title('%s: model' % tag)
            ps.savefig()

            plt.clf()
            plt.imshow(chi, interpolation='nearest', origin='lower',
                       cmap='gray', vmin=-5, vmax=+5)
            plt.colorbar()
            plt.title('%s: chi' % tag)
            ps.savefig()


    nm = np.array([src.getBrightness().getBand(wanyband) for src in cat])
    nm_ivar = flux_invvars
    # Sources out of bounds, eg, never change from their default
    # (1-sigma or whatever) initial fluxes.  Zero them out instead.
    nm[nm_ivar == 0] = 0.

    phot.set(wband + '_nanomaggies', nm.astype(np.float32))
    phot.set(wband + '_nanomaggies_ivar', nm_ivar.astype(np.float32))
    dnm = np.zeros(len(nm_ivar), np.float32)
    okiv = (nm_ivar > 0)
    dnm[okiv] = (1. / np.sqrt(nm_ivar[okiv])).astype(np.float32)
    okflux = (nm > 0)
    mag = np.zeros(len(nm), np.float32)
    mag[okflux] = (NanoMaggies.nanomaggiesToMag(nm[okflux])
                   ).astype(np.float32)
    dmag = np.zeros(len(nm), np.float32)
    ok = (okiv * okflux)
    dmag[ok] = (np.abs((-2.5 / np.log(10.)) * dnm[ok] / nm[ok])
                ).astype(np.float32)
    mag[np.logical_not(okflux)] = np.nan
    dmag[np.logical_not(ok)] = np.nan

    phot.set(wband + '_mag', mag)
    phot.set(wband + '_mag_err', dmag)

    for k in fskeys:
        phot.set(wband + '_' + k, fitstats[k])
    phot.set(wband + '_nexp', nexp)
    if not np.all(mjd == 0):
        phot.set(wband + '_mjd', mjd)

    rtn = wphotduck()
    rtn.phot = phot
    rtn.models = None
    rtn.maskmap = None
    if get_models:
        rtn.models = models
    if get_masks:
        rtn.maskmap = maskmap
    return rtn
Beispiel #4
0
def galex_forcedphot(galex_dir, cat, tiles, band, roiradecbox,
                     pixelized_psf=False, ps=None):
    '''
    Given a list of tractor sources *cat*
    and a list of GALEX tiles *tiles* (a fits_table with RA,Dec,tilename)
    runs forced photometry, returning a FITS table the same length as *cat*.
    '''
    from tractor import Tractor
    from astrometry.util.ttime import Time

    if False:
        from astrometry.util.plotutils import PlotSequence
        ps = PlotSequence('wise-forced-w%i' % band)
    plots = (ps is not None)
    if plots:
        import pylab as plt

    use_ceres = True
    wantims = True
    get_models = True
    gband = 'galex'
    phot = fits_table()
    tims = []
    for tile in tiles:
        info('Reading GALEX tile', tile.visitname.strip(), 'band', band)

        tim = galex_tractor_image(tile, band, galex_dir, roiradecbox, gband)
        if tim is None:
            debug('Actually, no overlap with tile', tile.tilename)
            continue

        # if plots:
        #     sig1 = tim.sig1
        #     plt.clf()
        #     plt.imshow(tim.getImage(), interpolation='nearest', origin='lower',
        #                cmap='gray', vmin=-3 * sig1, vmax=10 * sig1)
        #     plt.colorbar()
        #     tag = '%s W%i' % (tile.tilename, band)
        #     plt.title('%s: tim data' % tag)
        #     ps.savefig()

        if pixelized_psf:
            psfimg = galex_psf(band, galex_dir)
            tim.psf = PixelizedPSF(psfimg)
        # if hasattr(tim, 'mjdmin') and hasattr(tim, 'mjdmax'):
        #     mjd[I] = (tim.mjdmin + tim.mjdmax) / 2.
        # # PSF norm for depth
        # psf = tim.getPsf()
        # h,w = tim.shape
        # patch = psf.getPointSourcePatch(h//2, w//2).patch
        # psfnorm = np.sqrt(np.sum(patch**2))
        # # To handle zero-depth, we return 1/nanomaggies^2 units rather than mags.
        # psfdepth = 1. / (tim.sig1 / psfnorm)**2
        # phot.get(wband + '_psfdepth')[I] = psfdepth

        tim.tile = tile
        tims.append(tim)

    tractor = Tractor(tims, cat)
    if use_ceres:
        from tractor.ceres_optimizer import CeresOptimizer
        ceres_block = 8
        tractor.optimizer = CeresOptimizer(BW=ceres_block, BH=ceres_block)
    tractor.freezeParamsRecursive('*')
    tractor.thawPathsTo(gband)

    t0 = Time()

    R = tractor.optimize_forced_photometry(
        fitstats=True, variance=True, shared_params=False,
        wantims=wantims)
    info('GALEX forced photometry took', Time() - t0)
    #info('Result:', R)

    if use_ceres:
        term = R.ceres_status['termination']
        # Running out of memory can cause failure to converge and term
        # status = 2.  Fail completely in this case.
        if term != 0:
            info('Ceres termination status:', term)
            raise RuntimeError('Ceres terminated with status %i' % term)

    if wantims:
        ims1 = R.ims1
    flux_invvars = R.IV

    # if plots:
    #     # Create models for just the brightest sources
    #     bright_cat = [src for src in cat
    #                   if src.getBrightness().getBand(wanyband) > 1000]
    #     debug('Bright soures:', len(bright_cat))
    #     btr = Tractor(tims, bright_cat)
    #     for tim in tims:
    #         mod = btr.getModelImage(tim)
    #         tile = tim.tile
    #         tag = '%s W%i' % (tile.tilename, band)
    #         sig1 = tim.sig1
    #         plt.clf()
    #         plt.imshow(mod, interpolation='nearest', origin='lower',
    #                    cmap='gray', vmin=-3 * sig1, vmax=25 * sig1)
    #         plt.colorbar()
    #         plt.title('%s: bright-star models' % tag)
    #         ps.savefig()

    if get_models:
        models = []
        for i,tim in enumerate(tims):
            tile = tim.tile
            (dat, mod, ie, _, _) = ims1[i]
            models.append((tile.visitname, band, tim.wcs.wcs, dat, mod, ie))

    if plots:
        for i,tim in enumerate(tims):
            tile = tim.tile
            tag = '%s %s' % (tile.tilename, band)
            (dat, mod, _, chi, _) = ims1[i]
            sig1 = tim.sig1
            plt.clf()
            plt.imshow(dat, interpolation='nearest', origin='lower',
                       cmap='gray', vmin=-3 * sig1, vmax=25 * sig1)
            plt.colorbar()
            plt.title('%s: data' % tag)
            ps.savefig()
            plt.clf()
            plt.imshow(mod, interpolation='nearest', origin='lower',
                       cmap='gray', vmin=-3 * sig1, vmax=25 * sig1)
            plt.colorbar()
            plt.title('%s: model' % tag)
            ps.savefig()

            plt.clf()
            plt.imshow(chi, interpolation='nearest', origin='lower',
                       cmap='gray', vmin=-5, vmax=+5)
            plt.colorbar()
            plt.title('%s: chi' % tag)
            ps.savefig()

    nm = np.array([src.getBrightness().getBand(gband) for src in cat])
    nm_ivar = flux_invvars
    # Sources out of bounds, eg, never change from their default
    # (1-sigma or whatever) initial fluxes.  Zero them out instead.
    nm[nm_ivar == 0] = 0.

    niceband = band + 'uv'
    phot.set('flux_' + niceband, nm.astype(np.float32))
    phot.set('flux_ivar_' + niceband, nm_ivar.astype(np.float32))
    #phot.set(band + '_mjd', mjd)

    rtn = gphotduck()
    rtn.phot = phot
    rtn.models = None
    if get_models:
        rtn.models = models
    return rtn