示例#1
0
    def gslice(self, start, stop, end=1, redistribute=True):
        """
        Execute a global slice of a CatalogSource.

        .. note::
            After the global slice is performed, the data is scattered
            evenly across all ranks.

        Parameters
        ----------
        start : int
            the start index of the global slice
        stop : int
            the stop index of the global slice
        step : int, optional
            the default step size of the global size
        redistribute : bool, optional
            if ``True``, evenly re-distribute the sliced data across all
            ranks, otherwise just return any local data part of the global
            slice
        """
        from nbodykit.utils import ScatterArray, GatherArray

        # determine the boolean index corresponding to the slice
        if self.comm.rank == 0:
            index = numpy.zeros(self.csize, dtype=bool)
            index[slice(start, stop, end)] = True
        else:
            index = None
        index = self.comm.bcast(index)

        # scatter the index back to all ranks
        counts = self.comm.allgather(self.size)
        index = ScatterArray(index, self.comm, root=0, counts=counts)

        # perform the needed local slice
        subset = self[index]

        # if we don't want to redistribute evenly, just return the slice
        if not redistribute:
            return subset

        # re-distribute each column from the sliced data
        # NOTE: currently Gather/ScatterArray requires numpy arrays, but
        # in principle we could pass dask arrays around between ranks and
        # avoid compute() calls
        data = self.compute(*[subset[col] for col in subset])

        # gather/scatter each column
        evendata = {}
        for i, col in enumerate(subset):
            alldata = GatherArray(data[i], self.comm, root=0)
            evendata[col] = ScatterArray(alldata, self.comm, root=0)

        # return a new CatalogSource holding the evenly distributed data
        size = len(evendata[col])
        toret = self.__class__._from_columns(size, self.comm, **evendata)
        return toret.__finalize__(self)
示例#2
0
def test_scatter_wrong_counts(comm):

    # data
    if comm.rank == 0:
        data = numpy.ones(10, dtype=[('a', 'f')])
    else:
        data = None

    # wrong counts length
    with pytest.raises(ValueError):
        data = ScatterArray(data, comm, root=0, counts=[0, 5, 5])

    # wrong counts sum
    with pytest.raises(ValueError):
        data = ScatterArray(data, comm, root=0, counts=[5, 7])
示例#3
0
def test_scatter_objects(comm):

    # object arrays must fail
    if comm.rank == 0:
        data1 = numpy.ones(10, dtype=[('test', 'O')])
        data2 = numpy.ones(10, dtype='O')
    else:
        data1 = None
        data2 = None

    with pytest.raises(ValueError):
        data1 = ScatterArray(data1, comm, root=0)

    with pytest.raises(ValueError):
        data2 = ScatterArray(data2, comm, root=0)
示例#4
0
def test_fibercolls(comm):

    from scipy.spatial.distance import pdist, squareform
    from nbodykit.utils import ScatterArray, GatherArray

    CurrentMPIComm.set(comm)
    N = 10000

    # generate the initial data
    numpy.random.seed(42)
    if comm.rank == 0:
        ra = 10. * numpy.random.random(size=N)
        dec = 5. * numpy.random.random(size=N) - 5.0
    else:
        ra = None
        dec = None

    ra = ScatterArray(ra, comm)
    dec = ScatterArray(dec, comm)

    # compute the fiber collisions
    r = FiberCollisions(ra, dec, degrees=True, seed=42)
    rad = r._collision_radius_rad

    #  gather collided and position to root
    idx = GatherArray(r.labels['Collided'].compute().astype(bool),
                      comm,
                      root=0)
    pos = GatherArray(r.source['Position'].compute(), comm, root=0)

    # manually compute distances and check on root
    if comm.rank == 0:
        dists = squareform(pdist(pos, metric='euclidean'))
        numpy.fill_diagonal(dists, numpy.inf)  # ignore self pairs

        # no objects in clean sample (Collided==0) should be within
        # the collision radius of any other objects in the sample
        clean_dists = dists[~idx, ~idx]
        assert (clean_dists <= rad).sum(
        ) == 0, "objects in 'clean' sample within collision radius!"

        # the collided objects must collided with at least
        # one object in the clean sample
        ncolls_per = (dists[idx] <= rad).sum(axis=-1)
        assert (ncolls_per >= 1).all(
        ), "objects in 'collided' sample that do not collide with any objects!"
