Exemple #1
0
    def __init__(self, comm=None, nnz=1, dtype=np.float64, nside=16):
        if libsharp is None:
            raise RuntimeError('libsharp not available')
        self.data = None
        self._comm = comm
        self._nnz = nnz
        self._dtype = dtype
        self._nest = False
        self._nside = nside

        self._cache = Cache()

        self._libsharp_grid, self._local_ring_indices = distribute_rings(
            self._nside, self._comm.rank, self._comm.size)
        # returns start index of the ring and number of pixels
        startpix, ringpix, _, _, _ = hp.ringinfo(
            self._nside, self._local_ring_indices.astype(np.int64))

        local_npix = self._libsharp_grid.local_size()
        self._local_pixels = self._cache.create("local_pixels",
                                                shape=(local_npix, ),
                                                type=np.int64)
        expand_pix(startpix, ringpix, local_npix, self._local_pixels)

        self.data = self._cache.create("data",
                                       shape=(local_npix, self._nnz),
                                       type=self._dtype)
Exemple #2
0
def main():
    log = Logger.get()

    parser = argparse.ArgumentParser(
        description="Allocate and free cache objects.")

    parser.add_argument("--ndet",
                        required=False,
                        type=int,
                        default=10,
                        help="The number of detectors")

    parser.add_argument("--nobs",
                        required=False,
                        type=int,
                        default=2,
                        help="The number of observations")

    parser.add_argument(
        "--obsminutes",
        required=False,
        type=int,
        default=60,
        help="The number of minutes in each observation.",
    )

    parser.add_argument("--rate",
                        required=False,
                        type=float,
                        default=37.0,
                        help="The sample rate.")

    parser.add_argument(
        "--nloop",
        required=False,
        type=int,
        default=2,
        help="The number of allocate / free loops",
    )

    args = parser.parse_args()

    log.info("Input parameters:")
    log.info("  {} observations".format(args.nobs))
    log.info("  {} minutes per obs".format(args.obsminutes))
    log.info("  {} detectors per obs".format(args.ndet))
    log.info("  {}Hz sample rate".format(args.rate))

    nsampobs = int(args.obsminutes * 60 * args.rate)

    nsamptot = args.ndet * args.nobs * nsampobs

    log.info("{} total samples across all detectors and observations".format(
        nsamptot))

    bytes_sigobs = nsampobs * 8
    bytes_sigtot = nsamptot * 8
    bytes_flagobs = nsampobs * 1
    bytes_flagtot = nsamptot * 1
    bytes_pixobs = nsampobs * 8
    bytes_pixtot = nsamptot * 8
    bytes_wtobs = 3 * nsampobs * 4
    bytes_wttot = 3 * nsamptot * 4

    bytes_tot = bytes_sigtot + bytes_flagtot + bytes_pixtot + bytes_wttot
    bytes_tot_mb = bytes_tot / 2**20
    log.info("{} total bytes ({:0.2f}MB) of data expected".format(
        bytes_tot, bytes_tot_mb))

    for lp in range(args.nloop):
        log.info("Allocation loop {:02d}".format(lp))
        vmem = psutil.virtual_memory()._asdict()
        avstart = vmem["available"]
        avstart_mb = avstart / 2**20
        log.info(
            "  Starting with {:0.2f}MB of available memory".format(avstart_mb))

        # The list of Caches, one per "observation"
        caches = list()

        # This structure holds external references to cache objects, to ensure that we
        # can destroy objects and free memory, even if external references are held.
        refs = list()

        for ob in range(args.nobs):
            ch = Cache()
            rf = dict()
            for det in range(args.ndet):
                dname = "{:04d}".format(det)
                cname = "{}_sig".format(dname)
                rf[cname] = ch.create(cname, np.float64, (nsampobs, ))
                cname = "{}_flg".format(dname)
                rf[cname] = ch.create(cname, np.uint8, (nsampobs, ))
                cname = "{}_pix".format(dname)
                rf[cname] = ch.create(cname, np.int64, (nsampobs, ))
                cname = "{}_wgt".format(dname)
                rf[cname] = ch.create(cname, np.float32, (nsampobs, 3))
            caches.append(ch)
            refs.append(rf)

        vmem = psutil.virtual_memory()._asdict()
        avpost = vmem["available"]
        avpost_mb = avpost / 2**20
        log.info("  After allocation, {:0.2f}MB of available memory".format(
            avpost_mb))

        diff = avstart_mb - avpost_mb
        diffperc = 100.0 * np.absolute(diff - bytes_tot_mb) / bytes_tot_mb
        log.info(
            "  Difference is {:0.2f}MB, expected {:0.2f}MB ({:0.2f}% residual)"
            .format(diff, bytes_tot_mb, diffperc))

        for suf in ["wgt", "pix", "flg", "sig"]:
            for ob, ch in zip(range(args.nobs), caches):
                for det in range(args.ndet):
                    dname = "{:04d}".format(det)
                    ch.destroy("{}_{}".format(dname, suf))

        vmem = psutil.virtual_memory()._asdict()
        avfinal = vmem["available"]
        avfinal_mb = avfinal / 2**20
        log.info("  After destruction, {:0.2f}MB of available memory".format(
            avfinal_mb))

        diff = avfinal_mb - avpost_mb
        diffperc = 100.0 * np.absolute(diff - bytes_tot_mb) / bytes_tot_mb
        log.info(
            "  Difference is {:0.2f}MB, expected {:0.2f}MB ({:0.2f}% residual)"
            .format(diff, bytes_tot_mb, diffperc))

    return
Exemple #3
0
class OpMadam(Operator):
    """
    Operator which passes data to libmadam for map-making.

    Args:
        params (dictionary): parameters to pass to madam.
        detweights (dictionary): individual noise weights to use for each
            detector.
        pixels (str): the name of the cache object (<pixels>_<detector>)
            containing the pixel indices to use.
        pixels_nested (bool): Set to False if the pixel numbers are in
            ring ordering. Default is True.
        weights (str): the name of the cache object (<weights>_<detector>)
            containing the pointing weights to use.
        name (str): the name of the cache object (<name>_<detector>) to
            use for the detector timestream.  If None, use the TOD.
        name_out (str): the name of the cache object (<name>_<detector>)
            to use to output destriped detector timestream.
            No output if None.
        flag_name (str): the name of the cache object
            (<flag_name>_<detector>) to use for the detector flags.
            If None, use the TOD.
        flag_mask (int): the integer bit mask (0-255) that should be
            used with the detector flags in a bitwise AND.
        common_flag_name (str): the name of the cache object
            to use for the common flags.  If None, use the TOD.
        common_flag_mask (int): the integer bit mask (0-255) that should
            be used with the common flags in a bitwise AND.
        apply_flags (bool): whether to apply flags to the pixel numbers.
        purge (bool): if True, clear any cached data that is copied into
            the Madam buffers.
        purge_tod (bool): if True, clear any cached signal that is
            copied into the Madam buffers.
        purge_pixels (bool): if True, clear any cached pixels that are
            copied into the Madam buffers.
        purge_weights (bool): if True, clear any cached weights that are
            copied into the Madam buffers.
        purge_flags (bool): if True, clear any cached flags that are
            copied into the Madam buffers.
        dets (iterable):  List of detectors to map. If left as None, all
            available detectors are mapped.
        mcmode (bool): If true, the operator is constructed in
            Monte Carlo mode and Madam will cache auxiliary information
            such as pixel matrices and noise filter.
        noise (str): Keyword to use when retrieving the noise object
            from the observation.
        conserve_memory(bool): Stagger the Madam buffer staging on node.
        translate_timestamps(bool): Translate timestamps to enforce
            monotonity.
    """
    def __init__(self,
                 params={},
                 detweights=None,
                 pixels='pixels',
                 pixels_nested=True,
                 weights='weights',
                 name=None,
                 name_out=None,
                 flag_name=None,
                 flag_mask=255,
                 common_flag_name=None,
                 common_flag_mask=255,
                 apply_flags=True,
                 purge=False,
                 dets=None,
                 mcmode=False,
                 purge_tod=False,
                 purge_pixels=False,
                 purge_weights=False,
                 purge_flags=False,
                 noise='noise',
                 intervals='intervals',
                 conserve_memory=True,
                 translate_timestamps=True):

        # We call the parent class constructor, which currently does nothing
        super().__init__()
        # madam uses time-based distribution
        self._name = name
        self._name_out = name_out
        self._flag_name = flag_name
        self._flag_mask = flag_mask
        self._common_flag_name = common_flag_name
        self._common_flag_mask = common_flag_mask
        self._pixels = pixels
        self._pixels_nested = pixels_nested
        self._weights = weights
        self._detw = detweights
        self._purge = purge
        if self._purge:
            self._purge_tod = True
            self._purge_pixels = True
            self._purge_weights = True
            self._purge_flags = True
        else:
            self._purge_tod = purge_tod
            self._purge_pixels = purge_pixels
            self._purge_weights = purge_weights
            self._purge_flags = purge_flags
        self._apply_flags = apply_flags
        self.params = params
        if dets is not None:
            self._dets = set(dets)
        else:
            self._dets = None
        self._mcmode = mcmode
        if mcmode:
            self.params['mcmode'] = True
        else:
            self.params['mcmode'] = False
        if self._name_out is not None:
            self.params['write_tod'] = True
        else:
            self.params['write_tod'] = False
        self._cached = False
        self._noisekey = noise
        self._intervals = intervals
        self._cache = Cache()
        self._madam_timestamps = None
        self._madam_pixels = None
        self._madam_pixweights = None
        self._madam_signal = None
        self._conserve_memory = conserve_memory
        self._translate_timestamps = translate_timestamps

    def __del__(self):
        self._cache.clear()
        if self._cached:
            libmadam.clear_caches()
            self._cached = False

    @property
    def available(self):
        """
        (bool): True if libmadam is found in the library search path.
        """
        return (libmadam is not None)

    def _dict2parstring(self, d):
        s = ''
        for key, value in d.items():
            if key in repeated_keys:
                for separate_value in value:
                    s += '{} = {};'.format(key, separate_value)
            else:
                s += '{} = {};'.format(key, value)
        return s

    def _dets2detstring(self, dets):
        s = ''
        for d in dets:
            s += '{};'.format(d)
        return s

    def exec(self, data, comm=None):
        """
        Copy data to Madam-compatible buffers and make a map.

        Args:
            data (toast.Data): The distributed data.
        """
        if libmadam is None:
            raise RuntimeError("Cannot find libmadam")

        if len(data.obs) == 0:
            raise RuntimeError(
                'OpMadam requires every supplied data object to '
                'contain at least one observation')

        auto_timer = timing.auto_timer(type(self).__name__)

        if comm is None:
            # Just use COMM_WORLD
            comm = data.comm.comm_world

        (parstring, detstring, nsamp, ndet, nnz, nnz_full, nnz_stride, periods,
         obs_period_ranges, psdfreqs, detectors,
         nside) = self._prepare(data, comm)

        psdinfo, pixels_dtype, weight_dtype = self._stage_data(
            data, comm, nsamp, ndet, nnz, nnz_full, nnz_stride,
            obs_period_ranges, psdfreqs, detectors, nside)

        self._destripe(comm, parstring, ndet, detstring, nsamp, nnz, periods,
                       psdinfo)

        self._unstage_data(comm, data, nsamp, nnz, nnz_full, obs_period_ranges,
                           detectors, pixels_dtype, nside, weight_dtype)

        return

    def _destripe(self, comm, parstring, ndet, detstring, nsamp, nnz, periods,
                  psdinfo):
        """ Destripe the buffered data

        """
        auto_timer = timing.auto_timer(type(self).__name__)
        fcomm = comm.py2f()
        if self._cached:
            # destripe
            outpath = ''
            if 'path_output' in self.params:
                outpath = self.params['path_output']
            outpath = outpath.encode('ascii')
            libmadam.destripe_with_cache(fcomm, ndet, nsamp, nnz,
                                         self._madam_timestamps,
                                         self._madam_pixels,
                                         self._madam_pixweights,
                                         self._madam_signal, outpath)
        else:
            (detweights, npsd, npsdtot, psdstarts, npsdbin, psdfreqs, npsdval,
             psdvals) = psdinfo

            # destripe
            libmadam.destripe(fcomm, parstring.encode(), ndet,
                              detstring.encode(), detweights, nsamp, nnz,
                              self._madam_timestamps, self._madam_pixels,
                              self._madam_pixweights, self._madam_signal,
                              len(periods), periods, npsd, npsdtot, psdstarts,
                              npsdbin, psdfreqs, npsdval, psdvals)

            if self._mcmode:
                self._cached = True
        return

    def _count_samples(self, data):
        """ Loop over the observations and count the number of samples.

        """
        if len(data.obs) != 1:
            nsamp = 0
            tod0 = data.obs[0]['tod']
            detectors0 = tod0.local_dets
            for obs in data.obs:
                tod = obs['tod']
                # For the moment, we require that all observations have
                # the same set of detectors
                detectors = tod.local_dets
                dets_are_same = True
                if len(detectors0) != len(detectors):
                    dets_are_same = False
                else:
                    for det1, det2 in zip(detectors0, detectors):
                        if det1 != det2:
                            dets_are_same = False
                            break
                if not dets_are_same:
                    raise RuntimeError(
                        'When calling Madam, all TOD assigned to a process '
                        'must have the same local detectors.')
                nsamp += tod.local_samples[1]
        else:
            tod = data.obs[0]['tod']
            nsamp = tod.local_samples[1]
        return nsamp

    def _get_period_ranges(self, comm, data, detectors, nsamp):
        """ Collect the ranges of every observation.

        """
        # Discard intervals that are too short to fit a baseline
        if 'basis_order' in self.params:
            norder = int(self.params['basis_order']) + 1
        else:
            norder = 1

        psdfreqs = None
        period_lengths = []
        obs_period_ranges = []

        for obs in data.obs:
            tod = obs['tod']
            # Check that all noise objects have the same binning
            if self._noisekey in obs.keys():
                nse = obs[self._noisekey]
                if nse is not None:
                    if psdfreqs is None:
                        psdfreqs = nse.freq(detectors[0]).astype(
                            np.float64).copy()
                    for det in detectors:
                        check_psdfreqs = nse.freq(det)
                        if not np.allclose(psdfreqs, check_psdfreqs):
                            raise RuntimeError(
                                'All PSDs passed to Madam must have'
                                ' the same frequency binning.')
            # Collect the valid intervals for this observation
            period_ranges = []
            if self._intervals in obs:
                intervals = obs[self._intervals]
            else:
                intervals = None
            local_intervals = tod.local_intervals(intervals)

            for ival in local_intervals:
                local_start = ival.first
                local_stop = ival.last + 1
                if local_stop - local_start < norder:
                    continue
                period_lengths.append(local_stop - local_start)
                period_ranges.append((local_start, local_stop))
            obs_period_ranges.append(period_ranges)

        # Update the number of samples based on the valid intervals

        nsamp_tot_full = comm.allreduce(nsamp, op=MPI.SUM)
        nperiod = len(period_lengths)
        period_lengths = np.array(period_lengths, dtype=np.int64)
        nsamp = np.sum(period_lengths, dtype=np.int64)
        nsamp_tot = comm.allreduce(nsamp, op=MPI.SUM)
        if nsamp_tot == 0:
            raise RuntimeError(
                'No samples in valid intervals: nsamp_tot_full = {}, '
                'nsamp_tot = {}'.format(nsamp_tot_full, nsamp_tot))
        if comm.rank == 0:
            print('OpMadam: {:.2f} % of samples are included in valid '
                  'intervals.'.format(nsamp_tot * 100. / nsamp_tot_full))

        # Madam expects starting indices, not period lengths
        periods = np.zeros(nperiod, dtype=np.int64)
        for i, n in enumerate(period_lengths[:-1]):
            periods[i + 1] = periods[i] + n

        return obs_period_ranges, psdfreqs, periods, nsamp

    def _prepare(self, data, comm):
        """ Examine the data object.

        """
        auto_timer = timing.auto_timer(type(self).__name__)

        nsamp = self._count_samples(data)

        # Determine the detectors and the pointing matrix non-zeros
        # from the first observation. Madam will expect these to remain
        # unchanged across observations.

        tod = data.obs[0]['tod']

        if self._dets is None:
            detectors = tod.local_dets
        else:
            detectors = [det for det in tod.local_dets if det in self._dets]
        ndet = len(detectors)
        detstring = self._dets2detstring(detectors)

        # to get the number of Non-zero pointing weights per pixel,
        # we use the fact that for Madam, all processes have all detectors
        # for some slice of time.  So we can get this information from the
        # shape of the data from the first detector

        nnzname = "{}_{}".format(self._weights, detectors[0])
        nnz_full = tod.cache.reference(nnzname).shape[1]

        if 'temperature_only' in self.params \
           and self.params['temperature_only'] in [
               'T', 'True', 'TRUE', 'true', True]:
            if nnz_full not in [1, 3]:
                raise RuntimeError(
                    'OpMadam: Don\'t know how to make a temperature map '
                    'with nnz={}'.format(nnz_full))
            nnz = 1
            nnz_stride = nnz_full
        else:
            nnz = nnz_full
            nnz_stride = 1

        if 'nside_map' not in self.params:
            raise RuntimeError(
                'OpMadam: "nside_map" must be set in the parameter dictionary')
        nside = int(self.params['nside_map'])

        parstring = self._dict2parstring(self.params)

        if comm.rank == 0 and ('path_output' in self.params and
                               not os.path.isdir(self.params['path_output'])):
            os.makedirs(self.params['path_output'])

        # Inspect the valid intervals across all observations to
        # determine the number of samples per detector

        obs_period_ranges, psdfreqs, periods, nsamp = self._get_period_ranges(
            comm, data, detectors, nsamp)

        return (parstring, detstring, nsamp, ndet, nnz, nnz_full, nnz_stride,
                periods, obs_period_ranges, psdfreqs, detectors, nside)

    def _stage_time(self, data, detectors, nsamp, obs_period_ranges):
        """ Stage the timestamps and use them to build PSD inputs.

        """
        auto_timer = timing.auto_timer(type(self).__name__)
        self._madam_timestamps = self._cache.create('timestamps', np.float64,
                                                    (nsamp, ))

        offset = 0
        time_offset = 0
        psds = {}
        for iobs, obs in enumerate(data.obs):
            tod = obs['tod']
            period_ranges = obs_period_ranges[iobs]

            # Collect the timestamps for the valid intervals
            timestamps = tod.local_times().copy()
            if self._translate_timestamps:
                # Translate the time stamps to be monotonous
                timestamps -= timestamps[0] - time_offset
                time_offset = timestamps[-1] + 1

            for istart, istop in period_ranges:
                nn = istop - istart
                ind = slice(offset, offset + nn)
                self._madam_timestamps[ind] = timestamps[istart:istop]
                offset += nn

            # get the noise object for this observation and create new
            # entries in the dictionary when the PSD actually changes
            if self._noisekey in obs.keys():
                nse = obs[self._noisekey]
                if 'noise_scale' in obs:
                    noise_scale = obs['noise_scale']
                else:
                    noise_scale = 1
                if nse is not None:
                    for det in detectors:
                        psd = nse.psd(det) * noise_scale**2
                        if det not in psds:
                            psds[det] = [(0, psd)]
                        else:
                            if not np.allclose(psds[det][-1][1], psd):
                                psds[det] += [(timestamps[0], psd)]

        return psds

    def _stage_signal(self, data, detectors, nsamp, ndet, obs_period_ranges):
        """ Stage signal

        """
        auto_timer = timing.auto_timer(type(self).__name__)
        self._madam_signal = self._cache.create('signal', np.float64,
                                                (nsamp * ndet, ))
        self._madam_signal[:] = np.nan

        global_offset = 0
        for iobs, obs in enumerate(data.obs):
            tod = obs['tod']
            period_ranges = obs_period_ranges[iobs]

            for idet, det in enumerate(detectors):
                # Get the signal.
                signal = tod.local_signal(det, self._name)
                offset = global_offset
                for istart, istop in period_ranges:
                    nn = istop - istart
                    dslice = slice(idet * nsamp + offset,
                                   idet * nsamp + offset + nn)
                    self._madam_signal[dslice] = signal[istart:istop]
                    offset += nn

                del signal

            for idet, det in enumerate(detectors):
                if self._name is not None and (self._purge_tod or self._name
                                               == self._name_out):
                    cachename = "{}_{}".format(self._name, det)
                    tod.cache.clear(pattern=cachename)

            global_offset = offset

        return

    def _stage_pixels(self, data, detectors, nsamp, ndet, obs_period_ranges,
                      nside):
        """ Stage pixels

        """
        auto_timer = timing.auto_timer(type(self).__name__)
        self._madam_pixels = self._cache.create('pixels', np.int64,
                                                (nsamp * ndet, ))
        self._madam_pixels[:] = -1

        global_offset = 0
        for iobs, obs in enumerate(data.obs):
            tod = obs['tod']
            period_ranges = obs_period_ranges[iobs]

            commonflags = None
            for idet, det in enumerate(detectors):
                # Optionally get the flags, otherwise they are
                # assumed to have been applied to the pixel numbers.

                if self._apply_flags:
                    detflags = tod.local_flags(det, self._flag_name)
                    commonflags = tod.local_common_flags(
                        self._common_flag_name)
                    flags = np.logical_or(
                        (detflags & self._flag_mask) != 0,
                        (commonflags & self._common_flag_mask) != 0)
                    del detflags

                # get the pixels for the valid intervals from the cache

                pixelsname = "{}_{}".format(self._pixels, det)
                pixels = tod.cache.reference(pixelsname)
                pixels_dtype = pixels.dtype

                if not self._pixels_nested:
                    # Madam expects the pixels to be in nested ordering
                    pixels = pixels.copy()
                    good = pixels >= 0
                    pixels[good] = hp.ring2nest(nside, pixels[good])

                if self._apply_flags:
                    pixels = pixels.copy()
                    pixels[flags] = -1

                offset = global_offset
                for istart, istop in period_ranges:
                    nn = istop - istart
                    dslice = slice(idet * nsamp + offset,
                                   idet * nsamp + offset + nn)
                    self._madam_pixels[dslice] = pixels[istart:istop]
                    offset += nn

                del pixels
                if self._apply_flags:
                    del flags

            # Always purge the pixels but restore them from the Madam
            # buffers when purge_pixels=False
            for idet, det in enumerate(detectors):
                pixelsname = "{}_{}".format(self._pixels, det)
                tod.cache.clear(pattern=pixelsname)
                if self._name is not None and (self._purge_tod or self._name
                                               == self._name_out):
                    cachename = "{}_{}".format(self._name, det)
                    tod.cache.clear(pattern=cachename)
                if self._purge_flags and self._flag_name is not None:
                    cacheflagname = "{}_{}".format(self._flag_name, det)
                    tod.cache.clear(pattern=cacheflagname)

            del commonflags
            if self._purge_flags and self._common_flag_name is not None:
                tod.cache.clear(pattern=self._common_flag_name)
            global_offset = offset

        return pixels_dtype

    def _stage_pixweights(self, data, detectors, nsamp, ndet, nnz, nnz_full,
                          nnz_stride, obs_period_ranges):
        """Now collect the pixel weights

        """
        auto_timer = timing.auto_timer(type(self).__name__)

        self._madam_pixweights = self._cache.create('pixweights', np.float64,
                                                    (nsamp * ndet * nnz, ))
        self._madam_pixweights[:] = 0

        global_offset = 0
        for iobs, obs in enumerate(data.obs):
            tod = obs['tod']
            period_ranges = obs_period_ranges[iobs]
            for idet, det in enumerate(detectors):
                # get the pixels and weights for the valid intervals
                # from the cache
                weightsname = "{}_{}".format(self._weights, det)
                weights = tod.cache.reference(weightsname)
                weight_dtype = weights.dtype
                offset = global_offset
                for istart, istop in period_ranges:
                    nn = istop - istart
                    dwslice = slice((idet * nsamp + offset) * nnz,
                                    (idet * nsamp + offset + nn) * nnz)
                    self._madam_pixweights[dwslice] = weights[
                        istart:istop].flatten()[::nnz_stride]
                    offset += nn
                del weights
            # Purge the weights but restore them from the Madam
            # buffers when purge_weights=False.
            # Handle special case when Madam only stores a subset of
            # the weights.
            if not self._purge_weights and (nnz != nnz_full):
                pass
            else:
                for idet, det in enumerate(detectors):
                    # get the pixels and weights for the valid intervals
                    # from the cache
                    weightsname = "{}_{}".format(self._weights, det)
                    tod.cache.clear(pattern=weightsname)

            global_offset = offset

        return weight_dtype

    def _stage_data(self, data, comm, nsamp, ndet, nnz, nnz_full, nnz_stride,
                    obs_period_ranges, psdfreqs, detectors, nside):
        """ create madam-compatible buffers

        Collect the TOD into Madam buffers. Process pixel weights
        Separate from the rest to reduce the memory high water mark
        When the user has set purge=True

        Moving data between toast and Madam buffers has an overhead.
        We perform the operation in a staggered fashion to have the
        overhead only once per node.

        """
        auto_timer = timing.auto_timer(type(self).__name__)

        if self._conserve_memory:
            nodecomm = comm.Split_type(MPI.COMM_TYPE_SHARED, comm.rank)
            nread = nodecomm.size
        else:
            nodecomm = MPI.COMM_SELF
            nread = 1

        for iread in range(nread):
            nodecomm.Barrier()
            if nodecomm.rank % nread != iread:
                continue
            psds = self._stage_time(data, detectors, nsamp, obs_period_ranges)
            self._stage_signal(data, detectors, nsamp, ndet, obs_period_ranges)
            pixels_dtype = self._stage_pixels(data, detectors, nsamp, ndet,
                                              obs_period_ranges, nside)
            weight_dtype = self._stage_pixweights(data, detectors, nsamp, ndet,
                                                  nnz, nnz_full, nnz_stride,
                                                  obs_period_ranges)
        del nodecomm

        # detweights is either a dictionary of weights specified at
        # construction time, or else we use uniform weighting.
        detw = {}
        if self._detw is None:
            for idet, det in enumerate(detectors):
                detw[det] = 1.0
        else:
            detw = self._detw

        detweights = np.zeros(ndet, dtype=np.float64)
        for idet, det in enumerate(detectors):
            detweights[idet] = detw[det]

        if len(psds) > 0:
            npsdbin = len(psdfreqs)

            npsd = np.zeros(ndet, dtype=np.int64)
            psdstarts = []
            psdvals = []
            for idet, det in enumerate(detectors):
                if det not in psds:
                    raise RuntimeError('Every detector must have at least '
                                       'one PSD')
                psdlist = psds[det]
                npsd[idet] = len(psdlist)
                for psdstart, psd in psdlist:
                    psdstarts.append(psdstart)
                    psdvals.append(psd)
            npsdtot = np.sum(npsd)
            psdstarts = np.array(psdstarts, dtype=np.float64)
            psdvals = np.hstack(psdvals).astype(np.float64)
            npsdval = psdvals.size
        else:
            npsd = np.ones(ndet, dtype=np.int64)
            npsdtot = np.sum(npsd)
            psdstarts = np.zeros(npsdtot)
            npsdbin = 10
            fsample = 10.
            psdfreqs = np.arange(npsdbin) * fsample / npsdbin
            npsdval = npsdbin * npsdtot
            psdvals = np.ones(npsdval)
        psdinfo = (detweights, npsd, npsdtot, psdstarts, npsdbin, psdfreqs,
                   npsdval, psdvals)

        return psdinfo, pixels_dtype, weight_dtype

    def _unstage_data(self, comm, data, nsamp, nnz, nnz_full,
                      obs_period_ranges, detectors, pixels_dtype, nside,
                      weight_dtype):
        """ Clear Madam buffers, restore pointing into TOAST caches
        and cache the destriped signal.

        """
        auto_timer = timing.auto_timer(type(self).__name__)
        self._madam_timestamps = None
        self._cache.destroy('timestamps')

        if self._conserve_memory:
            nodecomm = comm.Split_type(MPI.COMM_TYPE_SHARED, comm.rank)
            nread = nodecomm.size
        else:
            nodecomm = MPI.COMM_SELF
            nread = 1

        for iread in range(nread):
            nodecomm.Barrier()
            if nodecomm.rank % nread != iread:
                continue
            if self._name_out is not None:
                global_offset = 0
                for obs, period_ranges in zip(data.obs, obs_period_ranges):
                    tod = obs['tod']
                    nlocal = tod.local_samples[1]
                    for idet, det in enumerate(detectors):
                        signal = np.ones(nlocal) * np.nan
                        offset = global_offset
                        for istart, istop in period_ranges:
                            nn = istop - istart
                            dslice = slice(idet * nsamp + offset,
                                           idet * nsamp + offset + nn)
                            signal[istart:istop] = self._madam_signal[dslice]
                            offset += nn
                        cachename = "{}_{}".format(self._name_out, det)
                        tod.cache.put(cachename, signal, replace=True)
                    global_offset = offset
            self._madam_signal = None
            self._cache.destroy('signal')

            if not self._purge_pixels:
                # restore the pixels from the Madam buffers
                global_offset = 0
                for obs, period_ranges in zip(data.obs, obs_period_ranges):
                    tod = obs['tod']
                    nlocal = tod.local_samples[1]
                    for idet, det in enumerate(detectors):
                        pixels = -np.ones(nlocal, dtype=pixels_dtype)
                        offset = global_offset
                        for istart, istop in period_ranges:
                            nn = istop - istart
                            dslice = slice(idet * nsamp + offset,
                                           idet * nsamp + offset + nn)
                            pixels[istart:istop] = self._madam_pixels[dslice]
                            offset += nn
                        npix = 12 * nside**2
                        good = np.logical_and(pixels >= 0, pixels < npix)
                        if not self._pixels_nested:
                            pixels[good] = hp.nest2ring(nside, pixels[good])
                        pixels[np.logical_not(good)] = -1
                        cachename = "{}_{}".format(self._pixels, det)
                        tod.cache.put(cachename, pixels, replace=True)
                    global_offset = offset
            self._madam_pixels = None
            self._cache.destroy('pixels')

            if not self._purge_weights and nnz == nnz_full:
                # restore the weights from the Madam buffers
                global_offset = 0
                for obs, period_ranges in zip(data.obs, obs_period_ranges):
                    tod = obs['tod']
                    nlocal = tod.local_samples[1]
                    for idet, det in enumerate(detectors):
                        weights = np.zeros([nlocal, nnz], dtype=weight_dtype)
                        offset = global_offset
                        for istart, istop in period_ranges:
                            nn = istop - istart
                            dwslice = slice((idet * nsamp + offset) * nnz,
                                            (idet * nsamp + offset + nn) * nnz)
                            weights[istart:istop] = self._madam_pixweights[
                                dwslice].reshape([-1, nnz])
                            offset += nn
                        cachename = "{}_{}".format(self._weights, det)
                        tod.cache.put(cachename, weights, replace=True)
                    global_offset = offset
            self._madam_pixweights = None
            self._cache.destroy('pixweights')
        del nodecomm
        return
