Beispiel #1
0
def cholesky_concat(chol, cols, name=None):
    """Concatenates `chol @ chol.T` with additional rows and columns.

  This operation is conceptually identical to:
  ```python
  def cholesky_concat_slow(chol, cols):  # cols shaped (n + m) x m = z x m
    mat = tf.matmul(chol, chol, adjoint_b=True)  # batch of n x n
    # Concat columns.
    mat = tf.concat([mat, cols[..., :tf.shape(mat)[-2], :]], axis=-1)  # n x z
    # Concat rows.
    mat = tf.concat([mat, tf.linalg.matrix_transpose(cols)], axis=-2)  # z x z
    return tf.linalg.cholesky(mat)
  ```
  but whereas `cholesky_concat_slow` would cost `O(z**3)` work,
  `cholesky_concat` only costs `O(z**2 + m**3)` work.

  The resulting (implicit) matrix must be symmetric and positive definite.
  Thus, the bottom right `m x m` must be self-adjoint, and we do not require a
  separate `rows` argument (which can be inferred from `conj(cols.T)`).

  Args:
    chol: Cholesky decomposition of `mat = chol @ chol.T`.
    cols: The new columns whose first `n` rows we would like concatenated to the
      right of `mat = chol @ chol.T`, and whose conjugate transpose we would
      like concatenated to the bottom of `concat(mat, cols[:n,:])`. A `Tensor`
      with final dims `(n+m, m)`. The first `n` rows are the top right rectangle
      (their conjugate transpose forms the bottom left), and the bottom `m x m`
      is self-adjoint.
    name: Optional name for this op.

  Returns:
    chol_concat: The Cholesky decomposition of:
      ```
      [ [ mat  cols[:n, :] ]
        [   conj(cols.T)   ] ]
      ```
  """
    with tf.compat.v2.name_scope(name or 'cholesky_extend'):
        dtype = dtype_util.common_dtype([chol, cols],
                                        preferred_dtype=tf.float32)
        chol = tf.convert_to_tensor(value=chol, name='chol', dtype=dtype)
        cols = tf.convert_to_tensor(value=cols, name='cols', dtype=dtype)
        n = prefer_static.shape(chol)[-1]
        mat_nm, mat_mm = cols[..., :n, :], cols[..., n:, :]
        solved_nm = linear_operator_util.matrix_triangular_solve_with_broadcast(
            chol, mat_nm)
        lower_right_mm = tf.linalg.cholesky(
            mat_mm - tf.matmul(solved_nm, solved_nm, adjoint_a=True))
        lower_left_mn = tf.math.conj(tf.linalg.matrix_transpose(solved_nm))
        out_batch = prefer_static.shape(solved_nm)[:-2]
        chol = tf.broadcast_to(
            chol,
            tf.concat([out_batch, prefer_static.shape(chol)[-2:]], axis=0))
        top_right_zeros_nm = tf.zeros_like(solved_nm)
        return tf.concat([
            tf.concat([chol, top_right_zeros_nm], axis=-1),
            tf.concat([lower_left_mn, lower_right_mm], axis=-1)
        ],
                         axis=-2)
Beispiel #2
0
    def test_static_dims_broadcast_matrix_has_extra_dims(self):
        # batch_shape = [2]
        matrix = rng.rand(2, 3, 3)
        rhs = rng.rand(3, 7)
        rhs_broadcast = rhs + np.zeros((2, 1, 1))

        result = linear_operator_util.matrix_triangular_solve_with_broadcast(
            matrix, rhs)
        self.assertAllEqual((2, 3, 7), result.shape)
        expected = linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast)
        self.assertAllClose(*self.evaluate([expected, result]))
  def test_static_dims_broadcast(self):
    # batch_shape = [2]
    matrix = rng.rand(2, 3, 3)
    rhs = rng.rand(3, 7)
    rhs_broadcast = rhs + np.zeros((2, 1, 1))

    with self.cached_session():
      result = linear_operator_util.matrix_triangular_solve_with_broadcast(
          matrix, rhs)
      self.assertAllEqual((2, 3, 7), result.get_shape())
      expected = linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast)
      self.assertAllEqual(expected.eval(), result.eval())
