def PRNGKey( seed: Optional[SeedT] = None, *, root: int = 0, comm=MPI_jax_comm ) -> PRNGKeyT: """ Initialises a PRNGKey using an optional starting seed. The same seed will be distributed to all processes. """ if seed is None: key = jax.random.PRNGKey(random_seed()) elif isinstance(seed, int): key = jax.random.PRNGKey(seed) else: key = seed key = jax.tree_map(lambda k: mpi.mpi_bcast_jax(k, root=root, comm=comm)[0], key) return key
def PRNGKey( seed: Optional[SeedT] = None, root: int = 0, comm=MPI.COMM_WORLD ) -> PRNGKeyT: """ Initialises a PRNGKey using an optional starting seed. The same seed will be distributed to all processes. """ if seed is None: key = jax.random.PRNGKey(random_seed()) elif isinstance(seed, int): key = jax.random.PRNGKey(seed) else: key = seed if n_nodes > 1: import mpi4jax key, _ = mpi4jax.bcast(key, root=root, comm=comm) return key