def PatchFitter( all_data, all_dq, ini_psf, patch_shape, id_start, background="linear", sequence=["shifts", "psf"], tol=1.0e-10, eps=1.0e-4, gamma=1.0e0, ini_shifts=None, Nthreads=20, floor=None, plotfilebase=None, gain=None, maxiter=np.Inf, dumpfilebase=None, trim_frac=0.005, min_data_frac=0.75, core_size=5, lower=1.0e-5, k=1, plot=False, clip_parms=None, final_clip=[1, 3.0], q=1.0, clip_shifts=False, h=1.4901161193847656e-08, Nplot=20, small=1.0e-5, Nsearch=256, search_rate=0.05, search_scale=1e-2, shift_test_thresh=0.475, min_frac=0.5, max_nll=1.0e10, Nburn=10, ): """ Patch fitting routines for psf inference. """ assert background in [None, "constant", "linear"] assert np.mod(patch_shape[0], 2) == 1, "Patch shape[0] must be odd" assert np.mod(patch_shape[1], 2) == 1, "Patch shape[1] must be odd" assert np.mod(core_size, 2) == 1, "Core size must be odd" assert (patch_shape[0] * patch_shape[1]) == all_data.shape[1], "Patch shape does not match data shape" kinds = ["shifts", "psf", "evaluate", "plot_data"] for i in range(len(sequence)): assert sequence[i] in kinds, "sequence not allowed" # set parameters parms = InferenceParms( h, k, q, eps, tol, gain, plot, floor, gamma, all_data.shape[0], Nplot, small, Nsearch, id_start, max_nll, min_frac, Nthreads, core_size, background, None, patch_shape, search_rate, plotfilebase, search_scale, ini_psf.shape, shift_test_thresh, ini_psf, ) # initialize current_psf = ini_psf.copy() current_psf /= current_psf.max() current_cost = np.inf if ini_shifts is not None: shifts = ini_shifts ref_shifts = ini_shifts.copy() else: ref_shifts = np.zeros((all_data.shape[0], 2)) if Nburn is not None: current_cost = None burn_iter = 0 else: data = all_data dq = all_dq current_cost = np.inf burn_iter = None # run t0 = time.time() while True: t = time.time() # assign data used during burnin if burn_iter is not None: burn_iter += 1 if burn_iter == Nburn: data = all_data dq = all_dq current_cost = np.inf else: burn_size = np.ceil(1.0 * all_data.shape[0] / Nburn) data = all_data[: burn_iter * burn_size] dq = all_dq[: burn_iter * burn_size] # minimum number of patches, mask initialization Nmin = np.ceil(min_data_frac * data.shape[0]).astype( mask = np.arange(data.shape[0], parms.Ndata = data.shape[0] # run a iteration for kind in sequence: if parms.iter >= maxiter: return current_psf if kind == "shifts": parms.clip_parms = None shifts, nll = update_shifts( data[:, parms.core_ind], dq[:, parms.core_ind], current_psf, np.zeros((data.shape[0], 2)), parms ) ref_shifts = shifts.copy() print "Shift step 1 done nll, total: ", nll.sum() print "Shift step 1 done nll, min: ", nll.min() print "Shift step 1 done nll, median: ", np.median(nll) print "Shift step 1 done nll, max: ", nll.max() if (trim_frac is not None) & (mask.size > Nmin): assert trim_frac > 0.0, "trim_frac must be positive or None" Ntrim = np.ceil(mask.size * trim_frac).astype( if mask.size - Ntrim < Nmin: Ntrim = mask.size - Nmin # sort and trim the arrays ind = np.sort(np.argsort(nll)[:-Ntrim]) dq = dq[ind] data = data[ind] mask = mask[ind] ref_shifts = ref_shifts[ind] parms.Ndata = data.shape[0] parms.data_ids = parms.data_ids[ind] # re-run shifts shifts, nll = update_shifts( data[:, parms.core_ind], dq[:, parms.core_ind], current_psf, ref_shifts, parms ) else: ind = np.arange(data.shape[0]) print "Shift step 2 done nll, total: ", nll.sum() print "Shift step 2 done nll, min: ", nll.min() print "Shift step 2 done nll, median: ", np.median(nll) print "Shift step 2 done nll, max: ", nll.max() if dumpfilebase is not None: name = dumpfilebase + "_mask_%d.dat" % parms.iter np.savetxt(name, mask, fmt="%d") name = dumpfilebase + "_shifts_%d.dat" % parms.iter np.savetxt(name, shifts) name = dumpfilebase + "_shift_nll_%d.dat" % parms.iter np.savetxt(name, nll) if kind == "evaluate": parms.return_parms = True set_clip_parameters(clip_parms, parms, final_clip) nll, fit_parms, masks = evaluate((data, dq, shifts, current_psf, parms, False)) parms.return_parms = False if kind == "psf": set_clip_parameters(clip_parms, parms, final_clip) if parms.k == 1: new_psf, cost = update_psf_linear(current_psf, data, dq, shifts, nll, fit_parms, masks, parms) else: new_psf, cost = update_psf(current_psf, data, dq, shifts, nll, fit_parms, masks, parms) if new_psf is not None: psf_plot( ini_psf, np.maximum(parms.small, current_psf), np.maximum(parms.small, new_psf), parms.small, parms, ) current_psf = new_psf if dumpfilebase is not None: hdu = pf.PrimaryHDU(current_psf / current_psf.max()) hdu.writeto(dumpfilebase + "_psf_%d.fits" % parms.iter, clobber=True) if kind == "plot_data": if clip_parms is None: parms.clip_parms = [1, np.inf] else: try: parms.clip_parms = clip_parms[parms.iter] except: parms.clip_parms = final_clip parms.plot_data = True foo = evaluate( (data[: parms.Nplot], dq[: parms.Nplot], shifts[: parms.Nplot], current_psf, parms, False) ) parms.plot_data = False if current_cost is None: tup = (-99.0, cost) else: tup = (current_cost, cost) print "\n\nUsing %d patches" % data.shape[0] print "Current cost: %0.2e, new cost %0.2e" % tup dt = (time.time() - t) / 3600.0 dt0 = (time.time() - t0) / 3600.0 print "Iter %d took %0.2e hrs, total %0.2e hrs\n\n" % (parms.iter, dt, dt0) if current_cost is not None: # assert cost < current_cost, 'Global cost did not decrease' if np.abs((current_cost - cost) / cost) < tol: print "Converged at cost %s" % cost return current_psf else: current_cost = cost parms.iter += 1
def learn_psf( data, dq, initial_psf, clip_parms, noise_parms, plotfilebase, kernel_parms, patch_shape, knn=32, min_patch_frac=0.75, core_size=5, nll_tol=1.0e-5, k=3, q=1.0, Nplot=20, plot=False, flann_precision=0.99, final_clip=[1, 3.0], background="constant", Nthreads=20, max_iter=20, max_nll=1.0e10, shift_test_thresh=0.475, ): """ Inference routine for learning a psf model via scaled data and a kernel basis. """ assert background in [None, "constant", "linear"] assert np.mod(patch_shape[0], 2) == 1, "Patch shape[0] must be odd" assert np.mod(patch_shape[1], 2) == 1, "Patch shape[1] must be odd" assert np.mod(core_size, 2) == 1, "Core size must be odd" assert (patch_shape[0] * patch_shape[1]) == data.shape[1], "Patch shape does not match data shape" # bundle parameters to be passed to other functions parms = InferenceParms( k, q, knn, plot, data.shape[0], Nplot, nll_tol, max_nll, Nthreads, core_size, background, clip_parms, patch_shape, noise_parms, kernel_parms, plotfilebase, min_patch_frac, flann_precision, initial_psf.shape, shift_test_thresh, ) # initialize print "Initialized with %d patches\n" % data.shape[0] initial_psf /= initial_psf.max() psf_model = initial_psf.copy() cost = np.Inf # Run through data, reject patches that are bad/crowded. parms.clip_parms = None shifts, nll = update_shifts(data[:, parms.core_ind], dq[:, parms.core_ind], psf_model, parms) set_clip_parameters(clip_parms, parms, final_clip) fit_parms, fit_vars, nll, masks = fit_patches(data, dq, shifts, psf_model, parms) nll = np.sum(nll, axis=1) ind = nll < parms.max_nll shifts = shifts[ind] data = data[ind] dq = dq[ind] parms.data_ids = data[ind] print "%d patches are ok under the initial model\n" % data.shape[0] print "Initial NLL is %0.6e" % np.sum(nll) # Build a new psf psf_model = psf_builder(data, masks, shifts, fit_parms, fit_vars, parms) for blah in range(12): parms.clip_parms = None shifts, nll = update_shifts(data[:, parms.core_ind], dq[:, parms.core_ind], psf_model, parms) print blah, nll.sum(), shifts[0], shifts[-1] set_clip_parameters(clip_parms, parms, final_clip) fit_parms, fit_vars, nll, masks = fit_patches(data, dq, shifts, psf_model, parms) nll = np.sum(nll, axis=1) ind = nll < parms.max_nll shifts = shifts[ind] data = data[ind] dq = dq[ind] parms.data_ids = data[ind] print "New NLL is %0.6e" % np.sum(nll) f = pl.figure(figsize=(16, 8)) pl.subplot(121) pl.imshow(np.abs(psf_model), interpolation="nearest", origin="lower", norm=LogNorm(vmin=1.0e-6, vmax=1.0)) psf_model = psf_builder(data, masks, shifts, fit_parms, fit_vars, parms) pl.colorbar(shrink=0.7) pl.subplot(122) pl.imshow(np.abs(psf_model), interpolation="nearest", origin="lower", norm=LogNorm(vmin=1.0e-6, vmax=1.0)) pl.colorbar(shrink=0.7) print psf_model[50, 50] print psf_model.max() f.savefig("../../plots/foo.png") assert 0