Exemplo n.º 1
0
def get_isomgchain(lmax_sky, datshape, tol=1e-5, iter_max=np.inf, **kwargs):
    assert datshape[0] == datshape[1], datshape
    nside_max = datshape[0]
    return [[
        0, ["diag_cl"], lmax_sky, nside_max, iter_max, tol, cd_solve.tr_cg,
        cd_solve.cache_mem()
    ]]
Exemplo n.º 2
0
    def solve(self, x, tqu):
        self.watch = util.stopwatch()

        self.iter_tot = 0
        self.prev_eps = None

        logger = (lambda iter, eps, stage=0, **kwargs: self.log(
            stage, iter, eps, **kwargs))

        monitor = cd_monitors.monitor_basic(self.opfilt.dot_op(),
                                            logger=logger,
                                            iter_max=self.iter_max,
                                            eps_min=self.eps_min)

        cd_solve.cd_solve(x=x,
                          b=self.opfilt.calc_prep(tqu, self.s_inv_filt,
                                                  self.n_inv_filt),
                          fwd_op=self.opfilt.fwd_op(self.s_inv_filt,
                                                    self.n_inv_filt),
                          pre_ops=[self.pre_op],
                          dot_op=self.opfilt.dot_op(),
                          criterion=monitor,
                          tr=cd_solve.tr_cg,
                          cache=cd_solve.cache_mem())

        return self.opfilt.calc_fini(x, self.s_inv_filt, self.n_inv_filt)
Exemplo n.º 3
0
    def __init__(self,
                 nstages,
                 opfilt,
                 s_inv_filt,
                 n_inv_filt,
                 plogdepth=0,
                 eps_min=1.e-6,
                 stage_iter_max=3):
        """ initialize this multigrid chain by constructing the resolution stages.
             * nstages        = total number of stages in the chain.
             * opfilt         = module defining dot_op, fwd_op, calc_prep, calc_fini, and pre_op_diag filter functions.
             * s_inv_filt     = the S^{-1} filter for the highest resolution stage in the chain.
             * n_inv_filt     = the N^{-1} filter for the highest resolution stage in the chain.
             * plogdepth      = the lowest stage at which to print convergence information.
             * eps_min        = the convergence criterion.
             * stage_iter_max = the maximum number of iterations in the substages.
        """
        self.opfilt = opfilt

        self.s_inv_filt = s_inv_filt
        self.n_inv_filt = n_inv_filt

        self.plogdepth = plogdepth
        self.eps_min = eps_min
        self.stage_iter_max = stage_iter_max

        self.iter_max = np.inf

        pre_op = opfilt.pre_op_diag(s_inv_filt, n_inv_filt.degrade(2**nstages))

        for i in xrange(0, nstages):

            class slog(object):
                def __init__(self, id, cobj):
                    self.id = 1 * id
                    self.cobj = cobj

                def log(self, iter, eps, **kwargs):
                    self.cobj.log(self.id, iter, eps, **kwargs)

            pre_op = pre_op_split(
                pre_op_multigrid(opfilt, s_inv_filt,
                                 n_inv_filt.degrade(2**(nstages - i)),
                                 [pre_op],
                                 slog(nstages - i, self).log, cd_solve.tr_cg,
                                 cd_solve.cache_mem(), stage_iter_max, 0.0),
                opfilt.pre_op_diag(s_inv_filt,
                                   n_inv_filt.degrade(2**(nstages - i - 1))))

        self.pre_op = pre_op
Exemplo n.º 4
0
def get_low_res_mgchain(lmax_sky=3000,
                        lmax_split=2999,
                        nside=2048,
                        tol=1e-4,
                        Nsub=20,
                        dense_file='',
                        **kwargs):
    #lmax_sky = 3000
    #lmax_split = 2999
    #nside = 2048
    #tol = 1e-4
    #Nsub = 20
    nside_diag = nside

    chain_descr = [[
        1, ["diag_cl"], lmax_split, nside_diag, Nsub, 0., cd_solve.tr_cg,
        cd_solve.cache_mem()
    ],
                   [
                       0, ["split(stage(1), %s, diag_cl)" % (lmax_split)],
                       lmax_sky, nside, np.inf, tol, cd_solve.tr_cg,
                       cd_solve.cache_mem()
                   ]]
    return chain_descr
