Ejemplo n.º 1
0
def main(opt, ps):
	#ralo = 36
	#rahi = 42
	#declo = -1.25
	#dechi = 1.25
	#width = 7
	  
	ralo = 37.5
	rahi = 41.5
	declo = -1.5
	dechi = 2.5
	width = 2.5

	rl,rh = 39,40
	dl,dh = 0,1
	roipoly = np.array([(rl,dl),(rl,dh),(rh,dh),(rh,dl)])

	ra  = (ralo  + rahi ) / 2.
	dec = (declo + dechi) / 2.

	bandnum = 1
	band = 'w%i' % bandnum
	plt.figure(figsize=(12,12))

	#basedir = '/project/projectdirs/bigboss'
	#wisedatadir = os.path.join(basedir, 'data', 'wise')

	wisedatadirs = ['/clusterfs/riemann/raid007/bosswork/boss/wise_level1b',
					'/clusterfs/riemann/raid000/bosswork/boss/wise1ext']

	wisecatdir = '/home/boss/products/NULL/wise/trunk/fits/'

	ofn = 'wise-images-overlapping.fits'

	if os.path.exists(ofn):
		print 'File exists:', ofn
		T = fits_table(ofn)
		print 'Found', len(T), 'images overlapping'

		print 'Reading WCS headers...'
		wcses = []
		T.filename = [fn.strip() for fn in T.filename]
		for fn in T.filename:
			wcs = anwcs(fn, 0)
			wcses.append(wcs)

	else:
		TT = []
		for d in wisedatadirs:
			ifn = os.path.join(d, 'WISE-index-L1b.fits') #'index-allsky-astr-L1b.fits')
			T = fits_table(ifn, columns=['ra','dec','scan_id','frame_num'])
			print 'Read', len(T), 'from WISE index', ifn
			I = np.flatnonzero((T.ra > ralo) * (T.ra < rahi) * (T.dec > declo) * (T.dec < dechi))
			print len(I), 'overlap RA,Dec box'
			T.cut(I)

			fns = []
			for sid,fnum in zip(T.scan_id, T.frame_num):
				print 'scan,frame', sid, fnum
				fn = get_l1b_file(d, sid, fnum, bandnum)
				print '-->', fn
				assert(os.path.exists(fn))
				fns.append(fn)
			T.filename = np.array(fns)
			TT.append(T)
		T = merge_tables(TT)

		wcses = []
		corners = []
		ii = []
		for i in range(len(T)):
			wcs = anwcs(T.filename[i], 0)
			W,H = wcs.get_width(), wcs.get_height()
			rd = []
			for x,y in [(1,1),(1,H),(W,H),(W,1)]:
				rd.append(wcs.pixelxy2radec(x,y))
			rd = np.array(rd)
			if polygons_intersect(roipoly, rd):
				wcses.append(wcs)
				corners.append(rd)
				ii.append(i)

		print 'Found', len(wcses), 'overlapping'
		I = np.array(ii)
		T.cut(I)

		outlines = corners
		corners = np.vstack(corners)

		nin = sum([1 if point_in_poly(ra,dec,ol) else 0 for ol in outlines])
		print 'Number of images containing RA,Dec,', ra,dec, 'is', nin

		r0,r1 = corners[:,0].min(), corners[:,0].max()
		d0,d1 = corners[:,1].min(), corners[:,1].max()
		print 'RA,Dec extent', r0,r1, d0,d1

		T.writeto(ofn)
		print 'Wrote', ofn


	# MAGIC 2.75: approximate pixel scale, "/pix
	S = int(3600. / 2.75)
	print 'Coadd size', S
	cowcs = anwcs_create_box(ra, dec, 1., S, S)

	if False:
		print 'Plotting map...'
		plot = Plotstuff(outformat='png', ra=ra, dec=dec, width=width, size=(800,800))
		out = plot.outline
		plot.color = 'white'
		plot.alpha = 0.07
		plot.apply_settings()

		for wcs in wcses:
			out.wcs = wcs
			out.fill = False
			plot.plot('outline')
			out.fill = True
			plot.plot('outline')

		plot.color = 'gray'
		plot.alpha = 1.0
		plot.lw = 1
		plot.plot_grid(1, 1, 1, 1)

		plot.color = 'red'
		plot.lw = 3
		plot.alpha = 0.75
		out.wcs = cowcs
		out.fill = False
		plot.plot('outline')

		if opt.sources:
			rd = plot.radec
			plot_radec_set_filename(rd, opt.sources)
			plot.plot('radec')

		pfn = ps.getnext()
		plot.write(pfn)
		print 'Wrote', pfn


	# Re-sort by distance to RA,Dec center...
	#I = np.argsort(np.hypot(T.ra - ra, T.dec - dec))
	#T.cut(I)
	# IF YOU DO THIS, MUST ALSO RE-SORT 'wcses'!

	
	if opt.sources:

		# Look at a radius this big, in arcsec, around each source position.
		# 15" = about 6 WISE pixels
		Wrad = 15. / 3600.

		# Look for SDSS objects within this radius; Wrad + a margin
		Srad = Wrad + 5./3600.


		S = fits_table(opt.sources)
		print 'Read', len(S), 'sources from', opt.sources

		groups,singles = cluster_radec(S.ra, S.dec, Wrad, singles=True)
		print 'Source clusters:', groups
		print 'Singletons:', singles

		tractors = []

		sdss = DR9(basedir='data-dr9')
		sband = 'r'

		for i in singles:
			r,d = S.ra[i],S.dec[i]
			print 'Source', i, 'at', r,d
			fn = sdss.retrieve('photoObj', S.run[i], S.camcol[i], S.field[i], band=sband)
			print 'Reading', fn
			oo = fits_table(fn)
			print 'Got', len(oo)
			cat1,obj1,I = get_tractor_sources_dr9(None, None, None, bandname=sband,
												  objs=oo, radecrad=(r,d,Srad), bands=[],
												  nanomaggies=True, extrabands=[band],
												  fixedComposites=True,
												  getobjs=True, getobjinds=True)
			print 'Got', len(cat1), 'SDSS sources nearby'

			# Find images that overlap?

			ims = []
			for j,wcs in enumerate(wcses):

				print 'Filename', T.filename[j]
				ok,x,y = wcs.radec2pixelxy(r,d)
				print 'WCS', j, '-> x,y:', x,y

				if not anwcs_radec_is_inside_image(wcs, r, d):
					continue

				tim = wise.read_wise_level1b(
					T.filename[j].replace('-int-1b.fits',''),
					nanomaggies=True, mask_gz=True, unc_gz=True,
					sipwcs=True, constantInvvar=True, radecrad=(r,d,Wrad))
				ims.append(tim)
			print 'Found', len(ims), 'images containing this source'

			tr = Tractor(ims, cat1)
			tractors.append(tr)
			

		if len(groups):
			# TODO!
			assert(False)

		sys.exit(0)



		# Find additional SDSS sources nearby = within R pixels radius.
		R = 30.
		#R = 50.
		rad = R * 0.396 / 3600.

		cats = []
		objs = []
		for run,camcol,field,r,d in zip(S.run, S.camcol, S.field, S.ra, S.dec):
			fn = sdss.retrieve('photoObj', run, camcol, field, band=sband)
			print 'Reading', fn
			oo = fits_table(fn)
			print 'Got', len(oo)
			cat1,obj1,I = get_tractor_sources_dr9(None, None, None, bandname=sband,
												  objs=oo, radecrad=(r,d,rad), bands=[],
												  nanomaggies=True, extrabands=[band],
												  fixedComposites=True,
												  getobjs=True, getobjinds=True)
			print 'Got', len(cat1), 'SDSS sources nearby'
			cats.append(cat1)
			objs.append(obj1[I])

		# Merge into one big catalog.
		cat = Catalog()
		for c in cats:
			for src in c:
				cat.append(src)
		S = merge_tables(objs)

		print 'Merged catalog has', len(cat), 'entries'
		print 'S table has', len(S)
		assert(len(S) == len(cat))

		if opt.ptsrc:
			print 'Converting all sources to PointSources'
			pcat = Catalog()
			for src in cat:
				ps = PointSource(src.getPosition(), src.getBrightness())
				pcat.append(ps)
			print 'PointSource catalog:', pcat
			cat = pcat

		# ??
		WW = S
		#WW = tabledata()

		# cat = get_tractor_sources_dr9(None, None, None, bandname=sband,
		# 							  objs=S, bands=[], nanomaggies=True,
		# 							  extrabands=[band])

		print 'Got', len(cat), 'tractor sources'
		#cat = Catalog(*cat)
		print cat
		for src in cat:
			print '  ', src

		### FIXME -- match to WISE catalog to initialize mags?

		# Initialize WISE mags to be at least detectable
		# so that we identify the right pixel ROIs below.

		#minbright = NanoMaggies.magToNanomaggies()
		#minbright = 50.
		minbright = 250.

		cat.freezeParamsRecursive('*')
		cat.thawPathsTo(band)
		p0 = cat.getParams()
		cat.setParams(np.maximum(minbright, p0))

		print 'Set minimum W1 brightness:'
		for src in cat:
			print '  ', src

		# Cut images that don't overlap.
		ii = []
		for i,wcs in enumerate(wcses):
			isin = False
			for r,d in zip(S.ra, S.dec):
				if anwcs_radec_is_inside_image(wcs, r, d):
					isin = True
					break
			if isin:
				ii.append(i)
		T.cut(np.array(ii))
		print 'Cut to', len(T), 'images containing sources'


		
	else:
		wfn = 'wise-sources-nearby.fits'
		if os.path.exists(wfn):
			print 'Reading existing file', wfn
			W = fits_table(wfn)
			print 'Got', len(W), 'with range RA', W.ra.min(), W.ra.max(), ', Dec', W.dec.min(), W.dec.max()
		else:
			# Range of WISE slices (inclusive) containing this Dec range.
			ws0, ws1 = 26,27
			WW = []
			for w in range(ws0, ws1+1):
				fn = os.path.join(wisecatdir, 'wise-allsky-cat-part%02i-radec.fits' % w)
				print 'Searching for sources in', fn
				W = fits_table(fn)
				I = np.flatnonzero((W.ra >= r0) * (W.ra <= r1) * (W.dec >= d0) * (W.dec <= d1))
				fn = os.path.join(wisecatdir, 'wise-allsky-cat-part%02i.fits' % w)
				print 'Reading', len(I), 'rows from', fn
				W = fits_table(fn, rows=I)
				print 'Cut to', len(W), 'sources in range'
				WW.append(W)
			W = merge_tables(WW)
			del WW
			print 'Total of', len(W)
			W.writeto(wfn)
			print 'wrote', wfn
	
		# DEBUG
		W.cut((W.ra >= rl) * (W.ra <= rh) * (W.dec >= dl) * (W.dec <= dh))
		print 'Cut to', len(W), 'in the central region'
	
		print 'Creating', len(W), 'Tractor sources'
		cat = Catalog()
		for i in range(len(W)):
			w1 = W.w1mpro[i]
			nm = NanoMaggies.magToNanomaggies(w1)
			cat.append(PointSource(RaDecPos(W.ra[i], W.dec[i]), NanoMaggies(w1=nm)))

		WW = W

	cat.freezeParamsRecursive('*')
	cat.thawPathsTo(band)

	cat0 = cat.getParams()
	br0 = [src.getBrightness().copy() for src in cat]
	nm0 = np.array([b.getBand(band) for b in br0])

	WW.nm0 = nm0

	w1psf = wise.get_psf_model(bandnum, opt.pixpsf)

	# Create fake image in the "coadd" footprint in order to find overlapping
	# sources.
	H,W = int(cowcs.imageh), int(cowcs.imagew)
	# MAGIC -- sigma a bit smaller than typical images (4.0-ish)
	sig = 3.5
	# typical zeropoint
	zp = 20.752
	
	faketim = Image(data=np.zeros((H,W), np.float32),
					invvar=np.zeros((H,W), np.float32) + (1./sig**2),
					psf=w1psf, wcs=ConstantFitsWcs(cowcs), sky=ConstantSky(0.),
					photocal = LinearPhotoCal(NanoMaggies.zeropointToScale(zp),
											  band=band),
					#photocal=LinearPhotoCal(1., band=band),
					name='fake')
	minsb = 0.1 * sig
	#minsb = 0.

	# pc = faketim.getPhotoCal()
	# print 'Source counts:'
	# for src in cat:
	# 	print '  ', src
	# 	print '-->', pc.brightnessToCounts(src.getBrightness())
	# 	print '  -->', [pc.brightnessToCounts(br) for br in src.getBrightnesses()]
	# print 'Source pixel positions:'
	# wcs = faketim.getWcs()
	# for src in cat:
	# 	print '  ', src
	# 	print '--> x,y', wcs.positionToPixel(src.getPosition())
	

	print 'Finding overlapping sources...'
	t0 = Time()
	tractor = Tractor([faketim], cat)
	groups,L,fakemod = tractor.getOverlappingSources(0, minsb=minsb)
	print 'Overlapping sources took', Time()-t0
	print 'Got', len(groups), 'groups of sources'
	nl = L.max()
	gslices = find_objects(L, nl)

	print 'unique labels:', np.unique(L)

	# plt.clf()
	# plt.imshow(fakemod, interpolation='nearest', origin='lower',
	# 		   vmin=0, vmax=sig*3.)
	# plt.title('Fakemod')
	# ps.savefig()
	# 
	# for IM in [L, (L>0)]:
	# 	plt.clf()
	# 	plt.imshow(IM, interpolation='nearest', origin='lower')
	# 	plt.gray()
	# 	wcs = faketim.getWcs()
	# 	xy = []
	# 	for src in cat:
	# 		x,y = wcs.positionToPixel(src.getPosition())
	# 		xy.append((x,y))
	# 	xy = np.array(xy)
	# 	ax = plt.axis()
	# 	plt.plot(xy[:,0], xy[:,1], 'r+')
	# 	plt.title('Source groups')
	# 	ps.savefig()

	# Find sources touching each group's (rectangular) ROI
	tgroups = {}
	for i,gslice in enumerate(gslices):
		gl = i+1
		tg = np.unique(L[gslice])
		tsrcs = []
		for g in tg:
			if not g in [gl,0]:
				if g in groups:
					tsrcs.extend(groups[g])
		tgroups[gl] = tsrcs


	# for i,gslice in enumerate(gslices):
	# 	if not (i+1) in groups:
	# 		continue
	# 
	# 	plt.clf()
	# 	plt.imshow(IM[gslice], interpolation='nearest', origin='lower')
	# 	plt.gray()
	# 	wcs = faketim.getWcs()
	# 	xy = []
	# 	y0,x0 = gslice[0].start, gslice[1].start
	# 	for src in cat:
	# 		x,y = wcs.positionToPixel(src.getPosition())
	# 		xy.append((x-x0,y-y0))
	# 	xy = np.array(xy)
	# 
	# 	ax = plt.axis()
	# 
	# 	plt.plot(xy[:,0], xy[:,1], 'r+')
	# 
	# 	I = np.array(groups[i+1])
	# 	if len(I):
	# 		plt.plot(xy[I,0], xy[I,1], 'g.')
	# 
	# 	I = np.array(tgroups[i+1])
	# 	if len(I):
	# 		plt.plot(xy[I,0], xy[I,1], 'gx')
	# 
	# 	ps.savefig()
	# 
	# 	plt.axis(ax)
	# 	ps.savefig()



	print 'Group size histogram:'
	ng = Counter()
	for g in groups.values():
		ng[len(g)] += 1
	kk = ng.keys()
	kk.sort()
	for k in kk:
		print '  ', k, 'sources:', ng[k], 'groups'

	nms = []
	tims = []
	allrois = {}
	badrois = {}

	if opt.threads:
		mp = multiproc(opt.threads)
	else:
		mp = multiproc(1)

	tims = mp.map(_read_l1b, T.filename)

	for imi,tim in enumerate(tims):
		tim.psf = w1psf
		H,W = tim.shape
		nin = 0
		for src in cat:
			x,y = tim.getWcs().positionToPixel(src.getPosition())
			if x >= 0 and y >= 0 and x < W and y < H:
				nin += 1
		print 'Number of sources inside image:', nin

		tractor = Tractor([tim], cat)
		tractor.freezeParam('images')
		### ??
		cat.setParams(cat0)

		pgroups = 0
		pobjs = 0

		for gi in range(len(gslices)):
			gl = gi
			# note, gslices is zero-indexed
			gslice = gslices[gl]
			gl += 1
			if not gl in groups:
				print 'Group', gl, 'not in groups array; skipping'
				continue
			gsrcs = groups[gl]
			tsrcs = tgroups[gl]

			# print 'Group number', (gi+1), 'of', len(Gorder), ', id', gl, ': sources', gsrcs
			# print 'sources in groups touching slice:', tsrcs

			# Convert from 'canonical' ROI to this image.
			yl,yh = gslice[0].start, gslice[0].stop
			xl,xh = gslice[1].start, gslice[1].stop
			x0,y0 = W-1,H-1
			x1,y1 = 0,0
			for x,y in [(xl,yl),(xh-1,yl),(xh-1,yh-1),(xl,yh-1)]:
				r,d = cowcs.pixelxy2radec(x+1, y+1)
				x,y = tim.getWcs().positionToPixel(RaDecPos(r,d))
				x = int(np.round(x))
				y = int(np.round(y))

				x = np.clip(x, 0, W-1)
				y = np.clip(y, 0, H-1)
				x0 = min(x0, x)
				y0 = min(y0, y)
				x1 = max(x1, x)
				y1 = max(y1, y)
			if x1 == x0 or y1 == y0:
				print 'Gslice', gslice, 'is completely outside this image'
				continue
			
			gslice = (slice(y0,y1+1), slice(x0, x1+1))

			if np.all(tim.getInvError()[gslice] == 0):
				print 'This whole object group has invvar = 0.'

				if not gl in badrois:
					badrois[gl] = {}
				badrois[gl][imi] = gslice

				continue

			if not gl in allrois:
				allrois[gl] = {}
			allrois[gl][imi] = gslice

			if not opt.individual:
				continue

			fullcat = tractor.catalog
			subcat = Catalog(*[fullcat[i] for i in gsrcs + tsrcs])
			for i in range(len(tsrcs)):
				subcat.freezeParam(len(gsrcs) + i)
			tractor.catalog = subcat

			print len(gsrcs), 'sources unfrozen; total', len(subcat)

			pgroups += 1
			pobjs += len(gsrcs)
			
			t0 = Time()
			tractor.optimize_forced_photometry(minsb=minsb, mindlnp=1.,
											   rois=[gslice])
			print 'optimize_forced_photometry took', Time()-t0

			tractor.catalog = fullcat

		# mod = tractor.getModelImage(0, minsb=minsb)
		# noise = np.random.normal(size=mod.shape)
		# noise[tim.getInvError() == 0] = 0.
		# nz = (tim.getInvError() > 0)
		# noise[nz] *= (1./tim.getInvError()[nz])
		# ima = dict(interpolation='nearest', origin='lower',
		# 		   vmin=tim.zr[0], vmax=tim.zr[1])
		# imchi = dict(interpolation='nearest', origin='lower',
		# 		   vmin=-5, vmax=5)
		# plt.clf()
		# plt.subplot(2,2,1)
		# plt.imshow(tim.getImage(), **ima)
		# plt.gray()
		# plt.subplot(2,2,2)
		# plt.imshow(mod, **ima)
		# plt.gray()
		# plt.subplot(2,2,3)
		# plt.imshow((tim.getImage() - mod) * tim.getInvError(), **imchi)
		# plt.gray()
		# plt.subplot(2,2,4)
		# plt.imshow(mod + noise, **ima)
		# plt.gray()
		# plt.suptitle('W1, scan %s, frame %i' % (sid, fnum))
		# ps.savefig()

		if opt.individual:
			print 'Photometered', pgroups, 'groups containing', pobjs, 'objects'
	
			cat.thawPathsTo(band)
			nm1 = np.array([src.getBrightness().getBand(band) for src in cat])
			nms.append(nm1)
	
			WW.nms = np.array(nms).T
			fn = opt.output % imi
			WW.writeto(fn)
			print 'Wrote', fn

	return dict(cat0=cat0, WW=WW, band=band, tims=tims,
				allrois=allrois, badrois=badrois, groups=groups,
				tgroups=tgroups, minsb=minsb,
				gslices=gslices, cat=cat)
