def test_less_than_two_dims_raises_static(self):
        x = rng.rand(3)
        y = rng.rand(1, 1)

        with self.assertRaisesRegexp(ValueError, "at least two dimensions"):
            linear_operator_util.broadcast_matrix_batch_dims([x, y])

        with self.assertRaisesRegexp(ValueError, "at least two dimensions"):
            linear_operator_util.broadcast_matrix_batch_dims([y, x])
  def test_less_than_two_dims_raises_static(self):
    x = rng.rand(3)
    y = rng.rand(1, 1)

    with self.assertRaisesRegexp(ValueError, "at least two dimensions"):
      linear_operator_util.broadcast_matrix_batch_dims([x, y])

    with self.assertRaisesRegexp(ValueError, "at least two dimensions"):
      linear_operator_util.broadcast_matrix_batch_dims([y, x])
Exemple #3
0
    def _broadcast_batch_dims(self, x, spectrum):
        """Broadcast batch dims of batch matrix `x` and spectrum."""
        spectrum = ops.convert_to_tensor_v2_with_dispatch(spectrum,
                                                          name="spectrum")
        # spectrum.shape = batch_shape + block_shape
        # First make spectrum a batch matrix with
        #   spectrum.shape = batch_shape + [prod(block_shape), 1]
        batch_shape = self._batch_shape_tensor(shape=self._shape_tensor(
            spectrum=spectrum))
        spec_mat = array_ops.reshape(
            spectrum, array_ops.concat((batch_shape, [-1, 1]), axis=0))
        # Second, broadcast, possibly requiring an addition of array of zeros.
        x, spec_mat = linear_operator_util.broadcast_matrix_batch_dims(
            (x, spec_mat))
        # Third, put the block shape back into spectrum.
        x_batch_shape = array_ops.shape(x)[:-2]
        spectrum_shape = array_ops.shape(spectrum)
        spectrum = array_ops.reshape(
            spec_mat,
            array_ops.concat(
                (x_batch_shape,
                 self._block_shape_tensor(spectrum_shape=spectrum_shape)),
                axis=0))

        return x, spectrum
    def test_one_batch_matrix_returned_after_tensor_conversion(self):
        arr = rng.rand(2, 3, 4)
        tensor, = linear_operator_util.broadcast_matrix_batch_dims([arr])
        self.assertTrue(isinstance(tensor, ops.Tensor))

        with self.cached_session():
            self.assertAllClose(arr, self.evaluate(tensor))
  def _to_dense(self):
    num_cols = 0
    rows = []
    broadcasted_blocks = [operator.to_dense() for operator in self.operators]
    broadcasted_blocks = linear_operator_util.broadcast_matrix_batch_dims(
        broadcasted_blocks)
    for block in broadcasted_blocks:
      batch_row_shape = array_ops.shape(block)[:-1]

      zeros_to_pad_before_shape = array_ops.concat(
          [batch_row_shape, [num_cols]], axis=-1)
      zeros_to_pad_before = array_ops.zeros(
          shape=zeros_to_pad_before_shape, dtype=block.dtype)
      num_cols += array_ops.shape(block)[-1]
      zeros_to_pad_after_shape = array_ops.concat(
          [batch_row_shape,
           [self.domain_dimension_tensor() - num_cols]], axis=-1)
      zeros_to_pad_after = array_ops.zeros(
          shape=zeros_to_pad_after_shape, dtype=block.dtype)

      rows.append(array_ops.concat(
          [zeros_to_pad_before, block, zeros_to_pad_after], axis=-1))

    mat = array_ops.concat(rows, axis=-2)
    mat.set_shape(self.shape)
    return mat
  def test_one_batch_matrix_returned_after_tensor_conversion(self):
    arr = rng.rand(2, 3, 4)
    tensor, = linear_operator_util.broadcast_matrix_batch_dims([arr])
    self.assertTrue(isinstance(tensor, ops.Tensor))

    with self.cached_session():
      self.assertAllClose(arr, tensor.eval())
    def _matmul(self, x, adjoint=False, adjoint_arg=False):
        arg_dim = -1 if adjoint_arg else -2
        block_dimensions = (self._block_range_dimensions()
                            if adjoint else self._block_domain_dimensions())
        blockwise_arg = linear_operator_util.arg_is_blockwise(
            block_dimensions, x, arg_dim)
        if blockwise_arg:
            split_x = x
        else:
            split_dim = -1 if adjoint_arg else -2
            # Split input by rows normally, and otherwise columns.
            split_x = linear_operator_util.split_arg_into_blocks(
                self._block_domain_dimensions(),
                self._block_domain_dimension_tensors,
                x,
                axis=split_dim)

        result_list = []
        for index, operator in enumerate(self.operators):
            result_list += [
                operator.matmul(split_x[index],
                                adjoint=adjoint,
                                adjoint_arg=adjoint_arg)
            ]

        if blockwise_arg:
            return result_list

        result_list = linear_operator_util.broadcast_matrix_batch_dims(
            result_list)
        return array_ops.concat(result_list, axis=-2)
  def _operator_and_matrix(self, build_info, dtype, use_placeholder):
    shape = list(build_info.shape)
    expected_factors = build_info.__dict__["factors"]
    matrices = [
        linear_operator_test_util.random_positive_definite_matrix(
            block_shape, dtype, force_well_conditioned=True)
        for block_shape in expected_factors
    ]

    lin_op_matrices = matrices

    if use_placeholder:
      lin_op_matrices = [
          array_ops.placeholder_with_default(m, shape=None) for m in matrices]

    operator = kronecker.LinearOperatorKronecker(
        [linalg.LinearOperatorFullMatrix(
            l, is_square=True) for l in lin_op_matrices])

    matrices = linear_operator_util.broadcast_matrix_batch_dims(matrices)

    kronecker_dense = _kronecker_dense(matrices)

    if not use_placeholder:
      kronecker_dense.set_shape(shape)

    return operator, kronecker_dense
    def _to_dense(self):
        num_cols = 0
        rows = []
        broadcasted_blocks = [
            operator.to_dense() for operator in self.operators
        ]
        broadcasted_blocks = linear_operator_util.broadcast_matrix_batch_dims(
            broadcasted_blocks)
        for block in broadcasted_blocks:
            batch_row_shape = array_ops.shape(block)[:-1]

            zeros_to_pad_before_shape = array_ops.concat(
                [batch_row_shape, [num_cols]], axis=-1)
            zeros_to_pad_before = array_ops.zeros(
                shape=zeros_to_pad_before_shape, dtype=block.dtype)
            num_cols += array_ops.shape(block)[-1]
            zeros_to_pad_after_shape = array_ops.concat(
                [batch_row_shape, [self.domain_dimension_tensor() - num_cols]],
                axis=-1)
            zeros_to_pad_after = array_ops.zeros(
                shape=zeros_to_pad_after_shape, dtype=block.dtype)

            rows.append(
                array_ops.concat(
                    [zeros_to_pad_before, block, zeros_to_pad_after], axis=-1))

        mat = array_ops.concat(rows, axis=-2)
        mat.set_shape(self.shape)
        return mat
