def test_gather_scalar_jit(): from mpi4jax import gather arr = rank res = jax.jit(lambda x: gather(x, root=0)[0])(arr) if rank == 0: assert jnp.array_equal(res, jnp.arange(size)) else: assert jnp.array_equal(res, arr)
def test_gather_scalar(): from mpi4jax import gather arr = rank res, _ = gather(arr, root=0) if rank == 0: assert jnp.array_equal(res, jnp.arange(size)) else: assert jnp.array_equal(res, arr)
def test_gather(): from mpi4jax import gather arr = jnp.ones((3, 2)) * rank res, _ = gather(arr, root=0) if rank == 0: for p in range(size): assert jnp.array_equal(res[p], jnp.ones((3, 2)) * p) else: assert jnp.array_equal(res, arr)
def test_gather_jit(): from mpi4jax import gather arr = jnp.ones((3, 2)) * rank res = jax.jit(lambda x: gather(x, root=0)[0])(arr) if rank == 0: for p in range(size): assert jnp.array_equal(res[p], jnp.ones((3, 2)) * p) else: assert jnp.array_equal(res, arr)
anim = animation.FuncAnimation( fig, animate, frames=len(sol), interval=50, blit=True, repeat_delay=3_000 ) return anim if __name__ == "__main__": benchmark_mode = "--benchmark" in sys.argv sol = solve_shallow_water(t1=10 * DAY_IN_SECONDS, num_multisteps=PLOT_EVERY) if benchmark_mode: sys.exit(0) # copy solution to mpi_rank 0 full_sol_arr, _ = mpi4jax.gather(jnp.asarray(sol), root=0, comm=mpi_comm) if mpi_rank == 0: # full_sol_arr has shape (nproc, time, nvars, ny, nx) full_sol_arr = jnp.moveaxis(full_sol_arr, 0, 2) full_sol = [ModelState(*reassemble_array(x)) for x in full_sol_arr] anim = animate_shallow_water(full_sol) if "--save-animation" in sys.argv: # save animation as MP4 video (requires ffmpeg) anim.save("shallow-water.mp4", writer="ffmpeg", dpi=100) else: import matplotlib.pyplot as plt plt.show()