Пример #1
0
def test_sendrecv_scalar_jit():
    from mpi4jax import sendrecv

    arr = 1 * rank
    _arr = arr

    other = 1 - rank

    res = jax.jit(lambda x, y: sendrecv(x, y, source=other, dest=other)[0])(arr, arr)

    assert jnp.array_equal(res, jnp.ones_like(arr) * other)
    assert jnp.array_equal(_arr, arr)
Пример #2
0
def test_sendrecv():
    from mpi4jax import sendrecv

    arr = jnp.ones((3, 2)) * rank
    _arr = arr.copy()

    other = 1 - rank

    res, token = sendrecv(arr, arr, source=other, dest=other)

    assert jnp.array_equal(res, jnp.ones_like(arr) * other)
    assert jnp.array_equal(_arr, arr)
Пример #3
0
def test_sendrecv_scalar():
    from mpi4jax import sendrecv

    arr = 1 * rank
    _arr = arr

    other = 1 - rank

    res, token = sendrecv(arr, arr, source=other, dest=other)

    assert jnp.array_equal(res, jnp.ones_like(arr) * other)
    assert jnp.array_equal(_arr, arr)
Пример #4
0
def test_sendrecv_status():
    from mpi4jax import sendrecv

    arr = jnp.ones((3, 2)) * rank
    _arr = arr.copy()

    other = 1 - rank

    status = MPI.Status()
    res, token = sendrecv(arr, arr, source=other, dest=other, status=status)

    assert jnp.array_equal(res, jnp.ones_like(arr) * other)
    assert jnp.array_equal(_arr, arr)
    assert status.Get_source() == other
Пример #5
0
def test_sendrecv_vmap():
    from mpi4jax import sendrecv

    arr = jnp.ones((3, 2)) * rank
    _arr = arr.copy()

    other = 1 - rank

    res = sendrecv(arr, arr, source=other, dest=other)[0]

    def fun(x, y):
        return sendrecv(x, y, source=other, dest=other)[0]

    vfun = jax.vmap(fun, in_axes=(0, 0))
    res = vfun(_arr, arr)

    assert jnp.array_equal(res, jnp.ones_like(arr) * other)
    assert jnp.array_equal(_arr, arr)
Пример #6
0
 def f(x):
     x, token = sendrecv(x, x, source=other, dest=other)
     x = x * (rank + 1)
     return x.sum()
Пример #7
0
 def fun(x, y):
     return sendrecv(x, y, source=other, dest=other)[0]
Пример #8
0
def enforce_boundaries(arr, grid, token=None):
    """Handle boundary exchange between processors.

    This is where mpi4jax comes in!
    """
    assert grid in ("h", "u", "v")

    # start sending west, go clockwise
    send_order = (
        "west",
        "north",
        "east",
        "south",
    )

    # start receiving east, go clockwise
    recv_order = (
        "east",
        "south",
        "west",
        "north",
    )

    overlap_slices_send = dict(
        south=(1, slice(None), Ellipsis),
        west=(slice(None), 1, Ellipsis),
        north=(-2, slice(None), Ellipsis),
        east=(slice(None), -2, Ellipsis),
    )

    overlap_slices_recv = dict(
        south=(0, slice(None), Ellipsis),
        west=(slice(None), 0, Ellipsis),
        north=(-1, slice(None), Ellipsis),
        east=(slice(None), -1, Ellipsis),
    )

    proc_neighbors = {
        "south": (proc_idx[0] - 1, proc_idx[1]) if proc_idx[0] > 0 else None,
        "west": (proc_idx[0], proc_idx[1] - 1) if proc_idx[1] > 0 else None,
        "north": (proc_idx[0] + 1, proc_idx[1]) if proc_idx[0] < nproc_y - 1 else None,
        "east": (proc_idx[0], proc_idx[1] + 1) if proc_idx[1] < nproc_x - 1 else None,
    }

    if PERIODIC_BOUNDARY_X:
        if proc_idx[1] == 0:
            proc_neighbors["west"] = (proc_idx[0], nproc_x - 1)

        if proc_idx[1] == nproc_x - 1:
            proc_neighbors["east"] = (proc_idx[0], 0)

    if token is None:
        token = jax.lax.create_token()

    for send_dir, recv_dir in zip(send_order, recv_order):
        send_proc = proc_neighbors[send_dir]
        recv_proc = proc_neighbors[recv_dir]

        if send_proc is None and recv_proc is None:
            continue

        if send_proc is not None:
            send_proc = np.ravel_multi_index(send_proc, (nproc_y, nproc_x))

        if recv_proc is not None:
            recv_proc = np.ravel_multi_index(recv_proc, (nproc_y, nproc_x))

        recv_idx = overlap_slices_recv[recv_dir]
        recv_arr = jnp.empty_like(arr[recv_idx])

        send_idx = overlap_slices_send[send_dir]
        send_arr = arr[send_idx]

        if send_proc is None:
            recv_arr, token = mpi4jax.recv(
                recv_arr, source=recv_proc, comm=mpi_comm, token=token
            )
            arr = arr.at[recv_idx].set(recv_arr)
        elif recv_proc is None:
            token = mpi4jax.send(send_arr, dest=send_proc, comm=mpi_comm, token=token)
        else:
            recv_arr, token = mpi4jax.sendrecv(
                send_arr,
                recv_arr,
                source=recv_proc,
                dest=send_proc,
                comm=mpi_comm,
                token=token,
            )
            arr = arr.at[recv_idx].set(recv_arr)

    if not PERIODIC_BOUNDARY_X and grid == "u" and proc_idx[1] == nproc_x - 1:
        arr = arr.at[:, -2].set(0.0)

    if grid == "v" and proc_idx[0] == nproc_y - 1:
        arr = arr.at[-2, :].set(0.0)

    return arr, token