コード例 #1
0
def _mul_by_scalar(tt, c):
    tt = unwrap_tt(tt)
    cores = list(tt.tt_cores)
    cores[0] = c * cores[0]
    if tt.is_tt_matrix:
        return TTMatrix(cores)
    else:
        return TT(cores)
コード例 #2
0
def deltas_to_tangent(deltas: List[jnp.ndarray],
                      tt: TTTensOrMat) -> TTTensOrMat:
    """Converts deltas representation of tangent space vector to `TT-object`.
  Takes as input a list of [dP1, ..., dPd] and returns
  dP1 V2 ... Vd + U1 dP2 V3 ... Vd + ... + U1 ... Ud-1 dPd.
  
  This function is hard to use correctly because deltas should obey the
  so called gauge conditions. If they don't, the function will silently return
  incorrect result. That is why this function is not imported in __init__.
  
  :param deltas: a list of deltas (essentially `TT-cores`) obeying the gauge
                 conditions.
  :param tt: object on which the tangent space tensor represented
             by delta is projected.
  :type tt: `TT-Tensor` or `TT-Matrix`
  :return: object constructed from deltas, that is from the tangent
           space at point `tt`.
  :rtype: `TT-Tensor` or `TT-Matrix`
  """
    cores = []
    dtype = tt.dtype
    left = orthogonalize(tt)
    right = orthogonalize(left, left_to_right=False)
    left_rank_dim = 0
    right_rank_dim = 3 if tt.is_tt_matrix else 2
    for i in range(tt.ndim):
        left_tt_core = left.tt_cores[i]
        right_tt_core = right.tt_cores[i]

        if i == 0:
            tangent_core = jnp.concatenate((deltas[i], left_tt_core),
                                           axis=right_rank_dim)
        elif i == tt.ndim - 1:
            tangent_core = jnp.concatenate((right_tt_core, deltas[i]),
                                           axis=left_rank_dim)
        else:
            rank_1 = right.tt_ranks[i]
            rank_2 = left.tt_ranks[i + 1]
            if tt.is_tt_matrix:
                mode_size_n = tt.raw_tensor_shape[0][i]
                mode_size_m = tt.raw_tensor_shape[1][i]
                shape = [rank_1, mode_size_n, mode_size_m, rank_2]
            else:
                mode_size_n = tt.shape[i]
                shape = [rank_1, mode_size_n, rank_2]
            zeros = jnp.zeros(shape, dtype=dtype)
            upper = jnp.concatenate((right_tt_core, zeros),
                                    axis=right_rank_dim)
            lower = jnp.concatenate((deltas[i], left_tt_core),
                                    axis=right_rank_dim)
            tangent_core = jnp.concatenate((upper, lower), axis=left_rank_dim)
        cores.append(tangent_core)
    if tt.is_tt_matrix:
        return TTMatrix(cores)
    else:
        return TT(cores)
コード例 #3
0
 def vectorized_func(*args, **kwargs):
     tt_arg = args[0]  # TODO: what if only kwargs are present?
     if tt_arg.num_batch_dims == 0:
         return func(*args, **kwargs)
     else:
         if num_batch_args is not None:
             num_non_batch_args = len(args) + len(
                 kwargs) - num_batch_args
             in_axis = [0] * num_batch_args + [None
                                               ] * num_non_batch_args
             num_args = num_batch_args
         else:
             num_args = len(args) + len(kwargs)
             in_axis = [0] * num_args
         if num_args > 1 and (isinstance(args[1], TTMatrix)
                              or isinstance(args[1], TT)):
             if args[0].is_tt_matrix != args[1].is_tt_matrix:
                 raise ValueError(
                     'Types of the arguments are different.')
             if not are_batches_broadcastable(args[0], args[1]):
                 raise ValueError(
                     'The batch sizes are different and not 1, '
                     'broadcasting is not available.')
             broadcast_shape = np.maximum(list(args[0].batch_shape),
                                          list(args[1].batch_shape))
             new_args = list(args)
             if args[0].is_tt_matrix:
                 for i, tt in enumerate(args[:2]):
                     new_cores = []
                     for core in tt.tt_cores:
                         core = jnp.broadcast_to(
                             core,
                             list(broadcast_shape) +
                             list(core.shape[-4:]))
                         new_cores.append(core)
                     new_args[i] = TTMatrix(new_cores)
             else:
                 for i, tt in enumerate(args[:2]):
                     new_cores = []
                     for core in tt.tt_cores:
                         core = jnp.broadcast_to(
                             core,
                             list(broadcast_shape) +
                             list(core.shape[-3:]))
                         new_cores.append(core)
                     new_args[i] = TT(new_cores)
         else:
             new_args = args
         vmapped = func
         for _ in range(tt_arg.num_batch_dims):
             vmapped = jax.vmap(vmapped, in_axis)
         return vmapped(*new_args, **kwargs)
