Exemple #1
0
def test_lognormal_invariance(comm):
    cosmo = cosmology.Planck15
    CurrentMPIComm.set(comm)

    Plin = cosmology.LinearPower(cosmo, redshift=0.55, transfer='EisensteinHu')
    source = LogNormalCatalog(Plin=Plin,
                              nbar=0.5e-2,
                              BoxSize=128.,
                              Nmesh=32,
                              seed=42)
    source1 = LogNormalCatalog(Plin=Plin,
                               nbar=0.5e-2,
                               BoxSize=128.,
                               Nmesh=32,
                               seed=42,
                               comm=MPI.COMM_SELF)

    assert source.csize == source1.size

    allpos = GatherArray(source['Position'].compute(),
                         root=Ellipsis,
                         comm=comm)
    assert_allclose(allpos, source1['Position'])
    alldis = GatherArray(source['Velocity'].compute(),
                         root=Ellipsis,
                         comm=comm)
    assert_allclose(alldis, source1['Velocity'])
Exemple #2
0
def test_gather_array(comm):
    CurrentMPIComm.set(comm)

    # object arrays must fail
    data1a = numpy.ones(10, dtype=[('test', 'f8')])
    data2a = numpy.ones(10, dtype='f8')

    data1 = GatherArray(data1a, comm, root=0)
    if comm.rank == 0:
        numpy.testing.assert_array_equal(data1['test'], 1)
        assert len(data1) == 10 * comm.size
    else:
        assert data1 is None

    data2 = GatherArray(data2a, comm, root=0)
    if comm.rank == 0:
        numpy.testing.assert_array_equal(data2, 1)
        assert len(data2) == 10 * comm.size
    else:
        assert data2 is None

    data2 = GatherArray(data2a, comm, root=Ellipsis)
    numpy.testing.assert_array_equal(data2, 1)
    assert len(data2) == 10 * comm.size

    data1 = GatherArray(data1a, comm, root=Ellipsis)
    numpy.testing.assert_array_equal(data1['test'], 1)
    assert len(data1) == 10 * comm.size
Exemple #3
0
def test_gather_list(comm):

    # data
    data = numpy.ones(10, dtype=[('a', 'f')])

    # can't gather a list
    with pytest.raises(ValueError):
        data = GatherArray(list(data), comm, root=0)

    with pytest.raises(ValueError):
        data = GatherArray(list(data), comm, root=Ellipsis)
Exemple #4
0
def test_gather_objects(comm):
    CurrentMPIComm.set(comm)

    # object arrays must fail
    data1 = numpy.ones(10, dtype=[('test', 'O')])
    data2 = numpy.ones(10, dtype='O')

    with pytest.raises(ValueError):
        data1 = GatherArray(data1, comm, root=0)

    with pytest.raises(ValueError):
        data2 = GatherArray(data2, comm, root=0)
Exemple #5
0
def test_gather_bad_dtype(comm):

    # data
    if comm.rank == 0:
        data = numpy.ones(10, dtype=[('a', 'f4')])
    else:
        data = numpy.ones(10, dtype=[('a', 'f8')])

    # shape mismatch
    with pytest.raises(ValueError):
        data = GatherArray(data, comm, root=0)

    with pytest.raises(ValueError):
        data = GatherArray(data, comm, root=Ellipsis)
Exemple #6
0
def test_gather_bad_shape(comm):

    # data
    if comm.rank == 0:
        data = numpy.ones((10, 2))
    else:
        data = numpy.ones((10, 3))

    # shape mismatch
    with pytest.raises(ValueError):
        data = GatherArray(data, comm, root=0)

    with pytest.raises(ValueError):
        data = GatherArray(data, comm, root=Ellipsis)
Exemple #7
0
def test_gather_bad_data(comm):
    CurrentMPIComm.set(comm)

    # data
    if comm.rank == 0:
        data = numpy.ones(10, dtype=[('a', 'f')])
    else:
        data = numpy.ones(10, dtype=[('b', 'f')])

    # fields mismatch
    with pytest.raises(ValueError):
        data = GatherArray(data, comm, root=0)

    with pytest.raises(ValueError):
        data = GatherArray(data, comm, root=Ellipsis)
