Esempio n. 1
0
    def _matmul(self, x, adjoint=False, adjoint_arg=False):
        if self._assert_proper_shapes:
            x = linalg.adjoint(x) if adjoint_arg else x
            aps = linear_operator_util.assert_compatible_matrix_dimensions(
                self, x)
            x = distribution_util.with_dependencies([aps], x)
        if self.is_square:
            # Note that adjoint has no effect since this matrix is self-adjoint.
            if adjoint_arg:
                output_shape = prefer_static.concat([
                    prefer_static.shape(x)[:-2],
                    [prefer_static.shape(x)[-1],
                     prefer_static.shape(x)[-2]]
                ],
                                                    axis=0)
            else:
                output_shape = prefer_static.shape(x)

            return self._possibly_broadcast_batch_shape(
                array_ops.zeros(shape=output_shape, dtype=x.dtype))

        x_shape = prefer_static.shape(x)
        n = self._num_columns if adjoint else self._num_rows
        m = x_shape[-2] if adjoint_arg else x_shape[-1]

        output_shape = prefer_static.concat([x_shape[:-2], [n, m]], axis=0)

        zeros = array_ops.zeros(shape=output_shape, dtype=x.dtype)
        return self._possibly_broadcast_batch_shape(zeros)
 def _matmul(self, x, adjoint=False, adjoint_arg=False):
   # Note that adjoint has no effect since this matrix is self-adjoint.
   x = linalg.adjoint(x) if adjoint_arg else x
   if self._assert_proper_shapes:
     aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x)
     x = control_flow_ops.with_dependencies([aps], x)
   return self._possibly_broadcast_batch_shape(x)
 def _matmul(self, x, adjoint=False, adjoint_arg=False):
     x = linalg.adjoint(x) if adjoint_arg else x
     if self._assert_proper_shapes:
         aps = linear_operator_util.assert_compatible_matrix_dimensions(
             self, x)
         x = distribution_util.with_dependencies([aps], x)
     return x * self._make_multiplier_matrix(conjugate=adjoint)
 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
     rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
     if self._assert_proper_shapes:
         aps = linear_operator_util.assert_compatible_matrix_dimensions(
             self, rhs)
         rhs = distribution_util.with_dependencies([aps], rhs)
     return rhs / self._make_multiplier_matrix(conjugate=adjoint)
    def _matmul(self, x, adjoint=False, adjoint_arg=False):
        # Given a Toeplitz matrix, we can embed it in a Circulant matrix to perform
        # efficient matrix multiplications. Given a Toeplitz matrix with first row
        # [t_0, t_1, ... t_{n-1}] and first column [t0, t_{-1}, ..., t_{-(n-1)},
        # let C by the circulant matrix with first column [t0, t_{-1}, ...,
        # t_{-(n-1)}, 0, t_{n-1}, ..., t_1]. Also adjoin to our input vector `x`
        # `n` zeros, to make it a vector of length `2n` (call it y). It can be shown
        # that if we take the first n entries of `Cy`, this is equal to the Toeplitz
        # multiplication. See:
        # http://math.mit.edu/icg/resources/teaching/18.085-spring2015/toeplitz.pdf
        # for more details.
        x = linalg.adjoint(x) if adjoint_arg else x
        expanded_x = array_ops.concat([x, array_ops.zeros_like(x)], axis=-2)
        col = ops.convert_to_tensor(self.col)
        row = ops.convert_to_tensor(self.row)
        circulant_col = array_ops.concat([
            col,
            array_ops.zeros_like(col[..., 0:1]),
            array_ops.reverse(row[..., 1:], axis=[-1])
        ],
                                         axis=-1)
        circulant = linear_operator_circulant.LinearOperatorCirculant(
            fft_ops.fft(_to_complex(circulant_col)),
            input_output_dtype=row.dtype)
        result = circulant.matmul(expanded_x,
                                  adjoint=adjoint,
                                  adjoint_arg=False)

        shape = self._shape_tensor(row=row, col=col)
        return _ops.cast(
            result[..., :self._domain_dimension_tensor(shape=shape), :],
            self.dtype)
 def _assert_self_adjoint(self):
     dense = self.to_dense()
     logging.warn(
         "Using (possibly slow) default implementation of assert_self_adjoint."
         "  Requires conversion to a dense matrix.")
     return check_ops.assert_equal(
         dense,
         linalg.adjoint(dense),
         message="Matrix was not equal to its adjoint.")
