コード例 #1
0
    def test_non_blockwise_input(self):
        x = np.zeros((2, 3, 4, 6))
        x_tensor = ops.convert_to_tensor(x)
        x_placeholder = array_ops.placeholder_with_default(x, shape=None)
        x_list = x.tolist()

        # For known and matching operator dimensions, interpret all as non-blockwise
        op_dimension_values = [2, 1, 3]
        op_dimensions = [
            tensor_shape.Dimension(d) for d in op_dimension_values
        ]
        for inputs in [x, x_tensor, x_placeholder, x_list]:
            self.assertFalse(
                linear_operator_util.arg_is_blockwise(op_dimensions, inputs,
                                                      -1))

        # The input is still interpreted as non-blockwise for unknown operator
        # dimensions (`x_list` has an outermost dimension that does not matcn the
        # number of blocks, and the other inputs are not iterables).
        unknown_op_dimensions = [
            tensor_shape.Dimension(None) for _ in op_dimension_values
        ]
        for inputs in [x, x_tensor, x_placeholder, x_list]:
            self.assertFalse(
                linear_operator_util.arg_is_blockwise(unknown_op_dimensions,
                                                      inputs, -1))
コード例 #2
0
    def test_ambiguous_input_raises(self):
        x = np.zeros((3, 4, 2)).tolist()
        op_dimensions = [tensor_shape.Dimension(None) for _ in range(3)]

        # Since the leftmost dimension of `x` is equal to the number of blocks, and
        # the operators have unknown dimension, the input is ambiguous.
        with self.assertRaisesRegexp(ValueError, "structure is ambiguous"):
            linear_operator_util.arg_is_blockwise(op_dimensions, x, -2)
コード例 #3
0
    def test_mismatched_input_raises(self):
        x = np.zeros((2, 3, 4, 6)).tolist()
        op_dimension_values = [4, 3]
        op_dimensions = [
            tensor_shape.Dimension(v) for v in op_dimension_values
        ]

        # The dimensions of the two operator-blocks sum to 7. `x` is a
        # two-element list; if interpreted blockwise, its corresponding dimensions
        # sum to 12 (=6*2). If not interpreted blockwise, its corresponding
        # dimension is 6. This is a mismatch.
        with self.assertRaisesRegexp(ValueError, "dimension does not match"):
            linear_operator_util.arg_is_blockwise(op_dimensions, x, -1)
コード例 #4
0
    def _matmul(self, x, adjoint=False, adjoint_arg=False):
        arg_dim = -1 if adjoint_arg else -2
        block_dimensions = (self._block_range_dimensions()
                            if adjoint else self._block_domain_dimensions())
        blockwise_arg = linear_operator_util.arg_is_blockwise(
            block_dimensions, x, arg_dim)
        if blockwise_arg:
            split_x = x
        else:
            split_dim = -1 if adjoint_arg else -2
            # Split input by rows normally, and otherwise columns.
            split_x = linear_operator_util.split_arg_into_blocks(
                self._block_domain_dimensions(),
                self._block_domain_dimension_tensors,
                x,
                axis=split_dim)

        result_list = []
        for index, operator in enumerate(self.operators):
            result_list += [
                operator.matmul(split_x[index],
                                adjoint=adjoint,
                                adjoint_arg=adjoint_arg)
            ]

        if blockwise_arg:
            return result_list

        result_list = linear_operator_util.broadcast_matrix_batch_dims(
            result_list)
        return array_ops.concat(result_list, axis=-2)