def _test_matmul_base(self, use_placeholder, shapes_info, dtype, adjoint,
                      adjoint_arg, blockwise_arg, with_batch):
    # If batch dimensions are omitted, but there are
    # no batch dimensions for the linear operator, then
    # skip the test case. This is already checked with
    # with_batch=True.
    if not with_batch and len(shapes_info.shape) <= 2:
        return
    with self.session(graph=ops.Graph()) as sess:
        sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
        operator, mat = self.operator_and_matrix(
            shapes_info, dtype, use_placeholder=use_placeholder)
        x = self.make_x(operator, adjoint=adjoint, with_batch=with_batch)
        # If adjoint_arg, compute A X^H^H = A X.
        if adjoint_arg:
            op_matmul = operator.matmul(linalg.adjoint(x),
                                        adjoint=adjoint,
                                        adjoint_arg=adjoint_arg)
        else:
            op_matmul = operator.matmul(x, adjoint=adjoint)
        mat_matmul = math_ops.matmul(mat, x, adjoint_a=adjoint)
        if not use_placeholder:
            self.assertAllEqual(op_matmul.shape, mat_matmul.shape)

        # If the operator is blockwise, test both blockwise `x` and `Tensor` `x`;
        # else test only `Tensor` `x`. In both cases, evaluate all results in a
        # single `sess.run` call to avoid re-sampling the random `x` in graph mode.
        if blockwise_arg and len(operator.operators) > 1:
            # pylint: disable=protected-access
            block_dimensions = (operator._block_range_dimensions() if adjoint
                                else operator._block_domain_dimensions())
            block_dimensions_fn = (operator._block_range_dimension_tensors
                                   if adjoint else
                                   operator._block_domain_dimension_tensors)
            # pylint: enable=protected-access
            split_x = linear_operator_util.split_arg_into_blocks(
                block_dimensions, block_dimensions_fn, x, axis=-2)
            if adjoint_arg:
                split_x = [linalg.adjoint(y) for y in split_x]
            split_matmul = operator.matmul(split_x,
                                           adjoint=adjoint,
                                           adjoint_arg=adjoint_arg)

            self.assertEqual(len(split_matmul), len(operator.operators))
            split_matmul = linear_operator_util.broadcast_matrix_batch_dims(
                split_matmul)
            fused_block_matmul = array_ops.concat(split_matmul, axis=-2)
            op_matmul_v, mat_matmul_v, fused_block_matmul_v = sess.run(
                [op_matmul, mat_matmul, fused_block_matmul])

            # Check that the operator applied to blockwise input gives the same result
            # as matrix multiplication.
            self.assertAC(fused_block_matmul_v, mat_matmul_v)
        else:
            op_matmul_v, mat_matmul_v = sess.run([op_matmul, mat_matmul])

        # Check that the operator applied to a `Tensor` gives the same result as
        # matrix multiplication.
        self.assertAC(op_matmul_v, mat_matmul_v)
 def _diag_part(self):
   diag_list = []
   for operator in self.operators:
     # Extend the axis for broadcasting.
     diag_list += [operator.diag_part()[..., array_ops.newaxis]]
   diag_list = linear_operator_util.broadcast_matrix_batch_dims(diag_list)
   diagonal = array_ops.concat(diag_list, axis=-2)
   return array_ops.squeeze(diagonal, axis=-1)
 def _diag_part(self):
     diag_list = []
     for operator in self.operators:
         # Extend the axis for broadcasting.
         diag_list += [operator.diag_part()[..., array_ops.newaxis]]
     diag_list = linear_operator_util.broadcast_matrix_batch_dims(diag_list)
     diagonal = array_ops.concat(diag_list, axis=-2)
     return array_ops.squeeze(diagonal, axis=-1)
 def _eigvals(self):
     eig_list = []
     for op in self._diagonal_operators:
         # Extend the axis for broadcasting.
         eig_list.append(op.eigvals()[..., array_ops.newaxis])
     eig_list = linear_operator_util.broadcast_matrix_batch_dims(eig_list)
     eigs = array_ops.concat(eig_list, axis=-2)
     return array_ops.squeeze(eigs, axis=-1)
def _test_solve_base(self, use_placeholder, shapes_info, dtype, adjoint,
                     adjoint_arg, blockwise_arg, with_batch):
    # If batch dimensions are omitted, but there are
    # no batch dimensions for the linear operator, then
    # skip the test case. This is already checked with
    # with_batch=True.
    if not with_batch and len(shapes_info.shape) <= 2:
        return
    with self.session(graph=ops.Graph()) as sess:
        sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
        operator, mat = self.operator_and_matrix(
            shapes_info, dtype, use_placeholder=use_placeholder)
        rhs = self.make_rhs(operator, adjoint=adjoint, with_batch=with_batch)
        # If adjoint_arg, solve A X = (rhs^H)^H = rhs.
        if adjoint_arg:
            op_solve = operator.solve(linalg.adjoint(rhs),
                                      adjoint=adjoint,
                                      adjoint_arg=adjoint_arg)
        else:
            op_solve = operator.solve(rhs,
                                      adjoint=adjoint,
                                      adjoint_arg=adjoint_arg)
        mat_solve = linear_operator_util.matrix_solve_with_broadcast(
            mat, rhs, adjoint=adjoint)
        if not use_placeholder:
            self.assertAllEqual(op_solve.shape, mat_solve.shape)

        # If the operator is blockwise, test both blockwise rhs and `Tensor` rhs;
        # else test only `Tensor` rhs. In both cases, evaluate all results in a
        # single `sess.run` call to avoid re-sampling the random rhs in graph mode.
        if blockwise_arg and len(operator.operators) > 1:
            split_rhs = linear_operator_util.split_arg_into_blocks(
                operator._block_domain_dimensions(),  # pylint: disable=protected-access
                operator._block_domain_dimension_tensors,  # pylint: disable=protected-access
                rhs,
                axis=-2)
            if adjoint_arg:
                split_rhs = [linalg.adjoint(y) for y in split_rhs]
            split_solve = operator.solve(split_rhs,
                                         adjoint=adjoint,
                                         adjoint_arg=adjoint_arg)
            self.assertEqual(len(split_solve), len(operator.operators))
            split_solve = linear_operator_util.broadcast_matrix_batch_dims(
                split_solve)
            fused_block_solve = array_ops.concat(split_solve, axis=-2)
            op_solve_v, mat_solve_v, fused_block_solve_v = sess.run(
                [op_solve, mat_solve, fused_block_solve])

            # Check that the operator and matrix give the same solution when the rhs
            # is blockwise.
            self.assertAC(mat_solve_v, fused_block_solve_v)
        else:
            op_solve_v, mat_solve_v = sess.run([op_solve, mat_solve])

        # Check that the operator and matrix give the same solution when the rhs is
        # a `Tensor`.
        self.assertAC(op_solve_v, mat_solve_v)
 def _diag_part(self):
     diag_list = []
     for op in self._diagonal_operators:
         # Extend the axis, since `broadcast_matrix_batch_dims` treats all but the
         # final two dimensions as batch dimensions.
         diag_list.append(op.diag_part()[..., array_ops.newaxis])
     diag_list = linear_operator_util.broadcast_matrix_batch_dims(diag_list)
     diagonal = array_ops.concat(diag_list, axis=-2)
     return array_ops.squeeze(diagonal, axis=-1)
  def _matmul(self, x, adjoint=False, adjoint_arg=False):
    split_dim = -1 if adjoint_arg else -2
    # Split input by rows normally, and otherwise columns.
    split_x = self._split_input_into_blocks(x, axis=split_dim)

    result_list = []
    for index, operator in enumerate(self.operators):
      result_list += [operator.matmul(
          split_x[index], adjoint=adjoint, adjoint_arg=adjoint_arg)]
    result_list = linear_operator_util.broadcast_matrix_batch_dims(
        result_list)
    return array_ops.concat(result_list, axis=-2)
  def _matmul(self, x, adjoint=False, adjoint_arg=False):
    split_dim = -1 if adjoint_arg else -2
    # Split input by rows normally, and otherwise columns.
    split_x = self._split_input_into_blocks(x, axis=split_dim)

    result_list = []
    for index, operator in enumerate(self.operators):
      result_list += [operator.matmul(
          split_x[index], adjoint=adjoint, adjoint_arg=adjoint_arg)]
    result_list = linear_operator_util.broadcast_matrix_batch_dims(
        result_list)
    return array_ops.concat(result_list, axis=-2)
