def main():

    log = pexLog.Log(pexLog.Log.getDefaultLog(), 'foo', pexLog.Log.INFO)
    
    ny, nx  = 256, 256
        
    fwhm0 = 5.0
    psf   = measAlg.DoubleGaussianPsf(21, 21, fwhm0)
    flux  = 1.0e6

    # make two sets of fake data, seconds set is missing a source
    nSrc  = 4
    xy    = randomCoords(nSrc)
    fluxs = [flux]*(nSrc-1) + [0.7*flux]
    mimg  = makeFakeImage(nx, ny, xy, fluxs, [3.0*fwhm0]*nSrc)
    img   = mimg.getImage().getArray()
    bbox  = mimg.getBBox(afwImage.PARENT)
    mimg.writeFits("foo.fits")

    nSrcB = nSrc - 4
    mimgB = makeFakeImage(nx, ny, xy[0:nSrcB], fluxs[0:nSrcB], [3.0*fwhm0]*nSrcB)
    imgB  = mimgB.getImage().getArray()
    bboxB = mimgB.getBBox(afwImage.PARENT)
    mimgB.writeFits("fooB.fits")
    
    # Run the detection
    fp    = detect(mimg)
    fpB   = detect(mimgB)
    
    # deblend mimgB (missing a peak) using the fp with the extra peak
    deb   = measDeb.deblend(fp, mimgB, psf, fwhm0, verbose=True, rampFluxAtEdge=True, log=log)
    print "Deblended peaks: ", len(deb.peaks)

    fig   = makePortionFigure(deb, mimg, mimgB)
    fig.savefig("test.png")
    def testPeakRemoval(self):
        '''
        A simple example: three overlapping blobs (detected as 1
        footprint with three peaks).  Additional peaks are added near
        the blob peaks that should be identified as degenerate.
        '''
        H, W = 100, 100

        fpbb = afwGeom.Box2I(afwGeom.Point2I(0, 0), afwGeom.Point2I(W - 1, H - 1))

        afwimg = afwImage.MaskedImageF(fpbb)
        imgbb = afwimg.getBBox()
        img = afwimg.getImage().getArray()

        var = afwimg.getVariance().getArray()
        var[:, :] = 1.

        blob_fwhm = 10.
        blob_psf = doubleGaussianPsf(99, 99, blob_fwhm, 2.*blob_fwhm, 0.03)

        fakepsf_fwhm = 3.
        fakepsf = gaussianPsf(11, 11, fakepsf_fwhm)

        blobimgs = []
        x = 75.
        XY = [(x, 35.), (x, 65.), (50., 50.)]
        flux = 1e6
        for x, y in XY:
            bim = blob_psf.computeImage(afwGeom.Point2D(x, y))
            bbb = bim.getBBox()
            bbb.clip(imgbb)

            bim = bim.Factory(bim, bbb)
            bim2 = bim.getArray()

            blobimg = np.zeros_like(img)
            blobimg[bbb.getMinY():bbb.getMaxY()+1, bbb.getMinX():bbb.getMaxX()+1] += flux*bim2
            blobimgs.append(blobimg)

            img[bbb.getMinY():bbb.getMaxY()+1,
                bbb.getMinX():bbb.getMaxX()+1] += flux * bim2

        # Run the detection code to get a ~ realistic footprint
        thresh = afwDet.createThreshold(5., 'value', True)
        fpSet = afwDet.FootprintSet(afwimg, thresh, 'DETECTED', 1)
        fps = fpSet.getFootprints()

        self.assertTrue(len(fps) == 1)

        # Add new peaks near to the first peaks that will be degenerate
        fp0 = fps[0]
        for x, y in XY:
            fp0.addPeak(x - 10, y + 6, 10)

        deb = deblend(fp0, afwimg, fakepsf, fakepsf_fwhm, verbose=True, removeDegenerateTemplates=True)

        self.assertTrue(deb.deblendedParents[0].peaks[3].degenerate)
        self.assertTrue(deb.deblendedParents[0].peaks[4].degenerate)
        self.assertTrue(deb.deblendedParents[0].peaks[5].degenerate)
    def testPeakRemoval(self):
        '''
        A simple example: three overlapping blobs (detected as 1
        footprint with three peaks).  Additional peaks are added near
        the blob peaks that should be identified as degenerate.
        '''
        H, W = 100, 100

        fpbb = geom.Box2I(geom.Point2I(0, 0), geom.Point2I(W - 1, H - 1))

        afwimg = afwImage.MaskedImageF(fpbb)
        imgbb = afwimg.getBBox()
        img = afwimg.getImage().getArray()

        var = afwimg.getVariance().getArray()
        var[:, :] = 1.

        blob_fwhm = 10.
        blob_psf = doubleGaussianPsf(99, 99, blob_fwhm, 2.*blob_fwhm, 0.03)

        fakepsf_fwhm = 3.
        fakepsf = gaussianPsf(11, 11, fakepsf_fwhm)

        blobimgs = []
        x = 75.
        XY = [(x, 35.), (x, 65.), (50., 50.)]
        flux = 1e6
        for x, y in XY:
            bim = blob_psf.computeImage(geom.Point2D(x, y))
            bbb = bim.getBBox()
            bbb.clip(imgbb)

            bim = bim.Factory(bim, bbb)
            bim2 = bim.getArray()

            blobimg = np.zeros_like(img)
            blobimg[bbb.getMinY():bbb.getMaxY()+1, bbb.getMinX():bbb.getMaxX()+1] += flux*bim2
            blobimgs.append(blobimg)

            img[bbb.getMinY():bbb.getMaxY()+1,
                bbb.getMinX():bbb.getMaxX()+1] += flux * bim2

        # Run the detection code to get a ~ realistic footprint
        thresh = afwDet.createThreshold(5., 'value', True)
        fpSet = afwDet.FootprintSet(afwimg, thresh, 'DETECTED', 1)
        fps = fpSet.getFootprints()

        self.assertTrue(len(fps) == 1)

        # Add new peaks near to the first peaks that will be degenerate
        fp0 = fps[0]
        for x, y in XY:
            fp0.addPeak(x - 10, y + 6, 10)

        deb = deblend(fp0, afwimg, fakepsf, fakepsf_fwhm, verbose=True, removeDegenerateTemplates=True)

        self.assertTrue(deb.deblendedParents[0].peaks[3].degenerate)
        self.assertTrue(deb.deblendedParents[0].peaks[4].degenerate)
        self.assertTrue(deb.deblendedParents[0].peaks[5].degenerate)
def main():

    log = Log.getLogger('foo')
    log.setLevel(Log.INFO)

    ny, nx = 256, 256

    fwhm0 = 5.0
    psf = measAlg.DoubleGaussianPsf(21, 21, fwhm0)
    flux = 1.0e6

    # make two sets of fake data, seconds set is missing a source
    nSrc = 4
    xy = randomCoords(nSrc)
    fluxs = [flux] * (nSrc - 1) + [0.7 * flux]
    mimg = makeFakeImage(nx, ny, xy, fluxs, [3.0 * fwhm0] * nSrc)
    mimg.writeFits("foo.fits")

    nSrcB = nSrc - 4
    mimgB = makeFakeImage(nx, ny, xy[0:nSrcB], fluxs[0:nSrcB],
                          [3.0 * fwhm0] * nSrcB)
    mimgB.writeFits("fooB.fits")

    # Run the detection
    fp = detect(mimg)

    # deblend mimgB (missing a peak) using the fp with the extra peak
    deb = measDeb.deblend(fp,
                          mimgB,
                          psf,
                          fwhm0,
                          verbose=True,
                          rampFluxAtEdge=True,
                          log=log)
    print("Deblended peaks: ", len(deb.peaks))

    fig = makePortionFigure(deb, mimg, mimgB)
    fig.savefig("test.png")
    def test2(self):
        '''
        A 1-d example, to test the stray-flux assignment.
        '''
        H, W = 1, 100

        fpbb = afwGeom.Box2I(afwGeom.Point2I(0, 0),
                             afwGeom.Point2I(W-1, H-1))
        afwimg = afwImage.MaskedImageF(fpbb)
        img = afwimg.getImage().getArray()

        var = afwimg.getVariance().getArray()
        var[:, :] = 1.

        y = 0
        img[y, 1:-1] = 10.

        img[0, 1] = 20.
        img[0, -2] = 20.

        fakepsf_fwhm = 1.
        fakepsf = gaussianPsf(1, 1, fakepsf_fwhm)

        # Run the detection code to get a ~ realistic footprint
        thresh = afwDet.createThreshold(5., 'value', True)
        fpSet = afwDet.FootprintSet(afwimg, thresh, 'DETECTED', 1)
        fps = fpSet.getFootprints()
        self.assertEqual(len(fps), 1)
        fp = fps[0]

        # WORKAROUND: the detection alg produces ONE peak, at (1,0),
        # rather than two.
        self.assertEqual(len(fp.getPeaks()), 1)
        fp.addPeak(W-2, y, float("NaN"))
        # print 'Added peak; peaks:', len(fp.getPeaks())
        # for pk in fp.getPeaks():
        #    print '  ', pk.getFx(), pk.getFy()

        # Change verbose to False to quiet down the meas_deblender.baseline logger
        deb = deblend(fp, afwimg, fakepsf, fakepsf_fwhm, verbose=True,
                      fitPsfs=False, )

        if doPlot:
            XX = np.arange(W+1).repeat(2)[1:-1]

            plt.clf()
            p1 = plt.plot(XX, img[y, :].repeat(2), 'g-', lw=3, alpha=0.3)

            for i, dpk in enumerate(deb.peaks):
                print(dpk)
                port = dpk.fluxPortion.getImage()
                bb = port.getBBox()
                YY = np.zeros(XX.shape)
                YY[bb.getMinX()*2: (bb.getMaxX()+1)*2] = port.getArray()[0, :].repeat(2)
                p2 = plt.plot(XX, YY, 'r-')

                simg = afwImage.ImageF(fpbb)
                dpk.strayFlux.insert(simg)
                p3 = plt.plot(XX, simg.getArray()[y, :].repeat(2), 'b-')

            plt.legend((p1[0], p2[0], p3[0]),
                       ('Parent Flux', 'Child portion', 'Child stray flux'))
            plt.ylim(-2, 22)
            plt.savefig(plotpat % 3)

        strays = []
        for i, dpk in enumerate(deb.deblendedParents[0].peaks):
            simg = afwImage.ImageF(fpbb)
            dpk.strayFlux.insert(simg)
            strays.append(simg.getArray())

        ssum = reduce(np.add, strays)

        starget = np.zeros(W)
        starget[2:-2] = 10.

        self.assertFloatsEqual(ssum, starget)

        X = np.arange(W)
        dx1 = X - 1.
        dx2 = X - (W-2)
        f1 = (1. / (1. + dx1**2))
        f2 = (1. / (1. + dx2**2))
        strayclip = 0.001
        fsum = f1 + f2
        f1[f1 < strayclip * fsum] = 0.
        f2[f2 < strayclip * fsum] = 0.

        s1 = f1 / (f1+f2) * 10.
        s2 = f2 / (f1+f2) * 10.

        s1[:2] = 0.
        s2[-2:] = 0.

        if doPlot:
            p4 = plt.plot(XX, s1.repeat(2), 'm-')
            plt.plot(XX, s2.repeat(2), 'm-')

            plt.legend((p1[0], p2[0], p3[0], p4[0]),
                       ('Parent Flux', 'Child portion', 'Child stray flux',
                        'Expected stray flux'))
            plt.ylim(-2, 22)
            plt.savefig(plotpat % 4)

        # test abs diff
        d = np.max(np.abs(s1 - strays[0]))
        self.assertLess(d, 1e-6)
        d = np.max(np.abs(s2 - strays[1]))
        self.assertLess(d, 1e-6)

        # test relative diff
        self.assertLess(np.max(np.abs(s1 - strays[0])/np.maximum(1e-3, s1)), 1e-6)
        self.assertLess(np.max(np.abs(s2 - strays[1])/np.maximum(1e-3, s2)), 1e-6)
    def test1(self):
        '''
        In this test, we create a test image containing two blobs, one 
        of which is truncated by the edge of the image.

        We run the detection code to get realistic peaks and
        footprints.
        
        We then test out the different edge treatments and assert that
        they do what they claim.  We also make plots, tests/edge*.png
        '''


        # Create fake image...
        H,W = 100,100
        fpbb = afwGeom.Box2I(afwGeom.Point2I(0,0),
                             afwGeom.Point2I(W-1,H-1))
        afwimg = afwImage.MaskedImageF(fpbb)
        imgbb = afwimg.getBBox()
        img = afwimg.getImage().getArray()

        var = afwimg.getVariance().getArray()
        var[:,:] = 1.
        
        blob_fwhm = 15.
        blob_psf = doubleGaussianPsf(201, 201, blob_fwhm, 3.*blob_fwhm, 0.03)
        fakepsf_fwhm = 5.
        S = int(np.ceil(fakepsf_fwhm * 2.)) * 2 + 1
        print 'S', S
        fakepsf = gaussianPsf(S, S, fakepsf_fwhm)
    
        # Create and save blob images, and add to image to deblend.
        blobimgs = []
        XY = [(50.,50.), (90.,50.)]
        flux = 1e6
        for x,y in XY:
            bim = blob_psf.computeImage(afwGeom.Point2D(x, y))
            bbb = bim.getBBox()
            bbb.clip(imgbb)
    
            bim = bim.Factory(bim, bbb)
            bim2 = bim.getArray()
    
            blobimg = np.zeros_like(img)
            blobimg[bbb.getMinY():bbb.getMaxY()+1,
                    bbb.getMinX():bbb.getMaxX()+1] += flux * bim2
            blobimgs.append(blobimg)
    
            img[bbb.getMinY():bbb.getMaxY()+1,
                bbb.getMinX():bbb.getMaxX()+1] += flux * bim2
    
        # Run the detection code to get a ~ realistic footprint
        thresh = afwDet.createThreshold(10., 'value', True)
        fpSet = afwDet.FootprintSet(afwimg, thresh, 'DETECTED', 1)
        fps = fpSet.getFootprints()
        print 'found', len(fps), 'footprints'
    
        # set EDGE bit on edge pixels.
        margin = 5
        lo = imgbb.getMin()
        lo.shift(afwGeom.Extent2I(margin, margin))
        hi = imgbb.getMax()
        hi.shift(afwGeom.Extent2I(-margin, -margin))
        goodbbox = afwGeom.Box2I(lo, hi)
        print 'Good bbox for setting EDGE pixels:', goodbbox
        print 'image bbox:', imgbb
        edgebit = afwimg.getMask().getPlaneBitMask("EDGE")
        print 'edgebit:', edgebit
        measAlg.SourceDetectionTask.setEdgeBits(afwimg, goodbbox, edgebit)
    
        if False:
            plt.clf()
            plt.imshow(afwimg.getMask().getArray(),
                       interpolation='nearest', origin='lower')
            plt.colorbar()
            plt.title('Mask')
            plt.savefig('mask.png')
    
            M = afwimg.getMask().getArray()
            for bit in range(32):
                mbit = (1 << bit)
                if not np.any(M & mbit):
                    continue
                plt.clf()
                plt.imshow(M & mbit,
                           interpolation='nearest', origin='lower')
                plt.colorbar()
                plt.title('Mask bit %i (0x%x)' % (bit, mbit))
                plt.savefig('mask-%02i.png' % bit)
    
        for fp in fps:
            print 'peaks:', len(fp.getPeaks())
            for pk in fp.getPeaks():
                print '  ', pk.getIx(), pk.getIy()
        assert(len(fps) == 1)
        fp = fps[0]
        assert(len(fp.getPeaks()) == 2)
        
        ima = dict(interpolation='nearest', origin='lower', #cmap='gray',
                   cmap='jet',
                   vmin=0, vmax=400)
        
        for j,(tt,kwa) in enumerate([
                ('No edge treatment', dict()),
                ('Ramp by PSF', dict(rampFluxAtEdge=True)),
                ('No clip at edge', dict(patchEdges=True)),
            ]):
            #print 'Deblending...'
            deb = deblend(fp, afwimg, fakepsf, fakepsf_fwhm, verbose=True,
                          **kwa)
            #print 'Result:', deb
            #print len(deb.peaks), 'deblended peaks'
    
            parent_img = afwImage.ImageF(fpbb)
            afwDet.copyWithinFootprintImage(fp, afwimg.getImage(), parent_img)
    
            X = [x for x,y in XY]
            Y = [y for x,y in XY]
            PX = [pk.getIx() for pk in fp.getPeaks()]
            PY = [pk.getIy() for pk in fp.getPeaks()]

            # Grab 1-d slices to make assertion about.
            symms = []
            monos = []
            symm1ds = []
            mono1ds = []
            yslice = H/2
            parent1d = img[yslice, :]
            for i,dpk in enumerate(deb.peaks):
                symm = dpk.origTemplate
                symms.append(symm)

                bbox = symm.getBBox()
                x0,y0 = bbox.getMinX(), bbox.getMinY()
                im = symm.getArray()
                h,w = im.shape
                oned = np.zeros(W)
                oned[x0: x0+w] = im[yslice-y0, :]
                symm1ds.append(oned)
    
                mono = afwImage.ImageF(fpbb)
                afwDet.copyWithinFootprintImage(dpk.templateFootprint,
                                                dpk.templateImage, mono)
                monos.append(mono)

                im = mono.getArray()
                bbox = mono.getBBox()
                x0,y0 = bbox.getMinX(), bbox.getMinY()
                h,w = im.shape
                oned = np.zeros(W)
                oned[x0: x0+w] = im[yslice-y0, :]
                mono1ds.append(oned)


            for i,(symm,mono) in enumerate(zip(symm1ds, mono1ds)):
                # for the first two cases, the basic symmetric
                # template for the second source drops to zero at <
                # ~75 where the symmetric part is outside the
                # footprint.
                if i == 1 and j in [0,1]:
                    self.assertTrue(np.all(symm[:74] == 0))
                if i == 1 and j == 2:
                    # For the third case, the 'symm' template gets
                    # "patched" with the parent's value
                    self.assertTrue(np.all((symm == parent1d)[:74]))
    
                if i == 1 and j == 0:
                    # No edge handling: mono template == 0
                    self.assertTrue(np.all(mono[:74] == 0))
                if i == 1 and j == 1:
                    # ramp by psf: zero up to ~65, ramps up
                    self.assertTrue(np.all(mono[:64] == 0))
                    self.assertTrue(np.any(mono[65:74] > 0))
                    self.assertTrue(np.all(np.diff(mono)[60:80] >= 0.))
                if i == 1 and j == 2:
                    # no edge clipping: profile is monotonic and positive.
                    self.assertTrue(np.all(np.diff(mono)[:85] >= 0.))
                    self.assertTrue(np.all(mono[:85] > 0.))


            if not doPlot:
                continue

            plt.clf()
            p1 = plt.plot(parent1d, 'b-', lw=3, alpha=0.5)
            for i,(symm,mono) in enumerate(zip(symm1ds, mono1ds)):
                p2 = plt.plot(symm, 'r-', lw=2, alpha=0.7)
                p3 = plt.plot(mono, 'g-')
            plt.legend((p1[0],p2[0],p3[0]), ('Parent','Symm template', 'Mono template'),
                       loc='upper left')
            plt.title('1-d slice: %s' % tt)
            fn = plotpat % (2*j+0)
            plt.savefig(fn)
            print 'Wrote', fn
            
            def myimshow(*args, **kwargs):
                x0,x1,y0,y1 = imExt(afwimg)
                plt.fill([x0,x0,x1,x1,x0],[y0,y1,y1,y0,y0], color=(1,1,0.8),
                         zorder=20)
                plt.imshow(*args, zorder=25, **kwargs)
                plt.xticks([]); plt.yticks([])
                plt.axis(imExt(afwimg))
    
            plt.clf()
    
            pa = dict(color='m', marker='.', linestyle='None', zorder=30)
            
            R,C = 3,6
            plt.subplot(R, C, (2*C) + 1)
            myimshow(img, **ima)
            ax = plt.axis()
            plt.plot(X, Y, **pa)
            plt.axis(ax)
            plt.title('Image')
    
            plt.subplot(R, C, (2*C) + 2)
            myimshow(parent_img.getArray(), **ima)
            ax = plt.axis()
            plt.plot(PX, PY, **pa)
            plt.axis(ax)
            plt.title('Footprint')
    
            sumimg = None
            for i,dpk in enumerate(deb.peaks):
    
                plt.subplot(R, C, i*C + 1)
                myimshow(blobimgs[i], **ima)
                ax = plt.axis()
                plt.plot(PX[i], PY[i], **pa)
                plt.axis(ax)
                plt.title('true')
    
                plt.subplot(R, C, i*C + 2)
                t = dpk.origTemplate
                myimshow(t.getArray(), extent=imExt(t), **ima)
                ax = plt.axis()
                plt.plot(PX[i], PY[i], **pa)
                plt.axis(ax)
                plt.title('symm')
    
                # monotonic template
                mimg = afwImage.ImageF(fpbb)
                afwDet.copyWithinFootprintImage(dpk.templateFootprint,
                                                dpk.templateImage, mimg)
    
                plt.subplot(R, C, i*C + 3)
                myimshow(mimg.getArray(), extent=imExt(mimg), **ima)
                ax = plt.axis()
                plt.plot(PX[i], PY[i], **pa)
                plt.axis(ax)
                plt.title('monotonic')
    
                plt.subplot(R, C, i*C + 4)
                port = dpk.fluxPortion.getImage()
                myimshow(port.getArray(), extent=imExt(port), **ima)
                plt.title('portion')
                ax = plt.axis()
                plt.plot(PX[i], PY[i], **pa)
                plt.axis(ax)
    
                if dpk.strayFlux is not None:
                    simg = afwImage.ImageF(fpbb)
                    dpk.strayFlux.insert(simg)
                
                    plt.subplot(R, C, i*C + 5)
                    myimshow(simg.getArray(), **ima)
                    plt.title('stray')
                    ax = plt.axis()
                    plt.plot(PX, PY, **pa)
                    plt.axis(ax)
    
                himg2 = afwImage.ImageF(fpbb)
                portion = dpk.getFluxPortion()
                portion.insert(himg2)
    
                if sumimg is None:
                    sumimg = himg2.getArray().copy()
                else:
                    sumimg += himg2.getArray()
                    
                plt.subplot(R, C, i*C + 6)
                myimshow(himg2.getArray(), **ima)
                plt.title('portion+stray')
                ax = plt.axis()
                plt.plot(PX, PY, **pa)
                plt.axis(ax)
    
            plt.subplot(R, C, (2*C) + C)
            myimshow(sumimg, **ima)
            ax = plt.axis()
            plt.plot(X, Y, **pa)
            plt.axis(ax)
            plt.title('Sum of deblends')
    
            plt.suptitle(tt)
            fn = plotpat % (2*j + 1)
            plt.savefig(fn)
            print 'Wrote', fn
