def __init__(self,
               kernel,
               index_points=None,
               observation_index_points=None,
               observations=None,
               observation_noise_variance=0.,
               predictive_noise_variance=None,
               mean_fn=None,
               jitter=1e-6,
               validate_args=False,
               allow_nan_stats=False,
               name='GaussianProcessRegressionModel'):
    """Construct a GaussianProcessRegressionModel instance.

    Args:
      kernel: `PositiveSemidefiniteKernel`-like instance representing the
        GP's covariance function.
      index_points: `float` `Tensor` representing finite collection, or batch of
        collections, of points in the index set over which the GP is defined.
        Shape has the form `[b1, ..., bB, e, f1, ..., fF]` where `F` is the
        number of feature dimensions and must equal `kernel.feature_ndims` and
        `e` is the number (size) of index points in each batch. Ultimately this
        distribution corresponds to an `e`-dimensional multivariate normal. The
        batch shape must be broadcastable with `kernel.batch_shape` and any
        batch dims yielded by `mean_fn`.
      observation_index_points: `float` `Tensor` representing finite collection,
        or batch of collections, of points in the index set for which some data
        has been observed. Shape has the form `[b1, ..., bB, e, f1, ..., fF]`
        where `F` is the number of feature dimensions and must equal
        `kernel.feature_ndims`, and `e` is the number (size) of index points in
        each batch. `[b1, ..., bB, e]` must be broadcastable with the shape of
        `observations`, and `[b1, ..., bB]` must be broadcastable with the
        shapes of all other batched parameters (`kernel.batch_shape`,
        `index_points`, etc). The default value is `None`, which corresponds to
        the empty set of observations, and simply results in the prior
        predictive model (a GP with noise of variance
        `predictive_noise_variance`).
      observations: `float` `Tensor` representing collection, or batch of
        collections, of observations corresponding to
        `observation_index_points`. Shape has the form `[b1, ..., bB, e]`, which
        must be brodcastable with the batch and example shapes of
        `observation_index_points`. The batch shape `[b1, ..., bB]` must be
        broadcastable with the shapes of all other batched parameters
        (`kernel.batch_shape`, `index_points`, etc.). The default value is
        `None`, which corresponds to the empty set of observations, and simply
        results in the prior predictive model (a GP with noise of variance
        `predictive_noise_variance`).
      observation_noise_variance: `float` `Tensor` representing the variance
        of the noise in the Normal likelihood distribution of the model. May be
        batched, in which case the batch shape must be broadcastable with the
        shapes of all other batched parameters (`kernel.batch_shape`,
        `index_points`, etc.).
        Default value: `0.`
      predictive_noise_variance: `float` `Tensor` representing the variance in
        the posterior predictive model. If `None`, we simply re-use
        `observation_noise_variance` for the posterior predictive noise. If set
        explicitly, however, we use this value. This allows us, for example, to
        omit predictive noise variance (by setting this to zero) to obtain
        noiseless posterior predictions of function values, conditioned on noisy
        observations.
      mean_fn: Python `callable` that acts on `index_points` to produce a
        collection, or batch of collections, of mean values at `index_points`.
        Takes a `Tensor` of shape `[b1, ..., bB, f1, ..., fF]` and returns a
        `Tensor` whose shape is broadcastable with `[b1, ..., bB]`.
        Default value: `None` implies the constant zero function.
      jitter: `float` scalar `Tensor` added to the diagonal of the covariance
        matrix to ensure positive definiteness of the covariance matrix.
        Default value: `1e-6`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value `NaN` to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
        Default value: `False`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: 'GaussianProcessRegressionModel'.

    Raises:
      ValueError: if either
        - only one of `observations` and `observation_index_points` is given, or
        - `mean_fn` is not `None` and not callable.
    """
    parameters = dict(locals())
    with tf.name_scope(name) as name:
      dtype = dtype_util.common_dtype([
          index_points, observation_index_points, observations,
          observation_noise_variance, predictive_noise_variance, jitter
      ], tf.float32)
      index_points = tensor_util.convert_nonref_to_tensor(
          index_points, dtype=dtype, name='index_points')
      observation_index_points = tensor_util.convert_nonref_to_tensor(
          observation_index_points, dtype=dtype,
          name='observation_index_points')
      observations = tensor_util.convert_nonref_to_tensor(
          observations, dtype=dtype,
          name='observations')
      observation_noise_variance = tensor_util.convert_nonref_to_tensor(
          observation_noise_variance,
          dtype=dtype,
          name='observation_noise_variance')
      predictive_noise_variance = tensor_util.convert_nonref_to_tensor(
          predictive_noise_variance,
          dtype=dtype,
          name='observation_noise_variance')
      if predictive_noise_variance is None:
        predictive_noise_variance = observation_noise_variance
      jitter = tensor_util.convert_nonref_to_tensor(
          jitter, dtype=dtype, name='jitter')
      if (observation_index_points is None) != (observations is None):
        raise ValueError(
            '`observations` and `observation_index_points` must both be given '
            'or None. Got {} and {}, respectively.'.format(
                observations, observation_index_points))
      # Default to a constant zero function, borrowing the dtype from
      # index_points to ensure consistency.
      if mean_fn is None:
        mean_fn = lambda x: tf.zeros([1], dtype=dtype)
      else:
        if not callable(mean_fn):
          raise ValueError('`mean_fn` must be a Python callable')

      self._name = name
      self._observation_index_points = observation_index_points
      self._observations = observations
      self._observation_noise_variance = observation_noise_variance
      self._predictive_noise_variance = predictive_noise_variance
      self._jitter = jitter
      self._validate_args = validate_args

      with tf.name_scope('init'):
        conditional_kernel = tfpk.SchurComplement(
            base_kernel=kernel,
            fixed_inputs=observation_index_points,
            diag_shift=tfp_util.DeferredTensor(
                observation_noise_variance, lambda x: jitter + x))
        # Special logic for mean_fn only; SchurComplement already handles the
        # case of empty observations (ie, falls back to base_kernel).
        if _is_empty_observation_data(
            feature_ndims=kernel.feature_ndims,
            observation_index_points=observation_index_points,
            observations=observations):
          conditional_mean_fn = mean_fn
        else:
          _validate_observation_data(
              kernel=kernel,
              observation_index_points=observation_index_points,
              observations=observations)

          def conditional_mean_fn(x):
            """Conditional mean."""
            observations = tf.convert_to_tensor(self._observations)
            observation_index_points = tf.convert_to_tensor(
                self._observation_index_points)
            k_x_obs_linop = tf.linalg.LinearOperatorFullMatrix(
                kernel.matrix(x, observation_index_points))
            chol_linop = tf.linalg.LinearOperatorLowerTriangular(
                conditional_kernel.divisor_matrix_cholesky(
                    fixed_inputs=observation_index_points))

            diff = observations - mean_fn(observation_index_points)
            return mean_fn(x) + k_x_obs_linop.matvec(
                chol_linop.solvevec(chol_linop.solvevec(diff), adjoint=True))

        super(GaussianProcessRegressionModel, self).__init__(
            kernel=conditional_kernel,
            mean_fn=conditional_mean_fn,
            index_points=index_points,
            jitter=jitter,
            # What the GP super class calls "observation noise variance" we call
            # here the "predictive noise variance". We use the observation noise
            # variance for the fit/solve process above, and predictive for
            # downstream computations like sampling.
            observation_noise_variance=predictive_noise_variance,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats, name=name)
        self._parameters = parameters
