def Odagger_DeltaO_v(forward_fn, params, samples, v): w = O_jvp(forward_fn, params, samples, v) w = w * (1.0 / (samples.shape[0] * samples.shape[1] * mpi.n_nodes)) w_, chunk_fn = unchunk(w) w = chunk_fn(subtract_mean(w_)) # w/ MPI res = OH_w(forward_fn, params, samples, w) return jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], res) # MPI
def mat_vec(jvp_fn, v, diag_shift): # Save linearisation work # TODO move to mat_vec_factory after jax v0.2.19 vjp_fn = jax.linear_transpose(jvp_fn, v) w = jvp_fn(v) w = w * (1.0 / (w.size * mpi.n_nodes)) w = subtract_mean(w) # w/ MPI # Oᴴw = (wᴴO)ᴴ = (w* O)* since 1D arrays are not transposed # vjp_fn packages output into a length-1 tuple (res, ) = tree_conj(vjp_fn(w.conjugate())) res = jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], res) return tree_axpy(diag_shift, v, res) # res + diag_shift * v
def Odagger_O_v(forward_fn, params, samples, v, *, center=False): r""" if center=False (default): compute \langle O^\dagger O \rangle v else (center=True): compute \langle O^\dagger \Delta O \rangle v where \Delta O = O - \langle O \rangle """ # w is an array of size n_samples; each MPI rank has its own slice w = O_jvp(forward_fn, params, samples, v) # w /= n_samples (elementwise): w = w * (1.0 / (samples.shape[0] * mpi.n_nodes)) if center: w = subtract_mean(w) # w/ MPI return OH_w(forward_fn, params, samples, w)
def Odagger_O_v(samples, params, v, forward_fn, *, vjp_fun=None, center=False): r""" if center=False (default): compute \langle O^\dagger O \rangle v else (center=True): compute \langle O^\dagger \Delta O \rangle v where \Delta O = O - \langle O \rangle optional: pass vjp_fun to be reused """ # w is an array of size n_samples; each MPI rank has its own slice w = O_jvp(samples, params, v, forward_fn) # w /= n_samples (elementwise): w = w * (1.0 / (samples.shape[0] * n_nodes)) if center: w = subtract_mean(w) # w/ MPI return OH_w(samples, params, w, forward_fn, vjp_fun=vjp_fun)
def prepare_centered_oks( apply_fun: Callable, params: PyTree, samples: Array, model_state: Optional[PyTree], mode: str, rescale_shift: bool, chunk_size: int = None, ) -> PyTree: """ compute ΔOⱼₖ = Oⱼₖ - ⟨Oₖ⟩ = ∂/∂pₖ ln Ψ(σⱼ) - ⟨∂/∂pₖ ln Ψ⟩ divided by √n In a somewhat intransparent way this also internally splits all parameters to real in the 'real' and 'complex' modes (for C→R, R&C→R, R&C→C and general C→C) resulting in the respective ΔOⱼₖ which is only compatible with split-to-real pytree vectors Args: apply_fun: The forward pass of the Ansatz params : a pytree of parameters p samples : an array of (n in total) batched samples σ model_state: untrained state parameters of the model mode: differentiation mode, must be one of 'real', 'complex', 'holomorphic' rescale_shift: whether scale-invariant regularisation should be used (default: True) chunk_size: an int specfying the size of the chunks degradient should be computed in (default: None) Returns: if not rescale_shift: a pytree representing the centered jacobian of ln Ψ evaluated at the samples σ, divided by √n; None else: the same pytree, but the entries for each parameter normalised to unit norm; pytree containing the norms that were divided out (same shape as params) """ # un-batch the samples samples = samples.reshape((-1, samples.shape[-1])) # pre-apply the model state def forward_fn(W, σ): return apply_fun({"params": W, **model_state}, σ) if mode == "real": split_complex_params = True # convert C→R and R&C→R to R→R jacobian_fun = dense_jacobian_real_holo elif mode == "complex": split_complex_params = True # convert C→C and R&C→C to R→C # centered_jacobian_fun = compose(stack_jacobian, centered_jacobian_cplx) # avoid converting to complex and then back # by passing around the oks as a tuple of two pytrees representing the real and imag parts jacobian_fun = dense_jacobian_cplx elif mode == "holomorphic": split_complex_params = False jacobian_fun = dense_jacobian_real_holo else: raise NotImplementedError( 'Differentiation mode should be one of "real", "complex", or "holomorphic", got {}'.format( mode ) ) # Stored as contiguous real stacked on top of contiguous imaginary (SOA) if split_complex_params: # doesn't do anything if the params are already real params, reassemble = tree_to_reim(params) def f(W, σ): return forward_fn(reassemble(W), σ) else: f = forward_fn def gradf_fun(params, σ): gradf_dense = jacobian_fun(f, params, σ) return gradf_dense jacobians = nkjax.vmap_chunked(gradf_fun, in_axes=(None, 0), chunk_size=chunk_size)( params, samples ) n_samp = samples.shape[0] * mpi.n_nodes centered_oks = subtract_mean(jacobians, axis=0) / np.sqrt( n_samp, dtype=jacobians.dtype ) centered_oks = centered_oks.reshape(-1, centered_oks.shape[-1]) if rescale_shift: return _rescale(centered_oks) else: return centered_oks, None