Esempio n. 7
0
 def _dense_solve(self, rhs, adjoint=False, adjoint_arg=False):
     """Solve by conversion to a dense matrix."""
     if self.is_square is False:  # pylint: disable=g-bool-id-comparison
         raise NotImplementedError(
             "Solve is not yet implemented for non-square operators.")
     rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
     if self._can_use_cholesky():
         return linalg_ops.cholesky_solve(
             linalg_ops.cholesky(self.to_dense()), rhs)
     return linear_operator_util.matrix_solve_with_broadcast(
         self.to_dense(), rhs, adjoint=adjoint)
Esempio n. 8
0
  def _solve(self, rhs, adjoint=False, adjoint_arg=False):
    rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
    spectrum = self._conj_spectrum if adjoint else self._spectrum_complex

    rhs, spectrum = self._broadcast_batch_dims(rhs, spectrum)

    rhs_vb = self._vectorize_then_blockify(rhs)
    fft_rhs_vb = self._fft(rhs_vb)
    solution_vb = self._ifft(fft_rhs_vb / spectrum)
    x = self._unblockify_then_matricize(solution_vb)
    return _ops.cast(x, self.dtype)
 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
     """Default implementation of _solve."""
     if self.is_square is False:
         raise NotImplementedError(
             "Solve is not yet implemented for non-square operators.")
     logging.warn(
         "Using (possibly slow) default implementation of solve."
         "  Requires conversion to a dense matrix and O(N^3) operations.")
     rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
     if self._can_use_cholesky():
         return linear_operator_util.cholesky_solve_with_broadcast(
             linalg_ops.cholesky(self.to_dense()), rhs)
     return linear_operator_util.matrix_solve_with_broadcast(
         self.to_dense(), rhs, adjoint=adjoint)
Esempio n. 10
0
    def _matmul(self, x, adjoint=False, adjoint_arg=False):
        # Given a Toeplitz matrix, we can embed it in a Circulant matrix to perform
        # efficient matrix multiplications. Given a Toeplitz matrix with first row
        # [t_0, t_1, ... t_{n-1}] and first column [t0, t_{-1}, ..., t_{-(n-1)},
        # let C by the circulant matrix with first column [t0, t_{-1}, ...,
        # t_{-(n-1)}, 0, t_{n-1}, ..., t_1]. Also adjoin to our input vector `x`
        # `n` zeros, to make it a vector of length `2n` (call it y). It can be shown
        # that if we take the first n entries of `Cy`, this is equal to the Toeplitz
        # multiplication. See:
        # http://math.mit.edu/icg/resources/teaching/18.085-spring2015/toeplitz.pdf
        # for more details.
        x = linalg.adjoint(x) if adjoint_arg else x
        expanded_x = array_ops.concat([x, array_ops.zeros_like(x)], axis=-2)
        result = self._circulant.matmul(expanded_x,
                                        adjoint=adjoint,
                                        adjoint_arg=False)

        return _ops.cast(result[..., :self.domain_dimension_tensor(), :],
                         self.dtype)
    def _matmul(self, x, adjoint=False, adjoint_arg=False):
        # Given a vector `v`, we would like to reflect `x` about the hyperplane
        # orthogonal to `v` going through the origin.  We first project `x` to `v`
        # to get v * dot(v, x) / dot(v, v).  After we project, we can reflect the
        # projection about the hyperplane by flipping sign to get
        # -v * dot(v, x) / dot(v, v).  Finally, we can add back the component
        # that is orthogonal to v. This is invariant under reflection, since the
        # whole hyperplane is invariant. This component is equal to x - v * dot(v,
        # x) / dot(v, v), giving the formula x - 2 * v * dot(v, x) / dot(v, v)
        # for the reflection.

        # Note that because this is a reflection, it lies in O(n) (for real vector
        # spaces) or U(n) (for complex vector spaces), and thus is its own adjoint.
        reflection_axis = ops.convert_to_tensor(self.reflection_axis)
        x = linalg.adjoint(x) if adjoint_arg else x
        normalized_axis = nn.l2_normalize(reflection_axis, axis=-1)
        mat = normalized_axis[..., _ops.newaxis]
        x_dot_normalized_v = _linalg.matmul(mat, x, adjoint_a=True)

        return x - 2 * mat * x_dot_normalized_v
