def stage_psfplots(
    T=None, sedsn=None, coimgs=None, cons=None,
    detmaps=None, detivs=None,
    blobsrcs=None,blobflux=None,blobslices=None, blobs=None,
    tractor=None, cat=None, targetrd=None, pixscale=None, targetwcs=None,
    W=None,H=None, brickid=None,
    bands=None, ps=None, tims=None,
    plots=False,
    **kwargs):

    tim = tims[0]
    tim.psfex.fitSavedData(*tim.psfex.splinedata)
    spl = tim.psfex.splines[0]
    print 'Spline:', spl
    knots = spl.get_knots()
    print 'knots:', knots
    tx,ty = knots
    k = 3
    print 'interior knots x:', tx[k+1:-k-1]
    print 'additional knots x:', tx[:k+1], 'and', tx[-k-1:]
    print 'interior knots y:', ty[k+1:-k-1]
    print 'additional knots y:', ty[:k+1], 'and', ty[-k-1:]

    for itim,tim in enumerate(tims):
        psfex = tim.psfex
        psfex.fitSavedData(*psfex.splinedata)
        if plots:
            print
            print 'Tim', tim
            print
            pp,xx,yy = psfex.splinedata
            ny,nx,nparams = pp.shape
            assert(len(xx) == nx)
            assert(len(yy) == ny)
            psfnil = psfex.psfclass(*np.zeros(nparams))
            names = psfnil.getParamNames()
            xa = np.linspace(xx[0], xx[-1],  50)
            ya = np.linspace(yy[0], yy[-1], 100)
            #xa,ya = np.meshgrid(xa,ya)
            #xa = xa.ravel()
            #ya = ya.ravel()
            print 'xa', xa
            print 'ya', ya
            for i in range(nparams):
                plt.clf()
                plt.subplot(1,2,1)
                dimshow(pp[:,:,i])
                plt.title('grid fit')
                plt.colorbar()
                plt.subplot(1,2,2)
                sp = psfex.splines[i](xa, ya)
                sp = sp.T
                print 'spline shape', sp.shape
                assert(sp.shape == (len(ya),len(xa)))
                dimshow(sp, extent=[xx[0],xx[-1],yy[0],yy[-1]])
                plt.title('spline')
                plt.colorbar()
                plt.suptitle('tim %s: PSF param %s' % (tim.name, names[i]))
                ps.savefig()
def halo_plots_before(tims, bands, targetwcs, halostars, ps):
    coimgs, _ = quick_coadds(tims, bands, targetwcs)
    plt.clf()
    dimshow(get_rgb(coimgs, bands))
    ax = plt.axis()
    plt.plot(halostars.ibx, halostars.iby, 'o', mec='r', ms=15, mfc='none')
    plt.axis(ax)
    plt.title('Before star halo subtraction')
    ps.savefig()
    return coimgs
Exemple #3
0
    def test_fourier(self):
        F, (cx, cy), shape, (v,
                             w) = self.psf.getFourierTransform(100., 100., 32)
        print('F', F)
        print('cx,cy', cx, cy)
        print('shape', shape)
        print('v, w', v, w)

        if ps is not None:
            import pylab as plt
            from astrometry.util.plotutils import dimshow
            plt.clf()
            plt.subplot(1, 2, 1)
            dimshow(F.real)
            plt.subplot(1, 2, 2)
            dimshow(F.imag)
            ps.savefig()
def _plot_derivs(subtims, newsrc, srctractor, ps):
    plt.clf()
    rows = len(subtims)
    cols = 1 + newsrc.numberOfParams()
    for it,tim in enumerate(subtims):
        derivs = srctractor._getSourceDerivatives(newsrc, tim)
        c0 = 1 + cols*it
        mod = srctractor.getModelPatchNoCache(tim, src)
        if mod is not None and mod.patch is not None:
            plt.subplot(rows, cols, c0)
            dimshow(mod.patch, extent=mod.getExtent())
        c0 += 1
        for ip,deriv in enumerate(derivs):
            if deriv is None:
                continue
            plt.subplot(rows, cols, c0+ip)
            mx = np.max(np.abs(deriv.patch))
            dimshow(deriv.patch, extent=deriv.getExtent(), vmin=-mx, vmax=mx)
    plt.title('Derivatives for ' + name)
    ps.savefig()
    plt.clf()
    modimgs = srctractor.getModelImages()
    comods,nil = compute_coadds(subtims, bands, subtarget, images=modimgs)
    dimshow(get_rgb(comods, bands))
    plt.title('Initial ' + name)
    ps.savefig()
def _plot_derivs(subtims, newsrc, srctractor, ps):
    plt.clf()
    rows = len(subtims)
    cols = 1 + newsrc.numberOfParams()
    for it, tim in enumerate(subtims):
        derivs = srctractor._getSourceDerivatives(newsrc, tim)
        c0 = 1 + cols * it
        mod = srctractor.getModelPatchNoCache(tim, src)
        if mod is not None and mod.patch is not None:
            plt.subplot(rows, cols, c0)
            dimshow(mod.patch, extent=mod.getExtent())
        c0 += 1
        for ip, deriv in enumerate(derivs):
            if deriv is None:
                continue
            plt.subplot(rows, cols, c0 + ip)
            mx = np.max(np.abs(deriv.patch))
            dimshow(deriv.patch, extent=deriv.getExtent(), vmin=-mx, vmax=mx)
    plt.title('Derivatives for ' + name)
    ps.savefig()
    plt.clf()
    modimgs = srctractor.getModelImages()
    comods, nil = compute_coadds(subtims, bands, subtarget, images=modimgs)
    dimshow(get_rgb(comods, bands))
    plt.title('Initial ' + name)
    ps.savefig()
def fitblobs_plots_2(blobs, refstars, ps):
    plt.clf()
    dimshow(blobs >= 0, vmin=0, vmax=1)
    ax = plt.axis()
    plt.plot(refstars.ibx, refstars.iby, 'ro')
    for ref in refstars:
        magstr = ref.ref_cat
        if ref.ref_cat == 'T2':
            mag = ref.mag
            magstr = 'T(%.1f)' % mag
        elif ref.ref_cat == 'G2':
            mag = ref.phot_g_mean_mag
            magstr = 'G(%.1f)' % mag
        plt.text(ref.ibx,
                 ref.iby,
                 magstr,
                 color='r',
                 fontsize=10,
                 bbox=dict(facecolor='w', alpha=0.5))
    plt.axis(ax)
    plt.title('Reference stars')
    ps.savefig()
def halo_plots_after(tims, bands, targetwcs, halostars, coimgs, ps):
    coimgs2, _ = quick_coadds(tims, bands, targetwcs)
    plt.clf()
    dimshow(get_rgb(coimgs2, bands))
    ax = plt.axis()
    plt.plot(halostars.ibx, halostars.iby, 'o', mec='r', ms=15, mfc='none')
    plt.axis(ax)
    plt.title('After star halo subtraction')
    ps.savefig()

    plt.clf()
    dimshow(get_rgb([co - co2 for co, co2 in zip(coimgs, coimgs2)], bands))
    ax = plt.axis()
    plt.plot(halostars.ibx, halostars.iby, 'o', mec='r', ms=15, mfc='none')
    plt.axis(ax)
    plt.title('Subtracted halos')
    ps.savefig()

    for g in halostars[:10]:
        plt.clf()
        pixscale = targetwcs.pixel_scale()
        pixrad = int(g.radius * 3600. / pixscale)
        ax = [g.ibx - pixrad, g.ibx + pixrad, g.iby - pixrad, g.iby + pixrad]
        ima = dict(interpolation='nearest', origin='lower')
        plt.subplot(2, 2, 1)
        plt.imshow(get_rgb(coimgs, bands), **ima)
        plt.plot(halostars.ibx, halostars.iby, 'o', mec='r', ms=15, mfc='none')
        plt.axis(ax)
        plt.subplot(2, 2, 2)
        plt.imshow(get_rgb(coimgs2, bands), **ima)
        plt.axis(ax)
        plt.subplot(2, 2, 3)
        plt.imshow(
            get_rgb([co - co2 for co, co2 in zip(coimgs, coimgs2)], bands),
            **ima)
        plt.axis(ax)
        ps.savefig()
def detection_plots_2(tims, bands, targetwcs, refstars, Tnew, hot,
                      saturated_pix, ps):
    coimgs, _ = quick_coadds(tims, bands, targetwcs)
    crossa = dict(ms=10, mew=1.5)
    plt.clf()
    dimshow(get_rgb(coimgs, bands))
    plt.title('Detections')
    ps.savefig()
    ax = plt.axis()
    if len(refstars):
        I, = np.nonzero([len(r) and r[0] == 'T' for r in refstars.ref_cat])
        if len(I):
            plt.plot(refstars.ibx[I],
                     refstars.iby[I],
                     '+',
                     color=(0, 1, 1),
                     label='Tycho-2',
                     **crossa)
        I, = np.nonzero([len(r) and r[0] == 'G' for r in refstars.ref_cat])
        if len(I):
            plt.plot(refstars.ibx[I],
                     refstars.iby[I],
                     '+',
                     color=(0.2, 0.2, 1),
                     label='Gaia',
                     **crossa)
        I, = np.nonzero([len(r) and r[0] == 'L' for r in refstars.ref_cat])
        if len(I):
            plt.plot(refstars.ibx[I],
                     refstars.iby[I],
                     '+',
                     color=(0.6, 0.6, 0.2),
                     label='Large Galaxy',
                     **crossa)
    plt.plot(Tnew.ibx,
             Tnew.iby,
             '+',
             color=(0, 1, 0),
             label='New SED-matched detections',
             **crossa)
    plt.axis(ax)
    plt.title('Detections')
    plt.legend(loc='upper left')
    ps.savefig()

    plt.clf()
    plt.subplot(1, 2, 1)
    dimshow(hot, vmin=0, vmax=1, cmap='hot')
    plt.title('hot')
    plt.subplot(1, 2, 2)
    H, W = targetwcs.shape
    rgb = np.zeros((H, W, 3))
    for i, satpix in enumerate(saturated_pix):
        rgb[:, :, 2 - i] = satpix
    dimshow(rgb)
    plt.title('saturated_pix')
    ps.savefig()
def _psf_check_plots(tims):
    plt.figure(num=2, figsize=(7, 4.08))
    for im, tim in zip(ims, tims):
        print()
        print('Image', tim.name)

        plt.subplots_adjust(left=0,
                            right=1,
                            bottom=0,
                            top=0.95,
                            hspace=0,
                            wspace=0)
        W, H = 2048, 4096
        psfex = PsfEx(im.psffn, W, H)

        psfim0 = psfim = psfex.instantiateAt(W / 2, H / 2)
        # trim
        psfim = psfim[10:-10, 10:-10]

        tfit = Time()
        psffit2 = GaussianMixtureEllipsePSF.fromStamp(psfim, N=2)
        print('Fitting PSF mog:', psfim.shape, Time() - tfit)

        psfim = psfim0[5:-5, 5:-5]
        tfit = Time()
        psffit2 = GaussianMixtureEllipsePSF.fromStamp(psfim, N=2)
        print('Fitting PSF mog:', psfim.shape, Time() - tfit)

        ph, pw = psfim.shape
        psffit = GaussianMixtureEllipsePSF.fromStamp(psfim, N=3)

        #mx = 0.03
        mx = psfim.max()

        mod3 = np.zeros_like(psfim)
        p = psffit.getPointSourcePatch(pw / 2, ph / 2, radius=pw / 2)
        p.addTo(mod3)
        mod2 = np.zeros_like(psfim)
        p = psffit2.getPointSourcePatch(pw / 2, ph / 2, radius=pw / 2)
        p.addTo(mod2)

        plt.clf()
        plt.subplot(2, 3, 1)
        dimshow(psfim, vmin=0, vmax=mx, ticks=False)
        plt.subplot(2, 3, 2)
        dimshow(mod3, vmin=0, vmax=mx, ticks=False)
        plt.subplot(2, 3, 3)
        dimshow(mod2, vmin=0, vmax=mx, ticks=False)
        plt.subplot(2, 3, 5)
        dimshow(psfim - mod3, vmin=-mx / 2, vmax=mx / 2, ticks=False)
        plt.subplot(2, 3, 6)
        dimshow(psfim - mod2, vmin=-mx / 2, vmax=mx / 2, ticks=False)
        ps.savefig()
        #continue

        for round in [1, 2, 3, 4, 5]:
            plt.clf()
            k = 1
            #rows,cols = 10,5
            rows, cols = 7, 4
            for iy, y in enumerate(np.linspace(0, H, rows).astype(int)):
                for ix, x in enumerate(np.linspace(0, W, cols).astype(int)):
                    psfimg = psfex.instantiateAt(x, y)
                    # trim
                    psfimg = psfimg[5:-5, 5:-5]
                    print('psfimg', psfimg.shape)
                    ph, pw = psfimg.shape
                    psfimg2 = tim.psfex.getPointSourcePatch(x,
                                                            y,
                                                            radius=pw / 2)
                    mod = np.zeros_like(psfimg)
                    h, w = mod.shape
                    #psfimg2.x0 -= x
                    #psfimg2.x0 += w/2
                    #psfimg2.y0 -= y
                    #psfimg2.y0 += h/2
                    psfimg2.x0 = 0
                    psfimg2.y0 = 0
                    print('psfimg2:', (psfimg2.x0, psfimg2.y0))
                    psfimg2.addTo(mod)
                    print('psfimg:', psfimg.min(), psfimg.max(), psfimg.sum())
                    print('psfimg2:', psfimg2.patch.min(), psfimg2.patch.max(),
                          psfimg2.patch.sum())
                    print('mod:', mod.min(), mod.max(), mod.sum())

                    #plt.subplot(rows, cols, k)
                    plt.subplot(cols, rows, k)
                    k += 1
                    kwa = dict(vmin=0, vmax=mx, ticks=False)
                    if round == 1:
                        dimshow(psfimg, **kwa)
                        plt.suptitle('PsfEx')
                    elif round == 2:
                        dimshow(mod, **kwa)
                        plt.suptitle('varying MoG')
                    elif round == 3:
                        dimshow(psfimg - mod,
                                vmin=-mx / 2,
                                vmax=mx / 2,
                                ticks=False)
                        plt.suptitle('PsfEx - varying MoG')
                    elif round == 4:
                        dimshow(psfimg - mod3,
                                vmin=-mx / 2,
                                vmax=mx / 2,
                                ticks=False)
                        plt.suptitle('PsfEx - const MoG(3)')
                    elif round == 5:
                        dimshow(psfimg - mod2,
                                vmin=-mx / 2,
                                vmax=mx / 2,
                                ticks=False)
                        plt.suptitle('PsfEx - const MoG(2)')
            ps.savefig()
Exemple #10
0
def main():
    ps = PlotSequence('conv')
    
    S = 51
    center = S/2
    print('Center', center)

    #for psf_sigma in [2., 1.5, 1.]:
    for psf_sigma in [2.]:

        rms2 = []

        x = np.arange(S)
        y = np.arange(S)
        xx,yy = np.meshgrid(x, y)

        scale = 1.5 / psf_sigma
        pixpsf = render_airy((scale, center), x, y)
        psf = (scale,center)
        eval_psf = render_airy


        plt.clf()
        plt.subplot(2,1,1)
        plt.plot(x, pixpsf[center,:], 'b-')
        plt.plot(x, pixpsf[:,center], 'r-')
        plt.subplot(2,1,2)
        plt.plot(x, np.maximum(1e-16, pixpsf[center,:]), 'b-')
        plt.plot(x, np.maximum(1e-16, pixpsf[:,center]), 'r-')
        plt.yscale('log')
        ps.savefig()

        plt.clf()
        plt.imshow(pixpsf, interpolation='nearest', origin='lower')
        ps.savefig()

        plt.clf()
        plt.imshow(np.log10(np.maximum(1e-16, pixpsf)),
                   interpolation='nearest', origin='lower')
        plt.colorbar()
        plt.title('log PSF')
        ps.savefig()
        
        # psf
        #psf = scipy.stats.norm(loc=center + 0.5, scale=psf_sigma)

        # plt.clf()
        # plt.imshow(Pcdf, interpolation='nearest', origin='lower')
        # ps.savefig()

        # #Pcdf = psf.cdf(xx) * psf.cdf(yy)
        # #pixpsf = integrate_gaussian(psf, xx, yy)
        # 
        # padpsf = np.zeros((S*2-1, S*2-1))
        # ph,pw = pixpsf.shape
        # padpsf[S/2:S/2+ph, S/2:S/2+pw] = pixpsf
        # Fpsf = np.fft.rfft2(padpsf)
        # 
        # padh,padw = padpsf.shape
        # v = np.fft.rfftfreq(padw)
        # w = np.fft.fftfreq(padh)
        # fmax = max(max(np.abs(v)), max(np.abs(w)))
        # cut = fmax / 2. * 1.000001
        # #print('Frequence cut:', cut)
        # Ffiltpsf = Fpsf.copy()
        # #print('Ffiltpsf', Ffiltpsf.shape)
        # #print((np.abs(w) < cut).shape)
        # #print((np.abs(v) < cut).shape)
        # Ffiltpsf[np.abs(w) > cut, :] = 0.
        # Ffiltpsf[:, np.abs(v) > cut] = 0.
        # #print('pad v', v)
        # #print('pad w', w)
        # 
        # filtpsf = np.fft.irfft2(Ffiltpsf, s=(padh,padw))
        # 
        # print('filtered PSF real', np.max(np.abs(filtpsf.real)))
        # print('filtered PSF imag', np.max(np.abs(filtpsf.imag)))
        # 
        # plt.clf()
        # plt.subplot(2,3,1)
        # dimshow(Fpsf.real)
        # plt.colorbar()
        # plt.title('Padded PSF real')
        # plt.subplot(2,3,4)
        # dimshow(Fpsf.imag)
        # plt.colorbar()
        # plt.title('Padded PSF imag')
        # 
        # plt.subplot(2,3,2)
        # dimshow(Ffiltpsf.real)
        # plt.colorbar()
        # plt.title('Filt PSF real')
        # plt.subplot(2,3,5)
        # dimshow(Ffiltpsf.imag)
        # plt.colorbar()
        # plt.title('Filt PSF imag')
        # 
        # plt.subplot(2,3,3)
        # dimshow(filtpsf.real)
        # plt.title('PSF real')
        # plt.colorbar()
        # 
        # plt.subplot(2,3,6)
        # dimshow(filtpsf.imag)
        # plt.title('PSF imag')
        # plt.colorbar()
        # 
        # ps.savefig()
        # 
        # 
        # pixpsf = filtpsf
        
        
        gal_sigmas = [2, 1, 0.5, 0.25]
        for gal_sigma in gal_sigmas:
    
            # plt.clf()
            # plt.imshow(Gcdf, interpolation='nearest', origin='lower')
            # plt.savefig('dcdf.png')
    
            # plt.clf()
            # plt.imshow(np.exp(-0.5 * ((xx-center)**2 + (yy-center)**2)/2.**2),
            #            interpolation='nearest', origin='lower')
            # plt.savefig('g.png')
    
            # my convolution
            pixscale = 1.
            cd = pixscale * np.eye(2) / 3600.
            P,FG,Gmine,v,w = galaxy_psf_convolution(
                gal_sigma, 0., 0., GaussianGalaxy, cd,
                0., 0., pixpsf, debug=True)

            #print('v:', v)
            #print('w:', w)
            #print('P:', P.shape)

            print()
            print('PSF %g, Gal %g' % (psf_sigma, gal_sigma))
            
            rmax = np.argmax(np.abs(w))
            cmax = np.argmax(np.abs(v))
            l2_rmax = np.sqrt(np.sum(P[rmax,:].real**2 + P[rmax,:].imag**2))
            l2_cmax = np.sqrt(np.sum(P[:,cmax].real**2 + P[:,cmax].imag**2))
            print('PSF L_2 in highest-frequency rows & cols:', l2_rmax, l2_cmax)

            l2_rmax = np.sqrt(np.sum(FG[rmax,:].real**2 + FG[rmax,:].imag**2))
            l2_cmax = np.sqrt(np.sum(FG[:,cmax].real**2 + FG[:,cmax].imag**2))
            print('Gal L_2 in highest-frequency rows & cols:', l2_rmax, l2_cmax)

            C = P * FG
            l2_rmax = np.sqrt(np.sum(C[rmax,:].real**2 + C[rmax,:].imag**2))
            l2_cmax = np.sqrt(np.sum(C[:,cmax].real**2 + C[:,cmax].imag**2))
            print('PSF*Gal L_2 in highest-frequency rows & cols:', l2_rmax, l2_cmax)
            print()

            Fpsf, Fgal = compare_subsampled(
                S, 1, ps, psf, pixpsf, Gmine,v,w,
                gal_sigma, psf_sigma, cd, get_ffts=True, eval_psf=eval_psf)
            
            plt.clf()
            plt.subplot(2,4,1)
            dimshow(P.real)
            plt.colorbar()
            plt.title('PSF real')
            plt.subplot(2,4,5)
            dimshow(P.imag)
            plt.colorbar()
            plt.title('PSF imag')

            plt.subplot(2,4,2)
            dimshow(FG.real)
            plt.colorbar()
            plt.title('Gal real')
            plt.subplot(2,4,6)
            dimshow(FG.imag)
            plt.colorbar()
            plt.title('Gal imag')

            plt.subplot(2,4,3)
            dimshow((P * FG).real)
            plt.colorbar()
            plt.title('P*Gal real')
            plt.subplot(2,4,7)
            dimshow((P * FG).imag)
            plt.colorbar()
            plt.title('P*Gal imag')

            plt.subplot(2,4,4)
            dimshow((Fgal).real)
            plt.colorbar()
            plt.title('pixGal real')
            plt.subplot(2,4,8)
            dimshow((Fgal).imag)
            plt.colorbar()
            plt.title('pixGal imag')
            
            plt.suptitle('PSF %g, Gal %g' % (psf_sigma, gal_sigma))

            ps.savefig()
            
            subsample = [1,2,4]
            rms1 = []
            for s in subsample:
                rms = compare_subsampled(S, s, ps, psf, pixpsf, Gmine,v,w, gal_sigma, psf_sigma, cd, eval_psf=eval_psf)
                rms1.append(rms)
            rms2.append(rms1)


        print()
        print('PSF sigma =', psf_sigma)
        print('RMSes:')
        for rms1,gal_sigma in zip(rms2, gal_sigmas):
            print('Gal sigma', gal_sigma, 'rms:',
                  ', '.join(['%.3g' % r for r in rms1]))
def compare_subsampled(S,
                       s,
                       ps,
                       psf,
                       pixpsf,
                       Gmine,
                       v,
                       w,
                       gal_sigma,
                       psf_sigma,
                       cd,
                       get_ffts=False,
                       eval_psf=integrate_gaussian):
    print()
    print('Subsample', s)
    print()

    step = 1. / s
    sz = s * (S - 1) + 1

    #x = np.arange(0, S, step)[:sz+1]
    x = np.arange(0, S, step)[:sz]
    #y = np.arange(0, S, step)[:sz+1]
    # Create pixelized PSF (Gaussian)
    sx = x - 0.5 + step / 2.
    subpixpsf = eval_psf(psf, sx, sx)
    binned = bin_image(subpixpsf, s)

    bh, bw = binned.shape
    pixpsf1 = pixpsf[:bh, :bw]
    ph, pw = pixpsf.shape
    binned = binned[:ph, :pw]

    print('Binned PSF:')
    measure(binned)
    print('Pix PSF:')
    measure(pixpsf)

    # Recompute my convolution using the binned PSF
    P, FG, Gmine, v, w = galaxy_psf_convolution(gal_sigma,
                                                0.,
                                                0.,
                                                GaussianGalaxy,
                                                cd,
                                                0.,
                                                0.,
                                                binned,
                                                debug=True)

    xx, yy = np.meshgrid(x, x)

    # plt.clf()
    #
    # plt.subplot(2,2,1)
    # dimshow(subpixpsf)
    # plt.title('subpix psf')
    # plt.colorbar()
    #
    # plt.subplot(2,2,2)
    # dimshow(binned)
    # plt.title('binned subpix psf')
    # plt.colorbar()
    #
    # plt.subplot(2,2,3)
    # dimshow(pixpsf1)
    # plt.title('pix psf')
    # plt.colorbar()
    #
    # plt.subplot(2,2,4)
    # dimshow(pixpsf1 - binned)
    # plt.title('pix - binned')
    # plt.colorbar()
    # plt.suptitle('subsample %i' % s)
    # ps.savefig()

    # Create pixelized galaxy image
    #gxx,gyy = xx + step/2., yy + step/2.
    gxx, gyy = xx, yy
    #gxx,gyy = xx - step, yy - step
    #gxx,gyy = xx - step/2., yy - step/2.
    center = S / 2
    subpixgal = np.exp(-0.5 * ((gxx - center)**2 + (gyy - center)**2) /
                       gal_sigma**2)
    sh, sw = subpixpsf.shape
    subpixgal = subpixgal[:sh, :sw]

    print('Subpix psf, gal', subpixpsf.shape, subpixgal.shape)

    print('Subpix PSF:')
    measure(subpixpsf)
    print('Subpix gal:')
    measure(subpixgal)

    # FFT convolution
    Fpsf = np.fft.rfft2(subpixpsf)
    spg = np.fft.ifftshift(subpixgal)

    # plt.clf()
    # for i in range(len(w)):
    #     plt.plot(v, Fpsf[i,:], 'c-')
    # for i in range(len(v)):
    #     plt.plot(w, Fpsf[:,i], 'm-')
    # plt.title('PSF Fourier transform')
    # ps.savefig()
    #
    # IV = np.argsort(v)
    # IW = np.argsort(w)
    # plt.clf()
    # for i in range(len(w)):
    #     plt.plot(v[IV], np.abs(Fpsf[i,IV]), 'c-')
    # for i in range(len(v)):
    #     plt.plot(w[IW], np.abs(Fpsf[IW,i]), 'm-')
    # plt.title('abs PSF Fourier transform')
    # ps.savefig()
    #
    # plt.yscale('log')
    # ps.savefig()

    # plt.clf()
    # dimshow(spg)
    # plt.title('spg')
    # ps.savefig()

    Fgal = np.fft.rfft2(spg)

    if get_ffts:
        return Fpsf, Fgal

    Fconv = Fpsf * Fgal
    subpixfft = np.fft.irfft2(Fconv, s=subpixpsf.shape)
    print('Shapes:', 'subpixpsf', subpixpsf.shape, 'Fpsf', Fpsf.shape)
    print('spg', spg.shape, 'Fgal', Fgal.shape, 'Fconv', Fconv.shape,
          'subpixfft', subpixfft.shape)

    print('Subpix conv', subpixfft.shape)

    binned = bin_image(subpixfft, s)
    binned /= np.sum(binned)

    print('Binned', binned.shape)
    print('Mine:', Gmine.shape)
    print('Mine:')
    measure(Gmine)
    print('Binned subpix FFT:')
    measure(binned)

    mh, mw = Gmine.shape
    binned = binned[:mh, :mw]

    plt.clf()

    plt.subplot(2, 3, 1)
    dimshow(subpixpsf)
    plt.title('subpix psf')
    plt.colorbar()

    plt.subplot(2, 3, 2)
    dimshow(subpixgal)
    plt.title('subpix galaxy')
    plt.colorbar()

    plt.subplot(2, 3, 3)
    dimshow(subpixfft)
    plt.title('subpix FFT conv')
    plt.colorbar()

    plt.subplot(2, 3, 4)
    dimshow(np.log10(np.maximum(binned / np.max(binned), 1e-12)))
    plt.title('log binned FFT conv')
    plt.colorbar()

    plt.subplot(2, 3, 5)
    dimshow(np.log10(np.maximum(Gmine / np.max(Gmine), 1e-12)))
    #dimshow(Gmine)
    plt.title('log my conv')
    plt.colorbar()

    gh, gw = Gmine.shape
    binned = binned[:gh, :gw]
    bh, bw = binned.shape
    Gmine = Gmine[:bh, :bw]
    diff = Gmine - binned

    plt.subplot(2, 3, 6)
    dimshow(diff)
    plt.title('mine - FFT')
    plt.colorbar()

    plt.suptitle('PSF %g, Gal %g, subsample %i' % (psf_sigma, gal_sigma, s))

    ps.savefig()

    rmsdiff = np.sqrt(np.mean(diff**2))
    return rmsdiff
    gal_re = 10.

    assert (ph == pw)

    ps = PlotSequence('conv')

    plt.clf()
    plt.imshow(pixpsf, interpolation='nearest', origin='lower')
    ps.savefig()

    Fpsf = np.fft.rfft2(pixpsf)

    plt.clf()
    plt.subplot(2, 4, 1)
    dimshow(Fpsf.real)
    plt.colorbar()
    plt.title('PSF real')
    plt.subplot(2, 4, 5)
    dimshow(Fpsf.imag)
    plt.colorbar()
    plt.title('PSF imag')
    ps.savefig()

    # Subsample the PSF via resampling
    from astrometry.util.util import lanczos_shift_image

    scale = 2
    sh, sw = ph * scale, pw * scale
    subpixpsf = np.zeros((sh, sw))
    for ix in np.arange(scale):
def stage_initplots(
    coimgs=None, cons=None, bands=None, ps=None,
    targetwcs=None,
    blobs=None,
    T=None, cat=None, tims=None, tractor=None, **kwargs):
    # RGB image
    # plt.clf()
    # dimshow(get_rgb(coimgs, bands))
    # ps.savefig()

    print 'T:'
    T.about()

    # cluster zoom-in
    #x0,x1, y0,y1 = 1700,2700, 200,1200
    #x0,x1, y0,y1 = 1900,2500, 400,1000
    #x0,x1, y0,y1 = 1900,2400, 450,950
    x0,x1, y0,y1 = 0,500, 0,500

    clco = [co[y0:y1, x0:x1] for co in coimgs]
    clW, clH = x1-x0, y1-y0
    clwcs = targetwcs.get_subimage(x0, y0, clW, clH)

    plt.figure(figsize=(6,6))
    plt.subplots_adjust(left=0.005, right=0.995, bottom=0.005, top=0.995)
    ps.suffixes = ['png','pdf']

    # cluster zoom-in
    rgb = get_rgb(clco, bands)
    plt.clf()
    dimshow(rgb, ticks=False)
    ps.savefig()

    # blobs
    #b0 = blobs
    #b1 = binary_dilation(blobs, np.ones((3,3)))
    #bout = np.logical_and(b1, np.logical_not(b0))
    # b0 = blobs
    # b1 = binary_erosion(b0, np.ones((3,3)))
    # bout = np.logical_and(b0, np.logical_not(b1))
    # # set green
    # rgb[:,:,0][bout] = 0.
    # rgb[:,:,1][bout] = 1.
    # rgb[:,:,2][bout] = 0.
    # plt.clf()
    # dimshow(rgb, ticks=False)
    # ps.savefig()

    # Initial model (SDSS only)
    try:
        # convert from string to int
        T.objid = np.array([int(x) if len(x) else 0 for x in T.objid])
    except:
        pass
    scat = Catalog(*[cat[i] for i in np.flatnonzero(T.objid)])
    sedcat = Catalog(*[cat[i] for i in np.flatnonzero(T.objid == 0)])

    print len(cat), 'total sources'
    print len(scat), 'SDSS sources'
    print len(sedcat), 'SED-matched sources'
    tr = Tractor(tractor.images, scat)

    comods = []
    comods2 = []
    for iband,band in enumerate(bands):
        comod = np.zeros((clH,clW))
        comod2 = np.zeros((clH,clW))
        con = np.zeros((clH,clW))
        for itim,tim in enumerate(tims):
            if tim.band != band:
                continue
            (Yo,Xo,Yi,Xi) = tim.resamp
            mod = tr.getModelImage(tim)
            Yo -= y0
            Xo -= x0
            K, = np.nonzero((Yo >= 0) * (Yo < clH) * (Xo >= 0) * (Xo < clW))
            Xo,Yo,Xi,Yi = Xo[K],Yo[K],Xi[K],Yi[K]
            comod[Yo,Xo] += mod[Yi,Xi]
            ie = tim.getInvError()
            noise = np.random.normal(size=ie.shape) / ie
            noise[ie == 0] = 0.
            comod2[Yo,Xo] += mod[Yi,Xi] + noise[Yi,Xi]
            con[Yo,Xo] += 1
        comod /= np.maximum(con, 1)
        comods.append(comod)
        comod2 /= np.maximum(con, 1)
        comods2.append(comod2)
    
    plt.clf()
    dimshow(get_rgb(comods2, bands), ticks=False)
    ps.savefig()

    plt.clf()
    dimshow(get_rgb(comods, bands), ticks=False)
    ps.savefig()

    # Overplot SDSS sources
    ax = plt.axis()
    for src in scat:
        rd = src.getPosition()
        ok,x,y = clwcs.radec2pixelxy(rd.ra, rd.dec)
        cc = (0,1,0)
        if isinstance(src, PointSource):
            plt.plot(x-1, y-1, 'o', mec=cc, mfc='none', ms=10, mew=1.5)
        else:
            plt.plot(x-1, y-1, 'o', mec='r', mfc='none', ms=10, mew=1.5)
    plt.axis(ax)
    ps.savefig()

    # Add SED-matched detections
    for src in sedcat:
        rd = src.getPosition()
        ok,x,y = clwcs.radec2pixelxy(rd.ra, rd.dec)
        plt.plot(x-1, y-1, 'o', mec='c', mfc='none', ms=10, mew=1.5)
    plt.axis(ax)
    ps.savefig()

    # Mark SED-matched detections on image
    plt.clf()
    dimshow(get_rgb(clco, bands), ticks=False)
    ax = plt.axis()
    for src in sedcat:
        rd = src.getPosition()
        ok,x,y = clwcs.radec2pixelxy(rd.ra, rd.dec)
        #plt.plot(x-1, y-1, 'o', mec='c', mfc='none', ms=10, mew=1.5)
        x,y = x-1, y-1
        hi,lo = 20,7
        # plt.plot([x-lo,x-hi],[y,y], 'c-')
        # plt.plot([x+lo,x+hi],[y,y], 'c-')
        # plt.plot([x,x],[y+lo,y+hi], 'c-')
        # plt.plot([x,x],[y-lo,y-hi], 'c-')
        plt.annotate('', (x,y+lo), xytext=(x,y+hi),
                     arrowprops=dict(color='c', width=1, frac=0.3, headwidth=5))
    plt.axis(ax)
    ps.savefig()

    # plt.clf()
    # dimshow(get_rgb([gaussian_filter(x,1) for x in clco], bands), ticks=False)
    # ps.savefig()

    # Resid
    # plt.clf()
    # dimshow(get_rgb([im-mo for im,mo in zip(clco,comods)], bands), ticks=False)
    # ps.savefig()

    # find SDSS fields within that WCS
    sdsscoimgs,nil = sdss_coadd(clwcs, bands)

    plt.clf()
    dimshow(get_rgb(sdsscoimgs, bands, **rgbkwargs), ticks=False)
    #plt.title('SDSS')
    ps.savefig()
