Esempio n. 1
0
    def __init__(self,
                 lib_dir,
                 lmax,
                 nside,
                 cl,
                 transf,
                 ninv,
                 pcf='default',
                 chain_descr=None):
        assert lib_dir is not None and lmax >= 1024 and nside >= 512, (lib_dir,
                                                                       lmax,
                                                                       nside)
        super(cinv_p, self).__init__(lib_dir, lmax)

        self.nside = nside
        self.cl = cl
        self.transf = transf
        self.ninv = ninv

        pcf = os.path.join(lib_dir, "dense.pk") if pcf == 'default' else None
        if chain_descr is None:            chain_descr = \
[[2, ["split(dense(" + pcf + "), 32, diag_cl)"], 512, 256, 3, 0.0, cd_solve.tr_cg,cd_solve.cache_mem()],
            [1, ["split(stage(2),  512, diag_cl)"], 1024, 512, 3, 0.0, cd_solve.tr_cg, cd_solve.cache_mem()],
            [0, ["split(stage(1), 1024, diag_cl)"], lmax, nside, np.inf, 1.0e-5, cd_solve.tr_cg, cd_solve.cache_mem()]]
        n_inv_filt = util.jit(opfilt_pp.alm_filter_ninv, ninv,
                              transf[0:lmax + 1])
        self.chain = util.jit(multigrid.multigrid_chain, opfilt_pp,
                              chain_descr, cl, n_inv_filt)

        if mpi.rank == 0:
            if not os.path.exists(lib_dir):
                os.makedirs(lib_dir)

            if not os.path.exists(os.path.join(lib_dir, "filt_hash.pk")):
                pk.dump(self.hashdict(),
                        open(os.path.join(lib_dir, "filt_hash.pk"), 'wb'),
                        protocol=2)

            if not os.path.exists(os.path.join(self.lib_dir, "fbl.dat")):
                fel, fbl = self._calc_febl()
                np.savetxt(os.path.join(self.lib_dir, "fel.dat"), fel)
                np.savetxt(os.path.join(self.lib_dir, "fbl.dat"), fbl)

            if not os.path.exists(os.path.join(self.lib_dir, "tal.dat")):
                np.savetxt(os.path.join(self.lib_dir, "tal.dat"),
                           self._calc_tal())

            if not os.path.exists(os.path.join(self.lib_dir, "fmask.fits.gz")):
                hp.write_map(os.path.join(self.lib_dir, "fmask.fits.gz"),
                             self._calc_mask())

        mpi.barrier()
        utils.hash_check(
            pk.load(open(os.path.join(lib_dir, "filt_hash.pk"), 'rb')),
            self.hashdict())