示例#5
0
    def repopulate(self, seed=None, **params):
        """
        Update the HOD parameters and then re-populate the mock catalog

        .. warning::
            This operation is done in-place, so the size of the Source
            changes

        Parameters
        ----------
        seed : int, optional
            the new seed to use when populating the mock
        params :
            key/value pairs of HOD parameters to update
        """
        # set the seed randomly if it is None
        if seed is None:
            if self.comm.rank == 0:
                seed = numpy.random.randint(0, 4294967295)
            seed = self.comm.bcast(seed)
        self.attrs['seed'] = seed

        # update the HOD model parameters
        for name in params:
            if name not in self._model.param_dict:
                valid = list(self._model.param_dict.keys())
                raise ValueError("'%s' is not a valid Hod parameter name; valid are: %s" %(name, str(valid)))
            self._model.param_dict[name] = params[name]
            self.attrs[name] = params[name]

        # the root will do the mock population
        if self.comm.rank == 0:

            # re-populate the mock (without halo catalog pre-processing)
            self._model.mock.populate(Num_ptcl_requirement=1,
                                      halo_mass_column_key=self.mass,
                                      seed=self.attrs['seed'])

            # remap gal_type to integers (cen: 0, sats: 1)
            gal_type_integers(self._model.mock.galaxy_table)

            # crash if any object dtypes
            data = find_object_dtypes(self._model.mock.galaxy_table).as_array()
            del self._model.mock.galaxy_table

        else:
            data = None

        # log the stats
        if self.comm.rank == 0:
            self._log_populated_stats(data)

        # re-initialize with new source
        ArrayCatalog.__init__(self, ScatterArray(data, self.comm), comm=self.comm, use_cache=self.use_cache)
示例#6
0
def test_scatter_list(comm):

    # data
    if comm.rank == 0:
        data = list(numpy.ones(10, dtype=[('a', 'f')]))
    else:
        data = None

    # can't scatter list
    with pytest.raises(ValueError):
        data = ScatterArray(data, comm, root=0)
示例#7
0
    def __makesource__(self):
        """
        Make the source of galaxies by performing the halo HOD population

        .. note::
            The mock population is only done by the root, and the resulting
            catalog is then distributed evenly amongst the available ranks
        """
        from astropy.table import Table

        # gather all halos to root
        halo_table = find_object_dtypes(self._halos.halo_table)
        all_halos = GatherArray(halo_table.as_array(), self.comm, root=0)

        # root does the mock population
        if self.comm.rank == 0:

            # set the halo table on the root to the Table containing all halo
            self._halos.halo_table = Table(data=all_halos, copy=True)
            del all_halos

            # populate
            self._model.populate_mock(halocat=self._halos, halo_mass_column_key=self.mass,
                                      Num_ptcl_requirement=1, seed=self.attrs['seed'])

            # remap gal_type to integers (cen: 0, sats: 1)
            gal_type_integers(self._model.mock.galaxy_table)

            # crash if any object dtypes
            data = find_object_dtypes(self._model.mock.galaxy_table).as_array()
            del self._model.mock.galaxy_table
        else:
            data = None

        # log the stats
        if self.comm.rank == 0:
            self._log_populated_stats(data)

        return ScatterArray(data, self.comm)