Ejemplo n.º 2
0
def simult_photom(cat0=None, WW=None, band=None, tims=None,
				  allrois=None, badrois=None, groups=None,
				  tgroups=None, minsb=None,
				  gslices=None, cat=None,
				  opt=None, ps=None):

	def _plot_grid(ims, kwas):
		N = len(ims)
		C = int(np.ceil(np.sqrt(N)))
		R = int(np.ceil(N / float(C)))
		plt.clf()
		for i,(im,kwa) in enumerate(zip(ims, kwas)):
			plt.subplot(R,C, i+1)
			#print 'plotting grid cell', i, 'img shape', im.shape
			plt.imshow(im, **kwa)
			plt.gray()
			plt.xticks([]); plt.yticks([])
		return R,C

	def _plot_grid2(ims, cat, tims, kwas, ptype='mod'):
		xys = []
		stamps = []
		for (img,mod,chi,roi),tim in zip(ims, tims):
			if ptype == 'mod':
				stamps.append(mod)
			elif ptype == 'chi':
				stamps.append(chi)
			wcs = tim.getWcs()
			y0,x0 = roi[0].start, roi[1].start
			xy = []
			for src in cat:
				xi,yi = wcs.positionToPixel(src.getPosition())
				xy.append((xi - x0, yi - y0))
			xys.append(xy)
			#print 'X,Y source positions in stamp of shape', stamps[-1].shape
			#print '  ', xy
		R,C = _plot_grid(stamps, kwas)
		for i,xy in enumerate(xys):
			plt.subplot(R, C, i+1)
			ax = plt.axis()
			xy = np.array(xy)
			plt.plot(xy[:,0], xy[:,1], 'r+', lw=2)
			plt.axis(ax)

		

	# Simultaneous photometry

	if opt.osources:
		O = fits_table(opt.osources)
		ocat = Catalog()
		print 'Other catalog:'
		for i in range(len(O)):
			w1 = O.wiseflux[i, 0]
			s = PointSource(RaDecPos(O.ra[i], O.dec[i]), NanoMaggies(w1=w1))
			ocat.append(s)
		print ocat
		ocat.freezeParamsRecursive('*')
		ocat.thawPathsTo(band)

	# Keep track of params after simultaneous photometry...
	cat.freezeParamsRecursive('*')
	cat.thawPathsTo(band)
	catsim = cat.getParams()
	if opt.opt:
		# ... and also after RA,Dec opt.
		cat.thawPathsTo('ra','dec')
		catopt = cat.getParams()
		cat.freezeParamsRecursive('*')
		cat.thawPathsTo(band)

	for gi in range(len(gslices)):
		gl = gi
		gl += 1
		if not gl in groups:
			print 'Group', gl, 'not in groups array; skipping'
			continue
		gsrcs = groups[gl]
		tsrcs = tgroups[gl]

		print 'Group', gl
		print 'gsrcs:', gsrcs
		print 'tsrcs:', tsrcs

		if (not gl in allrois) and (not gl in badrois):
			print 'Group', gl, 'does not touch any images?'
			continue

		mytims = []
		rois = []
		if gl in allrois:
			for imi,roi in allrois[gl].items():
				mytims.append(tims[imi])
				rois.append(roi)

		mybadtims = []
		mybadrois = []
		if gl in badrois:
			for imi,roi in badrois[gl].items():
				mybadtims.append(tims[imi])
				mybadrois.append(roi)

		print 'Group', gl, 'touches', len(mytims), 'images and', len(mybadtims), 'bad ones'

		tt = 'group %i: %i+%i sources' % (gl, len(gsrcs), len(tsrcs))

		if len(mytims):

			cat.setParams(catsim)

			#print 'Restoring catsim:'
			#cat.printThawedParams()

			subcat = Catalog(*[cat[i] for i in gsrcs + tsrcs])
			for i in range(len(tsrcs)):
				subcat.freezeParam(len(gsrcs) + i)

			tractor = Tractor(mytims, subcat)
			tractor.freezeParam('images')

			print len(gsrcs), 'sources unfrozen; total', len(subcat)

			print 'Before fitting:'
			for src in subcat[:len(gsrcs)]:
				print '  ', src
				
			t0 = Time()
			ims0,ims1 = tractor.optimize_forced_photometry(minsb=minsb, mindlnp=1.,
														   rois=rois)
			print 'optimize_forced_photometry took', Time()-t0

			print 'After fitting:'
			for src in subcat[:len(gsrcs)]:
				print '  ', src

			imas = [dict(interpolation='nearest', origin='lower',
						 vmin=tim.zr[0], vmax=tim.zr[1])
					for tim in mytims]
			imchi = dict(interpolation='nearest', origin='lower', vmin=-5, vmax=5)
			imchis = [imchi] * len(mytims)

			_plot_grid([img for (img, mod, chi, roi) in ims0], imas)
			plt.suptitle('Data: ' + tt)
			ps.savefig()

			if ims1 is not None:
				#_plot_grid([mod for (img, mod, chi, roi) in ims1], imas)
				_plot_grid2(ims1, subcat, mytims, imas)
				plt.suptitle('Forced-phot model: ' + tt)
				ps.savefig()

				#_plot_grid([chi for (img, mod, chi, roi) in ims1], imchis)
				_plot_grid2(ims1, subcat, mytims, imchis, ptype='chi')
				plt.suptitle('Forced-phot chi: ' + tt)
				ps.savefig()

			if opt.osources:
				cc = tractor.catalog
				tractor.catalog = ocat
				nil,nil,ims3 = tractor.optimize_forced_photometry(minsb=minsb, rois=rois,
																  justims0=True)
				tractor.catalog = cc

				_plot_grid2(ims3, ocat, mytims, imas)
				plt.suptitle("Schlegel's model: group %i" % gl)
				ps.savefig()

				_plot_grid2(ims3, ocat, mytims, imchis, ptype='chi')
				plt.suptitle("Schlegel's chi: group %i" % gl)
				ps.savefig()
				


			if opt.opt:
				op1 = ps.getnext()
				op2 = ps.getnext()
				#fits[gl] = (tractor, len(gsrcs), rois, op1, op2)


			# print 'Plotting mods after simul photom'
			# #_plot_grid([mod for (img, mod, chi, roi) in ims0], imas)
			# _plot_grid2(ims0, subcat, mytims, imas)
			# plt.suptitle('Initial model: ' + tt)
			# ps.savefig()
			# 
			# print 'Plotting chis after simul photom'
			# #_plot_grid([chi for (img, mod, chi, roi) in ims0], imchis)
			# _plot_grid2(ims0, subcat, mytims, imchis, ptype='chi')
			# plt.suptitle('Initial chi: ' + tt)
			# ps.savefig()

			print 'After simultaneous photometry:'
			subcat.printThawedParams()

			# Copy updated params to "catsim"
			catsim = cat.getParams()

			#print 'Saving catsim:'
			#cat.printThawedParams()

			cat.freezeParamsRecursive('*')
			cat.thawPathsTo(band)
			WW.nmall = np.array([src.getBrightness().getBand(band) for src in cat])



		if len(mytims) and opt.opt:
			print 'Optimizing RA,Dec'

			subcat = tractor.catalog

			# Copy updated forced-phot params from catsim to catopt.

			#print 'Saving subcat forced-phot params:'
			#subcat.printThawedParams()
			fphot = subcat.getParams()

			cat.thawPathsTo('ra','dec')
			cat.setParams(catopt)

			#print 'Copying forced-phot results to catopt:'
			cat.freezeParamsRecursive('*')
			cat.thawPathsTo(band)
			cat.freezeAllBut(*gsrcs)
			#cat.printThawedParams()
			NP = cat.numberOfParams()
			cat.setParams(fphot[:NP])
			#print 'Result:'

			#print 'Restoring catopt:'
			#cat.printThawedParams()

			cat.freezeParamsRecursive('*')
			cat.thawPathsTo(band)
			NG = len(gsrcs)
			for i in range(NG):
				subcat[i].thawPathsTo('ra','dec')
			p0 = subcat.getParams()
			print 'Optimizing params:'
			subcat.printThawedParams()

			thetims = tractor.images
			subimgs = []
			for i,img in enumerate(thetims):
				roi = rois[i]
				y0 = roi[0].start
				x0 = roi[1].start
				subwcs = ShiftedWcs(img.wcs, x0, y0)
				subimg = Image(data=img.data[roi], invvar=img.invvar[roi],
							   psf=img.psf, wcs=subwcs, sky=img.sky,
							   photocal=img.photocal, name=img.name)
				subimgs.append(subimg)
			tractor.images = Images(*subimgs)

			while True:
				dlnp,X,alpha = tractor.optimize()
				print 'dlnp', dlnp
				print 'alpha', alpha
				if dlnp < 0.1:
					break

			p1 = subcat.getParams()

			print 'Param changes:'
			for nm,pp0,pp1 in zip(subcat.getParamNames(), p0, p1):
				print '  ', nm, pp0, 'to', pp1, '; delta', pp1-pp0


			cat.thawPathsTo('ra','dec')
			catopt = cat.getParams()

			print 'Saving catopt:'
			cat.printThawedParams()

			cat.freezeParamsRecursive('ra', 'dec')

			tractor.images = thetims

			nil,nil,ims2 = tractor.optimize_forced_photometry(minsb=minsb, rois=rois,
															  justims0=True)

			print 'Plotting mods after RA,Dec opt'
			#_plot_grid([mod for (img, mod, chi, roi) in ims2], imas)
			_plot_grid2(ims2, subcat, mytims, imas)
			plt.suptitle('RA,Dec-opt model: ' + tt)
			plt.savefig(op1)
			
			print 'Plotting chis after RA,Dec opt'
			#_plot_grid([chi for (img, mod, chi, roi) in ims2], imchis)
			_plot_grid2(ims2, subcat, mytims, imchis, ptype='chi')
			plt.suptitle('RA,Dec-opt chi: ' + tt)
			plt.savefig(op2)







		N = len(mybadtims)
		if N and False:
			C = int(np.ceil(np.sqrt(N)))
			R = int(np.ceil(N / float(C)))
			plt.clf()
			for i,(tim,roi) in enumerate(zip(mybadtims, mybadrois)):
				plt.subplot(R,C, i+1)
				plt.imshow(tim.getImage()[roi], interpolation='nearest', origin='lower',
						   vmin=tim.zr[0], vmax=tim.zr[1])
				plt.gray()
			plt.suptitle('Data in bad regions')
			ps.savefig()

			plt.clf()
			for i,(tim,roi) in enumerate(zip(mybadtims, mybadrois)):
				plt.subplot(R,C, i+1)
				plt.imshow(tim.getInvError()[roi], interpolation='nearest', origin='lower')
				plt.gray()
			plt.suptitle('Inverr in bad regions')
			ps.savefig()

		if gi == 0 and opt.plotmask:

			alltims = mybadtims+mytims
			_plot_grid([tim.uncplane[roi] for tim,roi in zip(alltims,
															 mybadrois + rois)],
					   [dict(interpolation='nearest', origin='lower')]*len(alltims))
			plt.suptitle('Uncertainty plane')
			ps.savefig()
	
			for bit,txt in [
				(0 ,  'static: excessively noisy due to high dark current alone'),
				(1 ,  'static: generally noisy [includes bit 0]'),
				(2 ,  'static: dead or very low responsivity'),
				(3 ,  'static: low responsivity or low dark current'),
				(4 ,  'static: high responsivity or high dark current'),
				(5 ,  'static: saturated anywhere in ramp'),
				(6 ,  'static: high, uncertain, or unreliable non-linearity'),
				(7 ,  'static: known broken hardware pixel or excessively noisy responsivity estimate [may include bit 1]'),
				(9 ,  'broken pixel or negative slope fit value'),
				(10,  'saturated in sample read 1'),
				(11,  'saturated in sample read 2'),
				(12,  'saturated in sample read 3'),
				(13,  'saturated in sample read 4'),
				(14,  'saturated in sample read 5'),
				(15,  'saturated in sample read 6'),
				(16,  'saturated in sample read 7'),
				(17,  'saturated in sample read 8'),
				(18,  'saturated in sample read 9'),
				(21,  'new/transient bad pixel from dynamic masking'),
				(26,  'non-linearity correction unreliable'),
				(27,  'contains cosmic-ray or outlier that cannot be classified (from temporal outlier rejection in multi-frame pipeline)'),
				(28,  'contains positive or negative spike-outlier'),
				]:

				_plot_grid([tim.maskplane[roi] & (1 << bit)
							for tim,roi in zip(alltims, mybadrois + rois)],
					   [dict(interpolation='nearest', origin='lower',
							 vmin=0, vmax=1)]*len(alltims))
				plt.suptitle('Mask: ' + txt)
				ps.savefig()



	if opt.opt:

		fn = opt.output % 998
		WW.writeto(fn)
		print 'Wrote', fn

		cat.thawPathsTo('ra','dec')
		cat.setParams(catopt)
		WW.nmoptrd = np.array([src.getBrightness().getBand(band) for src in cat])
		cat.freezeParamsRecursive(band, 'dec')
		WW.raoptrd = np.array(cat.getParams())
		cat.freezeParamsRecursive('ra')
		cat.thawPathsTo('dec')
		WW.decoptrd = np.array(cat.getParams())
		cat.freezeParamsRecursive('dec')

	fn = opt.output % 999
	WW.writeto(fn)
	print 'Wrote', fn
