def main(opt, ps): #ralo = 36 #rahi = 42 #declo = -1.25 #dechi = 1.25 #width = 7 ralo = 37.5 rahi = 41.5 declo = -1.5 dechi = 2.5 width = 2.5 rl,rh = 39,40 dl,dh = 0,1 roipoly = np.array([(rl,dl),(rl,dh),(rh,dh),(rh,dl)]) ra = (ralo + rahi ) / 2. dec = (declo + dechi) / 2. bandnum = 1 band = 'w%i' % bandnum plt.figure(figsize=(12,12)) #basedir = '/project/projectdirs/bigboss' #wisedatadir = os.path.join(basedir, 'data', 'wise') wisedatadirs = ['/clusterfs/riemann/raid007/bosswork/boss/wise_level1b', '/clusterfs/riemann/raid000/bosswork/boss/wise1ext'] wisecatdir = '/home/boss/products/NULL/wise/trunk/fits/' ofn = 'wise-images-overlapping.fits' if os.path.exists(ofn): print 'File exists:', ofn T = fits_table(ofn) print 'Found', len(T), 'images overlapping' print 'Reading WCS headers...' wcses = [] T.filename = [fn.strip() for fn in T.filename] for fn in T.filename: wcs = anwcs(fn, 0) wcses.append(wcs) else: TT = [] for d in wisedatadirs: ifn = os.path.join(d, 'WISE-index-L1b.fits') #'index-allsky-astr-L1b.fits') T = fits_table(ifn, columns=['ra','dec','scan_id','frame_num']) print 'Read', len(T), 'from WISE index', ifn I = np.flatnonzero((T.ra > ralo) * (T.ra < rahi) * (T.dec > declo) * (T.dec < dechi)) print len(I), 'overlap RA,Dec box' T.cut(I) fns = [] for sid,fnum in zip(T.scan_id, T.frame_num): print 'scan,frame', sid, fnum fn = get_l1b_file(d, sid, fnum, bandnum) print '-->', fn assert(os.path.exists(fn)) fns.append(fn) T.filename = np.array(fns) TT.append(T) T = merge_tables(TT) wcses = [] corners = [] ii = [] for i in range(len(T)): wcs = anwcs(T.filename[i], 0) W,H = wcs.get_width(), wcs.get_height() rd = [] for x,y in [(1,1),(1,H),(W,H),(W,1)]: rd.append(wcs.pixelxy2radec(x,y)) rd = np.array(rd) if polygons_intersect(roipoly, rd): wcses.append(wcs) corners.append(rd) ii.append(i) print 'Found', len(wcses), 'overlapping' I = np.array(ii) T.cut(I) outlines = corners corners = np.vstack(corners) nin = sum([1 if point_in_poly(ra,dec,ol) else 0 for ol in outlines]) print 'Number of images containing RA,Dec,', ra,dec, 'is', nin r0,r1 = corners[:,0].min(), corners[:,0].max() d0,d1 = corners[:,1].min(), corners[:,1].max() print 'RA,Dec extent', r0,r1, d0,d1 T.writeto(ofn) print 'Wrote', ofn # MAGIC 2.75: approximate pixel scale, "/pix S = int(3600. / 2.75) print 'Coadd size', S cowcs = anwcs_create_box(ra, dec, 1., S, S) if False: print 'Plotting map...' plot = Plotstuff(outformat='png', ra=ra, dec=dec, width=width, size=(800,800)) out = plot.outline plot.color = 'white' plot.alpha = 0.07 plot.apply_settings() for wcs in wcses: out.wcs = wcs out.fill = False plot.plot('outline') out.fill = True plot.plot('outline') plot.color = 'gray' plot.alpha = 1.0 plot.lw = 1 plot.plot_grid(1, 1, 1, 1) plot.color = 'red' plot.lw = 3 plot.alpha = 0.75 out.wcs = cowcs out.fill = False plot.plot('outline') if opt.sources: rd = plot.radec plot_radec_set_filename(rd, opt.sources) plot.plot('radec') pfn = ps.getnext() plot.write(pfn) print 'Wrote', pfn # Re-sort by distance to RA,Dec center... #I = np.argsort(np.hypot(T.ra - ra, T.dec - dec)) #T.cut(I) # IF YOU DO THIS, MUST ALSO RE-SORT 'wcses'! if opt.sources: # Look at a radius this big, in arcsec, around each source position. # 15" = about 6 WISE pixels Wrad = 15. / 3600. # Look for SDSS objects within this radius; Wrad + a margin Srad = Wrad + 5./3600. S = fits_table(opt.sources) print 'Read', len(S), 'sources from', opt.sources groups,singles = cluster_radec(S.ra, S.dec, Wrad, singles=True) print 'Source clusters:', groups print 'Singletons:', singles tractors = [] sdss = DR9(basedir='data-dr9') sband = 'r' for i in singles: r,d = S.ra[i],S.dec[i] print 'Source', i, 'at', r,d fn = sdss.retrieve('photoObj', S.run[i], S.camcol[i], S.field[i], band=sband) print 'Reading', fn oo = fits_table(fn) print 'Got', len(oo) cat1,obj1,I = get_tractor_sources_dr9(None, None, None, bandname=sband, objs=oo, radecrad=(r,d,Srad), bands=[], nanomaggies=True, extrabands=[band], fixedComposites=True, getobjs=True, getobjinds=True) print 'Got', len(cat1), 'SDSS sources nearby' # Find images that overlap? ims = [] for j,wcs in enumerate(wcses): print 'Filename', T.filename[j] ok,x,y = wcs.radec2pixelxy(r,d) print 'WCS', j, '-> x,y:', x,y if not anwcs_radec_is_inside_image(wcs, r, d): continue tim = wise.read_wise_level1b( T.filename[j].replace('-int-1b.fits',''), nanomaggies=True, mask_gz=True, unc_gz=True, sipwcs=True, constantInvvar=True, radecrad=(r,d,Wrad)) ims.append(tim) print 'Found', len(ims), 'images containing this source' tr = Tractor(ims, cat1) tractors.append(tr) if len(groups): # TODO! assert(False) sys.exit(0) # Find additional SDSS sources nearby = within R pixels radius. R = 30. #R = 50. rad = R * 0.396 / 3600. cats = [] objs = [] for run,camcol,field,r,d in zip(S.run, S.camcol, S.field, S.ra, S.dec): fn = sdss.retrieve('photoObj', run, camcol, field, band=sband) print 'Reading', fn oo = fits_table(fn) print 'Got', len(oo) cat1,obj1,I = get_tractor_sources_dr9(None, None, None, bandname=sband, objs=oo, radecrad=(r,d,rad), bands=[], nanomaggies=True, extrabands=[band], fixedComposites=True, getobjs=True, getobjinds=True) print 'Got', len(cat1), 'SDSS sources nearby' cats.append(cat1) objs.append(obj1[I]) # Merge into one big catalog. cat = Catalog() for c in cats: for src in c: cat.append(src) S = merge_tables(objs) print 'Merged catalog has', len(cat), 'entries' print 'S table has', len(S) assert(len(S) == len(cat)) if opt.ptsrc: print 'Converting all sources to PointSources' pcat = Catalog() for src in cat: ps = PointSource(src.getPosition(), src.getBrightness()) pcat.append(ps) print 'PointSource catalog:', pcat cat = pcat # ?? WW = S #WW = tabledata() # cat = get_tractor_sources_dr9(None, None, None, bandname=sband, # objs=S, bands=[], nanomaggies=True, # extrabands=[band]) print 'Got', len(cat), 'tractor sources' #cat = Catalog(*cat) print cat for src in cat: print ' ', src ### FIXME -- match to WISE catalog to initialize mags? # Initialize WISE mags to be at least detectable # so that we identify the right pixel ROIs below. #minbright = NanoMaggies.magToNanomaggies() #minbright = 50. minbright = 250. cat.freezeParamsRecursive('*') cat.thawPathsTo(band) p0 = cat.getParams() cat.setParams(np.maximum(minbright, p0)) print 'Set minimum W1 brightness:' for src in cat: print ' ', src # Cut images that don't overlap. ii = [] for i,wcs in enumerate(wcses): isin = False for r,d in zip(S.ra, S.dec): if anwcs_radec_is_inside_image(wcs, r, d): isin = True break if isin: ii.append(i) T.cut(np.array(ii)) print 'Cut to', len(T), 'images containing sources' else: wfn = 'wise-sources-nearby.fits' if os.path.exists(wfn): print 'Reading existing file', wfn W = fits_table(wfn) print 'Got', len(W), 'with range RA', W.ra.min(), W.ra.max(), ', Dec', W.dec.min(), W.dec.max() else: # Range of WISE slices (inclusive) containing this Dec range. ws0, ws1 = 26,27 WW = [] for w in range(ws0, ws1+1): fn = os.path.join(wisecatdir, 'wise-allsky-cat-part%02i-radec.fits' % w) print 'Searching for sources in', fn W = fits_table(fn) I = np.flatnonzero((W.ra >= r0) * (W.ra <= r1) * (W.dec >= d0) * (W.dec <= d1)) fn = os.path.join(wisecatdir, 'wise-allsky-cat-part%02i.fits' % w) print 'Reading', len(I), 'rows from', fn W = fits_table(fn, rows=I) print 'Cut to', len(W), 'sources in range' WW.append(W) W = merge_tables(WW) del WW print 'Total of', len(W) W.writeto(wfn) print 'wrote', wfn # DEBUG W.cut((W.ra >= rl) * (W.ra <= rh) * (W.dec >= dl) * (W.dec <= dh)) print 'Cut to', len(W), 'in the central region' print 'Creating', len(W), 'Tractor sources' cat = Catalog() for i in range(len(W)): w1 = W.w1mpro[i] nm = NanoMaggies.magToNanomaggies(w1) cat.append(PointSource(RaDecPos(W.ra[i], W.dec[i]), NanoMaggies(w1=nm))) WW = W cat.freezeParamsRecursive('*') cat.thawPathsTo(band) cat0 = cat.getParams() br0 = [src.getBrightness().copy() for src in cat] nm0 = np.array([b.getBand(band) for b in br0]) WW.nm0 = nm0 w1psf = wise.get_psf_model(bandnum, opt.pixpsf) # Create fake image in the "coadd" footprint in order to find overlapping # sources. H,W = int(cowcs.imageh), int(cowcs.imagew) # MAGIC -- sigma a bit smaller than typical images (4.0-ish) sig = 3.5 # typical zeropoint zp = 20.752 faketim = Image(data=np.zeros((H,W), np.float32), invvar=np.zeros((H,W), np.float32) + (1./sig**2), psf=w1psf, wcs=ConstantFitsWcs(cowcs), sky=ConstantSky(0.), photocal = LinearPhotoCal(NanoMaggies.zeropointToScale(zp), band=band), #photocal=LinearPhotoCal(1., band=band), name='fake') minsb = 0.1 * sig #minsb = 0. # pc = faketim.getPhotoCal() # print 'Source counts:' # for src in cat: # print ' ', src # print '-->', pc.brightnessToCounts(src.getBrightness()) # print ' -->', [pc.brightnessToCounts(br) for br in src.getBrightnesses()] # print 'Source pixel positions:' # wcs = faketim.getWcs() # for src in cat: # print ' ', src # print '--> x,y', wcs.positionToPixel(src.getPosition()) print 'Finding overlapping sources...' t0 = Time() tractor = Tractor([faketim], cat) groups,L,fakemod = tractor.getOverlappingSources(0, minsb=minsb) print 'Overlapping sources took', Time()-t0 print 'Got', len(groups), 'groups of sources' nl = L.max() gslices = find_objects(L, nl) print 'unique labels:', np.unique(L) # plt.clf() # plt.imshow(fakemod, interpolation='nearest', origin='lower', # vmin=0, vmax=sig*3.) # plt.title('Fakemod') # ps.savefig() # # for IM in [L, (L>0)]: # plt.clf() # plt.imshow(IM, interpolation='nearest', origin='lower') # plt.gray() # wcs = faketim.getWcs() # xy = [] # for src in cat: # x,y = wcs.positionToPixel(src.getPosition()) # xy.append((x,y)) # xy = np.array(xy) # ax = plt.axis() # plt.plot(xy[:,0], xy[:,1], 'r+') # plt.title('Source groups') # ps.savefig() # Find sources touching each group's (rectangular) ROI tgroups = {} for i,gslice in enumerate(gslices): gl = i+1 tg = np.unique(L[gslice]) tsrcs = [] for g in tg: if not g in [gl,0]: if g in groups: tsrcs.extend(groups[g]) tgroups[gl] = tsrcs # for i,gslice in enumerate(gslices): # if not (i+1) in groups: # continue # # plt.clf() # plt.imshow(IM[gslice], interpolation='nearest', origin='lower') # plt.gray() # wcs = faketim.getWcs() # xy = [] # y0,x0 = gslice[0].start, gslice[1].start # for src in cat: # x,y = wcs.positionToPixel(src.getPosition()) # xy.append((x-x0,y-y0)) # xy = np.array(xy) # # ax = plt.axis() # # plt.plot(xy[:,0], xy[:,1], 'r+') # # I = np.array(groups[i+1]) # if len(I): # plt.plot(xy[I,0], xy[I,1], 'g.') # # I = np.array(tgroups[i+1]) # if len(I): # plt.plot(xy[I,0], xy[I,1], 'gx') # # ps.savefig() # # plt.axis(ax) # ps.savefig() print 'Group size histogram:' ng = Counter() for g in groups.values(): ng[len(g)] += 1 kk = ng.keys() kk.sort() for k in kk: print ' ', k, 'sources:', ng[k], 'groups' nms = [] tims = [] allrois = {} badrois = {} if opt.threads: mp = multiproc(opt.threads) else: mp = multiproc(1) tims = mp.map(_read_l1b, T.filename) for imi,tim in enumerate(tims): tim.psf = w1psf H,W = tim.shape nin = 0 for src in cat: x,y = tim.getWcs().positionToPixel(src.getPosition()) if x >= 0 and y >= 0 and x < W and y < H: nin += 1 print 'Number of sources inside image:', nin tractor = Tractor([tim], cat) tractor.freezeParam('images') ### ?? cat.setParams(cat0) pgroups = 0 pobjs = 0 for gi in range(len(gslices)): gl = gi # note, gslices is zero-indexed gslice = gslices[gl] gl += 1 if not gl in groups: print 'Group', gl, 'not in groups array; skipping' continue gsrcs = groups[gl] tsrcs = tgroups[gl] # print 'Group number', (gi+1), 'of', len(Gorder), ', id', gl, ': sources', gsrcs # print 'sources in groups touching slice:', tsrcs # Convert from 'canonical' ROI to this image. yl,yh = gslice[0].start, gslice[0].stop xl,xh = gslice[1].start, gslice[1].stop x0,y0 = W-1,H-1 x1,y1 = 0,0 for x,y in [(xl,yl),(xh-1,yl),(xh-1,yh-1),(xl,yh-1)]: r,d = cowcs.pixelxy2radec(x+1, y+1) x,y = tim.getWcs().positionToPixel(RaDecPos(r,d)) x = int(np.round(x)) y = int(np.round(y)) x = np.clip(x, 0, W-1) y = np.clip(y, 0, H-1) x0 = min(x0, x) y0 = min(y0, y) x1 = max(x1, x) y1 = max(y1, y) if x1 == x0 or y1 == y0: print 'Gslice', gslice, 'is completely outside this image' continue gslice = (slice(y0,y1+1), slice(x0, x1+1)) if np.all(tim.getInvError()[gslice] == 0): print 'This whole object group has invvar = 0.' if not gl in badrois: badrois[gl] = {} badrois[gl][imi] = gslice continue if not gl in allrois: allrois[gl] = {} allrois[gl][imi] = gslice if not opt.individual: continue fullcat = tractor.catalog subcat = Catalog(*[fullcat[i] for i in gsrcs + tsrcs]) for i in range(len(tsrcs)): subcat.freezeParam(len(gsrcs) + i) tractor.catalog = subcat print len(gsrcs), 'sources unfrozen; total', len(subcat) pgroups += 1 pobjs += len(gsrcs) t0 = Time() tractor.optimize_forced_photometry(minsb=minsb, mindlnp=1., rois=[gslice]) print 'optimize_forced_photometry took', Time()-t0 tractor.catalog = fullcat # mod = tractor.getModelImage(0, minsb=minsb) # noise = np.random.normal(size=mod.shape) # noise[tim.getInvError() == 0] = 0. # nz = (tim.getInvError() > 0) # noise[nz] *= (1./tim.getInvError()[nz]) # ima = dict(interpolation='nearest', origin='lower', # vmin=tim.zr[0], vmax=tim.zr[1]) # imchi = dict(interpolation='nearest', origin='lower', # vmin=-5, vmax=5) # plt.clf() # plt.subplot(2,2,1) # plt.imshow(tim.getImage(), **ima) # plt.gray() # plt.subplot(2,2,2) # plt.imshow(mod, **ima) # plt.gray() # plt.subplot(2,2,3) # plt.imshow((tim.getImage() - mod) * tim.getInvError(), **imchi) # plt.gray() # plt.subplot(2,2,4) # plt.imshow(mod + noise, **ima) # plt.gray() # plt.suptitle('W1, scan %s, frame %i' % (sid, fnum)) # ps.savefig() if opt.individual: print 'Photometered', pgroups, 'groups containing', pobjs, 'objects' cat.thawPathsTo(band) nm1 = np.array([src.getBrightness().getBand(band) for src in cat]) nms.append(nm1) WW.nms = np.array(nms).T fn = opt.output % imi WW.writeto(fn) print 'Wrote', fn return dict(cat0=cat0, WW=WW, band=band, tims=tims, allrois=allrois, badrois=badrois, groups=groups, tgroups=tgroups, minsb=minsb, gslices=gslices, cat=cat)
def simult_photom(cat0=None, WW=None, band=None, tims=None, allrois=None, badrois=None, groups=None, tgroups=None, minsb=None, gslices=None, cat=None, opt=None, ps=None): def _plot_grid(ims, kwas): N = len(ims) C = int(np.ceil(np.sqrt(N))) R = int(np.ceil(N / float(C))) plt.clf() for i,(im,kwa) in enumerate(zip(ims, kwas)): plt.subplot(R,C, i+1) #print 'plotting grid cell', i, 'img shape', im.shape plt.imshow(im, **kwa) plt.gray() plt.xticks([]); plt.yticks([]) return R,C def _plot_grid2(ims, cat, tims, kwas, ptype='mod'): xys = [] stamps = [] for (img,mod,chi,roi),tim in zip(ims, tims): if ptype == 'mod': stamps.append(mod) elif ptype == 'chi': stamps.append(chi) wcs = tim.getWcs() y0,x0 = roi[0].start, roi[1].start xy = [] for src in cat: xi,yi = wcs.positionToPixel(src.getPosition()) xy.append((xi - x0, yi - y0)) xys.append(xy) #print 'X,Y source positions in stamp of shape', stamps[-1].shape #print ' ', xy R,C = _plot_grid(stamps, kwas) for i,xy in enumerate(xys): plt.subplot(R, C, i+1) ax = plt.axis() xy = np.array(xy) plt.plot(xy[:,0], xy[:,1], 'r+', lw=2) plt.axis(ax) # Simultaneous photometry if opt.osources: O = fits_table(opt.osources) ocat = Catalog() print 'Other catalog:' for i in range(len(O)): w1 = O.wiseflux[i, 0] s = PointSource(RaDecPos(O.ra[i], O.dec[i]), NanoMaggies(w1=w1)) ocat.append(s) print ocat ocat.freezeParamsRecursive('*') ocat.thawPathsTo(band) # Keep track of params after simultaneous photometry... cat.freezeParamsRecursive('*') cat.thawPathsTo(band) catsim = cat.getParams() if opt.opt: # ... and also after RA,Dec opt. cat.thawPathsTo('ra','dec') catopt = cat.getParams() cat.freezeParamsRecursive('*') cat.thawPathsTo(band) for gi in range(len(gslices)): gl = gi gl += 1 if not gl in groups: print 'Group', gl, 'not in groups array; skipping' continue gsrcs = groups[gl] tsrcs = tgroups[gl] print 'Group', gl print 'gsrcs:', gsrcs print 'tsrcs:', tsrcs if (not gl in allrois) and (not gl in badrois): print 'Group', gl, 'does not touch any images?' continue mytims = [] rois = [] if gl in allrois: for imi,roi in allrois[gl].items(): mytims.append(tims[imi]) rois.append(roi) mybadtims = [] mybadrois = [] if gl in badrois: for imi,roi in badrois[gl].items(): mybadtims.append(tims[imi]) mybadrois.append(roi) print 'Group', gl, 'touches', len(mytims), 'images and', len(mybadtims), 'bad ones' tt = 'group %i: %i+%i sources' % (gl, len(gsrcs), len(tsrcs)) if len(mytims): cat.setParams(catsim) #print 'Restoring catsim:' #cat.printThawedParams() subcat = Catalog(*[cat[i] for i in gsrcs + tsrcs]) for i in range(len(tsrcs)): subcat.freezeParam(len(gsrcs) + i) tractor = Tractor(mytims, subcat) tractor.freezeParam('images') print len(gsrcs), 'sources unfrozen; total', len(subcat) print 'Before fitting:' for src in subcat[:len(gsrcs)]: print ' ', src t0 = Time() ims0,ims1 = tractor.optimize_forced_photometry(minsb=minsb, mindlnp=1., rois=rois) print 'optimize_forced_photometry took', Time()-t0 print 'After fitting:' for src in subcat[:len(gsrcs)]: print ' ', src imas = [dict(interpolation='nearest', origin='lower', vmin=tim.zr[0], vmax=tim.zr[1]) for tim in mytims] imchi = dict(interpolation='nearest', origin='lower', vmin=-5, vmax=5) imchis = [imchi] * len(mytims) _plot_grid([img for (img, mod, chi, roi) in ims0], imas) plt.suptitle('Data: ' + tt) ps.savefig() if ims1 is not None: #_plot_grid([mod for (img, mod, chi, roi) in ims1], imas) _plot_grid2(ims1, subcat, mytims, imas) plt.suptitle('Forced-phot model: ' + tt) ps.savefig() #_plot_grid([chi for (img, mod, chi, roi) in ims1], imchis) _plot_grid2(ims1, subcat, mytims, imchis, ptype='chi') plt.suptitle('Forced-phot chi: ' + tt) ps.savefig() if opt.osources: cc = tractor.catalog tractor.catalog = ocat nil,nil,ims3 = tractor.optimize_forced_photometry(minsb=minsb, rois=rois, justims0=True) tractor.catalog = cc _plot_grid2(ims3, ocat, mytims, imas) plt.suptitle("Schlegel's model: group %i" % gl) ps.savefig() _plot_grid2(ims3, ocat, mytims, imchis, ptype='chi') plt.suptitle("Schlegel's chi: group %i" % gl) ps.savefig() if opt.opt: op1 = ps.getnext() op2 = ps.getnext() #fits[gl] = (tractor, len(gsrcs), rois, op1, op2) # print 'Plotting mods after simul photom' # #_plot_grid([mod for (img, mod, chi, roi) in ims0], imas) # _plot_grid2(ims0, subcat, mytims, imas) # plt.suptitle('Initial model: ' + tt) # ps.savefig() # # print 'Plotting chis after simul photom' # #_plot_grid([chi for (img, mod, chi, roi) in ims0], imchis) # _plot_grid2(ims0, subcat, mytims, imchis, ptype='chi') # plt.suptitle('Initial chi: ' + tt) # ps.savefig() print 'After simultaneous photometry:' subcat.printThawedParams() # Copy updated params to "catsim" catsim = cat.getParams() #print 'Saving catsim:' #cat.printThawedParams() cat.freezeParamsRecursive('*') cat.thawPathsTo(band) WW.nmall = np.array([src.getBrightness().getBand(band) for src in cat]) if len(mytims) and opt.opt: print 'Optimizing RA,Dec' subcat = tractor.catalog # Copy updated forced-phot params from catsim to catopt. #print 'Saving subcat forced-phot params:' #subcat.printThawedParams() fphot = subcat.getParams() cat.thawPathsTo('ra','dec') cat.setParams(catopt) #print 'Copying forced-phot results to catopt:' cat.freezeParamsRecursive('*') cat.thawPathsTo(band) cat.freezeAllBut(*gsrcs) #cat.printThawedParams() NP = cat.numberOfParams() cat.setParams(fphot[:NP]) #print 'Result:' #print 'Restoring catopt:' #cat.printThawedParams() cat.freezeParamsRecursive('*') cat.thawPathsTo(band) NG = len(gsrcs) for i in range(NG): subcat[i].thawPathsTo('ra','dec') p0 = subcat.getParams() print 'Optimizing params:' subcat.printThawedParams() thetims = tractor.images subimgs = [] for i,img in enumerate(thetims): roi = rois[i] y0 = roi[0].start x0 = roi[1].start subwcs = ShiftedWcs(img.wcs, x0, y0) subimg = Image(data=img.data[roi], invvar=img.invvar[roi], psf=img.psf, wcs=subwcs, sky=img.sky, photocal=img.photocal, name=img.name) subimgs.append(subimg) tractor.images = Images(*subimgs) while True: dlnp,X,alpha = tractor.optimize() print 'dlnp', dlnp print 'alpha', alpha if dlnp < 0.1: break p1 = subcat.getParams() print 'Param changes:' for nm,pp0,pp1 in zip(subcat.getParamNames(), p0, p1): print ' ', nm, pp0, 'to', pp1, '; delta', pp1-pp0 cat.thawPathsTo('ra','dec') catopt = cat.getParams() print 'Saving catopt:' cat.printThawedParams() cat.freezeParamsRecursive('ra', 'dec') tractor.images = thetims nil,nil,ims2 = tractor.optimize_forced_photometry(minsb=minsb, rois=rois, justims0=True) print 'Plotting mods after RA,Dec opt' #_plot_grid([mod for (img, mod, chi, roi) in ims2], imas) _plot_grid2(ims2, subcat, mytims, imas) plt.suptitle('RA,Dec-opt model: ' + tt) plt.savefig(op1) print 'Plotting chis after RA,Dec opt' #_plot_grid([chi for (img, mod, chi, roi) in ims2], imchis) _plot_grid2(ims2, subcat, mytims, imchis, ptype='chi') plt.suptitle('RA,Dec-opt chi: ' + tt) plt.savefig(op2) N = len(mybadtims) if N and False: C = int(np.ceil(np.sqrt(N))) R = int(np.ceil(N / float(C))) plt.clf() for i,(tim,roi) in enumerate(zip(mybadtims, mybadrois)): plt.subplot(R,C, i+1) plt.imshow(tim.getImage()[roi], interpolation='nearest', origin='lower', vmin=tim.zr[0], vmax=tim.zr[1]) plt.gray() plt.suptitle('Data in bad regions') ps.savefig() plt.clf() for i,(tim,roi) in enumerate(zip(mybadtims, mybadrois)): plt.subplot(R,C, i+1) plt.imshow(tim.getInvError()[roi], interpolation='nearest', origin='lower') plt.gray() plt.suptitle('Inverr in bad regions') ps.savefig() if gi == 0 and opt.plotmask: alltims = mybadtims+mytims _plot_grid([tim.uncplane[roi] for tim,roi in zip(alltims, mybadrois + rois)], [dict(interpolation='nearest', origin='lower')]*len(alltims)) plt.suptitle('Uncertainty plane') ps.savefig() for bit,txt in [ (0 , 'static: excessively noisy due to high dark current alone'), (1 , 'static: generally noisy [includes bit 0]'), (2 , 'static: dead or very low responsivity'), (3 , 'static: low responsivity or low dark current'), (4 , 'static: high responsivity or high dark current'), (5 , 'static: saturated anywhere in ramp'), (6 , 'static: high, uncertain, or unreliable non-linearity'), (7 , 'static: known broken hardware pixel or excessively noisy responsivity estimate [may include bit 1]'), (9 , 'broken pixel or negative slope fit value'), (10, 'saturated in sample read 1'), (11, 'saturated in sample read 2'), (12, 'saturated in sample read 3'), (13, 'saturated in sample read 4'), (14, 'saturated in sample read 5'), (15, 'saturated in sample read 6'), (16, 'saturated in sample read 7'), (17, 'saturated in sample read 8'), (18, 'saturated in sample read 9'), (21, 'new/transient bad pixel from dynamic masking'), (26, 'non-linearity correction unreliable'), (27, 'contains cosmic-ray or outlier that cannot be classified (from temporal outlier rejection in multi-frame pipeline)'), (28, 'contains positive or negative spike-outlier'), ]: _plot_grid([tim.maskplane[roi] & (1 << bit) for tim,roi in zip(alltims, mybadrois + rois)], [dict(interpolation='nearest', origin='lower', vmin=0, vmax=1)]*len(alltims)) plt.suptitle('Mask: ' + txt) ps.savefig() if opt.opt: fn = opt.output % 998 WW.writeto(fn) print 'Wrote', fn cat.thawPathsTo('ra','dec') cat.setParams(catopt) WW.nmoptrd = np.array([src.getBrightness().getBand(band) for src in cat]) cat.freezeParamsRecursive(band, 'dec') WW.raoptrd = np.array(cat.getParams()) cat.freezeParamsRecursive('ra') cat.thawPathsTo('dec') WW.decoptrd = np.array(cat.getParams()) cat.freezeParamsRecursive('dec') fn = opt.output % 999 WW.writeto(fn) print 'Wrote', fn
def _meisner_psf_models(): global plotslice # Meisner's PSF models # for band in [4]: for band in [1, 2, 3, 4]: print() print('W%i' % band) print() #pix = fitsio.read('wise-psf-avg-pix.fits', ext=band-1) pix = fitsio.read('wise-psf-avg-pix-bright.fits', ext=band - 1) fit = fits_table('wise-psf-avg.fits', hdu=band) #fit = fits_table('psf-allwise-con3.fits', hdu=band) scale = 1. print('Pix shape', pix.shape) h, w = pix.shape xx, yy = np.meshgrid(np.arange(w), np.arange(h)) cx, cy = np.sum(xx * pix) / np.sum(pix), np.sum(yy * pix) / np.sum(pix) print('Centroid:', cx, cy) #S = 100 S = h / 2 slc = slice(h / 2 - S, h / 2 + S + 1), slice(w / 2 - S, w / 2 + S + 1) plotss = 30 plotslice = slice(S - plotss, -(S - plotss)), slice(S - plotss, -(S - plotss)) opix = pix print('Original pixels sum:', opix.sum()) pix /= pix.sum() pix = pix[slc] print('Sliced pix sum', pix.sum()) psf = GaussianMixturePSF(fit.amp, fit.mean * scale, fit.var * scale**2) psfmodel = psf.getPointSourcePatch(0., 0., radius=h / 2) mod = psfmodel.patch print('Orig mod sum', mod.sum()) mod = mod[slc] print('Sliced mod sum', mod.sum()) print('Amps:', np.sum(psf.mog.amp)) _plot_psf(pix, mod, psf) plt.suptitle('W%i: Orig' % band) ps.savefig() # Lanczos sub-sample if band == 4: lpix = _lanczos_subsample(opix, 2) pix = lpix h, w = pix.shape #S = 140 slc = slice(h / 2 - S, h / 2 + S + 1), slice(w / 2 - S, w / 2 + S + 1) print('Resampled pix sum', pix.sum()) pix = pix[slc] print('sliced pix sum', pix.sum()) psf = GaussianMixturePSF(fit.amp, fit.mean * scale, fit.var * scale**2) psfmodel = psf.getPointSourcePatch(0., 0., radius=h / 2) mod = psfmodel.patch print('Scaled mod sum', mod.sum()) mod = mod[slc] print('Sliced mod sum', mod.sum()) _plot_psf(pix, mod, psf) plt.suptitle('W%i: Scaled' % band) ps.savefig() plotslice = slice(S - plotss, -(S - plotss)), slice(S - plotss, -(S - plotss)) psfx = GaussianMixturePSF.fromStamp(pix, P0=(fit.amp, fit.mean * scale, fit.var * scale**2)) psfmodel = psfx.getPointSourcePatch(0., 0., radius=h / 2) mod = psfmodel.patch print('From stamp: mod sum', mod.sum()) mod = mod[slc] print('Sliced mod sum:', mod.sum()) _plot_psf(pix, mod, psfx) plt.suptitle('W%i: fromStamp: %g = %s, res %.3f' % (band, np.sum(psfx.mog.amp), ','.join( ['%.3f' % a for a in psfx.mog.amp]), np.sum(pix - mod))) ps.savefig() print('Stamp-Fit PSF params:', psfx) # class MyGaussianMixturePSF(GaussianMixturePSF): # def getLogPrior(self): # if np.any(self.mog.amp < 0.): # return -np.inf # for k in range(self.mog.K): # if np.linalg.det(self.mog.var[k]) <= 0: # return -np.inf # return 0 # # @property # def amp(self): # return self.mog.amp # # mypsf = MyGaussianMixturePSF(psfx.mog.amp, psfx.mog.mean, psfx.mog.var) # mypsf.radius = sh/2 sh, sw = pix.shape # Initialize from original fit params psfx = psf # Try concentric gaussian PSF sigmas = [] for k in range(psfx.mog.K): v = psfx.mog.var[k, :, :] sigmas.append(np.sqrt(np.sqrt(np.abs(v[0, 0] * v[1, 1])))) print('Initializing concentric Gaussian PSF with sigmas', sigmas) gpsf = NCircularGaussianPSF(sigmas, psfx.mog.amp) gpsf.radius = sh / 2 mypsf = gpsf tim = Image(data=pix, invvar=1e6 * np.ones_like(pix), psf=mypsf) tim.modelMinval = 1e-16 # xx,yy = np.meshgrid(np.arange(sw), np.arange(sh)) # cx,cy = np.sum(xx*pix)/np.sum(pix), np.sum(yy*pix)/np.sum(pix) # print 'Centroid:', cx,cy # print 'Pix midpoint:', sw/2, sh/2 cx, cy = sw / 2, sh / 2 src = PointSource(PixPos(cx, cy), Flux(1.0)) tractor = Tractor([tim], [src]) tim.freezeAllBut('psf') # tractor.freezeParam('catalog') src.freezeAllBut('pos') # src.freezeAllBut('brightness') # tractor.freezeParam('images') # tractor.optimize_forced_photometry() # tractor.thawParam('images') # print 'Source flux after forced-photom fit:', src.getBrightness() print('Optimizing Params:') tractor.printThawedParams() for i in range(200): #dlnp,X,alpha = tractor.optimize(damp=0.1) dlnp, X, alpha = tractor.optimize(damp=1) print(i, 'dlnp %.3g' % dlnp, 'psf', gpsf) if dlnp < 1e-6: break tractor.freezeParam('catalog') gpsf.sigmas.stepsize = len(gpsf.sigmas) * [1e-6] gpsf.weights.stepsize = len(gpsf.sigmas) * [1e-6] for i in range(200): #dlnp,X,alpha = tractor.optimize(damp=0.1) dlnp, X, alpha = tractor.optimize(damp=1) print(i, 'dlnp %.3g' % dlnp, 'psf', gpsf) if dlnp < 1e-6: break print('PSF3(opt): flux', src.brightness) print('PSF amps:', np.sum(mypsf.amp)) print('PSF amps * Source brightness:', src.brightness.getValue() * np.sum(mypsf.amp)) print('pix sum:', pix.sum()) print('Optimize source:', src) print('Optimized PSF:', mypsf) mod = tractor.getModelImage(0) print('Mod sum:', mod.sum()) _plot_psf(pix, mod, mypsf, flux=src.brightness.getValue()) plt.suptitle('W%i psf3 (opt): %g = %s, resid %.3f' % (band, np.sum(mypsf.amp), ','.join( ['%.3f' % a for a in mypsf.amp]), np.sum(pix - mod))) ps.savefig() # Write before normalizing! T = fits_table() T.amp = mypsf.mog.amp T.mean = mypsf.mog.mean T.var = mypsf.mog.var T.writeto('psf3-w%i.fits' % band) T.writeto('psf3.fits', append=(band != 1)) mypsf.weights.setParams( np.array(mypsf.weights.getParams()) / sum(mypsf.weights.getParams())) print('Normalized PSF weights:', mypsf) mod = tractor.getModelImage(0) print('Mod sum:', mod.sum()) _plot_psf(pix, mod, mypsf, flux=src.brightness.getValue()) plt.suptitle('W%i psf3 (opt): %g = %s, resid %.3f' % (band, np.sum(mypsf.amp), ','.join( ['%.3f' % a for a in mypsf.amp]), np.sum(pix - mod))) ps.savefig() class MyGaussianPSF(NCircularGaussianPSF): ''' A PSF model with strictly positive weights that sum to unity. ''' def __init__(self, sigmas, weights): ww = np.array(weights) ww = np.log(ww) super(MyGaussianPSF, self).__init__(sigmas, ww) @staticmethod def getNamedParams(): return dict(sigmas=0, logweights=1) def __str__(self): return ('MyGaussianPSF: sigmas [ ' + ', '.join(['%.3f' % s for s in self.mysigmas]) + ' ], weights [ ' + ', '.join(['%.3f' % w for w in self.myweights]) + ' ]') @property def myweights(self): ww = np.exp(self.logweights.getAllParams()) ww /= ww.sum() return ww @property def weights(self): ww = np.exp(self.logweights.getParams()) wsum = np.sum(np.exp(self.logweights.getAllParams())) return ww / wsum if band == 4: # HACK mypsf = MyGaussianPSF([1.7, 6.4, 15.0], [0.333, 0.666, 0.1]) else: mypsf = MyGaussianPSF(gpsf.sigmas, gpsf.amp) mypsf.radius = sh / 2 tim.psf = mypsf print('Optimizing Params:') tractor.printThawedParams() for i in range(200): #dlnp,X,alpha = tractor.optimize(damp=0.1) dlnp, X, alpha = tractor.optimize(damp=1) print(i, 'dlnp %.3g' % dlnp, 'psf', mypsf) if dlnp < 1e-6: break mypsf.sigmas.stepsize = len(mypsf.sigmas) * [1e-6] mypsf.logweights.stepsize = len(mypsf.sigmas) * [1e-6] for i in range(200): #dlnp,X,alpha = tractor.optimize(damp=0.1) dlnp, X, alpha = tractor.optimize(damp=1) print(i, 'dlnp %.3g' % dlnp, 'psf', mypsf) if dlnp < 1e-6: break print('PSF amps:', np.sum(mypsf.amp)) print('pix sum:', pix.sum()) print('Optimize source:', src) print('Optimized PSF:', mypsf) mod = tractor.getModelImage(0) print('Mod sum:', mod.sum()) _plot_psf(pix, mod, mypsf, flux=src.brightness.getValue()) plt.suptitle('W%i psf3 (opt2): %g = %s, resid %.3f' % (band, np.sum(mypsf.amp), ','.join( ['%.3f' % a for a in mypsf.amp]), np.sum(pix - mod))) ps.savefig() # Write mog = mypsf.getMixtureOfGaussians() T = fits_table() T.amp = mog.amp T.mean = mog.mean T.var = mog.var T.writeto('psf4.fits', append=(band != 1))
#_meisner_psf_models() sys.exit(0) psf2 = GaussianMixturePSF(mypsf.mog.amp[:2] / mypsf.mog.amp[:2].sum(), mypsf.mog.mean[:2, :], mypsf.mog.var[:2, :, :]) psf2.radius = sh / 2 tim.psf = psf2 mod = tractor.getModelImage(0) _plot_psf(lpix, mod, psf2, flux=src.brightness.getValue()) plt.suptitle('psf2 (init)') ps.savefig() src.freezeAllBut('brightness') tractor.freezeParam('catalog') for i in range(100): print('Optimizing PSF:') tractor.printThawedParams() dlnp, X, alpha = tractor.optimize(damp=1.) tractor.freezeParam('images') tractor.thawParam('catalog') print('Optimizing flux:') tractor.printThawedParams() tractor.optimize_forced_photometry() tractor.thawParam('images')
def _meisner_psf_models(): global plotslice # Meisner's PSF models #for band in [4]: for band in [1,2,3,4]: print print 'W%i' % band print #pix = fitsio.read('wise-psf-avg-pix.fits', ext=band-1) pix = fitsio.read('wise-psf-avg-pix-bright.fits', ext=band-1) fit = fits_table('wise-psf-avg.fits', hdu=band) #fit = fits_table('psf-allwise-con3.fits', hdu=band) scale = 1. print 'Pix shape', pix.shape h,w = pix.shape xx,yy = np.meshgrid(np.arange(w), np.arange(h)) cx,cy = np.sum(xx*pix)/np.sum(pix), np.sum(yy*pix)/np.sum(pix) print 'Centroid:', cx,cy #S = 100 S = h/2 slc = slice(h/2-S, h/2+S+1), slice(w/2-S, w/2+S+1) plotss = 30 plotslice = slice(S-plotss, -(S-plotss)), slice(S-plotss, -(S-plotss)) opix = pix print 'Original pixels sum:', opix.sum() pix /= pix.sum() pix = pix[slc] print 'Sliced pix sum', pix.sum() psf = GaussianMixturePSF(fit.amp, fit.mean * scale, fit.var * scale**2) psfmodel = psf.getPointSourcePatch(0., 0., radius=h/2) mod = psfmodel.patch print 'Orig mod sum', mod.sum() mod = mod[slc] print 'Sliced mod sum', mod.sum() print 'Amps:', np.sum(psf.mog.amp) _plot_psf(pix, mod, psf) plt.suptitle('W%i: Orig' % band) ps.savefig() # Lanczos sub-sample if band == 4: lpix = _lanczos_subsample(opix, 2) pix = lpix h,w = pix.shape #S = 140 slc = slice(h/2-S, h/2+S+1), slice(w/2-S, w/2+S+1) print 'Resampled pix sum', pix.sum() pix = pix[slc] print 'sliced pix sum', pix.sum() psf = GaussianMixturePSF(fit.amp, fit.mean * scale, fit.var * scale**2) psfmodel = psf.getPointSourcePatch(0., 0., radius=h/2) mod = psfmodel.patch print 'Scaled mod sum', mod.sum() mod = mod[slc] print 'Sliced mod sum', mod.sum() _plot_psf(pix, mod, psf) plt.suptitle('W%i: Scaled' % band) ps.savefig() plotslice = slice(S-plotss,-(S-plotss)),slice(S-plotss,-(S-plotss)) psfx = GaussianMixturePSF.fromStamp(pix, P0=(fit.amp, fit.mean*scale, fit.var*scale**2)) psfmodel = psfx.getPointSourcePatch(0., 0., radius=h/2) mod = psfmodel.patch print 'From stamp: mod sum', mod.sum() mod = mod[slc] print 'Sliced mod sum:', mod.sum() _plot_psf(pix, mod, psfx) plt.suptitle('W%i: fromStamp: %g = %s, res %.3f' % (band, np.sum(psfx.mog.amp), ','.join(['%.3f'%a for a in psfx.mog.amp]), np.sum(pix - mod))) ps.savefig() print 'Stamp-Fit PSF params:', psfx # class MyGaussianMixturePSF(GaussianMixturePSF): # def getLogPrior(self): # if np.any(self.mog.amp < 0.): # return -np.inf # for k in range(self.mog.K): # if np.linalg.det(self.mog.var[k]) <= 0: # return -np.inf # return 0 # # @property # def amp(self): # return self.mog.amp # # mypsf = MyGaussianMixturePSF(psfx.mog.amp, psfx.mog.mean, psfx.mog.var) # mypsf.radius = sh/2 sh,sw = pix.shape # Initialize from original fit params psfx = psf # Try concentric gaussian PSF sigmas = [] for k in range(psfx.mog.K): v = psfx.mog.var[k,:,:] sigmas.append(np.sqrt(np.sqrt(np.abs(v[0,0] * v[1,1])))) print 'Initializing concentric Gaussian PSF with sigmas', sigmas gpsf = NCircularGaussianPSF(sigmas, psfx.mog.amp) gpsf.radius = sh/2 mypsf = gpsf tim = Image(data=pix, invvar=1e6 * np.ones_like(pix), psf=mypsf) tim.modelMinval = 1e-16 # xx,yy = np.meshgrid(np.arange(sw), np.arange(sh)) # cx,cy = np.sum(xx*pix)/np.sum(pix), np.sum(yy*pix)/np.sum(pix) # print 'Centroid:', cx,cy # print 'Pix midpoint:', sw/2, sh/2 cx,cy = sw/2, sh/2 src = PointSource(PixPos(cx, cy), Flux(1.0)) tractor = Tractor([tim], [src]) tim.freezeAllBut('psf') #tractor.freezeParam('catalog') src.freezeAllBut('pos') # src.freezeAllBut('brightness') # tractor.freezeParam('images') # tractor.optimize_forced_photometry() # tractor.thawParam('images') # print 'Source flux after forced-photom fit:', src.getBrightness() print 'Optimizing Params:' tractor.printThawedParams() for i in range(200): #dlnp,X,alpha = tractor.optimize(damp=0.1) dlnp,X,alpha = tractor.optimize(damp=1) print i,'dlnp %.3g' % dlnp, 'psf', gpsf if dlnp < 1e-6: break tractor.freezeParam('catalog') gpsf.sigmas.stepsize = len(gpsf.sigmas) * [1e-6] gpsf.weights.stepsize = len(gpsf.sigmas) * [1e-6] for i in range(200): #dlnp,X,alpha = tractor.optimize(damp=0.1) dlnp,X,alpha = tractor.optimize(damp=1) print i,'dlnp %.3g' % dlnp, 'psf', gpsf if dlnp < 1e-6: break print 'PSF3(opt): flux', src.brightness print 'PSF amps:', np.sum(mypsf.amp) print 'PSF amps * Source brightness:', src.brightness.getValue() * np.sum(mypsf.amp) print 'pix sum:', pix.sum() print 'Optimize source:', src print 'Optimized PSF:', mypsf mod = tractor.getModelImage(0) print 'Mod sum:', mod.sum() _plot_psf(pix, mod, mypsf, flux=src.brightness.getValue()) plt.suptitle('W%i psf3 (opt): %g = %s, resid %.3f' % (band, np.sum(mypsf.amp), ','.join(['%.3f'%a for a in mypsf.amp]), np.sum(pix-mod))) ps.savefig() # Write before normalizing! T = fits_table() T.amp = mypsf.mog.amp T.mean = mypsf.mog.mean T.var = mypsf.mog.var T.writeto('psf3-w%i.fits' % band) T.writeto('psf3.fits', append=(band != 1)) mypsf.weights.setParams(np.array(mypsf.weights.getParams()) / sum(mypsf.weights.getParams())) print 'Normalized PSF weights:', mypsf mod = tractor.getModelImage(0) print 'Mod sum:', mod.sum() _plot_psf(pix, mod, mypsf, flux=src.brightness.getValue()) plt.suptitle('W%i psf3 (opt): %g = %s, resid %.3f' % (band, np.sum(mypsf.amp), ','.join(['%.3f'%a for a in mypsf.amp]), np.sum(pix-mod))) ps.savefig() class MyGaussianPSF(NCircularGaussianPSF): ''' A PSF model with strictly positive weights that sum to unity. ''' def __init__(self, sigmas, weights): ww = np.array(weights) ww = np.log(ww) super(MyGaussianPSF, self).__init__(sigmas, ww) @staticmethod def getNamedParams(): return dict(sigmas=0, logweights=1) def __str__(self): return ('MyGaussianPSF: sigmas [ ' + ', '.join(['%.3f'%s for s in self.mysigmas]) + ' ], weights [ ' + ', '.join(['%.3f'%w for w in self.myweights]) + ' ]') @property def myweights(self): ww = np.exp(self.logweights.getAllParams()) ww /= ww.sum() return ww @property def weights(self): ww = np.exp(self.logweights.getParams()) wsum = np.sum(np.exp(self.logweights.getAllParams())) return ww / wsum if band == 4: # HACK mypsf = MyGaussianPSF([1.7, 6.4, 15.0], [0.333, 0.666, 0.1]) else: mypsf = MyGaussianPSF(gpsf.sigmas, gpsf.amp) mypsf.radius = sh/2 tim.psf = mypsf print 'Optimizing Params:' tractor.printThawedParams() for i in range(200): #dlnp,X,alpha = tractor.optimize(damp=0.1) dlnp,X,alpha = tractor.optimize(damp=1) print i,'dlnp %.3g' % dlnp, 'psf', mypsf if dlnp < 1e-6: break mypsf.sigmas.stepsize = len(mypsf.sigmas) * [1e-6] mypsf.logweights.stepsize = len(mypsf.sigmas) * [1e-6] for i in range(200): #dlnp,X,alpha = tractor.optimize(damp=0.1) dlnp,X,alpha = tractor.optimize(damp=1) print i,'dlnp %.3g' % dlnp, 'psf', mypsf if dlnp < 1e-6: break print 'PSF amps:', np.sum(mypsf.amp) print 'pix sum:', pix.sum() print 'Optimize source:', src print 'Optimized PSF:', mypsf mod = tractor.getModelImage(0) print 'Mod sum:', mod.sum() _plot_psf(pix, mod, mypsf, flux=src.brightness.getValue()) plt.suptitle('W%i psf3 (opt2): %g = %s, resid %.3f' % (band, np.sum(mypsf.amp), ','.join(['%.3f'%a for a in mypsf.amp]), np.sum(pix-mod))) ps.savefig() # Write mog = mypsf.getMixtureOfGaussians() T = fits_table() T.amp = mog.amp T.mean = mog.mean T.var = mog.var T.writeto('psf4.fits', append=(band != 1))
#_meisner_psf_models() sys.exit(0) psf2 = GaussianMixturePSF(mypsf.mog.amp[:2]/mypsf.mog.amp[:2].sum(), mypsf.mog.mean[:2,:], mypsf.mog.var[:2,:,:]) psf2.radius = sh/2 tim.psf = psf2 mod = tractor.getModelImage(0) _plot_psf(lpix, mod, psf2, flux=src.brightness.getValue()) plt.suptitle('psf2 (init)') ps.savefig() src.freezeAllBut('brightness') tractor.freezeParam('catalog') for i in range(100): print 'Optimizing PSF:' tractor.printThawedParams() dlnp,X,alpha = tractor.optimize(damp=1.) tractor.freezeParam('images') tractor.thawParam('catalog') print 'Optimizing flux:' tractor.printThawedParams() tractor.optimize_forced_photometry() tractor.thawParam('images')