示例#8
0
def fof_catalog(source, label, comm,
                position='Position', velocity='Velocity', initposition='InitialPosition',
                peakcolumn=None, periodic=True):
    """
    Catalog of FOF groups based on label from a parent source

    This is a collective operation -- the returned halo catalog will be
    equally distributed across all ranks

    Notes
    -----
    This computes the center-of-mass position and velocity in the same
    units as the corresponding columns ``source``

    Parameters
    ----------
    source: CatalogSource
        the parent source of particles from which the center-of-mass
        position and velocity are computed for each halo
    label : array_like
        the label for each particle that identifies which halo it
        belongs to
    comm: MPI.Comm
        the mpi communicator. Must agree with the datasource
    position : str, optional
        the column name specifying the position
    velocity : str, optional
        the column name specifying the velocity
    initposition : str, optional
        the column name specifying the initial position; this is only
        computed if available
    peakcolumn : str, optional
        if not None, find PeakPostion and PeakVelocity based on the
        value of peakcolumn

    Returns
    -------
    catalog: array_like
        A 1-d array of type 'Position', 'Velocity', 'Length'.
        The center mass position and velocity of the FOF halo, and
        Length is the number of particles in a halo. The catalog is
        sorted such that the most massive halo is first. ``catalog[0]``
        does not correspond to any halo.
    """
    from nbodykit.utils import ScatterArray

    # make sure all of the columns are there
    for col in [position, velocity]:
        if col not in source:
            raise ValueError("the column '%s' is missing from parent source; cannot compute halos" %col)

    dtype=[('CMPosition', ('f4', 3)),('CMVelocity', ('f4', 3)),('Length', 'i4')]
    N = count(label, comm=comm)

    if periodic:
        # make sure BoxSize is there
        boxsize = source.attrs.get('BoxSize', None)
        if boxsize is None:
            raise ValueError("cannot compute halo catalog from source without 'BoxSize' in ``attrs`` dict")
    else:
        boxsize = None

    # center of mass position
    hpos = centerofmass(label, source.compute(source[position]), boxsize=boxsize, comm=comm)

    # center of mass velocity
    hvel = centerofmass(label, source.compute(source[velocity]), boxsize=None, comm=comm)

    # center of mass initial position
    if initposition in source:
        dtype.append(('InitialPosition', ('f4', 3)))
        hpos_init = centerofmass(label, source.compute(source[initposition]), boxsize=boxsize, comm=comm)

    if peakcolumn is not None:
        assert peakcolumn in source

        dtype.append(('PeakPosition', ('f4', 3)))
        dtype.append(('PeakVelocity', ('f4', 3)))

        density = source[peakcolumn].compute()
        dmax = equiv_class(label, density, op=numpy.fmax, dense_labels=True, minlength=len(N), identity=-numpy.inf)
        comm.Allreduce(MPI.IN_PLACE, dmax, op=MPI.MAX)
        # remove any non-peak particle from the labels
        label1 = label * (density >= dmax[label])

        # compute the center of mass on the new labels
        ppos = centerofmass(label1, source.compute(source[position]), boxsize=boxsize, comm=comm)
        pvel = centerofmass(label1, source.compute(source[velocity]), boxsize=None, comm=comm)

    dtype = numpy.dtype(dtype)
    if comm.rank == 0:
        catalog = numpy.empty(shape=len(N), dtype=dtype)

        catalog['CMPosition'] = hpos
        catalog['CMVelocity'] = hvel
        catalog['Length'] = N
        catalog['Length'][0] = 0
        if 'InitialPosition' in dtype.names:
            catalog['InitialPosition'] = hpos_init

        if peakcolumn is not None:
            catalog['PeakPosition'] = ppos
            catalog['PeakVelocity'] = pvel
    else:
        catalog = None

    return ScatterArray(catalog, comm, root=0)
