def _batch_shape_tensor(self, loc=None, scale=None, concentration=None):
     return functools.reduce(
         prefer_static.broadcast_shape,
         (prefer_static.shape(self.loc if loc is None else loc),
          prefer_static.shape(self.scale if scale is None else scale),
          prefer_static.shape(self.concentration
                              if concentration is None else concentration)))
 def _batch_shape_tensor(self, logits_or_probs=None, total_count=None):
     if logits_or_probs is None:
         logits_or_probs = self._logits if self._probs is None else self._logits
     total_count = self._total_count if total_count is None else total_count
     return prefer_static.broadcast_shape(
         prefer_static.shape(logits_or_probs),
         prefer_static.shape(total_count))
Пример #3
0
def _swap_m_with_i(vecs, m, i):
    """Swaps `m` and `i` on axis -1. (Helper for pivoted_cholesky.)

  Given a batch of int64 vectors `vecs`, scalar index `m`, and compatibly shaped
  per-vector indices `i`, this function swaps elements `m` and `i` in each
  vector. For the use-case below, these are permutation vectors.

  Args:
    vecs: Vectors on which we perform the swap, int64 `Tensor`.
    m: Scalar int64 `Tensor`, the index into which the `i`th element is going.
    i: Batch int64 `Tensor`, shaped like vecs.shape[:-1] + [1]; the index into
      which the `m`th element is going.

  Returns:
    vecs: The updated vectors.
  """
    vecs = tf.convert_to_tensor(vecs, dtype=tf.int64, name='vecs')
    m = tf.convert_to_tensor(m, dtype=tf.int64, name='m')
    i = tf.convert_to_tensor(i, dtype=tf.int64, name='i')
    trailing_elts = tf.broadcast_to(
        tf.range(m + 1,
                 prefer_static.shape(vecs, out_type=tf.int64)[-1]),
        prefer_static.shape(vecs[..., m + 1:]))
    trailing_elts = tf.where(tf.equal(trailing_elts, i),
                             tf.gather(vecs, [m], axis=-1), vecs[..., m + 1:])
    # TODO(bjp): Could we use tensor_scatter_nd_update?
    vecs_shape = vecs.shape
    vecs = tf.concat([
        vecs[..., :m],
        tf.gather(vecs, i, batch_dims=int(prefer_static.rank(vecs)) - 1),
        trailing_elts
    ],
                     axis=-1)
    tensorshape_util.set_shape(vecs, vecs_shape)
    return vecs
Пример #4
0
 def _batch_shape_tensor(self, temperature=None, logits=None):
     param = logits
     if param is None:
         param = self._logits if self._logits is not None else self._probs
     if temperature is None:
         temperature = self.temperature
     return prefer_static.broadcast_shape(prefer_static.shape(temperature),
                                          prefer_static.shape(param)[:-1])
