コード例 #1
0
ファイル: utils.py プロジェクト: FermiQ/netket
def _to_array_rank(apply_fun, variables, σ_rank, n_states, normalize):
    """
    Computes apply_fun(variables, σ_rank) and gathers all results across all ranks.
    The input σ_rank should be a slice of all states in the hilbert space of equal
    length across all ranks because mpi4jax does not support allgatherv (yet).

    n_states: total number of elements in the hilbert space
    """
    # number of 'fake' states, in the last rank.
    n_fake_states = σ_rank.shape[0] * mpi.n_nodes - n_states

    psi_local = apply_fun(variables, σ_rank)

    # last rank, get rid of fake elements
    if mpi.rank == mpi.n_nodes - 1 and n_fake_states > 0:
        psi_local = jax.ops.index_update(psi_local,
                                         jax.ops.index[-n_fake_states:], 0.0)

    logmax, _ = mpi.mpi_max_jax(psi_local.real.max())
    psi_local = jnp.exp(psi_local - logmax)

    # compute normalization
    if normalize:
        norm2 = jnp.linalg.norm(psi_local)**2
        norm2, _ = mpi.mpi_sum_jax(norm2)

        psi_local /= jnp.sqrt(norm2)

    psi, _ = mpi.mpi_allgather_jax(psi_local)
    psi = psi.reshape(-1)

    # remove fake states
    psi = psi[0:n_states]
    return psi
コード例 #2
0
ファイル: utils.py プロジェクト: tobiaswiener/netket
def _to_array_rank(apply_fun, variables, σ_rank, n_states, normalize,
                   allgather):
    """
    Computes apply_fun(variables, σ_rank) and gathers all results across all ranks.
    The input σ_rank should be a slice of all states in the hilbert space of equal
    length across all ranks because mpi4jax does not support allgatherv (yet).

    Args:
        n_states: total number of elements in the hilbert space.
    """
    # number of 'fake' states, in the last rank.
    n_fake_states = σ_rank.shape[0] * mpi.n_nodes - n_states

    log_psi_local = apply_fun(variables, σ_rank)

    # last rank, get rid of fake elements
    if mpi.rank == mpi.n_nodes - 1 and n_fake_states > 0:
        log_psi_local = log_psi_local.at[jax.ops.index[-n_fake_states:]].set(
            -jnp.inf)

    if normalize:
        # subtract logmax for better numerical stability
        logmax, _ = mpi.mpi_max_jax(log_psi_local.real.max())
        log_psi_local -= logmax

    psi_local = jnp.exp(log_psi_local)

    if normalize:
        # compute normalization
        norm2 = jnp.linalg.norm(psi_local)**2
        norm2, _ = mpi.mpi_sum_jax(norm2)
        psi_local /= jnp.sqrt(norm2)

    if allgather:
        psi, _ = mpi.mpi_allgather_jax(psi_local)
    else:
        psi = psi_local

    psi = psi.reshape(-1)

    # remove fake states
    psi = psi[0:n_states]
    return psi