Exemple #2
0
    def precompute_regression_model(
            df,
            kernel,
            observation_index_points,
            observations,
            index_points=None,
            observation_noise_variance=0.,
            predictive_noise_variance=None,
            mean_fn=None,
            cholesky_fn=None,
            validate_args=False,
            allow_nan_stats=False,
            name='PrecomputedStudentTProcessRegressionModel'):
        """Returns a StudentTProcessRegressionModel with precomputed quantities.

    This differs from the constructor by precomputing quantities associated with
    observations in a non-tape safe way. `index_points` is the only parameter
    that is allowed to vary (i.e. is a `Variable` / changes after
    initialization).

    Specifically:

    * We make `observation_index_points` and `observations` mandatory
      parameters.
    * We precompute `kernel(observation_index_points, observation_index_points)`
      along with any other associated quantities relating to `df`, `kernel`,
      `observations` and `observation_index_points`.

    A typical usecase would be optimizing kernel hyperparameters for a
    `StudenTProcess`, and computing the posterior predictive with respect to
    those optimized hyperparameters and observation / index-points pairs.

    WARNING: This method assumes `index_points` is the only varying parameter
    (i.e. is a `Variable` / changes after initialization) and hence is not
    tape-safe.

    Args:
      df: Positive Floating-point `Tensor` representing the degrees of freedom.
        Must be greather than 2.
      kernel: `PositiveSemidefiniteKernel`-like instance representing the
        StP's covariance function.
      observation_index_points: `float` `Tensor` representing finite collection,
        or batch of collections, of points in the index set for which some data
        has been observed. Shape has the form `[b1, ..., bB, e, f1, ..., fF]`
        where `F` is the number of feature dimensions and must equal
        `kernel.feature_ndims`, and `e` is the number (size) of index points in
        each batch. `[b1, ..., bB, e]` must be broadcastable with the shape of
        `observations`, and `[b1, ..., bB]` must be broadcastable with the
        shapes of all other batched parameters (`kernel.batch_shape`,
        `index_points`, etc). The default value is `None`, which corresponds to
        the empty set of observations, and simply results in the prior
        predictive model (a StP with noise of variance
        `predictive_noise_variance`).
      observations: `float` `Tensor` representing collection, or batch of
        collections, of observations corresponding to
        `observation_index_points`. Shape has the form `[b1, ..., bB, e]`, which
        must be brodcastable with the batch and example shapes of
        `observation_index_points`. The batch shape `[b1, ..., bB]` must be
        broadcastable with the shapes of all other batched parameters
        (`kernel.batch_shape`, `index_points`, etc.). The default value is
        `None`, which corresponds to the empty set of observations, and simply
        results in the prior predictive model (a StP with noise of variance
        `predictive_noise_variance`).
      index_points: `float` `Tensor` representing finite collection, or batch of
        collections, of points in the index set over which the StP is defined.
        Shape has the form `[b1, ..., bB, e, f1, ..., fF]` where `F` is the
        number of feature dimensions and must equal `kernel.feature_ndims` and
        `e` is the number (size) of index points in each batch. Ultimately this
        distribution corresponds to an `e`-dimensional multivariate normal. The
        batch shape must be broadcastable with `kernel.batch_shape` and any
        batch dims yielded by `mean_fn`.
      observation_noise_variance: `float` `Tensor` representing the variance
        of the noise in the Normal likelihood distribution of the model. May be
        batched, in which case the batch shape must be broadcastable with the
        shapes of all other batched parameters (`kernel.batch_shape`,
        `index_points`, etc.).
        Default value: `0.`
      predictive_noise_variance: `float` `Tensor` representing the variance in
        the posterior predictive model. If `None`, we simply re-use
        `observation_noise_variance` for the posterior predictive noise. If set
        explicitly, however, we use this value. This allows us, for example, to
        omit predictive noise variance (by setting this to zero) to obtain
        noiseless posterior predictions of function values, conditioned on noisy
        observations.
      mean_fn: Python `callable` that acts on `index_points` to produce a
        collection, or batch of collections, of mean values at `index_points`.
        Takes a `Tensor` of shape `[b1, ..., bB, f1, ..., fF]` and returns a
        `Tensor` whose shape is broadcastable with `[b1, ..., bB]`.
        Default value: `None` implies the constant zero function.
      cholesky_fn: Callable which takes a single (batch) matrix argument and
        returns a Cholesky-like lower triangular factor.  Default value: `None`,
        in which case `make_cholesky_with_jitter_fn` is used with the `jitter`
        parameter.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value `NaN` to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
        Default value: `False`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: 'PrecomputedStudentTProcessRegressionModel'.
    Returns
      An instance of `StudentTProcessRegressionModel` with precomputed
      quantities associated with observations.
    """

        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([
                df,
                index_points,
                observation_index_points,
                observations,
                observation_noise_variance,
                predictive_noise_variance,
            ], tf.float32)

            # Convert to tensor arguments that are expected to not be Variables / not
            # going to change.
            df = tf.convert_to_tensor(df, dtype=dtype)
            observation_index_points = tf.convert_to_tensor(
                observation_index_points, dtype=dtype)
            observation_noise_variance = tf.convert_to_tensor(
                observation_noise_variance, dtype=dtype)
            observations = tf.convert_to_tensor(observations, dtype=dtype)

            observation_cholesky = kernel.matrix(observation_index_points,
                                                 observation_index_points)

            broadcast_shape = distribution_util.get_broadcast_shape(
                observation_cholesky,
                observation_noise_variance[..., tf.newaxis, tf.newaxis])

            observation_cholesky = tf.broadcast_to(observation_cholesky,
                                                   broadcast_shape)

            observation_cholesky = tf.linalg.set_diag(
                observation_cholesky,
                tf.linalg.diag_part(observation_cholesky) +
                observation_noise_variance[..., tf.newaxis])
            if cholesky_fn is None:
                cholesky_fn = cholesky_util.make_cholesky_with_jitter_fn()

            observation_cholesky = cholesky_fn(observation_cholesky)
            observation_cholesky_operator = tf.linalg.LinearOperatorLowerTriangular(
                observation_cholesky)

            conditional_kernel = DampedSchurComplement(
                df=df,
                schur_complement=tfpk.SchurComplement(
                    base_kernel=kernel,
                    fixed_inputs=observation_index_points,
                    diag_shift=observation_noise_variance),
                fixed_inputs_observations=observations,
                validate_args=validate_args)

            if mean_fn is None:
                mean_fn = lambda x: tf.zeros([1], dtype=dtype)
            else:
                if not callable(mean_fn):
                    raise ValueError('`mean_fn` must be a Python callable')

            diff = observations - mean_fn(observation_index_points)
            solve_on_observation = observation_cholesky_operator.solvevec(
                observation_cholesky_operator.solvevec(diff), adjoint=True)

            def conditional_mean_fn(x):
                k_x_obs = kernel.matrix(x, observation_index_points)
                return mean_fn(x) + tf.linalg.matvec(k_x_obs,
                                                     solve_on_observation)

            stprm = StudentTProcessRegressionModel(
                df=df,
                kernel=kernel,
                observation_index_points=observation_index_points,
                observations=observations,
                index_points=index_points,
                observation_noise_variance=observation_noise_variance,
                predictive_noise_variance=predictive_noise_variance,
                cholesky_fn=cholesky_fn,
                _conditional_kernel=conditional_kernel,
                _conditional_mean_fn=conditional_mean_fn,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                name=name)

        return stprm