Exemple #14
0
def stage0(**kwargs):
    ps = PlotSequence('cfht')

    decals = CfhtDecals()
    B = decals.get_bricks()
    print 'Bricks:'
    B.about()

    ra,dec = 190.0, 11.0

    #bands = 'ugri'
    bands = 'gri'
    
    B.cut(np.argsort(degrees_between(ra, dec, B.ra, B.dec)))
    print 'Nearest bricks:', B.ra[:5], B.dec[:5], B.brickid[:5]

    brick = B[0]
    pixscale = 0.186
    #W,H = 1024,1024
    #W,H = 2048,2048
    #W,H = 3600,3600
    W,H = 4800,4800

    targetwcs = wcs_for_brick(brick, pixscale=pixscale, W=W, H=H)
    ccdfn = 'cfht-ccds.fits'
    if os.path.exists(ccdfn):
        T = fits_table(ccdfn)
    else:
        T = get_ccd_list()
        T.writeto(ccdfn)
    print len(T), 'CCDs'
    T.cut(ccds_touching_wcs(targetwcs, T))
    print len(T), 'CCDs touching brick'

    T.cut(np.array([b in bands for b in T.filter]))
    print len(T), 'in bands', bands

    ims = []
    for t in T:
        im = CfhtImage(t)
        # magzp = hdr['PHOT_C'] + 2.5 * np.log10(hdr['EXPTIME'])
        # fwhm = t.seeing / (pixscale * 3600)
        # print '-> FWHM', fwhm, 'pix'
        im.seeing = t.seeing
        im.pixscale = t.pixscale
        print 'seeing', t.seeing
        print 'pixscale', im.pixscale*3600, 'arcsec/pix'
        im.run_calibs(t.ra, t.dec, im.pixscale, W=t.width, H=t.height)
        ims.append(im)


    # Read images, clip to ROI
    targetrd = np.array([targetwcs.pixelxy2radec(x,y) for x,y in
                         [(1,1),(W,1),(W,H),(1,H),(1,1)]])
    keepims = []
    tims = []
    for im in ims:
        print
        print 'Reading expnum', im.expnum, 'name', im.extname, 'band', im.band, 'exptime', im.exptime
        band = im.band
        wcs = im.read_wcs()
        imh,imw = wcs.imageh,wcs.imagew
        imgpoly = [(1,1),(1,imh),(imw,imh),(imw,1)]
        ok,tx,ty = wcs.radec2pixelxy(targetrd[:-1,0], targetrd[:-1,1])
        tpoly = zip(tx,ty)
        clip = clip_polygon(imgpoly, tpoly)
        clip = np.array(clip)
        #print 'Clip', clip
        if len(clip) == 0:
            continue
        x0,y0 = np.floor(clip.min(axis=0)).astype(int)
        x1,y1 = np.ceil (clip.max(axis=0)).astype(int)
        slc = slice(y0,y1+1), slice(x0,x1+1)

        ## FIXME -- it seems I got lucky and the cross product is
        ## negative == clockwise, as required by clip_polygon. One
        ## could check this and reverse the polygon vertex order.
        # dx0,dy0 = tx[1]-tx[0], ty[1]-ty[0]
        # dx1,dy1 = tx[2]-tx[1], ty[2]-ty[1]
        # cross = dx0*dy1 - dx1*dy0
        # print 'Cross:', cross

        print 'Image slice: x [%i,%i], y [%i,%i]' % (x0,x1, y0,y1)
        print 'Reading image from', im.imgfn, 'HDU', im.hdu
        img,imghdr = im.read_image(header=True, slice=slc)
        goodpix = (img != 0)
        print 'Number of pixels == 0:', np.sum(img == 0)
        print 'Number of pixels != 0:', np.sum(goodpix)
        if np.sum(goodpix) == 0:
            continue
        # print 'Image shape', img.shape
        print 'Image range', img.min(), img.max()
        print 'Goodpix image range:', (img[goodpix]).min(), (img[goodpix]).max()
        if img[goodpix].min() == img[goodpix].max():
            print 'No dynamic range in image'
            continue
        # print 'Reading invvar from', im.wtfn, 'HDU', im.hdu
        # invvar = im.read_invvar(slice=slc)
        # # print 'Invvar shape', invvar.shape
        # # print 'Invvar range:', invvar.min(), invvar.max()
        # invvar[goodpix == 0] = 0.
        # if np.all(invvar == 0.):
        #     print 'Skipping zero-invvar image'
        #     continue
        # assert(np.all(np.isfinite(img)))
        # assert(np.all(np.isfinite(invvar)))
        # assert(not(np.all(invvar == 0.)))
        # # Estimate per-pixel noise via Blanton's 5-pixel MAD
        # slice1 = (slice(0,-5,10),slice(0,-5,10))
        # slice2 = (slice(5,None,10),slice(5,None,10))
        # # print 'sliced shapes:', img[slice1].shape, img[slice2].shape
        # # print 'good shape:', (goodpix[slice1] * goodpix[slice2]).shape
        # # print 'good values:', np.unique(goodpix[slice1] * goodpix[slice2])
        # # print 'sliced[good] shapes:', (img[slice1] -  img[slice2])[goodpix[slice1] * goodpix[slice2]].shape
        # mad = np.median(np.abs(img[slice1] - img[slice2])[goodpix[slice1] * goodpix[slice2]].ravel())
        # sig1 = 1.4826 * mad / np.sqrt(2.)
        # print 'MAD sig1:', sig1
        # # invvar was 1 or 0
        # invvar *= (1./(sig1**2))
        # medsky = np.median(img[goodpix])

        # Read full image for sig1 and sky estimate
        fullimg = im.read_image()
        fullgood = (fullimg != 0)
        # Estimate per-pixel noise via Blanton's 5-pixel MAD
        slice1 = (slice(0,-5,10),slice(0,-5,10))
        slice2 = (slice(5,None,10),slice(5,None,10))
        mad = np.median(np.abs(fullimg[slice1] - fullimg[slice2])[fullgood[slice1] * fullgood[slice2]].ravel())
        sig1 = 1.4826 * mad / np.sqrt(2.)
        print 'MAD sig1:', sig1
        medsky = np.median(fullimg[fullgood])
        invvar = np.zeros_like(img)
        invvar[goodpix] = 1./sig1**2

        # Median-smooth sky subtraction
        plt.clf()
        dimshow(np.round((img-medsky) / sig1), vmin=-3, vmax=5)
        plt.title('Scalar median: %s' % im.name)
        ps.savefig()

        # medsky = np.zeros_like(img)
        # # astrometry.util.util
        # median_smooth(img, np.logical_not(goodpix), 256, medsky)
        fullmed = np.zeros_like(fullimg)
        median_smooth(fullimg - medsky, np.logical_not(fullgood), 256, fullmed)
        fullmed += medsky
        medimg = fullmed[slc]
        
        plt.clf()
        dimshow(np.round((img - medimg) / sig1), vmin=-3, vmax=5)
        plt.title('Median filtered: %s' % im.name)
        ps.savefig()
        
        #print 'Subtracting median:', medsky
        #img -= medsky
        img -= medimg
        
        primhdr = im.read_image_primary_header()

        magzp = decals.get_zeropoint_for(im)
        print 'magzp', magzp
        zpscale = NanoMaggies.zeropointToScale(magzp)
        print 'zpscale', zpscale

        # Scale images to Nanomaggies
        img /= zpscale
        sig1 /= zpscale
        invvar *= zpscale**2
        orig_zpscale = zpscale

        zpscale = 1.
        assert(np.sum(invvar > 0) > 0)
        print 'After scaling:'
        print 'sig1', sig1
        print 'invvar range', invvar.min(), invvar.max()
        print 'image range', img.min(), img.max()

        assert(np.all(np.isfinite(img)))
        assert(np.all(np.isfinite(invvar)))
        assert(np.isfinite(sig1))

        plt.clf()
        lo,hi = -5*sig1, 10*sig1
        n,b,p = plt.hist(img[goodpix].ravel(), 100, range=(lo,hi), histtype='step', color='k')
        xx = np.linspace(lo, hi, 200)
        plt.plot(xx, max(n)*np.exp(-xx**2 / (2.*sig1**2)), 'r-')
        plt.xlim(lo,hi)
        plt.title('Pixel histogram: %s' % im.name)
        ps.savefig()

        twcs = ConstantFitsWcs(wcs)
        if x0 or y0:
            twcs.setX0Y0(x0,y0)

        info = im.get_image_info()
        fullh,fullw = info['dims']

        # read fit PsfEx model
        psfex = PsfEx.fromFits(im.psffitfn)
        print 'Read', psfex

        # HACK -- highly approximate PSF here!
        #psf_fwhm = imghdr['FWHM']
        #psf_fwhm = im.seeing

        psf_fwhm = im.seeing / (im.pixscale * 3600)
        print 'PSF FWHM', psf_fwhm, 'pixels'
        psf_sigma = psf_fwhm / 2.35
        psf = NCircularGaussianPSF([psf_sigma],[1.])

        print 'img type', img.dtype
        
        tim = Image(img, invvar=invvar, wcs=twcs, psf=psf,
                    photocal=LinearPhotoCal(zpscale, band=band),
                    sky=ConstantSky(0.), name=im.name + ' ' + band)
        tim.zr = [-3. * sig1, 10. * sig1]
        tim.sig1 = sig1
        tim.band = band
        tim.psf_fwhm = psf_fwhm
        tim.psf_sigma = psf_sigma
        tim.sip_wcs = wcs
        tim.x0,tim.y0 = int(x0),int(y0)
        tim.psfex = psfex
        tim.imobj = im
        mn,mx = tim.zr
        tim.ima = dict(interpolation='nearest', origin='lower', cmap='gray',
                       vmin=mn, vmax=mx)
        tims.append(tim)
        keepims.append(im)

    ims = keepims

    print 'Computing resampling...'
    # save resampling params
    for tim in tims:
        wcs = tim.sip_wcs
        subh,subw = tim.shape
        subwcs = wcs.get_subimage(tim.x0, tim.y0, subw, subh)
        tim.subwcs = subwcs
        try:
            Yo,Xo,Yi,Xi,rims = resample_with_wcs(targetwcs, subwcs, [], 2)
        except OverlapError:
            print 'No overlap'
            continue
        if len(Yo) == 0:
            continue
        tim.resamp = (Yo,Xo,Yi,Xi)

    print 'Creating coadds...'
    # Produce per-band coadds, for plots
    coimgs = []
    cons = []
    for ib,band in enumerate(bands):
        coimg = np.zeros((H,W), np.float32)
        con   = np.zeros((H,W), np.uint8)
        for tim in tims:
            if tim.band != band:
                continue
            (Yo,Xo,Yi,Xi) = tim.resamp
            if len(Yo) == 0:
                continue
            nn = (tim.getInvvar()[Yi,Xi] > 0)
            coimg[Yo,Xo] += tim.getImage ()[Yi,Xi] * nn
            con  [Yo,Xo] += nn

            # print
            # print 'tim', tim.name
            # print 'number of resampled pix:', len(Yo)
            # reim = np.zeros_like(coimg)
            # ren  = np.zeros_like(coimg)
            # reim[Yo,Xo] = tim.getImage()[Yi,Xi] * nn
            # ren[Yo,Xo] = nn
            # print 'number of resampled pix with positive invvar:', ren.sum()
            # plt.clf()
            # plt.subplot(2,2,1)
            # mn,mx = [np.percentile(reim[ren>0], p) for p in [25,95]]
            # print 'Percentiles:', mn,mx
            # dimshow(reim, vmin=mn, vmax=mx)
            # plt.colorbar()
            # plt.subplot(2,2,2)
            # dimshow(con)
            # plt.colorbar()
            # plt.subplot(2,2,3)
            # dimshow(reim, vmin=tim.zr[0], vmax=tim.zr[1])
            # plt.colorbar()
            # plt.subplot(2,2,4)
            # plt.hist(reim.ravel(), 100, histtype='step', color='b')
            # plt.hist(tim.getImage().ravel(), 100, histtype='step', color='r')
            # plt.suptitle('%s: %s' % (band, tim.name))
            # ps.savefig()

        coimg /= np.maximum(con,1)
        coimgs.append(coimg)
        cons  .append(con)

    plt.clf()
    dimshow(get_rgb(coimgs, bands))
    ps.savefig()

    plt.clf()
    for i,b in enumerate(bands):
        plt.subplot(2,2,i+1)
        dimshow(cons[i], ticks=False)
        plt.title('%s band' % b)
        plt.colorbar()
    plt.suptitle('Number of exposures')
    ps.savefig()

    print 'Grabbing SDSS sources...'
    bandlist = [b for b in bands]
    cat,T = get_sdss_sources(bandlist, targetwcs)
    # record coordinates in target brick image
    ok,T.tx,T.ty = targetwcs.radec2pixelxy(T.ra, T.dec)
    T.tx -= 1
    T.ty -= 1
    T.itx = np.clip(np.round(T.tx).astype(int), 0, W-1)
    T.ity = np.clip(np.round(T.ty).astype(int), 0, H-1)

    plt.clf()
    dimshow(get_rgb(coimgs, bands))
    ax = plt.axis()
    plt.plot(T.tx, T.ty, 'o', mec=green, mfc='none', ms=10, mew=1.5)
    plt.axis(ax)
    plt.title('SDSS sources')
    ps.savefig()

    print 'Detmaps...'
    # Render the detection maps
    detmaps = dict([(b, np.zeros((H,W), np.float32)) for b in bands])
    detivs  = dict([(b, np.zeros((H,W), np.float32)) for b in bands])
    for tim in tims:
        iv = tim.getInvvar()
        psfnorm = 1./(2. * np.sqrt(np.pi) * tim.psf_sigma)
        detim = tim.getImage().copy()
        detim[iv == 0] = 0.
        detim = gaussian_filter(detim, tim.psf_sigma) / psfnorm**2
        detsig1 = tim.sig1 / psfnorm
        subh,subw = tim.shape
        detiv = np.zeros((subh,subw), np.float32) + (1. / detsig1**2)
        detiv[iv == 0] = 0.
        (Yo,Xo,Yi,Xi) = tim.resamp
        detmaps[tim.band][Yo,Xo] += detiv[Yi,Xi] * detim[Yi,Xi]
        detivs [tim.band][Yo,Xo] += detiv[Yi,Xi]

    rtn = dict()
    for k in ['T', 'coimgs', 'cons', 'detmaps', 'detivs',
              'targetrd', 'pixscale', 'targetwcs', 'W','H',
              'bands', 'tims', 'ps', 'brick', 'cat']:
        rtn[k] = locals()[k]
    return rtn
def stage_initplots(coimgs=None,
                    cons=None,
                    bands=None,
                    ps=None,
                    targetwcs=None,
                    blobs=None,
                    T=None,
                    cat=None,
                    tims=None,
                    tractor=None,
                    **kwargs):
    # RGB image
    # plt.clf()
    # dimshow(get_rgb(coimgs, bands))
    # ps.savefig()

    print('T:')
    T.about()

    # cluster zoom-in
    #x0,x1, y0,y1 = 1700,2700, 200,1200
    #x0,x1, y0,y1 = 1900,2500, 400,1000
    #x0,x1, y0,y1 = 1900,2400, 450,950
    x0, x1, y0, y1 = 0, 500, 0, 500

    clco = [co[y0:y1, x0:x1] for co in coimgs]
    clW, clH = x1 - x0, y1 - y0
    clwcs = targetwcs.get_subimage(x0, y0, clW, clH)

    plt.figure(figsize=(6, 6))
    plt.subplots_adjust(left=0.005, right=0.995, bottom=0.005, top=0.995)
    ps.suffixes = ['png', 'pdf']

    # cluster zoom-in
    rgb = get_rgb(clco, bands)
    plt.clf()
    dimshow(rgb, ticks=False)
    ps.savefig()

    # blobs
    #b0 = blobs
    #b1 = binary_dilation(blobs, np.ones((3,3)))
    #bout = np.logical_and(b1, np.logical_not(b0))
    # b0 = blobs
    # b1 = binary_erosion(b0, np.ones((3,3)))
    # bout = np.logical_and(b0, np.logical_not(b1))
    # # set green
    # rgb[:,:,0][bout] = 0.
    # rgb[:,:,1][bout] = 1.
    # rgb[:,:,2][bout] = 0.
    # plt.clf()
    # dimshow(rgb, ticks=False)
    # ps.savefig()

    # Initial model (SDSS only)
    try:
        # convert from string to int
        T.objid = np.array([int(x) if len(x) else 0 for x in T.objid])
    except:
        pass
    scat = Catalog(*[cat[i] for i in np.flatnonzero(T.objid)])
    sedcat = Catalog(*[cat[i] for i in np.flatnonzero(T.objid == 0)])

    print(len(cat), 'total sources')
    print(len(scat), 'SDSS sources')
    print(len(sedcat), 'SED-matched sources')
    tr = Tractor(tractor.images, scat)

    comods = []
    comods2 = []
    for iband, band in enumerate(bands):
        comod = np.zeros((clH, clW))
        comod2 = np.zeros((clH, clW))
        con = np.zeros((clH, clW))
        for itim, tim in enumerate(tims):
            if tim.band != band:
                continue
            (Yo, Xo, Yi, Xi) = tim.resamp
            mod = tr.getModelImage(tim)
            Yo -= y0
            Xo -= x0
            K, = np.nonzero((Yo >= 0) * (Yo < clH) * (Xo >= 0) * (Xo < clW))
            Xo, Yo, Xi, Yi = Xo[K], Yo[K], Xi[K], Yi[K]
            comod[Yo, Xo] += mod[Yi, Xi]
            ie = tim.getInvError()
            noise = np.random.normal(size=ie.shape) / ie
            noise[ie == 0] = 0.
            comod2[Yo, Xo] += mod[Yi, Xi] + noise[Yi, Xi]
            con[Yo, Xo] += 1
        comod /= np.maximum(con, 1)
        comods.append(comod)
        comod2 /= np.maximum(con, 1)
        comods2.append(comod2)

    plt.clf()
    dimshow(get_rgb(comods2, bands), ticks=False)
    ps.savefig()

    plt.clf()
    dimshow(get_rgb(comods, bands), ticks=False)
    ps.savefig()

    # Overplot SDSS sources
    ax = plt.axis()
    for src in scat:
        rd = src.getPosition()
        ok, x, y = clwcs.radec2pixelxy(rd.ra, rd.dec)
        cc = (0, 1, 0)
        if isinstance(src, PointSource):
            plt.plot(x - 1, y - 1, 'o', mec=cc, mfc='none', ms=10, mew=1.5)
        else:
            plt.plot(x - 1, y - 1, 'o', mec='r', mfc='none', ms=10, mew=1.5)
    plt.axis(ax)
    ps.savefig()

    # Add SED-matched detections
    for src in sedcat:
        rd = src.getPosition()
        ok, x, y = clwcs.radec2pixelxy(rd.ra, rd.dec)
        plt.plot(x - 1, y - 1, 'o', mec='c', mfc='none', ms=10, mew=1.5)
    plt.axis(ax)
    ps.savefig()

    # Mark SED-matched detections on image
    plt.clf()
    dimshow(get_rgb(clco, bands), ticks=False)
    ax = plt.axis()
    for src in sedcat:
        rd = src.getPosition()
        ok, x, y = clwcs.radec2pixelxy(rd.ra, rd.dec)
        #plt.plot(x-1, y-1, 'o', mec='c', mfc='none', ms=10, mew=1.5)
        x, y = x - 1, y - 1
        hi, lo = 20, 7
        # plt.plot([x-lo,x-hi],[y,y], 'c-')
        # plt.plot([x+lo,x+hi],[y,y], 'c-')
        # plt.plot([x,x],[y+lo,y+hi], 'c-')
        # plt.plot([x,x],[y-lo,y-hi], 'c-')
        plt.annotate('', (x, y + lo),
                     xytext=(x, y + hi),
                     arrowprops=dict(color='c', width=1, frac=0.3,
                                     headwidth=5))
    plt.axis(ax)
    ps.savefig()

    # plt.clf()
    # dimshow(get_rgb([gaussian_filter(x,1) for x in clco], bands), ticks=False)
    # ps.savefig()

    # Resid
    # plt.clf()
    # dimshow(get_rgb([im-mo for im,mo in zip(clco,comods)], bands), ticks=False)
    # ps.savefig()

    # find SDSS fields within that WCS
    sdsscoimgs, nil = sdss_coadd(clwcs, bands)

    plt.clf()
    dimshow(get_rgb(sdsscoimgs, bands, **rgbkwargs), ticks=False)
    #plt.title('SDSS')
    ps.savefig()
Exemple #16
0
def segment_and_group_sources(image, T, name=None, ps=None, plots=False):
    '''
    *image*: binary image that defines "blobs"
    *T*: source table; only ".itx" and ".ity" elements are used (x,y integer
        pix pos).  Note: ".blob" field is added.
    *name*: for debugging only

    Returns: (blobs, blobsrcs, blobslices)

    *blobs*: image, values -1 = no blob, integer blob indices
    *blobsrcs*: list of np arrays of integers, elements in T within each blob
    *blobslices*: list of slice objects for blob bounding-boxes.
    
    '''
    from scipy.ndimage.morphology import binary_fill_holes
    from scipy.ndimage.measurements import label, find_objects

    emptyblob = 0

    image = binary_fill_holes(image)

    blobs,nblobs = label(image)
    print('N detected blobs:', nblobs)
    H,W = image.shape
    del image

    blobslices = find_objects(blobs)
    T.blob = blobs[T.ity, T.itx]

    if plots:
        import pylab as plt
        from astrometry.util.plotutils import dimshow
        plt.clf()
        dimshow(blobs > 0, vmin=0, vmax=1)
        ax = plt.axis()
        for i,bs in enumerate(blobslices):
            sy,sx = bs
            by0,by1 = sy.start, sy.stop
            bx0,bx1 = sx.start, sx.stop
            plt.plot([bx0, bx0, bx1, bx1, bx0], [by0, by1, by1, by0, by0], 'r-')
            plt.text((bx0+bx1)/2., by0, '%i' % (i+1), ha='center', va='bottom', color='r')
        plt.plot(T.itx, T.ity, 'rx')
        for i,t in enumerate(T):
            plt.text(t.itx, t.ity, 'src %i' % i, color='red', ha='left', va='center')
        plt.axis(ax)
        plt.title('Blobs')
        ps.savefig()

    # Find sets of sources within blobs
    blobsrcs = []
    keepslices = []
    blobmap = {}
    dropslices = {}
    for blob in range(1, nblobs+1):
        Isrcs = np.flatnonzero(T.blob == blob)
        if len(Isrcs) == 0:
            #print('Blob', blob, 'has no sources')
            blobmap[blob] = -1
            dropslices[blob] = blobslices[blob-1]
            continue
        blobmap[blob] = len(blobsrcs)
        blobsrcs.append(Isrcs)
        bslc = blobslices[blob-1]
        keepslices.append(bslc)

    blobslices = keepslices

    # Find sources that do not belong to a blob and add them as
    # singleton "blobs"; otherwise they don't get optimized.
    # for sources outside the image bounds, what should we do?
    inblobs = np.zeros(len(T), bool)
    for Isrcs in blobsrcs:
        inblobs[Isrcs] = True
    noblobs = np.flatnonzero(np.logical_not(inblobs))
    del inblobs
    # Add new fake blobs!
    for ib,i in enumerate(noblobs):
        #S = 3
        S = 5
        bslc = (slice(np.clip(T.ity[i] - S, 0, H-1), np.clip(T.ity[i] + S+1, 0, H)),
                slice(np.clip(T.itx[i] - S, 0, W-1), np.clip(T.itx[i] + S+1, 0, W)))

        # Does this new blob overlap existing blob(s)?
        oblobs = np.unique(blobs[bslc])
        oblobs = oblobs[oblobs != emptyblob]

        #print('This blob overlaps existing blobs:', oblobs)
        if len(oblobs) > 1:
            print('WARNING: not merging overlapping blobs like maybe we should')
        if len(oblobs):
            blob = oblobs[0]
            #print('Adding source to existing blob', blob)
            blobs[bslc][blobs[bslc] == emptyblob] = blob
            blobindex = blobmap[blob]
            if blobindex == -1:
                # the overlapping blob was going to be dropped -- restore it.
                blobindex = len(blobsrcs)
                blobmap[blob] = blobindex
                blobslices.append(dropslices[blob])
                blobsrcs.append(np.array([], np.int64))
            # Expand the existing blob slice to encompass this new source
            oldslc = blobslices[blobindex]
            sy,sx = oldslc
            oy0,oy1, ox0,ox1 = sy.start,sy.stop, sx.start,sx.stop
            sy,sx = bslc
            ny0,ny1, nx0,nx1 = sy.start,sy.stop, sx.start,sx.stop
            newslc = slice(min(oy0,ny0), max(oy1,ny1)), slice(min(ox0,nx0), max(ox1,nx1))
            blobslices[blobindex] = newslc
            # Add this source to the list of source indices for the existing blob.
            blobsrcs[blobindex] = np.append(blobsrcs[blobindex], np.array([i]))

        else:
            # Set synthetic blob number
            blob = nblobs+1 + ib
            blobs[bslc][blobs[bslc] == emptyblob] = blob
            blobmap[blob] = len(blobsrcs)
            blobslices.append(bslc)
            blobsrcs.append(np.array([i]))
    #print('Added', len(noblobs), 'new fake singleton blobs')

    # Remap the "blobs" image so that empty regions are = -1 and the blob values
    # correspond to their indices in the "blobsrcs" list.
    if len(blobmap):
        maxblob = max(blobmap.keys())
    else:
        maxblob = 0
    maxblob = max(maxblob, blobs.max())
    bm = np.zeros(maxblob + 1, int)
    for k,v in blobmap.items():
        bm[k] = v
    bm[0] = -1

    # DEBUG
    if plots:
        import fitsio
        fitsio.write('blobs-before-%s.fits' % name, blobs, clobber=True)

    # Remap blob numbers
    blobs = bm[blobs]

    if plots:
        import fitsio
        fitsio.write('blobs-after-%s.fits' % name, blobs, clobber=True)

    if plots:
        import pylab as plt
        from astrometry.util.plotutils import dimshow
        plt.clf()
        dimshow(blobs > -1, vmin=0, vmax=1)
        ax = plt.axis()
        for i,bs in enumerate(blobslices):
            sy,sx = bs
            by0,by1 = sy.start, sy.stop
            bx0,bx1 = sx.start, sx.stop
            plt.plot([bx0, bx0, bx1, bx1, bx0], [by0, by1, by1, by0, by0], 'r-')
            plt.text((bx0+bx1)/2., by0, '%i' % (i+1), ha='center', va='bottom', color='r')
        plt.plot(T.itx, T.ity, 'rx')
        for i,t in enumerate(T):
            plt.text(t.itx, t.ity, 'src %i' % i, color='red', ha='left', va='center')
        plt.axis(ax)
        plt.title('Blobs')
        ps.savefig()

    for j,Isrcs in enumerate(blobsrcs):
        for i in Isrcs:
            #assert(blobs[T.ity[i], T.itx[i]] == j)
            if (blobs[T.ity[i], T.itx[i]] != j):
                print('---------------------------!!!--------------------------')
                print('Blob', j, 'sources', Isrcs)
                print('Source', i, 'coords x,y', T.itx[i], T.ity[i])
                print('Expected blob value', j, 'but got', blobs[T.ity[i], T.itx[i]])

    T.blob = blobs[T.ity, T.itx]
    assert(len(blobsrcs) == len(blobslices))

    return blobs, blobsrcs, blobslices
def detection_plots(detmaps, detivs, bands, saturated_pix, tims, targetwcs,
                    refstars, large_galaxies, gaia_stars, ps):
    rgb = get_rgb(detmaps, bands)
    plt.clf()
    dimshow(rgb)
    plt.title('detmaps')
    ps.savefig()

    for i, satpix in enumerate(saturated_pix):
        rgb[:, :, 2 - i][satpix] = 1
    plt.clf()
    dimshow(rgb)
    plt.title('detmaps & saturated')
    ps.savefig()

    coimgs, _ = quick_coadds(tims, bands, targetwcs, fill_holes=False)

    if refstars:
        plt.clf()
        dimshow(get_rgb(coimgs, bands))
        ax = plt.axis()
        lp, lt = [], []
        tycho = refstars[refstars.isbright]
        if len(tycho):
            _, ix, iy = targetwcs.radec2pixelxy(tycho.ra, tycho.dec)
            p = plt.plot(ix - 1,
                         iy - 1,
                         'o',
                         mew=3,
                         ms=14,
                         mec='r',
                         mfc='none')
            lp.append(p)
            lt.append('Tycho-2 only')
        if gaia_stars:
            gaia = refstars[refstars.isgaia]
        if gaia_stars and len(gaia):
            _, ix, iy = targetwcs.radec2pixelxy(gaia.ra, gaia.dec)
            p = plt.plot(ix - 1,
                         iy - 1,
                         'o',
                         mew=3,
                         ms=10,
                         mec='c',
                         mfc='none')
            for x, y, g in zip(ix, iy, gaia.phot_g_mean_mag):
                plt.text(x,
                         y,
                         '%.1f' % g,
                         color='k',
                         bbox=dict(facecolor='w', alpha=0.5))
            lp.append(p)
            lt.append('Gaia')
        # star_clusters?
        if large_galaxies:
            galaxies = refstars[refstars.islargegalaxy]
        if large_galaxies and len(galaxies):
            _, ix, iy = targetwcs.radec2pixelxy(galaxies.ra, galaxies.dec)
            p = plt.plot(ix - 1,
                         iy - 1,
                         'o',
                         mew=3,
                         ms=14,
                         mec=(0, 1, 0),
                         mfc='none')
            lp.append(p)
            lt.append('Galaxies')
        plt.axis(ax)
        plt.title('Ref sources')
        plt.figlegend([p[0] for p in lp], lt)
        ps.savefig()

    for band, detmap, detiv in zip(bands, detmaps, detivs):
        plt.clf()
        plt.subplot(2, 1, 1)
        plt.hist((detmap * np.sqrt(detiv))[detiv > 0],
                 bins=50,
                 range=(-5, 8),
                 log=True)
        plt.title('Detection map pixel values (sigmas): band %s' % band)
        plt.subplot(2, 1, 2)
        plt.hist((detmap * np.sqrt(detiv))[detiv > 0], bins=50, range=(-5, 8))
        ps.savefig()
