def grad_expect_hermitian( model_apply_fun: Callable, mutable: bool, parameters: PyTree, model_state: PyTree, σ: jnp.ndarray, σp: jnp.ndarray, mels: jnp.ndarray, ) -> Tuple[PyTree, PyTree]: σ_shape = σ.shape if jnp.ndim(σ) != 2: σ = σ.reshape((-1, σ_shape[-1])) n_samples = σ.shape[0] * utils.n_nodes O_loc = local_cost_function( local_value_cost, model_apply_fun, {"params": parameters, **model_state}, σp, mels, σ, ) Ō = statistics(O_loc.reshape(σ_shape[:-1]).T) O_loc -= Ō.mean # Then compute the vjp. # Code is a bit more complex than a standard one because we support # mutable state (if it's there) if mutable is False: _, vjp_fun = nkjax.vjp( lambda w: model_apply_fun({"params": w, **model_state}, σ), parameters, conjugate=True, ) new_model_state = None else: _, vjp_fun, new_model_state = nkjax.vjp( lambda w: model_apply_fun({"params": w, **model_state}, σ, mutable=mutable), parameters, conjugate=True, has_aux=True, ) Ō_grad = vjp_fun(jnp.conjugate(O_loc) / n_samples)[0] Ō_grad = jax.tree_multimap( lambda x, target: (x if jnp.iscomplexobj(target) else x.real).astype( target.dtype ), Ō_grad, parameters, ) return Ō, tree_map(sum_inplace, Ō_grad), new_model_state
def grad_expect_operator_kernel( local_value_kernel: Callable, model_apply_fun: Callable, machine_pow: int, mutable: bool, parameters: PyTree, model_state: PyTree, σ: jnp.ndarray, local_value_args: PyTree, ) -> Tuple[PyTree, PyTree, Stats]: if not config.FLAGS["NETKET_EXPERIMENTAL"]: raise RuntimeError( """ Computing the gradient of a squared or non hermitian operator is an experimental feature under development and is known not to return wrong values sometimes. If you want to debug it, set the environment variable NETKET_EXPERIMENTAL=1 """ ) σ_shape = σ.shape if jnp.ndim(σ) != 2: σ = σ.reshape((-1, σ_shape[-1])) is_mutable = mutable is not False logpsi = lambda w, σ: model_apply_fun( {"params": w, **model_state}, σ, mutable=mutable ) log_pdf = ( lambda w, σ: machine_pow * model_apply_fun({"params": w, **model_state}, σ).real ) def expect_closure_pars(pars): return nkjax.expect( log_pdf, partial(local_value_kernel, logpsi), pars, σ, local_value_args, n_chains=σ_shape[0], ) Ō, Ō_pb, Ō_stats = nkjax.vjp(expect_closure_pars, parameters, has_aux=True) Ō_pars_grad = Ō_pb(jnp.ones_like(Ō))[0] if is_mutable: raise NotImplementedError( "gradient of non-hermitian operators over mutable models " "is not yet implemented." ) new_model_state = None return ( Ō_stats, jax.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], Ō_pars_grad), new_model_state, )
def grad_expect_operator_kernel( local_value_kernel: Callable, model_apply_fun: Callable, machine_pow: int, mutable: bool, parameters: PyTree, model_state: PyTree, σ: jnp.ndarray, local_value_args: PyTree, ) -> Tuple[PyTree, PyTree, Stats]: σ_shape = σ.shape if jnp.ndim(σ) != 2: σ = σ.reshape((-1, σ_shape[-1])) is_mutable = mutable is not False logpsi = lambda w, σ: model_apply_fun( {"params": w, **model_state}, σ, mutable=mutable ) log_pdf = ( lambda w, σ: machine_pow * model_apply_fun({"params": w, **model_state}, σ).real ) def expect_closure_pars(pars): return nkjax.expect( log_pdf, partial(local_value_kernel, logpsi), pars, σ, local_value_args, n_chains=σ_shape[0], ) Ō, Ō_pb, Ō_stats = nkjax.vjp( expect_closure_pars, parameters, has_aux=True, conjugate=True ) Ō_pars_grad = Ō_pb(jnp.ones_like(Ō))[0] # This term below is needed otherwise it does not match the value obtained by # (ha@ha).collect(). I'm unsure of why it is needed. Ō_pars_grad = jax.tree_multimap( lambda x, target: x / 2 if jnp.iscomplexobj(target) else x, Ō_pars_grad, parameters, ) if is_mutable: raise NotImplementedError( "gradient of non-hermitian operators over mutable models " "is not yet implemented." ) new_model_state = None return ( Ō_stats, jax.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], Ō_pars_grad), new_model_state, )
def _local_values_and_grads_notcentered_kernel(logpsi, pars, vp, mel, v): logpsi_vp, f_vjp = nkjax.vjp(lambda w: logpsi(w, vp), pars, conjugate=False) vec = mel * jax.numpy.exp(logpsi_vp - logpsi(pars, v)) # TODO : here someone must bring order to those multiple conjs odtype = jax.eval_shape(logpsi, pars, v).dtype vec = jnp.asarray(jnp.conjugate(vec), dtype=odtype) loc_val = vec.sum() grad_c = f_vjp(vec.conj())[0] return loc_val, grad_c
def grad_expect_hermitian( local_value_kernel: Callable, model_apply_fun: Callable, mutable: bool, parameters: PyTree, model_state: PyTree, σ: jnp.ndarray, local_value_args: PyTree, ) -> Tuple[PyTree, PyTree]: σ_shape = σ.shape if jnp.ndim(σ) != 2: σ = σ.reshape((-1, σ_shape[-1])) n_samples = σ.shape[0] * mpi.n_nodes O_loc = local_value_kernel( model_apply_fun, {"params": parameters, **model_state}, σ, local_value_args, ) Ō = statistics(O_loc.reshape(σ_shape[:-1]).T) O_loc -= Ō.mean # Then compute the vjp. # Code is a bit more complex than a standard one because we support # mutable state (if it's there) is_mutable = mutable is not False _, vjp_fun, *new_model_state = nkjax.vjp( lambda w: model_apply_fun({"params": w, **model_state}, σ, mutable=mutable), parameters, conjugate=True, has_aux=is_mutable, ) Ō_grad = vjp_fun(jnp.conjugate(O_loc) / n_samples)[0] Ō_grad = jax.tree_multimap( lambda x, target: (x if jnp.iscomplexobj(target) else 2 * x.real).astype( target.dtype ), Ō_grad, parameters, ) new_model_state = new_model_state[0] if is_mutable else None return Ō, jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], Ō_grad), new_model_state
def tree_ravel(pytree: PyTree) -> Tuple[jnp.ndarray, Callable]: """Ravel (i.e. flatten) a pytree of arrays down to a 1D array. Args: pytree: a pytree to ravel Returns: A pair where the first element is a 1D array representing the flattened and concatenated leaf values, and the second element is a callable for unflattening a 1D vector of the same length back to a pytree of of the same structure as the input ``pytree``. """ leaves, treedef = tree_flatten(pytree) flat, unravel_list = nkjax.vjp(_ravel_list, *leaves) unravel_pytree = lambda flat: tree_unflatten(treedef, unravel_list(flat)) return flat, unravel_pytree
def _exp_grad( model_apply_fun: Callable, mutable: bool, parameters: PyTree, model_state: PyTree, σ: jnp.ndarray, OΨ: jnp.ndarray, Ψ: jnp.ndarray, ) -> Tuple[PyTree, PyTree]: is_mutable = mutable is not False expval_O = (Ψ.conj() * OΨ).sum() ΔOΨ = (OΨ - expval_O * Ψ.conj()) * Ψ _, vjp_fun, *new_model_state = nkjax.vjp( lambda w: model_apply_fun({"params": w, **model_state}, σ, mutable=mutable), parameters, conjugate=True, has_aux=is_mutable, ) Ō_grad = vjp_fun(ΔOΨ)[0] Ō_grad = jax.tree_multimap( lambda x, target: (x if jnp.iscomplexobj(target) else 2 * x.real).astype( target.dtype ), Ō_grad, parameters, ) new_model_state = new_model_state[0] if is_mutable else None return ( None, Ō_grad, expval_O, new_model_state, )
def grad_expect_operator_kernel( machine_pow: int, model_apply_fun: Callable, local_kernel: Callable, mutable: bool, parameters: PyTree, model_state: PyTree, σ: jnp.ndarray, σp: jnp.ndarray, mels: jnp.ndarray, ) -> Tuple[PyTree, PyTree, Stats]: if not config.FLAGS["NETKET_EXPERIMENTAL"]: raise RuntimeError( """ Computing the gradient of a squared or non hermitian operator is an experimental feature under development and is known not to return wrong values sometimes. If you want to debug it, set the environment variable NETKET_EXPERIMENTAL=1 """ ) σ_shape = σ.shape if jnp.ndim(σ) != 2: σ = σ.reshape((-1, σ_shape[-1])) has_aux = mutable is not False if not has_aux: out_axes = (0, 0) else: out_axes = (0, 0, 0) if not has_aux: logpsi = lambda w, σ: model_apply_fun({"params": w, **model_state}, σ) else: # TODO: output the mutable state logpsi = lambda w, σ: model_apply_fun( {"params": w, **model_state}, σ, mutable=mutable )[0] log_pdf = ( lambda w, σ: machine_pow * model_apply_fun({"params": w, **model_state}, σ).real ) def expect_closure(*args): local_kernel_vmap = jax.vmap( partial(local_kernel, logpsi), in_axes=(None, 0, 0, 0), out_axes=0 ) return nkjax.expect(log_pdf, local_kernel_vmap, *args, n_chains=σ_shape[0]) def expect_closure_pars(pars): return expect_closure(pars, σ, σp, mels) Ō, Ō_pb, Ō_stats = nkjax.vjp(expect_closure_pars, parameters, has_aux=True) Ō_pars_grad = Ō_pb(jnp.ones_like(Ō)) return ( Ō_stats, tree_map(lambda x: sum_inplace(x) / utils.n_nodes, Ō_pars_grad), model_state, )
def grad_expect_operator_Lrho2( model_apply_fun: Callable, mutable: bool, parameters: PyTree, model_state: PyTree, σ: jnp.ndarray, σp: jnp.ndarray, mels: jnp.ndarray, ) -> Tuple[PyTree, PyTree, Stats]: σ_shape = σ.shape if jnp.ndim(σ) != 2: σ = σ.reshape((-1, σ_shape[-1])) n_samples_node = σ.shape[0] has_aux = mutable is not False # if not has_aux: # out_axes = (0, 0) # else: # out_axes = (0, 0, 0) if not has_aux: logpsi = lambda w, σ: model_apply_fun({"params": w, **model_state}, σ) else: # TODO: output the mutable state logpsi = lambda w, σ: model_apply_fun( {"params": w, **model_state}, σ, mutable=mutable )[0] # local_kernel_vmap = jax.vmap( # partial(local_value_kernel, logpsi), in_axes=(None, 0, 0, 0), out_axes=0 # ) # _Lρ = local_kernel_vmap(parameters, σ, σp, mels).reshape((σ_shape[0], -1)) ( Lρ, der_loc_vals, ) = _local_values_and_grads_notcentered_kernel(logpsi, parameters, σp, mels, σ) # _local_values_and_grads_notcentered_kernel returns a loc_val that is conjugated Lρ = jnp.conjugate(Lρ) LdagL_stats = statistics((jnp.abs(Lρ) ** 2).T) LdagL_mean = LdagL_stats.mean _logpsi_ave, d_logpsi = nkjax.vjp(lambda w: logpsi(w, σ), parameters) # TODO: this ones_like might produce a complexXX type but we only need floatXX # and we cut in 1/2 the # of operations to do. der_logs_ave = d_logpsi( jnp.ones_like(_logpsi_ave).real / (n_samples_node * mpi.n_nodes) )[0] der_logs_ave = jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], der_logs_ave) def gradfun(der_loc_vals, der_logs_ave): par_dims = der_loc_vals.ndim - 1 _lloc_r = Lρ.reshape((n_samples_node,) + tuple(1 for i in range(par_dims))) grad = mean(der_loc_vals.conjugate() * _lloc_r, axis=0) - ( der_logs_ave.conjugate() * LdagL_mean ) return grad LdagL_grad = jax.tree_util.tree_multimap(gradfun, der_loc_vals, der_logs_ave) # ⟨L†L⟩ ∈ R, so if the parameters are real we should cast away # the imaginary part of the gradient. # we do this also for standard gradient of energy. # this avoid errors in #867, #789, #850 LdagL_grad = jax.tree_multimap( lambda x, target: (x if jnp.iscomplexobj(target) else x.real).astype( target.dtype ), LdagL_grad, parameters, ) return ( LdagL_stats, LdagL_grad, model_state, )