Esempio n. 1
0
def test_allreduce_grad():
    from mpi4jax import allreduce

    arr = jnp.ones((3, 2))
    _arr = arr.copy()

    res, grad = jax.value_and_grad(
        lambda x: allreduce(x, op=MPI.SUM)[0].sum())(arr)
    assert jnp.array_equal(res, arr.sum() * size)
    assert jnp.array_equal(_arr, grad)

    res, grad = jax.jit(
        jax.value_and_grad(lambda x: allreduce(x, op=MPI.SUM)[0].sum()))(arr)
    assert jnp.array_equal(res, arr.sum() * size)
    assert jnp.array_equal(_arr, grad)

    def testfun(x):
        y, token = allreduce(x, op=MPI.SUM)
        z = x + 2 * y  # noqa: F841
        res, token = allreduce(x, op=MPI.SUM, token=token)
        return res.sum()

    res, grad = jax.jit(jax.value_and_grad(testfun))(arr)
    assert jnp.array_equal(res, arr.sum() * size)
    assert jnp.array_equal(_arr, grad)
Esempio n. 2
0
def test_allreduce_vjp():
    from mpi4jax import allreduce

    arr = jnp.ones((3, 2))
    _arr = arr.copy()

    res, vjp_fun = jax.vjp(lambda x: allreduce(x, op=MPI.SUM)[0], arr)
    (vjp, ) = vjp_fun(_arr)

    expected, _ = allreduce(arr, op=MPI.SUM)
    assert jnp.array_equal(expected, res)
    assert jnp.array_equal(_arr, vjp)
def test_allreduce_jvp():
    from mpi4jax import allreduce

    arr = jnp.ones((3, 2))
    _arr = arr.copy()

    res, jvp = jax.jvp(lambda x: allreduce(x, op=MPI.SUM)[0], (arr,), (_arr,))

    expected, _ = allreduce(arr, op=MPI.SUM)
    assert jnp.array_equal(expected, res)
    expected, _ = allreduce(_arr, op=MPI.SUM)
    assert jnp.array_equal(expected, jvp)
def test_allreduce_transpose():
    from mpi4jax import allreduce

    arr = jnp.ones((3, 2))
    _arr = arr.copy()

    (res,) = jax.linear_transpose(lambda x: allreduce(x, op=MPI.SUM)[0], arr)(_arr)
    assert jnp.array_equal(_arr, res)
Esempio n. 5
0
def test_set_debug_logging(capsys):
    from mpi4jax import allreduce
    from mpi4jax._src.xla_bridge.mpi_xla_bridge import set_logging

    arr = jnp.ones((3, 2))
    set_logging(True)
    res = jax.jit(lambda x: allreduce(x, op=MPI.SUM)[0])(arr)
    res.block_until_ready()

    captured = capsys.readouterr()
    assert captured.out == f"r{rank} | MPI_Allreduce with {arr.size} items\n"

    set_logging(False)
    res = jax.jit(lambda x: allreduce(x, op=MPI.SUM)[0])(arr)
    res.block_until_ready()

    captured = capsys.readouterr()
    assert not captured.out
Esempio n. 6
0
def test_allreduce():
    from mpi4jax import allreduce

    arr = jnp.ones((3, 2))
    _arr = arr.copy()

    res, token = allreduce(arr, op=MPI.SUM)
    assert jnp.array_equal(res, arr * size)
    assert jnp.array_equal(_arr, arr)
Esempio n. 7
0
def test_allreduce_scalar():
    from mpi4jax import allreduce

    arr = 1
    _arr = 1

    res, token = allreduce(arr, op=MPI.SUM)
    assert jnp.array_equal(res, arr * size)
    assert jnp.array_equal(_arr, arr)
Esempio n. 8
0
def test_allreduce_scalar_jit():
    from mpi4jax import allreduce

    arr = 1
    _arr = 1

    res = jax.jit(lambda x: allreduce(x, op=MPI.SUM)[0])(arr)
    assert jnp.array_equal(res, arr * size)
    assert jnp.array_equal(_arr, arr)
