Пример #1
0
def stage_tune(tims=None,
               cat=None,
               targetwcs=None,
               coimgs=None,
               cons=None,
               bands=None,
               invvars=None,
               brickid=None,
               Tcat=None,
               version_header=None,
               ps=None,
               **kwargs):
    tstage = t0 = Time()
    print('kwargs:', kwargs.keys())

    #print 'invvars:', invvars

    # How far down to render model profiles
    minsigma = 0.1
    for tim in tims:
        tim.modelMinval = minsigma * tim.sig1

    # Caching PSF
    for tim in tims:
        from tractor.psfex import CachingPsfEx
        tim.psfex.radius = 20
        tim.psfex.fitSavedData(*tim.psfex.splinedata)
        tim.psf = CachingPsfEx.fromPsfEx(tim.psfex)

    orig_wcsxy0 = [tim.wcs.getX0Y0() for tim in tims]
    set_source_radii(bands, orig_wcsxy0, tims, cat, minsigma)

    plt.figure(figsize=(10, 10))
    plt.subplots_adjust(left=0.002, right=0.998, bottom=0.002, top=0.998)

    plt.clf()
    rgb = get_rgb(coimgs, bands)
    dimshow(rgb)
    #plt.title('Image')
    ps.savefig()

    tmpfn = create_temp(suffix='.png')
    plt.imsave(tmpfn, rgb)
    del rgb
    cmd = 'pngtopnm %s | pnmtojpeg -quality 90 > tunebrick/coadd/image-%06i-full.jpg' % (
        tmpfn, brickid)
    os.system(cmd)
    os.unlink(tmpfn)

    pla = dict(ms=5, mew=1)

    ax = plt.axis()
    for i, src in enumerate(cat):
        rd = src.getPosition()
        ok, x, y = targetwcs.radec2pixelxy(rd.ra, rd.dec)
        cc = (0, 1, 0)
        if isinstance(src, PointSource):
            plt.plot(x - 1, y - 1, '+', color=cc, **pla)
        else:
            plt.plot(x - 1, y - 1, 'o', mec=cc, mfc='none', **pla)
        # plt.text(x, y, '%i' % i, color=cc, ha='center', va='bottom')
    plt.axis(ax)
    ps.savefig()

    print('Plots:', Time() - t0)

    # print 'Catalog:'
    # for src in cat:
    #     print '  ', src
    # switch_to_soft_ellipses(cat)

    assert (Catalog(*cat).numberOfParams() == len(invvars))

    keepcat = []
    keepinvvars = []
    iterinvvars = invvars
    ikeep = []
    for i, src in enumerate(cat):
        N = src.numberOfParams()
        iv = iterinvvars[:N]
        iterinvvars = iterinvvars[N:]
        if not np.all(np.isfinite(src.getParams())):
            print('Dropping source:', src)
            continue
        keepcat.append(src)
        keepinvvars.extend(iv)
        #print 'Keep:', src
        #print 'iv:', iv
        #print 'sigma', 1./np.sqrt(np.array(iv))
        ikeep.append(i)
    cat = keepcat
    Tcat.cut(np.array(ikeep))
    invvars = keepinvvars
    print(len(cat), 'sources with finite params')
    assert (Catalog(*cat).numberOfParams() == len(invvars))
    assert (len(iterinvvars) == 0)

    print('Rendering model images...')
    t0 = Time()
    mods = _map(_get_mod, [(tim, cat) for tim in tims])
    print('Getting model images:', Time() - t0)

    wcsW = targetwcs.get_width()
    wcsH = targetwcs.get_height()

    t0 = Time()
    comods = []
    for iband, band in enumerate(bands):
        comod = np.zeros((wcsH, wcsW), np.float32)
        for itim, (tim, mod) in enumerate(zip(tims, mods)):
            if tim.band != band:
                continue
            R = tim_get_resamp(tim, targetwcs)
            if R is None:
                continue
            (Yo, Xo, Yi, Xi) = R
            comod[Yo, Xo] += mod[Yi, Xi]
        comod /= np.maximum(cons[iband], 1)
        comods.append(comod)
    print('Creating model coadd:', Time() - t0)

    plt.clf()
    dimshow(get_rgb(comods, bands))
    plt.title('Model')
    ps.savefig()
    del comods

    t0 = Time()
    keepinvvars = []
    keepcat = []
    iterinvvars = invvars
    ikeep = []
    for isrc, src in enumerate(cat):
        newiv = None
        N = src.numberOfParams()

        gc = get_galaxy_cache()
        print('Galaxy cache:', gc)
        if gc is not None:
            gc.clear()

        print('Checking source', isrc, 'of', len(cat), ':', src)
        #print 'N params:', N
        #print 'iterinvvars:', len(iterinvvars)

        oldiv = iterinvvars[:N]
        iterinvvars = iterinvvars[N:]
        recompute_iv = False

        if isinstance(src, FixedCompositeGalaxy):
            # Obvious simplification: for composite galaxies with fracdev
            # out of bounds, convert to exp or dev.
            f = src.fracDev.getClippedValue()
            if f == 0.:
                oldsrc = src
                src = ExpGalaxy(oldsrc.pos, oldsrc.brightness, oldsrc.shapeExp)
                print('Converted comp to exp')
                #print '   ', oldsrc
                #print ' ->', src
                # pull out the invvar elements!
                pp = src.getParams()
                oldsrc.setParams(oldiv)
                newiv = oldsrc.pos.getParams() + oldsrc.brightness.getParams(
                ) + oldsrc.shapeExp.getParams()
                src.setParams(pp)
            elif f == 1.:
                oldsrc = src
                src = DevGalaxy(oldsrc.pos, oldsrc.brightness, oldsrc.shapeDev)
                print('Converted comp to dev')
                ##print '   ', oldsrc
                print(' ->', src)
                pp = src.getParams()
                oldsrc.setParams(oldiv)
                newiv = oldsrc.pos.getParams() + oldsrc.brightness.getParams(
                ) + oldsrc.shapeDev.getParams()
                src.setParams(pp)

        # treated_as_pointsource: do the bright-star check at least!
        if not isinstance(src, PointSource):
            # This is the check we use in unWISE
            if src.getBrightness().getMag('r') < 12.5:
                oldsrc = src
                src = PointSource(oldsrc.pos, oldsrc.brightness)
                print('Bright star: replacing', oldsrc)
                print('With', src)
                # Not QUITE right.
                #oldsrc.setParams(oldiv)
                #newiv = oldsrc.pos.getParams() + oldsrc.brightness.getParams()
                recompute_iv = True

        #print 'Try removing source:', src
        tsrc = Time()

        srcmodlist = []
        for itim, tim in enumerate(tims):
            patch = src.getModelPatch(tim)
            if patch is None:
                continue
            if patch.patch is None:
                continue

            # HACK -- this shouldn't be necessary, but seems to be!
            # FIXME -- track down why patches are being made with extent outside
            # that of the parent!
            H, W = tim.shape
            if patch.x0 < 0 or patch.y0 < 0 or patch.x1 > W or patch.y1 > H:
                print('Warning: Patch extends outside tim bounds:')
                print('patch extent:', patch.getExtent())
                print('image size:', W, 'x', H)
            patch.clipTo(W, H)
            ph, pw = patch.shape
            if pw * ph == 0:
                continue
            srcmodlist.append((itim, patch))

        # Try removing the source from the model;
        # check chi-squared change in the patches.
        sdlnp = 0.
        for itim, patch in srcmodlist:
            tim = tims[itim]
            mod = mods[itim]
            slc = patch.getSlice(tim)
            simg = tim.getImage()[slc]
            sie = tim.getInvError()[slc]
            chisq0 = np.sum(((simg - mod[slc]) * sie)**2)
            chisq1 = np.sum(((simg - (mod[slc] - patch.patch)) * sie)**2)
            sdlnp += -0.5 * (chisq1 - chisq0)
        print('Removing source: dlnp =', sdlnp)
        print('Testing source removal:', Time() - tsrc)

        if sdlnp > 0:
            #print 'Removing source!'
            for itim, patch in srcmodlist:
                patch.addTo(mods[itim], scale=-1)
            continue

        # Try some model changes...
        newsrcs = []
        if isinstance(src, FixedCompositeGalaxy):
            newsrcs.append(ExpGalaxy(src.pos, src.brightness, src.shapeExp))
            newsrcs.append(DevGalaxy(src.pos, src.brightness, src.shapeDev))
            newsrcs.append(PointSource(src.pos, src.brightness))
        elif isinstance(src, (DevGalaxy, ExpGalaxy)):
            newsrcs.append(PointSource(src.pos, src.brightness))

        bestnew = None
        bestdlnp = 0.
        bestdpatches = None

        srcmodlist2 = [None for tim in tims]
        for itim, patch in srcmodlist:
            srcmodlist2[itim] = patch

        for newsrc in newsrcs:

            dpatches = []
            dlnp = 0.
            for itim, tim in enumerate(tims):
                patch = newsrc.getModelPatch(tim)
                if patch is not None:
                    if patch.patch is None:
                        patch = None
                if patch is not None:
                    # HACK -- this shouldn't be necessary, but seems to be!
                    # FIXME -- track down why patches are being made with extent outside
                    # that of the parent!
                    H, W = tim.shape
                    patch.clipTo(W, H)
                    ph, pw = patch.shape
                    if pw * ph == 0:
                        patch = None

                oldpatch = srcmodlist2[itim]
                if oldpatch is None and patch is None:
                    continue

                # Find difference in models
                if oldpatch is None:
                    dpatch = patch
                elif patch is None:
                    dpatch = oldpatch * -1.
                else:
                    dpatch = patch - oldpatch
                dpatches.append((itim, dpatch))

                mod = mods[itim]
                slc = dpatch.getSlice(tim)
                simg = tim.getImage()[slc]
                sie = tim.getInvError()[slc]
                chisq0 = np.sum(((simg - mod[slc]) * sie)**2)
                chisq1 = np.sum(((simg - (mod[slc] + dpatch.patch)) * sie)**2)
                dlnp += -0.5 * (chisq1 - chisq0)

            #print 'Trying source change:'
            #print 'from', src
            #print '  to', newsrc
            print('Trying source change to',
                  type(newsrc).__name__, ': dlnp =', dlnp)

            if dlnp >= bestdlnp:
                bestnew = newsrc
                bestdlnp = dlnp
                bestdpatches = dpatches

        if bestnew is not None:
            print('Found model improvement!  Switching to', end=' ')
            print(bestnew)
            for itim, dpatch in bestdpatches:
                dpatch.addTo(mods[itim])
            src = bestnew
            recompute_iv = True

        del srcmodlist
        del srcmodlist2

        if recompute_iv:
            dchisqs = np.zeros(src.numberOfParams())
            for tim in tims:
                derivs = src.getParamDerivatives(tim)
                h, w = tim.shape
                ie = tim.getInvError()
                for i, deriv in enumerate(derivs):
                    if deriv is None:
                        continue
                    deriv.clipTo(w, h)
                    slc = deriv.getSlice(ie)
                    chi = deriv.patch * ie[slc]
                    dchisqs[i] += (chi**2).sum()
            newiv = dchisqs

        if newiv is None:
            keepinvvars.append(oldiv)
        else:
            keepinvvars.append(newiv)

        keepcat.append(src)
        ikeep.append(isrc)
    cat = keepcat
    Tcat.cut(np.array(ikeep))

    gc = get_galaxy_cache()
    print('Galaxy cache:', gc)
    if gc is not None:
        gc.clear()

    assert (len(iterinvvars) == 0)
    keepinvvars = np.hstack(keepinvvars)
    assert (Catalog(*keepcat).numberOfParams() == len(keepinvvars))
    invvars = keepinvvars
    assert (len(cat) == len(Tcat))
    print('Model selection:', Time() - t0)

    t0 = Time()
    # WCS header for these images
    hdr = fitsio.FITSHDR()
    targetwcs.add_to_header(hdr)
    fwa = dict(clobber=True, header=hdr)

    comods = []
    for iband, band in enumerate(bands):
        comod = np.zeros((wcsH, wcsW), np.float32)
        cochi2 = np.zeros((wcsH, wcsW), np.float32)
        coiv = np.zeros((wcsH, wcsW), np.float32)
        detiv = np.zeros((wcsH, wcsW), np.float32)
        for itim, (tim, mod) in enumerate(zip(tims, mods)):
            if tim.band != band:
                continue
            R = tim_get_resamp(tim, targetwcs)
            if R is None:
                continue
            (Yo, Xo, Yi, Xi) = R
            comod[Yo, Xo] += mod[Yi, Xi]
            ie = tim.getInvError()
            cochi2[Yo, Xo] += ((tim.getImage()[Yi, Xi] - mod[Yi, Xi]) *
                               ie[Yi, Xi])**2
            coiv[Yo, Xo] += ie[Yi, Xi]**2

            psfnorm = 1. / (2. * np.sqrt(np.pi) * tim.psf_sigma)
            detsig1 = tim.sig1 / psfnorm
            detiv[Yo, Xo] += (ie[Yi, Xi] > 0) * (1. / detsig1**2)

        comod /= np.maximum(cons[iband], 1)
        comods.append(comod)
        del comod

        fn = 'tunebrick/coadd/chi2-%06i-%s.fits' % (brickid, band)
        fitsio.write(fn, cochi2, **fwa)
        del cochi2
        print('Wrote', fn)

        fn = 'tunebrick/coadd/image-%06i-%s.fits' % (brickid, band)
        fitsio.write(fn, coimgs[iband], **fwa)
        print('Wrote', fn)
        fitsio.write(fn, coiv, clobber=False)
        print('Appended ivar to', fn)
        del coiv

        fn = 'tunebrick/coadd/depth-%06i-%s.fits' % (brickid, band)
        fitsio.write(fn, detiv, **fwa)
        print('Wrote', fn)
        del detiv

        fn = 'tunebrick/coadd/model-%06i-%s.fits' % (brickid, band)
        fitsio.write(fn, comods[iband], **fwa)
        print('Wrote', fn)

        fn = 'tunebrick/coadd/nexp-b%06i-%s.fits' % (brickid, band)
        fitsio.write(fn, cons[iband], **fwa)
        print('Wrote', fn)

    plt.clf()
    rgb = get_rgb(comods, bands)
    dimshow(rgb)
    plt.title('Model')
    ps.savefig()
    del comods

    # Plot sources over top
    ax = plt.axis()
    for i, src in enumerate(cat):
        rd = src.getPosition()
        ok, x, y = targetwcs.radec2pixelxy(rd.ra, rd.dec)
        cc = (0, 1, 0)
        if isinstance(src, PointSource):
            plt.plot(x - 1, y - 1, '+', color=cc, **pla)
        else:
            plt.plot(x - 1, y - 1, 'o', mec=cc, mfc='none', **pla)
        # plt.text(x, y, '%i' % i, color=cc, ha='center', va='bottom')
    plt.axis(ax)
    ps.savefig()

    tmpfn = create_temp(suffix='.png')
    plt.imsave(tmpfn, rgb)
    del rgb
    cmd = 'pngtopnm %s | pnmtojpeg -quality 90 > tunebrick/coadd/model-%06i-full.jpg' % (
        tmpfn, brickid)
    os.system(cmd)
    os.unlink(tmpfn)

    assert (len(cat) == len(Tcat))
    print('Coadd FITS files and plots:', Time() - t0)

    print('Whole stage:', Time() - tstage)

    return dict(cat=cat, Tcat=Tcat, invvars=invvars)