Exemple #8
0
def test_fibercolls(comm):

    from scipy.spatial.distance import pdist, squareform
    from nbodykit.utils import ScatterArray, GatherArray

    CurrentMPIComm.set(comm)
    N = 10000

    # generate the initial data
    numpy.random.seed(42)
    if comm.rank == 0:
        ra = 10. * numpy.random.random(size=N)
        dec = 5. * numpy.random.random(size=N) - 5.0
    else:
        ra = None
        dec = None

    ra = ScatterArray(ra, comm)
    dec = ScatterArray(dec, comm)

    # compute the fiber collisions
    r = FiberCollisions(ra, dec, degrees=True, seed=42)
    rad = r._collision_radius_rad

    #  gather collided and position to root
    idx = GatherArray(r.labels['Collided'].compute().astype(bool),
                      comm,
                      root=0)
    pos = GatherArray(r.source['Position'].compute(), comm, root=0)

    # manually compute distances and check on root
    if comm.rank == 0:
        dists = squareform(pdist(pos, metric='euclidean'))
        numpy.fill_diagonal(dists, numpy.inf)  # ignore self pairs

        # no objects in clean sample (Collided==0) should be within
        # the collision radius of any other objects in the sample
        clean_dists = dists[~idx, ~idx]
        assert (clean_dists <= rad).sum(
        ) == 0, "objects in 'clean' sample within collision radius!"

        # the collided objects must collided with at least
        # one object in the clean sample
        ncolls_per = (dists[idx] <= rad).sum(axis=-1)
        assert (ncolls_per >= 1).all(
        ), "objects in 'collided' sample that do not collide with any objects!"
Exemple #9
0
    def gslice(self, start, stop, end=1, redistribute=True):
        """
        Execute a global slice of a CatalogSource.

        .. note::
            After the global slice is performed, the data is scattered
            evenly across all ranks.

        Parameters
        ----------
        start : int
            the start index of the global slice
        stop : int
            the stop index of the global slice
        step : int, optional
            the default step size of the global size
        redistribute : bool, optional
            if ``True``, evenly re-distribute the sliced data across all
            ranks, otherwise just return any local data part of the global
            slice
        """
        from nbodykit.utils import ScatterArray, GatherArray

        # determine the boolean index corresponding to the slice
        if self.comm.rank == 0:
            index = numpy.zeros(self.csize, dtype=bool)
            index[slice(start, stop, end)] = True
        else:
            index = None
        index = self.comm.bcast(index)

        # scatter the index back to all ranks
        counts = self.comm.allgather(self.size)
        index = ScatterArray(index, self.comm, root=0, counts=counts)

        # perform the needed local slice
        subset = self[index]

        # if we don't want to redistribute evenly, just return the slice
        if not redistribute:
            return subset

        # re-distribute each column from the sliced data
        # NOTE: currently Gather/ScatterArray requires numpy arrays, but
        # in principle we could pass dask arrays around between ranks and
        # avoid compute() calls
        data = self.compute(*[subset[col] for col in subset])

        # gather/scatter each column
        evendata = {}
        for i, col in enumerate(subset):
            alldata = GatherArray(data[i], self.comm, root=0)
            evendata[col] = ScatterArray(alldata, self.comm, root=0)

        # return a new CatalogSource holding the evenly distributed data
        size = len(evendata[col])
        toret = self.__class__._from_columns(size, self.comm, **evendata)
        return toret.__finalize__(self)
Exemple #10
0
def test_gather_list(comm):
    CurrentMPIComm.set(comm)

    # data
    data = numpy.ones(10, dtype=[('a', 'f')])

    # can't gather a list
    with pytest.raises(ValueError):
        data = GatherArray(list(data), comm, root=0)
Exemple #11
0
def test_uniform_invariant(comm):
    cat = UniformCatalog(nbar=2,
                         BoxSize=100.,
                         seed=1234,
                         dtype='f4',
                         comm=comm)

    cat1 = UniformCatalog(nbar=2,
                          BoxSize=100.,
                          seed=1234,
                          dtype='f4',
                          comm=MPI.COMM_SELF)

    allpos = GatherArray(cat['Position'].compute(), root=Ellipsis, comm=comm)

    assert_array_equal(allpos, cat1['Position'])

    allvel = GatherArray(cat['Velocity'].compute(), root=Ellipsis, comm=comm)

    assert_array_equal(allvel, cat1['Velocity'])
