コード例 #1
0
def tensor_to_vector(tt, vertical=True):
    """Converts TT-tensor to TT-matrix with column shape equals to I = (1, ..., 1). If by_columns is False,
  then row shape equals to I.

    :type tt:         `TT`
    :param tt:        TT-tensor
    :type vertical: `bool`
    :param vertical: defines, whether tt will be located by columns or by rows
    :return:          `TT-Matrix`
    :rtype:           `TT-Matrix`
    :raises [ValueError]: if the argument is not a TT-tensor
  """
    if not isinstance(tt, TT):
        raise ValueError('The argument should be a TT-tensor, not TT-matrix.')

    cores = []
    for core in tt.tt_cores:
        if vertical:
            cores.append(
                jnp.reshape(core,
                            (core.shape[0], core.shape[1], 1, core.shape[2]),
                            order="F"))
        else:
            cores.append(
                jnp.reshape(core,
                            (core.shape[0], 1, core.shape[1], core.shape[2]),
                            order="F"))
    return TTMatrix(cores)
コード例 #2
0
def _mul_by_scalar(tt, c):
    tt = unwrap_tt(tt)
    cores = list(tt.tt_cores)
    cores[0] = c * cores[0]
    if tt.is_tt_matrix:
        return TTMatrix(cores)
    else:
        return TT(cores)
コード例 #3
0
def deltas_to_tangent(deltas: List[jnp.ndarray],
                      tt: TTTensOrMat) -> TTTensOrMat:
    """Converts deltas representation of tangent space vector to `TT-object`.
  Takes as input a list of [dP1, ..., dPd] and returns
  dP1 V2 ... Vd + U1 dP2 V3 ... Vd + ... + U1 ... Ud-1 dPd.
  
  This function is hard to use correctly because deltas should obey the
  so called gauge conditions. If they don't, the function will silently return
  incorrect result. That is why this function is not imported in __init__.
  
  :param deltas: a list of deltas (essentially `TT-cores`) obeying the gauge
                 conditions.
  :param tt: object on which the tangent space tensor represented
             by delta is projected.
  :type tt: `TT-Tensor` or `TT-Matrix`
  :return: object constructed from deltas, that is from the tangent
           space at point `tt`.
  :rtype: `TT-Tensor` or `TT-Matrix`
  """
    cores = []
    dtype = tt.dtype
    left = orthogonalize(tt)
    right = orthogonalize(left, left_to_right=False)
    left_rank_dim = 0
    right_rank_dim = 3 if tt.is_tt_matrix else 2
    for i in range(tt.ndim):
        left_tt_core = left.tt_cores[i]
        right_tt_core = right.tt_cores[i]

        if i == 0:
            tangent_core = jnp.concatenate((deltas[i], left_tt_core),
                                           axis=right_rank_dim)
        elif i == tt.ndim - 1:
            tangent_core = jnp.concatenate((right_tt_core, deltas[i]),
                                           axis=left_rank_dim)
        else:
            rank_1 = right.tt_ranks[i]
            rank_2 = left.tt_ranks[i + 1]
            if tt.is_tt_matrix:
                mode_size_n = tt.raw_tensor_shape[0][i]
                mode_size_m = tt.raw_tensor_shape[1][i]
                shape = [rank_1, mode_size_n, mode_size_m, rank_2]
            else:
                mode_size_n = tt.shape[i]
                shape = [rank_1, mode_size_n, rank_2]
            zeros = jnp.zeros(shape, dtype=dtype)
            upper = jnp.concatenate((right_tt_core, zeros),
                                    axis=right_rank_dim)
            lower = jnp.concatenate((deltas[i], left_tt_core),
                                    axis=right_rank_dim)
            tangent_core = jnp.concatenate((upper, lower), axis=left_rank_dim)
        cores.append(tangent_core)
    if tt.is_tt_matrix:
        return TTMatrix(cores)
    else:
        return TT(cores)
