def _get_flattened_marginal_distribution(self, index_points=None):
        # This returns a MVN of event size [N * E], where N is the number of tasks
        # and E is the number of index points.
        with self._name_and_control_scope(
                'get_flattened_marginal_distribution'):
            index_points = self._get_index_points(index_points)
            covariance = self._compute_flattened_covariance(index_points)

            batch_shape = self._batch_shape_tensor(index_points=index_points)
            event_shape = self._event_shape_tensor(index_points=index_points)

            # Now take the cholesky but specialize to cases where we have block-diag
            # and kronecker.
            covariance_cholesky = cholesky_util.cholesky_from_fn(
                covariance, self._cholesky_fn)
            loc = self._mean_fn(index_points)
            # Ensure that we broadcast the mean function result to ensure we support
            # constant mean functions (constant over all tasks, and a constant
            # per-task)
            loc = ps.broadcast_to(
                loc, ps.concat([batch_shape, event_shape], axis=0))
            loc = _vec(loc)
            return mvn_linear_operator.MultivariateNormalLinearOperator(
                loc=loc,
                scale=covariance_cholesky,
                validate_args=self._validate_args,
                allow_nan_stats=self._allow_nan_stats,
                name='marginal_distribution')
Ejemplo n.º 2
0
    def testCholeskyFnType(self):
        identity = tf.linalg.LinearOperatorIdentity(3)
        self.assertIsInstance(
            cholesky_util.cholesky_from_fn(identity, tf.linalg.cholesky),
            tf.linalg.LinearOperatorIdentity)

        diag = tf.linalg.LinearOperatorDiag([3., 5., 7.])
        self.assertIsInstance(
            cholesky_util.cholesky_from_fn(diag, tf.linalg.cholesky),
            tf.linalg.LinearOperatorDiag)

        kron = tf.linalg.LinearOperatorKronecker([identity, diag])
        self.assertIsInstance(
            cholesky_util.cholesky_from_fn(kron, tf.linalg.cholesky),
            tf.linalg.LinearOperatorKronecker)

        block_diag = tf.linalg.LinearOperatorBlockDiag([identity, diag])
        self.assertIsInstance(
            cholesky_util.cholesky_from_fn(block_diag, tf.linalg.cholesky),
            tf.linalg.LinearOperatorBlockDiag)