Пример #2
0
def stage_tune(tims=None, cat=None, targetwcs=None, coimgs=None, cons=None,
               bands=None, invvars=None, brickid=None,
               Tcat=None, version_header=None, ps=None, **kwargs):
    tstage = t0 = Time()
    print 'kwargs:', kwargs.keys()

    #print 'invvars:', invvars

    # How far down to render model profiles
    minsigma = 0.1
    for tim in tims:
        tim.modelMinval = minsigma * tim.sig1

    # Caching PSF
    for tim in tims:
        from tractor.psfex import CachingPsfEx
        tim.psfex.radius = 20
        tim.psfex.fitSavedData(*tim.psfex.splinedata)
        tim.psf = CachingPsfEx.fromPsfEx(tim.psfex)

    orig_wcsxy0 = [tim.wcs.getX0Y0() for tim in tims]
    set_source_radii(bands, orig_wcsxy0, tims, cat, minsigma)

    plt.figure(figsize=(10,10))
    plt.subplots_adjust(left=0.002, right=0.998, bottom=0.002, top=0.998)

    plt.clf()
    rgb = get_rgb(coimgs, bands)
    dimshow(rgb)
    #plt.title('Image')
    ps.savefig()

    tmpfn = create_temp(suffix='.png')
    plt.imsave(tmpfn, rgb)
    del rgb
    cmd = 'pngtopnm %s | pnmtojpeg -quality 90 > tunebrick/coadd/image-%06i-full.jpg' % (tmpfn, brickid)
    os.system(cmd)
    os.unlink(tmpfn)

    pla = dict(ms=5, mew=1)

    ax = plt.axis()
    for i,src in enumerate(cat):
        rd = src.getPosition()
        ok,x,y = targetwcs.radec2pixelxy(rd.ra, rd.dec)
        cc = (0,1,0)
        if isinstance(src, PointSource):
            plt.plot(x-1, y-1, '+', color=cc, **pla)
        else:
            plt.plot(x-1, y-1, 'o', mec=cc, mfc='none', **pla)
        # plt.text(x, y, '%i' % i, color=cc, ha='center', va='bottom')
    plt.axis(ax)
    ps.savefig()

    print 'Plots:', Time()-t0

    # print 'Catalog:'
    # for src in cat:
    #     print '  ', src
    # switch_to_soft_ellipses(cat)

    assert(Catalog(*cat).numberOfParams() == len(invvars))

    keepcat = []
    keepinvvars = []
    iterinvvars = invvars
    ikeep = []
    for i,src in enumerate(cat):
        N = src.numberOfParams()
        iv = iterinvvars[:N]
        iterinvvars = iterinvvars[N:]
        if not np.all(np.isfinite(src.getParams())):
            print 'Dropping source:', src
            continue
        keepcat.append(src)
        keepinvvars.extend(iv)
        #print 'Keep:', src
        #print 'iv:', iv
        #print 'sigma', 1./np.sqrt(np.array(iv))
        ikeep.append(i)
    cat = keepcat
    Tcat.cut(np.array(ikeep))
    invvars = keepinvvars
    print len(cat), 'sources with finite params'
    assert(Catalog(*cat).numberOfParams() == len(invvars))
    assert(len(iterinvvars) == 0)

    print 'Rendering model images...'
    t0 = Time()
    mods = _map(_get_mod, [(tim, cat) for tim in tims])
    print 'Getting model images:', Time()-t0

    wcsW = targetwcs.get_width()
    wcsH = targetwcs.get_height()

    t0 = Time()
    comods = []
    for iband,band in enumerate(bands):
        comod  = np.zeros((wcsH,wcsW), np.float32)
        for itim, (tim,mod) in enumerate(zip(tims, mods)):
            if tim.band != band:
                continue
            R = tim_get_resamp(tim, targetwcs)
            if R is None:
                continue
            (Yo,Xo,Yi,Xi) = R
            comod[Yo,Xo] += mod[Yi,Xi]
        comod  /= np.maximum(cons[iband], 1)
        comods.append(comod)
    print 'Creating model coadd:', Time()-t0

    plt.clf()
    dimshow(get_rgb(comods, bands))
    plt.title('Model')
    ps.savefig()
    del comods

    t0 = Time()
    keepinvvars = []
    keepcat = []
    iterinvvars = invvars
    ikeep = []
    for isrc,src in enumerate(cat):
        newiv = None
        N = src.numberOfParams()

        gc = get_galaxy_cache()
        print 'Galaxy cache:', gc
        if gc is not None:
            gc.clear()

        print 'Checking source', isrc, 'of', len(cat), ':', src
        #print 'N params:', N
        #print 'iterinvvars:', len(iterinvvars)

        oldiv = iterinvvars[:N]
        iterinvvars = iterinvvars[N:]
        recompute_iv = False

        if isinstance(src, FixedCompositeGalaxy):
            # Obvious simplification: for composite galaxies with fracdev
            # out of bounds, convert to exp or dev.
            f = src.fracDev.getClippedValue()
            if f == 0.:
                oldsrc = src
                src = ExpGalaxy(oldsrc.pos, oldsrc.brightness, oldsrc.shapeExp)
                print 'Converted comp to exp'
                #print '   ', oldsrc
                #print ' ->', src
                # pull out the invvar elements!
                pp = src.getParams()
                oldsrc.setParams(oldiv)
                newiv = oldsrc.pos.getParams() + oldsrc.brightness.getParams() + oldsrc.shapeExp.getParams()
                src.setParams(pp)
            elif f == 1.:
                oldsrc = src
                src = DevGalaxy(oldsrc.pos, oldsrc.brightness, oldsrc.shapeDev)
                print 'Converted comp to dev'
                ##print '   ', oldsrc
                print ' ->', src
                pp = src.getParams()
                oldsrc.setParams(oldiv)
                newiv = oldsrc.pos.getParams() + oldsrc.brightness.getParams() + oldsrc.shapeDev.getParams()
                src.setParams(pp)

        # treated_as_pointsource: do the bright-star check at least!
        if not isinstance(src, PointSource):
            # This is the check we use in unWISE
            if src.getBrightness().getMag('r') < 12.5:
                oldsrc = src
                src = PointSource(oldsrc.pos, oldsrc.brightness)
                print 'Bright star: replacing', oldsrc
                print 'With', src
                # Not QUITE right.
                #oldsrc.setParams(oldiv)
                #newiv = oldsrc.pos.getParams() + oldsrc.brightness.getParams()
                recompute_iv = True

        #print 'Try removing source:', src
        tsrc = Time()

        srcmodlist = []
        for itim,tim in enumerate(tims):
            patch = src.getModelPatch(tim)
            if patch is None:
                continue
            if patch.patch is None:
                continue

            # HACK -- this shouldn't be necessary, but seems to be!
            # FIXME -- track down why patches are being made with extent outside
            # that of the parent!
            H,W = tim.shape
            if patch.x0 < 0 or patch.y0 < 0 or patch.x1 > W or patch.y1 > H:
                print 'Warning: Patch extends outside tim bounds:'
                print 'patch extent:', patch.getExtent()
                print 'image size:', W, 'x', H
            patch.clipTo(W,H)
            ph,pw = patch.shape
            if pw*ph == 0:
                continue
            srcmodlist.append((itim, patch))
    
        # Try removing the source from the model;
        # check chi-squared change in the patches.
        sdlnp = 0.
        for itim,patch in srcmodlist:
            tim = tims[itim]
            mod = mods[itim]
            slc = patch.getSlice(tim)
            simg = tim.getImage()[slc]
            sie  = tim.getInvError()[slc]
            chisq0 = np.sum(((simg - mod[slc]) * sie)**2)
            chisq1 = np.sum(((simg - (mod[slc] - patch.patch)) * sie)**2)
            sdlnp += -0.5 * (chisq1 - chisq0)
        print 'Removing source: dlnp =', sdlnp
        print 'Testing source removal:', Time()-tsrc
    
        if sdlnp > 0:
            #print 'Removing source!'
            for itim,patch in srcmodlist:
                patch.addTo(mods[itim], scale=-1)
            continue

        # Try some model changes...
        newsrcs = []
        if isinstance(src, FixedCompositeGalaxy):
            newsrcs.append(ExpGalaxy(src.pos, src.brightness, src.shapeExp))
            newsrcs.append(DevGalaxy(src.pos, src.brightness, src.shapeDev))
            newsrcs.append(PointSource(src.pos, src.brightness))
        elif isinstance(src, (DevGalaxy, ExpGalaxy)):
            newsrcs.append(PointSource(src.pos, src.brightness))

        bestnew = None
        bestdlnp = 0.
        bestdpatches = None

        srcmodlist2 = [None for tim in tims]
        for itim,patch in srcmodlist:
            srcmodlist2[itim] = patch

        for newsrc in newsrcs:

            dpatches = []
            dlnp = 0.
            for itim,tim in enumerate(tims):
                patch = newsrc.getModelPatch(tim)
                if patch is not None:
                    if patch.patch is None:
                        patch = None
                if patch is not None:
                    # HACK -- this shouldn't be necessary, but seems to be!
                    # FIXME -- track down why patches are being made with extent outside
                    # that of the parent!
                    H,W = tim.shape
                    patch.clipTo(W,H)
                    ph,pw = patch.shape
                    if pw*ph == 0:
                        patch = None

                oldpatch = srcmodlist2[itim]
                if oldpatch is None and patch is None:
                    continue

                # Find difference in models
                if oldpatch is None:
                    dpatch = patch
                elif patch is None:
                    dpatch = oldpatch * -1.
                else:
                    dpatch = patch - oldpatch
                dpatches.append((itim, dpatch))
                
                mod = mods[itim]
                slc = dpatch.getSlice(tim)
                simg = tim.getImage()[slc]
                sie  = tim.getInvError()[slc]
                chisq0 = np.sum(((simg - mod[slc]) * sie)**2)
                chisq1 = np.sum(((simg - (mod[slc] + dpatch.patch)) * sie)**2)
                dlnp += -0.5 * (chisq1 - chisq0)

            #print 'Trying source change:'
            #print 'from', src
            #print '  to', newsrc
            print 'Trying source change to', type(newsrc).__name__, ': dlnp =', dlnp

            if dlnp >= bestdlnp:
                bestnew = newsrc
                bestdlnp = dlnp
                bestdpatches = dpatches

        if bestnew is not None:
            print 'Found model improvement!  Switching to',
            print bestnew
            for itim,dpatch in bestdpatches:
                dpatch.addTo(mods[itim])
            src = bestnew
            recompute_iv = True

        del srcmodlist
        del srcmodlist2

        if recompute_iv:
            dchisqs = np.zeros(src.numberOfParams())
            for tim in tims:
                derivs = src.getParamDerivatives(tim)
                h,w = tim.shape
                ie = tim.getInvError()
                for i,deriv in enumerate(derivs):
                    if deriv is None:
                        continue
                    deriv.clipTo(w,h)
                    slc = deriv.getSlice(ie)
                    chi = deriv.patch * ie[slc]
                    dchisqs[i] += (chi**2).sum()
            newiv = dchisqs

        if newiv is None:
            keepinvvars.append(oldiv)
        else:
            keepinvvars.append(newiv)
    
        keepcat.append(src)
        ikeep.append(isrc)
    cat = keepcat
    Tcat.cut(np.array(ikeep))

    gc = get_galaxy_cache()
    print 'Galaxy cache:', gc
    if gc is not None:
        gc.clear()

    assert(len(iterinvvars) == 0)
    keepinvvars = np.hstack(keepinvvars)
    assert(Catalog(*keepcat).numberOfParams() == len(keepinvvars))
    invvars = keepinvvars
    assert(len(cat) == len(Tcat))
    print 'Model selection:', Time()-t0

    t0 = Time()
    # WCS header for these images
    hdr = fitsio.FITSHDR()
    targetwcs.add_to_header(hdr)
    fwa = dict(clobber=True, header=hdr)

    comods = []
    for iband,band in enumerate(bands):
        comod  = np.zeros((wcsH,wcsW), np.float32)
        cochi2 = np.zeros((wcsH,wcsW), np.float32)
        coiv   = np.zeros((wcsH,wcsW), np.float32)
        detiv   = np.zeros((wcsH,wcsW), np.float32)
        for itim, (tim,mod) in enumerate(zip(tims, mods)):
            if tim.band != band:
                continue
            R = tim_get_resamp(tim, targetwcs)
            if R is None:
                continue
            (Yo,Xo,Yi,Xi) = R
            comod[Yo,Xo] += mod[Yi,Xi]
            ie = tim.getInvError()
            cochi2[Yo,Xo] += ((tim.getImage()[Yi,Xi] - mod[Yi,Xi]) * ie[Yi,Xi])**2
            coiv[Yo,Xo] += ie[Yi,Xi]**2

            psfnorm = 1./(2. * np.sqrt(np.pi) * tim.psf_sigma)
            detsig1 = tim.sig1 / psfnorm
            detiv[Yo,Xo] += (ie[Yi,Xi] > 0) * (1. / detsig1**2)

        comod  /= np.maximum(cons[iband], 1)
        comods.append(comod)
        del comod

        fn = 'tunebrick/coadd/chi2-%06i-%s.fits' % (brickid, band)
        fitsio.write(fn, cochi2, **fwa)
        del cochi2
        print 'Wrote', fn

        fn = 'tunebrick/coadd/image-%06i-%s.fits' % (brickid, band)
        fitsio.write(fn, coimgs[iband], **fwa)
        print 'Wrote', fn
        fitsio.write(fn, coiv, clobber=False)
        print 'Appended ivar to', fn
        del coiv

        fn = 'tunebrick/coadd/depth-%06i-%s.fits' % (brickid, band)
        fitsio.write(fn, detiv, **fwa)
        print 'Wrote', fn
        del detiv

        fn = 'tunebrick/coadd/model-%06i-%s.fits' % (brickid, band)
        fitsio.write(fn, comods[iband], **fwa)
        print 'Wrote', fn

        fn = 'tunebrick/coadd/nexp-b%06i-%s.fits' % (brickid, band)
        fitsio.write(fn, cons[iband], **fwa)
        print 'Wrote', fn

    plt.clf()
    rgb = get_rgb(comods, bands)
    dimshow(rgb)
    plt.title('Model')
    ps.savefig()
    del comods
    
    # Plot sources over top
    ax = plt.axis()
    for i,src in enumerate(cat):
        rd = src.getPosition()
        ok,x,y = targetwcs.radec2pixelxy(rd.ra, rd.dec)
        cc = (0,1,0)
        if isinstance(src, PointSource):
            plt.plot(x-1, y-1, '+', color=cc, **pla)
        else:
            plt.plot(x-1, y-1, 'o', mec=cc, mfc='none', **pla)
        # plt.text(x, y, '%i' % i, color=cc, ha='center', va='bottom')
    plt.axis(ax)
    ps.savefig()

    tmpfn = create_temp(suffix='.png')
    plt.imsave(tmpfn, rgb)
    del rgb
    cmd = 'pngtopnm %s | pnmtojpeg -quality 90 > tunebrick/coadd/model-%06i-full.jpg' % (tmpfn, brickid)
    os.system(cmd)
    os.unlink(tmpfn)

    assert(len(cat) == len(Tcat))
    print 'Coadd FITS files and plots:', Time()-t0

    print 'Whole stage:', Time()-tstage

    return dict(cat=cat, Tcat=Tcat, invvars=invvars)
