def _covariance(self):
     # Derivation:
     event_dim = tf.compat.dimension_value(self.event_shape[0])
     if event_dim is None:
         raise ValueError(
             'event shape must be statically known for _bessel_ive')
     # TODO(bjp): Enable this; numerically unstable.
     if event_dim > 2:
         raise ValueError(
             'vMF covariance is numerically unstable for dim>2')
     concentration = self.concentration[..., tf.newaxis]
     safe_conc = tf.where(concentration > 0, concentration,
     h = (_bessel_ive(event_dim / 2, safe_conc) /
          _bessel_ive(event_dim / 2 - 1, safe_conc))
     intermediate = (
         tf.matmul(self.mean_direction[..., :, tf.newaxis],
                   self.mean_direction[..., tf.newaxis, :]) *
         (1 - event_dim * h / safe_conc - h**2)[..., tf.newaxis])
     cov = tf.linalg.set_diag(
         tf.linalg.diag_part(intermediate) + (h / safe_conc))
     return tf.where(
         concentration[..., tf.newaxis] > tf.zeros_like(cov), cov,
         tf.linalg.eye(event_dim, batch_shape=self.batch_shape_tensor()) /
Пример #2
def cholesky_concat(chol, cols, name=None):
    """Concatenates `chol @ chol.T` with additional rows and columns.

  This operation is conceptually identical to:
  def cholesky_concat_slow(chol, cols):  # cols shaped (n + m) x m = z x m
    mat = tf.matmul(chol, chol, adjoint_b=True)  # batch of n x n
    # Concat columns.
    mat = tf.concat([mat, cols[..., :tf.shape(mat)[-2], :]], axis=-1)  # n x z
    # Concat rows.
    mat = tf.concat([mat, tf.linalg.matrix_transpose(cols)], axis=-2)  # z x z
    return tf.linalg.cholesky(mat)
  but whereas `cholesky_concat_slow` would cost `O(z**3)` work,
  `cholesky_concat` only costs `O(z**2 + m**3)` work.

  The resulting (implicit) matrix must be symmetric and positive definite.
  Thus, the bottom right `m x m` must be self-adjoint, and we do not require a
  separate `rows` argument (which can be inferred from `conj(cols.T)`).

    chol: Cholesky decomposition of `mat = chol @ chol.T`.
    cols: The new columns whose first `n` rows we would like concatenated to the
      right of `mat = chol @ chol.T`, and whose conjugate transpose we would
      like concatenated to the bottom of `concat(mat, cols[:n,:])`. A `Tensor`
      with final dims `(n+m, m)`. The first `n` rows are the top right rectangle
      (their conjugate transpose forms the bottom left), and the bottom `m x m`
      is self-adjoint.
    name: Optional name for this op.

    chol_concat: The Cholesky decomposition of:
      [ [ mat  cols[:n, :] ]
        [   conj(cols.T)   ] ]
    with tf.name_scope(name or 'cholesky_extend'):
        dtype = dtype_util.common_dtype([chol, cols], dtype_hint=tf.float32)
        chol = tf.convert_to_tensor(chol, name='chol', dtype=dtype)
        cols = tf.convert_to_tensor(cols, name='cols', dtype=dtype)
        n = prefer_static.shape(chol)[-1]
        mat_nm, mat_mm = cols[..., :n, :], cols[..., n:, :]
        solved_nm = linear_operator_util.matrix_triangular_solve_with_broadcast(
            chol, mat_nm)
        lower_right_mm = tf.linalg.cholesky(
            mat_mm - tf.matmul(solved_nm, solved_nm, adjoint_a=True))
        lower_left_mn = tf.math.conj(tf.linalg.matrix_transpose(solved_nm))
        out_batch = prefer_static.shape(solved_nm)[:-2]
        chol = tf.broadcast_to(
            tf.concat([out_batch, prefer_static.shape(chol)[-2:]], axis=0))
        top_right_zeros_nm = tf.zeros_like(solved_nm)
        return tf.concat([
            tf.concat([chol, top_right_zeros_nm], axis=-1),
            tf.concat([lower_left_mn, lower_right_mm], axis=-1)
Пример #3
 def _variance(self):
     # Because df is a scalar, we need to expand dimensions to match
     # scale_operator. We use ellipses notation (...) to select all dimensions
     # and add two dimensions to the end.
     df = self.df[..., tf.newaxis, tf.newaxis]
     x = tf.sqrt(df) * self._square_scale_operator()
     d = tf.expand_dims(tf.linalg.diag_part(x), -1)
     v = tf.square(x) + tf.matmul(d, d, adjoint_b=True)
     return v
  def _covariance(self):
    total_count = tf.convert_to_tensor(self._total_count)
    concentration = tf.convert_to_tensor(self._concentration)

    scale = self._variance_scale_term(total_count, concentration)
    x = scale * self._mean(total_count, concentration)

    return tf.linalg.set_diag(
        -tf.matmul(x[..., tf.newaxis], x[..., tf.newaxis, :]),  # outer prod
        self._variance(total_count, concentration))
Пример #5
 def _covariance(self):
   concentration = tf.convert_to_tensor(self.concentration)
   total_concentration = tf.reduce_sum(concentration, axis=-1, keepdims=True)
   mean = concentration / total_concentration
   scale = tf.math.rsqrt(1. + total_concentration)
   x = scale * mean
   variance = x * (scale - x)
   return tf.linalg.set_diag(
       tf.matmul(-x[..., tf.newaxis], x[..., tf.newaxis, :]),
Пример #6
 def _forward(self, x):
     with tf.control_dependencies(self._assertions(x)):
         x_shape = tf.shape(x)
         identity_matrix = tf.eye(x_shape[-1],
         # Note `matrix_triangular_solve` implicitly zeros upper triangular of `x`.
         y = tf.linalg.triangular_solve(x, identity_matrix)
         y = tf.matmul(y, y, adjoint_a=True)
         return tf.linalg.cholesky(y)
Пример #7
def sparse_or_dense_matmul(sparse_or_dense_a,
    """Returns (batched) matmul of a SparseTensor (or Tensor) with a Tensor.

    sparse_or_dense_a: `SparseTensor` or `Tensor` representing a (batch of)
    dense_b: `Tensor` representing a (batch of) matrices, with the same batch
      shape as `sparse_or_dense_a`. The shape must be compatible with the shape
      of `sparse_or_dense_a` and kwargs.
    validate_args: When `True`, additional assertions might be embedded in the
      Default value: `False` (i.e., no graph assertions are added).
    name: Python `str` prefixed to ops created by this function.
      Default value: 'sparse_or_dense_matmul'.
    **kwargs: Keyword arguments to `tf.sparse_tensor_dense_matmul` or

    product: A dense (batch of) matrix-shaped Tensor of the same batch shape and
    dtype as `sparse_or_dense_a` and `dense_b`. If `sparse_or_dense_a` or
    `dense_b` is adjointed through `kwargs` then the shape is adjusted
    with tf.name_scope(name or 'sparse_or_dense_matmul'):
        dense_b = tf.convert_to_tensor(dense_b,

        if validate_args:
            assert_a_rank_at_least_2 = assert_util.assert_rank_at_least(
                'Input `sparse_or_dense_a` must have at least 2 dimensions.')
            assert_b_rank_at_least_2 = assert_util.assert_rank_at_least(
                message='Input `dense_b` must have at least 2 dimensions.')
            with tf.control_dependencies(
                [assert_a_rank_at_least_2, assert_b_rank_at_least_2]):
                sparse_or_dense_a = tf.identity(sparse_or_dense_a)
                dense_b = tf.identity(dense_b)

        if isinstance(sparse_or_dense_a,
                      (tf.SparseTensor, tf1.SparseTensorValue)):
            return _sparse_tensor_dense_matmul(sparse_or_dense_a, dense_b,
            return tf.matmul(sparse_or_dense_a, dense_b, **kwargs)
Пример #8
def _sparse_block_diag(sp_a):
    """Returns a block diagonal rank 2 SparseTensor from a batch of SparseTensors.

    sp_a: A rank 3 `SparseTensor` representing a batch of matrices.

    sp_block_diag_a: matrix-shaped, `float` `SparseTensor` with the same dtype
    as `sparse_or_matrix`, of shape [B * M, B * N] where `sp_a` has shape
    [B, M, N]. Each [M, N] batch of `sp_a` is lined up along the diagonal.
    # Construct the matrix [[M, N], [1, 0], [0, 1]] which would map the index
    # (b, i, j) to (Mb + i, Nb + j). This effectively creates a block-diagonal
    # matrix of dense shape [B * M, B * N].
    # Note that this transformation doesn't increase the number of non-zero
    # entries in the SparseTensor.
    sp_a_shape = tf.convert_to_tensor(_get_shape(sp_a, tf.int64))
    ind_mat = tf.concat([[sp_a_shape[-2:]], tf.eye(2, dtype=tf.int64)], axis=0)
    indices = tf.matmul(sp_a.indices, ind_mat)
    dense_shape = sp_a_shape[0] * sp_a_shape[1:]
    return tf.SparseTensor(indices=indices,
 def _covariance(self):
     p = self._probs_parameter_no_checks()
     ret = -tf.matmul(p[..., None], p[..., None, :])
     return tf.linalg.set_diag(ret, self._variance(p))
Пример #10
    def _forward_log_det_jacobian(self, x):
        # Let Y be a symmetric, positive definite matrix and write:
        #   Y = X X.T
        # where X is lower-triangular.
        # Observe that,
        #   dY[i,j]/dX[a,b]
        #   = d/dX[a,b] { X[i,:] X[j,:] }
        #   = sum_{d=1}^p { I[i=a] I[d=b] X[j,d] + I[j=a] I[d=b] X[i,d] }
        # To compute the Jacobian dX/dY we must represent X,Y as vectors. Since Y is
        # symmetric and X is lower-triangular, we need vectors of dimension:
        #   d = p (p + 1) / 2
        # where X, Y are p x p matrices, p > 0. We use a row-major mapping, i.e.,
        #   k = { i (i + 1) / 2 + j   i>=j
        #       { undef               i<j
        # and assume zero-based indexes. When k is undef, the element is dropped.
        # Example:
        #           j      k
        #        0 1 2 3  /
        #    0 [ 0 . . . ]
        # i  1 [ 1 2 . . ]
        #    2 [ 3 4 5 . ]
        #    3 [ 6 7 8 9 ]
        # Write vec[.] to indicate transforming a matrix to vector via k(i,j). (With
        # slight abuse: k(i,j)=undef means the element is dropped.)
        # We now show d vec[Y] / d vec[X] is lower triangular. Assuming both are
        # defined, observe that k(i,j) < k(a,b) iff (1) i<a or (2) i=a and j<b.
        # In both cases dvec[Y]/dvec[X]@[k(i,j),k(a,b)] = 0 since:
        # (1) j<=i<a thus i,j!=a.
        # (2) i=a>j  thus i,j!=a.
        # Since the Jacobian is lower-triangular, we need only compute the product
        # of diagonal elements:
        #   d vec[Y] / d vec[X] @[k(i,j), k(i,j)]
        #   = X[j,j] + I[i=j] X[i,j]
        #   = 2 X[j,j].
        # Since there is a 2 X[j,j] term for every lower-triangular element of X we
        # conclude:
        #   |Jac(d vec[Y]/d vec[X])| = 2^p prod_{j=0}^{p-1} X[j,j]^{p-j}.
        diag = tf.linalg.diag_part(x)

        # We now ensure diag is columnar. Eg, if `diag = [1, 2, 3]` then the output
        # is `[[1], [2], [3]]` and if `diag = [[1, 2, 3], [4, 5, 6]]` then the
        # output is unchanged.
        diag = self._make_columnar(diag)

        with tf.control_dependencies(self._assertions(x)):
            # Create a vector equal to: [p, p-1, ..., 2, 1].
            if tf.compat.dimension_value(x.shape[-1]) is None:
                p_int = tf.shape(x)[-1]
                p_float = tf.cast(p_int, dtype=x.dtype)
                p_int = tf.compat.dimension_value(x.shape[-1])
                p_float = dtype_util.as_numpy_dtype(x.dtype)(p_int)
            exponents = tf.linspace(p_float, 1., p_int)

            sum_weighted_log_diag = tf.squeeze(tf.matmul(
                tf.math.log(diag), exponents[..., tf.newaxis]),
            fldj = p_float * np.log(2.) + sum_weighted_log_diag

            # We finally need to undo adding an extra column in non-scalar cases
            # where there is a single matrix as input.
            if tensorshape_util.rank(x.shape) is not None:
                if tensorshape_util.rank(x.shape) == 2:
                    fldj = tf.squeeze(fldj, axis=-1)
                return fldj

            shape = tf.shape(fldj)
            maybe_squeeze_shape = tf.concat([
                    tf.rank(x), 2), np.array([], dtype=np.int32), shape[-1:])
            ], 0)
            return tf.reshape(fldj, maybe_squeeze_shape)
Пример #11
 def _forward(self, x):
     with tf.control_dependencies(self._assertions(x)):
         # For safety, explicitly zero-out the upper triangular part.
         x = tf.linalg.band_part(x, -1, 0)
         return tf.matmul(x, x, adjoint_b=True)
Пример #12
def lu_reconstruct(lower_upper, perm, validate_args=False, name=None):
    """The inverse LU decomposition, `X == lu_reconstruct(*`.

    lower_upper: `lu` as returned by ``, i.e., if
      `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`.
    perm: `p` as returned by ``, i.e., if
      `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`.
    validate_args: Python `bool` indicating whether arguments should be checked
      for correctness.
      Default value: `False` (i.e., don't validate arguments).
    name: Python `str` name given to ops managed by this object.
      Default value: `None` (i.e., 'lu_reconstruct').

    x: The original input to ``, i.e., `x` as in,

  #### Examples

  import numpy as np
  from tensorflow_probability.python.internal.backend import numpy as tf
  import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.numpy

  x = [[[3., 4], [1, 2]],
       [[7., 8], [3, 4]]]
  x_reconstructed = tfp.math.lu_reconstruct(*
  tf.assert_near(x, x_reconstructed)
  # ==> True

    with tf.name_scope(name or 'lu_reconstruct'):
        lower_upper = tf.convert_to_tensor(lower_upper,
        perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm')

        assertions = lu_reconstruct_assertions(lower_upper, perm,
        if assertions:
            with tf.control_dependencies(assertions):
                lower_upper = tf.identity(lower_upper)
                perm = tf.identity(perm)

        shape = tf.shape(lower_upper)

        lower = tf.linalg.set_diag(
            tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0),
            tf.ones(shape[:-1], dtype=lower_upper.dtype))
        upper = tf.linalg.band_part(lower_upper, num_lower=0, num_upper=-1)
        x = tf.matmul(lower, upper)

        if (tensorshape_util.rank(lower_upper.shape) is None
                or tensorshape_util.rank(lower_upper.shape) != 2):
            # We either don't know the batch rank or there are >0 batch dims.
            batch_size = tf.reduce_prod(shape[:-2])
            d = shape[-1]
            x = tf.reshape(x, [batch_size, d, d])
            perm = tf.reshape(perm, [batch_size, d])
            perm = tf.map_fn(tf.math.invert_permutation, perm)
            batch_indices = tf.broadcast_to(
                tf.range(batch_size)[:, tf.newaxis], [batch_size, d])
            x = tf.gather_nd(x, tf.stack([batch_indices, perm], axis=-1))
            x = tf.reshape(x, shape)
            x = tf.gather(x, tf.math.invert_permutation(perm))

        return x
Пример #13
def pinv(a, rcond=None, validate_args=False, name=None):
    """Compute the Moore-Penrose pseudo-inverse of a matrix.

  Calculate the [generalized inverse of a matrix]( using its
  singular-value decomposition (SVD) and including all large singular values.

  The pseudo-inverse of a matrix `A`, is defined as: 'the matrix that 'solves'
  [the least-squares problem] `A @ x = b`,' i.e., if `x_hat` is a solution, then
  `A_pinv` is the matrix such that `x_hat = A_pinv @ b`. It can be shown that if
  `U @ Sigma @ V.T = A` is the singular value decomposition of `A`, then
  `A_pinv = V @ inv(Sigma) U^T`. [(Strang, 1980)][1]

  This function is analogous to [`numpy.linalg.pinv`](
  It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the
  default `rcond` is `1e-15`. Here the default is
  `10. * max(num_rows, num_cols) * np.finfo(dtype).eps`.

    a: (Batch of) `float`-like matrix-shaped `Tensor`(s) which are to be
    rcond: `Tensor` of small singular value cutoffs.  Singular values smaller
      (in modulus) than `rcond` * largest_singular_value (again, in modulus) are
      set to zero. Must broadcast against `tf.shape(a)[:-2]`.
      Default value: `10. * max(num_rows, num_cols) * np.finfo(a.dtype).eps`.
    validate_args: When `True`, additional assertions might be embedded in the
      Default value: `False` (i.e., no graph assertions are added).
    name: Python `str` prefixed to ops created by this function.
      Default value: 'pinv'.

    a_pinv: The pseudo-inverse of input `a`. Has same shape as `a` except
      rightmost two dimensions are transposed.

    TypeError: if input `a` does not have `float`-like `dtype`.
    ValueError: if input `a` has fewer than 2 dimensions.

  #### Examples

  from tensorflow_probability.python.internal.backend import numpy as tf
  import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.numpy

  a = tf.constant([[1.,  0.4,  0.5],
                   [0.4, 0.2,  0.25],
                   [0.5, 0.25, 0.35]])
  tf.matmul(tfp.math.pinv(a), a)
  # ==> array([[1., 0., 0.],
               [0., 1., 0.],
               [0., 0., 1.]], dtype=float32)

  a = tf.constant([[1.,  0.4,  0.5,  1.],
                   [0.4, 0.2,  0.25, 2.],
                   [0.5, 0.25, 0.35, 3.]])
  tf.matmul(tfp.math.pinv(a), a)
  # ==> array([[ 0.76,  0.37,  0.21, -0.02],
               [ 0.37,  0.43, -0.33,  0.02],
               [ 0.21, -0.33,  0.81,  0.01],
               [-0.02,  0.02,  0.01,  1.  ]], dtype=float32)

  #### References

  [1]: G. Strang. 'Linear Algebra and Its Applications, 2nd Ed.' Academic Press,
       Inc., 1980, pp. 139-142.
    with tf.name_scope(name or 'pinv'):
        a = tf.convert_to_tensor(a, name='a')

        assertions = _maybe_validate_matrix(a, validate_args)
        if assertions:
            with tf.control_dependencies(assertions):
                a = tf.identity(a)

        dtype = dtype_util.as_numpy_dtype(a.dtype)

        if rcond is None:

            def get_dim_size(dim):
                if tf.compat.dimension_value(a.shape[dim]) is not None:
                    return tf.compat.dimension_value(a.shape[dim])
                return tf.shape(a)[dim]

            num_rows = get_dim_size(-2)
            num_cols = get_dim_size(-1)
            if isinstance(num_rows, int) and isinstance(num_cols, int):
                max_rows_cols = float(max(num_rows, num_cols))
                max_rows_cols = tf.cast(tf.maximum(num_rows, num_cols), dtype)
            rcond = 10. * max_rows_cols * np.finfo(dtype).eps

        rcond = tf.convert_to_tensor(rcond, dtype=dtype, name='rcond')

        # Calculate pseudo inverse via SVD.
        # Note: if a is symmetric then u == v. (We might observe additional
        # performance by explicitly setting `v = u` in such cases.)
            singular_values,  # Sigma
            left_singular_vectors,  # U
            right_singular_vectors,  # V
        ] = tf.linalg.svd(a, full_matrices=False, compute_uv=True)

        # Saturate small singular values to inf. This has the effect of make
        # `1. / s = 0.` while not resulting in `NaN` gradients.
        cutoff = rcond * tf.reduce_max(singular_values, axis=-1)
        singular_values = tf.where(singular_values > cutoff[..., tf.newaxis],
                                   singular_values, np.array(np.inf, dtype))

        # Although `a == tf.matmul(u, s * v, transpose_b=True)` we swap
        # `u` and `v` here so that `tf.matmul(pinv(A), A) = tf.eye()`, i.e.,
        # a matrix inverse has 'transposed' semantics.
        a_pinv = tf.matmul(right_singular_vectors /
                           singular_values[..., tf.newaxis, :],

        if tensorshape_util.rank(a.shape) is not None:
                [a.shape[-1], a.shape[-2]]))

        return a_pinv
Пример #14
    def _sample_n(self, n, seed):
        batch_shape = self.batch_shape_tensor()
        event_shape = self.event_shape_tensor()
        batch_ndims = tf.shape(batch_shape)[0]

        ndims = batch_ndims + 3  # sample_ndims=1, event_ndims=2
        shape = tf.concat([[n], batch_shape, event_shape], 0)
        stream = SeedStream(seed, salt="Wishart")

        # Complexity: O(nbk**2)
        x = tf.random.normal(shape=shape,

        # Complexity: O(nbk)
        # This parametrization is equivalent to Chi2, i.e.,
        # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2)
        expanded_df = self.df * tf.ones(

        g = tf.random.gamma(shape=[n],
                                0.5 * expanded_df, self.dimension),

        # Complexity: O(nbk**2)
        x = tf.linalg.band_part(x, -1, 0)  # Tri-lower.

        # Complexity: O(nbk)
        x = tf.linalg.set_diag(x, tf.sqrt(g))

        # Make batch-op ready.
        # Complexity: O(nbk**2)
        perm = tf.concat([tf.range(1, ndims), [0]], 0)
        x = tf.transpose(a=x, perm=perm)
        shape = tf.concat(
            [batch_shape, [event_shape[0]], [event_shape[1] * n]], 0)
        x = tf.reshape(x, shape)

        # Complexity: O(nbM) where M is the complexity of the operator solving a
        # vector system. For LinearOperatorLowerTriangular, each matmul is O(k^3) so
        # this step has complexity O(nbk^3).
        x = self.scale_operator.matmul(x)

        # Undo make batch-op ready.
        # Complexity: O(nbk**2)
        shape = tf.concat([batch_shape, event_shape, [n]], 0)
        x = tf.reshape(x, shape)
        perm = tf.concat([[ndims - 1], tf.range(0, ndims - 1)], 0)
        x = tf.transpose(a=x, perm=perm)

        if not self.input_output_cholesky:
            # Complexity: O(nbk**3)
            x = tf.matmul(x, x, adjoint_b=True)

        return x
Пример #15
    def _sample_n(self, num_samples, seed=None, name=None):
        """Returns a Tensor of samples from an LKJ distribution.

      num_samples: Python `int`. The number of samples to draw.
      seed: Python integer seed for RNG
      name: Python `str` name prefixed to Ops created by this function.

      samples: A Tensor of correlation matrices with shape `[n, B, D, D]`,
        where `B` is the shape of the `concentration` parameter, and `D`
        is the `dimension`.

      ValueError: If `dimension` is negative.
        if self.dimension < 0:
            raise ValueError(
                'Cannot sample negative-dimension correlation matrices.')
        # Notation below: B is the batch shape, i.e., tf.shape(concentration)
        seed = SeedStream(seed, 'sample_lkj')
        with tf.name_scope('sample_lkj' or name):
            concentration = tf.convert_to_tensor(self.concentration)
            if not dtype_util.is_floating(concentration.dtype):
                raise TypeError(
                    'The concentration argument should have floating type, not '

            concentration = _replicate(num_samples, concentration)
            concentration_shape = tf.shape(concentration)
            if self.dimension <= 1:
                # For any dimension <= 1, there is only one possible correlation matrix.
                shape = tf.concat(
                    [concentration_shape, [self.dimension, self.dimension]],
                return tf.ones(shape=shape, dtype=concentration.dtype)
            beta_conc = concentration + (self.dimension - 2.) / 2.
            beta_dist = beta.Beta(concentration1=beta_conc,

            # Note that the sampler below deviates from [1], by doing the sampling in
            # cholesky space. This does not change the fundamental logic of the
            # sampler, but does speed up the sampling.

            # This is the correlation coefficient between the first two dimensions.
            # This is also `r` in reference [1].
            corr12 = 2. * beta_dist.sample(seed=seed()) - 1.

            # Below we construct the Cholesky of the initial 2x2 correlation matrix,
            # which is of the form:
            # [[1, 0], [r, sqrt(1 - r**2)]], where r is the correlation between the
            # first two dimensions.
            # This is the top-left corner of the cholesky of the final sample.
            first_row = tf.concat([
                tf.ones_like(corr12)[..., tf.newaxis],
                tf.zeros_like(corr12)[..., tf.newaxis]
            second_row = tf.concat([
                corr12[..., tf.newaxis],
                tf.sqrt(1 - corr12**2)[..., tf.newaxis]

            chol_result = tf.concat([
                first_row[..., tf.newaxis, :], second_row[..., tf.newaxis, :]

            for n in range(2, self.dimension):
                # Loop invariant: on entry, result has shape B + [n, n]
                beta_conc = beta_conc - 0.5
                # norm is y in reference [1].
                norm = beta.Beta(concentration1=n / 2.,
                # distance shape: B + [1] for broadcast
                distance = tf.sqrt(norm)[..., tf.newaxis]
                # direction is u in reference [1].
                # direction shape: B + [n]
                direction = _uniform_unit_norm(n, concentration_shape,
                                               concentration.dtype, seed)
                # raw_correlation is w in reference [1].
                raw_correlation = distance * direction  # shape: B + [n]

                # This is the next row in the cholesky of the result,
                # which differs from the construction in reference [1].
                # In the reference, the new row `z` = chol_result @ raw_correlation^T
                # = C @ raw_correlation^T (where as short hand we use C = chol_result).
                # We prove that the below equation is the right row to add to the
                # cholesky, by showing equality with reference [1].
                # Let S be the sample constructed so far, and let `z` be as in
                # reference [1]. Then at this iteration, the new sample S' will be
                # [[S z^T]
                #  [z 1]]
                # In our case we have the cholesky decomposition factor C, so
                # we want our new row x (same size as z) to satisfy:
                #  [[S z^T]  [[C 0]    [[C^T  x^T]         [[CC^T  Cx^T]
                #   [z 1]] =  [x k]]    [0     k]]  =       [xC^t   xx^T + k**2]]
                # Since C @ raw_correlation^T = z = C @ x^T, and C is invertible,
                # we have that x = raw_correlation. Also 1 = xx^T + k**2, so k
                # = sqrt(1 - xx^T) = sqrt(1 - |raw_correlation|**2) = sqrt(1 -
                # distance**2).
                new_row = tf.concat(
                     tf.sqrt(1. - norm[..., tf.newaxis])],

                # Finally add this new row, by growing the cholesky of the result.
                chol_result = tf.concat([
                    tf.zeros_like(chol_result[..., 0][..., tf.newaxis])

                chol_result = tf.concat(
                    [chol_result, new_row[..., tf.newaxis, :]], axis=-2)

            if self.input_output_cholesky:
                return chol_result

            result = tf.matmul(chol_result, chol_result, transpose_b=True)
            # The diagonal for a correlation matrix should always be ones. Due to
            # numerical instability the matmul might not achieve that, so manually set
            # these to ones.
            result = tf.linalg.set_diag(
                result, tf.ones(shape=tf.shape(result)[:-1],
            # This sampling algorithm can produce near-PSD matrices on which standard
            # algorithms such as `tf.cholesky` or `tf.linalg.self_adjoint_eigvals`
            # fail. Specifically, as documented in b/116828694, around 2% of trials
            # of 900,000 5x5 matrices (distributed according to 9 different
            # concentration parameter values) contained at least one matrix on which
            # the Cholesky decomposition failed.
            return result