def expect_and_grad( # noqa: F811 vstate: MCState, Ô: AbstractOperator, use_covariance: TrueT, *, mutable: Any, ) -> Tuple[Stats, PyTree]: σ, args = get_local_kernel_arguments(vstate, Ô) local_estimator_fun = get_local_kernel(vstate, Ô) Ō, Ō_grad, new_model_state = grad_expect_hermitian( local_estimator_fun, vstate._apply_fun, mutable, vstate.parameters, vstate.model_state, σ, args, ) if mutable is not False: vstate.model_state = new_model_state return Ō, Ō_grad
def expect_and_grad( # noqa: F811 vstate, Ô, use_covariance, *, mutable: Any, ) -> Tuple[Stats, PyTree]: σ, args = get_local_kernel_arguments(vstate, Ô) local_estimator_fun = get_local_kernel(vstate, Ô) Ō, Ō_grad, new_model_state = grad_expect_operator_kernel( local_estimator_fun, vstate._apply_fun, vstate.sampler.machine_pow, mutable, vstate.parameters, vstate.model_state, σ, args, ) if mutable is not False: vstate.model_state = new_model_state return Ō, Ō_grad
def local_estimators( state: MCState, op: AbstractOperator, *, chunk_size: Optional[int] ): s, extra_args = get_local_kernel_arguments(state, op) shape = s.shape if jnp.ndim(s) != 2: s = s.reshape((-1, shape[-1])) if chunk_size is None: chunk_size = state.chunk_size # state.chunk_size can still be None if chunk_size is None: kernel = get_local_kernel(state, op) else: kernel = get_local_kernel(state, op, chunk_size) return _local_estimators_kernel( kernel, state._apply_fun, shape[:-1], state.variables, s, extra_args )
def expect(vstate: MCState, Ô: AbstractOperator) -> Stats: # noqa: F811 σ, args = get_local_kernel_arguments(vstate, Ô) local_estimator_fun = get_local_kernel(vstate, Ô) return _expect( local_estimator_fun, vstate._apply_fun, vstate.sampler.machine_pow, vstate.parameters, vstate.model_state, σ, args, )
def expect_mcstate_operator_chunked( vstate: MCState, Ô: AbstractOperator, chunk_size: int ) -> Stats: # noqa: F811 σ, args = get_local_kernel_arguments(vstate, Ô) local_estimator_fun = get_local_kernel(vstate, Ô, chunk_size) return _expect_chunking( chunk_size, local_estimator_fun, vstate._apply_fun, vstate.sampler.machine_pow, vstate.parameters, vstate.model_state, σ, args, )
def expect_and_grad( # noqa: F811 vstate, Ô, use_covariance, *, mutable: Any, ) -> Tuple[Stats, PyTree]: if not isinstance(Ô, Squared) and not config.FLAGS["NETKET_EXPERIMENTAL"]: raise RuntimeError( """ Computing the gradient of non hermitian operator is an experimental feature under development and is known not to return wrong values sometimes. If you want to debug it, set the environment variable NETKET_EXPERIMENTAL=1 """ ) σ, args = get_local_kernel_arguments(vstate, Ô) local_estimator_fun = get_local_kernel(vstate, Ô) Ō, Ō_grad, new_model_state = grad_expect_operator_kernel( local_estimator_fun, vstate._apply_fun, vstate.sampler.machine_pow, mutable, vstate.parameters, vstate.model_state, σ, args, ) if mutable is not False: vstate.model_state = new_model_state return Ō, Ō_grad