Пример #5
0
def cholesky_concat(chol, cols, name=None):
    """Concatenates `chol @ chol.T` with additional rows and columns.

  This operation is conceptually identical to:
  ```python
  def cholesky_concat_slow(chol, cols):  # cols shaped (n + m) x m = z x m
    mat = tf.matmul(chol, chol, adjoint_b=True)  # batch of n x n
    # Concat columns.
    mat = tf.concat([mat, cols[..., :tf.shape(mat)[-2], :]], axis=-1)  # n x z
    # Concat rows.
    mat = tf.concat([mat, tf.linalg.matrix_transpose(cols)], axis=-2)  # z x z
    return tf.linalg.cholesky(mat)
  ```
  but whereas `cholesky_concat_slow` would cost `O(z**3)` work,
  `cholesky_concat` only costs `O(z**2 + m**3)` work.

  The resulting (implicit) matrix must be symmetric and positive definite.
  Thus, the bottom right `m x m` must be self-adjoint, and we do not require a
  separate `rows` argument (which can be inferred from `conj(cols.T)`).

  Args:
    chol: Cholesky decomposition of `mat = chol @ chol.T`.
    cols: The new columns whose first `n` rows we would like concatenated to the
      right of `mat = chol @ chol.T`, and whose conjugate transpose we would
      like concatenated to the bottom of `concat(mat, cols[:n,:])`. A `Tensor`
      with final dims `(n+m, m)`. The first `n` rows are the top right rectangle
      (their conjugate transpose forms the bottom left), and the bottom `m x m`
      is self-adjoint.
    name: Optional name for this op.

  Returns:
    chol_concat: The Cholesky decomposition of:
      ```
      [ [ mat  cols[:n, :] ]
        [   conj(cols.T)   ] ]
      ```
  """
    with tf.name_scope(name or 'cholesky_extend'):
        dtype = dtype_util.common_dtype([chol, cols], dtype_hint=tf.float32)
        chol = tf.convert_to_tensor(chol, name='chol', dtype=dtype)
        cols = tf.convert_to_tensor(cols, name='cols', dtype=dtype)
        n = prefer_static.shape(chol)[-1]
        mat_nm, mat_mm = cols[..., :n, :], cols[..., n:, :]
        solved_nm = linear_operator_util.matrix_triangular_solve_with_broadcast(
            chol, mat_nm)
        lower_right_mm = tf.linalg.cholesky(
            mat_mm - tf.matmul(solved_nm, solved_nm, adjoint_a=True))
        lower_left_mn = tf.math.conj(tf.linalg.matrix_transpose(solved_nm))
        out_batch = prefer_static.shape(solved_nm)[:-2]
        chol = tf.broadcast_to(
            chol,
            tf.concat([out_batch, prefer_static.shape(chol)[-2:]], axis=0))
        top_right_zeros_nm = tf.zeros_like(solved_nm)
        return tf.concat([
            tf.concat([chol, top_right_zeros_nm], axis=-1),
            tf.concat([lower_left_mn, lower_right_mm], axis=-1)
        ],
                         axis=-2)
Пример #6
0
 def _reshape_part(part, event_shape):
     part = tf.cast(part, self.dtype)
     static_rank = tf.get_static_value(ps.rank_from_shape(event_shape))
     if static_rank == 1:
         return part
     new_shape = ps.concat([
         ps.shape(part)[:ps.size(ps.shape(part)) -
                        ps.size(event_shape)], [-1]
     ],
                           axis=-1)
     return tf.reshape(part, ps.cast(new_shape, tf.int32))
Пример #7
0
 def _batch_shape_tensor(self, distributions=None):
   if distributions is None:
     distributions = self.poisson_and_mixture_distributions()
   dist, mixture_dist = distributions
   return tf.broadcast_dynamic_shape(
       dist.batch_shape_tensor(),
       prefer_static.shape(mixture_dist.logits))[:-1]
Пример #8
0
 def _cdf(self, x):
   df = tf.convert_to_tensor(self.df)
   # Take Abs(scale) to make subsequent where work correctly.
   y = (x - self.loc) / tf.abs(self.scale)
   x_t = df / (y**2. + df)
   neg_cdf = 0.5 * tf.math.betainc(
       0.5 * tf.broadcast_to(df, prefer_static.shape(x_t)), 0.5, x_t)
   return tf.where(y < 0., neg_cdf, 1. - neg_cdf)
    def _inverse(self, y):
        n = prefer_static.shape(y)[-1]
        batch_shape = prefer_static.shape(y)[:-2]

        # Extract the reciprocal of the row norms from the diagonal.
        diag = tf.linalg.diag_part(y)[..., tf.newaxis]

        # Set the diagonal to 0s.
        y = tf.linalg.set_diag(
            y, tf.zeros(tf.concat([batch_shape, [n]], axis=-1), dtype=y.dtype))

        # Multiply with the norm (or divide by its reciprocal) to recover the
        # unconstrained reals in the (strictly) lower triangular part.
        x = y / diag

        # Remove the first row and last column before inverting the FillTriangular
        # transformation.
        return fill_triangular.FillTriangular().inverse(x[..., 1:, :-1])
    def _forward(self, x):
        x = tf.convert_to_tensor(x, name='x')
        batch_shape = prefer_static.shape(x)[:-1]

        # Pad zeros on the top row and right column.
        y = fill_triangular.FillTriangular().forward(x)
        rank = prefer_static.rank(y)
        paddings = tf.concat([
            tf.zeros(shape=(rank - 2, 2), dtype=tf.int32),
            tf.constant([[1, 0], [0, 1]], dtype=tf.int32)
        ],
                             axis=0)
        y = tf.pad(y, paddings)

        # Set diagonal to 1s.
        n = prefer_static.shape(y)[-1]
        diag = tf.ones(tf.concat([batch_shape, [n]], axis=-1), dtype=x.dtype)
        y = tf.linalg.set_diag(y, diag)

        # Normalize each row to have Euclidean (L2) norm 1.
        y /= tf.norm(y, axis=-1)[..., tf.newaxis]
        return y