Exemple #12
0
    def __makesource__(self):
        """
        Make the source of galaxies by performing the halo HOD population

        .. note::
            The mock population is only done by the root, and the resulting
            catalog is then distributed evenly amongst the available ranks
        """
        from astropy.table import Table

        # gather all halos to root
        halo_table = find_object_dtypes(self._halos.halo_table)
        all_halos = GatherArray(halo_table.as_array(), self.comm, root=0)

        # root does the mock population
        if self.comm.rank == 0:

            # set the halo table on the root to the Table containing all halo
            self._halos.halo_table = Table(data=all_halos, copy=True)
            del all_halos

            # populate
            self._model.populate_mock(halocat=self._halos, halo_mass_column_key=self.mass,
                                      Num_ptcl_requirement=1, seed=self.attrs['seed'])

            # remap gal_type to integers (cen: 0, sats: 1)
            gal_type_integers(self._model.mock.galaxy_table)

            # crash if any object dtypes
            data = find_object_dtypes(self._model.mock.galaxy_table).as_array()
            del self._model.mock.galaxy_table
        else:
            data = None

        # log the stats
        if self.comm.rank == 0:
            self._log_populated_stats(data)

        return ScatterArray(data, self.comm)
Exemple #13
0
def poisson_sample_to_points(delta,
                             displacement,
                             pm,
                             nbar,
                             bias=1.,
                             seed=None,
                             comm=None):
    """
    Poisson sample the linear delta and displacement fields to points.

    The steps in this function:

    #.  Apply a biased, lognormal transformation to the input ``delta`` field
    #.  Poisson sample the overdensity field to discrete points
    #.  Disribute the positions of particles uniformly within the mesh cells,
        and assign the displacement field at each cell to the particles

    Parameters
    ----------
    delta : RealField
        the linear overdensity field to sample
    displacement : list of RealField (3,)
        the linear displacement fields which is used to move the particles
    nbar : float
        the desired number density of the output catalog of objects
    bias : float, optional
        apply a linear bias to the overdensity field (default is 1.)
    seed : int, optional
        the random seed used to Poisson sample the field to points

    Returns
    -------
    pos : array_like, (N, 3)
        the Cartesian positions of each of the generated particles
    displ : array_like, (N, 3)
        the displacement field sampled for each of the generated particles in the
        same units as the ``pos`` array
    """
    if comm is None:
        comm = MPI.COMM_WORLD

    # create a random state with the input seed
    rng = numpy.random.RandomState(seed)

    # apply the lognormal transformation to the initial conditions density
    # this creates a positive-definite delta (necessary for Poisson sampling)
    lagrangian_bias = bias - 1.
    delta = lognormal_transform(delta, bias=lagrangian_bias)

    # mean number of objects per cell
    H = delta.BoxSize / delta.Nmesh
    overallmean = H.prod() * nbar

    # number of objects in each cell (per rank)
    cellmean = delta.value * overallmean
    cellmean = GatherArray(cellmean.flatten(), comm, root=0)

    # rank 0 computes the poisson sampling
    if comm.rank == 0:
        N = rng.poisson(cellmean)
    else:
        N = None

    # scatter N back evenly across the ranks
    counts = comm.allgather(delta.value.size)
    N = ScatterArray(N, comm, root=0, counts=counts).reshape(delta.shape)

    Nlocal = N.sum()  # local number of particles
    Ntot = comm.allreduce(Nlocal)  # the collective number of particles
    nonzero_cells = N.nonzero()  # indices of nonzero cells

    # initialize the mesh of particle positions and displacement
    # this has the shape: (number of dimensions, number of nonzero cells)
    pos_mesh = numpy.empty(numpy.shape(nonzero_cells), dtype=delta.dtype)
    disp_mesh = numpy.empty_like(pos_mesh)

    # generate the coordinates for each nonzero cell
    for i in range(delta.ndim):

        # particle positions initially on the coordinate grid
        pos_mesh[i] = numpy.squeeze(delta.pm.x[i])[nonzero_cells[i]]

        # displacements for each particle
        disp_mesh[i] = displacement[i][nonzero_cells]

    # rank 0 computes the in-cell uniform offsets
    if comm.rank == 0:
        in_cell_shift = numpy.empty((Ntot, delta.ndim), dtype=delta.dtype)
        for i in range(delta.ndim):
            in_cell_shift[:, i] = rng.uniform(0, H[i], size=Ntot)
    else:
        in_cell_shift = None

    # scatter the in-cell uniform offsets back to the ranks
    counts = comm.allgather(Nlocal)
    in_cell_shift = ScatterArray(in_cell_shift, comm, root=0, counts=counts)

    # initialize the output array of particle positions and displacement
    # this has shape: (local number of particles, number of dimensions)
    pos = numpy.zeros((Nlocal, delta.ndim), dtype=delta.dtype)
    disp = numpy.zeros_like(pos)

    # coordinates of each object (placed randomly in each cell)
    for i in range(delta.ndim):
        pos[:, i] = numpy.repeat(pos_mesh[i],
                                 N[nonzero_cells]) + in_cell_shift[:, i]
        pos[:, i] %= delta.BoxSize[i]

    # displacements of each object
    for i in range(delta.ndim):
        disp[:, i] = numpy.repeat(disp_mesh[i], N[nonzero_cells])

    return pos, disp
