Пример #1
0
def expect_and_grad(  # noqa: F811
    vstate: MCState,
    Ô: AbstractOperator,
    use_covariance: TrueT,
    *,
    mutable: Any,
) -> Tuple[Stats, PyTree]:
    σ, args = get_local_kernel_arguments(vstate, Ô)

    local_estimator_fun = get_local_kernel(vstate, Ô)

    Ō, Ō_grad, new_model_state = grad_expect_hermitian(
        local_estimator_fun,
        vstate._apply_fun,
        mutable,
        vstate.parameters,
        vstate.model_state,
        σ,
        args,
    )

    if mutable is not False:
        vstate.model_state = new_model_state

    return Ō, Ō_grad
Пример #2
0
def expect_and_grad(  # noqa: F811
    vstate,
    Ô,
    use_covariance,
    *,
    mutable: Any,
) -> Tuple[Stats, PyTree]:
    σ, args = get_local_kernel_arguments(vstate, Ô)

    local_estimator_fun = get_local_kernel(vstate, Ô)

    Ō, Ō_grad, new_model_state = grad_expect_operator_kernel(
        local_estimator_fun,
        vstate._apply_fun,
        vstate.sampler.machine_pow,
        mutable,
        vstate.parameters,
        vstate.model_state,
        σ,
        args,
    )

    if mutable is not False:
        vstate.model_state = new_model_state

    return Ō, Ō_grad
Пример #3
0
def local_estimators(
    state: MCState, op: AbstractOperator, *, chunk_size: Optional[int]
):
    s, extra_args = get_local_kernel_arguments(state, op)

    shape = s.shape
    if jnp.ndim(s) != 2:
        s = s.reshape((-1, shape[-1]))

    if chunk_size is None:
        chunk_size = state.chunk_size  # state.chunk_size can still be None

    if chunk_size is None:
        kernel = get_local_kernel(state, op)
    else:
        kernel = get_local_kernel(state, op, chunk_size)

    return _local_estimators_kernel(
        kernel, state._apply_fun, shape[:-1], state.variables, s, extra_args
    )
Пример #4
0
def expect(vstate: MCState, Ô: AbstractOperator) -> Stats:  # noqa: F811
    σ, args = get_local_kernel_arguments(vstate, Ô)

    local_estimator_fun = get_local_kernel(vstate, Ô)

    return _expect(
        local_estimator_fun,
        vstate._apply_fun,
        vstate.sampler.machine_pow,
        vstate.parameters,
        vstate.model_state,
        σ,
        args,
    )
Пример #5
0
def expect_mcstate_operator_chunked(
    vstate: MCState, Ô: AbstractOperator, chunk_size: int
) -> Stats:  # noqa: F811
    σ, args = get_local_kernel_arguments(vstate, Ô)

    local_estimator_fun = get_local_kernel(vstate, Ô, chunk_size)

    return _expect_chunking(
        chunk_size,
        local_estimator_fun,
        vstate._apply_fun,
        vstate.sampler.machine_pow,
        vstate.parameters,
        vstate.model_state,
        σ,
        args,
    )
Пример #6
0
def expect_and_grad(  # noqa: F811
    vstate,
    Ô,
    use_covariance,
    *,
    mutable: Any,
) -> Tuple[Stats, PyTree]:

    if not isinstance(Ô, Squared) and not config.FLAGS["NETKET_EXPERIMENTAL"]:
        raise RuntimeError(
            """
            Computing the gradient of non hermitian operator is an
            experimental feature under development and is known not to
            return wrong values sometimes.

            If you want to debug it, set the environment variable
            NETKET_EXPERIMENTAL=1
            """
        )

    σ, args = get_local_kernel_arguments(vstate, Ô)

    local_estimator_fun = get_local_kernel(vstate, Ô)

    Ō, Ō_grad, new_model_state = grad_expect_operator_kernel(
        local_estimator_fun,
        vstate._apply_fun,
        vstate.sampler.machine_pow,
        mutable,
        vstate.parameters,
        vstate.model_state,
        σ,
        args,
    )

    if mutable is not False:
        vstate.model_state = new_model_state

    return Ō, Ō_grad