def surrogate_posterior_kl_divergence_prior(self, name=None):
        """Compute `KL(surrogate inducing point posterior || prior)`.

    See [Hensman, 2013][1].

    Args:
      name: Python `str` name prefixed to Ops created by this class.
        Default value: 'surrogate_posterior_kl_divergence_prior'.
    Returns:
      kl: Scalar tensor representing the KL between the (surrogate/variational)
        posterior over inducing point function values, and the GP prior over
        the inducing point function values.

    #### References

    [1]: Hensman, J., Lawrence, N. "Gaussian Processes for Big Data", 2013
         https://arxiv.org/abs/1309.6835
    """
        with tf.name_scope(name or 'surrogate_posterior_kl_divergence_prior'):
            inducing_prior = gaussian_process.GaussianProcess(
                kernel=self._kernel,
                mean_fn=self._mean_fn,
                index_points=self._inducing_index_points,
                observation_noise_variance=self._observation_noise_variance)

            return kullback_leibler.kl_divergence(
                self._variational_inducing_observations_posterior,
                inducing_prior)
  def variational_loss(self,
                       observations,
                       observation_index_points=None,
                       kl_weight=1.,
                       name='variational_loss'):
    """Variational loss for the VGP.

    Given `observations` and `observation_index_points`, compute the
    negative variational lower bound as specified in [Hensman, 2013][1].

    Args:
      observations: `float` `Tensor` representing collection, or batch of
        collections, of observations corresponding to
        `observation_index_points`. Shape has the form `[b1, ..., bB, e]`, which
        must be brodcastable with the batch and example shapes of
        `observation_index_points`. The batch shape `[b1, ..., bB]` must be
        broadcastable with the shapes of all other batched parameters
        (`kernel.batch_shape`, `observation_index_points`, etc.).
      observation_index_points: `float` `Tensor` representing finite (batch of)
        vector(s) of points where observations are defined. Shape has the
        form `[b1, ..., bB, e1, f1, ..., fF]` where `F` is the number of feature
        dimensions and must equal `kernel.feature_ndims` and `e1` is the number
        (size) of index points in each batch (we denote it `e1` to distinguish
        it from the numer of inducing index points, denoted `e2` below). If
        set to `None` uses `index_points` as the origin for observations.
        Default value: None.
      kl_weight: Amount by which to scale the KL divergence loss between prior
        and posterior.
        Default value: 1.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: "GaussianProcess".
    Returns:
      loss: Scalar tensor representing the negative variational lower bound.
        Can be directly used in a `tf.Optimizer`.
    Raises:
      ValueError: if `mean_fn` is not `None` and is not callable.

    #### References

    [1]: Hensman, J., Lawrence, N. "Gaussian Processes for Big Data", 2013
         https://arxiv.org/abs/1309.6835
    """

    with tf.compat.v1.name_scope(
        name, 'variational_gp_loss', values=[
            observations,
            observation_index_points,
            kl_weight]):
      if observation_index_points is None:
        observation_index_points = self._index_points
      observation_index_points = tf.convert_to_tensor(
          value=observation_index_points, dtype=self._dtype,
          name='observation_index_points')
      observations = tf.convert_to_tensor(
          value=observations, dtype=self._dtype, name='observations')
      kl_weight = tf.convert_to_tensor(
          value=kl_weight, dtype=self._dtype,
          name='kl_weight')

      # The variational loss is a negative ELBO. The ELBO can be broken down
      # into three terms:
      #  1. a likelihood term
      #  2. a trace term arising from the covariance of the posterior predictive

      kzx = self.kernel.matrix(self._inducing_index_points,
                               observation_index_points)

      kzx_linop = tf.linalg.LinearOperatorFullMatrix(kzx)
      loc = (self._mean_fn(observation_index_points) +
             kzx_linop.matvec(self._kzz_inv_varloc, adjoint=True))

      likelihood = independent.Independent(
          normal.Normal(
              loc=loc,
              scale=tf.sqrt(self._observation_noise_variance + self._jitter),
              name='NormalLikelihood'),
          reinterpreted_batch_ndims=1)
      obs_ll = likelihood.log_prob(observations)

      chol_kzz_linop = tf.linalg.LinearOperatorLowerTriangular(self._chol_kzz)
      chol_kzz_inv_kzx = chol_kzz_linop.solve(kzx)
      kzz_inv_kzx = chol_kzz_linop.solve(chol_kzz_inv_kzx, adjoint=True)

      kxx_diag = tf.linalg.diag_part(
          self.kernel.matrix(
              observation_index_points, observation_index_points))
      ktilde_trace_term = (
          tf.reduce_sum(input_tensor=kxx_diag, axis=-1) -
          tf.reduce_sum(input_tensor=chol_kzz_inv_kzx ** 2, axis=[-2, -1]))

      # Tr(SB)
      # where S = A A.T, A = variational_inducing_observations_scale
      # and B = Kzz^-1 Kzx Kzx.T Kzz^-1
      #
      # Now Tr(SB) = Tr(A A.T Kzz^-1 Kzx Kzx.T Kzz^-1)
      #            = Tr(A.T Kzz^-1 Kzx Kzx.T Kzz^-1 A)
      #            = sum_ij (A.T Kzz^-1 Kzx)_{ij}^2
      other_trace_term = tf.reduce_sum(
          input_tensor=(
              self._variational_inducing_observations_posterior.scale.matmul(
                  kzz_inv_kzx) ** 2),
          axis=[-2, -1])

      trace_term = (.5 * (ktilde_trace_term + other_trace_term) /
                    self._observation_noise_variance)

      inducing_prior = gaussian_process.GaussianProcess(
          kernel=self._kernel,
          mean_fn=self._mean_fn,
          index_points=self._inducing_index_points,
          observation_noise_variance=self._observation_noise_variance)

      kl_term = kl_weight * kullback_leibler.kl_divergence(
          self._variational_inducing_observations_posterior,
          inducing_prior)

      lower_bound = (obs_ll - trace_term - kl_term)

      return -tf.reduce_mean(input_tensor=lower_bound)