def main():
    """
    Convert grids4plots 3d grids to 2d slices.
    """

    #####################################
    # PARSE COMMAND LINE ARGS
    #####################################
    ap = ArgumentParser()

    ap.add_argument('--inbasepath',
                    type=str,
                    default='$SCRATCH/perr/grids4plots/',
                    help='Input base path.')

    ap.add_argument('--outbasepath',
                    type=str,
                    default='$SCRATCH/perr/grids4plots/',
                    help='Output base path.')

    ap.add_argument(
        '--inpath',
        type=str,
        default=
        'main_calc_Perr_2020_Sep_22_18:44:31_time1600800271.dill',  # laptop
        #default='main_calc_Perr_2020_Aug_26_02:49:57_time1598410197.dill', # cluster
        help='Input path.')

    ap.add_argument('--Rsmooth',
                    type=float,
                    default=2.0,
                    help='3D Gaussian smoothing applied to field.')

    # min and max index included in output. inclusive.
    ap.add_argument('--ixmin',
                    type=int,
                    default=5,
                    help='xmin of output. must be between 0 and Ngrid.')
    ap.add_argument('--ixmax', type=int, default=5, help='xmax of output')
    ap.add_argument('--iymin', type=int, default=0, help='ymin of output')
    ap.add_argument('--iymax', type=int, default=-1, help='ymax of output')
    ap.add_argument('--izmin', type=int, default=0, help='zmin of output')
    ap.add_argument('--izmax', type=int, default=-1, help='zmax of output')

    cmd_args = ap.parse_args()

    verbose = True

    #####################################
    # START PROGRAM
    #####################################
    comm = CurrentMPIComm.get()
    rank = comm.rank

    path = os.path.join(os.path.expandvars(cmd_args.inbasepath),
                        cmd_args.inpath)
    if rank == 0:
        print('path: ', path)

    initialized_slicecat = False
    for fname in os.listdir(path):
        if fname.startswith('SLICE'):
            continue

        full_fname = os.path.join(path, fname)
        print('%d Reading %s' % (rank, full_fname))

        inmesh = BigFileMesh(full_fname,
                             dataset='tmp4storage',
                             header='header')
        Ngrid = inmesh.attrs['Ngrid']
        boxsize = inmesh.attrs['boxsize']

        # apply smoothing
        mesh = apply_smoothing(inmesh, mode='Gaussian', R=cmd_args.Rsmooth)
        del inmesh

        # convert indices to modulo ngrid
        ixmin = cmd_args.ixmin % Ngrid
        ixmax = cmd_args.ixmax % Ngrid
        iymin = cmd_args.iymin % Ngrid
        iymax = cmd_args.iymax % Ngrid
        izmin = cmd_args.izmin % Ngrid
        izmax = cmd_args.izmax % Ngrid

        # convert to boxsize units (Mpc/h)
        xmin = float(ixmin) / float(Ngrid) * boxsize
        xmax = float(ixmax + 1) / float(Ngrid) * boxsize
        ymin = float(iymin) / float(Ngrid) * boxsize
        ymax = float(iymax + 1) / float(Ngrid) * boxsize
        zmin = float(izmin) / float(Ngrid) * boxsize
        zmax = float(izmax + 1) / float(Ngrid) * boxsize

        if not initialized_slicecat:
            # Generate catalog with positions of slice points, to readout mesh there.
            # First generate all 3D points. Then keep only subset in slice.
            # THen readout mesh at those points.

            # use pmesh generate_uniform_particle_grid
            # http://rainwoodman.github.io/pmesh/pmesh.pm.html?highlight=
            # readout#pmesh.pm.ParticleMesh.generate_uniform_particle_grid
            partmesh = ParticleMesh(BoxSize=boxsize,
                                    Nmesh=[Ngrid, Ngrid, Ngrid])
            ptcles = partmesh.generate_uniform_particle_grid(shift=0.0,
                                                             dtype='f8')
            #print("type ptcles", type(ptcles), ptcles.shape)
            #print("head ptcles:", ptcles[:5,:])

            dtype = np.dtype([('Position', ('f8', 3))])

            # number of rows is given by number of ptcles on this rank
            uni_cat_array = np.empty((ptcles.shape[0], ), dtype=dtype)
            uni_cat_array['Position'] = ptcles

            uni_cat = ArrayCatalog(uni_cat_array,
                                   comm=None,
                                   BoxSize=boxsize * np.ones(3),
                                   Nmesh=[Ngrid, Ngrid, Ngrid])

            del ptcles
            del uni_cat_array

            print("%d: Before cut: local Nptcles=%d, global Nptcles=%d" %
                  (comm.rank, uni_cat.size, uni_cat.csize))

            # only keep points in the slice
            uni_cat = uni_cat[(uni_cat['Position'][:, 0] >= xmin)
                              & (uni_cat['Position'][:, 0] < xmax)
                              & (uni_cat['Position'][:, 1] >= ymin)
                              & (uni_cat['Position'][:, 1] < ymax)
                              & (uni_cat['Position'][:, 2] >= zmin)
                              & (uni_cat['Position'][:, 2] < zmax)]

            print("%d: After cut: local Nptcles=%d, global Nptcles=%d" %
                  (comm.rank, uni_cat.size, uni_cat.csize))

            initialized_slicecat = True

        # read out full 3D mesh at catalog positions. this is a numpy array
        slicecat = readout_mesh_at_cat_pos(mesh=mesh,
                                           cat=uni_cat,
                                           readout_window='nearest')

        if rank == 0:
            print('slicecat type:', type(slicecat))

        slicecat = GatherArray(slicecat, comm, root=0)

        if rank == 0:
            if not slicecat.shape == ((ixmax - ixmin + 1) *
                                      (iymax - iymin + 1) *
                                      (izmax - izmin + 1), ):
                raise Exception(
                    'Unexpected shape of particles read out on slice: %s' %
                    str(slicecat.shape))

            slicecat = slicecat.reshape(
                (ixmax - ixmin + 1, iymax - iymin + 1, izmax - izmin + 1))

            print('slicecat shape:', slicecat.shape)
            if verbose:
                print('slicecat:', slicecat)

        # convert to a mesh. assume full numpy array sits on rank 0.
        Lx = xmax - xmin
        Ly = ymax - ymin
        Lz = zmax - zmin
        if Lx == 0.: Lx = boxsize / float(Ngrid)
        if Ly == 0.: Ly = boxsize / float(Ngrid)
        if Lz == 0.: Lz = boxsize / float(Ngrid)
        BoxSize_slice = np.array([Lx, Ly, Lz])
        slicemesh = ArrayMesh(slicecat, BoxSize=BoxSize_slice, root=0)

        outshape = slicemesh.compute(mode='real').shape
        if verbose:
            print('slicemesh: ', slicemesh.compute(mode='real'))

        # write to disk
        outpath = os.path.join(
            os.path.expandvars(cmd_args.outbasepath), cmd_args.inpath,
            'SLICE_R%g_%d-%d_%d-%d_%d-%d/' %
            (cmd_args.Rsmooth, ixmin, ixmax, iymin, iymax, izmin, izmax))
        if rank == 0:
            if not os.path.exists(outpath):
                os.makedirs(outpath)
        full_outfname = os.path.join(outpath, fname)
        if rank == 0:
            print('Writing %s' % full_outfname)
        slicemesh.save(full_outfname)
        if rank == 0:
            print('Wrote %s' % full_outfname)