Beispiel #4
0
    def test_dynamic_dims_broadcast_64bit(self):
        # batch_shape = [2]
        matrix = rng.rand(2, 3, 3)
        rhs = rng.rand(3, 7)
        rhs_broadcast = rhs + np.zeros((2, 1, 1))

        matrix_ph = array_ops.placeholder_with_default(matrix, shape=None)
        rhs_ph = array_ops.placeholder_with_default(rhs, shape=None)

        result, expected = self.evaluate([
            linear_operator_util.matrix_triangular_solve_with_broadcast(
                matrix_ph, rhs_ph),
            linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast)
        ])
        self.assertAllClose(expected, result)
Beispiel #5
0
    def test_static_dims_broadcast_rhs_has_extra_dims(self):
        # Since the second arg has extra dims, and the domain dim of the first arg
        # is larger than the number of linear equations, code will "flip" the extra
        # dims of the first arg to the far right, making extra linear equations
        # (then call the matrix function, then flip back).
        # We have verified that this optimization indeed happens.  How? We stepped
        # through with a debugger.
        # batch_shape = [2]
        matrix = rng.rand(3, 3)
        rhs = rng.rand(2, 3, 2)
        matrix_broadcast = matrix + np.zeros((2, 1, 1))

        result = linear_operator_util.matrix_triangular_solve_with_broadcast(
            matrix, rhs)
        self.assertAllEqual((2, 3, 2), result.shape)
        expected = linalg_ops.matrix_triangular_solve(matrix_broadcast, rhs)
        self.assertAllClose(*self.evaluate([expected, result]))
  def test_static_dims_broadcast_rhs_has_extra_dims(self):
    # Since the second arg has extra dims, and the domain dim of the first arg
    # is larger than the number of linear equations, code will "flip" the extra
    # dims of the first arg to the far right, making extra linear equations
    # (then call the matrix function, then flip back).
    # We have verified that this optimization indeed happens.  How? We stepped
    # through with a debugger.
    # batch_shape = [2]
    matrix = rng.rand(3, 3)
    rhs = rng.rand(2, 3, 2)
    matrix_broadcast = matrix + np.zeros((2, 1, 1))

    with self.cached_session():
      result = linear_operator_util.matrix_triangular_solve_with_broadcast(
          matrix, rhs)
      self.assertAllEqual((2, 3, 2), result.get_shape())
      expected = linalg_ops.matrix_triangular_solve(matrix_broadcast, rhs)
      self.assertAllClose(expected.eval(), result.eval())
    def test_dynamic_dims_broadcast_64bit(self):
        # batch_shape = [2]
        matrix = rng.rand(2, 3, 3)
        rhs = rng.rand(3, 7)
        rhs_broadcast = rhs + np.zeros((2, 1, 1))

        matrix_ph = array_ops.placeholder(dtypes.float64)
        rhs_ph = array_ops.placeholder(dtypes.float64)

        with self.cached_session() as sess:
            result, expected = sess.run([
                linear_operator_util.matrix_triangular_solve_with_broadcast(
                    matrix_ph, rhs_ph),
                linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast)
            ],
                                        feed_dict={
                                            matrix_ph: matrix,
                                            rhs_ph: rhs,
                                        })
            self.assertAllClose(expected, result)
  def test_dynamic_dims_broadcast_64bit(self):
    # batch_shape = [2]
    matrix = rng.rand(2, 3, 3)
    rhs = rng.rand(3, 7)
    rhs_broadcast = rhs + np.zeros((2, 1, 1))

    matrix_ph = array_ops.placeholder(dtypes.float64)
    rhs_ph = array_ops.placeholder(dtypes.float64)

    with self.cached_session() as sess:
      result, expected = sess.run(
          [
              linear_operator_util.matrix_triangular_solve_with_broadcast(
                  matrix_ph, rhs_ph),
              linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast)
          ],
          feed_dict={
              matrix_ph: matrix,
              rhs_ph: rhs,
          })
      self.assertAllEqual(expected, result)
