def grad_expect_hermitian( model_apply_fun: Callable, mutable: bool, parameters: PyTree, model_state: PyTree, σ: jnp.ndarray, σp: jnp.ndarray, mels: jnp.ndarray, ) -> Tuple[PyTree, PyTree]: σ_shape = σ.shape if jnp.ndim(σ) != 2: σ = σ.reshape((-1, σ_shape[-1])) n_samples = σ.shape[0] * utils.n_nodes O_loc = local_cost_function( local_value_cost, model_apply_fun, {"params": parameters, **model_state}, σp, mels, σ, ) Ō = statistics(O_loc.reshape(σ_shape[:-1]).T) O_loc -= Ō.mean # Then compute the vjp. # Code is a bit more complex than a standard one because we support # mutable state (if it's there) if mutable is False: _, vjp_fun = nkjax.vjp( lambda w: model_apply_fun({"params": w, **model_state}, σ), parameters, conjugate=True, ) new_model_state = None else: _, vjp_fun, new_model_state = nkjax.vjp( lambda w: model_apply_fun({"params": w, **model_state}, σ, mutable=mutable), parameters, conjugate=True, has_aux=True, ) Ō_grad = vjp_fun(jnp.conjugate(O_loc) / n_samples)[0] Ō_grad = jax.tree_multimap( lambda x, target: (x if jnp.iscomplexobj(target) else x.real).astype( target.dtype ), Ō_grad, parameters, ) return Ō, tree_map(sum_inplace, Ō_grad), new_model_state
def expect_operator(self, Ô: AbstractOperator) -> Stats: σ = self.diagonal.samples σ_shape = σ.shape σ = σ.reshape((-1, σ.shape[-1])) σ_np = np.asarray(σ) σp, mels = Ô.get_conn_padded(σ_np) # now we have to concatenate the two O_loc = local_cost_function( local_value_op_op_cost, self._apply_fun, self.variables, σp, mels, σ, ).reshape(σ_shape[:-1]) # notice that loc.T is passed to statistics, since that function assumes # that the first index is the batch index. return statistics(O_loc.T)