Exemple #18
0
 def _eigvals(self):
   if not all(operator.is_square for operator in self.operators):
     raise NotImplementedError(
         "`eigvals` not implemented for an operator whose blocks are not "
         "square.")
   eig_list = []
   for operator in self.operators:
     # Extend the axis for broadcasting.
     eig_list += [operator.eigvals()[..., array_ops.newaxis]]
   eig_list = linear_operator_util.broadcast_matrix_batch_dims(eig_list)
   eigs = array_ops.concat(eig_list, axis=-2)
   return array_ops.squeeze(eigs, axis=-1)
Exemple #19
0
 def _diag_part(self):
   if not all(operator.is_square for operator in self.operators):
     raise NotImplementedError(
         "`diag_part` not implemented for an operator whose blocks are not "
         "square.")
   diag_list = []
   for operator in self.operators:
     # Extend the axis for broadcasting.
     diag_list += [operator.diag_part()[..., array_ops.newaxis]]
   diag_list = linear_operator_util.broadcast_matrix_batch_dims(diag_list)
   diagonal = array_ops.concat(diag_list, axis=-2)
   return array_ops.squeeze(diagonal, axis=-1)
  def operator_and_matrix(
      self, shape_info, dtype, use_placeholder,
      ensure_self_adjoint_and_pd=False):

    expected_blocks = (
        shape_info.__dict__["blocks"] if "blocks" in shape_info.__dict__
        else [[list(shape_info.shape)]])

    matrices = []
    for i, row_shapes in enumerate(expected_blocks):
      row = []
      for j, block_shape in enumerate(row_shapes):
        if i == j:  # operator is on the diagonal
          row.append(
              linear_operator_test_util.random_positive_definite_matrix(
                  block_shape, dtype, force_well_conditioned=True))
        else:
          row.append(
              linear_operator_test_util.random_normal(block_shape, dtype=dtype))
      matrices.append(row)

    lin_op_matrices = matrices

    if use_placeholder:
      lin_op_matrices = [[
          array_ops.placeholder_with_default(
              matrix, shape=None) for matrix in row] for row in matrices]

    operator = block_lower_triangular.LinearOperatorBlockLowerTriangular(
        [[linalg.LinearOperatorFullMatrix(  # pylint:disable=g-complex-comprehension
            l,
            is_square=True,
            is_self_adjoint=True if ensure_self_adjoint_and_pd else None,
            is_positive_definite=True if ensure_self_adjoint_and_pd else None)
          for l in row] for row in lin_op_matrices])

    # Should be auto-set.
    self.assertTrue(operator.is_square)

    # Broadcast the shapes.
    expected_shape = list(shape_info.shape)
    broadcasted_matrices = linear_operator_util.broadcast_matrix_batch_dims(
        [op for row in matrices for op in row])  # pylint: disable=g-complex-comprehension
    matrices = [broadcasted_matrices[i * (i + 1) // 2:(i + 1) * (i + 2) // 2]
                for i in range(len(matrices))]

    block_lower_triangular_dense = _block_lower_triangular_dense(
        expected_shape, matrices)

    if not use_placeholder:
      block_lower_triangular_dense.set_shape(expected_shape)

    return operator, block_lower_triangular_dense
  def _solve(self, rhs, adjoint=False, adjoint_arg=False):
    split_dim = -1 if adjoint_arg else -2
    # Split input by rows normally, and otherwise columns.
    split_rhs = self._split_input_into_blocks(rhs, axis=split_dim)

    solution_list = []
    for index, operator in enumerate(self.operators):
      solution_list += [operator.solve(
          split_rhs[index], adjoint=adjoint, adjoint_arg=adjoint_arg)]

    solution_list = linear_operator_util.broadcast_matrix_batch_dims(
        solution_list)
    return array_ops.concat(solution_list, axis=-2)
  def _solve(self, rhs, adjoint=False, adjoint_arg=False):
    split_dim = -1 if adjoint_arg else -2
    # Split input by rows normally, and otherwise columns.
    split_rhs = self._split_input_into_blocks(rhs, axis=split_dim)

    solution_list = []
    for index, operator in enumerate(self.operators):
      solution_list += [operator.solve(
          split_rhs[index], adjoint=adjoint, adjoint_arg=adjoint_arg)]

    solution_list = linear_operator_util.broadcast_matrix_batch_dims(
        solution_list)
    return array_ops.concat(solution_list, axis=-2)
  def _matmul(self, x, adjoint=False, adjoint_arg=False):
    split_dim = -1 if adjoint_arg else -2
    # Split input by columns if adjoint_arg is True, else rows
    split_x = self._split_input_into_blocks(x, axis=split_dim)

    result_list = []
    # Iterate over row-partitions (i.e. column-partitions of the adjoint).
    if adjoint:
      for index in range(len(self.operators)):
        # Begin with the operator on the diagonal and apply it to the respective
        # `rhs` block.
        result = self.operators[index][index].matmul(
            split_x[index], adjoint=adjoint, adjoint_arg=adjoint_arg)

        # Iterate top to bottom over the operators in the remainder of the
        # column-partition (i.e. left to right over the row-partition of the
        # adjoint), apply the operator to the respective `rhs` block and
        # accumulate the sum. For example, given the
        # `LinearOperatorBlockLowerTriangular`:
        #
        # op = [[A, 0, 0],
        #       [B, C, 0],
        #       [D, E, F]]
        #
        # if `index = 1`, the following loop calculates:
        # `y_1 = (C.matmul(x_1, adjoint=adjoint) +
        #         E.matmul(x_2, adjoint=adjoint)`,
        # where `x_1` and `x_2` are splits of `x`.
        for j in range(index + 1, len(self.operators)):
          result += self.operators[j][index].matmul(
              split_x[j], adjoint=adjoint, adjoint_arg=adjoint_arg)
        result_list.append(result)
    else:
      for row in self.operators:
        # Begin with the left-most operator in the row-partition and apply it to
        # the first `rhs` block.
        result = row[0].matmul(
            split_x[0], adjoint=adjoint, adjoint_arg=adjoint_arg)
        # Iterate left to right over the operators in the remainder of the row
        # partition, apply the operator to the respective `rhs` block, and
        # accumulate the sum.
        for j, operator in enumerate(row[1:]):
          result += operator.matmul(
              split_x[j + 1], adjoint=adjoint, adjoint_arg=adjoint_arg)
        result_list.append(result)

    result_list = linear_operator_util.broadcast_matrix_batch_dims(
        result_list)
    return array_ops.concat(result_list, axis=-2)
    def _operator_and_mat_and_feed_dict(self, build_info, dtype,
                                        use_placeholder):
        shape = list(build_info.shape)
        expected_blocks = (build_info.__dict__["blocks"]
                           if "blocks" in build_info.__dict__ else [shape])
        diag_matrices = [
            linear_operator_test_util.random_uniform(shape=block_shape[:-1],
                                                     minval=1.,
                                                     maxval=20.,
                                                     dtype=dtype)
            for block_shape in expected_blocks
        ]

        if use_placeholder:
            diag_matrices_ph = [
                array_ops.placeholder(dtype=dtype) for _ in expected_blocks
            ]
            diag_matrices = self.evaluate(diag_matrices)
            # Evaluate here because (i) you cannot feed a tensor, and (ii)
            # values are random and we want the same value used for both mat and
            # feed_dict.
            operator = block_diag.LinearOperatorBlockDiag(
                [linalg.LinearOperatorDiag(m_ph) for m_ph in diag_matrices_ph])
            feed_dict = {
                m_ph: m
                for (m_ph, m) in zip(diag_matrices_ph, diag_matrices)
            }
        else:
            operator = block_diag.LinearOperatorBlockDiag(
                [linalg.LinearOperatorDiag(m) for m in diag_matrices])
            feed_dict = None
            # Should be auto-set.
            self.assertTrue(operator.is_square)

        # Broadcast the shapes.
        expected_shape = list(build_info.shape)

        matrices = linear_operator_util.broadcast_matrix_batch_dims([
            array_ops.matrix_diag(diag_block) for diag_block in diag_matrices
        ])

        block_diag_dense = _block_diag_dense(expected_shape, matrices)
        if not use_placeholder:
            block_diag_dense.set_shape(
                expected_shape[:-2] + [expected_shape[-1], expected_shape[-1]])

        return operator, block_diag_dense, feed_dict
    def _operator_and_matrix(self,
                             build_info,
                             dtype,
                             use_placeholder,
                             ensure_self_adjoint_and_pd=False):
        shape = list(build_info.shape)
        expected_blocks = (build_info.__dict__["blocks"]
                           if "blocks" in build_info.__dict__ else [shape])
        matrices = [
            linear_operator_test_util.random_positive_definite_matrix(
                block_shape, dtype, force_well_conditioned=True)
            for block_shape in expected_blocks
        ]

        lin_op_matrices = matrices

        if use_placeholder:
            lin_op_matrices = [
                array_ops.placeholder_with_default(matrix, shape=None)
                for matrix in matrices
            ]

        operator = block_diag.LinearOperatorBlockDiag([
            linalg.LinearOperatorFullMatrix(
                l,
                is_square=True,
                is_self_adjoint=True if ensure_self_adjoint_and_pd else None,
                is_positive_definite=True
                if ensure_self_adjoint_and_pd else None)
            for l in lin_op_matrices
        ])

        # Should be auto-set.
        self.assertTrue(operator.is_square)

        # Broadcast the shapes.
        expected_shape = list(build_info.shape)

        matrices = linear_operator_util.broadcast_matrix_batch_dims(matrices)

        block_diag_dense = _block_diag_dense(expected_shape, matrices)

        if not use_placeholder:
            block_diag_dense.set_shape(
                expected_shape[:-2] + [expected_shape[-1], expected_shape[-1]])

        return operator, block_diag_dense
    def test_static_dims_broadcast_second_arg_higher_rank(self):
        # x.batch_shape =    [1, 2]
        # y.batch_shape = [1, 3, 1]
        # broadcast batch shape = [1, 3, 2]
        x = rng.rand(1, 2, 1, 5)
        y = rng.rand(1, 3, 2, 3, 7)
        batch_of_zeros = np.zeros((1, 3, 2, 1, 1))
        x_bc_expected = x + batch_of_zeros
        y_bc_expected = y + batch_of_zeros

        x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x, y])

        self.assertAllEqual(x_bc_expected.shape, x_bc.shape)
        self.assertAllEqual(y_bc_expected.shape, y_bc.shape)
        x_bc_, y_bc_ = self.evaluate([x_bc, y_bc])
        self.assertAllClose(x_bc_expected, x_bc_)
        self.assertAllClose(y_bc_expected, y_bc_)
    def _operator_and_mat_and_feed_dict(self, build_info, dtype,
                                        use_placeholder):
        shape = list(build_info.shape)
        expected_blocks = (build_info.__dict__["blocks"]
                           if "blocks" in build_info.__dict__ else [shape])
        matrices = [
            linear_operator_test_util.random_positive_definite_matrix(
                block_shape, dtype, force_well_conditioned=True)
            for block_shape in expected_blocks
        ]

        if use_placeholder:
            matrices_ph = [
                array_ops.placeholder(dtype=dtype) for _ in expected_blocks
            ]
            # Evaluate here because (i) you cannot feed a tensor, and (ii)
            # values are random and we want the same value used for both mat and
            # feed_dict.
            matrices = self.evaluate(matrices)
            operator = block_diag.LinearOperatorBlockDiag([
                linalg.LinearOperatorFullMatrix(m_ph, is_square=True)
                for m_ph in matrices_ph
            ],
                                                          is_square=True)
            feed_dict = {m_ph: m for (m_ph, m) in zip(matrices_ph, matrices)}
        else:
            operator = block_diag.LinearOperatorBlockDiag([
                linalg.LinearOperatorFullMatrix(m, is_square=True)
                for m in matrices
            ])
            feed_dict = None
            # Should be auto-set.
            self.assertTrue(operator.is_square)

        # Broadcast the shapes.
        expected_shape = list(build_info.shape)

        matrices = linear_operator_util.broadcast_matrix_batch_dims(matrices)

        block_diag_dense = _block_diag_dense(expected_shape, matrices)

        if not use_placeholder:
            block_diag_dense.set_shape(
                expected_shape[:-2] + [expected_shape[-1], expected_shape[-1]])

        return operator, block_diag_dense, feed_dict
    def test_static_dims_broadcast(self):
        # x.batch_shape = [3, 1, 2]
        # y.batch_shape = [4, 1]
        # broadcast batch shape = [3, 4, 2]
        x = rng.rand(3, 1, 2, 1, 5)
        y = rng.rand(4, 1, 3, 7)
        batch_of_zeros = np.zeros((3, 4, 2, 1, 1))
        x_bc_expected = x + batch_of_zeros
        y_bc_expected = y + batch_of_zeros

        x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x, y])

        with self.cached_session() as sess:
            self.assertAllEqual(x_bc_expected.shape, x_bc.get_shape())
            self.assertAllEqual(y_bc_expected.shape, y_bc.get_shape())
            x_bc_, y_bc_ = self.evaluate([x_bc, y_bc])
            self.assertAllClose(x_bc_expected, x_bc_)
            self.assertAllClose(y_bc_expected, y_bc_)
  def _broadcast_batch_dims(self, x, spectrum):
    """Broadcast batch dims of batch matrix `x` and spectrum."""
    # spectrum.shape = batch_shape + block_shape
    # First make spectrum a batch matrix with
    #   spectrum.shape = batch_shape + [prod(block_shape), 1]
    spec_mat = array_ops.reshape(
        spectrum, array_ops.concat(
            (self.batch_shape_tensor(), [-1, 1]), axis=0))
    # Second, broadcast, possibly requiring an addition of array of zeros.
    x, spec_mat = linear_operator_util.broadcast_matrix_batch_dims((x,
                                                                    spec_mat))
    # Third, put the block shape back into spectrum.
    batch_shape = array_ops.shape(x)[:-2]
    spectrum = array_ops.reshape(
        spec_mat,
        array_ops.concat((batch_shape, self.block_shape_tensor()), axis=0))

    return x, spectrum