def _compute_observation_scale(kernel,
                               observation_index_points,
                               cholesky_fn,
                               observation_noise_variance=None,
                               observations_is_missing=None):
    """Compute matrix square root of the kernel on observation index points."""
    if observations_is_missing is not None:
        observations_is_missing = tf.convert_to_tensor(observations_is_missing)
        # If observations are missing, there's nothing we can do to preserve the
        # operator structure, so densify.

        observation_covariance = kernel.matrix_over_all_tasks(
            observation_index_points, observation_index_points).to_dense()

        if observation_noise_variance is not None:
            broadcast_shape = distribution_util.get_broadcast_shape(
                observation_covariance,
                observation_noise_variance[..., tf.newaxis, tf.newaxis])
            observation_covariance = tf.broadcast_to(observation_covariance,
                                                     broadcast_shape)
            observation_covariance = _add_diagonal_shift(
                observation_covariance, observation_noise_variance)
        vec_observations_is_missing = _vec(observations_is_missing)
        observation_covariance = tf.linalg.LinearOperatorFullMatrix(
            psd_kernels_util.mask_matrix(
                observation_covariance,
                is_missing=vec_observations_is_missing),
            is_non_singular=True,
            is_positive_definite=True)
        observation_scale = cholesky_util.cholesky_from_fn(
            observation_covariance, cholesky_fn)
    else:
        observation_scale = mtgp._compute_flattened_scale(  # pylint:disable=protected-access
            kernel=kernel,
            index_points=observation_index_points,
            cholesky_fn=cholesky_fn,
            observation_noise_variance=observation_noise_variance)

    return observation_scale
    def precompute_regression_model(
            kernel,
            observation_index_points,
            observations,
            observations_is_missing=None,
            index_points=None,
            observation_noise_variance=None,
            predictive_noise_variance=None,
            mean_fn=None,
            cholesky_fn=None,
            validate_args=False,
            allow_nan_stats=False,
            name='PrecomputedMultiTaskGaussianProcessRegressionModel'):
        """Returns a MTGaussianProcessRegressionModel with precomputed quantities.

    This differs from the constructor by precomputing quantities associated with
    observations in a non-tape safe way. `index_points` is the only parameter
    that is allowed to vary (i.e. is a `Variable` / changes after
    initialization).

    Specifically:

    * We make `observation_index_points` and `observations` mandatory
      parameters.
    * We precompute `kernel(observation_index_points, observation_index_points)`
      along with any other associated quantities relating to the `kernel`,
      `observations` and `observation_index_points`.

    A typical usecase would be optimizing kernel hyperparameters for a
    `MultiTaskGaussianProcess`, and computing the posterior predictive with
    respect to those optimized hyperparameters and observation / index-points
    pairs.

    WARNING: This method assumes `index_points` is the only varying parameter
    (i.e. is a `Variable` / changes after initialization) and hence is not
    tape-safe.

    Args:
      kernel: `PositiveSemidefiniteKernel`-like instance representing the
        GP's covariance function.
      observation_index_points: `float` `Tensor` representing finite collection,
        or batch of collections, of points in the index set for which some data
        has been observed. Shape has the form `[b1, ..., bB, e, f1, ..., fF]`
        where `F` is the number of feature dimensions and must equal
        `kernel.feature_ndims`, and `e` is the number (size) of index points in
        each batch. `[b1, ..., bB, e]` must be broadcastable with the shape of
        `observations`, and `[b1, ..., bB]` must be broadcastable with the
        shapes of all other batched parameters (`kernel.batch_shape`,
        `index_points`, etc). The default value is `None`, which corresponds to
        the empty set of observations, and simply results in the prior
        predictive model (a GP with noise of variance
        `predictive_noise_variance`).
      observations: `float` `Tensor` representing collection, or batch of
        collections, of observations corresponding to
        `observation_index_points`. Shape has the form `[b1, ..., bB, e, t]`
        The batch shape `[b1, ..., bB]` must be
        broadcastable with the shapes of all other batched parameters
        (`kernel.batch_shape`, `index_points`, etc.). The default value is
        `None`, which corresponds to the empty set of observations, and simply
        results in the prior predictive model (a GP with noise of variance
        `predictive_noise_variance`).
      observations_is_missing:  `bool` `Tensor` of shape `[..., e]`,
        representing a batch of boolean masks.  When `observations_is_missing`
        is not `None`, the returned distribution is conditioned only on the
        observations for which the corresponding elements of
        `observations_is_missing` are `True`.
      index_points: `float` `Tensor` representing finite collection, or batch of
        collections, of points in the index set over which the GP is defined.
        Shape has the form `[b1, ..., bB, e, f1, ..., fF]` where `F` is the
        number of feature dimensions and must equal `kernel.feature_ndims` and
        `e` is the number (size) of index points in each batch. Ultimately this
        distribution corresponds to an `e`-dimensional multivariate normal. The
        batch shape must be broadcastable with `kernel.batch_shape` and any
        batch dims yielded by `mean_fn`.
      observation_noise_variance: `float` `Tensor` representing the variance
        of the noise in the Normal likelihood distribution of the model. May be
        batched, in which case the batch shape must be broadcastable with the
        shapes of all other batched parameters (`kernel.batch_shape`,
        `index_points`, etc.).
        Default value: `None`
      predictive_noise_variance: `float` `Tensor` representing the variance in
        the posterior predictive model. If `None`, we simply re-use
        `observation_noise_variance` for the posterior predictive noise. If set
        explicitly, however, we use this value. This allows us, for example, to
        omit predictive noise variance (by setting this to zero) to obtain
        noiseless posterior predictions of function values, conditioned on noisy
        observations.
      mean_fn: Python `callable` that acts on `index_points` to produce a
        collection, or batch of collections, of mean values at `index_points`.
        Takes a `Tensor` of shape `[b1, ..., bB, f1, ..., fF]` and returns a
        `Tensor` whose shape is broadcastable with `[b1, ..., bB, t]`.
        Default value: `None` implies the constant zero function.
      cholesky_fn: Callable which takes a single (batch) matrix argument and
        returns a Cholesky-like lower triangular factor.  Default value: `None`,
        in which case `make_cholesky_with_jitter_fn` is used with the `jitter`
        parameter.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value `NaN` to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
        Default value: `False`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: 'PrecomputedGaussianProcessRegressionModel'.
    Returns
      An instance of `MultiTaskGaussianProcessRegressionModel` with precomputed
      quantities associated with observations.
    """

        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([
                index_points,
                observation_index_points,
                observations,
                observation_noise_variance,
                predictive_noise_variance,
            ], tf.float32)

            # Convert-to-tensor arguments that are expected to not be Variables / not
            # going to change.
            observation_index_points = tf.convert_to_tensor(
                observation_index_points, dtype=dtype)
            if observation_noise_variance is not None:
                observation_noise_variance = tf.convert_to_tensor(
                    observation_noise_variance, dtype=dtype)
            observations = tf.convert_to_tensor(observations, dtype=dtype)

            if observations_is_missing is not None:
                observations_is_missing = tf.convert_to_tensor(
                    observations_is_missing)

            if cholesky_fn is None:
                cholesky_fn = cholesky_util.make_cholesky_with_jitter_fn()
            else:
                if not callable(cholesky_fn):
                    raise ValueError('`cholesky_fn` must be a Python callable')

            if mean_fn is None:
                mean_fn = lambda x: tf.zeros([1], dtype=dtype)
            else:
                if not callable(mean_fn):
                    raise ValueError('`mean_fn` must be a Python callable')

            if observations_is_missing is not None:
                # If observations are missing, there's nothing we can do to preserve the
                # operator structure, so densify.

                observation_covariance = kernel.matrix_over_all_tasks(
                    observation_index_points,
                    observation_index_points).to_dense()

                if observation_noise_variance is not None:
                    broadcast_shape = distribution_util.get_broadcast_shape(
                        observation_covariance,
                        observation_noise_variance[..., tf.newaxis,
                                                   tf.newaxis])
                    observation_covariance = tf.broadcast_to(
                        observation_covariance, broadcast_shape)
                    observation_covariance = _add_diagonal_shift(
                        observation_covariance, observation_noise_variance)
                vec_observations_is_missing = _vec(observations_is_missing)
                observation_covariance = tf.linalg.LinearOperatorFullMatrix(
                    psd_kernels_util.mask_matrix(
                        observation_covariance,
                        is_missing=vec_observations_is_missing),
                    is_non_singular=True,
                    is_positive_definite=True)
                observation_scale = cholesky_util.cholesky_from_fn(
                    observation_covariance, cholesky_fn)
            else:
                observation_scale = mtgp._compute_flattened_scale(  # pylint:disable=protected-access
                    kernel=kernel,
                    index_points=observation_index_points,
                    cholesky_fn=cholesky_fn,
                    observation_noise_variance=observation_noise_variance)

            # Note that the conditional mean is
            # k(x, o) @ (k(o, o) + sigma**2)^-1 obs. We can precompute the latter
            # term since it won't change per iteration.
            vec_diff = _vec(observations - mean_fn(observation_index_points))

            if observations_is_missing is not None:
                vec_diff = tf.where(vec_observations_is_missing,
                                    tf.zeros([], dtype=vec_diff.dtype),
                                    vec_diff)
            solve_on_observations = observation_scale.solvevec(
                observation_scale.solvevec(vec_diff), adjoint=True)

            def flattened_conditional_mean_fn(x):

                return _flattened_conditional_mean_fn_helper(
                    x,
                    kernel,
                    observations,
                    observation_index_points,
                    observations_is_missing,
                    observation_scale,
                    mean_fn,
                    solve_on_observations=solve_on_observations)

            mtgprm = MultiTaskGaussianProcessRegressionModel(
                kernel=kernel,
                observation_index_points=observation_index_points,
                observations=observations,
                index_points=index_points,
                observation_noise_variance=observation_noise_variance,
                predictive_noise_variance=predictive_noise_variance,
                cholesky_fn=cholesky_fn,
                _flattened_conditional_mean_fn=flattened_conditional_mean_fn,
                _observation_scale=observation_scale,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                name=name)

        return mtgprm
  def __init__(self,
               kernel,
               observation_index_points,
               observations,
               observations_is_missing=None,
               index_points=None,
               mean_fn=None,
               observation_noise_variance=None,
               predictive_noise_variance=None,
               cholesky_fn=None,
               validate_args=False,
               allow_nan_stats=False,
               name='MultiTaskGaussianProcessRegressionModelWithCholesky'):
    """Construct a MultiTaskGaussianProcessRegressionModelWithCholesky instance.

    WARNING: This method assumes `index_points` is the only varying parameter
    (i.e. is a `Variable` / changes after initialization) and hence is not
    tape-safe.

    Args:
      kernel: `MultiTaskKernel`-like instance representing the GP's covariance
        function.
      observation_index_points: `float` `Tensor` representing finite collection,
        or batch of collections, of points in the index set for which some data
        has been observed. Shape has the form `[b1, ..., bB, e, f1, ..., fF]`
        where `F` is the number of feature dimensions and must equal
        `kernel.feature_ndims`, and `e` is the number (size) of index points in
        each batch. `[b1, ..., bB, e]` must be broadcastable with the shape of
        `observations`, and `[b1, ..., bB]` must be broadcastable with the
        shapes of all other batched parameters (`kernel.batch_shape`,
        `index_points`, etc).
      observations: `float` `Tensor` representing collection, or batch of
        collections, of observations corresponding to
        `observation_index_points`. Shape has the form `[b1, ..., bB, e, t]`,
        which must be broadcastable with the batch and example shapes of
        `observation_index_points`. The batch shape `[b1, ..., bB]` must be
        broadcastable with the shapes of all other batched parameters
        (`kernel.batch_shape`, `index_points`, etc.).
      observations_is_missing:  `bool` `Tensor` of shape `[..., e, t]`,
        representing a batch of boolean masks.  When
        `observations_is_missing` is not `None`, this distribution is
        conditioned only on the observations for which the
        corresponding elements of `observations_is_missing` are `False`.
      index_points: `float` `Tensor` representing finite collection, or batch of
        collections, of points in the index set over which the GP is defined.
        Shape has the form `[b1, ..., bB, e, f1, ..., fF]` where `F` is the
        number of feature dimensions and must equal `kernel.feature_ndims` and
        `e` is the number (size) of index points in each batch. Ultimately this
        distribution corresponds to an `e`-dimensional multivariate normal. The
        batch shape must be broadcastable with `kernel.batch_shape`.
      mean_fn: Python `callable` that acts on `index_points` to produce a (batch
        of) collection of mean values at `index_points`. Takes a `Tensor` of
        shape `[b1, ..., bB, e, f1, ..., fF]` and returns a `Tensor` whose shape
        is broadcastable with `[b1, ..., bB, e, t]`, where `t` is the number of
        tasks.
      observation_noise_variance: `float` `Tensor` representing the variance of
        the noise in the Normal likelihood distribution of the model. May be
        batched, in which case the batch shape must be broadcastable with the
        shapes of all other batched parameters (`kernel.batch_shape`,
        `index_points`, etc.).
        Default value: `None`
      predictive_noise_variance: `float` `Tensor` representing the variance in
        the posterior predictive model. If `None`, we simply re-use
        `observation_noise_variance` for the posterior predictive noise. If set
        explicitly, however, we use this value. This allows us, for example, to
        omit predictive noise variance (by setting this to zero) to obtain
        noiseless posterior predictions of function values, conditioned on noisy
        observations.
      cholesky_fn: Callable which takes a single (batch) matrix argument and
        returns a Cholesky-like lower triangular factor.  Default value: `None`,
          in which case `make_cholesky_with_jitter_fn(1e-6)` is used.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value `NaN` to indicate the result
        is undefined. When `False`, an exception is raised if one or more of the
        statistic's batch members are undefined.
        Default value: `False`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: 'MultiTaskGaussianProcessRegressionModel'.
    """
    parameters = dict(locals())
    with tf.name_scope(name) as name:

      if not isinstance(kernel, multitask_kernel.MultiTaskKernel):
        raise ValueError('`kernel` must be a `MultiTaskKernel`.')

      dtype = dtype_util.common_dtype([
          index_points, observation_index_points, observations,
          observation_noise_variance, predictive_noise_variance
      ], tf.float32)
      index_points = tensor_util.convert_nonref_to_tensor(
          index_points, dtype=dtype, name='index_points')
      observation_index_points = tf.convert_to_tensor(
          observation_index_points,
          dtype=dtype,
          name='observation_index_points')
      observations = tf.convert_to_tensor(
          observations, dtype=dtype, name='observations')
      if observations_is_missing is not None:
        observations_is_missing = tf.convert_to_tensor(
            observations_is_missing, dtype=tf.bool)
      if observation_noise_variance is not None:
        observation_noise_variance = tf.convert_to_tensor(
            observation_noise_variance,
            dtype=dtype,
            name='observation_noise_variance')
      predictive_noise_variance = tensor_util.convert_nonref_to_tensor(
          predictive_noise_variance,
          dtype=dtype,
          name='predictive_noise_variance')
      if predictive_noise_variance is None:
        predictive_noise_variance = observation_noise_variance
      if cholesky_fn is None:
        self._cholesky_fn = cholesky_util.make_cholesky_with_jitter_fn()
      else:
        if not callable(cholesky_fn):
          raise ValueError('`cholesky_fn` must be a Python callable')
        self._cholesky_fn = cholesky_fn

      self._kernel = kernel
      self._index_points = index_points

      # Scalar or vector the size of the number of tasks.
      if mean_fn is not None:
        if not callable(mean_fn):
          raise ValueError('`mean_fn` must be a Python callable')
      self._mean_fn = mean_fn
      self._observation_noise_variance = observation_noise_variance
      self._predictive_noise_variance = predictive_noise_variance
      self._index_ponts = index_points
      self._observation_index_points = observation_index_points
      self._observations = observations
      self._observations_is_missing = observations_is_missing

      observation_covariance = self.kernel.matrix_over_all_tasks(
          observation_index_points, observation_index_points)

      if observation_noise_variance is not None:
        observation_covariance = observation_covariance.to_dense()
        broadcast_shape = distribution_util.get_broadcast_shape(
            observation_covariance, observation_noise_variance[..., tf.newaxis,
                                                               tf.newaxis])
        observation_covariance = tf.broadcast_to(observation_covariance,
                                                 broadcast_shape)
        observation_covariance = _add_diagonal_shift(observation_covariance,
                                                     observation_noise_variance)
        observation_covariance = tf.linalg.LinearOperatorFullMatrix(
            observation_covariance,
            is_non_singular=True,
            is_positive_definite=True)

      if observations_is_missing is not None:
        vec_observations_is_missing = _vec(observations_is_missing)
        observation_covariance = tf.linalg.LinearOperatorFullMatrix(
            psd_kernels_util.mask_matrix(
                observation_covariance.to_dense(),
                mask=~vec_observations_is_missing),
            is_non_singular=True,
            is_positive_definite=True)

      self._observation_cholesky = cholesky_util.cholesky_from_fn(
          observation_covariance, self._cholesky_fn)

      # Note that the conditional mean is
      # k(x, o) @ (k(o, o) + sigma**2)^-1 obs. We can precompute the latter
      # term since it won't change per iteration.
      if mean_fn:
        vec_observations = _vec(observations -
                                mean_fn(observation_index_points))
      else:
        vec_observations = _vec(observations)
      if observations_is_missing is not None:
        vec_observations = tf.where(~vec_observations_is_missing,
                                    vec_observations,
                                    tf.zeros([], dtype=vec_observations.dtype))
      self._solve_on_obs = self._observation_cholesky.solvevec(
          self._observation_cholesky.solvevec(vec_observations), adjoint=True)
      super(MultiTaskGaussianProcessRegressionModel, self).__init__(
          dtype=dtype,
          reparameterization_type=(reparameterization.FULLY_REPARAMETERIZED),
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats,
          parameters=parameters,
          name=name)
