def _lu(input, output_idx_type=np.int32, name=None): # pylint: disable=redefined-builtin """Returns Lu(lu, p), as TF does.""" del name input = ops.convert_to_tensor(input) if JAX_MODE: # JAX uses XLA, which can do a batched factorization. lu_out, pivots = scipy_linalg.lu_factor(input) from jax import lax_linalg # pylint: disable=g-import-not-at-top return Lu( lu_out, lax_linalg.lu_pivots_to_permutation(pivots, lu_out.shape[-1])) # Scipy can't batch, so we must do so manually. nbatch = int(np.prod(input.shape[:-2])) dim = input.shape[-1] flat_mat = input.reshape(nbatch, dim, dim) flat_lu = np.empty((nbatch, dim, dim), dtype=input.dtype) flat_piv = np.empty((nbatch, dim), dtype=utils.numpy_dtype(output_idx_type)) if np.size(flat_lu): # Avoid non-empty batches of empty matrices. for i, mat in enumerate(flat_mat): lu_out, pivots = scipy_linalg.lu_factor(mat) flat_lu[i] = lu_out flat_piv[i] = _lu_pivot_to_permutation(pivots, flat_lu.shape[-1]) return Lu(flat_lu.reshape(*input.shape), flat_piv.reshape(*input.shape[:-1]))
def _shape_tensor(self): # Avoid messy broadcasting if possible. if tensor_shape.TensorShape(self.shape).is_fully_defined(): return ops.convert_to_tensor(tensor_shape.TensorShape( self.shape).as_list(), dtype=dtypes.int32, name="shape") # Don't check the matrix dimensions. That would add unnecessary Asserts to # the graph. Things will fail at runtime naturally if shapes are # incompatible. matrix_shape = array_ops.stack([ self.operators[0].range_dimension_tensor(), self.operators[-1].domain_dimension_tensor() ]) # Dummy Tensor of zeros. Will never be materialized. zeros = array_ops.zeros(shape=self.operators[0].batch_shape_tensor()) for operator in self.operators[1:]: zeros = zeros + array_ops.zeros( shape=operator.batch_shape_tensor()) batch_shape = array_ops.shape(zeros) return array_ops.concat((batch_shape, matrix_shape), 0)
def _broadcast_batch_dims(self, x, spectrum): """Broadcast batch dims of batch matrix `x` and spectrum.""" spectrum = ops.convert_to_tensor(spectrum, name="spectrum") # tensor_shape.TensorShape(spectrum.shape) = batch_shape + block_shape # First make spectrum a batch matrix with # tensor_shape.TensorShape(spectrum.shape) = batch_shape + [prod(block_shape), 1] batch_shape = self._batch_shape_tensor(shape=self._shape_tensor( spectrum=spectrum)) spec_mat = array_ops.reshape( spectrum, array_ops.concat((batch_shape, [-1, 1]), axis=0)) # Second, broadcast, possibly requiring an addition of array of zeros. x, spec_mat = linear_operator_util.broadcast_matrix_batch_dims( (x, spec_mat)) # Third, put the block shape back into spectrum. x_batch_shape = array_ops.shape(x)[:-2] spectrum_shape = array_ops.shape(spectrum) spectrum = array_ops.reshape( spec_mat, array_ops.concat( (x_batch_shape, self._block_shape_tensor(spectrum_shape=spectrum_shape)), axis=0)) return x, spectrum
def _check_matrix(self, matrix): """Static check of the `matrix` argument.""" allowed_dtypes = [ dtypes.float16, dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128, ] matrix = ops.convert_to_tensor(matrix, name="matrix") dtype = matrix.dtype if dtype not in allowed_dtypes: raise TypeError( f"Argument `matrix` must have dtype in {allowed_dtypes}. " f"Received: {dtype}.") if tensor_shape.TensorShape( matrix.shape).ndims is not None and tensor_shape.TensorShape( matrix.shape).ndims < 2: raise ValueError( f"Argument `matrix` must have at least 2 dimensions. " f"Received: {matrix}.")
def add_to_tensor(self, mat, name="add_to_tensor"): """Add matrix represented by this operator to `mat`. Equiv to `I + mat`. Args: mat: `Tensor` with same `dtype` and shape broadcastable to `self`. name: A name to give this `Op`. Returns: A `Tensor` with broadcast shape and same `dtype` as `self`. """ with self._name_scope(name): # Shape [B1,...,Bb, 1] multiplier_vector = array_ops.expand_dims(self.multiplier, -1) # Shape [C1,...,Cc, M, M] mat = ops.convert_to_tensor(mat, name="mat") # Shape [C1,...,Cc, M] mat_diag = _linalg.diag_part(mat) # multiplier_vector broadcasts here. new_diag = multiplier_vector + mat_diag return _linalg.set_diag(mat, new_diag)
def matvec(self, x, adjoint=False, name="matvec"): """Transform [batch] vector `x` with left multiplication: `x --> Ax`. ```python # Make an operator acting like batch matrix A. Assume tensor_shape.TensorShape(A.shape) = [..., M, N] operator = LinearOperator(...) X = ... # shape [..., N], batch vector Y = operator.matvec(X) tensor_shape.TensorShape(Y.shape) ==> [..., M] Y[..., :] = sum_j A[..., :, j] X[..., j] ``` Args: x: `Tensor` with compatible shape and same `dtype` as `self`. `x` is treated as a [batch] vector meaning for every set of leading dimensions, the last dimension defines a vector. See class docstring for definition of compatibility. adjoint: Python `bool`. If `True`, left multiply by the adjoint: `A^H x`. name: A name for this `Op`. Returns: A `Tensor` with shape `[..., M]` and same `dtype` as `self`. """ with self._name_scope(name): x = ops.convert_to_tensor(x, name="x") # self._check_input_dtype(x) self_dim = -2 if adjoint else -1 tensor_shape.dimension_at_index( tensor_shape.TensorShape(self.shape), self_dim).assert_is_compatible_with( tensor_shape.TensorShape(x.shape)[-1]) return self._matvec(x, adjoint=adjoint)
def _shape(input, out_type=np.int32, name=None): # pylint: disable=redefined-builtin,unused-argument return ops.convert_to_tensor(ops.convert_to_tensor(input).shape).astype( out_type)
def _diag_part(self): reflection_axis = ops.convert_to_tensor( self.reflection_axis) normalized_axis = nn.l2_normalize(reflection_axis, axis=-1) return 1. - 2 * normalized_axis * math_ops.conj(normalized_axis)
def broadcast_matrix_batch_dims(batch_matrices, name=None): """Broadcast leading dimensions of zero or more [batch] matrices. Example broadcasting one batch dim of two simple matrices. ```python x = [[1, 2], [3, 4]] # Shape [2, 2], no batch dims y = [[[1]]] # Shape [1, 1, 1], 1 batch dim of shape [1] x_bc, y_bc = broadcast_matrix_batch_dims([x, y]) x_bc ==> [[[1, 2], [3, 4]]] # Shape [1, 2, 2], 1 batch dim of shape [1]. y_bc ==> same as y ``` Example broadcasting many batch dims ```python x = tf.random.normal(shape=(2, 3, 1, 4, 4)) y = tf.random.normal(shape=(1, 3, 2, 5, 5)) x_bc, y_bc = broadcast_matrix_batch_dims([x, y]) tensor_shape.TensorShape(x_bc.shape) ==> (2, 3, 2, 4, 4) tensor_shape.TensorShape(y_bc.shape) ==> (2, 3, 2, 5, 5) ``` Args: batch_matrices: Iterable of `Tensor`s, each having two or more dimensions. name: A string name to prepend to created ops. Returns: bcast_matrices: List of `Tensor`s, with `bcast_matrices[i]` containing the values from `batch_matrices[i]`, with possibly broadcast batch dims. Raises: ValueError: If any input `Tensor` is statically determined to have less than two dimensions. """ with ops.name_scope( name or "broadcast_matrix_batch_dims", values=batch_matrices): check_ops.assert_proper_iterable(batch_matrices) batch_matrices = list(batch_matrices) for i, mat in enumerate(batch_matrices): batch_matrices[i] = ops.convert_to_tensor(mat) assert_is_batch_matrix(batch_matrices[i]) if len(batch_matrices) < 2: return batch_matrices # Try static broadcasting. # bcast_batch_shape is the broadcast batch shape of ALL matrices. # E.g. if batch_matrices = [x, y], with # tensor_shape.TensorShape(x.shape) = [2, j, k] (batch shape = [2]) # tensor_shape.TensorShape(y.shape) = [3, 1, l, m] (batch shape = [3, 1]) # ==> bcast_batch_shape = [3, 2] bcast_batch_shape = tensor_shape.TensorShape(batch_matrices[0].shape)[:-2] for mat in batch_matrices[1:]: bcast_batch_shape = _ops.broadcast_static_shape( bcast_batch_shape, tensor_shape.TensorShape(mat.shape)[:-2]) if bcast_batch_shape.is_fully_defined(): for i, mat in enumerate(batch_matrices): if tensor_shape.TensorShape(mat.shape)[:-2] != bcast_batch_shape: bcast_shape = array_ops.concat( [bcast_batch_shape.as_list(), array_ops.shape(mat)[-2:]], axis=0) batch_matrices[i] = _ops.broadcast_to(mat, bcast_shape) return batch_matrices # Since static didn't work, do dynamic, which always copies data. bcast_batch_shape = array_ops.shape(batch_matrices[0])[:-2] for mat in batch_matrices[1:]: bcast_batch_shape = array_ops.broadcast_dynamic_shape( bcast_batch_shape, array_ops.shape(mat)[:-2]) for i, mat in enumerate(batch_matrices): batch_matrices[i] = _ops.broadcast_to( mat, array_ops.concat( [bcast_batch_shape, array_ops.shape(mat)[-2:]], axis=0)) return batch_matrices
def _irfftn(x, s, axes): x = ops.convert_to_tensor(x) float_dtype = np.finfo(x.dtype).dtype return np.fft.irfftn(x, s=s, axes=axes).astype(float_dtype)
def _ifftn(x, axes): x = ops.convert_to_tensor(x) return np.fft.ifftn(x, axes=axes).astype(x.dtype)
range = utils.copy_docstring( # pylint: disable=redefined-builtin 'tf.range', _range) rank = utils.copy_docstring( 'tf.rank', lambda input, name=None: np.int32(np.array(input).ndim)) # pylint: disable=redefined-builtin,g-long-lambda repeat = utils.copy_docstring( 'tf.repeat', lambda input, repeats, axis=None, name=None: np.repeat( # pylint: disable=g-long-lambda input, repeats, axis=axis)) reshape = utils.copy_docstring( 'tf.reshape', lambda tensor, shape, name=None: np.reshape( # pylint: disable=g-long-lambda ops.convert_to_tensor(tensor), shape)) roll = utils.copy_docstring( 'tf.roll', lambda input, shift, axis: np.roll(input, shift, axis)) # pylint: disable=unnecessary-lambda searchsorted = utils.copy_docstring('tf.searchsorted', _searchsorted) shape = utils.copy_docstring('tf.shape', _shape) size = utils.copy_docstring('tf.size', _size) slice = utils.copy_docstring( # pylint: disable=redefined-builtin 'tf.slice', _slice) split = utils.copy_docstring('tf.split', _split)
def _eigvals(self): return ops.convert_to_tensor(self.diag)
def _diag_part(self): reflection_axis = ops.convert_to_tensor(self.reflection_axis) normalized_axis = reflection_axis / linalg.norm( reflection_axis, axis=-1, keepdims=True) return 1. - 2 * normalized_axis * math_ops.conj(normalized_axis)
def _transpose(a, perm=None, conjugate=False, name='transpose'): # pylint: disable=unused-argument x = np.transpose(ops.convert_to_tensor(a), perm) return np.conjugate(x) if conjugate else x
def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"): """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`. The returned `Tensor` will be close to an exact solution if `A` is well conditioned. Otherwise closeness will vary. See class docstring for details. Examples: ```python # Make an operator acting like batch matrix A. Assume tensor_shape.TensorShape(A.shape) = [..., M, N] operator = LinearOperator(...) tensor_shape.TensorShape(operator.shape) = [..., M, N] # Solve R > 0 linear systems for every member of the batch. RHS = ... # shape [..., M, R] X = operator.solve(RHS) # X[..., :, r] is the solution to the r'th linear system # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r] operator.matmul(X) ==> RHS ``` Args: rhs: `Tensor` with same `dtype` as this operator and compatible shape, or a list of `Tensor`s (for blockwise operators). `Tensor`s are treated like a [batch] matrices meaning for every set of leading dimensions, the last two dimensions defines a matrix. See class docstring for definition of compatibility. adjoint: Python `bool`. If `True`, solve the system involving the adjoint of this `LinearOperator`: `A^H X = rhs`. adjoint_arg: Python `bool`. If `True`, solve `A X = rhs^H` where `rhs^H` is the hermitian transpose (transposition and complex conjugation). name: A name scope to use for ops added by this method. Returns: `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`. Raises: NotImplementedError: If `self.is_non_singular` or `is_square` is False. """ if self.is_non_singular is False: raise NotImplementedError( "Exact solve not implemented for an operator that is expected to " "be singular.") if self.is_square is False: raise NotImplementedError( "Exact solve not implemented for an operator that is expected to " "not be square.") if isinstance(rhs, linear_operator.LinearOperator): left_operator = self.adjoint() if adjoint else self right_operator = rhs.adjoint() if adjoint_arg else rhs if (right_operator.range_dimension is not None and left_operator.domain_dimension is not None and right_operator.range_dimension != left_operator.domain_dimension): raise ValueError( "Operators are incompatible. Expected `rhs` to have dimension" " {} but got {}.".format(left_operator.domain_dimension, right_operator.range_dimension)) with self._name_scope(name): return linear_operator_algebra.solve(left_operator, right_operator) with self._name_scope(name): block_dimensions = (self._block_domain_dimensions() if adjoint else self._block_range_dimensions()) arg_dim = -1 if adjoint_arg else -2 blockwise_arg = linear_operator_util.arg_is_blockwise( block_dimensions, rhs, arg_dim) if blockwise_arg: split_rhs = rhs for i, block in enumerate(split_rhs): if not isinstance(block, linear_operator.LinearOperator): block = ops.convert_to_tensor(block) # self._check_input_dtype(block) block_dimensions[i].assert_is_compatible_with( tensor_shape.TensorShape(block.shape)[arg_dim]) split_rhs[i] = block else: rhs = ops.convert_to_tensor(rhs, name="rhs") # self._check_input_dtype(rhs) op_dimension = (self.domain_dimension if adjoint else self.range_dimension) op_dimension.assert_is_compatible_with( tensor_shape.TensorShape(rhs.shape)[arg_dim]) split_dim = -1 if adjoint_arg else -2 # Split input by rows normally, and otherwise columns. split_rhs = linear_operator_util.split_arg_into_blocks( self._block_domain_dimensions(), self._block_domain_dimension_tensors, rhs, axis=split_dim) solution_list = [] for index, operator in enumerate(self.operators): solution_list += [ operator.solve(split_rhs[index], adjoint=adjoint, adjoint_arg=adjoint_arg) ] if blockwise_arg: return solution_list solution_list = linear_operator_util.broadcast_matrix_batch_dims( solution_list) return array_ops.concat(solution_list, axis=-2)
def solvevec(self, rhs, adjoint=False, name="solve"): """Solve single equation with best effort: `A X = rhs`. The returned `Tensor` will be close to an exact solution if `A` is well conditioned. Otherwise closeness will vary. See class docstring for details. Examples: ```python # Make an operator acting like batch matrix A. Assume tensor_shape.TensorShape(A.shape) = [..., M, N] operator = LinearOperator(...) tensor_shape.TensorShape(operator.shape) = [..., M, N] # Solve one linear system for every member of the batch. RHS = ... # shape [..., M] X = operator.solvevec(RHS) # X is the solution to the linear system # sum_j A[..., :, j] X[..., j] = RHS[..., :] operator.matvec(X) ==> RHS ``` Args: rhs: `Tensor` with same `dtype` as this operator, or list of `Tensor`s (for blockwise operators). `Tensor`s are treated as [batch] vectors, meaning for every set of leading dimensions, the last dimension defines a vector. See class docstring for definition of compatibility regarding batch dimensions. adjoint: Python `bool`. If `True`, solve the system involving the adjoint of this `LinearOperator`: `A^H X = rhs`. name: A name scope to use for ops added by this method. Returns: `Tensor` with shape `[...,N]` and same `dtype` as `rhs`. Raises: NotImplementedError: If `self.is_non_singular` or `is_square` is False. """ with self._name_scope(name): block_dimensions = (self._block_domain_dimensions() if adjoint else self._block_range_dimensions()) if linear_operator_util.arg_is_blockwise(block_dimensions, rhs, -1): for i, block in enumerate(rhs): if not isinstance(block, linear_operator.LinearOperator): block = ops.convert_to_tensor(block) # self._check_input_dtype(block) block_dimensions[i].assert_is_compatible_with( tensor_shape.TensorShape(block.shape)[-1]) rhs[i] = block rhs_mat = [ array_ops.expand_dims(block, axis=-1) for block in rhs ] solution_mat = self.solve(rhs_mat, adjoint=adjoint) return [array_ops.squeeze(x, axis=-1) for x in solution_mat] rhs = ops.convert_to_tensor(rhs, name="rhs") # self._check_input_dtype(rhs) op_dimension = (self.domain_dimension if adjoint else self.range_dimension) op_dimension.assert_is_compatible_with( tensor_shape.TensorShape(rhs.shape)[-1]) rhs_mat = array_ops.expand_dims(rhs, axis=-1) solution_mat = self.solve(rhs_mat, adjoint=adjoint) return array_ops.squeeze(solution_mat, axis=-1)
return np.conjugate(x) if conjugate else x def _zeros_like(input, dtype=None, name=None): # pylint: disable=redefined-builtin s = _shape(input) if isinstance(s, (np.ndarray, np.generic)): return np.zeros(s, utils.numpy_dtype(dtype or input.dtype)) return tf.zeros(s, dtype or s.dtype, name) # --- Begin Public Functions -------------------------------------------------- concat = utils.copy_docstring( tf.concat, lambda values, axis, name='concat': ( # pylint: disable=g-long-lambda np.concatenate([ops.convert_to_tensor(v) for v in values], axis))) expand_dims = utils.copy_docstring( tf.expand_dims, lambda input, axis, name=None: np.expand_dims(input, axis)) fill = utils.copy_docstring( tf.fill, lambda dims, value, name=None: value * np.ones(dims, np.array(value).dtype)) gather = utils.copy_docstring(tf.gather, _gather) gather_nd = utils.copy_docstring(tf.gather_nd, _gather_nd) reverse = utils.copy_docstring(tf.reverse, _reverse)
cholesky_solve = utils.copy_docstring( 'tf.linalg.cholesky_solve', _cholesky_solve) det = utils.copy_docstring( 'tf.linalg.det', lambda input, name=None: np.linalg.det(input)) diag = utils.copy_docstring( 'tf.linalg.diag', _diag) diag_part = utils.copy_docstring( 'tf.linalg.diag_part', lambda input, name=None: np.diagonal( # pylint: disable=g-long-lambda ops.convert_to_tensor(input), axis1=-2, axis2=-1)) eig = utils.copy_docstring('tf.linalg.eig', _eig) eigh = utils.copy_docstring( 'tf.linalg.eigh', lambda tensor, name=None: np.linalg.eigh(tensor)) eigvals = utils.copy_docstring('tf.linalg.eigvals', _eigvals) eigvalsh = utils.copy_docstring( 'tf.linalg.eigvalsh', lambda tensor, name=None: np.linalg.eigvalsh(tensor)) einsum = utils.copy_docstring( 'tf.linalg.einsum',
def _rfftn(x, s, axes): x = ops.convert_to_tensor(x) complex_dtype = np.result_type(np.complex64, x.dtype) return np.fft.rfftn(x, s=s, axes=axes).astype(complex_dtype)
def _to_dense(self): reflection_axis = ops.convert_to_tensor(self.reflection_axis) normalized_axis = nn.l2_normalize(reflection_axis, axis=-1) mat = normalized_axis[..., _ops.newaxis] matrix = -2 * _linalg.matmul(mat, mat, adjoint_b=True) return _linalg.set_diag(matrix, 1. + _linalg.diag_part(matrix))
def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"): """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`. The returned `Tensor` will be close to an exact solution if `A` is well conditioned. Otherwise closeness will vary. See class docstring for details. Examples: ```python # Make an operator acting like batch matrix A. Assume _ops.TensorShape(A.shape) = [..., M, N] operator = LinearOperator(...) _ops.TensorShape(operator.shape) = [..., M, N] # Solve R > 0 linear systems for every member of the batch. RHS = ... # shape [..., M, R] X = operator.solve(RHS) # X[..., :, r] is the solution to the r'th linear system # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r] operator.matmul(X) ==> RHS ``` Args: rhs: `Tensor` with same `dtype` as this operator and compatible shape. `rhs` is treated like a [batch] matrix meaning for every set of leading dimensions, the last two dimensions defines a matrix. See class docstring for definition of compatibility. adjoint: Python `bool`. If `True`, solve the system involving the adjoint of this `LinearOperator`: `A^H X = rhs`. adjoint_arg: Python `bool`. If `True`, solve `A X = rhs^H` where `rhs^H` is the hermitian transpose (transposition and complex conjugation). name: A name scope to use for ops added by this method. Returns: `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`. Raises: NotImplementedError: If `self.is_non_singular` or `is_square` is False. """ if self.is_non_singular is False: raise NotImplementedError( "Exact solve not implemented for an operator that is expected to " "be singular.") if self.is_square is False: raise NotImplementedError( "Exact solve not implemented for an operator that is expected to " "not be square.") if isinstance(rhs, LinearOperator): left_operator = self.adjoint() if adjoint else self right_operator = rhs.adjoint() if adjoint_arg else rhs if (right_operator.range_dimension is not None and left_operator.domain_dimension is not None and right_operator.range_dimension != left_operator.domain_dimension): raise ValueError( "Operators are incompatible. Expected `rhs` to have dimension" " {} but got {}.".format(left_operator.domain_dimension, right_operator.range_dimension)) with self._name_scope(name): return linear_operator_algebra.solve(left_operator, right_operator) with self._name_scope(name): rhs = ops.convert_to_tensor(rhs, name="rhs") # self._check_input_dtype(rhs) self_dim = -1 if adjoint else -2 arg_dim = -1 if adjoint_arg else -2 tensor_shape.dimension_at_index(_ops.TensorShape( self.shape), self_dim).assert_is_compatible_with( _ops.TensorShape(rhs.shape)[arg_dim]) return self._solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
def _eigvals(self): return ops.convert_to_tensor(self.spectrum)
def convert_nonref_to_tensor(value, dtype=None, dtype_hint=None, name=None): """Converts the given `value` to a `Tensor` if input is nonreference type. This function converts Python objects of various types to `Tensor` objects except if the input has nonreference semantics. Reference semantics are characterized by `is_ref` and is any object which is a `tf.Variable` or instance of `tf.Module`. This function accepts any input which `tf.convert_to_tensor` would also. Note: This function diverges from default Numpy behavior for `float` and `string` types when `None` is present in a Python list or scalar. Rather than silently converting `None` values, an error will be thrown. Args: value: An object whose type has a registered `Tensor` conversion function. dtype: Optional element type for the returned tensor. If missing, the type is inferred from the type of `value`. dtype_hint: Optional element type for the returned tensor, used when dtype is None. In some cases, a caller may not have a dtype in mind when converting to a tensor, so dtype_hint can be used as a soft preference. If the conversion to `dtype_hint` is not possible, this argument has no effect. name: Optional name to use if a new `Tensor` is created. Returns: tensor: A `Tensor` based on `value`. Raises: TypeError: If no conversion function is registered for `value` to `dtype`. RuntimeError: If a registered conversion function returns an invalid value. ValueError: If the `value` is a tensor not of given `dtype` in graph mode. #### Examples: ```python x = tf.Variable(0.) y = convert_nonref_to_tensor(x) x is y # ==> True x = tf.constant(0.) y = convert_nonref_to_tensor(x) x is y # ==> True x = np.array(0.) y = convert_nonref_to_tensor(x) x is y # ==> False tf.is_tensor(y) # ==> True x = tfp.util.DeferredTensor(13.37, lambda x: x) y = convert_nonref_to_tensor(x) x is y # ==> True tf.is_tensor(y) # ==> False tf.equal(y, 13.37) # ==> True ``` """ # We explicitly do not use a tf.name_scope to avoid graph clutter. if value is None: return None if is_ref(value): if dtype is None: return value dtype_base = base_dtype(dtype) value_dtype_base = base_dtype(value.dtype) if dtype_base != value_dtype_base: raise TypeError('Mutable type must be of dtype "{}" but is "{}".'.format( dtype_name(dtype_base), dtype_name(value_dtype_base))) return value return ops.convert_to_tensor( value, dtype=dtype, dtype_hint=dtype_hint, name=name)
def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"): """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`. The returned `Tensor` will be close to an exact solution if `A` is well conditioned. Otherwise closeness will vary. See class docstring for details. Given the blockwise `n + 1`-by-`n + 1` linear operator: op = [[A_00 0 ... 0 ... 0], [A_10 A_11 ... 0 ... 0], ... [A_k0 A_k1 ... A_kk ... 0], ... [A_n0 A_n1 ... A_nk ... A_nn]] we find `x = op.solve(y)` by observing that `y_k = A_k0.matmul(x_0) + A_k1.matmul(x_1) + ... + A_kk.matmul(x_k)` and therefore `x_k = A_kk.solve(y_k - A_k0.matmul(x_0) - ... - A_k(k-1).matmul(x_(k-1)))` where `x_k` and `y_k` are the `k`th blocks obtained by decomposing `x` and `y` along their appropriate axes. We first solve `x_0 = A_00.solve(y_0)`. Proceeding inductively, we solve for `x_k`, `k = 1..n`, given `x_0..x_(k-1)`. The adjoint case is solved similarly, beginning with `x_n = A_nn.solve(y_n, adjoint=True)` and proceeding backwards. Examples: ```python # Make an operator acting like batch matrix A. Assume tensor_shape.TensorShape(A.shape) = [..., M, N] operator = LinearOperator(...) tensor_shape.TensorShape(operator.shape) = [..., M, N] # Solve R > 0 linear systems for every member of the batch. RHS = ... # shape [..., M, R] X = operator.solve(RHS) # X[..., :, r] is the solution to the r'th linear system # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r] operator.matmul(X) ==> RHS ``` Args: rhs: `Tensor` with same `dtype` as this operator and compatible shape, or a list of `Tensor`s. `Tensor`s are treated like a [batch] matrices meaning for every set of leading dimensions, the last two dimensions defines a matrix. See class docstring for definition of compatibility. adjoint: Python `bool`. If `True`, solve the system involving the adjoint of this `LinearOperator`: `A^H X = rhs`. adjoint_arg: Python `bool`. If `True`, solve `A X = rhs^H` where `rhs^H` is the hermitian transpose (transposition and complex conjugation). name: A name scope to use for ops added by this method. Returns: `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`. Raises: NotImplementedError: If `self.is_non_singular` or `is_square` is False. """ if self.is_non_singular is False: raise NotImplementedError( "Exact solve not implemented for an operator that is expected to " "be singular.") if self.is_square is False: raise NotImplementedError( "Exact solve not implemented for an operator that is expected to " "not be square.") if isinstance(rhs, linear_operator.LinearOperator): left_operator = self.adjoint() if adjoint else self right_operator = rhs.adjoint() if adjoint_arg else rhs if (right_operator.range_dimension is not None and left_operator.domain_dimension is not None and right_operator.range_dimension != left_operator.domain_dimension): raise ValueError( "Operators are incompatible. Expected `rhs` to have dimension" " {} but got {}.".format(left_operator.domain_dimension, right_operator.range_dimension)) with self._name_scope(name): return linear_operator_algebra.solve(left_operator, right_operator) with self._name_scope(name): block_dimensions = (self._block_domain_dimensions() if adjoint else self._block_range_dimensions()) arg_dim = -1 if adjoint_arg else -2 blockwise_arg = linear_operator_util.arg_is_blockwise( block_dimensions, rhs, arg_dim) if blockwise_arg: for i, block in enumerate(rhs): if not isinstance(block, linear_operator.LinearOperator): block = ops.convert_to_tensor(block) # self._check_input_dtype(block) block_dimensions[i].assert_is_compatible_with( tensor_shape.TensorShape(block.shape)[arg_dim]) rhs[i] = block if adjoint_arg: split_rhs = [linalg.adjoint(y) for y in rhs] else: split_rhs = rhs else: rhs = ops.convert_to_tensor(rhs, name="rhs") # self._check_input_dtype(rhs) op_dimension = (self.domain_dimension if adjoint else self.range_dimension) op_dimension.assert_is_compatible_with( tensor_shape.TensorShape(rhs.shape)[arg_dim]) rhs = linalg.adjoint(rhs) if adjoint_arg else rhs split_rhs = linear_operator_util.split_arg_into_blocks( self._block_domain_dimensions(), self._block_domain_dimension_tensors, rhs, axis=-2) solution_list = [] if adjoint: # For an adjoint blockwise lower-triangular linear operator, the system # must be solved bottom to top. Iterate backwards over rows of the # adjoint (i.e. columns of the non-adjoint operator). for index in reversed(range(len(self.operators))): y = split_rhs[index] # Iterate top to bottom over the operators in the off-diagonal portion # of the column-partition (i.e. row-partition of the adjoint), apply # the operator to the respective block of the solution found in # previous iterations, and subtract the result from the `rhs` block. # For example,let `A`, `B`, and `D` be the linear operators in the top # row-partition of the adjoint of # `LinearOperatorBlockLowerTriangular([[A], [B, C], [D, E, F]])`, # and `x_1` and `x_2` be blocks of the solution found in previous # iterations of the outer loop. The following loop (when `index == 0`) # expresses # `Ax_0 + Bx_1 + Dx_2 = y_0` as `Ax_0 = y_0*`, where # `y_0* = y_0 - Bx_1 - Dx_2`. for j in reversed(range(index + 1, len(self.operators))): y -= self.operators[j][index].matmul( solution_list[len(self.operators) - 1 - j], adjoint=adjoint) # Continuing the example above, solve `Ax_0 = y_0*` for `x_0`. solution_list.append(self._diagonal_operators[index].solve( y, adjoint=adjoint)) solution_list.reverse() else: # Iterate top to bottom over the row-partitions. for row, y in zip(self.operators, split_rhs): # Iterate left to right over the operators in the off-diagonal portion # of the row-partition, apply the operator to the block of the # solution found in previous iterations, and subtract the result from # the `rhs` block. For example, let `D`, `E`, and `F` be the linear # operators in the bottom row-partition of # `LinearOperatorBlockLowerTriangular([[A], [B, C], [D, E, F]])` and # `x_0` and `x_1` be blocks of the solution found in previous # iterations of the outer loop. The following loop # (when `index == 2`), expresses # `Dx_0 + Ex_1 + Fx_2 = y_2` as `Fx_2 = y_2*`, where # `y_2* = y_2 - D_x0 - Ex_1`. for i, operator in enumerate(row[:-1]): y -= operator.matmul(solution_list[i], adjoint=adjoint) # Continuing the example above, solve `Fx_2 = y_2*` for `x_2`. solution_list.append(row[-1].solve(y, adjoint=adjoint)) if blockwise_arg: return solution_list solution_list = linear_operator_util.broadcast_matrix_batch_dims( solution_list) return array_ops.concat(solution_list, axis=-2)
# --- Begin Public Functions -------------------------------------------------- concat = utils.copy_docstring( 'tf.concat', _concat) expand_dims = utils.copy_docstring( 'tf.expand_dims', lambda input, axis, name=None: np.expand_dims(input, axis)) fill = utils.copy_docstring( 'tf.fill', lambda dims, value, name=None: np.full(dims, ops.convert_to_tensor(value))) gather = utils.copy_docstring( 'tf.gather', _gather) gather_nd = utils.copy_docstring( 'tf.gather_nd', _gather_nd) reverse = utils.copy_docstring('tf.reverse', _reverse) linspace = utils.copy_docstring( 'tf.linspace', _linspace)
cholesky = utils.copy_docstring( 'tf.linalg.cholesky', lambda input, name=None: np.linalg.cholesky(input)) cholesky_solve = utils.copy_docstring('tf.linalg.cholesky_solve', _cholesky_solve) det = utils.copy_docstring('tf.linalg.det', lambda input, name=None: np.linalg.det(input)) diag = utils.copy_docstring('tf.linalg.diag', _diag) diag_part = utils.copy_docstring( 'tf.linalg.diag_part', lambda input, name=None: np.diagonal( # pylint: disable=g-long-lambda ops.convert_to_tensor(input), axis1=-2, axis2=-1)) eig = utils.copy_docstring('tf.linalg.eig', _eig) eigh = utils.copy_docstring('tf.linalg.eigh', lambda tensor, name=None: np.linalg.eigh(tensor)) eigvals = utils.copy_docstring('tf.linalg.eigvals', _eigvals) eigvalsh = utils.copy_docstring( 'tf.linalg.eigvalsh', lambda tensor, name=None: np.linalg.eigvalsh(tensor)) einsum = utils.copy_docstring('tf.linalg.einsum', _einsum)
def _size(input, out_type=np.int32, name=None): # pylint: disable=redefined-builtin, unused-argument return np.asarray( onp.prod(ops.convert_to_tensor(input).shape), dtype=out_type)
def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"): """Transform [batch] matrix `x` with left multiplication: `x --> Ax`. ```python # Make an operator acting like batch matrix A. Assume tensor_shape.TensorShape(A.shape) = [..., M, N] operator = LinearOperator(...) tensor_shape.TensorShape(operator.shape) = [..., M, N] X = ... # shape [..., N, R], batch matrix, R > 0. Y = operator.matmul(X) tensor_shape.TensorShape(Y.shape) ==> [..., M, R] Y[..., :, r] = sum_j A[..., :, j] X[j, r] ``` Args: x: `LinearOperator`, `Tensor` with compatible shape and same `dtype` as `self`, or a blockwise iterable of `LinearOperator`s or `Tensor`s. See class docstring for definition of shape compatibility. adjoint: Python `bool`. If `True`, left multiply by the adjoint: `A^H x`. adjoint_arg: Python `bool`. If `True`, compute `A x^H` where `x^H` is the hermitian transpose (transposition and complex conjugation). name: A name for this `Op`. Returns: A `LinearOperator` or `Tensor` with shape `[..., M, R]` and same `dtype` as `self`, or if `x` is blockwise, a list of `Tensor`s with shapes that concatenate to `[..., M, R]`. """ if isinstance(x, linear_operator.LinearOperator): left_operator = self.adjoint() if adjoint else self right_operator = x.adjoint() if adjoint_arg else x if (right_operator.range_dimension is not None and left_operator.domain_dimension is not None and right_operator.range_dimension != left_operator.domain_dimension): raise ValueError( "Operators are incompatible. Expected `x` to have dimension" " {} but got {}.".format(left_operator.domain_dimension, right_operator.range_dimension)) with self._name_scope(name): return linear_operator_algebra.matmul(left_operator, right_operator) with self._name_scope(name): arg_dim = -1 if adjoint_arg else -2 block_dimensions = (self._block_range_dimensions() if adjoint else self._block_domain_dimensions()) if linear_operator_util.arg_is_blockwise(block_dimensions, x, arg_dim): for i, block in enumerate(x): if not isinstance(block, linear_operator.LinearOperator): block = ops.convert_to_tensor(block) # self._check_input_dtype(block) block_dimensions[i].assert_is_compatible_with( tensor_shape.TensorShape(block.shape)[arg_dim]) x[i] = block else: x = ops.convert_to_tensor(x, name="x") # self._check_input_dtype(x) op_dimension = (self.range_dimension if adjoint else self.domain_dimension) op_dimension.assert_is_compatible_with( tensor_shape.TensorShape(x.shape)[arg_dim]) return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)