Exemple #4
0
    def __init__(self,
                 params={},
                 detweights=None,
                 pixels='pixels',
                 pixels_nested=True,
                 weights='weights',
                 name=None,
                 name_out=None,
                 flag_name=None,
                 flag_mask=255,
                 common_flag_name=None,
                 common_flag_mask=255,
                 apply_flags=True,
                 purge=False,
                 dets=None,
                 mcmode=False,
                 purge_tod=False,
                 purge_pixels=False,
                 purge_weights=False,
                 purge_flags=False,
                 noise='noise',
                 intervals='intervals',
                 conserve_memory=True,
                 translate_timestamps=True):

        # We call the parent class constructor, which currently does nothing
        super().__init__()
        # madam uses time-based distribution
        self._name = name
        self._name_out = name_out
        self._flag_name = flag_name
        self._flag_mask = flag_mask
        self._common_flag_name = common_flag_name
        self._common_flag_mask = common_flag_mask
        self._pixels = pixels
        self._pixels_nested = pixels_nested
        self._weights = weights
        self._detw = detweights
        self._purge = purge
        if self._purge:
            self._purge_tod = True
            self._purge_pixels = True
            self._purge_weights = True
            self._purge_flags = True
        else:
            self._purge_tod = purge_tod
            self._purge_pixels = purge_pixels
            self._purge_weights = purge_weights
            self._purge_flags = purge_flags
        self._apply_flags = apply_flags
        self.params = params
        if dets is not None:
            self._dets = set(dets)
        else:
            self._dets = None
        self._mcmode = mcmode
        if mcmode:
            self.params['mcmode'] = True
        else:
            self.params['mcmode'] = False
        if self._name_out is not None:
            self.params['write_tod'] = True
        else:
            self.params['write_tod'] = False
        self._cached = False
        self._noisekey = noise
        self._intervals = intervals
        self._cache = Cache()
        self._madam_timestamps = None
        self._madam_pixels = None
        self._madam_pixweights = None
        self._madam_signal = None
        self._conserve_memory = conserve_memory
        self._translate_timestamps = translate_timestamps
    def __init__(self,
                 nside,
                 file_sync,
                 file_sync_pol,
                 file_freefree,
                 file_ame,
                 file_dust,
                 file_dust_pol,
                 comm,
                 fwhm=0,
                 verbose=False,
                 groupsize=256,
                 quickpolbeam=None):
        # We have to split the communicator because there is not enough
        # work for every process.
        self.comm = comm.Split(color=(comm.rank // groupsize), key=comm.rank)
        self.global_rank = comm.rank
        self.rank = self.comm.rank
        self.ntask = self.comm.size

        self.cache = Cache()
        self.verbose = verbose

        # Divide effort for loading, smoothing and downgrading the input maps
        id_step = 16
        self.id_sync = 0
        self.id_sync_pol = (self.id_sync + id_step) % self.ntask
        self.id_ff = (self.id_sync_pol + id_step) % self.ntask
        self.id_ame = (self.id_ff + id_step) % self.ntask
        self.id_dust = (self.id_ame + id_step) % self.ntask
        self.id_dust_pol = (self.id_dust + id_step) % self.ntask

        # Set up a libsharp processing grid
        self.nside = nside
        self.npix = 12 * self.nside**2
        self.dist_rings = {}
        self.my_pix = self.get_my_pix(self.nside)
        self.my_npix = len(self.my_pix)

        # Minimum smoothing in every returned component
        self.fwhm = fwhm
        self.quickpolbeam = quickpolbeam

        if self.verbose and self.global_rank == 0:
            print('Initializing SkyModel. nside = {}, fwhm = {}, '
                  'quickpolbeam = {}'.format(self.nside, self.fwhm,
                                             self.quickpolbeam),
                  flush=True)

        self.file_sync = file_sync
        self.file_sync_pol = file_sync_pol
        self.file_freefree = file_freefree
        self.file_ame = file_ame
        self.file_dust = file_dust
        self.file_dust_pol = file_dust_pol

        # Load the model components

        self.load_synchrotron_temperature()
        self.load_synchrotron_polarization()
        self.load_freefree()
        self.load_ame()
        self.load_dust_temperature()
        self.load_dust_polarization()

        # Distribute the inputs

        self.comm.Barrier()
        self.broadcast_synchrotron()
        self.broadcast_freefree()
        self.broadcast_ame()
        self.broadcast_dust_temperature()
        self.broadcast_dust_polarization()

        return
class SkyModel():
    def __init__(self,
                 nside,
                 file_sync,
                 file_sync_pol,
                 file_freefree,
                 file_ame,
                 file_dust,
                 file_dust_pol,
                 comm,
                 fwhm=0,
                 verbose=False,
                 groupsize=256,
                 quickpolbeam=None):
        # We have to split the communicator because there is not enough
        # work for every process.
        self.comm = comm.Split(color=(comm.rank // groupsize), key=comm.rank)
        self.global_rank = comm.rank
        self.rank = self.comm.rank
        self.ntask = self.comm.size

        self.cache = Cache()
        self.verbose = verbose

        # Divide effort for loading, smoothing and downgrading the input maps
        id_step = 16
        self.id_sync = 0
        self.id_sync_pol = (self.id_sync + id_step) % self.ntask
        self.id_ff = (self.id_sync_pol + id_step) % self.ntask
        self.id_ame = (self.id_ff + id_step) % self.ntask
        self.id_dust = (self.id_ame + id_step) % self.ntask
        self.id_dust_pol = (self.id_dust + id_step) % self.ntask

        # Set up a libsharp processing grid
        self.nside = nside
        self.npix = 12 * self.nside**2
        self.dist_rings = {}
        self.my_pix = self.get_my_pix(self.nside)
        self.my_npix = len(self.my_pix)

        # Minimum smoothing in every returned component
        self.fwhm = fwhm
        self.quickpolbeam = quickpolbeam

        if self.verbose and self.global_rank == 0:
            print('Initializing SkyModel. nside = {}, fwhm = {}, '
                  'quickpolbeam = {}'.format(self.nside, self.fwhm,
                                             self.quickpolbeam),
                  flush=True)

        self.file_sync = file_sync
        self.file_sync_pol = file_sync_pol
        self.file_freefree = file_freefree
        self.file_ame = file_ame
        self.file_dust = file_dust
        self.file_dust_pol = file_dust_pol

        # Load the model components

        self.load_synchrotron_temperature()
        self.load_synchrotron_polarization()
        self.load_freefree()
        self.load_ame()
        self.load_dust_temperature()
        self.load_dust_polarization()

        # Distribute the inputs

        self.comm.Barrier()
        self.broadcast_synchrotron()
        self.broadcast_freefree()
        self.broadcast_ame()
        self.broadcast_dust_temperature()
        self.broadcast_dust_polarization()

        return

    def __del__(self):

        del self.my_pix
        del self.dist_rings
        self.comm.Free()
        del self.sync_As
        del self.sync_beta
        del self.sync_As_Q
        del self.sync_As_U
        del self.sync_pol_beta
        del self.ff_em
        del self.ff_amp
        del self.ff_T_e
        del self.ame_1
        del self.ame_2
        del self.ame_nu_p_1
        del self.dust_Ad
        del self.dust_Ad_Q
        del self.dust_Ad_U
        del self.dust_temp
        del self.dust_beta
        del self.dust_temp_pol
        del self.dust_beta_pol

        self.cache.clear()

    def optimal_lmax(self, fwhm_in, nside_in):
        lmax = 2 * min(nside_in, self.nside)
        if fwhm_in < self.fwhm and self.quickpolbeam is None:
            beam = hp.gauss_beam(self.fwhm * arcmin, lmax=lmax, pol=False)
            better_lmax = np.argmin(np.abs(beam - 1e-4)) + 1
            if better_lmax < lmax:
                lmax = better_lmax
        return lmax

    def total_beam(self, fwhm_in, lmax, pol=False):
        total_beam = None
        if fwhm_in < .99 * self.fwhm:
            if self.quickpolbeam is None:
                total_beam = hp.gauss_beam(self.fwhm * arcmin,
                                           lmax=lmax,
                                           pol=True)
                total_beam = total_beam[:, 0:3].copy()
            else:
                total_beam = np.array(hp.read_cl(self.quickpolbeam))
                if total_beam.ndim == 1:
                    total_beam = np.vstack(
                        [total_beam, total_beam, total_beam])
                total_beam = total_beam[:, :lmax + 1].T.copy()
            beam_in = hp.gauss_beam(fwhm_in * arcmin, lmax=lmax, pol=True)
            beam_in = beam_in[:, 0:3].copy()
            good = beam_in != 0
            total_beam[good] /= beam_in[good]
            if pol:
                total_beam = np.ascontiguousarray(total_beam[:, (1, 2)])
            else:
                total_beam = np.ascontiguousarray(total_beam[:, 0:1])
        return total_beam

    def smooth(self, fwhm_in, nside_in, maps_in):
        """ Smooth a distributed map and change the resolution.
        """
        if fwhm_in > .9 * self.fwhm and nside_in == self.nside:
            return maps_in
        if fwhm_in > .9 * self.fwhm and self.nside < nside_in:
            # Simple ud_grade
            if self.global_rank == 0 and self.verbose:
                print('Downgrading Nside {} -> {}'
                      ''.format(nside_in, self.nside),
                      flush=True)
            maps_out = []
            npix_in = hp.nside2npix(nside_in)
            my_pix_in = self.get_my_pix(nside_in)
            for m in maps_in:
                my_outmap = np.zeros(npix_in, dtype=np.float)
                outmap = np.zeros(npix_in, dtype=np.float)
                my_outmap[my_pix_in] = m
                self.comm.Allreduce(my_outmap, outmap)
                del my_outmap
                maps_out.append(hp.ud_grade(outmap, self.nside)[self.my_pix])
        else:
            # Full smoothing
            lmax = self.optimal_lmax(fwhm_in, nside_in)
            total_beam = self.total_beam(fwhm_in, lmax, pol=False)
            if self.global_rank == 0 and self.verbose:
                print('Smoothing {} -> {}. lmax = {}. Nside {} -> {}'
                      ''.format(fwhm_in, self.fwhm, lmax, nside_in,
                                self.nside),
                      flush=True)
            local_m = np.arange(self.rank,
                                lmax + 1,
                                self.ntask,
                                dtype=np.int32)
            alminfo = packed_real_order(lmax, ms=local_m)
            grid_in = self.get_grid(nside_in)
            grid_out = self.get_grid(self.nside)
            maps_out = []
            for local_map in maps_in:
                map_I = np.ascontiguousarray(local_map.reshape([1, 1, -1]),
                                             dtype=np.float64)
                alm_I = analysis(grid_in,
                                 alminfo,
                                 map_I,
                                 spin=0,
                                 comm=self.comm)
                if total_beam is not None:
                    alminfo.almxfl(alm_I, total_beam)
                map_I = synthesis(grid_out,
                                  alminfo,
                                  alm_I,
                                  spin=0,
                                  comm=self.comm)[0][0]
                maps_out.append(map_I)
        return maps_out

    def smooth_pol(self, fwhm_in, nside_in, maps_in):
        if fwhm_in > .9 * self.fwhm and nside_in == self.nside:
            return maps_in
        if fwhm_in > .9 * self.fwhm and self.nside < nside_in:
            # Simple ud_grade
            if self.global_rank == 0 and self.verbose:
                print('Downgrading Nside {} -> {}'
                      ''.format(nside_in, self.nside),
                      flush=True)
            maps_out = []
            npix_in = hp.nside2npix(nside_in)
            my_pix_in = self.get_my_pix(nside_in)
            for (qmap, umap) in maps_in:
                my_mapout = np.zeros(npix_in, dtype=np.float)
                qmapout = np.zeros(npix_in, dtype=np.float)
                umapout = np.zeros(npix_in, dtype=np.float)
                my_mapout[my_pix_in] = qmap
                self.comm.Allreduce(my_mapout, qmapout)
                my_mapout[my_pix_in] = umap
                self.comm.Allreduce(my_mapout, umapout)
                del my_mapout
                maps_out.append((hp.ud_grade(qmapout, self.nside)[self.my_pix],
                                 hp.ud_grade(umapout,
                                             self.nside)[self.my_pix]))
        else:
            # Full smoothing
            lmax = self.optimal_lmax(fwhm_in, nside_in)
            total_beam = self.total_beam(fwhm_in, lmax, pol=True)
            if self.global_rank == 0 and self.verbose:
                print('Smoothing {} -> {}. lmax = {}. Nside {} -> {}'
                      ''.format(fwhm_in, self.fwhm, lmax, nside_in,
                                self.nside),
                      flush=True)
            local_m = np.arange(self.rank,
                                lmax + 1,
                                self.ntask,
                                dtype=np.int32)
            alminfo = packed_real_order(lmax, ms=local_m)
            grid_in = self.get_grid(nside_in)
            grid_out = self.get_grid(self.nside)
            maps_out = []
            for (local_map_Q, local_map_U) in maps_in:
                map_P = np.ascontiguousarray(np.vstack(
                    [local_map_Q, local_map_U]).reshape((1, 2, -1)),
                                             dtype=np.float64)
                alm_P = analysis(grid_in,
                                 alminfo,
                                 map_P,
                                 spin=2,
                                 comm=self.comm)
                if total_beam is not None:
                    alminfo.almxfl(alm_P, total_beam)
                map_P = synthesis(grid_out,
                                  alminfo,
                                  alm_P,
                                  spin=2,
                                  comm=self.comm)[0]
                maps_out.append(map_P)
        return maps_out

    def load_synchrotron_temperature(self):
        # # Synchrotron temperature
        if self.rank == self.id_sync:
            try:
                # Try old format first
                with pf.open(self.file_sync, 'readonly') as h:
                    self.sync_psd_freq = h[2].data.field(0)
                    self.sync_psd = h[2].data.field(1)
                    self.sync_nu_ref = np.float(
                        h[1].header['nu_ref'].split()[0]) * 1e-3  # To GHz
                    self.sync_fwhm = h[1].header['fwhm']
                self.sync_As = hp.read_map(self.file_sync,
                                           verbose=False,
                                           dtype=DTYPE,
                                           memmap=True)
                self.sync_nside = hp.get_nside(self.sync_As)
                self.sync_beta = None
                if self.verbose:
                    print('Loaded synchrotron T: nside = {}, fwhm = {}'.format(
                        self.sync_nside, self.sync_fwhm),
                          flush=True)
            except Exception as e:
                if self.verbose:
                    print('Old synchrotron T format failed ("{}"). Trying new '
                          'format'.format(e),
                          flush=True)
                self.sync_fwhm = 60.
                self.sync_nu_ref = 0.408
                self.sync_As = hp.read_map(self.file_sync,
                                           verbose=False,
                                           dtype=DTYPE,
                                           memmap=True)
                self.sync_beta = hp.read_map(self.file_sync.replace(
                    'synch_', 'synch_beta_'),
                                             verbose=False,
                                             dtype=DTYPE,
                                             memmap=True)
                self.sync_nside = hp.get_nside(self.sync_As)
                self.sync_psd_freq = None
                self.sync_psd = None
                if self.verbose:
                    print('Loaded synchrotron T: nside = {}, fwhm = {}'.format(
                        self.sync_nside, self.sync_fwhm),
                          flush=True)
        else:
            self.sync_As = None
            self.sync_beta = None
            self.sync_psd_freq = None
            self.sync_psd = None
            self.sync_nu_ref = None
            self.sync_fwhm = None
            self.sync_nside = None
        return

    def load_synchrotron_polarization(self):
        # # Synchrotron polarization
        if self.rank == self.id_sync_pol:
            try:
                # Try old format first
                with pf.open(self.file_sync_pol) as h:
                    self.sync_pol_nu_ref = np.float(
                        h[1].header['nu_ref'].split()[0])  # In GHz
                    self.sync_pol_fwhm = h[1].header['fwhm']
                self.sync_As_Q, self.sync_As_U = hp.read_map(
                    self.file_sync_pol, [0, 1],
                    verbose=False,
                    dtype=DTYPE,
                    memmap=True)
                self.sync_pol_nside = hp.get_nside(self.sync_As_Q)
                self.sync_pol_psd_freq, self.sync_pol_psd = np.genfromtxt(
                    os.path.join(DATADIR, 'synchrotron_psd_2015.dat')).T
                self.sync_pol_beta = None
                if self.verbose:
                    print('Loaded synchrotron P: nside = {}, fwhm = {}'.format(
                        self.sync_pol_nside, self.sync_pol_fwhm),
                          flush=True)
            except Exception as e:
                if self.verbose:
                    print('Old synchrotron T format failed ("{}"). Trying new '
                          'format'.format(e),
                          flush=True)
                self.sync_pol_fwhm = 40.
                self.sync_pol_nu_ref = 30.
                self.sync_As_Q, self.sync_As_U = hp.read_map(
                    self.file_sync_pol, [1, 2],
                    verbose=False,
                    dtype=DTYPE,
                    memmap=True)
                self.sync_pol_beta = hp.read_map(self.file_sync.replace(
                    'synch_', 'synch_beta_'),
                                                 verbose=False,
                                                 dtype=DTYPE,
                                                 memmap=True)
                self.sync_pol_nside = hp.get_nside(self.sync_As_Q)
                self.sync_pol_psd_freq = None
                self.sync_pol_psd = None
                if self.verbose:
                    print('Loaded synchrotron P: nside = {}, fwhm = {}'.format(
                        self.sync_pol_nside, self.sync_pol_fwhm),
                          flush=True)
        else:
            self.sync_As_Q = None
            self.sync_As_U = None
            self.sync_pol_beta = None
            self.sync_pol_psd_freq = None
            self.sync_pol_psd = None
            self.sync_pol_nu_ref = None
            self.sync_pol_fwhm = None
            self.sync_pol_nside = None
        return

    def load_freefree(self):
        # # free-free
        if self.rank == self.id_ff:
            try:
                # Try old format first
                with pf.open(self.file_freefree) as h:
                    self.ff_fwhm = h[1].header['fwhm']
                self.ff_em, self.ff_T_e = hp.read_map(self.file_freefree,
                                                      [0, 3],
                                                      verbose=False,
                                                      dtype=DTYPE,
                                                      memmap=True)
                self.ff_nside = hp.get_nside(self.ff_em)
                self.ff_nu_ref = None
                self.ff_amp = None
                if self.verbose:
                    print('Loaded freefree: nside = {}, fwhm = {}'.format(
                        self.ff_nside, self.ff_fwhm),
                          flush=True)
            except Exception as e:
                if self.verbose:
                    print('Old freefree format failed ("{}"). Trying new '
                          'format'.format(e),
                          flush=True)
                self.ff_fwhm = 20.
                self.ff_nu_ref = 1.4
                self.ff_amp = hp.read_map(self.file_freefree,
                                          verbose=False,
                                          dtype=DTYPE,
                                          memmap=True)
                # self.ff_em = hp.read_map(
                #    self.file_freefree.replace('ff_', 'ff_EM'), verbose=False,
                #    dtype=DTYPE, memmap=True)
                self.ff_T_e = hp.read_map(self.file_freefree.replace(
                    'ff_', 'ff_Te_'),
                                          verbose=False,
                                          dtype=DTYPE,
                                          memmap=True)
                self.ff_nside = hp.get_nside(self.ff_amp)
                self.ff_em = None
                if self.verbose:
                    print('Loaded freefree: nside = {}, fwhm = {}'.format(
                        self.ff_nside, self.ff_fwhm),
                          flush=True)
        else:
            self.ff_amp = None
            self.ff_em = None
            self.ff_T_e = None
            self.ff_fwhm = None
            self.ff_nside = None
            self.ff_nu_ref = None
        return

    def load_ame(self):
        # # spinning dust
        if self.rank == self.id_ame:
            try:
                # Try old format first
                self.ame_nu_p0 = 30.
                with pf.open(self.file_ame) as h:
                    self.ame_fwhm = h[1].header['fwhm']
                    # All frequencies are in GHz
                    self.ame_nu_ref1 = np.float(
                        h[1].header['nu_ref'].split()[0])
                    self.ame_nu_ref2 = np.float(
                        h[2].header['nu_ref'].split()[0])
                    self.ame_nu_p_2 = np.float(h[2].header['nu_p'].split()[0])
                    self.ame_psd_freq = h[3].data.field(0)
                    self.ame_psd = h[3].data.field(1)
                self.ame_1, self.ame_nu_p_1 = hp.read_map(self.file_ame,
                                                          [0, 3],
                                                          verbose=False,
                                                          dtype=DTYPE,
                                                          memmap=True)
                self.ame_2 = hp.read_map(self.file_ame,
                                         hdu=2,
                                         verbose=False,
                                         dtype=DTYPE,
                                         memmap=True)
                self.ame_nside = hp.get_nside(self.ame_1)
                if self.verbose:
                    print('Loaded AME: nside = {}, fwhm = {}'.format(
                        self.ame_nside, self.ame_fwhm),
                          flush=True)
            except Exception as e:
                if self.verbose:
                    print('Old AME format failed ("{}"). Trying new '
                          'format'.format(e),
                          flush=True)
                self.ame_nu_p0 = 22.2  # GHz
                self.ame_fwhm = 30.  # GHz
                self.ame_nu_ref1 = 30.  # arc min
                self.ame_1 = hp.read_map(self.file_ame,
                                         verbose=False,
                                         dtype=DTYPE,
                                         memmap=True)
                self.ame_nu_p_1 = hp.read_map(self.file_ame.replace(
                    'ame_', 'ame_nu_p_'),
                                              verbose=False,
                                              dtype=DTYPE,
                                              memmap=True)
                self.ame_nu_ref2 = None
                self.ame_2 = None
                self.ame_nu_p_2 = None
                self.ame_psd_freq, self.ame_psd = np.genfromtxt(
                    os.path.join(DATADIR, 'spdust2_cnm.dat')).T
                self.ame_nside = hp.get_nside(self.ame_1)
                if self.verbose:
                    print('Loaded AME: nside = {}, fwhm = {}'.format(
                        self.ame_nside, self.ame_fwhm),
                          flush=True)
        else:
            self.ame_1 = None
            self.ame_2 = None
            self.ame_nu_p0 = None
            self.ame_nu_p_1 = None
            self.ame_nu_p_2 = None
            self.ame_nu_ref1 = None
            self.ame_nu_ref2 = None
            self.ame_fwhm = None
            self.ame_nside = None
            self.ame_psd_freq = None
            self.ame_psd = None
        return

    def load_dust_temperature(self):
        # # Thermal dust temperature
        if self.rank == self.id_dust:
            try:
                # Try old format first
                with pf.open(self.file_dust) as h:
                    self.dust_fwhm = h[1].header['fwhm']
                    self.dust_nu_ref = np.float(
                        h[1].header['nu_ref'].split()[0])  # in GHz
                self.dust_Ad, self.dust_temp, self.dust_beta = hp.read_map(
                    self.file_dust, [0, 3, 6],
                    verbose=False,
                    dtype=DTYPE,
                    memmap=True)
            except Exception as e:
                if self.verbose:
                    print('Old dust T format failed ("{}"). Trying new '
                          'format'.format(e),
                          flush=True)
                self.dust_fwhm = 5.
                self.dust_nu_ref = 857
                self.dust_Ad = hp.read_map(self.file_dust,
                                           verbose=False,
                                           dtype=DTYPE,
                                           memmap=True)
                self.dust_temp = hp.read_map(self.file_dust.replace(
                    'dust', 'dust_T'),
                                             verbose=False,
                                             dtype=DTYPE,
                                             memmap=True)
                self.dust_beta = hp.read_map(self.file_dust.replace(
                    'dust', 'dust_beta'),
                                             verbose=False,
                                             dtype=DTYPE,
                                             memmap=True)
            self.dust_nside = hp.get_nside(self.dust_Ad)
            if self.verbose:
                print('Loaded dust T: nside = {}, fwhm = {}'.format(
                    self.dust_nside, self.dust_fwhm),
                      flush=True)
        else:
            self.dust_Ad = None
            self.dust_temp = None
            self.dust_beta = None
            self.dust_fwhm = None
            self.dust_nside = None
            self.dust_nu_ref = None
        return

    def load_dust_polarization(self):
        # # Thermal dust polarization
        if self.rank == self.id_dust_pol:
            try:
                # Try old format first
                with pf.open(self.file_dust_pol) as h:
                    self.dust_pol_fwhm = h[1].header['fwhm']
                    self.dust_pol_nu_ref = np.float(
                        h[1].header['nu_ref'].split()[0])  # in GHz
                self.dust_temp_pol = None
                self.dust_beta_pol = None
                self.dust_Ad_Q, self.dust_Ad_U = hp.read_map(
                    self.file_dust_pol,
                    range(2),
                    verbose=False,
                    dtype=DTYPE,
                    memmap=True)
            except Exception as e:
                if self.verbose:
                    print('Old dust P format failed ("{}"). Trying new '
                          'format'.format(e),
                          flush=True)
                self.dust_pol_fwhm = 5
                self.dust_pol_nu_ref = 353
                self.dust_Ad_Q, self.dust_Ad_U = hp.read_map(
                    self.file_dust_pol, [1, 2],
                    verbose=False,
                    dtype=DTYPE,
                    memmap=True)
                fname = self.file_dust_pol.replace('dust', 'dust_T')
                self.dust_temp_pol = hp.read_map(fname,
                                                 1,
                                                 verbose=False,
                                                 dtype=DTYPE,
                                                 memmap=True)
                tlim = 12
                bad = self.dust_temp_pol < tlim
                nbad = np.sum(bad)
                if nbad > 0:
                    print('WARNING: regularizing {} cold pixels in {}'
                          ''.format(nbad, fname),
                          flush=True)
                    self.dust_temp_pol[bad] = tlim
                self.dust_beta_pol = hp.read_map(self.file_dust_pol.replace(
                    'dust', 'dust_beta'),
                                                 1,
                                                 verbose=False,
                                                 dtype=DTYPE,
                                                 memmap=True)
            self.dust_pol_nside = hp.get_nside(self.dust_Ad_Q)
            if self.verbose:
                print('Loaded dust P: nside = {}, fwhm = {}'.format(
                    self.dust_pol_nside, self.dust_pol_fwhm),
                      flush=True)
        else:
            self.dust_Ad_Q = None
            self.dust_Ad_U = None
            self.dust_temp_pol = None
            self.dust_beta_pol = None
            self.dust_pol_fwhm = None
            self.dust_pol_nside = None
            self.dust_pol_nu_ref = None
        return

    def get_my_pix(self, nside):
        if nside not in self.dist_rings:
            self.dist_rings[nside] = DistRings(self.comm, nside=nside, nnz=3)
        return self.dist_rings[nside].local_pixels

    def get_grid(self, nside):
        if nside not in self.dist_rings:
            self.dist_rings[nside] = DistRings(self.comm, nside=nside, nnz=3)
        return self.dist_rings[nside].libsharp_grid

    def broadcast_synchrotron(self):
        # Broadcast synchrotron temperature
        root = self.id_sync
        self.sync_psd_freq = self.comm.bcast(self.sync_psd_freq, root=root)
        self.sync_psd = self.comm.bcast(self.sync_psd, root=root)
        self.sync_nu_ref = self.comm.bcast(self.sync_nu_ref, root=root)
        self.sync_fwhm = self.comm.bcast(self.sync_fwhm, root=root)
        self.sync_nside = self.comm.bcast(self.sync_nside, root=root)
        my_pix = self.get_my_pix(self.sync_nside)
        self.sync_As = self.cache.put(
            'sync_As',
            self.comm.bcast(self.sync_As, root=root)[my_pix])
        if self.sync_psd is None:
            # New format
            self.sync_beta = self.cache.put(
                'sync_beta',
                self.comm.bcast(self.sync_beta, root=root)[my_pix])
        # Broadcast synchrotron polarization
        root = self.id_sync_pol
        self.sync_pol_psd_freq = self.comm.bcast(self.sync_pol_psd_freq,
                                                 root=root)
        self.sync_pol_psd = self.comm.bcast(self.sync_pol_psd, root=root)
        self.sync_pol_nu_ref = self.comm.bcast(self.sync_pol_nu_ref, root=root)
        self.sync_pol_fwhm = self.comm.bcast(self.sync_pol_fwhm, root=root)
        self.sync_pol_nside = self.comm.bcast(self.sync_pol_nside, root=root)
        my_pix = self.get_my_pix(self.sync_pol_nside)
        self.sync_pol_beta = self.comm.bcast(self.sync_pol_beta, root=root)
        if self.sync_pol_beta is not None:
            self.sync_pol_beta = self.cache.put('sync_pol_beta',
                                                self.sync_pol_beta[my_pix])
        self.sync_As_Q = self.cache.put(
            'sync_As_Q',
            self.comm.bcast(self.sync_As_Q, root=root)[my_pix])
        self.sync_As_U = self.cache.put(
            'sync_As_U',
            self.comm.bcast(self.sync_As_U, root=root)[my_pix])
        return

    def broadcast_freefree(self):
        # Broadcast free-free
        root = self.id_ff
        self.ff_nu_ref = self.comm.bcast(self.ff_nu_ref, root=root)
        self.ff_fwhm = self.comm.bcast(self.ff_fwhm, root=root)
        self.ff_nside = self.comm.bcast(self.ff_nside, root=root)
        my_pix = self.get_my_pix(self.ff_nside)
        self.ff_amp = self.comm.bcast(self.ff_amp, root=root)
        if self.ff_amp is not None:
            self.ff_amp = self.cache.put('ff_amp', self.ff_amp[my_pix])
        self.ff_em = self.comm.bcast(self.ff_em, root=root)
        if self.ff_em is not None:
            self.ff_em = self.cache.put('ff_em', self.ff_em[my_pix])
        self.ff_T_e = self.cache.put(
            'ff_T_e',
            self.comm.bcast(self.ff_T_e, root=root)[my_pix])
        return

    def broadcast_ame(self):
        # Broadcast AME
        root = self.id_ame
        self.ame_nu_p0 = self.comm.bcast(self.ame_nu_p0, root=root)
        self.ame_nu_p_2 = self.comm.bcast(self.ame_nu_p_2, root=root)
        self.ame_nu_ref1 = self.comm.bcast(self.ame_nu_ref1, root=root)
        self.ame_nu_ref2 = self.comm.bcast(self.ame_nu_ref2, root=root)
        self.ame_fwhm = self.comm.bcast(self.ame_fwhm, root=root)
        self.ame_nside = self.comm.bcast(self.ame_nside, root=root)
        self.ame_psd_freq = self.comm.bcast(self.ame_psd_freq, root=root)
        self.ame_psd = self.comm.bcast(self.ame_psd, root=root)
        my_pix = self.get_my_pix(self.ame_nside)
        self.ame_1 = self.cache.put(
            'ame_1',
            self.comm.bcast(self.ame_1, root=root)[my_pix])
        self.ame_2 = self.comm.bcast(self.ame_2, root=root)
        if self.ame_2 is not None:
            self.ame_2 = self.cache.put('ame_2', self.ame_2[my_pix])
        self.ame_nu_p_1 = self.cache.put(
            'ame_freq_1',
            self.comm.bcast(self.ame_nu_p_1, root=root)[my_pix])
        return

    def broadcast_dust_temperature(self):
        # Broadcast dust
        root = self.id_dust
        self.dust_fwhm = self.comm.bcast(self.dust_fwhm, root=root)
        self.dust_nside = self.comm.bcast(self.dust_nside, root=root)
        self.dust_nu_ref = self.comm.bcast(self.dust_nu_ref, root=root)
        my_pix = self.get_my_pix(self.dust_nside)
        self.dust_Ad = self.cache.put(
            'dust_Ad',
            self.comm.bcast(self.dust_Ad, root=root)[my_pix])
        # We need to store dust temperature and beta fully in case
        # the dust inputs are in the old format and the polarization
        # has different resolution than the temperature
        self.dust_temp = self.comm.bcast(self.dust_temp, root=root)
        self.dust_beta = self.comm.bcast(self.dust_beta, root=root)
        return

    def broadcast_dust_polarization(self):
        root = self.id_dust_pol
        self.dust_pol_fwhm = self.comm.bcast(self.dust_pol_fwhm, root=root)
        self.dust_pol_nside = self.comm.bcast(self.dust_pol_nside, root=root)
        self.dust_pol_nu_ref = self.comm.bcast(self.dust_pol_nu_ref, root=root)
        self.dust_temp_pol = self.comm.bcast(self.dust_temp_pol, root=root)
        my_pix = self.get_my_pix(self.dust_pol_nside)
        # amplitude
        self.dust_Ad_Q = self.cache.put(
            'dust_Ad_Q',
            self.comm.bcast(self.dust_Ad_Q, root=root)[my_pix])
        self.dust_Ad_U = self.cache.put(
            'dust_Ad_U',
            self.comm.bcast(self.dust_Ad_U, root=root)[my_pix])
        # temperature
        if self.dust_temp_pol is None:
            # Old dust inputs
            self.dust_temp_pol = hp.ud_grade(self.dust_temp,
                                             self.dust_pol_nside)
        self.dust_temp_pol = self.cache.put('dust_temp_pol',
                                            self.dust_temp_pol[my_pix])
        # beta
        self.dust_beta_pol = self.comm.bcast(self.dust_beta_pol, root=root)
        if self.dust_beta_pol is None:
            # Old dust inputs
            self.dust_beta_pol = hp.ud_grade(self.dust_beta,
                                             self.dust_pol_nside)
        self.dust_beta_pol = self.cache.put('dust_beta_pol',
                                            self.dust_beta_pol[my_pix])
        # Only now, can we extract the local portion of dust
        # temperature and beta.
        my_pix = self.get_my_pix(self.dust_nside)
        self.dust_temp = self.cache.put('dust_temp', self.dust_temp[my_pix])
        self.dust_beta = self.cache.put('dust_beta', self.dust_beta[my_pix])
        return

    def add_synchrotron(self, map_I, map_Q, map_U, freq, krj2kcmb):
        # synchrotron temperature
        if self.sync_beta is None:
            # Old format
            psd_sync_ref = np.exp(
                np.interp(np.log(self.sync_nu_ref), np.log(self.sync_psd_freq),
                          np.log(self.sync_psd)))
            psd_sync = np.exp(
                np.interp(np.log(freq), np.log(self.sync_psd_freq),
                          np.log(self.sync_psd)))
            scale = (self.sync_nu_ref / freq)**2 * psd_sync / psd_sync_ref
        else:
            # New format
            scale = (freq / self.sync_nu_ref) \
                ** self.sync_beta.astype(np.float64)
        sync = self.sync_As * scale
        sync = (sync * krj2kcmb).astype(DTYPE)
        if self.verbose and self.global_rank == 0:
            print('Smoothing synchrotron T', flush=True)
        sync = self.smooth(self.sync_fwhm, self.sync_nside, [sync])[0]
        map_I += sync
        # synchrotron polarization
        if self.sync_pol_beta is None:
            # Old format
            psd_sync_ref = np.exp(
                np.interp(np.log(self.sync_pol_nu_ref),
                          np.log(self.sync_pol_psd_freq),
                          np.log(self.sync_pol_psd)))
            psd_sync = np.exp(
                np.interp(np.log(freq), np.log(self.sync_pol_psd_freq),
                          np.log(self.sync_pol_psd)))
            scale = (self.sync_pol_nu_ref / freq)**2 * psd_sync / psd_sync_ref
        else:
            # New format
            scale = (freq / self.sync_pol_nu_ref) \
                ** self.sync_pol_beta.astype(np.float64)
        sync_Q = self.sync_As_Q * scale
        sync_U = self.sync_As_U * scale
        sync_Q = (sync_Q * krj2kcmb).astype(DTYPE)
        sync_U = (sync_U * krj2kcmb).astype(DTYPE)
        if self.verbose and self.global_rank == 0:
            print('Smoothing synchrotron P', flush=True)
        sync_Q, sync_U = self.smooth_pol(self.sync_pol_fwhm,
                                         self.sync_pol_nside,
                                         [(sync_Q, sync_U)])[0]
        map_Q += sync_Q
        map_U += sync_U
        return

    def add_freefree(self, map_I, freq, krj2kcmb):
        # freefree temperature
        if self.ff_amp is None:
            # 2015 model
            gff = np.log(
                np.exp(5.960 - np.sqrt(3) / np.pi *
                       np.log(freq * (self.ff_T_e.astype(np.float64) * 1e-4)**
                              (-1.5))) + np.exp(1))
            tau = (0.05468 * self.ff_T_e.astype(np.float64)**(-1.5) *
                   freq**(-2) * self.ff_em * gff)
            ff = 1e6 * self.ff_T_e.astype(np.float64) * (1 - np.exp(-tau))
        else:
            # 2018 model
            S = np.log(
                np.exp(5.960 - np.sqrt(3) / np.pi *
                       np.log(freq * (self.ff_T_e.astype(np.float64) * 1e-4)**
                              (-1.5))) + np.exp(1))
            S_ref = np.log(
                np.exp(5.960 - np.sqrt(3) / np.pi *
                       np.log(self.ff_nu_ref *
                              (self.ff_T_e.astype(np.float64) * 1e-4)**
                              (-1.5))) + np.exp(1))
            ff = self.ff_amp * S / S_ref * np.exp(
                -h * (freq - self.ff_nu_ref) / k /
                self.ff_T_e.astype(np.float64)) * (freq / self.ff_nu_ref)**(-2)
        # K_RJ -> K_CMB
        ff = (ff * krj2kcmb).astype(DTYPE)
        if self.verbose and self.global_rank == 0:
            print('Smoothing freefree', flush=True)
        ff = self.smooth(self.ff_fwhm, self.ff_nside, [ff])[0]
        map_I += ff
        return

    def add_ame(self, map_I, freq, krj2kcmb):
        # Prepare for logarithmic interpolation
        x, y = np.log(self.ame_psd_freq), np.log(self.ame_psd)
        # spinning dust, first component
        scale = self.ame_nu_p0 / self.ame_nu_p_1.astype(np.float64)
        arg1 = freq * scale
        arg2 = self.ame_nu_ref1 * scale
        psd_ame = np.exp(np.interp(np.log(arg1), x, y))
        psd_ame_ref = np.exp(np.interp(np.log(arg2), x, y))
        ame1 = self.ame_1 * (self.ame_nu_ref1 / freq) ** 2 * \
            psd_ame / psd_ame_ref
        ame1 = (ame1 * krj2kcmb).astype(DTYPE)
        if self.ame_2 is not None:
            # spinning dust, second component
            scale = self.ame_nu_p0 / self.ame_nu_p_2
            arg1 = freq * scale
            arg2 = self.ame_nu_ref2 * scale
            psd_ame = np.exp(np.interp(np.log(arg1), x, y))
            psd_ame_ref = np.exp(np.interp(np.log(arg2), x, y))
            ame2 = self.ame_2 * (self.ame_nu_ref2 / freq) ** 2 * \
                psd_ame / psd_ame_ref
            ame2 = (ame2 * krj2kcmb).astype(DTYPE)
        if self.verbose and self.global_rank == 0:
            print('Smoothing AME', flush=True)
        if self.ame_2 is None:
            ame1 = self.smooth(self.ame_fwhm, self.ame_nside, [ame1])[0]
            ame2 = 0
        else:
            ame1, ame2 = self.smooth(self.ame_fwhm, self.ame_nside,
                                     [ame1, ame2])
        map_I += ame1 + ame2
        return

    def add_dust(self, map_I, map_Q, map_U, freq, krj2kcmb):
        # thermal dust temperature
        gamma = h / k / self.dust_temp.astype(np.float64)
        scale = ((freq / self.dust_nu_ref)
                 **(self.dust_beta.astype(np.float64) + 1) *
                 (np.exp(gamma * self.dust_nu_ref * 1e9) - 1) /
                 (np.exp(gamma * freq * 1e9) - 1))
        dust = self.dust_Ad * scale
        dust = (dust * krj2kcmb).astype(DTYPE)
        if self.verbose and self.global_rank == 0:
            print('Smoothing dust T', flush=True)
        dust = self.smooth(self.dust_fwhm, self.dust_nside, [dust])[0]
        map_I += dust
        # thermal dust polarization
        gamma = h / k / self.dust_temp_pol.astype(np.float64)
        scale = ((freq / self.dust_pol_nu_ref)
                 **(self.dust_beta_pol.astype(np.float64) + 1) *
                 (np.exp(gamma * self.dust_pol_nu_ref * 1e9) - 1) /
                 (np.exp(gamma * freq * 1e9) - 1))
        dust_Q = self.dust_Ad_Q * scale
        dust_U = self.dust_Ad_U * scale
        dust_Q = (dust_Q * krj2kcmb).astype(DTYPE)
        dust_U = (dust_U * krj2kcmb).astype(DTYPE)
        if self.verbose and self.global_rank == 0:
            print('Smoothing dust P', flush=True)
        dust_Q, dust_U = self.smooth_pol(self.dust_pol_fwhm,
                                         self.dust_pol_nside,
                                         [(dust_Q, dust_U)])[0]
        map_Q += dust_Q
        map_U += dust_U
        return

    def eval(self, freq, synchrotron=True, freefree=True, ame=True, dust=True):
        """
        Evaluate the total sky model.
        Args:
        freq [GHz]
        """

        my_mtot = np.zeros(self.my_npix)
        my_mtot_Q = np.zeros(self.my_npix)
        my_mtot_U = np.zeros(self.my_npix)

        # uK_RJ -> K_CMB
        x = h * freq * 1e9 / k / TCMB
        # delta T_CMB / delta T_RJ
        g = (np.exp(x) - 1)**2 / (x**2 * np.exp(x))
        # uK_CMB -> K_CMB
        krj2kcmb = g * 1e-6

        if synchrotron:
            self.add_synchrotron(my_mtot, my_mtot_Q, my_mtot_U, freq, krj2kcmb)
        if freefree:
            self.add_freefree(my_mtot, freq, krj2kcmb)
        if ame:
            self.add_ame(my_mtot, freq, krj2kcmb)
        if dust:
            self.add_dust(my_mtot, my_mtot_Q, my_mtot_U, freq, krj2kcmb)

        # Gather the pieces, each process gets a copy of the full map

        my_pix = self.get_my_pix(self.nside)
        my_outmap = np.zeros([3, self.npix], dtype=np.float)
        outmap = np.zeros([3, self.npix], dtype=np.float)
        my_outmap[0, my_pix] = my_mtot
        my_outmap[1, my_pix] = my_mtot_Q
        my_outmap[2, my_pix] = my_mtot_U
        self.comm.Allreduce(my_outmap, outmap)
        del my_outmap

        return outmap
    def __init__(
        self,
        map_path,
        pol=False,
        pol_fwhm=None,
        no_temperature=False,
        dtype=None,
        plug_holes=True,
        verbose=False,
        nside=None,
        comm=None,
        cache=None,
        preloaded_map=None,
        buflen=1000000,
        nest=False,
        pscorrect=False,
        psradius=30,
        use_shmem=True,
    ):
        """
        Instantiate the map sampler object, load a healpix
        map in a file located at map_path

        if pol==True, reads I,Q,U maps from extensions 0, 1, 2
        """

        if not pol and no_temperature:
            raise RuntimeError("You cannot have pol=False, " "no_temperature=True")

        self.path = map_path
        self.pol = pol
        self.pol_fwhm = pol_fwhm
        self.nest = nest
        if nest:
            self.order = "NESTED"
        else:
            self.order = "RING"
        self.pscorrect = pscorrect
        self.psradius = psradius
        self.buflen = buflen
        # Output data type, internal is always DTYPE
        if dtype is not None:
            warnings.warn("MapSampler no longer supports dtype", DeprecationWarning)

        # Use healpy to load the map into memory.

        if comm is None:
            self.comm = None
            self.rank = 0
            self.ntask = 1
        else:
            self.comm = comm
            self.rank = comm.Get_rank()
            self.ntask = comm.Get_size()

        self.shmem = self.ntask > 1 and use_shmem
        self.pol = pol

        if self.rank == 0:
            if map_path is None and preloaded_map is None:
                raise RuntimeError("Either map_path or preloaded_map must be provided")
            if map_path is not None and preloaded_map is not None:
                if os.path.isfile(map_path):
                    raise RuntimeError(
                        "Both map_path and preloaded_map cannot be provided"
                    )
            if self.pol:
                if preloaded_map is not None:
                    if pscorrect or plug_holes or self.pol_fwhm is not None:
                        copy = True
                    else:
                        copy = False
                    if no_temperature:
                        self.Map_Q = preloaded_map[0].astype(DTYPE, copy=copy)
                        self.Map_U = preloaded_map[1].astype(DTYPE, copy=copy)
                    else:
                        self.Map = preloaded_map[0].astype(DTYPE, copy=copy)
                        self.Map_Q = preloaded_map[1].astype(DTYPE, copy=copy)
                        self.Map_U = preloaded_map[2].astype(DTYPE, copy=copy)
                else:
                    if no_temperature:
                        self.Map_Q, self.Map_U = hp.read_map(
                            self.path,
                            field=[1, 2],
                            dtype=DTYPE,
                            verbose=verbose,
                            memmmap=True,
                            nest=self.nest,
                        )
                    else:
                        try:
                            self.Map, self.Map_Q, self.Map_U = hp.read_map(
                                self.path,
                                field=[0, 1, 2],
                                dtype=DTYPE,
                                verbose=verbose,
                                memmap=True,
                                nest=self.nest,
                            )
                        except IndexError:
                            print(
                                "WARNING: {} is not polarized".format(self.path),
                                flush=True,
                            )
                            self.pol = False
                            self.Map = hp.read_map(
                                self.path,
                                dtype=DTYPE,
                                verbose=verbose,
                                memmap=True,
                                nest=self.nest,
                            )

                if nside is not None:
                    if not no_temperature:
                        self.Map = hp.ud_grade(
                            self.Map,
                            nside,
                            dtype=DTYPE,
                            order_in=self.order,
                            order_out=self.order,
                        )
                    if self.pol:
                        self.Map_Q = hp.ud_grade(
                            self.Map_Q,
                            nside,
                            dtype=DTYPE,
                            order_in=self.order,
                            order_out=self.order,
                        )
                        self.Map_U = hp.ud_grade(
                            self.Map_U,
                            nside,
                            dtype=DTYPE,
                            order_in=self.order,
                            order_out=self.order,
                        )

                if self.pscorrect:
                    if not no_temperature:
                        utilities.remove_bright_sources(
                            self.Map, nest=self.nest, fwhm=self.psradius
                        )
                    if self.pol:
                        utilities.remove_bright_sources(
                            self.Map_Q, nest=self.nest, fwhm=self.psradius
                        )
                        utilities.remove_bright_sources(
                            self.Map_U, nest=self.nest, fwhm=self.psradius
                        )
                elif plug_holes or self.pol_fwhm is not None:
                    if not no_temperature:
                        utilities.plug_holes(self.Map, verbose=verbose, nest=self.nest)
                    if self.pol:
                        utilities.plug_holes(
                            self.Map_Q, verbose=verbose, nest=self.nest
                        )
                        utilities.plug_holes(
                            self.Map_U, verbose=verbose, nest=self.nest
                        )
            else:
                if preloaded_map is not None:
                    self.Map = np.array(preloaded_map, dtype=DTYPE)
                else:
                    self.Map = hp.read_map(
                        map_path,
                        field=[0],
                        dtype=DTYPE,
                        verbose=verbose,
                        memmap=True,
                        nest=self.nest,
                    )
                if nside is not None:
                    self.Map = hp.ud_grade(
                        self.Map,
                        nside,
                        dtype=DTYPE,
                        order_in=self.order,
                        order_out=self.order,
                    )
                if self.pscorrect:
                    utilities.remove_bright_sources(
                        self.Map, nest=self.nest, fwhm=self.psradius
                    )
                elif plug_holes:
                    utilities.plug_holes(self.Map, verbose=verbose, nest=self.nest)

        if self.ntask > 1:
            self.pol = comm.bcast(self.pol, root=0)
            npix = 0
            if self.rank == 0:
                if self.pol:
                    npix = len(self.Map_Q)
                else:
                    npix = len(self.Map)
            npix = comm.bcast(npix, root=0)
            if self.shmem:
                shared = MPIShared((npix,), np.dtype(DTYPE), comm)
                if not no_temperature:
                    if self.rank == 0 and self.Map is None:
                        raise RuntimeError("Cannot set shared map from None")
                    shared.set(self.Map, (0,), fromrank=0)
                    self.Map = shared
                if self.pol:
                    if self.rank == 0 and self.Map_Q is None:
                        raise RuntimeError("Cannot set shared map from None")
                    shared_Q = MPIShared((npix,), np.dtype(DTYPE), comm)
                    shared_Q.set(self.Map_Q, (0,), fromrank=0)
                    self.Map_Q = shared_Q
                    shared_U = MPIShared((npix,), np.dtype(DTYPE), comm)
                    shared_U.set(self.Map_U, (0,), fromrank=0)
                    self.Map_U = shared_U
            else:
                if self.rank != 0:
                    if not no_temperature:
                        self.Map = np.zeros(npix, dtype=DTYPE)
                    if self.pol:
                        self.Map_Q = np.zeros(npix, dtype=DTYPE)
                        self.Map_U = np.zeros(npix, dtype=DTYPE)

                if not no_temperature:
                    comm.Bcast(self.Map, root=0)
                if self.pol:
                    comm.Bcast(self.Map_Q, root=0)
                    comm.Bcast(self.Map_U, root=0)

        if self.pol:
            self.npix = len(self.Map_Q[:])
        else:
            self.npix = len(self.Map[:])
        self.nside = hp.npix2nside(self.npix)

        if cache is None:
            self.cache = Cache()
        else:
            self.cache = cache
        self.instance = 0
        if not self.shmem:
            # Increase the instance counter until we find an unused
            # instance.  If the user did not want to store duplicates,
            # they would not have created two identical mapsampler
            # objects.
            while self.cache.exists(self._cachename("I")):
                self.instance += 1
            if not no_temperature:
                self.Map = self.cache.put(self._cachename("I"), self.Map)
            if self.pol:
                self.Map_Q = self.cache.put(self._cachename("Q"), self.Map_Q)
                self.Map_U = self.cache.put(self._cachename("U"), self.Map_U)

        if self.pol_fwhm is not None:
            self.smooth(self.pol_fwhm, pol_only=True)
        return
class MapSampler:

    Map = None
    Map_Q = None
    Map_U = None

    def __init__(
        self,
        map_path,
        pol=False,
        pol_fwhm=None,
        no_temperature=False,
        dtype=None,
        plug_holes=True,
        verbose=False,
        nside=None,
        comm=None,
        cache=None,
        preloaded_map=None,
        buflen=1000000,
        nest=False,
        pscorrect=False,
        psradius=30,
        use_shmem=True,
    ):
        """
        Instantiate the map sampler object, load a healpix
        map in a file located at map_path

        if pol==True, reads I,Q,U maps from extensions 0, 1, 2
        """

        if not pol and no_temperature:
            raise RuntimeError("You cannot have pol=False, " "no_temperature=True")

        self.path = map_path
        self.pol = pol
        self.pol_fwhm = pol_fwhm
        self.nest = nest
        if nest:
            self.order = "NESTED"
        else:
            self.order = "RING"
        self.pscorrect = pscorrect
        self.psradius = psradius
        self.buflen = buflen
        # Output data type, internal is always DTYPE
        if dtype is not None:
            warnings.warn("MapSampler no longer supports dtype", DeprecationWarning)

        # Use healpy to load the map into memory.

        if comm is None:
            self.comm = None
            self.rank = 0
            self.ntask = 1
        else:
            self.comm = comm
            self.rank = comm.Get_rank()
            self.ntask = comm.Get_size()

        self.shmem = self.ntask > 1 and use_shmem
        self.pol = pol

        if self.rank == 0:
            if map_path is None and preloaded_map is None:
                raise RuntimeError("Either map_path or preloaded_map must be provided")
            if map_path is not None and preloaded_map is not None:
                if os.path.isfile(map_path):
                    raise RuntimeError(
                        "Both map_path and preloaded_map cannot be provided"
                    )
            if self.pol:
                if preloaded_map is not None:
                    if pscorrect or plug_holes or self.pol_fwhm is not None:
                        copy = True
                    else:
                        copy = False
                    if no_temperature:
                        self.Map_Q = preloaded_map[0].astype(DTYPE, copy=copy)
                        self.Map_U = preloaded_map[1].astype(DTYPE, copy=copy)
                    else:
                        self.Map = preloaded_map[0].astype(DTYPE, copy=copy)
                        self.Map_Q = preloaded_map[1].astype(DTYPE, copy=copy)
                        self.Map_U = preloaded_map[2].astype(DTYPE, copy=copy)
                else:
                    if no_temperature:
                        self.Map_Q, self.Map_U = hp.read_map(
                            self.path,
                            field=[1, 2],
                            dtype=DTYPE,
                            verbose=verbose,
                            memmmap=True,
                            nest=self.nest,
                        )
                    else:
                        try:
                            self.Map, self.Map_Q, self.Map_U = hp.read_map(
                                self.path,
                                field=[0, 1, 2],
                                dtype=DTYPE,
                                verbose=verbose,
                                memmap=True,
                                nest=self.nest,
                            )
                        except IndexError:
                            print(
                                "WARNING: {} is not polarized".format(self.path),
                                flush=True,
                            )
                            self.pol = False
                            self.Map = hp.read_map(
                                self.path,
                                dtype=DTYPE,
                                verbose=verbose,
                                memmap=True,
                                nest=self.nest,
                            )

                if nside is not None:
                    if not no_temperature:
                        self.Map = hp.ud_grade(
                            self.Map,
                            nside,
                            dtype=DTYPE,
                            order_in=self.order,
                            order_out=self.order,
                        )
                    if self.pol:
                        self.Map_Q = hp.ud_grade(
                            self.Map_Q,
                            nside,
                            dtype=DTYPE,
                            order_in=self.order,
                            order_out=self.order,
                        )
                        self.Map_U = hp.ud_grade(
                            self.Map_U,
                            nside,
                            dtype=DTYPE,
                            order_in=self.order,
                            order_out=self.order,
                        )

                if self.pscorrect:
                    if not no_temperature:
                        utilities.remove_bright_sources(
                            self.Map, nest=self.nest, fwhm=self.psradius
                        )
                    if self.pol:
                        utilities.remove_bright_sources(
                            self.Map_Q, nest=self.nest, fwhm=self.psradius
                        )
                        utilities.remove_bright_sources(
                            self.Map_U, nest=self.nest, fwhm=self.psradius
                        )
                elif plug_holes or self.pol_fwhm is not None:
                    if not no_temperature:
                        utilities.plug_holes(self.Map, verbose=verbose, nest=self.nest)
                    if self.pol:
                        utilities.plug_holes(
                            self.Map_Q, verbose=verbose, nest=self.nest
                        )
                        utilities.plug_holes(
                            self.Map_U, verbose=verbose, nest=self.nest
                        )
            else:
                if preloaded_map is not None:
                    self.Map = np.array(preloaded_map, dtype=DTYPE)
                else:
                    self.Map = hp.read_map(
                        map_path,
                        field=[0],
                        dtype=DTYPE,
                        verbose=verbose,
                        memmap=True,
                        nest=self.nest,
                    )
                if nside is not None:
                    self.Map = hp.ud_grade(
                        self.Map,
                        nside,
                        dtype=DTYPE,
                        order_in=self.order,
                        order_out=self.order,
                    )
                if self.pscorrect:
                    utilities.remove_bright_sources(
                        self.Map, nest=self.nest, fwhm=self.psradius
                    )
                elif plug_holes:
                    utilities.plug_holes(self.Map, verbose=verbose, nest=self.nest)

        if self.ntask > 1:
            self.pol = comm.bcast(self.pol, root=0)
            npix = 0
            if self.rank == 0:
                if self.pol:
                    npix = len(self.Map_Q)
                else:
                    npix = len(self.Map)
            npix = comm.bcast(npix, root=0)
            if self.shmem:
                shared = MPIShared((npix,), np.dtype(DTYPE), comm)
                if not no_temperature:
                    if self.rank == 0 and self.Map is None:
                        raise RuntimeError("Cannot set shared map from None")
                    shared.set(self.Map, (0,), fromrank=0)
                    self.Map = shared
                if self.pol:
                    if self.rank == 0 and self.Map_Q is None:
                        raise RuntimeError("Cannot set shared map from None")
                    shared_Q = MPIShared((npix,), np.dtype(DTYPE), comm)
                    shared_Q.set(self.Map_Q, (0,), fromrank=0)
                    self.Map_Q = shared_Q
                    shared_U = MPIShared((npix,), np.dtype(DTYPE), comm)
                    shared_U.set(self.Map_U, (0,), fromrank=0)
                    self.Map_U = shared_U
            else:
                if self.rank != 0:
                    if not no_temperature:
                        self.Map = np.zeros(npix, dtype=DTYPE)
                    if self.pol:
                        self.Map_Q = np.zeros(npix, dtype=DTYPE)
                        self.Map_U = np.zeros(npix, dtype=DTYPE)

                if not no_temperature:
                    comm.Bcast(self.Map, root=0)
                if self.pol:
                    comm.Bcast(self.Map_Q, root=0)
                    comm.Bcast(self.Map_U, root=0)

        if self.pol:
            self.npix = len(self.Map_Q[:])
        else:
            self.npix = len(self.Map[:])
        self.nside = hp.npix2nside(self.npix)

        if cache is None:
            self.cache = Cache()
        else:
            self.cache = cache
        self.instance = 0
        if not self.shmem:
            # Increase the instance counter until we find an unused
            # instance.  If the user did not want to store duplicates,
            # they would not have created two identical mapsampler
            # objects.
            while self.cache.exists(self._cachename("I")):
                self.instance += 1
            if not no_temperature:
                self.Map = self.cache.put(self._cachename("I"), self.Map)
            if self.pol:
                self.Map_Q = self.cache.put(self._cachename("Q"), self.Map_Q)
                self.Map_U = self.cache.put(self._cachename("U"), self.Map_U)

        if self.pol_fwhm is not None:
            self.smooth(self.pol_fwhm, pol_only=True)
        return

    def smooth(self, fwhm, lmax=None, pol_only=False):
        """ Smooth the map with a Gaussian kernel.
        """
        if self.rank == 0:
            if pol_only:
                print(
                    "Smoothing the polarization to {} arcmin".format(fwhm), flush=True
                )
            else:
                print("Smoothing the map to {} arcmin".format(fwhm), flush=True)

        if lmax is None:
            lmax = min(np.int(fwhm / 60 * 512), 2 * self.nside)

        # If the map is in node-shared memory, only the root process on each
        # node does the smoothing.
        if not self.shmem or self.Map.nodecomm.rank == 0:
            if self.pol:
                m = np.vstack([self.Map[:], self.Map_Q[:], self.Map_U[:]])
            else:
                m = self.Map[:]
            if self.nest:
                m = hp.reorder(m, n2r=True)
            smap = hp.smoothing(m, fwhm=fwhm * arcmin, lmax=lmax, verbose=False)
            del m
            if self.nest:
                smap = hp.reorder(smap, r2n=True)
        else:
            # Convenience dummy variable
            smap = np.zeros([3, 12])

        if not pol_only:
            if self.shmem:
                self.Map.set(smap[0].astype(DTYPE, copy=False), (0,), fromrank=0)
            else:
                self.Map[:] = smap[0]

        if self.pol:
            if self.shmem:
                self.Map_Q.set(smap[1].astype(DTYPE, copy=False), (0,), fromrank=0)
                self.Map_U.set(smap[2].astype(DTYPE, copy=False), (0,), fromrank=0)
            else:
                self.Map_Q[:] = smap[1]
                self.Map_U[:] = smap[2]

        self.pol_fwhm = fwhm
        return

    def _cachename(self, stokes):
        """
        Construct a cache name string for the selected Stokes map
        """
        return "{}_ns{:04}_{}_{:04}".format(
            self.path, self.nside, stokes, self.instance
        )

    def __del__(self):
        """
        Explicitly free memory taken up in the cache.
        """
        if not self.shmem:
            # Ensure the cache objects are destroyed after their references
            self.Map = None
            self.Map_Q = None
            self.Map_U = None
            self.cache.destroy(self._cachename("I"))
            if self.pol:
                self.cache.destroy(self._cachename("Q"))
                self.cache.destroy(self._cachename("U"))

    def __iadd__(self, other):
        """ Accumulate provided Mapsampler object with this one.
        """
        if self.shmem:
            # One process does the manipulation on each node
            self.Map._nodecomm.Barrier()
            if self.Map._noderank == 0:
                self.Map.data[:] += other.Map[:]
            if self.pol and other.pol:
                if self.Map_Q._noderank == (1 % self.Map_Q._nodeprocs):
                    self.Map_Q.data[:] += other.Map_Q[:]
                if self.Map_U._noderank == (2 % self.Map_U._nodeprocs):
                    self.Map_U.data[:] += other.Map_U[:]
            self.Map._nodecomm.Barrier()
        else:
            self.Map += other.Map
            if self.pol and other.pol:
                self.Map_Q += other.Map_Q
                self.Map_U += other.Map_U
        return self

    def __isub__(self, other):
        """ Subtract provided Mapsampler object from this one.
        """
        if self.shmem:
            # One process does the manipulation on each node
            self.Map._nodecomm.Barrier()
            if self.Map._noderank == 0:
                self.Map.data[:] -= other.Map[:]
            if self.pol and other.pol:
                if self.Map_Q._noderank == (1 % self.Map_Q._nodeprocs):
                    self.Map_Q.data[:] -= other.Map_Q[:]
                if self.Map_U._noderank == (2 % self.Map_U._nodeprocs):
                    self.Map_U.data[:] -= other.Map_U[:]
            self.Map._nodecomm.Barrier()
        else:
            self.Map -= other.Map
            if self.pol and other.pol:
                self.Map_Q -= other.Map_Q
                self.Map_U -= other.Map_U
        return self

    def __imul__(self, other):
        """ Scale the maps in this MapSampler object
        """
        if self.shmem:
            # One process does the manipulation on each node
            self.Map._nodecomm.Barrier()
            if self.Map._noderank == 0:
                self.Map.data[:] *= other
            if self.pol:
                if self.Map_Q._noderank == (1 % self.Map_Q._nodeprocs):
                    self.Map_Q.data[:] *= other
                if self.Map_U._noderank == (2 % self.Map_U._nodeprocs):
                    self.Map_U.data[:] *= other
            self.Map._nodecomm.Barrier()
        else:
            self.Map *= other
            if self.pol:
                self.Map_Q *= other
                self.Map_U *= other
        return self

    def __itruediv__(self, other):
        """ Divide the maps in this MapSampler object
        """
        if self.shmem:
            self.Map._nodecomm.Barrier()
            if self.Map._noderank == 0:
                self.Map.data[:] /= other
            if self.pol:
                if self.Map_Q._noderank == (1 % self.Map_Q._nodeprocs):
                    self.Map_Q.data[:] /= other
                if self.Map_U._noderank == (2 % self.Map_U._nodeprocs):
                    self.Map_U.data[:] /= other
            self.Map._nodecomm.Barrier()
        else:
            self.Map /= other
            if self.pol:
                self.Map_Q /= other
                self.Map_U /= other
        return self

    def at(self, theta, phi, interp_pix=None, interp_weights=None):
        """
        Use healpy bilinear interpolation to interpolate the
        map.  User must make sure that coordinate system used
        for theta and phi matches the map coordinate system.
        """

        if self.Map is None:
            raise RuntimeError("No temperature map to sample")

        n = len(theta)
        stepsize = self.buflen
        signal = np.zeros(n, dtype=np.float32)

        for istart in range(0, n, stepsize):
            istop = min(istart + stepsize, n)
            ind = slice(istart, istop)
            if interp_pix is None or interp_weights is None:
                p, w = hp.get_interp_weights(
                    self.nside, theta[ind], phi[ind], nest=self.nest
                )
            else:
                p = np.ascontiguousarray(interp_pix[:, ind])
                w = np.ascontiguousarray(interp_weights[:, ind])
            buffer = np.zeros(istop - istart, dtype=np.float64)
            fast_scanning32(buffer, p, w, self.Map[:])
            signal[ind] = buffer
        return signal

    def atpol(
        self,
        theta,
        phi,
        IQUweight,
        onlypol=False,
        interp_pix=None,
        interp_weights=None,
        pol=True,
        pol_deriv=False,
    ):
        """
        Use healpy bilinear interpolation to interpolate the
        map.  User must make sure that coordinate system used
        for theta and phi matches the map coordinate system.
        IQUweight is an array of shape (nsamp,3) returned by the
        pointing library that gives the weights of the I,Q, and U maps.

        Args:
            pol_deriv(bool):  Return the polarization angle derivative
                of the signal instead of the actual signal.

        """

        if onlypol and not self.pol:
            return None

        if not self.pol or not pol:
            return self.at(
                theta, phi, interp_pix=interp_pix, interp_weights=interp_weights
            )

        if np.shape(IQUweight)[1] != 3:
            raise RuntimeError(
                "Cannot sample polarized map with only " "intensity weights"
            )

        n = len(theta)
        stepsize = self.buflen
        signal = np.zeros(n, dtype=np.float32)

        for istart in range(0, n, stepsize):
            istop = min(istart + stepsize, n)
            ind = slice(istart, istop)

            if interp_pix is None or interp_weights is None:
                p, w = hp.get_interp_weights(
                    self.nside, theta[ind], phi[ind], nest=self.nest
                )
            else:
                p = np.ascontiguousarray(interp_pix[:, ind])
                w = np.ascontiguousarray(interp_weights[:, ind])

            weights = np.ascontiguousarray(IQUweight[ind].T)

            buffer = np.zeros(istop - istart, dtype=np.float64)
            fast_scanning32(buffer, p, w, self.Map_Q[:])
            if pol_deriv:
                signal[ind] = -2 * weights[2] * buffer
            else:
                signal[ind] = weights[1] * buffer

            buffer[:] = 0
            fast_scanning32(buffer, p, w, self.Map_U[:])
            if pol_deriv:
                signal[ind] += 2 * weights[1] * buffer
            else:
                signal[ind] += weights[2] * buffer

            if not onlypol:
                if self.Map is None:
                    raise RuntimeError("No temperature map to sample")
                buffer[:] = 0
                fast_scanning32(buffer, p, w, self.Map[:])
                signal[ind] += weights[0] * buffer

        return signal
Exemple #9
0
    def __init__(
        self,
        params={},
        detweights=None,
        pixels="pixels",
        pixels_nested=True,
        weights="weights",
        name=None,
        name_out=None,
        flag_name=None,
        flag_mask=255,
        common_flag_name=None,
        common_flag_mask=255,
        apply_flags=True,
        purge=False,
        dets=None,
        mcmode=False,
        purge_tod=False,
        purge_pixels=False,
        purge_weights=False,
        purge_flags=False,
        noise="noise",
        intervals="intervals",
        conserve_memory=True,
        translate_timestamps=True,
    ):

        # We call the parent class constructor, which currently does nothing
        super().__init__()
        # madam uses time-based distribution
        self._name = name
        self._name_out = name_out
        self._flag_name = flag_name
        self._flag_mask = flag_mask
        self._common_flag_name = common_flag_name
        self._common_flag_mask = common_flag_mask
        self._pixels = pixels
        self._pixels_nested = pixels_nested
        self._weights = weights
        self._detw = detweights
        self._purge = purge
        if self._purge:
            self._purge_tod = True
            self._purge_pixels = True
            self._purge_weights = True
            self._purge_flags = True
        else:
            self._purge_tod = purge_tod
            self._purge_pixels = purge_pixels
            self._purge_weights = purge_weights
            self._purge_flags = purge_flags
        self._apply_flags = apply_flags
        self.params = params
        if dets is not None:
            self._dets = set(dets)
        else:
            self._dets = None
        self._mcmode = mcmode
        if mcmode:
            self.params["mcmode"] = True
        else:
            self.params["mcmode"] = False
        if self._name_out is not None:
            self.params["write_tod"] = True
        else:
            self.params["write_tod"] = False
        self._cached = False
        self._noisekey = noise
        self._intervals = intervals
        self._cache = Cache()
        self._madam_timestamps = None
        self._madam_pixels = None
        self._madam_pixweights = None
        self._madam_signal = None
        if conserve_memory is None:
            conserve_memory = True
        self._conserve_memory = int(conserve_memory)
        self._translate_timestamps = translate_timestamps
        if "info" in params:
            self._verbose = int(params["info"]) > 0
        else:
            self._verbose = True
Exemple #10
0
class DistRings(object):
    """
    A map with unique pixels distributed as disjoint isolatitude rings.

    Designed for Harmonic Transforms with libsharp

    Pixel domain data is distributed across an MPI communicator.  Each
    process has a number of isolatitude rings and a list of the pixels
    within those rings.

    Args:
        comm (mpi4py.MPI.Comm): the MPI communicator containing all
            processes.
        size (int): the total number of pixels.
        nnz (int): the number of values per pixel.
        submap (int): the locally stored data is in units of this size.
        local (array): the list of local submaps (integers).
        localpix (array): the list of local pixels (integers).
        nest (bool): nested pixel order flag
    """
    def __init__(self, comm=None, nnz=1, dtype=np.float64, nside=16):
        if libsharp is None:
            raise RuntimeError('libsharp not available')
        self.data = None
        self._comm = comm
        self._nnz = nnz
        self._dtype = dtype
        self._nest = False
        self._nside = nside

        self._cache = Cache()

        self._libsharp_grid, self._local_ring_indices = distribute_rings(
            self._nside, self._comm.rank, self._comm.size)
        # returns start index of the ring and number of pixels
        startpix, ringpix, _, _, _ = hp.ringinfo(
            self._nside, self._local_ring_indices.astype(np.int64))

        local_npix = self._libsharp_grid.local_size()
        self._local_pixels = self._cache.create("local_pixels",
                                                shape=(local_npix, ),
                                                type=np.int64)
        expand_pix(startpix, ringpix, local_npix, self._local_pixels)

        self.data = self._cache.create("data",
                                       shape=(local_npix, self._nnz),
                                       type=self._dtype)

    def __del__(self):
        if self.data is not None:
            del self.data
        del self._local_pixels
        self._cache.clear()

    @property
    def comm(self):
        """
        (mpi4py.MPI.Comm): The MPI communicator used.
        """
        return self._comm

    @property
    def nnz(self):
        """
        (int): The number of non-zero values per pixel.
        """
        return self._nnz

    @property
    def dtype(self):
        """
        (numpy.dtype): The data type of the values.
        """
        return self._dtype

    @property
    def nested(self):
        """
        (bool): If True, data is HEALPix NESTED ordering.
        """
        return self._nest

    @property
    def local_pixels(self):
        """
        (numpy.ndarray int64): Array of local pixel indices in RING ordering
        """
        return self._local_pixels

    @property
    def libsharp_grid(self):
        """
        (libsharp grid): Libsharp grid distribution object
        """
        return self._libsharp_grid
Exemple #11
0
    def __init__(
        self,
        params={},
        detweights=None,
        pixels="pixels",
        pixels_nested=True,
        weights="weights",
        name=None,
        name_out=None,
        flag_name=None,
        flag_mask=255,
        common_flag_name=None,
        common_flag_mask=255,
        apply_flags=True,
        purge=False,
        dets=None,
        mcmode=False,
        purge_tod=False,
        purge_pixels=False,
        purge_weights=False,
        purge_flags=False,
        noise="noise",
        intervals="intervals",
        conserve_memory=True,
        translate_timestamps=True,
    ):

        # We call the parent class constructor, which currently does nothing
        super().__init__()
        # madam uses time-based distribution
        self._name = name
        self._name_out = name_out
        self._flag_name = flag_name
        self._flag_mask = flag_mask
        self._common_flag_name = common_flag_name
        self._common_flag_mask = common_flag_mask
        self._pixels = pixels
        self._pixels_nested = pixels_nested
        self._weights = weights
        self._detw = detweights
        self._purge = purge
        if self._purge:
            self._purge_tod = True
            self._purge_pixels = True
            self._purge_weights = True
            self._purge_flags = True
        else:
            self._purge_tod = purge_tod
            self._purge_pixels = purge_pixels
            self._purge_weights = purge_weights
            self._purge_flags = purge_flags
        self._apply_flags = apply_flags
        self.params = params
        if dets is not None:
            self._dets = set(dets)
        else:
            self._dets = None
        self._mcmode = mcmode
        if mcmode:
            self.params["mcmode"] = True
        else:
            self.params["mcmode"] = False
        if self._name_out is not None:
            self.params["write_tod"] = True
        else:
            self.params["write_tod"] = False
        self._cached = False
        self._noisekey = noise
        self._intervals = intervals
        self._cache = Cache()
        self._madam_timestamps = None
        self._madam_pixels = None
        self._madam_pixweights = None
        self._madam_signal = None
        if conserve_memory is None:
            conserve_memory = True
        self._conserve_memory = int(conserve_memory)
        self._translate_timestamps = translate_timestamps
        if "info" in params:
            self._verbose = int(params["info"]) > 0
        else:
            self._verbose = True
Exemple #12
0
class OpMadam(Operator):
    """
    Operator which passes data to libmadam for map-making.

    Args:
        params (dictionary): parameters to pass to madam.
        detweights (dictionary): individual noise weights to use for each
            detector.
        pixels (str): the name of the cache object (<pixels>_<detector>)
            containing the pixel indices to use.
        pixels_nested (bool): Set to False if the pixel numbers are in
            ring ordering. Default is True.
        weights (str): the name of the cache object (<weights>_<detector>)
            containing the pointing weights to use.
        name (str): the name of the cache object (<name>_<detector>) to
            use for the detector timestream.  If None, use the TOD.
        name_out (str): the name of the cache object (<name>_<detector>)
            to use to output destriped detector timestream.
            No output if None.
        flag_name (str): the name of the cache object
            (<flag_name>_<detector>) to use for the detector flags.
            If None, use the TOD.
        flag_mask (int): the integer bit mask (0-255) that should be
            used with the detector flags in a bitwise AND.
        common_flag_name (str): the name of the cache object
            to use for the common flags.  If None, use the TOD.
        common_flag_mask (int): the integer bit mask (0-255) that should
            be used with the common flags in a bitwise AND.
        apply_flags (bool): whether to apply flags to the pixel numbers.
        purge (bool): if True, clear any cached data that is copied into
            the Madam buffers.
        purge_tod (bool): if True, clear any cached signal that is
            copied into the Madam buffers.
        purge_pixels (bool): if True, clear any cached pixels that are
            copied into the Madam buffers.
        purge_weights (bool): if True, clear any cached weights that are
            copied into the Madam buffers.
        purge_flags (bool): if True, clear any cached flags that are
            copied into the Madam buffers.
        dets (iterable):  List of detectors to map. If left as None, all
            available detectors are mapped.
        mcmode (bool): If true, the operator is constructed in
            Monte Carlo mode and Madam will cache auxiliary information
            such as pixel matrices and noise filter.
        noise (str): Keyword to use when retrieving the noise object
            from the observation.
        conserve_memory(bool/int): Stagger the Madam buffer staging on node.
        translate_timestamps(bool): Translate timestamps to enforce
            monotonity.
    """

    def __init__(
        self,
        params={},
        detweights=None,
        pixels="pixels",
        pixels_nested=True,
        weights="weights",
        name=None,
        name_out=None,
        flag_name=None,
        flag_mask=255,
        common_flag_name=None,
        common_flag_mask=255,
        apply_flags=True,
        purge=False,
        dets=None,
        mcmode=False,
        purge_tod=False,
        purge_pixels=False,
        purge_weights=False,
        purge_flags=False,
        noise="noise",
        intervals="intervals",
        conserve_memory=True,
        translate_timestamps=True,
    ):

        # We call the parent class constructor, which currently does nothing
        super().__init__()
        # madam uses time-based distribution
        self._name = name
        self._name_out = name_out
        self._flag_name = flag_name
        self._flag_mask = flag_mask
        self._common_flag_name = common_flag_name
        self._common_flag_mask = common_flag_mask
        self._pixels = pixels
        self._pixels_nested = pixels_nested
        self._weights = weights
        self._detw = detweights
        self._purge = purge
        if self._purge:
            self._purge_tod = True
            self._purge_pixels = True
            self._purge_weights = True
            self._purge_flags = True
        else:
            self._purge_tod = purge_tod
            self._purge_pixels = purge_pixels
            self._purge_weights = purge_weights
            self._purge_flags = purge_flags
        self._apply_flags = apply_flags
        self.params = params
        if dets is not None:
            self._dets = set(dets)
        else:
            self._dets = None
        self._mcmode = mcmode
        if mcmode:
            self.params["mcmode"] = True
        else:
            self.params["mcmode"] = False
        if self._name_out is not None:
            self.params["write_tod"] = True
        else:
            self.params["write_tod"] = False
        self._cached = False
        self._noisekey = noise
        self._intervals = intervals
        self._cache = Cache()
        self._madam_timestamps = None
        self._madam_pixels = None
        self._madam_pixweights = None
        self._madam_signal = None
        if conserve_memory is None:
            conserve_memory = True
        self._conserve_memory = int(conserve_memory)
        self._translate_timestamps = translate_timestamps
        if "info" in params:
            self._verbose = int(params["info"]) > 0
        else:
            self._verbose = True

    def __del__(self):
        self._cache.clear()
        if self._cached:
            madam.clear_caches()
            self._cached = False

    @property
    def available(self):
        """
        (bool): True if libmadam is found in the library search path.
        """
        return madam is not None and madam.available

    def exec(self, data, comm=None):
        """
        Copy data to Madam-compatible buffers and make a map.

        Args:
            data (toast.Data): The distributed data.
        """
        if not self.available:
            raise RuntimeError("libmadam is not available")

        if len(data.obs) == 0:
            raise RuntimeError(
                "OpMadam requires every supplied data object to "
                "contain at least one observation"
            )

        auto_timer = timing.auto_timer(type(self).__name__)

        if comm is None:
            # Just use COMM_WORLD
            comm = data.comm.comm_world

        (
            pars,
            dets,
            nsamp,
            ndet,
            nnz,
            nnz_full,
            nnz_stride,
            periods,
            obs_period_ranges,
            psdfreqs,
            nside,
        ) = self._prepare(data, comm)

        psdinfo, signal_type, pixels_dtype, weight_dtype = self._stage_data(
            data,
            comm,
            nsamp,
            ndet,
            nnz,
            nnz_full,
            nnz_stride,
            obs_period_ranges,
            psdfreqs,
            dets,
            nside,
        )

        # if comm.rank == 0:
        #    data.obs[0]['tod'].cache.report()

        self._destripe(comm, pars, dets, periods, psdinfo)

        self._unstage_data(
            comm,
            data,
            nsamp,
            nnz,
            nnz_full,
            obs_period_ranges,
            dets,
            signal_type,
            pixels_dtype,
            nside,
            weight_dtype,
        )

        # if comm.rank == 0:
        #    data.obs[0]['tod'].cache.report()

        return

    def _destripe(self, comm, pars, dets, periods, psdinfo):
        """ Destripe the buffered data

        """
        auto_timer = timing.auto_timer(type(self).__name__)
        if self._verbose:
            memreport(comm, "just before calling libmadam.destripe")
        if self._cached:
            # destripe
            outpath = ""
            if "path_output" in self.params:
                outpath = self.params["path_output"]
            outpath = outpath.encode("ascii")
            madam.destripe_with_cache(
                comm,
                self._madam_timestamps,
                self._madam_pixels,
                self._madam_pixweights,
                self._madam_signal,
                outpath,
            )
        else:
            (detweights, npsd, psdstarts, psdfreqs, psdvals) = psdinfo

            # destripe
            madam.destripe(
                comm,
                pars,
                dets,
                detweights,
                self._madam_timestamps,
                self._madam_pixels,
                self._madam_pixweights,
                self._madam_signal,
                periods,
                npsd,
                psdstarts,
                psdfreqs,
                psdvals,
            )

            if self._mcmode:
                self._cached = True
        return

    def _count_samples(self, data):
        """ Loop over the observations and count the number of samples.

        """
        if len(data.obs) != 1:
            nsamp = 0
            tod0 = data.obs[0]["tod"]
            detectors0 = tod0.local_dets
            for obs in data.obs:
                tod = obs["tod"]
                # For the moment, we require that all observations have
                # the same set of detectors
                detectors = tod.local_dets
                dets_are_same = True
                if len(detectors0) != len(detectors):
                    dets_are_same = False
                else:
                    for det1, det2 in zip(detectors0, detectors):
                        if det1 != det2:
                            dets_are_same = False
                            break
                if not dets_are_same:
                    raise RuntimeError(
                        "When calling Madam, all TOD assigned to a process "
                        "must have the same local detectors."
                    )
                nsamp += tod.local_samples[1]
        else:
            tod = data.obs[0]["tod"]
            nsamp = tod.local_samples[1]
        return nsamp

    def _get_period_ranges(self, comm, data, detectors, nsamp):
        """ Collect the ranges of every observation.

        """
        # Discard intervals that are too short to fit a baseline
        if "basis_order" in self.params:
            norder = int(self.params["basis_order"]) + 1
        else:
            norder = 1

        psdfreqs = None
        period_lengths = []
        obs_period_ranges = []

        for obs in data.obs:
            tod = obs["tod"]
            # Check that all noise objects have the same binning
            if self._noisekey in obs.keys():
                nse = obs[self._noisekey]
                if nse is not None:
                    if psdfreqs is None:
                        psdfreqs = nse.freq(detectors[0]).astype(np.float64).copy()
                    for det in detectors:
                        check_psdfreqs = nse.freq(det)
                        if not np.allclose(psdfreqs, check_psdfreqs):
                            raise RuntimeError(
                                "All PSDs passed to Madam must have"
                                " the same frequency binning."
                            )
            # Collect the valid intervals for this observation
            period_ranges = []
            if self._intervals in obs:
                intervals = obs[self._intervals]
            else:
                intervals = None
            local_intervals = tod.local_intervals(intervals)

            for ival in local_intervals:
                local_start = ival.first
                local_stop = ival.last + 1
                if local_stop - local_start < norder:
                    continue
                period_lengths.append(local_stop - local_start)
                period_ranges.append((local_start, local_stop))
            obs_period_ranges.append(period_ranges)

        # Update the number of samples based on the valid intervals

        nsamp_tot_full = comm.allreduce(nsamp, op=MPI.SUM)
        nperiod = len(period_lengths)
        period_lengths = np.array(period_lengths, dtype=np.int64)
        nsamp = np.sum(period_lengths, dtype=np.int64)
        nsamp_tot = comm.allreduce(nsamp, op=MPI.SUM)
        if nsamp_tot == 0:
            raise RuntimeError(
                "No samples in valid intervals: nsamp_tot_full = {}, "
                "nsamp_tot = {}".format(nsamp_tot_full, nsamp_tot)
            )
        if comm.rank == 0:
            print(
                "OpMadam: {:.2f} % of samples are included in valid "
                "intervals.".format(nsamp_tot * 100.0 / nsamp_tot_full)
            )

        # Madam expects starting indices, not period lengths
        periods = np.zeros(nperiod, dtype=np.int64)
        for i, n in enumerate(period_lengths[:-1]):
            periods[i + 1] = periods[i] + n

        return obs_period_ranges, psdfreqs, periods, nsamp

    def _prepare(self, data, comm):
        """ Examine the data object.

        """
        auto_timer = timing.auto_timer(type(self).__name__)

        nsamp = self._count_samples(data)

        # Determine the detectors and the pointing matrix non-zeros
        # from the first observation. Madam will expect these to remain
        # unchanged across observations.

        tod = data.obs[0]["tod"]

        if self._dets is None:
            dets = tod.local_dets
        else:
            dets = [det for det in tod.local_dets if det in self._dets]
        ndet = len(dets)

        # to get the number of Non-zero pointing weights per pixel,
        # we use the fact that for Madam, all processes have all detectors
        # for some slice of time.  So we can get this information from the
        # shape of the data from the first detector

        nnzname = "{}_{}".format(self._weights, dets[0])
        nnz_full = tod.cache.reference(nnzname).shape[1]

        if "temperature_only" in self.params and self.params["temperature_only"] in [
            "T",
            "True",
            "TRUE",
            "true",
            True,
        ]:
            if nnz_full not in [1, 3]:
                raise RuntimeError(
                    "OpMadam: Don't know how to make a temperature map "
                    "with nnz={}".format(nnz_full)
                )
            nnz = 1
            nnz_stride = nnz_full
        else:
            nnz = nnz_full
            nnz_stride = 1

        if "nside_map" not in self.params:
            raise RuntimeError(
                'OpMadam: "nside_map" must be set in the parameter dictionary'
            )
        nside = int(self.params["nside_map"])

        if comm.rank == 0 and (
            "path_output" in self.params
            and not os.path.isdir(self.params["path_output"])
        ):
            os.makedirs(self.params["path_output"])

        # Inspect the valid intervals across all observations to
        # determine the number of samples per detector

        obs_period_ranges, psdfreqs, periods, nsamp = self._get_period_ranges(
            comm, data, dets, nsamp
        )

        return (
            self.params,
            dets,
            nsamp,
            ndet,
            nnz,
            nnz_full,
            nnz_stride,
            periods,
            obs_period_ranges,
            psdfreqs,
            nside,
        )

    def _stage_time(self, data, detectors, nsamp, obs_period_ranges):
        """ Stage the timestamps and use them to build PSD inputs.

        """
        auto_timer = timing.auto_timer(type(self).__name__)
        self._madam_timestamps = self._cache.create(
            "timestamps", madam.TIMESTAMP_TYPE, (nsamp,)
        )

        offset = 0
        time_offset = 0
        psds = {}
        for iobs, obs in enumerate(data.obs):
            tod = obs["tod"]
            period_ranges = obs_period_ranges[iobs]

            # Collect the timestamps for the valid intervals
            timestamps = tod.local_times().copy()
            if self._translate_timestamps:
                # Translate the time stamps to be monotonous
                timestamps -= timestamps[0] - time_offset
                time_offset = timestamps[-1] + 1

            for istart, istop in period_ranges:
                nn = istop - istart
                ind = slice(offset, offset + nn)
                self._madam_timestamps[ind] = timestamps[istart:istop]
                offset += nn

            # get the noise object for this observation and create new
            # entries in the dictionary when the PSD actually changes
            if self._noisekey in obs.keys():
                nse = obs[self._noisekey]
                if "noise_scale" in obs:
                    noise_scale = obs["noise_scale"]
                else:
                    noise_scale = 1
                if nse is not None:
                    for det in detectors:
                        psd = nse.psd(det) * noise_scale ** 2
                        if det not in psds:
                            psds[det] = [(0, psd)]
                        else:
                            if not np.allclose(psds[det][-1][1], psd):
                                psds[det] += [(timestamps[0], psd)]

        return psds

    def _stage_signal(self, data, detectors, nsamp, ndet, obs_period_ranges):
        """ Stage signal

        """
        auto_timer = timing.auto_timer(type(self).__name__)
        self._madam_signal = self._cache.create(
            "signal", madam.SIGNAL_TYPE, (nsamp * ndet,)
        )
        self._madam_signal[:] = np.nan

        global_offset = 0
        for iobs, obs in enumerate(data.obs):
            tod = obs["tod"]
            period_ranges = obs_period_ranges[iobs]

            for idet, det in enumerate(detectors):
                # Get the signal.
                signal = tod.local_signal(det, self._name)
                signal_dtype = signal.dtype
                offset = global_offset
                for istart, istop in period_ranges:
                    nn = istop - istart
                    dslice = slice(idet * nsamp + offset, idet * nsamp + offset + nn)
                    self._madam_signal[dslice] = signal[istart:istop]
                    offset += nn

                del signal

            for idet, det in enumerate(detectors):
                if self._name is not None and (
                    self._purge_tod or self._name == self._name_out
                ):
                    cachename = "{}_{}".format(self._name, det)
                    tod.cache.clear(pattern=cachename)

            global_offset = offset

        return signal_dtype

    def _stage_pixels(self, data, detectors, nsamp, ndet, obs_period_ranges, nside):
        """ Stage pixels

        """
        auto_timer = timing.auto_timer(type(self).__name__)
        self._madam_pixels = self._cache.create(
            "pixels", madam.PIXEL_TYPE, (nsamp * ndet,)
        )
        self._madam_pixels[:] = -1

        global_offset = 0
        for iobs, obs in enumerate(data.obs):
            tod = obs["tod"]
            period_ranges = obs_period_ranges[iobs]

            commonflags = None
            for idet, det in enumerate(detectors):
                # Optionally get the flags, otherwise they are
                # assumed to have been applied to the pixel numbers.

                if self._apply_flags:
                    detflags = tod.local_flags(det, self._flag_name)
                    commonflags = tod.local_common_flags(self._common_flag_name)
                    flags = np.logical_or(
                        (detflags & self._flag_mask) != 0,
                        (commonflags & self._common_flag_mask) != 0,
                    )
                    del detflags

                # get the pixels for the valid intervals from the cache

                pixelsname = "{}_{}".format(self._pixels, det)
                pixels = tod.cache.reference(pixelsname)
                pixels_dtype = pixels.dtype

                if not self._pixels_nested:
                    # Madam expects the pixels to be in nested ordering
                    pixels = pixels.copy()
                    good = pixels >= 0
                    pixels[good] = hp.ring2nest(nside, pixels[good])

                if self._apply_flags:
                    pixels = pixels.copy()
                    pixels[flags] = -1

                offset = global_offset
                for istart, istop in period_ranges:
                    nn = istop - istart
                    dslice = slice(idet * nsamp + offset, idet * nsamp + offset + nn)
                    self._madam_pixels[dslice] = pixels[istart:istop]
                    offset += nn

                del pixels
                if self._apply_flags:
                    del flags

            # Always purge the pixels but restore them from the Madam
            # buffers when purge_pixels=False
            for idet, det in enumerate(detectors):
                pixelsname = "{}_{}".format(self._pixels, det)
                tod.cache.clear(pattern=pixelsname)
                if self._name is not None and (
                    self._purge_tod or self._name == self._name_out
                ):
                    cachename = "{}_{}".format(self._name, det)
                    tod.cache.clear(pattern=cachename)
                if self._purge_flags and self._flag_name is not None:
                    cacheflagname = "{}_{}".format(self._flag_name, det)
                    tod.cache.clear(pattern=cacheflagname)

            del commonflags
            if self._purge_flags and self._common_flag_name is not None:
                tod.cache.clear(pattern=self._common_flag_name)
            global_offset = offset

        return pixels_dtype

    def _stage_pixweights(
        self, data, detectors, nsamp, ndet, nnz, nnz_full, nnz_stride, obs_period_ranges
    ):
        """Now collect the pixel weights

        """
        auto_timer = timing.auto_timer(type(self).__name__)

        self._madam_pixweights = self._cache.create(
            "pixweights", madam.WEIGHT_TYPE, (nsamp * ndet * nnz,)
        )
        self._madam_pixweights[:] = 0

        global_offset = 0
        for iobs, obs in enumerate(data.obs):
            tod = obs["tod"]
            period_ranges = obs_period_ranges[iobs]
            for idet, det in enumerate(detectors):
                # get the pixels and weights for the valid intervals
                # from the cache
                weightsname = "{}_{}".format(self._weights, det)
                weights = tod.cache.reference(weightsname)
                weight_dtype = weights.dtype
                offset = global_offset
                for istart, istop in period_ranges:
                    nn = istop - istart
                    dwslice = slice(
                        (idet * nsamp + offset) * nnz,
                        (idet * nsamp + offset + nn) * nnz,
                    )
                    self._madam_pixweights[dwslice] = weights[istart:istop].flatten()[
                        ::nnz_stride
                    ]
                    offset += nn
                del weights
            # Purge the weights but restore them from the Madam
            # buffers when purge_weights=False.
            # Handle special case when Madam only stores a subset of
            # the weights.
            if not self._purge_weights and (nnz != nnz_full):
                pass
            else:
                for idet, det in enumerate(detectors):
                    # get the pixels and weights for the valid intervals
                    # from the cache
                    weightsname = "{}_{}".format(self._weights, det)
                    tod.cache.clear(pattern=weightsname)

            global_offset = offset

        return weight_dtype

    def _stage_data(
        self,
        data,
        comm,
        nsamp,
        ndet,
        nnz,
        nnz_full,
        nnz_stride,
        obs_period_ranges,
        psdfreqs,
        detectors,
        nside,
    ):
        """ create madam-compatible buffers

        Collect the TOD into Madam buffers. Process pixel weights
        Separate from the rest to reduce the memory high water mark
        When the user has set purge=True

        Moving data between toast and Madam buffers has an overhead.
        We perform the operation in a staggered fashion to have the
        overhead only once per node.

        """
        auto_timer = timing.auto_timer(type(self).__name__)

        if self._conserve_memory:
            # The user has elected to stagger staging the data on each
            # node to avoid exhausting memory
            nodecomm = comm.Split_type(MPI.COMM_TYPE_SHARED, comm.rank)
            if self._conserve_memory == 1:
                nread = nodecomm.size
            else:
                nread = min(self._conserve_memory, nodecomm.size)
        else:
            nodecomm = MPI.COMM_SELF
            nread = 1

        for iread in range(nread):
            nodecomm.Barrier()
            if nodecomm.rank % nread != iread:
                continue
            psds = self._stage_time(data, detectors, nsamp, obs_period_ranges)
            signal_dtype = self._stage_signal(
                data, detectors, nsamp, ndet, obs_period_ranges
            )
            pixels_dtype = self._stage_pixels(
                data, detectors, nsamp, ndet, obs_period_ranges, nside
            )
            weight_dtype = self._stage_pixweights(
                data,
                detectors,
                nsamp,
                ndet,
                nnz,
                nnz_full,
                nnz_stride,
                obs_period_ranges,
            )
        del nodecomm

        # detweights is either a dictionary of weights specified at
        # construction time, or else we use uniform weighting.
        detw = {}
        if self._detw is None:
            for idet, det in enumerate(detectors):
                detw[det] = 1.0
        else:
            detw = self._detw

        detweights = np.zeros(ndet, dtype=np.float64)
        for idet, det in enumerate(detectors):
            detweights[idet] = detw[det]

        if len(psds) > 0:
            npsdbin = len(psdfreqs)

            npsd = np.zeros(ndet, dtype=np.int64)
            psdstarts = []
            psdvals = []
            for idet, det in enumerate(detectors):
                if det not in psds:
                    raise RuntimeError("Every detector must have at least " "one PSD")
                psdlist = psds[det]
                npsd[idet] = len(psdlist)
                for psdstart, psd in psdlist:
                    psdstarts.append(psdstart)
                    psdvals.append(psd)
            npsdtot = np.sum(npsd)
            psdstarts = np.array(psdstarts, dtype=np.float64)
            psdvals = np.hstack(psdvals).astype(madam.PSD_TYPE)
            npsdval = psdvals.size
        else:
            npsd = np.ones(ndet, dtype=np.int64)
            npsdtot = np.sum(npsd)
            psdstarts = np.zeros(npsdtot)
            npsdbin = 10
            fsample = 10.0
            psdfreqs = np.arange(npsdbin) * fsample / npsdbin
            npsdval = npsdbin * npsdtot
            psdvals = np.ones(npsdval)
        psdinfo = (detweights, npsd, psdstarts, psdfreqs, psdvals)

        return psdinfo, signal_dtype, pixels_dtype, weight_dtype

    def _unstage_data(
        self,
        comm,
        data,
        nsamp,
        nnz,
        nnz_full,
        obs_period_ranges,
        detectors,
        signal_type,
        pixels_dtype,
        nside,
        weight_dtype,
    ):
        """ Clear Madam buffers, restore pointing into TOAST caches
        and cache the destriped signal.

        """
        auto_timer = timing.auto_timer(type(self).__name__)
        self._madam_timestamps = None
        self._cache.destroy("timestamps")

        if self._conserve_memory:
            nodecomm = comm.Split_type(MPI.COMM_TYPE_SHARED, comm.rank)
            nread = nodecomm.size
        else:
            nodecomm = MPI.COMM_SELF
            nread = 1

        for iread in range(nread):
            nodecomm.Barrier()
            if nodecomm.rank % nread != iread:
                continue
            if self._name_out is not None:
                global_offset = 0
                for obs, period_ranges in zip(data.obs, obs_period_ranges):
                    tod = obs["tod"]
                    nlocal = tod.local_samples[1]
                    for idet, det in enumerate(detectors):
                        signal = np.ones(nlocal, dtype=signal_type) * np.nan
                        offset = global_offset
                        for istart, istop in period_ranges:
                            nn = istop - istart
                            dslice = slice(
                                idet * nsamp + offset, idet * nsamp + offset + nn
                            )
                            signal[istart:istop] = self._madam_signal[dslice]
                            offset += nn
                        cachename = "{}_{}".format(self._name_out, det)
                        tod.cache.put(cachename, signal, replace=True)
                    global_offset = offset
            self._madam_signal = None
            self._cache.destroy("signal")

            if not self._purge_pixels:
                # restore the pixels from the Madam buffers
                global_offset = 0
                for obs, period_ranges in zip(data.obs, obs_period_ranges):
                    tod = obs["tod"]
                    nlocal = tod.local_samples[1]
                    for idet, det in enumerate(detectors):
                        pixels = -np.ones(nlocal, dtype=pixels_dtype)
                        offset = global_offset
                        for istart, istop in period_ranges:
                            nn = istop - istart
                            dslice = slice(
                                idet * nsamp + offset, idet * nsamp + offset + nn
                            )
                            pixels[istart:istop] = self._madam_pixels[dslice]
                            offset += nn
                        npix = 12 * nside ** 2
                        good = np.logical_and(pixels >= 0, pixels < npix)
                        if not self._pixels_nested:
                            pixels[good] = hp.nest2ring(nside, pixels[good])
                        pixels[np.logical_not(good)] = -1
                        cachename = "{}_{}".format(self._pixels, det)
                        tod.cache.put(cachename, pixels, replace=True)
                    global_offset = offset
            self._madam_pixels = None
            self._cache.destroy("pixels")

            if not self._purge_weights and nnz == nnz_full:
                # restore the weights from the Madam buffers
                global_offset = 0
                for obs, period_ranges in zip(data.obs, obs_period_ranges):
                    tod = obs["tod"]
                    nlocal = tod.local_samples[1]
                    for idet, det in enumerate(detectors):
                        weights = np.zeros([nlocal, nnz], dtype=weight_dtype)
                        offset = global_offset
                        for istart, istop in period_ranges:
                            nn = istop - istart
                            dwslice = slice(
                                (idet * nsamp + offset) * nnz,
                                (idet * nsamp + offset + nn) * nnz,
                            )
                            weights[istart:istop] = self._madam_pixweights[
                                dwslice
                            ].reshape([-1, nnz])
                            offset += nn
                        cachename = "{}_{}".format(self._weights, det)
                        tod.cache.put(cachename, weights, replace=True)
                    global_offset = offset
            self._madam_pixweights = None
            self._cache.destroy("pixweights")
        del nodecomm
        return
Exemple #13
0
class OpMappraiser(Operator):
    """
    Operator which passes data to libmappraiser for map-making.
    Args:
        params (dictionary): parameters to mappraiser
        detweights (dictionary): individual noise weights to use for each
            detector.
        pixels (str): the name of the cache object (<pixels>_<detector>)
            containing the pixel indices to use.
        pixels_nested (bool): Set to False if the pixel numbers are in
            ring ordering. Default is True.
        weights (str): the name of the cache object (<weights>_<detector>)
            containing the pointing weights to use.
        name (str): the name of the cache object (<name>_<detector>) to
            use for the detector timestream.  If None, use the TOD.
        noise_name (str) : the name of the cache object (<name>_<detector>) to
            use for the noise timestream. If None, skip.
        flag_name (str): the name of the cache object
            (<flag_name>_<detector>) to use for the detector flags.
            If None, use the TOD.
        flag_mask (int): the integer bit mask (0-255) that should be
            used with the detector flags in a bitwise AND.
        common_flag_name (str): the name of the cache object
            to use for the common flags.  If None, use the TOD.
        common_flag_mask (int): the integer bit mask (0-255) that should
            be used with the common flags in a bitwise AND.
        apply_flags (bool): whether to apply flags to the pixel numbers.
        purge (bool): if True, clear any cached data that is copied into
            the Mappraiser buffers.
        purge_tod (bool): if True, clear any cached signal that is
            copied into the Mappraiser buffers.
        purge_pixels (bool): if True, clear any cached pixels that are
            copied into the Mappraiser buffers.
        purge_weights (bool): if True, clear any cached weights that are
            copied into the Mappraiser buffers.
        purge_flags (bool): if True, clear any cached flags that are
            copied into the Mappraiser buffers.
        dets (iterable):  List of detectors to map. If left as None, all
            available detectors are mapped.
        noise (str): Keyword to use when retrieving the noise object
            from the observation.
        conserve_memory(bool/int): Stagger the Mappraiser buffer staging on node.
        translate_timestamps(bool): Translate timestamps to enforce
            monotonity.
    """
    def __init__(
        self,
        params={},
        detweights=None,
        pixels="pixels",
        pixels_nested=True,
        weights="weights",
        name="signal",
        noise_name=None,
        flag_name=None,
        flag_mask=255,
        common_flag_name=None,
        common_flag_mask=255,
        apply_flags=False,
        purge=False,
        dets=None,
        purge_tod=False,
        purge_pixels=False,
        purge_weights=False,
        purge_flags=False,
        noise="noise",
        intervals="intervals",
        conserve_memory=False,
        translate_timestamps=True,
    ):
        # Call the parent class constructor
        super().__init__()

        # mappraiser uses time-based distribution
        self._name = name
        self._noise_name = noise_name
        self._flag_name = flag_name
        self._flag_mask = flag_mask
        self._common_flag_name = common_flag_name
        self._common_flag_mask = common_flag_mask
        self._pixels = pixels
        self._pixels_nested = pixels_nested
        self._weights = weights
        self._detw = detweights
        self._purge = purge
        if self._purge:
            self._purge_tod = True
            self._purge_pixels = True
            self._purge_weights = True
            self._purge_flags = True
        else:
            self._purge_tod = purge_tod
            self._purge_pixels = purge_pixels
            self._purge_weights = purge_weights
            self._purge_flags = purge_flags
        self._apply_flags = apply_flags
        self._params = params
        if dets is not None:
            self._dets = set(dets)
        else:
            self._dets = None
        self._noisekey = noise
        self._intervals = intervals
        self._cache = Cache()
        self._mappraiser_timestamps = None
        self._mappraiser_noise = None
        self._mappraiser_pixels = None
        self._mappraiser_pixweights = None
        self._mappraiser_signal = None
        self._mappraiser_invtt = None
        self._conserve_memory = int(conserve_memory)
        self._translate_timestamps = translate_timestamps
        self._verbose = True

    def __del__(self):
        self._cache.clear()

    @property
    def available(self):
        """
        (bool): True if libmappraiser is found in the library search path.
        """
        return mappraiser is not None and mappraiser.available

    @function_timer
    def exec(self, data, comm=None):
        """
        Copy data to Mappraiser-compatible buffers and make a map.

        Args:
            data (toast.Data): The distributed data.

        Returns:
            None
        """
        if not self.available:
            raise RuntimeError("libmappraiser is not available")

        if len(data.obs) == 0:
            raise RuntimeError(
                "OpMappraiser requires every supplied data object to "
                "contain at least one observation")

        if comm is None:
            # Use the word communicator from the distributed data.
            comm = data.comm.comm_world
        self._data = data
        self._comm = comm
        self._rank = comm.rank

        (
            dets,
            nsamp,
            ndet,
            nnz,
            nnz_full,
            nnz_stride,
            psdfreqs,
            nside,
        ) = self._prepare()

        data_size_proc, nobsloc, local_blocks_sizes, signal_type, noise_type, pixels_dtype, weight_dtype = self._stage_data(
            nsamp,
            ndet,
            nnz,
            nnz_full,
            nnz_stride,
            psdfreqs,
            dets,
            nside,
        )

        self._MLmap(data_size_proc, nobsloc * ndet, local_blocks_sizes, nnz)

        self._unstage_data(
            nsamp,
            nnz,
            nnz_full,
            dets,
            signal_type,
            noise_type,
            pixels_dtype,
            nside,
            weight_dtype,
        )

        return

    @function_timer
    def _MLmap(self, data_size_proc, nb_blocks_loc, local_blocks_sizes, nnz):
        """ Compute the ML map
        """
        if self._verbose:
            memreport("just before calling libmappraiser.MLmap", self._comm)

        # Compute the Maximum Likelihood map
        # os.environ["OMP_NUM_THREADS"] = "1"
        mappraiser.MLmap(
            self._comm,
            self._params,
            data_size_proc,
            nb_blocks_loc,
            local_blocks_sizes,
            nnz,
            self._mappraiser_pixels,
            self._mappraiser_pixweights,
            self._mappraiser_signal,
            self._mappraiser_noise,
            self._params["Lambda"],
            self._mappraiser_invtt,
        )
        # os.environ["OMP_NUM_THREADS"] = "4"

        return

    def _count_samples(self):
        """ Loop over the observations and count the number of samples.

        """
        if len(self._data.obs) != 1:
            nsamp = 0
            tod0 = self._data.obs[0]["tod"]
            detectors0 = tod0.local_dets
            for obs in self._data.obs:
                tod = obs["tod"]
                # For the moment, we require that all observations have
                # the same set of detectors
                detectors = tod.local_dets
                dets_are_same = True
                if len(detectors0) != len(detectors):
                    dets_are_same = False
                else:
                    for det1, det2 in zip(detectors0, detectors):
                        if det1 != det2:
                            dets_are_same = False
                            break
                if not dets_are_same:
                    raise RuntimeError(
                        "When calling Mappraiser, all TOD assigned to a process "
                        "must have the same local detectors.")
                nsamp += tod.local_samples[1]
        else:
            tod = self._data.obs[0]["tod"]
            nsamp = tod.local_samples[1]
        return nsamp

    def _get_period_ranges(self, detectors):
        """ Collect the ranges of every observation.
        (This routine taken as is from Madam has been truncated, for now it is
        only extracting the frequency binning of the PSDs)
        """
        psdfreqs = None

        for obs in self._data.obs:
            tod = obs["tod"]
            # Check that all noise objects have the same binning
            if self._noisekey in obs.keys():
                nse = obs[self._noisekey]
                if nse is not None:
                    if psdfreqs is None:
                        psdfreqs = nse.freq(detectors[0]).astype(
                            np.float64).copy()
                    for det in detectors:
                        check_psdfreqs = nse.freq(det)
                        if not np.allclose(psdfreqs, check_psdfreqs):
                            raise RuntimeError(
                                "All PSDs passed to Mappraiser must have"
                                " the same frequency binning.")

        return psdfreqs

    def _psd2invtt(self, psdfreqs, psd):
        """ Generate the first rows of the Toeplitz blocks from the PSDs
        """
        # parameters
        sampling_freq = self._params["samplerate"]
        f_defl = sampling_freq / (np.pi * self._params["Lambda"])
        df = f_defl / 2
        block_size = 2**(int(math.log(sampling_freq * 1. / df, 2)) + 1)

        # Invert PSD
        psd_sim_m1 = np.reciprocal(psd)

        # Initialize full size inverse PSD in frequency domain
        fs = fftfreq(block_size, 1. / sampling_freq)
        psdm1 = np.zeros_like(fs)

        # Perform interpolation to get full size PSD from TOAST provided PSD
        tck = interpolate.splrep(psdfreqs, psd_sim_m1,
                                 s=0)  #s=0 : no smoothing
        psdfit = interpolate.splev(np.abs(fs[:int(block_size / 2) + 1]),
                                   tck,
                                   der=0)
        psdfit[0] = 0  #set offset noise contribution to zero
        psdm1[:int(block_size / 2)] = psdfit[:int(block_size / 2)]
        psdm1[int(block_size / 2):] = np.flip(psdfit[1:], 0)

        # Compute inverse noise autocorrelation functions
        inv_tt = np.real(np.fft.ifft(psdm1, n=block_size))

        # Define apodization window
        window = scipy.signal.gaussian(2 * self._params["Lambda"],
                                       1. / 2 * self._params["Lambda"])
        window = np.fft.ifftshift(window)
        window = window[:self._params["Lambda"]]
        window = np.pad(window,
                        (0, int(block_size / 2 - (self._params["Lambda"]))),
                        'constant')
        symw = np.zeros(block_size)
        symw[:int(block_size / 2)] = window
        symw[int(block_size / 2):] = np.flip(window, 0)

        inv_tt_w = np.multiply(symw, inv_tt, dtype=mappraiser.INVTT_TYPE)

        return inv_tt_w[:self._params["Lambda"]]

    def _noise2invtt(self, noise, nn, idet):
        """ Computes a periodogram from a noise timestream, and fits a PSD model
        to it, which is then used to build the first row of a Toeplitz block.
        """
        # parameters
        sampling_freq = self._params["samplerate"]
        Max_lambda = 2**(int(math.log(
            nn / 4,
            2)))  # closest power of two to 1/4 of the timestream length
        f_defl = sampling_freq / (np.pi * Max_lambda)
        df = f_defl / 2
        block_size = 2**(int(math.log(sampling_freq * 1. / df, 2)))

        # Compute periodogram
        f, psd = scipy.signal.periodogram(noise,
                                          sampling_freq,
                                          nfft=block_size,
                                          window='blackman')
        # if idet==37:
        #     print(len(f), flush=True)

        # Fit the psd model to the periodogram (in log scale)
        popt, pcov = curve_fit(logpsd_model,
                               f[1:],
                               np.log10(psd[1:]),
                               p0=np.array([-7, -1.0, 0.1, 0.]),
                               bounds=([-20, -10, 0., 0.], [0., 0., 10,
                                                            0.001]),
                               maxfev=1000)

        if self._rank == 0 and idet == 0:
            print(
                "\n[det " + str(idet) +
                "]: PSD fit log(sigma2) = %1.2f, alpha = %1.2f, fknee = %1.2f, fmin = %1.2f\n"
                % tuple(popt),
                flush=True)
            print("[det " + str(idet) + "]: PSD fit covariance: \n",
                  pcov,
                  flush=True)
        # psd_fit_m1 = np.zeros_like(f)
        # psd_fit_m1[1:] = inversepsd_model(f[1:],10**popt[0],popt[1],popt[2])

        # Invert periodogram
        psd_sim_m1 = np.reciprocal(psd)
        # if self._rank == 0 and idet == 0:
        # np.save("psd_sim.npy",psd_sim_m1)
        # psd_sim_m1_log = np.log10(psd_sim_m1)

        # Invert the fit to the psd model / Fit the inverse psd model to the inverted periodogram
        # popt,pcov = curve_fit(inverselogpsd_model,f[1:],psd_sim_m1_log[1:])
        # print(popt)
        # print(pcov)
        psd_fit_m1 = np.zeros_like(f)
        psd_fit_m1[1:] = inversepsd_model(f[1:], 10**(-popt[0]), popt[1],
                                          popt[2], popt[3])

        # Initialize full size inverse PSD in frequency domain
        fs = fftfreq(block_size, 1. / sampling_freq)
        psdm1 = np.zeros_like(fs)

        # Symmetrize inverse PSD according to fs shape
        psdm1[:int(block_size /
                   2)] = psd_fit_m1[:-1]  #psdfit[:int(block_size/2)]
        psdm1[int(block_size / 2):] = np.flip(psd_fit_m1[1:], 0)

        # Compute inverse noise autocorrelation functions
        inv_tt = np.real(np.fft.ifft(psdm1, n=block_size))

        # Define apodization window
        window = scipy.signal.gaussian(2 * self._params["Lambda"],
                                       1. / 2 * self._params["Lambda"])
        window = np.fft.ifftshift(window)
        window = window[:self._params["Lambda"]]
        window = np.pad(window,
                        (0, int(block_size / 2 - (self._params["Lambda"]))),
                        'constant')
        symw = np.zeros(block_size)
        symw[:int(block_size / 2)] = window
        symw[int(block_size / 2):] = np.flip(window, 0)

        inv_tt_w = np.multiply(symw, inv_tt, dtype=mappraiser.INVTT_TYPE)

        #effective inverse noise power
        # if self._rank == 0 and idet == 0:
        # psd = np.abs(np.fft.fft(inv_tt_w,n=block_size))
        # np.save("freq.npy",fs[:int(block_size/2)])
        # np.save("psd0.npy",psdm1[:int(block_size/2)])
        # np.save("psd"+str(self._params["Lambda"])+".npy",psd[:int(block_size/2)])

        return inv_tt_w[:self._params[
            "Lambda"]]  #, popt[0], popt[1], popt[2], popt[3]

    @function_timer
    def _prepare(self):
        """ Examine the data object.

        """
        log = Logger.get()
        timer = Timer()
        timer.start()

        nsamp = self._count_samples()

        # Determine the detectors and the pointing matrix non-zeros
        # from the first observation. Mappraiser will expect these to remain
        # unchanged across observations.

        tod = self._data.obs[0]["tod"]

        if self._dets is None:
            dets = tod.local_dets
        else:
            dets = [det for det in tod.local_dets if det in self._dets]
        ndet = len(dets)

        # We get the number of Non-zero pointing weights per pixel, from the
        # shape of the data from the first detector

        nnzname = "{}_{}".format(self._weights, dets[0])
        nnz_full = tod.cache.reference(nnzname).shape[1]

        if nnz_full != 3:
            raise RuntimeError("OpMappraiser: Don't know how to make a map "
                               "with nnz={}".format(nnz_full))
            nnz = 3
            nnz_stride = 1
        else:
            nnz = nnz_full
            nnz_stride = 1

        if "nside" not in self._params:
            raise RuntimeError(
                'OpMappraiser: "nside" must be set in the parameter dictionary'
            )
        nside = int(self._params["nside"])

        # Inspect the valid intervals across all observations to
        # determine the number of samples per detector
        # N.B: Above comment is from OpMadam, for now
        # it only gives frequency binning of the PSDs
        psdfreqs = self._get_period_ranges(dets)

        self._comm.Barrier()
        if self._rank == 0 and self._verbose:
            timer.report_clear("Collect dataset dimensions")

        return (
            dets,
            nsamp,
            ndet,
            nnz,
            nnz_full,
            nnz_stride,
            psdfreqs,
            nside,
        )

    @function_timer
    def _stage_time(self, detectors, nsamp, psdfreqs):
        """ Stage the timestamps and use them to build PSD inputs.
        N.B: timestamps are not currently used in MAPPRAISER, however, this may
        change in the future. At this stage, the routine builds the time-domain
        Toeplitz blocks inputs (first rows).
        """
        # self._mappraiser_timestamps = self._cache.create(
        #     "timestamps", mappraiser.TIMESTAMP_TYPE, (nsamp,)
        # )

        # offset = 0
        # time_offset = 0
        # psds = {}
        invtt_list = []
        for iobs, obs in enumerate(self._data.obs):
            tod = obs["tod"]
            # period_ranges = obs_period_ranges[iobs]

            # # Collect the timestamps for the valid intervals
            # timestamps = tod.local_times().copy()
            # if self._translate_timestamps:
            #     # Translate the time stamps to be monotonous
            #     timestamps -= timestamps[0] - time_offset
            #     time_offset = timestamps[-1] + 1
            #
            # for istart, istop in period_ranges:
            #     nn = istop - istart
            #     ind = slice(offset, offset + nn)
            #     self._mappraiser_timestamps[ind] = timestamps[istart:istop]
            #     offset += nn

            # get the noise object for this observation and create new
            # entries in the dictionary when the PSD actually changes
            if self._noisekey in obs.keys():
                nse = obs[self._noisekey]
                if "noise_scale" in obs:
                    noise_scale = obs["noise_scale"]
                else:
                    noise_scale = 1
                if nse is not None:
                    for det in detectors:
                        psd = nse.psd(det) * noise_scale**2
                        invtt = self._psd2invtt(psdfreqs, psd)
                        # if det not in psds:
                        #     psds[det] = [(0, psd)]
                        # else:
                        #     if not np.allclose(psds[det][-1][1], psd):
                        #         psds[det] += [(timestamps[0], psd)]

        return invtt_list

    @function_timer
    def _stage_signal(self, detectors, nsamp, ndet, nodecomm, nread):
        """ Stage signal
        """
        log = Logger.get()
        timer = Timer()
        # Determine if we can purge the signal and avoid keeping two
        # copies in memory
        purge = self._name is not None and self._purge_tod
        if not purge:
            nread = 1
            nodecomm = MPI.COMM_SELF

        for iread in range(nread):
            nodecomm.Barrier()
            timer.start()
            if nodecomm.rank % nread == iread:
                self._mappraiser_signal = self._cache.create(
                    "signal", mappraiser.SIGNAL_TYPE, (nsamp * ndet, ))
                self._mappraiser_signal[:] = np.nan

                global_offset = 0
                local_blocks_sizes = []
                for iobs, obs in enumerate(self._data.obs):
                    tod = obs["tod"]

                    for idet, det in enumerate(detectors):
                        # Get the signal.
                        signal = tod.local_signal(det, self._name)
                        signal_dtype = signal.dtype
                        offset = global_offset
                        local_V_size = len(signal)
                        dslice = slice(idet * nsamp + offset,
                                       idet * nsamp + offset + local_V_size)
                        self._mappraiser_signal[dslice] = signal
                        offset += local_V_size
                        local_blocks_sizes.append(local_V_size)

                        del signal
                    # Purge only after all detectors are staged in case some are aliased
                    # cache.clear() will not fail if the object was already
                    # deleted as an alias
                    if purge:
                        for det in detectors:
                            cachename = "{}_{}".format(self._name, det)
                            tod.cache.clear(cachename)
                    global_offset = offset

                local_blocks_sizes = np.array(local_blocks_sizes,
                                              dtype=np.int32)
            if self._verbose and nread > 1:
                nodecomm.Barrier()
                if self._rank == 0:
                    timer.report_clear("Stage signal {} / {}".format(
                        iread + 1, nread))

        return signal_dtype, local_blocks_sizes

    @function_timer
    def _stage_noise(self, detectors, nsamp, ndet, nodecomm, nread):
        """ Stage noise timestream (detector noise + atmosphere)
        """
        log = Logger.get()
        timer = Timer()
        # Determine if we can purge the signal and avoid keeping two
        # copies in memory
        purge = self._noise_name is not None and self._purge_tod
        if not purge:
            nread = 1
            nodecomm = MPI.COMM_SELF

        for iread in range(nread):
            nodecomm.Barrier()
            timer.start()
            if nodecomm.rank % nread == iread:
                self._mappraiser_noise = self._cache.create(
                    "noise", mappraiser.SIGNAL_TYPE, (nsamp * ndet, ))
                if self._noise_name == None:
                    self._mappraiser_noise = np.zeros_like(
                        self._mappraiser_noise)
                    invtt_list = []
                    for i in range(len(self._data.obs) * len(detectors)):
                        invtt_list.append(
                            np.ones(1))  #Must be used with lambda = 1
                    return invtt_list, self._mappraiser_noise.dtype

                self._mappraiser_noise[:] = np.nan

                global_offset = 0
                invtt_list = []
                # fknee_list = []
                # fmin_list = []
                # alpha_list = []
                # logsigma2_list = []
                for iobs, obs in enumerate(self._data.obs):
                    tod = obs["tod"]

                    for idet, det in enumerate(detectors):
                        # Get the signal.
                        noise = tod.local_signal(det, self._noise_name)
                        # if self._rank ==0 and idet == 0:
                        #     print("|noise| = {}".format(np.sum(noise**2)))
                        noise_dtype = noise.dtype
                        offset = global_offset
                        nn = len(noise)
                        invtt = self._noise2invtt(
                            noise, nn, idet
                        )  #, logsigma2, alpha, fknee, fmin = self._noise2invtt(noise, nn, idet)
                        invtt_list.append(invtt)
                        # logsigma2_list.append(logsigma2)
                        # alpha_list.append(alpha)
                        # fknee_list.append(fknee)
                        # fmin_list.append(fmin)
                        dslice = slice(idet * nsamp + offset,
                                       idet * nsamp + offset + nn)
                        self._mappraiser_noise[dslice] = noise
                        offset += nn

                        del noise
                    # Purge only after all detectors are staged in case some are aliased
                    # cache.clear() will not fail if the object was already
                    # deleted as an alias
                    if purge:
                        for det in detectors:
                            cachename = "{}_{}".format(self._noise_name, det)
                            tod.cache.clear(cachename)
                    global_offset = offset
            if self._verbose and nread > 1:
                nodecomm.Barrier()
                if self._rank == 0:
                    timer.report_clear("Stage noise {} / {}".format(
                        iread + 1, nread))

        # sendcounts = np.array(self._comm.gather(len(fknee_list), 0))
        #
        # Fknee_list = None
        # Fmin_list = None
        # Alpha_list = None
        # Logsigma2_list = None
        #
        # if self._rank ==0:
        #     Fknee_list = np.empty(sum(sendcounts))
        #     Fmin_list = np.empty(sum(sendcounts))
        #     Alpha_list = np.empty(sum(sendcounts))
        #     Logsigma2_list = np.empty(sum(sendcounts))
        #
        # self._comm.Gatherv(np.array(fknee_list),(Fknee_list,sendcounts),0)
        # self._comm.Gatherv(np.array(fmin_list),(Fmin_list, sendcounts),0)
        # self._comm.Gatherv(np.array(alpha_list),(Alpha_list, sendcounts),0)
        # self._comm.Gatherv(np.array(logsigma2_list),(Logsigma2_list, sendcounts),0)

        # if self._rank ==0:
        #     np.save("fknee.npy",Fknee_list)
        #     np.save("fmin.npy", Fmin_list)
        #     np.save("logsigma2.npy",Logsigma2_list)
        #     np.save("alpha.npy",Alpha_list)

        return invtt_list, noise_dtype

    @function_timer
    def _stage_pixels(self, detectors, nsamp, ndet, nnz, nside):
        """ Stage pixels
        """
        self._mappraiser_pixels = self._cache.create("pixels",
                                                     mappraiser.PIXEL_TYPE,
                                                     (nsamp * ndet * nnz, ))
        self._mappraiser_pixels[:] = -1

        global_offset = 0
        for iobs, obs in enumerate(self._data.obs):
            tod = obs["tod"]

            commonflags = None
            for idet, det in enumerate(detectors):
                # Optionally get the flags, otherwise they are
                # assumed to have been applied to the pixel numbers.
                # N.B: MAPPRAISER doesn't use flags for now but might be useful for
                # future updates.

                if self._apply_flags:
                    detflags = tod.local_flags(det, self._flag_name)
                    commonflags = tod.local_common_flags(
                        self._common_flag_name)
                    flags = np.logical_or(
                        (detflags & self._flag_mask) != 0,
                        (commonflags & self._common_flag_mask) != 0,
                    )
                    del detflags

                # get the pixels for the valid intervals from the cache

                pixelsname = "{}_{}".format(self._pixels, det)
                pixels = tod.cache.reference(pixelsname)
                pixels_dtype = pixels.dtype

                if not self._pixels_nested:
                    # Madam expects the pixels to be in nested ordering.
                    # This is not the case for Mappraiser but keeping it for now
                    pixels = pixels.copy()
                    good = pixels >= 0
                    pixels[good] = hp.ring2nest(nside, pixels[good])

                if self._apply_flags:
                    pixels = pixels.copy()
                    pixels[flags] = -1

                offset = global_offset
                nn = len(pixels)
                dslice = slice(
                    (idet * nsamp + offset) * nnz,
                    (idet * nsamp + offset + nn) * nnz,
                )
                # nnz = 3 is a mandatory assumption here (could easily be generalized ...)
                self._mappraiser_pixels[dslice] = nnz * np.repeat(pixels, nnz)
                self._mappraiser_pixels[dslice][1::nnz] += 1
                self._mappraiser_pixels[dslice][2::nnz] += 2
                offset += nn

                del pixels
                if self._apply_flags:
                    del flags

            # Always purge the pixels but restore them from the Mappraiser
            # buffers when purge_pixels = False
            # Purging MUST happen after all detectors are staged because
            # some the pixel numbers may be aliased between detectors
            for det in detectors:
                pixelsname = "{}_{}".format(self._pixels, det)
                # cache.clear() will not fail if the object was already
                # deleted as an alias
                tod.cache.clear(pixelsname)
                if self._purge_flags and self._flag_name is not None:
                    cacheflagname = "{}_{}".format(self._flag_name, det)
                    tod.cache.clear(cacheflagname)

            del commonflags
            if self._purge_flags and self._common_flag_name is not None:
                tod.cache.clear(self._common_flag_name)
            global_offset = offset
        return pixels_dtype

    @function_timer
    def _stage_pixweights(
        self,
        detectors,
        nsamp,
        ndet,
        nnz,
        nnz_full,
        nnz_stride,
        nodecomm,
        nread,
    ):
        """Now collect the pixel weights
        """
        log = Logger.get()
        timer = Timer()
        # Determine if we can purge the pixel weights and avoid keeping two
        # copies of the weights in memory
        purge = self._purge_weights or (nnz == nnz_full)
        if not purge:
            nread = 1
            nodecomm = MPI.COMM_SELF
        for iread in range(nread):
            nodecomm.Barrier()
            timer.start()
            if nodecomm.rank % nread == iread:
                self._mappraiser_pixweights = self._cache.create(
                    "pixweights", mappraiser.WEIGHT_TYPE,
                    (nsamp * ndet * nnz, ))
                self._mappraiser_pixweights[:] = 0

                global_offset = 0
                for iobs, obs in enumerate(self._data.obs):
                    tod = obs["tod"]

                    for idet, det in enumerate(detectors):
                        # get the pixels and weights for the valid intervals
                        # from the cache
                        weightsname = "{}_{}".format(self._weights, det)
                        weights = tod.cache.reference(weightsname)
                        weight_dtype = weights.dtype
                        offset = global_offset
                        nn = len(weights)
                        dwslice = slice(
                            (idet * nsamp + offset) * nnz,
                            (idet * nsamp + offset + nn) * nnz,
                        )
                        self._mappraiser_pixweights[dwslice] = weights.flatten(
                        )[::nnz_stride]
                        offset += nn
                        del weights
                    # Purge the weights but restore them from the Mappraiser
                    # buffers when purge_weights=False.
                    if purge:
                        for idet, det in enumerate(detectors):
                            weightsname = "{}_{}".format(self._weights, det)
                            tod.cache.clear(pattern=weightsname)

                    global_offset = offset
            if self._verbose and nread > 1:
                nodecomm.Barrier()
                if self._rank == 0:
                    timer.report_clear("Stage pixel weights {} / {}".format(
                        iread + 1, nread))
        return weight_dtype

    @function_timer
    def _stage_data(
        self,
        nsamp,
        ndet,
        nnz,
        nnz_full,
        nnz_stride,
        psdfreqs,
        detectors,
        nside,
    ):
        """ create Mappraiser-compatible buffers

        Collect the TOD into Mappraiser buffers. Process pixel weights
        Separate from the rest to reduce the memory high water mark
        When the user has set purge=True

        Moving data between toast and Mappraiser buffers has an overhead.
        We perform the operation in a staggered fashion to have the
        overhead only once per node.
        """
        log = Logger.get()
        nodecomm = self._comm.Split_type(MPI.COMM_TYPE_SHARED, self._rank)
        # Check if the user has elected to stagger staging the data on each
        # node to avoid exhausting memory
        if self._conserve_memory:
            if self._conserve_memory == 1:
                nread = nodecomm.size
            else:
                nread = min(self._conserve_memory, nodecomm.size)
        else:
            nread = 1

        self._comm.Barrier()
        timer_tot = Timer()
        timer_tot.start()

        # Stage time (Tpltz blocks in Mappraiser), it is never purged
        # so the staging is never stepped
        timer = Timer()
        # THIS STEP IS SKIPPED: we do not have timestamps, nor do we build Toeplitz blocks
        # from TOAST psds which comprise detector noise only - a psd fit is done when staging noise -
        #timer.start()
        #invtt_list = self._stage_time(detectors, nsamp, psdfreqs)
        #self._mappraiser_invtt = np.array([np.array(invtt_i, dtype= mappraiser.INVTT_TYPE) for invtt_i in invtt_list])
        #del invtt_list
        #self._mappraiser_invtt = np.concatenate(self._mappraiser_invtt)
        #if self._verbose:
        #    nodecomm.Barrier()
        #    if self._rank == 0:
        #        timer.report_clear("Stage time")
        #memreport("after staging time", self._comm)  # DEBUG
        #count_caches(
        #    self._data, self._comm, nodecomm, self._cache, "after staging time"
        #)  # DEBUG

        # Stage signal.  If signal is not being purged, staging is not stepped
        timer.start()
        signal_dtype, local_blocks_sizes = self._stage_signal(
            detectors, nsamp, ndet, nodecomm, nread)
        if self._verbose:
            nodecomm.Barrier()
            if self._rank == 0:
                timer.report_clear("Stage signal")
        memreport("after staging signal", self._comm)  # DEBUG
        count_caches(self._data, self._comm, nodecomm, self._cache,
                     "after staging signal")  # DEBUG

        # Stage noise.  If noise is not being purged, staging is not stepped
        timer.start()
        invtt_list, noise_dtype = self._stage_noise(detectors, nsamp, ndet,
                                                    nodecomm, nread)
        self._mappraiser_invtt = np.array([
            np.array(invtt_i, dtype=mappraiser.INVTT_TYPE)
            for invtt_i in invtt_list
        ])
        del invtt_list
        self._mappraiser_invtt = np.concatenate(self._mappraiser_invtt)
        if self._params["uniform_w"] == 1:
            self._mappraiser_invtt = np.ones_like(self._mappraiser_invtt)
        if self._verbose:
            nodecomm.Barrier()
            if self._rank == 0:
                timer.report_clear("Stage noise")
        memreport("after staging noise", self._comm)  # DEBUG
        count_caches(self._data, self._comm, nodecomm, self._cache,
                     "after staging noise")  # DEBUG

        # Stage pixels
        timer_step = Timer()
        timer_step.start()
        for iread in range(nread):
            nodecomm.Barrier()
            timer.start()
            if nodecomm.rank % nread == iread:
                pixels_dtype = self._stage_pixels(detectors, nsamp, ndet, nnz,
                                                  nside)
            if self._verbose and nread > 1:
                nodecomm.Barrier()
                if self._rank == 0:
                    timer.report_clear("Stage pixels {} / {}".format(
                        iread + 1, nread))
        if self._verbose:
            nodecomm.Barrier()
            if self._rank == 0:
                timer_step.report_clear("Stage pixels")
        memreport("after staging pixels", self._comm)  # DEBUG
        count_caches(self._data, self._comm, nodecomm, self._cache,
                     "after staging pixels")  # DEBUG

        # Stage pixel weights
        timer_step.start()
        weight_dtype = self._stage_pixweights(
            detectors,
            nsamp,
            ndet,
            nnz,
            nnz_full,
            nnz_stride,
            nodecomm,
            nread,
        )
        if self._verbose:
            nodecomm.Barrier()
            if self._rank == 0:
                timer_step.report_clear("Stage pixel weights")
        memreport("after staging pixel weights", self._comm)  # DEBUG
        count_caches(self._data, self._comm, nodecomm, self._cache,
                     "after staging pixel weights")  # DEBUG

        del nodecomm
        if self._rank == 0 and self._verbose:
            timer_tot.report_clear("Stage all data")

        # detweights is either a dictionary of weights specified at
        # construction time, or else we use uniform weighting.
        # N.B: This is essentially useless in current implementation
        detw = {}
        if self._detw is None:
            for idet, det in enumerate(detectors):
                detw[det] = 1.0
        else:
            detw = self._detw

        detweights = np.zeros(ndet, dtype=np.float64)
        for idet, det in enumerate(detectors):
            detweights[idet] = detw[det]

        # Get global array of data sizes of the full communicator
        data_size_proc = np.array(self._comm.allgather(
            len(self._mappraiser_signal)),
                                  dtype=np.int32)
        # Get number of local observations
        nobsloc = len(self._data.obs)

        return data_size_proc, nobsloc, local_blocks_sizes, signal_dtype, noise_dtype, pixels_dtype, weight_dtype

    @function_timer
    def _unstage_signal(self, detectors, nsamp, signal_type):
        # N.B: useful when we want to get back data after mapmaking, not allowed for now
        # if self._name_out is not None:
        #     global_offset = 0
        #     for obs, period_ranges in zip(self._data.obs, obs_period_ranges):
        #         tod = obs["tod"]
        #         nlocal = tod.local_samples[1]
        #         for idet, det in enumerate(detectors):
        #             signal = np.ones(nlocal, dtype=signal_type) * np.nan
        #             offset = global_offset
        #             for istart, istop in period_ranges:
        #                 nn = istop - istart
        #                 dslice = slice(
        #                     idet * nsamp + offset, idet * nsamp + offset + nn
        #                 )
        #                 signal[istart:istop] = self._madam_signal[dslice]
        #                 offset += nn
        #             cachename = "{}_{}".format(self._name_out, det)
        #             tod.cache.put(cachename, signal, replace=True)
        #         global_offset = offset
        self._mappraiser_signal = None
        self._cache.destroy("signal")
        return

    @function_timer
    def _unstage_noise(self, detectors, nsamp, noise_type):
        # N.B: useful when we want to get back data after mapmaking, not allowed for now
        # if self._name_out is not None:
        #     global_offset = 0
        #     for obs, period_ranges in zip(self._data.obs, obs_period_ranges):
        #         tod = obs["tod"]
        #         nlocal = tod.local_samples[1]
        #         for idet, det in enumerate(detectors):
        #             signal = np.ones(nlocal, dtype=signal_type) * np.nan
        #             offset = global_offset
        #             for istart, istop in period_ranges:
        #                 nn = istop - istart
        #                 dslice = slice(
        #                     idet * nsamp + offset, idet * nsamp + offset + nn
        #                 )
        #                 signal[istart:istop] = self._madam_signal[dslice]
        #                 offset += nn
        #             cachename = "{}_{}".format(self._name_out, det)
        #             tod.cache.put(cachename, signal, replace=True)
        #         global_offset = offset
        self._mappraiser_noise = None
        self._cache.destroy("noise")
        return

    @function_timer
    def _unstage_pixels(self, detectors, nsamp, pixels_dtype, nside):
        # N.B: useful when we want to get back data after mapmaking, not allowed for now
        # if not self._purge_pixels:
        #     # restore the pixels from the Madam buffers
        #     global_offset = 0
        #     for obs, period_ranges in zip(self._data.obs, obs_period_ranges):
        #         tod = obs["tod"]
        #         nlocal = tod.local_samples[1]
        #         for idet, det in enumerate(detectors):
        #             pixels = -(np.ones(nlocal, dtype=pixels_dtype))
        #             offset = global_offset
        #             for istart, istop in period_ranges:
        #                 nn = istop - istart
        #                 dslice = slice(
        #                     idet * nsamp + offset, idet * nsamp + offset + nn
        #                 )
        #                 pixels[istart:istop] = self._madam_pixels[dslice]
        #                 offset += nn
        #             npix = 12 * nside ** 2
        #             good = np.logical_and(pixels >= 0, pixels < npix)
        #             if not self._pixels_nested:
        #                 pixels[good] = hp.nest2ring(nside, pixels[good])
        #             pixels[np.logical_not(good)] = -1
        #             cachename = "{}_{}".format(self._pixels, det)
        #             tod.cache.put(cachename, pixels, replace=True)
        #         global_offset = offset
        self._mappraiser_pixels = None
        self._cache.destroy("pixels")
        return

    @function_timer
    def _unstage_pixweights(self, detectors, nsamp, weight_dtype, nnz,
                            nnz_full):
        # N.B: useful when we want to get back data after mapmaking, not allowed for now
        # if not self._purge_weights and nnz == nnz_full:
        #     # restore the weights from the Madam buffers
        #     global_offset = 0
        #     for obs, period_ranges in zip(self._data.obs, obs_period_ranges):
        #         tod = obs["tod"]
        #         nlocal = tod.local_samples[1]
        #         for idet, det in enumerate(detectors):
        #             weights = np.zeros([nlocal, nnz], dtype=weight_dtype)
        #             offset = global_offset
        #             for istart, istop in period_ranges:
        #                 nn = istop - istart
        #                 dwslice = slice(
        #                     (idet * nsamp + offset) * nnz,
        #                     (idet * nsamp + offset + nn) * nnz,
        #                 )
        #                 weights[istart:istop] = self._madam_pixweights[dwslice].reshape(
        #                     [-1, nnz]
        #                 )
        #                 offset += nn
        #             cachename = "{}_{}".format(self._weights, det)
        #             tod.cache.put(cachename, weights, replace=True)
        #         global_offset = offset
        self._mappraiser_pixweights = None
        self._cache.destroy("pixweights")
        return

    def _unstage_data(
        self,
        nsamp,
        nnz,
        nnz_full,
        detectors,
        signal_type,
        noise_type,
        pixels_dtype,
        nside,
        weight_dtype,
    ):
        """ Clear Mappraiser buffers, [restore pointing into TOAST caches-> not done currently].
        """
        log = Logger.get()
        # self._mappraiser_timestamps = None
        # self._cache.destroy("timestamps")

        if self._conserve_memory:
            nodecomm = self._comm.Split_type(MPI.COMM_TYPE_SHARED, self._rank)
            nread = nodecomm.size
        else:
            nodecomm = MPI.COMM_SELF
            nread = 1

        self._comm.Barrier()
        timer_tot = Timer()
        timer_tot.start()
        for iread in range(nread):
            timer_step = Timer()
            timer_step.start()
            timer = Timer()
            timer.start()
            if nodecomm.rank % nread == iread:
                self._unstage_signal(detectors, nsamp, signal_type)
            if self._verbose:
                nodecomm.Barrier()
                if self._rank == 0:
                    timer.report_clear("Unstage signal {} / {}".format(
                        iread + 1, nread))
            if nodecomm.rank % nread == iread:
                self._unstage_noise(detectors, nsamp, noise_type)
            if self._verbose:
                nodecomm.Barrier()
                if self._rank == 0:
                    timer.report_clear("Unstage noise {} / {}".format(
                        iread + 1, nread))
            if nodecomm.rank % nread == iread:
                self._unstage_pixels(detectors, nsamp, pixels_dtype, nside)
            if self._verbose:
                nodecomm.Barrier()
                if self._rank == 0:
                    timer.report_clear("Unstage pixels {} / {}".format(
                        iread + 1, nread))
            if nodecomm.rank % nread == iread:
                self._unstage_pixweights(detectors, nsamp, weight_dtype, nnz,
                                         nnz_full)
            nodecomm.Barrier()
            if self._verbose and self._rank == 0:
                timer.report_clear("Unstage pixel weights {} / {}".format(
                    iread + 1, nread))
            if self._rank == 0 and self._verbose and nread > 1:
                timer_step.report_clear("Unstage data {} / {}".format(
                    iread + 1, nread))
        self._comm.Barrier()
        if self._rank == 0 and self._verbose:
            timer_tot.report_clear("Unstage all data")

        del nodecomm
        return
Exemple #14
0
    def __init__(
        self,
        params={},
        detweights=None,
        pixels="pixels",
        pixels_nested=True,
        weights="weights",
        name="signal",
        noise_name=None,
        flag_name=None,
        flag_mask=255,
        common_flag_name=None,
        common_flag_mask=255,
        apply_flags=False,
        purge=False,
        dets=None,
        purge_tod=False,
        purge_pixels=False,
        purge_weights=False,
        purge_flags=False,
        noise="noise",
        intervals="intervals",
        conserve_memory=False,
        translate_timestamps=True,
    ):
        # Call the parent class constructor
        super().__init__()

        # mappraiser uses time-based distribution
        self._name = name
        self._noise_name = noise_name
        self._flag_name = flag_name
        self._flag_mask = flag_mask
        self._common_flag_name = common_flag_name
        self._common_flag_mask = common_flag_mask
        self._pixels = pixels
        self._pixels_nested = pixels_nested
        self._weights = weights
        self._detw = detweights
        self._purge = purge
        if self._purge:
            self._purge_tod = True
            self._purge_pixels = True
            self._purge_weights = True
            self._purge_flags = True
        else:
            self._purge_tod = purge_tod
            self._purge_pixels = purge_pixels
            self._purge_weights = purge_weights
            self._purge_flags = purge_flags
        self._apply_flags = apply_flags
        self._params = params
        if dets is not None:
            self._dets = set(dets)
        else:
            self._dets = None
        self._noisekey = noise
        self._intervals = intervals
        self._cache = Cache()
        self._mappraiser_timestamps = None
        self._mappraiser_noise = None
        self._mappraiser_pixels = None
        self._mappraiser_pixweights = None
        self._mappraiser_signal = None
        self._mappraiser_invtt = None
        self._conserve_memory = int(conserve_memory)
        self._translate_timestamps = translate_timestamps
        self._verbose = True