def _check_shapes(self): """Static check that shapes are compatible.""" # Broadcast shape also checks that u and v are compatible. uv_shape = _ops.broadcast_static_shape( tensor_shape.TensorShape(self.u.shape), tensor_shape.TensorShape(self.v.shape)) batch_shape = _ops.broadcast_static_shape( self.base_operator.batch_shape, uv_shape[:-2]) tensor_shape.Dimension( self.base_operator.domain_dimension).assert_is_compatible_with( uv_shape[-2]) if self._diag_update is not None: tensor_shape.dimension_at_index( uv_shape, -1).assert_is_compatible_with( tensor_shape.TensorShape(self._diag_update.shape)[-1]) _ops.broadcast_static_shape( batch_shape, tensor_shape.TensorShape(self._diag_update.shape)[:-1])
def batch_shape(self): """`TensorShape` of batch dimensions of this `LinearOperator`. If this operator acts like the batch matrix `A` with `tensor_shape.TensorShape(A.shape) = [B1,...,Bb, M, N]`, then this returns `TensorShape([B1,...,Bb])`, equivalent to `tensor_shape.TensorShape(A.shape)[:-2]` Returns: `TensorShape`, statically determined, may be undefined. """ # Derived classes get this "for free" once .shape is implemented. return tensor_shape.TensorShape(self.shape)[:-2]
def _shape_tensor(self): # Avoid messy broadcasting if possible. if tensor_shape.TensorShape(self.shape).is_fully_defined(): return ops.convert_to_tensor(tensor_shape.TensorShape( self.shape).as_list(), dtype=dtypes.int32, name="shape") # Don't check the matrix dimensions. That would add unnecessary Asserts to # the graph. Things will fail at runtime naturally if shapes are # incompatible. matrix_shape = array_ops.stack([ self.operators[0].range_dimension_tensor(), self.operators[-1].domain_dimension_tensor() ]) # Dummy Tensor of zeros. Will never be materialized. zeros = array_ops.zeros(shape=self.operators[0].batch_shape_tensor()) for operator in self.operators[1:]: zeros += array_ops.zeros(shape=operator.batch_shape_tensor()) batch_shape = array_ops.shape(zeros) return array_ops.concat((batch_shape, matrix_shape), 0)
def _shape(self): # Get final matrix shape. domain_dimension = sum(self._block_domain_dimensions()) range_dimension = sum(self._block_range_dimensions()) matrix_shape = tensor_shape.TensorShape([domain_dimension, range_dimension]) # Get broadcast batch shape. # broadcast_shape checks for compatibility. batch_shape = self.operators[0].batch_shape for operator in self.operators[1:]: batch_shape = common_shapes.broadcast_shape( batch_shape, operator.batch_shape) return batch_shape.concatenate(matrix_shape)
def _trace(self): # The diagonal of the [[nested] block] circulant operator is the mean of # the spectrum. # Proof: For the [0,...,0] element, this follows from the IDFT formula. # Then the result follows since all diagonal elements are the same. # Therefore, the trace is the sum of the spectrum. # Get shape of diag along with the axis over which to reduce the spectrum. # We will reduce the spectrum over all block indices. if tensor_shape.TensorShape(self.spectrum.shape).is_fully_defined(): spec_rank = tensor_shape.TensorShape(self.spectrum.shape).ndims axis = np.arange(spec_rank - self.block_depth, spec_rank, dtype=np.int32) else: spec_rank = array_ops.rank(self.spectrum) axis = array_ops.range(spec_rank - self.block_depth, spec_rank) # Real diag part "re_d". # Suppose tensor_shape.TensorShape(spectrum.shape) = [B1,...,Bb, N1, N2] # tensor_shape.TensorShape(self.shape) = [B1,...,Bb, N, N], with N1 * N2 = N. # tensor_shape.TensorShape(re_d_value.shape) = [B1,...,Bb] re_d_value = math_ops.reduce_sum(math_ops.real(self.spectrum), axis=axis) if not np.issubdtype(self.dtype, np.complexfloating): return _ops.cast(re_d_value, self.dtype) # Imaginary part, "im_d". if self.is_self_adjoint: im_d_value = array_ops.zeros_like(re_d_value) else: im_d_value = math_ops.reduce_sum(math_ops.imag(self.spectrum), axis=axis) return _ops.cast(math_ops.complex(re_d_value, im_d_value), self.dtype)
def _unblockify_then_matricize(self, vec): """Flatten the block dimensions then reshape to a batch matrix.""" # Suppose # tensor_shape.TensorShape(vec.shape) = [v0, v1, v2, v3], # self.block_depth = 2. # Then # leading shape = [v0, v1] # block shape = [v2, v3]. # We will reshape vec to # [v1, v2*v3, v0]. # Un-blockify: Flatten block dimensions. Reshape # [v0, v1, v2, v3] --> [v0, v1, v2*v3]. if tensor_shape.TensorShape(vec.shape).is_fully_defined(): # vec_shape = [v0, v1, v2, v3] vec_shape = tensor_shape.TensorShape(vec.shape).as_list() # vec_leading_shape = [v0, v1] vec_leading_shape = vec_shape[:-self.block_depth] # vec_block_shape = [v2, v3] vec_block_shape = vec_shape[-self.block_depth:] # flat_shape = [v0, v1, v2*v3] flat_shape = vec_leading_shape + [np.prod(vec_block_shape)] else: vec_shape = prefer_static.shape(vec) vec_leading_shape = vec_shape[:-self.block_depth] vec_block_shape = vec_shape[-self.block_depth:] flat_shape = prefer_static.concat( (vec_leading_shape, [math_ops.reduce_prod(vec_block_shape)]), 0) vec_flat = array_ops.reshape(vec, flat_shape) # Matricize: Reshape to batch matrix. # [v0, v1, v2*v3] --> [v1, v2*v3, v0], # representing a shape [v1] batch of [v2*v3, v0] matrices. matrix = distribution_util.rotate_transpose(vec_flat, shift=-1) return matrix
def matvec(self, x, adjoint=False, name="matvec"): """Transform [batch] vector `x` with left multiplication: `x --> Ax`. ```python # Make an operator acting like batch matrix A. Assume tensor_shape.TensorShape(A.shape) = [..., M, N] operator = LinearOperator(...) X = ... # shape [..., N], batch vector Y = operator.matvec(X) tensor_shape.TensorShape(Y.shape) ==> [..., M] Y[..., :] = sum_j A[..., :, j] X[..., j] ``` Args: x: `Tensor` with compatible shape and same `dtype` as `self`. `x` is treated as a [batch] vector meaning for every set of leading dimensions, the last dimension defines a vector. See class docstring for definition of compatibility. adjoint: Python `bool`. If `True`, left multiply by the adjoint: `A^H x`. name: A name for this `Op`. Returns: A `Tensor` with shape `[..., M]` and same `dtype` as `self`. """ with self._name_scope(name): x = ops.convert_to_tensor(x, name="x") # self._check_input_dtype(x) self_dim = -2 if adjoint else -1 tensor_shape.dimension_at_index( tensor_shape.TensorShape(self.shape), self_dim).assert_is_compatible_with( tensor_shape.TensorShape(x.shape)[-1]) return self._matvec(x, adjoint=adjoint)
def tensor_rank(self, name="tensor_rank"): """Rank (in the sense of tensors) of matrix corresponding to this operator. If this operator acts like the batch matrix `A` with `tensor_shape.TensorShape(A.shape) = [B1,...,Bb, M, N]`, then this returns `b + 2`. Args: name: A name for this `Op`. Returns: Python integer, or None if the tensor rank is undefined. """ # Derived classes get this "for free" once .shape() is implemented. with self._name_scope(name): return tensor_shape.TensorShape(self.shape).ndims
def _eigvals(self): # This will be the kronecker product of all the eigenvalues. # Note: It doesn't matter which kronecker product it is, since every # kronecker product of the same matrices are similar. eigvals = [operator.eigvals() for operator in self.operators] # Now compute the kronecker product product = eigvals[0] for eigval in eigvals[1:]: # Product has shape [B, R1, 1]. product = product[..., _ops.newaxis] # Eigval has shape [B, 1, R2]. Produces shape [B, R1, R2]. product = product * eigval[..., _ops.newaxis, :] # Reshape to [B, R1 * R2] product = array_ops.reshape( product, shape=prefer_static.concat([prefer_static.shape(product)[:-2], [-1]], axis=0)) tensorshape_util.set_shape(product, tensor_shape.TensorShape(self.shape)[:-1]) return product
def _shape(self): # Get final matrix shape. domain_dimension = self.operators[0].domain_dimension for operator in self.operators[1:]: domain_dimension.assert_is_compatible_with(operator.range_dimension) domain_dimension = operator.domain_dimension matrix_shape = tensor_shape.TensorShape( [self.operators[0].range_dimension, self.operators[-1].domain_dimension]) # Get broadcast batch shape. # broadcast_shape checks for compatibility. batch_shape = self.operators[0].batch_shape for operator in self.operators[1:]: batch_shape = common_shapes.broadcast_shape( batch_shape, operator.batch_shape) return batch_shape.concatenate(matrix_shape)
def _to_dense(self): product = self.operators[0].to_dense() for operator in self.operators[1:]: # Product has shape [B, R1, 1, C1, 1]. product = product[..., :, _ops.newaxis, :, _ops.newaxis] # Operator has shape [B, 1, R2, 1, C2]. op_to_mul = operator.to_dense()[..., _ops.newaxis, :, _ops.newaxis, :] # This is now [B, R1, R2, C1, C2]. product = product * op_to_mul # Now merge together dimensions to get [B, R1 * R2, C1 * C2]. product_shape = _prefer_static_shape(product) shape = _prefer_static_concat_shape(product_shape[:-4], [ product_shape[-4] * product_shape[-3], product_shape[-2] * product_shape[-1] ]) product = array_ops.reshape(product, shape=shape) tensorshape_util.set_shape(product, tensor_shape.TensorShape(self.shape)) return product
def _to_dense(self): num_cols = 0 dense_rows = [] flat_broadcast_operators = linear_operator_util.broadcast_matrix_batch_dims( [op.to_dense() for row in self.operators for op in row]) # pylint: disable=g-complex-comprehension broadcast_operators = [ flat_broadcast_operators[i * (i + 1) // 2:(i + 1) * (i + 2) // 2] for i in range(len(self.operators))] for row_blocks in broadcast_operators: batch_row_shape = prefer_static.shape(row_blocks[0])[:-1] num_cols = num_cols + prefer_static.shape(row_blocks[-1])[-1] zeros_to_pad_after_shape = prefer_static.concat( [batch_row_shape, [self.domain_dimension_tensor() - num_cols]], axis=-1) zeros_to_pad_after = array_ops.zeros( shape=zeros_to_pad_after_shape, dtype=self.dtype) row_blocks.append(zeros_to_pad_after) dense_rows.append(prefer_static.concat(row_blocks, axis=-1)) mat = prefer_static.concat(dense_rows, axis=-2) tensorshape_util.set_shape(mat, tensor_shape.TensorShape(self.shape)) return mat
def _broadcast_static_shape(shape_x, shape_y): """Reimplements `tf.broadcast_static_shape` in JAX/NumPy.""" if (tensor_shape.TensorShape(shape_x).ndims is None or tensor_shape.TensorShape(shape_y).ndims is None): return tensor_shape.TensorShape(None) shape_x = tuple(tensor_shape.TensorShape(shape_x).as_list()) shape_y = tuple(tensor_shape.TensorShape(shape_y).as_list()) try: if JAX_MODE: error_message = 'Incompatible shapes for broadcasting' return tensor_shape.TensorShape(lax.broadcast_shapes(shape_x, shape_y)) error_message = ('shape mismatch: objects cannot be broadcast to' ' a single shape') return tensor_shape.TensorShape( np.broadcast(np.zeros(shape_x), np.zeros(shape_y)).shape) except ValueError as e: # Match TF error message if error_message in str(e): raise ValueError( 'Incompatible shapes for broadcasting: {} and {}'.format( shape_x, shape_y)) raise
def _to_dense(self): product = self.operators[0].to_dense() for operator in self.operators[1:]: # Product has shape [B, R1, 1, C1, 1]. product = product[..., :, _ops.newaxis, :, _ops.newaxis] # Operator has shape [B, 1, R2, 1, C2]. op_to_mul = operator.to_dense()[..., _ops.newaxis, :, _ops.newaxis, :] # This is now [B, R1, R2, C1, C2]. product *= op_to_mul # Now merge together dimensions to get [B, R1 * R2, C1 * C2]. product = array_ops.reshape(product, shape=array_ops.concat([ array_ops.shape(product)[:-4], [ array_ops.shape(product)[-4] * array_ops.shape(product)[-3], array_ops.shape(product)[-2] * array_ops.shape(product)[-1] ] ], axis=0)) product.set_shape(tensor_shape.TensorShape(self.shape)) return product
def _solve(self, rhs, adjoint=False, adjoint_arg=False): # Here we follow the same use of Roth's column lemma as in `matmul`, with # the key difference that we replace all `matmul` instances with `solve`. # This follows from the property that inv(A x B) = inv(A) x inv(B). # Below we document the shape manipulation for adjoint=False, # adjoint_arg=False, but the general case of different adjoints is still # handled. if adjoint_arg: rhs = linalg.adjoint(rhs) # Always add a batch dimension to enable broadcasting to work. batch_shape = self._compute_ones_matrix_shape() rhs = rhs + array_ops.zeros(batch_shape, dtype=rhs.dtype) # rhs has shape [B, R, C], where B represent some number of batch # dimensions, # R represents the number of rows, and C represents the number of columns. # In order to apply Roth's column lemma, we need to operate on a batch of # column vectors, so we reshape into a batch of column vectors. We put it # at the front to ensure that broadcasting between operators to the batch # dimensions B still works. output = _rotate_last_dim(rhs, rotate_right=True) # Also expand the shape to be [A, C, B, R]. The first dimension will be # used to accumulate dimensions from each operator matmul. output = output[_ops.newaxis, ...] # In this loop, A is going to refer to the value of the accumulated # dimension. A = 1 at the start, and will end up being self.range_dimension. # V will refer to the last dimension. V = R at the start, and will end up # being 1 in the end. for operator in self.operators[:-1]: # Reshape output from [A, C, B, V] to be # [A, C, B, V / op.domain_dimension, op.domain_dimension] if adjoint: operator_dimension = operator.range_dimension_tensor() else: operator_dimension = operator.domain_dimension_tensor() output = _unvec_by(output, operator_dimension) # We are computing (XA^-1^T) = (A^-1 X^T)^T. # output has [A, C, B, V / op.domain_dimension, op.domain_dimension], # which is being converted to: # [A, C, B, V / op.domain_dimension, op.range_dimension] output = _linalg.matrix_transpose(output) output = operator.solve(output, adjoint=adjoint, adjoint_arg=False) output = _linalg.matrix_transpose(output) # Rearrange it to [A * op.range_dimension, C, B, V / op.domain_dimension] output = _rotate_last_dim(output, rotate_right=False) output = _vec(output) output = _rotate_last_dim(output, rotate_right=True) # After the loop, we will have # A = self.range_dimension / op[-1].range_dimension # V = op[-1].domain_dimension # We convert that using matvec to get: # [A, C, B, op[-1].range_dimension] output = self.operators[-1].solvevec(output, adjoint=adjoint) # Rearrange shape to be [B1, ... Bn, self.range_dimension, C] output = _rotate_last_dim(output, rotate_right=False) output = _vec(output) output = _rotate_last_dim(output, rotate_right=False) if tensor_shape.TensorShape(rhs.shape).is_fully_defined(): column_dim = tensor_shape.TensorShape(rhs.shape)[-1] broadcast_batch_shape = common_shapes.broadcast_shape( tensor_shape.TensorShape(rhs.shape)[:-2], self.batch_shape) if adjoint: matrix_dimensions = [self.domain_dimension, column_dim] else: matrix_dimensions = [self.range_dimension, column_dim] tensorshape_util.set_shape( output, broadcast_batch_shape.concatenate(matrix_dimensions)) return output
def _matmul(self, x, adjoint=False, adjoint_arg=False): # Here we heavily rely on Roth's column Lemma [1]: # (A x B) * vec X = vec BXA^T, # where vec stacks all the columns of the matrix under each other. In our # case, x represents a batch of vec X (i.e. we think of x as a batch of # column vectors, rather than a matrix). Each member of the batch can be # reshaped to a matrix (hence we get a batch of matrices). # We can iteratively apply this lemma by noting that if B is a Kronecker # product, then we can apply the lemma again. # [1] W. E. Roth, "On direct product matrices," # Bulletin of the American Mathematical Society, vol. 40, pp. 461-468, # 1934 # Efficiency # Naively doing the Kronecker product, by calculating the dense matrix and # applying it will can take cubic time in the size of domain_dimension # (assuming a square matrix). The other issue is that calculating the dense # matrix can be prohibitively expensive, in that it can take a large amount # of memory. # # This implementation avoids this memory blow up by only computing matmuls # with the factors. In this way, we don't have to realize the dense matrix. # In terms of complexity, if we have Kronecker Factors of size: # (n1, n1), (n2, n2), (n3, n3), ... (nJ, nJ), with N = \prod n_i, and we # have as input a [N, M] matrix, the naive approach would take O(N^2 M). # With this approach (ignoring reshaping of tensors and transposes for now), # the time complexity can be O(M * (\sum n_i) * N). There is also the # benefit of batched multiplication (In this example, the batch size is # roughly M * N) so this can be much faster. However, not factored in are # the costs of the several transposing of tensors, which can affect cache # behavior. # Below we document the shape manipulation for adjoint=False, # adjoint_arg=False, but the general case of different adjoints is still # handled. if adjoint_arg: x = linalg.adjoint(x) # Always add a batch dimension to enable broadcasting to work. batch_shape = self._compute_ones_matrix_shape() x = x + array_ops.zeros(batch_shape, dtype=x.dtype) # x has shape [B, R, C], where B represent some number of batch dimensions, # R represents the number of rows, and C represents the number of columns. # In order to apply Roth's column lemma, we need to operate on a batch of # column vectors, so we reshape into a batch of column vectors. We put it # at the front to ensure that broadcasting between operators to the batch # dimensions B still works. output = _rotate_last_dim(x, rotate_right=True) # Also expand the shape to be [A, C, B, R]. The first dimension will be # used to accumulate dimensions from each operator matmul. output = output[_ops.newaxis, ...] # In this loop, A is going to refer to the value of the accumulated # dimension. A = 1 at the start, and will end up being self.range_dimension. # V will refer to the last dimension. V = R at the start, and will end up # being 1 in the end. for operator in self.operators[:-1]: # Reshape output from [A, C, B, V] to be # [A, C, B, V / op.domain_dimension, op.domain_dimension] if adjoint: operator_dimension = operator.range_dimension_tensor() else: operator_dimension = operator.domain_dimension_tensor() output = _unvec_by(output, operator_dimension) # We are computing (XA^T) = (AX^T)^T. # output has [A, C, B, V / op.domain_dimension, op.domain_dimension], # which is being converted to: # [A, C, B, V / op.domain_dimension, op.range_dimension] output = _linalg.matrix_transpose(output) output = operator.matmul(output, adjoint=adjoint, adjoint_arg=False) output = _linalg.matrix_transpose(output) # Rearrange it to [A * op.range_dimension, C, B, V / op.domain_dimension] output = _rotate_last_dim(output, rotate_right=False) output = _vec(output) output = _rotate_last_dim(output, rotate_right=True) # After the loop, we will have # A = self.range_dimension / op[-1].range_dimension # V = op[-1].domain_dimension # We convert that using matvec to get: # [A, C, B, op[-1].range_dimension] output = self.operators[-1].matvec(output, adjoint=adjoint) # Rearrange shape to be [B1, ... Bn, self.range_dimension, C] output = _rotate_last_dim(output, rotate_right=False) output = _vec(output) output = _rotate_last_dim(output, rotate_right=False) if tensor_shape.TensorShape(x.shape).is_fully_defined(): column_dim = tensor_shape.TensorShape(x.shape)[-1] broadcast_batch_shape = common_shapes.broadcast_shape( tensor_shape.TensorShape(x.shape)[:-2], self.batch_shape) if adjoint: matrix_dimensions = [self.domain_dimension, column_dim] else: matrix_dimensions = [self.range_dimension, column_dim] tensorshape_util.set_shape( output, broadcast_batch_shape.concatenate(matrix_dimensions)) return output
def _shape(self): matrix_shape = tensor_shape.TensorShape( (self._num_rows_static, self._num_rows_static)) batch_shape = tensor_shape.TensorShape(self.multiplier.shape) return batch_shape.concatenate(matrix_shape)
def _shape(self): batch_shape = _ops.broadcast_static_shape( self.base_operator.batch_shape, tensor_shape.TensorShape(self.u.shape)[:-2]) return batch_shape.concatenate( tensor_shape.TensorShape(self.base_operator.shape)[-2:])
def solvevec(self, rhs, adjoint=False, name="solve"): """Solve single equation with best effort: `A X = rhs`. The returned `Tensor` will be close to an exact solution if `A` is well conditioned. Otherwise closeness will vary. See class docstring for details. Examples: ```python # Make an operator acting like batch matrix A. Assume tensor_shape.TensorShape(A.shape) = [..., M, N] operator = LinearOperator(...) tensor_shape.TensorShape(operator.shape) = [..., M, N] # Solve one linear system for every member of the batch. RHS = ... # shape [..., M] X = operator.solvevec(RHS) # X is the solution to the linear system # sum_j A[..., :, j] X[..., j] = RHS[..., :] operator.matvec(X) ==> RHS ``` Args: rhs: `Tensor` with same `dtype` as this operator, or list of `Tensor`s (for blockwise operators). `Tensor`s are treated as [batch] vectors, meaning for every set of leading dimensions, the last dimension defines a vector. See class docstring for definition of compatibility regarding batch dimensions. adjoint: Python `bool`. If `True`, solve the system involving the adjoint of this `LinearOperator`: `A^H X = rhs`. name: A name scope to use for ops added by this method. Returns: `Tensor` with shape `[...,N]` and same `dtype` as `rhs`. Raises: NotImplementedError: If `self.is_non_singular` or `is_square` is False. """ with self._name_scope(name): block_dimensions = (self._block_domain_dimensions() if adjoint else self._block_range_dimensions()) if linear_operator_util.arg_is_blockwise(block_dimensions, rhs, -1): for i, block in enumerate(rhs): if not isinstance(block, linear_operator.LinearOperator): block = ops.convert_to_tensor(block) # self._check_input_dtype(block) block_dimensions[i].assert_is_compatible_with( tensor_shape.TensorShape(block.shape)[-1]) rhs[i] = block rhs_mat = [ array_ops.expand_dims(block, axis=-1) for block in rhs ] solution_mat = self.solve(rhs_mat, adjoint=adjoint) return [array_ops.squeeze(x, axis=-1) for x in solution_mat] rhs = ops.convert_to_tensor(rhs, name="rhs") # self._check_input_dtype(rhs) op_dimension = (self.domain_dimension if adjoint else self.range_dimension) op_dimension.assert_is_compatible_with( tensor_shape.TensorShape(rhs.shape)[-1]) rhs_mat = array_ops.expand_dims(rhs, axis=-1) solution_mat = self.solve(rhs_mat, adjoint=adjoint) return array_ops.squeeze(solution_mat, axis=-1)
def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"): """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`. The returned `Tensor` will be close to an exact solution if `A` is well conditioned. Otherwise closeness will vary. See class docstring for details. Examples: ```python # Make an operator acting like batch matrix A. Assume tensor_shape.TensorShape(A.shape) = [..., M, N] operator = LinearOperator(...) tensor_shape.TensorShape(operator.shape) = [..., M, N] # Solve R > 0 linear systems for every member of the batch. RHS = ... # shape [..., M, R] X = operator.solve(RHS) # X[..., :, r] is the solution to the r'th linear system # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r] operator.matmul(X) ==> RHS ``` Args: rhs: `Tensor` with same `dtype` as this operator and compatible shape, or a list of `Tensor`s (for blockwise operators). `Tensor`s are treated like a [batch] matrices meaning for every set of leading dimensions, the last two dimensions defines a matrix. See class docstring for definition of compatibility. adjoint: Python `bool`. If `True`, solve the system involving the adjoint of this `LinearOperator`: `A^H X = rhs`. adjoint_arg: Python `bool`. If `True`, solve `A X = rhs^H` where `rhs^H` is the hermitian transpose (transposition and complex conjugation). name: A name scope to use for ops added by this method. Returns: `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`. Raises: NotImplementedError: If `self.is_non_singular` or `is_square` is False. """ if self.is_non_singular is False: raise NotImplementedError( "Exact solve not implemented for an operator that is expected to " "be singular.") if self.is_square is False: raise NotImplementedError( "Exact solve not implemented for an operator that is expected to " "not be square.") if isinstance(rhs, linear_operator.LinearOperator): left_operator = self.adjoint() if adjoint else self right_operator = rhs.adjoint() if adjoint_arg else rhs if (right_operator.range_dimension is not None and left_operator.domain_dimension is not None and right_operator.range_dimension != left_operator.domain_dimension): raise ValueError( "Operators are incompatible. Expected `rhs` to have dimension" " {} but got {}.".format(left_operator.domain_dimension, right_operator.range_dimension)) with self._name_scope(name): return linear_operator_algebra.solve(left_operator, right_operator) with self._name_scope(name): block_dimensions = (self._block_domain_dimensions() if adjoint else self._block_range_dimensions()) arg_dim = -1 if adjoint_arg else -2 blockwise_arg = linear_operator_util.arg_is_blockwise( block_dimensions, rhs, arg_dim) if blockwise_arg: split_rhs = rhs for i, block in enumerate(split_rhs): if not isinstance(block, linear_operator.LinearOperator): block = ops.convert_to_tensor(block) # self._check_input_dtype(block) block_dimensions[i].assert_is_compatible_with( tensor_shape.TensorShape(block.shape)[arg_dim]) split_rhs[i] = block else: rhs = ops.convert_to_tensor(rhs, name="rhs") # self._check_input_dtype(rhs) op_dimension = (self.domain_dimension if adjoint else self.range_dimension) op_dimension.assert_is_compatible_with( tensor_shape.TensorShape(rhs.shape)[arg_dim]) split_dim = -1 if adjoint_arg else -2 # Split input by rows normally, and otherwise columns. split_rhs = linear_operator_util.split_arg_into_blocks( self._block_domain_dimensions(), self._block_domain_dimension_tensors, rhs, axis=split_dim) solution_list = [] for index, operator in enumerate(self.operators): solution_list += [ operator.solve(split_rhs[index], adjoint=adjoint, adjoint_arg=adjoint_arg) ] if blockwise_arg: return solution_list solution_list = linear_operator_util.broadcast_matrix_batch_dims( solution_list) return array_ops.concat(solution_list, axis=-2)
def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"): """Transform [batch] matrix `x` with left multiplication: `x --> Ax`. ```python # Make an operator acting like batch matrix A. Assume tensor_shape.TensorShape(A.shape) = [..., M, N] operator = LinearOperator(...) tensor_shape.TensorShape(operator.shape) = [..., M, N] X = ... # shape [..., N, R], batch matrix, R > 0. Y = operator.matmul(X) tensor_shape.TensorShape(Y.shape) ==> [..., M, R] Y[..., :, r] = sum_j A[..., :, j] X[j, r] ``` Args: x: `LinearOperator`, `Tensor` with compatible shape and same `dtype` as `self`, or a blockwise iterable of `LinearOperator`s or `Tensor`s. See class docstring for definition of shape compatibility. adjoint: Python `bool`. If `True`, left multiply by the adjoint: `A^H x`. adjoint_arg: Python `bool`. If `True`, compute `A x^H` where `x^H` is the hermitian transpose (transposition and complex conjugation). name: A name for this `Op`. Returns: A `LinearOperator` or `Tensor` with shape `[..., M, R]` and same `dtype` as `self`, or if `x` is blockwise, a list of `Tensor`s with shapes that concatenate to `[..., M, R]`. """ if isinstance(x, linear_operator.LinearOperator): left_operator = self.adjoint() if adjoint else self right_operator = x.adjoint() if adjoint_arg else x if (right_operator.range_dimension is not None and left_operator.domain_dimension is not None and right_operator.range_dimension != left_operator.domain_dimension): raise ValueError( "Operators are incompatible. Expected `x` to have dimension" " {} but got {}.".format(left_operator.domain_dimension, right_operator.range_dimension)) with self._name_scope(name): return linear_operator_algebra.matmul(left_operator, right_operator) with self._name_scope(name): arg_dim = -1 if adjoint_arg else -2 block_dimensions = (self._block_range_dimensions() if adjoint else self._block_domain_dimensions()) if linear_operator_util.arg_is_blockwise(block_dimensions, x, arg_dim): for i, block in enumerate(x): if not isinstance(block, linear_operator.LinearOperator): block = ops.convert_to_tensor(block) # self._check_input_dtype(block) block_dimensions[i].assert_is_compatible_with( tensor_shape.TensorShape(block.shape)[arg_dim]) x[i] = block else: x = ops.convert_to_tensor(x, name="x") # self._check_input_dtype(x) op_dimension = (self.range_dimension if adjoint else self.domain_dimension) op_dimension.assert_is_compatible_with( tensor_shape.TensorShape(x.shape)[arg_dim]) return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"): """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`. The returned `Tensor` will be close to an exact solution if `A` is well conditioned. Otherwise closeness will vary. See class docstring for details. Given the blockwise `n + 1`-by-`n + 1` linear operator: op = [[A_00 0 ... 0 ... 0], [A_10 A_11 ... 0 ... 0], ... [A_k0 A_k1 ... A_kk ... 0], ... [A_n0 A_n1 ... A_nk ... A_nn]] we find `x = op.solve(y)` by observing that `y_k = A_k0.matmul(x_0) + A_k1.matmul(x_1) + ... + A_kk.matmul(x_k)` and therefore `x_k = A_kk.solve(y_k - A_k0.matmul(x_0) - ... - A_k(k-1).matmul(x_(k-1)))` where `x_k` and `y_k` are the `k`th blocks obtained by decomposing `x` and `y` along their appropriate axes. We first solve `x_0 = A_00.solve(y_0)`. Proceeding inductively, we solve for `x_k`, `k = 1..n`, given `x_0..x_(k-1)`. The adjoint case is solved similarly, beginning with `x_n = A_nn.solve(y_n, adjoint=True)` and proceeding backwards. Examples: ```python # Make an operator acting like batch matrix A. Assume tensor_shape.TensorShape(A.shape) = [..., M, N] operator = LinearOperator(...) tensor_shape.TensorShape(operator.shape) = [..., M, N] # Solve R > 0 linear systems for every member of the batch. RHS = ... # shape [..., M, R] X = operator.solve(RHS) # X[..., :, r] is the solution to the r'th linear system # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r] operator.matmul(X) ==> RHS ``` Args: rhs: `Tensor` with same `dtype` as this operator and compatible shape, or a list of `Tensor`s. `Tensor`s are treated like a [batch] matrices meaning for every set of leading dimensions, the last two dimensions defines a matrix. See class docstring for definition of compatibility. adjoint: Python `bool`. If `True`, solve the system involving the adjoint of this `LinearOperator`: `A^H X = rhs`. adjoint_arg: Python `bool`. If `True`, solve `A X = rhs^H` where `rhs^H` is the hermitian transpose (transposition and complex conjugation). name: A name scope to use for ops added by this method. Returns: `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`. Raises: NotImplementedError: If `self.is_non_singular` or `is_square` is False. """ if self.is_non_singular is False: raise NotImplementedError( "Exact solve not implemented for an operator that is expected to " "be singular.") if self.is_square is False: raise NotImplementedError( "Exact solve not implemented for an operator that is expected to " "not be square.") if isinstance(rhs, linear_operator.LinearOperator): left_operator = self.adjoint() if adjoint else self right_operator = rhs.adjoint() if adjoint_arg else rhs if (right_operator.range_dimension is not None and left_operator.domain_dimension is not None and right_operator.range_dimension != left_operator.domain_dimension): raise ValueError( "Operators are incompatible. Expected `rhs` to have dimension" " {} but got {}.".format(left_operator.domain_dimension, right_operator.range_dimension)) with self._name_scope(name): return linear_operator_algebra.solve(left_operator, right_operator) with self._name_scope(name): block_dimensions = (self._block_domain_dimensions() if adjoint else self._block_range_dimensions()) arg_dim = -1 if adjoint_arg else -2 blockwise_arg = linear_operator_util.arg_is_blockwise( block_dimensions, rhs, arg_dim) if blockwise_arg: for i, block in enumerate(rhs): if not isinstance(block, linear_operator.LinearOperator): block = ops.convert_to_tensor(block) # self._check_input_dtype(block) block_dimensions[i].assert_is_compatible_with( tensor_shape.TensorShape(block.shape)[arg_dim]) rhs[i] = block if adjoint_arg: split_rhs = [linalg.adjoint(y) for y in rhs] else: split_rhs = rhs else: rhs = ops.convert_to_tensor(rhs, name="rhs") # self._check_input_dtype(rhs) op_dimension = (self.domain_dimension if adjoint else self.range_dimension) op_dimension.assert_is_compatible_with( tensor_shape.TensorShape(rhs.shape)[arg_dim]) rhs = linalg.adjoint(rhs) if adjoint_arg else rhs split_rhs = linear_operator_util.split_arg_into_blocks( self._block_domain_dimensions(), self._block_domain_dimension_tensors, rhs, axis=-2) solution_list = [] if adjoint: # For an adjoint blockwise lower-triangular linear operator, the system # must be solved bottom to top. Iterate backwards over rows of the # adjoint (i.e. columns of the non-adjoint operator). for index in reversed(range(len(self.operators))): y = split_rhs[index] # Iterate top to bottom over the operators in the off-diagonal portion # of the column-partition (i.e. row-partition of the adjoint), apply # the operator to the respective block of the solution found in # previous iterations, and subtract the result from the `rhs` block. # For example,let `A`, `B`, and `D` be the linear operators in the top # row-partition of the adjoint of # `LinearOperatorBlockLowerTriangular([[A], [B, C], [D, E, F]])`, # and `x_1` and `x_2` be blocks of the solution found in previous # iterations of the outer loop. The following loop (when `index == 0`) # expresses # `Ax_0 + Bx_1 + Dx_2 = y_0` as `Ax_0 = y_0*`, where # `y_0* = y_0 - Bx_1 - Dx_2`. for j in reversed(range(index + 1, len(self.operators))): y -= self.operators[j][index].matmul( solution_list[len(self.operators) - 1 - j], adjoint=adjoint) # Continuing the example above, solve `Ax_0 = y_0*` for `x_0`. solution_list.append(self._diagonal_operators[index].solve( y, adjoint=adjoint)) solution_list.reverse() else: # Iterate top to bottom over the row-partitions. for row, y in zip(self.operators, split_rhs): # Iterate left to right over the operators in the off-diagonal portion # of the row-partition, apply the operator to the block of the # solution found in previous iterations, and subtract the result from # the `rhs` block. For example, let `D`, `E`, and `F` be the linear # operators in the bottom row-partition of # `LinearOperatorBlockLowerTriangular([[A], [B, C], [D, E, F]])` and # `x_0` and `x_1` be blocks of the solution found in previous # iterations of the outer loop. The following loop # (when `index == 2`), expresses # `Dx_0 + Ex_1 + Fx_2 = y_2` as `Fx_2 = y_2*`, where # `y_2* = y_2 - D_x0 - Ex_1`. for i, operator in enumerate(row[:-1]): y -= operator.matmul(solution_list[i], adjoint=adjoint) # Continuing the example above, solve `Fx_2 = y_2*` for `x_2`. solution_list.append(row[-1].solve(y, adjoint=adjoint)) if blockwise_arg: return solution_list solution_list = linear_operator_util.broadcast_matrix_batch_dims( solution_list) return array_ops.concat(solution_list, axis=-2)
def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"): """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`. The returned `Tensor` will be close to an exact solution if `A` is well conditioned. Otherwise closeness will vary. See class docstring for details. Examples: ```python # Make an operator acting like batch matrix A. Assume tensor_shape.TensorShape(A.shape) = [..., M, N] operator = LinearOperator(...) tensor_shape.TensorShape(operator.shape) = [..., M, N] # Solve R > 0 linear systems for every member of the batch. RHS = ... # shape [..., M, R] X = operator.solve(RHS) # X[..., :, r] is the solution to the r'th linear system # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r] operator.matmul(X) ==> RHS ``` Args: rhs: `Tensor` with same `dtype` as this operator and compatible shape. `rhs` is treated like a [batch] matrix meaning for every set of leading dimensions, the last two dimensions defines a matrix. See class docstring for definition of compatibility. adjoint: Python `bool`. If `True`, solve the system involving the adjoint of this `LinearOperator`: `A^H X = rhs`. adjoint_arg: Python `bool`. If `True`, solve `A X = rhs^H` where `rhs^H` is the hermitian transpose (transposition and complex conjugation). name: A name scope to use for ops added by this method. Returns: `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`. Raises: NotImplementedError: If `self.is_non_singular` or `is_square` is False. """ if self.is_non_singular is False: raise NotImplementedError( "Exact solve not implemented for an operator that is expected to " "be singular.") if self.is_square is False: raise NotImplementedError( "Exact solve not implemented for an operator that is expected to " "not be square.") if isinstance(rhs, LinearOperator): left_operator = self.adjoint() if adjoint else self right_operator = rhs.adjoint() if adjoint_arg else rhs if (right_operator.range_dimension is not None and left_operator.domain_dimension is not None and right_operator.range_dimension != left_operator.domain_dimension): raise ValueError( "Operators are incompatible. Expected `rhs` to have dimension" " {} but got {}.".format(left_operator.domain_dimension, right_operator.range_dimension)) with self._name_scope(name): return linear_operator_algebra.solve(left_operator, right_operator) with self._name_scope(name): rhs = ops.convert_to_tensor(rhs, name="rhs") # self._check_input_dtype(rhs) self_dim = -1 if adjoint else -2 arg_dim = -1 if adjoint_arg else -2 tensor_shape.dimension_at_index( tensor_shape.TensorShape(self.shape), self_dim).assert_is_compatible_with( tensor_shape.TensorShape(rhs.shape)[arg_dim]) return self._solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
def block_shape(self): return tensor_shape.TensorShape( self.spectrum.shape)[-self.block_depth:]
def _solve_matmul_internal(self, x, solve_matmul_fn, adjoint=False, adjoint_arg=False): # We heavily rely on Roth's column Lemma [1]: # (A x B) * vec X = vec BXA^T # where vec stacks all the columns of the matrix under each other. # In our case, we use a variant of the lemma that is row-major # friendly: (A x B) * vec' X = vec' AXB^T # Where vec' reshapes a matrix into a vector. We can repeatedly apply this # for a collection of kronecker products. # Given that (A x B)^-1 = A^-1 x B^-1 and (A x B)^T = A^T x B^T, we can # use the above to compute multiplications, solves with any composition of # transposes. output = x if adjoint_arg: if np.issubdtype(self.dtype, np.complexfloating): output = math_ops.conj(output) else: output = linalg.transpose(output) for o in reversed(self.operators): # Statically compute the reshape. if adjoint: operator_dimension = o.range_dimension_tensor() else: operator_dimension = o.domain_dimension_tensor() output_shape = _prefer_static_shape(output) if ops.get_static_value(operator_dimension) is not None: operator_dimension = ops.get_static_value(operator_dimension) if tensor_shape.TensorShape( output.shape )[-2] is not None and tensor_shape.TensorShape( output.shape)[-1] is not None: dim = int( tensor_shape.TensorShape(output.shape)[-2] * output_shape[-1] // operator_dimension) else: dim = _ops.cast(output_shape[-2] * output_shape[-1] // operator_dimension, dtype=dtypes.int32) output_shape = _prefer_static_concat_shape( output_shape[:-2], [dim, operator_dimension]) output = array_ops.reshape(output, shape=output_shape) # Conjugate because we are trying to compute A @ B^T, but # `LinearOperator` only supports `adjoint_arg`. if np.issubdtype(self.dtype, np.complexfloating): output = math_ops.conj(output) output = solve_matmul_fn(o, output, adjoint=adjoint, adjoint_arg=True) if adjoint_arg: col_dim = _prefer_static_shape(x)[-2] else: col_dim = _prefer_static_shape(x)[-1] if adjoint: row_dim = self.domain_dimension_tensor() else: row_dim = self.range_dimension_tensor() matrix_shape = [row_dim, col_dim] output = array_ops.reshape( output, _prefer_static_concat_shape( _prefer_static_shape(output)[:-2], matrix_shape)) if tensor_shape.TensorShape(x.shape).is_fully_defined(): if adjoint_arg: column_dim = tensor_shape.TensorShape(x.shape)[-2] else: column_dim = tensor_shape.TensorShape(x.shape)[-1] broadcast_batch_shape = common_shapes.broadcast_shape( tensor_shape.TensorShape(x.shape)[:-2], self.batch_shape) if adjoint: matrix_dimensions = [self.domain_dimension, column_dim] else: matrix_dimensions = [self.range_dimension, column_dim] tensorshape_util.set_shape( output, broadcast_batch_shape.concatenate(matrix_dimensions)) return output
def _prefer_static_shape(x): if tensor_shape.TensorShape(x.shape).is_fully_defined(): return tensor_shape.TensorShape(x.shape) return prefer_static.shape(x)