Exemple #18
0
def sed_matched_detection(sedname, sed, detmaps, detivs, bands,
                          xomit, yomit,
                          nsigma=5.,
                          saturated_pix=None,
                          saddle=2.,
                          cutonaper=True,
                          ps=None):
    '''
    Runs a single SED-matched detection filter.

    Avoids creating sources close to existing sources.

    Parameters
    ----------
    sedname : string
        Name of this SED; only used for plots.
    sed : list of floats
        The SED -- a list of floats, one per band, of this SED.
    detmaps : list of numpy arrays
        The per-band detection maps.  These must all be the same size, the
        brick image size.
    detivs : list of numpy arrays
        The inverse-variance maps associated with `detmaps`.
    bands : list of strings
        The band names of the `detmaps` and `detivs` images.
    xomit, yomit : iterables (lists or numpy arrays) of int
        Previously known sources that are to be avoided.
    nsigma : float, optional
        Detection threshold.
    saturated_pix : None or numpy array, boolean
        A map of pixels that are always considered "hot" when
        determining whether a new source touches hot pixels of an
        existing source.
    saddle : float, optional
        Saddle-point depth from existing sources down to new sources.
    cutonaper : bool, optional
        Apply a cut that the source's detection strength must be greater
        than `nsigma` above the 16th percentile of the detection strength in
        an annulus (from 10 to 20 pixels) around the source.
    ps : PlotSequence object, optional
        Create plots?

    Returns
    -------
    hotblobs : numpy array of bool
        A map of the blobs yielding sources in this SED.
    px, py : numpy array of int
        The new sources found.
    aper : numpy array of float
        The detection strength in the annulus around the source, if
        `cutonaper` is set; else -1.
    peakval : numpy array of float
        The detection strength.

    See also
    --------
    sed_matched_filters : creates the `(sedname, sed)` pairs used here
    run_sed_matched_filters : calls this method
    
    '''
    from scipy.ndimage.measurements import label, find_objects
    from scipy.ndimage.morphology import binary_dilation, binary_fill_holes

    t0 = Time()
    H,W = detmaps[0].shape

    allzero = True
    for iband,band in enumerate(bands):
        if sed[iband] == 0:
            continue
        if np.all(detivs[iband] == 0):
            continue
        allzero = False
        break
    if allzero:
        print('SED', sedname, 'has all zero weight')
        return None,None,None,None,None

    sedmap = np.zeros((H,W), np.float32)
    sediv  = np.zeros((H,W), np.float32)
    for iband,band in enumerate(bands):
        if sed[iband] == 0:
            continue
        # We convert the detmap to canonical band via
        #   detmap * w
        # And the corresponding change to sig1 is
        #   sig1 * w
        # So the invvar-weighted sum is
        #    (detmap * w) / (sig1**2 * w**2)
        #  = detmap / (sig1**2 * w)
        sedmap += detmaps[iband] * detivs[iband] / sed[iband]
        sediv  += detivs [iband] / sed[iband]**2
    sedmap /= np.maximum(1e-16, sediv)
    sedsn   = sedmap * np.sqrt(sediv)
    del sedmap

    peaks = (sedsn > nsigma)
    print('SED sn:', Time()-t0)
    t0 = Time()

    def saddle_level(Y):
        # Require a saddle that drops by (the larger of) "saddle"
        # sigma, or 20% of the peak height
        drop = max(saddle, Y * 0.2)
        return Y - drop

    lowest_saddle = nsigma - saddle

    # zero out the edges -- larger margin here?
    peaks[0 ,:] = 0
    peaks[:, 0] = 0
    peaks[-1,:] = 0
    peaks[:,-1] = 0

    # Label the N-sigma blobs at this point... we'll use this to build
    # "sedhot", which in turn is used to define the blobs that we will
    # optimize simultaneously.  This also determines which pixels go
    # into the fitting!
    dilate = 8
    hotblobs,nhot = label(binary_fill_holes(
            binary_dilation(peaks, iterations=dilate)))

    # find pixels that are larger than their 8 neighbors
    peaks[1:-1, 1:-1] &= (sedsn[1:-1,1:-1] >= sedsn[0:-2,1:-1])
    peaks[1:-1, 1:-1] &= (sedsn[1:-1,1:-1] >= sedsn[2:  ,1:-1])
    peaks[1:-1, 1:-1] &= (sedsn[1:-1,1:-1] >= sedsn[1:-1,0:-2])
    peaks[1:-1, 1:-1] &= (sedsn[1:-1,1:-1] >= sedsn[1:-1,2:  ])
    peaks[1:-1, 1:-1] &= (sedsn[1:-1,1:-1] >= sedsn[0:-2,0:-2])
    peaks[1:-1, 1:-1] &= (sedsn[1:-1,1:-1] >= sedsn[0:-2,2:  ])
    peaks[1:-1, 1:-1] &= (sedsn[1:-1,1:-1] >= sedsn[2:  ,0:-2])
    peaks[1:-1, 1:-1] &= (sedsn[1:-1,1:-1] >= sedsn[2:  ,2:  ])
    print('Peaks:', Time()-t0)
    t0 = Time()

    if ps is not None:
        from astrometry.util.plotutils import dimshow
        crossa = dict(ms=10, mew=1.5)
        green = (0,1,0)

        def plot_boundary_map(X):
            bounds = binary_dilation(X) - X
            H,W = X.shape
            rgba = np.zeros((H,W,4), np.uint8)
            rgba[:,:,1] = bounds*255
            rgba[:,:,3] = bounds*255
            plt.imshow(rgba, interpolation='nearest', origin='lower')

        plt.clf()
        plt.subplot(1,2,2)
        dimshow(sedsn, vmin=-2, vmax=100, cmap='hot', ticks=False)
        plt.subplot(1,2,1)
        dimshow(sedsn, vmin=-2, vmax=10, cmap='hot', ticks=False)
        above = (sedsn > nsigma)
        plot_boundary_map(above)
        ax = plt.axis()
        y,x = np.nonzero(peaks)
        plt.plot(x, y, 'r+')
        plt.axis(ax)
        plt.title('SED %s: S/N & peaks' % sedname)
        ps.savefig()

        # plt.clf()
        # plt.imshow(sedsn, vmin=-2, vmax=10, interpolation='nearest',
        #            origin='lower', cmap='hot')
        # plot_boundary_map(sedsn > lowest_saddle)
        # plt.title('SED %s: S/N & lowest saddle point bounds' % sedname)
        # ps.savefig()

    # For each new source, compute the saddle value, segment at that
    # level, and drop the source if it is in the same blob as a
    # previously-detected source.  We dilate the blobs a bit too, to
    # catch slight differences in centroid vs SDSS sources.
    dilate = 2

    # For efficiency, segment at the minimum saddle level to compute
    # slices; the operations described above need only happen within
    # the slice.
    saddlemap = (sedsn > lowest_saddle)
    if saturated_pix is not None:
        saddlemap |= saturated_pix
    saddlemap = binary_dilation(saddlemap, iterations=dilate)
    allblobs,nblobs = label(saddlemap)
    allslices = find_objects(allblobs)
    ally0 = [sy.start for sy,sx in allslices]
    allx0 = [sx.start for sy,sx in allslices]

    # brightest peaks first
    py,px = np.nonzero(peaks)
    I = np.argsort(-sedsn[py,px])
    py = py[I]
    px = px[I]

    keep = np.zeros(len(px), bool)

    peakval = []
    aper = []
    apin = 10
    apout = 20

    # Map of pixels that are vetoed by sources found so far.  The veto
    # area is based on saddle height.  We go from brightest to
    # faintest pixels.  Thus the saddle level decreases, and the
    # saddlemap areas become larger; the saddlemap when a source is
    # found is a lower bound on the pixels that it will veto based on
    # the saddle heights of fainter sources.  Thus the vetomap isn't
    # the final word, it is just a quick veto of pixels we know for
    # sure will be vetoed.
    vetomap = np.zeros(sedsn.shape, bool)
    
    # For each peak, determine whether it is isolated enough --
    # separated by a low enough saddle from other sources.  Need only
    # search within its "allblob", which is defined by the lowest
    # saddle.
    print('Found', len(px), 'potential peaks')
    #tlast = Time()
    for i,(x,y) in enumerate(zip(px, py)):
        if vetomap[y,x]:
            #print('  in veto map!')
            continue
        #t0 = Time()
        #t1 = Time()
        #print('Time since last source:', t1-tlast)
        #tlast = t1

        level = saddle_level(sedsn[y,x])
        ablob = allblobs[y,x]
        index = ablob - 1
        slc = allslices[index]

        #print('source', i, 'of', len(px), 'at', x,y, 'S/N', sedsn[y,x], 'saddle', level)
        #print('  allblobs slice', slc)

        saddlemap = (sedsn[slc] > level)
        if saturated_pix is not None:
            saddlemap |= saturated_pix[slc]
        saddlemap *= (allblobs[slc] == ablob)
        #print('  saddlemap', Time()-tlast)
        saddlemap = binary_fill_holes(saddlemap)
        #print('  fill holes', Time()-tlast)
        saddlemap = binary_dilation(saddlemap, iterations=dilate)
        #print('  dilation', Time()-tlast)
        blobs,nblobs = label(saddlemap)
        #print('  label', Time()-tlast)
        x0,y0 = allx0[index], ally0[index]
        thisblob = blobs[y-y0, x-x0]

        # previously found sources:
        ox = np.append(xomit, px[:i][keep[:i]]) - x0
        oy = np.append(yomit, py[:i][keep[:i]]) - y0
        h,w = blobs.shape
        cut = False
        if len(ox):
            ox = ox.astype(int)
            oy = oy.astype(int)
            cut = any((ox >= 0) * (ox < w) * (oy >= 0) * (oy < h) *
                      (blobs[np.clip(oy,0,h-1), np.clip(ox,0,w-1)] == 
                       thisblob))

        if False and (not cut) and ps is not None:
            plt.clf()
            plt.subplot(1,2,1)
            dimshow(sedsn, vmin=-2, vmax=10, cmap='hot')
            plot_boundary_map((sedsn > nsigma))
            ax = plt.axis()
            plt.plot(x, y, 'm+', ms=12, mew=2)
            plt.axis(ax)

            plt.subplot(1,2,2)
            y1,x1 = [s.stop for s in slc]
            ext = [x0,x1,y0,y1]
            dimshow(saddlemap, extent=ext)
            #plt.plot([x0,x0,x1,x1,x0], [y0,y1,y1,y0,y0], 'c-')
            #ax = plt.axis()
            #plt.plot(ox+x0, oy+y0, 'rx')
            plt.plot(xomit, yomit, 'rx', ms=8, mew=2)
            plt.plot(px[:i][keep[:i]], py[:i][keep[:i]], '+',
                     color=green, ms=8, mew=2)
            plt.plot(x, y, 'mo', mec='m', mfc='none', ms=12, mew=2)
            plt.axis(ax)
            if cut:
                plt.suptitle('Cut')
            else:
                plt.suptitle('Keep')
            ps.savefig()

        #t1 = Time()
        #print(t1 - t0)

        if cut:
            # in same blob as previously found source.
            #print('  cut')
            # update vetomap
            vetomap[slc] |= saddlemap
            #print('Added to vetomap:', np.sum(saddlemap), 'pixels set; now total of', np.sum(vetomap), 'pixels set')
            continue

        # Measure in aperture...
        ap   =  sedsn[max(0, y-apout):min(H,y+apout+1),
                      max(0, x-apout):min(W,x+apout+1)]
        apiv = (sediv[max(0, y-apout):min(H,y+apout+1),
                      max(0, x-apout):min(W,x+apout+1)] > 0)
        aph,apw = ap.shape
        apx0, apy0 = max(0, x - apout), max(0, y - apout)
        R2 = ((np.arange(aph)+apy0 - y)[:,np.newaxis]**2 + 
              (np.arange(apw)+apx0 - x)[np.newaxis,:]**2)
        ap = ap[apiv * (R2 >= apin**2) * (R2 <= apout**2)]
        if len(ap):
            # 16th percentile ~ -1 sigma point.
            m = np.percentile(ap, 16.)
        else:
            # fake
            m = -1.
        #print('  aper', Time()-tlast)
        if cutonaper:
            if sedsn[y,x] - m < nsigma:
                continue

        aper.append(m)
        peakval.append(sedsn[y,x])
        keep[i] = True

        vetomap[slc] |= saddlemap
        #print('Added to vetomap:', np.sum(saddlemap), 'pixels set; now total of', np.sum(vetomap), 'pixels set')

        if False and ps is not None:
            plt.clf()
            plt.subplot(1,2,1)
            dimshow(ap, vmin=-2, vmax=10, cmap='hot',
                    extent=[apx0,apx0+apw,apy0,apy0+aph])
            plt.subplot(1,2,2)
            dimshow(ap * ((R2 >= apin**2) * (R2 <= apout**2)),
                    vmin=-2, vmax=10, cmap='hot',
                    extent=[apx0,apx0+apw,apy0,apy0+aph])
            plt.suptitle('peak %.1f vs ap %.1f' % (sedsn[y,x], m))
            ps.savefig()

    print('New sources:', Time()-t0)
    t0 = Time()

    if ps is not None:
        pxdrop = px[np.logical_not(keep)]
        pydrop = py[np.logical_not(keep)]

    py = py[keep]
    px = px[keep]

    # Which of the hotblobs yielded sources?  Those are the ones to keep.
    hbmap = np.zeros(nhot+1, bool)
    hbmap[hotblobs[py,px]] = True
    if len(xomit):
        h,w = hotblobs.shape
        hbmap[hotblobs[np.clip(yomit, 0, h-1), np.clip(xomit, 0, w-1)]] = True
    # in case a source is (somehow) not in a hotblob?
    hbmap[0] = False
    hotblobs = hbmap[hotblobs]

    if ps is not None:
        plt.clf()
        dimshow(vetomap, vmin=0, vmax=1, cmap='hot')
        plt.title('SED %s: veto map' % sedname)
        ps.savefig()


        plt.clf()
        dimshow(hotblobs, vmin=0, vmax=1, cmap='hot')
        ax = plt.axis()
        p1 = plt.plot(px, py, 'g+', ms=8, mew=2)
        p2 = plt.plot(pxdrop, pydrop, 'm+', ms=8, mew=2)
        p3 = plt.plot(xomit, yomit, 'r+', ms=8, mew=2)
        plt.axis(ax)
        plt.title('SED %s: hot blobs' % sedname)
        plt.figlegend((p3[0],p1[0],p2[0]), ('Existing', 'Keep', 'Drop'),
                      'upper left')
        ps.savefig()

    return hotblobs, px, py, aper, peakval
Exemple #19
0
def compare_subsampled(S, s, ps, psf, pixpsf, Gmine,v,w, gal_sigma, psf_sigma,
                       cd,
                       get_ffts=False, eval_psf=integrate_gaussian):
    print()
    print('Subsample', s)
    print()

    step = 1./s
    sz = s * (S-1) + 1
    
    #x = np.arange(0, S, step)[:sz+1]
    x = np.arange(0, S, step)[:sz]
    #y = np.arange(0, S, step)[:sz+1]
    # Create pixelized PSF (Gaussian)
    sx = x - 0.5 + step/2.
    subpixpsf = eval_psf(psf, sx, sx)
    binned = bin_image(subpixpsf, s)

    bh,bw = binned.shape
    pixpsf1 = pixpsf[:bh,:bw]
    ph,pw = pixpsf.shape
    binned = binned[:ph,:pw]

    print('Binned PSF:')
    measure(binned)
    print('Pix PSF:')
    measure(pixpsf)

    # Recompute my convolution using the binned PSF
    P,FG,Gmine,v,w = galaxy_psf_convolution(
        gal_sigma, 0., 0., GaussianGalaxy, cd,
        0., 0., binned, debug=True)

    xx,yy = np.meshgrid(x,x)

    # plt.clf()
    # 
    # plt.subplot(2,2,1)
    # dimshow(subpixpsf)
    # plt.title('subpix psf')
    # plt.colorbar()
    # 
    # plt.subplot(2,2,2)
    # dimshow(binned)
    # plt.title('binned subpix psf')
    # plt.colorbar()
    # 
    # plt.subplot(2,2,3)
    # dimshow(pixpsf1)
    # plt.title('pix psf')
    # plt.colorbar()
    # 
    # plt.subplot(2,2,4)
    # dimshow(pixpsf1 - binned)
    # plt.title('pix - binned')
    # plt.colorbar()
    # plt.suptitle('subsample %i' % s)
    # ps.savefig()
    
    # Create pixelized galaxy image
    #gxx,gyy = xx + step/2., yy + step/2.
    gxx,gyy = xx,yy
    #gxx,gyy = xx - step, yy - step
    #gxx,gyy = xx - step/2., yy - step/2.
    center = S/2
    subpixgal = np.exp(-0.5 * ((gxx-center)**2 + (gyy-center)**2)/gal_sigma**2)
    sh,sw = subpixpsf.shape
    subpixgal = subpixgal[:sh,:sw]

    print('Subpix psf, gal', subpixpsf.shape, subpixgal.shape)

    print('Subpix PSF:')
    measure(subpixpsf)
    print('Subpix gal:')
    measure(subpixgal)

    # FFT convolution
    Fpsf = np.fft.rfft2(subpixpsf)
    spg = np.fft.ifftshift(subpixgal)

    # plt.clf()
    # for i in range(len(w)):
    #     plt.plot(v, Fpsf[i,:], 'c-')
    # for i in range(len(v)):
    #     plt.plot(w, Fpsf[:,i], 'm-')
    # plt.title('PSF Fourier transform')
    # ps.savefig()
    # 
    # IV = np.argsort(v)
    # IW = np.argsort(w)
    # plt.clf()
    # for i in range(len(w)):
    #     plt.plot(v[IV], np.abs(Fpsf[i,IV]), 'c-')
    # for i in range(len(v)):
    #     plt.plot(w[IW], np.abs(Fpsf[IW,i]), 'm-')
    # plt.title('abs PSF Fourier transform')
    # ps.savefig()
    # 
    # plt.yscale('log')
    # ps.savefig()
    
    # plt.clf()
    # dimshow(spg)
    # plt.title('spg')
    # ps.savefig()

    Fgal = np.fft.rfft2(spg)

    if get_ffts:
        return Fpsf, Fgal

    Fconv = Fpsf * Fgal
    subpixfft = np.fft.irfft2(Fconv, s=subpixpsf.shape)
    print('Shapes:', 'subpixpsf', subpixpsf.shape, 'Fpsf', Fpsf.shape)
    print('spg', spg.shape, 'Fgal', Fgal.shape, 'Fconv', Fconv.shape,
          'subpixfft', subpixfft.shape)
    
    print('Subpix conv', subpixfft.shape)
    
    binned = bin_image(subpixfft, s)
    binned /= np.sum(binned)

    print('Binned', binned.shape)
    print('Mine:', Gmine.shape)
    print('Mine:')
    measure(Gmine)
    print('Binned subpix FFT:')
    measure(binned)

    mh,mw = Gmine.shape
    binned = binned[:mh,:mw]

    plt.clf()

    plt.subplot(2,3,1)
    dimshow(subpixpsf)
    plt.title('subpix psf')
    plt.colorbar()

    plt.subplot(2,3,2)
    dimshow(subpixgal)
    plt.title('subpix galaxy')
    plt.colorbar()

    plt.subplot(2,3,3)
    dimshow(subpixfft)
    plt.title('subpix FFT conv')
    plt.colorbar()

    plt.subplot(2,3,4)
    dimshow(np.log10(np.maximum(binned / np.max(binned), 1e-12)))
    plt.title('log binned FFT conv')
    plt.colorbar()

    plt.subplot(2,3,5)
    dimshow(np.log10(np.maximum(Gmine / np.max(Gmine), 1e-12)))
    #dimshow(Gmine)
    plt.title('log my conv')
    plt.colorbar()

    gh,gw = Gmine.shape
    binned = binned[:gh,:gw]
    bh,bw = binned.shape
    Gmine = Gmine[:bh,:bw]
    diff  = Gmine - binned
    
    plt.subplot(2,3,6)
    dimshow(diff)
    plt.title('mine - FFT')
    plt.colorbar()

    plt.suptitle('PSF %g, Gal %g, subsample %i' % (psf_sigma, gal_sigma, s))
    
    ps.savefig()

    rmsdiff = np.sqrt(np.mean(diff**2))
    return rmsdiff
def tim_plots(tims, bands, ps):
    # Pixel histograms of subimages.
    for b in bands:
        sig1 = np.median([tim.sig1 for tim in tims if tim.band == b])
        plt.clf()
        for tim in tims:
            if tim.band != b:
                continue
            # broaden range to encompass most pixels... only req'd
            # when sky is bad
            lo, hi = -5. * sig1, 5. * sig1
            pix = tim.getImage()[tim.getInvError() > 0]
            lo = min(lo, np.percentile(pix, 5))
            hi = max(hi, np.percentile(pix, 95))
            plt.hist(pix,
                     range=(lo, hi),
                     bins=50,
                     histtype='step',
                     alpha=0.5,
                     label=tim.name)
        plt.legend()
        plt.xlabel('Pixel values')
        plt.title('Pixel distributions: %s band' % b)
        ps.savefig()

        plt.clf()
        lo, hi = -5., 5.
        for tim in tims:
            if tim.band != b:
                continue
            ie = tim.getInvError()
            pix = (tim.getImage() * ie)[ie > 0]
            plt.hist(pix,
                     range=(lo, hi),
                     bins=50,
                     histtype='step',
                     alpha=0.5,
                     label=tim.name)
        plt.legend()
        plt.xlabel('Pixel values (sigma)')
        plt.xlim(lo, hi)
        plt.title('Pixel distributions: %s band' % b)
        ps.savefig()

    # Plot image pixels, invvars, masks
    for tim in tims:
        plt.clf()
        plt.subplot(2, 2, 1)
        dimshow(tim.getImage(), vmin=-3. * tim.sig1, vmax=10. * tim.sig1)
        plt.title('image')
        plt.subplot(2, 2, 2)
        dimshow(tim.getInvError(), vmin=0, vmax=1.1 / tim.sig1)
        plt.title('inverr')
        if tim.dq is not None:
            plt.subplot(2, 2, 3)
            dimshow(tim.dq, vmin=0, vmax=tim.dq.max())
            plt.title('DQ')
            plt.subplot(2, 2, 3)
            dimshow(((tim.dq & tim.dq_saturation_bits) > 0),
                    vmin=0,
                    vmax=1.5,
                    cmap='hot')
            plt.title('SATUR')
        plt.subplot(2, 2, 4)
        dimshow(tim.getImage() * (tim.getInvError() > 0),
                vmin=-3. * tim.sig1,
                vmax=10. * tim.sig1)
        plt.title('image (masked)')
        plt.suptitle(tim.name)
        ps.savefig()

        if True and tim.dq is not None:
            from legacypipe.bits import DQ_BITS
            plt.clf()
            bitmap = dict([(v, k) for k, v in DQ_BITS.items()])
            k = 1
            for i in range(12):
                bitval = 1 << i
                if not bitval in bitmap:
                    continue
                # only 9 bits are actually used
                plt.subplot(3, 3, k)
                k += 1
                plt.imshow((tim.dq & bitval) > 0, vmin=0, vmax=1.5, cmap='hot')
                plt.title(bitmap[bitval])
            plt.suptitle(
                'Mask planes: %s (%s %s)' %
                (tim.name, tim.imobj.image_filename, tim.imobj.ccdname))
            ps.savefig()

            im = tim.imobj
            from legacypipe.decam import decam_has_dq_codes
            print(tim.name, ': plver "%s"' % im.plver, 'has DQ codes:',
                  decam_has_dq_codes(im.plver))
            if im.camera == 'decam' and decam_has_dq_codes(im.plver):
                # Integer codes, not bitmask.  Re-read and plot.
                dq = im.read_dq(slice=tim.slice)
                plt.clf()
                plt.subplot(1, 3, 1)
                dimshow(tim.getImage(),
                        vmin=-3. * tim.sig1,
                        vmax=30. * tim.sig1)
                plt.title('image')
                plt.subplot(1, 3, 2)
                dimshow(tim.getInvError(), vmin=0, vmax=1.1 / tim.sig1)
                plt.title('inverr')
                plt.subplot(1, 3, 3)
                plt.imshow(dq,
                           interpolation='nearest',
                           origin='lower',
                           cmap='tab10',
                           vmin=-0.5,
                           vmax=9.5)
                plt.colorbar()
                plt.title('DQ codes')
                plt.suptitle(
                    '%s (%s %s) PLVER %s' %
                    (tim.name, im.image_filename, im.ccdname, im.plver))
                ps.savefig()
def fitblobs_plots(tims, bands, targetwcs, blobslices, blobsrcs, cat, blobs,
                   ps):
    coimgs, _ = quick_coadds(tims, bands, targetwcs)
    plt.clf()
    dimshow(get_rgb(coimgs, bands))
    ax = plt.axis()
    for i, bs in enumerate(blobslices):
        sy, sx = bs
        by0, by1 = sy.start, sy.stop
        bx0, bx1 = sx.start, sx.stop
        plt.plot([bx0, bx0, bx1, bx1, bx0], [by0, by1, by1, by0, by0], 'r-')
        plt.text((bx0 + bx1) / 2.,
                 by0,
                 '%i' % i,
                 ha='center',
                 va='bottom',
                 color='r')
    plt.axis(ax)
    plt.title('Blobs')
    ps.savefig()

    for i, Isrcs in enumerate(blobsrcs):
        for isrc in Isrcs:
            src = cat[isrc]
            ra, dec = src.getPosition().ra, src.getPosition().dec
            _, x, y = targetwcs.radec2pixelxy(ra, dec)
            plt.text(x,
                     y,
                     'b%i/s%i' % (i, isrc),
                     ha='center',
                     va='bottom',
                     color='r')
    plt.axis(ax)
    plt.title('Blobs + Sources')
    ps.savefig()

    plt.clf()
    dimshow(blobs)
    ax = plt.axis()
    for i, bs in enumerate(blobslices):
        sy, sx = bs
        by0, by1 = sy.start, sy.stop
        bx0, bx1 = sx.start, sx.stop
        plt.plot([bx0, bx0, bx1, bx1, bx0], [by0, by1, by1, by0, by0], 'r-')
        plt.text((bx0 + bx1) / 2.,
                 by0,
                 '%i' % i,
                 ha='center',
                 va='bottom',
                 color='r')
    plt.axis(ax)
    plt.title('Blobs')
    ps.savefig()

    plt.clf()
    dimshow(blobs != -1)
    ax = plt.axis()
    for i, bs in enumerate(blobslices):
        sy, sx = bs
        by0, by1 = sy.start, sy.stop
        bx0, bx1 = sx.start, sx.stop
        plt.plot([bx0, bx0, bx1, bx1, bx0], [by0, by1, by1, by0, by0], 'r-')
        plt.text((bx0 + bx1) / 2.,
                 by0,
                 '%i' % i,
                 ha='center',
                 va='bottom',
                 color='r')
    plt.axis(ax)
    plt.title('Blobs')
    ps.savefig()
Exemple #22
0
def stage1(T=None, coimgs=None, cons=None, detmaps=None, detivs=None,
           targetrd=None, pixscale=None, targetwcs=None, W=None,H=None,
           bands=None, tims=None, ps=None, brick=None, cat=None):
    orig_wcsxy0 = [tim.wcs.getX0Y0() for tim in tims]
    hot = np.zeros((H,W), np.float32)

    for band in bands:
        detmap = detmaps[band] / np.maximum(1e-16, detivs[band])
        detsn = detmap * np.sqrt(detivs[band])
        hot = np.maximum(hot, detsn)
        detmaps[band] = detmap

    ### FIXME -- ugri
    for sedname,sed in [('Flat', (1.,1.,1.)), ('Red', (2.5, 1.0, 0.4))]:
        sedmap = np.zeros((H,W), np.float32)
        sediv  = np.zeros((H,W), np.float32)
        for iband,band in enumerate(bands):
            # We convert the detmap to canonical band via
            #   detmap * w
            # And the corresponding change to sig1 is
            #   sig1 * w
            # So the invvar-weighted sum is
            #    (detmap * w) / (sig1**2 * w**2)
            #  = detmap / (sig1**2 * w)
            sedmap += detmaps[band] * detivs[band] / sed[iband]
            sediv  += detivs [band] / sed[iband]**2
        sedmap /= np.maximum(1e-16, sediv)
        sedsn   = sedmap * np.sqrt(sediv)
        hot = np.maximum(hot, sedsn)

        plt.clf()
        dimshow(np.round(sedsn), vmin=0, vmax=10, cmap='hot')
        plt.title('SED-matched detection filter: %s' % sedname)
        ps.savefig()

    peaks = (hot > 4)
    blobs,nblobs = label(peaks)
    print 'N detected blobs:', nblobs
    blobslices = find_objects(blobs)
    # Un-set catalog blobs
    for x,y in zip(T.itx, T.ity):
        # blob number
        bb = blobs[y,x]
        if bb == 0:
            continue
        # un-set 'peaks' within this blob
        slc = blobslices[bb-1]
        peaks[slc][blobs[slc] == bb] = 0

    # Now, after having removed catalog sources, crank up the detection threshold
    peaks &= (hot > 5)
        
    # zero out the edges(?)
    peaks[0 ,:] = peaks[:, 0] = 0
    peaks[-1,:] = peaks[:,-1] = 0
    peaks[1:-1, 1:-1] &= (hot[1:-1,1:-1] >= hot[0:-2,1:-1])
    peaks[1:-1, 1:-1] &= (hot[1:-1,1:-1] >= hot[2:  ,1:-1])
    peaks[1:-1, 1:-1] &= (hot[1:-1,1:-1] >= hot[1:-1,0:-2])
    peaks[1:-1, 1:-1] &= (hot[1:-1,1:-1] >= hot[1:-1,2:  ])

    # These are our peaks
    pki = np.flatnonzero(peaks)
    peaky,peakx = np.unravel_index(pki, peaks.shape)
    print len(peaky), 'peaks'

    crossa = dict(ms=10, mew=1.5)
    plt.clf()
    dimshow(get_rgb(coimgs, bands))
    ax = plt.axis()
    plt.plot(T.tx, T.ty, 'r+', **crossa)
    plt.plot(peakx, peaky, '+', color=green, **crossa)
    plt.axis(ax)
    plt.title('SDSS + SED-matched detections')
    ps.savefig()


    ### HACK -- high threshold again

    # Segment, and record which sources fall into each blob
    blobs,nblobs = label((hot > 20))
    print 'N detected blobs:', nblobs
    blobslices = find_objects(blobs)
    T.blob = blobs[T.ity, T.itx]
    blobsrcs = []
    blobflux = []
    fluximg = coimgs[1]
    for blob in range(1, nblobs+1):
        blobsrcs.append(np.flatnonzero(T.blob == blob))
        bslc = blobslices[blob-1]
        blobflux.append(np.sum(fluximg[bslc][blobs[bslc] == blob]))

    # Fit the SDSS sources
    
    for tim in tims:
        tim.psfex.fitSavedData(*tim.psfex.splinedata)
        tim.psf = tim.psfex
        
    # How far down to render model profiles
    minsigma = 0.1
    for tim in tims:
        tim.modelMinval = minsigma * tim.sig1
    srcvariances = [[] for src in cat]
    # Fit in order of flux
    for blobnumber,iblob in enumerate(np.argsort(-np.array(blobflux))):

        bslc  = blobslices[iblob]
        Isrcs = blobsrcs  [iblob]
        if len(Isrcs) == 0:
            continue

        print
        print 'Blob', blobnumber, 'of', len(blobflux), ':', len(Isrcs), 'sources'
        print 'Source indices:', Isrcs
        print

        # blob bbox in target coords
        sy,sx = bslc
        by0,by1 = sy.start, sy.stop
        bx0,bx1 = sx.start, sx.stop
        blobh,blobw = by1 - by0, bx1 - bx0

        rr,dd = targetwcs.pixelxy2radec([bx0,bx0,bx1,bx1],[by0,by1,by1,by0])
        alphas = [0.1, 0.3, 1.0]
        subtims = []
        for itim,tim in enumerate(tims):
            h,w = tim.shape
            ok,x,y = tim.subwcs.radec2pixelxy(rr,dd)
            sx0,sx1 = x.min(), x.max()
            sy0,sy1 = y.min(), y.max()
            if sx1 < 0 or sy1 < 0 or sx1 > w or sy1 > h:
                continue
            sx0 = np.clip(int(np.floor(sx0)), 0, w-1)
            sx1 = np.clip(int(np.ceil (sx1)), 0, w-1) + 1
            sy0 = np.clip(int(np.floor(sy0)), 0, h-1)
            sy1 = np.clip(int(np.ceil (sy1)), 0, h-1) + 1
            subslc = slice(sy0,sy1),slice(sx0,sx1)
            subimg = tim.getImage ()[subslc]
            subie  = tim.getInvError()[subslc]
            subwcs = tim.getWcs().copy()
            ox0,oy0 = orig_wcsxy0[itim]
            subwcs.setX0Y0(ox0 + sx0, oy0 + sy0)

            # Mask out inverr for pixels that are not within the blob.
            subtarget = targetwcs.get_subimage(bx0, by0, blobw, blobh)
            subsubwcs = tim.subwcs.get_subimage(int(sx0), int(sy0), int(sx1-sx0), int(sy1-sy0))
            try:
                Yo,Xo,Yi,Xi,rims = resample_with_wcs(subsubwcs, subtarget, [], 2)
            except OverlapError:
                print 'No overlap'
                continue
            if len(Yo) == 0:
                continue
            subie2 = np.zeros_like(subie)
            I = np.flatnonzero(blobs[bslc][Yi, Xi] == (iblob+1))
            subie2[Yo[I],Xo[I]] = subie[Yo[I],Xo[I]]
            subie = subie2
            # If the subimage (blob) is small enough, instantiate a
            # constant PSF model in the center.
            if sy1-sy0 < 100 and sx1-sx0 < 100:
                subpsf = tim.psf.mogAt(ox0 + (sx0+sx1)/2., oy0 + (sy0+sy1)/2.)
            else:
                # Otherwise, instantiate a (shifted) spatially-varying
                # PsfEx model.
                subpsf = ShiftedPsf(tim.psf, ox0+sx0, oy0+sy0)

            subtim = Image(data=subimg, inverr=subie, wcs=subwcs,
                           psf=subpsf, photocal=tim.getPhotoCal(),
                           sky=tim.getSky(), name=tim.name)
            subtim.band = tim.band
            subtim.sig1 = tim.sig1
            subtim.modelMinval = tim.modelMinval
            subtims.append(subtim)

        subcat = Catalog(*[cat[i] for i in Isrcs])
        subtr = Tractor(subtims, subcat)
        subtr.freezeParam('images')
        # Optimize individual sources in order of flux
        fluxes = []
        for src in subcat:
            # HACK -- here we just *sum* the nanomaggies in each band.  Bogus!
            br = src.getBrightness()
            flux = sum([br.getFlux(band) for band in bands])
            fluxes.append(flux)
        Ibright = np.argsort(-np.array(fluxes))

        if len(Ibright) >= 5:
            # -Remember the original subtim images
            # -Compute initial models for each source (in each tim)
            # -Subtract initial models from images
            # -During fitting, for each source:
            #   -add back in the source's initial model (to each tim)
            #   -fit, with Catalog([src])
            #   -subtract final model (from each tim)
            # -Replace original subtim images
            #
            # --Might want to omit newly-added detection-filter sources, since their
            # fluxes are bogus.

            # Remember original tim images
            orig_timages = [tim.getImage().copy() for tim in subtims]
            initial_models = []
            # Create initial models for each tim x each source
            for tim in subtims:
                mods = []
                for src in subcat:
                    mod = src.getModelPatch(tim)
                    mods.append(mod)
                    if mod is not None:
                        if not np.all(np.isfinite(mod.patch)):
                            print 'Non-finite mod patch'
                            print 'source:', src
                            print 'tim:', tim
                            print 'PSF:', tim.getPsf()
                        assert(np.all(np.isfinite(mod.patch)))
                        mod.addTo(tim.getImage(), scale=-1)
                initial_models.append(mods)
            # For sources in decreasing order of brightness
            for numi,i in enumerate(Ibright):
                tsrc = Time()
                print 'Fitting source', i, '(%i of %i in blob)' % (numi, len(Ibright))
                src = subcat[i]
                print src

                srctractor = Tractor(subtims, [src])
                srctractor.freezeParams('images')
                
                # Add this source's initial model back in.
                for tim,mods in zip(subtims, initial_models):
                    mod = mods[i]
                    if mod is not None:
                        mod.addTo(tim.getImage())

                print 'Optimizing:', srctractor
                srctractor.printThawedParams()
                for step in range(50):
                    dlnp,X,alpha = srctractor.optimize(priors=False, shared_params=False,
                                                  alphas=alphas)
                    print 'dlnp:', dlnp, 'src', src
                    if dlnp < 0.1:
                        break

                for tim in subtims:
                    mod = src.getModelPatch(tim)
                    if mod is not None:
                        mod.addTo(tim.getImage(), scale=-1)
    
            for tim,img in zip(subtims, orig_timages):
                tim.data = img

            del orig_timages
            del initial_models
        else:
            # Fit sources one at a time, but don't subtract other models
            subcat.freezeAllParams()
            for numi,i in enumerate(Ibright):
                tsrc = Time()
                print 'Fitting source', i, '(%i of %i in blob)' % (numi, len(Ibright))
                print subcat[i]
                subcat.freezeAllBut(i)
                print 'Optimizing:', subtr
                subtr.printThawedParams()
                for step in range(10):
                    dlnp,X,alpha = subtr.optimize(priors=False, shared_params=False,
                                                  alphas=alphas)
                    print 'dlnp:', dlnp
                    if dlnp < 0.1:
                        break
                print 'Fitting source took', Time()-tsrc
                print subcat[i]
        if len(Isrcs) > 1 and len(Isrcs) <= 10:
            tfit = Time()
            # Optimize all at once?
            subcat.thawAllParams()
            print 'Optimizing:', subtr
            subtr.printThawedParams()
            for step in range(20):
                dlnp,X,alpha = subtr.optimize(priors=False, shared_params=False,
                                              alphas=alphas)
                print 'dlnp:', dlnp
                if dlnp < 0.1:
                    break

        # Variances
        subcat.thawAllRecursive()
        subcat.freezeAllParams()
        for isub,srci in enumerate(Isrcs):
            print 'Variances for source', srci
            subcat.thawParam(isub)

            src = subcat[isub]
            print 'Source', src
            print 'Params:', src.getParamNames()
            
            if isinstance(src, (DevGalaxy, ExpGalaxy)):
                src.shape = EllipseE.fromEllipseESoft(src.shape)
            elif isinstance(src, FixedCompositeGalaxy):
                src.shapeExp = EllipseE.fromEllipseESoft(src.shapeExp)
                src.shapeDev = EllipseE.fromEllipseESoft(src.shapeDev)

            print 'Converted ellipse:', src

            allderivs = subtr.getDerivs()
            for iparam,derivs in enumerate(allderivs):
                dchisq = 0
                for deriv,tim in derivs:
                    h,w = tim.shape
                    deriv.clipTo(w,h)
                    ie = tim.getInvError()
                    slc = deriv.getSlice(ie)
                    chi = deriv.patch * ie[slc]
                    dchisq += (chi**2).sum()
                if dchisq == 0.:
                    v = np.nan
                else:
                    v = 1./dchisq
                srcvariances[srci].append(v)
            assert(len(srcvariances[srci]) == subcat[isub].numberOfParams())
            subcat.freezeParam(isub)

    cat.thawAllRecursive()

    for i,src in enumerate(cat):
        print 'Source', i, src
        print 'variances:', srcvariances[i]
        print len(srcvariances[i]), 'vs', src.numberOfParams()
        if len(srcvariances[i]) != src.numberOfParams():
            # This can happen for sources outside the brick bounds: they never get optimized?
            print 'Warning: zeroing variances for source', src
            srcvariances[i] = [0]*src.numberOfParams()
            if isinstance(src, (DevGalaxy, ExpGalaxy)):
                src.shape = EllipseE.fromEllipseESoft(src.shape)
            elif isinstance(src, FixedCompositeGalaxy):
                src.shapeExp = EllipseE.fromEllipseESoft(src.shapeExp)
                src.shapeDev = EllipseE.fromEllipseESoft(src.shapeDev)
        assert(len(srcvariances[i]) == src.numberOfParams())

    variances = np.hstack(srcvariances)
    assert(len(variances) == cat.numberOfParams())

    return dict(cat=cat, variances=variances)