Exemple #3
0
def schur_complements(draw,
                      batch_shape=None,
                      event_dim=None,
                      feature_dim=None,
                      feature_ndims=None,
                      enable_vars=None,
                      depth=None):
    """Strategy for drawing `SchurComplement` kernels.

  The underlying kernel is drawn from the `kernels` strategy.

  Args:
    draw: Hypothesis strategy sampler supplied by `@hps.composite`.
    batch_shape: An optional `TensorShape`.  The batch shape of the resulting
      Kernel.  Hypothesis will pick a batch shape if omitted.
    event_dim: Optional Python int giving the size of each of the
      kernel's parameters' event dimensions.  This is shared across all
      parameters, permitting square event matrices, compatible location and
      scale Tensors, etc. If omitted, Hypothesis will choose one.
    feature_dim: Optional Python int giving the size of each feature dimension.
      If omitted, Hypothesis will choose one.
    feature_ndims: Optional Python int stating the number of feature dimensions
      inputs will have. If omitted, Hypothesis will choose one.
    enable_vars: TODO(bjp): Make this `True` all the time and put variable
      initialization in slicing_test.  If `False`, the returned parameters are
      all Tensors, never Variables or DeferredTensor.
    depth: Python `int` giving maximum nesting depth of compound kernel.

  Returns:
    kernels: A strategy for drawing `SchurComplement` kernels with the specified
      `batch_shape` (or an arbitrary one if omitted).
  """
    if depth is None:
        depth = draw(depths())
    if batch_shape is None:
        batch_shape = draw(tfp_hps.shapes())
    if event_dim is None:
        event_dim = draw(hps.integers(min_value=2, max_value=6))
    if feature_dim is None:
        feature_dim = draw(hps.integers(min_value=2, max_value=6))
    if feature_ndims is None:
        feature_ndims = draw(hps.integers(min_value=2, max_value=6))

    base_kernel, kernel_variable_names = draw(
        kernels(batch_shape=batch_shape,
                event_dim=event_dim,
                feature_dim=feature_dim,
                feature_ndims=feature_ndims,
                enable_vars=False,
                depth=depth - 1))

    # SchurComplement requires the inputs to have one example dimension.
    fixed_inputs = draw(
        kernel_input(batch_shape=batch_shape,
                     example_ndims=1,
                     feature_dim=feature_dim,
                     feature_ndims=feature_ndims))
    # Positive shift to ensure the divisor matrix is PD.
    diag_shift = np.float64(
        draw(
            hpnp.arrays(dtype=np.float64,
                        shape=tensorshape_util.as_list(batch_shape),
                        elements=hps.floats(1,
                                            100,
                                            allow_nan=False,
                                            allow_infinity=False))))

    hp.note('Forming SchurComplement kernel with fixed_inputs: {} '
            'and diag_shift: {}'.format(fixed_inputs, diag_shift))

    schur_complement_params = {
        'fixed_inputs': fixed_inputs,
        'diag_shift': diag_shift
    }

    for param_name in schur_complement_params:
        if enable_vars and draw(hps.booleans()):
            kernel_variable_names.append(param_name)
            schur_complement_params[param_name] = tf.Variable(
                schur_complement_params[param_name], name=param_name)
            if draw(hps.booleans()):
                schur_complement_params[
                    param_name] = tfp_hps.defer_and_count_usage(
                        schur_complement_params[param_name])
    result_kernel = tfpk.SchurComplement(
        base_kernel=base_kernel,
        fixed_inputs=schur_complement_params['fixed_inputs'],
        diag_shift=schur_complement_params['diag_shift'],
        cholesky_fn=lambda x: marginal_fns.retrying_cholesky(x)[0],
        validate_args=True)
    return result_kernel, kernel_variable_names
