def _shape(self): # Get final matrix shape. domain_dimension = sum(self._block_domain_dimensions()) range_dimension = sum(self._block_range_dimensions()) matrix_shape = tensor_shape.TensorShape([domain_dimension, range_dimension]) # Get broadcast batch shape. # broadcast_shape checks for compatibility. batch_shape = self.operators[0].batch_shape for operator in self.operators[1:]: batch_shape = common_shapes.broadcast_shape( batch_shape, operator.batch_shape) return batch_shape.concatenate(matrix_shape)
def _shape(self): # Get final matrix shape. domain_dimension = self.operators[0].domain_dimension for operator in self.operators[1:]: domain_dimension.assert_is_compatible_with(operator.range_dimension) domain_dimension = operator.domain_dimension matrix_shape = tensor_shape.TensorShape( [self.operators[0].range_dimension, self.operators[-1].domain_dimension]) # Get broadcast batch shape. # broadcast_shape checks for compatibility. batch_shape = self.operators[0].batch_shape for operator in self.operators[1:]: batch_shape = common_shapes.broadcast_shape( batch_shape, operator.batch_shape) return batch_shape.concatenate(matrix_shape)
def _solve(self, rhs, adjoint=False, adjoint_arg=False): # Here we follow the same use of Roth's column lemma as in `matmul`, with # the key difference that we replace all `matmul` instances with `solve`. # This follows from the property that inv(A x B) = inv(A) x inv(B). # Below we document the shape manipulation for adjoint=False, # adjoint_arg=False, but the general case of different adjoints is still # handled. if adjoint_arg: rhs = linalg.adjoint(rhs) # Always add a batch dimension to enable broadcasting to work. batch_shape = self._compute_ones_matrix_shape() rhs = rhs + array_ops.zeros(batch_shape, dtype=rhs.dtype) # rhs has shape [B, R, C], where B represent some number of batch # dimensions, # R represents the number of rows, and C represents the number of columns. # In order to apply Roth's column lemma, we need to operate on a batch of # column vectors, so we reshape into a batch of column vectors. We put it # at the front to ensure that broadcasting between operators to the batch # dimensions B still works. output = _rotate_last_dim(rhs, rotate_right=True) # Also expand the shape to be [A, C, B, R]. The first dimension will be # used to accumulate dimensions from each operator matmul. output = output[_ops.newaxis, ...] # In this loop, A is going to refer to the value of the accumulated # dimension. A = 1 at the start, and will end up being self.range_dimension. # V will refer to the last dimension. V = R at the start, and will end up # being 1 in the end. for operator in self.operators[:-1]: # Reshape output from [A, C, B, V] to be # [A, C, B, V / op.domain_dimension, op.domain_dimension] if adjoint: operator_dimension = operator.range_dimension_tensor() else: operator_dimension = operator.domain_dimension_tensor() output = _unvec_by(output, operator_dimension) # We are computing (XA^-1^T) = (A^-1 X^T)^T. # output has [A, C, B, V / op.domain_dimension, op.domain_dimension], # which is being converted to: # [A, C, B, V / op.domain_dimension, op.range_dimension] output = _linalg.matrix_transpose(output) output = operator.solve(output, adjoint=adjoint, adjoint_arg=False) output = _linalg.matrix_transpose(output) # Rearrange it to [A * op.range_dimension, C, B, V / op.domain_dimension] output = _rotate_last_dim(output, rotate_right=False) output = _vec(output) output = _rotate_last_dim(output, rotate_right=True) # After the loop, we will have # A = self.range_dimension / op[-1].range_dimension # V = op[-1].domain_dimension # We convert that using matvec to get: # [A, C, B, op[-1].range_dimension] output = self.operators[-1].solvevec(output, adjoint=adjoint) # Rearrange shape to be [B1, ... Bn, self.range_dimension, C] output = _rotate_last_dim(output, rotate_right=False) output = _vec(output) output = _rotate_last_dim(output, rotate_right=False) if tensor_shape.TensorShape(rhs.shape).is_fully_defined(): column_dim = tensor_shape.TensorShape(rhs.shape)[-1] broadcast_batch_shape = common_shapes.broadcast_shape( tensor_shape.TensorShape(rhs.shape)[:-2], self.batch_shape) if adjoint: matrix_dimensions = [self.domain_dimension, column_dim] else: matrix_dimensions = [self.range_dimension, column_dim] tensorshape_util.set_shape( output, broadcast_batch_shape.concatenate(matrix_dimensions)) return output
def _matmul(self, x, adjoint=False, adjoint_arg=False): # Here we heavily rely on Roth's column Lemma [1]: # (A x B) * vec X = vec BXA^T, # where vec stacks all the columns of the matrix under each other. In our # case, x represents a batch of vec X (i.e. we think of x as a batch of # column vectors, rather than a matrix). Each member of the batch can be # reshaped to a matrix (hence we get a batch of matrices). # We can iteratively apply this lemma by noting that if B is a Kronecker # product, then we can apply the lemma again. # [1] W. E. Roth, "On direct product matrices," # Bulletin of the American Mathematical Society, vol. 40, pp. 461-468, # 1934 # Efficiency # Naively doing the Kronecker product, by calculating the dense matrix and # applying it will can take cubic time in the size of domain_dimension # (assuming a square matrix). The other issue is that calculating the dense # matrix can be prohibitively expensive, in that it can take a large amount # of memory. # # This implementation avoids this memory blow up by only computing matmuls # with the factors. In this way, we don't have to realize the dense matrix. # In terms of complexity, if we have Kronecker Factors of size: # (n1, n1), (n2, n2), (n3, n3), ... (nJ, nJ), with N = \prod n_i, and we # have as input a [N, M] matrix, the naive approach would take O(N^2 M). # With this approach (ignoring reshaping of tensors and transposes for now), # the time complexity can be O(M * (\sum n_i) * N). There is also the # benefit of batched multiplication (In this example, the batch size is # roughly M * N) so this can be much faster. However, not factored in are # the costs of the several transposing of tensors, which can affect cache # behavior. # Below we document the shape manipulation for adjoint=False, # adjoint_arg=False, but the general case of different adjoints is still # handled. if adjoint_arg: x = linalg.adjoint(x) # Always add a batch dimension to enable broadcasting to work. batch_shape = self._compute_ones_matrix_shape() x = x + array_ops.zeros(batch_shape, dtype=x.dtype) # x has shape [B, R, C], where B represent some number of batch dimensions, # R represents the number of rows, and C represents the number of columns. # In order to apply Roth's column lemma, we need to operate on a batch of # column vectors, so we reshape into a batch of column vectors. We put it # at the front to ensure that broadcasting between operators to the batch # dimensions B still works. output = _rotate_last_dim(x, rotate_right=True) # Also expand the shape to be [A, C, B, R]. The first dimension will be # used to accumulate dimensions from each operator matmul. output = output[_ops.newaxis, ...] # In this loop, A is going to refer to the value of the accumulated # dimension. A = 1 at the start, and will end up being self.range_dimension. # V will refer to the last dimension. V = R at the start, and will end up # being 1 in the end. for operator in self.operators[:-1]: # Reshape output from [A, C, B, V] to be # [A, C, B, V / op.domain_dimension, op.domain_dimension] if adjoint: operator_dimension = operator.range_dimension_tensor() else: operator_dimension = operator.domain_dimension_tensor() output = _unvec_by(output, operator_dimension) # We are computing (XA^T) = (AX^T)^T. # output has [A, C, B, V / op.domain_dimension, op.domain_dimension], # which is being converted to: # [A, C, B, V / op.domain_dimension, op.range_dimension] output = _linalg.matrix_transpose(output) output = operator.matmul(output, adjoint=adjoint, adjoint_arg=False) output = _linalg.matrix_transpose(output) # Rearrange it to [A * op.range_dimension, C, B, V / op.domain_dimension] output = _rotate_last_dim(output, rotate_right=False) output = _vec(output) output = _rotate_last_dim(output, rotate_right=True) # After the loop, we will have # A = self.range_dimension / op[-1].range_dimension # V = op[-1].domain_dimension # We convert that using matvec to get: # [A, C, B, op[-1].range_dimension] output = self.operators[-1].matvec(output, adjoint=adjoint) # Rearrange shape to be [B1, ... Bn, self.range_dimension, C] output = _rotate_last_dim(output, rotate_right=False) output = _vec(output) output = _rotate_last_dim(output, rotate_right=False) if tensor_shape.TensorShape(x.shape).is_fully_defined(): column_dim = tensor_shape.TensorShape(x.shape)[-1] broadcast_batch_shape = common_shapes.broadcast_shape( tensor_shape.TensorShape(x.shape)[:-2], self.batch_shape) if adjoint: matrix_dimensions = [self.domain_dimension, column_dim] else: matrix_dimensions = [self.range_dimension, column_dim] tensorshape_util.set_shape( output, broadcast_batch_shape.concatenate(matrix_dimensions)) return output
def _bcast_shape(base_shape, args): bcast_shape = _ensure_shape_tuple(base_shape) for arg in args: bcast_shape = ops.broadcast_shape(bcast_shape, np.asarray(arg).shape) return bcast_shape
def _solve_matmul_internal(self, x, solve_matmul_fn, adjoint=False, adjoint_arg=False): # We heavily rely on Roth's column Lemma [1]: # (A x B) * vec X = vec BXA^T # where vec stacks all the columns of the matrix under each other. # In our case, we use a variant of the lemma that is row-major # friendly: (A x B) * vec' X = vec' AXB^T # Where vec' reshapes a matrix into a vector. We can repeatedly apply this # for a collection of kronecker products. # Given that (A x B)^-1 = A^-1 x B^-1 and (A x B)^T = A^T x B^T, we can # use the above to compute multiplications, solves with any composition of # transposes. output = x if adjoint_arg: if np.issubdtype(self.dtype, np.complexfloating): output = math_ops.conj(output) else: output = linalg.transpose(output) for o in reversed(self.operators): # Statically compute the reshape. if adjoint: operator_dimension = o.range_dimension_tensor() else: operator_dimension = o.domain_dimension_tensor() output_shape = _prefer_static_shape(output) if ops.get_static_value(operator_dimension) is not None: operator_dimension = ops.get_static_value(operator_dimension) if tensor_shape.TensorShape( output.shape )[-2] is not None and tensor_shape.TensorShape( output.shape)[-1] is not None: dim = int( tensor_shape.TensorShape(output.shape)[-2] * output_shape[-1] // operator_dimension) else: dim = _ops.cast(output_shape[-2] * output_shape[-1] // operator_dimension, dtype=dtypes.int32) output_shape = _prefer_static_concat_shape( output_shape[:-2], [dim, operator_dimension]) output = array_ops.reshape(output, shape=output_shape) # Conjugate because we are trying to compute A @ B^T, but # `LinearOperator` only supports `adjoint_arg`. if np.issubdtype(self.dtype, np.complexfloating): output = math_ops.conj(output) output = solve_matmul_fn(o, output, adjoint=adjoint, adjoint_arg=True) if adjoint_arg: col_dim = _prefer_static_shape(x)[-2] else: col_dim = _prefer_static_shape(x)[-1] if adjoint: row_dim = self.domain_dimension_tensor() else: row_dim = self.range_dimension_tensor() matrix_shape = [row_dim, col_dim] output = array_ops.reshape( output, _prefer_static_concat_shape( _prefer_static_shape(output)[:-2], matrix_shape)) if tensor_shape.TensorShape(x.shape).is_fully_defined(): if adjoint_arg: column_dim = tensor_shape.TensorShape(x.shape)[-2] else: column_dim = tensor_shape.TensorShape(x.shape)[-1] broadcast_batch_shape = common_shapes.broadcast_shape( tensor_shape.TensorShape(x.shape)[:-2], self.batch_shape) if adjoint: matrix_dimensions = [self.domain_dimension, column_dim] else: matrix_dimensions = [self.range_dimension, column_dim] tensorshape_util.set_shape( output, broadcast_batch_shape.concatenate(matrix_dimensions)) return output