Exemplo n.º 1
0
def test_sendrecv():
    from mpi4jax import Sendrecv

    arr = jnp.ones((3, 2)) * rank
    _arr = arr.copy()

    other = 1 - rank

    res, token = Sendrecv(arr, arr, source=other, dest=other)

    assert jnp.array_equal(res, jnp.ones_like(arr) * other)
    assert jnp.array_equal(_arr, arr)
Exemplo n.º 2
0
def test_sendrecv_scalar_jit():
    from mpi4jax import Sendrecv

    arr = 1 * rank
    _arr = arr

    other = 1 - rank

    res, token = jax.jit(
        lambda x, y: Sendrecv(x, y, source=other, dest=other))(arr, arr)

    assert jnp.array_equal(res, jnp.ones_like(arr) * other)
    assert jnp.array_equal(_arr, arr)
Exemplo n.º 3
0
def test_sendrecv_jit():
    from mpi4jax import Sendrecv

    arr = jnp.ones((3, 2)) * rank
    _arr = arr.copy()

    other = 1 - rank

    res = jax.jit(lambda x, y: Sendrecv(x, y, source=other, dest=other)[0])(
        arr, arr)

    assert jnp.array_equal(res, jnp.ones_like(arr) * other)
    assert jnp.array_equal(_arr, arr)
Exemplo n.º 4
0
def test_sendrecv_status():
    from mpi4jax import Sendrecv

    arr = jnp.ones((3, 2)) * rank
    _arr = arr.copy()

    other = 1 - rank

    status = MPI.Status()
    res, token = Sendrecv(arr, arr, source=other, dest=other, status=status)

    assert jnp.array_equal(res, jnp.ones_like(arr) * other)
    assert jnp.array_equal(_arr, arr)
    assert status.Get_source() == other
def test_sendrecv_jit_deprecated():
    from mpi4jax import Sendrecv

    arr = jnp.ones((3, 2)) * rank
    _arr = arr.copy()

    other = 1 - rank

    with pytest.warns(UserWarning, match="deprecated"):
        res = jax.jit(lambda x, y: Sendrecv(x, y, source=other, dest=other)[0])(
            arr, arr
        )

    assert jnp.array_equal(res, jnp.ones_like(arr) * other)
    assert jnp.array_equal(_arr, arr)