def _rotate_tensor_block(buf): ntasks = mpi.pool.size tasks = list(range(ntasks)) tasks = tasks[rank:] + tasks[:rank] buf_prefetch = [None] def rotate(): buf_prefetch[0] = mpi.rotate(buf, blocking=False) # DO NOT ThreadWithReturnValue, the return value of mpi.rotate is too large # for Queue module. handler = lib.ThreadWithTraceBack(target=rotate, args=()) handler.start() for k, task in enumerate(tasks): if task != rank: handler.join() buf = buf_prefetch[0] if k + 1 < ntasks: handler = lib.ThreadWithTraceBack(target=rotate, args=()) handler.start() yield task, buf
def rotate(sendbuf, blocking=True, tag=0): '''On every process, pass the sendbuf to the next process. Node-ID Before-rotate After-rotate node-0 buf-0 buf-1 node-1 buf-1 buf-2 node-2 buf-2 buf-3 node-3 buf-3 buf-0 ''' if pool.size <= 1: return sendbuf if rank == 0: prev_node = pool.size - 1 next_node = 1 elif rank == pool.size - 1: prev_node = rank - 1 next_node = 0 else: prev_node = rank - 1 next_node = rank + 1 if isinstance(sendbuf, numpy.ndarray): if blocking: if rank % 2 == 0: send(sendbuf, prev_node, tag) recvbuf = recv(next_node, tag) else: recvbuf = recv(next_node, tag) send(sendbuf, prev_node, tag) else: handler = lib.ThreadWithTraceBack(target=send, args=(sendbuf, prev_node, tag)) handler.start() recvbuf = recv(next_node, tag) handler.join() else: if rank % 2 == 0: comm.send(sendbuf, dest=next_node, tag=tag) recvbuf = comm.recv(source=prev_node, tag=tag) else: recvbuf = comm.recv(source=prev_node, tag=tag) comm.send(sendbuf, dest=next_node, tag=tag) return recvbuf