Exemple #23
0
def segment_and_group_sources(image, T, name=None, ps=None, plots=False):
    '''
    *image*: binary image that defines "blobs"
    *T*: source table; only ".ibx" and ".iby" elements are used (x,y integer
    pix pos).  Note: ".blob" field is added.
    *name*: for debugging only

    Returns: (blobs, blobsrcs, blobslices)

    *blobs*: image, values -1 = no blob, integer blob indices
    *blobsrcs*: list of np arrays of integers, elements in T within each blob
    *blobslices*: list of slice objects for blob bounding-boxes.
    '''
    from scipy.ndimage.morphology import binary_fill_holes
    from scipy.ndimage.measurements import label, find_objects

    image = binary_fill_holes(image)
    blobs, nblobs = label(image)
    #print('Detected blobs:', nblobs)
    H, W = image.shape
    del image

    blobslices = find_objects(blobs)
    clipx = np.clip(T.ibx, 0, W - 1)
    clipy = np.clip(T.iby, 0, H - 1)
    T.blob = blobs[clipy, clipx]

    if plots:
        import pylab as plt
        from astrometry.util.plotutils import dimshow
        plt.clf()
        dimshow(blobs > 0, vmin=0, vmax=1)
        ax = plt.axis()
        for i, bs in enumerate(blobslices):
            sy, sx = bs
            by0, by1 = sy.start, sy.stop
            bx0, bx1 = sx.start, sx.stop
            plt.plot([bx0, bx0, bx1, bx1, bx0], [by0, by1, by1, by0, by0],
                     'r-')
            plt.text((bx0 + bx1) / 2.,
                     by0,
                     '%i' % (i + 1),
                     ha='center',
                     va='bottom',
                     color='r')
        plt.plot(T.ibx, T.iby, 'rx')
        for i, t in enumerate(T):
            plt.text(t.ibx,
                     t.iby,
                     'src %i' % i,
                     color='red',
                     ha='left',
                     va='center')
        plt.axis(ax)
        plt.title('Blobs')
        ps.savefig()

    # Find sets of sources within blobs
    blobsrcs = []
    keepslices = []
    blobmap = {}
    for blob in range(1, nblobs + 1):
        Isrcs, = np.nonzero(T.blob == blob)
        if len(Isrcs) == 0:
            blobmap[blob] = -1
            continue
        blobmap[blob] = len(blobsrcs)
        blobsrcs.append(Isrcs)
        bslc = blobslices[blob - 1]
        keepslices.append(bslc)

    blobslices = keepslices

    # Find sources that do not belong to a blob and add them as
    # singleton "blobs"; otherwise they don't get optimized.
    # for sources outside the image bounds, what should we do?
    inblobs = np.zeros(len(T), bool)
    for Isrcs in blobsrcs:
        inblobs[Isrcs] = True
    noblobs = np.flatnonzero(np.logical_not(inblobs))
    del inblobs
    #print(len(noblobs), 'sources are not in blobs')

    # Remap the "blobs" image so that empty regions are = -1 and the blob values
    # correspond to their indices in the "blobsrcs" list.
    if len(blobmap):
        maxblob = max(blobmap.keys())
    else:
        maxblob = 0
    maxblob = max(maxblob, blobs.max())
    bm = np.zeros(maxblob + 1, int)
    for k, v in blobmap.items():
        bm[k] = v
    bm[0] = -1
    # Remap blob numbers
    blobs = bm[blobs]

    if plots:
        from astrometry.util.plotutils import dimshow
        plt.clf()
        dimshow(blobs > -1, vmin=0, vmax=1)
        ax = plt.axis()
        for i, bs in enumerate(blobslices):
            sy, sx = bs
            by0, by1 = sy.start, sy.stop
            bx0, bx1 = sx.start, sx.stop
            plt.plot([bx0, bx0, bx1, bx1, bx0], [by0, by1, by1, by0, by0],
                     'r-')
            plt.text((bx0 + bx1) / 2.,
                     by0,
                     '%i' % (i + 1),
                     ha='center',
                     va='bottom',
                     color='r')
        plt.plot(T.ibx, T.iby, 'rx')
        for i, t in enumerate(T):
            plt.text(t.ibx,
                     t.iby,
                     'src %i' % i,
                     color='red',
                     ha='left',
                     va='center')
        plt.axis(ax)
        plt.title('Blobs')
        ps.savefig()

    for j, Isrcs in enumerate(blobsrcs):
        for i in Isrcs:
            if (blobs[clipy[i], clipx[i]] != j):
                print(
                    '---------------------------!!!-------------------------')
                print('Blob', j, 'sources', Isrcs)
                print('Source', i, 'coords x,y', T.ibx[i], T.iby[i])
                print('Expected blob value', j, 'but got', blobs[clipy[i],
                                                                 clipx[i]])

    T.blob = blobs[clipy, clipx]
    assert (len(blobsrcs) == len(blobslices))
    return blobs, blobsrcs, blobslices
def _psf_check_plots(tims):
    # HACK -- check PSF models
    plt.figure(num=2, figsize=(7,4.08))
    for im,tim in zip(ims,tims):
        print
        print 'Image', tim.name

        plt.subplots_adjust(left=0, right=1, bottom=0, top=0.95,
                            hspace=0, wspace=0)
        W,H = 2048,4096
        psfex = PsfEx(im.psffn, W, H)

        psfim0 = psfim = psfex.instantiateAt(W/2, H/2)
        # trim
        psfim = psfim[10:-10, 10:-10]

        tfit = Time()
        psffit2 = GaussianMixtureEllipsePSF.fromStamp(psfim, N=2)
        print 'Fitting PSF mog:', psfim.shape, Time()-tfit

        psfim = psfim0[5:-5, 5:-5]
        tfit = Time()
        psffit2 = GaussianMixtureEllipsePSF.fromStamp(psfim, N=2)
        print 'Fitting PSF mog:', psfim.shape, Time()-tfit

        ph,pw = psfim.shape
        psffit = GaussianMixtureEllipsePSF.fromStamp(psfim, N=3)

        #mx = 0.03
        mx = psfim.max()

        mod3 = np.zeros_like(psfim)
        p = psffit.getPointSourcePatch(pw/2, ph/2, radius=pw/2)
        p.addTo(mod3)
        mod2 = np.zeros_like(psfim)
        p = psffit2.getPointSourcePatch(pw/2, ph/2, radius=pw/2)
        p.addTo(mod2)

        plt.clf()
        plt.subplot(2,3,1)
        dimshow(psfim, vmin=0, vmax=mx, ticks=False)
        plt.subplot(2,3,2)
        dimshow(mod3, vmin=0, vmax=mx, ticks=False)
        plt.subplot(2,3,3)
        dimshow(mod2, vmin=0, vmax=mx, ticks=False)
        plt.subplot(2,3,5)
        dimshow(psfim-mod3, vmin=-mx/2, vmax=mx/2, ticks=False)
        plt.subplot(2,3,6)
        dimshow(psfim-mod2, vmin=-mx/2, vmax=mx/2, ticks=False)
        ps.savefig()
        #continue

        for round in [1,2,3,4,5]:
            plt.clf()
            k = 1
            #rows,cols = 10,5
            rows,cols = 7,4
            for iy,y in enumerate(np.linspace(0, H, rows).astype(int)):
                for ix,x in enumerate(np.linspace(0, W, cols).astype(int)):
                    psfimg = psfex.instantiateAt(x, y)
                    # trim
                    psfimg = psfimg[5:-5, 5:-5]
                    print 'psfimg', psfimg.shape
                    ph,pw = psfimg.shape
                    psfimg2 = tim.psfex.getPointSourcePatch(x, y, radius=pw/2)
                    mod = np.zeros_like(psfimg)
                    h,w = mod.shape
                    #psfimg2.x0 -= x
                    #psfimg2.x0 += w/2
                    #psfimg2.y0 -= y
                    #psfimg2.y0 += h/2
                    psfimg2.x0 = 0
                    psfimg2.y0 = 0
                    print 'psfimg2:', (psfimg2.x0,psfimg2.y0)
                    psfimg2.addTo(mod)
                    print 'psfimg:', psfimg.min(), psfimg.max(), psfimg.sum()
                    print 'psfimg2:', psfimg2.patch.min(), psfimg2.patch.max(), psfimg2.patch.sum()
                    print 'mod:', mod.min(), mod.max(), mod.sum()

                    #plt.subplot(rows, cols, k)
                    plt.subplot(cols, rows, k)
                    k += 1
                    kwa = dict(vmin=0, vmax=mx, ticks=False)
                    if round == 1:
                        dimshow(psfimg, **kwa)
                        plt.suptitle('PsfEx')
                    elif round == 2:
                        dimshow(mod, **kwa)
                        plt.suptitle('varying MoG')
                    elif round == 3:
                        dimshow(psfimg - mod, vmin=-mx/2, vmax=mx/2, ticks=False)
                        plt.suptitle('PsfEx - varying MoG')
                    elif round == 4:
                        dimshow(psfimg - mod3, vmin=-mx/2, vmax=mx/2, ticks=False)
                        plt.suptitle('PsfEx - const MoG(3)')
                    elif round == 5:
                        dimshow(psfimg - mod2, vmin=-mx/2, vmax=mx/2, ticks=False)
                        plt.suptitle('PsfEx - const MoG(2)')
            ps.savefig()
Exemple #25
0
    def run(self, ps=None, focus=False, momentsize=5, n_fwhm=100):
        import pylab as plt
        from astrometry.util.plotutils import dimshow, plothist
        from legacyanalysis.ps1cat import ps1cat
        import photutils
        import tractor

        fn = self.fn
        ext = self.ext
        pixsc = self.pixscale

        F = fitsio.FITS(fn)
        primhdr = F[0].read_header()
        self.primhdr = primhdr
        img, hdr = self.read_raw(F, ext)
        self.hdr = hdr

        # pre sky-sub
        mn, mx = np.percentile(img.ravel(), [25, 98])
        self.imgkwa = dict(vmin=mn, vmax=mx, cmap='gray')

        if self.debug and ps is not None:
            plt.clf()
            dimshow(img, **self.imgkwa)
            plt.title('Raw image')
            ps.savefig()

            M = 200
            plt.clf()
            plt.subplot(2, 2, 1)
            dimshow(img[-M:, :M], ticks=False, **self.imgkwa)
            plt.subplot(2, 2, 2)
            dimshow(img[-M:, -M:], ticks=False, **self.imgkwa)
            plt.subplot(2, 2, 3)
            dimshow(img[:M, :M], ticks=False, **self.imgkwa)
            plt.subplot(2, 2, 4)
            dimshow(img[:M, -M:], ticks=False, **self.imgkwa)
            plt.suptitle('Raw image corners')
            ps.savefig()

        img, trim_x0, trim_y0 = self.trim_edges(img)

        fullH, fullW = img.shape

        if self.debug and ps is not None:
            plt.clf()
            dimshow(img, **self.imgkwa)
            plt.title('Trimmed image')
            ps.savefig()

            M = 200
            plt.clf()
            plt.subplot(2, 2, 1)
            dimshow(img[-M:, :M], ticks=False, **self.imgkwa)
            plt.subplot(2, 2, 2)
            dimshow(img[-M:, -M:], ticks=False, **self.imgkwa)
            plt.subplot(2, 2, 3)
            dimshow(img[:M, :M], ticks=False, **self.imgkwa)
            plt.subplot(2, 2, 4)
            dimshow(img[:M, -M:], ticks=False, **self.imgkwa)
            plt.suptitle('Trimmed corners')
            ps.savefig()

        band = self.get_band(primhdr)
        exptime = primhdr['EXPTIME']
        airmass = primhdr['AIRMASS']
        print('Band', band, 'Exptime', exptime, 'Airmass', airmass)

        zp0 = self.nom.zeropoint(band, ext=self.ext)
        sky0 = self.nom.sky(band)
        kx = self.nom.fiducial_exptime(band).k_co

        # Find the sky value and noise level
        sky, sig1 = self.get_sky_and_sigma(img)

        sky1 = np.median(sky)
        skybr = -2.5 * np.log10(sky1 / pixsc / pixsc / exptime) + zp0
        print('Sky brightness: %8.3f mag/arcsec^2' % skybr)
        print('Fiducial:       %8.3f mag/arcsec^2' % sky0)

        img -= sky

        self.remove_sky_gradients(img)

        # Post sky-sub
        mn, mx = np.percentile(img.ravel(), [25, 98])
        self.imgkwa = dict(vmin=mn, vmax=mx, cmap='gray')

        if ps is not None:
            plt.clf()
            dimshow(img, **self.imgkwa)
            plt.title('Sky-sub image: %s-%s' % (os.path.basename(fn).replace(
                '.fits', '').replace('.fz', ''), ext))
            plt.colorbar()
            ps.savefig()

        # Read WCS header and compute boresight
        wcs = self.get_wcs(hdr)
        ra_ccd, dec_ccd = wcs.pixelxy2radec((fullW + 1) / 2., (fullH + 1) / 2.)

        # Detect stars
        psfsig = self.nominal_fwhm / 2.35
        detsn = self.detection_map(img, sig1, psfsig, ps)

        slices = self.detect_sources(detsn, self.det_thresh, ps)
        print(len(slices), 'sources detected')
        if len(slices) < 20:
            slices = self.detect_sources(detsn, 10., ps)
            print(len(slices), 'sources detected')
        ndetected = len(slices)

        camera = primhdr.get('INSTRUME', '').strip().lower()
        # -> "decam" / "mosaic3"
        meas = dict(band=band,
                    airmass=airmass,
                    skybright=skybr,
                    pixscale=pixsc,
                    primhdr=primhdr,
                    hdr=hdr,
                    wcs=wcs,
                    ra_ccd=ra_ccd,
                    dec_ccd=dec_ccd,
                    extension=ext,
                    camera=camera,
                    ndetected=ndetected)

        if ndetected == 0:
            print('NO SOURCES DETECTED')
            return meas

        xx, yy = [], []
        fx, fy = [], []
        mx2, my2, mxy = [], [], []
        wmx2, wmy2, wmxy = [], [], []
        # "Peak" region to centroid
        P = momentsize
        H, W = img.shape

        for i, slc in enumerate(slices):
            y0 = slc[0].start
            x0 = slc[1].start
            subimg = detsn[slc]
            imax = np.argmax(subimg)
            y, x = np.unravel_index(imax, subimg.shape)
            if (x0 + x) < P or (x0 + x) > W - 1 - P or (y0 + y) < P or (
                    y0 + y) > H - 1 - P:
                #print('Skipping edge peak', x0+x, y0+y)
                continue
            xx.append(x0 + x)
            yy.append(y0 + y)
            pkarea = detsn[y0 + y - P:y0 + y + P + 1,
                           x0 + x - P:x0 + x + P + 1]

            from scipy.ndimage.measurements import center_of_mass
            cy, cx = center_of_mass(pkarea)
            #print('Center of mass', cx,cy)
            fx.append(x0 + x - P + cx)
            fy.append(y0 + y - P + cy)
            #print('x,y', x0+x, y0+y, 'vs centroid', x0+x-P+cx, y0+y-P+cy)

            ### HACK -- measure source ellipticity
            # go back to the image (not detection map)
            #subimg = img[slc]
            subimg = img[y0 + y - P:y0 + y + P + 1,
                         x0 + x - P:x0 + x + P + 1].copy()
            subimg /= subimg.sum()
            ph, pw = subimg.shape
            px, py = np.meshgrid(np.arange(pw), np.arange(ph))
            mx2.append(np.sum(subimg * (px - cx)**2))
            my2.append(np.sum(subimg * (py - cy)**2))
            mxy.append(np.sum(subimg * (px - cx) * (py - cy)))
            # Gaussian windowed version
            s = 1.
            wimg = subimg * np.exp(-0.5 * ((px - cx)**2 + (py - cy)**2) / s**2)
            wimg /= np.sum(wimg)
            wmx2.append(np.sum(wimg * (px - cx)**2))
            wmy2.append(np.sum(wimg * (py - cy)**2))
            wmxy.append(np.sum(wimg * (px - cx) * (py - cy)))

        mx2 = np.array(mx2)
        my2 = np.array(my2)
        mxy = np.array(mxy)
        wmx2 = np.array(wmx2)
        wmy2 = np.array(wmy2)
        wmxy = np.array(wmxy)

        # semi-major/minor axes and position angle
        theta = np.rad2deg(np.arctan2(2 * mxy, mx2 - my2) / 2.)
        theta = np.abs(theta) * np.sign(mxy)
        s = np.sqrt(((mx2 - my2) / 2.)**2 + mxy**2)
        a = np.sqrt((mx2 + my2) / 2. + s)
        b = np.sqrt((mx2 + my2) / 2. - s)
        ell = 1. - b / a

        wtheta = np.rad2deg(np.arctan2(2 * wmxy, wmx2 - wmy2) / 2.)
        wtheta = np.abs(wtheta) * np.sign(wmxy)
        ws = np.sqrt(((wmx2 - wmy2) / 2.)**2 + wmxy**2)
        wa = np.sqrt((wmx2 + wmy2) / 2. + ws)
        wb = np.sqrt((wmx2 + wmy2) / 2. - ws)
        well = 1. - wb / wa

        fx = np.array(fx)
        fy = np.array(fy)
        xx = np.array(xx)
        yy = np.array(yy)

        if ps is not None:

            plt.clf()
            dimshow(detsn, vmin=-3, vmax=50, cmap='gray')
            ax = plt.axis()
            plt.plot(fx, fy, 'go', mec='g', mfc='none', ms=10)
            plt.colorbar()
            plt.title('Detected sources')
            plt.axis(ax)
            ps.savefig()

            # show centroids too
            # plt.plot(xx, yy, 'go', mec='g', mfc='none', ms=8)
            # plt.axis(ax)
            # ps.savefig()

        # if ps is not None:
        #     plt.clf()
        #     plt.subplot(2,1,1)
        #     mx = np.percentile(np.append(mx2,my2), 99)
        #     ha = dict(histtype='step', range=(0,mx), bins=50)
        #     plt.hist(mx2, color='b', label='mx2', **ha)
        #     plt.hist(my2, color='r', label='my2', **ha)
        #     plt.hist(mxy, color='g', label='mxy', **ha)
        #     plt.legend()
        #     plt.xlim(0,mx)
        #     plt.subplot(2,1,2)
        #     mx = np.percentile(np.append(wmx2,wmy2), 99)
        #     ha = dict(histtype='step', range=(0,mx), bins=50, lw=3, alpha=0.3)
        #     plt.hist(wmx2, color='b', label='wx2', **ha)
        #     plt.hist(wmy2, color='r', label='wy2', **ha)
        #     plt.hist(wmxy, color='g', label='wxy', **ha)
        #     plt.legend()
        #     plt.xlim(0,mx)
        #     plt.suptitle('Source moments')
        #     ps.savefig()
        #
        #     #mx = np.percentile(np.abs(np.append(mxy,wmxy)), 99)
        #     plt.clf()
        #     plt.subplot(2,1,1)
        #     ha = dict(histtype='step', range=(0,1), bins=50)
        #     plt.hist(ell, color='g', label='ell', **ha)
        #     plt.hist(well, color='g', lw=3, alpha=0.3, label='windowed ell', **ha)
        #     plt.legend()
        #     plt.subplot(2,1,2)
        #     ha = dict(histtype='step', range=(-90,90), bins=50)
        #     plt.hist(theta, color='g', label='theta', **ha)
        #     plt.hist(wtheta, color='g', lw=3, alpha=0.3,
        #              label='windowed theta', **ha)
        #     plt.xlim(-90,90)
        #     plt.legend()
        #     plt.suptitle('Source ellipticities & angles')
        #     ps.savefig()

        # Cut down to stars whose centroids are within 1 pixel of their peaks...
        #keep = (np.hypot(fx - xx, fy - yy) < 2)
        #print(sum(keep), 'of', len(keep), 'stars have centroids within 2 of peaks')
        #print('mean dx', np.mean(fx-xx), 'dy', np.mean(fy-yy), 'pixels')
        #assert(float(sum(keep)) / len(keep) > 0.9)
        #fx = fx[keep]
        #fy = fy[keep]

        apxy = np.vstack((fx, fy)).T
        ap = []
        aprad_pix = self.aprad / pixsc
        aper = photutils.CircularAperture(apxy, aprad_pix)
        p = photutils.aperture_photometry(img, aper)
        apflux = p.field('aperture_sum')

        # Manual aperture photometry to get clipped means in sky annulus
        sky_inner_r, sky_outer_r = [r / pixsc for r in self.skyrad]
        sky = []
        for xi, yi in zip(fx, fy):
            ix = int(np.round(xi))
            iy = int(np.round(yi))
            skyR = int(np.ceil(sky_outer_r))
            xlo = max(0, ix - skyR)
            xhi = min(W, ix + skyR + 1)
            ylo = max(0, iy - skyR)
            yhi = min(H, iy + skyR + 1)
            xx, yy = np.meshgrid(np.arange(xlo, xhi), np.arange(ylo, yhi))
            r2 = (xx - xi)**2 + (yy - yi)**2
            inannulus = ((r2 >= sky_inner_r**2) * (r2 < sky_outer_r**2))
            skypix = img[ylo:yhi, xlo:xhi][inannulus]
            #print('ylo,yhi, xlo,xhi', ylo,yhi, xlo,xhi, 'img subshape', img[ylo:yhi, xlo:xhi].shape, 'inann shape', inannulus.shape)
            s, nil = sensible_sigmaclip(skypix)
            sky.append(s)
        sky = np.array(sky)

        apflux2 = apflux - sky * (np.pi * aprad_pix**2)
        good = (apflux2 > 0) * (apflux > 0)
        apflux = apflux[good]
        apflux2 = apflux2[good]
        fx = fx[good]
        fy = fy[good]

        # Read in the PS1 catalog, and keep those within 0.25 deg of CCD center
        # and those with main sequence colors
        pscat = ps1cat(ccdwcs=wcs)
        stars = pscat.get_stars()
        #print('Got PS1 stars:', len(stars))

        # we add the color term later
        ps1band = ps1cat.ps1band[band]
        stars.mag = stars.median[:, ps1band]

        ok, px, py = wcs.radec2pixelxy(stars.ra, stars.dec)
        px -= 1
        py -= 1

        if ps is not None:
            #kwa = dict(vmin=-3*sig1, vmax=50*sig1, cmap='gray')
            # Add to the 'detected sources' plot
            # mn,mx = np.percentile(img.ravel(), [50,99])
            # kwa = dict(vmin=mn, vmax=mx, cmap='gray')
            # plt.clf()
            # dimshow(img, **kwa)
            ax = plt.axis()
            #plt.plot(fx, fy, 'go', mec='g', mfc='none', ms=10)
            K = np.argsort(stars.mag)
            plt.plot(px[K[:10]] - trim_x0,
                     py[K[:10]] - trim_y0,
                     'o',
                     mec='m',
                     mfc='none',
                     ms=12,
                     mew=2)
            plt.plot(px[K[10:]] - trim_x0,
                     py[K[10:]] - trim_y0,
                     'o',
                     mec='m',
                     mfc='none',
                     ms=8)
            plt.axis(ax)
            plt.title('PS1 stars')
            #plt.colorbar()
            ps.savefig()

        # we trimmed the image before running detection; re-add that margin
        fullx = fx + trim_x0
        fully = fy + trim_y0

        # Match PS1 to our detections, find offset
        radius = self.maxshift / pixsc

        I, J, dx, dy = self.match_ps1_stars(px, py, fullx, fully, radius,
                                            stars)
        print(len(I), 'spatial matches with large radius', self.maxshift,
              'arcsec,', radius, 'pix')

        bins = 2 * int(np.ceil(radius))
        #print('Histogramming with', bins, 'bins')
        histo, xe, ye = np.histogram2d(dx,
                                       dy,
                                       bins=bins,
                                       range=((-radius, radius), (-radius,
                                                                  radius)))
        # smooth histogram before finding peak -- fuzzy matching
        from scipy.ndimage.filters import gaussian_filter
        histo = gaussian_filter(histo, 1.)
        histo = histo.T
        mx = np.argmax(histo)
        my, mx = np.unravel_index(mx, histo.shape)
        shiftx = (xe[mx] + xe[mx + 1]) / 2.
        shifty = (ye[my] + ye[my + 1]) / 2.

        if ps is not None:
            plt.clf()
            plothist(dx, dy, range=((-radius, radius), (-radius, radius)))
            plt.xlabel('dx (pixels)')
            plt.ylabel('dy (pixels)')
            plt.title('Offsets to PS1 stars')
            ax = plt.axis()
            plt.axhline(0, color='b')
            plt.axvline(0, color='b')
            plt.plot(shiftx, shifty, 'o', mec='m', mfc='none', ms=15, mew=3)
            plt.axis(ax)
            ps.savefig()

        # Refine with smaller search radius
        radius2 = 3. / pixsc
        I, J, dx, dy = self.match_ps1_stars(px, py, fullx + shiftx,
                                            fully + shifty, radius2, stars)
        print(len(J), 'matches to PS1 with small radius', 3, 'arcsec')
        shiftx2 = np.median(dx)
        shifty2 = np.median(dy)
        #print('Stage-1 shift', shiftx, shifty)
        #print('Stage-2 shift', shiftx2, shifty2)
        sx = shiftx + shiftx2
        sy = shifty + shifty2
        print('Astrometric shift (%.0f, %.0f) pixels' % (sx, sy))

        if self.debug and ps is not None:
            plt.clf()
            plothist(dx, dy, range=((-radius2, radius2), (-radius2, radius2)))
            plt.xlabel('dx (pixels)')
            plt.ylabel('dy (pixels)')
            plt.title('Offsets to PS1 stars')
            ax = plt.axis()
            plt.axhline(0, color='b')
            plt.axvline(0, color='b')
            plt.plot(shiftx2, shifty2, 'o', mec='m', mfc='none', ms=15, mew=3)
            plt.axis(ax)
            ps.savefig()

        if ps is not None:
            mn, mx = np.percentile(img.ravel(), [50, 99])
            kwa2 = dict(vmin=mn, vmax=mx, cmap='gray')
            plt.clf()
            dimshow(img, **kwa2)
            ax = plt.axis()
            plt.plot(fx[J], fy[J], 'go', mec='g', mfc='none', ms=10, mew=2)
            plt.plot(px[I] - sx - trim_x0,
                     py[I] - sy - trim_y0,
                     'm+',
                     ms=10,
                     mew=2)
            plt.axis(ax)
            plt.title('Matched PS1 stars')
            plt.colorbar()
            ps.savefig()

            plt.clf()
            dimshow(img, **kwa2)
            ax = plt.axis()
            plt.plot(fx[J], fy[J], 'go', mec='g', mfc='none', ms=10, mew=2)
            K = np.argsort(stars.mag)
            plt.plot(px[K[:10]] - sx - trim_x0,
                     py[K[:10]] - sy - trim_y0,
                     'o',
                     mec='m',
                     mfc='none',
                     ms=12,
                     mew=2)
            plt.plot(px[K[10:]] - sx - trim_x0,
                     py[K[10:]] - sy - trim_y0,
                     'o',
                     mec='m',
                     mfc='none',
                     ms=8,
                     mew=2)
            plt.axis(ax)
            plt.title('All PS1 stars')
            plt.colorbar()
            ps.savefig()

        # Now cut to just *stars* with good colors
        stars.gicolor = stars.median[:, 0] - stars.median[:, 2]
        keep = (stars.gicolor > 0.4) * (stars.gicolor < 2.7)
        stars.cut(keep)
        if len(stars) == 0:
            print('No overlap or too few stars in PS1')
            return None
        px = px[keep]
        py = py[keep]
        # Re-match
        I, J, dx, dy = self.match_ps1_stars(px, py, fullx + sx, fully + sy,
                                            radius2, stars)
        print('Cut to', len(stars), 'PS1 stars with good colors; matched',
              len(I))

        nmatched = len(I)

        meas.update(dx=sx, dy=sy, nmatched=nmatched)

        if focus:
            meas.update(img=img,
                        hdr=hdr,
                        primhdr=primhdr,
                        fx=fx,
                        fy=fy,
                        px=px - trim_x0 - sx,
                        py=py - trim_y0 - sy,
                        sig1=sig1,
                        stars=stars,
                        moments=(mx2, my2, mxy, theta, a, b, ell),
                        wmoments=(wmx2, wmy2, wmxy, wtheta, wa, wb, well),
                        apflux=apflux,
                        apflux2=apflux2)
            return meas

        #print('Mean astrometric shift (arcsec): delta-ra=', -np.mean(dy)*0.263, 'delta-dec=', np.mean(dx)*0.263)

        # Compute photometric offset compared to PS1
        # as the PS1 minus observed mags
        colorterm = self.colorterm_ps1_to_observed(stars.median, band)
        stars.mag += colorterm
        ps1mag = stars.mag[I]

        if False and ps is not None:
            plt.clf()
            plt.semilogy(ps1mag, apflux2[J], 'b.')
            plt.xlabel('PS1 mag')
            plt.ylabel('DECam ap flux (with sky sub)')
            ps.savefig()

            plt.clf()
            plt.semilogy(ps1mag, apflux[J], 'b.')
            plt.xlabel('PS1 mag')
            plt.ylabel('DECam ap flux (no sky sub)')
            ps.savefig()

        apmag2 = -2.5 * np.log10(apflux2) + zp0 + 2.5 * np.log10(exptime)
        apmag = -2.5 * np.log10(apflux) + zp0 + 2.5 * np.log10(exptime)

        if ps is not None:
            plt.clf()
            plt.plot(ps1mag, apmag[J], 'b.', label='No sky sub')
            plt.plot(ps1mag, apmag2[J], 'r.', label='Sky sub')
            # ax = plt.axis()
            # mn = min(ax[0], ax[2])
            # mx = max(ax[1], ax[3])
            # plt.plot([mn,mx], [mn,mx], 'k-', alpha=0.1)
            # plt.axis(ax)
            plt.xlabel('PS1 mag')
            plt.ylabel('DECam ap mag')
            plt.legend(loc='upper left')
            plt.title('Zeropoint')
            ps.savefig()

        dm = ps1mag - apmag[J]
        dmag, dsig = sensible_sigmaclip(dm, nsigma=2.5)
        print('Mag offset: %8.3f' % dmag)
        print('Scatter:    %8.3f' % dsig)

        if not np.isfinite(dmag) or not np.isfinite(dsig):
            print('FAILED TO GET ZEROPOINT!')
            meas.update(zp=None)
            return meas

        from scipy.stats import sigmaclip
        goodpix, lo, hi = sigmaclip(dm, low=3, high=3)
        dmagmed = np.median(goodpix)
        print(len(goodpix), 'stars used for zeropoint median')
        print('Using median zeropoint:')
        zp_med = zp0 + dmagmed
        trans_med = 10.**(-0.4 * (zp0 - zp_med - kx * (airmass - 1.)))
        print('Zeropoint %6.3f' % zp_med)
        print('Transparency: %.3f' % trans_med)

        dm = ps1mag - apmag2[J]
        dmag2, dsig2 = sensible_sigmaclip(dm, nsigma=2.5)
        #print('Sky-sub mag offset', dmag2)
        #print('Scatter', dsig2)

        if ps is not None:
            plt.clf()
            plt.plot(ps1mag,
                     apmag[J] + dmag - ps1mag,
                     'b.',
                     label='No sky sub')
            plt.plot(ps1mag, apmag2[J] + dmag2 - ps1mag, 'r.', label='Sky sub')
            plt.xlabel('PS1 mag')
            plt.ylabel('DECam ap mag - PS1 mag')
            plt.legend(loc='upper left')
            plt.ylim(-0.25, 0.25)
            plt.axhline(0, color='k', alpha=0.25)
            plt.title('Zeropoint')
            ps.savefig()

        zp_obs = zp0 + dmag
        transparency = 10.**(-0.4 * (zp0 - zp_obs - kx * (airmass - 1.)))
        meas.update(zp=zp_obs, transparency=transparency)

        print('Zeropoint %6.3f' % zp_obs)
        print('Fiducial  %6.3f' % zp0)
        print('Transparency: %.3f' % transparency)

        # print('Using sky-subtracted values:')
        # zp_sky = zp0 + dmag2
        # trans_sky = 10.**(-0.4 * (zp0 - zp_sky - kx * (airmass - 1.)))
        # print('Zeropoint %6.3f' % zp_sky)
        # print('Transparency: %.3f' % trans_sky)

        fwhms = []
        psf_r = 15
        if n_fwhm not in [0, None]:
            Jf = J[:n_fwhm]

        for i, (xi, yi, fluxi) in enumerate(zip(fx[Jf], fy[Jf], apflux[Jf])):
            #print('Fitting source', i, 'of', len(Jf))
            ix = int(np.round(xi))
            iy = int(np.round(yi))
            xlo = max(0, ix - psf_r)
            xhi = min(W, ix + psf_r + 1)
            ylo = max(0, iy - psf_r)
            yhi = min(H, iy + psf_r + 1)
            xx, yy = np.meshgrid(np.arange(xlo, xhi), np.arange(ylo, yhi))
            r2 = (xx - xi)**2 + (yy - yi)**2
            keep = (r2 < psf_r**2)
            pix = img[ylo:yhi, xlo:xhi].copy()
            ie = np.zeros_like(pix)
            ie[keep] = 1. / sig1
            #print('fitting source at', ix,iy)
            #print('number of active pixels:', np.sum(ie > 0), 'shape', ie.shape)

            psf = tractor.NCircularGaussianPSF([4.], [1.])
            tim = tractor.Image(data=pix, inverr=ie, psf=psf)
            src = tractor.PointSource(tractor.PixPos(xi - xlo, yi - ylo),
                                      tractor.Flux(fluxi))
            tr = tractor.Tractor([tim], [src])

            #print('Posterior before prior:', tr.getLogProb())
            src.pos.addGaussianPrior('x', 0., 1.)
            #print('Posterior after prior:', tr.getLogProb())

            doplot = (i < 5) * (ps is not None)
            if doplot:
                mod0 = tr.getModelImage(0)

            tim.freezeAllBut('psf')
            psf.freezeAllBut('sigmas')

            # print('Optimizing params:')
            # tr.printThawedParams()

            #print('Parameter step sizes:', tr.getStepSizes())
            optargs = dict(priors=False, shared_params=False)
            for step in range(50):
                dlnp, x, alpha = tr.optimize(**optargs)
                #print('dlnp', dlnp)
                #print('src', src)
                #print('psf', psf)
                if dlnp == 0:
                    break
            # Now fit only the PSF size
            tr.freezeParam('catalog')
            # print('Optimizing params:')
            # tr.printThawedParams()

            for step in range(50):
                dlnp, x, alpha = tr.optimize(**optargs)
                #print('dlnp', dlnp)
                #print('src', src)
                #print('psf', psf)
                if dlnp == 0:
                    break

            fwhms.append(psf.sigmas[0] * 2.35 * pixsc)

            if doplot:
                mod1 = tr.getModelImage(0)
                chi1 = tr.getChiImage(0)

                plt.clf()
                plt.subplot(2, 2, 1)
                plt.title('Image')
                dimshow(pix, **self.imgkwa)
                plt.subplot(2, 2, 2)
                plt.title('Initial model')
                dimshow(mod0, **self.imgkwa)
                plt.subplot(2, 2, 3)
                plt.title('Final model')
                dimshow(mod1, **self.imgkwa)
                plt.subplot(2, 2, 4)
                plt.title('Final chi')
                dimshow(chi1, vmin=-10, vmax=10)
                plt.suptitle('PSF fit')
                ps.savefig()

        fwhms = np.array(fwhms)
        fwhm = np.median(fwhms)
        print('Median FWHM: %.3f' % np.median(fwhms))
        meas.update(seeing=fwhm)

        if False and ps is not None:
            lo, hi = np.percentile(fwhms, [5, 95])
            lo -= 0.1
            hi += 0.1
            plt.clf()
            plt.hist(fwhms, 25, range=(lo, hi), histtype='step', color='b')
            plt.xlabel('FWHM (arcsec)')
            ps.savefig()

        if ps is not None:
            plt.clf()
            for i, (xi, yi) in enumerate(zip(fx[J], fy[J])[:50]):
                ix = int(np.round(xi))
                iy = int(np.round(yi))
                xlo = max(0, ix - psf_r)
                xhi = min(W, ix + psf_r + 1)
                ylo = max(0, iy - psf_r)
                yhi = min(H, iy + psf_r + 1)
                pix = img[ylo:yhi, xlo:xhi]

                slc = pix[iy - ylo, :].copy()
                slc /= np.sum(slc)
                p1 = plt.plot(slc, 'b-', alpha=0.2)
                slc = pix[:, ix - xlo].copy()
                slc /= np.sum(slc)
                p2 = plt.plot(slc, 'r-', alpha=0.2)
                ph, pw = pix.shape
                cx, cy = pw / 2, ph / 2
                if i == 0:
                    xx = np.linspace(0, pw, 300)
                    dx = xx[1] - xx[0]
                    sig = fwhm / pixsc / 2.35
                    yy = np.exp(-0.5 * (xx - cx)**2 / sig**2)  # * np.sum(pix)
                    yy /= (np.sum(yy) * dx)
                    p3 = plt.plot(xx, yy, 'k-', zorder=20)
            #plt.ylim(-0.2, 1.0)
            plt.legend([p1[0], p2[0], p3[0]],
                       ['image slice (y)', 'image slice (x)', 'fit'])
            plt.title('PSF fit')
            ps.savefig()

        return meas
