def testMeanSameAsGPRM(self):
        df = np.float64(3.)
        index_points = np.linspace(-4., 4., 5, dtype=np.float64)
        index_points = np.stack(np.meshgrid(index_points, index_points),
                                axis=-1)
        index_points = np.reshape(index_points, [-1, 2])

        # Kernel with batch_shape [5, 3]
        amplitude = np.array([1., 2., 3., 4., 5.], np.float64).reshape([5, 1])
        length_scale = np.array([.1, .2, .3], np.float64).reshape([1, 3])
        observation_noise_variance = np.array([1e-5, 1e-6, 1e-9],
                                              np.float64).reshape([1, 3])

        observation_index_points = (np.random.uniform(
            -1., 1., (3, 7, 2)).astype(np.float64))
        observations = np.random.uniform(-1., 1., (3, 7)).astype(np.float64)

        kernel = psd_kernels.ExponentiatedQuadratic(amplitude, length_scale)
        stprm = tfd.StudentTProcessRegressionModel(
            df=df,
            kernel=kernel,
            index_points=index_points,
            observation_index_points=observation_index_points,
            observations=observations,
            observation_noise_variance=observation_noise_variance)
        gprm = tfd.GaussianProcessRegressionModel(
            kernel=kernel,
            index_points=index_points,
            observation_index_points=observation_index_points,
            observations=observations,
            observation_noise_variance=observation_noise_variance)

        self.assertAllClose(self.evaluate(stprm.mean()),
                            self.evaluate(gprm.mean()))
def gaussian_process_regression_models(draw,
                                       kernel_name=None,
                                       batch_shape=None,
                                       event_dim=None,
                                       feature_dim=None,
                                       feature_ndims=None,
                                       enable_vars=False):
    # First draw a kernel.
    k, _ = draw(
        kernel_hps.base_kernels(
            kernel_name=kernel_name,
            batch_shape=batch_shape,
            event_dim=event_dim,
            feature_dim=feature_dim,
            feature_ndims=feature_ndims,
            # Disable variables
            enable_vars=False))
    compatible_batch_shape = draw(
        tfp_hps.broadcast_compatible_shape(k.batch_shape))
    index_points = draw(
        kernel_hps.kernel_input(batch_shape=compatible_batch_shape,
                                example_ndims=1,
                                feature_dim=feature_dim,
                                feature_ndims=feature_ndims,
                                enable_vars=enable_vars,
                                name='index_points'))

    observation_index_points = draw(
        kernel_hps.kernel_input(batch_shape=compatible_batch_shape,
                                example_ndims=1,
                                feature_dim=feature_dim,
                                feature_ndims=feature_ndims,
                                enable_vars=enable_vars,
                                name='observation_index_points'))

    observations = draw(
        kernel_hps.kernel_input(
            batch_shape=compatible_batch_shape,
            example_ndims=1,
            # This is the example dimension suggested observation_index_points.
            example_dim=int(
                observation_index_points.shape[-(feature_ndims + 1)]),
            # No feature dimensions.
            feature_dim=0,
            feature_ndims=0,
            enable_vars=enable_vars,
            name='observations'))

    params = draw(
        broadcasting_params('GaussianProcessRegressionModel',
                            compatible_batch_shape,
                            event_dim=event_dim,
                            enable_vars=enable_vars))
    gp = tfd.GaussianProcessRegressionModel(
        kernel=k,
        index_points=index_points,
        observation_index_points=observation_index_points,
        observations=observations,
        observation_noise_variance=params['observation_noise_variance'])
    return gp
    def testInitParameterVariations(self, noise_kwargs, implied_values):
        num_test_points = 3
        num_obs_points = 4
        kernel = psd_kernels.ExponentiatedQuadratic()
        index_points = np.random.uniform(-1., 1., (num_test_points, 1))
        observation_index_points = np.random.uniform(-1., 1.,
                                                     (num_obs_points, 1))
        observations = np.random.uniform(-1., 1., num_obs_points)
        jitter = 1e-6

        gprm = tfd.GaussianProcessRegressionModel(
            kernel=kernel,
            index_points=index_points,
            observation_index_points=observation_index_points,
            observations=observations,
            jitter=jitter,
            validate_args=True,
            **noise_kwargs)

        # 'Property' means what was passed to CTOR. 'Parameter' is the effective
        # value by which the distribution is parameterized.
        implied_onv_param = implied_values[
            'observation_noise_variance_parameter']
        implied_pnv_param = implied_values[
            'predictive_noise_variance_parameter']

        # k_xx - k_xn @ (k_nn + ONV) @ k_nx + PNV
        k = lambda x, y: _np_kernel_matrix_fn(1., 1., x, y)
        k_tt_ = k(index_points, index_points)
        k_tx_ = k(index_points, observation_index_points)
        k_xx_plus_noise_ = (
            k(observation_index_points, observation_index_points) +
            (jitter + implied_onv_param) * np.eye(num_obs_points))

        expected_predictive_covariance = (
            k_tt_ - np.dot(k_tx_, np.linalg.solve(k_xx_plus_noise_, k_tx_.T)) +
            implied_pnv_param * np.eye(num_test_points))

        # Assertion 1: predictive covariance is correct.
        self.assertAllClose(self.evaluate(gprm.covariance()),
                            expected_predictive_covariance)

        # Assertion 2: predictive_noise_variance property is correct
        self.assertIsInstance(gprm.predictive_noise_variance, tf.Tensor)
        self.assertAllClose(self.evaluate(gprm.predictive_noise_variance),
                            implied_pnv_param)

        # Assertion 3: observation_noise_variance property is correct.
        self.assertIsInstance(gprm.observation_noise_variance, tf.Tensor)
        self.assertAllClose(
            self.evaluate(gprm.observation_noise_variance),
            # Note that this is, somewhat unintuitively, expceted to equal the
            # predictive_noise_variance. This is because of 1) the inheritance
            # structure of GPRM as a subclass of GaussianProcess and 2) the poor
            # choice of name of the GaussianProcess noise parameter. The latter
            # issue is being cleaned up in cl/256413439.
            implied_pnv_param)
