Exemple #1
def sum(a, axis=None, keepdims: bool = False):
    Compute the sum along the specified axis and over MPI processes.

        a: The input array
        axis: Axis or axes along which the mean is computed. The default (None) is to
              compute the mean of the flattened array.
        out: An optional pre-allocated array to fill with the result.
        keepdims: If True the output array will have the same number of dimensions as
              the input, with the reduced axes having length 1. (default=False)

        The array with reduced dimensions defined by axis. If out is not none,
        returns out.

    # if it's a numpy-like array...
    if hasattr(a, "shape"):
        # use jax
        a_sum = a.sum(axis=axis, keepdims=keepdims)
        # assume it's a scalar
        a_sum = jnp.asarray(a)

    out, _ = mpi.mpi_sum_jax(a_sum)
    return out
Exemple #2
def _to_array_rank(apply_fun, variables, σ_rank, n_states, normalize):
    Computes apply_fun(variables, σ_rank) and gathers all results across all ranks.
    The input σ_rank should be a slice of all states in the hilbert space of equal
    length across all ranks because mpi4jax does not support allgatherv (yet).

    n_states: total number of elements in the hilbert space
    # number of 'fake' states, in the last rank.
    n_fake_states = σ_rank.shape[0] * mpi.n_nodes - n_states

    psi_local = apply_fun(variables, σ_rank)

    # last rank, get rid of fake elements
    if mpi.rank == mpi.n_nodes - 1 and n_fake_states > 0:
        psi_local = jax.ops.index_update(psi_local,
                                         jax.ops.index[-n_fake_states:], 0.0)

    logmax, _ = mpi.mpi_max_jax(psi_local.real.max())
    psi_local = jnp.exp(psi_local - logmax)

    # compute normalization
    if normalize:
        norm2 = jnp.linalg.norm(psi_local)**2
        norm2, _ = mpi.mpi_sum_jax(norm2)

        psi_local /= jnp.sqrt(norm2)

    psi, _ = mpi.mpi_allgather_jax(psi_local)
    psi = psi.reshape(-1)

    # remove fake states
    psi = psi[0:n_states]
    return psi
Exemple #3
def Odagger_DeltaO_v(forward_fn, params, samples, v):
    w = O_jvp(forward_fn, params, samples, v)
    w = w * (1.0 / (samples.shape[0] * samples.shape[1] * mpi.n_nodes))
    w_, chunk_fn = unchunk(w)
    w = chunk_fn(subtract_mean(w_))  # w/ MPI
    res = OH_w(forward_fn, params, samples, w)
    return jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], res)  # MPI
def O_vjp_rc(forward_fn, params, samples, w):
    _, vjp_fun = jax.vjp(forward_fn, params, samples)
    res_r, _ = vjp_fun(w)
    res_i, _ = vjp_fun(-1.0j * w)
    res = jax.tree_multimap(jax.lax.complex, res_r, res_i)
    return jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0],
                        res)  # allreduce w/ MPI.SUM
def _matmul(self: QGTJacobianDenseT,
            vec: Union[PyTree, jnp.ndarray]) -> Union[PyTree, jnp.ndarray]:

    unravel = None
    if not hasattr(vec, "ndim"):
        vec, unravel = nkjax.tree_ravel(vec)

    # Real-imaginary split RHS in R→R and R→C modes
    reassemble = None
    if self.mode != "holomorphic" and not self._in_solve:
        vec, reassemble = vec_to_real(vec)

    if self.scale is not None:
        vec = vec * self.scale

    result = (mpi.mpi_sum_jax(((self.O @ vec).T.conj() @ self.O).T.conj())[0] +
              self.diag_shift * vec)

    if self.scale is not None:
        result = result * self.scale

    if reassemble is not None:
        result = reassemble(result)

    if unravel is not None:
        result = unravel(result)

    return result