Beispiel #7
0
    def __call__(self, source, exposure):
        fp = source.getFootprint()
        peaks = fp.getPeaks()
        peaksF = [pk.getF() for pk in peaks]
        fbb = fp.getBBox()
        fmask = afwImage.Mask(fbb)
        fmask.setXY0(fbb.getMinX(), fbb.getMinY())
        fp.spans.setMask(fmask, 1)

        psf = exposure.getPsf()
        psfSigPix = psf.computeShape().getDeterminantRadius()
        psfFwhmPix = psfSigPix * self.sigma2fwhm
        subimage = afwImage.ExposureF(exposure, bbox=fbb, deep=True)
        cpsf = deblendBaseline.CachingPsf(psf)

        # if fewer than 2 peaks, just return a copy of the source
        if len(peaks) < 2:
            return source.getTable().copyRecord(source)

        # make sure you only deblend 2 peaks; take the brighest and faintest
        speaks = [(p.getPeakValue(), p) for p in peaks]
        speaks.sort()
        dpeaks = [speaks[0][1], speaks[-1][1]]

        # and only set these peaks in the footprint (peaks is mutable)
        peaks.clear()
        for peak in dpeaks:
            peaks.append(peak)

        if True:
            # Call top-level deblend task
            fpres = deblendBaseline.deblend(fp,
                                            exposure.getMaskedImage(),
                                            psf,
                                            psfFwhmPix,
                                            log=self.log,
                                            psfChisqCut1=self.psfChisqCut1,
                                            psfChisqCut2=self.psfChisqCut2,
                                            psfChisqCut2b=self.psfChisqCut2b)
        else:
            # Call lower-level _fit_psf task

            # Prepare results structure
            fpres = deblendBaseline.DeblenderResult(fp,
                                                    exposure.getMaskedImage(),
                                                    psf, psfFwhmPix, self.log)

            for pki, (pk, pkres, pkF) in enumerate(
                    zip(dpeaks, fpres.deblendedParents[0].peaks, peaksF)):
                self.log.debug('Peak %i', pki)
                deblendBaseline._fitPsf(
                    fp, fmask, pk, pkF, pkres, fbb, dpeaks, peaksF, self.log,
                    cpsf, psfFwhmPix,
                    subimage.getMaskedImage().getImage(),
                    subimage.getMaskedImage().getVariance(), self.psfChisqCut1,
                    self.psfChisqCut2, self.psfChisqCut2b)

        deblendedSource = source.getTable().copyRecord(source)
        deblendedSource.setParent(source.getId())
        peakList = deblendedSource.getFootprint().getPeaks()
        peakList.clear()

        for i, peak in enumerate(fpres.deblendedParents[0].peaks):
            if peak.psfFitFlux > 0:
                suffix = "pos"
            else:
                suffix = "neg"
            c = peak.psfFitCenter
            self.log.info("deblended.centroid.dipole.psf.%s %f %f", suffix,
                          c[0], c[1])
            self.log.info("deblended.chi2dof.dipole.%s %f", suffix,
                          peak.psfFitChisq / peak.psfFitDof)
            self.log.info(
                "deblended.flux.dipole.psf.%s %f", suffix,
                peak.psfFitFlux * np.sum(peak.templateImage.getArray()))
            peakList.append(peak.peak)
        return deblendedSource