Beispiel #9
0
def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None):
    """Solves systems of linear eqns `A X = RHS`, given LU factorizations.

  Note: this function does not verify the implied matrix is actually invertible
  nor is this condition checked even when `validate_args=True`.

  Args:
    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`.
    perm: `p` as returned by `tf.linag.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`.
    rhs: Matrix-shaped float `Tensor` representing targets for which to solve;
      `A X = RHS`. To handle vector cases, use:
      `lu_solve(..., rhs[..., tf.newaxis])[..., 0]`.
    validate_args: Python `bool` indicating whether arguments should be checked
      for correctness. Note: this function does not verify the implied matrix is
      actually invertible, even when `validate_args=True`.
      Default value: `False` (i.e., don't validate arguments).
    name: Python `str` name given to ops managed by this object.
      Default value: `None` (i.e., 'lu_solve').

  Returns:
    x: The `X` in `A @ X = RHS`.

  #### Examples

  ```python
  import numpy as np
  import tensorflow as tf
  import tensorflow_probability as tfp

  x = [[[1., 2],
        [3, 4]],
       [[7, 8],
        [3, 4]]]
  inv_x = tfp.math.lu_solve(*tf.linalg.lu(x), rhs=tf.eye(2))
  tf.assert_near(tf.matrix_inverse(x), inv_x)
  # ==> True
  ```

  """

    with tf.name_scope(name or 'lu_solve'):
        lower_upper = tf.convert_to_tensor(lower_upper,
                                           dtype_hint=tf.float32,
                                           name='lower_upper')
        perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm')
        rhs = tf.convert_to_tensor(rhs,
                                   dtype_hint=lower_upper.dtype,
                                   name='rhs')

        assertions = _lu_solve_assertions(lower_upper, perm, rhs,
                                          validate_args)
        if assertions:
            with tf.control_dependencies(assertions):
                lower_upper = tf.identity(lower_upper)
                perm = tf.identity(perm)
                rhs = tf.identity(rhs)

        if (tensorshape_util.rank(rhs.shape) == 2
                and tensorshape_util.rank(perm.shape) == 1):
            # Both rhs and perm have scalar batch_shape.
            permuted_rhs = tf.gather(rhs, perm, axis=-2)
        else:
            # Either rhs or perm have non-scalar batch_shape or we can't determine
            # this information statically.
            rhs_shape = tf.shape(rhs)
            broadcast_batch_shape = tf.broadcast_dynamic_shape(
                rhs_shape[:-2],
                tf.shape(perm)[:-1])
            d, m = rhs_shape[-2], rhs_shape[-1]
            rhs_broadcast_shape = tf.concat([broadcast_batch_shape, [d, m]],
                                            axis=0)

            # Tile out rhs.
            broadcast_rhs = tf.broadcast_to(rhs, rhs_broadcast_shape)
            broadcast_rhs = tf.reshape(broadcast_rhs, [-1, d, m])

            # Tile out perm and add batch indices.
            broadcast_perm = tf.broadcast_to(perm, rhs_broadcast_shape[:-1])
            broadcast_perm = tf.reshape(broadcast_perm, [-1, d])
            broadcast_batch_size = tf.reduce_prod(broadcast_batch_shape)
            broadcast_batch_indices = tf.broadcast_to(
                tf.range(broadcast_batch_size)[:, tf.newaxis],
                [broadcast_batch_size, d])
            broadcast_perm = tf.stack(
                [broadcast_batch_indices, broadcast_perm], axis=-1)

            permuted_rhs = tf.gather_nd(broadcast_rhs, broadcast_perm)
            permuted_rhs = tf.reshape(permuted_rhs, rhs_broadcast_shape)

        lower = tf.linalg.set_diag(
            tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0),
            tf.ones(tf.shape(lower_upper)[:-1], dtype=lower_upper.dtype))
        return linear_operator_util.matrix_triangular_solve_with_broadcast(
            lower_upper,  # Only upper is accessed.
            linear_operator_util.matrix_triangular_solve_with_broadcast(
                lower, permuted_rhs),
            lower=False)
