Ejemplo n.º 1
0
def _matmul(self: QGTJacobianPyTreeT,
            vec: Union[PyTree, Array]) -> Union[PyTree, Array]:
    # Turn vector RHS into PyTree
    if hasattr(vec, "ndim"):
        _, unravel = nkjax.tree_ravel(self.params)
        vec = unravel(vec)
        ravel = True
    else:
        ravel = False

    # Real-imaginary split RHS in R→R and R→C modes
    reassemble = None
    if self.mode != "holomorphic" and not self._in_solve:
        vec, reassemble = nkjax.tree_to_real(vec)

    check_valid_vector_type(self.params, vec)

    if self.scale is not None:
        vec = jax.tree_multimap(jnp.multiply, vec, self.scale)

    result = mat_vec(vec, self.O, self.diag_shift)

    if self.scale is not None:
        result = jax.tree_multimap(jnp.multiply, result, self.scale)

    # Reassemble real-imaginary split as needed
    if reassemble is not None:
        result = reassemble(result)

    # Ravel PyTree back into vector as needed
    if ravel:
        result, _ = nkjax.tree_ravel(result)

    return result
Ejemplo n.º 2
0
def _solve(self: QGTJacobianDenseT,
           solve_fun,
           y: PyTree,
           *,
           x0: Optional[PyTree] = None) -> PyTree:
    # Ravel input PyTrees, record unravelling function too
    y, unravel = nkjax.tree_ravel(y)

    if self.mode != "holomorphic":
        y, reassemble = vec_to_real(y)

    if x0 is not None:
        x0, _ = nkjax.tree_ravel(x0)
        if self.scale is not None:
            x0 = x0 * self.scale

    if self.scale is not None:
        y = y / self.scale

    # to pass the object LinearOperator itself down
    # but avoid rescaling, we pass down an object with
    # scale = None
    unscaled_self = self.replace(scale=None, _in_solve=True)

    out, info = solve_fun(unscaled_self, y, x0=x0)

    if self.scale is not None:
        out = out / self.scale

    if self.mode != "holomorphic":
        out = reassemble(out)

    return unravel(out), info
Ejemplo n.º 3
0
def onthefly_mat_treevec(
        S: QGTOnTheFly, vec: Union[PyTree,
                                   jnp.ndarray]) -> Union[PyTree, jnp.ndarray]:
    """
    Perform the lazy mat-vec product, where vec is either a tree with the same structure as
    params or a ravelled vector
    """

    # if hasa ndim it's an array and not a pytree
    if hasattr(vec, "ndim"):
        if not vec.ndim == 1:
            raise ValueError("Unsupported mat-vec for chunks of vectors")
        # If the input is a vector
        if not nkjax.tree_size(S._params) == vec.size:
            raise ValueError(
                """Size mismatch between number of parameters ({nkjax.tree_size(S.params)})
                                and vector size {vec.size}.
                             """)

        _, unravel = nkjax.tree_ravel(S._params)
        vec = unravel(vec)
        ravel_result = True
    else:
        ravel_result = False

    check_valid_vector_type(S._params, vec)

    vec = nkjax.tree_cast(vec, S._params)

    res = S._mat_vec(vec, S.diag_shift)

    if ravel_result:
        res, _ = nkjax.tree_ravel(res)

    return res
Ejemplo n.º 4
0
def _to_dense(self: QGTJacobianPyTreeT) -> jnp.ndarray:
    O = jax.vmap(lambda l: nkjax.tree_ravel(l)[0])(self.O)

    if self.scale is None:
        diag = jnp.eye(O.shape[1])
    else:
        scale, _ = nkjax.tree_ravel(self.scale)
        O = O * scale[jnp.newaxis, :]
        diag = jnp.diag(scale**2)

    return mpi.mpi_sum_jax(O.T.conj() @ O)[0] + self.diag_shift * diag
Ejemplo n.º 5
0
def lazysmatrix_mat_treevec(
        S: LazySMatrix, vec: Union[PyTree,
                                   jnp.ndarray]) -> Union[PyTree, jnp.ndarray]:
    """
    Perform the lazy mat-vec product, where vec is either a tree with the same structure as
    params or a ravelled vector
    """
    def fun(W, σ):
        return S.apply_fun({"params": W, **S.model_state}, σ)

    # if hasa ndim it's an array and not a pytree
    if hasattr(vec, "ndim"):
        if not vec.ndim == 1:
            raise ValueError("Unsupported mat-vec for batches of vectors")
        # If the input is a vector
        if not nkjax.tree_size(S.params) == vec.size:
            raise ValueError(
                """Size mismatch between number of parameters ({nkjax.tree_size(S.params)}) 
                                and vector size {vec.size}.
                             """)

        _, unravel = nkjax.tree_ravel(S.params)
        vec = unravel(vec)
        ravel_result = True
    else:
        ravel_result = False

    samples = S.samples
    if jnp.ndim(samples) != 2:
        samples = samples.reshape((-1, samples.shape[-1]))

    vec = tree_cast(vec, S.params)

    mat_vec = partial(
        mat_vec_onthefly,
        forward_fn=fun,
        params=S.params,
        samples=samples,
        diag_shift=S.sr.diag_shift,
        centered=S.sr.centered,
    )

    res = mat_vec(vec)

    if ravel_result:
        res, _ = nkjax.tree_ravel(res)

    return res