Beispiel #8
0
    def test1(self):
        '''
        In this test, we create a test image containing two blobs, one
        of which is truncated by the edge of the image.

        We run the detection code to get realistic peaks and
        footprints.

        We then test out the different edge treatments and assert that
        they do what they claim.  We also make plots, tests/edge*.png
        '''

        # Create fake image...
        H, W = 100, 100
        fpbb = geom.Box2I(geom.Point2I(0, 0),
                          geom.Point2I(W-1, H-1))
        afwimg = afwImage.MaskedImageF(fpbb)
        imgbb = afwimg.getBBox()
        img = afwimg.getImage().getArray()

        var = afwimg.getVariance().getArray()
        var[:, :] = 1.

        blob_fwhm = 15.
        blob_psf = doubleGaussianPsf(201, 201, blob_fwhm, 3.*blob_fwhm, 0.03)
        fakepsf_fwhm = 5.
        S = int(np.ceil(fakepsf_fwhm * 2.)) * 2 + 1
        print('S', S)
        fakepsf = gaussianPsf(S, S, fakepsf_fwhm)

        # Create and save blob images, and add to image to deblend.
        blobimgs = []
        XY = [(50., 50.), (90., 50.)]
        flux = 1e6
        for x, y in XY:
            bim = blob_psf.computeImage(geom.Point2D(x, y))
            bbb = bim.getBBox()
            bbb.clip(imgbb)

            bim = bim.Factory(bim, bbb)
            bim2 = bim.getArray()

            blobimg = np.zeros_like(img)
            blobimg[bbb.getMinY():bbb.getMaxY()+1,
                    bbb.getMinX():bbb.getMaxX()+1] += flux * bim2
            blobimgs.append(blobimg)

            img[bbb.getMinY():bbb.getMaxY()+1,
                bbb.getMinX():bbb.getMaxX()+1] += flux * bim2

        # Run the detection code to get a ~ realistic footprint
        thresh = afwDet.createThreshold(10., 'value', True)
        fpSet = afwDet.FootprintSet(afwimg, thresh, 'DETECTED', 1)
        fps = fpSet.getFootprints()
        print('found', len(fps), 'footprints')

        # set EDGE bit on edge pixels.
        margin = 5
        lo = imgbb.getMin()
        lo.shift(geom.Extent2I(margin, margin))
        hi = imgbb.getMax()
        hi.shift(geom.Extent2I(-margin, -margin))
        goodbbox = geom.Box2I(lo, hi)
        print('Good bbox for setting EDGE pixels:', goodbbox)
        print('image bbox:', imgbb)
        edgebit = afwimg.getMask().getPlaneBitMask("EDGE")
        print('edgebit:', edgebit)
        measAlg.SourceDetectionTask.setEdgeBits(afwimg, goodbbox, edgebit)

        if False:
            plt.clf()
            plt.imshow(afwimg.getMask().getArray(),
                       interpolation='nearest', origin='lower')
            plt.colorbar()
            plt.title('Mask')
            plt.savefig('mask.png')

            M = afwimg.getMask().getArray()
            for bit in range(32):
                mbit = (1 << bit)
                if not np.any(M & mbit):
                    continue
                plt.clf()
                plt.imshow(M & mbit,
                           interpolation='nearest', origin='lower')
                plt.colorbar()
                plt.title('Mask bit %i (0x%x)' % (bit, mbit))
                plt.savefig('mask-%02i.png' % bit)

        for fp in fps:
            print('peaks:', len(fp.getPeaks()))
            for pk in fp.getPeaks():
                print('  ', pk.getIx(), pk.getIy())
        assert(len(fps) == 1)
        fp = fps[0]
        assert(len(fp.getPeaks()) == 2)

        ima = dict(interpolation='nearest', origin='lower',  # cmap='gray',
                   cmap='jet',
                   vmin=0, vmax=400)

        for j, (tt, kwa) in enumerate([
            ('No edge treatment', dict()),
            ('Ramp by PSF', dict(rampFluxAtEdge=True)),
            ('No clip at edge', dict(patchEdges=True)),
        ]):
            # print 'Deblending...'
            # Change verbose to False to quiet down the meas_deblender.baseline logger
            deb = deblend(fp, afwimg, fakepsf, fakepsf_fwhm, verbose=True,
                          **kwa)
            # print 'Result:', deb
            # print len(deb.peaks), 'deblended peaks'

            parent_img = afwImage.ImageF(fpbb)
            fp.spans.copyImage(afwimg.getImage(), parent_img)

            X = [x for x, y in XY]
            Y = [y for x, y in XY]
            PX = [pk.getIx() for pk in fp.getPeaks()]
            PY = [pk.getIy() for pk in fp.getPeaks()]

            # Grab 1-d slices to make assertion about.
            symms = []
            monos = []
            symm1ds = []
            mono1ds = []
            yslice = H//2
            parent1d = img[yslice, :]
            for i, dpk in enumerate(deb.deblendedParents[0].peaks):
                symm = dpk.origTemplate
                symms.append(symm)

                bbox = symm.getBBox()
                x0, y0 = bbox.getMinX(), bbox.getMinY()
                im = symm.getArray()
                h, w = im.shape
                oned = np.zeros(W)
                oned[x0: x0+w] = im[yslice-y0, :]
                symm1ds.append(oned)

                mono = afwImage.ImageF(fpbb)
                dpk.templateFootprint.spans.copyImage(dpk.templateImage, mono)
                monos.append(mono)

                im = mono.getArray()
                bbox = mono.getBBox()
                x0, y0 = bbox.getMinX(), bbox.getMinY()
                h, w = im.shape
                oned = np.zeros(W)
                oned[x0: x0+w] = im[yslice-y0, :]
                mono1ds.append(oned)

            for i, (symm, mono) in enumerate(zip(symm1ds, mono1ds)):
                # for the first two cases, the basic symmetric
                # template for the second source drops to zero at <
                # ~75 where the symmetric part is outside the
                # footprint.
                if i == 1 and j in [0, 1]:
                    self.assertFloatsEqual(symm[:74], 0.0)
                if i == 1 and j == 2:
                    # For the third case, the 'symm' template gets
                    # "patched" with the parent's value
                    self.assertFloatsEqual(symm[:74], parent1d[:74])

                if i == 1 and j == 0:
                    # No edge handling: mono template == 0
                    self.assertFloatsEqual(mono[:74], 0.0)
                if i == 1 and j == 1:
                    # ramp by psf: zero up to ~65, ramps up
                    self.assertFloatsEqual(mono[:64], 0.0)
                    self.assertTrue(np.any(mono[65:74] > 0))
                    self.assertTrue(np.all(np.diff(mono)[60:80] >= 0.))
                if i == 1 and j == 2:
                    # no edge clipping: profile is monotonic and positive.
                    self.assertTrue(np.all(np.diff(mono)[:85] >= 0.))
                    self.assertTrue(np.all(mono[:85] > 0.))

            if not doPlot:
                continue

            plt.clf()
            p1 = plt.plot(parent1d, 'b-', lw=3, alpha=0.5)
            for i, (symm, mono) in enumerate(zip(symm1ds, mono1ds)):
                p2 = plt.plot(symm, 'r-', lw=2, alpha=0.7)
                p3 = plt.plot(mono, 'g-')
            plt.legend((p1[0], p2[0], p3[0]), ('Parent', 'Symm template', 'Mono template'),
                       loc='upper left')
            plt.title('1-d slice: %s' % tt)
            fn = plotpat % (2*j+0)
            plt.savefig(fn)
            print('Wrote', fn)

            def myimshow(*args, **kwargs):
                x0, x1, y0, y1 = imExt(afwimg)
                plt.fill([x0, x0, x1, x1, x0], [y0, y1, y1, y0, y0], color=(1, 1, 0.8),
                         zorder=20)
                plt.imshow(*args, zorder=25, **kwargs)
                plt.xticks([])
                plt.yticks([])
                plt.axis(imExt(afwimg))

            plt.clf()

            pa = dict(color='m', marker='.', linestyle='None', zorder=30)

            R, C = 3, 6
            plt.subplot(R, C, (2*C) + 1)
            myimshow(img, **ima)
            ax = plt.axis()
            plt.plot(X, Y, **pa)
            plt.axis(ax)
            plt.title('Image')

            plt.subplot(R, C, (2*C) + 2)
            myimshow(parent_img.getArray(), **ima)
            ax = plt.axis()
            plt.plot(PX, PY, **pa)
            plt.axis(ax)
            plt.title('Footprint')

            sumimg = None
            for i, dpk in enumerate(deb.deblendedParents[0].peaks):

                plt.subplot(R, C, i*C + 1)
                myimshow(blobimgs[i], **ima)
                ax = plt.axis()
                plt.plot(PX[i], PY[i], **pa)
                plt.axis(ax)
                plt.title('true')

                plt.subplot(R, C, i*C + 2)
                t = dpk.origTemplate
                myimshow(t.getArray(), extent=imExt(t), **ima)
                ax = plt.axis()
                plt.plot(PX[i], PY[i], **pa)
                plt.axis(ax)
                plt.title('symm')

                # monotonic template
                mimg = afwImage.ImageF(fpbb)
                afwDet.copyWithinFootprintImage(dpk.templateFootprint,
                                                dpk.templateImage, mimg)

                plt.subplot(R, C, i*C + 3)
                myimshow(mimg.getArray(), extent=imExt(mimg), **ima)
                ax = plt.axis()
                plt.plot(PX[i], PY[i], **pa)
                plt.axis(ax)
                plt.title('monotonic')

                plt.subplot(R, C, i*C + 4)
                port = dpk.fluxPortion.getImage()
                myimshow(port.getArray(), extent=imExt(port), **ima)
                plt.title('portion')
                ax = plt.axis()
                plt.plot(PX[i], PY[i], **pa)
                plt.axis(ax)

                if dpk.strayFlux is not None:
                    simg = afwImage.ImageF(fpbb)
                    dpk.strayFlux.insert(simg)

                    plt.subplot(R, C, i*C + 5)
                    myimshow(simg.getArray(), **ima)
                    plt.title('stray')
                    ax = plt.axis()
                    plt.plot(PX, PY, **pa)
                    plt.axis(ax)

                himg2 = afwImage.ImageF(fpbb)
                portion = dpk.getFluxPortion()
                portion.insert(himg2)

                if sumimg is None:
                    sumimg = himg2.getArray().copy()
                else:
                    sumimg += himg2.getArray()

                plt.subplot(R, C, i*C + 6)
                myimshow(himg2.getArray(), **ima)
                plt.title('portion+stray')
                ax = plt.axis()
                plt.plot(PX, PY, **pa)
                plt.axis(ax)

            plt.subplot(R, C, (2*C) + C)
            myimshow(sumimg, **ima)
            ax = plt.axis()
            plt.plot(X, Y, **pa)
            plt.axis(ax)
            plt.title('Sum of deblends')

            plt.suptitle(tt)
            fn = plotpat % (2*j + 1)
            plt.savefig(fn)
            print('Wrote', fn)
    def deblend(self, exposure, srcs, psf):
        """!
        Deblend.

        @param[in]     exposure Exposure to process
        @param[in,out] srcs     SourceCatalog containing sources detected on this exposure.
        @param[in]     psf      PSF

        @return None
        """
        self.log.info("Deblending %d sources" % len(srcs))

        from lsst.meas.deblender.baseline import deblend

        # find the median stdev in the image...
        mi = exposure.getMaskedImage()
        statsCtrl = afwMath.StatisticsControl()
        statsCtrl.setAndMask(mi.getMask().getPlaneBitMask(self.config.maskPlanes))
        stats = afwMath.makeStatistics(mi.getVariance(), mi.getMask(), afwMath.MEDIAN, statsCtrl)
        sigma1 = math.sqrt(stats.getValue(afwMath.MEDIAN))
        self.log.logdebug('sigma1: %g' % sigma1)

        n0 = len(srcs)
        nparents = 0
        for i,src in enumerate(srcs):
            #t0 = time.clock()

            fp = src.getFootprint()
            pks = fp.getPeaks()

            # Since we use the first peak for the parent object, we should propagate its flags
            # to the parent source.
            src.assign(pks[0], self.peakSchemaMapper)

            if len(pks) < 2:
                continue

            if self.isLargeFootprint(fp):
                src.set(self.tooBigKey, True)
                self.skipParent(src, mi.getMask())
                self.log.logdebug('Parent %i: skipping large footprint' % (int(src.getId()),))
                continue
            if self.isMasked(fp, exposure.getMaskedImage().getMask()):
                src.set(self.maskedKey, True)
                self.skipParent(src, mi.getMask())
                self.log.logdebug('Parent %i: skipping masked footprint' % (int(src.getId()),))
                continue

            nparents += 1
            bb = fp.getBBox()
            psf_fwhm = self._getPsfFwhm(psf, bb)

            self.log.logdebug('Parent %i: deblending %i peaks' % (int(src.getId()), len(pks)))

            self.preSingleDeblendHook(exposure, srcs, i, fp, psf, psf_fwhm, sigma1)
            npre = len(srcs)

            # This should really be set in deblend, but deblend doesn't have access to the src
            src.set(self.tooManyPeaksKey, len(fp.getPeaks()) > self.config.maxNumberOfPeaks)

            try:
                res = deblend(
                    fp, mi, psf, psf_fwhm, sigma1=sigma1,
                    psfChisqCut1 = self.config.psfChisq1,
                    psfChisqCut2 = self.config.psfChisq2,
                    psfChisqCut2b= self.config.psfChisq2b,
                    maxNumberOfPeaks=self.config.maxNumberOfPeaks,
                    strayFluxToPointSources=self.config.strayFluxToPointSources,
                    assignStrayFlux=self.config.assignStrayFlux,
                    findStrayFlux=(self.config.assignStrayFlux or self.config.findStrayFlux),
                    strayFluxAssignment=self.config.strayFluxRule,
                    rampFluxAtEdge=(self.config.edgeHandling == 'ramp'),
                    patchEdges=(self.config.edgeHandling == 'noclip'),
                    tinyFootprintSize=self.config.tinyFootprintSize,
                    clipStrayFluxFraction=self.config.clipStrayFluxFraction,
                    )
                if self.config.catchFailures:
                    src.set(self.deblendFailedKey, False)
            except Exception as e:
                if self.config.catchFailures:
                    self.log.warn("Unable to deblend source %d: %s" % (src.getId(), e))
                    src.set(self.deblendFailedKey, True)
                    import traceback
                    traceback.print_exc()
                    continue
                else:
                    raise

            kids = []
            nchild = 0
            for j, peak in enumerate(res.peaks):
                heavy = peak.getFluxPortion()
                if heavy is None or peak.skip:
                    src.set(self.deblendSkippedKey, True)
                    if not self.config.propagateAllPeaks:
                        # Don't care
                        continue
                    # We need to preserve the peak: make sure we have enough info to create a minimal child src
                    self.log.logdebug("Peak at (%i,%i) failed.  Using minimal default info for child." %
                                      (pks[j].getIx(), pks[j].getIy()))
                    if heavy is None:
                        # copy the full footprint and strip out extra peaks
                        foot = afwDet.Footprint(src.getFootprint())
                        peakList = foot.getPeaks()
                        peakList.clear()
                        peakList.append(peak.peak)
                        zeroMimg = afwImage.MaskedImageF(foot.getBBox())
                        heavy = afwDet.makeHeavyFootprint(foot, zeroMimg)
                    if peak.deblendedAsPsf:
                        if peak.psfFitFlux is None:
                            peak.psfFitFlux = 0.0
                        if peak.psfFitCenter is None:
                            peak.psfFitCenter = (peak.peak.getIx(), peak.peak.getIy())

                assert(len(heavy.getPeaks()) == 1)

                src.set(self.deblendSkippedKey, False)
                child = srcs.addNew(); nchild += 1
                child.assign(heavy.getPeaks()[0], self.peakSchemaMapper)
                child.setParent(src.getId())
                child.setFootprint(heavy)
                child.set(self.psfKey, peak.deblendedAsPsf)
                child.set(self.hasStrayFluxKey, peak.strayFlux is not None)
                if peak.deblendedAsPsf:
                    (cx,cy) = peak.psfFitCenter
                    child.set(self.psfCenterKey, afwGeom.Point2D(cx, cy))
                    child.set(self.psfFluxKey, peak.psfFitFlux)
                child.set(self.deblendRampedTemplateKey, peak.hasRampedTemplate)
                child.set(self.deblendPatchedTemplateKey, peak.patched)
                kids.append(child)

            # Child footprints may extend beyond the full extent of their parent's which
            # results in a failure of the replace-by-noise code to reinstate these pixels
            # to their original values.  The following updates the parent footprint
            # in-place to ensure it contains the full union of itself and all of its
            # children's footprints.
            src.getFootprint().include([child.getFootprint() for child in kids])

            src.set(self.nChildKey, nchild)

            self.postSingleDeblendHook(exposure, srcs, i, npre, kids, fp, psf, psf_fwhm, sigma1, res)
            #print 'Deblending parent id', src.getId(), 'took', time.clock() - t0


        n1 = len(srcs)
        self.log.info('Deblended: of %i sources, %i were deblended, creating %i children, total %i sources'
                      % (n0, nparents, n1-n0, n1))
    def deblend(self, exposure, srcs, psf):
        """Deblend.
        
        @param[in]     exposure Exposure to process
        @param[in,out] srcs     SourceCatalog containing sources detected on this exposure.
        @param[in]     psf      PSF
                       
        @return None
        """
        self.log.info("Deblending %d sources" % len(srcs))

        from lsst.meas.deblender.baseline import deblend
        import lsst.meas.algorithms as measAlg

        # find the median stdev in the image...
        mi = exposure.getMaskedImage()
        stats = afwMath.makeStatistics(mi.getVariance(), mi.getMask(),
                                       afwMath.MEDIAN)
        sigma1 = math.sqrt(stats.getValue(afwMath.MEDIAN))

        schema = srcs.getSchema()

        n0 = len(srcs)
        nparents = 0
        for i, src in enumerate(srcs):
            fp = src.getFootprint()
            pks = fp.getPeaks()
            if len(pks) < 2:
                continue
            nparents += 1
            bb = fp.getBBox()
            xc = int((bb.getMinX() + bb.getMaxX()) / 2.)
            yc = int((bb.getMinY() + bb.getMaxY()) / 2.)
            if hasattr(psf, 'getFwhm'):
                psf_fwhm = psf.getFwhm(xc, yc)
            else:
                pa = measAlg.PsfAttributes(psf, xc, yc)
                psfw = pa.computeGaussianWidth(
                    measAlg.PsfAttributes.ADAPTIVE_MOMENT)
                psf_fwhm = 2.35 * psfw

            self.log.logdebug('Parent %i: deblending %i peaks' %
                              (int(src.getId()), len(pks)))

            self.preSingleDeblendHook(exposure, srcs, i, fp, psf, psf_fwhm,
                                      sigma1)
            npre = len(srcs)

            # This should really be set in deblend, but deblend doesn't have access to the src
            src.set(self.tooManyPeaksKey,
                    len(fp.getPeaks()) > self.config.maxNumberOfPeaks)

            try:
                res = deblend(fp,
                              mi,
                              psf,
                              psf_fwhm,
                              sigma1=sigma1,
                              psf_chisq_cut1=self.config.psf_chisq_1,
                              psf_chisq_cut2=self.config.psf_chisq_2,
                              psf_chisq_cut2b=self.config.psf_chisq_2b,
                              maxNumberOfPeaks=self.config.maxNumberOfPeaks)
                src.set(self.deblendFailedKey, False)
            except Exception as e:
                self.log.warn("Error deblending source %d: %s" %
                              (src.getId(), e))
                src.set(self.deblendFailedKey, True)
                continue

            kids = []
            nchild = 0
            for j, pkres in enumerate(res.peaks):
                if pkres.out_of_bounds:
                    # skip this source?
                    self.log.logdebug(
                        'Skipping out-of-bounds peak at (%i,%i)' %
                        (pks[j].getIx(), pks[j].getIy()))
                    continue
                child = srcs.addNew()
                nchild += 1
                child.setParent(src.getId())
                if hasattr(pkres, 'heavy'):
                    child.setFootprint(pkres.heavy)
                    #maskbits = pkres.heavy.getMaskBitsSet()
                    #print 'Mask bits set: 0x%x' % maskbits

                child.set(self.psfKey, pkres.deblend_as_psf)
                (cx, cy) = pkres.center
                child.set(self.psfCenterKey, afwGeom.Point2D(cx, cy))
                child.set(self.psfFluxKey, pkres.psfflux)
                kids.append(child)

            src.set(self.nChildKey, nchild)

            self.postSingleDeblendHook(exposure, srcs, i, npre, kids, fp, psf,
                                       psf_fwhm, sigma1, res)

        n1 = len(srcs)
        self.log.info(
            'Deblended: of %i sources, %i were deblended, creating %i children, total %i sources'
            % (n0, nparents, n1 - n0, n1))
Beispiel #11
0
    def __call__(self, source, exposure):
        fp     = source.getFootprint()
        peaks  = fp.getPeaks()
        peaksF = [pk.getF() for pk in peaks]
        fbb    = fp.getBBox()
        fmask  = afwImage.MaskU(fbb)
        fmask.setXY0(fbb.getMinX(), fbb.getMinY())
        afwDetect.setMaskFromFootprint(fmask, fp, 1)

        psf        = exposure.getPsf()
        psfSigPix  = psf.computeShape().getDeterminantRadius()
        psfFwhmPix = psfSigPix * self.sigma2fwhm
        subimage   = afwImage.ExposureF(exposure, fbb, True)
        cpsf       = deblendBaseline.CachingPsf(psf)

        # if fewer than 2 peaks, just return a copy of the source
        if len(peaks) < 2:
            return source.getTable().copyRecord(source)

        # make sure you only deblend 2 peaks; take the brighest and faintest
        speaks = [(p.getPeakValue(), p) for p in peaks]
        speaks.sort()
        dpeaks = [speaks[0][1], speaks[-1][1]]

        # and only set these peaks in the footprint (peaks is mutable)
        peaks.clear()
        for peak in dpeaks:
            peaks.append(peak)

        if True:
            # Call top-level deblend task
            fpres = deblendBaseline.deblend(fp, exposure.getMaskedImage(), psf, psfFwhmPix,
                                            log = self.log,
                                            psfChisqCut1 = self.psfChisqCut1,
                                            psfChisqCut2 = self.psfChisqCut2,
                                            psfChisqCut2b = self.psfChisqCut2b)
        else:
            # Call lower-level _fit_psf task

            # Prepare results structure
            fpres = deblendBaseline.PerFootprint()
            fpres.peaks = []
            for pki,pk in enumerate(dpeaks):
                pkres = deblendBaseline.PerPeak()
                pkres.peak = pk
                pkres.pki = pki
                fpres.peaks.append(pkres)

            for pki,(pk,pkres,pkF) in enumerate(zip(dpeaks, fpres.peaks, peaksF)):
                self.log.logdebug('Peak %i' % pki)
                deblendBaseline._fitPsf(fp, fmask, pk, pkF, pkres, fbb, dpeaks, peaksF, self.log,
                                         cpsf, psfFwhmPix,
                                         subimage.getMaskedImage().getImage(),
                                         subimage.getMaskedImage().getVariance(),
                                         self.psfChisqCut1, self.psfChisqCut2, self.psfChisqCut2b)


        deblendedSource = source.getTable().copyRecord(source)
        deblendedSource.setParent(source.getId())
        peakList        = deblendedSource.getFootprint().getPeaks()
        peakList.clear()

        for i, peak in enumerate(fpres.peaks):
            if peak.psfFitFlux > 0:
                suffix = "pos"
            else:
                suffix = "neg"
            c = peak.psfFitCenter
            self.log.info("deblended.centroid.dipole.psf.%s %f %f" % (
                suffix, c[0], c[1]))
            self.log.info("deblended.chi2dof.dipole.%s %f" % (
                suffix, peak.psfFitChisq / peak.psfFitDof))
            self.log.info("deblended.flux.dipole.psf.%s %f" % (
                suffix, peak.psfFitFlux * np.sum(peak.templateImage.getArray())))
            peakList.append(peak.peak)
        return deblendedSource
