def surrogate_posterior_kl_divergence_prior(self, name=None): """Compute `KL(surrogate inducing point posterior || prior)`. See [Hensman, 2013][1]. Args: name: Python `str` name prefixed to Ops created by this class. Default value: 'surrogate_posterior_kl_divergence_prior'. Returns: kl: Scalar tensor representing the KL between the (surrogate/variational) posterior over inducing point function values, and the GP prior over the inducing point function values. #### References [1]: Hensman, J., Lawrence, N. "Gaussian Processes for Big Data", 2013 https://arxiv.org/abs/1309.6835 """ with tf.name_scope(name or 'surrogate_posterior_kl_divergence_prior'): inducing_prior = gaussian_process.GaussianProcess( kernel=self._kernel, mean_fn=self._mean_fn, index_points=self._inducing_index_points, observation_noise_variance=self._observation_noise_variance) return kullback_leibler.kl_divergence( self._variational_inducing_observations_posterior, inducing_prior)
def variational_loss(self, observations, observation_index_points=None, kl_weight=1., name='variational_loss'): """Variational loss for the VGP. Given `observations` and `observation_index_points`, compute the negative variational lower bound as specified in [Hensman, 2013][1]. Args: observations: `float` `Tensor` representing collection, or batch of collections, of observations corresponding to `observation_index_points`. Shape has the form `[b1, ..., bB, e]`, which must be brodcastable with the batch and example shapes of `observation_index_points`. The batch shape `[b1, ..., bB]` must be broadcastable with the shapes of all other batched parameters (`kernel.batch_shape`, `observation_index_points`, etc.). observation_index_points: `float` `Tensor` representing finite (batch of) vector(s) of points where observations are defined. Shape has the form `[b1, ..., bB, e1, f1, ..., fF]` where `F` is the number of feature dimensions and must equal `kernel.feature_ndims` and `e1` is the number (size) of index points in each batch (we denote it `e1` to distinguish it from the numer of inducing index points, denoted `e2` below). If set to `None` uses `index_points` as the origin for observations. Default value: None. kl_weight: Amount by which to scale the KL divergence loss between prior and posterior. Default value: 1. name: Python `str` name prefixed to Ops created by this class. Default value: "GaussianProcess". Returns: loss: Scalar tensor representing the negative variational lower bound. Can be directly used in a `tf.Optimizer`. Raises: ValueError: if `mean_fn` is not `None` and is not callable. #### References [1]: Hensman, J., Lawrence, N. "Gaussian Processes for Big Data", 2013 https://arxiv.org/abs/1309.6835 """ with tf.compat.v1.name_scope( name, 'variational_gp_loss', values=[ observations, observation_index_points, kl_weight]): if observation_index_points is None: observation_index_points = self._index_points observation_index_points = tf.convert_to_tensor( value=observation_index_points, dtype=self._dtype, name='observation_index_points') observations = tf.convert_to_tensor( value=observations, dtype=self._dtype, name='observations') kl_weight = tf.convert_to_tensor( value=kl_weight, dtype=self._dtype, name='kl_weight') # The variational loss is a negative ELBO. The ELBO can be broken down # into three terms: # 1. a likelihood term # 2. a trace term arising from the covariance of the posterior predictive kzx = self.kernel.matrix(self._inducing_index_points, observation_index_points) kzx_linop = tf.linalg.LinearOperatorFullMatrix(kzx) loc = (self._mean_fn(observation_index_points) + kzx_linop.matvec(self._kzz_inv_varloc, adjoint=True)) likelihood = independent.Independent( normal.Normal( loc=loc, scale=tf.sqrt(self._observation_noise_variance + self._jitter), name='NormalLikelihood'), reinterpreted_batch_ndims=1) obs_ll = likelihood.log_prob(observations) chol_kzz_linop = tf.linalg.LinearOperatorLowerTriangular(self._chol_kzz) chol_kzz_inv_kzx = chol_kzz_linop.solve(kzx) kzz_inv_kzx = chol_kzz_linop.solve(chol_kzz_inv_kzx, adjoint=True) kxx_diag = tf.linalg.diag_part( self.kernel.matrix( observation_index_points, observation_index_points)) ktilde_trace_term = ( tf.reduce_sum(input_tensor=kxx_diag, axis=-1) - tf.reduce_sum(input_tensor=chol_kzz_inv_kzx ** 2, axis=[-2, -1])) # Tr(SB) # where S = A A.T, A = variational_inducing_observations_scale # and B = Kzz^-1 Kzx Kzx.T Kzz^-1 # # Now Tr(SB) = Tr(A A.T Kzz^-1 Kzx Kzx.T Kzz^-1) # = Tr(A.T Kzz^-1 Kzx Kzx.T Kzz^-1 A) # = sum_ij (A.T Kzz^-1 Kzx)_{ij}^2 other_trace_term = tf.reduce_sum( input_tensor=( self._variational_inducing_observations_posterior.scale.matmul( kzz_inv_kzx) ** 2), axis=[-2, -1]) trace_term = (.5 * (ktilde_trace_term + other_trace_term) / self._observation_noise_variance) inducing_prior = gaussian_process.GaussianProcess( kernel=self._kernel, mean_fn=self._mean_fn, index_points=self._inducing_index_points, observation_noise_variance=self._observation_noise_variance) kl_term = kl_weight * kullback_leibler.kl_divergence( self._variational_inducing_observations_posterior, inducing_prior) lower_bound = (obs_ll - trace_term - kl_term) return -tf.reduce_mean(input_tensor=lower_bound)