Exemple #15
0
    def populate(self, model, BoxSize=None, seed=None, **params):
        """
        Populate the HaloCatalog using a :mod:`halotools` model.

        The model can be a built-in model from :mod:`nbodykit.hod` (which
        will be converted to a Halotools model) or directly a Halotools model
        instance.

        This assumes that this is the first time this catalog has been
        populated with the input model. To re-populate using the same
        model (but different parameters), call the :func:`repopulate`
        function of the returned :class:`PopulatedHaloCatalog`.

        Parameters
        ----------
        model : :class:`nbodykit.hod.HODModel` or halotools model object
            the model instance to use to populate; model types from
            :mod:`nbodykit.hod` will automatically be converted
        BoxSize : float, 3-vector, optional
            the box size of the catalog; this must be supplied if 'BoxSize'
            is not in :attr:`attrs`
        seed : int, optional
            the random seed to use when populating the mock
        **params :
            key/value pairs specifying the model parameters to use

        Returns
        -------
        cat : :class:`PopulatedHaloCatalog`
            the catalog object storing information about the populated objects

        Examples
        --------
        Initialize a demo halo catalog:

        >>> from nbodykit.tutorials import DemoHaloCatalog
        >>> cat = DemoHaloCatalog('bolshoi', 'rockstar', 0.5)

        Populate with the built-in Zheng07 model:

        >>> from nbodykit.hod import Zheng07Model
        >>> galcat = cat.populate(Zheng07Model, seed=42)

        And then re-populate galaxy catalog with new parameters:

        >>> galcat.repopulate(alpha=0.9, logMmin=13.5, seed=42)
        """
        from nbodykit.hod import HODModel
        from halotools.empirical_models import ModelFactory
        from halotools.sim_manager import UserSuppliedHaloCatalog

        # handle builtin model types
        if isinstance(model, (type, HODModel)) and issubclass(model, HODModel):
            model = model.to_halotools(self.cosmo,
                                       self.attrs['redshift'],
                                       self.attrs['mdef'],
                                       concentration_key='halo_nfw_conc')

        # check model type
        if not isinstance(model, ModelFactory):
            raise TypeError(
                "model for populating mocks should be a Halotools ModelFactory"
            )

        # make halotools catalog
        halocat = self.to_halotools(BoxSize=BoxSize)

        # gather the halo data to root
        all_halos = GatherArray(halocat.halo_table.as_array(),
                                self.comm,
                                root=0)

        # only the root rank needs to store the halo data
        if self.comm.rank == 0:
            data = {col: all_halos[col] for col in all_halos.dtype.names}
            data.update({
                col: getattr(halocat, col)
                for col in ['Lbox', 'redshift', 'particle_mass']
            })
            halocat = UserSuppliedHaloCatalog(**data)
        else:
            halocat = None

        # cache the model so we have option to call repopulate later
        self.model = model

        # return the populated catalog
        return _populate_mock(self,
                              model,
                              seed=seed,
                              halocat=halocat,
                              **params)