Beispiel #12
0
def makeplots(butler,
              dataId,
              ps,
              sources=None,
              pids=None,
              minsize=0,
              maxpeaks=10):
    calexp = butler.get("calexp", **dataId)
    if sources is None:
        ss = butler.get('src', **dataId)
    else:
        ss = sources

    # print('Sources', ss)
    # print('Calexp', calexp)
    # print(dir(ss))

    srcs = {}
    families = {}
    for src in ss:
        sid = src.getId()
        srcs[sid] = src
        parent = src.getParent()
        if parent == 0:
            continue
        if parent not in families:
            families[parent] = []
        families[parent].append(src)
        # print 'Source', src
        # print '  ', dir(src)
        # print '  parent', src.getParent()
        # print '  footprint', src.getFootprint()

    print()
    lsstimg = calexp.getMaskedImage().getImage()
    img = lsstimg.getArray()
    schema = ss.getSchema()
    psfkey = schema.find("deblend_deblendedAsPsf").key
    nchildkey = schema.find("deblend_nChild").key
    toomanykey = schema.find("deblend_tooManyPeaks").key
    failedkey = schema.find("deblend_failed").key

    def getFlagString(src):
        ss = ['Nchild: %i' % src.get(nchildkey)]
        for key, s in [(psfkey, 'PSF'), (toomanykey, 'TooMany'),
                       (failedkey, 'Failed')]:
            if src.get(key):
                ss.append(s)
        return ', '.join(ss)

    plt.subplots_adjust(left=0.05,
                        right=0.95,
                        bottom=0.05,
                        top=0.9,
                        hspace=0.2,
                        wspace=0.3)

    sig1 = np.sqrt(
        np.median(calexp.getMaskedImage().getVariance().getArray().ravel()))
    pp = (img /
          np.sqrt(calexp.getMaskedImage().getVariance().getArray())).ravel()
    plt.clf()
    lo, hi = -4, 4
    n, b, p = plt.hist(img.ravel() / sig1,
                       100,
                       range=(lo, hi),
                       histtype='step',
                       color='b')
    plt.hist(pp, 100, range=(lo, hi), histtype='step', color='g')
    xx = np.linspace(lo, hi, 200)
    yy = 1. / (np.sqrt(2. * np.pi)) * np.exp(-0.5 * xx**2)
    yy *= sum(n) * (b[1] - b[0])
    plt.plot(xx, yy, 'k-', alpha=0.5)
    plt.xlim(lo, hi)
    plt.title('image-wide sig1: %.1f' % sig1)
    ps.savefig()

    for ifam, (p, kids) in enumerate(families.items()):

        parent = srcs[p]
        pid = parent.getId() & 0xffff
        if len(pids) and pid not in pids:
            # print('Skipping pid', pid)
            continue

        if len(kids) < minsize:
            print('Skipping parent', pid, ': n kids', len(kids))
            continue

        # if len(kids) < 5:
        #     print 'Skipping family with', len(kids)
        #     continue
        # print 'ifam', ifam
        # if ifam != 18:
        #     print 'skipping'
        #     continue

        print('Parent', parent)
        print('Kids', kids)

        print('Parent', parent.getId())
        print('Kids', [k.getId() for k in kids])

        pfoot = parent.getFootprint()
        bb = pfoot.getBBox()

        y0, y1, x0, x1 = bb.getMinY(), bb.getMaxY(), bb.getMinX(), bb.getMaxX()
        slc = slice(y0, y1 + 1), slice(x0, x1 + 1)

        ima = dict(interpolation='nearest',
                   origin='lower',
                   cmap='gray',
                   vmin=-10,
                   vmax=40)
        mn, mx = ima['vmin'], ima['vmax']

        if False:
            plt.clf()
            plt.imshow(img[slc], extent=bb_to_ext(bb), **ima)
            plt.title('Parent %i, %s' %
                      (parent.getId(), getFlagString(parent)))
            ax = plt.axis()
            x, y = bb_to_xy(bb)
            plt.plot(x, y, 'r-', lw=2)
            for i, kid in enumerate(kids):
                kfoot = kid.getFootprint()
                kbb = kfoot.getBBox()
                kx, ky = bb_to_xy(kbb, margin=0.4)
                plt.plot(kx, ky, 'm-')
            for pk in pfoot.getPeaks():
                plt.plot(pk.getIx(), pk.getIy(), 'r+', ms=10, mew=3)
            plt.axis(ax)
            ps.savefig()

        print('parent footprint:', pfoot)
        print('heavy?', pfoot.isHeavy())
        plt.clf()
        pimg, h = foot_to_img(pfoot, lsstimg)

        plt.imshow(img_to_rgb(pimg.getArray(), mn, mx),
                   extent=bb_to_ext(bb),
                   **ima)
        tt = 'Parent %i' % parent.getId()
        if not h:
            tt += ', no HFoot'
        tt += ', ' + getFlagString(parent)
        plt.title(tt)
        ax = plt.axis()
        plt.plot([x0, x0, x1, x1, x0], [y0, y1, y1, y0, y0], 'r-', lw=2)
        for i, kid in enumerate(kids):
            kfoot = kid.getFootprint()
            kbb = kfoot.getBBox()
            kx, ky = bb_to_xy(kbb, margin=-0.1)
            plt.plot(kx, ky, 'm-', lw=1.5)
        for pk in pfoot.getPeaks():
            plt.plot(pk.getIx(), pk.getIy(), 'r+', ms=10, mew=3)
        plt.axis(ax)
        ps.savefig()

        cols = int(np.ceil(np.sqrt(len(kids))))
        rows = int(np.ceil(len(kids) / float(cols)))

        if False:
            plt.clf()
            for i, kid in enumerate(kids):
                plt.subplot(rows, cols, 1 + i)
                kfoot = kid.getFootprint()
                print('kfoot:', kfoot)
                print('heavy?', kfoot.isHeavy())
                # print(dir(kid))
                kbb = kfoot.getBBox()
                ky0, ky1, kx0, kx1 = kbb.getMinY(), kbb.getMaxY(), kbb.getMinX(
                ), kbb.getMaxX()
                kslc = slice(ky0, ky1 + 1), slice(kx0, kx1 + 1)
                plt.imshow(img[kslc], extent=bb_to_ext(kbb), **ima)
                plt.title('Child %i' % kid.getId())
                plt.axis(ax)
            ps.savefig()

        plt.clf()
        for i, kid in enumerate(kids):
            plt.subplot(rows, cols, 1 + i)
            kfoot = kid.getFootprint()
            kbb = kfoot.getBBox()
            kimg, h = foot_to_img(kfoot, lsstimg)
            tt = getFlagString(kid)
            if not h:
                tt += ', no HFoot'
            plt.title('%s' % tt)
            if kimg is None:
                plt.axis(ax)
                continue
            plt.imshow(img_to_rgb(kimg.getArray(), mn, mx),
                       extent=bb_to_ext(kbb),
                       **ima)
            for pk in kfoot.getPeaks():
                plt.plot(pk.getIx(), pk.getIy(), 'g+', ms=10, mew=3)
            plt.axis(ax)
        plt.suptitle('Child HeavyFootprints')
        ps.savefig()

        print()
        print('Re-running deblender...')
        psf = calexp.getPsf()
        psf_fwhm = psf.computeShape().getDeterminantRadius() * 2.35
        deb = deblend(
            pfoot,
            calexp.getMaskedImage(),
            psf,
            psf_fwhm,
            verbose=True,
            maxNumberOfPeaks=maxpeaks,
            rampFluxAtEdge=True,
            clipStrayFluxFraction=0.01,
        )
        print('Got', deb)

        def getDebFlagString(kid):
            ss = []
            for k in [
                    'skip', 'outOfBounds', 'tinyFootprint', 'noValidPixels',
                ('deblendedAsPsf', 'PSF'), 'psfFitFailed', 'psfFitBadDof',
                    'psfFitBigDecenter', 'psfFitWithDecenter',
                    'failedSymmetricTemplate', 'hasRampedTemplate', 'patched'
            ]:
                if len(k) == 2:
                    k, s = k
                else:
                    s = k
                if getattr(kid, k):
                    ss.append(s)
            return ', '.join(ss)

        N = len(deb.peaks)
        cols = int(np.ceil(np.sqrt(N)))
        rows = int(np.ceil(N / float(cols)))

        for plotnum in range(4):
            plt.clf()
            for i, kid in enumerate(deb.peaks):
                # print 'child', kid
                # print '  flags:', getDebFlagString(kid)

                kfoot = None
                if plotnum == 0:
                    kfoot = kid.getFluxPortion(strayFlux=False)
                    supt = 'flux portion'
                elif plotnum == 1:
                    kfoot = kid.getFluxPortion(strayFlux=True)
                    supt = 'flux portion + stray'
                elif plotnum == 2:
                    kfoot = afwDet.makeHeavyFootprint(kid.templateFootprint,
                                                      kid.templateImage)
                    supt = 'template'
                elif plotnum == 3:
                    if kid.deblendedAsPsf:
                        kfoot = afwDet.makeHeavyFootprint(
                            kid.psfFootprint, kid.psfTemplate)
                        kfoot.normalize()
                        kfoot.clipToNonzero(kid.psfTemplate.getImage())
                        # print 'kfoot BB:', kfoot.getBBox()
                        # print 'Img bb:', kid.psfTemplate.getImage().getBBox()
                        # for sp in kfoot.getSpans():
                        #     print '  span', sp
                    else:
                        kfoot = afwDet.makeHeavyFootprint(
                            kid.templateFootprint, kid.templateImage)
                    supt = 'psf template'

                kimg, h = foot_to_img(kfoot, None)
                tt = 'kid %i: %s' % (i, getDebFlagString(kid))
                if not h:
                    tt += ', no HFoot'
                plt.subplot(rows, cols, 1 + i)
                plt.title('%s' % tt, fontsize=8)
                if kimg is None:
                    plt.axis(ax)
                    continue
                kbb = kfoot.getBBox()

                plt.imshow(img_to_rgb(kimg.getArray(), mn, mx),
                           extent=bb_to_ext(kbb),
                           **ima)

                # plt.imshow(kimg.getArray(), extent=bb_to_ext(kbb), **ima)

                plt.axis(ax)

            plt.suptitle(supt)
            ps.savefig()

        for i, kid in enumerate(deb.peaks):
            if not kid.deblendedAsPsf:
                continue
            plt.clf()

            ima = dict(interpolation='nearest', origin='lower', cmap='gray')
            # vmin=0, vmax=kid.psfFitFlux)

            plt.subplot(2, 4, 1)
            # plt.title('fit psf 0')
            # plt.imshow(kid.psfFitDebugPsf0Img.getArray(), **ima)
            # plt.colorbar()
            # plt.title('valid pixels')
            # plt.imshow(kid.psfFitDebugValidPix, vmin=0, vmax=1, **ima)
            plt.title('weights')
            plt.imshow(kid.psfFitDebugWeight, vmin=0, **ima)
            plt.xticks([])
            plt.yticks([])
            plt.colorbar()

            plt.subplot(2, 4, 7)
            plt.title('valid pixels')
            plt.imshow(kid.psfFitDebugValidPix, vmin=0, vmax=1, **ima)
            plt.xticks([])
            plt.yticks([])
            plt.colorbar()

            plt.subplot(2, 4, 2)
            # plt.title('ramp weights')
            # plt.imshow(kid.psfFitDebugRampWeight, vmin=0, vmax=1, **ima)
            # plt.colorbar()
            sig = np.sqrt(kid.psfFitDebugVar.getArray())
            data = kid.psfFitDebugStamp.getArray()
            model = kid.psfFitDebugPsfModel.getArray()
            chi = ((data - model) / sig)
            valid = kid.psfFitDebugValidPix

            plt.hist(np.clip((data / sig)[valid], -5, 5),
                     20,
                     range=(-5, 5),
                     histtype='step',
                     color='m')
            plt.hist(np.clip((model / sig)[valid], -5, 5),
                     20,
                     range=(-5, 5),
                     histtype='step',
                     color='r')
            plt.hist(np.clip(chi.ravel(), -5, 5),
                     20,
                     range=(-5, 5),
                     histtype='step',
                     color='g')
            n, b, p = plt.hist(np.clip(chi[valid], -5, 5),
                               20,
                               range=(-5, 5),
                               histtype='step',
                               color='b')

            xx = np.linspace(-5, 5, 200)
            yy = 1. / (np.sqrt(2. * np.pi)) * np.exp(-0.5 * xx**2)
            yy *= sum(n) * (b[1] - b[0])
            plt.plot(xx, yy, 'k-', alpha=0.5)

            plt.xlim(-5, 5)

            print('Sum of ramp weights:', np.sum(kid.psfFitDebugRampWeight))
            print('Quadrature sum of ramp weights:',
                  np.sqrt(np.sum(kid.psfFitDebugRampWeight**2)))
            print('Number of valid pix:', np.sum(kid.psfFitDebugValidPix))
            rw = kid.psfFitDebugRampWeight
            valid = kid.psfFitDebugValidPix
            # print 'valid values:', np.unique(valid)
            print('rw[valid]', np.sum(rw[valid]))
            print('rw range', rw.min(), rw.max())
            # print 'rw', rw.shape, rw.dtype
            # print 'valid', valid.shape, valid.dtype
            # print 'rw[valid]:', rw[valid]

            myresid = np.sum(kid.psfFitDebugValidPix *
                             kid.psfFitDebugRampWeight *
                             ((kid.psfFitDebugStamp.getArray() -
                               kid.psfFitDebugPsfModel.getArray()) /
                              np.sqrt(kid.psfFitDebugVar.getArray()))**2)
            print('myresid:', myresid)

            plt.subplot(2, 4, 8)
            N = 20000
            rwv = rw[valid]
            print('rwv', rwv)
            x = np.random.normal(size=(N, len(rwv)))
            ss = np.sum(rwv * x**2, axis=1)
            plt.hist(ss, 25)
            chi, dof = kid.psfFitBest
            plt.axvline(chi, color='r')

            mx = kid.psfFitDebugPsfModel.getArray().max()

            plt.subplot(2, 4, 3)
            # plt.title('fit psf')
            # plt.imshow(kid.psfFitDebugPsfImg.getArray(), **ima)
            # plt.colorbar()
            # plt.title('variance')
            # plt.imshow(kid.psfFitDebugVar.getArray(), vmin=0, **ima)
            # plt.colorbar()
            plt.title('model+noise')
            plt.imshow((kid.psfFitDebugPsfModel.getArray() +
                        sig * np.random.normal(size=sig.shape)) * valid,
                       vmin=0,
                       vmax=mx,
                       **ima)
            plt.xticks([])
            plt.yticks([])
            plt.colorbar()

            plt.subplot(2, 4, 4)
            plt.title('fit psf model')
            plt.imshow(kid.psfFitDebugPsfModel.getArray(),
                       vmin=0,
                       vmax=mx,
                       **ima)
            plt.xticks([])
            plt.yticks([])
            plt.colorbar()

            plt.subplot(2, 4, 5)
            plt.title('fit psf image')
            plt.imshow(kid.psfFitDebugStamp.getArray(), vmin=0, vmax=mx, **ima)
            plt.xticks([])
            plt.yticks([])
            plt.colorbar()

            chi = (kid.psfFitDebugValidPix *
                   (kid.psfFitDebugStamp.getArray() -
                    kid.psfFitDebugPsfModel.getArray()) /
                   np.sqrt(kid.psfFitDebugVar.getArray()))

            plt.subplot(2, 4, 6)
            plt.title('fit psf chi')
            plt.imshow(-chi,
                       vmin=-3,
                       vmax=3,
                       interpolation='nearest',
                       origin='lower',
                       cmap='RdBu')
            plt.xticks([])
            plt.yticks([])
            plt.colorbar()

            params = kid.psfFitParams
            (flux, sky, skyx, skyy) = params[:4]

            print('Model sum:', model.sum())
            print('- sky', model.sum() - np.sum(valid) * sky)

            sig1 = np.median(sig)

            chi, dof = kid.psfFitBest
            plt.suptitle(
                'PSF kid %i: flux %.1f, sky %.1f, sig1 %.1f' %
                (i, flux, sky, sig1))  # : chisq %g, dof %i' % (i, chi, dof))

            ps.savefig()