Ejemplo n.º 6
0
def _compute_flattened_scale(kernel,
                             index_points,
                             cholesky_fn,
                             observation_noise_variance=None):
    """Computes a matrix square root of the flattened covariance matrix.

  Given a multi-task kernel `k`, computes a matrix square root of the
  matrix over all tasks of `index_points`. That is, compute `S` such that
  `S^T @ S = k.matrix_over_all_tasks(index_points, index_points)`.

  In the case of a `Separable` or `Independent` kernel, this function tries to
  do this efficiently in O(N^3 + T^3) time where `N` is the number of
  `index_points` and `T` is the number of tasks.

  Args:
    kernel: `MultiTaskKernel`-like instance representing the GP's covariance
      function.
    index_points: `float` `Tensor` representing finite collection, or batch of
      collections, of points in the index set over which the GP is defined.
      Shape has the form `[b1, ..., bB, e, f1, ..., fF]` where `F` is the
      number of feature dimensions and must equal `kernel.feature_ndims` and
      `e` is the number (size) of index points in each batch. Ultimately this
      distribution corresponds to an `e`-dimensional multivariate normal. The
      batch shape must be broadcastable with `kernel.batch_shape`.
    cholesky_fn: Callable which takes a single (batch) matrix argument and
      returns a Cholesky-like lower triangular factor.  Default value: `None`,
      in which case `make_cholesky_with_jitter_fn(1e-6)` is used.
    observation_noise_variance: `float` `Tensor` representing the variance
      of the noise in the Normal likelihood distribution of the model. May be
      batched, in which case the batch shape must be broadcastable with the
      shapes of all other batched parameters (`kernel.batch_shape`,
      `index_points`, etc.).
      Default value: `None`
  Returns:
    scale_operator: `LinearOperator` representing a matrix square root of
    the flattened kernel matrix over all tasks.

  """
    # This is of shape KN x KN, where K is the number of outputs
    kernel_matrix = kernel.matrix_over_all_tasks(index_points, index_points)
    if observation_noise_variance is None:
        return cholesky_util.cholesky_from_fn(kernel_matrix, cholesky_fn)

    observation_noise_variance = tf.convert_to_tensor(
        observation_noise_variance)

    # We can add the observation noise to each block.
    if isinstance(kernel, multitask_kernel.Independent):
        # The Independent kernel matrix is realized as a kronecker product of the
        # kernel over inputs, and an identity matrix per task (representing
        # independent tasks). Update the diagonal of the first matrix and take the
        # cholesky of it (since the cholesky of the second matrix will remain the
        # identity matrix.)
        base_kernel_matrix = kernel_matrix.operators[0].to_dense()

        broadcast_shape = distribution_util.get_broadcast_shape(
            base_kernel_matrix, observation_noise_variance[..., tf.newaxis,
                                                           tf.newaxis])
        base_kernel_matrix = tf.broadcast_to(base_kernel_matrix,
                                             broadcast_shape)
        base_kernel_matrix = tf.linalg.set_diag(
            base_kernel_matrix,
            tf.linalg.diag_part(base_kernel_matrix) +
            observation_noise_variance[..., tf.newaxis])
        base_kernel_matrix = tf.linalg.LinearOperatorFullMatrix(
            base_kernel_matrix)
        kernel_matrix = tf.linalg.LinearOperatorKronecker(
            operators=[base_kernel_matrix] + kernel_matrix.operators[1:])
        return cholesky_util.cholesky_from_fn(kernel_matrix, cholesky_fn)

    if isinstance(kernel, multitask_kernel.Separable):
        # When `kernel_matrix` is a kronecker product, we can compute
        # an eigenvalue decomposition to get a matrix square-root, which will
        # be faster than densifying the kronecker product.

        # Let K = A X B. Let A (and B) have an eigenvalue decomposition of
        # U @ D @ U^T, where U is an orthogonal matrix. Then,
        # K = (U_A @ D_A @ U_A^T) X (U_B @ D_B @ U_B^T) =
        # (U_A X U_B) @ (D_A X D_B) @ (U_A X U_B)^T
        # Thus, a matrix square root of K would be
        # (U_A X U_B) @ (sqrt(D_A) X sqrt(D_B)) which offers
        # efficient matmul and solves.

        # Now, if we update the diagonal by `v * I`, we have
        # (U_A X U_B) @ (sqrt((D_A X D_B + vI)) @ (U_A X U_B)^T
        # which still admits an efficient matmul and solve.

        kronecker_diags = []
        kronecker_orths = []
        for block in kernel_matrix.operators:
            diag, orth = tf.linalg.eigh(block.to_dense())
            kronecker_diags.append(tf.linalg.LinearOperatorDiag(diag))
            kronecker_orths.append(
                linear_operator_unitary.LinearOperatorUnitary(orth))

        full_diag = tf.linalg.LinearOperatorKronecker(
            kronecker_diags).diag_part()
        full_diag = full_diag + observation_noise_variance[..., tf.newaxis]
        scale_diag = tf.math.sqrt(full_diag)
        diag_operator = tf.linalg.LinearOperatorDiag(scale_diag,
                                                     is_square=True,
                                                     is_non_singular=True,
                                                     is_positive_definite=True)

        orthogonal_operator = tf.linalg.LinearOperatorKronecker(
            kronecker_orths, is_square=True, is_non_singular=True)
        # This is efficient as a scale matrix. When used for matmuls, we take
        # advantage of the kronecker product and diagonal operator. When used for
        # solves, we take advantage of the orthogonal and diagonal structure,
        # which essentially reduces to the matmul case.
        return orthogonal_operator.matmul(diag_operator)

    # By default densify the kernel matrix and add noise.

    kernel_matrix = kernel_matrix.to_dense()
    broadcast_shape = distribution_util.get_broadcast_shape(
        kernel_matrix, observation_noise_variance[..., tf.newaxis, tf.newaxis])
    kernel_matrix = tf.broadcast_to(kernel_matrix, broadcast_shape)
    kernel_matrix = tf.linalg.set_diag(
        kernel_matrix,
        tf.linalg.diag_part(kernel_matrix) +
        observation_noise_variance[..., tf.newaxis])
    kernel_matrix = tf.linalg.LinearOperatorFullMatrix(kernel_matrix)
    kernel_cholesky = cholesky_util.cholesky_from_fn(kernel_matrix,
                                                     cholesky_fn)
    return kernel_cholesky