Exemple #4
0
    def testErrorCases(self):
        kernel = psd_kernels.ExponentiatedQuadratic()
        index_points = np.random.uniform(-1., 1., (10, 1)).astype(np.float64)
        observation_index_points = (np.random.uniform(-1., 1., (5, 1)).astype(
            np.float64))
        observations = np.random.uniform(-1., 1., 3).astype(np.float64)

        # Both or neither of `observation_index_points` and `observations` must be
        # specified.
        with self.assertRaises(ValueError):
            tfd.GaussianProcessRegressionModel(kernel,
                                               index_points,
                                               observation_index_points=None,
                                               observations=observations)
        with self.assertRaises(ValueError):
            tfd.GaussianProcessRegressionModel(kernel,
                                               index_points,
                                               observation_index_points,
                                               observations=None)

        # If specified, mean_fn must be a callable.
        with self.assertRaises(ValueError):
            tfd.GaussianProcessRegressionModel(kernel,
                                               index_points,
                                               mean_fn=0.)

        # Observation index point and observation counts must be broadcastable.
        if self.is_static or tf.executing_eagerly:
            with self.assertRaises(ValueError):
                tfd.GaussianProcessRegressionModel(
                    kernel,
                    index_points,
                    observation_index_points=np.ones([2, 2, 2]),
                    observations=np.ones([5, 5]))
        else:
            gprm = tfd.GaussianProcessRegressionModel(
                kernel,
                index_points,
                observation_index_points=tf.placeholder_with_default(
                    np.ones([2, 2, 2]), shape=None),
                observations=tf.placeholder_with_default(np.ones([5, 5]),
                                                         shape=None))
            with self.assertRaises(ValueError):
                self.evaluate(gprm.event_shape_tensor())