Esempio n. 2
0
    def __init__(self,
                 lib_dir,
                 lmax,
                 nside,
                 cl,
                 transf,
                 ninv,
                 marge_maps_t=(),
                 marge_monopole=False,
                 marge_dipole=False,
                 pcf='default',
                 rescal_cl='default'):
        """Instance for joint temperature-polarization filtering

            Here ninv is a  list of lists with mask paths and / or inverse pixel noise levels.
            TT, (QQ + UU) / 2 if len(ninv) == 2 or TT, QQ, QU UU if == 4
            e.g. [[iNevT,mask1,mask2,..],[iNevP,mask1,mask2...]]


        """
        assert (lmax >= 1024)
        assert (nside >= 512)
        assert len(
            ninv) == 2 or len(ninv) == 4  # TT, (QQ + UU)/2 or TT,QQ,QU,UU

        if rescal_cl == 'default':
            rescal_cl = np.sqrt(
                np.arange(lmax + 1, dtype=float) *
                np.arange(1, lmax + 2, dtype=float) / 2. / np.pi)
        elif rescal_cl is None:
            rescal_cl = np.ones(lmax + 1, dtype=float)
        dl = {k: rescal_cl**2 * cl[k][:lmax + 1]
              for k in cl.keys()}  # rescaled cls (Dls by default)
        transf_dl = transf[:lmax + 1] * utils.cli(rescal_cl)

        self.lmax = lmax
        self.nside = nside
        self.cl = cl
        self.transf = transf
        self.ninv = ninv
        self.marge_maps_t = marge_maps_t
        self.marge_maps_p = []

        self.lib_dir = lib_dir
        self.rescal_cl = rescal_cl

        pcf = lib_dir + "/dense_tp.pk" if pcf == 'default' else None
        chain_descr = [[
            3, ["split(dense(" + pcf + "), 64, diag_cl)"], 256, 128, 3, 0.0,
            cd_solve.tr_cg,
            cd_solve.cache_mem()
        ],
                       [
                           2, ["split(stage(3),  256, diag_cl)"], 512, 256, 3,
                           0.0, cd_solve.tr_cg,
                           cd_solve.cache_mem()
                       ],
                       [
                           1, ["split(stage(2),  512, diag_cl)"], 1024, 512, 3,
                           0.0, cd_solve.tr_cg,
                           cd_solve.cache_mem()
                       ],
                       [
                           0, ["split(stage(1), 1024, diag_cl)"], lmax, nside,
                           np.inf, 1.0e-5, cd_solve.tr_cg,
                           cd_solve.cache_mem()
                       ]]

        n_inv_filt = util.jit(opfilt_tp.alm_filter_ninv,
                              ninv,
                              transf_dl,
                              marge_maps_t=marge_maps_t,
                              marge_monopole=marge_monopole,
                              marge_dipole=marge_dipole)
        self.chain = util.jit(multigrid.multigrid_chain, opfilt_tp,
                              chain_descr, dl, n_inv_filt)

        if mpi.rank == 0:
            if not os.path.exists(lib_dir):
                os.makedirs(lib_dir)

            if not os.path.exists(os.path.join(lib_dir, "filt_hash.pk")):
                pk.dump(self.hashdict(),
                        open(os.path.join(lib_dir, "filt_hash.pk"), 'wb'),
                        protocol=2)

            # if (not os.path.exists(self.lib_dir + "/fbl.dat")):
            #    fel, fbl = self.calc_febl()
            #    fel.write(self.lib_dir + "/fel.dat", lambda l: 1.0)
            #    fbl.write(self.lib_dir + "/fbl.dat", lambda l: 1.0)

            # if (not os.path.exists(self.lib_dir + "/tal.dat")):
            #    tal = self.calc_tal()
            #    tal.write(self.lib_dir + "/tal.dat", lambda l: 1.0)

            if not os.path.exists(os.path.join(self.lib_dir, "fmask.fits.gz")):
                fmask = self.calc_mask()
                hp.write_map(os.path.join(self.lib_dir, "fmask.fits.gz"),
                             fmask)

        mpi.barrier()
        utils.hash_check(
            pk.load(open(os.path.join(lib_dir, "filt_hash.pk"), 'rb')),
            self.hashdict())
Esempio n. 3
0
    def __init__(self,
                 lib_dir,
                 lmax,
                 nside,
                 cl,
                 transf,
                 ninv,
                 marge_monopole=True,
                 marge_dipole=True,
                 marge_maps=(),
                 pcf='default',
                 chain_descr=None):

        assert lib_dir is not None and lmax >= 1024 and nside >= 512, (lib_dir,
                                                                       lmax,
                                                                       nside)
        assert isinstance(ninv, list)
        super(cinv_t, self).__init__(lib_dir, lmax)

        self.nside = nside
        self.cl = cl
        self.transf = transf
        self.ninv = ninv
        self.marge_monopole = marge_monopole
        self.marge_dipole = marge_dipole
        self.marge_maps = marge_maps

        pcf = os.path.join(
            lib_dir, "dense.pk"
        ) if pcf == 'default' else ''  # Dense matrices will be cached there.
        if chain_descr is None:            chain_descr = \