コード例 #4
0
def matrix(rng, shape, tt_rank=2, batch_shape=None, dtype=jnp.float32):
    """Generate a random `TT-Matrix` of the given shape and `TT-rank`.
    
  :param rng: JAX PRNG key
  :type rng: random state is described by two unsigned 32-bit integers
  :param shape: desired tensor shape. Also supports omitting one of the dimensions
        matrix(..., shape=((2, 2, 2), None))
      and
        matrix(..., shape=(None, (2, 2, 2)))
      will create an 8-element column/row vector.
  :type shape: tuple
  :param tt_rank: desired `TT-ranks` of `TT-Matrix`
  :type tt_rank: single number for equal `TT-ranks` or array specifying all `TT-ranks`
  :param batch_shape: desired batch shape of `TT-Matrix`
  :type batch_shape: tuple
  :param dtype: type of elements in `TT-Matrix`
  :type dtype: `dtype`
  :return: generated `TT-Matrix`
  :rtype: TTMatrix
  :raises [ValueError]: if shape is (None, None)
  """
    if shape == (None, None):
        raise ValueError("At least one of shape elements must not be None")
    if None in shape:
        shape = [
            np.array(shape[0])
            if shape[0] is not None else np.ones(len(shape[1]), dtype=int),
            np.array(shape[1])
            if shape[1] is not None else np.ones(len(shape[0]), dtype=int),
        ]
    else:
        shape = [np.array(shape[0]), np.array(shape[1])]
    tt_rank = np.array(tt_rank)
    batch_shape = list(batch_shape) if batch_shape else []

    num_dims = shape[0].size
    if tt_rank.size == 1:
        tt_rank = tt_rank * np.ones(num_dims - 1)
        tt_rank = np.insert(tt_rank, 0, 1)
        tt_rank = np.append(tt_rank, 1)

    tt_rank = tt_rank.astype(int)

    tt_cores = []
    rng_arr = jax.random.split(rng, num_dims)
    for i in range(num_dims):
        curr_core_shape = [
            tt_rank[i], shape[0][i], shape[1][i], tt_rank[i + 1]
        ]
        curr_core_shape = batch_shape + curr_core_shape
        tt_cores.append(
            jax.random.normal(rng_arr[i], curr_core_shape, dtype=dtype))

    return TTMatrix(tt_cores)
コード例 #5
0
 def vectorized_func(*args, **kwargs):
     tt_arg = args[0]  # TODO: what if only kwargs are present?
     if tt_arg.num_batch_dims == 0:
         return func(*args, **kwargs)
     else:
         if num_batch_args is not None:
             num_non_batch_args = len(args) + len(
                 kwargs) - num_batch_args
             in_axis = [0] * num_batch_args + [None
                                               ] * num_non_batch_args
             num_args = num_batch_args
         else:
             num_args = len(args) + len(kwargs)
             in_axis = [0] * num_args
         if num_args > 1 and (isinstance(args[1], TTMatrix)
                              or isinstance(args[1], TT)):
             if args[0].is_tt_matrix != args[1].is_tt_matrix:
                 raise ValueError(
                     'Types of the arguments are different.')
             if not are_batches_broadcastable(args[0], args[1]):
                 raise ValueError(
                     'The batch sizes are different and not 1, '
                     'broadcasting is not available.')
             broadcast_shape = np.maximum(list(args[0].batch_shape),
                                          list(args[1].batch_shape))
             new_args = list(args)
             if args[0].is_tt_matrix:
                 for i, tt in enumerate(args[:2]):
                     new_cores = []
                     for core in tt.tt_cores:
                         core = jnp.broadcast_to(
                             core,
                             list(broadcast_shape) +
                             list(core.shape[-4:]))
                         new_cores.append(core)
                     new_args[i] = TTMatrix(new_cores)
             else:
                 for i, tt in enumerate(args[:2]):
                     new_cores = []
                     for core in tt.tt_cores:
                         core = jnp.broadcast_to(
                             core,
                             list(broadcast_shape) +
                             list(core.shape[-3:]))
                         new_cores.append(core)
                     new_args[i] = TT(new_cores)
         else:
             new_args = args
         vmapped = func
         for _ in range(tt_arg.num_batch_dims):
             vmapped = jax.vmap(vmapped, in_axis)
         return vmapped(*new_args, **kwargs)
