예제 #1
0
    def __init__(self,
                 tt_cores,
                 shape=None,
                 tt_ranks=None,
                 batch_size=None,
                 convert_to_tensors=True,
                 name="TensorTrainBatch"):
        """Creates a `TensorTrainBatch`.

    Args:
      tt_cores: A tuple of 4d or 5d tensor-like objects of shape
        `[batch_size, r_k-1, n_k, r_k]` or
        `[batch_size, r_k-1, n_k, m_k, r_k]`
        Tensor-like can be numpy array, tf.Tensor, of tf.Variable
      batch_size: number of elements in the batch. If None, tries to infer from
        the TT-cores (not always possible even if it should be, e.g. if ranks
        are unknown, than the whole shape of a core can be unknown).
      shape: Shape of the underlying tensor. If None, tries to infer from the
        TT-cores.
      tt_ranks: a TensorShape of length d+1 (d is the dimensionality of
        the underlying tensor). The first and the last ranks are assumed to
        equal to 1. If None, tries to infer the ranks from the cores.
      convert_to_tensors: bool, if True than convert each element of the
        tt_cores tuple into a tf.Tensor (e.g. to initialize from np.array)
      name: The name of ops.

    Returns:
      A `TensorTrainBatch`.

    Raises:
      ValueError if the provided TT-cores are not valid or inconsistent with
        the provided shape.
    """
        tt_cores = list(tt_cores)
        if convert_to_tensors:
            with tf.name_scope(name):
                for i in range(len(tt_cores)):
                    name = "core%d" % i
                    tt_cores[i] = tf.convert_to_tensor(tt_cores[i], name=name)

        if not _are_batch_tt_cores_valid(tt_cores, shape, tt_ranks,
                                         batch_size):
            raise ValueError(
                'The tt_cores provided to TensorTrainBatch constructor '
                'are not valid, have different dtypes, or are '
                'inconsistent with the provided batch_size, shape, or '
                'TT-ranks.')

        self._tt_cores = tuple(tt_cores)
        if batch_size is None:
            self._batch_size = tt_cores[0].get_shape()[0].value
        else:
            self._batch_size = batch_size
        self._raw_shape = shapes.clean_raw_shape(shape)
        if self._raw_shape is None:
            self._raw_shape = _infer_batch_raw_shape(self._tt_cores)
        self._tt_ranks = None if tt_ranks is None else tf.TensorShape(tt_ranks)
        if self._tt_ranks is None:
            self._tt_ranks = _infer_batch_tt_ranks(self._tt_cores)
예제 #2
0
def _are_batch_tt_cores_valid(tt_cores, shape, tt_ranks, batch_size):
    """Check if dimensions of the TT-cores are consistent and the dtypes coincide.

  Args:
    tt_cores: a tuple of `Tensor` objects
    shape: An np.array, a tf.TensorShape (for tensors), a tuple of
      tf.TensorShapes (for TT-matrices or tensors), or None
    tt_ranks: An np.array or a tf.TensorShape of length len(tt_cores)+1.
    batch_size: a number or None

  Returns:
    boolean, True if the dimensions and dtypes are consistent.
  """
    shape = shapes.clean_raw_shape(shape)
    num_dims = len(tt_cores)

    for core_idx in range(1, num_dims):
        if tt_cores[core_idx].dtype != tt_cores[0].dtype:
            return False
    try:
        for core_idx in range(num_dims):
            curr_core_shape = tt_cores[core_idx].get_shape()
            if len(curr_core_shape) != len(tt_cores[0].get_shape()):
                # Shapes are inconsistent.
                return False
            if batch_size is not None and curr_core_shape[0].value is not None:
                if curr_core_shape[0].value != batch_size:
                    # The TT-cores are not aligned with the given batch_size.
                    return False
            if shape is not None:
                for i in range(len(shape)):
                    if curr_core_shape[i + 2] != shape[i][core_idx]:
                        # The TT-cores are not aligned with the given shape.
                        return False
            if core_idx >= 1:
                prev_core_shape = tt_cores[core_idx - 1].get_shape()
                if curr_core_shape[1] != prev_core_shape[-1]:
                    # TT-ranks are inconsistent.
                    return False
            if tt_ranks is not None:
                if curr_core_shape[1] != tt_ranks[core_idx]:
                    # The TT-ranks are not aligned with the TT-cores shape.
                    return False
                if curr_core_shape[-1] != tt_ranks[core_idx + 1]:
                    # The TT-ranks are not aligned with the TT-cores shape.
                    return False
        if tt_cores[0].get_shape()[1] != 1 or tt_cores[-1].get_shape(
        )[-1] != 1:
            # The first or the last rank is not 1.
            return False
    except ValueError:
        # The shape of the TT-cores is undetermined, can not validate it.
        pass
    return True
예제 #3
0
    def __init__(self,
                 tt_cores,
                 shape=None,
                 tt_ranks=None,
                 convert_to_tensors=True):
        """Creates a `TensorTrain`.

    Args:
      tt_cores: A tuple of 3d or 4d tensor-like objects of shape
        `[r_k-1, n_k, r_k]`.
        Tensor-like can be numpy array, tf.Tensor, of tf.Variable
      shape: Shape of the underlying tensor. If None, tries to infer from the
        cores (not always possible even if it should be, e.g. if ranks are
        unknown, than the whole shape of a core can be unknown).
      tt_ranks: a TensorShape of length d+1 (d is the dimensionality of
        the underlying tensor). The first and the last ranks are assumed to
        equal to 1. If None, tries to infer the ranks from the cores.
      convert_to_tensors: bool, if True than convert each element of the
        tt_cores tuple into a tf.Tensor (e.g. to initialize from np.array)

    Returns:
      A `TensorTrain`.

    Raises:
      ValueError if the provided TT-cores are not valid or inconsistent with
        the provided shape.
    """
        tt_cores = list(tt_cores)
        if convert_to_tensors:
            # TODO: what does this namescope do?
            with tf.name_scope("TensorTrain", tt_cores):
                for i in range(len(tt_cores)):
                    name = "core%d" % i
                    tt_cores[i] = tf.convert_to_tensor(tt_cores[i], name=name)

        if not _are_tt_cores_valid(tt_cores, shape, tt_ranks):
            raise ValueError(
                'The tt_cores provided to TensorTrain constructor are '
                'not valid, have different dtypes, or are inconsistent '
                'with the provided shape or TT-ranks.')

        self._tt_cores = tuple(tt_cores)
        self._raw_shape = shapes.clean_raw_shape(shape)
        if self._raw_shape is None:
            self._raw_shape = _infer_raw_shape(self._tt_cores)
        self._tt_ranks = None if tt_ranks is None else tf.TensorShape(tt_ranks)
        if self._tt_ranks is None:
            self._tt_ranks = _infer_tt_ranks(self._tt_cores)