Пример #11
0
 def _forward_log_det_jacobian(self, x):
     # This code is similar to tf.math.log_softmax but different because we have
     # an implicit zero column to handle. I.e., instead of:
     #   reduce_sum(logits - reduce_sum(exp(logits), dim))
     # we must do:
     #   log_normalization = 1 + reduce_sum(exp(logits))
     #   -log_normalization + reduce_sum(logits - log_normalization)
     n = prefer_static.shape(x)[-1]
     log_normalization = tf.math.softplus(
         tf.reduce_logsumexp(x, axis=-1, keepdims=True))
     return tf.squeeze(
         (-log_normalization +
          tf.reduce_sum(x - log_normalization, axis=-1, keepdims=True)),
         axis=-1) + 0.5 * tf.math.log(tf.cast(n + 1, dtype=x.dtype))
 def _sample_n(self, n, seed=None):
     logits = self._logits_parameter_no_checks()
     sample_shape = prefer_static.concat(
         [[n], prefer_static.shape(logits)], 0)
     event_size = self._event_size(logits)
     if tensorshape_util.rank(logits.shape) == 2:
         logits_2d = logits
     else:
         logits_2d = tf.reshape(logits, [-1, event_size])
     samples = tf.random.categorical(logits_2d, n, seed=seed)
     samples = tf.transpose(a=samples)
     samples = tf.one_hot(samples, event_size, dtype=self.dtype)
     ret = tf.reshape(samples, sample_shape)
     return ret
  def _assert_compatible_shape(self, index, sample_shape, samples):
    requested_shape, _ = self._expand_sample_shape_to_vector(
        tf.convert_to_tensor(sample_shape, dtype=tf.int32),
        name='requested_shape')
    actual_shape = prefer_static.shape(samples)
    actual_rank = prefer_static.rank_from_shape(actual_shape)
    requested_rank = prefer_static.rank_from_shape(requested_shape)

    # We test for two properties we expect of yielded distributions:
    # (1) The rank of the tensor of generated samples must be at least
    #     as large as the rank requested.
    # (2) The requested shape must be a prefix of the shape of the
    #     generated tensor of samples.
    # We attempt to perform test (1) statically first.
    # We don't need to do this explicitly for test (2) because
    # `assert_equal` evaluates statically if it can.
    static_actual_rank = tf.get_static_value(actual_rank)
    static_requested_rank = tf.get_static_value(requested_rank)

    assertion_message = ('Samples yielded by distribution #{} are not '
                         'consistent with `sample_shape` passed to '
                         '`JointDistributionCoroutine` '
                         'distribution.'.format(index))

    # TODO Remove this static check (b/138738650)
    if (static_actual_rank is not None and
        static_requested_rank is not None):
      # We're able to statically check the rank
      if static_actual_rank < static_requested_rank:
        raise ValueError(assertion_message)
      else:
        control_dependencies = []
    else:
      # We're not able to statically check the rank
      control_dependencies = [
          assert_util.assert_greater_equal(
              actual_rank, requested_rank,
              message=assertion_message)
          ]

    with tf.control_dependencies(control_dependencies):
      trimmed_actual_shape = actual_shape[:requested_rank]

    control_dependencies = [
        assert_util.assert_equal(
            requested_shape, trimmed_actual_shape,
            message=assertion_message)
    ]

    return control_dependencies
