def _testErrorWithShapesEager(self, exception_regex, superdiag_shape,
                               maindiag_shape, subdiag_shape, rhs_shape):
     with context.eager_mode():
         superdiag = array_ops.ones(superdiag_shape)
         maindiag = array_ops.ones(maindiag_shape)
         subdiag = array_ops.ones(subdiag_shape)
         rhs = array_ops.ones(rhs_shape)
         with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
                                     exception_regex):
             linalg_ops.tridiagonal_mat_mul(superdiag, maindiag, subdiag,
                                            rhs)
Esempio n. 2
0
def tridiagonal_matmul(diagonals, rhs, diagonals_format='compact', name=None):
  r"""Multiplies tridiagonal matrix by matrix.

  `diagonals` is representation of 3-diagonal NxN matrix, which depends on
  `diagonals_format`.

  In `matrix` format, `diagonals` must be a tensor of shape `[..., M, M]`, with
  two inner-most dimensions representing the square tridiagonal matrices.
  Elements outside of the three diagonals will be ignored.

  If `sequence` format, `diagonals` is list or tuple of three tensors:
  `[superdiag, maindiag, subdiag]`, each having shape [..., M]. Last element
  of `superdiag` first element of `subdiag` are ignored.

  In `compact` format the three diagonals are brought together into one tensor
  of shape `[..., 3, M]`, with last two dimensions containing superdiagonals,
  diagonals, and subdiagonals, in order. Similarly to `sequence` format,
  elements `diagonals[..., 0, M-1]` and `diagonals[..., 2, 0]` are ignored.

  The `sequence` format is recommended as the one with the best performance.

  `rhs` is matrix to the right of multiplication. It has shape `[..., M, N]`.

  Example:

  ```python
  superdiag = tf.constant([-1, -1, 0], dtype=tf.float64)
  maindiag = tf.constant([2, 2, 2], dtype=tf.float64)
  subdiag = tf.constant([0, -1, -1], dtype=tf.float64)
  diagonals = [superdiag, maindiag, subdiag]
  rhs = tf.constant([[1, 1], [1, 1], [1, 1]], dtype=tf.float64)
  x = tf.linalg.tridiagonal_matmul(diagonals, rhs, diagonals_format='sequence')
  ```

  Args:
    diagonals: A `Tensor` or tuple of `Tensor`s describing left-hand sides. The
      shape depends of `diagonals_format`, see description above. Must be
      `float32`, `float64`, `complex64`, or `complex128`.
    rhs: A `Tensor` of shape [..., M, N] and with the same dtype as `diagonals`.
    diagonals_format: one of `sequence`, or `compact`. Default is `compact`.
    name:  A name to give this `Op` (optional).

  Returns:
    A `Tensor` of shape [..., M, N] containing the result of multiplication.

  Raises:
    ValueError: An unsupported type is provided as input, or when the input
    tensors have incorrect shapes.
  """
  if diagonals_format == 'compact':
    superdiag = diagonals[..., 0, :]
    maindiag = diagonals[..., 1, :]
    subdiag = diagonals[..., 2, :]
  elif diagonals_format == 'sequence':
    superdiag, maindiag, subdiag = diagonals
  elif diagonals_format == 'matrix':
    m1 = tensor_shape.dimension_value(diagonals.shape[-1])
    m2 = tensor_shape.dimension_value(diagonals.shape[-2])
    if not m1 or not m2:
      raise ValueError('The size of the matrix needs to be known for '
                       'diagonals_format="matrix"')
    if m1 != m2:
      raise ValueError(
          'Expected last two dimensions of diagonals to be same, got {} and {}'
          .format(m1, m2))

    # TODO(b/131695260): use matrix_diag_part when it supports extracting
    # arbitrary diagonals.
    maindiag = array_ops.matrix_diag_part(diagonals)
    diagonals = array_ops.transpose(diagonals)
    dummy_index = [0, 0]
    superdiag_indices = [[i + 1, i] for i in range(0, m1 - 1)] + [dummy_index]
    subdiag_indices = [dummy_index] + [[i - 1, i] for i in range(1, m1)]
    superdiag = array_ops.transpose(
        array_ops.gather_nd(diagonals, superdiag_indices))
    subdiag = array_ops.transpose(
        array_ops.gather_nd(diagonals, subdiag_indices))
  else:
    raise ValueError('Unrecognized diagonals_format: %s' % diagonals_format)

  # C++ backend requires matrices.
  # Converting 1-dimensional vectors to matrices with 1 row.
  superdiag = array_ops.expand_dims(superdiag, -2)
  maindiag = array_ops.expand_dims(maindiag, -2)
  subdiag = array_ops.expand_dims(subdiag, -2)

  return linalg_ops.tridiagonal_mat_mul(superdiag, maindiag, subdiag, rhs, name)