Пример #3
0
def stage_recoadd(tims=None,
                  bands=None,
                  targetwcs=None,
                  ps=None,
                  brickid=None,
                  basedir=None,
                  ccds=None,
                  **kwargs):
    #print 'kwargs:', kwargs.keys()
    if targetwcs is None:
        # can happen if no CCDs overlap...
        import sys
        sys.exit(0)

    fn = os.path.join(basedir, 'ccds-%06i.fits' % brickid)
    ccds.writeto(fn)
    print('Wrote', fn)

    W = targetwcs.get_width()
    H = targetwcs.get_height()

    coimgs = []
    # moo
    cowimgs = []
    #nimgs = []
    wimgs = []
    for iband, band in enumerate(bands):
        coimg = np.zeros((H, W), np.float32)
        cowimg = np.zeros((H, W), np.float32)
        wimg = np.zeros((H, W), np.float32)
        nimg = np.zeros((H, W), np.uint8)
        for tim in tims:
            if tim.band != band:
                continue
            print('Coadding', tim.name)
            R = tim_get_resamp(tim, targetwcs)
            if R is None:
                continue
            (Yo, Xo, Yi, Xi) = R
            coimg[Yo, Xo] += tim.getImage()[Yi, Xi]
            nimg[Yo, Xo] += 1
            cowimg[Yo, Xo] += tim.getImage()[Yi, Xi] * tim.getInvvar()[Yi, Xi]
            wimg[Yo, Xo] += tim.getInvvar()[Yi, Xi]
            del R, Yo, Xo, Yi, Xi
        coimg /= np.maximum(nimg, 1)
        cowimg /= np.maximum(wimg, 1e-16)
        coimgs.append(coimg)
        cowimgs.append(cowimg)
        #nimgs.append(nimg)
        wimgs.append(wimg)

    for i, (wimg, cowimg, coimg) in enumerate(zip(wimgs, cowimgs, coimgs)):
        cowimg[wimg == 0] = coimg[wimg == 0]
    del wimgs
    del coimgs
    del wimg
    del coimg

    try:
        os.path.makedirs(os.path.join(basedir, 'coadd'))
    except:
        pass

    # WCS header for these images
    hdr = fitsio.FITSHDR()
    targetwcs.add_to_header(hdr)
    fwa = dict(clobber=True, header=hdr)

    for band, cow in zip(bands, cowimgs):
        fn = os.path.join(basedir, 'coadd',
                          'image2-%06i-%s.fits' % (brickid, band))
        fitsio.write(fn, cow, **fwa)
        print('Wrote', fn)

    return dict(coimgs=cowimgs, tims=None)