コード例 #6
0
def transpose(tt):
    """Transpose a TT-matrix or a batch of TT-matrices.

  :type tt:   `TT-Matrix` object containing TT-matrix or batch of TT-matrices
  :param tt:  TT-matrix or batch of TT-matrices
  :return:    `TT-Matrix` object containing TT-matrix or batch of TT-matrices
  :rtype:     `TT-Matrix` object containing TT-matrix or batch of TT-matrices
  :raises [ValueError]: if the argument is not a TT-matrix
  """
    if not isinstance(tt, TTMatrix) or not tt.is_tt_matrix:
        raise ValueError('The argument should be a TT-matrix.')

    transposed_tt_cores = [
        jnp.transpose(tt.tt_cores[core_idx], (0, 2, 1, 3))
        for core_idx in range(tt.ndim)
    ]
    return TTMatrix(transposed_tt_cores)
コード例 #7
0
ファイル: compile.py プロジェクト: fasghq/ttax
  def new_func(*args):
    are_tt_matrix_inputs = args[0].is_tt_matrix
    tt_einsum_ = tt_einsum.resolve_i_or_ij(are_tt_matrix_inputs)

    is_fusing = any([isinstance(tt, WrappedTT) for tt in args])
    if is_fusing:
      # Have to use a different name to make upper level tt_einsum visible.
      tt_einsum_, args = _fuse_tt_einsums(tt_einsum_, args)
    einsum = tt_einsum_.to_vanilla_einsum()
    num_batch_dims = args[0].num_batch_dims
    # TODO: support broadcasting.
    res_batch_shape = list(args[0].batch_shape)
    # TODO: do in parallel w.r.t. cores.
    # TODO: use optimal einsum.
    res_cores = []
    for i in range(len(args[0].tt_cores)):
      curr_input_cores = [tt.tt_cores[i] for tt in args]
      core = oe.contract(einsum, *curr_input_cores, backend='jax')
      shape = core.shape[num_batch_dims:]
      num_left_rank_dims = len(tt_einsum_.output[0])
      num_tensor_dims = len(tt_einsum_.output[1])
      split_points = (num_left_rank_dims, num_left_rank_dims + num_tensor_dims)
      new_shape = np.split(shape, split_points)
      left_rank = np.prod(new_shape[0])
      right_rank = np.prod(new_shape[2])
      new_shape = [left_rank] + new_shape[1].tolist() + [right_rank]
      new_shape = res_batch_shape + new_shape
      res_cores.append(core.reshape(new_shape))

    if are_tt_matrix_inputs:
      res = TTMatrix(res_cores)
    else:
      res = TT(res_cores)

    if is_fusing:
      res = WrappedTT(res, args, tt_einsum_)
    return res
