def test_sendrecv(): from mpi4jax import Sendrecv arr = jnp.ones((3, 2)) * rank _arr = arr.copy() other = 1 - rank res, token = Sendrecv(arr, arr, source=other, dest=other) assert jnp.array_equal(res, jnp.ones_like(arr) * other) assert jnp.array_equal(_arr, arr)
def test_sendrecv_scalar_jit(): from mpi4jax import Sendrecv arr = 1 * rank _arr = arr other = 1 - rank res, token = jax.jit( lambda x, y: Sendrecv(x, y, source=other, dest=other))(arr, arr) assert jnp.array_equal(res, jnp.ones_like(arr) * other) assert jnp.array_equal(_arr, arr)
def test_sendrecv_jit(): from mpi4jax import Sendrecv arr = jnp.ones((3, 2)) * rank _arr = arr.copy() other = 1 - rank res = jax.jit(lambda x, y: Sendrecv(x, y, source=other, dest=other)[0])( arr, arr) assert jnp.array_equal(res, jnp.ones_like(arr) * other) assert jnp.array_equal(_arr, arr)
def test_sendrecv_status(): from mpi4jax import Sendrecv arr = jnp.ones((3, 2)) * rank _arr = arr.copy() other = 1 - rank status = MPI.Status() res, token = Sendrecv(arr, arr, source=other, dest=other, status=status) assert jnp.array_equal(res, jnp.ones_like(arr) * other) assert jnp.array_equal(_arr, arr) assert status.Get_source() == other
def test_sendrecv_jit_deprecated(): from mpi4jax import Sendrecv arr = jnp.ones((3, 2)) * rank _arr = arr.copy() other = 1 - rank with pytest.warns(UserWarning, match="deprecated"): res = jax.jit(lambda x, y: Sendrecv(x, y, source=other, dest=other)[0])( arr, arr ) assert jnp.array_equal(res, jnp.ones_like(arr) * other) assert jnp.array_equal(_arr, arr)