コード例 #4
0
ファイル: decompositions_test.py プロジェクト: fasghq/ttax
 def testRound2d(self):
   dtype = jnp.float32
   rank = 5
   np.random.seed(0)
   x = np.random.randn(10, 20).astype(dtype)
   u, s, v = np.linalg.svd(x, full_matrices=False)
   core_1 = u @ np.diag(s)
   core_1 = core_1.reshape(1, 10, 10)
   core_2 = v
   core_2 = core_2.reshape(10, 20, 1)
   tt = TT((core_1, core_2))
   truncated_x = u[:, :rank] @ np.diag(s[:rank]) @ v[:rank, :]
   rounded = decompositions.round(tt, 5)
   self.assertAllClose(truncated_x, ops.full(rounded), rtol=1e-5, atol=1e-5)
コード例 #5
0
def vector_to_tensor(tt):
    """Converts TT-matrix to TT-tensor, if matrix has shape N x 1 or 1 x N

    :type tt:   `TT-Matrix`
    :param tt:  TT-matrix
    :return:    `TT`
    :rtype:     `TT`
    :raises [ValueError]: if the argument is not a TT-matrix, or if matrix has wrong shape
  """
    if not isinstance(tt, TTMatrix) or not tt.is_tt_matrix:
        raise ValueError('The argument should be a TT-matrix')
    if tt.shape[0] != 1 and tt.shape[1] != 1:
        raise ValueError(
            'At least one of matrix dimensions should be equal to one')

    cores = []
    for core in tt.tt_cores:
        if tt.shape[0] == 1:
            cores.append(jnp.squeeze(core, 1))
        else:
            cores.append(jnp.squeeze(core, 2))
    return TT(cores)
コード例 #6
0
def tensor(rng, shape, tt_rank=2, batch_shape=None, dtype=jnp.float32):
    """Generate a random `TT-Tensor` of the given shape and `TT-rank`.
  
  :param rng: JAX PRNG key
  :type rng: random state is described by two unsigned 32-bit integers
  :param shape: desired tensor shape
  :type shape: array
  :param tt_rank: desired `TT-ranks` of `TT-Tensor`
  :type tt_rank: single number for equal `TT-ranks` or array specifying all `TT-ranks`
  :param batch_shape: desired batch shape of `TT-Tensor`
  :type batch_shape: array
  :param dtype: type of elements in `TT-Tensor`
  :type dtype: `dtype`
  :return: generated `TT-Tensor`
  :rtype: TT
  """
    shape = np.array(shape)
    tt_rank = np.array(tt_rank)
    batch_shape = list(batch_shape) if batch_shape else []

    num_dims = shape.size
    if tt_rank.size == 1:
        tt_rank = tt_rank * np.ones(num_dims - 1)
        tt_rank = np.insert(tt_rank, 0, 1)
        tt_rank = np.append(tt_rank, 1)

    tt_rank = tt_rank.astype(int)

    tt_cores = []
    rng_arr = jax.random.split(rng, num_dims)
    for i in range(num_dims):
        curr_core_shape = [tt_rank[i], shape[i], tt_rank[i + 1]]
        curr_core_shape = batch_shape + curr_core_shape
        tt_cores.append(
            jax.random.normal(rng_arr[i], curr_core_shape, dtype=dtype))

    return TT(tt_cores)