Esempio n. 12
0
  def _matmul(self, x, adjoint=False, adjoint_arg=False):
    x = linalg.adjoint(x) if adjoint_arg else x
    # With F the matrix of a DFT, and F^{-1}, F^H the inverse and Hermitian
    # transpose, one can show that F^{-1} = F^{H} is the IDFT matrix.  Therefore
    # matmul(x) = F^{-1} diag(spectrum) F x,
    #           = F^{H} diag(spectrum) F x,
    # so that
    # matmul(x, adjoint=True) = F^{H} diag(conj(spectrum)) F x.
    spectrum = self._conj_spectrum if adjoint else self._spectrum_complex

    x = _ops.cast(x, spectrum.dtype)

    x, spectrum = self._broadcast_batch_dims(x, spectrum)

    x_vb = self._vectorize_then_blockify(x)
    fft_x_vb = self._fft(x_vb)
    block_vector_result = self._ifft(spectrum * fft_x_vb)
    y = self._unblockify_then_matricize(block_vector_result)

    return _ops.cast(y, self.dtype)
def _matmul(  # pylint:disable=missing-docstring
        a,
        b,
        transpose_a=False,
        transpose_b=False,
        adjoint_a=False,
        adjoint_b=False,
        a_is_sparse=False,
        b_is_sparse=False,
        name=None):
    if transpose_a or transpose_b:
        raise ValueError("Transposing not supported at this time.")
    if a_is_sparse or b_is_sparse:
        raise ValueError("Sparse methods not supported at this time.")
    if not isinstance(a, LinearOperator):
        # We use the identity (B^HA^H)^H =  AB
        adjoint_matmul = b.matmul(a,
                                  adjoint=(not adjoint_b),
                                  adjoint_arg=(not adjoint_a),
                                  name=name)
        return linalg.adjoint(adjoint_matmul)
    return a.matmul(b, adjoint=adjoint_a, adjoint_arg=adjoint_b, name=name)
    def _solve(self, rhs, adjoint=False, adjoint_arg=False):
        # Here we follow the same use of Roth's column lemma as in `matmul`, with
        # the key difference that we replace all `matmul` instances with `solve`.
        # This follows from the property that inv(A x B) = inv(A) x inv(B).

        # Below we document the shape manipulation for adjoint=False,
        # adjoint_arg=False, but the general case of different adjoints is still
        # handled.

        if adjoint_arg:
            rhs = linalg.adjoint(rhs)

        # Always add a batch dimension to enable broadcasting to work.
        batch_shape = self._compute_ones_matrix_shape()
        rhs = rhs + array_ops.zeros(batch_shape, dtype=rhs.dtype)

        # rhs has shape [B, R, C], where B represent some number of batch
        # dimensions,
        # R represents the number of rows, and C represents the number of columns.
        # In order to apply Roth's column lemma, we need to operate on a batch of
        # column vectors, so we reshape into a batch of column vectors. We put it
        # at the front to ensure that broadcasting between operators to the batch
        # dimensions B still works.
        output = _rotate_last_dim(rhs, rotate_right=True)

        # Also expand the shape to be [A, C, B, R]. The first dimension will be
        # used to accumulate dimensions from each operator matmul.
        output = output[_ops.newaxis, ...]

        # In this loop, A is going to refer to the value of the accumulated
        # dimension. A = 1 at the start, and will end up being self.range_dimension.
        # V will refer to the last dimension. V = R at the start, and will end up
        # being 1 in the end.
        for operator in self.operators[:-1]:
            # Reshape output from [A, C, B, V] to be
            # [A, C, B, V / op.domain_dimension, op.domain_dimension]
            if adjoint:
                operator_dimension = operator.range_dimension_tensor()
            else:
                operator_dimension = operator.domain_dimension_tensor()

            output = _unvec_by(output, operator_dimension)

            # We are computing (XA^-1^T) = (A^-1 X^T)^T.
            # output has [A, C, B, V / op.domain_dimension, op.domain_dimension],
            # which is being converted to:
            # [A, C, B, V / op.domain_dimension, op.range_dimension]
            output = _linalg.matrix_transpose(output)
            output = operator.solve(output, adjoint=adjoint, adjoint_arg=False)
            output = _linalg.matrix_transpose(output)
            # Rearrange it to [A * op.range_dimension, C, B, V / op.domain_dimension]
            output = _rotate_last_dim(output, rotate_right=False)
            output = _vec(output)
            output = _rotate_last_dim(output, rotate_right=True)

        # After the loop, we will have
        # A = self.range_dimension / op[-1].range_dimension
        # V = op[-1].domain_dimension

        # We convert that using matvec to get:
        # [A, C, B, op[-1].range_dimension]
        output = self.operators[-1].solvevec(output, adjoint=adjoint)
        # Rearrange shape to be [B1, ... Bn, self.range_dimension, C]
        output = _rotate_last_dim(output, rotate_right=False)
        output = _vec(output)
        output = _rotate_last_dim(output, rotate_right=False)

        if tensor_shape.TensorShape(rhs.shape).is_fully_defined():
            column_dim = tensor_shape.TensorShape(rhs.shape)[-1]
            broadcast_batch_shape = common_shapes.broadcast_shape(
                tensor_shape.TensorShape(rhs.shape)[:-2], self.batch_shape)
            if adjoint:
                matrix_dimensions = [self.domain_dimension, column_dim]
            else:
                matrix_dimensions = [self.range_dimension, column_dim]

            tensorshape_util.set_shape(
                output, broadcast_batch_shape.concatenate(matrix_dimensions))

        return output
    def _matmul(self, x, adjoint=False, adjoint_arg=False):
        # Here we heavily rely on Roth's column Lemma [1]:
        # (A x B) * vec X = vec BXA^T,
        # where vec stacks all the columns of the matrix under each other. In our
        # case, x represents a batch of vec X (i.e. we think of x as a batch of
        # column vectors, rather than a matrix). Each member of the batch can be
        # reshaped to a matrix (hence we get a batch of matrices).
        # We can iteratively apply this lemma by noting that if B is a Kronecker
        # product, then we can apply the lemma again.

        # [1] W. E. Roth, "On direct product matrices,"
        # Bulletin of the American Mathematical Society, vol. 40, pp. 461-468,
        # 1934

        # Efficiency

        # Naively doing the Kronecker product, by calculating the dense matrix and
        # applying it will can take cubic time in  the size of domain_dimension
        # (assuming a square matrix). The other issue is that calculating the dense
        # matrix can be prohibitively expensive, in that it can take a large amount
        # of memory.
        #
        # This implementation avoids this memory blow up by only computing matmuls
        # with the factors. In this way, we don't have to realize the dense matrix.
        # In terms of complexity, if we have Kronecker Factors of size:
        # (n1, n1), (n2, n2), (n3, n3), ... (nJ, nJ), with N = \prod n_i, and we
        # have as input a [N, M] matrix, the naive approach would take O(N^2 M).
        # With this approach (ignoring reshaping of tensors and transposes for now),
        # the time complexity can be O(M * (\sum n_i) * N). There is also the
        # benefit of batched multiplication (In this example, the batch size is
        # roughly M * N) so this can be much faster. However, not factored in are
        # the costs of the several transposing of tensors, which can affect cache
        # behavior.

        # Below we document the shape manipulation for adjoint=False,
        # adjoint_arg=False, but the general case of different adjoints is still
        # handled.

        if adjoint_arg:
            x = linalg.adjoint(x)

        # Always add a batch dimension to enable broadcasting to work.
        batch_shape = self._compute_ones_matrix_shape()
        x = x + array_ops.zeros(batch_shape, dtype=x.dtype)

        # x has shape [B, R, C], where B represent some number of batch dimensions,
        # R represents the number of rows, and C represents the number of columns.
        # In order to apply Roth's column lemma, we need to operate on a batch of
        # column vectors, so we reshape into a batch of column vectors. We put it
        # at the front to ensure that broadcasting between operators to the batch
        # dimensions B still works.
        output = _rotate_last_dim(x, rotate_right=True)

        # Also expand the shape to be [A, C, B, R]. The first dimension will be
        # used to accumulate dimensions from each operator matmul.
        output = output[_ops.newaxis, ...]

        # In this loop, A is going to refer to the value of the accumulated
        # dimension. A = 1 at the start, and will end up being self.range_dimension.
        # V will refer to the last dimension. V = R at the start, and will end up
        # being 1 in the end.
        for operator in self.operators[:-1]:
            # Reshape output from [A, C, B, V] to be
            # [A, C, B, V / op.domain_dimension, op.domain_dimension]
            if adjoint:
                operator_dimension = operator.range_dimension_tensor()
            else:
                operator_dimension = operator.domain_dimension_tensor()

            output = _unvec_by(output, operator_dimension)

            # We are computing (XA^T) = (AX^T)^T.
            # output has [A, C, B, V / op.domain_dimension, op.domain_dimension],
            # which is being converted to:
            # [A, C, B, V / op.domain_dimension, op.range_dimension]
            output = _linalg.matrix_transpose(output)
            output = operator.matmul(output,
                                     adjoint=adjoint,
                                     adjoint_arg=False)
            output = _linalg.matrix_transpose(output)
            # Rearrange it to [A * op.range_dimension, C, B, V / op.domain_dimension]
            output = _rotate_last_dim(output, rotate_right=False)
            output = _vec(output)
            output = _rotate_last_dim(output, rotate_right=True)

        # After the loop, we will have
        # A = self.range_dimension / op[-1].range_dimension
        # V = op[-1].domain_dimension

        # We convert that using matvec to get:
        # [A, C, B, op[-1].range_dimension]
        output = self.operators[-1].matvec(output, adjoint=adjoint)
        # Rearrange shape to be [B1, ... Bn, self.range_dimension, C]
        output = _rotate_last_dim(output, rotate_right=False)
        output = _vec(output)
        output = _rotate_last_dim(output, rotate_right=False)

        if tensor_shape.TensorShape(x.shape).is_fully_defined():
            column_dim = tensor_shape.TensorShape(x.shape)[-1]
            broadcast_batch_shape = common_shapes.broadcast_shape(
                tensor_shape.TensorShape(x.shape)[:-2], self.batch_shape)
            if adjoint:
                matrix_dimensions = [self.domain_dimension, column_dim]
            else:
                matrix_dimensions = [self.range_dimension, column_dim]

            tensorshape_util.set_shape(
                output, broadcast_batch_shape.concatenate(matrix_dimensions))

        return output
