def testJitterFn(self): cholesky_fn = cholesky_util.make_cholesky_with_jitter_fn(jitter=0.) x = tf.random.normal(shape=[2, 4, 4], seed=test_util.test_seed()) x = tf.linalg.matmul(x, x, transpose_b=True) actual_chol, expected_chol = self.evaluate( [cholesky_fn(x), tf.linalg.cholesky(x)]) self.assertAllClose(expected_chol, actual_chol) cholesky_fn = cholesky_util.make_cholesky_with_jitter_fn(jitter=1.) x = 3. * tf.linalg.eye(3) self.assertAllClose(self.evaluate(2. * tf.linalg.eye(3)), self.evaluate(cholesky_fn(x)))
def testSchurComplementCholeskyFn(self): base_kernel = tfpk.ExponentiatedQuadratic([1., 2.]) fixed_inputs = tf.ones([0, 2], np.float32) cholesky_fn = cholesky_util.make_cholesky_with_jitter_fn(jitter=1e-5) schur = tfpk.SchurComplement(base_kernel, fixed_inputs, cholesky_fn=cholesky_fn) schur_actual = tfpk.SchurComplement(base_kernel, fixed_inputs) self.assertEqual(cholesky_fn, schur.cholesky_fn) x = np.ones([4, 3], np.float32) y = np.ones([5, 3], np.float32) self.assertAllClose(self.evaluate(schur_actual.matrix(x, y)), self.evaluate(schur.matrix(x, y)))
def precompute_regression_model( kernel, observation_index_points, observations, observations_is_missing=None, index_points=None, observation_noise_variance=None, predictive_noise_variance=None, mean_fn=None, cholesky_fn=None, validate_args=False, allow_nan_stats=False, name='PrecomputedMultiTaskGaussianProcessRegressionModel'): """Returns a MTGaussianProcessRegressionModel with precomputed quantities. This differs from the constructor by precomputing quantities associated with observations in a non-tape safe way. `index_points` is the only parameter that is allowed to vary (i.e. is a `Variable` / changes after initialization). Specifically: * We make `observation_index_points` and `observations` mandatory parameters. * We precompute `kernel(observation_index_points, observation_index_points)` along with any other associated quantities relating to the `kernel`, `observations` and `observation_index_points`. A typical usecase would be optimizing kernel hyperparameters for a `MultiTaskGaussianProcess`, and computing the posterior predictive with respect to those optimized hyperparameters and observation / index-points pairs. WARNING: This method assumes `index_points` is the only varying parameter (i.e. is a `Variable` / changes after initialization) and hence is not tape-safe. Args: kernel: `PositiveSemidefiniteKernel`-like instance representing the GP's covariance function. observation_index_points: `float` `Tensor` representing finite collection, or batch of collections, of points in the index set for which some data has been observed. Shape has the form `[b1, ..., bB, e, f1, ..., fF]` where `F` is the number of feature dimensions and must equal `kernel.feature_ndims`, and `e` is the number (size) of index points in each batch. `[b1, ..., bB, e]` must be broadcastable with the shape of `observations`, and `[b1, ..., bB]` must be broadcastable with the shapes of all other batched parameters (`kernel.batch_shape`, `index_points`, etc). The default value is `None`, which corresponds to the empty set of observations, and simply results in the prior predictive model (a GP with noise of variance `predictive_noise_variance`). observations: `float` `Tensor` representing collection, or batch of collections, of observations corresponding to `observation_index_points`. Shape has the form `[b1, ..., bB, e, t]` The batch shape `[b1, ..., bB]` must be broadcastable with the shapes of all other batched parameters (`kernel.batch_shape`, `index_points`, etc.). The default value is `None`, which corresponds to the empty set of observations, and simply results in the prior predictive model (a GP with noise of variance `predictive_noise_variance`). observations_is_missing: `bool` `Tensor` of shape `[..., e]`, representing a batch of boolean masks. When `observations_is_missing` is not `None`, the returned distribution is conditioned only on the observations for which the corresponding elements of `observations_is_missing` are `True`. index_points: `float` `Tensor` representing finite collection, or batch of collections, of points in the index set over which the GP is defined. Shape has the form `[b1, ..., bB, e, f1, ..., fF]` where `F` is the number of feature dimensions and must equal `kernel.feature_ndims` and `e` is the number (size) of index points in each batch. Ultimately this distribution corresponds to an `e`-dimensional multivariate normal. The batch shape must be broadcastable with `kernel.batch_shape` and any batch dims yielded by `mean_fn`. observation_noise_variance: `float` `Tensor` representing the variance of the noise in the Normal likelihood distribution of the model. May be batched, in which case the batch shape must be broadcastable with the shapes of all other batched parameters (`kernel.batch_shape`, `index_points`, etc.). Default value: `None` predictive_noise_variance: `float` `Tensor` representing the variance in the posterior predictive model. If `None`, we simply re-use `observation_noise_variance` for the posterior predictive noise. If set explicitly, however, we use this value. This allows us, for example, to omit predictive noise variance (by setting this to zero) to obtain noiseless posterior predictions of function values, conditioned on noisy observations. mean_fn: Python `callable` that acts on `index_points` to produce a collection, or batch of collections, of mean values at `index_points`. Takes a `Tensor` of shape `[b1, ..., bB, f1, ..., fF]` and returns a `Tensor` whose shape is broadcastable with `[b1, ..., bB, t]`. Default value: `None` implies the constant zero function. cholesky_fn: Callable which takes a single (batch) matrix argument and returns a Cholesky-like lower triangular factor. Default value: `None`, in which case `make_cholesky_with_jitter_fn` is used with the `jitter` parameter. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value `NaN` to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. Default value: `False`. name: Python `str` name prefixed to Ops created by this class. Default value: 'PrecomputedGaussianProcessRegressionModel'. Returns An instance of `MultiTaskGaussianProcessRegressionModel` with precomputed quantities associated with observations. """ with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([ index_points, observation_index_points, observations, observation_noise_variance, predictive_noise_variance, ], tf.float32) # Convert-to-tensor arguments that are expected to not be Variables / not # going to change. observation_index_points = tf.convert_to_tensor( observation_index_points, dtype=dtype) if observation_noise_variance is not None: observation_noise_variance = tf.convert_to_tensor( observation_noise_variance, dtype=dtype) observations = tf.convert_to_tensor(observations, dtype=dtype) if observations_is_missing is not None: observations_is_missing = tf.convert_to_tensor( observations_is_missing) if cholesky_fn is None: cholesky_fn = cholesky_util.make_cholesky_with_jitter_fn() else: if not callable(cholesky_fn): raise ValueError('`cholesky_fn` must be a Python callable') if mean_fn is None: mean_fn = lambda x: tf.zeros([1], dtype=dtype) else: if not callable(mean_fn): raise ValueError('`mean_fn` must be a Python callable') if observations_is_missing is not None: # If observations are missing, there's nothing we can do to preserve the # operator structure, so densify. observation_covariance = kernel.matrix_over_all_tasks( observation_index_points, observation_index_points).to_dense() if observation_noise_variance is not None: broadcast_shape = distribution_util.get_broadcast_shape( observation_covariance, observation_noise_variance[..., tf.newaxis, tf.newaxis]) observation_covariance = tf.broadcast_to( observation_covariance, broadcast_shape) observation_covariance = _add_diagonal_shift( observation_covariance, observation_noise_variance) vec_observations_is_missing = _vec(observations_is_missing) observation_covariance = tf.linalg.LinearOperatorFullMatrix( psd_kernels_util.mask_matrix( observation_covariance, is_missing=vec_observations_is_missing), is_non_singular=True, is_positive_definite=True) observation_scale = cholesky_util.cholesky_from_fn( observation_covariance, cholesky_fn) else: observation_scale = mtgp._compute_flattened_scale( # pylint:disable=protected-access kernel=kernel, index_points=observation_index_points, cholesky_fn=cholesky_fn, observation_noise_variance=observation_noise_variance) # Note that the conditional mean is # k(x, o) @ (k(o, o) + sigma**2)^-1 obs. We can precompute the latter # term since it won't change per iteration. vec_diff = _vec(observations - mean_fn(observation_index_points)) if observations_is_missing is not None: vec_diff = tf.where(vec_observations_is_missing, tf.zeros([], dtype=vec_diff.dtype), vec_diff) solve_on_observations = observation_scale.solvevec( observation_scale.solvevec(vec_diff), adjoint=True) def flattened_conditional_mean_fn(x): return _flattened_conditional_mean_fn_helper( x, kernel, observations, observation_index_points, observations_is_missing, observation_scale, mean_fn, solve_on_observations=solve_on_observations) mtgprm = MultiTaskGaussianProcessRegressionModel( kernel=kernel, observation_index_points=observation_index_points, observations=observations, index_points=index_points, observation_noise_variance=observation_noise_variance, predictive_noise_variance=predictive_noise_variance, cholesky_fn=cholesky_fn, _flattened_conditional_mean_fn=flattened_conditional_mean_fn, _observation_scale=observation_scale, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) return mtgprm
def __init__(self, kernel, observation_index_points, observations, observations_is_missing=None, index_points=None, mean_fn=None, observation_noise_variance=None, predictive_noise_variance=None, cholesky_fn=None, validate_args=False, allow_nan_stats=False, name='MultiTaskGaussianProcessRegressionModelWithCholesky', _flattened_conditional_mean_fn=None, _observation_scale=None): """Construct a MultiTaskGaussianProcessRegressionModelWithCholesky instance. Args: kernel: `MultiTaskKernel`-like instance representing the GP's covariance function. observation_index_points: `float` `Tensor` representing finite collection, or batch of collections, of points in the index set for which some data has been observed. Shape has the form `[b1, ..., bB, e, f1, ..., fF]` where `F` is the number of feature dimensions and must equal `kernel.feature_ndims`, and `e` is the number (size) of index points in each batch. `[b1, ..., bB, e]` must be broadcastable with the shape of `observations`, and `[b1, ..., bB]` must be broadcastable with the shapes of all other batched parameters (`kernel.batch_shape`, `index_points`, etc). observations: `float` `Tensor` representing collection, or batch of collections, of observations corresponding to `observation_index_points`. Shape has the form `[b1, ..., bB, e, t]`, which must be broadcastable with the batch and example shapes of `observation_index_points`. The batch shape `[b1, ..., bB]` must be broadcastable with the shapes of all other batched parameters (`kernel.batch_shape`, `index_points`, etc.). observations_is_missing: `bool` `Tensor` of shape `[..., e, t]`, representing a batch of boolean masks. When `observations_is_missing` is not `None`, this distribution is conditioned only on the observations for which the corresponding elements of `observations_is_missing` are `False`. index_points: `float` `Tensor` representing finite collection, or batch of collections, of points in the index set over which the GP is defined. Shape has the form `[b1, ..., bB, e, f1, ..., fF]` where `F` is the number of feature dimensions and must equal `kernel.feature_ndims` and `e` is the number (size) of index points in each batch. Ultimately this distribution corresponds to an `e`-dimensional multivariate normal. The batch shape must be broadcastable with `kernel.batch_shape`. mean_fn: Python `callable` that acts on `index_points` to produce a (batch of) collection of mean values at `index_points`. Takes a `Tensor` of shape `[b1, ..., bB, e, f1, ..., fF]` and returns a `Tensor` whose shape is broadcastable with `[b1, ..., bB, e, t]`, where `t` is the number of tasks. observation_noise_variance: `float` `Tensor` representing the variance of the noise in the Normal likelihood distribution of the model. May be batched, in which case the batch shape must be broadcastable with the shapes of all other batched parameters (`kernel.batch_shape`, `index_points`, etc.). Default value: `None` predictive_noise_variance: `float` `Tensor` representing the variance in the posterior predictive model. If `None`, we simply re-use `observation_noise_variance` for the posterior predictive noise. If set explicitly, however, we use this value. This allows us, for example, to omit predictive noise variance (by setting this to zero) to obtain noiseless posterior predictions of function values, conditioned on noisy observations. cholesky_fn: Callable which takes a single (batch) matrix argument and returns a Cholesky-like lower triangular factor. Default value: `None`, in which case `make_cholesky_with_jitter_fn(1e-6)` is used. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value `NaN` to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. Default value: `False`. name: Python `str` name prefixed to Ops created by this class. Default value: 'MultiTaskGaussianProcessRegressionModel'. _flattened_conditional_mean_fn: Internal parameter -- do not use. _observation_scale: Internal parameter -- do not use. """ parameters = dict(locals()) with tf.name_scope(name) as name: if not isinstance(kernel, multitask_kernel.MultiTaskKernel): raise ValueError('`kernel` must be a `MultiTaskKernel`.') dtype = dtype_util.common_dtype([ index_points, observation_index_points, observations, observation_noise_variance, predictive_noise_variance ], tf.float32) index_points = tensor_util.convert_nonref_to_tensor( index_points, dtype=dtype, name='index_points') observation_index_points = tensor_util.convert_nonref_to_tensor( observation_index_points, dtype=dtype, name='observation_index_points') observations = tensor_util.convert_nonref_to_tensor( observations, dtype=dtype, name='observations') if observations_is_missing is not None: observations_is_missing = tensor_util.convert_nonref_to_tensor( observations_is_missing, dtype=tf.bool) if observation_noise_variance is not None: observation_noise_variance = tensor_util.convert_nonref_to_tensor( observation_noise_variance, dtype=dtype, name='observation_noise_variance') predictive_noise_variance = tensor_util.convert_nonref_to_tensor( predictive_noise_variance, dtype=dtype, name='predictive_noise_variance') if predictive_noise_variance is None: predictive_noise_variance = observation_noise_variance if cholesky_fn is None: self._cholesky_fn = cholesky_util.make_cholesky_with_jitter_fn( ) else: if not callable(cholesky_fn): raise ValueError('`cholesky_fn` must be a Python callable') self._cholesky_fn = cholesky_fn self._kernel = kernel self._index_points = index_points # Scalar or vector the size of the number of tasks. if mean_fn is None: def _mean_fn(x): # Shape B1 + [E, N], where E is the number of index points, and N is # the number of tasks. return tf.zeros(tf.concat([ tf.shape(x)[:-self.kernel.feature_ndims], [self.kernel.num_tasks] ], axis=0), dtype=self.dtype) mean_fn = _mean_fn else: if not callable(mean_fn): raise ValueError('`mean_fn` must be a Python callable') self._mean_fn = mean_fn self._observation_noise_variance = observation_noise_variance self._predictive_noise_variance = predictive_noise_variance self._index_ponts = index_points self._observation_index_points = observation_index_points self._observations = observations self._observations_is_missing = observations_is_missing if _flattened_conditional_mean_fn is None: def flattened_conditional_mean_fn(x): """Flattened Conditional mean.""" observation_scale = _compute_observation_scale( kernel, observation_index_points, self._cholesky_fn, observation_noise_variance=self. observation_noise_variance, observations_is_missing=observations_is_missing) return _flattened_conditional_mean_fn_helper( x, self.kernel, self._observations, self._observation_index_points, observations_is_missing, observation_scale, mean_fn) _flattened_conditional_mean_fn = flattened_conditional_mean_fn self._flattened_conditional_mean_fn = _flattened_conditional_mean_fn self._observation_scale = _observation_scale super(MultiTaskGaussianProcessRegressionModel, self).__init__(dtype=dtype, reparameterization_type=( reparameterization.FULLY_REPARAMETERIZED), validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name)
def __init__(self, kernel, observation_index_points, observations, observations_is_missing=None, index_points=None, mean_fn=None, observation_noise_variance=None, predictive_noise_variance=None, cholesky_fn=None, validate_args=False, allow_nan_stats=False, name='MultiTaskGaussianProcessRegressionModelWithCholesky'): """Construct a MultiTaskGaussianProcessRegressionModelWithCholesky instance. WARNING: This method assumes `index_points` is the only varying parameter (i.e. is a `Variable` / changes after initialization) and hence is not tape-safe. Args: kernel: `MultiTaskKernel`-like instance representing the GP's covariance function. observation_index_points: `float` `Tensor` representing finite collection, or batch of collections, of points in the index set for which some data has been observed. Shape has the form `[b1, ..., bB, e, f1, ..., fF]` where `F` is the number of feature dimensions and must equal `kernel.feature_ndims`, and `e` is the number (size) of index points in each batch. `[b1, ..., bB, e]` must be broadcastable with the shape of `observations`, and `[b1, ..., bB]` must be broadcastable with the shapes of all other batched parameters (`kernel.batch_shape`, `index_points`, etc). observations: `float` `Tensor` representing collection, or batch of collections, of observations corresponding to `observation_index_points`. Shape has the form `[b1, ..., bB, e, t]`, which must be broadcastable with the batch and example shapes of `observation_index_points`. The batch shape `[b1, ..., bB]` must be broadcastable with the shapes of all other batched parameters (`kernel.batch_shape`, `index_points`, etc.). observations_is_missing: `bool` `Tensor` of shape `[..., e, t]`, representing a batch of boolean masks. When `observations_is_missing` is not `None`, this distribution is conditioned only on the observations for which the corresponding elements of `observations_is_missing` are `False`. index_points: `float` `Tensor` representing finite collection, or batch of collections, of points in the index set over which the GP is defined. Shape has the form `[b1, ..., bB, e, f1, ..., fF]` where `F` is the number of feature dimensions and must equal `kernel.feature_ndims` and `e` is the number (size) of index points in each batch. Ultimately this distribution corresponds to an `e`-dimensional multivariate normal. The batch shape must be broadcastable with `kernel.batch_shape`. mean_fn: Python `callable` that acts on `index_points` to produce a (batch of) collection of mean values at `index_points`. Takes a `Tensor` of shape `[b1, ..., bB, e, f1, ..., fF]` and returns a `Tensor` whose shape is broadcastable with `[b1, ..., bB, e, t]`, where `t` is the number of tasks. observation_noise_variance: `float` `Tensor` representing the variance of the noise in the Normal likelihood distribution of the model. May be batched, in which case the batch shape must be broadcastable with the shapes of all other batched parameters (`kernel.batch_shape`, `index_points`, etc.). Default value: `None` predictive_noise_variance: `float` `Tensor` representing the variance in the posterior predictive model. If `None`, we simply re-use `observation_noise_variance` for the posterior predictive noise. If set explicitly, however, we use this value. This allows us, for example, to omit predictive noise variance (by setting this to zero) to obtain noiseless posterior predictions of function values, conditioned on noisy observations. cholesky_fn: Callable which takes a single (batch) matrix argument and returns a Cholesky-like lower triangular factor. Default value: `None`, in which case `make_cholesky_with_jitter_fn(1e-6)` is used. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value `NaN` to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. Default value: `False`. name: Python `str` name prefixed to Ops created by this class. Default value: 'MultiTaskGaussianProcessRegressionModel'. """ parameters = dict(locals()) with tf.name_scope(name) as name: if not isinstance(kernel, multitask_kernel.MultiTaskKernel): raise ValueError('`kernel` must be a `MultiTaskKernel`.') dtype = dtype_util.common_dtype([ index_points, observation_index_points, observations, observation_noise_variance, predictive_noise_variance ], tf.float32) index_points = tensor_util.convert_nonref_to_tensor( index_points, dtype=dtype, name='index_points') observation_index_points = tf.convert_to_tensor( observation_index_points, dtype=dtype, name='observation_index_points') observations = tf.convert_to_tensor( observations, dtype=dtype, name='observations') if observations_is_missing is not None: observations_is_missing = tf.convert_to_tensor( observations_is_missing, dtype=tf.bool) if observation_noise_variance is not None: observation_noise_variance = tf.convert_to_tensor( observation_noise_variance, dtype=dtype, name='observation_noise_variance') predictive_noise_variance = tensor_util.convert_nonref_to_tensor( predictive_noise_variance, dtype=dtype, name='predictive_noise_variance') if predictive_noise_variance is None: predictive_noise_variance = observation_noise_variance if cholesky_fn is None: self._cholesky_fn = cholesky_util.make_cholesky_with_jitter_fn() else: if not callable(cholesky_fn): raise ValueError('`cholesky_fn` must be a Python callable') self._cholesky_fn = cholesky_fn self._kernel = kernel self._index_points = index_points # Scalar or vector the size of the number of tasks. if mean_fn is not None: if not callable(mean_fn): raise ValueError('`mean_fn` must be a Python callable') self._mean_fn = mean_fn self._observation_noise_variance = observation_noise_variance self._predictive_noise_variance = predictive_noise_variance self._index_ponts = index_points self._observation_index_points = observation_index_points self._observations = observations self._observations_is_missing = observations_is_missing observation_covariance = self.kernel.matrix_over_all_tasks( observation_index_points, observation_index_points) if observation_noise_variance is not None: observation_covariance = observation_covariance.to_dense() broadcast_shape = distribution_util.get_broadcast_shape( observation_covariance, observation_noise_variance[..., tf.newaxis, tf.newaxis]) observation_covariance = tf.broadcast_to(observation_covariance, broadcast_shape) observation_covariance = _add_diagonal_shift(observation_covariance, observation_noise_variance) observation_covariance = tf.linalg.LinearOperatorFullMatrix( observation_covariance, is_non_singular=True, is_positive_definite=True) if observations_is_missing is not None: vec_observations_is_missing = _vec(observations_is_missing) observation_covariance = tf.linalg.LinearOperatorFullMatrix( psd_kernels_util.mask_matrix( observation_covariance.to_dense(), mask=~vec_observations_is_missing), is_non_singular=True, is_positive_definite=True) self._observation_cholesky = cholesky_util.cholesky_from_fn( observation_covariance, self._cholesky_fn) # Note that the conditional mean is # k(x, o) @ (k(o, o) + sigma**2)^-1 obs. We can precompute the latter # term since it won't change per iteration. if mean_fn: vec_observations = _vec(observations - mean_fn(observation_index_points)) else: vec_observations = _vec(observations) if observations_is_missing is not None: vec_observations = tf.where(~vec_observations_is_missing, vec_observations, tf.zeros([], dtype=vec_observations.dtype)) self._solve_on_obs = self._observation_cholesky.solvevec( self._observation_cholesky.solvevec(vec_observations), adjoint=True) super(MultiTaskGaussianProcessRegressionModel, self).__init__( dtype=dtype, reparameterization_type=(reparameterization.FULLY_REPARAMETERIZED), validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name)
def with_precomputed_divisor(base_kernel, fixed_inputs, fixed_inputs_mask=None, diag_shift=None, cholesky_fn=None, validate_args=False, name='PrecomputedSchurComplement'): """Returns a `SchurComplement` with a precomputed divisor matrix. This method is the same as creating a `SchurComplement` kernel, but assumes that `fixed_inputs`, `diag_shift` and `base_kernel` are unchanging / not parameterized by any mutable state. We explicitly read / concretize these values when this method is called, since we can precompute some factorizations in order to speed up subsequent invocations of the kernel. WARNING: This method assumes passed in arguments are not parameterized by mutable state (`fixed_inputs`, `diag_shift` and `base_kernel`), and hence is not tape-safe. Args: base_kernel: A `PositiveSemidefiniteKernel` instance, the kernel used to build the block matrices of which this kernel computes the Schur complement. fixed_inputs: A Tensor, representing a collection of inputs. The Schur complement that this kernel computes comes from a block matrix, whose bottom-right corner is derived from `base_kernel.matrix(fixed_inputs, fixed_inputs)`, and whose top-right and bottom-left pieces are constructed by computing the base_kernel at pairs of input locations together with these `fixed_inputs`. `fixed_inputs` is allowed to be an empty collection (either `None` or having a zero shape entry), in which case the kernel falls back to the trivial application of `base_kernel` to inputs. See class-level docstring for more details on the exact computation this does; `fixed_inputs` correspond to the `Z` structure discussed there. `fixed_inputs` is assumed to have shape `[b1, ..., bB, N, f1, ..., fF]` where the `b`'s are batch shape entries, the `f`'s are feature_shape entries, and `N` is the number of fixed inputs. Use of this kernel entails a 1-time O(N^3) cost of computing the Cholesky decomposition of the k(Z, Z) matrix. The batch shape elements of `fixed_inputs` must be broadcast compatible with `base_kernel.batch_shape`. fixed_inputs_mask: A boolean Tensor of shape `[..., N]`. When `mask` is not None and an element of `mask` is False, the returned kernel will return values computed as if the divisor matrix did not contain the corresponding row or column. diag_shift: A floating point scalar to be added to the diagonal of the divisor_matrix before computing its Cholesky. cholesky_fn: Callable which takes a single (batch) matrix argument and returns a Cholesky-like lower triangular factor. Default value: `None`, in which case `make_cholesky_with_jitter_fn` is used with the `jitter` parameter. validate_args: If `True`, parameters are checked for validity despite possibly degrading runtime performance. Default value: `False` name: Python `str` name prefixed to Ops created by this class. Default value: `"PrecomputedSchurComplement"` """ dtype = dtype_util.common_dtype( [base_kernel, fixed_inputs, diag_shift], tf.float32) fixed_inputs = tf.convert_to_tensor(fixed_inputs, dtype) if fixed_inputs_mask is not None: fixed_inputs_mask = tf.convert_to_tensor(fixed_inputs_mask, tf.bool) if diag_shift is not None: diag_shift = tf.convert_to_tensor(diag_shift, dtype) if cholesky_fn is None: from tensorflow_probability.python.distributions import cholesky_util # pylint:disable=g-import-not-at-top cholesky_fn = cholesky_util.make_cholesky_with_jitter_fn() # TODO(b/196219597): Add a check to ensure that we have a `base_kernel` # that is explicitly concretized. divisor_matrix_cholesky = cholesky_fn( util.mask_matrix(_compute_divisor_matrix( base_kernel, diag_shift=diag_shift, fixed_inputs=fixed_inputs), mask=fixed_inputs_mask)) schur_complement = SchurComplement( base_kernel=base_kernel, fixed_inputs=fixed_inputs, fixed_inputs_mask=fixed_inputs_mask, diag_shift=diag_shift, cholesky_fn=cholesky_fn, validate_args=validate_args, _precomputed_divisor_matrix_cholesky=divisor_matrix_cholesky, name=name) return schur_complement
def __init__(self, base_kernel, fixed_inputs, fixed_inputs_mask=None, diag_shift=None, cholesky_fn=None, validate_args=False, name='SchurComplement', _precomputed_divisor_matrix_cholesky=None): """Construct a SchurComplement kernel instance. Args: base_kernel: A `PositiveSemidefiniteKernel` instance, the kernel used to build the block matrices of which this kernel computes the Schur complement. fixed_inputs: A Tensor, representing a collection of inputs. The Schur complement that this kernel computes comes from a block matrix, whose bottom-right corner is derived from `base_kernel.matrix(fixed_inputs, fixed_inputs)`, and whose top-right and bottom-left pieces are constructed by computing the base_kernel at pairs of input locations together with these `fixed_inputs`. `fixed_inputs` is allowed to be an empty collection (either `None` or having a zero shape entry), in which case the kernel falls back to the trivial application of `base_kernel` to inputs. See class-level docstring for more details on the exact computation this does; `fixed_inputs` correspond to the `Z` structure discussed there. `fixed_inputs` is assumed to have shape `[b1, ..., bB, N, f1, ..., fF]` where the `b`'s are batch shape entries, the `f`'s are feature_shape entries, and `N` is the number of fixed inputs. Use of this kernel entails a 1-time O(N^3) cost of computing the Cholesky decomposition of the k(Z, Z) matrix. The batch shape elements of `fixed_inputs` must be broadcast compatible with `base_kernel.batch_shape`. fixed_inputs_mask: A boolean Tensor of shape `[..., N]`. When `mask` is not None and an element of `mask` is `False`, this kernel will return values computed as if the divisor matrix did not contain the corresponding row or column. diag_shift: A floating point scalar to be added to the diagonal of the divisor_matrix before computing its Cholesky. cholesky_fn: Callable which takes a single (batch) matrix argument and returns a Cholesky-like lower triangular factor. Default value: `None`, in which case `make_cholesky_with_jitter_fn` is used with the `jitter` parameter. validate_args: If `True`, parameters are checked for validity despite possibly degrading runtime performance. Default value: `False` name: Python `str` name prefixed to Ops created by this class. Default value: `"SchurComplement"` _precomputed_divisor_matrix_cholesky: Internal parameter -- do not use. """ parameters = dict(locals()) # Delayed import to avoid circular dependency between `tfp.bijectors` and # `tfp.math` # pylint: disable=g-import-not-at-top from tensorflow_probability.python.bijectors import cholesky_outer_product from tensorflow_probability.python.bijectors import invert # pylint: enable=g-import-not-at-top with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([ base_kernel, fixed_inputs, diag_shift, _precomputed_divisor_matrix_cholesky ], tf.float32) self._base_kernel = base_kernel self._diag_shift = tensor_util.convert_nonref_to_tensor( diag_shift, dtype=dtype, name='diag_shift') self._fixed_inputs = tensor_util.convert_nonref_to_tensor( fixed_inputs, dtype=dtype, name='fixed_inputs') self._fixed_inputs_mask = tensor_util.convert_nonref_to_tensor( fixed_inputs_mask, dtype=tf.bool, name='fixed_inputs_mask') self._cholesky_bijector = invert.Invert( cholesky_outer_product.CholeskyOuterProduct()) self._precomputed_divisor_matrix_cholesky = _precomputed_divisor_matrix_cholesky if self._precomputed_divisor_matrix_cholesky is not None: self._precomputed_divisor_matrix_cholesky = tf.convert_to_tensor( self._precomputed_divisor_matrix_cholesky, dtype) if cholesky_fn is None: from tensorflow_probability.python.distributions import cholesky_util # pylint:disable=g-import-not-at-top cholesky_fn = cholesky_util.make_cholesky_with_jitter_fn() self._cholesky_fn = cholesky_fn self._cholesky_bijector = invert.Invert( cholesky_outer_product.CholeskyOuterProduct( cholesky_fn=cholesky_fn)) super(SchurComplement, self).__init__(base_kernel.feature_ndims, dtype=dtype, name=name, parameters=parameters)
def precompute_regression_model( kernel, observation_index_points, observations, observations_mask=None, index_points=None, observation_noise_variance=0., predictive_noise_variance=None, mean_fn=None, cholesky_fn=None, jitter=1e-6, validate_args=False, allow_nan_stats=False, name='PrecomputedGaussianProcessRegressionModel'): """Returns a GaussianProcessRegressionModel with precomputed quantities. This differs from the constructor by precomputing quantities associated with observations in a non-tape safe way. `index_points` is the only parameter that is allowed to vary (i.e. is a `Variable` / changes after initialization). Specifically: * We make `observation_index_points` and `observations` mandatory parameters. * We precompute `kernel(observation_index_points, observation_index_points)` along with any other associated quantities relating to the `kernel`, `observations` and `observation_index_points`. A typical usecase would be optimizing kernel hyperparameters for a `GaussianProcess`, and computing the posterior predictive with respect to those optimized hyperparameters and observation / index-points pairs. WARNING: This method assumes `index_points` is the only varying parameter (i.e. is a `Variable` / changes after initialization) and hence is not tape-safe. Args: kernel: `PositiveSemidefiniteKernel`-like instance representing the GP's covariance function. observation_index_points: `float` `Tensor` representing finite collection, or batch of collections, of points in the index set for which some data has been observed. Shape has the form `[b1, ..., bB, e, f1, ..., fF]` where `F` is the number of feature dimensions and must equal `kernel.feature_ndims`, and `e` is the number (size) of index points in each batch. `[b1, ..., bB, e]` must be broadcastable with the shape of `observations`, and `[b1, ..., bB]` must be broadcastable with the shapes of all other batched parameters (`kernel.batch_shape`, `index_points`, etc). The default value is `None`, which corresponds to the empty set of observations, and simply results in the prior predictive model (a GP with noise of variance `predictive_noise_variance`). observations: `float` `Tensor` representing collection, or batch of collections, of observations corresponding to `observation_index_points`. Shape has the form `[b1, ..., bB, e]`, which must be brodcastable with the batch and example shapes of `observation_index_points`. The batch shape `[b1, ..., bB]` must be broadcastable with the shapes of all other batched parameters (`kernel.batch_shape`, `index_points`, etc.). The default value is `None`, which corresponds to the empty set of observations, and simply results in the prior predictive model (a GP with noise of variance `predictive_noise_variance`). observations_mask: `bool` `Tensor` of shape `[..., e]`, representing a batch of boolean masks. When `observation_masks` is not `None`, the returned distribution is conditioned only on the observations for which the corresponding elements of `observations_masks` are `True`. index_points: `float` `Tensor` representing finite collection, or batch of collections, of points in the index set over which the GP is defined. Shape has the form `[b1, ..., bB, e, f1, ..., fF]` where `F` is the number of feature dimensions and must equal `kernel.feature_ndims` and `e` is the number (size) of index points in each batch. Ultimately this distribution corresponds to an `e`-dimensional multivariate normal. The batch shape must be broadcastable with `kernel.batch_shape` and any batch dims yielded by `mean_fn`. observation_noise_variance: `float` `Tensor` representing the variance of the noise in the Normal likelihood distribution of the model. May be batched, in which case the batch shape must be broadcastable with the shapes of all other batched parameters (`kernel.batch_shape`, `index_points`, etc.). Default value: `0.` predictive_noise_variance: `float` `Tensor` representing the variance in the posterior predictive model. If `None`, we simply re-use `observation_noise_variance` for the posterior predictive noise. If set explicitly, however, we use this value. This allows us, for example, to omit predictive noise variance (by setting this to zero) to obtain noiseless posterior predictions of function values, conditioned on noisy observations. mean_fn: Python `callable` that acts on `index_points` to produce a collection, or batch of collections, of mean values at `index_points`. Takes a `Tensor` of shape `[b1, ..., bB, f1, ..., fF]` and returns a `Tensor` whose shape is broadcastable with `[b1, ..., bB]`. Default value: `None` implies the constant zero function. cholesky_fn: Callable which takes a single (batch) matrix argument and returns a Cholesky-like lower triangular factor. Default value: `None`, in which case `make_cholesky_with_jitter_fn` is used with the `jitter` parameter. jitter: `float` scalar `Tensor` added to the diagonal of the covariance matrix to ensure positive definiteness of the covariance matrix. Default value: `1e-6`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value `NaN` to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. Default value: `False`. name: Python `str` name prefixed to Ops created by this class. Default value: 'PrecomputedGaussianProcessRegressionModel'. Returns An instance of `GaussianProcessRegressionModel` with precomputed quantities associated with observations. """ with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([ index_points, observation_index_points, observations, observation_noise_variance, predictive_noise_variance, jitter ], tf.float32) # Convert-to-tensor arguments that are expected to not be Variables / not # going to change. jitter = tf.convert_to_tensor(jitter, dtype=dtype) observation_index_points = tf.convert_to_tensor( observation_index_points, dtype=dtype) observation_noise_variance = tf.convert_to_tensor( observation_noise_variance, dtype=dtype) observations = tf.convert_to_tensor(observations, dtype=dtype) if observations_mask is not None: observations_mask = tf.convert_to_tensor(observations_mask) if cholesky_fn is None: cholesky_fn = cholesky_util.make_cholesky_with_jitter_fn( jitter) conditional_kernel = tfpk.SchurComplement.with_precomputed_divisor( base_kernel=kernel, fixed_inputs=observation_index_points, fixed_inputs_mask=observations_mask, cholesky_fn=cholesky_fn, diag_shift=observation_noise_variance) observation_cholesky_operator = tf.linalg.LinearOperatorLowerTriangular( conditional_kernel.divisor_matrix_cholesky()) if mean_fn is None: mean_fn = lambda x: tf.zeros([1], dtype=dtype) else: if not callable(mean_fn): raise ValueError('`mean_fn` must be a Python callable') diff = observations - mean_fn(observation_index_points) if observations_mask is not None: diff = tf.where(observations_mask, diff, tf.zeros([], dtype=diff.dtype)) solve_on_observation = observation_cholesky_operator.solvevec( observation_cholesky_operator.solvevec(diff), adjoint=True) def conditional_mean_fn(x): k_x_obs = kernel.matrix(x, observation_index_points) if observations_mask is not None: k_x_obs = tf.where(observations_mask[..., tf.newaxis, :], k_x_obs, tf.zeros([], dtype=k_x_obs.dtype)) return mean_fn(x) + tf.linalg.matvec(k_x_obs, solve_on_observation) gprm = GaussianProcessRegressionModel( kernel=kernel, observation_index_points=observation_index_points, observations=observations, index_points=index_points, observation_noise_variance=observation_noise_variance, predictive_noise_variance=predictive_noise_variance, cholesky_fn=cholesky_fn, jitter=jitter, _conditional_kernel=conditional_kernel, _conditional_mean_fn=conditional_mean_fn, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) return gprm
def __init__(self, kernel, index_points=None, observation_index_points=None, observations=None, observation_noise_variance=0., predictive_noise_variance=None, mean_fn=None, cholesky_fn=None, jitter=1e-6, validate_args=False, allow_nan_stats=False, name='GaussianProcessRegressionModel', _conditional_kernel=None, _conditional_mean_fn=None): """Construct a GaussianProcessRegressionModel instance. Args: kernel: `PositiveSemidefiniteKernel`-like instance representing the GP's covariance function. index_points: `float` `Tensor` representing finite collection, or batch of collections, of points in the index set over which the GP is defined. Shape has the form `[b1, ..., bB, e, f1, ..., fF]` where `F` is the number of feature dimensions and must equal `kernel.feature_ndims` and `e` is the number (size) of index points in each batch. Ultimately this distribution corresponds to an `e`-dimensional multivariate normal. The batch shape must be broadcastable with `kernel.batch_shape` and any batch dims yielded by `mean_fn`. observation_index_points: `float` `Tensor` representing finite collection, or batch of collections, of points in the index set for which some data has been observed. Shape has the form `[b1, ..., bB, e, f1, ..., fF]` where `F` is the number of feature dimensions and must equal `kernel.feature_ndims`, and `e` is the number (size) of index points in each batch. `[b1, ..., bB, e]` must be broadcastable with the shape of `observations`, and `[b1, ..., bB]` must be broadcastable with the shapes of all other batched parameters (`kernel.batch_shape`, `index_points`, etc). The default value is `None`, which corresponds to the empty set of observations, and simply results in the prior predictive model (a GP with noise of variance `predictive_noise_variance`). observations: `float` `Tensor` representing collection, or batch of collections, of observations corresponding to `observation_index_points`. Shape has the form `[b1, ..., bB, e]`, which must be brodcastable with the batch and example shapes of `observation_index_points`. The batch shape `[b1, ..., bB]` must be broadcastable with the shapes of all other batched parameters (`kernel.batch_shape`, `index_points`, etc.). The default value is `None`, which corresponds to the empty set of observations, and simply results in the prior predictive model (a GP with noise of variance `predictive_noise_variance`). observation_noise_variance: `float` `Tensor` representing the variance of the noise in the Normal likelihood distribution of the model. May be batched, in which case the batch shape must be broadcastable with the shapes of all other batched parameters (`kernel.batch_shape`, `index_points`, etc.). Default value: `0.` predictive_noise_variance: `float` `Tensor` representing the variance in the posterior predictive model. If `None`, we simply re-use `observation_noise_variance` for the posterior predictive noise. If set explicitly, however, we use this value. This allows us, for example, to omit predictive noise variance (by setting this to zero) to obtain noiseless posterior predictions of function values, conditioned on noisy observations. mean_fn: Python `callable` that acts on `index_points` to produce a collection, or batch of collections, of mean values at `index_points`. Takes a `Tensor` of shape `[b1, ..., bB, f1, ..., fF]` and returns a `Tensor` whose shape is broadcastable with `[b1, ..., bB]`. Default value: `None` implies the constant zero function. cholesky_fn: Callable which takes a single (batch) matrix argument and returns a Cholesky-like lower triangular factor. Default value: `None`, in which case `make_cholesky_with_jitter_fn` is used with the `jitter` parameter. jitter: `float` scalar `Tensor` added to the diagonal of the covariance matrix to ensure positive definiteness of the covariance matrix. This argument is ignored if `cholesky_fn` is set. Default value: `1e-6`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value `NaN` to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. Default value: `False`. name: Python `str` name prefixed to Ops created by this class. Default value: 'GaussianProcessRegressionModel'. _conditional_kernel: Internal parameter -- do not use. _conditional_mean_fn: Internal parameter -- do not use. Raises: ValueError: if either - only one of `observations` and `observation_index_points` is given, or - `mean_fn` is not `None` and not callable. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([ index_points, observation_index_points, observations, observation_noise_variance, predictive_noise_variance, jitter ], tf.float32) index_points = tensor_util.convert_nonref_to_tensor( index_points, dtype=dtype, name='index_points') observation_index_points = tensor_util.convert_nonref_to_tensor( observation_index_points, dtype=dtype, name='observation_index_points') observations = tensor_util.convert_nonref_to_tensor( observations, dtype=dtype, name='observations') observation_noise_variance = tensor_util.convert_nonref_to_tensor( observation_noise_variance, dtype=dtype, name='observation_noise_variance') predictive_noise_variance = tensor_util.convert_nonref_to_tensor( predictive_noise_variance, dtype=dtype, name='predictive_noise_variance') if predictive_noise_variance is None: predictive_noise_variance = observation_noise_variance jitter = tensor_util.convert_nonref_to_tensor(jitter, dtype=dtype, name='jitter') if (observation_index_points is None) != (observations is None): raise ValueError( '`observations` and `observation_index_points` must both be given ' 'or None. Got {} and {}, respectively.'.format( observations, observation_index_points)) # Default to a constant zero function, borrowing the dtype from # index_points to ensure consistency. if mean_fn is None: mean_fn = lambda x: tf.zeros([1], dtype=dtype) else: if not callable(mean_fn): raise ValueError('`mean_fn` must be a Python callable') if cholesky_fn is None: cholesky_fn = cholesky_util.make_cholesky_with_jitter_fn( jitter) self._name = name self._observation_index_points = observation_index_points self._observations = observations self._observation_noise_variance = observation_noise_variance self._predictive_noise_variance = predictive_noise_variance self._jitter = jitter self._validate_args = validate_args with tf.name_scope('init'): if _conditional_kernel is None: _conditional_kernel = tfpk.SchurComplement( base_kernel=kernel, fixed_inputs=observation_index_points, cholesky_fn=cholesky_fn, diag_shift=observation_noise_variance) # Special logic for mean_fn only; SchurComplement already handles the # case of empty observations (ie, falls back to base_kernel). if _is_empty_observation_data( feature_ndims=kernel.feature_ndims, observation_index_points=observation_index_points, observations=observations): if _conditional_mean_fn is None: _conditional_mean_fn = mean_fn else: _validate_observation_data( kernel=kernel, observation_index_points=observation_index_points, observations=observations) if _conditional_mean_fn is None: def conditional_mean_fn(x): """Conditional mean.""" observations = tf.convert_to_tensor( self._observations) observation_index_points = tf.convert_to_tensor( self._observation_index_points) k_x_obs_linop = tf.linalg.LinearOperatorFullMatrix( kernel.matrix(x, observation_index_points)) chol_linop = tf.linalg.LinearOperatorLowerTriangular( _conditional_kernel.divisor_matrix_cholesky( fixed_inputs=observation_index_points)) diff = observations - mean_fn( observation_index_points) return mean_fn(x) + k_x_obs_linop.matvec( chol_linop.solvevec(chol_linop.solvevec(diff), adjoint=True)) _conditional_mean_fn = conditional_mean_fn super(GaussianProcessRegressionModel, self).__init__( kernel=_conditional_kernel, mean_fn=_conditional_mean_fn, index_points=index_points, cholesky_fn=cholesky_fn, jitter=jitter, # What the GP super class calls "observation noise variance" we call # here the "predictive noise variance". We use the observation noise # variance for the fit/solve process above, and predictive for # downstream computations like sampling. observation_noise_variance=predictive_noise_variance, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters
def __init__(self, df, kernel, index_points=None, mean_fn=None, observation_noise_variance=0., marginal_fn=None, cholesky_fn=None, jitter=1e-6, validate_args=False, allow_nan_stats=False, name='StudentTProcess'): """Instantiate a StudentTProcess Distribution. Args: df: Positive Floating-point `Tensor` representing the degrees of freedom. Must be greater than 2. kernel: `PositiveSemidefiniteKernel`-like instance representing the TP's covariance function. index_points: `float` `Tensor` representing finite (batch of) vector(s) of points in the index set over which the TP is defined. Shape has the form `[b1, ..., bB, e, f1, ..., fF]` where `F` is the number of feature dimensions and must equal `kernel.feature_ndims` and `e` is the number (size) of index points in each batch. Ultimately this distribution corresponds to a `e`-dimensional multivariate Student's T. The batch shape must be broadcastable with `kernel.batch_shape` and any batch dims yielded by `mean_fn`. mean_fn: Python `callable` that acts on `index_points` to produce a (batch of) vector(s) of mean values at `index_points`. Takes a `Tensor` of shape `[b1, ..., bB, f1, ..., fF]` and returns a `Tensor` whose shape is broadcastable with `[b1, ..., bB]`. Default value: `None` implies constant zero function. observation_noise_variance: `float` `Tensor` representing (batch of) scalar variance(s) of the noise in the Normal likelihood distribution of the model. If batched, the batch shape must be broadcastable with the shapes of all other batched parameters (`kernel.batch_shape`, `index_points`, etc.). Default value: `0.` marginal_fn: A Python callable that takes a location, covariance matrix, optional `validate_args`, `allow_nan_stats` and `name` arguments, and returns a multivariate normal subclass of `tfd.Distribution`. Default value: `None`, in which case a Cholesky-factorizing function is is created using `make_cholesky_factored_marginal_fn` and the `jitter` argument. cholesky_fn: Callable which takes a single (batch) matrix argument and returns a Cholesky-like lower triangular factor. Default value: `None`, in which case `make_cholesky_with_jitter_fn` is used with the `jitter` parameter. At most one of `cholesky_fn` and `marginal_fn` should be set. jitter: `float` scalar `Tensor` added to the diagonal of the covariance matrix to ensure positive definiteness of the covariance matrix. This argument is ignored if `cholesky_fn` is set. Default value: `1e-6`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. Default value: `False`. name: Python `str` name prefixed to Ops created by this class. Default value: "StudentTProcess". Raises: ValueError: if `mean_fn` is not `None` and is not callable. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype( [df, kernel, index_points, observation_noise_variance, jitter], tf.float32) df = tensor_util.convert_nonref_to_tensor(df, dtype=dtype, name='df') observation_noise_variance = tensor_util.convert_nonref_to_tensor( observation_noise_variance, dtype=dtype, name='observation_noise_variance') index_points = tensor_util.convert_nonref_to_tensor( index_points, dtype=dtype, name='index_points') jitter = tensor_util.convert_nonref_to_tensor( jitter, dtype=dtype, name='jitter') self._kernel = kernel self._index_points = index_points # Default to a constant zero function, borrowing the dtype from # index_points to ensure consistency. if mean_fn is None: mean_fn = lambda x: tf.zeros([1], dtype=dtype) else: if not callable(mean_fn): raise ValueError('`mean_fn` must be a Python callable') self._df = df self._observation_noise_variance = observation_noise_variance self._mean_fn = mean_fn self._jitter = jitter self._cholesky_fn = cholesky_fn if marginal_fn is not None and cholesky_fn is not None: raise ValueError( 'At most one of `marginal_fn` and `cholesky_fn` should be set.') if marginal_fn is None: if self._cholesky_fn is None: self._cholesky_fn = cholesky_util.make_cholesky_with_jitter_fn( jitter) self._marginal_fn = make_cholesky_factored_marginal_fn( self._cholesky_fn) else: self._marginal_fn = marginal_fn with tf.name_scope('init'): super(StudentTProcess, self).__init__( dtype=dtype, reparameterization_type=reparameterization.FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name)
def __init__(self, kernel, index_points=None, mean_fn=None, observation_noise_variance=None, cholesky_fn=None, validate_args=False, allow_nan_stats=False, name='MultiTaskGaussianProcess'): """Constructs a MultiTaskGaussianProcess instance. Args: kernel: `MultiTaskKernel`-like instance representing the GP's covariance function. index_points: `float` `Tensor` representing finite collection, or batch of collections, of points in the index set over which the GP is defined. Shape has the form `[b1, ..., bB, e, f1, ..., fF]` where `F` is the number of feature dimensions and must equal `kernel.feature_ndims` and `e` is the number (size) of index points in each batch. Ultimately this distribution corresponds to an `e`-dimensional multivariate normal. The batch shape must be broadcastable with `kernel.batch_shape`. mean_fn: Python `callable` that acts on `index_points` to produce a (batch of) collection of mean values at `index_points`. Takes a `Tensor` of shape `[b1, ..., bB, e, f1, ..., fF]` and returns a `Tensor` whose shape is broadcastable with `[b1, ..., bB, e, t]`, where `t` is the number of tasks. observation_noise_variance: `float` `Tensor` representing the variance of the noise in the Normal likelihood distribution of the model. May be batched, in which case the batch shape must be broadcastable with the shapes of all other batched parameters (`kernel.batch_shape`, `index_points`, etc.). Default value: `0.` cholesky_fn: Callable which takes a single (batch) matrix argument and returns a Cholesky-like lower triangular factor. Default value: `None`, in which case `make_cholesky_with_jitter_fn(1e-6)` is used. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value `NaN` to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. Default value: `False`. name: Python `str` name prefixed to Ops created by this class. Default value: 'MultiTaskGaussianProcess'. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype( [index_points, observation_noise_variance], tf.float32) index_points = tensor_util.convert_nonref_to_tensor( index_points, dtype=dtype, name='index_points') observation_noise_variance = tensor_util.convert_nonref_to_tensor( observation_noise_variance, dtype=dtype, name='observation_noise_variance') if not isinstance(kernel, multitask_kernel.MultiTaskKernel): raise ValueError('`kernel` must be a `MultiTaskKernel`.') self._kernel = kernel self._index_points = index_points if mean_fn is None: def _mean_fn(x): # Shape B1 + [E, N], where E is the number of index points, and N is # the number of tasks. return tf.zeros(ps.concat([ ps.shape(x)[:-self.kernel.feature_ndims], [self.kernel.num_tasks] ], axis=0), dtype=dtype) mean_fn = _mean_fn else: if not callable(mean_fn): raise ValueError('`mean_fn` must be a Python callable') self._mean_fn = mean_fn # Scalar or vector the size of the number of tasks. self._observation_noise_variance = observation_noise_variance if cholesky_fn is None: self._cholesky_fn = cholesky_util.make_cholesky_with_jitter_fn( ) else: if not callable(cholesky_fn): raise ValueError('`cholesky_fn` must be a Python callable') self._cholesky_fn = cholesky_fn with tf.name_scope('init'): super(MultiTaskGaussianProcess, self).__init__( dtype=dtype, reparameterization_type=reparameterization. FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name)