コード例 #7
0
ファイル: compile.py プロジェクト: fasghq/ttax
  def new_func(*args):
    are_tt_matrix_inputs = args[0].is_tt_matrix
    tt_einsum_ = tt_einsum.resolve_i_or_ij(are_tt_matrix_inputs)

    is_fusing = any([isinstance(tt, WrappedTT) for tt in args])
    if is_fusing:
      # Have to use a different name to make upper level tt_einsum visible.
      tt_einsum_, args = _fuse_tt_einsums(tt_einsum_, args)
    einsum = tt_einsum_.to_vanilla_einsum()
    num_batch_dims = args[0].num_batch_dims
    # TODO: support broadcasting.
    res_batch_shape = list(args[0].batch_shape)
    # TODO: do in parallel w.r.t. cores.
    # TODO: use optimal einsum.
    res_cores = []
    for i in range(len(args[0].tt_cores)):
      curr_input_cores = [tt.tt_cores[i] for tt in args]
      core = oe.contract(einsum, *curr_input_cores, backend='jax')
      shape = core.shape[num_batch_dims:]
      num_left_rank_dims = len(tt_einsum_.output[0])
      num_tensor_dims = len(tt_einsum_.output[1])
      split_points = (num_left_rank_dims, num_left_rank_dims + num_tensor_dims)
      new_shape = np.split(shape, split_points)
      left_rank = np.prod(new_shape[0])
      right_rank = np.prod(new_shape[2])
      new_shape = [left_rank] + new_shape[1].tolist() + [right_rank]
      new_shape = res_batch_shape + new_shape
      res_cores.append(core.reshape(new_shape))

    if are_tt_matrix_inputs:
      res = TTMatrix(res_cores)
    else:
      res = TT(res_cores)

    if is_fusing:
      res = WrappedTT(res, args, tt_einsum_)
    return res
コード例 #8
0
def add(tt_a, tt_b):
    """Returns a `TT-object` corresponding to elementwise sum `tt_a + tt_b`.
  The shapes of `tt_a` and `tt_b` should coincide.
  Supports broadcasting, e.g. you can add a tensor train with
  batch size 7 and a tensor train with batch size 1:
  
  ``tt_batch.add(tt_single.batch_loc[np.newaxis])``
  
  where ``tt_single.batch_loc[np.newaxis]`` 
  creates a singleton batch dimension.

  :type tt_a: `TT-Tensor` or `TT-Matrix`
  :param tt_a: first argument
  :type tt_b: `TT-Tensor` or `TT-Matrix`
  :param tt_b: second argument
  :rtype: `TT-Tensor` or `TT-Matrix`
  :return: `tt_a + tt_b`
  :raises [ValueError]: if the arguments shapes do not coincide
  """
    tt_a = unwrap_tt(tt_a)
    tt_b = unwrap_tt(tt_b)

    if not are_shapes_equal(tt_a, tt_b):
        raise ValueError('Types of the arguments or their tensor '
                         'shapes are different, addition is not '
                         'available.')
    if not are_batches_broadcastable(tt_a, tt_b):
        raise ValueError('The batch sizes are different and not 1, '
                         'broadcasting is not available.')

    if tt_a.is_tt_matrix:
        tt_cores = _add_matrix_cores(tt_a, tt_b)
        return TTMatrix(tt_cores)
    else:
        tt_cores = _add_tensor_cores(tt_a, tt_b)
        return TT(tt_cores)
コード例 #9
0
ファイル: decompositions.py プロジェクト: fasghq/ttax
def _orthogonalize_tt_cores_right_to_left(tt):
  """Orthogonalize TT-cores of a TT-object.
  Args:
    tt: TT-tensor or TT-matrix.
  Returns:
    TT-tensor or TT-matrix.
  """
  
  # Right to left orthogonalization.
  num_dims = tt.ndim
  if tt.is_tt_matrix:
    raw_shape = tt.raw_tensor_shape
  else:
    raw_shape = tt.shape

  tt_ranks = tt.tt_ranks
  prev_rank = tt_ranks[num_dims]
  # Copy cores references so we can change the cores.
  tt_cores = list(tt.tt_cores)
  for core_idx in range(num_dims - 1, 0, -1):
    curr_core = tt_cores[core_idx]
    # TT-ranks could have changed on the previous iteration, so `tt_ranks` can
    # be outdated for the current TT-rank, but should be valid for the next
    # TT-rank.
    curr_rank = prev_rank
    prev_rank = tt_ranks[core_idx]
    if tt.is_tt_matrix:
      curr_mode_left = raw_shape[0][core_idx]
      curr_mode_right = raw_shape[1][core_idx]
      curr_mode = curr_mode_left * curr_mode_right
    else:
      curr_mode = raw_shape[core_idx]
    
    qr_shape = (prev_rank, curr_mode * curr_rank)
    curr_core = jnp.reshape(curr_core, qr_shape)
    curr_core, triang = jnp.linalg.qr(curr_core.T)
    curr_core = curr_core.T
    triang = triang.T
    triang_shape = triang.shape

    # The TT-rank could have changed: if qr_shape is e.g. 4 x 10, than q would
    # be of size 4 x 4 and r would be 4 x 10, which means that the next rank
    # should be changed to 4.
    prev_rank = triang_shape[1]
    if tt.is_tt_matrix:
      new_core_shape = (prev_rank, curr_mode_left, curr_mode_right, curr_rank)
    else:
      new_core_shape = (prev_rank, curr_mode, curr_rank)
    tt_cores[core_idx] = jnp.reshape(curr_core, new_core_shape)

    prev_core = jnp.reshape(tt_cores[core_idx - 1], (-1, triang_shape[0]))
    tt_cores[core_idx - 1] = jnp.matmul(prev_core, triang)

  if tt.is_tt_matrix:
    first_core_shape = (1, raw_shape[0][0], raw_shape[1][0], prev_rank)
  else:
    first_core_shape = (1, raw_shape[0], prev_rank)
  tt_cores[0] = jnp.reshape(tt_cores[0], first_core_shape)
  
  if tt.is_tt_matrix:
    return TTMatrix(tt_cores)
  else:
    return TT(tt_cores)
