def cholesky_concat(chol, cols, name=None): """Concatenates `chol @ chol.T` with additional rows and columns. This operation is conceptually identical to: ```python def cholesky_concat_slow(chol, cols): # cols shaped (n + m) x m = z x m mat = tf.matmul(chol, chol, adjoint_b=True) # batch of n x n # Concat columns. mat = tf.concat([mat, cols[..., :tf.shape(mat)[-2], :]], axis=-1) # n x z # Concat rows. mat = tf.concat([mat, tf.linalg.matrix_transpose(cols)], axis=-2) # z x z return tf.linalg.cholesky(mat) ``` but whereas `cholesky_concat_slow` would cost `O(z**3)` work, `cholesky_concat` only costs `O(z**2 + m**3)` work. The resulting (implicit) matrix must be symmetric and positive definite. Thus, the bottom right `m x m` must be self-adjoint, and we do not require a separate `rows` argument (which can be inferred from `conj(cols.T)`). Args: chol: Cholesky decomposition of `mat = chol @ chol.T`. cols: The new columns whose first `n` rows we would like concatenated to the right of `mat = chol @ chol.T`, and whose conjugate transpose we would like concatenated to the bottom of `concat(mat, cols[:n,:])`. A `Tensor` with final dims `(n+m, m)`. The first `n` rows are the top right rectangle (their conjugate transpose forms the bottom left), and the bottom `m x m` is self-adjoint. name: Optional name for this op. Returns: chol_concat: The Cholesky decomposition of: ``` [ [ mat cols[:n, :] ] [ conj(cols.T) ] ] ``` """ with tf.compat.v2.name_scope(name or 'cholesky_extend'): dtype = dtype_util.common_dtype([chol, cols], preferred_dtype=tf.float32) chol = tf.convert_to_tensor(value=chol, name='chol', dtype=dtype) cols = tf.convert_to_tensor(value=cols, name='cols', dtype=dtype) n = prefer_static.shape(chol)[-1] mat_nm, mat_mm = cols[..., :n, :], cols[..., n:, :] solved_nm = linear_operator_util.matrix_triangular_solve_with_broadcast( chol, mat_nm) lower_right_mm = tf.linalg.cholesky( mat_mm - tf.matmul(solved_nm, solved_nm, adjoint_a=True)) lower_left_mn = tf.math.conj(tf.linalg.matrix_transpose(solved_nm)) out_batch = prefer_static.shape(solved_nm)[:-2] chol = tf.broadcast_to( chol, tf.concat([out_batch, prefer_static.shape(chol)[-2:]], axis=0)) top_right_zeros_nm = tf.zeros_like(solved_nm) return tf.concat([ tf.concat([chol, top_right_zeros_nm], axis=-1), tf.concat([lower_left_mn, lower_right_mm], axis=-1) ], axis=-2)
def test_static_dims_broadcast_matrix_has_extra_dims(self): # batch_shape = [2] matrix = rng.rand(2, 3, 3) rhs = rng.rand(3, 7) rhs_broadcast = rhs + np.zeros((2, 1, 1)) result = linear_operator_util.matrix_triangular_solve_with_broadcast( matrix, rhs) self.assertAllEqual((2, 3, 7), result.shape) expected = linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast) self.assertAllClose(*self.evaluate([expected, result]))
def test_static_dims_broadcast(self): # batch_shape = [2] matrix = rng.rand(2, 3, 3) rhs = rng.rand(3, 7) rhs_broadcast = rhs + np.zeros((2, 1, 1)) with self.cached_session(): result = linear_operator_util.matrix_triangular_solve_with_broadcast( matrix, rhs) self.assertAllEqual((2, 3, 7), result.get_shape()) expected = linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast) self.assertAllEqual(expected.eval(), result.eval())
def test_dynamic_dims_broadcast_64bit(self): # batch_shape = [2] matrix = rng.rand(2, 3, 3) rhs = rng.rand(3, 7) rhs_broadcast = rhs + np.zeros((2, 1, 1)) matrix_ph = array_ops.placeholder_with_default(matrix, shape=None) rhs_ph = array_ops.placeholder_with_default(rhs, shape=None) result, expected = self.evaluate([ linear_operator_util.matrix_triangular_solve_with_broadcast( matrix_ph, rhs_ph), linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast) ]) self.assertAllClose(expected, result)
def test_static_dims_broadcast_rhs_has_extra_dims(self): # Since the second arg has extra dims, and the domain dim of the first arg # is larger than the number of linear equations, code will "flip" the extra # dims of the first arg to the far right, making extra linear equations # (then call the matrix function, then flip back). # We have verified that this optimization indeed happens. How? We stepped # through with a debugger. # batch_shape = [2] matrix = rng.rand(3, 3) rhs = rng.rand(2, 3, 2) matrix_broadcast = matrix + np.zeros((2, 1, 1)) result = linear_operator_util.matrix_triangular_solve_with_broadcast( matrix, rhs) self.assertAllEqual((2, 3, 2), result.shape) expected = linalg_ops.matrix_triangular_solve(matrix_broadcast, rhs) self.assertAllClose(*self.evaluate([expected, result]))
def test_static_dims_broadcast_rhs_has_extra_dims(self): # Since the second arg has extra dims, and the domain dim of the first arg # is larger than the number of linear equations, code will "flip" the extra # dims of the first arg to the far right, making extra linear equations # (then call the matrix function, then flip back). # We have verified that this optimization indeed happens. How? We stepped # through with a debugger. # batch_shape = [2] matrix = rng.rand(3, 3) rhs = rng.rand(2, 3, 2) matrix_broadcast = matrix + np.zeros((2, 1, 1)) with self.cached_session(): result = linear_operator_util.matrix_triangular_solve_with_broadcast( matrix, rhs) self.assertAllEqual((2, 3, 2), result.get_shape()) expected = linalg_ops.matrix_triangular_solve(matrix_broadcast, rhs) self.assertAllClose(expected.eval(), result.eval())
def test_dynamic_dims_broadcast_64bit(self): # batch_shape = [2] matrix = rng.rand(2, 3, 3) rhs = rng.rand(3, 7) rhs_broadcast = rhs + np.zeros((2, 1, 1)) matrix_ph = array_ops.placeholder(dtypes.float64) rhs_ph = array_ops.placeholder(dtypes.float64) with self.cached_session() as sess: result, expected = sess.run([ linear_operator_util.matrix_triangular_solve_with_broadcast( matrix_ph, rhs_ph), linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast) ], feed_dict={ matrix_ph: matrix, rhs_ph: rhs, }) self.assertAllClose(expected, result)
def test_dynamic_dims_broadcast_64bit(self): # batch_shape = [2] matrix = rng.rand(2, 3, 3) rhs = rng.rand(3, 7) rhs_broadcast = rhs + np.zeros((2, 1, 1)) matrix_ph = array_ops.placeholder(dtypes.float64) rhs_ph = array_ops.placeholder(dtypes.float64) with self.cached_session() as sess: result, expected = sess.run( [ linear_operator_util.matrix_triangular_solve_with_broadcast( matrix_ph, rhs_ph), linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast) ], feed_dict={ matrix_ph: matrix, rhs_ph: rhs, }) self.assertAllEqual(expected, result)
def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None): """Solves systems of linear eqns `A X = RHS`, given LU factorizations. Note: this function does not verify the implied matrix is actually invertible nor is this condition checked even when `validate_args=True`. Args: lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`. perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`. rhs: Matrix-shaped float `Tensor` representing targets for which to solve; `A X = RHS`. To handle vector cases, use: `lu_solve(..., rhs[..., tf.newaxis])[..., 0]`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. Note: this function does not verify the implied matrix is actually invertible, even when `validate_args=True`. Default value: `False` (i.e., don't validate arguments). name: Python `str` name given to ops managed by this object. Default value: `None` (i.e., 'lu_solve'). Returns: x: The `X` in `A @ X = RHS`. #### Examples ```python import numpy as np import tensorflow as tf import tensorflow_probability as tfp x = [[[1., 2], [3, 4]], [[7, 8], [3, 4]]] inv_x = tfp.math.lu_solve(*tf.linalg.lu(x), rhs=tf.eye(2)) tf.assert_near(tf.matrix_inverse(x), inv_x) # ==> True ``` """ with tf.name_scope(name or 'lu_solve'): lower_upper = tf.convert_to_tensor(lower_upper, dtype_hint=tf.float32, name='lower_upper') perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm') rhs = tf.convert_to_tensor(rhs, dtype_hint=lower_upper.dtype, name='rhs') assertions = _lu_solve_assertions(lower_upper, perm, rhs, validate_args) if assertions: with tf.control_dependencies(assertions): lower_upper = tf.identity(lower_upper) perm = tf.identity(perm) rhs = tf.identity(rhs) if (tensorshape_util.rank(rhs.shape) == 2 and tensorshape_util.rank(perm.shape) == 1): # Both rhs and perm have scalar batch_shape. permuted_rhs = tf.gather(rhs, perm, axis=-2) else: # Either rhs or perm have non-scalar batch_shape or we can't determine # this information statically. rhs_shape = tf.shape(rhs) broadcast_batch_shape = tf.broadcast_dynamic_shape( rhs_shape[:-2], tf.shape(perm)[:-1]) d, m = rhs_shape[-2], rhs_shape[-1] rhs_broadcast_shape = tf.concat([broadcast_batch_shape, [d, m]], axis=0) # Tile out rhs. broadcast_rhs = tf.broadcast_to(rhs, rhs_broadcast_shape) broadcast_rhs = tf.reshape(broadcast_rhs, [-1, d, m]) # Tile out perm and add batch indices. broadcast_perm = tf.broadcast_to(perm, rhs_broadcast_shape[:-1]) broadcast_perm = tf.reshape(broadcast_perm, [-1, d]) broadcast_batch_size = tf.reduce_prod(broadcast_batch_shape) broadcast_batch_indices = tf.broadcast_to( tf.range(broadcast_batch_size)[:, tf.newaxis], [broadcast_batch_size, d]) broadcast_perm = tf.stack( [broadcast_batch_indices, broadcast_perm], axis=-1) permuted_rhs = tf.gather_nd(broadcast_rhs, broadcast_perm) permuted_rhs = tf.reshape(permuted_rhs, rhs_broadcast_shape) lower = tf.linalg.set_diag( tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0), tf.ones(tf.shape(lower_upper)[:-1], dtype=lower_upper.dtype)) return linear_operator_util.matrix_triangular_solve_with_broadcast( lower_upper, # Only upper is accessed. linear_operator_util.matrix_triangular_solve_with_broadcast( lower, permuted_rhs), lower=False)
def _solve(self, rhs, adjoint=False, adjoint_arg=False): rhs = linalg.adjoint(rhs) if adjoint_arg else rhs return linear_operator_util.matrix_triangular_solve_with_broadcast( self._tril, rhs, lower=True, adjoint=adjoint)
def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None): """Solves systems of linear eqns `A X = RHS`, given LU factorizations. Note: this function does not verify the implied matrix is actually invertible nor is this condition checked even when `validate_args=True`. Args: lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U`. perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`. rhs: `Tensor` representing targets for which to solve; `A X = RHS`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. Note: this function does not verify the implied matrix is actually invertible, even when `validate_args=True`. Default value: `False` (i.e., don't validate arguments). name: Python `str` name given to ops managed by this object. Default value: `None` (i.e., "lu_solve"). Returns: x: The x in `A X = RHS`. #### Examples ```python import numpy as np import tensorflow as tf import tensorflow_probability as tfp x = [[[3., 4], [1, 2]], [[7., 8], [3, 4]]] inv_x = tfp.math.lu_solve( *tf.linalg.lu(x), rhs=tf.eye(2, batch_shape=[2])) tf.assert_near(tf.matrix_inverse(x), inv_x) # ==> True ``` """ with tf.name_scope(name, 'lu_solve', [lower_upper, perm, rhs]): lower_upper = tf.convert_to_tensor(lower_upper, preferred_dtype=tf.float32, name='lower_upper') perm = tf.convert_to_tensor(perm, preferred_dtype=tf.int32, name='perm') rhs = tf.convert_to_tensor(rhs, preferred_dtype=lower_upper.dtype, name='rhs') assertions = _lu_reconstruct_assertions(lower_upper, perm, validate_args) if assertions: with tf.control_dependencies(assertions): lower_upper = tf.identity(lower_upper) perm = tf.identity(perm) rhs = tf.identity(rhs) shape = tf.shape(lower_upper) d = shape[-1] lower = tf.linalg.set_diag( tf.matrix_band_part(lower_upper, num_lower=-1, num_upper=0), tf.ones(shape[:-1], dtype=lower_upper.dtype)) x = linear_operator_util.matrix_triangular_solve_with_broadcast( lower_upper, # Only upper is accessed. linear_operator_util.matrix_triangular_solve_with_broadcast( lower, rhs), lower=False) if lower_upper.shape.ndims is None or lower_upper.shape.ndims != 2: # We either don't know the batch rank or there are >0 batch dims. batch_size = tf.reduce_prod(shape[:-2]) x = tf.reshape(x, [batch_size, d, d]) perm = tf.reshape(perm, [batch_size, d]) batch_indices = tf.broadcast_to( tf.range(batch_size)[:, tf.newaxis], [batch_size, d]) x = tf.gather_nd(x, tf.stack([batch_indices, perm], axis=-1)) x = tf.reshape(x, shape) else: x = tf.gather(x, perm, axis=-1) x.set_shape(lower_upper.shape) return x
def _inverse(self, y): return linalg_util.matrix_triangular_solve_with_broadcast( matrix=self._R, rhs=self._Q_inv_operator.matvec(y)[..., tf.newaxis], lower=False, adjoint=False)[..., 0]
def _solve(self, rhs, adjoint=False, adjoint_arg=False): rhs = linalg.adjoint(rhs) if adjoint_arg else rhs return linear_operator_util.matrix_triangular_solve_with_broadcast( array_ops.matrix_set_diag(self._tril, math_ops.exp(self._diag)), rhs, lower=True, adjoint=adjoint)