Beispiel #13
0
    def test1(self):
        '''
        A simple example: three overlapping blobs (detected as 1
        footprint with three peaks).  We artificially omit one of the
        peaks, meaning that its flux is "stray".  Assert that the
        stray flux assigned to the other two peaks accounts for all
        the flux in the parent.
        '''
        H, W = 100, 100

        fpbb = afwGeom.Box2I(afwGeom.Point2I(0, 0),
                             afwGeom.Point2I(W - 1, H - 1))

        afwimg = afwImage.MaskedImageF(fpbb)
        imgbb = afwimg.getBBox()
        img = afwimg.getImage().getArray()

        var = afwimg.getVariance().getArray()
        var[:, :] = 1.

        blob_fwhm = 10.
        blob_psf = doubleGaussianPsf(99, 99, blob_fwhm, 3. * blob_fwhm, 0.03)

        fakepsf_fwhm = 3.
        fakepsf = gaussianPsf(11, 11, fakepsf_fwhm)

        blobimgs = []
        x = 75.
        XY = [(x, 35.), (x, 65.), (50., 50.)]
        flux = 1e6
        for x, y in XY:
            bim = blob_psf.computeImage(afwGeom.Point2D(x, y))
            bbb = bim.getBBox()
            bbb.clip(imgbb)

            bim = bim.Factory(bim, bbb)
            bim2 = bim.getArray()

            blobimg = np.zeros_like(img)
            blobimg[bbb.getMinY():bbb.getMaxY() + 1,
                    bbb.getMinX():bbb.getMaxX() + 1] += flux * bim2
            blobimgs.append(blobimg)

            img[bbb.getMinY():bbb.getMaxY() + 1,
                bbb.getMinX():bbb.getMaxX() + 1] += flux * bim2

        # Run the detection code to get a ~ realistic footprint
        thresh = afwDet.createThreshold(5., 'value', True)
        fpSet = afwDet.FootprintSet(afwimg, thresh, 'DETECTED', 1)
        fps = fpSet.getFootprints()
        print('found', len(fps), 'footprints')
        pks2 = []
        for fp in fps:
            print('peaks:', len(fp.getPeaks()))
            for pk in fp.getPeaks():
                print('  ', pk.getIx(), pk.getIy())
                pks2.append((pk.getIx(), pk.getIy()))

        # The first peak in this list is the one we want to omit.
        fp0 = fps[0]
        fakefp = afwDet.Footprint(fp0.getSpans(), fp0.getBBox())
        for pk in fp0.getPeaks()[1:]:
            fakefp.getPeaks().append(pk)

        ima = dict(interpolation='nearest',
                   origin='lower',
                   cmap='gray',
                   vmin=0,
                   vmax=1e3)

        if doPlot:
            plt.figure(figsize=(12, 6))

            plt.clf()
            plt.suptitle('strayFlux.py: test1 input')
            plt.subplot(2, 2, 1)
            plt.title('Image')
            plt.imshow(img, **ima)
            ax = plt.axis()
            plt.plot([x for x, y in XY], [y for x, y in XY], 'r.')
            plt.axis(ax)
            for i, (b, (x, y)) in enumerate(zip(blobimgs, XY)):
                plt.subplot(2, 2, 2 + i)
                plt.title('Blob %i' % i)
                plt.imshow(b, **ima)
                ax = plt.axis()
                plt.plot(x, y, 'r.')
                plt.axis(ax)
            plt.savefig(plotpat % 1)

        # Change verbose to False to quiet down the meas_deblender.baseline logger
        deb = deblend(fakefp, afwimg, fakepsf, fakepsf_fwhm, verbose=True)
        parent_img = afwImage.ImageF(fpbb)
        fakefp.spans.copyImage(afwimg.getImage(), parent_img)

        if doPlot:

            def myimshow(*args, **kwargs):
                plt.imshow(*args, **kwargs)
                plt.xticks([])
                plt.yticks([])
                plt.axis(imExt(afwimg))

            plt.clf()
            plt.suptitle('strayFlux.py: test1 results')
            #R,C = 3,5
            R, C = 3, 4
            plt.subplot(R, C, (2 * C) + 1)
            plt.title('Image')
            myimshow(img, **ima)
            ax = plt.axis()
            plt.plot([x for x, y in XY], [y for x, y in XY], 'r.')
            plt.axis(ax)

            plt.subplot(R, C, (2 * C) + 2)
            plt.title('Parent footprint')
            myimshow(parent_img.getArray(), **ima)
            ax = plt.axis()
            plt.plot([pk.getIx() for pk in fakefp.getPeaks()],
                     [pk.getIy() for pk in fakefp.getPeaks()], 'r.')
            plt.axis(ax)

            sumimg = None
            for i, dpk in enumerate(deb.peaks):
                plt.subplot(R, C, i * C + 1)
                plt.title('ch%i symm' % i)
                symm = dpk.templateImage
                myimshow(symm.getArray(), extent=imExt(symm), **ima)

                plt.subplot(R, C, i * C + 2)
                plt.title('ch%i portion' % i)
                port = dpk.fluxPortion.getImage()
                myimshow(port.getArray(), extent=imExt(port), **ima)

                himg = afwImage.ImageF(fpbb)
                heavy = dpk.getFluxPortion(strayFlux=False)
                heavy.insert(himg)

                # plt.subplot(R, C, i*C + 3)
                # plt.title('ch%i heavy' % i)
                # myimshow(himg.getArray(), **ima)
                # ax = plt.axis()
                # plt.plot([x for x,y in XY], [y for x,y in XY], 'r.')
                # plt.axis(ax)

                simg = afwImage.ImageF(fpbb)
                dpk.strayFlux.insert(simg)

                plt.subplot(R, C, i * C + 3)
                plt.title('ch%i stray' % i)
                myimshow(simg.getArray(), **ima)
                ax = plt.axis()
                plt.plot([x for x, y in XY], [y for x, y in XY], 'r.')
                plt.axis(ax)

                himg2 = afwImage.ImageF(fpbb)
                heavy = dpk.getFluxPortion(strayFlux=True)
                heavy.insert(himg2)

                if sumimg is None:
                    sumimg = himg2.getArray().copy()
                else:
                    sumimg += himg2.getArray()

                plt.subplot(R, C, i * C + 4)
                myimshow(himg2.getArray(), **ima)
                plt.title('ch%i total' % i)
                ax = plt.axis()
                plt.plot([x for x, y in XY], [y for x, y in XY], 'r.')
                plt.axis(ax)

            plt.subplot(R, C, (2 * C) + C)
            myimshow(sumimg, **ima)
            ax = plt.axis()
            plt.plot([x for x, y in XY], [y for x, y in XY], 'r.')
            plt.axis(ax)
            plt.title('Sum of deblends')

            plt.savefig(plotpat % 2)

        # Compute the sum-of-children image
        sumimg = None
        for i, dpk in enumerate(deb.deblendedParents[0].peaks):
            himg2 = afwImage.ImageF(fpbb)
            dpk.getFluxPortion().insert(himg2)
            if sumimg is None:
                sumimg = himg2.getArray().copy()
            else:
                sumimg += himg2.getArray()

        # Sum of children ~= Original image inside footprint (parent_img)

        absdiff = np.max(np.abs(sumimg - parent_img.getArray()))
        print('Max abs diff:', absdiff)
        imgmax = parent_img.getArray().max()
        print('Img max:', imgmax)
        self.assertLess(absdiff, imgmax * 1e-6)
Beispiel #14
0
    def test2(self):
        '''
        A 1-d example, to test the stray-flux assignment.
        '''
        H, W = 1, 100

        fpbb = afwGeom.Box2I(afwGeom.Point2I(0, 0),
                             afwGeom.Point2I(W - 1, H - 1))
        afwimg = afwImage.MaskedImageF(fpbb)
        img = afwimg.getImage().getArray()

        var = afwimg.getVariance().getArray()
        var[:, :] = 1.

        y = 0
        img[y, 1:-1] = 10.

        img[0, 1] = 20.
        img[0, -2] = 20.

        fakepsf_fwhm = 1.
        fakepsf = gaussianPsf(1, 1, fakepsf_fwhm)

        # Run the detection code to get a ~ realistic footprint
        thresh = afwDet.createThreshold(5., 'value', True)
        fpSet = afwDet.FootprintSet(afwimg, thresh, 'DETECTED', 1)
        fps = fpSet.getFootprints()
        self.assertEqual(len(fps), 1)
        fp = fps[0]

        # WORKAROUND: the detection alg produces ONE peak, at (1,0),
        # rather than two.
        self.assertEqual(len(fp.getPeaks()), 1)
        fp.addPeak(W - 2, y, float("NaN"))
        # print 'Added peak; peaks:', len(fp.getPeaks())
        # for pk in fp.getPeaks():
        #    print '  ', pk.getFx(), pk.getFy()

        # Change verbose to False to quiet down the meas_deblender.baseline logger
        deb = deblend(
            fp,
            afwimg,
            fakepsf,
            fakepsf_fwhm,
            verbose=True,
            fitPsfs=False,
        )

        if doPlot:
            XX = np.arange(W + 1).repeat(2)[1:-1]

            plt.clf()
            p1 = plt.plot(XX, img[y, :].repeat(2), 'g-', lw=3, alpha=0.3)

            for i, dpk in enumerate(deb.peaks):
                print(dpk)
                port = dpk.fluxPortion.getImage()
                bb = port.getBBox()
                YY = np.zeros(XX.shape)
                YY[bb.getMinX() * 2:(bb.getMaxX() + 1) *
                   2] = port.getArray()[0, :].repeat(2)
                p2 = plt.plot(XX, YY, 'r-')

                simg = afwImage.ImageF(fpbb)
                dpk.strayFlux.insert(simg)
                p3 = plt.plot(XX, simg.getArray()[y, :].repeat(2), 'b-')

            plt.legend((p1[0], p2[0], p3[0]),
                       ('Parent Flux', 'Child portion', 'Child stray flux'))
            plt.ylim(-2, 22)
            plt.savefig(plotpat % 3)

        strays = []
        for i, dpk in enumerate(deb.deblendedParents[0].peaks):
            simg = afwImage.ImageF(fpbb)
            dpk.strayFlux.insert(simg)
            strays.append(simg.getArray())

        ssum = reduce(np.add, strays)

        starget = np.zeros(W)
        starget[2:-2] = 10.

        self.assertFloatsEqual(ssum, starget)

        X = np.arange(W)
        dx1 = X - 1.
        dx2 = X - (W - 2)
        f1 = (1. / (1. + dx1**2))
        f2 = (1. / (1. + dx2**2))
        strayclip = 0.001
        fsum = f1 + f2
        f1[f1 < strayclip * fsum] = 0.
        f2[f2 < strayclip * fsum] = 0.

        s1 = f1 / (f1 + f2) * 10.
        s2 = f2 / (f1 + f2) * 10.

        s1[:2] = 0.
        s2[-2:] = 0.

        if doPlot:
            p4 = plt.plot(XX, s1.repeat(2), 'm-')
            plt.plot(XX, s2.repeat(2), 'm-')

            plt.legend((p1[0], p2[0], p3[0], p4[0]),
                       ('Parent Flux', 'Child portion', 'Child stray flux',
                        'Expected stray flux'))
            plt.ylim(-2, 22)
            plt.savefig(plotpat % 4)

        # test abs diff
        d = np.max(np.abs(s1 - strays[0]))
        self.assertLess(d, 1e-6)
        d = np.max(np.abs(s2 - strays[1]))
        self.assertLess(d, 1e-6)

        # test relative diff
        self.assertLess(np.max(np.abs(s1 - strays[0]) / np.maximum(1e-3, s1)),
                        1e-6)
        self.assertLess(np.max(np.abs(s2 - strays[1]) / np.maximum(1e-3, s2)),
                        1e-6)
    def test1(self):
        '''
        A simple example: three overlapping blobs (detected as 1
        footprint with three peaks).  We artificially omit one of the
        peaks, meaning that its flux is "stray".  Assert that the
        stray flux assigned to the other two peaks accounts for all
        the flux in the parent.
        '''
        H, W = 100, 100

        fpbb = afwGeom.Box2I(afwGeom.Point2I(0, 0),
                             afwGeom.Point2I(W-1, H-1))

        afwimg = afwImage.MaskedImageF(fpbb)
        imgbb = afwimg.getBBox()
        img = afwimg.getImage().getArray()

        var = afwimg.getVariance().getArray()
        var[:, :] = 1.

        blob_fwhm = 10.
        blob_psf = doubleGaussianPsf(99, 99, blob_fwhm, 3.*blob_fwhm, 0.03)

        fakepsf_fwhm = 3.
        fakepsf = gaussianPsf(11, 11, fakepsf_fwhm)

        blobimgs = []
        x = 75.
        XY = [(x, 35.), (x, 65.), (50., 50.)]
        flux = 1e6
        for x, y in XY:
            bim = blob_psf.computeImage(afwGeom.Point2D(x, y))
            bbb = bim.getBBox()
            bbb.clip(imgbb)

            bim = bim.Factory(bim, bbb)
            bim2 = bim.getArray()

            blobimg = np.zeros_like(img)
            blobimg[bbb.getMinY():bbb.getMaxY()+1,
                    bbb.getMinX():bbb.getMaxX()+1] += flux * bim2
            blobimgs.append(blobimg)

            img[bbb.getMinY():bbb.getMaxY()+1,
                bbb.getMinX():bbb.getMaxX()+1] += flux * bim2

        # Run the detection code to get a ~ realistic footprint
        thresh = afwDet.createThreshold(5., 'value', True)
        fpSet = afwDet.FootprintSet(afwimg, thresh, 'DETECTED', 1)
        fps = fpSet.getFootprints()
        print('found', len(fps), 'footprints')
        pks2 = []
        for fp in fps:
            print('peaks:', len(fp.getPeaks()))
            for pk in fp.getPeaks():
                print('  ', pk.getIx(), pk.getIy())
                pks2.append((pk.getIx(), pk.getIy()))

        # The first peak in this list is the one we want to omit.
        fp0 = fps[0]
        fakefp = afwDet.Footprint(fp0.getSpans(), fp0.getBBox())
        for pk in fp0.getPeaks()[1:]:
            fakefp.getPeaks().append(pk)

        ima = dict(interpolation='nearest', origin='lower', cmap='gray',
                   vmin=0, vmax=1e3)

        if doPlot:
            plt.figure(figsize=(12, 6))

            plt.clf()
            plt.suptitle('strayFlux.py: test1 input')
            plt.subplot(2, 2, 1)
            plt.title('Image')
            plt.imshow(img, **ima)
            ax = plt.axis()
            plt.plot([x for x, y in XY], [y for x, y in XY], 'r.')
            plt.axis(ax)
            for i, (b, (x, y)) in enumerate(zip(blobimgs, XY)):
                plt.subplot(2, 2, 2+i)
                plt.title('Blob %i' % i)
                plt.imshow(b, **ima)
                ax = plt.axis()
                plt.plot(x, y, 'r.')
                plt.axis(ax)
            plt.savefig(plotpat % 1)

        # Change verbose to False to quiet down the meas_deblender.baseline logger
        deb = deblend(fakefp, afwimg, fakepsf, fakepsf_fwhm, verbose=True)
        parent_img = afwImage.ImageF(fpbb)
        fakefp.spans.copyImage(afwimg.getImage(), parent_img)

        if doPlot:
            def myimshow(*args, **kwargs):
                plt.imshow(*args, **kwargs)
                plt.xticks([])
                plt.yticks([])
                plt.axis(imExt(afwimg))

            plt.clf()
            plt.suptitle('strayFlux.py: test1 results')
            # R,C = 3,5
            R, C = 3, 4
            plt.subplot(R, C, (2*C) + 1)
            plt.title('Image')
            myimshow(img, **ima)
            ax = plt.axis()
            plt.plot([x for x, y in XY], [y for x, y in XY], 'r.')
            plt.axis(ax)

            plt.subplot(R, C, (2*C) + 2)
            plt.title('Parent footprint')
            myimshow(parent_img.getArray(), **ima)
            ax = plt.axis()
            plt.plot([pk.getIx() for pk in fakefp.getPeaks()],
                     [pk.getIy() for pk in fakefp.getPeaks()], 'r.')
            plt.axis(ax)

            sumimg = None
            for i, dpk in enumerate(deb.peaks):
                plt.subplot(R, C, i*C + 1)
                plt.title('ch%i symm' % i)
                symm = dpk.templateImage
                myimshow(symm.getArray(), extent=imExt(symm), **ima)

                plt.subplot(R, C, i*C + 2)
                plt.title('ch%i portion' % i)
                port = dpk.fluxPortion.getImage()
                myimshow(port.getArray(), extent=imExt(port), **ima)

                himg = afwImage.ImageF(fpbb)
                heavy = dpk.getFluxPortion(strayFlux=False)
                heavy.insert(himg)

                # plt.subplot(R, C, i*C + 3)
                # plt.title('ch%i heavy' % i)
                # myimshow(himg.getArray(), **ima)
                # ax = plt.axis()
                # plt.plot([x for x,y in XY], [y for x,y in XY], 'r.')
                # plt.axis(ax)

                simg = afwImage.ImageF(fpbb)
                dpk.strayFlux.insert(simg)

                plt.subplot(R, C, i*C + 3)
                plt.title('ch%i stray' % i)
                myimshow(simg.getArray(), **ima)
                ax = plt.axis()
                plt.plot([x for x, y in XY], [y for x, y in XY], 'r.')
                plt.axis(ax)

                himg2 = afwImage.ImageF(fpbb)
                heavy = dpk.getFluxPortion(strayFlux=True)
                heavy.insert(himg2)

                if sumimg is None:
                    sumimg = himg2.getArray().copy()
                else:
                    sumimg += himg2.getArray()

                plt.subplot(R, C, i*C + 4)
                myimshow(himg2.getArray(), **ima)
                plt.title('ch%i total' % i)
                ax = plt.axis()
                plt.plot([x for x, y in XY], [y for x, y in XY], 'r.')
                plt.axis(ax)

            plt.subplot(R, C, (2*C) + C)
            myimshow(sumimg, **ima)
            ax = plt.axis()
            plt.plot([x for x, y in XY], [y for x, y in XY], 'r.')
            plt.axis(ax)
            plt.title('Sum of deblends')

            plt.savefig(plotpat % 2)

        # Compute the sum-of-children image
        sumimg = None
        for i, dpk in enumerate(deb.deblendedParents[0].peaks):
            himg2 = afwImage.ImageF(fpbb)
            dpk.getFluxPortion().insert(himg2)
            if sumimg is None:
                sumimg = himg2.getArray().copy()
            else:
                sumimg += himg2.getArray()

        # Sum of children ~= Original image inside footprint (parent_img)

        absdiff = np.max(np.abs(sumimg - parent_img.getArray()))
        print('Max abs diff:', absdiff)
        imgmax = parent_img.getArray().max()
        print('Img max:', imgmax)
        self.assertLess(absdiff, imgmax*1e-6)