Пример #14
0
    def _inverse(self, y):
        ndims = prefer_static.rank(y)
        shifted_y = tf.pad(
            tf.slice(
                y, tf.zeros(ndims, dtype=tf.int32),
                prefer_static.shape(y) -
                tf.one_hot(ndims + self.axis, ndims, dtype=tf.int32)
            ),  # Remove the last entry of y in the chosen dimension.
            paddings=tf.one_hot(
                tf.one_hot(ndims + self.axis, ndims, on_value=0, off_value=-1),
                2,
                dtype=tf.int32
            )  # Insert zeros at the beginning of the chosen dimension.
        )

        return y - shifted_y
 def _inverse_log_det_jacobian(self, y):
     # The inverse log det jacobian (ILDJ) of the entire mapping is the sum of
     # the ILDJs of each row's mapping.
     #
     # To compute the ILDJ for each row's mapping, consider the forward mapping
     # `f_k` restricted to the `k`th (1-indexed) row. It maps unconstrained reals
     # in `R^{k-1}` to unit vectors in `R^k`. `f_k : R^{k-1} -> R^k` is given by:
     #
     #   f(x_1, x_2, ... x_{k-1}) = (x_1/s, x_2/s, ..., x_{k-1}/s, 1/s)
     #
     # where `s = norm(x_1, x_2, ..., x_{k-1}, 1)`.
     #
     # The change in infinitesimal `k-1`-dimensional volume (or surface area) is
     # given by sqrt(|det J^T J|); where J is the `k x (k-1)` Jacobian matrix.
     #
     # Claim: sqrt(|det(J^T J)|) = s^{-k}.
     #
     # Proof: We compute the entries of the Jacobian matrix J:
     #
     #     J_{i, j} =  -x_j / s^3           if i == k
     #     J_{i, j} =  (s^2 - x_i^2) / s^3  if i == j and i < k
     #     J_{i, j} = -(x_i * x_j) / s^3    if i != j and i < k
     #
     #   By spherical symmetry, the volume element depends only on `s`; w.l.o.g.
     #   we can assume that `x_1 = r` and `x_2, ..., x_n = 0`; where
     #   `r^2 + 1 = s^2`.
     #
     #   We can write `J^T = [A|B]` where `A` is a diagonal matrix of rank `k-1`
     #   with diagonal `(1/s^3, 1/s, 1/s, ..., 1/s)`; and `B` is a column vector
     #   of size `k-1`, with entries (-r/s^3, 0, 0, ..., 0). Hence,
     #
     #     det(J^T J) = det(diag((r^2 + 1) / s^6, 1/s^2, ..., s^2))
     #                = s^{-2k}.
     #
     #   Or, sqrt(|det(J^T J)|) = s^{-k}.
     #
     # Hence, the forward log det jacobian (FLDJ) for the `k`th row is given by
     # `-k * log(s)`. The ILDJ is equal to negative FLDJ at the pre-image, or,
     # `k * log(s)`; where `s` is the reciprocal of the `k`th diagonal entry.
     #
     n = prefer_static.shape(y)[-1]
     return -tf.reduce_sum(tf.range(1, n + 1, dtype=y.dtype) *
                           tf.math.log(tf.linalg.diag_part(y)),
                           axis=-1)
def maybe_check_wont_broadcast(flat_xs, validate_args):
    """Verifies that `parts` don't broadcast."""
    flat_xs = tuple(flat_xs)  # So we can receive generators.
    if not validate_args:
        # Note: we don't try static validation because it is theoretically
        # possible that a user wants to take advantage of broadcasting.
        # Only when `validate_args` is `True` do we enforce the validation.
        return flat_xs
    msg = 'Broadcasting probably indicates an error in model specification.'
    s = tuple(prefer_static.shape(x) for x in flat_xs)
    if all(prefer_static.is_numpy(s_) for s_ in s):
        if not all(np.all(a == b) for a, b in zip(s[1:], s[:-1])):
            raise ValueError(msg)
        return flat_xs
    assertions = [
        assert_util.assert_equal(a, b, message=msg)
        for a, b in zip(s[1:], s[:-1])
    ]
    with tf.control_dependencies(assertions):
        return tuple(tf.identity(x) for x in flat_xs)