コード例 #5
0
    def matvec(self, x, adjoint=False, name="matvec"):
        """Transform [batch] vector `x` with left multiplication:  `x --> Ax`.

    ```python
    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
    operator = LinearOperator(...)

    X = ... # shape [..., N], batch vector

    Y = operator.matvec(X)
    Y.shape
    ==> [..., M]

    Y[..., :] = sum_j A[..., :, j] X[..., j]
    ```

    Args:
      x: `Tensor` with compatible shape and same `dtype` as `self`, or an
        iterable of `Tensor`s. `Tensor`s are treated a [batch] vectors, meaning
        for every set of leading dimensions, the last dimension defines a
        vector.
        See class docstring for definition of compatibility.
      adjoint: Python `bool`.  If `True`, left multiply by the adjoint: `A^H x`.
      name:  A name for this `Op`.

    Returns:
      A `Tensor` with shape `[..., M]` and same `dtype` as `self`.
    """
        with self._name_scope(name):  # pylint: disable=not-callable
            block_dimensions = (self._block_range_dimensions() if adjoint else
                                self._block_domain_dimensions())
            if linear_operator_util.arg_is_blockwise(block_dimensions, x, -1):
                for i, block in enumerate(x):
                    if not isinstance(block, linear_operator.LinearOperator):
                        block = ops.convert_to_tensor_v2_with_dispatch(block)
                        self._check_input_dtype(block)
                        block_dimensions[i].assert_is_compatible_with(
                            block.shape[-1])
                        x[i] = block
                x_mat = [block[..., array_ops.newaxis] for block in x]
                y_mat = self.matmul(x_mat, adjoint=adjoint)
                return [array_ops.squeeze(y, axis=-1) for y in y_mat]

            x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
            self._check_input_dtype(x)
            op_dimension = (self.range_dimension
                            if adjoint else self.domain_dimension)
            op_dimension.assert_is_compatible_with(x.shape[-1])
            x_mat = x[..., array_ops.newaxis]
            y_mat = self.matmul(x_mat, adjoint=adjoint)
            return array_ops.squeeze(y_mat, axis=-1)
コード例 #6
0
    def test_blockwise_input(self, op_dimension_values, split_dim):

        op_dimensions = [
            tensor_shape.Dimension(v) for v in op_dimension_values
        ]
        unknown_op_dimensions = [
            tensor_shape.Dimension(None) for _ in op_dimension_values
        ]

        batch_shape = [2, 1]
        arg_dim = 5
        if split_dim == -1:
            blockwise_arrays = [
                np.zeros(batch_shape + [arg_dim, d])
                for d in op_dimension_values
            ]
        else:
            blockwise_arrays = [
                np.zeros(batch_shape + [d, arg_dim])
                for d in op_dimension_values
            ]

        blockwise_list = [block.tolist() for block in blockwise_arrays]
        blockwise_tensors = [
            ops.convert_to_tensor(block) for block in blockwise_arrays
        ]
        blockwise_placeholders = [
            array_ops.placeholder_with_default(block, shape=None)
            for block in blockwise_arrays
        ]

        # Iterables of non-nested structures are always interpreted as blockwise.
        # The list of lists is interpreted as blockwise as well, regardless of
        # whether the operator dimensions are known, since the sizes of its elements
        # along `split_dim` are non-identical.
        for op_dims in [op_dimensions, unknown_op_dimensions]:
            for blockwise_inputs in [
                    blockwise_arrays, blockwise_list, blockwise_tensors,
                    blockwise_placeholders
            ]:
                self.assertTrue(
                    linear_operator_util.arg_is_blockwise(
                        op_dims, blockwise_inputs, split_dim))
コード例 #7
0
    def solvevec(self, rhs, adjoint=False, name="solve"):
        """Solve single equation with best effort: `A X = rhs`.

    The returned `Tensor` will be close to an exact solution if `A` is well
    conditioned. Otherwise closeness will vary. See class docstring for details.

    Examples:

    ```python
    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
    operator = LinearOperator(...)
    operator.shape = [..., M, N]

    # Solve one linear system for every member of the batch.
    RHS = ... # shape [..., M]

    X = operator.solvevec(RHS)
    # X is the solution to the linear system
    # sum_j A[..., :, j] X[..., j] = RHS[..., :]

    operator.matvec(X)
    ==> RHS
    ```

    Args:
      rhs: `Tensor` with same `dtype` as this operator, or list of `Tensor`s
        (for blockwise operators). `Tensor`s are treated as [batch] vectors,
        meaning for every set of leading dimensions, the last dimension defines
        a vector.  See class docstring for definition of compatibility regarding
        batch dimensions.
      adjoint: Python `bool`.  If `True`, solve the system involving the adjoint
        of this `LinearOperator`:  `A^H X = rhs`.
      name:  A name scope to use for ops added by this method.

    Returns:
      `Tensor` with shape `[...,N]` and same `dtype` as `rhs`.

    Raises:
      NotImplementedError:  If `self.is_non_singular` or `is_square` is False.
    """
        with self._name_scope(name):  # pylint: disable=not-callable
            block_dimensions = (self._block_domain_dimensions()
                                if adjoint else self._block_range_dimensions())
            if linear_operator_util.arg_is_blockwise(block_dimensions, rhs,
                                                     -1):
                for i, block in enumerate(rhs):
                    if not isinstance(block, linear_operator.LinearOperator):
                        block = ops.convert_to_tensor_v2_with_dispatch(block)
                        self._check_input_dtype(block)
                        block_dimensions[i].assert_is_compatible_with(
                            block.shape[-1])
                        rhs[i] = block
                rhs_mat = [
                    array_ops.expand_dims(block, axis=-1) for block in rhs
                ]
                solution_mat = self.solve(rhs_mat, adjoint=adjoint)
                return [array_ops.squeeze(x, axis=-1) for x in solution_mat]
            rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
            self._check_input_dtype(rhs)
            op_dimension = (self.domain_dimension
                            if adjoint else self.range_dimension)
            op_dimension.assert_is_compatible_with(rhs.shape[-1])
            rhs_mat = array_ops.expand_dims(rhs, axis=-1)
            solution_mat = self.solve(rhs_mat, adjoint=adjoint)
            return array_ops.squeeze(solution_mat, axis=-1)