Beispiel #10
0
 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
     rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
     return linear_operator_util.matrix_triangular_solve_with_broadcast(
         self._tril, rhs, lower=True, adjoint=adjoint)
Beispiel #11
0
def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None):
    """Solves systems of linear eqns `A X = RHS`, given LU factorizations.

  Note: this function does not verify the implied matrix is actually invertible
  nor is this condition checked even when `validate_args=True`.

  Args:
    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U`.
    perm: `p` as returned by `tf.linag.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`.
    rhs: `Tensor` representing targets for which to solve; `A X = RHS`.
    validate_args: Python `bool` indicating whether arguments should be checked
      for correctness. Note: this function does not verify the implied matrix is
      actually invertible, even when `validate_args=True`.
      Default value: `False` (i.e., don't validate arguments).
    name: Python `str` name given to ops managed by this object.
      Default value: `None` (i.e., "lu_solve").

  Returns:
    x: The x in `A X = RHS`.

  #### Examples

  ```python
  import numpy as np
  import tensorflow as tf
  import tensorflow_probability as tfp

  x = [[[3., 4], [1, 2]],
       [[7., 8], [3, 4]]]
  inv_x = tfp.math.lu_solve(
    *tf.linalg.lu(x), rhs=tf.eye(2, batch_shape=[2]))
  tf.assert_near(tf.matrix_inverse(x), inv_x)
  # ==> True
  ```

  """

    with tf.name_scope(name, 'lu_solve', [lower_upper, perm, rhs]):
        lower_upper = tf.convert_to_tensor(lower_upper,
                                           preferred_dtype=tf.float32,
                                           name='lower_upper')
        perm = tf.convert_to_tensor(perm,
                                    preferred_dtype=tf.int32,
                                    name='perm')
        rhs = tf.convert_to_tensor(rhs,
                                   preferred_dtype=lower_upper.dtype,
                                   name='rhs')

        assertions = _lu_reconstruct_assertions(lower_upper, perm,
                                                validate_args)
        if assertions:
            with tf.control_dependencies(assertions):
                lower_upper = tf.identity(lower_upper)
                perm = tf.identity(perm)
                rhs = tf.identity(rhs)

        shape = tf.shape(lower_upper)

        d = shape[-1]
        lower = tf.linalg.set_diag(
            tf.matrix_band_part(lower_upper, num_lower=-1, num_upper=0),
            tf.ones(shape[:-1], dtype=lower_upper.dtype))
        x = linear_operator_util.matrix_triangular_solve_with_broadcast(
            lower_upper,  # Only upper is accessed.
            linear_operator_util.matrix_triangular_solve_with_broadcast(
                lower, rhs),
            lower=False)

        if lower_upper.shape.ndims is None or lower_upper.shape.ndims != 2:
            # We either don't know the batch rank or there are >0 batch dims.
            batch_size = tf.reduce_prod(shape[:-2])
            x = tf.reshape(x, [batch_size, d, d])
            perm = tf.reshape(perm, [batch_size, d])
            batch_indices = tf.broadcast_to(
                tf.range(batch_size)[:, tf.newaxis], [batch_size, d])
            x = tf.gather_nd(x, tf.stack([batch_indices, perm], axis=-1))
            x = tf.reshape(x, shape)
        else:
            x = tf.gather(x, perm, axis=-1)

        x.set_shape(lower_upper.shape)
        return x
 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
   rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
   return linear_operator_util.matrix_triangular_solve_with_broadcast(
       self._tril, rhs, lower=True, adjoint=adjoint)
Beispiel #13
0
 def _inverse(self, y):
     return linalg_util.matrix_triangular_solve_with_broadcast(
         matrix=self._R,
         rhs=self._Q_inv_operator.matvec(y)[..., tf.newaxis],
         lower=False,
         adjoint=False)[..., 0]
Beispiel #14
0
 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
     rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
     return linear_operator_util.matrix_triangular_solve_with_broadcast(
         array_ops.matrix_set_diag(self._tril, math_ops.exp(self._diag)), rhs, lower=True, adjoint=adjoint)