예제 #1
0
def _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs,
                                      conjugate_rhs, partial_pivoting, name):
    """Helper function used after the input has been cast to compact form."""
    diags_rank, rhs_rank = diagonals.shape.rank, rhs.shape.rank

    # If we know the rank of the diagonal tensor, do some static checking.
    if diags_rank:
        if diags_rank < 2:
            raise ValueError(
                'Expected diagonals to have rank at least 2, got {}'.format(
                    diags_rank))
        if rhs_rank and rhs_rank != diags_rank and rhs_rank != diags_rank - 1:
            raise ValueError(
                'Expected the rank of rhs to be {} or {}, got {}'.format(
                    diags_rank - 1, diags_rank, rhs_rank))
        if (rhs_rank and not diagonals.shape[:-2].is_compatible_with(
                rhs.shape[:diags_rank - 2])):
            raise ValueError('Batch shapes {} and {} are incompatible'.format(
                diagonals.shape[:-2], rhs.shape[:diags_rank - 2]))

    if diagonals.shape[-2] and diagonals.shape[-2] != 3:
        raise ValueError('Expected 3 diagonals got {}'.format(
            diagonals.shape[-2]))

    def check_num_lhs_matches_num_rhs():
        if (diagonals.shape[-1] and rhs.shape[-2]
                and diagonals.shape[-1] != rhs.shape[-2]):
            raise ValueError(
                'Expected number of left-hand sided and right-hand '
                'sides to be equal, got {} and {}'.format(
                    diagonals.shape[-1], rhs.shape[-2]))

    if rhs_rank and diags_rank and rhs_rank == diags_rank - 1:
        # Rhs provided as a vector, ignoring transpose_rhs
        if conjugate_rhs:
            rhs = math_ops.conj(rhs)
        rhs = array_ops.expand_dims(rhs, -1)
        check_num_lhs_matches_num_rhs()
        return array_ops.squeeze(
            linalg_ops.tridiagonal_solve(diagonals, rhs, partial_pivoting,
                                         name), -1)

    if transpose_rhs:
        rhs = array_ops.matrix_transpose(rhs, conjugate=conjugate_rhs)
    elif conjugate_rhs:
        rhs = math_ops.conj(rhs)

    check_num_lhs_matches_num_rhs()
    result = linalg_ops.tridiagonal_solve(diagonals, rhs, partial_pivoting,
                                          name)
    if transpose_rhs and not compat.forward_compatible(2019, 10, 18):
        return array_ops.matrix_transpose(result)
    return result
예제 #2
0
def _tridiagonal_solve_compact_format(diagonals,
                                      rhs,
                                      transpose_rhs=False,
                                      conjugate_rhs=False,
                                      name=None):
    """Helper function used after the input has been cast to compact form."""
    diags_rank, rhs_rank = len(diagonals.shape), len(rhs.shape)

    if diags_rank < 2:
        raise ValueError(
            'Expected diagonals to have rank at least 2, got {}'.format(
                diags_rank))
    if rhs_rank != diags_rank and rhs_rank != diags_rank - 1:
        raise ValueError(
            'Expected the rank of rhs to be {} or {}, got {}'.format(
                diags_rank - 1, diags_rank, rhs_rank))
    if diagonals.shape[-2] != 3:
        raise ValueError('Expected 3 diagonals got {}'.format(
            diagonals.shape[-2]))
    if not diagonals.shape[:-2].is_compatible_with(rhs.shape[:diags_rank - 2]):
        raise ValueError('Batch shapes {} and {} are incompatible'.format(
            diagonals.shape[:-2], rhs.shape[:diags_rank - 2]))

    def check_num_lhs_matches_num_rhs():
        if diagonals.shape[-1] != rhs.shape[-2]:
            raise ValueError(
                'Expected number of left-hand sided and right-hand '
                'sides to be equal, got {} and {}'.format(
                    diagonals.shape[-1], rhs.shape[-2]))

    if rhs_rank == diags_rank - 1:
        # Rhs provided as a vector, ignoring transpose_rhs
        if conjugate_rhs:
            rhs = math_ops.conj(rhs)
        rhs = array_ops.expand_dims(rhs, -1)
        check_num_lhs_matches_num_rhs()
        return array_ops.squeeze(
            linalg_ops.tridiagonal_solve(diagonals, rhs, name), -1)

    if transpose_rhs:
        rhs = array_ops.matrix_transpose(rhs, conjugate=conjugate_rhs)
    elif conjugate_rhs:
        rhs = math_ops.conj(rhs)

    check_num_lhs_matches_num_rhs()
    result = linalg_ops.tridiagonal_solve(diagonals, rhs, name)
    return array_ops.matrix_transpose(result) if transpose_rhs else result