コード例 #8
0
    def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"):
        """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`.

    The returned `Tensor` will be close to an exact solution if `A` is well
    conditioned. Otherwise closeness will vary. See class docstring for details.

    Given the blockwise `n + 1`-by-`n + 1` linear operator:

    op = [[A_00     0  ...     0  ...    0],
          [A_10  A_11  ...     0  ...    0],
          ...
          [A_k0  A_k1  ...  A_kk  ...    0],
          ...
          [A_n0  A_n1  ...  A_nk  ... A_nn]]

    we find `x = op.solve(y)` by observing that

    `y_k = A_k0.matmul(x_0) + A_k1.matmul(x_1) + ... + A_kk.matmul(x_k)`

    and therefore

    `x_k = A_kk.solve(y_k -
                      A_k0.matmul(x_0) - ... - A_k(k-1).matmul(x_(k-1)))`

    where `x_k` and `y_k` are the `k`th blocks obtained by decomposing `x`
    and `y` along their appropriate axes.

    We first solve `x_0 = A_00.solve(y_0)`. Proceeding inductively, we solve
    for `x_k`, `k = 1..n`, given `x_0..x_(k-1)`.

    The adjoint case is solved similarly, beginning with
    `x_n = A_nn.solve(y_n, adjoint=True)` and proceeding backwards.

    Examples:

    ```python
    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
    operator = LinearOperator(...)
    operator.shape = [..., M, N]

    # Solve R > 0 linear systems for every member of the batch.
    RHS = ... # shape [..., M, R]

    X = operator.solve(RHS)
    # X[..., :, r] is the solution to the r'th linear system
    # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r]

    operator.matmul(X)
    ==> RHS
    ```

    Args:
      rhs: `Tensor` with same `dtype` as this operator and compatible shape,
        or a list of `Tensor`s. `Tensor`s are treated like a [batch] matrices
        meaning for every set of leading dimensions, the last two dimensions
        defines a matrix.
        See class docstring for definition of compatibility.
      adjoint: Python `bool`.  If `True`, solve the system involving the adjoint
        of this `LinearOperator`:  `A^H X = rhs`.
      adjoint_arg:  Python `bool`.  If `True`, solve `A X = rhs^H` where `rhs^H`
        is the hermitian transpose (transposition and complex conjugation).
      name:  A name scope to use for ops added by this method.

    Returns:
      `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`.

    Raises:
      NotImplementedError:  If `self.is_non_singular` or `is_square` is False.
    """
        if self.is_non_singular is False:
            raise NotImplementedError(
                "Exact solve not implemented for an operator that is expected to "
                "be singular.")
        if self.is_square is False:
            raise NotImplementedError(
                "Exact solve not implemented for an operator that is expected to "
                "not be square.")
        if isinstance(rhs, linear_operator.LinearOperator):
            left_operator = self.adjoint() if adjoint else self
            right_operator = rhs.adjoint() if adjoint_arg else rhs

            if (right_operator.range_dimension is not None
                    and left_operator.domain_dimension is not None
                    and right_operator.range_dimension !=
                    left_operator.domain_dimension):
                raise ValueError(
                    "Operators are incompatible. Expected `rhs` to have dimension"
                    " {} but got {}.".format(left_operator.domain_dimension,
                                             right_operator.range_dimension))
            with self._name_scope(name):  # pylint: disable=not-callable
                return linear_operator_algebra.solve(left_operator,
                                                     right_operator)

        with self._name_scope(name):  # pylint: disable=not-callable
            block_dimensions = (self._block_domain_dimensions()
                                if adjoint else self._block_range_dimensions())
            arg_dim = -1 if adjoint_arg else -2
            blockwise_arg = linear_operator_util.arg_is_blockwise(
                block_dimensions, rhs, arg_dim)
            if blockwise_arg:
                for i, block in enumerate(rhs):
                    if not isinstance(block, linear_operator.LinearOperator):
                        block = ops.convert_to_tensor_v2_with_dispatch(block)
                        self._check_input_dtype(block)
                        block_dimensions[i].assert_is_compatible_with(
                            block.shape[arg_dim])
                        rhs[i] = block
                if adjoint_arg:
                    split_rhs = [linalg.adjoint(y) for y in rhs]
                else:
                    split_rhs = rhs

            else:
                rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
                self._check_input_dtype(rhs)
                op_dimension = (self.domain_dimension
                                if adjoint else self.range_dimension)
                op_dimension.assert_is_compatible_with(rhs.shape[arg_dim])

                rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
                split_rhs = linear_operator_util.split_arg_into_blocks(
                    self._block_domain_dimensions(),
                    self._block_domain_dimension_tensors,
                    rhs,
                    axis=-2)

            solution_list = []
            if adjoint:
                # For an adjoint blockwise lower-triangular linear operator, the system
                # must be solved bottom to top. Iterate backwards over rows of the
                # adjoint (i.e. columns of the non-adjoint operator).
                for index in reversed(range(len(self.operators))):
                    y = split_rhs[index]
                    # Iterate top to bottom over the operators in the off-diagonal portion
                    # of the column-partition (i.e. row-partition of the adjoint), apply
                    # the operator to the respective block of the solution found in
                    # previous iterations, and subtract the result from the `rhs` block.
                    # For example,let `A`, `B`, and `D` be the linear operators in the top
                    # row-partition of the adjoint of
                    # `LinearOperatorBlockLowerTriangular([[A], [B, C], [D, E, F]])`,
                    # and `x_1` and `x_2` be blocks of the solution found in previous
                    # iterations of the outer loop. The following loop (when `index == 0`)
                    # expresses
                    # `Ax_0 + Bx_1 + Dx_2 = y_0` as `Ax_0 = y_0*`, where
                    # `y_0* = y_0 - Bx_1 - Dx_2`.
                    for j in reversed(range(index + 1, len(self.operators))):
                        y = y - self.operators[j][index].matmul(
                            solution_list[len(self.operators) - 1 - j],
                            adjoint=adjoint)
                    # Continuing the example above, solve `Ax_0 = y_0*` for `x_0`.
                    solution_list.append(self._diagonal_operators[index].solve(
                        y, adjoint=adjoint))
                solution_list.reverse()
            else:
                # Iterate top to bottom over the row-partitions.
                for row, y in zip(self.operators, split_rhs):
                    # Iterate left to right over the operators in the off-diagonal portion
                    # of the row-partition, apply the operator to the block of the
                    # solution found in previous iterations, and subtract the result from
                    # the `rhs` block. For example, let `D`, `E`, and `F` be the linear
                    # operators in the bottom row-partition of
                    # `LinearOperatorBlockLowerTriangular([[A], [B, C], [D, E, F]])` and
                    # `x_0` and `x_1` be blocks of the solution found in previous
                    # iterations of the outer loop. The following loop
                    # (when `index == 2`), expresses
                    # `Dx_0 + Ex_1 + Fx_2 = y_2` as `Fx_2 = y_2*`, where
                    # `y_2* = y_2 - D_x0 - Ex_1`.
                    for i, operator in enumerate(row[:-1]):
                        y = y - operator.matmul(solution_list[i],
                                                adjoint=adjoint)
                    # Continuing the example above, solve `Fx_2 = y_2*` for `x_2`.
                    solution_list.append(row[-1].solve(y, adjoint=adjoint))

            if blockwise_arg:
                return solution_list

            solution_list = linear_operator_util.broadcast_matrix_batch_dims(
                solution_list)
            return array_ops.concat(solution_list, axis=-2)