コード例 #10
0
ファイル: decompositions.py プロジェクト: fasghq/ttax
def _orthogonalize_tt_cores_left_to_right(tt):
  """Orthogonalize TT-cores of a TT-object.
  Args:
    tt: TT-tensor or TT-matrix.
    TT-tensor or TT-matrix.
  """

  # Left to right orthogonalization.
  num_dims = tt.ndim
  if tt.is_tt_matrix:
    raw_shape = tt.raw_tensor_shape
  else:
    raw_shape = tt.shape

  tt_ranks = tt.tt_ranks
  next_rank = tt_ranks[0]
  # Copy cores references so we can change the cores.
  tt_cores = list(tt.tt_cores)
  for core_idx in range(num_dims - 1):
    curr_core = tt_cores[core_idx]
    # TT-ranks could have changed on the previous iteration, so `tt_ranks` can
    # be outdated for the current TT-rank, but should be valid for the next
    # TT-rank.
    curr_rank = next_rank
    next_rank = tt_ranks[core_idx + 1]
    if tt.is_tt_matrix:
      curr_mode_left = raw_shape[0][core_idx]
      curr_mode_right = raw_shape[1][core_idx]
      curr_mode = curr_mode_left * curr_mode_right
    else:
      curr_mode = raw_shape[core_idx]
    
    qr_shape = (curr_rank * curr_mode, next_rank)
    curr_core = jnp.reshape(curr_core, qr_shape)
    curr_core, triang = jnp.linalg.qr(curr_core)
    
    triang_shape = triang.shape

    # The TT-rank could have changed: if qr_shape is e.g. 4 x 10, than q would
    # be of size 4 x 4 and r would be 4 x 10, which means that the next rank
    # should be changed to 4.
    next_rank = triang_shape[0]
    if tt.is_tt_matrix:
      new_core_shape = (curr_rank, curr_mode_left, curr_mode_right, next_rank)
    else:
      new_core_shape = (curr_rank, curr_mode, next_rank)
    tt_cores[core_idx] = jnp.reshape(curr_core, new_core_shape)

    next_core = jnp.reshape(tt_cores[core_idx + 1], (triang_shape[1], -1))
    tt_cores[core_idx + 1] = jnp.matmul(triang, next_core)

  if tt.is_tt_matrix:
    last_core_shape = (next_rank, raw_shape[0][-1], raw_shape[1][-1], 1)
  else:
    last_core_shape = (next_rank, raw_shape[-1], 1)
  tt_cores[-1] = jnp.reshape(tt_cores[-1], last_core_shape)

  if tt.is_tt_matrix:
    return TTMatrix(tt_cores)
  else:
    return TT(tt_cores)