Esempio n. 16
0
def _reshape_for_efficiency(a,
                            b,
                            transpose_a=False,
                            transpose_b=False,
                            adjoint_a=False,
                            adjoint_b=False):
    """Maybe reshape a, b, and return an inverse map.  For matmul/solve."""
    def identity(x):
        return x

    # At this point, we have not taken transpose/adjoint of a/b.
    still_need_to_transpose = True

    if _ops.TensorShape(a.shape).ndims is None or _ops.TensorShape(
            b.shape).ndims is None:
        return a, b, identity, still_need_to_transpose

    # This could be handled in the future, but seems less common.
    if _ops.TensorShape(a.shape).ndims >= _ops.TensorShape(b.shape).ndims:
        return a, b, identity, still_need_to_transpose

    # From now on, we might modify b, but will not modify a.

    # Suppose:
    #   _ops.TensorShape(a.shape) =     C + [m, n], _ops.TensorShape(b.shape) =
    #   _ops.TensorShape(b.shape) = S + C + [n, r]
    b_extra_ndims = _ops.TensorShape(b.shape).ndims - _ops.TensorShape(
        a.shape).ndims

    # b_extra_sh = S, b_main_sh = C + [n, r]
    b_extra_sh = array_ops.shape(b)[:b_extra_ndims]
    b_main_sh = array_ops.shape(b)[b_extra_ndims:]

    # No reason to flip unless the extra dims of b are big enough.  Why?
    # Assume adjoint/transpose = False.  Then...
    # By not flipping, we have to replicate a to shape
    #   b_extra_sh + _ops.TensorShape(a.shape),
    # which could use extra memory.  But in all cases, the final output has shape
    #   b_extra_sh + _ops.TensorShape(a.shape)[:-1] + _ops.TensorShape([b.shape)[-1]]
    # So we only end up creating a larger object if the end dim of b is smaller
    # than the end dim of a.  This often happens, e.g. if b was a vector that was
    # expanded to a matrix (by appending a singleton).

    # Since adjoint/transpose may not be False, we must make adjustments here.
    # The dim of b that holds the multiple equations.
    a_domain_sz_ = _ops.TensorShape(
        a.shape)[-2 if adjoint_a or transpose_a else -1]
    b_eq_sz_ = _ops.TensorShape(
        b.shape)[-2 if adjoint_b or transpose_b else -1]
    b_extra_sz_ = (
        np.prod(_ops.TensorShape(b.shape)[:b_extra_ndims].as_list()) if
        _ops.TensorShape(b.shape)[:b_extra_ndims].is_fully_defined() else None)
    if (a_domain_sz_ is not None and b_eq_sz_ is not None
            and b_extra_sz_ is not None):
        if b_extra_sz_ < 2 or a_domain_sz_ <= b_eq_sz_:
            return a, b, identity, still_need_to_transpose

    # At this point, we're flipping for sure!
    # Any transposes/adjoints will happen here explicitly, rather than in calling
    # code.  Why?  To avoid having to write separate complex code for each case.
    if adjoint_a:
        a = linalg.adjoint(a)
    elif transpose_a:
        a = linalg.transpose(a)
    if adjoint_b:
        b = linalg.adjoint(b)
    elif transpose_b:
        b = linalg.transpose(b)
    still_need_to_transpose = False

    # Recompute shapes, since the transpose/adjoint may have changed them.
    b_extra_sh = array_ops.shape(b)[:b_extra_ndims]
    b_main_sh = array_ops.shape(b)[b_extra_ndims:]

    # Permutation to put the extra dims at the end.
    perm = (np.concatenate(
        (np.arange(b_extra_ndims,
                   _ops.TensorShape(b.shape).ndims), np.arange(
                       0, b_extra_ndims)), 0))
    b_extra_on_end = array_ops.transpose(b, perm=perm)

    # Now squash this end into one long dim.
    b_squashed_end = array_ops.reshape(
        b_extra_on_end, array_ops.concat((b_main_sh[:-1], [-1]), 0))

    def reshape_inv(y):
        # Expand the extra dims hanging off the end, "b_extra_sh".
        # Note we use y_sh[:-1] + [b_main_sh[-1]] rather than b_main_sh, because y
        # Could have different batch dims than a and b, because of broadcasting.
        y_extra_shape = array_ops.concat(
            (array_ops.shape(y)[:-1], [b_main_sh[-1]], b_extra_sh), 0)
        y_extra_on_end = array_ops.reshape(y, y_extra_shape)
        inverse_perm = np.argsort(perm)
        return array_ops.transpose(y_extra_on_end, perm=inverse_perm)

    return a, b_squashed_end, reshape_inv, still_need_to_transpose