def main():
    '''
    Runs the deblender and creates plots for the "design document",
    doc/design.tex.  See the file NOTES for how to get set up to the
    point where you can actually run this on data.
    '''

    from optparse import OptionParser
    parser = OptionParser()
    parser.add_option('--root',
                      dest='root',
                      help='Root directory for Subaru data')
    parser.add_option('--outroot',
                      '-o',
                      dest='outroot',
                      help='Output root directory for Subaru data')
    parser.add_option('--sources', help='Read a FITS table of sources')
    parser.add_option('--calexp', help='Read a FITS calexp')
    parser.add_option('--psf', help='Read a FITS PSF')

    parser.add_option('--drill',
                      '-D',
                      dest='drill',
                      action='append',
                      type=str,
                      default=[],
                      help='Drill down on individual source IDs')
    parser.add_option(
        '--drillxy',
        dest='drillxy',
        action='append',
        type=str,
        default=[],
        help='Drill down on individual source positions, eg 132,46;54,67')
    parser.add_option('--visit',
                      dest='visit',
                      type=int,
                      default=108792,
                      help='Suprimecam visit id')
    parser.add_option('--ccd',
                      dest='ccd',
                      type=int,
                      default=5,
                      help='Suprimecam CCD number')
    parser.add_option('--prefix',
                      dest='prefix',
                      default='design-',
                      help='plot filename prefix')
    parser.add_option('--suffix',
                      dest='suffix',
                      default=None,
                      help='plot filename suffix (default: ".png")')
    parser.add_option(
        '--pat',
        dest='pat',
        help=
        'Plot filename pattern: eg, "design-%(pid)04i-%(name).png"; overrides --prefix and --suffix'
    )
    parser.add_option('--pdf',
                      dest='pdf',
                      action='store_true',
                      default=False,
                      help='save in PDF format?')
    parser.add_option('-v', dest='verbose', action='store_true')
    parser.add_option('--figw',
                      dest='figw',
                      type=float,
                      help='Figure window width (inches)',
                      default=4.)
    parser.add_option('--figh',
                      dest='figh',
                      type=float,
                      help='Figure window height (inches)',
                      default=4.)
    parser.add_option('--order',
                      dest='order',
                      type=str,
                      help='Child order: eg 3,0,1,2')

    parser.add_option('--sdss',
                      dest='sec',
                      action='store_const',
                      const='sdss',
                      help='Produce plots for the SDSS section.')
    parser.add_option('--mono',
                      dest='sec',
                      action='store_const',
                      const='mono',
                      help='Produce plots for the "monotonic" section.')
    parser.add_option('--median',
                      dest='sec',
                      action='store_const',
                      const='median',
                      help='Produce plots for the "median filter" section.')
    parser.add_option('--ramp',
                      dest='sec',
                      action='store_const',
                      const='ramp',
                      help='Produce plots for the "ramp edges" section.')
    parser.add_option(
        '--ramp2',
        dest='sec',
        action='store_const',
        const='ramp2',
        help='Produce plots for the "ramp edges + stray flux" section.')
    parser.add_option('--patch',
                      dest='sec',
                      action='store_const',
                      const='patch',
                      help='Produce plots for the "patch edges" section.')

    opt, args = parser.parse_args()

    # Logging
    if opt.verbose:
        lsst.log.setLevel('', lsst.log.DEBUG)
    else:
        lsst.log.setLevel('', lsst.log.INFO)

    if opt.sec is None:
        opt.sec = 'sdss'
    if opt.pdf:
        if opt.suffix is None:
            opt.suffix = ''
        opt.suffix += '.pdf'
    if not opt.suffix:
        opt.suffix = '.png'

    if opt.pat:
        plotpattern = opt.pat
    else:
        plotpattern = opt.prefix + '%(pid)04i-%(name)s' + opt.suffix

    if opt.order is not None:
        opt.order = [int(x) for x in opt.order.split(',')]
        invorder = np.zeros(len(opt.order))
        invorder[opt.order] = np.arange(len(opt.order))

    def mapchild(i):
        if opt.order is None:
            return i
        return invorder[i]

    def savefig(pid, figname):
        fn = plotpattern % dict(pid=pid, name=figname)
        plt.savefig(fn)

    # Load data using the butler, if desired
    dr = None
    if opt.sources is None or opt.calexp is None:
        print('Creating DataRef...')
        dr = getSuprimeDataref(opt.visit,
                               opt.ccd,
                               rootdir=opt.root,
                               outrootdir=opt.outroot)
        print('Got', dr)

    # Which parent ids / deblend families are we going to plot?
    keepids = None
    if len(opt.drill):
        keepids = []
        for d in opt.drill:
            for dd in d.split(','):
                keepids.append(int(dd))
        print('Keeping parent ids', keepids)

    keepxys = None
    if len(opt.drillxy):
        keepxys = []
        for d in opt.drillxy:
            for dd in d.split(';'):
                xy = dd.split(',')
                assert (len(xy) == 2)
                keepxys.append((int(xy[0]), int(xy[1])))
        print('Keeping parents at xy', keepxys)

    # Read from butler or local file
    cat = readCatalog(opt.sources,
                      None,
                      dataref=dr,
                      keepids=keepids,
                      keepxys=keepxys,
                      patargs=dict(visit=opt.visit, ccd=opt.ccd))
    print('Got', len(cat), 'sources')

    # Load data from butler or local files
    if opt.calexp is not None:
        print('Reading exposure from', opt.calexp)
        exposure = afwImage.ExposureF(opt.calexp)
    else:
        exposure = dr.get('calexp')
    print('Exposure', exposure)
    mi = exposure.getMaskedImage()

    if opt.psf is not None:
        print('Reading PSF from', opt.psf)
        psf = afwDet.Psf.readFits(opt.psf)
        print('Got', psf)
    elif dr:
        psf = dr.get('psf')
    else:
        psf = exposure.getPsf()

    sigma1 = get_sigma1(mi)

    fams = getFamilies(cat)
    print(len(fams), 'deblend families')

    if False:
        for j, (parent, children) in enumerate(fams):
            print('parent', parent)
            print('children', children)
            plotDeblendFamily(mi,
                              parent,
                              children,
                              cat,
                              sigma1,
                              ellipses=False)
            fn = '%04i.png' % parent.getId()
            plt.savefig(fn)
            print('wrote', fn)

    def nlmap(X):
        return np.arcsinh(X / (3. * sigma1))

    def myimshow(im, **kwargs):
        kwargs = kwargs.copy()
        mn = kwargs.get('vmin', -5 * sigma1)
        kwargs['vmin'] = nlmap(mn)
        mx = kwargs.get('vmax', 100 * sigma1)
        kwargs['vmax'] = nlmap(mx)
        plt.imshow(nlmap(im), **kwargs)

    plt.figure(figsize=(opt.figw, opt.figh))
    plt.subplot(1, 1, 1)
    plt.subplots_adjust(left=0.01,
                        right=0.99,
                        bottom=0.01,
                        top=0.99,
                        wspace=0.05,
                        hspace=0.1)

    # Make plots for each deblend family.

    for j, (parent, children) in enumerate(fams):
        print('parent', parent.getId())
        print('children', [ch.getId() for ch in children])
        print('parent x,y', parent.getX(), parent.getY())

        pid = parent.getId()
        fp = parent.getFootprint()
        bb = fp.getBBox()
        pim = footprintToImage(parent.getFootprint(), mi).getArray()
        pext = getExtent(bb)
        imargs = dict(interpolation='nearest',
                      origin='lower',
                      vmax=pim.max() * 0.95,
                      vmin=-3. * sigma1)
        pksty = dict(linestyle='None',
                     marker='+',
                     color='r',
                     mew=3,
                     ms=20,
                     alpha=0.6)

        plt.clf()
        myimshow(afwImage.ImageF(mi.getImage(), bb).getArray(), **imargs)
        plt.gray()
        plt.xticks([])
        plt.yticks([])
        savefig(pid, 'image')

        # Parent footprint
        plt.clf()
        myimshow(pim, extent=pext, **imargs)
        plt.gray()
        pks = fp.getPeaks()
        plt.plot([pk.getIx() for pk in pks], [pk.getIy() for pk in pks],
                 **pksty)
        plt.xticks([])
        plt.yticks([])
        plt.axis(pext)
        savefig(pid, 'parent')

        from lsst.meas.deblender.baseline import deblend

        xc = int((bb.getMinX() + bb.getMaxX()) / 2.)
        yc = int((bb.getMinY() + bb.getMaxY()) / 2.)
        if hasattr(psf, 'getFwhm'):
            psf_fwhm = psf.getFwhm(xc, yc)
        else:
            psf_fwhm = psf.computeShape().getDeterminantRadius() * 2.35

        # Each section of the design doc runs the deblender with different args.

        kwargs = dict(sigma1=sigma1, verbose=opt.verbose, getTemplateSum=True)

        basic = kwargs.copy()
        basic.update(fit_psfs=False,
                     median_smooth_template=False,
                     monotonic_template=False,
                     lstsq_weight_templates=False,
                     assignStrayFlux=False,
                     rampFluxAtEdge=False,
                     patchEdges=False)

        if opt.sec == 'sdss':
            # SDSS intro
            kwargs = basic
            kwargs.update(lstsq_weight_templates=True)

        elif opt.sec == 'mono':
            kwargs = basic
            kwargs.update(lstsq_weight_templates=True, monotonic_template=True)
        elif opt.sec == 'median':
            kwargs = basic
            kwargs.update(lstsq_weight_templates=True,
                          median_smooth_template=True,
                          monotonic_template=True)
        elif opt.sec == 'ramp':
            kwargs = basic
            kwargs.update(median_smooth_template=True,
                          monotonic_template=True,
                          rampFluxAtEdge=True)

        elif opt.sec == 'ramp2':
            kwargs = basic
            kwargs.update(median_smooth_template=True,
                          monotonic_template=True,
                          rampFluxAtEdge=True,
                          assignStrayFlux=True)

        elif opt.sec == 'patch':
            kwargs = basic
            kwargs.update(median_smooth_template=True,
                          monotonic_template=True,
                          patchEdges=True)

        else:
            raise 'Unknown section: "%s"' % opt.sec

        print('Running deblender with kwargs:', kwargs)
        res = deblend(fp, mi, psf, psf_fwhm, **kwargs)
        # print('got result with', [x for x in dir(res) if not x.startswith('__')])
        # for pk in res.peaks:
        #     print('got peak with', [x for x in dir(pk) if not x.startswith('__')])
        #     print('  deblend as psf?', pk.deblend_as_psf)

        # Find bounding-box of all templates.
        tbb = fp.getBBox()
        for pkres, pk in zip(res.peaks, pks):
            tbb.include(pkres.template_foot.getBBox())
        print('Bounding-box of all templates:', tbb)

        # Sum-of-templates plot
        tsum = np.zeros((tbb.getHeight(), tbb.getWidth()))
        tx0, ty0 = tbb.getMinX(), tbb.getMinY()

        # Sum-of-deblended children plot(s)
        # "heavy" bbox == template bbox.
        hsum = np.zeros((tbb.getHeight(), tbb.getWidth()))
        hsum2 = np.zeros((tbb.getHeight(), tbb.getWidth()))

        # Sum of templates from the deblender itself
        plt.clf()
        t = res.templateSum
        myimshow(t.getArray(), extent=getExtent(t.getBBox()), **imargs)
        plt.gray()
        plt.xticks([])
        plt.yticks([])
        savefig(pid, 'tsum1')

        # Make plots for each deblended child (peak)

        k = 0
        for pkres, pk in zip(res.peaks, pks):

            heavy = pkres.get_flux_portion()
            if heavy is None:
                print('Child has no HeavyFootprint -- skipping')
                continue

            kk = mapchild(k)

            w = pkres.template_weight

            cfp = pkres.template_foot
            cbb = cfp.getBBox()
            cext = getExtent(cbb)

            # Template image
            tim = pkres.template_mimg.getImage()
            timext = cext
            tim = tim.getArray()

            (x0, x1, y0, y1) = timext
            print('tim ext', timext)
            tsum[y0 - ty0:y1 - ty0, x0 - tx0:x1 - tx0] += tim

            # "Heavy" image -- flux assigned to child
            him = footprintToImage(heavy).getArray()
            hext = getExtent(heavy.getBBox())

            (x0, x1, y0, y1) = hext
            hsum[y0 - ty0:y1 - ty0, x0 - tx0:x1 - tx0] += him

            # "Heavy" without stray flux
            h2 = pkres.get_flux_portion(strayFlux=False)
            him2 = footprintToImage(h2).getArray()
            hext2 = getExtent(h2.getBBox())
            (x0, x1, y0, y1) = hext2
            hsum2[y0 - ty0:y1 - ty0, x0 - tx0:x1 - tx0] += him2

            if opt.sec == 'median':
                try:
                    med = pkres.median_filtered_template
                except Exception:
                    med = pkres.orig_template

                for im, nm in [(pkres.orig_template, 'symm'), (med, 'med')]:
                    # print('im:', im)
                    plt.clf()
                    myimshow(im.getArray(), extent=cext, **imargs)
                    plt.gray()
                    plt.xticks([])
                    plt.yticks([])
                    plt.plot([pk.getIx()], [pk.getIy()], **pksty)
                    plt.axis(pext)
                    savefig(pid, nm + '%i' % (kk))

            # Template
            plt.clf()
            myimshow(pkres.template_mimg.getImage().getArray() / w,
                     extent=cext,
                     **imargs)
            plt.gray()
            plt.xticks([])
            plt.yticks([])
            plt.plot([pk.getIx()], [pk.getIy()], **pksty)
            plt.axis(pext)
            savefig(pid, 't%i' % (kk))

            # Weighted template
            plt.clf()
            myimshow(tim, extent=cext, **imargs)
            plt.gray()
            plt.xticks([])
            plt.yticks([])
            plt.plot([pk.getIx()], [pk.getIy()], **pksty)
            plt.axis(pext)
            savefig(pid, 'tw%i' % (kk))

            # "Heavy"
            plt.clf()
            myimshow(him, extent=hext, **imargs)
            plt.gray()
            plt.xticks([])
            plt.yticks([])
            plt.plot([pk.getIx()], [pk.getIy()], **pksty)
            plt.axis(pext)
            savefig(pid, 'h%i' % (kk))

            # Original symmetric template
            plt.clf()
            t = pkres.orig_template
            foot = pkres.orig_foot
            myimshow(t.getArray(), extent=getExtent(foot.getBBox()), **imargs)
            plt.gray()
            plt.xticks([])
            plt.yticks([])
            plt.plot([pk.getIx()], [pk.getIy()], **pksty)
            plt.axis(pext)
            savefig(pid, 'o%i' % (kk))

            if opt.sec == 'patch' and pkres.patched:
                pass

            if opt.sec in ['ramp', 'ramp2'] and pkres.has_ramped_template:

                # Ramped template
                plt.clf()
                t = pkres.ramped_template
                myimshow(t.getArray(), extent=getExtent(t.getBBox()), **imargs)
                plt.gray()
                plt.xticks([])
                plt.yticks([])
                plt.plot([pk.getIx()], [pk.getIy()], **pksty)
                plt.axis(pext)
                savefig(pid, 'r%i' % (kk))

                # Median-filtered template
                plt.clf()
                t = pkres.median_filtered_template
                myimshow(t.getArray(), extent=getExtent(t.getBBox()), **imargs)
                plt.gray()
                plt.xticks([])
                plt.yticks([])
                plt.plot([pk.getIx()], [pk.getIy()], **pksty)
                plt.axis(pext)
                savefig(pid, 'med%i' % (kk))

                # Assigned flux
                plt.clf()
                t = pkres.portion_mimg.getImage()
                myimshow(t.getArray(), extent=getExtent(t.getBBox()), **imargs)
                plt.gray()
                plt.xticks([])
                plt.yticks([])
                plt.plot([pk.getIx()], [pk.getIy()], **pksty)
                plt.axis(pext)
                savefig(pid, 'p%i' % (kk))

            if opt.sec == 'ramp2':
                # stray flux
                if pkres.stray_flux is not None:
                    s = pkres.stray_flux
                    strayim = footprintToImage(s).getArray()
                    strayext = getExtent(s.getBBox())

                    plt.clf()
                    myimshow(strayim, extent=strayext, **imargs)
                    plt.gray()
                    plt.xticks([])
                    plt.yticks([])
                    plt.plot([pk.getIx()], [pk.getIy()], **pksty)
                    plt.axis(pext)
                    savefig(pid, 's%i' % (kk))

                    # Assigned flux, omitting stray flux.
                    plt.clf()
                    myimshow(him2, extent=hext2, **imargs)
                    plt.gray()
                    plt.xticks([])
                    plt.yticks([])
                    plt.plot([pk.getIx()], [pk.getIy()], **pksty)
                    plt.axis(pext)
                    savefig(pid, 'hb%i' % (kk))

            k += 1

        # sum of templates
        plt.clf()
        myimshow(tsum, extent=getExtent(tbb), **imargs)
        plt.gray()
        plt.xticks([])
        plt.yticks([])
        plt.plot([pk.getIx() for pk in pks], [pk.getIy() for pk in pks],
                 **pksty)
        plt.axis(pext)
        savefig(pid, 'tsum')

        # sum of assigned flux
        plt.clf()
        myimshow(hsum, extent=getExtent(tbb), **imargs)
        plt.gray()
        plt.xticks([])
        plt.yticks([])
        plt.plot([pk.getIx() for pk in pks], [pk.getIy() for pk in pks],
                 **pksty)
        plt.axis(pext)
        savefig(pid, 'hsum')

        plt.clf()
        myimshow(hsum2, extent=getExtent(tbb), **imargs)
        plt.gray()
        plt.xticks([])
        plt.yticks([])
        plt.plot([pk.getIx() for pk in pks], [pk.getIy() for pk in pks],
                 **pksty)
        plt.axis(pext)
        savefig(pid, 'hsum2')

        k = 0
        for pkres, pk in zip(res.peaks, pks):
            heavy = pkres.get_flux_portion()
            if heavy is None:
                continue

            print('Template footprint:', pkres.template_foot.getBBox())
            print('Template img:', pkres.template_mimg.getBBox())
            print('Heavy footprint:', heavy.getBBox())

            cfp = pkres.template_foot
            cbb = cfp.getBBox()
            cext = getExtent(cbb)
            tim = pkres.template_mimg.getImage().getArray()
            (x0, x1, y0, y1) = cext

            frac = tim / tsum[y0 - ty0:y1 - ty0, x0 - tx0:x1 - tx0]

            msk = afwImage.ImageF(cbb.getWidth(), cbb.getHeight())
            msk.setXY0(cbb.getMinX(), cbb.getMinY())
            afwDet.setImageFromFootprint(msk, cfp, 1.)
            msk = msk.getArray()
            frac[msk == 0.] = np.nan

            # Fraction of flux assigned to this child.
            plt.clf()
            plt.imshow(frac,
                       extent=cext,
                       interpolation='nearest',
                       origin='lower',
                       vmin=0,
                       vmax=1)
            # plt.plot([x0,x0,x1,x1,x0], [y0,y1,y1,y0,y0], 'k-')
            plt.gray()
            plt.xticks([])
            plt.yticks([])
            plt.plot([pk.getIx()], [pk.getIy()], **pksty)
            plt.gca().set_axis_bgcolor((0.9, 0.9, 0.5))
            plt.axis(pext)
            savefig(pid, 'f%i' % (mapchild(k)))

            k += 1
