def __init__(self, tt_cores, shape=None, tt_ranks=None, batch_size=None, convert_to_tensors=True, name="TensorTrainBatch"): """Creates a `TensorTrainBatch`. Args: tt_cores: A tuple of 4d or 5d tensor-like objects of shape `[batch_size, r_k-1, n_k, r_k]` or `[batch_size, r_k-1, n_k, m_k, r_k]` Tensor-like can be numpy array, tf.Tensor, of tf.Variable batch_size: number of elements in the batch. If None, tries to infer from the TT-cores (not always possible even if it should be, e.g. if ranks are unknown, than the whole shape of a core can be unknown). shape: Shape of the underlying tensor. If None, tries to infer from the TT-cores. tt_ranks: a TensorShape of length d+1 (d is the dimensionality of the underlying tensor). The first and the last ranks are assumed to equal to 1. If None, tries to infer the ranks from the cores. convert_to_tensors: bool, if True than convert each element of the tt_cores tuple into a tf.Tensor (e.g. to initialize from np.array) name: The name of ops. Returns: A `TensorTrainBatch`. Raises: ValueError if the provided TT-cores are not valid or inconsistent with the provided shape. """ tt_cores = list(tt_cores) if convert_to_tensors: with tf.name_scope(name): for i in range(len(tt_cores)): name = "core%d" % i tt_cores[i] = tf.convert_to_tensor(tt_cores[i], name=name) if not _are_batch_tt_cores_valid(tt_cores, shape, tt_ranks, batch_size): raise ValueError( 'The tt_cores provided to TensorTrainBatch constructor ' 'are not valid, have different dtypes, or are ' 'inconsistent with the provided batch_size, shape, or ' 'TT-ranks.') self._tt_cores = tuple(tt_cores) if batch_size is None: self._batch_size = tt_cores[0].get_shape()[0].value else: self._batch_size = batch_size self._raw_shape = shapes.clean_raw_shape(shape) if self._raw_shape is None: self._raw_shape = _infer_batch_raw_shape(self._tt_cores) self._tt_ranks = None if tt_ranks is None else tf.TensorShape(tt_ranks) if self._tt_ranks is None: self._tt_ranks = _infer_batch_tt_ranks(self._tt_cores)
def _are_batch_tt_cores_valid(tt_cores, shape, tt_ranks, batch_size): """Check if dimensions of the TT-cores are consistent and the dtypes coincide. Args: tt_cores: a tuple of `Tensor` objects shape: An np.array, a tf.TensorShape (for tensors), a tuple of tf.TensorShapes (for TT-matrices or tensors), or None tt_ranks: An np.array or a tf.TensorShape of length len(tt_cores)+1. batch_size: a number or None Returns: boolean, True if the dimensions and dtypes are consistent. """ shape = shapes.clean_raw_shape(shape) num_dims = len(tt_cores) for core_idx in range(1, num_dims): if tt_cores[core_idx].dtype != tt_cores[0].dtype: return False try: for core_idx in range(num_dims): curr_core_shape = tt_cores[core_idx].get_shape() if len(curr_core_shape) != len(tt_cores[0].get_shape()): # Shapes are inconsistent. return False if batch_size is not None and curr_core_shape[0].value is not None: if curr_core_shape[0].value != batch_size: # The TT-cores are not aligned with the given batch_size. return False if shape is not None: for i in range(len(shape)): if curr_core_shape[i + 2] != shape[i][core_idx]: # The TT-cores are not aligned with the given shape. return False if core_idx >= 1: prev_core_shape = tt_cores[core_idx - 1].get_shape() if curr_core_shape[1] != prev_core_shape[-1]: # TT-ranks are inconsistent. return False if tt_ranks is not None: if curr_core_shape[1] != tt_ranks[core_idx]: # The TT-ranks are not aligned with the TT-cores shape. return False if curr_core_shape[-1] != tt_ranks[core_idx + 1]: # The TT-ranks are not aligned with the TT-cores shape. return False if tt_cores[0].get_shape()[1] != 1 or tt_cores[-1].get_shape( )[-1] != 1: # The first or the last rank is not 1. return False except ValueError: # The shape of the TT-cores is undetermined, can not validate it. pass return True
def __init__(self, tt_cores, shape=None, tt_ranks=None, convert_to_tensors=True): """Creates a `TensorTrain`. Args: tt_cores: A tuple of 3d or 4d tensor-like objects of shape `[r_k-1, n_k, r_k]`. Tensor-like can be numpy array, tf.Tensor, of tf.Variable shape: Shape of the underlying tensor. If None, tries to infer from the cores (not always possible even if it should be, e.g. if ranks are unknown, than the whole shape of a core can be unknown). tt_ranks: a TensorShape of length d+1 (d is the dimensionality of the underlying tensor). The first and the last ranks are assumed to equal to 1. If None, tries to infer the ranks from the cores. convert_to_tensors: bool, if True than convert each element of the tt_cores tuple into a tf.Tensor (e.g. to initialize from np.array) Returns: A `TensorTrain`. Raises: ValueError if the provided TT-cores are not valid or inconsistent with the provided shape. """ tt_cores = list(tt_cores) if convert_to_tensors: # TODO: what does this namescope do? with tf.name_scope("TensorTrain", tt_cores): for i in range(len(tt_cores)): name = "core%d" % i tt_cores[i] = tf.convert_to_tensor(tt_cores[i], name=name) if not _are_tt_cores_valid(tt_cores, shape, tt_ranks): raise ValueError( 'The tt_cores provided to TensorTrain constructor are ' 'not valid, have different dtypes, or are inconsistent ' 'with the provided shape or TT-ranks.') self._tt_cores = tuple(tt_cores) self._raw_shape = shapes.clean_raw_shape(shape) if self._raw_shape is None: self._raw_shape = _infer_raw_shape(self._tt_cores) self._tt_ranks = None if tt_ranks is None else tf.TensorShape(tt_ranks) if self._tt_ranks is None: self._tt_ranks = _infer_tt_ranks(self._tt_cores)