def _solve(self: QGTJacobianPyTreeT, solve_fun, y: PyTree, *, x0: Optional[PyTree] = None) -> PyTree: # Real-imaginary split RHS in R→R and R→C modes if self.mode != "holomorphic": y, reassemble = nkjax.tree_to_real(y) if x0 is not None: x0, _ = nkjax.tree_to_real(x0) check_valid_vector_type(self.params, y) if self.scale is not None: y = jax.tree_multimap(jnp.divide, y, self.scale) if x0 is not None: x0 = jax.tree_multimap(jnp.multiply, x0, self.scale) # to pass the object LinearOperator itself down # but avoid rescaling, we pass down an object with # scale = None # mode=holomoprhic to disable splitting the complex part unscaled_self = self.replace(scale=None, _in_solve=True) out, info = solve_fun(unscaled_self, y, x0=x0) if self.scale is not None: out = jax.tree_multimap(jnp.divide, out, self.scale) # Reassemble real-imaginary split as needed if self.mode != "holomorphic": out = reassemble(out) return out, info
def DeltaOdagger_DeltaO_v(forward_fn, params, samples, v, holomorphic=True): r""" compute \langle \Delta O^\dagger \Delta O \rangle v where \Delta O = O - \langle O \rangle """ homogeneous = nkjax.tree_ishomogeneous(params) real_params = not nkjax.tree_leaf_iscomplex(params) # real_out = not nkjax.is_complex(jax.eval_shape(forward_fn, params, samples)) if not (homogeneous and (real_params or holomorphic)): # everything except R->R, holomorphic C->C and R->C params, reassemble = nkjax.tree_to_real(params) v, _ = nkjax.tree_to_real(v) _forward_fn = forward_fn def forward_fn(p, x): return _forward_fn(reassemble(p), x) omean = O_mean(forward_fn, params, samples, holomorphic=holomorphic) def forward_fn_centered(p, x): return forward_fn(p, x) - tree_dot(p, omean) res = Odagger_O_v(forward_fn_centered, params, samples, v) if not (homogeneous and (real_params or holomorphic)): res = reassemble(res) return res
def vec_to_real(vec: Array) -> Tuple[Array, Callable]: """ If the input vector is real, splits the vector into real and imaginary parts and concatenates them along the 0-th axis. It is equivalent to changing the complex storage from AOS to SOA. Args: vec: a dense vector """ out, reassemble = nkjax.tree_to_real(vec) if nkjax.is_complex(vec): re, im = out out = jnp.concatenate([re, im], axis=0) def reassemble_concat(x): x = tuple(jnp.split(x, 2, axis=0)) return reassemble(x) else: reassemble_concat = reassemble return out, reassemble_concat
def test_matvec_treemv_modes(e, jit, holomorphic, pardtype, outdtype): diag_shift = 0.01 model_state = {} rescale_shift = False def apply_fun(params, samples): return e.f(params["params"], samples) mv = qgt_jacobian_pytree_logic.mat_vec homogeneous = pardtype is not None if not nkjax.is_complex_dtype(outdtype): mode = "real" elif homogeneous and nkjax.is_complex_dtype(pardtype) and holomorphic: mode = "holomorphic" else: mode = "complex" if mode == "holomorphic": v = e.v reassemble = lambda x: x else: v, reassemble = nkjax.tree_to_real(e.v) if jit: mv = jax.jit(mv) centered_oks, _ = qgt_jacobian_pytree_logic.prepare_centered_oks( apply_fun, e.params, e.samples, model_state, mode, rescale_shift) actual = reassemble(mv(v, centered_oks, diag_shift)) expected = reassemble_complex(e.S_real @ e.v_real_flat + diag_shift * e.v_real_flat, target=e.target) assert tree_allclose(actual, expected)
def _matmul(self: QGTJacobianPyTreeT, vec: Union[PyTree, Array]) -> Union[PyTree, Array]: # Turn vector RHS into PyTree if hasattr(vec, "ndim"): _, unravel = nkjax.tree_ravel(self.params) vec = unravel(vec) ravel = True else: ravel = False # Real-imaginary split RHS in R→R and R→C modes reassemble = None if self.mode != "holomorphic" and not self._in_solve: vec, reassemble = nkjax.tree_to_real(vec) check_valid_vector_type(self.params, vec) if self.scale is not None: vec = jax.tree_multimap(jnp.multiply, vec, self.scale) result = mat_vec(vec, self.O, self.diag_shift) if self.scale is not None: result = jax.tree_multimap(jnp.multiply, result, self.scale) # Reassemble real-imaginary split as needed if reassemble is not None: result = reassemble(result) # Ravel PyTree back into vector as needed if ravel: result, _ = nkjax.tree_ravel(result) return result
def prepare_centered_oks( apply_fun: Callable, params: PyTree, samples: Array, model_state: Optional[PyTree], mode: str, rescale_shift: bool, pdf=None, chunk_size: int = None, ) -> PyTree: """ compute ΔOⱼₖ = Oⱼₖ - ⟨Oₖ⟩ = ∂/∂pₖ ln Ψ(σⱼ) - ⟨∂/∂pₖ ln Ψ⟩ divided by √n In a somewhat intransparent way this also internally splits all parameters to real in the 'real' and 'complex' modes (for C→R, R&C→R, R&C→C and general C→C) resulting in the respective ΔOⱼₖ which is only compatible with split-to-real pytree vectors Args: apply_fun: The forward pass of the Ansatz params : a pytree of parameters p samples : an array of (n in total) batched samples σ model_state: untrained state parameters of the model mode: differentiation mode, must be one of 'real', 'complex', 'holomorphic' rescale_shift: whether scale-invariant regularisation should be used (default: True) pdf: |ψ(x)|^2 if exact optimization is being used else None chunk_size: an int specifying the size of the chunks the gradient should be computed in (default: None) Returns: if not rescale_shift: a pytree representing the centered jacobian of ln Ψ evaluated at the samples σ, divided by √n; None else: the same pytree, but the entries for each parameter normalised to unit norm; pytree containing the norms that were divided out (same shape as params) """ # un-batch the samples samples = samples.reshape((-1, samples.shape[-1])) # pre-apply the model state def forward_fn(W, σ): return apply_fun({"params": W, **model_state}, σ) if mode == "real": split_complex_params = True # convert C→R and R&C→R to R→R centered_jacobian_fun = centered_jacobian_real_holo jacobian_fun = jacobian_real_holo elif mode == "complex": split_complex_params = True # convert C→C and R&C→C to R→C # centered_jacobian_fun = compose(stack_jacobian, centered_jacobian_cplx) # avoid converting to complex and then back # by passing around the oks as a tuple of two pytrees representing the real and imag parts centered_jacobian_fun = compose( stack_jacobian_tuple, partial(centered_jacobian_cplx, _build_fn=lambda *x: x), ) jacobian_fun = jacobian_cplx elif mode == "holomorphic": split_complex_params = False centered_jacobian_fun = centered_jacobian_real_holo jacobian_fun = jacobian_real_holo else: raise NotImplementedError( 'Differentiation mode should be one of "real", "complex", or "holomorphic", got {}' .format(mode)) if split_complex_params: # doesn't do anything if the params are already real params, reassemble = tree_to_real(params) def f(W, σ): return forward_fn(reassemble(W), σ) else: f = forward_fn if pdf is None: centered_oks = _divide_by_sqrt_n_samp( centered_jacobian_fun( f, params, samples, chunk_size=chunk_size, ), samples, ) else: oks = jacobian_fun(f, params, samples) oks_mean = jax.tree_map(partial(sum, axis=0), _multiply_by_pdf(oks, pdf)) centered_oks = jax.tree_map(lambda x, y: x - y, oks, oks_mean) centered_oks = _multiply_by_pdf(centered_oks, jnp.sqrt(pdf)) if rescale_shift: return _rescale(centered_oks) else: return centered_oks, None