def onthefly_mat_treevec( S: QGTOnTheFly, vec: Union[PyTree, jnp.ndarray]) -> Union[PyTree, jnp.ndarray]: """ Perform the lazy mat-vec product, where vec is either a tree with the same structure as params or a ravelled vector """ # if hasa ndim it's an array and not a pytree if hasattr(vec, "ndim"): if not vec.ndim == 1: raise ValueError("Unsupported mat-vec for chunks of vectors") # If the input is a vector if not nkjax.tree_size(S._params) == vec.size: raise ValueError( """Size mismatch between number of parameters ({nkjax.tree_size(S.params)}) and vector size {vec.size}. """) _, unravel = nkjax.tree_ravel(S._params) vec = unravel(vec) ravel_result = True else: ravel_result = False check_valid_vector_type(S._params, vec) vec = nkjax.tree_cast(vec, S._params) res = S._mat_vec(vec, S.diag_shift) if ravel_result: res, _ = nkjax.tree_ravel(res) return res
def to_dense(self) -> jnp.ndarray: """ Convert the lazy matrix representation to a dense matrix representation.s Returns: A dense matrix representation of this S matrix. """ Npars = nkjax.tree_size(self.params) I = jax.numpy.eye(Npars) return jax.vmap(lambda S, x: self @ x, in_axes=(None, 0))(self, I)
def _to_dense(self: QGTOnTheFlyT) -> jnp.ndarray: """ Convert the lazy matrix representation to a dense matrix representation.s Returns: A dense matrix representation of this S matrix. """ Npars = nkjax.tree_size(self.params) I = jax.numpy.eye(Npars) out = jax.vmap(lambda x: self @ x, in_axes=0)(I) if nkjax.is_complex(out): out = out.T return out
def lazysmatrix_mat_treevec( S: LazySMatrix, vec: Union[PyTree, jnp.ndarray]) -> Union[PyTree, jnp.ndarray]: """ Perform the lazy mat-vec product, where vec is either a tree with the same structure as params or a ravelled vector """ def fun(W, σ): return S.apply_fun({"params": W, **S.model_state}, σ) # if hasa ndim it's an array and not a pytree if hasattr(vec, "ndim"): if not vec.ndim == 1: raise ValueError("Unsupported mat-vec for batches of vectors") # If the input is a vector if not nkjax.tree_size(S.params) == vec.size: raise ValueError( """Size mismatch between number of parameters ({nkjax.tree_size(S.params)}) and vector size {vec.size}. """) _, unravel = nkjax.tree_ravel(S.params) vec = unravel(vec) ravel_result = True else: ravel_result = False samples = S.samples if jnp.ndim(samples) != 2: samples = samples.reshape((-1, samples.shape[-1])) vec = tree_cast(vec, S.params) mat_vec = partial( mat_vec_onthefly, forward_fn=fun, params=S.params, samples=samples, diag_shift=S.sr.diag_shift, centered=S.sr.centered, ) res = mat_vec(vec) if ravel_result: res, _ = nkjax.tree_ravel(res) return res
def _to_dense(self: QGTOnTheFlyT) -> jnp.ndarray: """ Convert the lazy matrix representation to a dense matrix representation Returns: A dense matrix representation of this S matrix. """ Npars = nkjax.tree_size(self._params) I = jax.numpy.eye(Npars) if self._chunking: # the linear_call in mat_vec_chunked does currently not have a jax batching rule, # so it cannot be vmapped but we can use scan # which is better for reducing the memory consumption anyway _, out = jax.lax.scan(lambda _, x: (None, self @ x), None, I) else: out = jax.vmap(lambda x: self @ x, in_axes=0)(I) if nkjax.is_complex(out): out = out.T return out
def n_parameters(self) -> int: r"""The total number of parameters in the model.""" return nkjax.tree_size(self.parameters)