Example #1
    def __init__(self, pos, weights=None, boxsize=None):
        """ create a dataset object for points located at pos in a boxsize.
            points is of (Npoints, Ndim)
            boxsize will be broadcasted to the dimension of points. 
        self.pos = pos
        self.tree = kdcount.KDTree(self.pos, boxsize=boxsize)
        self.boxsize = self.tree.boxsize

        if weights is None:
            weights = constant_array(len(self.pos))
            weights.value[...] = 1

        self.weights = weights
        self.attr = kdcount.KDAttr(self.tree, weights)
Example #2
    def _run(self, pos, w, pos_sec, w_sec, boxsize=None, bunchsize=10000):
        Internal function to run the 3PCF algorithm on the input data and

        The input data/weights have already been domain-decomposed, and
        the loads should be balanced on all ranks.
        # maximum radius
        rmax = numpy.max(self.attrs['edges'])

        # the array to hold output values
        nbins  = len(self.attrs['edges'])-1
        Nell   = len(self.attrs['poles'])
        zeta = numpy.zeros((Nell,nbins,nbins), dtype='f8')
        alms = {}
        walms = {}

        # compute the Ylm expressions we need
        if self.comm.rank == 0:
            self.logger.info("computing Ylm expressions...")
        Ylm_cache = YlmCache(self.attrs['poles'], self.comm)
        if self.comm.rank ==  0:

        # make the KD-tree holding the secondaries
        tree_sec = kdcount.KDTree(pos_sec, boxsize=boxsize).root

        def callback(r, i, j, iprim=None):

            # remove self pairs
            valid = r > 0.
            r = r[valid]; i = i[valid]

            # normalized, re-centered position array (periodic)
            dpos = (pos_sec[i] - pos[iprim])

            # enforce periodicity in dpos
            if boxsize is not None:
                for axis, col in enumerate(dpos.T):
                    col[col > boxsize[axis]*0.5] -= boxsize[axis]
                    col[col <= -boxsize[axis]*0.5] += boxsize[axis]
            recen_pos = dpos / r[:,numpy.newaxis]

            # find the mapping of r to rbins
            dig = numpy.searchsorted(self.attrs['edges'], r, side='left')

            # evaluate all Ylms
            Ylms = Ylm_cache(recen_pos[:,0]+1j*recen_pos[:,1], recen_pos[:,2])

            # sqrt of primary weight
            w0 = w[iprim]

            # loop over each (l,m) pair
            for (l,m) in Ylms:

                # the Ylm evaluated at galaxy positions
                weights = Ylms[(l,m)] * w_sec[i]

                # sum over for each radial bin
                alm = alms.setdefault((l, m), numpy.zeros(nbins, dtype='c16'))
                walm = walms.setdefault((l, m), numpy.zeros(nbins, dtype='c16'))

                r1 = numpy.bincount(dig, weights=weights.real, minlength=nbins+2)[1:-1]
                alm[...] += r1
                walm[...] += w0 * r1
                if m != 0:
                    i1 = numpy.bincount(dig, weights=weights.imag, minlength=nbins+2)[1:-1]
                    alm[...] += 1j*i1
                    walm[...] += w0*1j*i1

        # determine rank with largest load
        loads = self.comm.allgather(len(pos))
        largest_load = numpy.argmax(loads)
        chunk_size = max(loads) // 10

        # compute multipoles for each primary (s vector in the paper)
        for iprim in range(len(pos)):
            # alms must be clean for each primary particle; (s) in eq 15 and 8 of arXiv:1506.02040v2
            tree_prim = kdcount.KDTree(numpy.atleast_2d(pos[iprim]), boxsize=boxsize).root
            tree_sec.enum(tree_prim, rmax, process=callback, iprim=iprim, bunch=bunchsize)

            if self.comm.rank == largest_load and iprim % chunk_size == 0:
                self.logger.info("%d%% done" % (10*iprim//chunk_size))

            # combine alms into zeta(s);
            # this cannot be done in the callback because
            # it is a nonlinear function (outer product) of alm.
            for (l, m) in alms:
                alm = alms[(l, m)]
                walm = walms[(l, m)]

                # compute alm * conjugate(alm)
                alm_w_alm = numpy.outer(walm, alm.conj())
                if m != 0: alm_w_alm += alm_w_alm.T # add in the -m contribution for m != 0
                zeta[Ylm_cache.ell_to_iell[l], ...] += alm_w_alm.real

        # sum across all ranks
        zeta = self.comm.allreduce(zeta)

        # normalize according to Eq. 15 of Slepian et al. 2015
        # differs by factor of (4 pi)^2 / (2l+1) from the C++ code
        zeta /= (4*numpy.pi)

        # make a BinnedStatistic
        dtype = numpy.dtype([('corr_%d' % ell, zeta.dtype) for ell in self.attrs['poles']])
        data = numpy.empty(zeta.shape[-2:], dtype=dtype)
        for i, ell in enumerate(self.attrs['poles']):
            data['corr_%d' % ell] = zeta[i]

        # save the result
        edges = self.attrs['edges']
        poles = BinnedStatistic(['r1', 'r2'], [edges, edges], data)
        return poles
Example #3
def cgm(comm, data, domain, rperp, rpar, los, boxsize):
    Perform the cylindrical grouping method

    This outputs a structured array with the same length as the input data
    with the following fields for each object in the original data:

    #. cgm_type :
        a flag specifying the type for each object,
        with 0 specifying CGM central and 1 denoting CGM satellite
    #. cgm_haloid :
        The index of the CGM object this object belongs to; an integer
        between 0 and the total number of CGM halos
    #. num_cgm_sats :
        The number of satellites in the CGM halo

    comm :
        the MPI communicator
    data : CatalogSource
        catalog with sorted input data, including Position
    domain :
        the domain decomposition
    rperp, rpar : float
        the maximum distances to group objects together in the directions
        perpendicular and parallel to the line-of-sight; the cylinder
        has radius ``rperp`` and height ``2 * rpar``
    los :
        the line-of-sight vector
    boxsize :
        the boxsize, or ``None`` if not using periodic boundary conditions
    # whether we do periodic boundary conditions
    periodic = boxsize is not None
    flat_sky = los is not None

    # the maximum distance still inside the cylinder set by rperp,rpar
    rperp2 = rperp**2
    rpar2 = rpar**2
    rmax = (rperp2 + rpar2)**0.5

    pos0, origind0, sortindex0 = data.compute(data['Position'],

    layout1 = domain.decompose(pos0, smoothing=0)
    pos1 = layout1.exchange(pos0)
    origind1 = layout1.exchange(origind0)
    sortindex1 = layout1.exchange(sortindex0)

    # exchange particles across ranks, accounting for smoothing radius
    layout2 = domain.decompose(pos1, smoothing=rmax)
    pos2 = layout2.exchange(pos1)
    origind2 = layout2.exchange(origind1)
    sortindex2 = layout2.exchange(sortindex1)
    startrank = layout2.exchange(numpy.ones(len(pos1), dtype='i4') * comm.rank)

    # make the KD-tree
    tree1 = kdcount.KDTree(pos1, boxsize=boxsize).root
    tree2 = kdcount.KDTree(pos2, boxsize=boxsize).root

    dataframe = []
    j_gt_i = numpy.zeros(len(pos1), dtype='f4')
    wrong_rank = numpy.zeros(len(pos1), dtype='f4')

    def callback(r, i, j):

        r1 = pos1[i]
        r2 = pos2[j]
        dr = r1 - r2

        # enforce periodicity in dpos
        if periodic:
            for axis, col in enumerate(dr.T):
                col[col > boxsize[axis] * 0.5] -= boxsize[axis]
                col[col <= -boxsize[axis] * 0.5] += boxsize[axis]

        # los distance
        if flat_sky:
            rlos2 = numpy.einsum("ij,j->i", dr, los)**2
            center = 0.5 * (r1 + r2)
            dot2 = numpy.einsum('ij, ij->i', dr, center)**2
            center2 = numpy.einsum('ij, ij->i', center, center)
            rlos2 = dot2 / center2

        # sky
        dr2 = numpy.einsum('ij, ij->i', dr, dr)
        rsky2 = numpy.abs(dr2 - rlos2)

        # save the valid pairs
        # To Be Valid: pairs must be within cylinder (compare rperp and rpar)
        valid = (rsky2 <= rperp2) & (rlos2 <= rpar2)
        i = i[valid]
        j = j[valid]

        # the correctly sorted indices of particles
        sort_i = sortindex1[i]
        sort_j = sortindex2[j]

        # the rank where the j object lives
        rank_j = startrank[j]

        # track pairs where sorted j > sorted i
        weights = numpy.where(sort_i < sort_j, 1, 0)
        j_gt_i[:] += numpy.bincount(i, weights=weights, minlength=len(pos1))

        # track pairs where j rank is wrong
        weights *= numpy.where(rank_j != comm.rank, 1, 0)
        wrong_rank[:] += numpy.bincount(i,

        # save the valid pairs for final calculations
        res = numpy.vstack([i, j, sort_i, sort_j]).T

    # add all the valid pairs to a dataframe
    tree1.enum(tree2, rmax, process=callback)

    # sorted indices of objects that are centrals
    # (objects with no pairs with j > i)
    centrals = set(sortindex1[(j_gt_i == 0)])

    # sorted indices of objects that might be centrals
    # (pairs with j>i that live on other ranks)
    maybes = set(sortindex1[(wrong_rank > 0)])

    # store the pairs in a pandas dataframe for fast groupby
    dataframe = numpy.concatenate(dataframe, axis=0)
    df = pd.DataFrame(dataframe, columns=['i', 'j', 'sort_i', 'sort_j'])

    # we sort by the correct sorted index in descending order which puts
    # highest priority objects first
    df.sort_values("sort_i", ascending=False, inplace=True)

    # index by the correct sorted order
    df.set_index('sort_i', inplace=True)

    # to find centrals, considers objects that could be satellites of another
    # (pairs with sort_j > sort_i)
    possible_cens = df[(df['sort_j'] > df.index.values)]
    possible_cens = possible_cens.drop(centrals, errors='ignore')
                                centrals)  # remove objs paired with cens

    # sorted indices of objects that have pairs on other ranks
    # these objects are already "maybe" centrals
    on_other_ranks = sortindex1[(wrong_rank > 0)]

    # find the centrals and associated halo labels for each central
    all_centrals, labels = _find_centrals(comm, possible_cens, on_other_ranks,
                                          centrals, maybes)

    # reset the index and return

    # add the halo labels for each pair in the dataframe
    labels = pd.Series(labels,
                       index=pd.Index(all_centrals, name='sort_i'))
    df = df.join(labels, on='sort_i')
    labels.name = 'label_j'
    labels.index.name = 'sort_j'
    df = df.join(labels, on='sort_j')

    # iniitalize the output arrays
    labels = numpy.zeros(len(pos1), dtype='i8') - 1  # indexed by i
    types = numpy.zeros(len(pos1), dtype='u4')  # indexed by i
    counts = numpy.zeros(len(pos2), dtype='i8')  # indexed by j

    # assign labels of the centrals
    cens = df.dropna(subset=['label_j']).drop_duplicates('i')
    labels[cens['i'].values] = cens['label_i'].values

    # objects on this rank that are satellites
    # (no label for the 1st object in pair but a label for the 2nd object)
    sats = (df['label_i'].isnull()) & (~df['label_j'].isnull())
    df = df[sats]

    # find the corresponding central for each satellite
    df = df.sort_values('sort_j', ascending=False)
    df.set_index('sort_i', inplace=True)
    sats_grouped = df.groupby('sort_i', sort=False, as_index=False)
    centrals = sats_grouped.first(
    )  # these are the centrals for each satellite

    # update the satellite info with its pair with the highest priority
    cens_i = centrals['i'].values
    cens_j = centrals['j'].values
    counts += numpy.bincount(cens_j, minlength=len(pos2))
    types[cens_i] = 1
    labels[cens_i] = centrals['label_j'].values

    # sum counts across ranks (take the sum of any repeated objects)
    counts = layout2.gather(counts, mode='sum')

    # output fields
    dtype = numpy.dtype([('cgm_haloid', 'i8'), ('num_cgm_sats', 'i8'),
                         ('cgm_type', 'u4'), ('origind', 'u4')])
    out = numpy.empty(len(data), dtype=dtype)

    # gather the data back onto the original ranks
    # no ghosts for this domain layout so choose any particle
    out['cgm_haloid'] = layout1.gather(labels, mode='any')
    out['origind'] = layout1.gather(origind1, mode='any')
    out['num_cgm_sats'] = layout1.gather(counts, mode='any')
    out['cgm_type'] = layout1.gather(types, mode='any')

    # restore the original order
    mpsort.sort(out, orderby='origind', comm=comm)

    fields = ['cgm_type', 'cgm_haloid', 'num_cgm_sats']
    return out[fields]