Exemple #5
0
  def testEmptyDataMatchesGPPrior(self):
    amp = np.float64(.5)
    len_scale = np.float64(.2)
    jitter = np.float64(1e-4)
    index_points = np.random.uniform(-1., 1., (10, 1)).astype(np.float64)

    # k_xx - k_xn @ (k_nn + sigma^2) @ k_nx + sigma^2
    mean_fn = lambda x: x[:, 0]**2

    kernel = psd_kernels.ExponentiatedQuadratic(amp, len_scale)
    gp = tfd.GaussianProcess(
        kernel,
        index_points,
        mean_fn=mean_fn,
        jitter=jitter,
        validate_args=True)

    gprm_nones = tfd.GaussianProcessRegressionModel(
        kernel,
        index_points,
        mean_fn=mean_fn,
        jitter=jitter,
        validate_args=True)

    gprm_zero_shapes = tfd.GaussianProcessRegressionModel(
        kernel,
        index_points,
        observation_index_points=tf.ones([5, 0, 1], tf.float64),
        observations=tf.ones([5, 0], tf.float64),
        mean_fn=mean_fn,
        jitter=jitter,
        validate_args=True)

    for gprm in [gprm_nones, gprm_zero_shapes]:
      self.assertAllClose(self.evaluate(gp.mean()), self.evaluate(gprm.mean()))
      self.assertAllClose(self.evaluate(gp.covariance()),
                          self.evaluate(gprm.covariance()))
      self.assertAllClose(self.evaluate(gp.variance()),
                          self.evaluate(gprm.variance()))

      observations = np.random.uniform(-1., 1., 10).astype(np.float64)
      self.assertAllClose(self.evaluate(gp.log_prob(observations)),
                          self.evaluate(gprm.log_prob(observations)))
    def testErrorCases(self):
        kernel = psd_kernels.ExponentiatedQuadratic()
        index_points = np.random.uniform(-1., 1., (10, 1)).astype(np.float64)
        observation_index_points = (np.random.uniform(-1., 1., (5, 1)).astype(
            np.float64))
        observations = np.random.uniform(-1., 1., 3).astype(np.float64)

        # Both or neither of `observation_index_points` and `observations` must be
        # specified.
        with self.assertRaises(ValueError):
            tfd.GaussianProcessRegressionModel(kernel,
                                               index_points,
                                               observation_index_points=None,
                                               observations=observations,
                                               validate_args=True)
        with self.assertRaises(ValueError):
            tfd.GaussianProcessRegressionModel(kernel,
                                               index_points,
                                               observation_index_points,
                                               observations=None,
                                               validate_args=True)

        # If specified, mean_fn must be a callable.
        with self.assertRaises(ValueError):
            tfd.GaussianProcessRegressionModel(kernel,
                                               index_points,
                                               mean_fn=0.,
                                               validate_args=True)

        # Observation index point and observation counts must be broadcastable.
        # Errors based on conditions of dynamic shape in graph mode cannot be
        # caught, so we only check this error case in static shape or eager mode.
        if self.is_static or tf.executing_eagerly():
            with self.assertRaises(ValueError):
                tfd.GaussianProcessRegressionModel(
                    kernel,
                    index_points,
                    observation_index_points=np.ones([2, 2, 2]),
                    observations=np.ones([5, 5]),
                    validate_args=True)
  def testGPPosteriorPredictive(self):
    amplitude = np.float64(.5)
    length_scale = np.float64(2.)
    jitter = np.float64(1e-4)
    observation_noise_variance = np.float64(3e-3)
    kernel = psd_kernels.ExponentiatedQuadratic(amplitude, length_scale)

    index_points = np.random.uniform(-1., 1., 10)[..., np.newaxis]

    gp = tfd.GaussianProcess(
        kernel,
        index_points,
        observation_noise_variance=observation_noise_variance,
        jitter=jitter,
        validate_args=True)

    predictive_index_points = np.random.uniform(1., 2., 10)[..., np.newaxis]
    observations = np.linspace(1., 10., 10)

    expected_gprm = tfd.GaussianProcessRegressionModel(
        kernel=kernel,
        observation_index_points=index_points,
        observations=observations,
        observation_noise_variance=observation_noise_variance,
        jitter=jitter,
        index_points=predictive_index_points,
        validate_args=True)

    actual_gprm = gp.posterior_predictive(
        predictive_index_points=predictive_index_points,
        observations=observations)

    samples = self.evaluate(actual_gprm.sample(10, seed=test_util.test_seed()))

    self.assertAllClose(
        self.evaluate(expected_gprm.mean()),
        self.evaluate(actual_gprm.mean()))

    self.assertAllClose(
        self.evaluate(expected_gprm.covariance()),
        self.evaluate(actual_gprm.covariance()))

    self.assertAllClose(
        self.evaluate(expected_gprm.log_prob(samples)),
        self.evaluate(actual_gprm.log_prob(samples)))
    def testLogProbNearGPRM(self):
        # For large df, the log_prob calculations should be the same.
        df = np.float64(1e6)
        index_points = np.linspace(-4., 4., 5, dtype=np.float64)
        index_points = np.stack(np.meshgrid(index_points, index_points),
                                axis=-1)
        index_points = np.reshape(index_points, [-1, 2])

        # Kernel with batch_shape [5, 3]
        amplitude = np.array([1., 2., 3., 4., 5.], np.float64).reshape([5, 1])
        length_scale = np.array([.1, .2, .3], np.float64).reshape([1, 3])
        observation_noise_variance = np.array([1e-5, 1e-6, 1e-9],
                                              np.float64).reshape([1, 3])

        observation_index_points = (np.random.uniform(
            -1., 1., (3, 7, 2)).astype(np.float64))
        observations = np.random.uniform(-1., 1., (3, 7)).astype(np.float64)

        kernel = psd_kernels.ExponentiatedQuadratic(amplitude, length_scale)
        stprm = tfd.StudentTProcessRegressionModel(
            df=df,
            kernel=kernel,
            index_points=index_points,
            observation_index_points=observation_index_points,
            observations=observations,
            observation_noise_variance=observation_noise_variance)
        gprm = tfd.GaussianProcessRegressionModel(
            kernel=kernel,
            index_points=index_points,
            observation_index_points=observation_index_points,
            observations=observations,
            observation_noise_variance=observation_noise_variance)

        x = np.linspace(-3., 3., 25)

        self.assertAllClose(self.evaluate(stprm.log_prob(x)),
                            self.evaluate(gprm.log_prob(x)),
                            rtol=2e-5)
    def testMeanVarianceAndCovariancePrecomputed(self):
        amplitude = np.array([1., 2.], np.float64).reshape([2, 1])
        length_scale = np.array([.1, .2, .3], np.float64).reshape([1, 3])
        observation_noise_variance = np.array([1e-9], np.float64)

        jitter = np.float64(1e-6)
        observation_index_points = (np.random.uniform(
            -1., 1., (1, 1, 7, 2)).astype(np.float64))
        observations = np.random.uniform(-1., 1., (1, 1, 7)).astype(np.float64)

        index_points = np.random.uniform(-1., 1., (6, 2)).astype(np.float64)

        kernel = psd_kernels.ExponentiatedQuadratic(amplitude, length_scale)
        gprm = tfd.GaussianProcessRegressionModel(
            kernel=kernel,
            index_points=index_points,
            observation_index_points=observation_index_points,
            observations=observations,
            observation_noise_variance=observation_noise_variance,
            jitter=jitter,
            validate_args=True)

        precomputed_gprm = tfd.GaussianProcessRegressionModel.precompute_regression_model(
            kernel=kernel,
            index_points=index_points,
            observation_index_points=observation_index_points,
            observations=observations,
            observation_noise_variance=observation_noise_variance,
            jitter=jitter,
            validate_args=True)

        self.assertAllClose(self.evaluate(precomputed_gprm.covariance()),
                            self.evaluate(gprm.covariance()))
        self.assertAllClose(self.evaluate(precomputed_gprm.variance()),
                            self.evaluate(gprm.variance()))
        self.assertAllClose(self.evaluate(precomputed_gprm.mean()),
                            self.evaluate(gprm.mean()))
  def testShapes(self):
    # 5x5 grid of index points in R^2 and flatten to 25x2
    index_points = np.linspace(-4., 4., 5, dtype=np.float64)
    index_points = np.stack(np.meshgrid(index_points, index_points), axis=-1)
    index_points = np.reshape(index_points, [-1, 2])
    # ==> shape = [25, 2]
    batched_index_points = np.expand_dims(np.stack([index_points]*6), -3)
    # ==> shape = [6, 1, 25, 2]

    # Kernel with batch_shape [2, 4, 1, 1]
    amplitude = np.array([1., 2.], np.float64).reshape([2, 1, 1, 1])
    length_scale = np.array([.1, .2, .3, .4], np.float64).reshape([1, 4, 1, 1])

    jitter = np.float64(1e-6)
    observation_noise_variance = np.float64(1e-2)
    observation_index_points = (
        np.random.uniform(-1., 1., (3, 7, 2)).astype(np.float64))
    observations = np.random.uniform(-1., 1., (3, 7)).astype(np.float64)

    if not self.is_static:
      amplitude = tf.placeholder_with_default(amplitude, shape=None)
      length_scale = tf.placeholder_with_default(length_scale, shape=None)
      batched_index_points = tf.placeholder_with_default(
          batched_index_points, shape=None)

      observation_index_points = tf.placeholder_with_default(
          observation_index_points, shape=None)
      observations = tf.placeholder_with_default(observations, shape=None)

    kernel = psd_kernels.ExponentiatedQuadratic(amplitude, length_scale)

    gprm = tfd.GaussianProcessRegressionModel(
        kernel,
        batched_index_points,
        observation_index_points,
        observations,
        observation_noise_variance,
        jitter=jitter)

    batch_shape = [2, 4, 6, 3]
    event_shape = [25]
    sample_shape = [9, 3]

    samples = gprm.sample(sample_shape)

    if self.is_static or tf.executing_eagerly():
      self.assertAllEqual(gprm.batch_shape_tensor(), batch_shape)
      self.assertAllEqual(gprm.event_shape_tensor(), event_shape)
      self.assertAllEqual(samples.shape,
                          sample_shape + batch_shape + event_shape)
      self.assertAllEqual(gprm.batch_shape, batch_shape)
      self.assertAllEqual(gprm.event_shape, event_shape)
      self.assertAllEqual(samples.shape,
                          sample_shape + batch_shape + event_shape)
    else:
      self.assertAllEqual(self.evaluate(gprm.batch_shape_tensor()), batch_shape)
      self.assertAllEqual(self.evaluate(gprm.event_shape_tensor()), event_shape)
      self.assertAllEqual(self.evaluate(samples).shape,
                          sample_shape + batch_shape + event_shape)
      self.assertIsNone(samples.shape.ndims)
      self.assertIsNone(gprm.batch_shape.ndims)
      self.assertEqual(gprm.event_shape.ndims, 1)
      self.assertIsNone(gprm.event_shape.dims[0].value)
  def testCopy(self):
    # 5 random index points in R^2
    index_points_1 = np.random.uniform(-4., 4., (5, 2)).astype(np.float32)
    # 10 random index points in R^2
    index_points_2 = np.random.uniform(-4., 4., (10, 2)).astype(np.float32)

    observation_index_points_1 = (
        np.random.uniform(-4., 4., (7, 2)).astype(np.float32))
    observation_index_points_2 = (
        np.random.uniform(-4., 4., (9, 2)).astype(np.float32))

    observations_1 = np.random.uniform(-1., 1., 7).astype(np.float32)
    observations_2 = np.random.uniform(-1., 1., 9).astype(np.float32)

    # ==> shape = [6, 25, 2]
    if not self.is_static:
      index_points_1 = tf.placeholder_with_default(index_points_1, shape=None)
      index_points_2 = tf.placeholder_with_default(index_points_2, shape=None)
      observation_index_points_1 = tf.placeholder_with_default(
          observation_index_points_1, shape=None)
      observation_index_points_2 = tf.placeholder_with_default(
          observation_index_points_2, shape=None)
      observations_1 = tf.placeholder_with_default(observations_1, shape=None)
      observations_2 = tf.placeholder_with_default(observations_2, shape=None)

    mean_fn = lambda x: np.array([0.], np.float32)
    kernel_1 = psd_kernels.ExponentiatedQuadratic()
    kernel_2 = psd_kernels.ExpSinSquared()

    gprm1 = tfd.GaussianProcessRegressionModel(
        kernel=kernel_1,
        index_points=index_points_1,
        observation_index_points=observation_index_points_1,
        observations=observations_1,
        mean_fn=mean_fn,
        jitter=1e-5)
    gprm2 = gprm1.copy(
        kernel=kernel_2,
        index_points=index_points_2,
        observation_index_points=observation_index_points_2,
        observations=observations_2)

    event_shape_1 = [5]
    event_shape_2 = [10]

    self.assertEqual(gprm1.mean_fn, gprm2.mean_fn)
    self.assertIsInstance(gprm1.kernel, psd_kernels.ExponentiatedQuadratic)
    self.assertIsInstance(gprm2.kernel, psd_kernels.ExpSinSquared)

    if self.is_static or tf.executing_eagerly():
      self.assertAllEqual(gprm1.batch_shape, gprm2.batch_shape)
      self.assertAllEqual(gprm1.event_shape, event_shape_1)
      self.assertAllEqual(gprm2.event_shape, event_shape_2)
      self.assertAllEqual(gprm1.index_points, index_points_1)
      self.assertAllEqual(gprm2.index_points, index_points_2)
      self.assertAllEqual(tensor_util.constant_value(gprm1.jitter),
                          tensor_util.constant_value(gprm2.jitter))
    else:
      self.assertAllEqual(self.evaluate(gprm1.batch_shape_tensor()),
                          self.evaluate(gprm2.batch_shape_tensor()))
      self.assertAllEqual(self.evaluate(gprm1.event_shape_tensor()),
                          event_shape_1)
      self.assertAllEqual(self.evaluate(gprm2.event_shape_tensor()),
                          event_shape_2)
      self.assertEqual(self.evaluate(gprm1.jitter), self.evaluate(gprm2.jitter))
      self.assertAllEqual(self.evaluate(gprm1.index_points), index_points_1)
      self.assertAllEqual(self.evaluate(gprm2.index_points), index_points_2)
  def testMeanVarianceAndCovariance(self):
    amp = np.float64(.5)
    len_scale = np.float64(.2)
    observation_noise_variance = np.float64(1e-3)
    jitter = np.float64(1e-4)
    num_test = 10
    num_obs = 3
    index_points = np.random.uniform(-1., 1., (num_test, 1)).astype(np.float64)
    observation_index_points = (
        np.random.uniform(-1., 1., (num_obs, 1)).astype(np.float64))
    observations = np.random.uniform(-1., 1., 3).astype(np.float64)

    # k_xx - k_xn @ (k_nn + sigma^2) @ k_nx + sigma^2
    k = lambda x, y: _np_kernel_matrix_fn(amp, len_scale, x, y)
    k_xx_ = k(index_points, index_points)
    k_xn_ = k(index_points, observation_index_points)
    k_nn_plus_noise_ = (
        k(observation_index_points, observation_index_points) +
        (jitter + observation_noise_variance) * np.eye(num_obs))

    expected_predictive_covariance_no_noise = (
        k_xx_ - np.dot(k_xn_, np.linalg.solve(k_nn_plus_noise_, k_xn_.T)) +
        np.eye(num_test) * jitter)

    expected_predictive_covariance_with_noise = (
        expected_predictive_covariance_no_noise +
        np.eye(num_test) * observation_noise_variance)

    mean_fn = lambda x: x[:, 0]**2
    prior_mean = mean_fn(observation_index_points)
    expected_mean = np.dot(
        k_xn_, np.linalg.solve(k_nn_plus_noise_, observations - prior_mean))

    kernel = psd_kernels.ExponentiatedQuadratic(amp, len_scale)
    gprm = tfd.GaussianProcessRegressionModel(
        kernel=kernel,
        index_points=index_points,
        observation_index_points=observation_index_points,
        observations=observations,
        observation_noise_variance=observation_noise_variance,
        mean_fn=mean_fn,
        jitter=jitter)

    self.assertAllClose(expected_predictive_covariance_with_noise,
                        self.evaluate(gprm.covariance()))
    self.assertAllClose(np.diag(expected_predictive_covariance_with_noise),
                        self.evaluate(gprm.variance()))
    self.assertAllClose(expected_mean,
                        self.evaluate(gprm.mean()))

    gprm_no_predictive_noise = tfd.GaussianProcessRegressionModel(
        kernel=kernel,
        index_points=index_points,
        observation_index_points=observation_index_points,
        observations=observations,
        observation_noise_variance=observation_noise_variance,
        predictive_noise_variance=0.,
        mean_fn=mean_fn,
        jitter=jitter)

    self.assertAllClose(expected_predictive_covariance_no_noise,
                        self.evaluate(gprm_no_predictive_noise.covariance()))
    self.assertAllClose(np.diag(expected_predictive_covariance_no_noise),
                        self.evaluate(gprm_no_predictive_noise.variance()))
    self.assertAllClose(expected_mean,
                        self.evaluate(gprm_no_predictive_noise.mean()))
    def testShapes(self):
        # We'll use a batch shape of [2, 3, 5, 7, 11]

        # 5x5 grid of index points in R^2 and flatten to 25x2
        index_points = np.linspace(-4., 4., 5, dtype=np.float64)
        index_points = np.stack(np.meshgrid(index_points, index_points),
                                axis=-1)
        index_points = np.reshape(index_points, [-1, 2])
        # ==> shape = [25, 2]
        batched_index_points = np.reshape(index_points, [1, 1, 25, 2])
        batched_index_points = np.stack([batched_index_points] * 5)
        # ==> shape = [5, 1, 1, 25, 2]

        # Kernel with batch_shape [2, 3, 1, 1, 1]
        amplitude = np.array([1., 2.], np.float64).reshape([2, 1, 1, 1, 1])
        length_scale = np.array([.1, .2, .3],
                                np.float64).reshape([1, 3, 1, 1, 1])
        observation_noise_variance = np.array([1e-9], np.float64).reshape(
            [1, 1, 1, 1, 1])

        jitter = np.float64(1e-6)
        observation_index_points = (np.random.uniform(
            -1., 1., (7, 1, 7, 2)).astype(np.float64))
        observations = np.random.uniform(-1., 1., (11, 7)).astype(np.float64)

        def cholesky_fn(x):
            return tf.linalg.cholesky(
                tf.linalg.set_diag(x,
                                   tf.linalg.diag_part(x) + 1.))

        if not self.is_static:
            amplitude = tf1.placeholder_with_default(amplitude, shape=None)
            length_scale = tf1.placeholder_with_default(length_scale,
                                                        shape=None)
            batched_index_points = tf1.placeholder_with_default(
                batched_index_points, shape=None)

            observation_index_points = tf1.placeholder_with_default(
                observation_index_points, shape=None)
            observations = tf1.placeholder_with_default(observations,
                                                        shape=None)

        kernel = psd_kernels.ExponentiatedQuadratic(amplitude, length_scale)

        gprm = tfd.GaussianProcessRegressionModel(kernel,
                                                  batched_index_points,
                                                  observation_index_points,
                                                  observations,
                                                  observation_noise_variance,
                                                  cholesky_fn=cholesky_fn,
                                                  jitter=jitter,
                                                  validate_args=True)

        batch_shape = [2, 3, 5, 7, 11]
        event_shape = [25]
        sample_shape = [9, 3]

        samples = gprm.sample(sample_shape, seed=test_util.test_seed())

        self.assertIs(cholesky_fn, gprm.cholesky_fn)

        if self.is_static or tf.executing_eagerly():
            self.assertAllEqual(gprm.batch_shape_tensor(), batch_shape)
            self.assertAllEqual(gprm.event_shape_tensor(), event_shape)
            self.assertAllEqual(samples.shape,
                                sample_shape + batch_shape + event_shape)
            self.assertAllEqual(gprm.batch_shape, batch_shape)
            self.assertAllEqual(gprm.event_shape, event_shape)
            self.assertAllEqual(samples.shape,
                                sample_shape + batch_shape + event_shape)
        else:
            self.assertAllEqual(self.evaluate(gprm.batch_shape_tensor()),
                                batch_shape)
            self.assertAllEqual(self.evaluate(gprm.event_shape_tensor()),
                                event_shape)
            self.assertAllEqual(
                self.evaluate(samples).shape,
                sample_shape + batch_shape + event_shape)
            self.assertIsNone(tensorshape_util.rank(samples.shape))
            self.assertIsNone(tensorshape_util.rank(gprm.batch_shape))
            self.assertEqual(tensorshape_util.rank(gprm.event_shape), 1)
            self.assertIsNone(
                tf.compat.dimension_value(
                    tensorshape_util.dims(gprm.event_shape)[0]))