def test_allreduce_grad(): from mpi4jax import Allreduce arr = jnp.ones((3, 2)) _arr = arr.copy() token = jax.lax.create_token(arr) res, grad = jax.value_and_grad( lambda x: Allreduce(x, op=MPI.SUM)[0].sum())(arr) assert jnp.array_equal(res, arr.sum() * size) assert jnp.array_equal(_arr, arr) res, grad = jax.jit( jax.value_and_grad(lambda x: Allreduce(x, op=MPI.SUM)[0].sum()))(arr) assert jnp.array_equal(res, arr.sum() * size) assert jnp.array_equal(_arr, arr) def testfun(x): y, token = Allreduce(x, op=MPI.SUM) z = x + 2 * y res, token2 = Allreduce(x, op=MPI.SUM, token=token) return res.sum() res, grad = jax.jit(jax.value_and_grad(testfun))(arr) assert jnp.array_equal(res, arr.sum() * size) assert jnp.array_equal(_arr, arr)
def test_allreduce(): arr = np.ones((3, 2)) _arr = arr.copy() res = Allreduce(arr, op=MPI.SUM) assert np.array_equal(res, arr * size) assert np.array_equal(_arr, arr) res = jax.jit(lambda x: Allreduce(x, op=MPI.SUM))(arr) assert np.array_equal(res, arr * size) assert np.array_equal(_arr, arr)
def test_allreduce_scalar_jit(): from mpi4jax import Allreduce arr = 1 _arr = 1 res, token = jax.jit(lambda x: Allreduce(x, op=MPI.SUM))(arr) assert jnp.array_equal(res, arr * size) assert jnp.array_equal(_arr, arr)
def test_allreduce_scalar(): from mpi4jax import Allreduce arr = 1 _arr = 1 res, token = Allreduce(arr, op=MPI.SUM) assert jnp.array_equal(res, arr * size) assert jnp.array_equal(_arr, arr)
def test_allreduce_jit(): from mpi4jax import Allreduce arr = jnp.ones((3, 2)) _arr = arr.copy() res, token = jax.jit(lambda x: Allreduce(x, op=MPI.SUM))(arr) assert jnp.array_equal(res, arr * size) assert jnp.array_equal(_arr, arr)
def test_allreduce(): from mpi4jax import Allreduce arr = jnp.ones((3, 2)) _arr = arr.copy() res, token = Allreduce(arr, op=MPI.SUM) assert jnp.array_equal(res, arr * size) assert jnp.array_equal(_arr, arr)
def test_set_debug_logging(capsys): from mpi4jax import Allreduce from mpi4jax.cython.mpi_xla_bridge import set_logging arr = jnp.ones((3, 2)) set_logging(True) res = jax.jit(lambda x: Allreduce(x, op=MPI.SUM)[0])(arr) res.block_until_ready() captured = capsys.readouterr() assert captured.out.startswith(f"r{rank} | MPI_Allreduce with token") set_logging(False) res = jax.jit(lambda x: Allreduce(x, op=MPI.SUM)[0])(arr) res.block_until_ready() captured = capsys.readouterr() assert not captured.out
def test_allreduce_jit_deprecated(): from mpi4jax import Allreduce arr = jnp.ones((3, 2)) _arr = arr.copy() with pytest.warns(UserWarning, match="deprecated"): res = jax.jit(lambda x: Allreduce(x, op=MPI.SUM)[0])(arr) assert jnp.array_equal(res, arr * size) assert jnp.array_equal(_arr, arr)
def test_debug_logging_enabled(capsys, monkeypatch): from mpi4jax import Allreduce from mpi4jax.cython.mpi_xla_bridge import set_logging arr = jnp.ones((3, 2)) try: set_logging(True) res = jax.jit(lambda x: Allreduce(x, op=MPI.SUM))(arr) res[0].block_until_ready() finally: set_logging(False) captured = capsys.readouterr() assert captured.out.startswith(f"r{rank} | MPI_Allreduce with token")
def test_debug_logging_disabled(capsys, monkeypatch): from mpi4jax import Allreduce from mpi4jax.cython.mpi_xla_bridge import set_logging arr = jnp.ones((3, 2)) set_logging(True) set_logging(False) res = jax.jit(lambda x: Allreduce(x, op=MPI.SUM))(arr) res[0].block_until_ready() captured = capsys.readouterr() assert not captured.out
def testfun(x): y, token = Allreduce(x, op=MPI.SUM) z = x + 2 * y res, token2 = Allreduce(x, op=MPI.SUM, token=token) return res.sum()