Пример #17
0
 def _inverse_log_det_jacobian(self, y):
     # Let B be the forward map defined by the bijector. Consider the map
     # F : R^n -> R^n where the image of B in R^{n+1} is restricted to the first
     # n coordinates.
     #
     # Claim: det{ dF(X)/dX } = prod(Y) where Y = B(X).
     # Proof: WLOG, in vector notation:
     #     X = log(Y[:-1]) - log(Y[-1])
     #   where,
     #     Y[-1] = 1 - sum(Y[:-1]).
     #   We have:
     #     det{dF} = 1 / det{ dX/dF(X} }                                      (1)
     #             = 1 / det{ diag(1 / Y[:-1]) + 1 / Y[-1] }
     #             = 1 / det{ inv{ diag(Y[:-1]) - Y[:-1]' Y[:-1] } }
     #             = det{ diag(Y[:-1]) - Y[:-1]' Y[:-1] }
     #             = (1 + Y[:-1]' inv{diag(Y[:-1])} Y[:-1]) det{diag(Y[:-1])} (2)
     #             = Y[-1] prod(Y[:-1])
     #             = prod(Y)
     #
     # Let P be the image of R^n under F. Define the lift G, from P to R^{n+1},
     # which appends the last coordinate, Y[-1] := 1 - \sum_k Y_k. G is linear,
     # so its Jacobian is constant.
     #
     # The differential of G, DG, is eye(n) with a row of -1s appended to the
     # bottom. To compute the Jacobian sqrt{det{(DG)^T(DG)}}, one can see that
     # (DG)^T(DG) = A + eye(n), where A is the n x n matrix of 1s. This has
     # eigenvalues (n + 1, 1,...,1), so the determinant is (n + 1). Hence, the
     # Jacobian of G is sqrt{n + 1} everywhere.
     #
     # Putting it all together, the forward bijective map B can be written as
     # B(X) = G(F(X)) and has Jacobian sqrt{n + 1} * prod(F(X)).
     #
     # (1) - https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula
     #       or by noting that det{ dX/dY } = 1 / det{ dY/dX } from Bijector
     #       docstring "Tip".
     # (2) - https://en.wikipedia.org/wiki/Matrix_determinant_lemma
     n_plus_one = prefer_static.shape(y)[-1]
     return -tf.reduce_sum(tf.math.log(y), axis=-1) - 0.5 * tf.math.log(
         tf.cast(n_plus_one, dtype=y.dtype))
Пример #18
0
 def _event_shape_tensor(self, logits=None):
     param = logits
     if param is None:
         param = self._logits if self._logits is not None else self._probs
     return prefer_static.shape(param)[-1:]
Пример #19
0
 def _batch_shape_tensor(self, concentration=None, rate=None):
   return prefer_static.broadcast_shape(
       prefer_static.shape(
           self.concentration if concentration is None else concentration),
       prefer_static.shape(self.rate if rate is None else rate))
Пример #20
0
 def _batch_shape_tensor(self, concentration1=None, concentration0=None):
     return prefer_static.broadcast_shape(
         prefer_static.shape(self.concentration1
                             if concentration1 is None else concentration1),
         prefer_static.shape(self.concentration0
                             if concentration0 is None else concentration0))
 def _event_shape_tensor(self):
     param = self._logits if self._logits is not None else self._probs
     # NOTE: If the last dimension of `param.shape` is statically-known, but
     # the `param.shape` is not statically-known, then we will *not* return a
     # statically-known event size here.  This could be fixed.
     return prefer_static.shape(param)[-1:]
 def _batch_shape_tensor(self):
     param = self._logits if self._logits is not None else self._probs
     return prefer_static.shape(param)[:-1]
