コード例 #1
0
def deltas_to_tangent(deltas: List[jnp.ndarray],
                      tt: TTTensOrMat) -> TTTensOrMat:
    """Converts deltas representation of tangent space vector to `TT-object`.
  Takes as input a list of [dP1, ..., dPd] and returns
  dP1 V2 ... Vd + U1 dP2 V3 ... Vd + ... + U1 ... Ud-1 dPd.
  
  This function is hard to use correctly because deltas should obey the
  so called gauge conditions. If they don't, the function will silently return
  incorrect result. That is why this function is not imported in __init__.
  
  :param deltas: a list of deltas (essentially `TT-cores`) obeying the gauge
                 conditions.
  :param tt: object on which the tangent space tensor represented
             by delta is projected.
  :type tt: `TT-Tensor` or `TT-Matrix`
  :return: object constructed from deltas, that is from the tangent
           space at point `tt`.
  :rtype: `TT-Tensor` or `TT-Matrix`
  """
    cores = []
    dtype = tt.dtype
    left = orthogonalize(tt)
    right = orthogonalize(left, left_to_right=False)
    left_rank_dim = 0
    right_rank_dim = 3 if tt.is_tt_matrix else 2
    for i in range(tt.ndim):
        left_tt_core = left.tt_cores[i]
        right_tt_core = right.tt_cores[i]

        if i == 0:
            tangent_core = jnp.concatenate((deltas[i], left_tt_core),
                                           axis=right_rank_dim)
        elif i == tt.ndim - 1:
            tangent_core = jnp.concatenate((right_tt_core, deltas[i]),
                                           axis=left_rank_dim)
        else:
            rank_1 = right.tt_ranks[i]
            rank_2 = left.tt_ranks[i + 1]
            if tt.is_tt_matrix:
                mode_size_n = tt.raw_tensor_shape[0][i]
                mode_size_m = tt.raw_tensor_shape[1][i]
                shape = [rank_1, mode_size_n, mode_size_m, rank_2]
            else:
                mode_size_n = tt.shape[i]
                shape = [rank_1, mode_size_n, rank_2]
            zeros = jnp.zeros(shape, dtype=dtype)
            upper = jnp.concatenate((right_tt_core, zeros),
                                    axis=right_rank_dim)
            lower = jnp.concatenate((deltas[i], left_tt_core),
                                    axis=right_rank_dim)
            tangent_core = jnp.concatenate((upper, lower), axis=left_rank_dim)
        cores.append(tangent_core)
    if tt.is_tt_matrix:
        return TTMatrix(cores)
    else:
        return TT(cores)
コード例 #2
0
    def _grad(x: TTTensOrMat) -> TangentVector:
        # TODO: support runtime checks
        left = decompositions.orthogonalize(x)
        right = decompositions.orthogonalize(left, left_to_right=False)
        deltas = [right.tt_cores[0]]
        deltas += [jnp.zeros_like(cc) for cc in right.tt_cores[1:]]

        def augmented_func(d):
            x_projection = riemannian.deltas_to_tangent(d, x)
            return func(x_projection)

        function_value, cores_grad = jax.value_and_grad(augmented_func)(deltas)

        deltas = _enforce_gauge_conditions(cores_grad, left)
        return riemannian.deltas_to_tangent(deltas, x)
コード例 #3
0
    def _hess_by_vec(x: TTTensOrMat, vector: TangentVector) -> TangentVector:
        left = decompositions.orthogonalize(x)
        right = decompositions.orthogonalize(left, left_to_right=False)
        deltas = [right.tt_cores[0]]
        deltas += [jnp.zeros_like(cc) for cc in right.tt_cores[1:]]

        def augmented_outer_func(deltas_outer):
            def augmented_inner_func(deltas_inner):
                x_projection = riemannian.deltas_to_tangent(deltas_inner, x)
                return func(x_projection)

            function_value, cores_grad = jax.value_and_grad(
                augmented_inner_func)(deltas_outer)
            # TODO: support runtime checks

            vector_projected = project(vector, x)
            vec_deltas = riemannian.tangent_to_deltas(vector_projected)
            products = [jnp.sum(a * b) for a, b in zip(cores_grad, vec_deltas)]
            return sum(products)

        _, second_cores_grad = jax.value_and_grad(augmented_outer_func)(deltas)
        final_deltas = _enforce_gauge_conditions(second_cores_grad, left)
        # TODO: pass left and right?
        return riemannian.deltas_to_tangent(final_deltas, x)
コード例 #4
0
def norm(tt, differentiable=False):
    """Frobenius norm of TT object

  :type tt: `TT` or `TTMatrix`
  :param tt: TT object (tensor or matrix)
  :type differentiable: bool
  :param differentiable: whether to use a differentiable implementation or a fast implementation based on QR decomposition
  :return: non-negative number which is the Frobenius norm of `tt`
  :rtype: `float`
  """

    from ttax.decompositions import orthogonalize
    if differentiable:
        return jnp.sqrt(flat_inner(tt, tt))
    else:
        orth_tt = orthogonalize(tt)
        return jnp.linalg.norm(orth_tt.tt_cores[-1])
コード例 #5
0
ファイル: decompositions_test.py プロジェクト: fasghq/ttax
  def testOrthogonalizeRightToLeft(self):
    dtype = jnp.float32
    rng = jax.random.PRNGKey(0)
    shape = (2, 4, 3, 3)
    tt_ranks = (1, 5, 2, 17, 1)
    updated_tt_ranks = (1, 5, 2, 3, 1)
    tens = random_.tensor(rng, shape, tt_rank=tt_ranks, dtype=dtype)
    orthogonal = decompositions.orthogonalize(tens, left_to_right=False)

    self.assertAllClose(ops.full(tens), ops.full(orthogonal), atol=1e-5,
                        rtol=1e-5)
    self.assertArraysEqual(updated_tt_ranks, orthogonal.tt_ranks)
    # Check that the TT-cores are orthogonal.
    for core_idx in range(1, 4):
      core = orthogonal.tt_cores[core_idx]
      core = jnp.reshape(core, (updated_tt_ranks[core_idx], shape[core_idx] *
                                updated_tt_ranks[core_idx + 1]))
      should_be_eye = core @ core.T
      self.assertAllClose(np.eye(updated_tt_ranks[core_idx]), should_be_eye)