def _mul_by_scalar(tt, c): tt = unwrap_tt(tt) cores = list(tt.tt_cores) cores[0] = c * cores[0] if tt.is_tt_matrix: return TTMatrix(cores) else: return TT(cores)
def deltas_to_tangent(deltas: List[jnp.ndarray], tt: TTTensOrMat) -> TTTensOrMat: """Converts deltas representation of tangent space vector to `TT-object`. Takes as input a list of [dP1, ..., dPd] and returns dP1 V2 ... Vd + U1 dP2 V3 ... Vd + ... + U1 ... Ud-1 dPd. This function is hard to use correctly because deltas should obey the so called gauge conditions. If they don't, the function will silently return incorrect result. That is why this function is not imported in __init__. :param deltas: a list of deltas (essentially `TT-cores`) obeying the gauge conditions. :param tt: object on which the tangent space tensor represented by delta is projected. :type tt: `TT-Tensor` or `TT-Matrix` :return: object constructed from deltas, that is from the tangent space at point `tt`. :rtype: `TT-Tensor` or `TT-Matrix` """ cores = [] dtype = tt.dtype left = orthogonalize(tt) right = orthogonalize(left, left_to_right=False) left_rank_dim = 0 right_rank_dim = 3 if tt.is_tt_matrix else 2 for i in range(tt.ndim): left_tt_core = left.tt_cores[i] right_tt_core = right.tt_cores[i] if i == 0: tangent_core = jnp.concatenate((deltas[i], left_tt_core), axis=right_rank_dim) elif i == tt.ndim - 1: tangent_core = jnp.concatenate((right_tt_core, deltas[i]), axis=left_rank_dim) else: rank_1 = right.tt_ranks[i] rank_2 = left.tt_ranks[i + 1] if tt.is_tt_matrix: mode_size_n = tt.raw_tensor_shape[0][i] mode_size_m = tt.raw_tensor_shape[1][i] shape = [rank_1, mode_size_n, mode_size_m, rank_2] else: mode_size_n = tt.shape[i] shape = [rank_1, mode_size_n, rank_2] zeros = jnp.zeros(shape, dtype=dtype) upper = jnp.concatenate((right_tt_core, zeros), axis=right_rank_dim) lower = jnp.concatenate((deltas[i], left_tt_core), axis=right_rank_dim) tangent_core = jnp.concatenate((upper, lower), axis=left_rank_dim) cores.append(tangent_core) if tt.is_tt_matrix: return TTMatrix(cores) else: return TT(cores)
def vectorized_func(*args, **kwargs): tt_arg = args[0] # TODO: what if only kwargs are present? if tt_arg.num_batch_dims == 0: return func(*args, **kwargs) else: if num_batch_args is not None: num_non_batch_args = len(args) + len( kwargs) - num_batch_args in_axis = [0] * num_batch_args + [None ] * num_non_batch_args num_args = num_batch_args else: num_args = len(args) + len(kwargs) in_axis = [0] * num_args if num_args > 1 and (isinstance(args[1], TTMatrix) or isinstance(args[1], TT)): if args[0].is_tt_matrix != args[1].is_tt_matrix: raise ValueError( 'Types of the arguments are different.') if not are_batches_broadcastable(args[0], args[1]): raise ValueError( 'The batch sizes are different and not 1, ' 'broadcasting is not available.') broadcast_shape = np.maximum(list(args[0].batch_shape), list(args[1].batch_shape)) new_args = list(args) if args[0].is_tt_matrix: for i, tt in enumerate(args[:2]): new_cores = [] for core in tt.tt_cores: core = jnp.broadcast_to( core, list(broadcast_shape) + list(core.shape[-4:])) new_cores.append(core) new_args[i] = TTMatrix(new_cores) else: for i, tt in enumerate(args[:2]): new_cores = [] for core in tt.tt_cores: core = jnp.broadcast_to( core, list(broadcast_shape) + list(core.shape[-3:])) new_cores.append(core) new_args[i] = TT(new_cores) else: new_args = args vmapped = func for _ in range(tt_arg.num_batch_dims): vmapped = jax.vmap(vmapped, in_axis) return vmapped(*new_args, **kwargs)
def testRound2d(self): dtype = jnp.float32 rank = 5 np.random.seed(0) x = np.random.randn(10, 20).astype(dtype) u, s, v = np.linalg.svd(x, full_matrices=False) core_1 = u @ np.diag(s) core_1 = core_1.reshape(1, 10, 10) core_2 = v core_2 = core_2.reshape(10, 20, 1) tt = TT((core_1, core_2)) truncated_x = u[:, :rank] @ np.diag(s[:rank]) @ v[:rank, :] rounded = decompositions.round(tt, 5) self.assertAllClose(truncated_x, ops.full(rounded), rtol=1e-5, atol=1e-5)
def vector_to_tensor(tt): """Converts TT-matrix to TT-tensor, if matrix has shape N x 1 or 1 x N :type tt: `TT-Matrix` :param tt: TT-matrix :return: `TT` :rtype: `TT` :raises [ValueError]: if the argument is not a TT-matrix, or if matrix has wrong shape """ if not isinstance(tt, TTMatrix) or not tt.is_tt_matrix: raise ValueError('The argument should be a TT-matrix') if tt.shape[0] != 1 and tt.shape[1] != 1: raise ValueError( 'At least one of matrix dimensions should be equal to one') cores = [] for core in tt.tt_cores: if tt.shape[0] == 1: cores.append(jnp.squeeze(core, 1)) else: cores.append(jnp.squeeze(core, 2)) return TT(cores)
def tensor(rng, shape, tt_rank=2, batch_shape=None, dtype=jnp.float32): """Generate a random `TT-Tensor` of the given shape and `TT-rank`. :param rng: JAX PRNG key :type rng: random state is described by two unsigned 32-bit integers :param shape: desired tensor shape :type shape: array :param tt_rank: desired `TT-ranks` of `TT-Tensor` :type tt_rank: single number for equal `TT-ranks` or array specifying all `TT-ranks` :param batch_shape: desired batch shape of `TT-Tensor` :type batch_shape: array :param dtype: type of elements in `TT-Tensor` :type dtype: `dtype` :return: generated `TT-Tensor` :rtype: TT """ shape = np.array(shape) tt_rank = np.array(tt_rank) batch_shape = list(batch_shape) if batch_shape else [] num_dims = shape.size if tt_rank.size == 1: tt_rank = tt_rank * np.ones(num_dims - 1) tt_rank = np.insert(tt_rank, 0, 1) tt_rank = np.append(tt_rank, 1) tt_rank = tt_rank.astype(int) tt_cores = [] rng_arr = jax.random.split(rng, num_dims) for i in range(num_dims): curr_core_shape = [tt_rank[i], shape[i], tt_rank[i + 1]] curr_core_shape = batch_shape + curr_core_shape tt_cores.append( jax.random.normal(rng_arr[i], curr_core_shape, dtype=dtype)) return TT(tt_cores)
def new_func(*args): are_tt_matrix_inputs = args[0].is_tt_matrix tt_einsum_ = tt_einsum.resolve_i_or_ij(are_tt_matrix_inputs) is_fusing = any([isinstance(tt, WrappedTT) for tt in args]) if is_fusing: # Have to use a different name to make upper level tt_einsum visible. tt_einsum_, args = _fuse_tt_einsums(tt_einsum_, args) einsum = tt_einsum_.to_vanilla_einsum() num_batch_dims = args[0].num_batch_dims # TODO: support broadcasting. res_batch_shape = list(args[0].batch_shape) # TODO: do in parallel w.r.t. cores. # TODO: use optimal einsum. res_cores = [] for i in range(len(args[0].tt_cores)): curr_input_cores = [tt.tt_cores[i] for tt in args] core = oe.contract(einsum, *curr_input_cores, backend='jax') shape = core.shape[num_batch_dims:] num_left_rank_dims = len(tt_einsum_.output[0]) num_tensor_dims = len(tt_einsum_.output[1]) split_points = (num_left_rank_dims, num_left_rank_dims + num_tensor_dims) new_shape = np.split(shape, split_points) left_rank = np.prod(new_shape[0]) right_rank = np.prod(new_shape[2]) new_shape = [left_rank] + new_shape[1].tolist() + [right_rank] new_shape = res_batch_shape + new_shape res_cores.append(core.reshape(new_shape)) if are_tt_matrix_inputs: res = TTMatrix(res_cores) else: res = TT(res_cores) if is_fusing: res = WrappedTT(res, args, tt_einsum_) return res
def add(tt_a, tt_b): """Returns a `TT-object` corresponding to elementwise sum `tt_a + tt_b`. The shapes of `tt_a` and `tt_b` should coincide. Supports broadcasting, e.g. you can add a tensor train with batch size 7 and a tensor train with batch size 1: ``tt_batch.add(tt_single.batch_loc[np.newaxis])`` where ``tt_single.batch_loc[np.newaxis]`` creates a singleton batch dimension. :type tt_a: `TT-Tensor` or `TT-Matrix` :param tt_a: first argument :type tt_b: `TT-Tensor` or `TT-Matrix` :param tt_b: second argument :rtype: `TT-Tensor` or `TT-Matrix` :return: `tt_a + tt_b` :raises [ValueError]: if the arguments shapes do not coincide """ tt_a = unwrap_tt(tt_a) tt_b = unwrap_tt(tt_b) if not are_shapes_equal(tt_a, tt_b): raise ValueError('Types of the arguments or their tensor ' 'shapes are different, addition is not ' 'available.') if not are_batches_broadcastable(tt_a, tt_b): raise ValueError('The batch sizes are different and not 1, ' 'broadcasting is not available.') if tt_a.is_tt_matrix: tt_cores = _add_matrix_cores(tt_a, tt_b) return TTMatrix(tt_cores) else: tt_cores = _add_tensor_cores(tt_a, tt_b) return TT(tt_cores)
def _orthogonalize_tt_cores_right_to_left(tt): """Orthogonalize TT-cores of a TT-object. Args: tt: TT-tensor or TT-matrix. Returns: TT-tensor or TT-matrix. """ # Right to left orthogonalization. num_dims = tt.ndim if tt.is_tt_matrix: raw_shape = tt.raw_tensor_shape else: raw_shape = tt.shape tt_ranks = tt.tt_ranks prev_rank = tt_ranks[num_dims] # Copy cores references so we can change the cores. tt_cores = list(tt.tt_cores) for core_idx in range(num_dims - 1, 0, -1): curr_core = tt_cores[core_idx] # TT-ranks could have changed on the previous iteration, so `tt_ranks` can # be outdated for the current TT-rank, but should be valid for the next # TT-rank. curr_rank = prev_rank prev_rank = tt_ranks[core_idx] if tt.is_tt_matrix: curr_mode_left = raw_shape[0][core_idx] curr_mode_right = raw_shape[1][core_idx] curr_mode = curr_mode_left * curr_mode_right else: curr_mode = raw_shape[core_idx] qr_shape = (prev_rank, curr_mode * curr_rank) curr_core = jnp.reshape(curr_core, qr_shape) curr_core, triang = jnp.linalg.qr(curr_core.T) curr_core = curr_core.T triang = triang.T triang_shape = triang.shape # The TT-rank could have changed: if qr_shape is e.g. 4 x 10, than q would # be of size 4 x 4 and r would be 4 x 10, which means that the next rank # should be changed to 4. prev_rank = triang_shape[1] if tt.is_tt_matrix: new_core_shape = (prev_rank, curr_mode_left, curr_mode_right, curr_rank) else: new_core_shape = (prev_rank, curr_mode, curr_rank) tt_cores[core_idx] = jnp.reshape(curr_core, new_core_shape) prev_core = jnp.reshape(tt_cores[core_idx - 1], (-1, triang_shape[0])) tt_cores[core_idx - 1] = jnp.matmul(prev_core, triang) if tt.is_tt_matrix: first_core_shape = (1, raw_shape[0][0], raw_shape[1][0], prev_rank) else: first_core_shape = (1, raw_shape[0], prev_rank) tt_cores[0] = jnp.reshape(tt_cores[0], first_core_shape) if tt.is_tt_matrix: return TTMatrix(tt_cores) else: return TT(tt_cores)
def _orthogonalize_tt_cores_left_to_right(tt): """Orthogonalize TT-cores of a TT-object. Args: tt: TT-tensor or TT-matrix. TT-tensor or TT-matrix. """ # Left to right orthogonalization. num_dims = tt.ndim if tt.is_tt_matrix: raw_shape = tt.raw_tensor_shape else: raw_shape = tt.shape tt_ranks = tt.tt_ranks next_rank = tt_ranks[0] # Copy cores references so we can change the cores. tt_cores = list(tt.tt_cores) for core_idx in range(num_dims - 1): curr_core = tt_cores[core_idx] # TT-ranks could have changed on the previous iteration, so `tt_ranks` can # be outdated for the current TT-rank, but should be valid for the next # TT-rank. curr_rank = next_rank next_rank = tt_ranks[core_idx + 1] if tt.is_tt_matrix: curr_mode_left = raw_shape[0][core_idx] curr_mode_right = raw_shape[1][core_idx] curr_mode = curr_mode_left * curr_mode_right else: curr_mode = raw_shape[core_idx] qr_shape = (curr_rank * curr_mode, next_rank) curr_core = jnp.reshape(curr_core, qr_shape) curr_core, triang = jnp.linalg.qr(curr_core) triang_shape = triang.shape # The TT-rank could have changed: if qr_shape is e.g. 4 x 10, than q would # be of size 4 x 4 and r would be 4 x 10, which means that the next rank # should be changed to 4. next_rank = triang_shape[0] if tt.is_tt_matrix: new_core_shape = (curr_rank, curr_mode_left, curr_mode_right, next_rank) else: new_core_shape = (curr_rank, curr_mode, next_rank) tt_cores[core_idx] = jnp.reshape(curr_core, new_core_shape) next_core = jnp.reshape(tt_cores[core_idx + 1], (triang_shape[1], -1)) tt_cores[core_idx + 1] = jnp.matmul(triang, next_core) if tt.is_tt_matrix: last_core_shape = (next_rank, raw_shape[0][-1], raw_shape[1][-1], 1) else: last_core_shape = (next_rank, raw_shape[-1], 1) tt_cores[-1] = jnp.reshape(tt_cores[-1], last_core_shape) if tt.is_tt_matrix: return TTMatrix(tt_cores) else: return TT(tt_cores)
def round(tt, max_tt_rank=None, epsilon=None): """Tensor Train rounding procedure, returns a `TT-object` with smaller `TT-ranks`. :param tt: argument which ranks would be reduced :type tt: `TT-Tensor` or `TT-Matrix` :type max_tt_rank: int or list of ints :param max_tt_rank: - If a number, than defines the maximal `TT-rank` of the result. - If a list of numbers, than `max_tt_rank` length should be d+1 (where d is the number of dimensions) and `max_tt_rank[i]` defines the maximal (i+1)-th `TT-rank` of the result. The following two versions are equivalent - ``max_tt_rank = r`` - ``max_tt_rank = [1] + [r] * (d-1) + [1]`` :type epsilon: float or None :param epsilon: - If the `TT-ranks` are not restricted (`max_tt_rank=None`), then the result would be guarantied to be `epsilon`-close to `tt` in terms of relative Frobenius error: `||res - tt||_F / ||tt||_F <= epsilon` - If the `TT-ranks` are restricted, providing a loose `epsilon` may reduce the `TT-ranks` of the result. E.g. ``round(tt, max_tt_rank=100, epsilon=0.9)`` will probably return you a `TT-Tensor` with `TT-ranks` close to 1, not 100. Note that providing a nontrivial (= not equal to `None`) epsilon will make the `TT-ranks` of the result change depending on the data, which will prevent you from using ``jax.jit`` for speeding up the computations. :return: `TT-object` with reduced `TT-ranks` :rtype: `TT-Tensor` or `TT-Matrix` :raises: ValueError if `max_tt_rank` is less than 0, if `max_tt_rank` is not a number and not a vector of length d + 1 where d is the number of dimensions of the input tensor, if `epsilon` is less than 0. """ if max_tt_rank is None: max_tt_rank = np.iinfo(np.int32).max num_dims = tt.ndim max_tt_rank = np.array(max_tt_rank).astype(np.int32) if np.any(max_tt_rank < 1): raise ValueError('Maximum TT-rank should be greater or equal to 1.') if epsilon is not None: raise NotImplementedError('Epsilon is not supported yet.') if max_tt_rank.size == 1: max_tt_rank = (max_tt_rank * np.ones(num_dims + 1)).astype(jnp.int32) elif max_tt_rank.size != num_dims + 1: raise ValueError('max_tt_rank should be a number or a vector of size (d+1) ' 'where d is the number of dimensions of the tensor.') if tt.is_tt_matrix: raw_shape = tt.raw_tensor_shape else: raw_shape = tt.shape tt_cores = orthogonalize(tt).tt_cores # Copy cores references so we can change the cores. tt_cores = list(tt_cores) ranks = [1] * (num_dims + 1) # Right to left SVD compression. for core_idx in range(num_dims - 1, 0, -1): curr_core = tt_cores[core_idx] if tt.is_tt_matrix: curr_mode_left = raw_shape[0][core_idx] curr_mode_right = raw_shape[1][core_idx] curr_mode = curr_mode_left * curr_mode_right else: curr_mode = raw_shape[core_idx] columns = curr_mode * ranks[core_idx + 1] curr_core = jnp.reshape(curr_core, [-1, columns]) rows = curr_core.shape[0] if max_tt_rank[core_idx] == 1: ranks[core_idx] = 1 else: ranks[core_idx] = min(max_tt_rank[core_idx], rows, columns) u, s, v = jnp.linalg.svd(curr_core, full_matrices=False) u = u[:, 0:ranks[core_idx]] s = s[0:ranks[core_idx]] v = v[0:ranks[core_idx], :] if tt.is_tt_matrix: core_shape = (ranks[core_idx], curr_mode_left, curr_mode_right, ranks[core_idx + 1]) else: core_shape = (ranks[core_idx], curr_mode, ranks[core_idx + 1]) tt_cores[core_idx] = jnp.reshape(v, core_shape) prev_core_shape = (-1, rows) tt_cores[core_idx - 1] = jnp.reshape(tt_cores[core_idx - 1], prev_core_shape) tt_cores[core_idx - 1] = jnp.matmul(tt_cores[core_idx - 1], u) tt_cores[core_idx - 1] = jnp.matmul(tt_cores[core_idx - 1], jnp.diag(s)) if tt.is_tt_matrix: core_shape = (ranks[0], raw_shape[0][0], raw_shape[1][0], ranks[1]) else: core_shape = (ranks[0], raw_shape[0], ranks[1]) tt_cores[0] = jnp.reshape(tt_cores[0], core_shape) if tt.is_tt_matrix: return TTMatrix(tt_cores) else: return TT(tt_cores)