def _get_flattened_marginal_distribution(self, index_points=None):
        # This returns a MVN of event size [N * E], where N is the number of tasks
        # and E is the number of index points.
        with self._name_and_control_scope(
                'get_flattened_marginal_distribution'):
            index_points = self._get_index_points(index_points)
            covariance = self._compute_flattened_covariance(index_points)

            batch_shape = self._batch_shape_tensor(index_points=index_points)
            event_shape = self._event_shape_tensor(index_points=index_points)

            # Now take the cholesky but specialize to cases where we have block-diag
            # and kronecker.
            covariance_cholesky = cholesky_util.cholesky_from_fn(
                covariance, self._cholesky_fn)
            loc = self._mean_fn(index_points)
            # Ensure that we broadcast the mean function result to ensure we support
            # constant mean functions (constant over all tasks, and a constant
            # per-task)
            loc = ps.broadcast_to(
                loc, ps.concat([batch_shape, event_shape], axis=0))
            loc = _vec(loc)
            return mvn_linear_operator.MultivariateNormalLinearOperator(
                loc=loc,
                scale=covariance_cholesky,
                validate_args=self._validate_args,
                allow_nan_stats=self._allow_nan_stats,
                name='marginal_distribution')
Пример #2
0
    def _get_flattened_marginal_distribution(self, index_points=None):
        # This returns a MVN of event size [N * E], where N is the number of tasks
        # and E is the number of index points.
        with self._name_and_control_scope(
                'get_flattened_marginal_distribution'):
            index_points = self._get_index_points(index_points)
            scale = _compute_flattened_scale(
                kernel=self.kernel,
                index_points=index_points,
                cholesky_fn=self._cholesky_fn,
                observation_noise_variance=self.observation_noise_variance)

            batch_shape = self._batch_shape_tensor(index_points=index_points)
            event_shape = self._event_shape_tensor(index_points=index_points)

            loc = self._mean_fn(index_points)
            # Ensure that we broadcast the mean function result to ensure we support
            # constant mean functions (constant over all tasks, and a constant
            # per-task)
            loc = ps.broadcast_to(
                loc, ps.concat([batch_shape, event_shape], axis=0))
            loc = _vec(loc)
            return mvn_linear_operator.MultivariateNormalLinearOperator(
                loc=loc,
                scale=scale,
                validate_args=self._validate_args,
                allow_nan_stats=self._allow_nan_stats,
                name='marginal_distribution')
 def _as_multivariate_normal(self, loc=None):
     # Rebuild the Multivariate Normal Distribution on every call because the
     # underlying tensor shapes might have changed.
     loc = tf.convert_to_tensor(self.loc if loc is None else loc)
     return mvn_linear_operator.MultivariateNormalLinearOperator(
         loc=_vec(loc),
         scale=tf.linalg.LinearOperatorKronecker(
             [self.scale_row, self.scale_column]),
         validate_args=self.validate_args)
Пример #4
0
 def marginal_fn(loc,
                 covariance,
                 validate_args=False,
                 allow_nan_stats=False,
                 name=name):
     with tf.name_scope(name) as name:
         scale = tf.linalg.LinearOperatorLowerTriangular(
             cholesky_like(covariance), is_non_singular=True)
         return mvn_linear_operator.MultivariateNormalLinearOperator(
             loc=loc,
             scale=scale,
             validate_args=validate_args,
             allow_nan_stats=allow_nan_stats)
Пример #5
0
 def marginal_fn(loc,
                 covariance,
                 validate_args=False,
                 allow_nan_stats=False,
                 name='marginal_distribution'):
     scale = tf.linalg.LinearOperatorLowerTriangular(
         tf.linalg.cholesky(_add_diagonal_shift(covariance, jitter)),
         is_non_singular=True,
         name='GaussianProcessScaleLinearOperator')
     return mvn_linear_operator.MultivariateNormalLinearOperator(
         loc=loc,
         scale=scale,
         validate_args=validate_args,
         allow_nan_stats=allow_nan_stats,
         name=name)