コード例 #9
0
    def _matmul(self, x, adjoint=False, adjoint_arg=False):
        arg_dim = -1 if adjoint_arg else -2
        block_dimensions = (self._block_range_dimensions()
                            if adjoint else self._block_domain_dimensions())
        blockwise_arg = linear_operator_util.arg_is_blockwise(
            block_dimensions, x, arg_dim)
        if blockwise_arg:
            split_x = x
        else:
            split_dim = -1 if adjoint_arg else -2
            # Split input by columns if adjoint_arg is True, else rows
            split_x = linear_operator_util.split_arg_into_blocks(
                self._block_domain_dimensions(),
                self._block_domain_dimension_tensors,
                x,
                axis=split_dim)

        result_list = []
        # Iterate over row-partitions (i.e. column-partitions of the adjoint).
        if adjoint:
            for index in range(len(self.operators)):
                # Begin with the operator on the diagonal and apply it to the
                # respective `rhs` block.
                result = self.operators[index][index].matmul(
                    split_x[index], adjoint=adjoint, adjoint_arg=adjoint_arg)

                # Iterate top to bottom over the operators in the remainder of the
                # column-partition (i.e. left to right over the row-partition of the
                # adjoint), apply the operator to the respective `rhs` block and
                # accumulate the sum. For example, given the
                # `LinearOperatorBlockLowerTriangular`:
                #
                # op = [[A, 0, 0],
                #       [B, C, 0],
                #       [D, E, F]]
                #
                # if `index = 1`, the following loop calculates:
                # `y_1 = (C.matmul(x_1, adjoint=adjoint) +
                #         E.matmul(x_2, adjoint=adjoint)`,
                # where `x_1` and `x_2` are splits of `x`.
                for j in range(index + 1, len(self.operators)):
                    result += self.operators[j][index].matmul(
                        split_x[j], adjoint=adjoint, adjoint_arg=adjoint_arg)
                result_list.append(result)
        else:
            for row in self.operators:
                # Begin with the left-most operator in the row-partition and apply it
                # to the first `rhs` block.
                result = row[0].matmul(split_x[0],
                                       adjoint=adjoint,
                                       adjoint_arg=adjoint_arg)
                # Iterate left to right over the operators in the remainder of the row
                # partition, apply the operator to the respective `rhs` block, and
                # accumulate the sum.
                for j, operator in enumerate(row[1:]):
                    result += operator.matmul(split_x[j + 1],
                                              adjoint=adjoint,
                                              adjoint_arg=adjoint_arg)
                result_list.append(result)

        if blockwise_arg:
            return result_list

        result_list = linear_operator_util.broadcast_matrix_batch_dims(
            result_list)
        return array_ops.concat(result_list, axis=-2)
