def send_jit(x): send(x, 0, tag=rank) return x
def enforce_boundaries(arr, grid, token=None): """Handle boundary exchange between processors. This is where mpi4jax comes in! """ assert grid in ("h", "u", "v") # start sending west, go clockwise send_order = ( "west", "north", "east", "south", ) # start receiving east, go clockwise recv_order = ( "east", "south", "west", "north", ) overlap_slices_send = dict( south=(1, slice(None), Ellipsis), west=(slice(None), 1, Ellipsis), north=(-2, slice(None), Ellipsis), east=(slice(None), -2, Ellipsis), ) overlap_slices_recv = dict( south=(0, slice(None), Ellipsis), west=(slice(None), 0, Ellipsis), north=(-1, slice(None), Ellipsis), east=(slice(None), -1, Ellipsis), ) proc_neighbors = { "south": (proc_idx[0] - 1, proc_idx[1]) if proc_idx[0] > 0 else None, "west": (proc_idx[0], proc_idx[1] - 1) if proc_idx[1] > 0 else None, "north": (proc_idx[0] + 1, proc_idx[1]) if proc_idx[0] < nproc_y - 1 else None, "east": (proc_idx[0], proc_idx[1] + 1) if proc_idx[1] < nproc_x - 1 else None, } if PERIODIC_BOUNDARY_X: if proc_idx[1] == 0: proc_neighbors["west"] = (proc_idx[0], nproc_x - 1) if proc_idx[1] == nproc_x - 1: proc_neighbors["east"] = (proc_idx[0], 0) if token is None: token = jax.lax.create_token() for send_dir, recv_dir in zip(send_order, recv_order): send_proc = proc_neighbors[send_dir] recv_proc = proc_neighbors[recv_dir] if send_proc is None and recv_proc is None: continue if send_proc is not None: send_proc = np.ravel_multi_index(send_proc, (nproc_y, nproc_x)) if recv_proc is not None: recv_proc = np.ravel_multi_index(recv_proc, (nproc_y, nproc_x)) recv_idx = overlap_slices_recv[recv_dir] recv_arr = jnp.empty_like(arr[recv_idx]) send_idx = overlap_slices_send[send_dir] send_arr = arr[send_idx] if send_proc is None: recv_arr, token = mpi4jax.recv( recv_arr, source=recv_proc, comm=mpi_comm, token=token ) arr = arr.at[recv_idx].set(recv_arr) elif recv_proc is None: token = mpi4jax.send(send_arr, dest=send_proc, comm=mpi_comm, token=token) else: recv_arr, token = mpi4jax.sendrecv( send_arr, recv_arr, source=recv_proc, dest=send_proc, comm=mpi_comm, token=token, ) arr = arr.at[recv_idx].set(recv_arr) if not PERIODIC_BOUNDARY_X and grid == "u" and proc_idx[1] == nproc_x - 1: arr = arr.at[:, -2].set(0.0) if grid == "v" and proc_idx[0] == nproc_y - 1: arr = arr.at[-2, :].set(0.0) return arr, token