Пример #4
0
def stage_recoadd(tims=None, bands=None, targetwcs=None, ps=None, brickid=None,
                  basedir=None, ccds=None,
                  **kwargs):
    #print 'kwargs:', kwargs.keys()
    if targetwcs is None:
        # can happen if no CCDs overlap...
        import sys
        sys.exit(0)

    fn = os.path.join(basedir, 'ccds-%06i.fits' % brickid)
    ccds.writeto(fn)
    print 'Wrote', fn
    
    W = targetwcs.get_width()
    H = targetwcs.get_height()

    coimgs = []
    # moo
    cowimgs = []
    #nimgs = []
    wimgs = []
    for iband,band in enumerate(bands):
        coimg  = np.zeros((H,W), np.float32)
        cowimg  = np.zeros((H,W), np.float32)
        wimg  = np.zeros((H,W), np.float32)
        nimg  = np.zeros((H,W), np.uint8)
        for tim in tims:
            if tim.band != band:
                continue
            print 'Coadding', tim.name
            R = tim_get_resamp(tim, targetwcs)
            if R is None:
                continue
            (Yo,Xo,Yi,Xi) = R
            coimg[Yo,Xo] += tim.getImage()[Yi,Xi]
            nimg[Yo,Xo] += 1
            cowimg[Yo,Xo] += tim.getImage()[Yi,Xi] * tim.getInvvar()[Yi,Xi]
            wimg[Yo,Xo] += tim.getInvvar()[Yi,Xi]
            del R,Yo,Xo,Yi,Xi
        coimg /= np.maximum(nimg, 1)
        cowimg /= np.maximum(wimg, 1e-16)
        coimgs.append(coimg)
        cowimgs.append(cowimg)
        #nimgs.append(nimg)
        wimgs.append(wimg)

    for i,(wimg,cowimg,coimg) in enumerate(zip(wimgs, cowimgs, coimgs)):
        cowimg[wimg == 0] = coimg[wimg == 0]
    del wimgs
    del coimgs
    del wimg
    del coimg

    try:
        os.path.makedirs(os.path.join(basedir, 'coadd'))
    except:
        pass

    # WCS header for these images
    hdr = fitsio.FITSHDR()
    targetwcs.add_to_header(hdr)
    fwa = dict(clobber=True, header=hdr)

    for band,cow in zip(bands, cowimgs):
        fn = os.path.join(basedir, 'coadd', 'image2-%06i-%s.fits' % (brickid,band))
        fitsio.write(fn, cow, **fwa)
        print 'Wrote', fn

    return dict(coimgs=cowimgs, tims=None)