def matrix_triangular_solve_with_broadcast(matrix, rhs, lower=True, adjoint=False, name=None): """Solves triangular systems of linear equations with by backsubstitution. Works identically to `tf.linalg.triangular_solve`, but broadcasts batch dims of `matrix` and `rhs` (by replicating) if they are determined statically to be different, or if static shapes are not fully defined. Thus, this may result in an inefficient replication of data. Args: matrix: A Tensor. Must be one of the following types: `float64`, `float32`, `complex64`, `complex128`. Shape is `[..., M, M]`. rhs: A `Tensor`. Must have the same `dtype` as `matrix`. Shape is `[..., M, K]`. lower: An optional `bool`. Defaults to `True`. Indicates whether the innermost matrices in `matrix` are lower or upper triangular. adjoint: An optional `bool`. Defaults to `False`. Indicates whether to solve with matrix or its (block-wise) adjoint. name: A name for the operation (optional). Returns: `Tensor` with same `dtype` as `matrix` and shape `[..., M, K]`. """ with ops.name_scope(name, "MatrixTriangularSolve", [matrix, rhs]): matrix = ops.convert_to_tensor(matrix, name="matrix") rhs = ops.convert_to_tensor(rhs, name="rhs", dtype=matrix.dtype) # If either matrix/rhs has extra dims, we can reshape to get rid of them. matrix, rhs, reshape_inv, still_need_to_transpose = _reshape_for_efficiency( matrix, rhs, adjoint_a=adjoint) # lower indicates whether the matrix is lower triangular. If we have # manually taken adjoint inside _reshape_for_efficiency, it is now upper tri if not still_need_to_transpose and adjoint: lower = not lower # This will broadcast by brute force if we still need to. matrix, rhs = broadcast_matrix_batch_dims([matrix, rhs]) solution = linalg_ops.triangular_solve(matrix, rhs, lower=lower, adjoint=adjoint and still_need_to_transpose) return reshape_inv(solution)
def _solve(self, rhs, adjoint=False, adjoint_arg=False): rhs = linalg.adjoint(rhs) if adjoint_arg else rhs return linalg.triangular_solve( self._get_tril(), rhs, lower=True, adjoint=adjoint)