def stage_fitplots(
    T=None, coimgs=None, cons=None,
    cat=None, targetrd=None, pixscale=None, targetwcs=None,
    W=None,H=None,
    bands=None, ps=None, brickid=None,
    plots=False, plots2=False, tims=None, tractor=None,
    pipe=None,
    outdir=None,
    **kwargs):

    for tim in tims:
        print 'Tim', tim, 'PSF', tim.getPsf()
        
    writeModels = False

    if pipe:
        t0 = Time()
        # Produce per-band coadds, for plots
        coimgs,cons = compute_coadds(tims, bands, targetwcs)
        print 'Coadds:', Time()-t0

    plt.figure(figsize=(10,10.5))
    #plt.subplots_adjust(left=0.002, right=0.998, bottom=0.002, top=0.998)
    plt.subplots_adjust(left=0.002, right=0.998, bottom=0.002, top=0.95)

    plt.clf()
    dimshow(get_rgb(coimgs, bands))
    plt.title('Image')
    ps.savefig()

    ax = plt.axis()
    cat = tractor.getCatalog()
    for i,src in enumerate(cat):
        rd = src.getPosition()
        ok,x,y = targetwcs.radec2pixelxy(rd.ra, rd.dec)
        cc = (0,1,0)
        if isinstance(src, PointSource):
            plt.plot(x-1, y-1, '+', color=cc, ms=10, mew=1.5)
        else:
            plt.plot(x-1, y-1, 'o', mec=cc, mfc='none', ms=10, mew=1.5)
        # plt.text(x, y, '%i' % i, color=cc, ha='center', va='bottom')
    plt.axis(ax)
    ps.savefig()

    mnmx = -5,300
    arcsinha = dict(mnmx=mnmx, arcsinh=1)

    # After plot
    rgbmod = []
    rgbmod2 = []
    rgbresids = []
    rgbchisqs = []

    chibins = np.linspace(-10., 10., 200)
    chihist = [np.zeros(len(chibins)-1, int) for band in bands]

    wcsW = targetwcs.get_width()
    wcsH = targetwcs.get_height()
    print 'Target WCS shape', wcsW,wcsH

    t0 = Time()
    mods = _map(_get_mod, [(tim, cat) for tim in tims])
    print 'Getting model images:', Time()-t0

    orig_wcsxy0 = [tim.wcs.getX0Y0() for tim in tims]
    for iband,band in enumerate(bands):
        coimg = coimgs[iband]
        comod  = np.zeros((wcsH,wcsW), np.float32)
        comod2 = np.zeros((wcsH,wcsW), np.float32)
        cochi2 = np.zeros((wcsH,wcsW), np.float32)
        for itim, (tim,mod) in enumerate(zip(tims, mods)):
            if tim.band != band:
                continue

            #mod = tractor.getModelImage(tim)

            if plots2:
                plt.clf()
                dimshow(tim.getImage(), **tim.ima)
                plt.title(tim.name)
                ps.savefig()
                plt.clf()
                dimshow(mod, **tim.ima)
                plt.title(tim.name)
                ps.savefig()
                plt.clf()
                dimshow((tim.getImage() - mod) * tim.getInvError(), **imchi)
                plt.title(tim.name)
                ps.savefig()

            R = tim_get_resamp(tim, targetwcs)
            if R is None:
                continue
            (Yo,Xo,Yi,Xi) = R
            comod[Yo,Xo] += mod[Yi,Xi]
            ie = tim.getInvError()
            noise = np.random.normal(size=ie.shape) / ie
            noise[ie == 0] = 0.
            comod2[Yo,Xo] += mod[Yi,Xi] + noise[Yi,Xi]
            chi = ((tim.getImage()[Yi,Xi] - mod[Yi,Xi]) * tim.getInvError()[Yi,Xi])
            cochi2[Yo,Xo] += chi**2
            chi = chi[chi != 0.]
            hh,xe = np.histogram(np.clip(chi, -10, 10).ravel(), bins=chibins)
            chihist[iband] += hh

            if not writeModels:
                continue

            im = tim.imobj
            fn = 'image-b%06i-%s-%s.fits' % (brickid, band, im.name)

            wcsfn = create_temp()
            wcs = tim.getWcs().wcs
            x0,y0 = orig_wcsxy0[itim]
            h,w = tim.shape
            subwcs = wcs.get_subimage(int(x0), int(y0), w, h)
            subwcs.write_to(wcsfn)

            primhdr = fitsio.FITSHDR()
            primhdr.add_record(dict(name='X0', value=x0, comment='Pixel origin of subimage'))
            primhdr.add_record(dict(name='Y0', value=y0, comment='Pixel origin of subimage'))
            xfn = im.wcsfn.replace(decals_dir+'/', '')
            primhdr.add_record(dict(name='WCS_FILE', value=xfn))
            xfn = im.psffn.replace(decals_dir+'/', '')
            primhdr.add_record(dict(name='PSF_FILE', value=xfn))
            primhdr.add_record(dict(name='INHERIT', value=True))

            imhdr = fitsio.read_header(wcsfn)
            imhdr.add_record(dict(name='EXTTYPE', value='IMAGE', comment='This HDU contains image data'))
            ivhdr = fitsio.read_header(wcsfn)
            ivhdr.add_record(dict(name='EXTTYPE', value='INVVAR', comment='This HDU contains an inverse-variance map'))
            fits = fitsio.FITS(fn, 'rw', clobber=True)
            tim.toFits(fits, primheader=primhdr, imageheader=imhdr, invvarheader=ivhdr)

            imhdr.add_record(dict(name='EXTTYPE', value='MODEL', comment='This HDU contains a Tractor model image'))
            fits.write(mod, header=imhdr)
            print 'Wrote image and model to', fn
            
        comod  /= np.maximum(cons[iband], 1)
        comod2 /= np.maximum(cons[iband], 1)

        rgbmod.append(comod)
        rgbmod2.append(comod2)
        resid = coimg - comod
        resid[cons[iband] == 0] = np.nan
        rgbresids.append(resid)
        rgbchisqs.append(cochi2)

        # Plug the WCS header cards into these images
        wcsfn = create_temp()
        targetwcs.write_to(wcsfn)
        hdr = fitsio.read_header(wcsfn)
        os.remove(wcsfn)

        if outdir is None:
            outdir = '.'
        wa = dict(clobber=True, header=hdr)
        for name,img in [('image', coimg), ('model', comod), ('resid', resid), ('chi2', cochi2)]:
            fn = os.path.join(outdir, '%s-coadd-%06i-%s.fits' % (name, brickid, band))
            fitsio.write(fn, img,  **wa)
            print 'Wrote', fn

    del cons

    plt.clf()
    dimshow(get_rgb(rgbmod, bands))
    plt.title('Model')
    ps.savefig()

    plt.clf()
    dimshow(get_rgb(rgbmod2, bands))
    plt.title('Model + Noise')
    ps.savefig()

    plt.clf()
    dimshow(get_rgb(rgbresids, bands))
    plt.title('Residuals')
    ps.savefig()

    plt.clf()
    dimshow(get_rgb(rgbresids, bands, mnmx=(-30,30)))
    plt.title('Residuals (2)')
    ps.savefig()

    plt.clf()
    dimshow(get_rgb(coimgs, bands, **arcsinha))
    plt.title('Image (stretched)')
    ps.savefig()

    plt.clf()
    dimshow(get_rgb(rgbmod2, bands, **arcsinha))
    plt.title('Model + Noise (stretched)')
    ps.savefig()

    del coimgs
    del rgbresids
    del rgbmod
    del rgbmod2

    plt.clf()
    g,r,z = rgbchisqs
    im = np.log10(np.dstack((z,r,g)))
    mn,mx = 0, im.max()
    dimshow(np.clip((im - mn) / (mx - mn), 0., 1.))
    plt.title('Chi-squared')
    ps.savefig()

    plt.clf()
    xx = np.repeat(chibins, 2)[1:-1]
    for y,cc in zip(chihist, 'grm'):
        plt.plot(xx, np.repeat(np.maximum(0.1, y),2), '-', color=cc)
    plt.xlabel('Chi')
    plt.yticks([])
    plt.axvline(0., color='k', alpha=0.25)
    ps.savefig()

    plt.yscale('log')
    mx = np.max([max(y) for y in chihist])
    plt.ylim(1, mx * 1.05)
    ps.savefig()

    return dict(tims=tims)
Exemple #27
0
def segment_and_group_sources(image, T, name=None, ps=None, plots=False):
    '''
    *image*: binary image that defines "blobs"
    *T*: source table; only ".ibx" and ".iby" elements are used (x,y integer
    pix pos).  Note: ".blob" field is added.
    *name*: for debugging only

    Returns: (blobmap, blobsrcs, blobslices)

    *blobmap*: image, values -1 = no blob, integer blob indices
    *blobsrcs*: list of np arrays of integers, elements in T within each blob
    *blobslices*: list of slice objects for blob bounding-boxes.
    '''
    from scipy.ndimage.morphology import binary_fill_holes
    from scipy.ndimage.measurements import label, find_objects

    image = binary_fill_holes(image)
    blobmap, nblobs = label(image)
    H, W = image.shape
    del image

    blobslices = find_objects(blobmap)
    clipx = np.clip(T.ibx, 0, W - 1)
    clipy = np.clip(T.iby, 0, H - 1)
    source_blobs = blobmap[clipy, clipx]

    if plots:
        import pylab as plt
        from astrometry.util.plotutils import dimshow
        plt.clf()
        dimshow(blobmap > 0, vmin=0, vmax=1)
        ax = plt.axis()
        for i, bs in enumerate(blobslices):
            sy, sx = bs
            by0, by1 = sy.start, sy.stop
            bx0, bx1 = sx.start, sx.stop
            plt.plot([bx0, bx0, bx1, bx1, bx0], [by0, by1, by1, by0, by0],
                     'r-')
            plt.text((bx0 + bx1) / 2.,
                     by0,
                     '%i' % (i + 1),
                     ha='center',
                     va='bottom',
                     color='r')
        plt.plot(T.ibx, T.iby, 'rx')
        for i, t in enumerate(T):
            plt.text(t.ibx,
                     t.iby,
                     'src %i' % i,
                     color='red',
                     ha='left',
                     va='center')
        plt.axis(ax)
        plt.title('Blobs')
        ps.savefig()

    # Find sets of sources within blobs
    blobsrcs = []
    keepslices = []
    blobindex = {}
    for blob in range(1, nblobs + 1):
        Isrcs, = np.nonzero(source_blobs == blob)
        if len(Isrcs) == 0:
            blobindex[blob] = -1
            continue
        blobindex[blob] = len(blobsrcs)
        blobsrcs.append(Isrcs)
        bslc = blobslices[blob - 1]
        keepslices.append(bslc)

    blobslices = keepslices

    # Remap the "blobmap" image so that empty regions are = -1 and the blob values
    # correspond to their indices in the "blobsrcs" list.
    if len(blobindex):
        maxblob = max(blobindex.keys())
    else:
        maxblob = 0
    maxblob = max(maxblob, blobmap.max())
    remap = np.zeros(maxblob + 1, np.int32)
    for k, v in blobindex.items():
        remap[k] = v
    remap[0] = -1
    # Remap blob numbers
    blobmap = remap[blobmap]
    del remap
    del blobindex

    if plots:
        from astrometry.util.plotutils import dimshow
        plt.clf()
        dimshow(blobmap > -1, vmin=0, vmax=1)
        ax = plt.axis()
        for i, bs in enumerate(blobslices):
            sy, sx = bs
            by0, by1 = sy.start, sy.stop
            bx0, bx1 = sx.start, sx.stop
            plt.plot([bx0, bx0, bx1, bx1, bx0], [by0, by1, by1, by0, by0],
                     'r-')
            plt.text((bx0 + bx1) / 2.,
                     by0,
                     '%i' % (i + 1),
                     ha='center',
                     va='bottom',
                     color='r')
        plt.plot(T.ibx, T.iby, 'rx')
        for i, t in enumerate(T):
            plt.text(t.ibx,
                     t.iby,
                     'src %i' % i,
                     color='red',
                     ha='left',
                     va='center')
        plt.axis(ax)
        plt.title('Blobs')
        ps.savefig()

    for j, Isrcs in enumerate(blobsrcs):
        for i in Isrcs:
            if (blobmap[clipy[i], clipx[i]] != j):
                info('---------------------------!!!-------------------------')
                info('Blob', j, 'sources', Isrcs)
                info('Source', i, 'coords x,y', T.ibx[i], T.iby[i])
                info('Expected blob value', j, 'but got', blobmap[clipy[i],
                                                                  clipx[i]])

    assert (len(blobsrcs) == len(blobslices))
    return blobmap, blobsrcs, blobslices
def main():
    ps = PlotSequence('conv')

    S = 51
    center = S / 2
    print('Center', center)

    #for psf_sigma in [2., 1.5, 1.]:
    for psf_sigma in [2.]:

        rms2 = []

        x = np.arange(S)
        y = np.arange(S)
        xx, yy = np.meshgrid(x, y)

        scale = 1.5 / psf_sigma
        pixpsf = render_airy((scale, center), x, y)
        psf = (scale, center)
        eval_psf = render_airy

        plt.clf()
        plt.subplot(2, 1, 1)
        plt.plot(x, pixpsf[center, :], 'b-')
        plt.plot(x, pixpsf[:, center], 'r-')
        plt.subplot(2, 1, 2)
        plt.plot(x, np.maximum(1e-16, pixpsf[center, :]), 'b-')
        plt.plot(x, np.maximum(1e-16, pixpsf[:, center]), 'r-')
        plt.yscale('log')
        ps.savefig()

        plt.clf()
        plt.imshow(pixpsf, interpolation='nearest', origin='lower')
        ps.savefig()

        plt.clf()
        plt.imshow(np.log10(np.maximum(1e-16, pixpsf)),
                   interpolation='nearest',
                   origin='lower')
        plt.colorbar()
        plt.title('log PSF')
        ps.savefig()

        # psf
        #psf = scipy.stats.norm(loc=center + 0.5, scale=psf_sigma)

        # plt.clf()
        # plt.imshow(Pcdf, interpolation='nearest', origin='lower')
        # ps.savefig()

        # #Pcdf = psf.cdf(xx) * psf.cdf(yy)
        # #pixpsf = integrate_gaussian(psf, xx, yy)
        #
        # padpsf = np.zeros((S*2-1, S*2-1))
        # ph,pw = pixpsf.shape
        # padpsf[S/2:S/2+ph, S/2:S/2+pw] = pixpsf
        # Fpsf = np.fft.rfft2(padpsf)
        #
        # padh,padw = padpsf.shape
        # v = np.fft.rfftfreq(padw)
        # w = np.fft.fftfreq(padh)
        # fmax = max(max(np.abs(v)), max(np.abs(w)))
        # cut = fmax / 2. * 1.000001
        # #print('Frequence cut:', cut)
        # Ffiltpsf = Fpsf.copy()
        # #print('Ffiltpsf', Ffiltpsf.shape)
        # #print((np.abs(w) < cut).shape)
        # #print((np.abs(v) < cut).shape)
        # Ffiltpsf[np.abs(w) > cut, :] = 0.
        # Ffiltpsf[:, np.abs(v) > cut] = 0.
        # #print('pad v', v)
        # #print('pad w', w)
        #
        # filtpsf = np.fft.irfft2(Ffiltpsf, s=(padh,padw))
        #
        # print('filtered PSF real', np.max(np.abs(filtpsf.real)))
        # print('filtered PSF imag', np.max(np.abs(filtpsf.imag)))
        #
        # plt.clf()
        # plt.subplot(2,3,1)
        # dimshow(Fpsf.real)
        # plt.colorbar()
        # plt.title('Padded PSF real')
        # plt.subplot(2,3,4)
        # dimshow(Fpsf.imag)
        # plt.colorbar()
        # plt.title('Padded PSF imag')
        #
        # plt.subplot(2,3,2)
        # dimshow(Ffiltpsf.real)
        # plt.colorbar()
        # plt.title('Filt PSF real')
        # plt.subplot(2,3,5)
        # dimshow(Ffiltpsf.imag)
        # plt.colorbar()
        # plt.title('Filt PSF imag')
        #
        # plt.subplot(2,3,3)
        # dimshow(filtpsf.real)
        # plt.title('PSF real')
        # plt.colorbar()
        #
        # plt.subplot(2,3,6)
        # dimshow(filtpsf.imag)
        # plt.title('PSF imag')
        # plt.colorbar()
        #
        # ps.savefig()
        #
        #
        # pixpsf = filtpsf

        gal_sigmas = [2, 1, 0.5, 0.25]
        for gal_sigma in gal_sigmas:

            # plt.clf()
            # plt.imshow(Gcdf, interpolation='nearest', origin='lower')
            # plt.savefig('dcdf.png')

            # plt.clf()
            # plt.imshow(np.exp(-0.5 * ((xx-center)**2 + (yy-center)**2)/2.**2),
            #            interpolation='nearest', origin='lower')
            # plt.savefig('g.png')

            # my convolution
            pixscale = 1.
            cd = pixscale * np.eye(2) / 3600.
            P, FG, Gmine, v, w = galaxy_psf_convolution(gal_sigma,
                                                        0.,
                                                        0.,
                                                        GaussianGalaxy,
                                                        cd,
                                                        0.,
                                                        0.,
                                                        pixpsf,
                                                        debug=True)

            #print('v:', v)
            #print('w:', w)
            #print('P:', P.shape)

            print()
            print('PSF %g, Gal %g' % (psf_sigma, gal_sigma))

            rmax = np.argmax(np.abs(w))
            cmax = np.argmax(np.abs(v))
            l2_rmax = np.sqrt(np.sum(P[rmax, :].real**2 + P[rmax, :].imag**2))
            l2_cmax = np.sqrt(np.sum(P[:, cmax].real**2 + P[:, cmax].imag**2))
            print('PSF L_2 in highest-frequency rows & cols:', l2_rmax,
                  l2_cmax)

            l2_rmax = np.sqrt(np.sum(FG[rmax, :].real**2 +
                                     FG[rmax, :].imag**2))
            l2_cmax = np.sqrt(np.sum(FG[:, cmax].real**2 +
                                     FG[:, cmax].imag**2))
            print('Gal L_2 in highest-frequency rows & cols:', l2_rmax,
                  l2_cmax)

            C = P * FG
            l2_rmax = np.sqrt(np.sum(C[rmax, :].real**2 + C[rmax, :].imag**2))
            l2_cmax = np.sqrt(np.sum(C[:, cmax].real**2 + C[:, cmax].imag**2))
            print('PSF*Gal L_2 in highest-frequency rows & cols:', l2_rmax,
                  l2_cmax)
            print()

            Fpsf, Fgal = compare_subsampled(S,
                                            1,
                                            ps,
                                            psf,
                                            pixpsf,
                                            Gmine,
                                            v,
                                            w,
                                            gal_sigma,
                                            psf_sigma,
                                            cd,
                                            get_ffts=True,
                                            eval_psf=eval_psf)

            plt.clf()
            plt.subplot(2, 4, 1)
            dimshow(P.real)
            plt.colorbar()
            plt.title('PSF real')
            plt.subplot(2, 4, 5)
            dimshow(P.imag)
            plt.colorbar()
            plt.title('PSF imag')

            plt.subplot(2, 4, 2)
            dimshow(FG.real)
            plt.colorbar()
            plt.title('Gal real')
            plt.subplot(2, 4, 6)
            dimshow(FG.imag)
            plt.colorbar()
            plt.title('Gal imag')

            plt.subplot(2, 4, 3)
            dimshow((P * FG).real)
            plt.colorbar()
            plt.title('P*Gal real')
            plt.subplot(2, 4, 7)
            dimshow((P * FG).imag)
            plt.colorbar()
            plt.title('P*Gal imag')

            plt.subplot(2, 4, 4)
            dimshow((Fgal).real)
            plt.colorbar()
            plt.title('pixGal real')
            plt.subplot(2, 4, 8)
            dimshow((Fgal).imag)
            plt.colorbar()
            plt.title('pixGal imag')

            plt.suptitle('PSF %g, Gal %g' % (psf_sigma, gal_sigma))

            ps.savefig()

            subsample = [1, 2, 4]
            rms1 = []
            for s in subsample:
                rms = compare_subsampled(S,
                                         s,
                                         ps,
                                         psf,
                                         pixpsf,
                                         Gmine,
                                         v,
                                         w,
                                         gal_sigma,
                                         psf_sigma,
                                         cd,
                                         eval_psf=eval_psf)
                rms1.append(rms)
            rms2.append(rms1)

        print()
        print('PSF sigma =', psf_sigma)
        print('RMSes:')
        for rms1, gal_sigma in zip(rms2, gal_sigmas):
            print('Gal sigma', gal_sigma, 'rms:',
                  ', '.join(['%.3g' % r for r in rms1]))
Exemple #29
0
def stage1(T=None,
           coimgs=None,
           cons=None,
           detmaps=None,
           detivs=None,
           targetrd=None,
           pixscale=None,
           targetwcs=None,
           W=None,
           H=None,
           bands=None,
           tims=None,
           ps=None,
           brick=None,
           cat=None):
    orig_wcsxy0 = [tim.wcs.getX0Y0() for tim in tims]
    hot = np.zeros((H, W), np.float32)

    for band in bands:
        detmap = detmaps[band] / np.maximum(1e-16, detivs[band])
        detsn = detmap * np.sqrt(detivs[band])
        hot = np.maximum(hot, detsn)
        detmaps[band] = detmap

    ### FIXME -- ugri
    for sedname, sed in [('Flat', (1., 1., 1.)), ('Red', (2.5, 1.0, 0.4))]:
        sedmap = np.zeros((H, W), np.float32)
        sediv = np.zeros((H, W), np.float32)
        for iband, band in enumerate(bands):
            # We convert the detmap to canonical band via
            #   detmap * w
            # And the corresponding change to sig1 is
            #   sig1 * w
            # So the invvar-weighted sum is
            #    (detmap * w) / (sig1**2 * w**2)
            #  = detmap / (sig1**2 * w)
            sedmap += detmaps[band] * detivs[band] / sed[iband]
            sediv += detivs[band] / sed[iband]**2
        sedmap /= np.maximum(1e-16, sediv)
        sedsn = sedmap * np.sqrt(sediv)
        hot = np.maximum(hot, sedsn)

        plt.clf()
        dimshow(np.round(sedsn), vmin=0, vmax=10, cmap='hot')
        plt.title('SED-matched detection filter: %s' % sedname)
        ps.savefig()

    peaks = (hot > 4)
    blobs, nblobs = label(peaks)
    print('N detected blobs:', nblobs)
    blobslices = find_objects(blobs)
    # Un-set catalog blobs
    for x, y in zip(T.itx, T.ity):
        # blob number
        bb = blobs[y, x]
        if bb == 0:
            continue
        # un-set 'peaks' within this blob
        slc = blobslices[bb - 1]
        peaks[slc][blobs[slc] == bb] = 0

    # Now, after having removed catalog sources, crank up the detection threshold
    peaks &= (hot > 5)

    # zero out the edges(?)
    peaks[0, :] = peaks[:, 0] = 0
    peaks[-1, :] = peaks[:, -1] = 0
    peaks[1:-1, 1:-1] &= (hot[1:-1, 1:-1] >= hot[0:-2, 1:-1])
    peaks[1:-1, 1:-1] &= (hot[1:-1, 1:-1] >= hot[2:, 1:-1])
    peaks[1:-1, 1:-1] &= (hot[1:-1, 1:-1] >= hot[1:-1, 0:-2])
    peaks[1:-1, 1:-1] &= (hot[1:-1, 1:-1] >= hot[1:-1, 2:])

    # These are our peaks
    pki = np.flatnonzero(peaks)
    peaky, peakx = np.unravel_index(pki, peaks.shape)
    print(len(peaky), 'peaks')

    crossa = dict(ms=10, mew=1.5)
    plt.clf()
    dimshow(get_rgb(coimgs, bands))
    ax = plt.axis()
    plt.plot(T.tx, T.ty, 'r+', **crossa)
    plt.plot(peakx, peaky, '+', color=green, **crossa)
    plt.axis(ax)
    plt.title('SDSS + SED-matched detections')
    ps.savefig()

    ### HACK -- high threshold again

    # Segment, and record which sources fall into each blob
    blobs, nblobs = label((hot > 20))
    print('N detected blobs:', nblobs)
    blobslices = find_objects(blobs)
    T.blob = blobs[T.ity, T.itx]
    blobsrcs = []
    blobflux = []
    fluximg = coimgs[1]
    for blob in range(1, nblobs + 1):
        blobsrcs.append(np.flatnonzero(T.blob == blob))
        bslc = blobslices[blob - 1]
        blobflux.append(np.sum(fluximg[bslc][blobs[bslc] == blob]))

    # Fit the SDSS sources

    for tim in tims:
        tim.psfex.fitSavedData(*tim.psfex.splinedata)
        tim.psf = tim.psfex

    # How far down to render model profiles
    minsigma = 0.1
    for tim in tims:
        tim.modelMinval = minsigma * tim.sig1
    srcvariances = [[] for src in cat]
    # Fit in order of flux
    for blobnumber, iblob in enumerate(np.argsort(-np.array(blobflux))):

        bslc = blobslices[iblob]
        Isrcs = blobsrcs[iblob]
        if len(Isrcs) == 0:
            continue

        print()
        print('Blob', blobnumber, 'of', len(blobflux), ':', len(Isrcs),
              'sources')
        print('Source indices:', Isrcs)
        print()

        # blob bbox in target coords
        sy, sx = bslc
        by0, by1 = sy.start, sy.stop
        bx0, bx1 = sx.start, sx.stop
        blobh, blobw = by1 - by0, bx1 - bx0

        rr, dd = targetwcs.pixelxy2radec([bx0, bx0, bx1, bx1],
                                         [by0, by1, by1, by0])
        alphas = [0.1, 0.3, 1.0]
        subtims = []
        for itim, tim in enumerate(tims):
            h, w = tim.shape
            ok, x, y = tim.subwcs.radec2pixelxy(rr, dd)
            sx0, sx1 = x.min(), x.max()
            sy0, sy1 = y.min(), y.max()
            if sx1 < 0 or sy1 < 0 or sx1 > w or sy1 > h:
                continue
            sx0 = np.clip(int(np.floor(sx0)), 0, w - 1)
            sx1 = np.clip(int(np.ceil(sx1)), 0, w - 1) + 1
            sy0 = np.clip(int(np.floor(sy0)), 0, h - 1)
            sy1 = np.clip(int(np.ceil(sy1)), 0, h - 1) + 1
            subslc = slice(sy0, sy1), slice(sx0, sx1)
            subimg = tim.getImage()[subslc]
            subie = tim.getInvError()[subslc]
            subwcs = tim.getWcs().copy()
            ox0, oy0 = orig_wcsxy0[itim]
            subwcs.setX0Y0(ox0 + sx0, oy0 + sy0)

            # Mask out inverr for pixels that are not within the blob.
            subtarget = targetwcs.get_subimage(bx0, by0, blobw, blobh)
            subsubwcs = tim.subwcs.get_subimage(int(sx0), int(sy0),
                                                int(sx1 - sx0), int(sy1 - sy0))
            try:
                Yo, Xo, Yi, Xi, rims = resample_with_wcs(
                    subsubwcs, subtarget, [], 2)
            except OverlapError:
                print('No overlap')
                continue
            if len(Yo) == 0:
                continue
            subie2 = np.zeros_like(subie)
            I = np.flatnonzero(blobs[bslc][Yi, Xi] == (iblob + 1))
            subie2[Yo[I], Xo[I]] = subie[Yo[I], Xo[I]]
            subie = subie2
            # If the subimage (blob) is small enough, instantiate a
            # constant PSF model in the center.
            if sy1 - sy0 < 100 and sx1 - sx0 < 100:
                subpsf = tim.psf.mogAt(ox0 + (sx0 + sx1) / 2.,
                                       oy0 + (sy0 + sy1) / 2.)
            else:
                # Otherwise, instantiate a (shifted) spatially-varying
                # PsfEx model.
                subpsf = ShiftedPsf(tim.psf, ox0 + sx0, oy0 + sy0)

            subtim = Image(data=subimg,
                           inverr=subie,
                           wcs=subwcs,
                           psf=subpsf,
                           photocal=tim.getPhotoCal(),
                           sky=tim.getSky(),
                           name=tim.name)
            subtim.band = tim.band
            subtim.sig1 = tim.sig1
            subtim.modelMinval = tim.modelMinval
            subtims.append(subtim)

        subcat = Catalog(*[cat[i] for i in Isrcs])
        subtr = Tractor(subtims, subcat)
        subtr.freezeParam('images')
        # Optimize individual sources in order of flux
        fluxes = []
        for src in subcat:
            # HACK -- here we just *sum* the nanomaggies in each band.  Bogus!
            br = src.getBrightness()
            flux = sum([br.getFlux(band) for band in bands])
            fluxes.append(flux)
        Ibright = np.argsort(-np.array(fluxes))

        if len(Ibright) >= 5:
            # -Remember the original subtim images
            # -Compute initial models for each source (in each tim)
            # -Subtract initial models from images
            # -During fitting, for each source:
            #   -add back in the source's initial model (to each tim)
            #   -fit, with Catalog([src])
            #   -subtract final model (from each tim)
            # -Replace original subtim images
            #
            # --Might want to omit newly-added detection-filter sources, since their
            # fluxes are bogus.

            # Remember original tim images
            orig_timages = [tim.getImage().copy() for tim in subtims]
            initial_models = []
            # Create initial models for each tim x each source
            for tim in subtims:
                mods = []
                for src in subcat:
                    mod = src.getModelPatch(tim)
                    mods.append(mod)
                    if mod is not None:
                        if not np.all(np.isfinite(mod.patch)):
                            print('Non-finite mod patch')
                            print('source:', src)
                            print('tim:', tim)
                            print('PSF:', tim.getPsf())
                        assert (np.all(np.isfinite(mod.patch)))
                        mod.addTo(tim.getImage(), scale=-1)
                initial_models.append(mods)
            # For sources in decreasing order of brightness
            for numi, i in enumerate(Ibright):
                tsrc = Time()
                print('Fitting source', i,
                      '(%i of %i in blob)' % (numi, len(Ibright)))
                src = subcat[i]
                print(src)

                srctractor = Tractor(subtims, [src])
                srctractor.freezeParams('images')

                # Add this source's initial model back in.
                for tim, mods in zip(subtims, initial_models):
                    mod = mods[i]
                    if mod is not None:
                        mod.addTo(tim.getImage())

                print('Optimizing:', srctractor)
                srctractor.printThawedParams()
                for step in range(50):
                    dlnp, X, alpha = srctractor.optimize(priors=False,
                                                         shared_params=False,
                                                         alphas=alphas)
                    print('dlnp:', dlnp, 'src', src)
                    if dlnp < 0.1:
                        break

                for tim in subtims:
                    mod = src.getModelPatch(tim)
                    if mod is not None:
                        mod.addTo(tim.getImage(), scale=-1)

            for tim, img in zip(subtims, orig_timages):
                tim.data = img

            del orig_timages
            del initial_models
        else:
            # Fit sources one at a time, but don't subtract other models
            subcat.freezeAllParams()
            for numi, i in enumerate(Ibright):
                tsrc = Time()
                print('Fitting source', i,
                      '(%i of %i in blob)' % (numi, len(Ibright)))
                print(subcat[i])
                subcat.freezeAllBut(i)
                print('Optimizing:', subtr)
                subtr.printThawedParams()
                for step in range(10):
                    dlnp, X, alpha = subtr.optimize(priors=False,
                                                    shared_params=False,
                                                    alphas=alphas)
                    print('dlnp:', dlnp)
                    if dlnp < 0.1:
                        break
                print('Fitting source took', Time() - tsrc)
                print(subcat[i])
        if len(Isrcs) > 1 and len(Isrcs) <= 10:
            tfit = Time()
            # Optimize all at once?
            subcat.thawAllParams()
            print('Optimizing:', subtr)
            subtr.printThawedParams()
            for step in range(20):
                dlnp, X, alpha = subtr.optimize(priors=False,
                                                shared_params=False,
                                                alphas=alphas)
                print('dlnp:', dlnp)
                if dlnp < 0.1:
                    break

        # Variances
        subcat.thawAllRecursive()
        subcat.freezeAllParams()
        for isub, srci in enumerate(Isrcs):
            print('Variances for source', srci)
            subcat.thawParam(isub)

            src = subcat[isub]
            print('Source', src)
            print('Params:', src.getParamNames())

            if isinstance(src, (DevGalaxy, ExpGalaxy)):
                src.shape = EllipseE.fromEllipseESoft(src.shape)
            elif isinstance(src, FixedCompositeGalaxy):
                src.shapeExp = EllipseE.fromEllipseESoft(src.shapeExp)
                src.shapeDev = EllipseE.fromEllipseESoft(src.shapeDev)

            print('Converted ellipse:', src)

            allderivs = subtr.getDerivs()
            for iparam, derivs in enumerate(allderivs):
                dchisq = 0
                for deriv, tim in derivs:
                    h, w = tim.shape
                    deriv.clipTo(w, h)
                    ie = tim.getInvError()
                    slc = deriv.getSlice(ie)
                    chi = deriv.patch * ie[slc]
                    dchisq += (chi**2).sum()
                if dchisq == 0.:
                    v = np.nan
                else:
                    v = 1. / dchisq
                srcvariances[srci].append(v)
            assert (len(srcvariances[srci]) == subcat[isub].numberOfParams())
            subcat.freezeParam(isub)

    cat.thawAllRecursive()

    for i, src in enumerate(cat):
        print('Source', i, src)
        print('variances:', srcvariances[i])
        print(len(srcvariances[i]), 'vs', src.numberOfParams())
        if len(srcvariances[i]) != src.numberOfParams():
            # This can happen for sources outside the brick bounds: they never get optimized?
            print('Warning: zeroing variances for source', src)
            srcvariances[i] = [0] * src.numberOfParams()
            if isinstance(src, (DevGalaxy, ExpGalaxy)):
                src.shape = EllipseE.fromEllipseESoft(src.shape)
            elif isinstance(src, FixedCompositeGalaxy):
                src.shapeExp = EllipseE.fromEllipseESoft(src.shapeExp)
                src.shapeDev = EllipseE.fromEllipseESoft(src.shapeDev)
        assert (len(srcvariances[i]) == src.numberOfParams())

    variances = np.hstack(srcvariances)
    assert (len(variances) == cat.numberOfParams())

    return dict(cat=cat, variances=variances)