Exemplo n.º 5
0
    def solve( self, x, tqu ):
        self.watch = util.stopwatch()

        self.iter_tot   = 0
        self.prev_eps   = None

        logger = (lambda iter, eps, stage=0, **kwargs :
                  self.log(stage, iter, eps, **kwargs))

        monitor = cd_monitors.monitor_basic(self.opfilt.dot_op(), logger=logger, iter_max=self.iter_max, eps_min=self.eps_min)

        cd_solve.cd_solve( x = x,
                           b = self.opfilt.calc_prep(tqu, self.s_inv_filt, self.n_inv_filt),
                           fwd_op = self.opfilt.fwd_op(self.s_inv_filt, self.n_inv_filt),
                           pre_ops = [self.pre_op], dot_op = self.opfilt.dot_op(),
                           criterion = monitor, tr=cd_solve.tr_cg, cache=cd_solve.cache_mem() )

        return self.opfilt.calc_fini( x, self.s_inv_filt, self.n_inv_filt)
Exemplo n.º 6
0
def get_densediagchain(lsides,
                       lmax_sky,
                       datshape,
                       dense_file,
                       tol=1e-5,
                       iter_max=np.inf):
    assert datshape[0] == datshape[1], datshape
    dense_size = 2000
    if np.prod(lsides) >= (4. * np.pi) - 0.1:
        lmax_dense = 64
    else:
        lmax_dense = np.sqrt(2. / 2. / np.pi * (2 * np.pi)**2 /
                             np.prod(lsides) * dense_size)
        lmax_dense = int(np.round(min(lmax_dense, 1300)))
    print "chain_samples : setting lmax_dense to ", lmax_dense
    chain_descr = [[
        0,
        ["split(dense(" + dense_file + "), %s, diag_cl)" % (int(lmax_dense))],
        lmax_sky, datshape[0], iter_max, tol, cd_solve.tr_cg,
        cd_solve.cache_mem()
    ]]
    return chain_descr
Exemplo n.º 7
0
    def __init__(self, nstages, opfilt, s_inv_filt, n_inv_filt, plogdepth=0, eps_min=1.e-6, stage_iter_max=3):
        """ initialize this multigrid chain by constructing the resolution stages.
             * nstages        = total number of stages in the chain.
             * opfilt         = module defining dot_op, fwd_op, calc_prep, calc_fini, and pre_op_diag filter functions.
             * s_inv_filt     = the S^{-1} filter for the highest resolution stage in the chain.
             * n_inv_filt     = the N^{-1} filter for the highest resolution stage in the chain.
             * plogdepth      = the lowest stage at which to print convergence information.
             * eps_min        = the convergence criterion.
             * stage_iter_max = the maximum number of iterations in the substages.
        """
        self.opfilt         = opfilt

        self.s_inv_filt     = s_inv_filt
        self.n_inv_filt     = n_inv_filt

        self.plogdepth      = plogdepth
        self.eps_min        = eps_min
        self.stage_iter_max = stage_iter_max

        self.iter_max       = np.inf

        pre_op              = opfilt.pre_op_diag( s_inv_filt, n_inv_filt.degrade(2**nstages) )

        for i in xrange(0, nstages):
            class slog(object):
                def __init__(self, id, cobj):
                    self.id = 1*id
                    self.cobj = cobj
                def log(self, iter, eps, **kwargs):
                    self.cobj.log( self.id, iter, eps, **kwargs )

            pre_op = pre_op_split( pre_op_multigrid(opfilt, s_inv_filt, n_inv_filt.degrade(2**(nstages-i)),
                                                    [pre_op], slog(nstages-i, self).log, cd_solve.tr_cg,
                                                    cd_solve.cache_mem(), stage_iter_max, 0.0 ),
                                   opfilt.pre_op_diag( s_inv_filt, n_inv_filt.degrade(2**(nstages-i-1)) ) )
            
        self.pre_op = pre_op
