Example #1
0
def mat_vec(v,
            forward_fn,
            params,
            samples,
            diag_shift,
            centered=True,
            holomorphic=True):
    r"""
    compute (S + diag_shift) v

    where the elements of S are given by one of the following equivalent formulations:

    if centered=True (default): S_kl = \langle \Delta O_k^\dagger \Delta O_l \rangle
    if centered=False : S_kl = \langle O_k^\dagger \Delta O_l \rangle

    where \Delta O_k = O_k - \langle O_k \rangle
    and O_k (operator) is derivative of the log wavefunction w.r.t parameter k
    The expectation values are calculated as mean over the samples

    v: a pytree with the same structure as params
    forward_fn(params, x): a vectorised function returning the logarithm of the wavefunction for each configuration in x
    params: pytree of parameters with arrays as leaves
    samples: an array of samples (when using MPI each rank has its own slice of samples)
    diag_shift: a scalar diagonal shift
    holomorphic: whether forward_fn is holomorphic (only needed if centered=True and forward_fn has complex params and output)
    """

    if centered:
        f = partial(DeltaOdagger_DeltaO_v, holomorphic=holomorphic)
    else:
        f = Odagger_DeltaO_v
    res = f(forward_fn, params, samples, v)
    # add diagonal shift:
    res = tree_axpy(diag_shift, v, res)  # res += diag_shift * v
    return res
Example #2
0
def mat_vec(jvp_fn, v, diag_shift):
    # Save linearisation work
    # TODO move to mat_vec_factory after jax v0.2.19
    vjp_fn = jax.linear_transpose(jvp_fn, v)

    w = jvp_fn(v)
    w = w * (1.0 / (w.size * mpi.n_nodes))
    w = subtract_mean(w)  # w/ MPI
    # Oᴴw = (wᴴO)ᴴ = (w* O)* since 1D arrays are not transposed
    # vjp_fn packages output into a length-1 tuple
    (res, ) = tree_conj(vjp_fn(w.conjugate()))
    res = jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], res)

    return tree_axpy(diag_shift, v, res)  # res + diag_shift * v
Example #3
0
def mat_vec(v: PyTree, centered_oks: PyTree, diag_shift: Scalar) -> PyTree:
    """
    Compute (S + δ) v = 1/n ⟨ΔO† ΔO⟩v + δ v = ∑ₗ 1/n ⟨ΔOₖᴴΔOₗ⟩ vₗ + δ vₗ

    Only compatible with R→R, R→C, and holomorphic C→C
    for C→R, R&C→R, R&C→C and general C→C the parameters for generating ΔOⱼₖ should be converted to R,
    and thus also the v passed to this function as well as the output are expected to be of this form

    Args:
        v: pytree representing the vector v compatible with centered_oks
        centered_oks: pytree of gradients 1/√n ΔOⱼₖ
        diag_shift: a scalar diagonal shift δ
    Returns:
        a pytree corresponding to the sr matrix-vector product (S + δ) v
    """
    return tree_axpy(diag_shift, v, _mat_vec(v, centered_oks))
Example #4
0
def mat_vec(v, forward_fn, params, samples, diag_shift):
    r"""
    compute (S + diag_shift) v

    where the elements of S are given by Sₖₗ = ⟨Oₖ†ΔOₗ⟩
    where ΔOₖ = Oₖ-⟨Oₖ⟩
    and Oₖ (operator) is derivative of the log wavefunction w.r.t parameter k
    The expectation values are calculated as mean over the samples

    v: a pytree with the same structure as params
    forward_fn(params, x): a vectorised function returning the logarithm of the wavefunction for each configuration in x
    params: pytree of parameters with arrays as leaves
    samples: an array of samples (when using MPI each rank has its own slice of samples)
    diag_shift: a scalar diagonal shift
    """

    res = Odagger_DeltaO_v(forward_fn, params, samples, v)
    # add diagonal shift:
    res = tree_axpy(diag_shift, v, res)  # res += diag_shift * v
    return res
Example #5
0
def mat_vec_chunked(forward_fn, params, samples, v, diag_shift):
    res = Odagger_DeltaO_v(forward_fn, params, samples, v)
    return tree_axpy(diag_shift, v, res)