Exemple #30
0
def read_raw_decam(F, ext):
    '''
    F: fitsio FITS object
    '''
    img = F[ext].read()
    hdr = F[ext].read_header()
    #print('Image type', img.dtype, img.shape)

    img = img.astype(np.float32)
    #print('Converted image to', img.dtype, img.shape)

    if False:
        mn, mx = np.percentile(img.ravel(), [25, 95])
        kwa = dict(vmin=mn, vmax=mx)

        plt.clf()
        dimshow(img, **kwa)
        plt.title('Raw image')
        ps.savefig()

        M = 200
        plt.clf()
        plt.subplot(2, 2, 1)
        dimshow(img[-M:, :M], ticks=False, **kwa)
        plt.subplot(2, 2, 2)
        dimshow(img[-M:, -M:], ticks=False, **kwa)
        plt.subplot(2, 2, 3)
        dimshow(img[:M, :M], ticks=False, **kwa)
        plt.subplot(2, 2, 4)
        dimshow(img[:M, -M:], ticks=False, **kwa)
        plt.suptitle('Raw corners')
        ps.savefig()

    if 'DESBIAS' in hdr:
        assert (False)
    # DECam RAW image

    # Raw image size 2160 x 4146

    # Subtract median overscan and multiply by gains
    # print('DATASECA', hdr['DATASECA'])
    # print('BIASSECA', hdr['BIASSECA'])
    # print('DATASECB', hdr['DATASECB'])
    # print('BIASSECB', hdr['BIASSECB'])
    dataA = parse_section(hdr['DATASECA'], slices=True)
    biasA = parse_section(hdr['BIASSECA'], slices=True)
    dataB = parse_section(hdr['DATASECB'], slices=True)
    biasB = parse_section(hdr['BIASSECB'], slices=True)
    gainA = hdr['GAINA']
    gainB = hdr['GAINB']
    # print('DataA', dataA)
    # print('BiasA', biasA)
    # print('DataB', dataB)
    # print('BiasB', biasB)

    if False:
        plt.clf()
        plt.plot([np.median(img[i, :]) for i in range(100)], 'b-')
        plt.plot([np.median(img[-(i + 1), :]) for i in range(100)], 'c-')
        plt.plot([np.median(img[:, i]) for i in range(100)], 'r-')
        plt.plot([np.median(img[:, -(i + 1)]) for i in range(100)], 'm-')
        plt.title('Img')
        ps.savefig()

        plt.clf()
        plt.plot([np.median(img[dataA][i, :]) for i in range(100)], 'b-')
        plt.plot([np.median(img[dataA][-(i + 1), :]) for i in range(100)],
                 'c-')
        plt.plot([np.median(img[dataA][:, i]) for i in range(100)], 'r-')
        plt.plot([np.median(img[dataA][:, -(i + 1)]) for i in range(100)],
                 'm-')
        plt.title('Img DataA')
        ps.savefig()

        plt.clf()
        plt.plot([np.median(img[dataB][i, :]) for i in range(100)], 'b-')
        plt.plot([np.median(img[dataB][-(i + 1), :]) for i in range(100)],
                 'c-')
        plt.plot([np.median(img[dataB][:, i]) for i in range(100)], 'r-')
        plt.plot([np.median(img[dataB][:, -(i + 1)]) for i in range(100)],
                 'm-')
        plt.title('Img DataB')
        ps.savefig()

    img[dataA] = (img[dataA] - np.median(img[biasA])) * gainA
    img[dataB] = (img[dataB] - np.median(img[biasB])) * gainB

    # Trim the image -- could just take the min/max of TRIMSECA/TRIMSECB...
    trimA = parse_section(hdr['TRIMSECA'], slices=True)
    trimB = parse_section(hdr['TRIMSECB'], slices=True)
    # copy the TRIM A,B sections into a new image...
    trimg = np.zeros_like(img)
    trimg[trimA] = img[trimA]
    trimg[trimB] = img[trimB]
    # ... and then cut that new image
    trim = parse_section(hdr['TRIMSEC'], slices=True)
    img = trimg[trim]
    #print('Trimmed image:', img.dtype, img.shape)

    return img, hdr
Exemple #31
0
    gal_re = 10.
    
    assert(ph == pw)
    
    ps = PlotSequence('conv')

    plt.clf()
    plt.imshow(pixpsf, interpolation='nearest', origin='lower')
    ps.savefig()
    
    Fpsf = np.fft.rfft2(pixpsf)
    
    plt.clf()
    plt.subplot(2,4,1)
    dimshow(Fpsf.real)
    plt.colorbar()
    plt.title('PSF real')
    plt.subplot(2,4,5)
    dimshow(Fpsf.imag)
    plt.colorbar()
    plt.title('PSF imag')
    ps.savefig()

    # Subsample the PSF via resampling
    from astrometry.util.util import lanczos_shift_image

    scale = 2
    sh,sw = ph*scale, pw*scale
    subpixpsf = np.zeros((sh,sw))
    for ix in np.arange(scale):
Exemple #32
0
def sed_matched_detection(sedname,
                          sed,
                          detmaps,
                          detivs,
                          bands,
                          xomit,
                          yomit,
                          romit,
                          nsigma=5.,
                          saddle_fraction=0.1,
                          saddle_min=2.,
                          saturated_pix=None,
                          veto_map=None,
                          cutonaper=True,
                          ps=None,
                          rgbimg=None):
    '''
    Runs a single SED-matched detection filter.

    Avoids creating sources close to existing sources.

    Parameters
    ----------
    sedname : string
        Name of this SED; only used for plots.
    sed : list of floats
        The SED -- a list of floats, one per band, of this SED.
    detmaps : list of numpy arrays
        The per-band detection maps.  These must all be the same size, the
        brick image size.
    detivs : list of numpy arrays
        The inverse-variance maps associated with `detmaps`.
    bands : list of strings
        The band names of the `detmaps` and `detivs` images.
    xomit, yomit, romit : iterables (lists or numpy arrays) of int
        Previously known sources that are to be avoided; x,y +- radius
    nsigma : float, optional
        Detection threshold.
    saddle_fraction : float, optional
        Fraction of the peak heigh for selecting new sources.
    saddle_min : float, optional
        Saddle-point depth from existing sources down to new sources.
    saturated_pix : None or list of numpy arrays, boolean
        A map of pixels that are always considered "hot" when
        determining whether a new source touches hot pixels of an
        existing source.
    cutonaper : bool, optional
        Apply a cut that the source's detection strength must be greater
        than `nsigma` above the 16th percentile of the detection strength in
        an annulus (from 10 to 20 pixels) around the source.
    ps : PlotSequence object, optional
        Create plots?

    Returns
    -------
    hotblobs : numpy array of bool
        A map of the blobs yielding sources in this SED.
    px, py : numpy array of int
        The new sources found.
    aper : numpy array of float
        The detection strength in the annulus around the source, if
        `cutonaper` is set; else -1.
    peakval : numpy array of float
        The detection strength.

    See also
    --------
    sed_matched_filters : creates the `(sedname, sed)` pairs used here
    run_sed_matched_filters : calls this method
    '''
    from scipy.ndimage.measurements import label, find_objects
    from scipy.ndimage.morphology import binary_dilation, binary_fill_holes

    H, W = detmaps[0].shape
    allzero = True
    for iband in range(len(bands)):
        if sed[iband] == 0:
            continue
        if np.all(detivs[iband] == 0):
            continue
        allzero = False
        break
    if allzero:
        info('SED', sedname, 'has all zero weight')
        return None, None, None, None, None

    sedmap = np.zeros((H, W), np.float32)
    sediv = np.zeros((H, W), np.float32)
    if saturated_pix is not None:
        satur = np.zeros((H, W), bool)

    for iband in range(len(bands)):
        if sed[iband] == 0:
            continue
        # We convert the detmap to canonical band via
        #   detmap * w
        # And the corresponding change to sig1 is
        #   sig1 * w
        # So the invvar-weighted sum is
        #    (detmap * w) / (sig1**2 * w**2)
        #  = detmap / (sig1**2 * w)
        sedmap += detmaps[iband] * detivs[iband] / sed[iband]
        sediv += detivs[iband] / sed[iband]**2
        if saturated_pix is not None:
            satur |= saturated_pix[iband]
    sedmap /= np.maximum(1e-16, sediv)
    sedsn = sedmap * np.sqrt(sediv)
    del sedmap
    peaks = (sedsn > nsigma)

    def saddle_level(Y):
        # Require a saddle that drops by (the larger of) "saddle"
        # sigma, or 10% of the peak height.
        # ("saddle" is passed in as an argument to the
        #  sed_matched_detection function)
        drop = max(saddle_min, Y * saddle_fraction)
        return Y - drop

    lowest_saddle = nsigma - saddle_min

    # zero out the edges -- larger margin here?
    peaks[0, :] = 0
    peaks[:, 0] = 0
    peaks[-1, :] = 0
    peaks[:, -1] = 0

    # Label the N-sigma blobs at this point... we'll use this to build
    # "sedhot", which in turn is used to define the blobs that we will
    # optimize simultaneously.  This also determines which pixels go
    # into the fitting!
    dilate = 8
    hotblobs, nhot = label(
        binary_fill_holes(binary_dilation(peaks, iterations=dilate)))

    # find pixels that are larger than their 8 neighbors
    peaks[1:-1, 1:-1] &= (sedsn[1:-1, 1:-1] >= sedsn[0:-2, 1:-1])
    peaks[1:-1, 1:-1] &= (sedsn[1:-1, 1:-1] >= sedsn[2:, 1:-1])
    peaks[1:-1, 1:-1] &= (sedsn[1:-1, 1:-1] >= sedsn[1:-1, 0:-2])
    peaks[1:-1, 1:-1] &= (sedsn[1:-1, 1:-1] >= sedsn[1:-1, 2:])
    peaks[1:-1, 1:-1] &= (sedsn[1:-1, 1:-1] >= sedsn[0:-2, 0:-2])
    peaks[1:-1, 1:-1] &= (sedsn[1:-1, 1:-1] >= sedsn[0:-2, 2:])
    peaks[1:-1, 1:-1] &= (sedsn[1:-1, 1:-1] >= sedsn[2:, 0:-2])
    peaks[1:-1, 1:-1] &= (sedsn[1:-1, 1:-1] >= sedsn[2:, 2:])

    if ps is not None:
        import pylab as plt
        from astrometry.util.plotutils import dimshow
        plt.clf()
        plt.subplot(1, 2, 2)
        dimshow(sedsn, vmin=-2, vmax=100, cmap='hot', ticks=False)
        plt.subplot(1, 2, 1)
        dimshow(sedsn, vmin=-2, vmax=10, cmap='hot', ticks=False)
        above = (sedsn > nsigma)
        plot_boundary_map(above)
        ax = plt.axis()
        y, x = np.nonzero(peaks)
        plt.plot(xomit, yomit, 'm.')
        plt.plot(x, y, 'r+')
        plt.axis(ax)
        plt.title('SED %s: S/N & peaks' % sedname)
        ps.savefig()

        #import fitsio
        #fitsio.write('sed-sn-%s.fits' % sedname, sedsn)

        # plt.clf()
        # plt.imshow(sedsn, vmin=-2, vmax=10, interpolation='nearest',
        #            origin='lower', cmap='hot')
        # plot_boundary_map(sedsn > lowest_saddle)
        # plt.title('SED %s: S/N & lowest saddle point bounds' % sedname)
        # ps.savefig()

    # For each new source, compute the saddle value, segment at that
    # level, and drop the source if it is in the same blob as a
    # previously-detected source.

    # We dilate the blobs a bit too, to
    # catch slight differences in centroid positions.
    dilate = 1

    # For efficiency, segment at the minimum saddle level to compute
    # slices; the operations described above need only happen within
    # the slice.
    saddlemap = (sedsn > lowest_saddle)
    saddlemap = binary_dilation(saddlemap, iterations=dilate)
    if saturated_pix is not None:
        saddlemap |= satur
    allblobs, _ = label(saddlemap)
    allslices = find_objects(allblobs)
    ally0 = [sy.start for sy, sx in allslices]
    allx0 = [sx.start for sy, sx in allslices]

    # brightest peaks first
    py, px = np.nonzero(peaks)
    I = np.argsort(-sedsn[py, px])
    py = py[I]
    px = px[I]

    keep = np.zeros(len(px), bool)

    peakval = []
    aper = []
    apin = 10
    apout = 20

    # Map of pixels that are vetoed by sources found so far.  The veto
    # area is based on saddle height.  We go from brightest to
    # faintest pixels.  Thus the saddle level decreases, and the
    # saddlemap areas become larger; the saddlemap when a source is
    # found is a lower bound on the pixels that it will veto based on
    # the saddle heights of fainter sources.  Thus the vetomap isn't
    # the final word, it is just a quick veto of pixels we know for
    # sure will be vetoed.
    if veto_map is None:
        this_veto_map = np.zeros(sedsn.shape, bool)
    else:
        this_veto_map = veto_map.copy()

    for x, y, r in zip(xomit, yomit, romit):
        xlo = int(np.clip(np.floor(x - r), 0, W - 1))
        xhi = int(np.clip(np.ceil(x + r), 0, W - 1))
        ylo = int(np.clip(np.floor(y - r), 0, H - 1))
        yhi = int(np.clip(np.ceil(y + r), 0, H - 1))
        this_veto_map[ylo:yhi + 1, xlo:xhi + 1] |= (np.hypot(
            (x - np.arange(xlo, xhi + 1))[np.newaxis, :],
            (y - np.arange(ylo, yhi + 1))[:, np.newaxis]) < r)

    if ps is not None:
        plt.clf()
        plt.imshow(this_veto_map,
                   interpolation='nearest',
                   origin='lower',
                   vmin=0,
                   vmax=1,
                   cmap='hot')
        plt.title('Veto map')
        ps.savefig()

        info('Peaks in initial veto map:', np.sum(this_veto_map[py, px]), 'of',
             len(px))

        plt.clf()
        plt.imshow(saddlemap,
                   interpolation='nearest',
                   origin='lower',
                   vmin=0,
                   vmax=1,
                   cmap='hot')
        ax = plt.axis()
        for slc in allslices:
            sy, sx = slc
            by0, by1 = sy.start, sy.stop
            bx0, bx1 = sx.start, sx.stop
            plt.plot([bx0, bx0, bx1, bx1, bx0], [by0, by1, by1, by0, by0],
                     'r-')
        plt.axis(ax)
        plt.title('Saddle map (lowest level): %i blobs' % len(allslices))
        ps.savefig()

    # For each peak, determine whether it is isolated enough --
    # separated by a low enough saddle from other sources.  Need only
    # search within its "allblob", which is defined by the lowest
    # saddle.
    info('Found', len(px), 'potential peaks')
    nveto = 0
    nsaddle = 0
    naper = 0
    for i, (x, y) in enumerate(zip(px, py)):
        if this_veto_map[y, x]:
            nveto += 1
            continue
        level = saddle_level(sedsn[y, x])
        ablob = allblobs[y, x]
        index = int(ablob - 1)
        slc = allslices[index]
        saddlemap = (sedsn[slc] > level)
        saddlemap = binary_dilation(saddlemap, iterations=dilate)
        if saturated_pix is not None:
            saddlemap |= satur[slc]
        saddlemap *= (allblobs[slc] == ablob)
        saddlemap = binary_fill_holes(saddlemap)
        blobs, _ = label(saddlemap)
        x0, y0 = allx0[index], ally0[index]
        thisblob = blobs[y - y0, x - x0]
        saddlemap *= (blobs == thisblob)

        # previously found sources:
        ox = np.append(xomit, px[:i][keep[:i]]) - x0
        oy = np.append(yomit, py[:i][keep[:i]]) - y0
        h, w = blobs.shape
        cut = False
        if len(ox):
            ox = ox.astype(int)
            oy = oy.astype(int)
            cut = any((ox >= 0) * (ox < w) * (oy >= 0) * (oy < h) *
                      (blobs[np.clip(oy, 0, h - 1),
                             np.clip(ox, 0, w - 1)] == thisblob))

        # one plot per peak is a little excessive!
        if ps is not None and i < 10:
            _peak_plot_1(this_veto_map, x, y, px, py, keep, i, xomit, yomit,
                         sedsn, allblobs, level, dilate, saturated_pix, satur,
                         ps, rgbimg, cut)
        if False and cut and ps is not None:
            _peak_plot_2(ox, oy, w, h, blobs, thisblob, sedsn, x0, y0, x, y,
                         level, ps)
        if False and (not cut) and ps is not None:
            _peak_plot_3(sedsn, nsigma, x, y, x0, y0, slc, saddlemap, xomit,
                         yomit, px, py, keep, i, cut, ps)

        if cut:
            # in same blob as previously found source.
            # update vetomap
            this_veto_map[slc] |= saddlemap
            nsaddle += 1
            continue

        # Measure in aperture...
        ap = sedsn[max(0, y - apout):min(H, y + apout + 1),
                   max(0, x - apout):min(W, x + apout + 1)]
        apiv = (sediv[max(0, y - apout):min(H, y + apout + 1),
                      max(0, x - apout):min(W, x + apout + 1)] > 0)
        aph, apw = ap.shape
        apx0, apy0 = max(0, x - apout), max(0, y - apout)
        R2 = ((np.arange(aph) + apy0 - y)[:, np.newaxis]**2 +
              (np.arange(apw) + apx0 - x)[np.newaxis, :]**2)
        ap = ap[apiv * (R2 >= apin**2) * (R2 <= apout**2)]
        if len(ap):
            # 16th percentile ~ -1 sigma point.
            m = np.percentile(ap, 16.)
        else:
            # fake
            m = -1.
        if cutonaper:
            if sedsn[y, x] - m < nsigma:
                naper += 1
                continue

        aper.append(m)
        peakval.append(sedsn[y, x])
        keep[i] = True
        this_veto_map[slc] |= saddlemap

        if False and ps is not None:
            plt.clf()
            plt.subplot(1, 2, 1)
            dimshow(ap,
                    vmin=-2,
                    vmax=10,
                    cmap='hot',
                    extent=[apx0, apx0 + apw, apy0, apy0 + aph])
            plt.subplot(1, 2, 2)
            dimshow(ap * ((R2 >= apin**2) * (R2 <= apout**2)),
                    vmin=-2,
                    vmax=10,
                    cmap='hot',
                    extent=[apx0, apx0 + apw, apy0, apy0 + aph])
            plt.suptitle('peak %.1f vs ap %.1f' % (sedsn[y, x], m))
            ps.savefig()

    info('Of', len(px), 'potential peaks:', nveto, 'in veto map,',
         nsaddle, 'cut by saddle test,', naper, 'cut by aper test,',
         np.sum(keep), 'kept')

    if ps is not None:
        pxdrop = px[np.logical_not(keep)]
        pydrop = py[np.logical_not(keep)]
    py = py[keep]
    px = px[keep]

    # Which of the hotblobs yielded sources?  Those are the ones to keep.
    hbmap = np.zeros(nhot + 1, bool)
    hbmap[hotblobs[py, px]] = True
    if len(xomit):
        h, w = hotblobs.shape
        hbmap[hotblobs[np.clip(yomit, 0, h - 1),
                       np.clip(xomit, 0, w - 1)]] = True
    # in case a source is (somehow) not in a hotblob?
    hbmap[0] = False
    hotblobs = hbmap[hotblobs]

    if ps is not None:
        plt.clf()
        dimshow(this_veto_map, vmin=0, vmax=1, cmap='hot')
        plt.title('SED %s: veto map' % sedname)
        ps.savefig()

        plt.clf()
        dimshow(hotblobs, vmin=0, vmax=1, cmap='hot')
        ax = plt.axis()
        p2 = plt.plot(pxdrop, pydrop, 'm+', ms=8, mew=2)
        p1 = plt.plot(px, py, 'g+', ms=8, mew=2)
        p3 = plt.plot(xomit, yomit, 'r+', ms=8, mew=2)
        plt.axis(ax)
        plt.title('SED %s: hot blobs' % sedname)
        plt.figlegend((p3[0], p1[0], p2[0]), ('Existing', 'Keep', 'Drop'),
                      'upper left')
        ps.savefig()

    return hotblobs, px, py, aper, peakval
Exemple #33
0
    def test_psfex(self):

        if ps is not None:
            from astrometry.util.plotutils import dimshow
            import pylab as plt

        H, W = 100, 100
        cx, cy = W / 2., H / 2.

        pixpsf = self.psf.constantPsfAt(cx, cy)

        ph, pw = pixpsf.shape
        xx, yy = np.meshgrid(np.arange(pw), np.arange(ph))
        im = pixpsf.img.copy()
        im /= np.sum(im)
        cenx, ceny = np.sum(im * xx), np.sum(im * yy)
        print('Pixpsf centroid:', cenx, ceny)
        print('shape:', ph, pw)

        dx, dy = cenx - pw // 2, ceny - ph // 2
        print('dx,dy', dx, dy)

        # gpsf = GaussianMixturePSF.fromStamp(im, N=1)
        # print('Fit gpsf:', gpsf)
        # self.assertTrue(np.abs(gpsf.mog.mean[0,0] - dx) < 0.1)
        # self.assertTrue(np.abs(gpsf.mog.mean[0,1] - dy) < 0.1)
        # self.assertTrue(np.abs(gpsf.mog.var[0,0,0] - 15.5) < 1.)
        # self.assertTrue(np.abs(gpsf.mog.var[0,1,1] - 13.5) < 1.)
        # self.assertTrue(np.abs(gpsf.mog.var[0,1,0] -   -1) < 1.)

        gpsf = GaussianMixturePSF.fromStamp(im, N=2)
        print('Fit gpsf:', gpsf)
        print('Params:', ', '.join(['%.1f' % p for p in gpsf.getParams()]))

        pp = np.array(
            [0.8, 0.2, 0.1, -0.0, 1.2, 0.2, 7.6, 6.0, -1.0, 51.6, 49.1, -1.3])
        self.assertTrue(np.all(np.abs(np.array(gpsf.getParams()) - pp) < 0.1))

        tim = Image(data=np.zeros((H, W)),
                    invvar=np.ones((H, W)),
                    psf=self.psf)

        xx, yy = np.meshgrid(np.arange(W), np.arange(H))

        star = PointSource(PixPos(cx, cy), Flux(100.))
        gal = ExpGalaxy(PixPos(cx, cy), Flux(100.), EllipseE(1., 0., 0.))

        tr1 = Tractor([tim], [star])
        tr2 = Tractor([tim], [gal])

        disable_galaxy_cache()

        tim.psf = self.psf
        mod = tr1.getModelImage(0)
        mod1 = mod

        im = mod.copy()
        im /= im.sum()
        cenx, ceny = np.sum(im * xx), np.sum(im * yy)
        print('Star model + PsfEx centroid', cenx, ceny)

        self.assertTrue(np.abs(cenx - (cx + dx)) < 0.1)
        self.assertTrue(np.abs(ceny - (cy + dy)) < 0.1)

        if ps is not None:
            plt.clf()
            dimshow(mod)
            plt.title('Star model, PsfEx')
            ps.savefig()

        tim.psf = pixpsf

        mod = tr1.getModelImage(0)

        if ps is not None:
            plt.clf()
            dimshow(mod)
            plt.title('Star model, pixpsf')
            ps.savefig()

        tim.psf = gpsf

        mod = tr1.getModelImage(0)
        mod2 = mod

        if ps is not None:
            plt.clf()
            dimshow(mod)
            plt.title('Star model, gpsf')
            plt.colorbar()
            ps.savefig()

            plt.clf()
            dimshow(mod1 - mod2)
            plt.title('Star model, PsfEx - gpsf')
            plt.colorbar()
            ps.savefig()

        # range ~ -0.15 to +0.25
        self.assertTrue(np.all(np.abs(mod1 - mod2) < 0.25))

        tim.psf = self.psf
        mod = tr2.getModelImage(0)
        mod1 = mod

        im = mod.copy()
        im /= im.sum()
        cenx, ceny = np.sum(im * xx), np.sum(im * yy)
        print('Gal model + PsfEx centroid', cenx, ceny)

        self.assertTrue(np.abs(cenx - (cx + dx)) < 0.1)
        self.assertTrue(np.abs(ceny - (cy + dy)) < 0.1)

        if ps is not None:
            plt.clf()
            dimshow(mod)
            plt.title('Gal model, PsfEx')
            ps.savefig()

        # tim.psf = pixpsf
        # mod = tr2.getModelImage(0)
        # plt.clf()
        # dimshow(mod)
        # plt.title('Gal model, pixpsf')
        # ps.savefig()

        tim.psf = gpsf
        mod = tr2.getModelImage(0)
        mod2 = mod
        # range ~ -0.1 to +0.2
        self.assertTrue(np.all(np.abs(mod1 - mod2) < 0.2))

        if ps is not None:
            plt.clf()
            dimshow(mod)
            plt.title('Gal model, gpsf')
            ps.savefig()

            plt.clf()
            dimshow(mod1 - mod2)
            plt.title('Gal model, PsfEx - gpsf')
            plt.colorbar()
            ps.savefig()
Exemple #34
0
            for x in XX:
                psfimg = psfex.instantiateAt(x, y)
                psfgrid.append(psfimg)
                h, w = psfimg.shape
                img = psfimg[h / 2 - crop:h / 2 + crop + 1,
                             w / 2 - crop:w / 2 + crop + 1]
                psfcropgrid.append(img)
        mx = np.max([psfimg.max() for psfimg in psfgrid])
        logmx = np.log10(mx)
        plt.clf()
        for i, psfimg in enumerate(psfcropgrid):
            #plt.subplot(len(YY), len(XX), i+1)
            subplot_grid(len(YY), len(XX), i)
            dimshow(np.log10(np.maximum(psfimg, mx * 1e-16)),
                    vmax=logmx,
                    vmin=logmx - 4,
                    ticks=False,
                    cmap='jet')
        plt.suptitle('PsfEx models')
        ps.savefig()

        modnames = [
            'Dense-grid MoG', 'Coarse-grid MoG', 'Dense-grid MoG (variance)',
            'Coarse-grid MoG (variance)'
        ]
        models = [psfex, subpsfex, psfexvar, subpsfexvar]

        modgrids = [[] for m in models]

        for iy, y in enumerate(YY):
            for ix, x in enumerate(XX):
Exemple #35
0
def main():
    decals = Decals()

    catpattern = 'pipebrick-cats/tractor-phot-b%06i.fits'
    ra, dec = 242, 7

    # Region-of-interest, in pixels: x0, x1, y0, y1
    #roi = None
    roi = [500, 1000, 500, 1000]

    if roi is not None:
        x0, x1, y0, y1 = roi

    #expnum = 346623
    #ccdname = 'N12'
    #chips = decals.find_ccds(expnum=expnum, extname=ccdname)
    #print 'Found', len(chips), 'chips for expnum', expnum, 'extname', ccdname
    #if len(chips) != 1:
    #return False

    chips = decals.get_ccds()
    D = np.argsort(np.hypot(chips.ra - ra, chips.dec - dec))
    print('Closest chip:', chips[D[0]])
    chips = [chips[D[0]]]

    im = DecamImage(decals, chips[0])
    print('Image:', im)

    targetwcs = Sip(im.wcsfn)
    if roi is not None:
        targetwcs = targetwcs.get_subimage(x0, y0, x1 - x0, y1 - y0)

    r0, r1, d0, d1 = targetwcs.radec_bounds()
    # ~ 30-pixel margin
    margin = 2e-3
    if r0 > r1:
        # RA wrap-around
        TT = [
            brick_catalog_for_radec_box(ra, rb, d0 - margin, d1 + margin,
                                        decals, catpattern)
            for (ra, rb) in [(0, r1 + margin), (r0 - margin, 360.)]
        ]
        T = merge_tables(TT)
        T._header = TT[0]._header
    else:
        T = brick_catalog_for_radec_box(r0 - margin, r1 + margin, d0 - margin,
                                        d1 + margin, decals, catpattern)

    print('Got', len(T), 'catalog entries within range')
    cat = read_fits_catalog(T, T._header)
    print('Got', len(cat), 'catalog objects')

    print('Switching ellipse parameterizations')
    switch_to_soft_ellipses(cat)
    keepcat = []
    for src in cat:
        if not np.all(np.isfinite(src.getParams())):
            print('Src has infinite params:', src)
            continue
        if isinstance(src, FixedCompositeGalaxy):
            f = src.fracDev.getClippedValue()
            if f == 0.:
                src = ExpGalaxy(src.pos, src.brightness, src.shapeExp)
            elif f == 1.:
                src = DevGalaxy(src.pos, src.brightness, src.shapeDev)
        keepcat.append(src)
    cat = keepcat

    slc = None
    if roi is not None:
        slc = slice(y0, y1), slice(x0, x1)
    tim = im.get_tractor_image(slc=slc)
    print('Got', tim)
    tim.psfex.fitSavedData(*tim.psfex.splinedata)
    tim.psfex.radius = 20
    tim.psf = CachingPsfEx.fromPsfEx(tim.psfex)

    tractor = Tractor([tim], cat)
    print('Created', tractor)

    mod = tractor.getModelImage(0)

    plt.clf()
    dimshow(tim.getImage(), **tim.ima)
    plt.title('Image')
    plt.savefig('1.png')

    plt.clf()
    dimshow(mod, **tim.ima)
    plt.title('Model')
    plt.savefig('2.png')

    ok, x, y = targetwcs.radec2pixelxy([src.getPosition().ra for src in cat],
                                       [src.getPosition().dec for src in cat])
    ax = plt.axis()
    plt.plot(x, y, 'rx')
    #plt.savefig('3.png')
    plt.axis(ax)
    plt.title('Sources')
    plt.savefig('3.png')

    bands = [im.band]
    import runbrick
    runbrick.photoobjdir = '.'
    scat, T = get_sdss_sources(bands, targetwcs, local=False)
    print('Got', len(scat), 'SDSS sources in bounds')

    stractor = Tractor([tim], scat)
    print('Created', stractor)
    smod = stractor.getModelImage(0)

    plt.clf()
    dimshow(smod, **tim.ima)
    plt.title('SDSS model')
    plt.savefig('4.png')
