示例#1
0
def _solve(self: QGTJacobianPyTreeT,
           solve_fun,
           y: PyTree,
           *,
           x0: Optional[PyTree] = None) -> PyTree:

    # Real-imaginary split RHS in R→R and R→C modes
    if self.mode != "holomorphic":
        y, reassemble = nkjax.tree_to_real(y)
        if x0 is not None:
            x0, _ = nkjax.tree_to_real(x0)

    check_valid_vector_type(self.params, y)

    if self.scale is not None:
        y = jax.tree_multimap(jnp.divide, y, self.scale)
        if x0 is not None:
            x0 = jax.tree_multimap(jnp.multiply, x0, self.scale)

    # to pass the object LinearOperator itself down
    # but avoid rescaling, we pass down an object with
    # scale = None
    # mode=holomoprhic to disable splitting the complex part
    unscaled_self = self.replace(scale=None, _in_solve=True)

    out, info = solve_fun(unscaled_self, y, x0=x0)

    if self.scale is not None:
        out = jax.tree_multimap(jnp.divide, out, self.scale)

    # Reassemble real-imaginary split as needed
    if self.mode != "holomorphic":
        out = reassemble(out)

    return out, info
示例#2
0
def DeltaOdagger_DeltaO_v(forward_fn, params, samples, v, holomorphic=True):
    r"""
    compute \langle \Delta O^\dagger \Delta O \rangle v

    where \Delta O = O - \langle O \rangle
    """

    homogeneous = nkjax.tree_ishomogeneous(params)
    real_params = not nkjax.tree_leaf_iscomplex(params)
    #  real_out = not nkjax.is_complex(jax.eval_shape(forward_fn, params, samples))

    if not (homogeneous and (real_params or holomorphic)):
        # everything except R->R, holomorphic C->C and R->C
        params, reassemble = nkjax.tree_to_real(params)
        v, _ = nkjax.tree_to_real(v)
        _forward_fn = forward_fn

        def forward_fn(p, x):
            return _forward_fn(reassemble(p), x)

    omean = O_mean(forward_fn, params, samples, holomorphic=holomorphic)

    def forward_fn_centered(p, x):
        return forward_fn(p, x) - tree_dot(p, omean)

    res = Odagger_O_v(forward_fn_centered, params, samples, v)

    if not (homogeneous and (real_params or holomorphic)):
        res = reassemble(res)
    return res
示例#3
0
def vec_to_real(vec: Array) -> Tuple[Array, Callable]:
    """
    If the input vector is real, splits the vector into real
    and imaginary parts and concatenates them along the 0-th
    axis.

    It is equivalent to changing the complex storage from AOS
    to SOA.

    Args:
        vec: a dense vector
    """
    out, reassemble = nkjax.tree_to_real(vec)

    if nkjax.is_complex(vec):
        re, im = out

        out = jnp.concatenate([re, im], axis=0)

        def reassemble_concat(x):
            x = tuple(jnp.split(x, 2, axis=0))
            return reassemble(x)

    else:
        reassemble_concat = reassemble

    return out, reassemble_concat
示例#4
0
def test_matvec_treemv_modes(e, jit, holomorphic, pardtype, outdtype):
    diag_shift = 0.01
    model_state = {}
    rescale_shift = False

    def apply_fun(params, samples):
        return e.f(params["params"], samples)

    mv = qgt_jacobian_pytree_logic.mat_vec

    homogeneous = pardtype is not None

    if not nkjax.is_complex_dtype(outdtype):
        mode = "real"
    elif homogeneous and nkjax.is_complex_dtype(pardtype) and holomorphic:
        mode = "holomorphic"
    else:
        mode = "complex"

    if mode == "holomorphic":
        v = e.v
        reassemble = lambda x: x
    else:
        v, reassemble = nkjax.tree_to_real(e.v)

    if jit:
        mv = jax.jit(mv)

    centered_oks, _ = qgt_jacobian_pytree_logic.prepare_centered_oks(
        apply_fun, e.params, e.samples, model_state, mode, rescale_shift)
    actual = reassemble(mv(v, centered_oks, diag_shift))
    expected = reassemble_complex(e.S_real @ e.v_real_flat +
                                  diag_shift * e.v_real_flat,
                                  target=e.target)
    assert tree_allclose(actual, expected)