예제 #3
0
def _tridiagonal_solve_compact_format(diagonals,
                                      rhs,
                                      transpose_rhs=False,
                                      conjugate_rhs=False,
                                      name=None):
  """Helper function used after the input has been cast to compact form."""
  diags_rank, rhs_rank = len(diagonals.shape), len(rhs.shape)

  if diags_rank < 2:
    raise ValueError(
        'Expected diagonals to have rank at least 2, got {}'.format(diags_rank))
  if rhs_rank != diags_rank and rhs_rank != diags_rank - 1:
    raise ValueError('Expected the rank of rhs to be {} or {}, got {}'.format(
        diags_rank - 1, diags_rank, rhs_rank))
  if diagonals.shape[-2] != 3:
    raise ValueError('Expected 3 diagonals got {}'.format(diagonals.shape[-2]))
  if not diagonals.shape[:-2].is_compatible_with(rhs.shape[:diags_rank - 2]):
    raise ValueError('Batch shapes {} and {} are incompatible'.format(
        diagonals.shape[:-2], rhs.shape[:diags_rank - 2]))

  def check_num_lhs_matches_num_rhs():
    if diagonals.shape[-1] != rhs.shape[-2]:
      raise ValueError('Expected number of left-hand sided and right-hand '
                       'sides to be equal, got {} and {}'.format(
                           diagonals.shape[-1], rhs.shape[-2]))

  if rhs_rank == diags_rank - 1:
    # Rhs provided as a vector, ignoring transpose_rhs
    if conjugate_rhs:
      rhs = math_ops.conj(rhs)
    rhs = array_ops.expand_dims(rhs, -1)
    check_num_lhs_matches_num_rhs()
    return array_ops.squeeze(
        linalg_ops.tridiagonal_solve(diagonals, rhs, name), -1)

  if transpose_rhs:
    rhs = array_ops.matrix_transpose(rhs, conjugate=conjugate_rhs)
  elif conjugate_rhs:
    rhs = math_ops.conj(rhs)

  check_num_lhs_matches_num_rhs()
  result = linalg_ops.tridiagonal_solve(diagonals, rhs, name)
  return array_ops.matrix_transpose(result) if transpose_rhs else result
예제 #4
0
def _TridiagonalSolveGrad(op, grad):
  """Gradient for TridiagonalSolveGrad."""
  diags = op.inputs[0]
  x = op.outputs[0]

  # Transposing the matrix within tridiagonal_solve kernel by interchanging
  # superdiagonal and subdiagonal wouldn't work on GPU due to mismatch with
  # paddings required by cusparse*gtsv routines.
  # So constructing the transposed matrix in Python.
  diags_transposed = _TransposeTridiagonalMatrix(diags)

  grad_rhs = linalg_ops.tridiagonal_solve(diags_transposed, grad)
  grad_diags = -_MatmulExtractingThreeDiagonals(grad_rhs, x)
  return grad_diags, grad_rhs
예제 #5
0
def _TridiagonalSolveGrad(op, grad):
    """Gradient for TridiagonalSolveGrad."""
    diags = op.inputs[0]
    x = op.outputs[0]
    partial_pivoting = op.get_attr("partial_pivoting")

    # Transposing the matrix within tridiagonal_solve kernel by interchanging
    # superdiagonal and subdiagonal wouldn't work on GPU due to mismatch with
    # paddings required by cusparse*gtsv routines.
    # So constructing the transposed matrix in Python.
    diags_transposed = _TransposeTridiagonalMatrix(diags)

    grad_rhs = linalg_ops.tridiagonal_solve(diags_transposed,
                                            grad,
                                            partial_pivoting=partial_pivoting)
    grad_diags = -_MatmulExtractingThreeDiagonals(grad_rhs, x)  # pylint: disable=invalid-unary-operand-type
    return grad_diags, grad_rhs