Exemple #36
0
def stage0(**kwargs):
    ps = PlotSequence('cfht')

    decals = CfhtDecals()
    B = decals.get_bricks()
    print('Bricks:')
    B.about()

    ra, dec = 190.0, 11.0

    #bands = 'ugri'
    bands = 'gri'

    B.cut(np.argsort(degrees_between(ra, dec, B.ra, B.dec)))
    print('Nearest bricks:', B.ra[:5], B.dec[:5], B.brickid[:5])

    brick = B[0]
    pixscale = 0.186
    #W,H = 1024,1024
    #W,H = 2048,2048
    #W,H = 3600,3600
    W, H = 4800, 4800

    targetwcs = wcs_for_brick(brick, pixscale=pixscale, W=W, H=H)
    ccdfn = 'cfht-ccds.fits'
    if os.path.exists(ccdfn):
        T = fits_table(ccdfn)
    else:
        T = get_ccd_list()
        T.writeto(ccdfn)
    print(len(T), 'CCDs')
    T.cut(ccds_touching_wcs(targetwcs, T))
    print(len(T), 'CCDs touching brick')

    T.cut(np.array([b in bands for b in T.filter]))
    print(len(T), 'in bands', bands)

    ims = []
    for t in T:
        im = CfhtImage(t)
        # magzp = hdr['PHOT_C'] + 2.5 * np.log10(hdr['EXPTIME'])
        # fwhm = t.seeing / (pixscale * 3600)
        # print '-> FWHM', fwhm, 'pix'
        im.seeing = t.seeing
        im.pixscale = t.pixscale
        print('seeing', t.seeing)
        print('pixscale', im.pixscale * 3600, 'arcsec/pix')
        im.run_calibs(t.ra, t.dec, im.pixscale, W=t.width, H=t.height)
        ims.append(im)

    # Read images, clip to ROI
    targetrd = np.array([
        targetwcs.pixelxy2radec(x, y)
        for x, y in [(1, 1), (W, 1), (W, H), (1, H), (1, 1)]
    ])
    keepims = []
    tims = []
    for im in ims:
        print()
        print('Reading expnum', im.expnum, 'name', im.extname, 'band', im.band,
              'exptime', im.exptime)
        band = im.band
        wcs = im.read_wcs()
        imh, imw = wcs.imageh, wcs.imagew
        imgpoly = [(1, 1), (1, imh), (imw, imh), (imw, 1)]
        ok, tx, ty = wcs.radec2pixelxy(targetrd[:-1, 0], targetrd[:-1, 1])
        tpoly = zip(tx, ty)
        clip = clip_polygon(imgpoly, tpoly)
        clip = np.array(clip)
        #print 'Clip', clip
        if len(clip) == 0:
            continue
        x0, y0 = np.floor(clip.min(axis=0)).astype(int)
        x1, y1 = np.ceil(clip.max(axis=0)).astype(int)
        slc = slice(y0, y1 + 1), slice(x0, x1 + 1)

        ## FIXME -- it seems I got lucky and the cross product is
        ## negative == clockwise, as required by clip_polygon. One
        ## could check this and reverse the polygon vertex order.
        # dx0,dy0 = tx[1]-tx[0], ty[1]-ty[0]
        # dx1,dy1 = tx[2]-tx[1], ty[2]-ty[1]
        # cross = dx0*dy1 - dx1*dy0
        # print 'Cross:', cross

        print('Image slice: x [%i,%i], y [%i,%i]' % (x0, x1, y0, y1))
        print('Reading image from', im.imgfn, 'HDU', im.hdu)
        img, imghdr = im.read_image(header=True, slice=slc)
        goodpix = (img != 0)
        print('Number of pixels == 0:', np.sum(img == 0))
        print('Number of pixels != 0:', np.sum(goodpix))
        if np.sum(goodpix) == 0:
            continue
        # print 'Image shape', img.shape
        print('Image range', img.min(), img.max())
        print('Goodpix image range:', (img[goodpix]).min(),
              (img[goodpix]).max())
        if img[goodpix].min() == img[goodpix].max():
            print('No dynamic range in image')
            continue
        # print 'Reading invvar from', im.wtfn, 'HDU', im.hdu
        # invvar = im.read_invvar(slice=slc)
        # # print 'Invvar shape', invvar.shape
        # # print 'Invvar range:', invvar.min(), invvar.max()
        # invvar[goodpix == 0] = 0.
        # if np.all(invvar == 0.):
        #     print 'Skipping zero-invvar image'
        #     continue
        # assert(np.all(np.isfinite(img)))
        # assert(np.all(np.isfinite(invvar)))
        # assert(not(np.all(invvar == 0.)))
        # # Estimate per-pixel noise via Blanton's 5-pixel MAD
        # slice1 = (slice(0,-5,10),slice(0,-5,10))
        # slice2 = (slice(5,None,10),slice(5,None,10))
        # # print 'sliced shapes:', img[slice1].shape, img[slice2].shape
        # # print 'good shape:', (goodpix[slice1] * goodpix[slice2]).shape
        # # print 'good values:', np.unique(goodpix[slice1] * goodpix[slice2])
        # # print 'sliced[good] shapes:', (img[slice1] -  img[slice2])[goodpix[slice1] * goodpix[slice2]].shape
        # mad = np.median(np.abs(img[slice1] - img[slice2])[goodpix[slice1] * goodpix[slice2]].ravel())
        # sig1 = 1.4826 * mad / np.sqrt(2.)
        # print 'MAD sig1:', sig1
        # # invvar was 1 or 0
        # invvar *= (1./(sig1**2))
        # medsky = np.median(img[goodpix])

        # Read full image for sig1 and sky estimate
        fullimg = im.read_image()
        fullgood = (fullimg != 0)
        # Estimate per-pixel noise via Blanton's 5-pixel MAD
        slice1 = (slice(0, -5, 10), slice(0, -5, 10))
        slice2 = (slice(5, None, 10), slice(5, None, 10))
        mad = np.median(
            np.abs(fullimg[slice1] -
                   fullimg[slice2])[fullgood[slice1] *
                                    fullgood[slice2]].ravel())
        sig1 = 1.4826 * mad / np.sqrt(2.)
        print('MAD sig1:', sig1)
        medsky = np.median(fullimg[fullgood])
        invvar = np.zeros_like(img)
        invvar[goodpix] = 1. / sig1**2

        # Median-smooth sky subtraction
        plt.clf()
        dimshow(np.round((img - medsky) / sig1), vmin=-3, vmax=5)
        plt.title('Scalar median: %s' % im.name)
        ps.savefig()

        # medsky = np.zeros_like(img)
        # # astrometry.util.util
        # median_smooth(img, np.logical_not(goodpix), 256, medsky)
        fullmed = np.zeros_like(fullimg)
        median_smooth(fullimg - medsky, np.logical_not(fullgood), 256, fullmed)
        fullmed += medsky
        medimg = fullmed[slc]

        plt.clf()
        dimshow(np.round((img - medimg) / sig1), vmin=-3, vmax=5)
        plt.title('Median filtered: %s' % im.name)
        ps.savefig()

        #print 'Subtracting median:', medsky
        #img -= medsky
        img -= medimg

        primhdr = im.read_image_primary_header()

        magzp = decals.get_zeropoint_for(im)
        print('magzp', magzp)
        zpscale = NanoMaggies.zeropointToScale(magzp)
        print('zpscale', zpscale)

        # Scale images to Nanomaggies
        img /= zpscale
        sig1 /= zpscale
        invvar *= zpscale**2
        orig_zpscale = zpscale

        zpscale = 1.
        assert (np.sum(invvar > 0) > 0)
        print('After scaling:')
        print('sig1', sig1)
        print('invvar range', invvar.min(), invvar.max())
        print('image range', img.min(), img.max())

        assert (np.all(np.isfinite(img)))
        assert (np.all(np.isfinite(invvar)))
        assert (np.isfinite(sig1))

        plt.clf()
        lo, hi = -5 * sig1, 10 * sig1
        n, b, p = plt.hist(img[goodpix].ravel(),
                           100,
                           range=(lo, hi),
                           histtype='step',
                           color='k')
        xx = np.linspace(lo, hi, 200)
        plt.plot(xx, max(n) * np.exp(-xx**2 / (2. * sig1**2)), 'r-')
        plt.xlim(lo, hi)
        plt.title('Pixel histogram: %s' % im.name)
        ps.savefig()

        twcs = ConstantFitsWcs(wcs)
        if x0 or y0:
            twcs.setX0Y0(x0, y0)

        info = im.get_image_info()
        fullh, fullw = info['dims']

        # read fit PsfEx model
        psfex = PsfEx.fromFits(im.psffitfn)
        print('Read', psfex)

        # HACK -- highly approximate PSF here!
        #psf_fwhm = imghdr['FWHM']
        #psf_fwhm = im.seeing

        psf_fwhm = im.seeing / (im.pixscale * 3600)
        print('PSF FWHM', psf_fwhm, 'pixels')
        psf_sigma = psf_fwhm / 2.35
        psf = NCircularGaussianPSF([psf_sigma], [1.])

        print('img type', img.dtype)

        tim = Image(img,
                    invvar=invvar,
                    wcs=twcs,
                    psf=psf,
                    photocal=LinearPhotoCal(zpscale, band=band),
                    sky=ConstantSky(0.),
                    name=im.name + ' ' + band)
        tim.zr = [-3. * sig1, 10. * sig1]
        tim.sig1 = sig1
        tim.band = band
        tim.psf_fwhm = psf_fwhm
        tim.psf_sigma = psf_sigma
        tim.sip_wcs = wcs
        tim.x0, tim.y0 = int(x0), int(y0)
        tim.psfex = psfex
        tim.imobj = im
        mn, mx = tim.zr
        tim.ima = dict(interpolation='nearest',
                       origin='lower',
                       cmap='gray',
                       vmin=mn,
                       vmax=mx)
        tims.append(tim)
        keepims.append(im)

    ims = keepims

    print('Computing resampling...')
    # save resampling params
    for tim in tims:
        wcs = tim.sip_wcs
        subh, subw = tim.shape
        subwcs = wcs.get_subimage(tim.x0, tim.y0, subw, subh)
        tim.subwcs = subwcs
        try:
            Yo, Xo, Yi, Xi, rims = resample_with_wcs(targetwcs, subwcs, [], 2)
        except OverlapError:
            print('No overlap')
            continue
        if len(Yo) == 0:
            continue
        tim.resamp = (Yo, Xo, Yi, Xi)

    print('Creating coadds...')
    # Produce per-band coadds, for plots
    coimgs = []
    cons = []
    for ib, band in enumerate(bands):
        coimg = np.zeros((H, W), np.float32)
        con = np.zeros((H, W), np.uint8)
        for tim in tims:
            if tim.band != band:
                continue
            (Yo, Xo, Yi, Xi) = tim.resamp
            if len(Yo) == 0:
                continue
            nn = (tim.getInvvar()[Yi, Xi] > 0)
            coimg[Yo, Xo] += tim.getImage()[Yi, Xi] * nn
            con[Yo, Xo] += nn

            # print
            # print 'tim', tim.name
            # print 'number of resampled pix:', len(Yo)
            # reim = np.zeros_like(coimg)
            # ren  = np.zeros_like(coimg)
            # reim[Yo,Xo] = tim.getImage()[Yi,Xi] * nn
            # ren[Yo,Xo] = nn
            # print 'number of resampled pix with positive invvar:', ren.sum()
            # plt.clf()
            # plt.subplot(2,2,1)
            # mn,mx = [np.percentile(reim[ren>0], p) for p in [25,95]]
            # print 'Percentiles:', mn,mx
            # dimshow(reim, vmin=mn, vmax=mx)
            # plt.colorbar()
            # plt.subplot(2,2,2)
            # dimshow(con)
            # plt.colorbar()
            # plt.subplot(2,2,3)
            # dimshow(reim, vmin=tim.zr[0], vmax=tim.zr[1])
            # plt.colorbar()
            # plt.subplot(2,2,4)
            # plt.hist(reim.ravel(), 100, histtype='step', color='b')
            # plt.hist(tim.getImage().ravel(), 100, histtype='step', color='r')
            # plt.suptitle('%s: %s' % (band, tim.name))
            # ps.savefig()

        coimg /= np.maximum(con, 1)
        coimgs.append(coimg)
        cons.append(con)

    plt.clf()
    dimshow(get_rgb(coimgs, bands))
    ps.savefig()

    plt.clf()
    for i, b in enumerate(bands):
        plt.subplot(2, 2, i + 1)
        dimshow(cons[i], ticks=False)
        plt.title('%s band' % b)
        plt.colorbar()
    plt.suptitle('Number of exposures')
    ps.savefig()

    print('Grabbing SDSS sources...')
    bandlist = [b for b in bands]
    cat, T = get_sdss_sources(bandlist, targetwcs)
    # record coordinates in target brick image
    ok, T.tx, T.ty = targetwcs.radec2pixelxy(T.ra, T.dec)
    T.tx -= 1
    T.ty -= 1
    T.itx = np.clip(np.round(T.tx).astype(int), 0, W - 1)
    T.ity = np.clip(np.round(T.ty).astype(int), 0, H - 1)

    plt.clf()
    dimshow(get_rgb(coimgs, bands))
    ax = plt.axis()
    plt.plot(T.tx, T.ty, 'o', mec=green, mfc='none', ms=10, mew=1.5)
    plt.axis(ax)
    plt.title('SDSS sources')
    ps.savefig()

    print('Detmaps...')
    # Render the detection maps
    detmaps = dict([(b, np.zeros((H, W), np.float32)) for b in bands])
    detivs = dict([(b, np.zeros((H, W), np.float32)) for b in bands])
    for tim in tims:
        iv = tim.getInvvar()
        psfnorm = 1. / (2. * np.sqrt(np.pi) * tim.psf_sigma)
        detim = tim.getImage().copy()
        detim[iv == 0] = 0.
        detim = gaussian_filter(detim, tim.psf_sigma) / psfnorm**2
        detsig1 = tim.sig1 / psfnorm
        subh, subw = tim.shape
        detiv = np.zeros((subh, subw), np.float32) + (1. / detsig1**2)
        detiv[iv == 0] = 0.
        (Yo, Xo, Yi, Xi) = tim.resamp
        detmaps[tim.band][Yo, Xo] += detiv[Yi, Xi] * detim[Yi, Xi]
        detivs[tim.band][Yo, Xo] += detiv[Yi, Xi]

    rtn = dict()
    for k in [
            'T', 'coimgs', 'cons', 'detmaps', 'detivs', 'targetrd', 'pixscale',
            'targetwcs', 'W', 'H', 'bands', 'tims', 'ps', 'brick', 'cat'
    ]:
        rtn[k] = locals()[k]
    return rtn
Exemple #37
0
        psfcropgrid = []
        crop = 10
        for y in YY:
            for x in XX:
                psfimg = psfex.instantiateAt(x, y)
                psfgrid.append(psfimg)
                h,w = psfimg.shape
                img = psfimg[h/2-crop:h/2+crop+1, w/2-crop:w/2+crop+1]
                psfcropgrid.append(img)
        mx = np.max([psfimg.max() for psfimg in psfgrid])
        logmx = np.log10(mx)
        plt.clf()
        for i,psfimg in enumerate(psfcropgrid):
            #plt.subplot(len(YY), len(XX), i+1)
            subplot_grid(len(YY), len(XX), i)
            dimshow(np.log10(np.maximum(psfimg, mx*1e-16)),
                    vmax=logmx, vmin=logmx-4, ticks=False, cmap='jet')
        plt.suptitle('PsfEx models')
        ps.savefig()


        modnames = ['Dense-grid MoG', 'Coarse-grid MoG',
                    'Dense-grid MoG (variance)', 'Coarse-grid MoG (variance)']
        models = [ psfex, subpsfex, psfexvar, subpsfexvar ]

        modgrids = [[] for m in models]
        
        for iy,y in enumerate(YY):
            for ix,x in enumerate(XX):

                for model,grid in zip(models, modgrids):
                    psf = model.psfAt(x, y)
def _debug_plots(srctractor, ps):
    thislnp0 = srctractor.getLogProb()
    p0 = np.array(srctractor.getParams())
    print('logprob:', p0, '=', thislnp0)

    print('p0 type:', p0.dtype)
    px = p0 + np.zeros_like(p0)
    srctractor.setParams(px)
    lnpx = srctractor.getLogProb()
    assert (lnpx == thislnp0)
    print('logprob:', px, '=', lnpx)

    scales = srctractor.getParameterScales()
    print('Parameter scales:', scales)
    print('Parameters:')
    srctractor.printThawedParams()

    # getParameterScales better not have changed the params!!
    assert (np.all(p0 == np.array(srctractor.getParams())))
    assert (srctractor.getLogProb() == thislnp0)

    pfinal = srctractor.getParams()
    pnames = srctractor.getParamNames()

    plt.figure(3, figsize=(8, 6))

    plt.clf()
    for i in range(len(scales)):
        plt.plot([(p[i] - pfinal[i]) * scales[i] for lnp, p in params],
                 [lnp for lnp, p in params],
                 '-',
                 label=pnames[i])
    plt.ylabel('lnp')
    plt.legend()
    plt.title('scaled')
    ps.savefig()

    for i in range(len(scales)):
        plt.clf()
        #plt.subplot(2,1,1)
        plt.plot([p[i] for lnp, p in params], '-')
        plt.xlabel('step')
        plt.title(pnames[i])
        ps.savefig()

        plt.clf()
        plt.plot([p[i] for lnp, p in params], [lnp for lnp, p in params],
                 'b.-')

        # We also want to know about d(lnp)/d(param)
        # and d(lnp)/d(X)
        step = 1.1
        steps = 1.1**np.arange(-20, 21)
        s2 = np.linspace(0, steps[0], 10)[1:-1]
        steps = reduce(np.append, [-steps[::-1], -s2[::-1], 0, s2, steps])
        print('Steps:', steps)

        plt.plot(p0[i], thislnp0, 'bx', ms=20)

        print('Stepping in param', pnames[i], '...')
        pp = p0.copy()
        lnps, parms = [], []
        for s in steps:
            parm = p0[i] + s / scales[i]
            pp[i] = parm
            srctractor.setParams(pp)
            lnp = srctractor.getLogProb()
            parms.append(parm)
            lnps.append(lnp)
            print('logprob:', pp, '=', lnp)

        plt.plot(parms, lnps, 'k.-')
        j = np.argmin(np.abs(steps - 1.))
        plt.plot(parms[j], lnps[j], 'ko')

        print('Stepping in X...')
        lnps, parms = [], []
        for s in steps:
            pp = p0 + s * X
            srctractor.setParams(pp)
            lnp = srctractor.getLogProb()
            parms.append(pp[i])
            lnps.append(lnp)
            print('logprob:', pp, '=', lnp)

        ##
        s3 = s2[:2]
        ministeps = reduce(np.append, [-s3[::-1], 0, s3])
        print('mini steps:', ministeps)
        for s in ministeps:
            pp = p0 + s * X
            srctractor.setParams(pp)
            lnp = srctractor.getLogProb()
            print('logprob:', pp, '=', lnp)

        rows = len(ministeps)
        cols = len(srctractor.images)

        plt.figure(4, figsize=(8, 6))
        plt.subplots_adjust(hspace=0.05,
                            wspace=0.05,
                            left=0.01,
                            right=0.99,
                            bottom=0.01,
                            top=0.99)
        plt.clf()
        k = 1
        mods = []
        for s in ministeps:
            pp = p0 + s * X
            srctractor.setParams(pp)
            print('ministep', s)
            print('log prior', srctractor.getLogPrior())
            print('log likelihood', srctractor.getLogLikelihood())
            mods.append(srctractor.getModelImages())
            chis = srctractor.getChiImages()
            # for chi in chis:
            #     plt.subplot(rows, cols, k)
            #     k += 1
            #     dimshow(chi, ticks=False, vmin=-10, vmax=10, cmap='jet')
            print('chisqs:', [(chi**2).sum() for chi in chis])
            print('sum:', sum([(chi**2).sum() for chi in chis]))

        mod0 = mods[len(ministeps) / 2]
        for modlist in mods:
            for mi, mod in enumerate(modlist):
                plt.subplot(rows, cols, k)
                k += 1
                m0 = mod0[mi]
                rng = m0.max() - m0.min()
                dimshow(mod - mod0[mi],
                        vmin=-0.01 * rng,
                        vmax=0.01 * rng,
                        ticks=False,
                        cmap='gray')
        ps.savefig()
        plt.figure(3)

        plt.plot(parms, lnps, 'r.-')

        print('Stepping in X by alphas...')
        lnps = []
        for cc, ss in [('m', 0.1), ('m', 0.3), ('r', 1)]:
            pp = p0 + ss * X
            srctractor.setParams(pp)
            lnp = srctractor.getLogProb()
            print('logprob:', pp, '=', lnp)

            plt.plot(p0[i] + ss * X[i], lnp, 'o', color=cc)
            lnps.append(lnp)

        px = p0[i] + X[i]
        pmid = (px + p0[i]) / 2.
        dp = np.abs((px - pmid) * 2.)
        hi, lo = max(max(lnps), thislnp0), min(min(lnps), thislnp0)
        lnpmid = (hi + lo) / 2.
        dlnp = np.abs((hi - lo) * 2.)

        plt.ylabel('lnp')
        plt.title(pnames[i])
        ps.savefig()

        plt.axis([pmid - dp, pmid + dp, lnpmid - dlnp, lnpmid + dlnp])
        ps.savefig()

    srctractor.setParams(p0)
Exemple #39
0
def sed_matched_detection(sedname,
                          sed,
                          detmaps,
                          detivs,
                          bands,
                          xomit,
                          yomit,
                          nsigma=5.,
                          saturated_pix=None,
                          saddle=2.,
                          cutonaper=True,
                          ps=None):
    '''
    Runs a single SED-matched detection filter.

    Avoids creating sources close to existing sources.

    Parameters
    ----------
    sedname : string
        Name of this SED; only used for plots.
    sed : list of floats
        The SED -- a list of floats, one per band, of this SED.
    detmaps : list of numpy arrays
        The per-band detection maps.  These must all be the same size, the
        brick image size.
    detivs : list of numpy arrays
        The inverse-variance maps associated with `detmaps`.
    bands : list of strings
        The band names of the `detmaps` and `detivs` images.
    xomit, yomit : iterables (lists or numpy arrays) of int
        Previously known sources that are to be avoided.
    nsigma : float, optional
        Detection threshold.
    saturated_pix : None or numpy array, boolean
        A map of pixels that are always considered "hot" when
        determining whether a new source touches hot pixels of an
        existing source.
    saddle : float, optional
        Saddle-point depth from existing sources down to new sources.
    cutonaper : bool, optional
        Apply a cut that the source's detection strength must be greater
        than `nsigma` above the 16th percentile of the detection strength in
        an annulus (from 10 to 20 pixels) around the source.
    ps : PlotSequence object, optional
        Create plots?

    Returns
    -------
    hotblobs : numpy array of bool
        A map of the blobs yielding sources in this SED.
    px, py : numpy array of int
        The new sources found.
    aper : numpy array of float
        The detection strength in the annulus around the source, if
        `cutonaper` is set; else -1.
    peakval : numpy array of float
        The detection strength.

    See also
    --------
    sed_matched_filters : creates the `(sedname, sed)` pairs used here
    run_sed_matched_filters : calls this method
    
    '''
    from scipy.ndimage.measurements import label, find_objects
    from scipy.ndimage.morphology import binary_dilation, binary_fill_holes

    t0 = Time()
    H, W = detmaps[0].shape

    allzero = True
    for iband, band in enumerate(bands):
        if sed[iband] == 0:
            continue
        if np.all(detivs[iband] == 0):
            continue
        allzero = False
        break
    if allzero:
        print('SED', sedname, 'has all zero weight')
        return None, None, None, None, None

    sedmap = np.zeros((H, W), np.float32)
    sediv = np.zeros((H, W), np.float32)
    for iband, band in enumerate(bands):
        if sed[iband] == 0:
            continue
        # We convert the detmap to canonical band via
        #   detmap * w
        # And the corresponding change to sig1 is
        #   sig1 * w
        # So the invvar-weighted sum is
        #    (detmap * w) / (sig1**2 * w**2)
        #  = detmap / (sig1**2 * w)
        sedmap += detmaps[iband] * detivs[iband] / sed[iband]
        sediv += detivs[iband] / sed[iband]**2
    sedmap /= np.maximum(1e-16, sediv)
    sedsn = sedmap * np.sqrt(sediv)
    del sedmap

    peaks = (sedsn > nsigma)
    print('SED sn:', Time() - t0)
    t0 = Time()

    def saddle_level(Y):
        # Require a saddle that drops by (the larger of) "saddle"
        # sigma, or 20% of the peak height
        drop = max(saddle, Y * 0.2)
        return Y - drop

    lowest_saddle = nsigma - saddle

    # zero out the edges -- larger margin here?
    peaks[0, :] = 0
    peaks[:, 0] = 0
    peaks[-1, :] = 0
    peaks[:, -1] = 0

    # Label the N-sigma blobs at this point... we'll use this to build
    # "sedhot", which in turn is used to define the blobs that we will
    # optimize simultaneously.  This also determines which pixels go
    # into the fitting!
    dilate = 8
    hotblobs, nhot = label(
        binary_fill_holes(binary_dilation(peaks, iterations=dilate)))

    # find pixels that are larger than their 8 neighbors
    peaks[1:-1, 1:-1] &= (sedsn[1:-1, 1:-1] >= sedsn[0:-2, 1:-1])
    peaks[1:-1, 1:-1] &= (sedsn[1:-1, 1:-1] >= sedsn[2:, 1:-1])
    peaks[1:-1, 1:-1] &= (sedsn[1:-1, 1:-1] >= sedsn[1:-1, 0:-2])
    peaks[1:-1, 1:-1] &= (sedsn[1:-1, 1:-1] >= sedsn[1:-1, 2:])
    peaks[1:-1, 1:-1] &= (sedsn[1:-1, 1:-1] >= sedsn[0:-2, 0:-2])
    peaks[1:-1, 1:-1] &= (sedsn[1:-1, 1:-1] >= sedsn[0:-2, 2:])
    peaks[1:-1, 1:-1] &= (sedsn[1:-1, 1:-1] >= sedsn[2:, 0:-2])
    peaks[1:-1, 1:-1] &= (sedsn[1:-1, 1:-1] >= sedsn[2:, 2:])
    print('Peaks:', Time() - t0)
    t0 = Time()

    if ps is not None:
        from astrometry.util.plotutils import dimshow
        crossa = dict(ms=10, mew=1.5)
        green = (0, 1, 0)

        def plot_boundary_map(X):
            bounds = binary_dilation(X) - X
            H, W = X.shape
            rgba = np.zeros((H, W, 4), np.uint8)
            rgba[:, :, 1] = bounds * 255
            rgba[:, :, 3] = bounds * 255
            plt.imshow(rgba, interpolation='nearest', origin='lower')

        plt.clf()
        plt.subplot(1, 2, 2)
        dimshow(sedsn, vmin=-2, vmax=100, cmap='hot', ticks=False)
        plt.subplot(1, 2, 1)
        dimshow(sedsn, vmin=-2, vmax=10, cmap='hot', ticks=False)
        above = (sedsn > nsigma)
        plot_boundary_map(above)
        ax = plt.axis()
        y, x = np.nonzero(peaks)
        plt.plot(x, y, 'r+')
        plt.axis(ax)
        plt.title('SED %s: S/N & peaks' % sedname)
        ps.savefig()

        # plt.clf()
        # plt.imshow(sedsn, vmin=-2, vmax=10, interpolation='nearest',
        #            origin='lower', cmap='hot')
        # plot_boundary_map(sedsn > lowest_saddle)
        # plt.title('SED %s: S/N & lowest saddle point bounds' % sedname)
        # ps.savefig()

    # For each new source, compute the saddle value, segment at that
    # level, and drop the source if it is in the same blob as a
    # previously-detected source.  We dilate the blobs a bit too, to
    # catch slight differences in centroid vs SDSS sources.
    dilate = 2

    # For efficiency, segment at the minimum saddle level to compute
    # slices; the operations described above need only happen within
    # the slice.
    saddlemap = (sedsn > lowest_saddle)
    if saturated_pix is not None:
        saddlemap |= saturated_pix
    saddlemap = binary_dilation(saddlemap, iterations=dilate)
    allblobs, nblobs = label(saddlemap)
    allslices = find_objects(allblobs)
    ally0 = [sy.start for sy, sx in allslices]
    allx0 = [sx.start for sy, sx in allslices]

    # brightest peaks first
    py, px = np.nonzero(peaks)
    I = np.argsort(-sedsn[py, px])
    py = py[I]
    px = px[I]

    keep = np.zeros(len(px), bool)

    peakval = []
    aper = []
    apin = 10
    apout = 20

    # Map of pixels that are vetoed by sources found so far.  The veto
    # area is based on saddle height.  We go from brightest to
    # faintest pixels.  Thus the saddle level decreases, and the
    # saddlemap areas become larger; the saddlemap when a source is
    # found is a lower bound on the pixels that it will veto based on
    # the saddle heights of fainter sources.  Thus the vetomap isn't
    # the final word, it is just a quick veto of pixels we know for
    # sure will be vetoed.
    vetomap = np.zeros(sedsn.shape, bool)

    # For each peak, determine whether it is isolated enough --
    # separated by a low enough saddle from other sources.  Need only
    # search within its "allblob", which is defined by the lowest
    # saddle.
    print('Found', len(px), 'potential peaks')
    #tlast = Time()
    for i, (x, y) in enumerate(zip(px, py)):
        if vetomap[y, x]:
            #print('  in veto map!')
            continue
        #t0 = Time()
        #t1 = Time()
        #print('Time since last source:', t1-tlast)
        #tlast = t1

        level = saddle_level(sedsn[y, x])
        ablob = allblobs[y, x]
        index = ablob - 1
        slc = allslices[index]

        #print('source', i, 'of', len(px), 'at', x,y, 'S/N', sedsn[y,x], 'saddle', level)
        #print('  allblobs slice', slc)

        saddlemap = (sedsn[slc] > level)
        if saturated_pix is not None:
            saddlemap |= saturated_pix[slc]
        saddlemap *= (allblobs[slc] == ablob)
        #print('  saddlemap', Time()-tlast)
        saddlemap = binary_fill_holes(saddlemap)
        #print('  fill holes', Time()-tlast)
        saddlemap = binary_dilation(saddlemap, iterations=dilate)
        #print('  dilation', Time()-tlast)
        blobs, nblobs = label(saddlemap)
        #print('  label', Time()-tlast)
        x0, y0 = allx0[index], ally0[index]
        thisblob = blobs[y - y0, x - x0]

        # previously found sources:
        ox = np.append(xomit, px[:i][keep[:i]]) - x0
        oy = np.append(yomit, py[:i][keep[:i]]) - y0
        h, w = blobs.shape
        cut = False
        if len(ox):
            ox = ox.astype(int)
            oy = oy.astype(int)
            cut = any((ox >= 0) * (ox < w) * (oy >= 0) * (oy < h) *
                      (blobs[np.clip(oy, 0, h - 1),
                             np.clip(ox, 0, w - 1)] == thisblob))

        if False and (not cut) and ps is not None:
            plt.clf()
            plt.subplot(1, 2, 1)
            dimshow(sedsn, vmin=-2, vmax=10, cmap='hot')
            plot_boundary_map((sedsn > nsigma))
            ax = plt.axis()
            plt.plot(x, y, 'm+', ms=12, mew=2)
            plt.axis(ax)

            plt.subplot(1, 2, 2)
            y1, x1 = [s.stop for s in slc]
            ext = [x0, x1, y0, y1]
            dimshow(saddlemap, extent=ext)
            #plt.plot([x0,x0,x1,x1,x0], [y0,y1,y1,y0,y0], 'c-')
            #ax = plt.axis()
            #plt.plot(ox+x0, oy+y0, 'rx')
            plt.plot(xomit, yomit, 'rx', ms=8, mew=2)
            plt.plot(px[:i][keep[:i]],
                     py[:i][keep[:i]],
                     '+',
                     color=green,
                     ms=8,
                     mew=2)
            plt.plot(x, y, 'mo', mec='m', mfc='none', ms=12, mew=2)
            plt.axis(ax)
            if cut:
                plt.suptitle('Cut')
            else:
                plt.suptitle('Keep')
            ps.savefig()

        #t1 = Time()
        #print(t1 - t0)

        if cut:
            # in same blob as previously found source.
            #print('  cut')
            # update vetomap
            vetomap[slc] |= saddlemap
            #print('Added to vetomap:', np.sum(saddlemap), 'pixels set; now total of', np.sum(vetomap), 'pixels set')
            continue

        # Measure in aperture...
        ap = sedsn[max(0, y - apout):min(H, y + apout + 1),
                   max(0, x - apout):min(W, x + apout + 1)]
        apiv = (sediv[max(0, y - apout):min(H, y + apout + 1),
                      max(0, x - apout):min(W, x + apout + 1)] > 0)
        aph, apw = ap.shape
        apx0, apy0 = max(0, x - apout), max(0, y - apout)
        R2 = ((np.arange(aph) + apy0 - y)[:, np.newaxis]**2 +
              (np.arange(apw) + apx0 - x)[np.newaxis, :]**2)
        ap = ap[apiv * (R2 >= apin**2) * (R2 <= apout**2)]
        if len(ap):
            # 16th percentile ~ -1 sigma point.
            m = np.percentile(ap, 16.)
        else:
            # fake
            m = -1.
        #print('  aper', Time()-tlast)
        if cutonaper:
            if sedsn[y, x] - m < nsigma:
                continue

        aper.append(m)
        peakval.append(sedsn[y, x])
        keep[i] = True

        vetomap[slc] |= saddlemap
        #print('Added to vetomap:', np.sum(saddlemap), 'pixels set; now total of', np.sum(vetomap), 'pixels set')

        if False and ps is not None:
            plt.clf()
            plt.subplot(1, 2, 1)
            dimshow(ap,
                    vmin=-2,
                    vmax=10,
                    cmap='hot',
                    extent=[apx0, apx0 + apw, apy0, apy0 + aph])
            plt.subplot(1, 2, 2)
            dimshow(ap * ((R2 >= apin**2) * (R2 <= apout**2)),
                    vmin=-2,
                    vmax=10,
                    cmap='hot',
                    extent=[apx0, apx0 + apw, apy0, apy0 + aph])
            plt.suptitle('peak %.1f vs ap %.1f' % (sedsn[y, x], m))
            ps.savefig()

    print('New sources:', Time() - t0)
    t0 = Time()

    if ps is not None:
        pxdrop = px[np.logical_not(keep)]
        pydrop = py[np.logical_not(keep)]
    KJB = True
    if KJB:
        pxdrop = px[np.logical_not(keep)]
        pydrop = py[np.logical_not(keep)]
    py = py[keep]
    px = px[keep]

    # Which of the hotblobs yielded sources?  Those are the ones to keep.
    hbmap = np.zeros(nhot + 1, bool)
    hbmap[hotblobs[py, px]] = True
    if len(xomit):
        h, w = hotblobs.shape
        hbmap[hotblobs[np.clip(yomit, 0, h - 1),
                       np.clip(xomit, 0, w - 1)]] = True
    # in case a source is (somehow) not in a hotblob?
    hbmap[0] = False
    hotblobs = hbmap[hotblobs]

    if ps is not None:
        plt.clf()
        dimshow(vetomap, vmin=0, vmax=1, cmap='hot')
        plt.title('SED %s: veto map' % sedname)
        ps.savefig()

        plt.clf()
        dimshow(hotblobs, vmin=0, vmax=1, cmap='hot')
        ax = plt.axis()
        p1 = plt.plot(px, py, 'g+', ms=8, mew=2)
        p2 = plt.plot(pxdrop, pydrop, 'm+', ms=8, mew=2)
        p3 = plt.plot(xomit, yomit, 'r+', ms=8, mew=2)
        plt.axis(ax)
        plt.title('SED %s: hot blobs' % sedname)
        plt.figlegend((p3[0], p1[0], p2[0]), ('Existing', 'Keep', 'Drop'),
                      'upper left')
        ps.savefig()
    if KJB:
        from astrometry.util.plotutils import dimshow
        crossa = dict(ms=10, mew=1.5)
        green = (0, 1, 0)

        def plot_boundary_map(X):
            bounds = binary_dilation(X) - X
            H, W = X.shape
            rgba = np.zeros((H, W, 4), np.uint8)
            rgba[:, :, 1] = bounds * 255
            rgba[:, :, 3] = bounds * 255
            plt.imshow(rgba, interpolation='nearest', origin='lower')

        plt.clf()
        plt.subplot(1, 2, 2)
        dimshow(sedsn, vmin=-2, vmax=100, cmap='hot', ticks=False)
        plt.title('sedsn')
        plt.subplot(1, 2, 1)
        dimshow(hotblobs, vmin=0, vmax=1, cmap='hot')
        ax = plt.axis()
        p1 = plt.plot(px, py, 'g+', ms=8, mew=2)
        p2 = plt.plot(pxdrop, pydrop, 'm+', ms=8, mew=2)
        p3 = plt.plot(xomit, yomit, 'r+', ms=8, mew=2)
        plt.axis(ax)
        plt.title('SED %s: hot blobs' % sedname)
        plt.figlegend((p3[0], p1[0], p2[0]), ('Existing', 'Keep', 'Drop'),
                      'upper left')
        #crazy! ps.outdir is not known here yet it is in all stages!
        #output SED pngs with current time so can later see which belongs to which
        #plt.savefig(os.path.join(ps.outdir,'sedmap_%s.png' % sedname))
        cpudate = do_bash("date")
        cpudate = cpudate[cpudate.find(':') - 2:].replace(' ',
                                                          '').strip().replace(
                                                              ':', '-')
        plt.savefig('./sedmap_%s_%s.png' % (sedname, cpudate))

    return hotblobs, px, py, aper, peakval