Пример #6
0
    def get_marginal_distribution(self, index_points=None):
        """Compute the marginal of this GP over function values at `index_points`.

    Args:
      index_points: `float` `Tensor` representing finite (batch of) vector(s) of
        points in the index set over which the GP is defined. Shape has the form
        `[b1, ..., bB, e, f1, ..., fF]` where `F` is the number of feature
        dimensions and must equal `kernel.feature_ndims` and `e` is the number
        (size) of index points in each batch. Ultimately this distribution
        corresponds to a `e`-dimensional multivariate normal. The batch shape
        must be broadcastable with `kernel.batch_shape` and any batch dims
        yielded by `mean_fn`.

    Returns:
      marginal: a `Normal` or `MultivariateNormalLinearOperator` distribution,
        according to whether `index_points` consists of one or many index
        points, respectively.
    """
        with self._name_and_control_scope('get_marginal_distribution'):
            # TODO(cgs): consider caching the result here, keyed on `index_points`.
            index_points = self._get_index_points(index_points)
            covariance = self._compute_covariance(index_points)
            loc = self._mean_fn(index_points)
            # If we're sure the number of index points is 1, we can just construct a
            # scalar Normal. This has computational benefits and supports things like
            # CDF that aren't otherwise straightforward to provide.
            if self._is_univariate_marginal(index_points):
                scale = tf.sqrt(covariance)
                # `loc` has a trailing 1 in the shape; squeeze it.
                loc = tf.squeeze(loc, axis=-1)
                return normal.Normal(loc=loc,
                                     scale=scale,
                                     validate_args=self._validate_args,
                                     allow_nan_stats=self._allow_nan_stats,
                                     name='marginal_distribution')
            else:
                scale = tf.linalg.LinearOperatorLowerTriangular(
                    tf.linalg.cholesky(
                        _add_diagonal_shift(covariance, self.jitter)),
                    is_non_singular=True,
                    name='GaussianProcessScaleLinearOperator')
                return mvn_linear_operator.MultivariateNormalLinearOperator(
                    loc=loc,
                    scale=scale,
                    validate_args=self._validate_args,
                    allow_nan_stats=self._allow_nan_stats,
                    name='marginal_distribution')
Пример #7
0
  def _sample_n(self, n, seed=None):
    # Like with the univariate Student's t, sampling can be implemented as a
    # ratio of samples from a multivariate gaussian with the appropriate
    # covariance matrix and a sample from the chi-squared distribution.
    seed = seed_stream.SeedStream(seed, salt="multivariate t")

    loc = tf.broadcast_to(self.loc, self._sample_shape())
    mvn = mvn_linear_operator.MultivariateNormalLinearOperator(
        loc=tf.zeros_like(loc), scale=self.scale)
    normal_samp = mvn.sample(n, seed=seed())

    df = tf.broadcast_to(self.df, self.batch_shape_tensor())
    chi2 = chi2_lib.Chi2(df=df)
    chi2_samp = chi2.sample(n, seed=seed())

    return (self._loc +
            normal_samp * tf.math.rsqrt(chi2_samp / self._df)[..., tf.newaxis])
 def _get_flattened_marginal_distribution(self, index_points=None):
   # This returns a MVN of event size [N * E], where N is the number of tasks
   # and E is the number of index points.
   with self._name_and_control_scope('get_flattened_marginal_distribution'):
     index_points = self._get_index_points(index_points)
     covariance = self._compute_flattened_covariance(index_points)
     loc = self._conditional_mean_fn(index_points)
     scale = tf.linalg.LinearOperatorLowerTriangular(
         self._cholesky_fn(covariance),
         is_non_singular=True,
         name='GaussianProcessScaleLinearOperator')
     return mvn_linear_operator.MultivariateNormalLinearOperator(
         loc=loc,
         scale=scale,
         validate_args=self._validate_args,
         allow_nan_stats=self._allow_nan_stats,
         name='marginal_distribution')
