def test_allreduce_grad(): from mpi4jax import allreduce arr = jnp.ones((3, 2)) _arr = arr.copy() res, grad = jax.value_and_grad( lambda x: allreduce(x, op=MPI.SUM)[0].sum())(arr) assert jnp.array_equal(res, arr.sum() * size) assert jnp.array_equal(_arr, grad) res, grad = jax.jit( jax.value_and_grad(lambda x: allreduce(x, op=MPI.SUM)[0].sum()))(arr) assert jnp.array_equal(res, arr.sum() * size) assert jnp.array_equal(_arr, grad) def testfun(x): y, token = allreduce(x, op=MPI.SUM) z = x + 2 * y # noqa: F841 res, token = allreduce(x, op=MPI.SUM, token=token) return res.sum() res, grad = jax.jit(jax.value_and_grad(testfun))(arr) assert jnp.array_equal(res, arr.sum() * size) assert jnp.array_equal(_arr, grad)
def test_allreduce_vjp(): from mpi4jax import allreduce arr = jnp.ones((3, 2)) _arr = arr.copy() res, vjp_fun = jax.vjp(lambda x: allreduce(x, op=MPI.SUM)[0], arr) (vjp, ) = vjp_fun(_arr) expected, _ = allreduce(arr, op=MPI.SUM) assert jnp.array_equal(expected, res) assert jnp.array_equal(_arr, vjp)
def test_allreduce_jvp(): from mpi4jax import allreduce arr = jnp.ones((3, 2)) _arr = arr.copy() res, jvp = jax.jvp(lambda x: allreduce(x, op=MPI.SUM)[0], (arr,), (_arr,)) expected, _ = allreduce(arr, op=MPI.SUM) assert jnp.array_equal(expected, res) expected, _ = allreduce(_arr, op=MPI.SUM) assert jnp.array_equal(expected, jvp)
def test_allreduce_transpose(): from mpi4jax import allreduce arr = jnp.ones((3, 2)) _arr = arr.copy() (res,) = jax.linear_transpose(lambda x: allreduce(x, op=MPI.SUM)[0], arr)(_arr) assert jnp.array_equal(_arr, res)
def test_set_debug_logging(capsys): from mpi4jax import allreduce from mpi4jax._src.xla_bridge.mpi_xla_bridge import set_logging arr = jnp.ones((3, 2)) set_logging(True) res = jax.jit(lambda x: allreduce(x, op=MPI.SUM)[0])(arr) res.block_until_ready() captured = capsys.readouterr() assert captured.out == f"r{rank} | MPI_Allreduce with {arr.size} items\n" set_logging(False) res = jax.jit(lambda x: allreduce(x, op=MPI.SUM)[0])(arr) res.block_until_ready() captured = capsys.readouterr() assert not captured.out
def test_allreduce(): from mpi4jax import allreduce arr = jnp.ones((3, 2)) _arr = arr.copy() res, token = allreduce(arr, op=MPI.SUM) assert jnp.array_equal(res, arr * size) assert jnp.array_equal(_arr, arr)
def test_allreduce_scalar(): from mpi4jax import allreduce arr = 1 _arr = 1 res, token = allreduce(arr, op=MPI.SUM) assert jnp.array_equal(res, arr * size) assert jnp.array_equal(_arr, arr)
def test_allreduce_scalar_jit(): from mpi4jax import allreduce arr = 1 _arr = 1 res = jax.jit(lambda x: allreduce(x, op=MPI.SUM)[0])(arr) assert jnp.array_equal(res, arr * size) assert jnp.array_equal(_arr, arr)
def test_allreduce_jit(): from mpi4jax import allreduce arr = jnp.ones((3, 2)) _arr = arr.copy() res = jax.jit(lambda x: allreduce(x, op=MPI.SUM)[0])(arr) assert jnp.array_equal(res, arr * size) assert jnp.array_equal(_arr, arr)
def test_allreduce_vmap(): from mpi4jax import allreduce arr = jnp.ones((3, 2)) _arr = arr.copy() res = jax.vmap(lambda x: allreduce(x, op=MPI.SUM)[0], in_axes=0, out_axes=0)(arr) assert jnp.array_equal(res, arr * size) assert jnp.array_equal(_arr, arr)
def sum_inplace_jax(x): if _n_nodes == 1: return x else: # Note: We must supply a token because we can't transpose `create_token`. # The token can't depend on x for the same reason # This token depends on a constant and will be eliminated by DCE token = jax.lax.create_token(0) res, _ = mpi4jax.allreduce(x, op=_MPI.SUM, comm=_MPI_jax_comm, token=token) return res
def test_allreduce_transpose2(): # test transposing twice from mpi4jax import allreduce arr = jnp.ones((3, 2)) _arr = arr.copy() _arr2 = arr.copy() def lt(y): return jax.linear_transpose(lambda x: allreduce(x, op=MPI.SUM)[0], arr)(y)[0] (res,) = jax.linear_transpose(lt, _arr)(_arr2) expected, _ = allreduce(_arr2, op=MPI.SUM) assert jnp.array_equal(expected, res)
def mpi_max_jax(x, *, token=None, comm=MPI_jax_comm): """ Computes the elementwise logical OR of an array or a scalar across all MPI processes, effectively equivalent to an elementwise any Args: a: The input array. token: An optional token to impose ordering of MPI operations Returns: out: The reduced array. token: an output token """ if n_nodes == 1: return x, token else: import mpi4jax return mpi4jax.allreduce(x, op=MPI.MAX, comm=comm, token=token)
def mpi_prod_jax(x, *, token=None, comm=MPI_jax_comm): """ Computes the elementwise sum of an array or a scalar across all MPI processes. Attempts to perform this sum inplace if possible, but for some types a copy might be returned. Args: a: The input array. token: An optional token to impose ordering of MPI operations Returns: out: The reduced array. token: an output token """ if n_nodes == 1: return x, token else: import mpi4jax return mpi4jax.allreduce(x, op=MPI.PROD, comm=comm, token=token)
def lt(y): return jax.linear_transpose(lambda x: allreduce(x, op=MPI.SUM)[0], arr)(y)[0]
def f(x): (res, ) = jax.linear_transpose(lambda x: allreduce(x, op=MPI.SUM)[0], arr)(x) return res
def foo(x): token = jax.lax.create_token() x1, token = allreduce(x, op=MPI.SUM, comm=comm, token=token) x2, token = allreduce(x, op=MPI.SUM, comm=comm, token=token) return x1 + x2
def testfun(x): y, token = allreduce(x, op=MPI.SUM) z = x + 2 * y # noqa: F841 res, token = allreduce(x, op=MPI.SUM, token=token) return res.sum()
def allreduce_sum(x): res, _ = allreduce(x, op=MPI.SUM, comm=MPI.COMM_WORLD) return res