Exemple #16
0
def main(ns):
    if ns.zlmax is None:
        ns.zlmax = max(ns.zs)

    zs_list = ns.zs

    zlmin = ns.zlmin
    zlmax = ns.zlmax

    # no need to be accurate here
    ds_list = Planck15.comoving_distance(zs_list)

    path = ns.source
    #'/global/cscratch1/sd/yfeng1/m3127/desi/1536-9201-40eae2464/lightcone/usmesh/'

    cat = BigFileCatalog(path, dataset=ns.dataset)

    kappa = 0
    Nm = 0
    kappabar = 0

    npix = healpix.nside2npix(ns.nside)
    localsize = npix * (cat.comm.rank + 1) // cat.comm.size - npix * (
        cat.comm.rank) // cat.comm.size
    nbar = (cat.attrs['NC']**3 / cat.attrs['BoxSize']**3 *
            cat.attrs['ParticleFraction'])[0]

    Nsteps = int(numpy.round((zlmax - zlmin) / ns.zstep))
    if Nsteps < 2: Nsteps = 2
    z = numpy.linspace(zlmax, zlmin, Nsteps, endpoint=True)

    if cat.comm.rank == 0:
        cat.logger.info("Splitting data redshift bins %s" % str(z))

    for z1, z2 in zip(z[:-1], z[1:]):
        import gc
        gc.collect()
        if cat.comm.rank == 0:
            cat.logger.info("nbar = %g, zlmin = %g, zlmax = %g zs = %s" %
                            (nbar, z2, z1, zs_list))

        slice = read_range(cat, 1 / (1 + z1), 1 / (1 + z2))

        if slice.csize == 0: continue
        if cat.comm.rank == 0:
            cat.logger.info("read %d particles" % slice.csize)

        kappa1, kappa1bar, Nm1 = make_kappa_maps(slice, ns.nside, zs_list,
                                                 ds_list, localsize, nbar)

        kappa = kappa + kappa1
        Nm = Nm + Nm1
        kappabar = kappabar + kappa1bar

    cat.comm.barrier()

    if cat.comm.rank == 0:
        # use bigfile because it allows concurrent write to different datasets.
        cat.logger.info("writing to %s", ns.output)

    for i, (zs, ds) in enumerate(zip(zs_list, ds_list)):
        std = numpy.std(cat.comm.allgather(len(kappa[i])))
        mean = numpy.mean(cat.comm.allgather(len(kappa[i])))
        if cat.comm.rank == 0:
            cat.logger.info(
                "started gathering source plane %s, size-var = %g, size-bar = %g"
                % (zs, std, mean))

        kappa1 = GatherArray(kappa[i], cat.comm)
        Nm1 = GatherArray(Nm[i], cat.comm)

        if cat.comm.rank == 0:
            cat.logger.info("done gathering source plane %s" % zs)

        if cat.comm.rank == 0:
            fname = ns.output + "/WL-%02.2f-N%04d" % (zs, ns.nside)
            cat.logger.info("started writing source plane %s" % zs)

            with bigfile.File(fname, create=True) as ff:

                ds1 = ff.create_from_array("kappa", kappa1, Nfile=1)
                ds2 = ff.create_from_array("Nm", Nm1, Nfile=1)

                for d in ds1, ds2:
                    d.attrs['kappabar'] = kappabar[i]
                    d.attrs['nside'] = ns.nside
                    d.attrs['zlmin'] = zlmin
                    d.attrs['zlmax'] = zlmax
                    d.attrs['zs'] = zs
                    d.attrs['ds'] = ds
                    d.attrs['nbar'] = nbar

        cat.comm.barrier()
        if cat.comm.rank == 0:
            # use bigfile because it allows concurrent write to different datasets.
            cat.logger.info("source plane at %g written. " % zs)