示例#9
0
def _populate_mock(cat,
                   model,
                   seed=None,
                   halocat=None,
                   inplace=False,
                   **params):
    """
    Internal function to perform the mock population on a HaloCatalog, given
    a :mod:`halotools` model.

    The implementation is not massively parallel. The data is gathered to
    the root rank, mock population is performed, and then the data is
    re-scattered evenly across ranks.
    """
    # verify input params
    valid = sorted(model.param_dict)
    missing = set(params) - set(valid)
    if len(missing):
        raise ValueError("invalid halo model parameter names: %s" %
                         str(missing))

    # update the model parameters
    model.param_dict.update(params)

    # set the seed randomly if it is None
    if seed is None:
        if cat.comm.rank == 0:
            seed = numpy.random.randint(0, 4294967295)
        seed = cat.comm.bcast(seed)

    # the types of galaxies we are populating
    gal_types = getattr(model, 'gal_types', [])

    exception = None
    try:
        # the root will do the mock population
        if cat.comm.rank == 0:

            # re-populate the mock (without halo catalog pre-processing)
            kws = {
                'seed': seed,
                'Num_ptcl_requirement': 0,
                'halo_mass_column_key': cat.attrs['halo_mass_key']
            }
            if hasattr(model, 'mock'):
                model.mock.populate(**kws)
            # populating model for the first time (initialization costs)
            else:
                if halocat is None:
                    raise ValueError(
                        "halocat cannot be None if we are populating for the first time"
                    )
                model.populate_mock(halocat=halocat, **kws)

            # enumerate gal types as integers
            # NOTE: necessary to avoid "O" type columns
            _enum_gal_types(model.mock.galaxy_table, gal_types)

            # crash if any object dtypes
            # NOTE: we cannot use GatherArray/ScatterArray on objects
            data = _test_for_objects(model.mock.galaxy_table).as_array()

        else:
            data = None

    except Exception as e:
        exception = e

    # re-raise the error
    exception = cat.comm.bcast(exception, root=0)
    if exception is not None:
        raise exception

    # re-scatter the data evenly
    data = ScatterArray(data, cat.comm, root=0)

    # re-initialize with new source
    if inplace:
        PopulatedHaloCatalog.__init__(cat,
                                      data,
                                      model,
                                      cat.cosmo,
                                      comm=cat.comm)
        galcat = cat
    else:
        galcat = PopulatedHaloCatalog(data, model, cat.cosmo, comm=cat.comm)

    # crash with no particles!
    if galcat.csize == 0:
        raise ValueError(
            "no particles in catalog after populating halo catalog")

    # add Position, Velocity
    galcat['Position'] = transform.StackColumns(galcat['x'], galcat['y'],
                                                galcat['z'])
    galcat['Velocity'] = transform.StackColumns(galcat['vx'], galcat['vy'],
                                                galcat['vz'])

    # add VelocityOffset
    z = cat.attrs['redshift']
    rsd_factor = (1 + z) / (100. * cat.cosmo.efunc(z))
    galcat['VelocityOffset'] = galcat['Velocity'] * rsd_factor

    # add meta-data
    galcat.attrs.update(cat.attrs)
    galcat.attrs.update(model.param_dict)
    galcat.attrs['seed'] = seed
    galcat.attrs['gal_types'] = {t: i for i, t in enumerate(gal_types)}

    # propagate total number of halos for logging
    if galcat.comm.rank == 0:
        Nhalos = len(galcat.model.mock.halo_table)
    else:
        Nhalos = None
    Nhalos = galcat.comm.bcast(Nhalos, root=0)

    # and log some info
    _log_populated_stats(galcat, Nhalos)

    return galcat
