def sum(a, axis=None, keepdims: bool = False): """ Compute the sum along the specified axis and over MPI processes. Args: a: The input array axis: Axis or axes along which the mean is computed. The default (None) is to compute the mean of the flattened array. out: An optional pre-allocated array to fill with the result. keepdims: If True the output array will have the same number of dimensions as the input, with the reduced axes having length 1. (default=False) Returns: The array with reduced dimensions defined by axis. If out is not none, returns out. """ # if it's a numpy-like array... if hasattr(a, "shape"): # use jax a_sum = a.sum(axis=axis, keepdims=keepdims) else: # assume it's a scalar a_sum = jnp.asarray(a) out, _ = mpi.mpi_sum_jax(a_sum) return out
def _to_array_rank(apply_fun, variables, σ_rank, n_states, normalize): """ Computes apply_fun(variables, σ_rank) and gathers all results across all ranks. The input σ_rank should be a slice of all states in the hilbert space of equal length across all ranks because mpi4jax does not support allgatherv (yet). n_states: total number of elements in the hilbert space """ # number of 'fake' states, in the last rank. n_fake_states = σ_rank.shape[0] * mpi.n_nodes - n_states psi_local = apply_fun(variables, σ_rank) # last rank, get rid of fake elements if mpi.rank == mpi.n_nodes - 1 and n_fake_states > 0: psi_local = jax.ops.index_update(psi_local, jax.ops.index[-n_fake_states:], 0.0) logmax, _ = mpi.mpi_max_jax(psi_local.real.max()) psi_local = jnp.exp(psi_local - logmax) # compute normalization if normalize: norm2 = jnp.linalg.norm(psi_local)**2 norm2, _ = mpi.mpi_sum_jax(norm2) psi_local /= jnp.sqrt(norm2) psi, _ = mpi.mpi_allgather_jax(psi_local) psi = psi.reshape(-1) # remove fake states psi = psi[0:n_states] return psi
def Odagger_DeltaO_v(forward_fn, params, samples, v): w = O_jvp(forward_fn, params, samples, v) w = w * (1.0 / (samples.shape[0] * samples.shape[1] * mpi.n_nodes)) w_, chunk_fn = unchunk(w) w = chunk_fn(subtract_mean(w_)) # w/ MPI res = OH_w(forward_fn, params, samples, w) return jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], res) # MPI
def O_vjp_rc(forward_fn, params, samples, w): _, vjp_fun = jax.vjp(forward_fn, params, samples) res_r, _ = vjp_fun(w) res_i, _ = vjp_fun(-1.0j * w) res = jax.tree_multimap(jax.lax.complex, res_r, res_i) return jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], res) # allreduce w/ MPI.SUM
def _matmul(self: QGTJacobianDenseT, vec: Union[PyTree, jnp.ndarray]) -> Union[PyTree, jnp.ndarray]: unravel = None if not hasattr(vec, "ndim"): vec, unravel = nkjax.tree_ravel(vec) # Real-imaginary split RHS in R→R and R→C modes reassemble = None if self.mode != "holomorphic" and not self._in_solve: vec, reassemble = vec_to_real(vec) if self.scale is not None: vec = vec * self.scale result = (mpi.mpi_sum_jax(((self.O @ vec).T.conj() @ self.O).T.conj())[0] + self.diag_shift * vec) if self.scale is not None: result = result * self.scale if reassemble is not None: result = reassemble(result) if unravel is not None: result = unravel(result) return result
def grad_expect_hermitian_chunked( chunk_size: int, local_value_kernel_chunked: Callable, model_apply_fun: Callable, mutable: bool, parameters: PyTree, model_state: PyTree, σ: jnp.ndarray, local_value_args: PyTree, ) -> Tuple[PyTree, PyTree]: σ_shape = σ.shape if jnp.ndim(σ) != 2: σ = σ.reshape((-1, σ_shape[-1])) n_samples = σ.shape[0] * mpi.n_nodes O_loc = local_value_kernel_chunked( model_apply_fun, {"params": parameters, **model_state}, σ, local_value_args, chunk_size=chunk_size, ) Ō = statistics(O_loc.reshape(σ_shape[:-1]).T) O_loc -= Ō.mean # Then compute the vjp. # Code is a bit more complex than a standard one because we support # mutable state (if it's there) if mutable is False: vjp_fun_chunked = nkjax.vjp_chunked( lambda w, σ: model_apply_fun({"params": w, **model_state}, σ), parameters, σ, conjugate=True, chunk_size=chunk_size, chunk_argnums=1, nondiff_argnums=1, ) new_model_state = None else: raise NotImplementedError Ō_grad = vjp_fun_chunked( (jnp.conjugate(O_loc) / n_samples), )[0] Ō_grad = jax.tree_multimap( lambda x, target: (x if jnp.iscomplexobj(target) else 2 * x.real).astype( target.dtype ), Ō_grad, parameters, ) return Ō, tree_map(lambda x: mpi.mpi_sum_jax(x)[0], Ō_grad), new_model_state
def grad_expect_hermitian( model_apply_fun: Callable, mutable: bool, parameters: PyTree, model_state: PyTree, σ: jnp.ndarray, σp: jnp.ndarray, mels: jnp.ndarray, ) -> Tuple[PyTree, PyTree]: σ_shape = σ.shape if jnp.ndim(σ) != 2: σ = σ.reshape((-1, σ_shape[-1])) n_samples = σ.shape[0] * mpi.n_nodes O_loc = local_cost_function( local_value_cost, model_apply_fun, {"params": parameters, **model_state}, σp, mels, σ, ) Ō = statistics(O_loc.reshape(σ_shape[:-1]).T) O_loc -= Ō.mean # Then compute the vjp. # Code is a bit more complex than a standard one because we support # mutable state (if it's there) if mutable is False: _, vjp_fun = nkjax.vjp( lambda w: model_apply_fun({"params": w, **model_state}, σ), parameters, conjugate=True, ) new_model_state = None else: _, vjp_fun, new_model_state = nkjax.vjp( lambda w: model_apply_fun({"params": w, **model_state}, σ, mutable=mutable), parameters, conjugate=True, has_aux=True, ) Ō_grad = vjp_fun(jnp.conjugate(O_loc) / n_samples)[0] Ō_grad = jax.tree_multimap( lambda x, target: (x if jnp.iscomplexobj(target) else x.real).astype( target.dtype ), Ō_grad, parameters, ) return Ō, tree_map(lambda x: mpi.mpi_sum_jax(x)[0], Ō_grad), new_model_state
def _to_dense(self: QGTJacobianDenseT) -> jnp.ndarray: if self.scale is None: O = self.O diag = jnp.eye(self.O.shape[1]) else: O = self.O * self.scale[jnp.newaxis, :] diag = jnp.diag(self.scale**2) return mpi.mpi_sum_jax(O.T.conj() @ O)[0] + self.diag_shift * diag
def _to_dense(self: QGTJacobianPyTreeT) -> jnp.ndarray: O = jax.vmap(lambda l: nkjax.tree_ravel(l)[0])(self.O) if self.scale is None: diag = jnp.eye(O.shape[1]) else: scale, _ = nkjax.tree_ravel(self.scale) O = O * scale[jnp.newaxis, :] diag = jnp.diag(scale**2) return mpi.mpi_sum_jax(O.T.conj() @ O)[0] + self.diag_shift * diag
def _rescale(centered_oks): """ compute ΔOₖ/√Sₖₖ and √Sₖₖ to do scale-invariant regularization (Becca & Sorella 2017, pp. 143) Sₖₗ/(√Sₖₖ√Sₗₗ) = ΔOₖᴴΔOₗ/(√Sₖₖ√Sₗₗ) = (ΔOₖ/√Sₖₖ)ᴴ(ΔOₗ/√Sₗₗ) """ scale = (mpi.mpi_sum_jax( jnp.sum((centered_oks * centered_oks.conj()).real, axis=0, keepdims=True))[0]**0.5) centered_oks = jnp.divide(centered_oks, scale) scale = jnp.squeeze(scale, axis=0) return centered_oks, scale
def _rescale(centered_oks): """ compute ΔOₖ/√Sₖₖ and √Sₖₖ to do scale-invariant regularization (Becca & Sorella 2017, pp. 143) Sₖₗ/(√Sₖₖ√Sₗₗ) = ΔOₖᴴΔOₗ/(√Sₖₖ√Sₗₗ) = (ΔOₖ/√Sₖₖ)ᴴ(ΔOₗ/√Sₗₗ) """ scale = jax.tree_map( lambda x: mpi.mpi_sum_jax( jnp.sum((x * x.conj()).real, axis=0, keepdims=True))[0]**0.5, centered_oks, ) centered_oks = jax.tree_map(jnp.divide, centered_oks, scale) scale = jax.tree_map(partial(jnp.squeeze, axis=0), scale) return centered_oks, scale
def mat_vec(jvp_fn, v, diag_shift): # Save linearisation work # TODO move to mat_vec_factory after jax v0.2.19 vjp_fn = jax.linear_transpose(jvp_fn, v) w = jvp_fn(v) w = w * (1.0 / (w.size * mpi.n_nodes)) w = subtract_mean(w) # w/ MPI # Oᴴw = (wᴴO)ᴴ = (w* O)* since 1D arrays are not transposed # vjp_fn packages output into a length-1 tuple (res, ) = tree_conj(vjp_fn(w.conjugate())) res = jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], res) return tree_axpy(diag_shift, v, res) # res + diag_shift * v
def _to_array_rank(apply_fun, variables, σ_rank, n_states, normalize, allgather): """ Computes apply_fun(variables, σ_rank) and gathers all results across all ranks. The input σ_rank should be a slice of all states in the hilbert space of equal length across all ranks because mpi4jax does not support allgatherv (yet). Args: n_states: total number of elements in the hilbert space. """ # number of 'fake' states, in the last rank. n_fake_states = σ_rank.shape[0] * mpi.n_nodes - n_states log_psi_local = apply_fun(variables, σ_rank) # last rank, get rid of fake elements if mpi.rank == mpi.n_nodes - 1 and n_fake_states > 0: log_psi_local = log_psi_local.at[jax.ops.index[-n_fake_states:]].set( -jnp.inf) if normalize: # subtract logmax for better numerical stability logmax, _ = mpi.mpi_max_jax(log_psi_local.real.max()) log_psi_local -= logmax psi_local = jnp.exp(log_psi_local) if normalize: # compute normalization norm2 = jnp.linalg.norm(psi_local)**2 norm2, _ = mpi.mpi_sum_jax(norm2) psi_local /= jnp.sqrt(norm2) if allgather: psi, _ = mpi.mpi_allgather_jax(psi_local) else: psi = psi_local psi = psi.reshape(-1) # remove fake states psi = psi[0:n_states] return psi
def grad_expect_operator_Lrho2( model_apply_fun: Callable, mutable: bool, parameters: PyTree, model_state: PyTree, σ: jnp.ndarray, σp: jnp.ndarray, mels: jnp.ndarray, ) -> Tuple[PyTree, PyTree, Stats]: σ_shape = σ.shape if jnp.ndim(σ) != 2: σ = σ.reshape((-1, σ_shape[-1])) n_samples_node = σ.shape[0] has_aux = mutable is not False # if not has_aux: # out_axes = (0, 0) # else: # out_axes = (0, 0, 0) if not has_aux: logpsi = lambda w, σ: model_apply_fun({"params": w, **model_state}, σ) else: # TODO: output the mutable state logpsi = lambda w, σ: model_apply_fun( {"params": w, **model_state}, σ, mutable=mutable )[0] # local_kernel_vmap = jax.vmap( # partial(local_value_kernel, logpsi), in_axes=(None, 0, 0, 0), out_axes=0 # ) # _Lρ = local_kernel_vmap(parameters, σ, σp, mels).reshape((σ_shape[0], -1)) ( Lρ, der_loc_vals, ) = _der_local_values_jax._local_values_and_grads_notcentered_kernel( logpsi, parameters, σp, mels, σ ) # _der_local_values_jax._local_values_and_grads_notcentered_kernel returns a loc_val that is conjugated Lρ = jnp.conjugate(Lρ) LdagL_stats = statistics((jnp.abs(Lρ) ** 2).T) LdagL_mean = LdagL_stats.mean # old implementation # this is faster, even though i think the one below should be faster # (this works, but... yeah. let's keep it here and delete in a while.) grad_fun = jax.vmap(nkjax.grad(logpsi, argnums=0), in_axes=(None, 0), out_axes=0) der_logs = grad_fun(parameters, σ) der_logs_ave = tree_map(lambda x: mean(x, axis=0), der_logs) # TODO # NEW IMPLEMENTATION # This should be faster, but should benchmark as it seems slower # to compute der_logs_ave i can just do a jvp with a ones vector # _logpsi_ave, d_logpsi = nkjax.vjp(lambda w: logpsi(w, σ), parameters) # TODO: this ones_like might produce a complexXX type but we only need floatXX # and we cut in 1/2 the # of operations to do. # der_logs_ave = d_logpsi( # jnp.ones_like(_logpsi_ave).real / (n_samples_node * utils.n_nodes) # )[0] der_logs_ave = tree_map(lambda x: mpi.mpi_sum_jax(x)[0], der_logs_ave) def gradfun(der_loc_vals, der_logs_ave): par_dims = der_loc_vals.ndim - 1 _lloc_r = Lρ.reshape((n_samples_node,) + tuple(1 for i in range(par_dims))) grad = mean(der_loc_vals.conjugate() * _lloc_r, axis=0) - ( der_logs_ave.conjugate() * LdagL_mean ) return grad LdagL_grad = jax.tree_util.tree_multimap(gradfun, der_loc_vals, der_logs_ave) return ( LdagL_stats, LdagL_grad, model_state, )
def n_accepted(self) -> int: """Total number of moves accepted across all processes since the last reset.""" res, _ = mpi.mpi_sum_jax(self.n_accepted_proc.sum()) return res
def _vjp(oks: PyTree, w: Array) -> PyTree: """ Compute the vector-matrix product between the vector w and the pytree jacobian oks """ res = jax.tree_map(partial(jnp.tensordot, w, axes=1), oks) return jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], res) # MPI
def O_vjp(forward_fn, params, samples, w): _, vjp_fun = jax.vjp(forward_fn, params, samples) res, _ = vjp_fun(w) return jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], res) # allreduce w/ MPI.SUM
def sum_inplace_jax(x): res, _ = mpi_sum_jax(x) return res
def grad_expect_operator_Lrho2( model_apply_fun: Callable, mutable: bool, parameters: PyTree, model_state: PyTree, σ: jnp.ndarray, σp: jnp.ndarray, mels: jnp.ndarray, ) -> Tuple[PyTree, PyTree, Stats]: σ_shape = σ.shape if jnp.ndim(σ) != 2: σ = σ.reshape((-1, σ_shape[-1])) n_samples_node = σ.shape[0] has_aux = mutable is not False # if not has_aux: # out_axes = (0, 0) # else: # out_axes = (0, 0, 0) if not has_aux: logpsi = lambda w, σ: model_apply_fun({"params": w, **model_state}, σ) else: # TODO: output the mutable state logpsi = lambda w, σ: model_apply_fun( {"params": w, **model_state}, σ, mutable=mutable )[0] # local_kernel_vmap = jax.vmap( # partial(local_value_kernel, logpsi), in_axes=(None, 0, 0, 0), out_axes=0 # ) # _Lρ = local_kernel_vmap(parameters, σ, σp, mels).reshape((σ_shape[0], -1)) ( Lρ, der_loc_vals, ) = _local_values_and_grads_notcentered_kernel(logpsi, parameters, σp, mels, σ) # _local_values_and_grads_notcentered_kernel returns a loc_val that is conjugated Lρ = jnp.conjugate(Lρ) LdagL_stats = statistics((jnp.abs(Lρ) ** 2).T) LdagL_mean = LdagL_stats.mean _logpsi_ave, d_logpsi = nkjax.vjp(lambda w: logpsi(w, σ), parameters) # TODO: this ones_like might produce a complexXX type but we only need floatXX # and we cut in 1/2 the # of operations to do. der_logs_ave = d_logpsi( jnp.ones_like(_logpsi_ave).real / (n_samples_node * mpi.n_nodes) )[0] der_logs_ave = jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], der_logs_ave) def gradfun(der_loc_vals, der_logs_ave): par_dims = der_loc_vals.ndim - 1 _lloc_r = Lρ.reshape((n_samples_node,) + tuple(1 for i in range(par_dims))) grad = mean(der_loc_vals.conjugate() * _lloc_r, axis=0) - ( der_logs_ave.conjugate() * LdagL_mean ) return grad LdagL_grad = jax.tree_util.tree_multimap(gradfun, der_loc_vals, der_logs_ave) # ⟨L†L⟩ ∈ R, so if the parameters are real we should cast away # the imaginary part of the gradient. # we do this also for standard gradient of energy. # this avoid errors in #867, #789, #850 LdagL_grad = jax.tree_multimap( lambda x, target: (x if jnp.iscomplexobj(target) else x.real).astype( target.dtype ), LdagL_grad, parameters, ) return ( LdagL_stats, LdagL_grad, model_state, )