コード例 #8
0
def add(tt_a, tt_b):
    """Returns a `TT-object` corresponding to elementwise sum `tt_a + tt_b`.
  The shapes of `tt_a` and `tt_b` should coincide.
  Supports broadcasting, e.g. you can add a tensor train with
  batch size 7 and a tensor train with batch size 1:
  
  ``tt_batch.add(tt_single.batch_loc[np.newaxis])``
  
  where ``tt_single.batch_loc[np.newaxis]`` 
  creates a singleton batch dimension.

  :type tt_a: `TT-Tensor` or `TT-Matrix`
  :param tt_a: first argument
  :type tt_b: `TT-Tensor` or `TT-Matrix`
  :param tt_b: second argument
  :rtype: `TT-Tensor` or `TT-Matrix`
  :return: `tt_a + tt_b`
  :raises [ValueError]: if the arguments shapes do not coincide
  """
    tt_a = unwrap_tt(tt_a)
    tt_b = unwrap_tt(tt_b)

    if not are_shapes_equal(tt_a, tt_b):
        raise ValueError('Types of the arguments or their tensor '
                         'shapes are different, addition is not '
                         'available.')
    if not are_batches_broadcastable(tt_a, tt_b):
        raise ValueError('The batch sizes are different and not 1, '
                         'broadcasting is not available.')

    if tt_a.is_tt_matrix:
        tt_cores = _add_matrix_cores(tt_a, tt_b)
        return TTMatrix(tt_cores)
    else:
        tt_cores = _add_tensor_cores(tt_a, tt_b)
        return TT(tt_cores)
コード例 #9
0
ファイル: decompositions.py プロジェクト: fasghq/ttax
def _orthogonalize_tt_cores_right_to_left(tt):
  """Orthogonalize TT-cores of a TT-object.
  Args:
    tt: TT-tensor or TT-matrix.
  Returns:
    TT-tensor or TT-matrix.
  """
  
  # Right to left orthogonalization.
  num_dims = tt.ndim
  if tt.is_tt_matrix:
    raw_shape = tt.raw_tensor_shape
  else:
    raw_shape = tt.shape

  tt_ranks = tt.tt_ranks
  prev_rank = tt_ranks[num_dims]
  # Copy cores references so we can change the cores.
  tt_cores = list(tt.tt_cores)
  for core_idx in range(num_dims - 1, 0, -1):
    curr_core = tt_cores[core_idx]
    # TT-ranks could have changed on the previous iteration, so `tt_ranks` can
    # be outdated for the current TT-rank, but should be valid for the next
    # TT-rank.
    curr_rank = prev_rank
    prev_rank = tt_ranks[core_idx]
    if tt.is_tt_matrix:
      curr_mode_left = raw_shape[0][core_idx]
      curr_mode_right = raw_shape[1][core_idx]
      curr_mode = curr_mode_left * curr_mode_right
    else:
      curr_mode = raw_shape[core_idx]
    
    qr_shape = (prev_rank, curr_mode * curr_rank)
    curr_core = jnp.reshape(curr_core, qr_shape)
    curr_core, triang = jnp.linalg.qr(curr_core.T)
    curr_core = curr_core.T
    triang = triang.T
    triang_shape = triang.shape

    # The TT-rank could have changed: if qr_shape is e.g. 4 x 10, than q would
    # be of size 4 x 4 and r would be 4 x 10, which means that the next rank
    # should be changed to 4.
    prev_rank = triang_shape[1]
    if tt.is_tt_matrix:
      new_core_shape = (prev_rank, curr_mode_left, curr_mode_right, curr_rank)
    else:
      new_core_shape = (prev_rank, curr_mode, curr_rank)
    tt_cores[core_idx] = jnp.reshape(curr_core, new_core_shape)

    prev_core = jnp.reshape(tt_cores[core_idx - 1], (-1, triang_shape[0]))
    tt_cores[core_idx - 1] = jnp.matmul(prev_core, triang)

  if tt.is_tt_matrix:
    first_core_shape = (1, raw_shape[0][0], raw_shape[1][0], prev_rank)
  else:
    first_core_shape = (1, raw_shape[0], prev_rank)
  tt_cores[0] = jnp.reshape(tt_cores[0], first_core_shape)
  
  if tt.is_tt_matrix:
    return TTMatrix(tt_cores)
  else:
    return TT(tt_cores)