Пример #23
0
 def _batch_shape_tensor(self, loc=None, scale=None):
   return prefer_static.broadcast_shape(
       prefer_static.shape(self.loc if loc is None else loc),
       prefer_static.shape(self.scale if scale is None else scale))
Пример #24
0
 def _batch_shape_tensor(self):
     return prefer_static.shape(self.concentration)
Пример #25
0
def pivoted_cholesky(matrix, max_rank, diag_rtol=1e-3, name=None):
    """Computes the (partial) pivoted cholesky decomposition of `matrix`.

  The pivoted Cholesky is a low rank approximation of the Cholesky decomposition
  of `matrix`, i.e. as described in [(Harbrecht et al., 2012)][1]. The
  currently-worst-approximated diagonal element is selected as the pivot at each
  iteration. This yields from a `[B1...Bn, N, N]` shaped `matrix` a `[B1...Bn,
  N, K]` shaped rank-`K` approximation `lr` such that `lr @ lr.T ~= matrix`.
  Note that, unlike the Cholesky decomposition, `lr` is not triangular even in
  a rectangular-matrix sense. However, under a permutation it could be made
  triangular (it has one more zero in each column as you move to the right).

  Such a matrix can be useful as a preconditioner for conjugate gradient
  optimization, i.e. as in [(Wang et al. 2019)][2], as matmuls and solves can be
  cheaply done via the Woodbury matrix identity, as implemented by
  `tf.linalg.LinearOperatorLowRankUpdate`.

  Args:
    matrix: Floating point `Tensor` batch of symmetric, positive definite
      matrices.
    max_rank: Scalar `int` `Tensor`, the rank at which to truncate the
      approximation.
    diag_rtol: Scalar floating point `Tensor` (same dtype as `matrix`). If the
      errors of all diagonal elements of `lr @ lr.T` are each lower than
      `element * diag_rtol`, iteration is permitted to terminate early.
    name: Optional name for the op.

  Returns:
    lr: Low rank pivoted Cholesky approximation of `matrix`.

  #### References

  [1]: H Harbrecht, M Peters, R Schneider. On the low-rank approximation by the
       pivoted Cholesky decomposition. _Applied numerical mathematics_,
       62(4):428-440, 2012.

  [2]: K. A. Wang et al. Exact Gaussian Processes on a Million Data Points.
       _arXiv preprint arXiv:1903.08114_, 2019. https://arxiv.org/abs/1903.08114
  """
    with tf.name_scope(name or 'pivoted_cholesky'):
        dtype = dtype_util.common_dtype([matrix, diag_rtol],
                                        dtype_hint=tf.float32)
        matrix = tf.convert_to_tensor(matrix, name='matrix', dtype=dtype)
        if tensorshape_util.rank(matrix.shape) is None:
            raise NotImplementedError(
                'Rank of `matrix` must be known statically')

        max_rank = tf.convert_to_tensor(max_rank,
                                        name='max_rank',
                                        dtype=tf.int64)
        max_rank = tf.minimum(
            max_rank,
            prefer_static.shape(matrix, out_type=tf.int64)[-1])
        diag_rtol = tf.convert_to_tensor(diag_rtol,
                                         dtype=dtype,
                                         name='diag_rtol')
        matrix_diag = tf.linalg.diag_part(matrix)
        # matrix is P.D., therefore all matrix_diag > 0, so we don't need abs.
        orig_error = tf.reduce_max(matrix_diag, axis=-1)

        def cond(m, pchol, perm, matrix_diag):
            """Condition for `tf.while_loop` continuation."""
            del pchol
            del perm
            error = tf.linalg.norm(matrix_diag, ord=1, axis=-1)
            max_err = tf.reduce_max(error / orig_error)
            return (m < max_rank) & (tf.equal(m, 0) | (max_err > diag_rtol))

        batch_dims = tensorshape_util.rank(matrix.shape) - 2

        def batch_gather(params, indices, axis=-1):
            return tf.gather(params, indices, axis=axis, batch_dims=batch_dims)

        def body(m, pchol, perm, matrix_diag):
            """Body of a single `tf.while_loop` iteration."""
            # Here is roughly a numpy, non-batched version of what's going to happen.
            # (See also Algorithm 1 of Harbrecht et al.)
            # 1: maxi = np.argmax(matrix_diag[perm[m:]]) + m
            # 2: maxval = matrix_diag[perm][maxi]
            # 3: perm[m], perm[maxi] = perm[maxi], perm[m]
            # 4: row = matrix[perm[m]][perm[m + 1:]]
            # 5: row -= np.sum(pchol[:m][perm[m + 1:]] * pchol[:m][perm[m]]], axis=-2)
            # 6: pivot = np.sqrt(maxval); row /= pivot
            # 7: row = np.concatenate([[[pivot]], row], -1)
            # 8: matrix_diag[perm[m:]] -= row**2
            # 9: pchol[m, perm[m:]] = row

            # Find the maximal position of the (remaining) permuted diagonal.
            # Steps 1, 2 above.
            permuted_diag = batch_gather(matrix_diag, perm[..., m:])
            maxi = tf.argmax(permuted_diag, axis=-1,
                             output_type=tf.int64)[..., tf.newaxis]
            maxval = batch_gather(permuted_diag, maxi)
            maxi = maxi + m
            maxval = maxval[..., 0]
            # Update perm: Swap perm[...,m] with perm[...,maxi]. Step 3 above.
            perm = _swap_m_with_i(perm, m, maxi)
            # Step 4.
            row = batch_gather(matrix, perm[..., m:m + 1], axis=-2)
            row = batch_gather(row, perm[..., m + 1:])
            # Step 5.
            prev_rows = pchol[..., :m, :]
            prev_rows_perm_m_onward = batch_gather(prev_rows, perm[...,
                                                                   m + 1:])
            prev_rows_pivot_col = batch_gather(prev_rows, perm[..., m:m + 1])
            row -= tf.reduce_sum(prev_rows_perm_m_onward * prev_rows_pivot_col,
                                 axis=-2)[..., tf.newaxis, :]
            # Step 6.
            pivot = tf.sqrt(maxval)[..., tf.newaxis, tf.newaxis]
            # Step 7.
            row = tf.concat([pivot, row / pivot], axis=-1)
            # TODO(b/130899118): Pad grad fails with int64 paddings.
            # Step 8.
            paddings = tf.concat([
                tf.zeros([prefer_static.rank(pchol) - 1, 2], dtype=tf.int32),
                [[tf.cast(m, tf.int32), 0]]
            ],
                                 axis=0)
            diag_update = tf.pad(row**2, paddings=paddings)[..., 0, :]
            reverse_perm = _invert_permutation(perm)
            matrix_diag -= batch_gather(diag_update, reverse_perm)
            # Step 9.
            row = tf.pad(row, paddings=paddings)
            # TODO(bjp): Defer the reverse permutation all-at-once at the end?
            row = batch_gather(row, reverse_perm)
            pchol_shape = pchol.shape
            pchol = tf.concat([pchol[..., :m, :], row, pchol[..., m + 1:, :]],
                              axis=-2)
            tensorshape_util.set_shape(pchol, pchol_shape)
            return m + 1, pchol, perm, matrix_diag

        m = np.int64(0)
        pchol = tf.zeros_like(matrix[..., :max_rank, :])
        matrix_shape = prefer_static.shape(matrix, out_type=tf.int64)
        perm = tf.broadcast_to(prefer_static.range(matrix_shape[-1]),
                               matrix_shape[:-1])
        _, pchol, _, _ = tf.while_loop(cond=cond,
                                       body=body,
                                       loop_vars=(m, pchol, perm, matrix_diag))
        pchol = tf.linalg.matrix_transpose(pchol)
        tensorshape_util.set_shape(
            pchol, tensorshape_util.concatenate(matrix_diag.shape, [None]))
        return pchol
Пример #26
0
def _invert_permutation(perm):  # TODO(b/130217510): Remove this function.
    return tf.cast(
        tf.math.top_k(perm, k=prefer_static.shape(perm)[-1],
                      sorted=True).indices[..., ::-1], perm.dtype)