Пример #9
0
 def eigh_marginal_fn(loc,
                      covariance,
                      validate_args=False,
                      allow_nan_stats=False,
                      name=name):
     """Compute EigH-based square root and return a MVN."""
     with tf.name_scope(name) as name:
         values, vectors = tf.linalg.eigh(covariance)
         safe_root = tf.math.sqrt(tf.where(values < tol, tol, values))
         scale = tf.linalg.LinearOperatorFullMatrix(
             tf.einsum('...ij,...j->...ij', vectors, safe_root),
             is_square=True,
             is_positive_definite=True,
             is_non_singular=True,
             name='GaussianProcessEigHScaleLinearOperator')
         return mvn_linear_operator.MultivariateNormalLinearOperator(
             loc=loc,
             scale=scale,
             validate_args=validate_args,
             allow_nan_stats=allow_nan_stats,
             name=name)
    def __init__(self,
                 kernel,
                 index_points,
                 inducing_index_points,
                 variational_inducing_observations_loc,
                 variational_inducing_observations_scale,
                 mean_fn=None,
                 observation_noise_variance=0.,
                 predictive_noise_variance=0.,
                 jitter=1e-6,
                 validate_args=False,
                 allow_nan_stats=False,
                 name='VariataionalGaussianProcess'):
        """Instantiate a VariationalGaussianProcess Distribution.

    Args:
      kernel: `PositiveSemidefiniteKernel`-like instance representing the
        GP's covariance function.
      index_points: `float` `Tensor` representing finite (batch of) vector(s) of
        points in the index set over which the VGP is defined. Shape has the
        form `[b1, ..., bB, e1, f1, ..., fF]` where `F` is the number of feature
        dimensions and must equal `kernel.feature_ndims` and `e1` is the number
        (size) of index points in each batch (we denote it `e1` to distinguish
        it from the numer of inducing index points, denoted `e2` below).
        Ultimately the VariationalGaussianProcess distribution corresponds to an
        `e1`-dimensional multivariate normal. The batch shape must be
        broadcastable with `kernel.batch_shape`, the batch shape of
        `inducing_index_points`, and any batch dims yielded by `mean_fn`.
      inducing_index_points: `float` `Tensor` of locations of inducing points in
        the index set. Shape has the form `[b1, ..., bB, e2, f1, ..., fF]`, just
        like `index_points`. The batch shape components needn't be identical to
        those of `index_points`, but must be broadcast compatible with them.
      variational_inducing_observations_loc: `float` `Tensor`; the mean of the
        (full-rank Gaussian) variational posterior over function values at the
        inducing points, conditional on observed data. Shape has the form `[b1,
        ..., bB, e2]`, where `b1, ..., bB` is broadcast compatible with other
        parameters' batch shapes, and `e2` is the number of inducing points.
      variational_inducing_observations_scale: `float` `Tensor`; the scale
        matrix of the (full-rank Gaussian) variational posterior over function
        values at the inducing points, conditional on observed data. Shape has
        the form `[b1, ..., bB, e2, e2]`, where `b1, ..., bB` is broadcast
        compatible with other parameters and `e2` is the number of inducing
        points.
      mean_fn: Python `callable` that acts on index points to produce a (batch
        of) vector(s) of mean values at those index points. Takes a `Tensor` of
        shape `[b1, ..., bB, f1, ..., fF]` and returns a `Tensor` whose shape is
        (broadcastable with) `[b1, ..., bB]`. Default value: `None` implies
        constant zero function.
      observation_noise_variance: `float` `Tensor` representing the variance
        of the noise in the Normal likelihood distribution of the model. May be
        batched, in which case the batch shape must be broadcastable with the
        shapes of all other batched parameters (`kernel.batch_shape`,
        `index_points`, etc.).
        Default value: `0.`
      predictive_noise_variance: `float` `Tensor` representing additional
        variance in the posterior predictive model. If `None`, we simply re-use
        `observation_noise_variance` for the posterior predictive noise. If set
        explicitly, however, we use the given value. This allows us, for
        example, to omit predictive noise variance (by setting this to zero) to
        obtain noiseless posterior predictions of function values, conditioned
        on noisy observations.
      jitter: `float` scalar `Tensor` added to the diagonal of the covariance
        matrix to ensure positive definiteness of the covariance matrix.
        Default value: `1e-6`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
        Default value: `False`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: "VariationalGaussianProcess".

    Raises:
      ValueError: if `mean_fn` is not `None` and is not callable.
    """
        parameters = dict(locals())
        with tf.name_scope(name or 'VariationalGaussianProcess') as name:
            dtype = dtype_util.common_dtype([
                kernel, index_points, inducing_index_points,
                variational_inducing_observations_loc,
                variational_inducing_observations_scale,
                observation_noise_variance, predictive_noise_variance, jitter
            ], tf.float32)

            index_points = tf.convert_to_tensor(index_points,
                                                dtype=dtype,
                                                name='index_points')
            inducing_index_points = tf.convert_to_tensor(
                inducing_index_points,
                dtype=dtype,
                name='inducing_index_points')
            variational_inducing_observations_loc = tf.convert_to_tensor(
                variational_inducing_observations_loc,
                dtype=dtype,
                name='variational_inducing_observations_loc')
            variational_inducing_observations_scale = tf.convert_to_tensor(
                variational_inducing_observations_scale,
                dtype=dtype,
                name='variational_inducing_observations_scale')
            observation_noise_variance = tf.convert_to_tensor(
                observation_noise_variance,
                dtype=dtype,
                name='observation_noise_variance')
            if predictive_noise_variance is None:
                predictive_noise_variance = observation_noise_variance
            else:
                predictive_noise_variance = tf.convert_to_tensor(
                    predictive_noise_variance,
                    dtype=dtype,
                    name='predictive_noise_variance')
            jitter = tf.convert_to_tensor(jitter, dtype=dtype, name='jitter')

            self._kernel = kernel
            self._index_points = index_points
            self._inducing_index_points = inducing_index_points
            self._variational_inducing_observations_posterior = (
                mvn_linear_operator.MultivariateNormalLinearOperator(
                    loc=variational_inducing_observations_loc,
                    scale=tf.linalg.LinearOperatorFullMatrix(
                        variational_inducing_observations_scale),
                    name='variational_inducing_observations_posterior'))

            # Default to a constant zero function, borrowing the dtype from
            # index_points to ensure consistency.
            if mean_fn is None:
                mean_fn = lambda x: tf.zeros([1], dtype=dtype)
            else:
                if not callable(mean_fn):
                    raise ValueError('`mean_fn` must be a Python callable')
            self._mean_fn = mean_fn
            self._observation_noise_variance = observation_noise_variance
            self._predictive_noise_variance = predictive_noise_variance
            self._jitter = jitter

            with tf.name_scope('init'):
                # We let t and z denote predictive and inducing index points, resp.
                kzz = _add_diagonal_shift(
                    kernel.matrix(inducing_index_points,
                                  inducing_index_points), jitter)

                self._chol_kzz = tf.linalg.cholesky(kzz)
                self._kzz_inv_varloc = _solve_cholesky_factored_system_vec(
                    self._chol_kzz, (variational_inducing_observations_loc -
                                     mean_fn(inducing_index_points)),
                    name='kzz_inv_varloc')

                loc, scale = self._compute_posterior_predictive_params()

                super(VariationalGaussianProcess,
                      self).__init__(loc=loc,
                                     scale=scale,
                                     validate_args=validate_args,
                                     allow_nan_stats=allow_nan_stats,
                                     name=name)
                self._parameters = parameters
                self._graph_parents = [
                    index_points, inducing_index_points,
                    variational_inducing_observations_loc,
                    variational_inducing_observations_scale,
                    observation_noise_variance, jitter
                ]