def mat_vec(v, forward_fn, params, samples, diag_shift, centered=True, holomorphic=True): r""" compute (S + diag_shift) v where the elements of S are given by one of the following equivalent formulations: if centered=True (default): S_kl = \langle \Delta O_k^\dagger \Delta O_l \rangle if centered=False : S_kl = \langle O_k^\dagger \Delta O_l \rangle where \Delta O_k = O_k - \langle O_k \rangle and O_k (operator) is derivative of the log wavefunction w.r.t parameter k The expectation values are calculated as mean over the samples v: a pytree with the same structure as params forward_fn(params, x): a vectorised function returning the logarithm of the wavefunction for each configuration in x params: pytree of parameters with arrays as leaves samples: an array of samples (when using MPI each rank has its own slice of samples) diag_shift: a scalar diagonal shift holomorphic: whether forward_fn is holomorphic (only needed if centered=True and forward_fn has complex params and output) """ if centered: f = partial(DeltaOdagger_DeltaO_v, holomorphic=holomorphic) else: f = Odagger_DeltaO_v res = f(forward_fn, params, samples, v) # add diagonal shift: res = tree_axpy(diag_shift, v, res) # res += diag_shift * v return res
def mat_vec(jvp_fn, v, diag_shift): # Save linearisation work # TODO move to mat_vec_factory after jax v0.2.19 vjp_fn = jax.linear_transpose(jvp_fn, v) w = jvp_fn(v) w = w * (1.0 / (w.size * mpi.n_nodes)) w = subtract_mean(w) # w/ MPI # Oᴴw = (wᴴO)ᴴ = (w* O)* since 1D arrays are not transposed # vjp_fn packages output into a length-1 tuple (res, ) = tree_conj(vjp_fn(w.conjugate())) res = jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], res) return tree_axpy(diag_shift, v, res) # res + diag_shift * v
def mat_vec(v: PyTree, centered_oks: PyTree, diag_shift: Scalar) -> PyTree: """ Compute (S + δ) v = 1/n ⟨ΔO† ΔO⟩v + δ v = ∑ₗ 1/n ⟨ΔOₖᴴΔOₗ⟩ vₗ + δ vₗ Only compatible with R→R, R→C, and holomorphic C→C for C→R, R&C→R, R&C→C and general C→C the parameters for generating ΔOⱼₖ should be converted to R, and thus also the v passed to this function as well as the output are expected to be of this form Args: v: pytree representing the vector v compatible with centered_oks centered_oks: pytree of gradients 1/√n ΔOⱼₖ diag_shift: a scalar diagonal shift δ Returns: a pytree corresponding to the sr matrix-vector product (S + δ) v """ return tree_axpy(diag_shift, v, _mat_vec(v, centered_oks))
def mat_vec(v, forward_fn, params, samples, diag_shift): r""" compute (S + diag_shift) v where the elements of S are given by Sₖₗ = ⟨Oₖ†ΔOₗ⟩ where ΔOₖ = Oₖ-⟨Oₖ⟩ and Oₖ (operator) is derivative of the log wavefunction w.r.t parameter k The expectation values are calculated as mean over the samples v: a pytree with the same structure as params forward_fn(params, x): a vectorised function returning the logarithm of the wavefunction for each configuration in x params: pytree of parameters with arrays as leaves samples: an array of samples (when using MPI each rank has its own slice of samples) diag_shift: a scalar diagonal shift """ res = Odagger_DeltaO_v(forward_fn, params, samples, v) # add diagonal shift: res = tree_axpy(diag_shift, v, res) # res += diag_shift * v return res
def mat_vec_chunked(forward_fn, params, samples, v, diag_shift): res = Odagger_DeltaO_v(forward_fn, params, samples, v) return tree_axpy(diag_shift, v, res)