def gather_new(sendbuf, root=0, split_recvbuf=False): """ gather that avoids overflow of displs. """ sendbuf = numpy.asarray(sendbuf, order='C') shape = sendbuf.shape size_dtype = comm.allgather((shape, sendbuf.dtype.char)) rshape = [x[0] for x in size_dtype] counts = numpy.array([numpy.prod(x) for x in rshape]) dtype = mpi_dtype = numpy.result_type(*[x[1] for x in size_dtype]).char _assert(sendbuf.dtype == mpi_dtype or sendbuf.size == 0) matched = all([x[1:] == rshape[0][1:] for x in rshape[1:]]) if matched: elem_dtype = MPI._typedict.get(mpi_dtype) each_count = numpy.prod(rshape[0][1:]) counts = counts // each_count mpi_dtype = MPI.Datatype(elem_dtype).Create_contiguous( each_count).Commit() else: each_count = 1 if rank == root: displs = numpy.append(0, numpy.cumsum(counts[:-1])) recvbuf = numpy.empty(sum(counts * each_count), dtype=dtype) sendbuf = sendbuf.ravel() for p0, p1 in lib.prange(0, numpy.max(counts), BLKSIZE): counts_seg = _segment_counts(counts, p0, p1) comm.Gatherv([sendbuf[p0 * each_count:p1 * each_count], mpi_dtype], [recvbuf, counts_seg, displs + p0, mpi_dtype], root) if matched: mpi_dtype.Free() if split_recvbuf: return [ recvbuf[p0 * each_count:(p0 + c) * each_count].reshape(shape) for p0, c, shape in zip(displs, counts, rshape) ] else: try: return recvbuf.reshape((-1, ) + shape[1:]) except ValueError: return recvbuf else: send_seg = sendbuf.ravel() for p0, p1 in lib.prange(0, numpy.max(counts), BLKSIZE): comm.Gatherv( [send_seg[p0 * each_count:p1 * each_count], mpi_dtype], None, root) if matched: mpi_dtype.Free() return sendbuf
def alltoall_new(sendbuf, split_recvbuf=False): _assert(len(sendbuf) == pool.size) dtype = comm.bcast(sendbuf[0].dtype) sendbuf = [numpy.asarray(x, dtype) for x in sendbuf] elem_dtype = MPI._typedict.get(dtype.char) sshape = [x.shape for x in sendbuf] # find an axis to segment for i in range(len(sshape[0])): counts = numpy.array([numpy.prod(s) // s[i] for s in sshape]) if all(counts == counts[0]): each_count = counts[0] break else: raise ValueError mpi_dtype = MPI.Datatype(elem_dtype).Create_contiguous(each_count).Commit() rshape = comm.alltoall(sshape) scounts = numpy.asarray([x.size // each_count for x in sendbuf]) sdispls = numpy.append(0, numpy.cumsum(scounts[:-1])) sendbuf = numpy.hstack([x.ravel() for x in sendbuf]) rcounts = numpy.asarray([numpy.prod(x) // each_count for x in rshape]) rdispls = numpy.append(0, numpy.cumsum(rcounts[:-1])) recvbuf = numpy.empty(sum(rcounts) * each_count, dtype=dtype) max_counts = max(numpy.max(scounts), numpy.max(rcounts)) sendbuf = sendbuf.ravel() #DONOT use lib.prange. lib.prange may terminate early in some processes for p0, p1 in prange(0, max_counts, BLKSIZE): scounts_seg = _segment_counts(scounts, p0, p1) rcounts_seg = _segment_counts(rcounts, p0, p1) comm.Alltoallv([sendbuf, scounts_seg, sdispls + p0, mpi_dtype], [recvbuf, rcounts_seg, rdispls + p0, mpi_dtype]) mpi_dtype.Free() if split_recvbuf: return [ recvbuf[p0 * each_count:(p0 + c) * each_count].reshape(shape) for p0, c, shape in zip(rdispls, rcounts, rshape) ] else: return recvbuf
def scatter_new(sendbuf, root=0, data=None): """ scatter that avoids the dipls overflow. """ if rank == root: dtype = sendbuf[0].dtype shape = [x.shape for x in sendbuf] matched = all([x[1:] == shape[0][1:] for x in shape[1:]]) shape = comm.scatter(shape) counts = numpy.asarray([x.size for x in sendbuf]) comm.bcast((dtype, counts, matched)) if data is None: sendbuf = [ numpy.asarray(x, dtype, order='C').ravel() for x in sendbuf ] sendbuf = numpy.hstack(sendbuf) else: sendbuf = numpy.asarray(data, order='C') else: shape = comm.scatter(None) dtype, counts, matched = comm.bcast(None) if matched: elem_dtype = MPI._typedict.get(dtype.char) each_count = numpy.prod(shape[1:]) counts = counts // each_count mpi_dtype = MPI.Datatype(elem_dtype).Create_contiguous( each_count).Commit() else: each_count = 1 mpi_dtype = dtype displs = numpy.append(0, numpy.cumsum(counts[:-1])) recvbuf = numpy.empty(numpy.prod(shape), dtype=dtype) for p0, p1 in prange(0, numpy.max(counts), BLKSIZE): counts_seg = _segment_counts(counts, p0, p1) comm.Scatterv([sendbuf, counts_seg, displs + p0, mpi_dtype], [recvbuf[p0 * each_count:p1 * each_count], mpi_dtype], root) if matched: mpi_dtype.Free() return recvbuf.reshape(shape)
def testBoolEqNe(self): for dtype in datatypes: self.assertTrue(not not dtype) self.assertTrue(dtype == MPI.Datatype(dtype)) self.assertFalse(dtype != MPI.Datatype(dtype))