Пример #1
0
def ab_decomposition(
        u: jnp.ndarray,
        v: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """Decompose vector v as follows
    v = u @ a + u_orth @ b. If vector v is a tangent,
    then a matrix is  skew-hermitian.

    Args:
        u: array like of shape (..., n, m).
        v: array like of shape (..., n, m).

    Returns:
        elements of decomposition a, b, and u_orth."""

    n, m = u.shape[-2:]
    tail = u.shape[:-2]
    u = u.reshape((-1, n, m))
    v = v.reshape((-1, n, m))
    u_orth = vmap(lambda x: jnp.linalg.qr(x, mode='complete')[0])(u)[..., m:]
    a = u.conj().transpose((0, 2, 1)) @ v
    b = u_orth.conj().transpose((0, 2, 1)) @ v
    a = a.reshape((*tail, -1, m))
    b = b.reshape((*tail, -1, m))
    u_orth = u_orth.reshape((*tail, n, -1))
    return a, b, u_orth
Пример #2
0
    def _apply(
        self,
        iter: jnp.ndarray,
        grad: jnp.ndarray,
        state: Tuple[jnp.ndarray],
        param: jnp.ndarray,
        precond: Union[None, jnp.ndarray],
        use_precond=False
    ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray]]:
        if use_precond:
            rgrad = self.manifold.egrad_to_rgrad(param, grad.conj(), precond)
        else:
            rgrad = self.manifold.egrad_to_rgrad(param, grad.conj())
        momentum = self.beta1 * state[0] + (1 - self.beta1) * rgrad
        if use_precond:
            v = self.beta2 * state[1] + (1 - self.beta2) * self.manifold.inner(
                param, rgrad, rgrad, precond
            )
        else:
            v = self.beta2 * state[1] + (1 - self.beta2) * self.manifold.inner(
                param, rgrad, rgrad
            )
        if self.ams:
            v_hat = jax.lax.complex(jnp.maximum(jnp.real(v), jnp.real(state[2])), jnp.imag(v))

        # Bias correction
        lr_corr = (
            self.learning_rate
            * jnp.sqrt(1 - self.beta2 ** (iter + 1))
            / (1 - self.beta1 ** (iter + 1))
        )

        if self.ams:
            search_dir = -lr_corr * momentum / (jnp.sqrt(v_hat) + self.eps)
            param, momentum = self.manifold.retraction_transport(
                param, momentum, search_dir
            )
            return param, (momentum, v, v_hat)
        else:
            search_dir = -lr_corr * momentum / (jnp.sqrt(v) + self.eps)
            param, momentum = self.manifold.retraction_transport(
                param, momentum, search_dir
            )
            return param, (momentum, v)
Пример #3
0
    def inner(self,
              u: jnp.ndarray,
              vec1: jnp.ndarray,
              vec2: jnp.ndarray,
              precond: Union[None, jnp.ndarray] = None) -> jnp.ndarray:
        """Returns manifold wise inner product of vectors from
        a tangent space.

        Args:
            u: complex valued tensor of shape (..., n, p),
                a set of points from the complex Stiefel
                manifold.
            vec1: complex valued tensor of shape (..., n, p),
                a set of tangent vectors from the complex
                Stiefel manifold.
            vec2: complex valued tensor of shape (..., n, p),
                a set of tangent vectors from the complex
                Stiefel manifold.
            precond: complex valued tensor of shape (..., p, p),
                optional preconditioner representing natural metric
                of an isometric tensor network.

        Returns:
            complex valued tensor of shape (..., 1, 1),
            manifold wise inner product

        Note:
            The complexity for the 'euclidean' metric is O(pn),
            the complexity for the 'canonical' metric is O(np^2)"""

        if not self._use_precond:
            if self._metric == "euclidean":
                s_sq = (vec1.conj() * vec2).sum(keepdims=True, axis=(-2, -1))
            elif self._metric == "canonical":
                s_sq_1 = (vec1.conj() * vec2).sum(keepdims=True, axis=(-2, -1))
                vec1_dag_u = adj(vec1) @ u
                u_dag_vec2 = adj(u) @ vec2
                s_sq_2 = (u_dag_vec2 * transp(vec1_dag_u)).sum(axis=(-2, -1),
                                                               keepdims=True)
                s_sq = s_sq_1 - 0.5 * s_sq_2
        else:
            s_sq = (vec1.conj() * (vec2 @ precond)).sum(keepdims=True,
                                                        axis=(-2, -1))
        return jnp.real(s_sq).astype(dtype=u.dtype)
Пример #4
0
 def _apply(
     self,
     iter: jnp.ndarray,
     grad: jnp.ndarray,
     state: Tuple[jnp.ndarray],
     param: jnp.ndarray,
     precond: Union[None, jnp.ndarray]=None,
     use_precond=False
 ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray]]:
     if use_precond:
         rgrad = self.manifold.egrad_to_rgrad(param, grad.conj(), precond)
     else:
         rgrad = self.manifold.egrad_to_rgrad(param, grad.conj())
     if self.use_momentum:
         momentum = self.momentum * state[0] + (1 - self.momentum) * rgrad
         param, momentum = self.manifold.retraction_transport(
             param, momentum, -self.learning_rate * momentum
         )
         return param, (momentum,)
     else:
         param = self.manifold.retraction(param, -self.learning_rate * rgrad)
         return param, state
Пример #5
0
def adj(a: jnp.ndarray) -> jnp.ndarray:
    """Returns adjoint matrix.

    Args:
        a: complex valued tensor of shape (..., n1, n2)

    Returns:
        complex valued tensor of shape (..., n2, n1)"""

    matrix_shape = a.shape[-2:]
    bs_shape = a.shape[:-2]
    a = a.reshape((-1, *matrix_shape))
    a = a.transpose((0, 2, 1))
    a = a.reshape((*bs_shape, matrix_shape[1], matrix_shape[0]))
    a = a.conj()
    return a
Пример #6
0
Файл: utils.py Проект: jackd/jju
def symmetrize(A: jnp.ndarray) -> jnp.ndarray:
    """Make symmetric and hermitian."""
    return (A + A.conj().T) / 2