def deadlock(arr): if rank == 0: # send, then receive token = Send(arr, 1) newarr, _ = Recv(arr, 1, token=token) else: # receive, then send newarr, token = Recv(arr, 0) Send(arr, 0, token=token) return newarr
def test_send_recv(): from mpi4jax import Send, Recv arr = jnp.ones((3, 2)) * rank _arr = arr.copy() if rank == 0: for proc in range(1, size): res, token = Recv(arr, source=proc, tag=proc) assert jnp.array_equal(res, jnp.ones_like(arr) * proc) assert jnp.array_equal(_arr, arr) else: Send(arr, 0, tag=rank) assert jnp.array_equal(_arr, arr)
def test_send_recv_scalar_jit(): from mpi4jax import Send, Recv arr = 1 * rank _arr = 1 * rank if rank == 0: for proc in range(1, size): res, token = jax.jit(lambda x: Recv(x, source=proc, tag=proc))(arr) assert jnp.array_equal(res, jnp.ones_like(arr) * proc) assert jnp.array_equal(_arr, arr) else: jax.jit(lambda x: Send(x, 0, tag=rank))(arr) assert jnp.array_equal(_arr, arr)
def test_send_recv_scalar(): from mpi4jax import Recv, Send arr = 1 * rank _arr = 1 * rank if rank == 0: for proc in range(1, size): res, token = Recv(arr, source=proc, tag=proc) assert jnp.array_equal(res, jnp.ones_like(arr) * proc) assert jnp.array_equal(_arr, arr) else: Send(arr, 0, tag=rank) assert jnp.array_equal(_arr, arr)
def test_send_recv_status(): from mpi4jax import Recv, Send arr = jnp.ones((3, 2)) * rank _arr = arr.copy() if rank == 0: for proc in range(1, size): status = MPI.Status() res, token = Recv(arr, source=proc, tag=proc, status=status) assert jnp.array_equal(res, jnp.ones_like(arr) * proc) assert jnp.array_equal(_arr, arr) assert status.Get_source() == proc else: Send(arr, 0, tag=rank) assert jnp.array_equal(_arr, arr)
def send_jit(x): Send(x, 0, tag=rank) return x
def send_jit_deprecated(x): Send(x, 0, tag=rank) return x