Exemplo n.º 8
0
def get_defaultmgchain(lmax_sky,
                       lsides,
                       datshape,
                       tol=1e-5,
                       iter_max=np.inf,
                       dense_file='',
                       **kwargs):
    # FIXME :
    assert datshape[0] == datshape[1], datshape
    nside_max = datshape[0]
    if lmax_sky > 4000:
        dense_size = 2000
        if np.prod(lsides) >= (4. * np.pi) - 0.1:
            lmax_dense = 64
        else:
            lmax_dense = np.sqrt(2. / 2. / np.pi * (2 * np.pi)**2 /
                                 np.prod(lsides) * dense_size)
            lmax_dense = int(np.round(min(lmax_dense, 1300)))
        print "chain_samples : setting lmax_dense to ", lmax_dense
        chain_descr = [[
            3,
            [
                "split(dense(" + dense_file + "), %s, diag_cl)" %
                (int(lmax_dense))
            ], 1400, nside_max / 4, 3, 0., cd_solve.tr_cg,
            cd_solve.cache_mem()
        ],
                       [
                           2, ["split(stage(3), %s, diag_cl)" % (1400)], 3000,
                           nside_max / 2, 3, 0., cd_solve.tr_cg,
                           cd_solve.cache_mem()
                       ],
                       [
                           1, ["split(stage(2), %s, diag_cl)" % (3000)], 4000,
                           nside_max / 2, 3, 0., cd_solve.tr_cg,
                           cd_solve.cache_mem()
                       ],
                       [
                           0, ["split(stage(1), %s, diag_cl)" % (4000)],
                           lmax_sky, nside_max, iter_max, tol, cd_solve.tr_cg,
                           cd_solve.cache_mem()
                       ]]
    elif lmax_sky > 3000:
        dense_size = 2000
        lmax_dense = np.sqrt(2. / 2. / np.pi * (2 * np.pi)**2 /
                             np.prod(lsides) * dense_size)
        lmax_dense = int(np.round(min(lmax_dense, 1300)))
        print "chain_samples : setting lmax_dense to ", lmax_dense
        chain_descr = [[
            2,
            [
                "split(dense(" + dense_file + "), %s, diag_cl)" %
                (int(lmax_dense))
            ], 1400, nside_max / 4, 3, 0., cd_solve.tr_cg,
            cd_solve.cache_mem()
        ],
                       [
                           1, ["split(stage(2), %s, diag_cl)" % (1400)], 3000,
                           nside_max / 2, 3, 0., cd_solve.tr_cg,
                           cd_solve.cache_mem()
                       ],
                       [
                           0, ["split(stage(1), %s, diag_cl)" % (3000)],
                           lmax_sky, nside_max / 2, iter_max, tol,
                           cd_solve.tr_cg,
                           cd_solve.cache_mem()
                       ]]
    else:
        # Same as PL2015 pipeline :
        # chain_descr = [[3, ["split(dense(" + pcf + "), 64, diag_cl)"], 256, 128, 3, 0.0, qcinv.cd_solve.tr_cg,
        #                qcinv.cd_solve.cache_mem()],
        #               [2, ["split(stage(3),  256, diag_cl)"], 512, 256, 3, 0.0, qcinv.cd_solve.tr_cg,
        #                qcinv.cd_solve.cache_mem()],
        #               [1, ["split(stage(2),  512, diag_cl)"], 1024, 512, 3, 0.0, qcinv.cd_solve.tr_cg,
        #                qcinv.cd_solve.cache_mem()],
        #               [0, ["split(stage(1), 1024, diag_cl)"], lmax, nside, np.inf, 1.0e-5, qcinv.cd_solve.tr_cg,
        #                qcinv.cd_solve.cache_mem()]]
        # On the full flat sky with lmax_sky 2048 it solves the thing to 1e-5 in 8 min or so on the laptop.
        res = lambda fac: max(10, nside_max / fac)
        chain_descr = [[
            3, ["split(dense(" + dense_file + "), %s, diag_cl)" % (64)], 256,
            res(16), 3, 0., cd_solve.tr_cg,
            cd_solve.cache_mem()
        ],
                       [
                           2, ["split(stage(3), %s, diag_cl)" % (256)], 512,
                           res(8), 3, 0., cd_solve.tr_cg,
                           cd_solve.cache_mem()
                       ],
                       [
                           1, ["split(stage(2), %s, diag_cl)" % (512)], 1024,
                           res(4), 3, 0., cd_solve.tr_cg,
                           cd_solve.cache_mem()
                       ],
                       [
                           0, ["split(stage(1), %s, diag_cl)" % (1024)],
                           lmax_sky, nside_max, iter_max, tol, cd_solve.tr_cg,
                           cd_solve.cache_mem()
                       ]]

    return chain_descr