Beispiel #1
0
 def deadlock(arr):
     if rank == 0:
         # send, then receive
         token = Send(arr, 1)
         newarr, _ = Recv(arr, 1, token=token)
     else:
         # receive, then send
         newarr, token = Recv(arr, 0)
         Send(arr, 0, token=token)
     return newarr
Beispiel #2
0
def test_send_recv():
    from mpi4jax import Send, Recv

    arr = jnp.ones((3, 2)) * rank
    _arr = arr.copy()

    if rank == 0:
        for proc in range(1, size):
            res, token = Recv(arr, source=proc, tag=proc)
            assert jnp.array_equal(res, jnp.ones_like(arr) * proc)
            assert jnp.array_equal(_arr, arr)
    else:
        Send(arr, 0, tag=rank)
        assert jnp.array_equal(_arr, arr)
Beispiel #3
0
def test_send_recv_scalar_jit():
    from mpi4jax import Send, Recv

    arr = 1 * rank
    _arr = 1 * rank

    if rank == 0:
        for proc in range(1, size):
            res, token = jax.jit(lambda x: Recv(x, source=proc, tag=proc))(arr)
            assert jnp.array_equal(res, jnp.ones_like(arr) * proc)
            assert jnp.array_equal(_arr, arr)
    else:
        jax.jit(lambda x: Send(x, 0, tag=rank))(arr)
        assert jnp.array_equal(_arr, arr)
Beispiel #4
0
def test_send_recv_scalar():
    from mpi4jax import Recv, Send

    arr = 1 * rank
    _arr = 1 * rank

    if rank == 0:
        for proc in range(1, size):
            res, token = Recv(arr, source=proc, tag=proc)
            assert jnp.array_equal(res, jnp.ones_like(arr) * proc)
            assert jnp.array_equal(_arr, arr)
    else:
        Send(arr, 0, tag=rank)
        assert jnp.array_equal(_arr, arr)
Beispiel #5
0
def test_send_recv_status():
    from mpi4jax import Recv, Send

    arr = jnp.ones((3, 2)) * rank
    _arr = arr.copy()

    if rank == 0:
        for proc in range(1, size):
            status = MPI.Status()
            res, token = Recv(arr, source=proc, tag=proc, status=status)
            assert jnp.array_equal(res, jnp.ones_like(arr) * proc)
            assert jnp.array_equal(_arr, arr)
            assert status.Get_source() == proc
    else:
        Send(arr, 0, tag=rank)
        assert jnp.array_equal(_arr, arr)
Beispiel #6
0
def test_send_recv_jit():
    from mpi4jax import Recv, Send

    arr = jnp.ones((3, 2)) * rank
    _arr = arr.copy()

    @jax.jit
    def send_jit(x):
        Send(x, 0, tag=rank)
        return x

    if rank == 0:
        for proc in range(1, size):
            res = jax.jit(lambda x: Recv(x, source=proc, tag=proc)[0])(arr)
            assert jnp.array_equal(res, jnp.ones_like(arr) * proc)
            assert jnp.array_equal(_arr, arr)
    else:
        send_jit(arr)
        assert jnp.array_equal(_arr, arr)