Ejemplo n.º 1
0
Archivo: ops.py Proyecto: vseledkin/t3f
def _full_tt_batch(tt):
    """Converts a TensorTrainBatch into a regular tensor or matrix (tf.Tensor).

  Args:
    tt: `TensorTrainBatch` object.

  Returns:
    tf.Tensor.
  """
    num_dims = tt.ndims()
    ranks = shapes.lazy_tt_ranks(tt)
    shape = shapes.lazy_shape(tt)
    raw_shape = shapes.lazy_raw_shape(tt)

    res = tt.tt_cores[0]
    batch_size = shapes.lazy_batch_size(tt)
    for i in range(1, num_dims):
        res = tf.reshape(res, (batch_size, -1, ranks[i]))
        curr_core = tf.reshape(tt.tt_cores[i], (batch_size, ranks[i], -1))
        res = tf.einsum('oqb,obw->oqw', res, curr_core)
    if tt.is_tt_matrix():
        intermediate_shape = [batch_size]
        for i in range(num_dims):
            intermediate_shape.append(raw_shape[0][i])
            intermediate_shape.append(raw_shape[1][i])
        res = tf.reshape(res, intermediate_shape)
        transpose = [0]
        for i in range(0, 2 * num_dims, 2):
            transpose.append(i + 1)
        for i in range(1, 2 * num_dims, 2):
            transpose.append(i + 1)
        res = tf.transpose(res, transpose)
        return tf.reshape(res, shape)
    else:
        return tf.reshape(res, shape)
Ejemplo n.º 2
0
Archivo: ops.py Proyecto: vseledkin/t3f
def _full_tt(tt):
    """Converts a TensorTrain into a regular tensor or matrix (tf.Tensor).

  Args:
    tt: `TensorTrain` object.

  Returns:
    tf.Tensor.
  """
    num_dims = tt.ndims()
    ranks = shapes.lazy_tt_ranks(tt)
    shape = shapes.lazy_shape(tt)
    raw_shape = shapes.lazy_raw_shape(tt)

    res = tt.tt_cores[0]
    for i in range(1, num_dims):
        res = tf.reshape(res, (-1, ranks[i]))
        curr_core = tf.reshape(tt.tt_cores[i], (ranks[i], -1))
        res = tf.matmul(res, curr_core)
    if tt.is_tt_matrix():
        intermediate_shape = []
        for i in range(num_dims):
            intermediate_shape.append(raw_shape[0][i])
            intermediate_shape.append(raw_shape[1][i])
        res = tf.reshape(res, intermediate_shape)
        transpose = []
        for i in range(0, 2 * num_dims, 2):
            transpose.append(i)
        for i in range(1, 2 * num_dims, 2):
            transpose.append(i)
        res = tf.transpose(res, transpose)
        return tf.reshape(res, shape)
    else:
        return tf.reshape(res, shape)
Ejemplo n.º 3
0
Archivo: ops.py Proyecto: vseledkin/t3f
def tt_dense_matmul(tt_matrix_a, matrix_b):
    """Multiplies a TT-matrix by a regular matrix, returns a regular matrix.

  Args:
    tt_matrix_a: `TensorTrain` object containing a TT-matrix of size M x N
    matrix_b: tf.Tensor of size N x P

  Returns
    tf.Tensor of size M x P
  """
    if not isinstance(tt_matrix_a,
                      TensorTrain) or not tt_matrix_a.is_tt_matrix():
        raise ValueError('The first argument should be a TT-matrix')

    ndims = tt_matrix_a.ndims()
    a_columns = tt_matrix_a.get_shape()[1].value
    b_rows = matrix_b.get_shape()[0].value
    if a_columns is not None and b_rows is not None:
        if a_columns != b_rows:
            raise ValueError(
                'Arguments shapes should align got %d and %d instead.' %
                (tt_matrix_a.get_shape(), matrix_b.get_shape()))

    a_shape = shapes.lazy_shape(tt_matrix_a)
    a_raw_shape = shapes.lazy_raw_shape(tt_matrix_a)
    if matrix_b.get_shape().is_fully_defined():
        b_shape = matrix_b.get_shape().as_list()
    else:
        b_shape = tf.shape(matrix_b)
    a_ranks = shapes.lazy_tt_ranks(tt_matrix_a)
    # If A is (i0, ..., id-1) x (j0, ..., jd-1) and B is (j0, ..., jd-1) x K,
    # data is (K, j0, ..., jd-2) x jd-1 x 1
    data = tf.transpose(matrix_b)
    data = tf.reshape(data, (-1, a_raw_shape[1][-1], 1))
    for core_idx in reversed(range(ndims)):
        curr_core = tt_matrix_a.tt_cores[core_idx]
        # On the k = core_idx iteration, after applying einsum the shape of data
        # becomes ik x (ik-1..., id-1, K, j0, ..., jk-1) x rank_k
        data = tf.einsum('aijb,rjb->ira', curr_core, data)
        if core_idx > 0:
            # After reshape the shape of data becomes
            # (ik, ..., id-1, K, j0, ..., jk-2) x jk-1 x rank_k
            new_data_shape = (-1, a_raw_shape[1][core_idx - 1],
                              a_ranks[core_idx])
            data = tf.reshape(data, new_data_shape)
    # At the end the shape of the data is (i0, ..., id-1) x K
    return tf.reshape(data, (a_shape[0], b_shape[1]))
Ejemplo n.º 4
0
def _full_tt(tt):
    """Converts a TensorTrain into a regular tensor or matrix (tf.Tensor).

  Args:
    tt: `TensorTrain` object.

  Returns:
    tf.Tensor.
  """
    num_dims = tt.ndims()
    ranks = shapes.lazy_tt_ranks(tt)
    shape = shapes.lazy_shape(tt)
    raw_shape = shapes.lazy_raw_shape(tt)

    quan_core_list = []
    for i in range(num_dims):
        print('FP core:', tt.tt_cores[i])
        if i == 0 or i == num_dims - 1:
            quan_core_list.append(fw(tt.tt_cores[i]))
        else:
            quan_core_list.append(fw(tt.tt_cores[i]))
        print('8bit core:', quan_core_list[i])
    res = quan_core_list[0]
    # Quan first core
    for i in range(1, num_dims):
        res = tf.reshape(res, (-1, ranks[i]))
        curr_core = tf.reshape(quan_core_list[i], (ranks[i], -1))
        res = tf.matmul(res, curr_core)
        print('core multi FP: ', res)
        # Quan mult cores
        res = fw(res)
        print('core multi 8bit: ', res)
    if tt.is_tt_matrix():
        intermediate_shape = []
        for i in range(num_dims):
            intermediate_shape.append(raw_shape[0][i])
            intermediate_shape.append(raw_shape[1][i])
        res = tf.reshape(res, intermediate_shape)
        transpose = []
        for i in range(0, 2 * num_dims, 2):
            transpose.append(i)
        for i in range(1, 2 * num_dims, 2):
            transpose.append(i)
        res = tf.transpose(res, transpose)
        return tf.reshape(res, shape)
    else:
        return tf.reshape(res, shape)
Ejemplo n.º 5
0
 def testLazyShapeOverflow(self):
     large_shape = [10] * 20
     tensor = initializers.random_matrix_batch([large_shape, large_shape],
                                               batch_size=5,
                                               dtype=self.dtype)
     self.assertAllEqual([5, 10**20, 10**20], shapes.lazy_shape(tensor))