示例#5
0
def _matmul(self: QGTJacobianPyTreeT,
            vec: Union[PyTree, Array]) -> Union[PyTree, Array]:
    # Turn vector RHS into PyTree
    if hasattr(vec, "ndim"):
        _, unravel = nkjax.tree_ravel(self.params)
        vec = unravel(vec)
        ravel = True
    else:
        ravel = False

    # Real-imaginary split RHS in R→R and R→C modes
    reassemble = None
    if self.mode != "holomorphic" and not self._in_solve:
        vec, reassemble = nkjax.tree_to_real(vec)

    check_valid_vector_type(self.params, vec)

    if self.scale is not None:
        vec = jax.tree_multimap(jnp.multiply, vec, self.scale)

    result = mat_vec(vec, self.O, self.diag_shift)

    if self.scale is not None:
        result = jax.tree_multimap(jnp.multiply, result, self.scale)

    # Reassemble real-imaginary split as needed
    if reassemble is not None:
        result = reassemble(result)

    # Ravel PyTree back into vector as needed
    if ravel:
        result, _ = nkjax.tree_ravel(result)

    return result
示例#6
0
def prepare_centered_oks(
    apply_fun: Callable,
    params: PyTree,
    samples: Array,
    model_state: Optional[PyTree],
    mode: str,
    rescale_shift: bool,
    pdf=None,
    chunk_size: int = None,
) -> PyTree:
    """
    compute ΔOⱼₖ = Oⱼₖ - ⟨Oₖ⟩ = ∂/∂pₖ ln Ψ(σⱼ) - ⟨∂/∂pₖ ln Ψ⟩
    divided by √n

    In a somewhat intransparent way this also internally splits all parameters to real
    in the 'real' and 'complex' modes (for C→R, R&C→R, R&C→C and general C→C) resulting in the respective ΔOⱼₖ
    which is only compatible with split-to-real pytree vectors

    Args:
        apply_fun: The forward pass of the Ansatz
        params : a pytree of parameters p
        samples : an array of (n in total) batched samples σ
        model_state: untrained state parameters of the model
        mode: differentiation mode, must be one of 'real', 'complex', 'holomorphic'
        rescale_shift: whether scale-invariant regularisation should be used (default: True)
        pdf: |ψ(x)|^2 if exact optimization is being used else None
        chunk_size: an int specifying the size of the chunks the gradient should be computed in (default: None)

    Returns:
        if not rescale_shift:
            a pytree representing the centered jacobian of ln Ψ evaluated at the samples σ, divided by √n;
            None
        else:
            the same pytree, but the entries for each parameter normalised to unit norm;
            pytree containing the norms that were divided out (same shape as params)

    """
    # un-batch the samples
    samples = samples.reshape((-1, samples.shape[-1]))

    # pre-apply the model state
    def forward_fn(W, σ):
        return apply_fun({"params": W, **model_state}, σ)

    if mode == "real":
        split_complex_params = True  # convert C→R and R&C→R to R→R
        centered_jacobian_fun = centered_jacobian_real_holo
        jacobian_fun = jacobian_real_holo
    elif mode == "complex":
        split_complex_params = True  # convert C→C and R&C→C to R→C
        # centered_jacobian_fun = compose(stack_jacobian, centered_jacobian_cplx)

        # avoid converting to complex and then back
        # by passing around the oks as a tuple of two pytrees representing the real and imag parts
        centered_jacobian_fun = compose(
            stack_jacobian_tuple,
            partial(centered_jacobian_cplx, _build_fn=lambda *x: x),
        )
        jacobian_fun = jacobian_cplx
    elif mode == "holomorphic":
        split_complex_params = False
        centered_jacobian_fun = centered_jacobian_real_holo
        jacobian_fun = jacobian_real_holo
    else:
        raise NotImplementedError(
            'Differentiation mode should be one of "real", "complex", or "holomorphic", got {}'
            .format(mode))

    if split_complex_params:
        # doesn't do anything if the params are already real
        params, reassemble = tree_to_real(params)

        def f(W, σ):
            return forward_fn(reassemble(W), σ)

    else:
        f = forward_fn

    if pdf is None:
        centered_oks = _divide_by_sqrt_n_samp(
            centered_jacobian_fun(
                f,
                params,
                samples,
                chunk_size=chunk_size,
            ),
            samples,
        )
    else:
        oks = jacobian_fun(f, params, samples)
        oks_mean = jax.tree_map(partial(sum, axis=0),
                                _multiply_by_pdf(oks, pdf))
        centered_oks = jax.tree_map(lambda x, y: x - y, oks, oks_mean)

        centered_oks = _multiply_by_pdf(centered_oks, jnp.sqrt(pdf))
    if rescale_shift:
        return _rescale(centered_oks)
    else:
        return centered_oks, None