def jacobian_cplx( forward_fn: Callable, params: PyTree, samples: Array, chunk_size: int = None, _build_fn: Callable = partial(jax.tree_map, jax.lax.complex), ) -> PyTree: """Calculates Jacobian entries by vmapping grad. Assumes the function is R→C, backpropagates 1 and -1j Args: forward_fn: the log wavefunction ln Ψ params : a pytree of parameters p samples : an array of n samples σ Returns: The Jacobian matrix ∂/∂pₖ ln Ψ(σⱼ) as a PyTree """ def _jacobian_cplx(forward_fn, params, samples, _build_fn): y, vjp_fun = jax.vjp(single_sample(forward_fn), params, samples) gr, _ = vjp_fun(np.array(1.0, dtype=jnp.result_type(y))) gi, _ = vjp_fun(np.array(-1.0j, dtype=jnp.result_type(y))) return _build_fn(gr, gi) return vmap_chunked(_jacobian_cplx, in_axes=(None, None, 0, None), chunk_size=chunk_size)(forward_fn, params, samples, _build_fn)
def local_value_kernel_chunked( logpsi: Callable, pars: PyTree, σ: Array, args: PyTree, *, chunk_size: Optional[int] = None, ): """ local_value kernel for MCState and generic operators """ σp, mels = args if jnp.ndim(σp) != 3: σp = σp.reshape((σ.shape[0], -1, σ.shape[-1])) mels = mels.reshape(σp.shape[:-1]) logpsi_chunked = nkjax.vmap_chunked( partial(logpsi, pars), in_axes=0, chunk_size=chunk_size ) N = σ.shape[-1] logpsi_σ = logpsi_chunked(σ.reshape((-1, N))).reshape(σ.shape[:-1] + (1,)) logpsi_σp = logpsi_chunked(σp.reshape((-1, N))).reshape(σp.shape[:-1]) return jnp.sum(mels * jnp.exp(logpsi_σp - logpsi_σ), axis=-1)
def jacobian_real_holo(forward_fn: Callable, params: PyTree, samples: Array, chunk_size: int = None) -> PyTree: """Calculates Jacobian entries by vmapping grad. Assumes the function is R→R or holomorphic C→C, so single grad is enough Args: forward_fn: the log wavefunction ln Ψ params : a pytree of parameters p samples : an array of n samples σ Returns: The Jacobian matrix ∂/∂pₖ ln Ψ(σⱼ) as a PyTree """ def _jacobian_real_holo(forward_fn, params, samples): y, vjp_fun = jax.vjp(single_sample(forward_fn), params, samples) res, _ = vjp_fun(np.array(1.0, dtype=jnp.result_type(y))) return res return vmap_chunked(_jacobian_real_holo, in_axes=(None, None, 0), chunk_size=chunk_size)(forward_fn, params, samples)
def prepare_centered_oks( apply_fun: Callable, params: PyTree, samples: Array, model_state: Optional[PyTree], mode: str, rescale_shift: bool, chunk_size: int = None, ) -> PyTree: """ compute ΔOⱼₖ = Oⱼₖ - ⟨Oₖ⟩ = ∂/∂pₖ ln Ψ(σⱼ) - ⟨∂/∂pₖ ln Ψ⟩ divided by √n In a somewhat intransparent way this also internally splits all parameters to real in the 'real' and 'complex' modes (for C→R, R&C→R, R&C→C and general C→C) resulting in the respective ΔOⱼₖ which is only compatible with split-to-real pytree vectors Args: apply_fun: The forward pass of the Ansatz params : a pytree of parameters p samples : an array of (n in total) batched samples σ model_state: untrained state parameters of the model mode: differentiation mode, must be one of 'real', 'complex', 'holomorphic' rescale_shift: whether scale-invariant regularisation should be used (default: True) chunk_size: an int specfying the size of the chunks degradient should be computed in (default: None) Returns: if not rescale_shift: a pytree representing the centered jacobian of ln Ψ evaluated at the samples σ, divided by √n; None else: the same pytree, but the entries for each parameter normalised to unit norm; pytree containing the norms that were divided out (same shape as params) """ # un-batch the samples samples = samples.reshape((-1, samples.shape[-1])) # pre-apply the model state def forward_fn(W, σ): return apply_fun({"params": W, **model_state}, σ) if mode == "real": split_complex_params = True # convert C→R and R&C→R to R→R jacobian_fun = dense_jacobian_real_holo elif mode == "complex": split_complex_params = True # convert C→C and R&C→C to R→C # centered_jacobian_fun = compose(stack_jacobian, centered_jacobian_cplx) # avoid converting to complex and then back # by passing around the oks as a tuple of two pytrees representing the real and imag parts jacobian_fun = dense_jacobian_cplx elif mode == "holomorphic": split_complex_params = False jacobian_fun = dense_jacobian_real_holo else: raise NotImplementedError( 'Differentiation mode should be one of "real", "complex", or "holomorphic", got {}'.format( mode ) ) # Stored as contiguous real stacked on top of contiguous imaginary (SOA) if split_complex_params: # doesn't do anything if the params are already real params, reassemble = tree_to_reim(params) def f(W, σ): return forward_fn(reassemble(W), σ) else: f = forward_fn def gradf_fun(params, σ): gradf_dense = jacobian_fun(f, params, σ) return gradf_dense jacobians = nkjax.vmap_chunked(gradf_fun, in_axes=(None, 0), chunk_size=chunk_size)( params, samples ) n_samp = samples.shape[0] * mpi.n_nodes centered_oks = subtract_mean(jacobians, axis=0) / np.sqrt( n_samp, dtype=jacobians.dtype ) centered_oks = centered_oks.reshape(-1, centered_oks.shape[-1]) if rescale_shift: return _rescale(centered_oks) else: return centered_oks, None
def _local_continuous_kernel(kernel, logpsi, pars, σ, args, *, chunk_size=None): def _kernel(σ): return kernel(logpsi, pars, σ, args) return nkjax.vmap_chunked(_kernel, in_axes=0, chunk_size=chunk_size)(σ)