示例#10
0
def weigh_and_shift_uni_cat(delta_for_weights,
                            displacements,
                            Nptcles_per_dim,
                            out_Ngrid,
                            boxsize,
                            internal_scale_factor_for_weights=None,
                            out_scale_factor=None,
                            cosmo_params=None,
                            weighted_CIC_mode=None,
                            uni_cat_generator='pmesh',
                            plot_slices=False,
                            verbose=False,
                            return_catalog=False):
    """
    Make uniform catalog, weigh by delta_for_weights, displace by displacements
    and interpolate to grid which we return.

    Parameters
    ----------
    delta_for_weights : None or pmesh.pm.RealField object
        Particles are weighted by 1+delta_for_weights. If None, use weight=1.

    displacements : list
        [Psi_x, Psi_y, Psi_z] where Psi_i are pmesh.pm.RealField objects,
        holding the displacement field in different directions (on the grid).
        If None, do not shift.

    uni_cat_generator : string
        If 'pmesh', use pmesh for generating uniform catalog.
        If 'manual', use an old serial code.

    Returns
    -------
    delta_shifted : FieldMesh object
        Density delta_shifted of shifted weighted particles (normalized to mean
        of 1 i.e. returning 1+delta). Returned if return_catalog=False.

    attrs : meshsource attrs
        Attrs of delta_shifted. Returned if return_catalog=False.

    catalog_shifted : ArrayCatalog object
        Shifted catalog, returned if return_catalog=True.
    """
    comm = CurrentMPIComm.get()

    # ######################################################################
    # Generate uniform catalog with Nptcles_per_dim^3 particles on regular
    # grid
    # ######################################################################

    if uni_cat_generator == 'pmesh':
        # use pmesh generate_uniform_particle_grid
        # http://rainwoodman.github.io/pmesh/pmesh.pm.html?highlight=
        # readout#pmesh.pm.ParticleMesh.generate_uniform_particle_grid
        pmesh = ParticleMesh(
            BoxSize=boxsize,
            Nmesh=[Nptcles_per_dim, Nptcles_per_dim, Nptcles_per_dim])
        ptcles = pmesh.generate_uniform_particle_grid(shift=0.0, dtype='f8')
        #print("type ptcles", type(ptcles), ptcles.shape)
        #print("head ptcles:", ptcles[:5,:])

        dtype = np.dtype([('Position', ('f8', 3))])

        # number of rows is given by number of ptcles on this rank
        uni_cat_array = np.empty((ptcles.shape[0], ), dtype=dtype)
        uni_cat_array['Position'] = ptcles

        uni_cat = ArrayCatalog(
            uni_cat_array,
            comm=None,
            BoxSize=boxsize * np.ones(3),
            Nmesh=[Nptcles_per_dim, Nptcles_per_dim, Nptcles_per_dim])

        print("%d: local Nptcles=%d, global Nptcles=%d" %
              (comm.rank, uni_cat.size, uni_cat.csize))

        del ptcles
        del uni_cat_array

    elif uni_cat_generator == 'manual':

        # Old serial code to generate regular grid and scatter across ranks.
        if comm.rank == 0:
            # Code copied from do_rec_v1.py and adopted

            # Note that nbkit UniformCatalog is random catalog, but we want a catalog
            # where each ptcle sits at grid points of a regular grid.
            # This is what we call 'regular uniform' catalog.
            Np = Nptcles_per_dim
            dtype = np.dtype([('Position', ('f8', 3))])
            # Have Np**3 particles, and each particle has position x,y,z and weight 'Weight'
            uni_cat_array = np.empty((Np**3, ), dtype=dtype)

            # x components in units such that box ranges from 0 to 1. Note dx=1/Np.
            #x_components_1d = np.linspace(0.0, (Np-1)*(L/float(Np)), num=Np, endpoint=True)/L
            x_components_1d = np.linspace(0.0, (Np - 1) / float(Np),
                                          num=Np,
                                          endpoint=True)
            ones_1d = np.ones(x_components_1d.shape)

            # Put particles on the regular grid
            print("%d: Fill regular uniform catalog" % comm.rank)
            uni_cat_array['Position'][:,
                                      0] = np.einsum('a,b,c->abc',
                                                     x_components_1d, ones_1d,
                                                     ones_1d).reshape(
                                                         (Np**3, ))
            uni_cat_array['Position'][:,
                                      1] = np.einsum('a,b,c->abc', ones_1d,
                                                     x_components_1d,
                                                     ones_1d).reshape(
                                                         (Np**3, ))
            uni_cat_array['Position'][:,
                                      2] = np.einsum('a,b,c->abc', ones_1d,
                                                     ones_1d,
                                                     x_components_1d).reshape(
                                                         (Np**3, ))
            print("%d: Done filling regular uniform catalog" % comm.rank)

            # in nbkit0.3 units must be in Mpc/h
            uni_cat_array['Position'] *= boxsize

        else:
            uni_cat_array = None

        # Scatter across all ranks
        print("%d: Scatter array" % comm.rank)
        from nbodykit.utils import ScatterArray
        uni_cat_array = ScatterArray(uni_cat_array, comm, root=0, counts=None)
        print("%d: Scatter array done. Shape: %s" %
              (comm.rank, str(uni_cat_array.shape)))

        # Save in ArrayCatalog object
        uni_cat = ArrayCatalog(uni_cat_array)
        uni_cat.attrs['BoxSize'] = np.ones(3) * boxsize
        uni_cat.attrs['Nmesh'] = np.ones(3) * Nptcles_per_dim
        uni_cat.attrs['Nmesh_internal'] = np.ones(3) * Nmesh_orig

    else:
        raise Exception('Invalid uni_cat_generator %s' % uni_cat_generator)

    ########################################################################
    # Set weight of particles in uni_cat to delta (interpolated to ptcle
    # positions)
    ########################################################################
    if delta_for_weights is None:
        # set all weights to 1
        uni_cat['Mass'] = np.ones(uni_cat['Position'].shape[0])
    else:
        # weight by delta_for_weights
        nbkit03_utils.interpolate_pm_rfield_to_catalog(
            delta_for_weights, uni_cat, catalog_column_to_save_to='Mass')

    print("%d: rms Mass: %g" %
          (comm.rank, np.sqrt(np.mean(np.array(uni_cat['Mass'])**2))))

    # optionally plot weighted uniform cat before shifting
    if plot_slices:
        # paint the original uni_cat to a grid and plot slice
        import matplotlib.pyplot as plt

        tmp_meshsource = uni_cat.to_mesh(Nmesh=out_Ngrid,
                                         value='Mass',
                                         window='cic',
                                         compensated=False,
                                         interlaced=False)
        # paint to get delta(a_internal)
        tmp_outfield = tmp_meshsource.paint(mode='real')
        # linear rescale factor from internal_scale_factor_for_weights to
        # out_scale_factor
        rescalefac = nbkit03_utils.linear_rescale_fac(
            internal_scale_factor_for_weights,
            out_scale_factor,
            cosmo_params=cosmo_params)
        tmp_outfield = 1.0 + rescalefac * (tmp_outfield - 1.0)
        tmp_mesh = FieldMesh(tmp_outfield)
        plt.imshow(tmp_mesh.preview(Nmesh=32, axes=(0, 1)))
        if comm.rank == 0:
            plt_fname = 'inmesh_Np%d_Nm%d_Ng%d.pdf' % (Nptcles_per_dim,
                                                       Nmesh_orig, out_Ngrid)
            plt.savefig(plt_fname)
            print("Made %s" % plt_fname)
        del tmp_meshsource, rescalefac, tmp_outfield, tmp_mesh

    # ######################################################################
    # Shift uniform catalog particles by Psi (changes uni_cat)
    # ######################################################################
    nbkit03_utils.shift_catalog_by_psi_grid(
        cat=uni_cat,
        in_displacement_rfields=displacements,
        pos_column='Position',
        pos_units='Mpc/h',
        displacement_units='Mpc/h',
        boxsize=boxsize,
        verbose=verbose)
    #del Psi_rfields

    if return_catalog:
        # return shifted catalog
        return uni_cat

    else:
        # return density of shifted catalog, delta_shifted

        # ######################################################################
        # paint shifted catalog to grid, using field_to_shift as weights
        # ######################################################################

        print("%d: paint shifted catalog to grid using mass weights" %
              comm.rank)

        # this gets 1+delta
        if weighted_CIC_mode == 'sum':
            delta_shifted, attrs = paint_utils.weighted_paint_cat_to_delta(
                uni_cat,
                weight='Mass',
                Nmesh=out_Ngrid,
                weighted_paint_mode=weighted_CIC_mode,
                normalize=True,  # compute 1+delta
                verbose=verbose,
                to_mesh_kwargs={
                    'window': 'cic',
                    'compensated': False,
                    'interlaced': False
                })

        # this get rho
        elif weighted_CIC_mode == 'avg':
            delta_shifted, attrs = paint_utils.mass_avg_weighted_paint_cat_to_rho(
                uni_cat,
                weight='Mass',
                Nmesh=out_Ngrid,
                verbose=verbose,
                to_mesh_kwargs={
                    'window': 'cic',
                    'compensated': False,
                    'interlaced': False
                })

        else:
            raise Exception('Invalid weighted_CIC_mode %s' % weighted_CIC_mode)

        # ######################################################################
        # rescale to output redshift
        # ######################################################################

        if internal_scale_factor_for_weights != out_scale_factor:
            # linear rescale factor from internal_scale_factor_for_weights to
            # out_scale_factor
            rescalefac = nbkit03_utils.linear_rescale_fac(
                internal_scale_factor_for_weights,
                out_scale_factor,
                cosmo_params=cosmo_params)

            delta_shifted *= rescalefac

            # print some info:
            if comm.rank == 0:
                print(
                    "%d: Linear rescalefac from a=%g to a=%g, rescalefac=%g" %
                    (comm.rank, internal_scale_factor_for_weights,
                     out_scale_factor, rescalefac))

            raise Exception(
                'Check if rescaling of delta_shifted is correct. Looks like 1+delta.'
            )

        if verbose:
            print("%d: delta_shifted: min, mean, max, rms(x-1):" % comm.rank,
                  np.min(delta_shifted), np.mean(delta_shifted),
                  np.max(delta_shifted),
                  np.mean((delta_shifted - 1.)**2)**0.5)

        # get 1+deta mesh from field
        #outmesh = FieldMesh(1 + out_delta)

        # print some info: this makes code never finish (race condition maybe?)
        #nbkit03_utils.rfield_print_info(outfield, comm, 'outfield: ')

        return delta_shifted, attrs
