def _matmul(self: QGTJacobianPyTreeT, vec: Union[PyTree, Array]) -> Union[PyTree, Array]: # Turn vector RHS into PyTree if hasattr(vec, "ndim"): _, unravel = nkjax.tree_ravel(self.params) vec = unravel(vec) ravel = True else: ravel = False # Real-imaginary split RHS in R→R and R→C modes reassemble = None if self.mode != "holomorphic" and not self._in_solve: vec, reassemble = nkjax.tree_to_real(vec) check_valid_vector_type(self.params, vec) if self.scale is not None: vec = jax.tree_multimap(jnp.multiply, vec, self.scale) result = mat_vec(vec, self.O, self.diag_shift) if self.scale is not None: result = jax.tree_multimap(jnp.multiply, result, self.scale) # Reassemble real-imaginary split as needed if reassemble is not None: result = reassemble(result) # Ravel PyTree back into vector as needed if ravel: result, _ = nkjax.tree_ravel(result) return result
def _solve(self: QGTJacobianDenseT, solve_fun, y: PyTree, *, x0: Optional[PyTree] = None) -> PyTree: # Ravel input PyTrees, record unravelling function too y, unravel = nkjax.tree_ravel(y) if self.mode != "holomorphic": y, reassemble = vec_to_real(y) if x0 is not None: x0, _ = nkjax.tree_ravel(x0) if self.scale is not None: x0 = x0 * self.scale if self.scale is not None: y = y / self.scale # to pass the object LinearOperator itself down # but avoid rescaling, we pass down an object with # scale = None unscaled_self = self.replace(scale=None, _in_solve=True) out, info = solve_fun(unscaled_self, y, x0=x0) if self.scale is not None: out = out / self.scale if self.mode != "holomorphic": out = reassemble(out) return unravel(out), info
def onthefly_mat_treevec( S: QGTOnTheFly, vec: Union[PyTree, jnp.ndarray]) -> Union[PyTree, jnp.ndarray]: """ Perform the lazy mat-vec product, where vec is either a tree with the same structure as params or a ravelled vector """ # if hasa ndim it's an array and not a pytree if hasattr(vec, "ndim"): if not vec.ndim == 1: raise ValueError("Unsupported mat-vec for chunks of vectors") # If the input is a vector if not nkjax.tree_size(S._params) == vec.size: raise ValueError( """Size mismatch between number of parameters ({nkjax.tree_size(S.params)}) and vector size {vec.size}. """) _, unravel = nkjax.tree_ravel(S._params) vec = unravel(vec) ravel_result = True else: ravel_result = False check_valid_vector_type(S._params, vec) vec = nkjax.tree_cast(vec, S._params) res = S._mat_vec(vec, S.diag_shift) if ravel_result: res, _ = nkjax.tree_ravel(res) return res
def _to_dense(self: QGTJacobianPyTreeT) -> jnp.ndarray: O = jax.vmap(lambda l: nkjax.tree_ravel(l)[0])(self.O) if self.scale is None: diag = jnp.eye(O.shape[1]) else: scale, _ = nkjax.tree_ravel(self.scale) O = O * scale[jnp.newaxis, :] diag = jnp.diag(scale**2) return mpi.mpi_sum_jax(O.T.conj() @ O)[0] + self.diag_shift * diag
def lazysmatrix_mat_treevec( S: LazySMatrix, vec: Union[PyTree, jnp.ndarray]) -> Union[PyTree, jnp.ndarray]: """ Perform the lazy mat-vec product, where vec is either a tree with the same structure as params or a ravelled vector """ def fun(W, σ): return S.apply_fun({"params": W, **S.model_state}, σ) # if hasa ndim it's an array and not a pytree if hasattr(vec, "ndim"): if not vec.ndim == 1: raise ValueError("Unsupported mat-vec for batches of vectors") # If the input is a vector if not nkjax.tree_size(S.params) == vec.size: raise ValueError( """Size mismatch between number of parameters ({nkjax.tree_size(S.params)}) and vector size {vec.size}. """) _, unravel = nkjax.tree_ravel(S.params) vec = unravel(vec) ravel_result = True else: ravel_result = False samples = S.samples if jnp.ndim(samples) != 2: samples = samples.reshape((-1, samples.shape[-1])) vec = tree_cast(vec, S.params) mat_vec = partial( mat_vec_onthefly, forward_fn=fun, params=S.params, samples=samples, diag_shift=S.sr.diag_shift, centered=S.sr.centered, ) res = mat_vec(vec) if ravel_result: res, _ = nkjax.tree_ravel(res) return res
def _matmul(self: QGTJacobianDenseT, vec: Union[PyTree, jnp.ndarray]) -> Union[PyTree, jnp.ndarray]: unravel = None if not hasattr(vec, "ndim"): vec, unravel = nkjax.tree_ravel(vec) # Real-imaginary split RHS in R→R and R→C modes reassemble = None if self.mode != "holomorphic" and not self._in_solve: vec, reassemble = vec_to_real(vec) if self.scale is not None: vec = vec * self.scale result = (mpi.mpi_sum_jax(((self.O @ vec).T.conj() @ self.O).T.conj())[0] + self.diag_shift * vec) if self.scale is not None: result = result * self.scale if reassemble is not None: result = reassemble(result) if unravel is not None: result = unravel(result) return result
def solve(A, b, sym_pos=True, x0=None): del x0 A = A.to_dense() b, unravel = tree_ravel(b) x = jsp.linalg.solve(A, b, sym_pos=sym_pos) return unravel(x), None
def LU(A, b, trans=0, x0=None): del x0 A = A.to_dense() b, unravel = tree_ravel(b) lu, piv = jsp.linalg.lu_factor(A) x = jsp.linalg.lu_solve((lu, piv), b, trans=0) return unravel(x), None
def cholesky(A, b, lower=False, x0=None): del x0 A = A.to_dense() b, unravel = tree_ravel(b) c, low = jsp.linalg.cho_factor(A, lower=lower) x = jsp.linalg.cho_solve((c, low), b) return unravel(x), None
def svd(A, b, rcond=None, x0=None): """ Solve the linear system using Singular Value Decomposition. The diagonal shift on the matrix should be 0. Args: A: the matrix A in Ax=b b: the vector b in Ax=b rcond: The condition number """ del x0 A = A.to_dense() b, unravel = tree_ravel(b) x, residuals, rank, s = jnp.linalg.lstsq(A, b, rcond=rcond) return unravel(x), (residuals, rank, s)
def solve(A, b, sym_pos=True, x0=None): """ Solve the linear system. The diagonal shift on the matrix should be 0. Internally uses {ref}`jax.numpy.solve`. Args: A: the matrix A in Ax=b b: the vector b in Ax=b lower: if True uses the lower half of the A matrix x0: unused """ del x0 A = A.to_dense() b, unravel = tree_ravel(b) x = jsp.linalg.solve(A, b, sym_pos=sym_pos) return unravel(x), None
def LU(A, b, trans=0, x0=None): """ Solve the linear system using a LU Factorisation. The diagonal shift on the matrix should be 0. Internally uses {ref}`jax.numpy.linalg.lu_solve`. Args: A: the matrix A in Ax=b b: the vector b in Ax=b lower: if True uses the lower half of the A matrix x0: unused """ del x0 A = A.to_dense() b, unravel = tree_ravel(b) lu, piv = jsp.linalg.lu_factor(A) x = jsp.linalg.lu_solve((lu, piv), b, trans=0) return unravel(x), None
def cholesky(A, b, lower=False, x0=None): """ Solve the linear system using a Cholesky Factorisation. The diagonal shift on the matrix should be 0. Internally uses {ref}`jax.numpy.linalg.cho_solve`. Args: A: the matrix A in Ax=b b: the vector b in Ax=b lower: if True uses the lower half of the A matrix x0: unused """ del x0 A = A.to_dense() b, unravel = tree_ravel(b) c, low = jsp.linalg.cho_factor(A, lower=lower) x = jsp.linalg.cho_solve((c, low), b) return unravel(x), None
def ravel(x: PyTree) -> Array: """ shorthand for tree_ravel """ dense, _ = nkjax.tree_ravel(x) return dense