Esempio n. 9
0
def test_allreduce_jit():
    from mpi4jax import allreduce

    arr = jnp.ones((3, 2))
    _arr = arr.copy()

    res = jax.jit(lambda x: allreduce(x, op=MPI.SUM)[0])(arr)
    assert jnp.array_equal(res, arr * size)
    assert jnp.array_equal(_arr, arr)
Esempio n. 10
0
def test_allreduce_vmap():
    from mpi4jax import allreduce

    arr = jnp.ones((3, 2))
    _arr = arr.copy()

    res = jax.vmap(lambda x: allreduce(x, op=MPI.SUM)[0],
                   in_axes=0,
                   out_axes=0)(arr)
    assert jnp.array_equal(res, arr * size)
    assert jnp.array_equal(_arr, arr)
Esempio n. 11
0
 def sum_inplace_jax(x):
     if _n_nodes == 1:
         return x
     else:
         # Note: We must supply a token because we can't transpose `create_token`.
         # The token can't depend on x for the same reason
         # This token depends on a constant and will be eliminated by DCE
         token = jax.lax.create_token(0)
         res, _ = mpi4jax.allreduce(x,
                                    op=_MPI.SUM,
                                    comm=_MPI_jax_comm,
                                    token=token)
         return res
def test_allreduce_transpose2():
    # test transposing twice
    from mpi4jax import allreduce

    arr = jnp.ones((3, 2))
    _arr = arr.copy()
    _arr2 = arr.copy()

    def lt(y):
        return jax.linear_transpose(lambda x: allreduce(x, op=MPI.SUM)[0], arr)(y)[0]

    (res,) = jax.linear_transpose(lt, _arr)(_arr2)
    expected, _ = allreduce(_arr2, op=MPI.SUM)
    assert jnp.array_equal(expected, res)
Esempio n. 13
0
def mpi_max_jax(x, *, token=None, comm=MPI_jax_comm):
    """
    Computes the elementwise logical OR of an array or a scalar across all MPI
    processes, effectively equivalent to an elementwise any

    Args:
        a: The input array.
        token: An optional token to impose ordering of MPI operations

    Returns:
        out: The reduced array.
        token: an output token
    """
    if n_nodes == 1:
        return x, token
    else:
        import mpi4jax

        return mpi4jax.allreduce(x, op=MPI.MAX, comm=comm, token=token)
Esempio n. 14
0
def mpi_prod_jax(x, *, token=None, comm=MPI_jax_comm):
    """
    Computes the elementwise sum of an array or a scalar across all MPI processes.
    Attempts to perform this sum inplace if possible, but for some types a copy
    might be returned.

    Args:
        a: The input array.
        token: An optional token to impose ordering of MPI operations

    Returns:
        out: The reduced array.
        token: an output token
    """
    if n_nodes == 1:
        return x, token
    else:
        import mpi4jax

        return mpi4jax.allreduce(x, op=MPI.PROD, comm=comm, token=token)
Esempio n. 15
0
 def lt(y):
     return jax.linear_transpose(lambda x: allreduce(x, op=MPI.SUM)[0],
                                 arr)(y)[0]
Esempio n. 16
0
 def f(x):
     (res, ) = jax.linear_transpose(lambda x: allreduce(x, op=MPI.SUM)[0],
                                    arr)(x)
     return res
Esempio n. 17
0
 def foo(x):
     token = jax.lax.create_token()
     x1, token = allreduce(x, op=MPI.SUM, comm=comm, token=token)
     x2, token = allreduce(x, op=MPI.SUM, comm=comm, token=token)
     return x1 + x2
Esempio n. 18
0
 def testfun(x):
     y, token = allreduce(x, op=MPI.SUM)
     z = x + 2 * y  # noqa: F841
     res, token = allreduce(x, op=MPI.SUM, token=token)
     return res.sum()
def allreduce_sum(x):
    res, _ = allreduce(x, op=MPI.SUM, comm=MPI.COMM_WORLD)
    return res