def _get_flattened_marginal_distribution(self, index_points=None): # This returns a MVN of event size [N * E], where N is the number of tasks # and E is the number of index points. with self._name_and_control_scope( 'get_flattened_marginal_distribution'): index_points = self._get_index_points(index_points) covariance = self._compute_flattened_covariance(index_points) batch_shape = self._batch_shape_tensor(index_points=index_points) event_shape = self._event_shape_tensor(index_points=index_points) # Now take the cholesky but specialize to cases where we have block-diag # and kronecker. covariance_cholesky = cholesky_util.cholesky_from_fn( covariance, self._cholesky_fn) loc = self._mean_fn(index_points) # Ensure that we broadcast the mean function result to ensure we support # constant mean functions (constant over all tasks, and a constant # per-task) loc = ps.broadcast_to( loc, ps.concat([batch_shape, event_shape], axis=0)) loc = _vec(loc) return mvn_linear_operator.MultivariateNormalLinearOperator( loc=loc, scale=covariance_cholesky, validate_args=self._validate_args, allow_nan_stats=self._allow_nan_stats, name='marginal_distribution')
def _get_flattened_marginal_distribution(self, index_points=None): # This returns a MVN of event size [N * E], where N is the number of tasks # and E is the number of index points. with self._name_and_control_scope( 'get_flattened_marginal_distribution'): index_points = self._get_index_points(index_points) scale = _compute_flattened_scale( kernel=self.kernel, index_points=index_points, cholesky_fn=self._cholesky_fn, observation_noise_variance=self.observation_noise_variance) batch_shape = self._batch_shape_tensor(index_points=index_points) event_shape = self._event_shape_tensor(index_points=index_points) loc = self._mean_fn(index_points) # Ensure that we broadcast the mean function result to ensure we support # constant mean functions (constant over all tasks, and a constant # per-task) loc = ps.broadcast_to( loc, ps.concat([batch_shape, event_shape], axis=0)) loc = _vec(loc) return mvn_linear_operator.MultivariateNormalLinearOperator( loc=loc, scale=scale, validate_args=self._validate_args, allow_nan_stats=self._allow_nan_stats, name='marginal_distribution')
def _as_multivariate_normal(self, loc=None): # Rebuild the Multivariate Normal Distribution on every call because the # underlying tensor shapes might have changed. loc = tf.convert_to_tensor(self.loc if loc is None else loc) return mvn_linear_operator.MultivariateNormalLinearOperator( loc=_vec(loc), scale=tf.linalg.LinearOperatorKronecker( [self.scale_row, self.scale_column]), validate_args=self.validate_args)
def marginal_fn(loc, covariance, validate_args=False, allow_nan_stats=False, name=name): with tf.name_scope(name) as name: scale = tf.linalg.LinearOperatorLowerTriangular( cholesky_like(covariance), is_non_singular=True) return mvn_linear_operator.MultivariateNormalLinearOperator( loc=loc, scale=scale, validate_args=validate_args, allow_nan_stats=allow_nan_stats)
def marginal_fn(loc, covariance, validate_args=False, allow_nan_stats=False, name='marginal_distribution'): scale = tf.linalg.LinearOperatorLowerTriangular( tf.linalg.cholesky(_add_diagonal_shift(covariance, jitter)), is_non_singular=True, name='GaussianProcessScaleLinearOperator') return mvn_linear_operator.MultivariateNormalLinearOperator( loc=loc, scale=scale, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name)
def get_marginal_distribution(self, index_points=None): """Compute the marginal of this GP over function values at `index_points`. Args: index_points: `float` `Tensor` representing finite (batch of) vector(s) of points in the index set over which the GP is defined. Shape has the form `[b1, ..., bB, e, f1, ..., fF]` where `F` is the number of feature dimensions and must equal `kernel.feature_ndims` and `e` is the number (size) of index points in each batch. Ultimately this distribution corresponds to a `e`-dimensional multivariate normal. The batch shape must be broadcastable with `kernel.batch_shape` and any batch dims yielded by `mean_fn`. Returns: marginal: a `Normal` or `MultivariateNormalLinearOperator` distribution, according to whether `index_points` consists of one or many index points, respectively. """ with self._name_and_control_scope('get_marginal_distribution'): # TODO(cgs): consider caching the result here, keyed on `index_points`. index_points = self._get_index_points(index_points) covariance = self._compute_covariance(index_points) loc = self._mean_fn(index_points) # If we're sure the number of index points is 1, we can just construct a # scalar Normal. This has computational benefits and supports things like # CDF that aren't otherwise straightforward to provide. if self._is_univariate_marginal(index_points): scale = tf.sqrt(covariance) # `loc` has a trailing 1 in the shape; squeeze it. loc = tf.squeeze(loc, axis=-1) return normal.Normal(loc=loc, scale=scale, validate_args=self._validate_args, allow_nan_stats=self._allow_nan_stats, name='marginal_distribution') else: scale = tf.linalg.LinearOperatorLowerTriangular( tf.linalg.cholesky( _add_diagonal_shift(covariance, self.jitter)), is_non_singular=True, name='GaussianProcessScaleLinearOperator') return mvn_linear_operator.MultivariateNormalLinearOperator( loc=loc, scale=scale, validate_args=self._validate_args, allow_nan_stats=self._allow_nan_stats, name='marginal_distribution')
def _sample_n(self, n, seed=None): # Like with the univariate Student's t, sampling can be implemented as a # ratio of samples from a multivariate gaussian with the appropriate # covariance matrix and a sample from the chi-squared distribution. seed = seed_stream.SeedStream(seed, salt="multivariate t") loc = tf.broadcast_to(self.loc, self._sample_shape()) mvn = mvn_linear_operator.MultivariateNormalLinearOperator( loc=tf.zeros_like(loc), scale=self.scale) normal_samp = mvn.sample(n, seed=seed()) df = tf.broadcast_to(self.df, self.batch_shape_tensor()) chi2 = chi2_lib.Chi2(df=df) chi2_samp = chi2.sample(n, seed=seed()) return (self._loc + normal_samp * tf.math.rsqrt(chi2_samp / self._df)[..., tf.newaxis])
def _get_flattened_marginal_distribution(self, index_points=None): # This returns a MVN of event size [N * E], where N is the number of tasks # and E is the number of index points. with self._name_and_control_scope('get_flattened_marginal_distribution'): index_points = self._get_index_points(index_points) covariance = self._compute_flattened_covariance(index_points) loc = self._conditional_mean_fn(index_points) scale = tf.linalg.LinearOperatorLowerTriangular( self._cholesky_fn(covariance), is_non_singular=True, name='GaussianProcessScaleLinearOperator') return mvn_linear_operator.MultivariateNormalLinearOperator( loc=loc, scale=scale, validate_args=self._validate_args, allow_nan_stats=self._allow_nan_stats, name='marginal_distribution')
def eigh_marginal_fn(loc, covariance, validate_args=False, allow_nan_stats=False, name=name): """Compute EigH-based square root and return a MVN.""" with tf.name_scope(name) as name: values, vectors = tf.linalg.eigh(covariance) safe_root = tf.math.sqrt(tf.where(values < tol, tol, values)) scale = tf.linalg.LinearOperatorFullMatrix( tf.einsum('...ij,...j->...ij', vectors, safe_root), is_square=True, is_positive_definite=True, is_non_singular=True, name='GaussianProcessEigHScaleLinearOperator') return mvn_linear_operator.MultivariateNormalLinearOperator( loc=loc, scale=scale, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name)
def __init__(self, kernel, index_points, inducing_index_points, variational_inducing_observations_loc, variational_inducing_observations_scale, mean_fn=None, observation_noise_variance=0., predictive_noise_variance=0., jitter=1e-6, validate_args=False, allow_nan_stats=False, name='VariataionalGaussianProcess'): """Instantiate a VariationalGaussianProcess Distribution. Args: kernel: `PositiveSemidefiniteKernel`-like instance representing the GP's covariance function. index_points: `float` `Tensor` representing finite (batch of) vector(s) of points in the index set over which the VGP is defined. Shape has the form `[b1, ..., bB, e1, f1, ..., fF]` where `F` is the number of feature dimensions and must equal `kernel.feature_ndims` and `e1` is the number (size) of index points in each batch (we denote it `e1` to distinguish it from the numer of inducing index points, denoted `e2` below). Ultimately the VariationalGaussianProcess distribution corresponds to an `e1`-dimensional multivariate normal. The batch shape must be broadcastable with `kernel.batch_shape`, the batch shape of `inducing_index_points`, and any batch dims yielded by `mean_fn`. inducing_index_points: `float` `Tensor` of locations of inducing points in the index set. Shape has the form `[b1, ..., bB, e2, f1, ..., fF]`, just like `index_points`. The batch shape components needn't be identical to those of `index_points`, but must be broadcast compatible with them. variational_inducing_observations_loc: `float` `Tensor`; the mean of the (full-rank Gaussian) variational posterior over function values at the inducing points, conditional on observed data. Shape has the form `[b1, ..., bB, e2]`, where `b1, ..., bB` is broadcast compatible with other parameters' batch shapes, and `e2` is the number of inducing points. variational_inducing_observations_scale: `float` `Tensor`; the scale matrix of the (full-rank Gaussian) variational posterior over function values at the inducing points, conditional on observed data. Shape has the form `[b1, ..., bB, e2, e2]`, where `b1, ..., bB` is broadcast compatible with other parameters and `e2` is the number of inducing points. mean_fn: Python `callable` that acts on index points to produce a (batch of) vector(s) of mean values at those index points. Takes a `Tensor` of shape `[b1, ..., bB, f1, ..., fF]` and returns a `Tensor` whose shape is (broadcastable with) `[b1, ..., bB]`. Default value: `None` implies constant zero function. observation_noise_variance: `float` `Tensor` representing the variance of the noise in the Normal likelihood distribution of the model. May be batched, in which case the batch shape must be broadcastable with the shapes of all other batched parameters (`kernel.batch_shape`, `index_points`, etc.). Default value: `0.` predictive_noise_variance: `float` `Tensor` representing additional variance in the posterior predictive model. If `None`, we simply re-use `observation_noise_variance` for the posterior predictive noise. If set explicitly, however, we use the given value. This allows us, for example, to omit predictive noise variance (by setting this to zero) to obtain noiseless posterior predictions of function values, conditioned on noisy observations. jitter: `float` scalar `Tensor` added to the diagonal of the covariance matrix to ensure positive definiteness of the covariance matrix. Default value: `1e-6`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. Default value: `False`. name: Python `str` name prefixed to Ops created by this class. Default value: "VariationalGaussianProcess". Raises: ValueError: if `mean_fn` is not `None` and is not callable. """ parameters = dict(locals()) with tf.name_scope(name or 'VariationalGaussianProcess') as name: dtype = dtype_util.common_dtype([ kernel, index_points, inducing_index_points, variational_inducing_observations_loc, variational_inducing_observations_scale, observation_noise_variance, predictive_noise_variance, jitter ], tf.float32) index_points = tf.convert_to_tensor(index_points, dtype=dtype, name='index_points') inducing_index_points = tf.convert_to_tensor( inducing_index_points, dtype=dtype, name='inducing_index_points') variational_inducing_observations_loc = tf.convert_to_tensor( variational_inducing_observations_loc, dtype=dtype, name='variational_inducing_observations_loc') variational_inducing_observations_scale = tf.convert_to_tensor( variational_inducing_observations_scale, dtype=dtype, name='variational_inducing_observations_scale') observation_noise_variance = tf.convert_to_tensor( observation_noise_variance, dtype=dtype, name='observation_noise_variance') if predictive_noise_variance is None: predictive_noise_variance = observation_noise_variance else: predictive_noise_variance = tf.convert_to_tensor( predictive_noise_variance, dtype=dtype, name='predictive_noise_variance') jitter = tf.convert_to_tensor(jitter, dtype=dtype, name='jitter') self._kernel = kernel self._index_points = index_points self._inducing_index_points = inducing_index_points self._variational_inducing_observations_posterior = ( mvn_linear_operator.MultivariateNormalLinearOperator( loc=variational_inducing_observations_loc, scale=tf.linalg.LinearOperatorFullMatrix( variational_inducing_observations_scale), name='variational_inducing_observations_posterior')) # Default to a constant zero function, borrowing the dtype from # index_points to ensure consistency. if mean_fn is None: mean_fn = lambda x: tf.zeros([1], dtype=dtype) else: if not callable(mean_fn): raise ValueError('`mean_fn` must be a Python callable') self._mean_fn = mean_fn self._observation_noise_variance = observation_noise_variance self._predictive_noise_variance = predictive_noise_variance self._jitter = jitter with tf.name_scope('init'): # We let t and z denote predictive and inducing index points, resp. kzz = _add_diagonal_shift( kernel.matrix(inducing_index_points, inducing_index_points), jitter) self._chol_kzz = tf.linalg.cholesky(kzz) self._kzz_inv_varloc = _solve_cholesky_factored_system_vec( self._chol_kzz, (variational_inducing_observations_loc - mean_fn(inducing_index_points)), name='kzz_inv_varloc') loc, scale = self._compute_posterior_predictive_params() super(VariationalGaussianProcess, self).__init__(loc=loc, scale=scale, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters self._graph_parents = [ index_points, inducing_index_points, variational_inducing_observations_loc, variational_inducing_observations_scale, observation_noise_variance, jitter ]