示例#1
0
 def deadlock(arr):
     if rank == 0:
         # send, then receive
         token = Send(arr, 1)
         newarr, _ = Recv(arr, 1, token=token)
     else:
         # receive, then send
         newarr, token = Recv(arr, 0)
         Send(arr, 0, token=token)
     return newarr
示例#2
0
def test_send_recv():
    from mpi4jax import Send, Recv

    arr = jnp.ones((3, 2)) * rank
    _arr = arr.copy()

    if rank == 0:
        for proc in range(1, size):
            res, token = Recv(arr, source=proc, tag=proc)
            assert jnp.array_equal(res, jnp.ones_like(arr) * proc)
            assert jnp.array_equal(_arr, arr)
    else:
        Send(arr, 0, tag=rank)
        assert jnp.array_equal(_arr, arr)
示例#3
0
def test_send_recv_scalar_jit():
    from mpi4jax import Send, Recv

    arr = 1 * rank
    _arr = 1 * rank

    if rank == 0:
        for proc in range(1, size):
            res, token = jax.jit(lambda x: Recv(x, source=proc, tag=proc))(arr)
            assert jnp.array_equal(res, jnp.ones_like(arr) * proc)
            assert jnp.array_equal(_arr, arr)
    else:
        jax.jit(lambda x: Send(x, 0, tag=rank))(arr)
        assert jnp.array_equal(_arr, arr)
示例#4
0
def test_send_recv_scalar():
    from mpi4jax import Recv, Send

    arr = 1 * rank
    _arr = 1 * rank

    if rank == 0:
        for proc in range(1, size):
            res, token = Recv(arr, source=proc, tag=proc)
            assert jnp.array_equal(res, jnp.ones_like(arr) * proc)
            assert jnp.array_equal(_arr, arr)
    else:
        Send(arr, 0, tag=rank)
        assert jnp.array_equal(_arr, arr)
示例#5
0
def test_send_recv_status():
    from mpi4jax import Recv, Send

    arr = jnp.ones((3, 2)) * rank
    _arr = arr.copy()

    if rank == 0:
        for proc in range(1, size):
            status = MPI.Status()
            res, token = Recv(arr, source=proc, tag=proc, status=status)
            assert jnp.array_equal(res, jnp.ones_like(arr) * proc)
            assert jnp.array_equal(_arr, arr)
            assert status.Get_source() == proc
    else:
        Send(arr, 0, tag=rank)
        assert jnp.array_equal(_arr, arr)
示例#6
0
 def send_jit(x):
     Send(x, 0, tag=rank)
     return x
 def send_jit_deprecated(x):
     Send(x, 0, tag=rank)
     return x