Exemple #30
0
  def test_dynamic_dims_broadcast_32bit_second_arg_higher_rank(self):
    # x.batch_shape =    [1, 2]
    # y.batch_shape = [3, 4, 1]
    # broadcast batch shape = [3, 4, 2]
    x = rng.rand(1, 2, 1, 5).astype(np.float32)
    y = rng.rand(3, 4, 1, 3, 7).astype(np.float32)
    batch_of_zeros = np.zeros((3, 4, 2, 1, 1)).astype(np.float32)
    x_bc_expected = x + batch_of_zeros
    y_bc_expected = y + batch_of_zeros

    x_ph = array_ops.placeholder_with_default(x, shape=None)
    y_ph = array_ops.placeholder_with_default(y, shape=None)

    x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x_ph, y_ph])

    x_bc_, y_bc_ = self.evaluate([x_bc, y_bc])
    self.assertAllClose(x_bc_expected, x_bc_)
    self.assertAllClose(y_bc_expected, y_bc_)
  def test_static_dims_broadcast_second_arg_higher_rank(self):
    # x.batch_shape =    [1, 2]
    # y.batch_shape = [1, 3, 1]
    # broadcast batch shape = [1, 3, 2]
    x = rng.rand(1, 2, 1, 5)
    y = rng.rand(1, 3, 2, 3, 7)
    batch_of_zeros = np.zeros((1, 3, 2, 1, 1))
    x_bc_expected = x + batch_of_zeros
    y_bc_expected = y + batch_of_zeros

    x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x, y])

    with self.cached_session() as sess:
      self.assertAllEqual(x_bc_expected.shape, x_bc.get_shape())
      self.assertAllEqual(y_bc_expected.shape, y_bc.get_shape())
      x_bc_, y_bc_ = sess.run([x_bc, y_bc])
      self.assertAllClose(x_bc_expected, x_bc_)
      self.assertAllClose(y_bc_expected, y_bc_)
  def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
    shape = list(build_info.shape)
    expected_blocks = (
        build_info.__dict__["blocks"] if "blocks" in build_info.__dict__
        else [shape])
    matrices = [
        linear_operator_test_util.random_positive_definite_matrix(
            block_shape, dtype, force_well_conditioned=True)
        for block_shape in expected_blocks
    ]

    if use_placeholder:
      matrices_ph = [
          array_ops.placeholder(dtype=dtype) for _ in expected_blocks
      ]
      # Evaluate here because (i) you cannot feed a tensor, and (ii)
      # values are random and we want the same value used for both mat and
      # feed_dict.
      matrices = self.evaluate(matrices)
      operator = block_diag.LinearOperatorBlockDiag(
          [linalg.LinearOperatorFullMatrix(
              m_ph, is_square=True) for m_ph in matrices_ph],
          is_square=True)
      feed_dict = {m_ph: m for (m_ph, m) in zip(matrices_ph, matrices)}
    else:
      operator = block_diag.LinearOperatorBlockDiag(
          [linalg.LinearOperatorFullMatrix(
              m, is_square=True) for m in matrices])
      feed_dict = None
      # Should be auto-set.
      self.assertTrue(operator.is_square)

    # Broadcast the shapes.
    expected_shape = list(build_info.shape)

    matrices = linear_operator_util.broadcast_matrix_batch_dims(matrices)

    block_diag_dense = _block_diag_dense(expected_shape, matrices)

    if not use_placeholder:
      block_diag_dense.set_shape(
          expected_shape[:-2] + [expected_shape[-1], expected_shape[-1]])

    return operator, block_diag_dense, feed_dict
  def operator_and_matrix(
      self, shape_info, dtype, use_placeholder,
      ensure_self_adjoint_and_pd=False):
    shape = list(shape_info.shape)
    expected_blocks = (
        shape_info.__dict__["blocks"] if "blocks" in shape_info.__dict__
        else [shape])
    matrices = [
        linear_operator_test_util.random_positive_definite_matrix(
            block_shape, dtype, force_well_conditioned=True)
        for block_shape in expected_blocks
    ]

    lin_op_matrices = matrices

    if use_placeholder:
      lin_op_matrices = [
          array_ops.placeholder_with_default(
              matrix, shape=None) for matrix in matrices]

    operator = block_diag.LinearOperatorBlockDiag(
        [linalg.LinearOperatorFullMatrix(
            l,
            is_square=True,
            is_self_adjoint=True if ensure_self_adjoint_and_pd else None,
            is_positive_definite=True if ensure_self_adjoint_and_pd else None)
         for l in lin_op_matrices])

    # Should be auto-set.
    self.assertTrue(operator.is_square)

    # Broadcast the shapes.
    expected_shape = list(shape_info.shape)

    matrices = linear_operator_util.broadcast_matrix_batch_dims(matrices)

    block_diag_dense = _block_diag_dense(expected_shape, matrices)

    if not use_placeholder:
      block_diag_dense.set_shape(
          expected_shape[:-2] + [expected_shape[-1], expected_shape[-1]])

    return operator, block_diag_dense
  def test_dynamic_dims_broadcast_32bit_second_arg_higher_rank(self):
    # x.batch_shape =    [1, 2]
    # y.batch_shape = [3, 4, 1]
    # broadcast batch shape = [3, 4, 2]
    x = rng.rand(1, 2, 1, 5).astype(np.float32)
    y = rng.rand(3, 4, 1, 3, 7).astype(np.float32)
    batch_of_zeros = np.zeros((3, 4, 2, 1, 1)).astype(np.float32)
    x_bc_expected = x + batch_of_zeros
    y_bc_expected = y + batch_of_zeros

    x_ph = array_ops.placeholder(dtypes.float32)
    y_ph = array_ops.placeholder(dtypes.float32)

    x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x_ph, y_ph])

    with self.cached_session() as sess:
      x_bc_, y_bc_ = sess.run([x_bc, y_bc], feed_dict={x_ph: x, y_ph: y})
      self.assertAllClose(x_bc_expected, x_bc_)
      self.assertAllClose(y_bc_expected, y_bc_)