Esempio n. 17
0
 def _matmul(self, x, adjoint=False, adjoint_arg=False):
     diag_term = math_ops.conj(self._diag) if adjoint else self._diag
     x = linalg.adjoint(x) if adjoint_arg else x
     diag_mat = array_ops.expand_dims(diag_term, -1)
     return diag_mat * x
Esempio n. 18
0
 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
     diag_term = math_ops.conj(self._diag) if adjoint else self._diag
     rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
     inv_diag_mat = array_ops.expand_dims(1. / diag_term, -1)
     return rhs * inv_diag_mat
 def _to_dense(self):
   if self.is_self_adjoint:
     return self.operator.to_dense()
   return linalg.adjoint(self.operator.to_dense())
    def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"):
        """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`.

    The returned `Tensor` will be close to an exact solution if `A` is well
    conditioned. Otherwise closeness will vary. See class docstring for details.

    Given the blockwise `n + 1`-by-`n + 1` linear operator:

    op = [[A_00     0  ...     0  ...    0],
          [A_10  A_11  ...     0  ...    0],
          ...
          [A_k0  A_k1  ...  A_kk  ...    0],
          ...
          [A_n0  A_n1  ...  A_nk  ... A_nn]]

    we find `x = op.solve(y)` by observing that

    `y_k = A_k0.matmul(x_0) + A_k1.matmul(x_1) + ... + A_kk.matmul(x_k)`

    and therefore

    `x_k = A_kk.solve(y_k -
                      A_k0.matmul(x_0) - ... - A_k(k-1).matmul(x_(k-1)))`

    where `x_k` and `y_k` are the `k`th blocks obtained by decomposing `x`
    and `y` along their appropriate axes.

    We first solve `x_0 = A_00.solve(y_0)`. Proceeding inductively, we solve
    for `x_k`, `k = 1..n`, given `x_0..x_(k-1)`.

    The adjoint case is solved similarly, beginning with
    `x_n = A_nn.solve(y_n, adjoint=True)` and proceeding backwards.

    Examples:

    ```python
    # Make an operator acting like batch matrix A.  Assume tensor_shape.TensorShape(A.shape) = [..., M, N]
    operator = LinearOperator(...)
    tensor_shape.TensorShape(operator.shape) = [..., M, N]

    # Solve R > 0 linear systems for every member of the batch.
    RHS = ... # shape [..., M, R]

    X = operator.solve(RHS)
    # X[..., :, r] is the solution to the r'th linear system
    # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r]

    operator.matmul(X)
    ==> RHS
    ```

    Args:
      rhs: `Tensor` with same `dtype` as this operator and compatible shape,
        or a list of `Tensor`s. `Tensor`s are treated like a [batch] matrices
        meaning for every set of leading dimensions, the last two dimensions
        defines a matrix.
        See class docstring for definition of compatibility.
      adjoint: Python `bool`.  If `True`, solve the system involving the adjoint
        of this `LinearOperator`:  `A^H X = rhs`.
      adjoint_arg:  Python `bool`.  If `True`, solve `A X = rhs^H` where `rhs^H`
        is the hermitian transpose (transposition and complex conjugation).
      name:  A name scope to use for ops added by this method.

    Returns:
      `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`.

    Raises:
      NotImplementedError:  If `self.is_non_singular` or `is_square` is False.
    """
        if self.is_non_singular is False:
            raise NotImplementedError(
                "Exact solve not implemented for an operator that is expected to "
                "be singular.")
        if self.is_square is False:
            raise NotImplementedError(
                "Exact solve not implemented for an operator that is expected to "
                "not be square.")
        if isinstance(rhs, linear_operator.LinearOperator):
            left_operator = self.adjoint() if adjoint else self
            right_operator = rhs.adjoint() if adjoint_arg else rhs

            if (right_operator.range_dimension is not None
                    and left_operator.domain_dimension is not None
                    and right_operator.range_dimension !=
                    left_operator.domain_dimension):
                raise ValueError(
                    "Operators are incompatible. Expected `rhs` to have dimension"
                    " {} but got {}.".format(left_operator.domain_dimension,
                                             right_operator.range_dimension))
            with self._name_scope(name):
                return linear_operator_algebra.solve(left_operator,
                                                     right_operator)

        with self._name_scope(name):
            block_dimensions = (self._block_domain_dimensions()
                                if adjoint else self._block_range_dimensions())
            arg_dim = -1 if adjoint_arg else -2
            blockwise_arg = linear_operator_util.arg_is_blockwise(
                block_dimensions, rhs, arg_dim)
            if blockwise_arg:
                for i, block in enumerate(rhs):
                    if not isinstance(block, linear_operator.LinearOperator):
                        block = ops.convert_to_tensor(block)
                        # self._check_input_dtype(block)
                        block_dimensions[i].assert_is_compatible_with(
                            tensor_shape.TensorShape(block.shape)[arg_dim])
                        rhs[i] = block
                if adjoint_arg:
                    split_rhs = [linalg.adjoint(y) for y in rhs]
                else:
                    split_rhs = rhs

            else:
                rhs = ops.convert_to_tensor(rhs, name="rhs")
                # self._check_input_dtype(rhs)
                op_dimension = (self.domain_dimension
                                if adjoint else self.range_dimension)
                op_dimension.assert_is_compatible_with(
                    tensor_shape.TensorShape(rhs.shape)[arg_dim])

                rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
                split_rhs = linear_operator_util.split_arg_into_blocks(
                    self._block_domain_dimensions(),
                    self._block_domain_dimension_tensors,
                    rhs,
                    axis=-2)

            solution_list = []
            if adjoint:
                # For an adjoint blockwise lower-triangular linear operator, the system
                # must be solved bottom to top. Iterate backwards over rows of the
                # adjoint (i.e. columns of the non-adjoint operator).
                for index in reversed(range(len(self.operators))):
                    y = split_rhs[index]
                    # Iterate top to bottom over the operators in the off-diagonal portion
                    # of the column-partition (i.e. row-partition of the adjoint), apply
                    # the operator to the respective block of the solution found in
                    # previous iterations, and subtract the result from the `rhs` block.
                    # For example,let `A`, `B`, and `D` be the linear operators in the top
                    # row-partition of the adjoint of
                    # `LinearOperatorBlockLowerTriangular([[A], [B, C], [D, E, F]])`,
                    # and `x_1` and `x_2` be blocks of the solution found in previous
                    # iterations of the outer loop. The following loop (when `index == 0`)
                    # expresses
                    # `Ax_0 + Bx_1 + Dx_2 = y_0` as `Ax_0 = y_0*`, where
                    # `y_0* = y_0 - Bx_1 - Dx_2`.
                    for j in reversed(range(index + 1, len(self.operators))):
                        y -= self.operators[j][index].matmul(
                            solution_list[len(self.operators) - 1 - j],
                            adjoint=adjoint)
                    # Continuing the example above, solve `Ax_0 = y_0*` for `x_0`.
                    solution_list.append(self._diagonal_operators[index].solve(
                        y, adjoint=adjoint))
                solution_list.reverse()
            else:
                # Iterate top to bottom over the row-partitions.
                for row, y in zip(self.operators, split_rhs):
                    # Iterate left to right over the operators in the off-diagonal portion
                    # of the row-partition, apply the operator to the block of the
                    # solution found in previous iterations, and subtract the result from
                    # the `rhs` block. For example, let `D`, `E`, and `F` be the linear
                    # operators in the bottom row-partition of
                    # `LinearOperatorBlockLowerTriangular([[A], [B, C], [D, E, F]])` and
                    # `x_0` and `x_1` be blocks of the solution found in previous
                    # iterations of the outer loop. The following loop
                    # (when `index == 2`), expresses
                    # `Dx_0 + Ex_1 + Fx_2 = y_2` as `Fx_2 = y_2*`, where
                    # `y_2* = y_2 - D_x0 - Ex_1`.
                    for i, operator in enumerate(row[:-1]):
                        y -= operator.matmul(solution_list[i], adjoint=adjoint)
                    # Continuing the example above, solve `Fx_2 = y_2*` for `x_2`.
                    solution_list.append(row[-1].solve(y, adjoint=adjoint))

            if blockwise_arg:
                return solution_list

            solution_list = linear_operator_util.broadcast_matrix_batch_dims(
                solution_list)
            return array_ops.concat(solution_list, axis=-2)
 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
   rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
   return linalg.triangular_solve(
       self._get_tril(), rhs, lower=True, adjoint=adjoint)