示例#1
0
def matrix_triangular_solve_with_broadcast(matrix,
                                           rhs,
                                           lower=True,
                                           adjoint=False,
                                           name=None):
    """Solves triangular systems of linear equations with by backsubstitution.

  Works identically to `tf.linalg.triangular_solve`, but broadcasts batch dims
  of `matrix` and `rhs` (by replicating) if they are determined statically to be
  different, or if static shapes are not fully defined.  Thus, this may result
  in an inefficient replication of data.

  Args:
    matrix: A Tensor. Must be one of the following types:
      `float64`, `float32`, `complex64`, `complex128`. Shape is `[..., M, M]`.
    rhs: A `Tensor`. Must have the same `dtype` as `matrix`.
      Shape is `[..., M, K]`.
    lower: An optional `bool`. Defaults to `True`. Indicates whether the
      innermost matrices in `matrix` are lower or upper triangular.
    adjoint: An optional `bool`. Defaults to `False`. Indicates whether to solve
      with matrix or its (block-wise) adjoint.
    name: A name for the operation (optional).

  Returns:
    `Tensor` with same `dtype` as `matrix` and shape `[..., M, K]`.
  """
    with ops.name_scope(name, "MatrixTriangularSolve", [matrix, rhs]):
        matrix = ops.convert_to_tensor(matrix, name="matrix")
        rhs = ops.convert_to_tensor(rhs, name="rhs", dtype=matrix.dtype)

        # If either matrix/rhs has extra dims, we can reshape to get rid of them.
        matrix, rhs, reshape_inv, still_need_to_transpose = _reshape_for_efficiency(
            matrix, rhs, adjoint_a=adjoint)

        # lower indicates whether the matrix is lower triangular. If we have
        # manually taken adjoint inside _reshape_for_efficiency, it is now upper tri
        if not still_need_to_transpose and adjoint:
            lower = not lower

        # This will broadcast by brute force if we still need to.
        matrix, rhs = broadcast_matrix_batch_dims([matrix, rhs])

        solution = linalg_ops.triangular_solve(matrix,
                                               rhs,
                                               lower=lower,
                                               adjoint=adjoint
                                               and still_need_to_transpose)

        return reshape_inv(solution)
 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
   rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
   return linalg.triangular_solve(
       self._get_tril(), rhs, lower=True, adjoint=adjoint)