[[3, ["split(dense(" + pcf + "), 64, diag_cl)"], 256, 128, 3, 0.0, cd_solve.tr_cg, cd_solve.cache_mem()],
            [2, ["split(stage(3),  256, diag_cl)"], 512, 256, 3, 0.0, cd_solve.tr_cg, cd_solve.cache_mem()],
            [1, ["split(stage(2),  512, diag_cl)"], 1024, 512, 3, 0.0, cd_solve.tr_cg, cd_solve.cache_mem()],
            [0, ["split(stage(1), 1024, diag_cl)"], lmax, nside, np.inf, 1.0e-5, cd_solve.tr_cg, cd_solve.cache_mem()]]

        n_inv_filt = util.jit(opfilt_tt.alm_filter_ninv,
                              ninv,
                              transf[0:lmax + 1],
                              marge_monopole=marge_monopole,
                              marge_dipole=marge_dipole,
                              marge_maps=marge_maps)
        self.chain = util.jit(multigrid.multigrid_chain, opfilt_tt,
                              chain_descr, cl, n_inv_filt)
        if mpi.rank == 0:
            if not os.path.exists(lib_dir):
                os.makedirs(lib_dir)

            if not os.path.exists(os.path.join(lib_dir, "filt_hash.pk")):
                pk.dump(self.hashdict(),
                        open(os.path.join(lib_dir, "filt_hash.pk"), 'wb'),
                        protocol=2)

            if not os.path.exists(os.path.join(self.lib_dir, "ftl.dat")):
                np.savetxt(os.path.join(self.lib_dir, "ftl.dat"),
                           self._calc_ftl())

            if not os.path.exists(os.path.join(self.lib_dir, "tal.dat")):
                np.savetxt(os.path.join(self.lib_dir, "tal.dat"),
                           self._calc_tal())

            if not os.path.exists(os.path.join(self.lib_dir, "fmask.fits.gz")):
                hp.write_map(os.path.join(self.lib_dir, "fmask.fits.gz"),
                             self._calc_mask())

        mpi.barrier()
        utils.hash_check(
            pk.load(open(os.path.join(lib_dir, "filt_hash.pk"), 'rb')),
            self.hashdict())