Exemple #35
0
    def operator_and_matrix(self,
                            shape_info,
                            dtype,
                            use_placeholder,
                            ensure_self_adjoint_and_pd=False):
        del ensure_self_adjoint_and_pd
        shape = list(shape_info.shape)
        expected_blocks = (shape_info.__dict__["blocks"]
                           if "blocks" in shape_info.__dict__ else [shape])
        matrices = [
            linear_operator_test_util.random_normal(block_shape, dtype=dtype)
            for block_shape in expected_blocks
        ]

        lin_op_matrices = matrices

        if use_placeholder:
            lin_op_matrices = [
                array_ops.placeholder_with_default(matrix, shape=None)
                for matrix in matrices
            ]

        blocks = []
        for l in lin_op_matrices:
            blocks.append(
                linalg.LinearOperatorFullMatrix(l,
                                                is_square=False,
                                                is_self_adjoint=False,
                                                is_positive_definite=False))
        operator = block_diag.LinearOperatorBlockDiag(blocks)

        # Broadcast the shapes.
        expected_shape = list(shape_info.shape)

        matrices = linear_operator_util.broadcast_matrix_batch_dims(matrices)

        block_diag_dense = _block_diag_dense(expected_shape, matrices)

        if not use_placeholder:
            block_diag_dense.set_shape(expected_shape)

        return operator, block_diag_dense
  def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
    shape = list(build_info.shape)
    expected_blocks = (
        build_info.__dict__["blocks"] if "blocks" in build_info.__dict__
        else [shape])
    diag_matrices = [
        linear_operator_test_util.random_uniform(
            shape=block_shape[:-1], minval=1., maxval=20., dtype=dtype)
        for block_shape in expected_blocks
    ]

    if use_placeholder:
      diag_matrices_ph = [
          array_ops.placeholder(dtype=dtype) for _ in expected_blocks
      ]
      diag_matrices = self.evaluate(diag_matrices)
      # Evaluate here because (i) you cannot feed a tensor, and (ii)
      # values are random and we want the same value used for both mat and
      # feed_dict.
      operator = block_diag.LinearOperatorBlockDiag(
          [linalg.LinearOperatorDiag(m_ph) for m_ph in diag_matrices_ph])
      feed_dict = {m_ph: m for (m_ph, m) in zip(
          diag_matrices_ph, diag_matrices)}
    else:
      operator = block_diag.LinearOperatorBlockDiag(
          [linalg.LinearOperatorDiag(m) for m in diag_matrices])
      feed_dict = None
      # Should be auto-set.
      self.assertTrue(operator.is_square)

    # Broadcast the shapes.
    expected_shape = list(build_info.shape)

    matrices = linear_operator_util.broadcast_matrix_batch_dims(
        [array_ops.matrix_diag(diag_block) for diag_block in diag_matrices])

    block_diag_dense = _block_diag_dense(expected_shape, matrices)
    if not use_placeholder:
      block_diag_dense.set_shape(
          expected_shape[:-2] + [expected_shape[-1], expected_shape[-1]])

    return operator, block_diag_dense, feed_dict