コード例 #10
0
ファイル: decompositions.py プロジェクト: fasghq/ttax
def _orthogonalize_tt_cores_left_to_right(tt):
  """Orthogonalize TT-cores of a TT-object.
  Args:
    tt: TT-tensor or TT-matrix.
    TT-tensor or TT-matrix.
  """

  # Left to right orthogonalization.
  num_dims = tt.ndim
  if tt.is_tt_matrix:
    raw_shape = tt.raw_tensor_shape
  else:
    raw_shape = tt.shape

  tt_ranks = tt.tt_ranks
  next_rank = tt_ranks[0]
  # Copy cores references so we can change the cores.
  tt_cores = list(tt.tt_cores)
  for core_idx in range(num_dims - 1):
    curr_core = tt_cores[core_idx]
    # TT-ranks could have changed on the previous iteration, so `tt_ranks` can
    # be outdated for the current TT-rank, but should be valid for the next
    # TT-rank.
    curr_rank = next_rank
    next_rank = tt_ranks[core_idx + 1]
    if tt.is_tt_matrix:
      curr_mode_left = raw_shape[0][core_idx]
      curr_mode_right = raw_shape[1][core_idx]
      curr_mode = curr_mode_left * curr_mode_right
    else:
      curr_mode = raw_shape[core_idx]
    
    qr_shape = (curr_rank * curr_mode, next_rank)
    curr_core = jnp.reshape(curr_core, qr_shape)
    curr_core, triang = jnp.linalg.qr(curr_core)
    
    triang_shape = triang.shape

    # The TT-rank could have changed: if qr_shape is e.g. 4 x 10, than q would
    # be of size 4 x 4 and r would be 4 x 10, which means that the next rank
    # should be changed to 4.
    next_rank = triang_shape[0]
    if tt.is_tt_matrix:
      new_core_shape = (curr_rank, curr_mode_left, curr_mode_right, next_rank)
    else:
      new_core_shape = (curr_rank, curr_mode, next_rank)
    tt_cores[core_idx] = jnp.reshape(curr_core, new_core_shape)

    next_core = jnp.reshape(tt_cores[core_idx + 1], (triang_shape[1], -1))
    tt_cores[core_idx + 1] = jnp.matmul(triang, next_core)

  if tt.is_tt_matrix:
    last_core_shape = (next_rank, raw_shape[0][-1], raw_shape[1][-1], 1)
  else:
    last_core_shape = (next_rank, raw_shape[-1], 1)
  tt_cores[-1] = jnp.reshape(tt_cores[-1], last_core_shape)

  if tt.is_tt_matrix:
    return TTMatrix(tt_cores)
  else:
    return TT(tt_cores)
