예제 #1
0
def onthefly_mat_treevec(
        S: QGTOnTheFly, vec: Union[PyTree,
                                   jnp.ndarray]) -> Union[PyTree, jnp.ndarray]:
    """
    Perform the lazy mat-vec product, where vec is either a tree with the same structure as
    params or a ravelled vector
    """

    # if hasa ndim it's an array and not a pytree
    if hasattr(vec, "ndim"):
        if not vec.ndim == 1:
            raise ValueError("Unsupported mat-vec for chunks of vectors")
        # If the input is a vector
        if not nkjax.tree_size(S._params) == vec.size:
            raise ValueError(
                """Size mismatch between number of parameters ({nkjax.tree_size(S.params)})
                                and vector size {vec.size}.
                             """)

        _, unravel = nkjax.tree_ravel(S._params)
        vec = unravel(vec)
        ravel_result = True
    else:
        ravel_result = False

    check_valid_vector_type(S._params, vec)

    vec = nkjax.tree_cast(vec, S._params)

    res = S._mat_vec(vec, S.diag_shift)

    if ravel_result:
        res, _ = nkjax.tree_ravel(res)

    return res
예제 #2
0
def _solve(self: QGTOnTheFlyT, solve_fun, y: PyTree, *, x0: Optional[PyTree],
           **kwargs) -> PyTree:

    y = nkjax.tree_cast(y, self.params)

    # we could cache this...
    if x0 is None:
        x0 = jax.tree_map(jnp.zeros_like, y)

    out, info = solve_fun(self, y, x0=x0)
    return out, info
예제 #3
0
def OH_w(forward_fn, params, samples, w):
    r"""
    compute  O^H w
    (where ^H is the hermitian transpose)
    """

    # O^H w = (w^H O)^H
    # The transposition of the 1D arrays is omitted in the implementation:
    # (w^H O)^H -> (w* O)*

    # TODO The allreduce in O_vjp could be deferred until after the tree_cast
    # where the amount of data to be transferred would potentially be smaller
    res = tree_conj(O_vjp(forward_fn, params, samples, w.conjugate()))

    return tree_cast(res, params)
예제 #4
0
def onthefly_mat_treevec(
        S: QGTOnTheFly, vec: Union[PyTree,
                                   jnp.ndarray]) -> Union[PyTree, jnp.ndarray]:
    """
    Perform the lazy mat-vec product, where vec is either a tree with the same structure as
    params or a ravelled vector
    """

    # if hasa ndim it's an array and not a pytree
    if hasattr(vec, "ndim"):
        if not vec.ndim == 1:
            raise ValueError("Unsupported mat-vec for batches of vectors")
        # If the input is a vector
        if not nkjax.tree_size(S.params) == vec.size:
            raise ValueError(
                """Size mismatch between number of parameters ({nkjax.tree_size(S.params)})
                                and vector size {vec.size}.
                             """)

        _, unravel = nkjax.tree_ravel(S.params)
        vec = unravel(vec)
        ravel_result = True
    else:
        ravel_result = False

    vec = nkjax.tree_cast(vec, S.params)

    def fun(W, σ):
        return S.apply_fun({"params": W, **S.model_state}, σ)

    mat_vec = partial(
        mat_vec_onthefly,
        forward_fn=fun,
        params=S.params,
        samples=S.samples,
        diag_shift=S.diag_shift,
    )

    res = mat_vec(vec)

    if ravel_result:
        res, _ = nkjax.tree_ravel(res)

    return res
예제 #5
0
def reassemble_complex(x, target, fun=tree_toreal_flat):
    # target: a tree with the expected shape and types of the result
    (res,) = jax.linear_transpose(fun, target)(x)
    res = qgt_onthefly_logic.tree_conj(res)
    # fix the dtypes:
    return nkjax.tree_cast(res, target)
예제 #6
0
def _mat_vec(v: PyTree, oks: PyTree) -> PyTree:
    """
    Compute ⟨O† O⟩v = ∑ₗ ⟨Oₖᴴ Oₗ⟩ vₗ
    """
    res = tree_conj(_vjp(oks, _jvp(oks, v).conjugate()))
    return tree_cast(res, v)