Exemple #6
def grad_expect_hermitian_chunked(
    chunk_size: int,
    local_value_kernel_chunked: Callable,
    model_apply_fun: Callable,
    mutable: bool,
    parameters: PyTree,
    model_state: PyTree,
    σ: jnp.ndarray,
    local_value_args: PyTree,
) -> Tuple[PyTree, PyTree]:

    σ_shape = σ.shape
    if jnp.ndim(σ) != 2:
        σ = σ.reshape((-1, σ_shape[-1]))

    n_samples = σ.shape[0] * mpi.n_nodes

    O_loc = local_value_kernel_chunked(
        {"params": parameters, **model_state},

    Ō = statistics(O_loc.reshape(σ_shape[:-1]).T)

    O_loc -= Ō.mean

    # Then compute the vjp.
    # Code is a bit more complex than a standard one because we support
    # mutable state (if it's there)
    if mutable is False:
        vjp_fun_chunked = nkjax.vjp_chunked(
            lambda w, σ: model_apply_fun({"params": w, **model_state}, σ),
        new_model_state = None
        raise NotImplementedError

    Ō_grad = vjp_fun_chunked(
        (jnp.conjugate(O_loc) / n_samples),

    Ō_grad = jax.tree_multimap(
        lambda x, target: (x if jnp.iscomplexobj(target) else 2 * x.real).astype(

    return Ō, tree_map(lambda x: mpi.mpi_sum_jax(x)[0], Ō_grad), new_model_state
Exemple #7
def grad_expect_hermitian(
    model_apply_fun: Callable,
    mutable: bool,
    parameters: PyTree,
    model_state: PyTree,
    σ: jnp.ndarray,
    σp: jnp.ndarray,
    mels: jnp.ndarray,
) -> Tuple[PyTree, PyTree]:

    σ_shape = σ.shape
    if jnp.ndim(σ) != 2:
        σ = σ.reshape((-1, σ_shape[-1]))

    n_samples = σ.shape[0] * mpi.n_nodes

    O_loc = local_cost_function(
        {"params": parameters, **model_state},

    Ō = statistics(O_loc.reshape(σ_shape[:-1]).T)

    O_loc -= Ō.mean

    # Then compute the vjp.
    # Code is a bit more complex than a standard one because we support
    # mutable state (if it's there)
    if mutable is False:
        _, vjp_fun = nkjax.vjp(
            lambda w: model_apply_fun({"params": w, **model_state}, σ),
        new_model_state = None
        _, vjp_fun, new_model_state = nkjax.vjp(
            lambda w: model_apply_fun({"params": w, **model_state}, σ, mutable=mutable),
    Ō_grad = vjp_fun(jnp.conjugate(O_loc) / n_samples)[0]

    Ō_grad = jax.tree_multimap(
        lambda x, target: (x if jnp.iscomplexobj(target) else x.real).astype(

    return Ō, tree_map(lambda x: mpi.mpi_sum_jax(x)[0], Ō_grad), new_model_state
def _to_dense(self: QGTJacobianDenseT) -> jnp.ndarray:
    if self.scale is None:
        O = self.O
        diag = jnp.eye(self.O.shape[1])
        O = self.O * self.scale[jnp.newaxis, :]
        diag = jnp.diag(self.scale**2)

    return mpi.mpi_sum_jax(O.T.conj() @ O)[0] + self.diag_shift * diag
Exemple #9
def _to_dense(self: QGTJacobianPyTreeT) -> jnp.ndarray:
    O = jax.vmap(lambda l: nkjax.tree_ravel(l)[0])(self.O)

    if self.scale is None:
        diag = jnp.eye(O.shape[1])
        scale, _ = nkjax.tree_ravel(self.scale)
        O = O * scale[jnp.newaxis, :]
        diag = jnp.diag(scale**2)

    return mpi.mpi_sum_jax(O.T.conj() @ O)[0] + self.diag_shift * diag
def _rescale(centered_oks):
    compute ΔOₖ/√Sₖₖ and √Sₖₖ
    to do scale-invariant regularization (Becca & Sorella 2017, pp. 143)
    Sₖₗ/(√Sₖₖ√Sₗₗ) = ΔOₖᴴΔOₗ/(√Sₖₖ√Sₗₗ) = (ΔOₖ/√Sₖₖ)ᴴ(ΔOₗ/√Sₗₗ)
    scale = (mpi.mpi_sum_jax(
        jnp.sum((centered_oks * centered_oks.conj()).real,
    centered_oks = jnp.divide(centered_oks, scale)
    scale = jnp.squeeze(scale, axis=0)
    return centered_oks, scale
Exemple #11
def _rescale(centered_oks):
    compute ΔOₖ/√Sₖₖ and √Sₖₖ
    to do scale-invariant regularization (Becca & Sorella 2017, pp. 143)
    Sₖₗ/(√Sₖₖ√Sₗₗ) = ΔOₖᴴΔOₗ/(√Sₖₖ√Sₗₗ) = (ΔOₖ/√Sₖₖ)ᴴ(ΔOₗ/√Sₗₗ)
    scale = jax.tree_map(
        lambda x: mpi.mpi_sum_jax(
            jnp.sum((x * x.conj()).real, axis=0, keepdims=True))[0]**0.5,
    centered_oks = jax.tree_map(jnp.divide, centered_oks, scale)
    scale = jax.tree_map(partial(jnp.squeeze, axis=0), scale)
    return centered_oks, scale
Exemple #12
def mat_vec(jvp_fn, v, diag_shift):
    # Save linearisation work
    # TODO move to mat_vec_factory after jax v0.2.19
    vjp_fn = jax.linear_transpose(jvp_fn, v)

    w = jvp_fn(v)
    w = w * (1.0 / (w.size * mpi.n_nodes))
    w = subtract_mean(w)  # w/ MPI
    # Oᴴw = (wᴴO)ᴴ = (w* O)* since 1D arrays are not transposed
    # vjp_fn packages output into a length-1 tuple
    (res, ) = tree_conj(vjp_fn(w.conjugate()))
    res = jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], res)

    return tree_axpy(diag_shift, v, res)  # res + diag_shift * v
Exemple #13
def _to_array_rank(apply_fun, variables, σ_rank, n_states, normalize,
    Computes apply_fun(variables, σ_rank) and gathers all results across all ranks.
    The input σ_rank should be a slice of all states in the hilbert space of equal
    length across all ranks because mpi4jax does not support allgatherv (yet).

        n_states: total number of elements in the hilbert space.
    # number of 'fake' states, in the last rank.
    n_fake_states = σ_rank.shape[0] * mpi.n_nodes - n_states

    log_psi_local = apply_fun(variables, σ_rank)

    # last rank, get rid of fake elements
    if mpi.rank == mpi.n_nodes - 1 and n_fake_states > 0:
        log_psi_local = log_psi_local.at[jax.ops.index[-n_fake_states:]].set(

    if normalize:
        # subtract logmax for better numerical stability
        logmax, _ = mpi.mpi_max_jax(log_psi_local.real.max())
        log_psi_local -= logmax

    psi_local = jnp.exp(log_psi_local)

    if normalize:
        # compute normalization
        norm2 = jnp.linalg.norm(psi_local)**2
        norm2, _ = mpi.mpi_sum_jax(norm2)
        psi_local /= jnp.sqrt(norm2)

    if allgather:
        psi, _ = mpi.mpi_allgather_jax(psi_local)
        psi = psi_local

    psi = psi.reshape(-1)

    # remove fake states
    psi = psi[0:n_states]
    return psi
Exemple #14
def grad_expect_operator_Lrho2(
    model_apply_fun: Callable,
    mutable: bool,
    parameters: PyTree,
    model_state: PyTree,
    σ: jnp.ndarray,
    σp: jnp.ndarray,
    mels: jnp.ndarray,
) -> Tuple[PyTree, PyTree, Stats]:
    σ_shape = σ.shape
    if jnp.ndim(σ) != 2:
        σ = σ.reshape((-1, σ_shape[-1]))

    n_samples_node = σ.shape[0]

    has_aux = mutable is not False
    # if not has_aux:
    #    out_axes = (0, 0)
    # else:
    #    out_axes = (0, 0, 0)

    if not has_aux:
        logpsi = lambda w, σ: model_apply_fun({"params": w, **model_state}, σ)
        # TODO: output the mutable state
        logpsi = lambda w, σ: model_apply_fun(
            {"params": w, **model_state}, σ, mutable=mutable

    # local_kernel_vmap = jax.vmap(
    #    partial(local_value_kernel, logpsi), in_axes=(None, 0, 0, 0), out_axes=0
    # )

    # _Lρ = local_kernel_vmap(parameters, σ, σp, mels).reshape((σ_shape[0], -1))
    ) = _der_local_values_jax._local_values_and_grads_notcentered_kernel(
        logpsi, parameters, σp, mels, σ
    # _der_local_values_jax._local_values_and_grads_notcentered_kernel returns a loc_val that is conjugated
    Lρ = jnp.conjugate(Lρ)

    LdagL_stats = statistics((jnp.abs(Lρ) ** 2).T)
    LdagL_mean = LdagL_stats.mean

    # old implementation
    # this is faster, even though i think the one below should be faster
    # (this works, but... yeah. let's keep it here and delete in a while.)
    grad_fun = jax.vmap(nkjax.grad(logpsi, argnums=0), in_axes=(None, 0), out_axes=0)
    der_logs = grad_fun(parameters, σ)
    der_logs_ave = tree_map(lambda x: mean(x, axis=0), der_logs)

    # TODO
    # This should be faster, but should benchmark as it seems slower
    # to compute der_logs_ave i can just do a jvp with a ones vector
    # _logpsi_ave, d_logpsi = nkjax.vjp(lambda w: logpsi(w, σ), parameters)
    # TODO: this ones_like might produce a complexXX type but we only need floatXX
    # and we cut in 1/2 the # of operations to do.
    # der_logs_ave = d_logpsi(
    #    jnp.ones_like(_logpsi_ave).real / (n_samples_node * utils.n_nodes)
    # )[0]
    der_logs_ave = tree_map(lambda x: mpi.mpi_sum_jax(x)[0], der_logs_ave)

    def gradfun(der_loc_vals, der_logs_ave):
        par_dims = der_loc_vals.ndim - 1

        _lloc_r = Lρ.reshape((n_samples_node,) + tuple(1 for i in range(par_dims)))

        grad = mean(der_loc_vals.conjugate() * _lloc_r, axis=0) - (
            der_logs_ave.conjugate() * LdagL_mean
        return grad

    LdagL_grad = jax.tree_util.tree_multimap(gradfun, der_loc_vals, der_logs_ave)

    return (
 def n_accepted(self) -> int:
     """Total number of moves accepted across all processes since the last reset."""
     res, _ = mpi.mpi_sum_jax(self.n_accepted_proc.sum())
     return res
Exemple #16
def _vjp(oks: PyTree, w: Array) -> PyTree:
    Compute the vector-matrix product between the vector w and the pytree jacobian oks
    res = jax.tree_map(partial(jnp.tensordot, w, axes=1), oks)
    return jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], res)  # MPI
def O_vjp(forward_fn, params, samples, w):
    _, vjp_fun = jax.vjp(forward_fn, params, samples)
    res, _ = vjp_fun(w)
    return jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0],
                        res)  # allreduce w/ MPI.SUM
Exemple #18
def sum_inplace_jax(x):
    res, _ = mpi_sum_jax(x)
    return res
Exemple #19
def grad_expect_operator_Lrho2(
    model_apply_fun: Callable,
    mutable: bool,
    parameters: PyTree,
    model_state: PyTree,
    σ: jnp.ndarray,
    σp: jnp.ndarray,
    mels: jnp.ndarray,
) -> Tuple[PyTree, PyTree, Stats]:
    σ_shape = σ.shape
    if jnp.ndim(σ) != 2:
        σ = σ.reshape((-1, σ_shape[-1]))

    n_samples_node = σ.shape[0]

    has_aux = mutable is not False
    # if not has_aux:
    #    out_axes = (0, 0)
    # else:
    #    out_axes = (0, 0, 0)

    if not has_aux:
        logpsi = lambda w, σ: model_apply_fun({"params": w, **model_state}, σ)
        # TODO: output the mutable state
        logpsi = lambda w, σ: model_apply_fun(
            {"params": w, **model_state}, σ, mutable=mutable

    # local_kernel_vmap = jax.vmap(
    #    partial(local_value_kernel, logpsi), in_axes=(None, 0, 0, 0), out_axes=0
    # )

    # _Lρ = local_kernel_vmap(parameters, σ, σp, mels).reshape((σ_shape[0], -1))
    ) = _local_values_and_grads_notcentered_kernel(logpsi, parameters, σp, mels, σ)
    # _local_values_and_grads_notcentered_kernel returns a loc_val that is conjugated
    Lρ = jnp.conjugate(Lρ)

    LdagL_stats = statistics((jnp.abs(Lρ) ** 2).T)
    LdagL_mean = LdagL_stats.mean

    _logpsi_ave, d_logpsi = nkjax.vjp(lambda w: logpsi(w, σ), parameters)
    # TODO: this ones_like might produce a complexXX type but we only need floatXX
    # and we cut in 1/2 the # of operations to do.
    der_logs_ave = d_logpsi(
        jnp.ones_like(_logpsi_ave).real / (n_samples_node * mpi.n_nodes)
    der_logs_ave = jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], der_logs_ave)

    def gradfun(der_loc_vals, der_logs_ave):
        par_dims = der_loc_vals.ndim - 1

        _lloc_r = Lρ.reshape((n_samples_node,) + tuple(1 for i in range(par_dims)))

        grad = mean(der_loc_vals.conjugate() * _lloc_r, axis=0) - (
            der_logs_ave.conjugate() * LdagL_mean
        return grad

    LdagL_grad = jax.tree_util.tree_multimap(gradfun, der_loc_vals, der_logs_ave)

    # ⟨L†L⟩ ∈ R, so if the parameters are real we should cast away
    # the imaginary part of the gradient.
    # we do this also for standard gradient of energy.
    # this avoid errors in #867, #789, #850
    LdagL_grad = jax.tree_multimap(
        lambda x, target: (x if jnp.iscomplexobj(target) else x.real).astype(

    return (