コード例 #11
0
ファイル: decompositions.py プロジェクト: fasghq/ttax
def round(tt, max_tt_rank=None, epsilon=None):
  """Tensor Train rounding procedure, returns a `TT-object` with smaller `TT-ranks`.

  :param tt: argument which ranks would be reduced
  :type tt: `TT-Tensor` or `TT-Matrix`
  :type max_tt_rank: int or list of ints
  :param max_tt_rank: 
  
    - If a number, than defines the maximal `TT-rank` of the result. 
    - If a list of numbers, than `max_tt_rank` length should be d+1 
      (where d is the number of dimensions) and `max_tt_rank[i]`
      defines the maximal (i+1)-th `TT-rank` of the result.
      
      The following two versions are equivalent
      
      - ``max_tt_rank = r``
      
      - ``max_tt_rank = [1] + [r] * (d-1) + [1]``  
       
  :type epsilon: float or None
  :param epsilon:
  
    - If the `TT-ranks` are not restricted (`max_tt_rank=None`), then
      the result would be guarantied to be `epsilon`-close to `tt`
      in terms of relative Frobenius error:
      
        `||res - tt||_F / ||tt||_F <= epsilon`
        
    - If the `TT-ranks` are restricted, providing a loose `epsilon` may
      reduce the `TT-ranks` of the result. E.g.
      
        ``round(tt, max_tt_rank=100, epsilon=0.9)``
        
      will probably return you a `TT-Tensor` with `TT-ranks` close to 1, not 100.
      Note that providing a nontrivial (= not equal to `None`) epsilon will make
      the `TT-ranks` of the result change depending on the data, 
      which will prevent you from using ``jax.jit``
      for speeding up the computations.
  :return: `TT-object` with reduced `TT-ranks`
  :rtype: `TT-Tensor` or `TT-Matrix`
  :raises:
    ValueError if `max_tt_rank` is less than 0, if `max_tt_rank` is not a number and
    not a vector of length d + 1 where d is the number of dimensions of
    the input tensor, if `epsilon` is less than 0.
  """

  if max_tt_rank is None:
    max_tt_rank = np.iinfo(np.int32).max
  num_dims = tt.ndim
  max_tt_rank = np.array(max_tt_rank).astype(np.int32)
  if np.any(max_tt_rank < 1):
    raise ValueError('Maximum TT-rank should be greater or equal to 1.')
  if epsilon is not None:
    raise NotImplementedError('Epsilon is not supported yet.')
  if max_tt_rank.size == 1:
    max_tt_rank = (max_tt_rank * np.ones(num_dims + 1)).astype(jnp.int32)
  elif max_tt_rank.size != num_dims + 1:
    raise ValueError('max_tt_rank should be a number or a vector of size (d+1) '
                     'where d is the number of dimensions of the tensor.')
  if tt.is_tt_matrix:
    raw_shape = tt.raw_tensor_shape
  else:
    raw_shape = tt.shape

  tt_cores = orthogonalize(tt).tt_cores
  # Copy cores references so we can change the cores.
  tt_cores = list(tt_cores)

  ranks = [1] * (num_dims + 1)

  # Right to left SVD compression.
  for core_idx in range(num_dims - 1, 0, -1):
    curr_core = tt_cores[core_idx]
    if tt.is_tt_matrix:
      curr_mode_left = raw_shape[0][core_idx]
      curr_mode_right = raw_shape[1][core_idx]
      curr_mode = curr_mode_left * curr_mode_right
    else:
      curr_mode = raw_shape[core_idx]
    
    columns = curr_mode * ranks[core_idx + 1]
    curr_core = jnp.reshape(curr_core, [-1, columns])
    rows = curr_core.shape[0]
    if max_tt_rank[core_idx] == 1:
      ranks[core_idx] = 1
    else:
      ranks[core_idx] = min(max_tt_rank[core_idx], rows, columns)
    u, s, v = jnp.linalg.svd(curr_core, full_matrices=False)
    u = u[:, 0:ranks[core_idx]]
    s = s[0:ranks[core_idx]]
    v = v[0:ranks[core_idx], :]
    if tt.is_tt_matrix:
      core_shape = (ranks[core_idx], curr_mode_left, curr_mode_right,
                    ranks[core_idx + 1])
    else:
      core_shape = (ranks[core_idx], curr_mode, ranks[core_idx + 1])
    tt_cores[core_idx] = jnp.reshape(v, core_shape)
    prev_core_shape = (-1, rows)
    tt_cores[core_idx - 1] = jnp.reshape(tt_cores[core_idx - 1], prev_core_shape)
    tt_cores[core_idx - 1] = jnp.matmul(tt_cores[core_idx - 1], u)
    tt_cores[core_idx - 1] = jnp.matmul(tt_cores[core_idx - 1], jnp.diag(s))

  if tt.is_tt_matrix:
    core_shape = (ranks[0], raw_shape[0][0], raw_shape[1][0], ranks[1])
  else:
    core_shape = (ranks[0], raw_shape[0], ranks[1])
  tt_cores[0] = jnp.reshape(tt_cores[0], core_shape)

  if tt.is_tt_matrix:
    return TTMatrix(tt_cores)
  else:
    return TT(tt_cores)