Esempio n. 3
0
def tridiagonal_matmul(diagonals, rhs, diagonals_format='compact', name=None):
  r"""Multiplies tridiagonal matrix by matrix.

  `diagonals` is representation of 3-diagonal NxN matrix, which depends on
  `diagonals_format`.

  In `matrix` format, `diagonals` must be a tensor of shape `[..., M, M]`, with
  two inner-most dimensions representing the square tridiagonal matrices.
  Elements outside of the three diagonals will be ignored.

  If `sequence` format, `diagonals` is list or tuple of three tensors:
  `[superdiag, maindiag, subdiag]`, each having shape [..., M]. Last element
  of `superdiag` first element of `subdiag` are ignored.

  In `compact` format the three diagonals are brought together into one tensor
  of shape `[..., 3, M]`, with last two dimensions containing superdiagonals,
  diagonals, and subdiagonals, in order. Similarly to `sequence` format,
  elements `diagonals[..., 0, M-1]` and `diagonals[..., 2, 0]` are ignored.

  The `sequence` format is recommended as the one with the best performance.

  `rhs` is matrix to the right of multiplication. It has shape `[..., M, N]`.

  Example:

  ```python
  superdiag = tf.constant([-1, -1, 0], dtype=tf.float64)
  maindiag = tf.constant([2, 2, 2], dtype=tf.float64)
  subdiag = tf.constant([0, -1, -1], dtype=tf.float64)
  diagonals = [superdiag, maindiag, subdiag]
  rhs = tf.constant([[1, 1], [1, 1], [1, 1]], dtype=tf.float64)
  x = tf.linalg.tridiagonal_matmul(diagonals, rhs, diagonals_format='sequence')
  ```

  Args:
    diagonals: A `Tensor` or tuple of `Tensor`s describing left-hand sides. The
      shape depends of `diagonals_format`, see description above. Must be
      `float32`, `float64`, `complex64`, or `complex128`.
    rhs: A `Tensor` of shape [..., M, N] and with the same dtype as `diagonals`.
    diagonals_format: one of `sequence`, or `compact`. Default is `compact`.
    name:  A name to give this `Op` (optional).

  Returns:
    A `Tensor` of shape [..., M, N] containing the result of multiplication.

  Raises:
    ValueError: An unsupported type is provided as input, or when the input
    tensors have incorrect shapes.
  """
  if diagonals_format == 'compact':
    superdiag = diagonals[..., 0, :]
    maindiag = diagonals[..., 1, :]
    subdiag = diagonals[..., 2, :]
  elif diagonals_format == 'sequence':
    superdiag, maindiag, subdiag = diagonals
  elif diagonals_format == 'matrix':
    m1 = tensor_shape.dimension_value(diagonals.shape[-1])
    m2 = tensor_shape.dimension_value(diagonals.shape[-2])
    if not m1 or not m2:
      raise ValueError('The size of the matrix needs to be known for '
                       'diagonals_format="matrix"')
    if m1 != m2:
      raise ValueError(
          'Expected last two dimensions of diagonals to be same, got {} and {}'
          .format(m1, m2))

    # TODO(b/131695260): use matrix_diag_part when it supports extracting
    # arbitrary diagonals.
    maindiag = array_ops.matrix_diag_part(diagonals)
    diagonals = array_ops.transpose(diagonals)
    dummy_index = [0, 0]
    superdiag_indices = [[i + 1, i] for i in range(0, m1 - 1)] + [dummy_index]
    subdiag_indices = [dummy_index] + [[i - 1, i] for i in range(1, m1)]
    superdiag = array_ops.transpose(
        array_ops.gather_nd(diagonals, superdiag_indices))
    subdiag = array_ops.transpose(
        array_ops.gather_nd(diagonals, subdiag_indices))
  else:
    raise ValueError('Unrecognized diagonals_format: %s' % diagonals_format)

  # C++ backend requires matrices.
  # Converting 1-dimensional vectors to matrices with 1 row.
  superdiag = array_ops.expand_dims(superdiag, -2)
  maindiag = array_ops.expand_dims(maindiag, -2)
  subdiag = array_ops.expand_dims(subdiag, -2)

  return linalg_ops.tridiagonal_mat_mul(superdiag, maindiag, subdiag, rhs, name)