def main(ns):
    if ns.zlmax is None:
        ns.zlmax = max(ns.zs)

    zs_list = ns.zs
    ###### JL hardcode zs_list
    #zs_list = numpy.arange(ns.zs, 2.21, 0.1)
    zs_list = ns.zs

    zlmin = ns.zlmin
    zlmax = zs_list[-1]#ns.zlmax

    # no need to be accurate here
    ds_list = Planck15.comoving_distance(zs_list)

    path = ns.source

    cat = BigFileCatalog(path, dataset=ns.dataset)

    kappa = 0
    Nm = 0
    kappabar = 0

    npix = healpix.nside2npix(ns.nside)
    localsize = npix * (cat.comm.rank + 1) // cat.comm.size - npix * (cat.comm.rank) // cat.comm.size
    nbar = (cat.attrs['NC'] ** 3  / cat.attrs['BoxSize'] ** 3 * cat.attrs['ParticleFraction'])[0]
 #   print('DEBUG BoxSize', cat.attrs['BoxSize'])
    
    Nsteps = int(numpy.round((zlmax - zlmin) / ns.zstep))
    if Nsteps < 2 : Nsteps = 2

    z = numpy.linspace(zlmax, zlmin, Nsteps+1, endpoint=True)

    if cat.comm.rank == 0:
        cat.logger.info("Splitting data redshift bins %s" % str(z))

    kappa_all = numpy.zeros((Nsteps, len(zs_list), localsize))
    for i, (z1, z2) in enumerate(zip(z[:-1], z[1:])):
        import gc
        gc.collect()
        if cat.comm.rank == 0:
            cat.logger.info("nbar = %g, zlmin = %g, zlmax = %g zs = %s" % (nbar, z2, z1, zs_list))

        slice = read_range(cat, 1/(1 + z1), 1 / (1 + z2))

        if slice.csize == 0: continue
        if cat.comm.rank == 0:
            cat.logger.info("read %d particles" % slice.csize)

        kappa1, kappa1bar, Nm1  = make_kappa_maps(slice, ns.nside, zs_list, ds_list, localsize, nbar)

        kappa = kappa + kappa1

        kappa_all[i] = kappa1
        
        Nm = Nm + Nm1
        kappabar = kappabar + kappa1bar

    cat.comm.barrier()

    if cat.comm.rank == 0:
        # use bigfile because it allows concurrent write to different datasets.
        cat.logger.info("writing to %s", ns.output)


    # array to get all map slices
    if cat.comm.rank == 0:
        kappa1_all = numpy.zeros((Nsteps, int(12*ns.nside**2)))
                                  
    for i, (zs, ds) in enumerate(zip(zs_list, ds_list)):
        std = numpy.std(cat.comm.allgather(len(kappa[i])))
        mean = numpy.mean(cat.comm.allgather(len(kappa[i])))
        if cat.comm.rank == 0:
            cat.logger.info("started gathering source plane %s, size-var = %g, size-bar = %g" % (zs, std, mean))

        kappa1 = GatherArray(kappa[i], cat.comm)
        Nm1 = GatherArray(Nm[i], cat.comm)

        # get slices of kappa map
        for j in range(Nsteps):
            kappa1_allj = GatherArray(kappa_all[j,i], cat.comm)
            if cat.comm.rank == 0:
                kappa1_all[j] = kappa1_allj
                
        if cat.comm.rank == 0:
            cat.logger.info("done gathering source plane %s" % zs)

        if cat.comm.rank == 0:
            fname = ns.output + "/WL-%02.2f-N%04d" % (zs, ns.nside)
            cat.logger.info("started writing source plane %s" % zs)

            with bigfile.File(fname, create=True) as ff:
                print('DEBUG', kappa1_all.shape, len(kappa1_all), numpy.dtype((kappa1_all.dtype, kappa1_all.shape[1:])))
                ds1 = ff.create_from_array("kappa", kappa1, Nfile=1)
                ds2 = ff.create_from_array("Nm", Nm1, Nfile=1)
                #ds3 = ff.create_from_array("kappa_all", kappa1_all.T, Nfile=1)#, memorylimit=1024*1024*1024)

                for d in ds1, ds2:#, ds3:
                    d.attrs['kappabar'] = kappabar[i]
                    d.attrs['nside'] = ns.nside
                    d.attrs['zlmin'] = zlmin
                    d.attrs['zlmax'] = zlmax
                    d.attrs['zstep'] = ns.zstep
                    d.attrs['zs'] = zs
                    d.attrs['ds'] = ds
                    d.attrs['nbar'] = nbar

        cat.comm.barrier()
        if cat.comm.rank == 0:
            # use bigfile because it allows concurrent write to different datasets.
            cat.logger.info("source plane at %g written. " % zs)
