def get_isomgchain(lmax_sky, datshape, tol=1e-5, iter_max=np.inf, **kwargs): assert datshape[0] == datshape[1], datshape nside_max = datshape[0] return [[ 0, ["diag_cl"], lmax_sky, nside_max, iter_max, tol, cd_solve.tr_cg, cd_solve.cache_mem() ]]
def solve(self, x, tqu): self.watch = util.stopwatch() self.iter_tot = 0 self.prev_eps = None logger = (lambda iter, eps, stage=0, **kwargs: self.log( stage, iter, eps, **kwargs)) monitor = cd_monitors.monitor_basic(self.opfilt.dot_op(), logger=logger, iter_max=self.iter_max, eps_min=self.eps_min) cd_solve.cd_solve(x=x, b=self.opfilt.calc_prep(tqu, self.s_inv_filt, self.n_inv_filt), fwd_op=self.opfilt.fwd_op(self.s_inv_filt, self.n_inv_filt), pre_ops=[self.pre_op], dot_op=self.opfilt.dot_op(), criterion=monitor, tr=cd_solve.tr_cg, cache=cd_solve.cache_mem()) return self.opfilt.calc_fini(x, self.s_inv_filt, self.n_inv_filt)
def __init__(self, nstages, opfilt, s_inv_filt, n_inv_filt, plogdepth=0, eps_min=1.e-6, stage_iter_max=3): """ initialize this multigrid chain by constructing the resolution stages. * nstages = total number of stages in the chain. * opfilt = module defining dot_op, fwd_op, calc_prep, calc_fini, and pre_op_diag filter functions. * s_inv_filt = the S^{-1} filter for the highest resolution stage in the chain. * n_inv_filt = the N^{-1} filter for the highest resolution stage in the chain. * plogdepth = the lowest stage at which to print convergence information. * eps_min = the convergence criterion. * stage_iter_max = the maximum number of iterations in the substages. """ self.opfilt = opfilt self.s_inv_filt = s_inv_filt self.n_inv_filt = n_inv_filt self.plogdepth = plogdepth self.eps_min = eps_min self.stage_iter_max = stage_iter_max self.iter_max = np.inf pre_op = opfilt.pre_op_diag(s_inv_filt, n_inv_filt.degrade(2**nstages)) for i in xrange(0, nstages): class slog(object): def __init__(self, id, cobj): self.id = 1 * id self.cobj = cobj def log(self, iter, eps, **kwargs): self.cobj.log(self.id, iter, eps, **kwargs) pre_op = pre_op_split( pre_op_multigrid(opfilt, s_inv_filt, n_inv_filt.degrade(2**(nstages - i)), [pre_op], slog(nstages - i, self).log, cd_solve.tr_cg, cd_solve.cache_mem(), stage_iter_max, 0.0), opfilt.pre_op_diag(s_inv_filt, n_inv_filt.degrade(2**(nstages - i - 1)))) self.pre_op = pre_op
def get_low_res_mgchain(lmax_sky=3000, lmax_split=2999, nside=2048, tol=1e-4, Nsub=20, dense_file='', **kwargs): #lmax_sky = 3000 #lmax_split = 2999 #nside = 2048 #tol = 1e-4 #Nsub = 20 nside_diag = nside chain_descr = [[ 1, ["diag_cl"], lmax_split, nside_diag, Nsub, 0., cd_solve.tr_cg, cd_solve.cache_mem() ], [ 0, ["split(stage(1), %s, diag_cl)" % (lmax_split)], lmax_sky, nside, np.inf, tol, cd_solve.tr_cg, cd_solve.cache_mem() ]] return chain_descr
def solve( self, x, tqu ): self.watch = util.stopwatch() self.iter_tot = 0 self.prev_eps = None logger = (lambda iter, eps, stage=0, **kwargs : self.log(stage, iter, eps, **kwargs)) monitor = cd_monitors.monitor_basic(self.opfilt.dot_op(), logger=logger, iter_max=self.iter_max, eps_min=self.eps_min) cd_solve.cd_solve( x = x, b = self.opfilt.calc_prep(tqu, self.s_inv_filt, self.n_inv_filt), fwd_op = self.opfilt.fwd_op(self.s_inv_filt, self.n_inv_filt), pre_ops = [self.pre_op], dot_op = self.opfilt.dot_op(), criterion = monitor, tr=cd_solve.tr_cg, cache=cd_solve.cache_mem() ) return self.opfilt.calc_fini( x, self.s_inv_filt, self.n_inv_filt)
def get_densediagchain(lsides, lmax_sky, datshape, dense_file, tol=1e-5, iter_max=np.inf): assert datshape[0] == datshape[1], datshape dense_size = 2000 if np.prod(lsides) >= (4. * np.pi) - 0.1: lmax_dense = 64 else: lmax_dense = np.sqrt(2. / 2. / np.pi * (2 * np.pi)**2 / np.prod(lsides) * dense_size) lmax_dense = int(np.round(min(lmax_dense, 1300))) print "chain_samples : setting lmax_dense to ", lmax_dense chain_descr = [[ 0, ["split(dense(" + dense_file + "), %s, diag_cl)" % (int(lmax_dense))], lmax_sky, datshape[0], iter_max, tol, cd_solve.tr_cg, cd_solve.cache_mem() ]] return chain_descr
def __init__(self, nstages, opfilt, s_inv_filt, n_inv_filt, plogdepth=0, eps_min=1.e-6, stage_iter_max=3): """ initialize this multigrid chain by constructing the resolution stages. * nstages = total number of stages in the chain. * opfilt = module defining dot_op, fwd_op, calc_prep, calc_fini, and pre_op_diag filter functions. * s_inv_filt = the S^{-1} filter for the highest resolution stage in the chain. * n_inv_filt = the N^{-1} filter for the highest resolution stage in the chain. * plogdepth = the lowest stage at which to print convergence information. * eps_min = the convergence criterion. * stage_iter_max = the maximum number of iterations in the substages. """ self.opfilt = opfilt self.s_inv_filt = s_inv_filt self.n_inv_filt = n_inv_filt self.plogdepth = plogdepth self.eps_min = eps_min self.stage_iter_max = stage_iter_max self.iter_max = np.inf pre_op = opfilt.pre_op_diag( s_inv_filt, n_inv_filt.degrade(2**nstages) ) for i in xrange(0, nstages): class slog(object): def __init__(self, id, cobj): self.id = 1*id self.cobj = cobj def log(self, iter, eps, **kwargs): self.cobj.log( self.id, iter, eps, **kwargs ) pre_op = pre_op_split( pre_op_multigrid(opfilt, s_inv_filt, n_inv_filt.degrade(2**(nstages-i)), [pre_op], slog(nstages-i, self).log, cd_solve.tr_cg, cd_solve.cache_mem(), stage_iter_max, 0.0 ), opfilt.pre_op_diag( s_inv_filt, n_inv_filt.degrade(2**(nstages-i-1)) ) ) self.pre_op = pre_op
def get_defaultmgchain(lmax_sky, lsides, datshape, tol=1e-5, iter_max=np.inf, dense_file='', **kwargs): # FIXME : assert datshape[0] == datshape[1], datshape nside_max = datshape[0] if lmax_sky > 4000: dense_size = 2000 if np.prod(lsides) >= (4. * np.pi) - 0.1: lmax_dense = 64 else: lmax_dense = np.sqrt(2. / 2. / np.pi * (2 * np.pi)**2 / np.prod(lsides) * dense_size) lmax_dense = int(np.round(min(lmax_dense, 1300))) print "chain_samples : setting lmax_dense to ", lmax_dense chain_descr = [[ 3, [ "split(dense(" + dense_file + "), %s, diag_cl)" % (int(lmax_dense)) ], 1400, nside_max / 4, 3, 0., cd_solve.tr_cg, cd_solve.cache_mem() ], [ 2, ["split(stage(3), %s, diag_cl)" % (1400)], 3000, nside_max / 2, 3, 0., cd_solve.tr_cg, cd_solve.cache_mem() ], [ 1, ["split(stage(2), %s, diag_cl)" % (3000)], 4000, nside_max / 2, 3, 0., cd_solve.tr_cg, cd_solve.cache_mem() ], [ 0, ["split(stage(1), %s, diag_cl)" % (4000)], lmax_sky, nside_max, iter_max, tol, cd_solve.tr_cg, cd_solve.cache_mem() ]] elif lmax_sky > 3000: dense_size = 2000 lmax_dense = np.sqrt(2. / 2. / np.pi * (2 * np.pi)**2 / np.prod(lsides) * dense_size) lmax_dense = int(np.round(min(lmax_dense, 1300))) print "chain_samples : setting lmax_dense to ", lmax_dense chain_descr = [[ 2, [ "split(dense(" + dense_file + "), %s, diag_cl)" % (int(lmax_dense)) ], 1400, nside_max / 4, 3, 0., cd_solve.tr_cg, cd_solve.cache_mem() ], [ 1, ["split(stage(2), %s, diag_cl)" % (1400)], 3000, nside_max / 2, 3, 0., cd_solve.tr_cg, cd_solve.cache_mem() ], [ 0, ["split(stage(1), %s, diag_cl)" % (3000)], lmax_sky, nside_max / 2, iter_max, tol, cd_solve.tr_cg, cd_solve.cache_mem() ]] else: # Same as PL2015 pipeline : # chain_descr = [[3, ["split(dense(" + pcf + "), 64, diag_cl)"], 256, 128, 3, 0.0, qcinv.cd_solve.tr_cg, # qcinv.cd_solve.cache_mem()], # [2, ["split(stage(3), 256, diag_cl)"], 512, 256, 3, 0.0, qcinv.cd_solve.tr_cg, # qcinv.cd_solve.cache_mem()], # [1, ["split(stage(2), 512, diag_cl)"], 1024, 512, 3, 0.0, qcinv.cd_solve.tr_cg, # qcinv.cd_solve.cache_mem()], # [0, ["split(stage(1), 1024, diag_cl)"], lmax, nside, np.inf, 1.0e-5, qcinv.cd_solve.tr_cg, # qcinv.cd_solve.cache_mem()]] # On the full flat sky with lmax_sky 2048 it solves the thing to 1e-5 in 8 min or so on the laptop. res = lambda fac: max(10, nside_max / fac) chain_descr = [[ 3, ["split(dense(" + dense_file + "), %s, diag_cl)" % (64)], 256, res(16), 3, 0., cd_solve.tr_cg, cd_solve.cache_mem() ], [ 2, ["split(stage(3), %s, diag_cl)" % (256)], 512, res(8), 3, 0., cd_solve.tr_cg, cd_solve.cache_mem() ], [ 1, ["split(stage(2), %s, diag_cl)" % (512)], 1024, res(4), 3, 0., cd_solve.tr_cg, cd_solve.cache_mem() ], [ 0, ["split(stage(1), %s, diag_cl)" % (1024)], lmax_sky, nside_max, iter_max, tol, cd_solve.tr_cg, cd_solve.cache_mem() ]] return chain_descr