def _full_tt_batch(tt): """Converts a TensorTrainBatch into a regular tensor or matrix (tf.Tensor). Args: tt: `TensorTrainBatch` object. Returns: tf.Tensor. """ num_dims = tt.ndims() ranks = shapes.lazy_tt_ranks(tt) shape = shapes.lazy_shape(tt) raw_shape = shapes.lazy_raw_shape(tt) res = tt.tt_cores[0] batch_size = shapes.lazy_batch_size(tt) for i in range(1, num_dims): res = tf.reshape(res, (batch_size, -1, ranks[i])) curr_core = tf.reshape(tt.tt_cores[i], (batch_size, ranks[i], -1)) res = tf.einsum('oqb,obw->oqw', res, curr_core) if tt.is_tt_matrix(): intermediate_shape = [batch_size] for i in range(num_dims): intermediate_shape.append(raw_shape[0][i]) intermediate_shape.append(raw_shape[1][i]) res = tf.reshape(res, intermediate_shape) transpose = [0] for i in range(0, 2 * num_dims, 2): transpose.append(i + 1) for i in range(1, 2 * num_dims, 2): transpose.append(i + 1) res = tf.transpose(res, transpose) return tf.reshape(res, shape) else: return tf.reshape(res, shape)
def _full_tt(tt): """Converts a TensorTrain into a regular tensor or matrix (tf.Tensor). Args: tt: `TensorTrain` object. Returns: tf.Tensor. """ num_dims = tt.ndims() ranks = shapes.lazy_tt_ranks(tt) shape = shapes.lazy_shape(tt) raw_shape = shapes.lazy_raw_shape(tt) res = tt.tt_cores[0] for i in range(1, num_dims): res = tf.reshape(res, (-1, ranks[i])) curr_core = tf.reshape(tt.tt_cores[i], (ranks[i], -1)) res = tf.matmul(res, curr_core) if tt.is_tt_matrix(): intermediate_shape = [] for i in range(num_dims): intermediate_shape.append(raw_shape[0][i]) intermediate_shape.append(raw_shape[1][i]) res = tf.reshape(res, intermediate_shape) transpose = [] for i in range(0, 2 * num_dims, 2): transpose.append(i) for i in range(1, 2 * num_dims, 2): transpose.append(i) res = tf.transpose(res, transpose) return tf.reshape(res, shape) else: return tf.reshape(res, shape)
def tt_dense_matmul(tt_matrix_a, matrix_b): """Multiplies a TT-matrix by a regular matrix, returns a regular matrix. Args: tt_matrix_a: `TensorTrain` object containing a TT-matrix of size M x N matrix_b: tf.Tensor of size N x P Returns tf.Tensor of size M x P """ if not isinstance(tt_matrix_a, TensorTrain) or not tt_matrix_a.is_tt_matrix(): raise ValueError('The first argument should be a TT-matrix') ndims = tt_matrix_a.ndims() a_columns = tt_matrix_a.get_shape()[1].value b_rows = matrix_b.get_shape()[0].value if a_columns is not None and b_rows is not None: if a_columns != b_rows: raise ValueError( 'Arguments shapes should align got %d and %d instead.' % (tt_matrix_a.get_shape(), matrix_b.get_shape())) a_shape = shapes.lazy_shape(tt_matrix_a) a_raw_shape = shapes.lazy_raw_shape(tt_matrix_a) if matrix_b.get_shape().is_fully_defined(): b_shape = matrix_b.get_shape().as_list() else: b_shape = tf.shape(matrix_b) a_ranks = shapes.lazy_tt_ranks(tt_matrix_a) # If A is (i0, ..., id-1) x (j0, ..., jd-1) and B is (j0, ..., jd-1) x K, # data is (K, j0, ..., jd-2) x jd-1 x 1 data = tf.transpose(matrix_b) data = tf.reshape(data, (-1, a_raw_shape[1][-1], 1)) for core_idx in reversed(range(ndims)): curr_core = tt_matrix_a.tt_cores[core_idx] # On the k = core_idx iteration, after applying einsum the shape of data # becomes ik x (ik-1..., id-1, K, j0, ..., jk-1) x rank_k data = tf.einsum('aijb,rjb->ira', curr_core, data) if core_idx > 0: # After reshape the shape of data becomes # (ik, ..., id-1, K, j0, ..., jk-2) x jk-1 x rank_k new_data_shape = (-1, a_raw_shape[1][core_idx - 1], a_ranks[core_idx]) data = tf.reshape(data, new_data_shape) # At the end the shape of the data is (i0, ..., id-1) x K return tf.reshape(data, (a_shape[0], b_shape[1]))
def _full_tt(tt): """Converts a TensorTrain into a regular tensor or matrix (tf.Tensor). Args: tt: `TensorTrain` object. Returns: tf.Tensor. """ num_dims = tt.ndims() ranks = shapes.lazy_tt_ranks(tt) shape = shapes.lazy_shape(tt) raw_shape = shapes.lazy_raw_shape(tt) quan_core_list = [] for i in range(num_dims): print('FP core:', tt.tt_cores[i]) if i == 0 or i == num_dims - 1: quan_core_list.append(fw(tt.tt_cores[i])) else: quan_core_list.append(fw(tt.tt_cores[i])) print('8bit core:', quan_core_list[i]) res = quan_core_list[0] # Quan first core for i in range(1, num_dims): res = tf.reshape(res, (-1, ranks[i])) curr_core = tf.reshape(quan_core_list[i], (ranks[i], -1)) res = tf.matmul(res, curr_core) print('core multi FP: ', res) # Quan mult cores res = fw(res) print('core multi 8bit: ', res) if tt.is_tt_matrix(): intermediate_shape = [] for i in range(num_dims): intermediate_shape.append(raw_shape[0][i]) intermediate_shape.append(raw_shape[1][i]) res = tf.reshape(res, intermediate_shape) transpose = [] for i in range(0, 2 * num_dims, 2): transpose.append(i) for i in range(1, 2 * num_dims, 2): transpose.append(i) res = tf.transpose(res, transpose) return tf.reshape(res, shape) else: return tf.reshape(res, shape)
def testLazyShapeOverflow(self): large_shape = [10] * 20 tensor = initializers.random_matrix_batch([large_shape, large_shape], batch_size=5, dtype=self.dtype) self.assertAllEqual([5, 10**20, 10**20], shapes.lazy_shape(tensor))