def make_kappa_maps(cat, nside, zs_list, ds_list, localsize, nbar):
    """ Make kappa maps at a list of ds
        Return kappa, Nm in shape of (n_ds, localsize), kappabar in shape of (n_ds,)
        The maps are distributed in memory, and localsize is the size of
        map on this rank.
    """

    dl = (abs(cat['Position'] **2).sum(axis=-1)) ** 0.5
    chunks = dl.chunks
    ra = cat['RA']
    dec = cat['DEC']
    zl = (1 / cat['Aemit'] - 1)
    
    ipix = da.apply_gufunc(lambda ra, dec, nside:
                           healpix.ang2pix(nside, numpy.radians(90-dec), numpy.radians(ra)),
                        '(),()->()', ra, dec, nside=nside)

    npix = healpix.nside2npix(nside)

    ipix = ipix.compute()
    dl = dl.persist()
 
    cat.comm.barrier()

    if cat.comm.rank == 0:
        cat.logger.info("ipix and dl are persisted")

    area = (4 * numpy.pi / npix) * dl**2

    Om = cat.attrs['OmegaM'][0]

    kappa_list = []
    kappabar_list = []
    Nm_list = []
    for zs, ds in zip(zs_list, ds_list):
        LensKernel = da.apply_gufunc(lambda dl, zl, Om, ds: wlen(Om, dl, zl, ds), 
                                     "(), ()-> ()",
                                     dl, zl, Om=Om, ds=ds)

        weights = (LensKernel / (area * nbar))
        weights = weights.compute()

        cat.comm.barrier()

        if cat.comm.rank == 0:
            cat.logger.info("source plane %g weights are persisted" % zs)
        Wmap, Nmap = weighted_map(ipix, npix, weights, localsize, cat.comm)

        cat.comm.barrier()
        if cat.comm.rank == 0:
            cat.logger.info("source plane %g maps generated" % zs)

        # compute kappa bar
        # this is a simple integral, but we do not know dl, dz relation
        # so do it with values from a subsample of particles
        every = (cat.csize // 100000)
        
        kappa1 = Wmap
        if every == 0: every = 1

        # use GatherArray, because it is faster than comm.gather at this scale
        # (> 4000 ranks on CrayMPI)
        ssdl = GatherArray(dl[::every].compute(), cat.comm)
        ssLensKernel = GatherArray(LensKernel[::every].compute(), cat.comm)

        if cat.comm.rank == 0:
            arg = ssdl.argsort()
            ssdl = ssdl[arg]
            ssLensKernel = ssLensKernel[arg]
            
            kappa1bar = numpy.trapz(ssLensKernel, ssdl)
        else:
            kappa1bar = None
        kappa1bar = cat.comm.bcast(kappa1bar)

        cat.comm.barrier()
        if cat.comm.rank == 0:
            cat.logger.info("source plane %g bar computed " % zs)
        kappa_list.append(kappa1)
        kappabar_list.append(kappa1bar)
        Nm_list.append(Nmap)
    """
    # estimate nbar
    dlmin = dl.min()
    dlmax = dl.max()
        
    volume = (Nmap > 0).sum() / len(Nmap) * 4  / 3 * numpy.pi * (dlmax**3 - dlmin ** 3)
    """
    # returns number rather than delta, since we do not know fsky here.
    #Nmap = Nmap / cat.csize * cat.comm.allreduce((Nmap > 0).sum()) # to overdensity.
    return numpy.array(kappa_list), numpy.array(kappabar_list), numpy.array(Nm_list)