示例#11
0
def poisson_sample_to_points(delta,
                             displacement,
                             pm,
                             nbar,
                             bias=1.,
                             seed=None,
                             comm=None):
    """
    Poisson sample the linear delta and displacement fields to points.

    The steps in this function:

    #.  Apply a biased, lognormal transformation to the input ``delta`` field
    #.  Poisson sample the overdensity field to discrete points
    #.  Disribute the positions of particles uniformly within the mesh cells,
        and assign the displacement field at each cell to the particles

    Parameters
    ----------
    delta : RealField
        the linear overdensity field to sample
    displacement : list of RealField (3,)
        the linear displacement fields which is used to move the particles
    nbar : float
        the desired number density of the output catalog of objects
    bias : float, optional
        apply a linear bias to the overdensity field (default is 1.)
    seed : int, optional
        the random seed used to Poisson sample the field to points

    Returns
    -------
    pos : array_like, (N, 3)
        the Cartesian positions of each of the generated particles
    displ : array_like, (N, 3)
        the displacement field sampled for each of the generated particles in the
        same units as the ``pos`` array
    """
    if comm is None:
        comm = MPI.COMM_WORLD

    # create a random state with the input seed
    rng = numpy.random.RandomState(seed)

    # apply the lognormal transformation to the initial conditions density
    # this creates a positive-definite delta (necessary for Poisson sampling)
    lagrangian_bias = bias - 1.
    delta = lognormal_transform(delta, bias=lagrangian_bias)

    # mean number of objects per cell
    H = delta.BoxSize / delta.Nmesh
    overallmean = H.prod() * nbar

    # number of objects in each cell (per rank)
    cellmean = delta.value * overallmean
    cellmean = GatherArray(cellmean.flatten(), comm, root=0)

    # rank 0 computes the poisson sampling
    if comm.rank == 0:
        N = rng.poisson(cellmean)
    else:
        N = None

    # scatter N back evenly across the ranks
    counts = comm.allgather(delta.value.size)
    N = ScatterArray(N, comm, root=0, counts=counts).reshape(delta.shape)

    Nlocal = N.sum()  # local number of particles
    Ntot = comm.allreduce(Nlocal)  # the collective number of particles
    nonzero_cells = N.nonzero()  # indices of nonzero cells

    # initialize the mesh of particle positions and displacement
    # this has the shape: (number of dimensions, number of nonzero cells)
    pos_mesh = numpy.empty(numpy.shape(nonzero_cells), dtype=delta.dtype)
    disp_mesh = numpy.empty_like(pos_mesh)

    # generate the coordinates for each nonzero cell
    for i in range(delta.ndim):

        # particle positions initially on the coordinate grid
        pos_mesh[i] = numpy.squeeze(delta.pm.x[i])[nonzero_cells[i]]

        # displacements for each particle
        disp_mesh[i] = displacement[i][nonzero_cells]

    # rank 0 computes the in-cell uniform offsets
    if comm.rank == 0:
        in_cell_shift = numpy.empty((Ntot, delta.ndim), dtype=delta.dtype)
        for i in range(delta.ndim):
            in_cell_shift[:, i] = rng.uniform(0, H[i], size=Ntot)
    else:
        in_cell_shift = None

    # scatter the in-cell uniform offsets back to the ranks
    counts = comm.allgather(Nlocal)
    in_cell_shift = ScatterArray(in_cell_shift, comm, root=0, counts=counts)

    # initialize the output array of particle positions and displacement
    # this has shape: (local number of particles, number of dimensions)
    pos = numpy.zeros((Nlocal, delta.ndim), dtype=delta.dtype)
    disp = numpy.zeros_like(pos)

    # coordinates of each object (placed randomly in each cell)
    for i in range(delta.ndim):
        pos[:, i] = numpy.repeat(pos_mesh[i],
                                 N[nonzero_cells]) + in_cell_shift[:, i]
        pos[:, i] %= delta.BoxSize[i]

    # displacements of each object
    for i in range(delta.ndim):
        disp[:, i] = numpy.repeat(disp_mesh[i], N[nonzero_cells])

    return pos, disp