re_dense = ravel(re) im_dense = ravel(im) res = jnp.stack([re_dense, im_dense], axis=0) return res def ravel(x: PyTree) -> Array: """ shorthand for tree_ravel """ dense, _ = nkjax.tree_ravel(x) return dense dense_jacobian_real_holo = nkjax.compose(ravel, jacobian_real_holo) dense_jacobian_cplx = nkjax.compose( stack_jacobian_tuple, partial(jacobian_cplx, _build_fn=lambda *x: x) ) def _rescale(centered_oks): """ compute ΔOₖ/√Sₖₖ and √Sₖₖ to do scale-invariant regularization (Becca & Sorella 2017, pp. 143) Sₖₗ/(√Sₖₖ√Sₗₗ) = ΔOₖᴴΔOₗ/(√Sₖₖ√Sₗₗ) = (ΔOₖ/√Sₖₖ)ᴴ(ΔOₗ/√Sₗₗ) """ scale = ( mpi.mpi_sum_jax( jnp.sum((centered_oks * centered_oks.conj()).real, axis=0, keepdims=True) )[0]
def prepare_centered_oks( apply_fun: Callable, params: PyTree, samples: Array, model_state: Optional[PyTree], mode: str, rescale_shift: bool, pdf=None, chunk_size: int = None, ) -> PyTree: """ compute ΔOⱼₖ = Oⱼₖ - ⟨Oₖ⟩ = ∂/∂pₖ ln Ψ(σⱼ) - ⟨∂/∂pₖ ln Ψ⟩ divided by √n In a somewhat intransparent way this also internally splits all parameters to real in the 'real' and 'complex' modes (for C→R, R&C→R, R&C→C and general C→C) resulting in the respective ΔOⱼₖ which is only compatible with split-to-real pytree vectors Args: apply_fun: The forward pass of the Ansatz params : a pytree of parameters p samples : an array of (n in total) batched samples σ model_state: untrained state parameters of the model mode: differentiation mode, must be one of 'real', 'complex', 'holomorphic' rescale_shift: whether scale-invariant regularisation should be used (default: True) pdf: |ψ(x)|^2 if exact optimization is being used else None chunk_size: an int specifying the size of the chunks the gradient should be computed in (default: None) Returns: if not rescale_shift: a pytree representing the centered jacobian of ln Ψ evaluated at the samples σ, divided by √n; None else: the same pytree, but the entries for each parameter normalised to unit norm; pytree containing the norms that were divided out (same shape as params) """ # un-batch the samples samples = samples.reshape((-1, samples.shape[-1])) # pre-apply the model state def forward_fn(W, σ): return apply_fun({"params": W, **model_state}, σ) if mode == "real": split_complex_params = True # convert C→R and R&C→R to R→R centered_jacobian_fun = centered_jacobian_real_holo jacobian_fun = jacobian_real_holo elif mode == "complex": split_complex_params = True # convert C→C and R&C→C to R→C # centered_jacobian_fun = compose(stack_jacobian, centered_jacobian_cplx) # avoid converting to complex and then back # by passing around the oks as a tuple of two pytrees representing the real and imag parts centered_jacobian_fun = compose( stack_jacobian_tuple, partial(centered_jacobian_cplx, _build_fn=lambda *x: x), ) jacobian_fun = jacobian_cplx elif mode == "holomorphic": split_complex_params = False centered_jacobian_fun = centered_jacobian_real_holo jacobian_fun = jacobian_real_holo else: raise NotImplementedError( 'Differentiation mode should be one of "real", "complex", or "holomorphic", got {}' .format(mode)) if split_complex_params: # doesn't do anything if the params are already real params, reassemble = tree_to_real(params) def f(W, σ): return forward_fn(reassemble(W), σ) else: f = forward_fn if pdf is None: centered_oks = _divide_by_sqrt_n_samp( centered_jacobian_fun( f, params, samples, chunk_size=chunk_size, ), samples, ) else: oks = jacobian_fun(f, params, samples) oks_mean = jax.tree_map(partial(sum, axis=0), _multiply_by_pdf(oks, pdf)) centered_oks = jax.tree_map(lambda x, y: x - y, oks, oks_mean) centered_oks = _multiply_by_pdf(centered_oks, jnp.sqrt(pdf)) if rescale_shift: return _rescale(centered_oks) else: return centered_oks, None
partial(_vjp, nondiff_argnums=nondiff_argnums, conjugate=conjugate), scan_fun=scan_fun, argnums=argnums, )(fun, cotangents, *primals) return _multimap(lambda c, l: _tree_unchunk(l) if c else l, append_cond, res) def _gen_append_cond_vjp(primals, nondiff_argnums, chunk_argnums): diff_argnums = filter(lambda i: i not in nondiff_argnums, range(len(primals))) return tuple(map(lambda i: i in chunk_argnums, diff_argnums)) _gen_append_cond_value_vjp = compose(lambda t: (True, ) + t, _gen_append_cond_vjp) _vjp_fun_chunked = partial( __vjp_fun_chunked, _vjp=compose(lambda yr: yr[1:], _vjp), _append_cond_fun=_gen_append_cond_vjp, ) _value_and_vjp_fun_chunked = compose( lambda yr: (yr[0], yr[1:]), partial(__vjp_fun_chunked, _vjp=_vjp, _append_cond_fun=_gen_append_cond_value_vjp), ) def vjp_chunked(
Returns: The Jacobian matrix ∂/∂pₖ ln Ψ(σⱼ) as a PyTree """ def _jacobian_cplx(forward_fn, params, samples, _build_fn): y, vjp_fun = jax.vjp(single_sample(forward_fn), params, samples) gr, _ = vjp_fun(np.array(1.0, dtype=jnp.result_type(y))) gi, _ = vjp_fun(np.array(-1.0j, dtype=jnp.result_type(y))) return _build_fn(gr, gi) return vmap_chunked(_jacobian_cplx, in_axes=(None, None, 0, None), chunk_size=chunk_size)(forward_fn, params, samples, _build_fn) centered_jacobian_real_holo = compose(tree_subtract_mean, jacobian_real_holo) centered_jacobian_cplx = compose(tree_subtract_mean, jacobian_cplx) def _divide_by_sqrt_n_samp(oks, samples): """ divide Oⱼₖ by √n """ n_samp = samples.shape[0] * mpi.n_nodes # MPI return jax.tree_map(lambda x: x / np.sqrt(n_samp, dtype=x.dtype), oks) def _multiply_by_pdf(oks, pdf): """ Computes O'ⱼ̨ₖ = Oⱼₖ pⱼ . Used to multiply the log-derivatives by the probability density.