Ejemplo n.º 3
0
def _meisner_psf_models():
    global plotslice
    # Meisner's PSF models
    # for band in [4]:
    for band in [1, 2, 3, 4]:
        print()
        print('W%i' % band)
        print()

        #pix = fitsio.read('wise-psf-avg-pix.fits', ext=band-1)
        pix = fitsio.read('wise-psf-avg-pix-bright.fits', ext=band - 1)
        fit = fits_table('wise-psf-avg.fits', hdu=band)
        #fit = fits_table('psf-allwise-con3.fits', hdu=band)

        scale = 1.

        print('Pix shape', pix.shape)
        h, w = pix.shape

        xx, yy = np.meshgrid(np.arange(w), np.arange(h))
        cx, cy = np.sum(xx * pix) / np.sum(pix), np.sum(yy * pix) / np.sum(pix)
        print('Centroid:', cx, cy)

        #S = 100
        S = h / 2
        slc = slice(h / 2 - S, h / 2 + S + 1), slice(w / 2 - S, w / 2 + S + 1)

        plotss = 30
        plotslice = slice(S - plotss,
                          -(S - plotss)), slice(S - plotss, -(S - plotss))

        opix = pix
        print('Original pixels sum:', opix.sum())
        pix /= pix.sum()
        pix = pix[slc]
        print('Sliced pix sum', pix.sum())

        psf = GaussianMixturePSF(fit.amp, fit.mean * scale, fit.var * scale**2)
        psfmodel = psf.getPointSourcePatch(0., 0., radius=h / 2)
        mod = psfmodel.patch
        print('Orig mod sum', mod.sum())
        mod = mod[slc]
        print('Sliced mod sum', mod.sum())
        print('Amps:', np.sum(psf.mog.amp))

        _plot_psf(pix, mod, psf)
        plt.suptitle('W%i: Orig' % band)
        ps.savefig()

        # Lanczos sub-sample
        if band == 4:
            lpix = _lanczos_subsample(opix, 2)
            pix = lpix
            h, w = pix.shape

            #S = 140
            slc = slice(h / 2 - S,
                        h / 2 + S + 1), slice(w / 2 - S, w / 2 + S + 1)
            print('Resampled pix sum', pix.sum())
            pix = pix[slc]
            print('sliced pix sum', pix.sum())

            psf = GaussianMixturePSF(fit.amp, fit.mean * scale,
                                     fit.var * scale**2)
            psfmodel = psf.getPointSourcePatch(0., 0., radius=h / 2)
            mod = psfmodel.patch
            print('Scaled mod sum', mod.sum())
            mod = mod[slc]
            print('Sliced mod sum', mod.sum())

            _plot_psf(pix, mod, psf)
            plt.suptitle('W%i: Scaled' % band)
            ps.savefig()

            plotslice = slice(S - plotss,
                              -(S - plotss)), slice(S - plotss, -(S - plotss))

        psfx = GaussianMixturePSF.fromStamp(pix,
                                            P0=(fit.amp, fit.mean * scale,
                                                fit.var * scale**2))
        psfmodel = psfx.getPointSourcePatch(0., 0., radius=h / 2)
        mod = psfmodel.patch
        print('From stamp: mod sum', mod.sum())
        mod = mod[slc]
        print('Sliced mod sum:', mod.sum())
        _plot_psf(pix, mod, psfx)
        plt.suptitle('W%i: fromStamp: %g = %s, res %.3f' %
                     (band, np.sum(psfx.mog.amp), ','.join(
                         ['%.3f' % a
                          for a in psfx.mog.amp]), np.sum(pix - mod)))
        ps.savefig()

        print('Stamp-Fit PSF params:', psfx)

        # class MyGaussianMixturePSF(GaussianMixturePSF):
        #     def getLogPrior(self):
        #         if np.any(self.mog.amp < 0.):
        #             return -np.inf
        #         for k in range(self.mog.K):
        #             if np.linalg.det(self.mog.var[k]) <= 0:
        #                 return -np.inf
        #         return 0
        #
        #     @property
        #     def amp(self):
        #         return self.mog.amp
        #
        # mypsf = MyGaussianMixturePSF(psfx.mog.amp, psfx.mog.mean, psfx.mog.var)
        # mypsf.radius = sh/2

        sh, sw = pix.shape

        # Initialize from original fit params
        psfx = psf

        # Try concentric gaussian PSF
        sigmas = []
        for k in range(psfx.mog.K):
            v = psfx.mog.var[k, :, :]
            sigmas.append(np.sqrt(np.sqrt(np.abs(v[0, 0] * v[1, 1]))))
        print('Initializing concentric Gaussian PSF with sigmas', sigmas)
        gpsf = NCircularGaussianPSF(sigmas, psfx.mog.amp)
        gpsf.radius = sh / 2
        mypsf = gpsf

        tim = Image(data=pix, invvar=1e6 * np.ones_like(pix), psf=mypsf)
        tim.modelMinval = 1e-16

        # xx,yy = np.meshgrid(np.arange(sw), np.arange(sh))
        # cx,cy = np.sum(xx*pix)/np.sum(pix), np.sum(yy*pix)/np.sum(pix)
        # print 'Centroid:', cx,cy
        # print 'Pix midpoint:', sw/2, sh/2
        cx, cy = sw / 2, sh / 2
        src = PointSource(PixPos(cx, cy), Flux(1.0))

        tractor = Tractor([tim], [src])

        tim.freezeAllBut('psf')
        # tractor.freezeParam('catalog')
        src.freezeAllBut('pos')

        # src.freezeAllBut('brightness')
        # tractor.freezeParam('images')
        # tractor.optimize_forced_photometry()
        # tractor.thawParam('images')
        # print 'Source flux after forced-photom fit:', src.getBrightness()

        print('Optimizing Params:')
        tractor.printThawedParams()

        for i in range(200):
            #dlnp,X,alpha = tractor.optimize(damp=0.1)
            dlnp, X, alpha = tractor.optimize(damp=1)
            print(i, 'dlnp %.3g' % dlnp, 'psf', gpsf)
            if dlnp < 1e-6:
                break

        tractor.freezeParam('catalog')
        gpsf.sigmas.stepsize = len(gpsf.sigmas) * [1e-6]
        gpsf.weights.stepsize = len(gpsf.sigmas) * [1e-6]

        for i in range(200):
            #dlnp,X,alpha = tractor.optimize(damp=0.1)
            dlnp, X, alpha = tractor.optimize(damp=1)
            print(i, 'dlnp %.3g' % dlnp, 'psf', gpsf)
            if dlnp < 1e-6:
                break

        print('PSF3(opt): flux', src.brightness)
        print('PSF amps:', np.sum(mypsf.amp))
        print('PSF amps * Source brightness:',
              src.brightness.getValue() * np.sum(mypsf.amp))
        print('pix sum:', pix.sum())

        print('Optimize source:', src)
        print('Optimized PSF:', mypsf)

        mod = tractor.getModelImage(0)
        print('Mod sum:', mod.sum())
        _plot_psf(pix, mod, mypsf, flux=src.brightness.getValue())
        plt.suptitle('W%i psf3 (opt): %g = %s, resid %.3f' %
                     (band, np.sum(mypsf.amp), ','.join(
                         ['%.3f' % a for a in mypsf.amp]), np.sum(pix - mod)))
        ps.savefig()

        # Write before normalizing!
        T = fits_table()
        T.amp = mypsf.mog.amp
        T.mean = mypsf.mog.mean
        T.var = mypsf.mog.var
        T.writeto('psf3-w%i.fits' % band)
        T.writeto('psf3.fits', append=(band != 1))

        mypsf.weights.setParams(
            np.array(mypsf.weights.getParams()) /
            sum(mypsf.weights.getParams()))
        print('Normalized PSF weights:', mypsf)
        mod = tractor.getModelImage(0)
        print('Mod sum:', mod.sum())
        _plot_psf(pix, mod, mypsf, flux=src.brightness.getValue())
        plt.suptitle('W%i psf3 (opt): %g = %s, resid %.3f' %
                     (band, np.sum(mypsf.amp), ','.join(
                         ['%.3f' % a for a in mypsf.amp]), np.sum(pix - mod)))
        ps.savefig()

        class MyGaussianPSF(NCircularGaussianPSF):
            '''
            A PSF model with strictly positive weights that sum to
            unity.
            '''
            def __init__(self, sigmas, weights):
                ww = np.array(weights)
                ww = np.log(ww)
                super(MyGaussianPSF, self).__init__(sigmas, ww)

            @staticmethod
            def getNamedParams():
                return dict(sigmas=0, logweights=1)

            def __str__(self):
                return ('MyGaussianPSF: sigmas [ ' +
                        ', '.join(['%.3f' % s for s in self.mysigmas]) +
                        ' ], weights [ ' +
                        ', '.join(['%.3f' % w for w in self.myweights]) + ' ]')

            @property
            def myweights(self):
                ww = np.exp(self.logweights.getAllParams())
                ww /= ww.sum()
                return ww

            @property
            def weights(self):
                ww = np.exp(self.logweights.getParams())
                wsum = np.sum(np.exp(self.logweights.getAllParams()))
                return ww / wsum

        if band == 4:
            # HACK
            mypsf = MyGaussianPSF([1.7, 6.4, 15.0], [0.333, 0.666, 0.1])
        else:
            mypsf = MyGaussianPSF(gpsf.sigmas, gpsf.amp)
        mypsf.radius = sh / 2

        tim.psf = mypsf

        print('Optimizing Params:')
        tractor.printThawedParams()

        for i in range(200):
            #dlnp,X,alpha = tractor.optimize(damp=0.1)
            dlnp, X, alpha = tractor.optimize(damp=1)
            print(i, 'dlnp %.3g' % dlnp, 'psf', mypsf)
            if dlnp < 1e-6:
                break

        mypsf.sigmas.stepsize = len(mypsf.sigmas) * [1e-6]
        mypsf.logweights.stepsize = len(mypsf.sigmas) * [1e-6]

        for i in range(200):
            #dlnp,X,alpha = tractor.optimize(damp=0.1)
            dlnp, X, alpha = tractor.optimize(damp=1)
            print(i, 'dlnp %.3g' % dlnp, 'psf', mypsf)
            if dlnp < 1e-6:
                break

        print('PSF amps:', np.sum(mypsf.amp))
        print('pix sum:', pix.sum())
        print('Optimize source:', src)
        print('Optimized PSF:', mypsf)

        mod = tractor.getModelImage(0)
        print('Mod sum:', mod.sum())
        _plot_psf(pix, mod, mypsf, flux=src.brightness.getValue())
        plt.suptitle('W%i psf3 (opt2): %g = %s, resid %.3f' %
                     (band, np.sum(mypsf.amp), ','.join(
                         ['%.3f' % a for a in mypsf.amp]), np.sum(pix - mod)))
        ps.savefig()

        # Write
        mog = mypsf.getMixtureOfGaussians()
        T = fits_table()
        T.amp = mog.amp
        T.mean = mog.mean
        T.var = mog.var
        T.writeto('psf4.fits', append=(band != 1))