コード例 #10
0
    def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
        """Transform [batch] matrix `x` with left multiplication:  `x --> Ax`.

    ```python
    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
    operator = LinearOperator(...)
    operator.shape = [..., M, N]

    X = ... # shape [..., N, R], batch matrix, R > 0.

    Y = operator.matmul(X)
    Y.shape
    ==> [..., M, R]

    Y[..., :, r] = sum_j A[..., :, j] X[j, r]
    ```

    Args:
      x: `LinearOperator`, `Tensor` with compatible shape and same `dtype` as
        `self`, or a blockwise iterable of `LinearOperator`s or `Tensor`s. See
        class docstring for definition of shape compatibility.
      adjoint: Python `bool`.  If `True`, left multiply by the adjoint: `A^H x`.
      adjoint_arg:  Python `bool`.  If `True`, compute `A x^H` where `x^H` is
        the hermitian transpose (transposition and complex conjugation).
      name:  A name for this `Op`.

    Returns:
      A `LinearOperator` or `Tensor` with shape `[..., M, R]` and same `dtype`
        as `self`, or if `x` is blockwise, a list of `Tensor`s with shapes that
        concatenate to `[..., M, R]`.
    """
        if isinstance(x, linear_operator.LinearOperator):
            left_operator = self.adjoint() if adjoint else self
            right_operator = x.adjoint() if adjoint_arg else x

            if (right_operator.range_dimension is not None
                    and left_operator.domain_dimension is not None
                    and right_operator.range_dimension !=
                    left_operator.domain_dimension):
                raise ValueError(
                    "Operators are incompatible. Expected `x` to have dimension"
                    " {} but got {}.".format(left_operator.domain_dimension,
                                             right_operator.range_dimension))
            with self._name_scope(name):  # pylint: disable=not-callable
                return linear_operator_algebra.matmul(left_operator,
                                                      right_operator)

        with self._name_scope(name):  # pylint: disable=not-callable
            arg_dim = -1 if adjoint_arg else -2
            block_dimensions = (self._block_range_dimensions() if adjoint else
                                self._block_domain_dimensions())
            if linear_operator_util.arg_is_blockwise(block_dimensions, x,
                                                     arg_dim):
                for i, block in enumerate(x):
                    if not isinstance(block, linear_operator.LinearOperator):
                        block = ops.convert_to_tensor_v2_with_dispatch(block)
                        self._check_input_dtype(block)
                        block_dimensions[i].assert_is_compatible_with(
                            block.shape[arg_dim])
                        x[i] = block
            else:
                x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
                self._check_input_dtype(x)
                op_dimension = (self.range_dimension
                                if adjoint else self.domain_dimension)
                op_dimension.assert_is_compatible_with(x.shape[arg_dim])
            return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
