Example #1
0
def gather_new(sendbuf, root=0, split_recvbuf=False):
    """
    gather that avoids overflow of displs.
    """
    sendbuf = numpy.asarray(sendbuf, order='C')
    shape = sendbuf.shape
    size_dtype = comm.allgather((shape, sendbuf.dtype.char))
    rshape = [x[0] for x in size_dtype]
    counts = numpy.array([numpy.prod(x) for x in rshape])

    dtype = mpi_dtype = numpy.result_type(*[x[1] for x in size_dtype]).char
    _assert(sendbuf.dtype == mpi_dtype or sendbuf.size == 0)

    matched = all([x[1:] == rshape[0][1:] for x in rshape[1:]])
    if matched:
        elem_dtype = MPI._typedict.get(mpi_dtype)
        each_count = numpy.prod(rshape[0][1:])
        counts = counts // each_count
        mpi_dtype = MPI.Datatype(elem_dtype).Create_contiguous(
            each_count).Commit()
    else:
        each_count = 1

    if rank == root:
        displs = numpy.append(0, numpy.cumsum(counts[:-1]))
        recvbuf = numpy.empty(sum(counts * each_count), dtype=dtype)

        sendbuf = sendbuf.ravel()
        for p0, p1 in lib.prange(0, numpy.max(counts), BLKSIZE):
            counts_seg = _segment_counts(counts, p0, p1)
            comm.Gatherv([sendbuf[p0 * each_count:p1 * each_count], mpi_dtype],
                         [recvbuf, counts_seg, displs + p0, mpi_dtype], root)
        if matched:
            mpi_dtype.Free()
        if split_recvbuf:
            return [
                recvbuf[p0 * each_count:(p0 + c) * each_count].reshape(shape)
                for p0, c, shape in zip(displs, counts, rshape)
            ]
        else:
            try:
                return recvbuf.reshape((-1, ) + shape[1:])
            except ValueError:
                return recvbuf
    else:
        send_seg = sendbuf.ravel()
        for p0, p1 in lib.prange(0, numpy.max(counts), BLKSIZE):
            comm.Gatherv(
                [send_seg[p0 * each_count:p1 * each_count], mpi_dtype], None,
                root)
        if matched:
            mpi_dtype.Free()
        return sendbuf
Example #2
0
def alltoall_new(sendbuf, split_recvbuf=False):
    _assert(len(sendbuf) == pool.size)

    dtype = comm.bcast(sendbuf[0].dtype)
    sendbuf = [numpy.asarray(x, dtype) for x in sendbuf]
    elem_dtype = MPI._typedict.get(dtype.char)
    sshape = [x.shape for x in sendbuf]
    # find an axis to segment
    for i in range(len(sshape[0])):
        counts = numpy.array([numpy.prod(s) // s[i] for s in sshape])
        if all(counts == counts[0]):
            each_count = counts[0]
            break
    else:
        raise ValueError

    mpi_dtype = MPI.Datatype(elem_dtype).Create_contiguous(each_count).Commit()
    rshape = comm.alltoall(sshape)

    scounts = numpy.asarray([x.size // each_count for x in sendbuf])
    sdispls = numpy.append(0, numpy.cumsum(scounts[:-1]))
    sendbuf = numpy.hstack([x.ravel() for x in sendbuf])

    rcounts = numpy.asarray([numpy.prod(x) // each_count for x in rshape])
    rdispls = numpy.append(0, numpy.cumsum(rcounts[:-1]))
    recvbuf = numpy.empty(sum(rcounts) * each_count, dtype=dtype)

    max_counts = max(numpy.max(scounts), numpy.max(rcounts))
    sendbuf = sendbuf.ravel()
    #DONOT use lib.prange. lib.prange may terminate early in some processes
    for p0, p1 in prange(0, max_counts, BLKSIZE):
        scounts_seg = _segment_counts(scounts, p0, p1)
        rcounts_seg = _segment_counts(rcounts, p0, p1)
        comm.Alltoallv([sendbuf, scounts_seg, sdispls + p0, mpi_dtype],
                       [recvbuf, rcounts_seg, rdispls + p0, mpi_dtype])

    mpi_dtype.Free()
    if split_recvbuf:
        return [
            recvbuf[p0 * each_count:(p0 + c) * each_count].reshape(shape)
            for p0, c, shape in zip(rdispls, rcounts, rshape)
        ]
    else:
        return recvbuf
Example #3
0
def scatter_new(sendbuf, root=0, data=None):
    """
    scatter that avoids the dipls overflow.
    """
    if rank == root:
        dtype = sendbuf[0].dtype
        shape = [x.shape for x in sendbuf]
        matched = all([x[1:] == shape[0][1:] for x in shape[1:]])
        shape = comm.scatter(shape)
        counts = numpy.asarray([x.size for x in sendbuf])
        comm.bcast((dtype, counts, matched))
        if data is None:
            sendbuf = [
                numpy.asarray(x, dtype, order='C').ravel() for x in sendbuf
            ]
            sendbuf = numpy.hstack(sendbuf)
        else:
            sendbuf = numpy.asarray(data, order='C')
    else:
        shape = comm.scatter(None)
        dtype, counts, matched = comm.bcast(None)

    if matched:
        elem_dtype = MPI._typedict.get(dtype.char)
        each_count = numpy.prod(shape[1:])
        counts = counts // each_count
        mpi_dtype = MPI.Datatype(elem_dtype).Create_contiguous(
            each_count).Commit()
    else:
        each_count = 1
        mpi_dtype = dtype

    displs = numpy.append(0, numpy.cumsum(counts[:-1]))
    recvbuf = numpy.empty(numpy.prod(shape), dtype=dtype)

    for p0, p1 in prange(0, numpy.max(counts), BLKSIZE):
        counts_seg = _segment_counts(counts, p0, p1)
        comm.Scatterv([sendbuf, counts_seg, displs + p0, mpi_dtype],
                      [recvbuf[p0 * each_count:p1 * each_count], mpi_dtype],
                      root)
    if matched:
        mpi_dtype.Free()
    return recvbuf.reshape(shape)
Example #4
0
 def testBoolEqNe(self):
     for dtype in datatypes:
         self.assertTrue(not not dtype)
         self.assertTrue(dtype == MPI.Datatype(dtype))
         self.assertFalse(dtype != MPI.Datatype(dtype))