Exemple #4
0
    def __init__(self,
                 df,
                 kernel,
                 index_points=None,
                 observation_index_points=None,
                 observations=None,
                 observation_noise_variance=0.,
                 predictive_noise_variance=None,
                 mean_fn=None,
                 cholesky_fn=None,
                 marginal_fn=None,
                 validate_args=False,
                 allow_nan_stats=False,
                 name='StudentTProcessRegressionModel',
                 _conditional_kernel=None,
                 _conditional_mean_fn=None):
        """Construct a StudentTProcessRegressionModel instance.

    Args:
      df: Positive Floating-point `Tensor` representing the degrees of freedom.
        Must be greather than 2.
      kernel: `PositiveSemidefiniteKernel`-like instance representing the
        StP's covariance function.
      index_points: `float` `Tensor` representing finite collection, or batch of
        collections, of points in the index set over which the STP is defined.
        Shape has the form `[b1, ..., bB, e, f1, ..., fF]` where `F` is the
        number of feature dimensions and must equal `kernel.feature_ndims` and
        `e` is the number (size) of index points in each batch. Ultimately this
        distribution corresponds to an `e`-dimensional multivariate normal. The
        batch shape must be broadcastable with `kernel.batch_shape`.
      observation_index_points: `float` `Tensor` representing finite collection,
        or batch of collections, of points in the index set for which some data
        has been observed. Shape has the form `[b1, ..., bB, e, f1, ..., fF]`
        where `F` is the number of feature dimensions and must equal
        `kernel.feature_ndims`, and `e` is the number (size) of index points in
        each batch. `[b1, ..., bB, e]` must be broadcastable with the shape of
        `observations`, and `[b1, ..., bB]` must be broadcastable with the
        shapes of all other batched parameters (`kernel.batch_shape`,
        `index_points`, etc).
      observations: `float` `Tensor` representing collection, or batch of
        collections, of observations corresponding to
        `observation_index_points`. Shape has the form `[b1, ..., bB, e]`, which
        must be brodcastable with the batch and example shapes of
        `observation_index_points`. The batch shape `[b1, ..., bB]` must be
        broadcastable with the shapes of all other batched parameters
        (`kernel.batch_shape`, `index_points`, etc.).
      observation_noise_variance: `float` `Tensor` representing the variance
        of the noise in the Normal likelihood distribution of the model. May be
        batched, in which case the batch shape must be broadcastable with the
        shapes of all other batched parameters (`kernel.batch_shape`,
        `index_points`, etc.).
        Default value: `0.`
      predictive_noise_variance: `float` `Tensor` representing the variance in
        the posterior predictive model. If `None`, we simply re-use
        `observation_noise_variance` for the posterior predictive noise. If set
        explicitly, however, we use this value. This allows us, for example, to
        omit predictive noise variance (by setting this to zero) to obtain
        noiseless posterior predictions of function values, conditioned on noisy
        observations.
      mean_fn: Python `callable` that acts on `index_points` to produce a
        collection, or batch of collections, of mean values at `index_points`.
        Takes a `Tensor` of shape `[b1, ..., bB, f1, ..., fF]` and returns a
        `Tensor` whose shape is broadcastable with `[b1, ..., bB]`.
        Default value: `None` implies the constant zero function.
      cholesky_fn: Callable which takes a single (batch) matrix argument and
        returns a Cholesky-like lower triangular factor.  Default value: `None`,
        in which case `make_cholesky_with_jitter_fn`.
      marginal_fn: A Python callable that takes a location, covariance matrix,
        optional `validate_args`, `allow_nan_stats` and `name` arguments, and
        returns a multivariate Student-T subclass of `tfd.Distribution`.
        Default value: `None`, in which case a Cholesky-factorizing function is
        is created using `make_cholesky_with_jitter_fn`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value `NaN` to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
        Default value: `False`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: 'StudentTProcessRegressionModel'.
      _conditional_kernel: Internal parameter -- do not use.
      _conditional_mean_fn: Internal parameter -- do not use.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([
                df, kernel, index_points, observation_noise_variance,
                observations
            ], tf.float32)
            df = tensor_util.convert_nonref_to_tensor(df,
                                                      dtype=dtype,
                                                      name='df')
            index_points = tensor_util.convert_nonref_to_tensor(
                index_points, dtype=dtype, name='index_points')
            observation_index_points = tensor_util.convert_nonref_to_tensor(
                observation_index_points,
                dtype=dtype,
                name='observation_index_points')
            observations = tensor_util.convert_nonref_to_tensor(
                observations, dtype=dtype, name='observations')
            observation_noise_variance = tensor_util.convert_nonref_to_tensor(
                observation_noise_variance,
                dtype=dtype,
                name='observation_noise_variance')
            predictive_noise_variance = tensor_util.convert_nonref_to_tensor(
                predictive_noise_variance,
                dtype=dtype,
                name='predictive_noise_variance')
            if predictive_noise_variance is None:
                predictive_noise_variance = observation_noise_variance
            if (observation_index_points is None) != (observations is None):
                raise ValueError(
                    '`observations` and `observation_index_points` must both be given '
                    'or None. Got {} and {}, respectively.'.format(
                        observations, observation_index_points))
            # Default to a constant zero function, borrowing the dtype from
            # index_points to ensure consistency.
            if mean_fn is None:
                mean_fn = lambda x: tf.zeros([1], dtype=dtype)
            else:
                if not callable(mean_fn):
                    raise ValueError('`mean_fn` must be a Python callable')

            if cholesky_fn is None:
                cholesky_fn = cholesky_util.make_cholesky_with_jitter_fn()

            self._observation_index_points = observation_index_points
            self._observations = observations
            self._observation_noise_variance = observation_noise_variance
            self._predictive_noise_variance = predictive_noise_variance

            with tf.name_scope('init'):
                if _conditional_kernel is None:
                    _conditional_kernel = DampedSchurComplement(
                        df=df,
                        schur_complement=tfpk.SchurComplement(
                            base_kernel=kernel,
                            fixed_inputs=self._observation_index_points,
                            diag_shift=observation_noise_variance),
                        fixed_inputs_observations=self._observations,
                        validate_args=validate_args)

                # Special logic for mean_fn only; SchurComplement already handles the
                # case of empty observations (ie, falls back to base_kernel).
                if _is_empty_observation_data(
                        feature_ndims=kernel.feature_ndims,
                        observation_index_points=observation_index_points,
                        observations=observations):
                    if _conditional_mean_fn is None:
                        _conditional_mean_fn = mean_fn
                else:
                    _validate_observation_data(
                        kernel=kernel,
                        observation_index_points=observation_index_points,
                        observations=observations)
                    n = tf.cast(ps.shape(observations)[-1], dtype=dtype)
                    df = tfp_util.DeferredTensor(df, lambda x: x + n)

                    if _conditional_mean_fn is None:

                        def conditional_mean_fn(x):
                            """Conditional mean."""
                            observations = tf.convert_to_tensor(
                                self._observations)
                            observation_index_points = tf.convert_to_tensor(
                                self._observation_index_points)
                            k_x_obs_linop = tf.linalg.LinearOperatorFullMatrix(
                                kernel.matrix(x, observation_index_points))
                            chol_linop = tf.linalg.LinearOperatorLowerTriangular(
                                _conditional_kernel.divisor_matrix_cholesky(
                                    fixed_inputs=observation_index_points))
                            diff = observations - mean_fn(
                                observation_index_points)
                            return mean_fn(x) + k_x_obs_linop.matvec(
                                chol_linop.solvevec(chol_linop.solvevec(diff),
                                                    adjoint=True))

                        _conditional_mean_fn = conditional_mean_fn

                super(StudentTProcessRegressionModel, self).__init__(
                    df=df,
                    kernel=_conditional_kernel,
                    mean_fn=_conditional_mean_fn,
                    cholesky_fn=cholesky_fn,
                    index_points=index_points,
                    observation_noise_variance=predictive_noise_variance,
                    validate_args=validate_args,
                    allow_nan_stats=allow_nan_stats,
                    name=name)
                self._parameters = parameters