def _plot_mods(tims,
               mods,
               blobwcs,
               titles,
               bands,
               coimgs,
               cons,
               bslc,
               blobw,
               blobh,
               ps,
               chi_plots=True,
               rgb_plots=False,
               main_plot=True,
               rgb_format='%s'):
    import numpy as np

    subims = [[] for m in mods]
    chis = dict([(b, []) for b in bands])

    make_coimgs = (coimgs is None)
    if make_coimgs:
        print('_plot_mods: blob shape', (blobh, blobw))
        coimgs = [np.zeros((blobh, blobw)) for b in bands]
        cons = [np.zeros((blobh, blobw)) for b in bands]

    for iband, band in enumerate(bands):
        comods = [np.zeros((blobh, blobw)) for m in mods]
        cochis = [np.zeros((blobh, blobw)) for m in mods]
        comodn = np.zeros((blobh, blobw))
        mn, mx = 0, 0
        sig1 = 1.
        for itim, tim in enumerate(tims):
            if tim.band != band:
                continue
            R = tim_get_resamp(tim, blobwcs)
            if R is None:
                continue
            (Yo, Xo, Yi, Xi) = R

            rechi = np.zeros((blobh, blobw))
            chilist = []
            comodn[Yo, Xo] += 1
            for imod, mod in enumerate(mods):
                chi = ((tim.getImage()[Yi, Xi] - mod[itim][Yi, Xi]) *
                       tim.getInvError()[Yi, Xi])
                rechi[Yo, Xo] = chi
                chilist.append((rechi.copy(), itim))
                cochis[imod][Yo, Xo] += chi
                comods[imod][Yo, Xo] += mod[itim][Yi, Xi]
            chis[band].append(chilist)
            # we'll use 'sig1' of the last tim in the list below...
            mn, mx = -10. * tim.sig1, 30. * tim.sig1
            sig1 = tim.sig1
            if make_coimgs:
                nn = (tim.getInvError()[Yi, Xi] > 0)
                coimgs[iband][Yo, Xo] += tim.getImage()[Yi, Xi] * nn
                cons[iband][Yo, Xo] += nn

        if make_coimgs:
            coimgs[iband] /= np.maximum(cons[iband], 1)
            coimg = coimgs[iband]
            coimgn = cons[iband]
        else:
            coimg = coimgs[iband][bslc]
            coimgn = cons[iband][bslc]

        for comod in comods:
            comod /= np.maximum(comodn, 1)
        ima = dict(vmin=mn, vmax=mx, ticks=False)
        resida = dict(vmin=-5. * sig1, vmax=5. * sig1, ticks=False)
        for subim, comod, cochi in zip(subims, comods, cochis):
            subim.append((coimg, coimgn, comod, ima, cochi, resida))

    # Plot per-band image, model, and chi coadds, and RGB images
    rgba = dict(ticks=False)
    rgbs = []
    rgbnames = []
    plt.figure(1)
    for i, subim in enumerate(subims):
        plt.clf()
        rows, cols = 3, 5
        for ib, b in enumerate(bands):
            plt.subplot(rows, cols, ib + 1)
            plt.title(b)
        plt.subplot(rows, cols, 4)
        plt.title('RGB')
        plt.subplot(rows, cols, 5)
        plt.title('RGB(stretch)')

        imgs = []
        themods = []
        resids = []
        for j, (img, imgn, mod, ima, chi, resida) in enumerate(subim):
            imgs.append(img)
            themods.append(mod)
            resid = img - mod
            resid[imgn == 0] = np.nan
            resids.append(resid)

            if main_plot:
                plt.subplot(rows, cols, 1 + j + 0)
                dimshow(img, **ima)
                plt.subplot(rows, cols, 1 + j + cols)
                dimshow(mod, **ima)
                plt.subplot(rows, cols, 1 + j + cols * 2)
                # dimshow(-chi, **imchi)
                # dimshow(imgn, vmin=0, vmax=3)
                dimshow(resid, nancolor='r', **resida)
        rgb = get_rgb(imgs, bands)
        if i == 0:
            rgbs.append(rgb)
            rgbnames.append(rgb_format % 'Image')
        if main_plot:
            plt.subplot(rows, cols, 4)
            dimshow(rgb, **rgba)
        rgb = get_rgb(themods, bands)
        rgbs.append(rgb)
        rgbnames.append(rgb_format % titles[i])
        if main_plot:
            plt.subplot(rows, cols, cols + 4)
            dimshow(rgb, **rgba)
            plt.subplot(rows, cols, cols * 2 + 4)
            dimshow(get_rgb(resids, bands, mnmx=(-10, 10)), **rgba)

            mnmx = -5, 300
            kwa = dict(mnmx=mnmx, arcsinh=1)
            plt.subplot(rows, cols, 5)
            dimshow(get_rgb(imgs, bands, **kwa), **rgba)
            plt.subplot(rows, cols, cols + 5)
            dimshow(get_rgb(themods, bands, **kwa), **rgba)
            plt.subplot(rows, cols, cols * 2 + 5)
            mnmx = -100, 100
            kwa = dict(mnmx=mnmx, arcsinh=1)
            dimshow(get_rgb(resids, bands, **kwa), **rgba)
            plt.suptitle(titles[i])
            ps.savefig()

    if rgb_plots:
        # RGB image and model
        plt.figure(2)
        for rgb, tt in zip(rgbs, rgbnames):
            plt.clf()
            dimshow(rgb, **rgba)
            plt.title(tt)
            ps.savefig()

    if not chi_plots:
        return

    imchi = dict(cmap='RdBu', vmin=-5, vmax=5)

    plt.figure(1)
    # Plot per-image chis: in a grid with band along the rows and images along the cols
    cols = max(len(v) for v in chis.values())
    rows = len(bands)
    for imod in range(len(mods)):
        plt.clf()
        for row, band in enumerate(bands):
            sp0 = 1 + cols * row
            # chis[band] = [ (one for each tim:) [ (one for each mod:) (chi,itim), (chi,itim) ], ...]
            for col, chilist in enumerate(chis[band]):
                chi, itim = chilist[imod]
                plt.subplot(rows, cols, sp0 + col)
                dimshow(-chi, **imchi)
                plt.xticks([])
                plt.yticks([])
                plt.title(tims[itim].name)
        #plt.suptitle(titles[imod])
        ps.savefig()
def _debug_plots(srctractor, ps):
    thislnp0 = srctractor.getLogProb()
    p0 = np.array(srctractor.getParams())
    print 'logprob:', p0, '=', thislnp0

    print 'p0 type:', p0.dtype
    px = p0 + np.zeros_like(p0)
    srctractor.setParams(px)
    lnpx = srctractor.getLogProb()
    assert(lnpx == thislnp0)
    print 'logprob:', px, '=', lnpx

    scales = srctractor.getParameterScales()
    print 'Parameter scales:', scales
    print 'Parameters:'
    srctractor.printThawedParams()

    # getParameterScales better not have changed the params!!
    assert(np.all(p0 == np.array(srctractor.getParams())))
    assert(srctractor.getLogProb() == thislnp0)

    pfinal = srctractor.getParams()
    pnames = srctractor.getParamNames()

    plt.figure(3, figsize=(8,6))

    plt.clf()
    for i in range(len(scales)):
        plt.plot([(p[i] - pfinal[i])*scales[i] for lnp,p in params],
                 [lnp for lnp,p in params], '-', label=pnames[i])
    plt.ylabel('lnp')
    plt.legend()
    plt.title('scaled')
    ps.savefig()

    for i in range(len(scales)):
        plt.clf()
        #plt.subplot(2,1,1)
        plt.plot([p[i] for lnp,p in params], '-')
        plt.xlabel('step')
        plt.title(pnames[i])
        ps.savefig()

        plt.clf()
        plt.plot([p[i] for lnp,p in params],
                 [lnp for lnp,p in params], 'b.-')

        # We also want to know about d(lnp)/d(param)
        # and d(lnp)/d(X)
        step = 1.1
        steps = 1.1 ** np.arange(-20, 21)
        s2 = np.linspace(0, steps[0], 10)[1:-1]
        steps = reduce(np.append, [-steps[::-1], -s2[::-1], 0, s2, steps])
        print 'Steps:', steps

        plt.plot(p0[i], thislnp0, 'bx', ms=20)

        print 'Stepping in param', pnames[i], '...'
        pp = p0.copy()
        lnps,parms = [],[]
        for s in steps:
            parm = p0[i] + s / scales[i]
            pp[i] = parm
            srctractor.setParams(pp)
            lnp = srctractor.getLogProb()
            parms.append(parm)
            lnps.append(lnp)
            print 'logprob:', pp, '=', lnp
            
        plt.plot(parms, lnps, 'k.-')
        j = np.argmin(np.abs(steps - 1.))
        plt.plot(parms[j], lnps[j], 'ko')

        print 'Stepping in X...'
        lnps,parms = [],[]
        for s in steps:
            pp = p0 + s * X
            srctractor.setParams(pp)
            lnp = srctractor.getLogProb()
            parms.append(pp[i])
            lnps.append(lnp)
            print 'logprob:', pp, '=', lnp


        ##
        s3 = s2[:2]
        ministeps = reduce(np.append, [-s3[::-1], 0, s3])
        print 'mini steps:', ministeps
        for s in ministeps:
            pp = p0 + s * X
            srctractor.setParams(pp)
            lnp = srctractor.getLogProb()
            print 'logprob:', pp, '=', lnp

        rows = len(ministeps)
        cols = len(srctractor.images)

        plt.figure(4, figsize=(8,6))
        plt.subplots_adjust(hspace=0.05, wspace=0.05, left=0.01,
                            right=0.99, bottom=0.01, top=0.99)
        plt.clf()
        k = 1
        mods = []
        for s in ministeps:
            pp = p0 + s * X
            srctractor.setParams(pp)
            print 'ministep', s
            print 'log prior', srctractor.getLogPrior()
            print 'log likelihood', srctractor.getLogLikelihood()
            mods.append(srctractor.getModelImages())
            chis = srctractor.getChiImages()
            # for chi in chis:
            #     plt.subplot(rows, cols, k)
            #     k += 1
            #     dimshow(chi, ticks=False, vmin=-10, vmax=10, cmap='jet')
            print 'chisqs:', [(chi**2).sum() for chi in chis]
            print 'sum:', sum([(chi**2).sum() for chi in chis])

        mod0 = mods[len(ministeps)/2]
        for modlist in mods:
            for mi,mod in enumerate(modlist):
                plt.subplot(rows, cols, k)
                k += 1
                m0 = mod0[mi]
                rng = m0.max() - m0.min()
                dimshow(mod - mod0[mi], vmin=-0.01*rng, vmax=0.01*rng,
                        ticks=False, cmap='gray')
        ps.savefig()
        plt.figure(3)
        
        plt.plot(parms, lnps, 'r.-')

        print 'Stepping in X by alphas...'
        lnps = []
        for cc,ss in [('m',0.1), ('m',0.3), ('r',1)]:
            pp = p0 + ss*X
            srctractor.setParams(pp)
            lnp = srctractor.getLogProb()
            print 'logprob:', pp, '=', lnp

            plt.plot(p0[i] + ss * X[i], lnp, 'o', color=cc)
            lnps.append(lnp)

        px = p0[i] + X[i]
        pmid = (px + p0[i]) / 2.
        dp = np.abs((px - pmid) * 2.)
        hi,lo = max(max(lnps), thislnp0), min(min(lnps), thislnp0)
        lnpmid = (hi + lo) / 2.
        dlnp = np.abs((hi - lo) * 2.)

        plt.ylabel('lnp')
        plt.title(pnames[i])
        ps.savefig()

        plt.axis([pmid - dp, pmid + dp, lnpmid-dlnp, lnpmid+dlnp])
        ps.savefig()

    srctractor.setParams(p0)
def stage_psfplots(T=None,
                   sedsn=None,
                   coimgs=None,
                   cons=None,
                   detmaps=None,
                   detivs=None,
                   blobsrcs=None,
                   blobflux=None,
                   blobslices=None,
                   blobs=None,
                   tractor=None,
                   cat=None,
                   targetrd=None,
                   pixscale=None,
                   targetwcs=None,
                   W=None,
                   H=None,
                   brickid=None,
                   bands=None,
                   ps=None,
                   tims=None,
                   plots=False,
                   **kwargs):

    tim = tims[0]
    tim.psfex.fitSavedData(*tim.psfex.splinedata)
    spl = tim.psfex.splines[0]
    print('Spline:', spl)
    knots = spl.get_knots()
    print('knots:', knots)
    tx, ty = knots
    k = 3
    print('interior knots x:', tx[k + 1:-k - 1])
    print('additional knots x:', tx[:k + 1], 'and', tx[-k - 1:])
    print('interior knots y:', ty[k + 1:-k - 1])
    print('additional knots y:', ty[:k + 1], 'and', ty[-k - 1:])

    for itim, tim in enumerate(tims):
        psfex = tim.psfex
        psfex.fitSavedData(*psfex.splinedata)
        if plots:
            print()
            print('Tim', tim)
            print()
            pp, xx, yy = psfex.splinedata
            ny, nx, nparams = pp.shape
            assert (len(xx) == nx)
            assert (len(yy) == ny)
            psfnil = psfex.psfclass(*np.zeros(nparams))
            names = psfnil.getParamNames()
            xa = np.linspace(xx[0], xx[-1], 50)
            ya = np.linspace(yy[0], yy[-1], 100)
            #xa,ya = np.meshgrid(xa,ya)
            #xa = xa.ravel()
            #ya = ya.ravel()
            print('xa', xa)
            print('ya', ya)
            for i in range(nparams):
                plt.clf()
                plt.subplot(1, 2, 1)
                dimshow(pp[:, :, i])
                plt.title('grid fit')
                plt.colorbar()
                plt.subplot(1, 2, 2)
                sp = psfex.splines[i](xa, ya)
                sp = sp.T
                print('spline shape', sp.shape)
                assert (sp.shape == (len(ya), len(xa)))
                dimshow(sp, extent=[xx[0], xx[-1], yy[0], yy[-1]])
                plt.title('spline')
                plt.colorbar()
                plt.suptitle('tim %s: PSF param %s' % (tim.name, names[i]))
                ps.savefig()
def _plot_mods(tims, mods, blobwcs, titles, bands, coimgs, cons, bslc,
               blobw, blobh, ps,
               chi_plots=True, rgb_plots=False, main_plot=True,
               rgb_format='%s'):
    import numpy as np

    subims = [[] for m in mods]
    chis = dict([(b,[]) for b in bands])
    
    make_coimgs = (coimgs is None)
    if make_coimgs:
        print '_plot_mods: blob shape', (blobh, blobw)
        coimgs = [np.zeros((blobh,blobw)) for b in bands]
        cons   = [np.zeros((blobh,blobw)) for b in bands]

    for iband,band in enumerate(bands):
        comods = [np.zeros((blobh,blobw)) for m in mods]
        cochis = [np.zeros((blobh,blobw)) for m in mods]
        comodn = np.zeros((blobh,blobw))
        mn,mx = 0,0
        sig1 = 1.
        for itim,tim in enumerate(tims):
            if tim.band != band:
                continue
            R = tim_get_resamp(tim, blobwcs)
            if R is None:
                continue
            (Yo,Xo,Yi,Xi) = R

            rechi = np.zeros((blobh,blobw))
            chilist = []
            comodn[Yo,Xo] += 1
            for imod,mod in enumerate(mods):
                chi = ((tim.getImage()[Yi,Xi] - mod[itim][Yi,Xi]) *
                       tim.getInvError()[Yi,Xi])
                rechi[Yo,Xo] = chi
                chilist.append((rechi.copy(), itim))
                cochis[imod][Yo,Xo] += chi
                comods[imod][Yo,Xo] += mod[itim][Yi,Xi]
            chis[band].append(chilist)
            # we'll use 'sig1' of the last tim in the list below...
            mn,mx = -10.*tim.sig1, 30.*tim.sig1
            sig1 = tim.sig1
            if make_coimgs:
                nn = (tim.getInvError()[Yi,Xi] > 0)
                coimgs[iband][Yo,Xo] += tim.getImage()[Yi,Xi] * nn
                cons  [iband][Yo,Xo] += nn
                
        if make_coimgs:
            coimgs[iband] /= np.maximum(cons[iband], 1)
            coimg  = coimgs[iband]
            coimgn = cons  [iband]
        else:
            coimg = coimgs[iband][bslc]
            coimgn = cons[iband][bslc]
            
        for comod in comods:
            comod /= np.maximum(comodn, 1)
        ima = dict(vmin=mn, vmax=mx, ticks=False)
        resida = dict(vmin=-5.*sig1, vmax=5.*sig1, ticks=False)
        for subim,comod,cochi in zip(subims, comods, cochis):
            subim.append((coimg, coimgn, comod, ima, cochi, resida))

    # Plot per-band image, model, and chi coadds, and RGB images
    rgba = dict(ticks=False)
    rgbs = []
    rgbnames = []
    plt.figure(1)
    for i,subim in enumerate(subims):
        plt.clf()
        rows,cols = 3,5
        for ib,b in enumerate(bands):
            plt.subplot(rows,cols,ib+1)
            plt.title(b)
        plt.subplot(rows,cols,4)
        plt.title('RGB')
        plt.subplot(rows,cols,5)
        plt.title('RGB(stretch)')
        
        imgs = []
        themods = []
        resids = []
        for j,(img,imgn,mod,ima,chi,resida) in enumerate(subim):
            imgs.append(img)
            themods.append(mod)
            resid = img - mod
            resid[imgn == 0] = np.nan
            resids.append(resid)

            if main_plot:
                plt.subplot(rows,cols,1 + j + 0)
                dimshow(img, **ima)
                plt.subplot(rows,cols,1 + j + cols)
                dimshow(mod, **ima)
                plt.subplot(rows,cols,1 + j + cols*2)
                # dimshow(-chi, **imchi)
                # dimshow(imgn, vmin=0, vmax=3)
                dimshow(resid, nancolor='r', **resida)
        rgb = get_rgb(imgs, bands)
        if i == 0:
            rgbs.append(rgb)
            rgbnames.append(rgb_format % 'Image')
        if main_plot:
            plt.subplot(rows,cols, 4)
            dimshow(rgb, **rgba)
        rgb = get_rgb(themods, bands)
        rgbs.append(rgb)
        rgbnames.append(rgb_format % titles[i])
        if main_plot:
            plt.subplot(rows,cols, cols+4)
            dimshow(rgb, **rgba)
            plt.subplot(rows,cols, cols*2+4)
            dimshow(get_rgb(resids, bands, mnmx=(-10,10)), **rgba)

            mnmx = -5,300
            kwa = dict(mnmx=mnmx, arcsinh=1)
            plt.subplot(rows,cols, 5)
            dimshow(get_rgb(imgs, bands, **kwa), **rgba)
            plt.subplot(rows,cols, cols+5)
            dimshow(get_rgb(themods, bands, **kwa), **rgba)
            plt.subplot(rows,cols, cols*2+5)
            mnmx = -100,100
            kwa = dict(mnmx=mnmx, arcsinh=1)
            dimshow(get_rgb(resids, bands, **kwa), **rgba)
            plt.suptitle(titles[i])
            ps.savefig()

    if rgb_plots:
        # RGB image and model
        plt.figure(2)
        for rgb,tt in zip(rgbs, rgbnames):
            plt.clf()
            dimshow(rgb, **rgba)
            plt.title(tt)
            ps.savefig()

    if not chi_plots:
        return

    imchi = dict(cmap='RdBu', vmin=-5, vmax=5)

    plt.figure(1)
    # Plot per-image chis: in a grid with band along the rows and images along the cols
    cols = max(len(v) for v in chis.values())
    rows = len(bands)
    for imod in range(len(mods)):
        plt.clf()
        for row,band in enumerate(bands):
            sp0 = 1 + cols*row
            # chis[band] = [ (one for each tim:) [ (one for each mod:) (chi,itim), (chi,itim) ], ...]
            for col,chilist in enumerate(chis[band]):
                chi,itim = chilist[imod]
                plt.subplot(rows, cols, sp0 + col)
                dimshow(-chi, **imchi)
                plt.xticks([]); plt.yticks([])
                plt.title(tims[itim].name)
        #plt.suptitle(titles[imod])
        ps.savefig()
def stage_fitplots(T=None,
                   coimgs=None,
                   cons=None,
                   cat=None,
                   targetrd=None,
                   pixscale=None,
                   targetwcs=None,
                   W=None,
                   H=None,
                   bands=None,
                   ps=None,
                   brickid=None,
                   plots=False,
                   plots2=False,
                   tims=None,
                   tractor=None,
                   pipe=None,
                   outdir=None,
                   **kwargs):

    for tim in tims:
        print('Tim', tim, 'PSF', tim.getPsf())

    writeModels = False

    if pipe:
        t0 = Time()
        # Produce per-band coadds, for plots
        coimgs, cons = compute_coadds(tims, bands, targetwcs)
        print('Coadds:', Time() - t0)

    plt.figure(figsize=(10, 10.5))
    #plt.subplots_adjust(left=0.002, right=0.998, bottom=0.002, top=0.998)
    plt.subplots_adjust(left=0.002, right=0.998, bottom=0.002, top=0.95)

    plt.clf()
    dimshow(get_rgb(coimgs, bands))
    plt.title('Image')
    ps.savefig()

    ax = plt.axis()
    cat = tractor.getCatalog()
    for i, src in enumerate(cat):
        rd = src.getPosition()
        ok, x, y = targetwcs.radec2pixelxy(rd.ra, rd.dec)
        cc = (0, 1, 0)
        if isinstance(src, PointSource):
            plt.plot(x - 1, y - 1, '+', color=cc, ms=10, mew=1.5)
        else:
            plt.plot(x - 1, y - 1, 'o', mec=cc, mfc='none', ms=10, mew=1.5)
        # plt.text(x, y, '%i' % i, color=cc, ha='center', va='bottom')
    plt.axis(ax)
    ps.savefig()

    mnmx = -5, 300
    arcsinha = dict(mnmx=mnmx, arcsinh=1)

    # After plot
    rgbmod = []
    rgbmod2 = []
    rgbresids = []
    rgbchisqs = []

    chibins = np.linspace(-10., 10., 200)
    chihist = [np.zeros(len(chibins) - 1, int) for band in bands]

    wcsW = targetwcs.get_width()
    wcsH = targetwcs.get_height()
    print('Target WCS shape', wcsW, wcsH)

    t0 = Time()
    mods = _map(_get_mod, [(tim, cat) for tim in tims])
    print('Getting model images:', Time() - t0)

    orig_wcsxy0 = [tim.wcs.getX0Y0() for tim in tims]
    for iband, band in enumerate(bands):
        coimg = coimgs[iband]
        comod = np.zeros((wcsH, wcsW), np.float32)
        comod2 = np.zeros((wcsH, wcsW), np.float32)
        cochi2 = np.zeros((wcsH, wcsW), np.float32)
        for itim, (tim, mod) in enumerate(zip(tims, mods)):
            if tim.band != band:
                continue

            #mod = tractor.getModelImage(tim)

            if plots2:
                plt.clf()
                dimshow(tim.getImage(), **tim.ima)
                plt.title(tim.name)
                ps.savefig()
                plt.clf()
                dimshow(mod, **tim.ima)
                plt.title(tim.name)
                ps.savefig()
                plt.clf()
                dimshow((tim.getImage() - mod) * tim.getInvError(), **imchi)
                plt.title(tim.name)
                ps.savefig()

            R = tim_get_resamp(tim, targetwcs)
            if R is None:
                continue
            (Yo, Xo, Yi, Xi) = R
            comod[Yo, Xo] += mod[Yi, Xi]
            ie = tim.getInvError()
            noise = np.random.normal(size=ie.shape) / ie
            noise[ie == 0] = 0.
            comod2[Yo, Xo] += mod[Yi, Xi] + noise[Yi, Xi]
            chi = ((tim.getImage()[Yi, Xi] - mod[Yi, Xi]) *
                   tim.getInvError()[Yi, Xi])
            cochi2[Yo, Xo] += chi**2
            chi = chi[chi != 0.]
            hh, xe = np.histogram(np.clip(chi, -10, 10).ravel(), bins=chibins)
            chihist[iband] += hh

            if not writeModels:
                continue

            im = tim.imobj
            fn = 'image-b%06i-%s-%s.fits' % (brickid, band, im.name)

            wcsfn = create_temp()
            wcs = tim.getWcs().wcs
            x0, y0 = orig_wcsxy0[itim]
            h, w = tim.shape
            subwcs = wcs.get_subimage(int(x0), int(y0), w, h)
            subwcs.write_to(wcsfn)

            primhdr = fitsio.FITSHDR()
            primhdr.add_record(
                dict(name='X0', value=x0, comment='Pixel origin of subimage'))
            primhdr.add_record(
                dict(name='Y0', value=y0, comment='Pixel origin of subimage'))
            xfn = im.wcsfn.replace(decals_dir + '/', '')
            primhdr.add_record(dict(name='WCS_FILE', value=xfn))
            xfn = im.psffn.replace(decals_dir + '/', '')
            primhdr.add_record(dict(name='PSF_FILE', value=xfn))
            primhdr.add_record(dict(name='INHERIT', value=True))

            imhdr = fitsio.read_header(wcsfn)
            imhdr.add_record(
                dict(name='EXTTYPE',
                     value='IMAGE',
                     comment='This HDU contains image data'))
            ivhdr = fitsio.read_header(wcsfn)
            ivhdr.add_record(
                dict(name='EXTTYPE',
                     value='INVVAR',
                     comment='This HDU contains an inverse-variance map'))
            fits = fitsio.FITS(fn, 'rw', clobber=True)
            tim.toFits(fits,
                       primheader=primhdr,
                       imageheader=imhdr,
                       invvarheader=ivhdr)

            imhdr.add_record(
                dict(name='EXTTYPE',
                     value='MODEL',
                     comment='This HDU contains a Tractor model image'))
            fits.write(mod, header=imhdr)
            print('Wrote image and model to', fn)

        comod /= np.maximum(cons[iband], 1)
        comod2 /= np.maximum(cons[iband], 1)

        rgbmod.append(comod)
        rgbmod2.append(comod2)
        resid = coimg - comod
        resid[cons[iband] == 0] = np.nan
        rgbresids.append(resid)
        rgbchisqs.append(cochi2)

        # Plug the WCS header cards into these images
        wcsfn = create_temp()
        targetwcs.write_to(wcsfn)
        hdr = fitsio.read_header(wcsfn)
        os.remove(wcsfn)

        if outdir is None:
            outdir = '.'
        wa = dict(clobber=True, header=hdr)
        for name, img in [('image', coimg), ('model', comod), ('resid', resid),
                          ('chi2', cochi2)]:
            fn = os.path.join(outdir,
                              '%s-coadd-%06i-%s.fits' % (name, brickid, band))
            fitsio.write(fn, img, **wa)
            print('Wrote', fn)

    del cons

    plt.clf()
    dimshow(get_rgb(rgbmod, bands))
    plt.title('Model')
    ps.savefig()

    plt.clf()
    dimshow(get_rgb(rgbmod2, bands))
    plt.title('Model + Noise')
    ps.savefig()

    plt.clf()
    dimshow(get_rgb(rgbresids, bands))
    plt.title('Residuals')
    ps.savefig()

    plt.clf()
    dimshow(get_rgb(rgbresids, bands, mnmx=(-30, 30)))
    plt.title('Residuals (2)')
    ps.savefig()

    plt.clf()
    dimshow(get_rgb(coimgs, bands, **arcsinha))
    plt.title('Image (stretched)')
    ps.savefig()

    plt.clf()
    dimshow(get_rgb(rgbmod2, bands, **arcsinha))
    plt.title('Model + Noise (stretched)')
    ps.savefig()

    del coimgs
    del rgbresids
    del rgbmod
    del rgbmod2

    plt.clf()
    g, r, z = rgbchisqs
    im = np.log10(np.dstack((z, r, g)))
    mn, mx = 0, im.max()
    dimshow(np.clip((im - mn) / (mx - mn), 0., 1.))
    plt.title('Chi-squared')
    ps.savefig()

    plt.clf()
    xx = np.repeat(chibins, 2)[1:-1]
    for y, cc in zip(chihist, 'grm'):
        plt.plot(xx, np.repeat(np.maximum(0.1, y), 2), '-', color=cc)
    plt.xlabel('Chi')
    plt.yticks([])
    plt.axvline(0., color='k', alpha=0.25)
    ps.savefig()

    plt.yscale('log')
    mx = np.max([max(y) for y in chihist])
    plt.ylim(1, mx * 1.05)
    ps.savefig()

    return dict(tims=tims)
def main():
    decals = Decals()

    catpattern = 'pipebrick-cats/tractor-phot-b%06i.fits'
    ra,dec = 242, 7

    # Region-of-interest, in pixels: x0, x1, y0, y1
    #roi = None
    roi = [500, 1000, 500, 1000]

    if roi is not None:
        x0,x1,y0,y1 = roi

    #expnum = 346623
    #ccdname = 'N12'
    #chips = decals.find_ccds(expnum=expnum, extname=ccdname)
    #print 'Found', len(chips), 'chips for expnum', expnum, 'extname', ccdname
    #if len(chips) != 1:
    #return False

    chips = decals.get_ccds()
    D = np.argsort(np.hypot(chips.ra - ra, chips.dec - dec))
    print('Closest chip:', chips[D[0]])
    chips = [chips[D[0]]]

    im = DecamImage(decals, chips[0])
    print('Image:', im)

    targetwcs = Sip(im.wcsfn)
    if roi is not None:
        targetwcs = targetwcs.get_subimage(x0, y0, x1-x0, y1-y0)

    r0,r1,d0,d1 = targetwcs.radec_bounds()
    # ~ 30-pixel margin
    margin = 2e-3
    if r0 > r1:
        # RA wrap-around
        TT = [brick_catalog_for_radec_box(ra,rb, d0-margin,d1+margin,
                                          decals, catpattern)
                for (ra,rb) in [(0, r1+margin), (r0-margin, 360.)]]
        T = merge_tables(TT)
        T._header = TT[0]._header
    else:
        T = brick_catalog_for_radec_box(r0-margin,r1+margin,d0-margin,
                                        d1+margin, decals, catpattern)

    print('Got', len(T), 'catalog entries within range')
    cat = read_fits_catalog(T, T._header)
    print('Got', len(cat), 'catalog objects')

    print('Switching ellipse parameterizations')
    switch_to_soft_ellipses(cat)
    keepcat = []
    for src in cat:
        if not np.all(np.isfinite(src.getParams())):
            print('Src has infinite params:', src)
            continue
        if isinstance(src, FixedCompositeGalaxy):
            f = src.fracDev.getClippedValue()
            if f == 0.:
                src = ExpGalaxy(src.pos, src.brightness, src.shapeExp)
            elif f == 1.:
                src = DevGalaxy(src.pos, src.brightness, src.shapeDev)
        keepcat.append(src)
    cat = keepcat

    slc = None
    if roi is not None:
        slc = slice(y0,y1), slice(x0,x1)
    tim = im.get_tractor_image(slc=slc)
    print('Got', tim)
    tim.psfex.fitSavedData(*tim.psfex.splinedata)
    tim.psfex.radius = 20
    tim.psf = CachingPsfEx.fromPsfEx(tim.psfex)
    
    tractor = Tractor([tim], cat)
    print('Created', tractor)

    mod = tractor.getModelImage(0)

    plt.clf()
    dimshow(tim.getImage(), **tim.ima)
    plt.title('Image')
    plt.savefig('1.png')

    plt.clf()
    dimshow(mod, **tim.ima)
    plt.title('Model')
    plt.savefig('2.png')

    
    ok,x,y = targetwcs.radec2pixelxy([src.getPosition().ra  for src in cat],
                                  [src.getPosition().dec for src in cat])
    ax = plt.axis()
    plt.plot(x, y, 'rx')
    #plt.savefig('3.png')
    plt.axis(ax)
    plt.title('Sources')
    plt.savefig('3.png')
    
    bands = [im.band]
    import runbrick
    runbrick.photoobjdir = '.'
    scat,T = get_sdss_sources(bands, targetwcs, local=False)
    print('Got', len(scat), 'SDSS sources in bounds')
    
    stractor = Tractor([tim], scat)
    print('Created', stractor)
    smod = stractor.getModelImage(0)

    plt.clf()
    dimshow(smod, **tim.ima)
    plt.title('SDSS model')
    plt.savefig('4.png')