def main():
    '''
    Runs the deblender and creates plots for the "design document",
    doc/design.tex.  See the file NOTES for how to get set up to the
    point where you can actually run this on data.
    '''

    from optparse import OptionParser
    parser = OptionParser()
    parser.add_option('--root', dest='root', help='Root directory for Subaru data')
    parser.add_option('--outroot', '-o', dest='outroot', help='Output root directory for Subaru data')
    parser.add_option('--sources', help='Read a FITS table of sources')
    parser.add_option('--calexp', help='Read a FITS calexp')
    parser.add_option('--psf', help='Read a FITS PSF')

    parser.add_option('--drill', '-D', dest='drill', action='append', type=str, default=[],
                      help='Drill down on individual source IDs')
    parser.add_option('--drillxy', dest='drillxy', action='append', type=str, default=[],
                      help='Drill down on individual source positions, eg 132,46;54,67')
    parser.add_option('--visit', dest='visit', type=int, default=108792, help='Suprimecam visit id')
    parser.add_option('--ccd', dest='ccd', type=int, default=5, help='Suprimecam CCD number')
    parser.add_option('--prefix', dest='prefix', default='design-', help='plot filename prefix')
    parser.add_option('--suffix', dest='suffix', default=None, help='plot filename suffix (default: ".png")')
    parser.add_option('--pat', dest='pat', help='Plot filename pattern: eg, "design-%(pid)04i-%(name).png"; overrides --prefix and --suffix')
    parser.add_option('--pdf', dest='pdf', action='store_true', default=False, help='save in PDF format?')
    parser.add_option('-v', dest='verbose', action='store_true')
    parser.add_option('--figw', dest='figw', type=float, help='Figure window width (inches)',
                      default=4.)
    parser.add_option('--figh', dest='figh', type=float, help='Figure window height (inches)',
                      default=4.)
    parser.add_option('--order', dest='order', type=str, help='Child order: eg 3,0,1,2')

    parser.add_option('--sdss', dest='sec', action='store_const', const='sdss',
                      help='Produce plots for the SDSS section.')
    parser.add_option('--mono', dest='sec', action='store_const', const='mono',
                      help='Produce plots for the "monotonic" section.')
    parser.add_option('--median', dest='sec', action='store_const', const='median',
                      help='Produce plots for the "median filter" section.')
    parser.add_option('--ramp', dest='sec', action='store_const', const='ramp',
                      help='Produce plots for the "ramp edges" section.')
    parser.add_option('--ramp2', dest='sec', action='store_const', const='ramp2',
                      help='Produce plots for the "ramp edges + stray flux" section.')
    parser.add_option('--patch', dest='sec', action='store_const', const='patch',
                      help='Produce plots for the "patch edges" section.')

    opt,args = parser.parse_args()

    # Logging
    root = pexLog.Log.getDefaultLog()
    if opt.verbose:
        root.setThreshold(pexLog.Log.DEBUG)
    else:
        root.setThreshold(pexLog.Log.INFO)
    # Quiet some of the more chatty loggers
    pexLog.Log(root, 'lsst.meas.deblender.symmetrizeFootprint',
                   pexLog.Log.INFO)
    #pexLog.Log(root, 'lsst.meas.deblender.symmetricFootprint',
    #               pexLog.Log.INFO)
    pexLog.Log(root, 'lsst.meas.deblender.getSignificantEdgePixels',
                   pexLog.Log.INFO)
    pexLog.Log(root, 'afw.Mask', pexLog.Log.INFO)


    if opt.sec is None:
        opt.sec = 'sdss'
    if opt.pdf:
        if opt.suffix is None:
            opt.suffix = ''
        opt.suffix += '.pdf'
    if not opt.suffix:
        opt.suffix = '.png'

    if opt.pat:
        plotpattern = opt.pat
    else:
        plotpattern = opt.prefix + '%(pid)04i-%(name)s' + opt.suffix

    if opt.order is not None:
        opt.order = [int(x) for x in opt.order.split(',')]
        invorder = np.zeros(len(opt.order))
        invorder[opt.order] = np.arange(len(opt.order))

    def mapchild(i):
        if opt.order is None:
            return i
        return invorder[i]

    def savefig(pid, figname):
        fn = plotpattern % dict(pid=pid, name=figname)
        plt.savefig(fn)

    # Load data using the butler, if desired
    dr = None
    if opt.sources is None or opt.calexp is None:
        print 'Creating DataRef...'
        dr = getSuprimeDataref(opt.visit, opt.ccd, rootdir=opt.root,
                               outrootdir=opt.outroot)
        print 'Got', dr

    # Which parent ids / deblend families are we going to plot?
    keepids = None
    if len(opt.drill):
        keepids = []
        for d in opt.drill:
            for dd in d.split(','):
                keepids.append(int(dd))
        print 'Keeping parent ids', keepids

    keepxys = None
    if len(opt.drillxy):
        keepxys = []
        for d in opt.drillxy:
            for dd in d.split(';'):
                xy = dd.split(',')
                assert(len(xy) == 2)
                keepxys.append((int(xy[0]),int(xy[1])))
        print 'Keeping parents at xy', keepxys
        
    # Read from butler or local file
    cat = readCatalog(opt.sources, None, dataref=dr, keepids=keepids,
                      keepxys=keepxys, patargs=dict(visit=opt.visit, ccd=opt.ccd))
    print 'Got', len(cat), 'sources'

    # Load data from butler or local files
    if opt.calexp is not None:
        print 'Reading exposure from', opt.calexp
        exposure = afwImage.ExposureF(opt.calexp)
    else:
        exposure = dr.get('calexp')
    print 'Exposure', exposure
    mi = exposure.getMaskedImage()

    if opt.psf is not None:
        print 'Reading PSF from', opt.psf
        psf = afwDet.Psf.readFits(opt.psf)
        print 'Got', psf
    elif dr:
        psf = dr.get('psf')
    else:
        psf = exposure.getPsf()
        

    sigma1 = get_sigma1(mi)

    fams = getFamilies(cat)
    print len(fams), 'deblend families'
    
    if False:
        for j,(parent,children) in enumerate(fams):
            print 'parent', parent
            print 'children', children
            plotDeblendFamily(mi, parent, children, cat, sigma1, ellipses=False)
            fn = '%04i.png' % parent.getId()
            plt.savefig(fn)
            print 'wrote', fn



    def nlmap(X):
        return np.arcsinh(X / (3.*sigma1))
    def myimshow(im, **kwargs):
        kwargs = kwargs.copy()
        mn = kwargs.get('vmin', -5*sigma1)
        kwargs['vmin'] = nlmap(mn)
        mx = kwargs.get('vmax', 100*sigma1)
        kwargs['vmax'] = nlmap(mx)
        plt.imshow(nlmap(im), **kwargs)
    plt.figure(figsize=(opt.figw, opt.figh))
    plt.subplot(1,1,1)
    plt.subplots_adjust(left=0.01, right=0.99, bottom=0.01, top=0.99,
                        wspace=0.05, hspace=0.1)

    # Make plots for each deblend family.

    for j,(parent,children) in enumerate(fams):
        print 'parent', parent.getId()
        print 'children', [ch.getId() for ch in children]
        print 'parent x,y', parent.getX(), parent.getY()

        pid = parent.getId()
        fp = parent.getFootprint()
        bb = fp.getBBox()
        pim = footprintToImage(parent.getFootprint(), mi).getArray()
        pext = getExtent(bb)
        imargs = dict(interpolation='nearest', origin='lower', vmax=pim.max() * 0.95, vmin=-3.*sigma1)
        pksty = dict(linestyle='None', marker='+', color='r', mew=3, ms=20, alpha=0.6)

        plt.clf()
        myimshow(afwImage.ImageF(mi.getImage(), bb).getArray(), **imargs)
        plt.gray()
        plt.xticks([])
        plt.yticks([])
        savefig(pid, 'image')
        
        # Parent footprint
        plt.clf()
        myimshow(pim, extent=pext, **imargs)
        plt.gray()
        pks = fp.getPeaks()
        plt.plot([pk.getIx() for pk in pks], [pk.getIy() for pk in pks], **pksty)
        plt.xticks([])
        plt.yticks([])
        plt.axis(pext)
        savefig(pid, 'parent')

        from lsst.meas.deblender.baseline import deblend

        xc = int((bb.getMinX() + bb.getMaxX()) / 2.)
        yc = int((bb.getMinY() + bb.getMaxY()) / 2.)
        if hasattr(psf, 'getFwhm'):
            psf_fwhm = psf.getFwhm(xc, yc)
        else:
            psf_fwhm = psf.computeShape().getDeterminantRadius() * 2.35
            
        # Each section of the design doc runs the deblender with different args.
            
        kwargs = dict(sigma1=sigma1, verbose=opt.verbose,
                      getTemplateSum=True)

        basic = kwargs.copy()
        basic.update(fit_psfs=False,
                     median_smooth_template=False,
                     monotonic_template=False,
                     lstsq_weight_templates=False,
                     findStrayFlux=False,
                     rampFluxAtEdge=False,
                     patchEdges=False)

        if opt.sec == 'sdss':
            # SDSS intro
            kwargs = basic
            kwargs.update(lstsq_weight_templates=True)
                          
        elif opt.sec == 'mono':
            kwargs = basic
            kwargs.update(lstsq_weight_templates=True,
                          monotonic_template=True)
        elif opt.sec == 'median':
            kwargs = basic
            kwargs.update(lstsq_weight_templates=True,
                          median_smooth_template=True,
                          monotonic_template=True)
        elif opt.sec == 'ramp':
            kwargs = basic
            kwargs.update(median_smooth_template=True,
                          monotonic_template=True,
                          rampFluxAtEdge=True)

        elif opt.sec == 'ramp2':
            kwargs = basic
            kwargs.update(median_smooth_template=True,
                          monotonic_template=True,
                          rampFluxAtEdge=True,
                          findStrayFlux=True,
                          assignStrayFlux=True)

        elif opt.sec == 'patch':
            kwargs = basic
            kwargs.update(median_smooth_template=True,
                          monotonic_template=True,
                          patchEdges=True)

        else:
            raise 'Unknown section: "%s"' % opt.sec

        print 'Running deblender with kwargs:', kwargs
        res = deblend(fp, mi, psf, psf_fwhm, **kwargs)
        #print 'got result with', [x for x in dir(res) if not x.startswith('__')]
        #for pk in res.peaks:
        #    print 'got peak with', [x for x in dir(pk) if not x.startswith('__')]
        #    print '  deblend as psf?', pk.deblend_as_psf

        # Find bounding-box of all templates.
        tbb = fp.getBBox()
        for pkres,pk in zip(res.peaks, pks):
            tbb.include(pkres.template_foot.getBBox())
        print 'Bounding-box of all templates:', tbb

        # Sum-of-templates plot
        tsum = np.zeros((tbb.getHeight(), tbb.getWidth()))
        tx0,ty0 = tbb.getMinX(), tbb.getMinY()

        # Sum-of-deblended children plot(s)
        # "heavy" bbox == template bbox.
        hsum = np.zeros((tbb.getHeight(), tbb.getWidth()))
        hsum2 = np.zeros((tbb.getHeight(), tbb.getWidth()))

        # Sum of templates from the deblender itself
        plt.clf()
        t = res.templateSum
        myimshow(t.getArray(), extent=getExtent(t.getBBox()), **imargs)
        plt.gray()
        plt.xticks([])
        plt.yticks([])
        savefig(pid, 'tsum1')

        # Make plots for each deblended child (peak)
        
        k = 0
        for pkres,pk in zip(res.peaks, pks):

            heavy = pkres.get_flux_portion()
            if heavy is None:
                print 'Child has no HeavyFootprint -- skipping'
                continue

            kk = mapchild(k)

            w = pkres.template_weight

            cfp = pkres.template_foot
            cbb = cfp.getBBox()
            cext = getExtent(cbb)

            # Template image
            tim = pkres.template_mimg.getImage()
            timext = cext
            tim = tim.getArray()

            (x0,x1,y0,y1) = timext
            print 'tim ext', timext
            tsum[y0-ty0:y1-ty0, x0-tx0:x1-tx0] += tim

            # "Heavy" image -- flux assigned to child
            him = footprintToImage(heavy).getArray()
            hext = getExtent(heavy.getBBox())

            (x0,x1,y0,y1) = hext
            hsum[y0-ty0:y1-ty0, x0-tx0:x1-tx0] += him

            # "Heavy" without stray flux
            h2 = pkres.get_flux_portion(strayFlux=False)
            him2 = footprintToImage(h2).getArray()
            hext2 = getExtent(h2.getBBox())
            (x0,x1,y0,y1) = hext2
            hsum2[y0-ty0:y1-ty0, x0-tx0:x1-tx0] += him2
            
            if opt.sec == 'median':
                try:
                    med = pkres.median_filtered_template
                except:
                    med = pkres.orig_template

                for im,nm in [(pkres.orig_template, 'symm'), (med, 'med')]:
                    #print 'im:', im
                    plt.clf()
                    myimshow(im.getArray(), extent=cext, **imargs)
                    plt.gray()
                    plt.xticks([])
                    plt.yticks([])
                    plt.plot([pk.getIx()], [pk.getIy()], **pksty)
                    plt.axis(pext)
                    savefig(pid, nm + '%i' % (kk))

            # Template
            plt.clf()
            myimshow(pkres.template_mimg.getImage().getArray() / w, extent=cext, **imargs)
            plt.gray()
            plt.xticks([])
            plt.yticks([])
            plt.plot([pk.getIx()], [pk.getIy()], **pksty)
            plt.axis(pext)
            savefig(pid, 't%i' % (kk))

            # Weighted template
            plt.clf()
            myimshow(tim, extent=cext, **imargs)
            plt.gray()
            plt.xticks([])
            plt.yticks([])
            plt.plot([pk.getIx()], [pk.getIy()], **pksty)
            plt.axis(pext)
            savefig(pid, 'tw%i' % (kk))

            # "Heavy"
            plt.clf()
            myimshow(him, extent=hext, **imargs)
            plt.gray()
            plt.xticks([])
            plt.yticks([])
            plt.plot([pk.getIx()], [pk.getIy()], **pksty)
            plt.axis(pext)
            savefig(pid, 'h%i' % (kk))

            # Original symmetric template
            plt.clf()
            t = pkres.orig_template
            foot = pkres.orig_foot
            myimshow(t.getArray(), extent=getExtent(foot.getBBox()), **imargs)
            plt.gray()
            plt.xticks([])
            plt.yticks([])
            plt.plot([pk.getIx()], [pk.getIy()], **pksty)
            plt.axis(pext)
            savefig(pid, 'o%i' % (kk))

            if opt.sec == 'patch' and pkres.patched:
                pass
            
            if opt.sec in ['ramp','ramp2'] and pkres.has_ramped_template:

                # Ramped template
                plt.clf()
                t = pkres.ramped_template
                myimshow(t.getArray(), extent=getExtent(t.getBBox()),
                         **imargs)
                plt.gray()
                plt.xticks([])
                plt.yticks([])
                plt.plot([pk.getIx()], [pk.getIy()], **pksty)
                plt.axis(pext)
                savefig(pid, 'r%i' % (kk))

                # Median-filtered template
                plt.clf()
                t = pkres.median_filtered_template
                myimshow(t.getArray(), extent=getExtent(t.getBBox()),
                         **imargs)
                plt.gray()
                plt.xticks([])
                plt.yticks([])
                plt.plot([pk.getIx()], [pk.getIy()], **pksty)
                plt.axis(pext)
                savefig(pid, 'med%i' % (kk))

                # Assigned flux
                plt.clf()
                t = pkres.portion_mimg.getImage()
                myimshow(t.getArray(), extent=getExtent(t.getBBox()),
                         **imargs)
                plt.gray()
                plt.xticks([])
                plt.yticks([])
                plt.plot([pk.getIx()], [pk.getIy()], **pksty)
                plt.axis(pext)
                savefig(pid, 'p%i' % (kk))

            if opt.sec == 'ramp2':
                # stray flux
                if pkres.stray_flux is not None:
                    s = pkres.stray_flux
                    strayim = footprintToImage(s).getArray()
                    strayext = getExtent(s.getBBox())

                    plt.clf()
                    myimshow(strayim, extent=strayext, **imargs)
                    plt.gray()
                    plt.xticks([])
                    plt.yticks([])
                    plt.plot([pk.getIx()], [pk.getIy()], **pksty)
                    plt.axis(pext)
                    savefig(pid, 's%i' % (kk))

                    # Assigned flux, omitting stray flux.
                    plt.clf()
                    myimshow(him2, extent=hext2, **imargs)
                    plt.gray()
                    plt.xticks([])
                    plt.yticks([])
                    plt.plot([pk.getIx()], [pk.getIy()], **pksty)
                    plt.axis(pext)
                    savefig(pid, 'hb%i' % (kk))


            k += 1

        # sum of templates
        plt.clf()
        myimshow(tsum, extent=getExtent(tbb), **imargs)
        plt.gray()
        plt.xticks([])
        plt.yticks([])
        plt.plot([pk.getIx() for pk in pks], [pk.getIy() for pk in pks], **pksty)
        plt.axis(pext)
        savefig(pid, 'tsum')

        # sum of assigned flux
        plt.clf()
        myimshow(hsum, extent=getExtent(tbb), **imargs)
        plt.gray()
        plt.xticks([])
        plt.yticks([])
        plt.plot([pk.getIx() for pk in pks], [pk.getIy() for pk in pks], **pksty)
        plt.axis(pext)
        savefig(pid, 'hsum')

        plt.clf()
        myimshow(hsum2, extent=getExtent(tbb), **imargs)
        plt.gray()
        plt.xticks([])
        plt.yticks([])
        plt.plot([pk.getIx() for pk in pks], [pk.getIy() for pk in pks], **pksty)
        plt.axis(pext)
        savefig(pid, 'hsum2')

        k = 0
        for pkres,pk in zip(res.peaks, pks):
            heavy = pkres.get_flux_portion()
            if heavy is None:
                continue

            print 'Template footprint:', pkres.template_foot.getBBox()
            print 'Template img:', pkres.template_mimg.getBBox()
            print 'Heavy footprint:', heavy.getBBox()

            cfp = pkres.template_foot
            cbb = cfp.getBBox()
            cext = getExtent(cbb)
            tim = pkres.template_mimg.getImage().getArray()
            (x0,x1,y0,y1) = cext

            frac = tim / tsum[y0-ty0:y1-ty0, x0-tx0:x1-tx0]

            msk = afwImage.ImageF(cbb.getWidth(), cbb.getHeight())
            msk.setXY0(cbb.getMinX(), cbb.getMinY())
            afwDet.setImageFromFootprint(msk, cfp, 1.)
            msk = msk.getArray()
            frac[msk == 0.] = np.nan

            # Fraction of flux assigned to this child.
            plt.clf()
            plt.imshow(frac, extent=cext, interpolation='nearest', origin='lower', vmin=0, vmax=1)
            #plt.plot([x0,x0,x1,x1,x0], [y0,y1,y1,y0,y0], 'k-')
            plt.gray()
            plt.xticks([])
            plt.yticks([])
            plt.plot([pk.getIx()], [pk.getIy()], **pksty)
            plt.gca().set_axis_bgcolor((0.9,0.9,0.5))
            plt.axis(pext)
            savefig(pid, 'f%i' % (mapchild(k)))

            k += 1
