Beispiel #1
0
def mpi_bcast_jax(x, *, token=None, root, comm=MPI_jax_comm):
    if n_nodes == 1:
        assert root == 0
        return x, token
    else:
        import mpi4jax

        return mpi4jax.bcast(x, token=token, root=root, comm=comm)
def test_bcast_scalar_jit():
    from mpi4jax import bcast

    arr = 1
    _arr = 1

    if rank != 0:
        _arr = _arr * 0

    res = jax.jit(lambda x: bcast(x, root=0)[0])(_arr)
    assert jnp.array_equal(res, arr)
    if rank == 0:
        assert jnp.array_equal(_arr, arr)
def test_bcast_scalar():
    from mpi4jax import bcast

    arr = 1
    _arr = 1

    if rank != 0:
        _arr = _arr * 0

    res, token = bcast(_arr, root=0)
    assert jnp.array_equal(res, arr)
    if rank == 0:
        assert jnp.array_equal(_arr, arr)
def test_bcast_jit():
    from mpi4jax import bcast

    arr = jnp.ones((3, 2))
    _arr = arr.copy()

    if rank != 0:
        _arr = _arr * 0

    res = jax.jit(lambda x: bcast(x, root=0)[0])(arr)
    assert jnp.array_equal(res, arr)
    if rank == 0:
        assert jnp.array_equal(_arr, arr)
def test_bcast():
    from mpi4jax import bcast

    arr = jnp.ones((3, 2))
    _arr = arr.copy()

    if rank != 0:
        _arr = _arr * 0

    res, token = bcast(_arr, root=0)
    assert jnp.array_equal(res, arr)
    if rank == 0:
        assert jnp.array_equal(_arr, arr)
Beispiel #6
0
def PRNGKey(
    seed: Optional[SeedT] = None, root: int = 0, comm=MPI.COMM_WORLD
) -> PRNGKeyT:
    """
    Initialises a PRNGKey using an optional starting seed.
    The same seed will be distributed to all processes.
    """
    if seed is None:
        key = jax.random.PRNGKey(random_seed())
    elif isinstance(seed, int):
        key = jax.random.PRNGKey(seed)
    else:
        key = seed

    if n_nodes > 1:
        import mpi4jax

        key, _ = mpi4jax.bcast(key, root=root, comm=comm)

    return key
Beispiel #7
0
def mpi_split(key, root=0, comm=MPI.COMM_WORLD) -> PRNGKeyT:
    """
    Split a key across MPI nodes in the communicator.
    Only the input key on the root process matters.

    Arguments:
        key: The key to split. Only considered the one on the root process.
        root: (default=0) The root rank from which to take the input key.
        comm: (default=MPI.COMM_WORLD) The MPI communicator.

    Returns:
        A PRNGKey depending on rank number and key.
    """

    # Maybe add error/warning if in_key is not the same
    # on all MPI nodes?
    keys = jax.random.split(key, n_nodes)

    if n_nodes > 1:
        import mpi4jax

        keys, _ = mpi4jax.bcast(keys, root=root, comm=comm)

    return keys[rank]