def test_gather_scalar_jit():
    from mpi4jax import gather

    arr = rank
    res = jax.jit(lambda x: gather(x, root=0)[0])(arr)
    if rank == 0:
        assert jnp.array_equal(res, jnp.arange(size))
    else:
        assert jnp.array_equal(res, arr)
def test_gather_scalar():
    from mpi4jax import gather

    arr = rank
    res, _ = gather(arr, root=0)
    if rank == 0:
        assert jnp.array_equal(res, jnp.arange(size))
    else:
        assert jnp.array_equal(res, arr)
def test_gather():
    from mpi4jax import gather

    arr = jnp.ones((3, 2)) * rank

    res, _ = gather(arr, root=0)
    if rank == 0:
        for p in range(size):
            assert jnp.array_equal(res[p], jnp.ones((3, 2)) * p)
    else:
        assert jnp.array_equal(res, arr)
def test_gather_jit():
    from mpi4jax import gather

    arr = jnp.ones((3, 2)) * rank

    res = jax.jit(lambda x: gather(x, root=0)[0])(arr)
    if rank == 0:
        for p in range(size):
            assert jnp.array_equal(res[p], jnp.ones((3, 2)) * p)
    else:
        assert jnp.array_equal(res, arr)
Example #5
0
    anim = animation.FuncAnimation(
        fig, animate, frames=len(sol), interval=50, blit=True, repeat_delay=3_000
    )
    return anim


if __name__ == "__main__":
    benchmark_mode = "--benchmark" in sys.argv

    sol = solve_shallow_water(t1=10 * DAY_IN_SECONDS, num_multisteps=PLOT_EVERY)

    if benchmark_mode:
        sys.exit(0)

    # copy solution to mpi_rank 0
    full_sol_arr, _ = mpi4jax.gather(jnp.asarray(sol), root=0, comm=mpi_comm)

    if mpi_rank == 0:
        # full_sol_arr has shape (nproc, time, nvars, ny, nx)
        full_sol_arr = jnp.moveaxis(full_sol_arr, 0, 2)
        full_sol = [ModelState(*reassemble_array(x)) for x in full_sol_arr]

        anim = animate_shallow_water(full_sol)

        if "--save-animation" in sys.argv:
            # save animation as MP4 video (requires ffmpeg)
            anim.save("shallow-water.mp4", writer="ffmpeg", dpi=100)
        else:
            import matplotlib.pyplot as plt

            plt.show()