Exemplo n.º 1
0
def test_allreduce_grad():
    from mpi4jax import Allreduce

    arr = jnp.ones((3, 2))
    _arr = arr.copy()

    token = jax.lax.create_token(arr)
    res, grad = jax.value_and_grad(
        lambda x: Allreduce(x, op=MPI.SUM)[0].sum())(arr)
    assert jnp.array_equal(res, arr.sum() * size)
    assert jnp.array_equal(_arr, arr)

    res, grad = jax.jit(
        jax.value_and_grad(lambda x: Allreduce(x, op=MPI.SUM)[0].sum()))(arr)
    assert jnp.array_equal(res, arr.sum() * size)
    assert jnp.array_equal(_arr, arr)

    def testfun(x):
        y, token = Allreduce(x, op=MPI.SUM)
        z = x + 2 * y
        res, token2 = Allreduce(x, op=MPI.SUM, token=token)
        return res.sum()

    res, grad = jax.jit(jax.value_and_grad(testfun))(arr)
    assert jnp.array_equal(res, arr.sum() * size)
    assert jnp.array_equal(_arr, arr)
def test_allreduce():
    arr = np.ones((3, 2))
    _arr = arr.copy()

    res = Allreduce(arr, op=MPI.SUM)
    assert np.array_equal(res, arr * size)
    assert np.array_equal(_arr, arr)

    res = jax.jit(lambda x: Allreduce(x, op=MPI.SUM))(arr)
    assert np.array_equal(res, arr * size)
    assert np.array_equal(_arr, arr)
Exemplo n.º 3
0
def test_allreduce_scalar_jit():
    from mpi4jax import Allreduce

    arr = 1
    _arr = 1

    res, token = jax.jit(lambda x: Allreduce(x, op=MPI.SUM))(arr)
    assert jnp.array_equal(res, arr * size)
    assert jnp.array_equal(_arr, arr)
Exemplo n.º 4
0
def test_allreduce_scalar():
    from mpi4jax import Allreduce

    arr = 1
    _arr = 1

    res, token = Allreduce(arr, op=MPI.SUM)
    assert jnp.array_equal(res, arr * size)
    assert jnp.array_equal(_arr, arr)
Exemplo n.º 5
0
def test_allreduce_jit():
    from mpi4jax import Allreduce

    arr = jnp.ones((3, 2))
    _arr = arr.copy()

    res, token = jax.jit(lambda x: Allreduce(x, op=MPI.SUM))(arr)
    assert jnp.array_equal(res, arr * size)
    assert jnp.array_equal(_arr, arr)
Exemplo n.º 6
0
def test_allreduce():
    from mpi4jax import Allreduce

    arr = jnp.ones((3, 2))
    _arr = arr.copy()

    res, token = Allreduce(arr, op=MPI.SUM)
    assert jnp.array_equal(res, arr * size)
    assert jnp.array_equal(_arr, arr)
Exemplo n.º 7
0
def test_set_debug_logging(capsys):
    from mpi4jax import Allreduce
    from mpi4jax.cython.mpi_xla_bridge import set_logging

    arr = jnp.ones((3, 2))
    set_logging(True)
    res = jax.jit(lambda x: Allreduce(x, op=MPI.SUM)[0])(arr)
    res.block_until_ready()

    captured = capsys.readouterr()
    assert captured.out.startswith(f"r{rank} | MPI_Allreduce with token")

    set_logging(False)
    res = jax.jit(lambda x: Allreduce(x, op=MPI.SUM)[0])(arr)
    res.block_until_ready()

    captured = capsys.readouterr()
    assert not captured.out
def test_allreduce_jit_deprecated():
    from mpi4jax import Allreduce

    arr = jnp.ones((3, 2))
    _arr = arr.copy()

    with pytest.warns(UserWarning, match="deprecated"):
        res = jax.jit(lambda x: Allreduce(x, op=MPI.SUM)[0])(arr)

    assert jnp.array_equal(res, arr * size)
    assert jnp.array_equal(_arr, arr)
Exemplo n.º 9
0
def test_debug_logging_enabled(capsys, monkeypatch):
    from mpi4jax import Allreduce
    from mpi4jax.cython.mpi_xla_bridge import set_logging

    arr = jnp.ones((3, 2))
    try:
        set_logging(True)
        res = jax.jit(lambda x: Allreduce(x, op=MPI.SUM))(arr)
        res[0].block_until_ready()
    finally:
        set_logging(False)

    captured = capsys.readouterr()
    assert captured.out.startswith(f"r{rank} | MPI_Allreduce with token")
Exemplo n.º 10
0
def test_debug_logging_disabled(capsys, monkeypatch):
    from mpi4jax import Allreduce
    from mpi4jax.cython.mpi_xla_bridge import set_logging

    arr = jnp.ones((3, 2))

    set_logging(True)
    set_logging(False)

    res = jax.jit(lambda x: Allreduce(x, op=MPI.SUM))(arr)
    res[0].block_until_ready()

    captured = capsys.readouterr()
    assert not captured.out
Exemplo n.º 11
0
 def testfun(x):
     y, token = Allreduce(x, op=MPI.SUM)
     z = x + 2 * y
     res, token2 = Allreduce(x, op=MPI.SUM, token=token)
     return res.sum()