Ejemplo n.º 4
0
    #_meisner_psf_models()

    sys.exit(0)

    psf2 = GaussianMixturePSF(mypsf.mog.amp[:2] / mypsf.mog.amp[:2].sum(),
                              mypsf.mog.mean[:2, :], mypsf.mog.var[:2, :, :])
    psf2.radius = sh / 2
    tim.psf = psf2
    mod = tractor.getModelImage(0)
    _plot_psf(lpix, mod, psf2, flux=src.brightness.getValue())
    plt.suptitle('psf2 (init)')
    ps.savefig()

    src.freezeAllBut('brightness')
    tractor.freezeParam('catalog')
    for i in range(100):

        print('Optimizing PSF:')
        tractor.printThawedParams()

        dlnp, X, alpha = tractor.optimize(damp=1.)

        tractor.freezeParam('images')
        tractor.thawParam('catalog')

        print('Optimizing flux:')
        tractor.printThawedParams()

        tractor.optimize_forced_photometry()
        tractor.thawParam('images')
Ejemplo n.º 5
0
def _meisner_psf_models():
    global plotslice
    # Meisner's PSF models
    #for band in [4]:
    for band in [1,2,3,4]:
        print
        print 'W%i' % band
        print
        
        #pix = fitsio.read('wise-psf-avg-pix.fits', ext=band-1)
        pix = fitsio.read('wise-psf-avg-pix-bright.fits', ext=band-1)
        fit = fits_table('wise-psf-avg.fits', hdu=band)
        #fit = fits_table('psf-allwise-con3.fits', hdu=band)
    
        scale = 1.
    
        print 'Pix shape', pix.shape
        h,w = pix.shape

        xx,yy = np.meshgrid(np.arange(w), np.arange(h))
        cx,cy = np.sum(xx*pix)/np.sum(pix), np.sum(yy*pix)/np.sum(pix)
        print 'Centroid:', cx,cy

        #S = 100
        S = h/2
        slc = slice(h/2-S, h/2+S+1), slice(w/2-S, w/2+S+1)

        plotss = 30
        plotslice = slice(S-plotss, -(S-plotss)), slice(S-plotss, -(S-plotss))
        
        opix = pix
        print 'Original pixels sum:', opix.sum()
        pix /= pix.sum()
        pix = pix[slc]
        print 'Sliced pix sum', pix.sum()

        psf = GaussianMixturePSF(fit.amp, fit.mean * scale, fit.var * scale**2)
        psfmodel = psf.getPointSourcePatch(0., 0., radius=h/2)
        mod = psfmodel.patch
        print 'Orig mod sum', mod.sum()
        mod = mod[slc]
        print 'Sliced mod sum', mod.sum()
        print 'Amps:', np.sum(psf.mog.amp)
        
        _plot_psf(pix, mod, psf)
        plt.suptitle('W%i: Orig' % band)
        ps.savefig()
                
        # Lanczos sub-sample
        if band == 4:
            lpix = _lanczos_subsample(opix, 2)
            pix = lpix
            h,w = pix.shape

            #S = 140
            slc = slice(h/2-S, h/2+S+1), slice(w/2-S, w/2+S+1)
            print 'Resampled pix sum', pix.sum()
            pix = pix[slc]
            print 'sliced pix sum', pix.sum()

            psf = GaussianMixturePSF(fit.amp, fit.mean * scale, fit.var * scale**2)
            psfmodel = psf.getPointSourcePatch(0., 0., radius=h/2)
            mod = psfmodel.patch
            print 'Scaled mod sum', mod.sum()
            mod = mod[slc]
            print 'Sliced mod sum', mod.sum()

            _plot_psf(pix, mod, psf)
            plt.suptitle('W%i: Scaled' % band)
            ps.savefig()

            plotslice = slice(S-plotss,-(S-plotss)),slice(S-plotss,-(S-plotss))

        psfx = GaussianMixturePSF.fromStamp(pix, P0=(fit.amp, fit.mean*scale, fit.var*scale**2))
        psfmodel = psfx.getPointSourcePatch(0., 0., radius=h/2)
        mod = psfmodel.patch
        print 'From stamp: mod sum', mod.sum()
        mod = mod[slc]
        print 'Sliced mod sum:', mod.sum()
        _plot_psf(pix, mod, psfx)
        plt.suptitle('W%i: fromStamp: %g = %s, res %.3f' %
                     (band, np.sum(psfx.mog.amp), ','.join(['%.3f'%a for a in psfx.mog.amp]),
                      np.sum(pix - mod)))
        ps.savefig()
        
        print 'Stamp-Fit PSF params:', psfx
    
        # class MyGaussianMixturePSF(GaussianMixturePSF):
        #     def getLogPrior(self):
        #         if np.any(self.mog.amp < 0.):
        #             return -np.inf
        #         for k in range(self.mog.K):
        #             if np.linalg.det(self.mog.var[k]) <= 0:
        #                 return -np.inf
        #         return 0
        # 
        #     @property
        #     def amp(self):
        #         return self.mog.amp
        #     
        # mypsf = MyGaussianMixturePSF(psfx.mog.amp, psfx.mog.mean, psfx.mog.var)
        # mypsf.radius = sh/2


        sh,sw = pix.shape

        # Initialize from original fit params
        psfx = psf
        
        # Try concentric gaussian PSF
        sigmas = []
        for k in range(psfx.mog.K):
            v = psfx.mog.var[k,:,:]
            sigmas.append(np.sqrt(np.sqrt(np.abs(v[0,0] * v[1,1]))))
        print 'Initializing concentric Gaussian PSF with sigmas', sigmas
        gpsf = NCircularGaussianPSF(sigmas, psfx.mog.amp)
        gpsf.radius = sh/2
        mypsf = gpsf
        
        tim = Image(data=pix, invvar=1e6 * np.ones_like(pix), psf=mypsf)
        tim.modelMinval = 1e-16
        
        # xx,yy = np.meshgrid(np.arange(sw), np.arange(sh))
        # cx,cy = np.sum(xx*pix)/np.sum(pix), np.sum(yy*pix)/np.sum(pix)
        # print 'Centroid:', cx,cy
        # print 'Pix midpoint:', sw/2, sh/2
        cx,cy = sw/2, sh/2
        src = PointSource(PixPos(cx, cy), Flux(1.0))
        
        tractor = Tractor([tim], [src])
        
        tim.freezeAllBut('psf')
        #tractor.freezeParam('catalog')
        src.freezeAllBut('pos')
        
        # src.freezeAllBut('brightness')
        # tractor.freezeParam('images')
        # tractor.optimize_forced_photometry()
        # tractor.thawParam('images')
        # print 'Source flux after forced-photom fit:', src.getBrightness()
    
        print 'Optimizing Params:'
        tractor.printThawedParams()
    
        for i in range(200):
            #dlnp,X,alpha = tractor.optimize(damp=0.1)
            dlnp,X,alpha = tractor.optimize(damp=1)
            print i,'dlnp %.3g' % dlnp, 'psf', gpsf
            if dlnp < 1e-6:
                break

        tractor.freezeParam('catalog')
        gpsf.sigmas.stepsize = len(gpsf.sigmas) * [1e-6]
        gpsf.weights.stepsize = len(gpsf.sigmas) * [1e-6]
        
        for i in range(200):
            #dlnp,X,alpha = tractor.optimize(damp=0.1)
            dlnp,X,alpha = tractor.optimize(damp=1)
            print i,'dlnp %.3g' % dlnp, 'psf', gpsf
            if dlnp < 1e-6:
                break

            
        print 'PSF3(opt): flux', src.brightness
        print 'PSF amps:', np.sum(mypsf.amp)
        print 'PSF amps * Source brightness:', src.brightness.getValue() * np.sum(mypsf.amp)
        print 'pix sum:', pix.sum()

        print 'Optimize source:', src
        print 'Optimized PSF:', mypsf
        
        mod = tractor.getModelImage(0)
        print 'Mod sum:', mod.sum()
        _plot_psf(pix, mod, mypsf, flux=src.brightness.getValue())
        plt.suptitle('W%i psf3 (opt): %g = %s, resid %.3f' %
                     (band, np.sum(mypsf.amp), ','.join(['%.3f'%a for a in mypsf.amp]), np.sum(pix-mod)))
        ps.savefig()

        # Write before normalizing!
        T = fits_table()
        T.amp  = mypsf.mog.amp
        T.mean = mypsf.mog.mean
        T.var  = mypsf.mog.var
        T.writeto('psf3-w%i.fits' % band)
        T.writeto('psf3.fits', append=(band != 1))
        
        mypsf.weights.setParams(np.array(mypsf.weights.getParams()) /
                                sum(mypsf.weights.getParams()))
        print 'Normalized PSF weights:', mypsf
        mod = tractor.getModelImage(0)
        print 'Mod sum:', mod.sum()
        _plot_psf(pix, mod, mypsf, flux=src.brightness.getValue())
        plt.suptitle('W%i psf3 (opt): %g = %s, resid %.3f' %
                     (band, np.sum(mypsf.amp), ','.join(['%.3f'%a for a in mypsf.amp]), np.sum(pix-mod)))
        ps.savefig()


        class MyGaussianPSF(NCircularGaussianPSF):
            '''
            A PSF model with strictly positive weights that sum to
            unity.
            '''
            def __init__(self, sigmas, weights):
                ww = np.array(weights)
                ww = np.log(ww)
                super(MyGaussianPSF, self).__init__(sigmas, ww)

            @staticmethod
            def getNamedParams():
                return dict(sigmas=0, logweights=1)
                
            def __str__(self):
                return ('MyGaussianPSF: sigmas [ ' +
                        ', '.join(['%.3f'%s for s in self.mysigmas]) +
                        ' ], weights [ ' +
                        ', '.join(['%.3f'%w for w in self.myweights]) +
                        ' ]')

            @property
            def myweights(self):
                ww = np.exp(self.logweights.getAllParams())
                ww /= ww.sum()
                return ww

            @property
            def weights(self):
                ww = np.exp(self.logweights.getParams())
                wsum = np.sum(np.exp(self.logweights.getAllParams()))
                return ww / wsum
            

        if band == 4:
            # HACK
            mypsf = MyGaussianPSF([1.7, 6.4, 15.0], [0.333, 0.666, 0.1])
        else:
            mypsf = MyGaussianPSF(gpsf.sigmas, gpsf.amp)
        mypsf.radius = sh/2

        tim.psf = mypsf
        
        print 'Optimizing Params:'
        tractor.printThawedParams()
    
        for i in range(200):
            #dlnp,X,alpha = tractor.optimize(damp=0.1)
            dlnp,X,alpha = tractor.optimize(damp=1)
            print i,'dlnp %.3g' % dlnp, 'psf', mypsf
            if dlnp < 1e-6:
                break

        mypsf.sigmas.stepsize = len(mypsf.sigmas) * [1e-6]
        mypsf.logweights.stepsize = len(mypsf.sigmas) * [1e-6]

        for i in range(200):
            #dlnp,X,alpha = tractor.optimize(damp=0.1)
            dlnp,X,alpha = tractor.optimize(damp=1)
            print i,'dlnp %.3g' % dlnp, 'psf', mypsf
            if dlnp < 1e-6:
                break
        
        print 'PSF amps:', np.sum(mypsf.amp)
        print 'pix sum:', pix.sum()
        print 'Optimize source:', src
        print 'Optimized PSF:', mypsf
        
        mod = tractor.getModelImage(0)
        print 'Mod sum:', mod.sum()
        _plot_psf(pix, mod, mypsf, flux=src.brightness.getValue())
        plt.suptitle('W%i psf3 (opt2): %g = %s, resid %.3f' %
                     (band, np.sum(mypsf.amp), ','.join(['%.3f'%a for a in mypsf.amp]), np.sum(pix-mod)))
        ps.savefig()

        # Write
        mog = mypsf.getMixtureOfGaussians()
        T = fits_table()
        T.amp  = mog.amp
        T.mean = mog.mean
        T.var  = mog.var
        T.writeto('psf4.fits', append=(band != 1))
Ejemplo n.º 6
0
    
    #_meisner_psf_models()
        
    sys.exit(0)
        
    psf2 = GaussianMixturePSF(mypsf.mog.amp[:2]/mypsf.mog.amp[:2].sum(),
                              mypsf.mog.mean[:2,:], mypsf.mog.var[:2,:,:])
    psf2.radius = sh/2
    tim.psf = psf2
    mod = tractor.getModelImage(0)
    _plot_psf(lpix, mod, psf2, flux=src.brightness.getValue())
    plt.suptitle('psf2 (init)')
    ps.savefig()

    src.freezeAllBut('brightness')
    tractor.freezeParam('catalog')
    for i in range(100):

        print 'Optimizing PSF:'
        tractor.printThawedParams()
        
        dlnp,X,alpha = tractor.optimize(damp=1.)

        tractor.freezeParam('images')
        tractor.thawParam('catalog')

        print 'Optimizing flux:'
        tractor.printThawedParams()

        tractor.optimize_forced_photometry()
        tractor.thawParam('images')