def _matmul(self, x, adjoint=False, adjoint_arg=False): if self._assert_proper_shapes: x = linalg.adjoint(x) if adjoint_arg else x aps = linear_operator_util.assert_compatible_matrix_dimensions( self, x) x = distribution_util.with_dependencies([aps], x) if self.is_square: # Note that adjoint has no effect since this matrix is self-adjoint. if adjoint_arg: output_shape = prefer_static.concat([ prefer_static.shape(x)[:-2], [prefer_static.shape(x)[-1], prefer_static.shape(x)[-2]] ], axis=0) else: output_shape = prefer_static.shape(x) return self._possibly_broadcast_batch_shape( array_ops.zeros(shape=output_shape, dtype=x.dtype)) x_shape = prefer_static.shape(x) n = self._num_columns if adjoint else self._num_rows m = x_shape[-2] if adjoint_arg else x_shape[-1] output_shape = prefer_static.concat([x_shape[:-2], [n, m]], axis=0) zeros = array_ops.zeros(shape=output_shape, dtype=x.dtype) return self._possibly_broadcast_batch_shape(zeros)
def _matmul(self, x, adjoint=False, adjoint_arg=False): # Note that adjoint has no effect since this matrix is self-adjoint. x = linalg.adjoint(x) if adjoint_arg else x if self._assert_proper_shapes: aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x) x = control_flow_ops.with_dependencies([aps], x) return self._possibly_broadcast_batch_shape(x)
def _matmul(self, x, adjoint=False, adjoint_arg=False): x = linalg.adjoint(x) if adjoint_arg else x if self._assert_proper_shapes: aps = linear_operator_util.assert_compatible_matrix_dimensions( self, x) x = distribution_util.with_dependencies([aps], x) return x * self._make_multiplier_matrix(conjugate=adjoint)
def _solve(self, rhs, adjoint=False, adjoint_arg=False): rhs = linalg.adjoint(rhs) if adjoint_arg else rhs if self._assert_proper_shapes: aps = linear_operator_util.assert_compatible_matrix_dimensions( self, rhs) rhs = distribution_util.with_dependencies([aps], rhs) return rhs / self._make_multiplier_matrix(conjugate=adjoint)
def _matmul(self, x, adjoint=False, adjoint_arg=False): # Given a Toeplitz matrix, we can embed it in a Circulant matrix to perform # efficient matrix multiplications. Given a Toeplitz matrix with first row # [t_0, t_1, ... t_{n-1}] and first column [t0, t_{-1}, ..., t_{-(n-1)}, # let C by the circulant matrix with first column [t0, t_{-1}, ..., # t_{-(n-1)}, 0, t_{n-1}, ..., t_1]. Also adjoin to our input vector `x` # `n` zeros, to make it a vector of length `2n` (call it y). It can be shown # that if we take the first n entries of `Cy`, this is equal to the Toeplitz # multiplication. See: # http://math.mit.edu/icg/resources/teaching/18.085-spring2015/toeplitz.pdf # for more details. x = linalg.adjoint(x) if adjoint_arg else x expanded_x = array_ops.concat([x, array_ops.zeros_like(x)], axis=-2) col = ops.convert_to_tensor(self.col) row = ops.convert_to_tensor(self.row) circulant_col = array_ops.concat([ col, array_ops.zeros_like(col[..., 0:1]), array_ops.reverse(row[..., 1:], axis=[-1]) ], axis=-1) circulant = linear_operator_circulant.LinearOperatorCirculant( fft_ops.fft(_to_complex(circulant_col)), input_output_dtype=row.dtype) result = circulant.matmul(expanded_x, adjoint=adjoint, adjoint_arg=False) shape = self._shape_tensor(row=row, col=col) return _ops.cast( result[..., :self._domain_dimension_tensor(shape=shape), :], self.dtype)
def _assert_self_adjoint(self): dense = self.to_dense() logging.warn( "Using (possibly slow) default implementation of assert_self_adjoint." " Requires conversion to a dense matrix.") return check_ops.assert_equal( dense, linalg.adjoint(dense), message="Matrix was not equal to its adjoint.")
def _dense_solve(self, rhs, adjoint=False, adjoint_arg=False): """Solve by conversion to a dense matrix.""" if self.is_square is False: # pylint: disable=g-bool-id-comparison raise NotImplementedError( "Solve is not yet implemented for non-square operators.") rhs = linalg.adjoint(rhs) if adjoint_arg else rhs if self._can_use_cholesky(): return linalg_ops.cholesky_solve( linalg_ops.cholesky(self.to_dense()), rhs) return linear_operator_util.matrix_solve_with_broadcast( self.to_dense(), rhs, adjoint=adjoint)
def _solve(self, rhs, adjoint=False, adjoint_arg=False): rhs = linalg.adjoint(rhs) if adjoint_arg else rhs spectrum = self._conj_spectrum if adjoint else self._spectrum_complex rhs, spectrum = self._broadcast_batch_dims(rhs, spectrum) rhs_vb = self._vectorize_then_blockify(rhs) fft_rhs_vb = self._fft(rhs_vb) solution_vb = self._ifft(fft_rhs_vb / spectrum) x = self._unblockify_then_matricize(solution_vb) return _ops.cast(x, self.dtype)
def _solve(self, rhs, adjoint=False, adjoint_arg=False): """Default implementation of _solve.""" if self.is_square is False: raise NotImplementedError( "Solve is not yet implemented for non-square operators.") logging.warn( "Using (possibly slow) default implementation of solve." " Requires conversion to a dense matrix and O(N^3) operations.") rhs = linalg.adjoint(rhs) if adjoint_arg else rhs if self._can_use_cholesky(): return linear_operator_util.cholesky_solve_with_broadcast( linalg_ops.cholesky(self.to_dense()), rhs) return linear_operator_util.matrix_solve_with_broadcast( self.to_dense(), rhs, adjoint=adjoint)
def _matmul(self, x, adjoint=False, adjoint_arg=False): # Given a Toeplitz matrix, we can embed it in a Circulant matrix to perform # efficient matrix multiplications. Given a Toeplitz matrix with first row # [t_0, t_1, ... t_{n-1}] and first column [t0, t_{-1}, ..., t_{-(n-1)}, # let C by the circulant matrix with first column [t0, t_{-1}, ..., # t_{-(n-1)}, 0, t_{n-1}, ..., t_1]. Also adjoin to our input vector `x` # `n` zeros, to make it a vector of length `2n` (call it y). It can be shown # that if we take the first n entries of `Cy`, this is equal to the Toeplitz # multiplication. See: # http://math.mit.edu/icg/resources/teaching/18.085-spring2015/toeplitz.pdf # for more details. x = linalg.adjoint(x) if adjoint_arg else x expanded_x = array_ops.concat([x, array_ops.zeros_like(x)], axis=-2) result = self._circulant.matmul(expanded_x, adjoint=adjoint, adjoint_arg=False) return _ops.cast(result[..., :self.domain_dimension_tensor(), :], self.dtype)
def _matmul(self, x, adjoint=False, adjoint_arg=False): # Given a vector `v`, we would like to reflect `x` about the hyperplane # orthogonal to `v` going through the origin. We first project `x` to `v` # to get v * dot(v, x) / dot(v, v). After we project, we can reflect the # projection about the hyperplane by flipping sign to get # -v * dot(v, x) / dot(v, v). Finally, we can add back the component # that is orthogonal to v. This is invariant under reflection, since the # whole hyperplane is invariant. This component is equal to x - v * dot(v, # x) / dot(v, v), giving the formula x - 2 * v * dot(v, x) / dot(v, v) # for the reflection. # Note that because this is a reflection, it lies in O(n) (for real vector # spaces) or U(n) (for complex vector spaces), and thus is its own adjoint. reflection_axis = ops.convert_to_tensor(self.reflection_axis) x = linalg.adjoint(x) if adjoint_arg else x normalized_axis = nn.l2_normalize(reflection_axis, axis=-1) mat = normalized_axis[..., _ops.newaxis] x_dot_normalized_v = _linalg.matmul(mat, x, adjoint_a=True) return x - 2 * mat * x_dot_normalized_v
def _matmul(self, x, adjoint=False, adjoint_arg=False): x = linalg.adjoint(x) if adjoint_arg else x # With F the matrix of a DFT, and F^{-1}, F^H the inverse and Hermitian # transpose, one can show that F^{-1} = F^{H} is the IDFT matrix. Therefore # matmul(x) = F^{-1} diag(spectrum) F x, # = F^{H} diag(spectrum) F x, # so that # matmul(x, adjoint=True) = F^{H} diag(conj(spectrum)) F x. spectrum = self._conj_spectrum if adjoint else self._spectrum_complex x = _ops.cast(x, spectrum.dtype) x, spectrum = self._broadcast_batch_dims(x, spectrum) x_vb = self._vectorize_then_blockify(x) fft_x_vb = self._fft(x_vb) block_vector_result = self._ifft(spectrum * fft_x_vb) y = self._unblockify_then_matricize(block_vector_result) return _ops.cast(y, self.dtype)
def _matmul( # pylint:disable=missing-docstring a, b, transpose_a=False, transpose_b=False, adjoint_a=False, adjoint_b=False, a_is_sparse=False, b_is_sparse=False, name=None): if transpose_a or transpose_b: raise ValueError("Transposing not supported at this time.") if a_is_sparse or b_is_sparse: raise ValueError("Sparse methods not supported at this time.") if not isinstance(a, LinearOperator): # We use the identity (B^HA^H)^H = AB adjoint_matmul = b.matmul(a, adjoint=(not adjoint_b), adjoint_arg=(not adjoint_a), name=name) return linalg.adjoint(adjoint_matmul) return a.matmul(b, adjoint=adjoint_a, adjoint_arg=adjoint_b, name=name)
def _solve(self, rhs, adjoint=False, adjoint_arg=False): # Here we follow the same use of Roth's column lemma as in `matmul`, with # the key difference that we replace all `matmul` instances with `solve`. # This follows from the property that inv(A x B) = inv(A) x inv(B). # Below we document the shape manipulation for adjoint=False, # adjoint_arg=False, but the general case of different adjoints is still # handled. if adjoint_arg: rhs = linalg.adjoint(rhs) # Always add a batch dimension to enable broadcasting to work. batch_shape = self._compute_ones_matrix_shape() rhs = rhs + array_ops.zeros(batch_shape, dtype=rhs.dtype) # rhs has shape [B, R, C], where B represent some number of batch # dimensions, # R represents the number of rows, and C represents the number of columns. # In order to apply Roth's column lemma, we need to operate on a batch of # column vectors, so we reshape into a batch of column vectors. We put it # at the front to ensure that broadcasting between operators to the batch # dimensions B still works. output = _rotate_last_dim(rhs, rotate_right=True) # Also expand the shape to be [A, C, B, R]. The first dimension will be # used to accumulate dimensions from each operator matmul. output = output[_ops.newaxis, ...] # In this loop, A is going to refer to the value of the accumulated # dimension. A = 1 at the start, and will end up being self.range_dimension. # V will refer to the last dimension. V = R at the start, and will end up # being 1 in the end. for operator in self.operators[:-1]: # Reshape output from [A, C, B, V] to be # [A, C, B, V / op.domain_dimension, op.domain_dimension] if adjoint: operator_dimension = operator.range_dimension_tensor() else: operator_dimension = operator.domain_dimension_tensor() output = _unvec_by(output, operator_dimension) # We are computing (XA^-1^T) = (A^-1 X^T)^T. # output has [A, C, B, V / op.domain_dimension, op.domain_dimension], # which is being converted to: # [A, C, B, V / op.domain_dimension, op.range_dimension] output = _linalg.matrix_transpose(output) output = operator.solve(output, adjoint=adjoint, adjoint_arg=False) output = _linalg.matrix_transpose(output) # Rearrange it to [A * op.range_dimension, C, B, V / op.domain_dimension] output = _rotate_last_dim(output, rotate_right=False) output = _vec(output) output = _rotate_last_dim(output, rotate_right=True) # After the loop, we will have # A = self.range_dimension / op[-1].range_dimension # V = op[-1].domain_dimension # We convert that using matvec to get: # [A, C, B, op[-1].range_dimension] output = self.operators[-1].solvevec(output, adjoint=adjoint) # Rearrange shape to be [B1, ... Bn, self.range_dimension, C] output = _rotate_last_dim(output, rotate_right=False) output = _vec(output) output = _rotate_last_dim(output, rotate_right=False) if tensor_shape.TensorShape(rhs.shape).is_fully_defined(): column_dim = tensor_shape.TensorShape(rhs.shape)[-1] broadcast_batch_shape = common_shapes.broadcast_shape( tensor_shape.TensorShape(rhs.shape)[:-2], self.batch_shape) if adjoint: matrix_dimensions = [self.domain_dimension, column_dim] else: matrix_dimensions = [self.range_dimension, column_dim] tensorshape_util.set_shape( output, broadcast_batch_shape.concatenate(matrix_dimensions)) return output
def _matmul(self, x, adjoint=False, adjoint_arg=False): # Here we heavily rely on Roth's column Lemma [1]: # (A x B) * vec X = vec BXA^T, # where vec stacks all the columns of the matrix under each other. In our # case, x represents a batch of vec X (i.e. we think of x as a batch of # column vectors, rather than a matrix). Each member of the batch can be # reshaped to a matrix (hence we get a batch of matrices). # We can iteratively apply this lemma by noting that if B is a Kronecker # product, then we can apply the lemma again. # [1] W. E. Roth, "On direct product matrices," # Bulletin of the American Mathematical Society, vol. 40, pp. 461-468, # 1934 # Efficiency # Naively doing the Kronecker product, by calculating the dense matrix and # applying it will can take cubic time in the size of domain_dimension # (assuming a square matrix). The other issue is that calculating the dense # matrix can be prohibitively expensive, in that it can take a large amount # of memory. # # This implementation avoids this memory blow up by only computing matmuls # with the factors. In this way, we don't have to realize the dense matrix. # In terms of complexity, if we have Kronecker Factors of size: # (n1, n1), (n2, n2), (n3, n3), ... (nJ, nJ), with N = \prod n_i, and we # have as input a [N, M] matrix, the naive approach would take O(N^2 M). # With this approach (ignoring reshaping of tensors and transposes for now), # the time complexity can be O(M * (\sum n_i) * N). There is also the # benefit of batched multiplication (In this example, the batch size is # roughly M * N) so this can be much faster. However, not factored in are # the costs of the several transposing of tensors, which can affect cache # behavior. # Below we document the shape manipulation for adjoint=False, # adjoint_arg=False, but the general case of different adjoints is still # handled. if adjoint_arg: x = linalg.adjoint(x) # Always add a batch dimension to enable broadcasting to work. batch_shape = self._compute_ones_matrix_shape() x = x + array_ops.zeros(batch_shape, dtype=x.dtype) # x has shape [B, R, C], where B represent some number of batch dimensions, # R represents the number of rows, and C represents the number of columns. # In order to apply Roth's column lemma, we need to operate on a batch of # column vectors, so we reshape into a batch of column vectors. We put it # at the front to ensure that broadcasting between operators to the batch # dimensions B still works. output = _rotate_last_dim(x, rotate_right=True) # Also expand the shape to be [A, C, B, R]. The first dimension will be # used to accumulate dimensions from each operator matmul. output = output[_ops.newaxis, ...] # In this loop, A is going to refer to the value of the accumulated # dimension. A = 1 at the start, and will end up being self.range_dimension. # V will refer to the last dimension. V = R at the start, and will end up # being 1 in the end. for operator in self.operators[:-1]: # Reshape output from [A, C, B, V] to be # [A, C, B, V / op.domain_dimension, op.domain_dimension] if adjoint: operator_dimension = operator.range_dimension_tensor() else: operator_dimension = operator.domain_dimension_tensor() output = _unvec_by(output, operator_dimension) # We are computing (XA^T) = (AX^T)^T. # output has [A, C, B, V / op.domain_dimension, op.domain_dimension], # which is being converted to: # [A, C, B, V / op.domain_dimension, op.range_dimension] output = _linalg.matrix_transpose(output) output = operator.matmul(output, adjoint=adjoint, adjoint_arg=False) output = _linalg.matrix_transpose(output) # Rearrange it to [A * op.range_dimension, C, B, V / op.domain_dimension] output = _rotate_last_dim(output, rotate_right=False) output = _vec(output) output = _rotate_last_dim(output, rotate_right=True) # After the loop, we will have # A = self.range_dimension / op[-1].range_dimension # V = op[-1].domain_dimension # We convert that using matvec to get: # [A, C, B, op[-1].range_dimension] output = self.operators[-1].matvec(output, adjoint=adjoint) # Rearrange shape to be [B1, ... Bn, self.range_dimension, C] output = _rotate_last_dim(output, rotate_right=False) output = _vec(output) output = _rotate_last_dim(output, rotate_right=False) if tensor_shape.TensorShape(x.shape).is_fully_defined(): column_dim = tensor_shape.TensorShape(x.shape)[-1] broadcast_batch_shape = common_shapes.broadcast_shape( tensor_shape.TensorShape(x.shape)[:-2], self.batch_shape) if adjoint: matrix_dimensions = [self.domain_dimension, column_dim] else: matrix_dimensions = [self.range_dimension, column_dim] tensorshape_util.set_shape( output, broadcast_batch_shape.concatenate(matrix_dimensions)) return output
def _reshape_for_efficiency(a, b, transpose_a=False, transpose_b=False, adjoint_a=False, adjoint_b=False): """Maybe reshape a, b, and return an inverse map. For matmul/solve.""" def identity(x): return x # At this point, we have not taken transpose/adjoint of a/b. still_need_to_transpose = True if _ops.TensorShape(a.shape).ndims is None or _ops.TensorShape( b.shape).ndims is None: return a, b, identity, still_need_to_transpose # This could be handled in the future, but seems less common. if _ops.TensorShape(a.shape).ndims >= _ops.TensorShape(b.shape).ndims: return a, b, identity, still_need_to_transpose # From now on, we might modify b, but will not modify a. # Suppose: # _ops.TensorShape(a.shape) = C + [m, n], _ops.TensorShape(b.shape) = # _ops.TensorShape(b.shape) = S + C + [n, r] b_extra_ndims = _ops.TensorShape(b.shape).ndims - _ops.TensorShape( a.shape).ndims # b_extra_sh = S, b_main_sh = C + [n, r] b_extra_sh = array_ops.shape(b)[:b_extra_ndims] b_main_sh = array_ops.shape(b)[b_extra_ndims:] # No reason to flip unless the extra dims of b are big enough. Why? # Assume adjoint/transpose = False. Then... # By not flipping, we have to replicate a to shape # b_extra_sh + _ops.TensorShape(a.shape), # which could use extra memory. But in all cases, the final output has shape # b_extra_sh + _ops.TensorShape(a.shape)[:-1] + _ops.TensorShape([b.shape)[-1]] # So we only end up creating a larger object if the end dim of b is smaller # than the end dim of a. This often happens, e.g. if b was a vector that was # expanded to a matrix (by appending a singleton). # Since adjoint/transpose may not be False, we must make adjustments here. # The dim of b that holds the multiple equations. a_domain_sz_ = _ops.TensorShape( a.shape)[-2 if adjoint_a or transpose_a else -1] b_eq_sz_ = _ops.TensorShape( b.shape)[-2 if adjoint_b or transpose_b else -1] b_extra_sz_ = ( np.prod(_ops.TensorShape(b.shape)[:b_extra_ndims].as_list()) if _ops.TensorShape(b.shape)[:b_extra_ndims].is_fully_defined() else None) if (a_domain_sz_ is not None and b_eq_sz_ is not None and b_extra_sz_ is not None): if b_extra_sz_ < 2 or a_domain_sz_ <= b_eq_sz_: return a, b, identity, still_need_to_transpose # At this point, we're flipping for sure! # Any transposes/adjoints will happen here explicitly, rather than in calling # code. Why? To avoid having to write separate complex code for each case. if adjoint_a: a = linalg.adjoint(a) elif transpose_a: a = linalg.transpose(a) if adjoint_b: b = linalg.adjoint(b) elif transpose_b: b = linalg.transpose(b) still_need_to_transpose = False # Recompute shapes, since the transpose/adjoint may have changed them. b_extra_sh = array_ops.shape(b)[:b_extra_ndims] b_main_sh = array_ops.shape(b)[b_extra_ndims:] # Permutation to put the extra dims at the end. perm = (np.concatenate( (np.arange(b_extra_ndims, _ops.TensorShape(b.shape).ndims), np.arange( 0, b_extra_ndims)), 0)) b_extra_on_end = array_ops.transpose(b, perm=perm) # Now squash this end into one long dim. b_squashed_end = array_ops.reshape( b_extra_on_end, array_ops.concat((b_main_sh[:-1], [-1]), 0)) def reshape_inv(y): # Expand the extra dims hanging off the end, "b_extra_sh". # Note we use y_sh[:-1] + [b_main_sh[-1]] rather than b_main_sh, because y # Could have different batch dims than a and b, because of broadcasting. y_extra_shape = array_ops.concat( (array_ops.shape(y)[:-1], [b_main_sh[-1]], b_extra_sh), 0) y_extra_on_end = array_ops.reshape(y, y_extra_shape) inverse_perm = np.argsort(perm) return array_ops.transpose(y_extra_on_end, perm=inverse_perm) return a, b_squashed_end, reshape_inv, still_need_to_transpose
def _matmul(self, x, adjoint=False, adjoint_arg=False): diag_term = math_ops.conj(self._diag) if adjoint else self._diag x = linalg.adjoint(x) if adjoint_arg else x diag_mat = array_ops.expand_dims(diag_term, -1) return diag_mat * x
def _solve(self, rhs, adjoint=False, adjoint_arg=False): diag_term = math_ops.conj(self._diag) if adjoint else self._diag rhs = linalg.adjoint(rhs) if adjoint_arg else rhs inv_diag_mat = array_ops.expand_dims(1. / diag_term, -1) return rhs * inv_diag_mat
def _to_dense(self): if self.is_self_adjoint: return self.operator.to_dense() return linalg.adjoint(self.operator.to_dense())
def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"): """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`. The returned `Tensor` will be close to an exact solution if `A` is well conditioned. Otherwise closeness will vary. See class docstring for details. Given the blockwise `n + 1`-by-`n + 1` linear operator: op = [[A_00 0 ... 0 ... 0], [A_10 A_11 ... 0 ... 0], ... [A_k0 A_k1 ... A_kk ... 0], ... [A_n0 A_n1 ... A_nk ... A_nn]] we find `x = op.solve(y)` by observing that `y_k = A_k0.matmul(x_0) + A_k1.matmul(x_1) + ... + A_kk.matmul(x_k)` and therefore `x_k = A_kk.solve(y_k - A_k0.matmul(x_0) - ... - A_k(k-1).matmul(x_(k-1)))` where `x_k` and `y_k` are the `k`th blocks obtained by decomposing `x` and `y` along their appropriate axes. We first solve `x_0 = A_00.solve(y_0)`. Proceeding inductively, we solve for `x_k`, `k = 1..n`, given `x_0..x_(k-1)`. The adjoint case is solved similarly, beginning with `x_n = A_nn.solve(y_n, adjoint=True)` and proceeding backwards. Examples: ```python # Make an operator acting like batch matrix A. Assume tensor_shape.TensorShape(A.shape) = [..., M, N] operator = LinearOperator(...) tensor_shape.TensorShape(operator.shape) = [..., M, N] # Solve R > 0 linear systems for every member of the batch. RHS = ... # shape [..., M, R] X = operator.solve(RHS) # X[..., :, r] is the solution to the r'th linear system # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r] operator.matmul(X) ==> RHS ``` Args: rhs: `Tensor` with same `dtype` as this operator and compatible shape, or a list of `Tensor`s. `Tensor`s are treated like a [batch] matrices meaning for every set of leading dimensions, the last two dimensions defines a matrix. See class docstring for definition of compatibility. adjoint: Python `bool`. If `True`, solve the system involving the adjoint of this `LinearOperator`: `A^H X = rhs`. adjoint_arg: Python `bool`. If `True`, solve `A X = rhs^H` where `rhs^H` is the hermitian transpose (transposition and complex conjugation). name: A name scope to use for ops added by this method. Returns: `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`. Raises: NotImplementedError: If `self.is_non_singular` or `is_square` is False. """ if self.is_non_singular is False: raise NotImplementedError( "Exact solve not implemented for an operator that is expected to " "be singular.") if self.is_square is False: raise NotImplementedError( "Exact solve not implemented for an operator that is expected to " "not be square.") if isinstance(rhs, linear_operator.LinearOperator): left_operator = self.adjoint() if adjoint else self right_operator = rhs.adjoint() if adjoint_arg else rhs if (right_operator.range_dimension is not None and left_operator.domain_dimension is not None and right_operator.range_dimension != left_operator.domain_dimension): raise ValueError( "Operators are incompatible. Expected `rhs` to have dimension" " {} but got {}.".format(left_operator.domain_dimension, right_operator.range_dimension)) with self._name_scope(name): return linear_operator_algebra.solve(left_operator, right_operator) with self._name_scope(name): block_dimensions = (self._block_domain_dimensions() if adjoint else self._block_range_dimensions()) arg_dim = -1 if adjoint_arg else -2 blockwise_arg = linear_operator_util.arg_is_blockwise( block_dimensions, rhs, arg_dim) if blockwise_arg: for i, block in enumerate(rhs): if not isinstance(block, linear_operator.LinearOperator): block = ops.convert_to_tensor(block) # self._check_input_dtype(block) block_dimensions[i].assert_is_compatible_with( tensor_shape.TensorShape(block.shape)[arg_dim]) rhs[i] = block if adjoint_arg: split_rhs = [linalg.adjoint(y) for y in rhs] else: split_rhs = rhs else: rhs = ops.convert_to_tensor(rhs, name="rhs") # self._check_input_dtype(rhs) op_dimension = (self.domain_dimension if adjoint else self.range_dimension) op_dimension.assert_is_compatible_with( tensor_shape.TensorShape(rhs.shape)[arg_dim]) rhs = linalg.adjoint(rhs) if adjoint_arg else rhs split_rhs = linear_operator_util.split_arg_into_blocks( self._block_domain_dimensions(), self._block_domain_dimension_tensors, rhs, axis=-2) solution_list = [] if adjoint: # For an adjoint blockwise lower-triangular linear operator, the system # must be solved bottom to top. Iterate backwards over rows of the # adjoint (i.e. columns of the non-adjoint operator). for index in reversed(range(len(self.operators))): y = split_rhs[index] # Iterate top to bottom over the operators in the off-diagonal portion # of the column-partition (i.e. row-partition of the adjoint), apply # the operator to the respective block of the solution found in # previous iterations, and subtract the result from the `rhs` block. # For example,let `A`, `B`, and `D` be the linear operators in the top # row-partition of the adjoint of # `LinearOperatorBlockLowerTriangular([[A], [B, C], [D, E, F]])`, # and `x_1` and `x_2` be blocks of the solution found in previous # iterations of the outer loop. The following loop (when `index == 0`) # expresses # `Ax_0 + Bx_1 + Dx_2 = y_0` as `Ax_0 = y_0*`, where # `y_0* = y_0 - Bx_1 - Dx_2`. for j in reversed(range(index + 1, len(self.operators))): y -= self.operators[j][index].matmul( solution_list[len(self.operators) - 1 - j], adjoint=adjoint) # Continuing the example above, solve `Ax_0 = y_0*` for `x_0`. solution_list.append(self._diagonal_operators[index].solve( y, adjoint=adjoint)) solution_list.reverse() else: # Iterate top to bottom over the row-partitions. for row, y in zip(self.operators, split_rhs): # Iterate left to right over the operators in the off-diagonal portion # of the row-partition, apply the operator to the block of the # solution found in previous iterations, and subtract the result from # the `rhs` block. For example, let `D`, `E`, and `F` be the linear # operators in the bottom row-partition of # `LinearOperatorBlockLowerTriangular([[A], [B, C], [D, E, F]])` and # `x_0` and `x_1` be blocks of the solution found in previous # iterations of the outer loop. The following loop # (when `index == 2`), expresses # `Dx_0 + Ex_1 + Fx_2 = y_2` as `Fx_2 = y_2*`, where # `y_2* = y_2 - D_x0 - Ex_1`. for i, operator in enumerate(row[:-1]): y -= operator.matmul(solution_list[i], adjoint=adjoint) # Continuing the example above, solve `Fx_2 = y_2*` for `x_2`. solution_list.append(row[-1].solve(y, adjoint=adjoint)) if blockwise_arg: return solution_list solution_list = linear_operator_util.broadcast_matrix_batch_dims( solution_list) return array_ops.concat(solution_list, axis=-2)
def _solve(self, rhs, adjoint=False, adjoint_arg=False): rhs = linalg.adjoint(rhs) if adjoint_arg else rhs return linalg.triangular_solve( self._get_tril(), rhs, lower=True, adjoint=adjoint)