예제 #1
0
def expect_and_grad(  # noqa: F811
    vstate: MCState,
    Ô: AbstractOperator,
    use_covariance: TrueT,
    *,
    mutable: Any,
) -> Tuple[Stats, PyTree]:
    σ, args = get_local_kernel_arguments(vstate, Ô)

    local_estimator_fun = get_local_kernel(vstate, Ô)

    Ō, Ō_grad, new_model_state = grad_expect_hermitian(
        local_estimator_fun,
        vstate._apply_fun,
        mutable,
        vstate.parameters,
        vstate.model_state,
        σ,
        args,
    )

    if mutable is not False:
        vstate.model_state = new_model_state

    return Ō, Ō_grad
예제 #2
0
def expect_and_grad(  # noqa: F811
    vstate,
    Ô,
    use_covariance,
    *,
    mutable: Any,
) -> Tuple[Stats, PyTree]:
    σ, args = get_local_kernel_arguments(vstate, Ô)

    local_estimator_fun = get_local_kernel(vstate, Ô)

    Ō, Ō_grad, new_model_state = grad_expect_operator_kernel(
        local_estimator_fun,
        vstate._apply_fun,
        vstate.sampler.machine_pow,
        mutable,
        vstate.parameters,
        vstate.model_state,
        σ,
        args,
    )

    if mutable is not False:
        vstate.model_state = new_model_state

    return Ō, Ō_grad
예제 #3
0
def local_estimators(
    state: MCState, op: AbstractOperator, *, chunk_size: Optional[int]
):
    s, extra_args = get_local_kernel_arguments(state, op)

    shape = s.shape
    if jnp.ndim(s) != 2:
        s = s.reshape((-1, shape[-1]))

    if chunk_size is None:
        chunk_size = state.chunk_size  # state.chunk_size can still be None

    if chunk_size is None:
        kernel = get_local_kernel(state, op)
    else:
        kernel = get_local_kernel(state, op, chunk_size)

    return _local_estimators_kernel(
        kernel, state._apply_fun, shape[:-1], state.variables, s, extra_args
    )
예제 #4
0
def expect(vstate: MCState, Ô: AbstractOperator) -> Stats:  # noqa: F811
    σ, args = get_local_kernel_arguments(vstate, Ô)

    local_estimator_fun = get_local_kernel(vstate, Ô)

    return _expect(
        local_estimator_fun,
        vstate._apply_fun,
        vstate.sampler.machine_pow,
        vstate.parameters,
        vstate.model_state,
        σ,
        args,
    )
예제 #5
0
def expect_mcstate_operator_chunked(
    vstate: MCState, Ô: AbstractOperator, chunk_size: int
) -> Stats:  # noqa: F811
    σ, args = get_local_kernel_arguments(vstate, Ô)

    local_estimator_fun = get_local_kernel(vstate, Ô, chunk_size)

    return _expect_chunking(
        chunk_size,
        local_estimator_fun,
        vstate._apply_fun,
        vstate.sampler.machine_pow,
        vstate.parameters,
        vstate.model_state,
        σ,
        args,
    )
예제 #6
0
def expect_and_grad(  # noqa: F811
    vstate,
    Ô,
    use_covariance,
    *,
    mutable: Any,
) -> Tuple[Stats, PyTree]:

    if not isinstance(Ô, Squared) and not config.FLAGS["NETKET_EXPERIMENTAL"]:
        raise RuntimeError(
            """
            Computing the gradient of non hermitian operator is an
            experimental feature under development and is known not to
            return wrong values sometimes.

            If you want to debug it, set the environment variable
            NETKET_EXPERIMENTAL=1
            """
        )

    σ, args = get_local_kernel_arguments(vstate, Ô)

    local_estimator_fun = get_local_kernel(vstate, Ô)

    Ō, Ō_grad, new_model_state = grad_expect_operator_kernel(
        local_estimator_fun,
        vstate._apply_fun,
        vstate.sampler.machine_pow,
        mutable,
        vstate.parameters,
        vstate.model_state,
        σ,
        args,
    )

    if mutable is not False:
        vstate.model_state = new_model_state

    return Ō, Ō_grad