def _testErrorWithShapesEager(self, exception_regex, superdiag_shape, maindiag_shape, subdiag_shape, rhs_shape): with context.eager_mode(): superdiag = array_ops.ones(superdiag_shape) maindiag = array_ops.ones(maindiag_shape) subdiag = array_ops.ones(subdiag_shape) rhs = array_ops.ones(rhs_shape) with self.assertRaisesRegex(errors_impl.InvalidArgumentError, exception_regex): linalg_ops.tridiagonal_mat_mul(superdiag, maindiag, subdiag, rhs)
def tridiagonal_matmul(diagonals, rhs, diagonals_format='compact', name=None): r"""Multiplies tridiagonal matrix by matrix. `diagonals` is representation of 3-diagonal NxN matrix, which depends on `diagonals_format`. In `matrix` format, `diagonals` must be a tensor of shape `[..., M, M]`, with two inner-most dimensions representing the square tridiagonal matrices. Elements outside of the three diagonals will be ignored. If `sequence` format, `diagonals` is list or tuple of three tensors: `[superdiag, maindiag, subdiag]`, each having shape [..., M]. Last element of `superdiag` first element of `subdiag` are ignored. In `compact` format the three diagonals are brought together into one tensor of shape `[..., 3, M]`, with last two dimensions containing superdiagonals, diagonals, and subdiagonals, in order. Similarly to `sequence` format, elements `diagonals[..., 0, M-1]` and `diagonals[..., 2, 0]` are ignored. The `sequence` format is recommended as the one with the best performance. `rhs` is matrix to the right of multiplication. It has shape `[..., M, N]`. Example: ```python superdiag = tf.constant([-1, -1, 0], dtype=tf.float64) maindiag = tf.constant([2, 2, 2], dtype=tf.float64) subdiag = tf.constant([0, -1, -1], dtype=tf.float64) diagonals = [superdiag, maindiag, subdiag] rhs = tf.constant([[1, 1], [1, 1], [1, 1]], dtype=tf.float64) x = tf.linalg.tridiagonal_matmul(diagonals, rhs, diagonals_format='sequence') ``` Args: diagonals: A `Tensor` or tuple of `Tensor`s describing left-hand sides. The shape depends of `diagonals_format`, see description above. Must be `float32`, `float64`, `complex64`, or `complex128`. rhs: A `Tensor` of shape [..., M, N] and with the same dtype as `diagonals`. diagonals_format: one of `sequence`, or `compact`. Default is `compact`. name: A name to give this `Op` (optional). Returns: A `Tensor` of shape [..., M, N] containing the result of multiplication. Raises: ValueError: An unsupported type is provided as input, or when the input tensors have incorrect shapes. """ if diagonals_format == 'compact': superdiag = diagonals[..., 0, :] maindiag = diagonals[..., 1, :] subdiag = diagonals[..., 2, :] elif diagonals_format == 'sequence': superdiag, maindiag, subdiag = diagonals elif diagonals_format == 'matrix': m1 = tensor_shape.dimension_value(diagonals.shape[-1]) m2 = tensor_shape.dimension_value(diagonals.shape[-2]) if not m1 or not m2: raise ValueError('The size of the matrix needs to be known for ' 'diagonals_format="matrix"') if m1 != m2: raise ValueError( 'Expected last two dimensions of diagonals to be same, got {} and {}' .format(m1, m2)) # TODO(b/131695260): use matrix_diag_part when it supports extracting # arbitrary diagonals. maindiag = array_ops.matrix_diag_part(diagonals) diagonals = array_ops.transpose(diagonals) dummy_index = [0, 0] superdiag_indices = [[i + 1, i] for i in range(0, m1 - 1)] + [dummy_index] subdiag_indices = [dummy_index] + [[i - 1, i] for i in range(1, m1)] superdiag = array_ops.transpose( array_ops.gather_nd(diagonals, superdiag_indices)) subdiag = array_ops.transpose( array_ops.gather_nd(diagonals, subdiag_indices)) else: raise ValueError('Unrecognized diagonals_format: %s' % diagonals_format) # C++ backend requires matrices. # Converting 1-dimensional vectors to matrices with 1 row. superdiag = array_ops.expand_dims(superdiag, -2) maindiag = array_ops.expand_dims(maindiag, -2) subdiag = array_ops.expand_dims(subdiag, -2) return linalg_ops.tridiagonal_mat_mul(superdiag, maindiag, subdiag, rhs, name)