コード例 #1
0
def test_matvec_linear_transpose(e, jit, chunk_size):
    def f(params_model_state, x):
        return e.f(params_model_state["params"], x)

    if chunk_size is None:
        mat_vec_factory = qgt_onthefly_logic.mat_vec_factory
        samples = e.samples
    else:
        mat_vec_factory = qgt_onthefly_logic.mat_vec_chunked_factory
        samples = e.samples.reshape((-1, chunk_size) + e.samples.shape[1:])

    mv = mat_vec_factory(f, e.params, {}, samples)

    def mvt(v, w):
        (res,) = jax.linear_transpose(lambda v_: mv(v_, 0.0), v)(w)
        return res

    if jit:
        mv = jax.jit(mv)
        mvt = jax.jit(mvt)

    w = e.v
    actual = mvt(e.v, w)

    # use that S is hermitian:
    # S^T = (O^H O)^T = O^T O* = (O^H O)* = S*
    # S^T w = S* w = (S w*)*
    expected = nkjax.tree_conj(
        mv(
            nkjax.tree_conj(w),
            0.0,
        )
    )
    # (expected,) = jax.linear_transpose(lambda v_: reassemble_complex(S_real @ tree_toreal_flat(v_), target=e.target), v)(v)
    assert tree_allclose(actual, expected)
コード例 #2
0
def OH_w(forward_fn, params, samples, w):
    r"""
    compute  Oᴴw
    (where ᴴ denotes the hermitian transpose)
    """

    # Oᴴw = (wᴴO)ᴴ
    # The transposition of the 1D arrays is omitted in the implementation:
    # (wᴴO)ᴴ -> (w* O)*

    return tree_conj(O_vjp(forward_fn, params, samples, w.conjugate()))
コード例 #3
0
def mat_vec(jvp_fn, v, diag_shift):
    # Save linearisation work
    # TODO move to mat_vec_factory after jax v0.2.19
    vjp_fn = jax.linear_transpose(jvp_fn, v)

    w = jvp_fn(v)
    w = w * (1.0 / (w.size * mpi.n_nodes))
    w = subtract_mean(w)  # w/ MPI
    # Oᴴw = (wᴴO)ᴴ = (w* O)* since 1D arrays are not transposed
    # vjp_fn packages output into a length-1 tuple
    (res, ) = tree_conj(vjp_fn(w.conjugate()))
    res = jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], res)

    return tree_axpy(diag_shift, v, res)  # res + diag_shift * v
コード例 #4
0
def OH_w(forward_fn, params, samples, w):
    r"""
    compute  O^H w
    (where ^H is the hermitian transpose)
    """

    # O^H w = (w^H O)^H
    # The transposition of the 1D arrays is omitted in the implementation:
    # (w^H O)^H -> (w* O)*

    # TODO The allreduce in O_vjp could be deferred until after the tree_cast
    # where the amount of data to be transferred would potentially be smaller
    res = tree_conj(O_vjp(forward_fn, params, samples, w.conjugate()))

    return tree_cast(res, params)
コード例 #5
0
def _tree_reassemble_complex(x, target, fun=_tree_to_reim):
    (res,) = jax.linear_transpose(fun, target)(x)
    return nkjax.tree_conj(res)
コード例 #6
0
def _mat_vec(v: PyTree, oks: PyTree) -> PyTree:
    """
    Compute ⟨O† O⟩v = ∑ₗ ⟨Oₖᴴ Oₗ⟩ vₗ
    """
    res = tree_conj(_vjp(oks, _jvp(oks, v).conjugate()))
    return tree_cast(res, v)
コード例 #7
0
def OH_w(forward_fn, params, samples, w):
    return tree_conj(O_vjp(forward_fn, params, samples, w.conjugate()))
コード例 #8
0
 def _mv_trans(extra_args, y):
     # the linear operator is hermitian
     params, samples, diag_shift = extra_args
     return tree_conj(
         mat_vec_chunked(forward_fn, params, samples, tree_conj(y),
                         diag_shift))