Exemple #37
0
    def operator_and_matrix(self,
                            build_info,
                            dtype,
                            use_placeholder,
                            ensure_self_adjoint_and_pd=False):
        # Kronecker products constructed below will be from symmetric
        # positive-definite matrices.
        del ensure_self_adjoint_and_pd
        shape = list(build_info.shape)
        expected_factors = build_info.__dict__["factors"]
        matrices = [
            linear_operator_test_util.random_positive_definite_matrix(
                block_shape, dtype, force_well_conditioned=True)
            for block_shape in expected_factors
        ]

        lin_op_matrices = matrices

        if use_placeholder:
            lin_op_matrices = [
                array_ops.placeholder_with_default(m, shape=None)
                for m in matrices
            ]

        operator = kronecker.LinearOperatorKronecker([
            linalg.LinearOperatorFullMatrix(l,
                                            is_square=True,
                                            is_self_adjoint=True,
                                            is_positive_definite=True)
            for l in lin_op_matrices
        ])

        matrices = linear_operator_util.broadcast_matrix_batch_dims(matrices)

        kronecker_dense = _kronecker_dense(matrices)

        if not use_placeholder:
            kronecker_dense.set_shape(shape)

        return operator, kronecker_dense
  def _to_dense(self):
    num_cols = 0
    dense_rows = []
    flat_broadcast_operators = linear_operator_util.broadcast_matrix_batch_dims(
        [op.to_dense() for row in self.operators for op in row])  # pylint: disable=g-complex-comprehension
    broadcast_operators = [
        flat_broadcast_operators[i * (i + 1) // 2:(i + 1) * (i + 2) // 2]
        for i in range(len(self.operators))]
    for row_blocks in broadcast_operators:
      batch_row_shape = array_ops.shape(row_blocks[0])[:-1]
      num_cols += array_ops.shape(row_blocks[-1])[-1]
      zeros_to_pad_after_shape = array_ops.concat(
          [batch_row_shape,
           [self.domain_dimension_tensor() - num_cols]], axis=-1)
      zeros_to_pad_after = array_ops.zeros(
          shape=zeros_to_pad_after_shape, dtype=self.dtype)

      row_blocks.append(zeros_to_pad_after)
      dense_rows.append(array_ops.concat(row_blocks, axis=-1))

    mat = array_ops.concat(dense_rows, axis=-2)
    mat.set_shape(self.shape)
    return mat
  def _operator_and_matrix(
      self, build_info, dtype, use_placeholder,
      ensure_self_adjoint_and_pd=False):
    # Kronecker products constructed below will be from symmetric
    # positive-definite matrices.
    del ensure_self_adjoint_and_pd
    shape = list(build_info.shape)
    expected_factors = build_info.__dict__["factors"]
    matrices = [
        linear_operator_test_util.random_positive_definite_matrix(
            block_shape, dtype, force_well_conditioned=True)
        for block_shape in expected_factors
    ]

    lin_op_matrices = matrices

    if use_placeholder:
      lin_op_matrices = [
          array_ops.placeholder_with_default(m, shape=None) for m in matrices]

    operator = kronecker.LinearOperatorKronecker(
        [linalg.LinearOperatorFullMatrix(
            l,
            is_square=True,
            is_self_adjoint=True,
            is_positive_definite=True)
         for l in lin_op_matrices])

    matrices = linear_operator_util.broadcast_matrix_batch_dims(matrices)

    kronecker_dense = _kronecker_dense(matrices)

    if not use_placeholder:
      kronecker_dense.set_shape(shape)

    return operator, kronecker_dense
 def test_zero_batch_matrices_returned_as_empty_list(self):
     self.assertAllEqual([],
                         linear_operator_util.broadcast_matrix_batch_dims(
                             []))
 def test_zero_batch_matrices_returned_as_empty_list(self):
   self.assertAllEqual([],
                       linear_operator_util.broadcast_matrix_batch_dims([]))
    def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"):
        """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`.

    The returned `Tensor` will be close to an exact solution if `A` is well
    conditioned. Otherwise closeness will vary. See class docstring for details.

    Examples:

    ```python
    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
    operator = LinearOperator(...)
    operator.shape = [..., M, N]

    # Solve R > 0 linear systems for every member of the batch.
    RHS = ... # shape [..., M, R]

    X = operator.solve(RHS)
    # X[..., :, r] is the solution to the r'th linear system
    # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r]

    operator.matmul(X)
    ==> RHS
    ```

    Args:
      rhs: `Tensor` with same `dtype` as this operator and compatible shape,
        or a list of `Tensor`s (for blockwise operators). `Tensor`s are treated
        like a [batch] matrices meaning for every set of leading dimensions, the
        last two dimensions defines a matrix.
        See class docstring for definition of compatibility.
      adjoint: Python `bool`.  If `True`, solve the system involving the adjoint
        of this `LinearOperator`:  `A^H X = rhs`.
      adjoint_arg:  Python `bool`.  If `True`, solve `A X = rhs^H` where `rhs^H`
        is the hermitian transpose (transposition and complex conjugation).
      name:  A name scope to use for ops added by this method.

    Returns:
      `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`.

    Raises:
      NotImplementedError:  If `self.is_non_singular` or `is_square` is False.
    """
        if self.is_non_singular is False:
            raise NotImplementedError(
                "Exact solve not implemented for an operator that is expected to "
                "be singular.")
        if self.is_square is False:
            raise NotImplementedError(
                "Exact solve not implemented for an operator that is expected to "
                "not be square.")
        if isinstance(rhs, linear_operator.LinearOperator):
            left_operator = self.adjoint() if adjoint else self
            right_operator = rhs.adjoint() if adjoint_arg else rhs

            if (right_operator.range_dimension is not None
                    and left_operator.domain_dimension is not None
                    and right_operator.range_dimension !=
                    left_operator.domain_dimension):
                raise ValueError(
                    "Operators are incompatible. Expected `rhs` to have dimension"
                    " {} but got {}.".format(left_operator.domain_dimension,
                                             right_operator.range_dimension))
            with self._name_scope(name):
                return linear_operator_algebra.solve(left_operator,
                                                     right_operator)

        with self._name_scope(name):
            block_dimensions = (self._block_domain_dimensions()
                                if adjoint else self._block_range_dimensions())
            arg_dim = -1 if adjoint_arg else -2
            blockwise_arg = linear_operator_util.arg_is_blockwise(
                block_dimensions, rhs, arg_dim)

            if blockwise_arg:
                split_rhs = rhs
                for i, block in enumerate(split_rhs):
                    if not isinstance(block, linear_operator.LinearOperator):
                        block = ops.convert_to_tensor_v2_with_dispatch(block)
                        self._check_input_dtype(block)
                        block_dimensions[i].assert_is_compatible_with(
                            block.shape[arg_dim])
                        split_rhs[i] = block
            else:
                rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
                self._check_input_dtype(rhs)
                op_dimension = (self.domain_dimension
                                if adjoint else self.range_dimension)
                op_dimension.assert_is_compatible_with(rhs.shape[arg_dim])
                split_dim = -1 if adjoint_arg else -2
                # Split input by rows normally, and otherwise columns.
                split_rhs = linear_operator_util.split_arg_into_blocks(
                    self._block_domain_dimensions(),
                    self._block_domain_dimension_tensors,
                    rhs,
                    axis=split_dim)

            solution_list = []
            for index, operator in enumerate(self.operators):
                solution_list += [
                    operator.solve(split_rhs[index],
                                   adjoint=adjoint,
                                   adjoint_arg=adjoint_arg)
                ]

            if blockwise_arg:
                return solution_list

            solution_list = linear_operator_util.broadcast_matrix_batch_dims(
                solution_list)
            return array_ops.concat(solution_list, axis=-2)
    def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"):
        """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`.

    The returned `Tensor` will be close to an exact solution if `A` is well
    conditioned. Otherwise closeness will vary. See class docstring for details.

    Given the blockwise `n + 1`-by-`n + 1` linear operator:

    op = [[A_00     0  ...     0  ...    0],
          [A_10  A_11  ...     0  ...    0],
          ...
          [A_k0  A_k1  ...  A_kk  ...    0],
          ...
          [A_n0  A_n1  ...  A_nk  ... A_nn]]

    we find `x = op.solve(y)` by observing that

    `y_k = A_k0.matmul(x_0) + A_k1.matmul(x_1) + ... + A_kk.matmul(x_k)`

    and therefore

    `x_k = A_kk.solve(y_k -
                      A_k0.matmul(x_0) - ... - A_k(k-1).matmul(x_(k-1)))`

    where `x_k` and `y_k` are the `k`th blocks obtained by decomposing `x`
    and `y` along their appropriate axes.

    We first solve `x_0 = A_00.solve(y_0)`. Proceeding inductively, we solve
    for `x_k`, `k = 1..n`, given `x_0..x_(k-1)`.

    The adjoint case is solved similarly, beginning with
    `x_n = A_nn.solve(y_n, adjoint=True)` and proceeding backwards.

    Examples:

    ```python
    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
    operator = LinearOperator(...)
    operator.shape = [..., M, N]

    # Solve R > 0 linear systems for every member of the batch.
    RHS = ... # shape [..., M, R]

    X = operator.solve(RHS)
    # X[..., :, r] is the solution to the r'th linear system
    # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r]

    operator.matmul(X)
    ==> RHS
    ```

    Args:
      rhs: `Tensor` with same `dtype` as this operator and compatible shape,
        or a list of `Tensor`s. `Tensor`s are treated like a [batch] matrices
        meaning for every set of leading dimensions, the last two dimensions
        defines a matrix.
        See class docstring for definition of compatibility.
      adjoint: Python `bool`.  If `True`, solve the system involving the adjoint
        of this `LinearOperator`:  `A^H X = rhs`.
      adjoint_arg:  Python `bool`.  If `True`, solve `A X = rhs^H` where `rhs^H`
        is the hermitian transpose (transposition and complex conjugation).
      name:  A name scope to use for ops added by this method.

    Returns:
      `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`.

    Raises:
      NotImplementedError:  If `self.is_non_singular` or `is_square` is False.
    """
        if self.is_non_singular is False:
            raise NotImplementedError(
                "Exact solve not implemented for an operator that is expected to "
                "be singular.")
        if self.is_square is False:
            raise NotImplementedError(
                "Exact solve not implemented for an operator that is expected to "
                "not be square.")
        if isinstance(rhs, linear_operator.LinearOperator):
            left_operator = self.adjoint() if adjoint else self
            right_operator = rhs.adjoint() if adjoint_arg else rhs

            if (right_operator.range_dimension is not None
                    and left_operator.domain_dimension is not None
                    and right_operator.range_dimension !=
                    left_operator.domain_dimension):
                raise ValueError(
                    "Operators are incompatible. Expected `rhs` to have dimension"
                    " {} but got {}.".format(left_operator.domain_dimension,
                                             right_operator.range_dimension))
            with self._name_scope(name):  # pylint: disable=not-callable
                return linear_operator_algebra.solve(left_operator,
                                                     right_operator)

        with self._name_scope(name):  # pylint: disable=not-callable
            block_dimensions = (self._block_domain_dimensions()
                                if adjoint else self._block_range_dimensions())
            arg_dim = -1 if adjoint_arg else -2
            blockwise_arg = linear_operator_util.arg_is_blockwise(
                block_dimensions, rhs, arg_dim)
            if blockwise_arg:
                for i, block in enumerate(rhs):
                    if not isinstance(block, linear_operator.LinearOperator):
                        block = ops.convert_to_tensor_v2_with_dispatch(block)
                        self._check_input_dtype(block)
                        block_dimensions[i].assert_is_compatible_with(
                            block.shape[arg_dim])
                        rhs[i] = block
                if adjoint_arg:
                    split_rhs = [linalg.adjoint(y) for y in rhs]
                else:
                    split_rhs = rhs

            else:
                rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
                self._check_input_dtype(rhs)
                op_dimension = (self.domain_dimension
                                if adjoint else self.range_dimension)
                op_dimension.assert_is_compatible_with(rhs.shape[arg_dim])

                rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
                split_rhs = linear_operator_util.split_arg_into_blocks(
                    self._block_domain_dimensions(),
                    self._block_domain_dimension_tensors,
                    rhs,
                    axis=-2)

            solution_list = []
            if adjoint:
                # For an adjoint blockwise lower-triangular linear operator, the system
                # must be solved bottom to top. Iterate backwards over rows of the
                # adjoint (i.e. columns of the non-adjoint operator).
                for index in reversed(range(len(self.operators))):
                    y = split_rhs[index]
                    # Iterate top to bottom over the operators in the off-diagonal portion
                    # of the column-partition (i.e. row-partition of the adjoint), apply
                    # the operator to the respective block of the solution found in
                    # previous iterations, and subtract the result from the `rhs` block.
                    # For example,let `A`, `B`, and `D` be the linear operators in the top
                    # row-partition of the adjoint of
                    # `LinearOperatorBlockLowerTriangular([[A], [B, C], [D, E, F]])`,
                    # and `x_1` and `x_2` be blocks of the solution found in previous
                    # iterations of the outer loop. The following loop (when `index == 0`)
                    # expresses
                    # `Ax_0 + Bx_1 + Dx_2 = y_0` as `Ax_0 = y_0*`, where
                    # `y_0* = y_0 - Bx_1 - Dx_2`.
                    for j in reversed(range(index + 1, len(self.operators))):
                        y = y - self.operators[j][index].matmul(
                            solution_list[len(self.operators) - 1 - j],
                            adjoint=adjoint)
                    # Continuing the example above, solve `Ax_0 = y_0*` for `x_0`.
                    solution_list.append(self._diagonal_operators[index].solve(
                        y, adjoint=adjoint))
                solution_list.reverse()
            else:
                # Iterate top to bottom over the row-partitions.
                for row, y in zip(self.operators, split_rhs):
                    # Iterate left to right over the operators in the off-diagonal portion
                    # of the row-partition, apply the operator to the block of the
                    # solution found in previous iterations, and subtract the result from
                    # the `rhs` block. For example, let `D`, `E`, and `F` be the linear
                    # operators in the bottom row-partition of
                    # `LinearOperatorBlockLowerTriangular([[A], [B, C], [D, E, F]])` and
                    # `x_0` and `x_1` be blocks of the solution found in previous
                    # iterations of the outer loop. The following loop
                    # (when `index == 2`), expresses
                    # `Dx_0 + Ex_1 + Fx_2 = y_2` as `Fx_2 = y_2*`, where
                    # `y_2* = y_2 - D_x0 - Ex_1`.
                    for i, operator in enumerate(row[:-1]):
                        y = y - operator.matmul(solution_list[i],
                                                adjoint=adjoint)
                    # Continuing the example above, solve `Fx_2 = y_2*` for `x_2`.
                    solution_list.append(row[-1].solve(y, adjoint=adjoint))

            if blockwise_arg:
                return solution_list

            solution_list = linear_operator_util.broadcast_matrix_batch_dims(
                solution_list)
            return array_ops.concat(solution_list, axis=-2)