Ejemplo n.º 6
0
def _matmul(self: QGTJacobianDenseT,
            vec: Union[PyTree, jnp.ndarray]) -> Union[PyTree, jnp.ndarray]:

    unravel = None
    if not hasattr(vec, "ndim"):
        vec, unravel = nkjax.tree_ravel(vec)

    # Real-imaginary split RHS in R→R and R→C modes
    reassemble = None
    if self.mode != "holomorphic" and not self._in_solve:
        vec, reassemble = vec_to_real(vec)

    if self.scale is not None:
        vec = vec * self.scale

    result = (mpi.mpi_sum_jax(((self.O @ vec).T.conj() @ self.O).T.conj())[0] +
              self.diag_shift * vec)

    if self.scale is not None:
        result = result * self.scale

    if reassemble is not None:
        result = reassemble(result)

    if unravel is not None:
        result = unravel(result)

    return result
Ejemplo n.º 7
0
def solve(A, b, sym_pos=True, x0=None):
    del x0

    A = A.to_dense()
    b, unravel = tree_ravel(b)

    x = jsp.linalg.solve(A, b, sym_pos=sym_pos)
    return unravel(x), None
Ejemplo n.º 8
0
def LU(A, b, trans=0, x0=None):
    del x0

    A = A.to_dense()
    b, unravel = tree_ravel(b)

    lu, piv = jsp.linalg.lu_factor(A)
    x = jsp.linalg.lu_solve((lu, piv), b, trans=0)
    return unravel(x), None
Ejemplo n.º 9
0
def cholesky(A, b, lower=False, x0=None):
    del x0

    A = A.to_dense()
    b, unravel = tree_ravel(b)

    c, low = jsp.linalg.cho_factor(A, lower=lower)
    x = jsp.linalg.cho_solve((c, low), b)
    return unravel(x), None
Ejemplo n.º 10
0
def svd(A, b, rcond=None, x0=None):
    """
    Solve the linear system using Singular Value Decomposition.
    The diagonal shift on the matrix should be 0.

    Args:
        A: the matrix A in Ax=b
        b: the vector b in Ax=b
        rcond: The condition number
    """
    del x0

    A = A.to_dense()
    b, unravel = tree_ravel(b)

    x, residuals, rank, s = jnp.linalg.lstsq(A, b, rcond=rcond)

    return unravel(x), (residuals, rank, s)
Ejemplo n.º 11
0
def solve(A, b, sym_pos=True, x0=None):
    """
    Solve the linear system.
    The diagonal shift on the matrix should be 0.

    Internally uses {ref}`jax.numpy.solve`.

    Args:
        A: the matrix A in Ax=b
        b: the vector b in Ax=b
        lower: if True uses the lower half of the A matrix
        x0: unused
    """
    del x0

    A = A.to_dense()
    b, unravel = tree_ravel(b)

    x = jsp.linalg.solve(A, b, sym_pos=sym_pos)
    return unravel(x), None
Ejemplo n.º 12
0
def LU(A, b, trans=0, x0=None):
    """
    Solve the linear system using a LU Factorisation.
    The diagonal shift on the matrix should be 0.

    Internally uses {ref}`jax.numpy.linalg.lu_solve`.

    Args:
        A: the matrix A in Ax=b
        b: the vector b in Ax=b
        lower: if True uses the lower half of the A matrix
        x0: unused
    """

    del x0

    A = A.to_dense()
    b, unravel = tree_ravel(b)

    lu, piv = jsp.linalg.lu_factor(A)
    x = jsp.linalg.lu_solve((lu, piv), b, trans=0)
    return unravel(x), None
Ejemplo n.º 13
0
def cholesky(A, b, lower=False, x0=None):
    """
    Solve the linear system using a Cholesky Factorisation.
    The diagonal shift on the matrix should be 0.

    Internally uses {ref}`jax.numpy.linalg.cho_solve`.

    Args:
        A: the matrix A in Ax=b
        b: the vector b in Ax=b
        lower: if True uses the lower half of the A matrix
        x0: unused
    """

    del x0

    A = A.to_dense()
    b, unravel = tree_ravel(b)

    c, low = jsp.linalg.cho_factor(A, lower=lower)
    x = jsp.linalg.cho_solve((c, low), b)
    return unravel(x), None
Ejemplo n.º 14
0
def ravel(x: PyTree) -> Array:
    """
    shorthand for tree_ravel
    """
    dense, _ = nkjax.tree_ravel(x)
    return dense