Beispiel #18
0
    def deblend(self, exposure, srcs, psf):
        """!
        Deblend.

        @param[in]     exposure Exposure to process
        @param[in,out] srcs     SourceCatalog containing sources detected on this exposure.
        @param[in]     psf      PSF

        @return None
        """
        self.log.info("Deblending %d sources" % len(srcs))

        from lsst.meas.deblender.baseline import deblend

        # find the median stdev in the image...
        mi = exposure.getMaskedImage()
        statsCtrl = afwMath.StatisticsControl()
        statsCtrl.setAndMask(mi.getMask().getPlaneBitMask(
            self.config.maskPlanes))
        stats = afwMath.makeStatistics(mi.getVariance(), mi.getMask(),
                                       afwMath.MEDIAN, statsCtrl)
        sigma1 = math.sqrt(stats.getValue(afwMath.MEDIAN))
        self.log.trace('sigma1: %g', sigma1)

        n0 = len(srcs)
        nparents = 0
        for i, src in enumerate(srcs):
            # t0 = time.clock()

            fp = src.getFootprint()
            pks = fp.getPeaks()

            # Since we use the first peak for the parent object, we should propagate its flags
            # to the parent source.
            src.assign(pks[0], self.peakSchemaMapper)

            if len(pks) < 2:
                continue

            if self.isLargeFootprint(fp):
                src.set(self.tooBigKey, True)
                self.skipParent(src, mi.getMask())
                self.log.warn('Parent %i: skipping large footprint (area: %i)',
                              int(src.getId()), int(fp.getArea()))
                continue
            if self.isMasked(fp, exposure.getMaskedImage().getMask()):
                src.set(self.maskedKey, True)
                self.skipParent(src, mi.getMask())
                self.log.warn(
                    'Parent %i: skipping masked footprint (area: %i)',
                    int(src.getId()), int(fp.getArea()))
                continue

            nparents += 1
            bb = fp.getBBox()
            psf_fwhm = self._getPsfFwhm(psf, bb)

            self.log.trace('Parent %i: deblending %i peaks', int(src.getId()),
                           len(pks))

            self.preSingleDeblendHook(exposure, srcs, i, fp, psf, psf_fwhm,
                                      sigma1)
            npre = len(srcs)

            # This should really be set in deblend, but deblend doesn't have access to the src
            src.set(self.tooManyPeaksKey,
                    len(fp.getPeaks()) > self.config.maxNumberOfPeaks)

            try:
                res = deblend(
                    fp,
                    mi,
                    psf,
                    psf_fwhm,
                    sigma1=sigma1,
                    psfChisqCut1=self.config.psfChisq1,
                    psfChisqCut2=self.config.psfChisq2,
                    psfChisqCut2b=self.config.psfChisq2b,
                    maxNumberOfPeaks=self.config.maxNumberOfPeaks,
                    strayFluxToPointSources=self.config.
                    strayFluxToPointSources,
                    assignStrayFlux=self.config.assignStrayFlux,
                    strayFluxAssignment=self.config.strayFluxRule,
                    rampFluxAtEdge=(self.config.edgeHandling == 'ramp'),
                    patchEdges=(self.config.edgeHandling == 'noclip'),
                    tinyFootprintSize=self.config.tinyFootprintSize,
                    clipStrayFluxFraction=self.config.clipStrayFluxFraction,
                    weightTemplates=self.config.weightTemplates,
                    removeDegenerateTemplates=self.config.
                    removeDegenerateTemplates,
                    maxTempDotProd=self.config.maxTempDotProd,
                    medianSmoothTemplate=self.config.medianSmoothTemplate)
                if self.config.catchFailures:
                    src.set(self.deblendFailedKey, False)
            except Exception as e:
                if self.config.catchFailures:
                    self.log.warn("Unable to deblend source %d: %s" %
                                  (src.getId(), e))
                    src.set(self.deblendFailedKey, True)
                    import traceback
                    traceback.print_exc()
                    continue
                else:
                    raise

            kids = []
            nchild = 0
            for j, peak in enumerate(res.deblendedParents[0].peaks):
                heavy = peak.getFluxPortion()
                if heavy is None or peak.skip:
                    src.set(self.deblendSkippedKey, True)
                    if not self.config.propagateAllPeaks:
                        # Don't care
                        continue
                    # We need to preserve the peak: make sure we have enough info to create a minimal
                    # child src
                    self.log.trace(
                        "Peak at (%i,%i) failed.  Using minimal default info for child.",
                        pks[j].getIx(), pks[j].getIy())
                    if heavy is None:
                        # copy the full footprint and strip out extra peaks
                        foot = afwDet.Footprint(src.getFootprint())
                        peakList = foot.getPeaks()
                        peakList.clear()
                        peakList.append(peak.peak)
                        zeroMimg = afwImage.MaskedImageF(foot.getBBox())
                        heavy = afwDet.makeHeavyFootprint(foot, zeroMimg)
                    if peak.deblendedAsPsf:
                        if peak.psfFitFlux is None:
                            peak.psfFitFlux = 0.0
                        if peak.psfFitCenter is None:
                            peak.psfFitCenter = (peak.peak.getIx(),
                                                 peak.peak.getIy())

                assert (len(heavy.getPeaks()) == 1)

                src.set(self.deblendSkippedKey, False)
                child = srcs.addNew()
                nchild += 1
                for key in self.toCopyFromParent:
                    child.set(key, src.get(key))
                child.assign(heavy.getPeaks()[0], self.peakSchemaMapper)
                child.setParent(src.getId())
                child.setFootprint(heavy)
                child.set(self.psfKey, peak.deblendedAsPsf)
                child.set(self.hasStrayFluxKey, peak.strayFlux is not None)
                if peak.deblendedAsPsf:
                    (cx, cy) = peak.psfFitCenter
                    child.set(self.psfCenterKey, geom.Point2D(cx, cy))
                    child.set(self.psfFluxKey, peak.psfFitFlux)
                child.set(self.deblendRampedTemplateKey,
                          peak.hasRampedTemplate)
                child.set(self.deblendPatchedTemplateKey, peak.patched)

                # Set the position of the peak from the parent footprint
                # This will make it easier to match the same source across
                # deblenders and across observations, where the peak
                # position is unlikely to change unless enough time passes
                # for a source to move on the sky.
                child.set(self.peakCenter,
                          geom.Point2I(pks[j].getIx(), pks[j].getIy()))
                child.set(self.peakIdKey, pks[j].getId())

                # The children have a single peak
                child.set(self.nPeaksKey, 1)
                # Set the number of peaks in the parent
                child.set(self.parentNPeaksKey, len(pks))

                kids.append(child)

            # Child footprints may extend beyond the full extent of their parent's which
            # results in a failure of the replace-by-noise code to reinstate these pixels
            # to their original values.  The following updates the parent footprint
            # in-place to ensure it contains the full union of itself and all of its
            # children's footprints.
            spans = src.getFootprint().spans
            for child in kids:
                spans = spans.union(child.getFootprint().spans)
            src.getFootprint().setSpans(spans)

            src.set(self.nChildKey, nchild)

            self.postSingleDeblendHook(exposure, srcs, i, npre, kids, fp, psf,
                                       psf_fwhm, sigma1, res)
            # print('Deblending parent id', src.getId(), 'took', time.clock() - t0)

        n1 = len(srcs)
        self.log.info(
            'Deblended: of %i sources, %i were deblended, creating %i children, total %i sources'
            % (n0, nparents, n1 - n0, n1))
    print 'deblend, measure', filter_
    exposure = exposures[filter_]
    fwhm = exposure.getPsf().computeShape().getDeterminantRadius() * 2.35
    sources = afwTable.SourceCatalog(merged_sources)

    for ii, src in enumerate(sources):
        deb = deblend(src.getFootprint(),
                      exposure.getMaskedImage(),
                      exposure.getPsf(),
                      fwhm,
                      verbose=False,
                      weightTemplates=False,
                      maxNumberOfPeaks=0,
                      rampFluxAtEdge=True,
                      assignStrayFlux=True,
                      strayFluxAssignment='trim',
                      strayFluxToPointSources='necessary',
                      findStrayFlux=True,
                      clipStrayFluxFraction=0.001,
                      psfChisqCut1=1.5,
                      psfChisqCut2=1.5,
                      monotonicTemplate=True,
                      medianSmoothTemplate=True,
                      medianFilterHalfsize=2,
                      tinyFootprintSize=2,
                      clipFootprintToNonzero=True)

        parent = src
        parent.assign(deb.peaks[0].peak, peakSchemaMapper)
        parent.setParent(0)
        parent.setFootprint(src.getFootprint())