def _to_array_rank(apply_fun, variables, σ_rank, n_states, normalize): """ Computes apply_fun(variables, σ_rank) and gathers all results across all ranks. The input σ_rank should be a slice of all states in the hilbert space of equal length across all ranks because mpi4jax does not support allgatherv (yet). n_states: total number of elements in the hilbert space """ # number of 'fake' states, in the last rank. n_fake_states = σ_rank.shape[0] * mpi.n_nodes - n_states psi_local = apply_fun(variables, σ_rank) # last rank, get rid of fake elements if mpi.rank == mpi.n_nodes - 1 and n_fake_states > 0: psi_local = jax.ops.index_update(psi_local, jax.ops.index[-n_fake_states:], 0.0) logmax, _ = mpi.mpi_max_jax(psi_local.real.max()) psi_local = jnp.exp(psi_local - logmax) # compute normalization if normalize: norm2 = jnp.linalg.norm(psi_local)**2 norm2, _ = mpi.mpi_sum_jax(norm2) psi_local /= jnp.sqrt(norm2) psi, _ = mpi.mpi_allgather_jax(psi_local) psi = psi.reshape(-1) # remove fake states psi = psi[0:n_states] return psi
def _to_array_rank(apply_fun, variables, σ_rank, n_states, normalize, allgather): """ Computes apply_fun(variables, σ_rank) and gathers all results across all ranks. The input σ_rank should be a slice of all states in the hilbert space of equal length across all ranks because mpi4jax does not support allgatherv (yet). Args: n_states: total number of elements in the hilbert space. """ # number of 'fake' states, in the last rank. n_fake_states = σ_rank.shape[0] * mpi.n_nodes - n_states log_psi_local = apply_fun(variables, σ_rank) # last rank, get rid of fake elements if mpi.rank == mpi.n_nodes - 1 and n_fake_states > 0: log_psi_local = log_psi_local.at[jax.ops.index[-n_fake_states:]].set( -jnp.inf) if normalize: # subtract logmax for better numerical stability logmax, _ = mpi.mpi_max_jax(log_psi_local.real.max()) log_psi_local -= logmax psi_local = jnp.exp(log_psi_local) if normalize: # compute normalization norm2 = jnp.linalg.norm(psi_local)**2 norm2, _ = mpi.mpi_sum_jax(norm2) psi_local /= jnp.sqrt(norm2) if allgather: psi, _ = mpi.mpi_allgather_jax(psi_local) else: psi = psi_local psi = psi.reshape(-1) # remove fake states psi = psi[0:n_states] return psi