Esempio n. 4
0
    def __init__(self,
                 lib_dir,
                 lmax,
                 nside,
                 cl,
                 transf,
                 ninv,
                 marge_maps_t=(),
                 marge_monopole=False,
                 marge_dipole=False,
                 pcf='default',
                 rescal_cl='default',
                 chain_descr=None,
                 transf_p=None):
        """Instance for joint temperature-polarization filtering

            Args:
                lib_dir: a few quantities might get cached there
                lmax: CMB filtering performed up to multipole lmax
                nside: healpy resolution of the input maps
                cl: fiducial CMB spectra used to filter the data (dict with 'tt', 'te', 'ee', 'bb' keys)
                transf: CMB transfer function in temperature
                ninv: list of lists with mask paths and / or inverse pixel noise levels.
                        TT, (QQ + UU) / 2 if len(ninv) == 2 or TT, QQ, QU UU if == 4
                        e.g. [[iNevT,mask1,mask2,..],[iNevP,mask1,mask2...]]
                marge_maps_t: maps to project out in the filtering (T-part)
                marge_monopole: marginalizes out the T monopole if set
                marge_dipole: marginalizes out the T dipole if set

                chain_descr: preconditioner mulitgrid chain description (if different from default)
                transf_p: polarization transfer function (if different from temperature)


        """
        assert (lmax >= 1024)
        assert (nside >= 512)
        assert len(
            ninv) == 2 or len(ninv) == 4  # TT, (QQ + UU)/2 or TT,QQ,QU,UU

        if rescal_cl == 'default':
            rescal_cl = {
                a: np.sqrt(
                    np.arange(lmax + 1, dtype=float) *
                    np.arange(1, lmax + 2, dtype=float) / 2. / np.pi)
                for a in ['t', 'e', 'b']
            }
        elif rescal_cl is None:
            rescal_cl = {
                a: np.ones(lmax + 1, dtype=float)
                for a in ['t', 'e', 'b']
            }
        elif rescal_cl == 'tonly':
            rescal_cl = {a: np.ones(lmax + 1, dtype=float) for a in ['e', 'b']}
            rescal_cl['t'] = np.sqrt(
                np.arange(lmax + 1, dtype=float) *
                np.arange(1, lmax + 2, dtype=float) / 2. / np.pi)
        else:
            assert 0
        for k in rescal_cl.keys():
            rescal_cl[k] /= np.mean(
                rescal_cl[k]
            )  # in order not mess around with the TEB relative weights of the spectra
        dl = {
            k: rescal_cl[k[0]] * rescal_cl[k[1]] * cl[k][:lmax + 1]
            for k in cl.keys()
        }  # rescaled cls (Dls by default)
        if transf_p is None:
            transf_p = transf
        transf_dls = {
            a: transf_p[:lmax + 1] * utils.cli(rescal_cl[a])
            for a in ['e', 'b']
        }
        transf_dls['t'] = transf[:lmax + 1] * utils.cli(rescal_cl['t'])

        self.lmax = lmax
        self.nside = nside
        self.cl = cl
        self.transf_t = transf
        self.transf_p = transf_p
        self.ninv = ninv
        self.marge_maps_t = marge_maps_t
        self.marge_maps_p = []

        self.lib_dir = lib_dir
        self.rescal_cl = rescal_cl

        if chain_descr is None:
            pcf = lib_dir + "/dense_tp.pk" if pcf == 'default' else None
            chain_descr = [[
                3, ["split(dense(" + pcf + "), 64, diag_cl)"], 256, 128, 3,
                0.0, cd_solve.tr_cg,
                cd_solve.cache_mem()
            ],
                           [
                               2, ["split(stage(3),  256, diag_cl)"], 512, 256,
                               3, 0.0, cd_solve.tr_cg,
                               cd_solve.cache_mem()
                           ],
                           [
                               1, ["split(stage(2),  512, diag_cl)"], 1024,
                               512, 3, 0.0, cd_solve.tr_cg,
                               cd_solve.cache_mem()
                           ],
                           [
                               0, ["split(stage(1), 1024, diag_cl)"], lmax,
                               nside, np.inf, 1.0e-5, cd_solve.tr_cg,
                               cd_solve.cache_mem()
                           ]]

        n_inv_filt = util.jit(opfilt_tp.alm_filter_ninv,
                              ninv,
                              transf_dls['t'],
                              b_transf_e=transf_dls['e'],
                              b_transf_b=transf_dls['b'],
                              marge_maps_t=marge_maps_t,
                              marge_monopole=marge_monopole,
                              marge_dipole=marge_dipole)
        self.chain = util.jit(multigrid.multigrid_chain, opfilt_tp,
                              chain_descr, dl, n_inv_filt)

        if mpi.rank == 0:
            if not os.path.exists(lib_dir):
                os.makedirs(lib_dir)

            if not os.path.exists(os.path.join(lib_dir, "filt_hash.pk")):
                pk.dump(self.hashdict(),
                        open(os.path.join(lib_dir, "filt_hash.pk"), 'wb'),
                        protocol=2)

            if not os.path.exists(os.path.join(lib_dir, "fal.pk")):
                pk.dump(self._calc_fal(),
                        open(os.path.join(lib_dir, "fal.pk"), 'wb'),
                        protocol=2)

            if not os.path.exists(os.path.join(self.lib_dir, "fmask.fits.gz")):
                fmask = self.calc_mask()
                hp.write_map(os.path.join(self.lib_dir, "fmask.fits.gz"),
                             fmask)

        mpi.barrier()
        utils.hash_check(
            pk.load(open(os.path.join(lib_dir, "filt_hash.pk"), 'rb')),
            self.hashdict())