Exemplo n.º 1
0
def test_debug_logging_disabled(capsys, monkeypatch):
    from mpi4jax import Allreduce
    from mpi4jax.cython.mpi_xla_bridge import set_logging

    arr = jnp.ones((3, 2))

    set_logging(True)
    set_logging(False)

    res = jax.jit(lambda x: Allreduce(x, op=MPI.SUM))(arr)
    res[0].block_until_ready()

    captured = capsys.readouterr()
    assert not captured.out
Exemplo n.º 2
0
def test_debug_logging_enabled(capsys, monkeypatch):
    from mpi4jax import Allreduce
    from mpi4jax.cython.mpi_xla_bridge import set_logging

    arr = jnp.ones((3, 2))
    try:
        set_logging(True)
        res = jax.jit(lambda x: Allreduce(x, op=MPI.SUM))(arr)
        res[0].block_until_ready()
    finally:
        set_logging(False)

    captured = capsys.readouterr()
    assert captured.out.startswith(f"r{rank} | MPI_Allreduce with token")
Exemplo n.º 3
0
def test_set_debug_logging(capsys):
    from mpi4jax import Allreduce
    from mpi4jax.cython.mpi_xla_bridge import set_logging

    arr = jnp.ones((3, 2))
    set_logging(True)
    res = jax.jit(lambda x: Allreduce(x, op=MPI.SUM)[0])(arr)
    res.block_until_ready()

    captured = capsys.readouterr()
    assert captured.out.startswith(f"r{rank} | MPI_Allreduce with token")

    set_logging(False)
    res = jax.jit(lambda x: Allreduce(x, op=MPI.SUM)[0])(arr)
    res.block_until_ready()

    captured = capsys.readouterr()
    assert not captured.out