コード例 #11
0
ファイル: decompositions.py プロジェクト: fasghq/ttax
def round(tt, max_tt_rank=None, epsilon=None):
  """Tensor Train rounding procedure, returns a `TT-object` with smaller `TT-ranks`.

  :param tt: argument which ranks would be reduced
  :type tt: `TT-Tensor` or `TT-Matrix`
  :type max_tt_rank: int or list of ints
  :param max_tt_rank: 
  
    - If a number, than defines the maximal `TT-rank` of the result. 
    - If a list of numbers, than `max_tt_rank` length should be d+1 
      (where d is the number of dimensions) and `max_tt_rank[i]`
      defines the maximal (i+1)-th `TT-rank` of the result.
      
      The following two versions are equivalent
      
      - ``max_tt_rank = r``
      
      - ``max_tt_rank = [1] + [r] * (d-1) + [1]``  
       
  :type epsilon: float or None
  :param epsilon:
  
    - If the `TT-ranks` are not restricted (`max_tt_rank=None`), then
      the result would be guarantied to be `epsilon`-close to `tt`
      in terms of relative Frobenius error:
      
        `||res - tt||_F / ||tt||_F <= epsilon`
        
    - If the `TT-ranks` are restricted, providing a loose `epsilon` may
      reduce the `TT-ranks` of the result. E.g.
      
        ``round(tt, max_tt_rank=100, epsilon=0.9)``
        
      will probably return you a `TT-Tensor` with `TT-ranks` close to 1, not 100.
      Note that providing a nontrivial (= not equal to `None`) epsilon will make
      the `TT-ranks` of the result change depending on the data, 
      which will prevent you from using ``jax.jit``
      for speeding up the computations.
  :return: `TT-object` with reduced `TT-ranks`
  :rtype: `TT-Tensor` or `TT-Matrix`
  :raises:
    ValueError if `max_tt_rank` is less than 0, if `max_tt_rank` is not a number and
    not a vector of length d + 1 where d is the number of dimensions of
    the input tensor, if `epsilon` is less than 0.
  """

  if max_tt_rank is None:
    max_tt_rank = np.iinfo(np.int32).max
  num_dims = tt.ndim
  max_tt_rank = np.array(max_tt_rank).astype(np.int32)
  if np.any(max_tt_rank < 1):
    raise ValueError('Maximum TT-rank should be greater or equal to 1.')
  if epsilon is not None:
    raise NotImplementedError('Epsilon is not supported yet.')
  if max_tt_rank.size == 1:
    max_tt_rank = (max_tt_rank * np.ones(num_dims + 1)).astype(jnp.int32)
  elif max_tt_rank.size != num_dims + 1:
    raise ValueError('max_tt_rank should be a number or a vector of size (d+1) '
                     'where d is the number of dimensions of the tensor.')
  if tt.is_tt_matrix:
    raw_shape = tt.raw_tensor_shape
  else:
    raw_shape = tt.shape

  tt_cores = orthogonalize(tt).tt_cores
  # Copy cores references so we can change the cores.
  tt_cores = list(tt_cores)

  ranks = [1] * (num_dims + 1)

  # Right to left SVD compression.
  for core_idx in range(num_dims - 1, 0, -1):
    curr_core = tt_cores[core_idx]
    if tt.is_tt_matrix:
      curr_mode_left = raw_shape[0][core_idx]
      curr_mode_right = raw_shape[1][core_idx]
      curr_mode = curr_mode_left * curr_mode_right
    else:
      curr_mode = raw_shape[core_idx]
    
    columns = curr_mode * ranks[core_idx + 1]
    curr_core = jnp.reshape(curr_core, [-1, columns])
    rows = curr_core.shape[0]
    if max_tt_rank[core_idx] == 1:
      ranks[core_idx] = 1
    else:
      ranks[core_idx] = min(max_tt_rank[core_idx], rows, columns)
    u, s, v = jnp.linalg.svd(curr_core, full_matrices=False)
    u = u[:, 0:ranks[core_idx]]
    s = s[0:ranks[core_idx]]
    v = v[0:ranks[core_idx], :]
    if tt.is_tt_matrix:
      core_shape = (ranks[core_idx], curr_mode_left, curr_mode_right,
                    ranks[core_idx + 1])
    else:
      core_shape = (ranks[core_idx], curr_mode, ranks[core_idx + 1])
    tt_cores[core_idx] = jnp.reshape(v, core_shape)
    prev_core_shape = (-1, rows)
    tt_cores[core_idx - 1] = jnp.reshape(tt_cores[core_idx - 1], prev_core_shape)
    tt_cores[core_idx - 1] = jnp.matmul(tt_cores[core_idx - 1], u)
    tt_cores[core_idx - 1] = jnp.matmul(tt_cores[core_idx - 1], jnp.diag(s))

  if tt.is_tt_matrix:
    core_shape = (ranks[0], raw_shape[0][0], raw_shape[1][0], ranks[1])
  else:
    core_shape = (ranks[0], raw_shape[0], ranks[1])
  tt_cores[0] = jnp.reshape(tt_cores[0], core_shape)

  if tt.is_tt_matrix:
    return TTMatrix(tt_cores)
  else:
    return TT(tt_cores)