def _single_worker_call_function(payload, worker): if mpi.rank0: if worker == 0: function, args, kwargs = payload[0] return mpi.function_call(function, *args, **kwargs) else: mpi.comm.send(payload[0], dest=worker) return mpi.comm.recv(source=worker) else: if mpi.rank != worker: return (function, args, kwargs) = mpi.comm.recv(source=0) retval = mpi.function_call(function, *args, **kwargs) mpi.comm.send(retval, dest=0)
def _worker_map_function(payload, function, **kwargs): if mpi.rank0: args = list(zip(*payload[0])) else: args = None args = zip(*mpi.comm.scatter(args, root=0)) result = [mpi.function_call(function, *a, **kwargs) for a in args] result = mpi.comm.gather(result, root=0) if mpi.rank0: return list(chain(*result))