Exemplo n.º 1
0
def grad_expect_hermitian(
    model_apply_fun: Callable,
    mutable: bool,
    parameters: PyTree,
    model_state: PyTree,
    σ: jnp.ndarray,
    σp: jnp.ndarray,
    mels: jnp.ndarray,
) -> Tuple[PyTree, PyTree]:

    σ_shape = σ.shape
    if jnp.ndim(σ) != 2:
        σ = σ.reshape((-1, σ_shape[-1]))

    n_samples = σ.shape[0] * utils.n_nodes

    O_loc = local_cost_function(
        local_value_cost,
        model_apply_fun,
        {"params": parameters, **model_state},
        σp,
        mels,
        σ,
    )

    Ō = statistics(O_loc.reshape(σ_shape[:-1]).T)

    O_loc -= Ō.mean

    # Then compute the vjp.
    # Code is a bit more complex than a standard one because we support
    # mutable state (if it's there)
    if mutable is False:
        _, vjp_fun = nkjax.vjp(
            lambda w: model_apply_fun({"params": w, **model_state}, σ),
            parameters,
            conjugate=True,
        )
        new_model_state = None
    else:
        _, vjp_fun, new_model_state = nkjax.vjp(
            lambda w: model_apply_fun({"params": w, **model_state}, σ, mutable=mutable),
            parameters,
            conjugate=True,
            has_aux=True,
        )
    Ō_grad = vjp_fun(jnp.conjugate(O_loc) / n_samples)[0]

    Ō_grad = jax.tree_multimap(
        lambda x, target: (x if jnp.iscomplexobj(target) else x.real).astype(
            target.dtype
        ),
        Ō_grad,
        parameters,
    )

    return Ō, tree_map(sum_inplace, Ō_grad), new_model_state
Exemplo n.º 2
0
    def expect_operator(self, Ô: AbstractOperator) -> Stats:
        σ = self.diagonal.samples
        σ_shape = σ.shape
        σ = σ.reshape((-1, σ.shape[-1]))

        σ_np = np.asarray(σ)
        σp, mels = Ô.get_conn_padded(σ_np)

        # now we have to concatenate the two
        O_loc = local_cost_function(
            local_value_op_op_cost,
            self._apply_fun,
            self.variables,
            σp,
            mels,
            σ,
        ).reshape(σ_shape[:-1])

        # notice that loc.T is passed to statistics, since that function assumes
        # that the first index is the batch index.
        return statistics(O_loc.T)