コード例 #11
0
    def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"):
        """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`.

    The returned `Tensor` will be close to an exact solution if `A` is well
    conditioned. Otherwise closeness will vary. See class docstring for details.

    Examples:

    ```python
    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
    operator = LinearOperator(...)
    operator.shape = [..., M, N]

    # Solve R > 0 linear systems for every member of the batch.
    RHS = ... # shape [..., M, R]

    X = operator.solve(RHS)
    # X[..., :, r] is the solution to the r'th linear system
    # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r]

    operator.matmul(X)
    ==> RHS
    ```

    Args:
      rhs: `Tensor` with same `dtype` as this operator and compatible shape,
        or a list of `Tensor`s (for blockwise operators). `Tensor`s are treated
        like a [batch] matrices meaning for every set of leading dimensions, the
        last two dimensions defines a matrix.
        See class docstring for definition of compatibility.
      adjoint: Python `bool`.  If `True`, solve the system involving the adjoint
        of this `LinearOperator`:  `A^H X = rhs`.
      adjoint_arg:  Python `bool`.  If `True`, solve `A X = rhs^H` where `rhs^H`
        is the hermitian transpose (transposition and complex conjugation).
      name:  A name scope to use for ops added by this method.

    Returns:
      `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`.

    Raises:
      NotImplementedError:  If `self.is_non_singular` or `is_square` is False.
    """
        if self.is_non_singular is False:
            raise NotImplementedError(
                "Exact solve not implemented for an operator that is expected to "
                "be singular.")
        if self.is_square is False:
            raise NotImplementedError(
                "Exact solve not implemented for an operator that is expected to "
                "not be square.")
        if isinstance(rhs, linear_operator.LinearOperator):
            left_operator = self.adjoint() if adjoint else self
            right_operator = rhs.adjoint() if adjoint_arg else rhs

            if (right_operator.range_dimension is not None
                    and left_operator.domain_dimension is not None
                    and right_operator.range_dimension !=
                    left_operator.domain_dimension):
                raise ValueError(
                    "Operators are incompatible. Expected `rhs` to have dimension"
                    " {} but got {}.".format(left_operator.domain_dimension,
                                             right_operator.range_dimension))
            with self._name_scope(name):
                return linear_operator_algebra.solve(left_operator,
                                                     right_operator)

        with self._name_scope(name):
            block_dimensions = (self._block_domain_dimensions()
                                if adjoint else self._block_range_dimensions())
            arg_dim = -1 if adjoint_arg else -2
            blockwise_arg = linear_operator_util.arg_is_blockwise(
                block_dimensions, rhs, arg_dim)

            if blockwise_arg:
                split_rhs = rhs
                for i, block in enumerate(split_rhs):
                    if not isinstance(block, linear_operator.LinearOperator):
                        block = ops.convert_to_tensor_v2_with_dispatch(block)
                        self._check_input_dtype(block)
                        block_dimensions[i].assert_is_compatible_with(
                            block.shape[arg_dim])
                        split_rhs[i] = block
            else:
                rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
                self._check_input_dtype(rhs)
                op_dimension = (self.domain_dimension
                                if adjoint else self.range_dimension)
                op_dimension.assert_is_compatible_with(rhs.shape[arg_dim])
                split_dim = -1 if adjoint_arg else -2
                # Split input by rows normally, and otherwise columns.
                split_rhs = linear_operator_util.split_arg_into_blocks(
                    self._block_domain_dimensions(),
                    self._block_domain_dimension_tensors,
                    rhs,
                    axis=split_dim)

            solution_list = []
            for index, operator in enumerate(self.operators):
                solution_list += [
                    operator.solve(split_rhs[index],
                                   adjoint=adjoint,
                                   adjoint_arg=adjoint_arg)
                ]

            if blockwise_arg:
                return solution_list

            solution_list = linear_operator_util.broadcast_matrix_batch_dims(
                solution_list)
            return array_ops.concat(solution_list, axis=-2)