def mpi_bcast_jax(x, *, token=None, root, comm=MPI_jax_comm): if n_nodes == 1: assert root == 0 return x, token else: import mpi4jax return mpi4jax.bcast(x, token=token, root=root, comm=comm)
def test_bcast_scalar_jit(): from mpi4jax import bcast arr = 1 _arr = 1 if rank != 0: _arr = _arr * 0 res = jax.jit(lambda x: bcast(x, root=0)[0])(_arr) assert jnp.array_equal(res, arr) if rank == 0: assert jnp.array_equal(_arr, arr)
def test_bcast_scalar(): from mpi4jax import bcast arr = 1 _arr = 1 if rank != 0: _arr = _arr * 0 res, token = bcast(_arr, root=0) assert jnp.array_equal(res, arr) if rank == 0: assert jnp.array_equal(_arr, arr)
def test_bcast_jit(): from mpi4jax import bcast arr = jnp.ones((3, 2)) _arr = arr.copy() if rank != 0: _arr = _arr * 0 res = jax.jit(lambda x: bcast(x, root=0)[0])(arr) assert jnp.array_equal(res, arr) if rank == 0: assert jnp.array_equal(_arr, arr)
def test_bcast(): from mpi4jax import bcast arr = jnp.ones((3, 2)) _arr = arr.copy() if rank != 0: _arr = _arr * 0 res, token = bcast(_arr, root=0) assert jnp.array_equal(res, arr) if rank == 0: assert jnp.array_equal(_arr, arr)
def PRNGKey( seed: Optional[SeedT] = None, root: int = 0, comm=MPI.COMM_WORLD ) -> PRNGKeyT: """ Initialises a PRNGKey using an optional starting seed. The same seed will be distributed to all processes. """ if seed is None: key = jax.random.PRNGKey(random_seed()) elif isinstance(seed, int): key = jax.random.PRNGKey(seed) else: key = seed if n_nodes > 1: import mpi4jax key, _ = mpi4jax.bcast(key, root=root, comm=comm) return key
def mpi_split(key, root=0, comm=MPI.COMM_WORLD) -> PRNGKeyT: """ Split a key across MPI nodes in the communicator. Only the input key on the root process matters. Arguments: key: The key to split. Only considered the one on the root process. root: (default=0) The root rank from which to take the input key. comm: (default=MPI.COMM_WORLD) The MPI communicator. Returns: A PRNGKey depending on rank number and key. """ # Maybe add error/warning if in_key is not the same # on all MPI nodes? keys = jax.random.split(key, n_nodes) if n_nodes > 1: import mpi4jax keys, _ = mpi4jax.bcast(keys, root=root, comm=comm) return keys[rank]