def testLogProbWithIsMissing(self):
        index_points = tf.Variable(
            [[-1.0, 0.0], [-0.5, -0.5], [1.5, 0.0], [1.6, 1.5]],
            shape=None if self.is_static else tf.TensorShape(None))
        self.evaluate(index_points.initializer)
        amplitude = tf.convert_to_tensor(1.1)
        length_scale = tf.convert_to_tensor(0.9)

        gp = tfd.GaussianProcess(kernel=psd_kernels.ExponentiatedQuadratic(
            amplitude, length_scale),
                                 index_points=index_points,
                                 mean_fn=lambda x: tf.reduce_mean(x, axis=-1),
                                 observation_noise_variance=.05,
                                 jitter=0.0)

        x = gp.sample(5, seed=test_util.test_seed())

        is_missing = np.array([
            [False, True, False, False],
            [False, False, False, False],
            [True, False, True, True],
            [True, False, False, True],
            [False, False, True, True],
        ])

        lp = gp.log_prob(tf.where(is_missing, np.nan, x),
                         is_missing=is_missing)

        # For each batch member, check that the log_prob is the same as for a
        # GaussianProcess without the missing index points.
        for i in range(5):
            gp_i = tfd.GaussianProcess(
                kernel=psd_kernels.ExponentiatedQuadratic(
                    amplitude, length_scale),
                index_points=tf.gather(index_points,
                                       (~is_missing[i]).nonzero()[0]),
                mean_fn=lambda x: tf.reduce_mean(x, axis=-1),
                observation_noise_variance=.05,
                jitter=0.0)
            lp_i = gp_i.log_prob(tf.gather(x[i],
                                           (~is_missing[i]).nonzero()[0]))
            # NOTE: This reshape is necessary because lp_i has shape [1] when
            # gp_i.index_points contains a single index point.
            self.assertAllClose(tf.reshape(lp_i, []), lp[i])

        # The log_prob should be zero when all points are missing out.
        self.assertAllClose(
            tf.zeros((3, 2)),
            gp.log_prob(tf.ones((3, 1, 4)) * np.nan,
                        is_missing=tf.constant(True, shape=(2, 4))))
    def testVarianceAndCovarianceMatrix(self):
        df = np.float64(4.)
        amp = np.float64(.5)
        len_scale = np.float64(.2)
        jitter = np.float64(1e-4)

        kernel = psd_kernels.ExponentiatedQuadratic(amp, len_scale)

        index_points = np.expand_dims(np.random.uniform(-1., 1., 10), -1)

        tp = tfd.StudentTProcess(df=df,
                                 kernel=kernel,
                                 index_points=index_points,
                                 jitter=jitter,
                                 validate_args=True)

        def _kernel_fn(x, y):
            return amp**2 * np.exp(-.5 * (np.squeeze(
                (x - y)**2)) / (len_scale**2))

        expected_covariance = (_kernel_fn(np.expand_dims(index_points, 0),
                                          np.expand_dims(index_points, 1)))

        self.assertAllClose(expected_covariance,
                            self.evaluate(tp.covariance()))
        self.assertAllClose(np.diag(expected_covariance),
                            self.evaluate(tp.variance()))
    def testPrecomputedCompositeTensor(self):
        amplitude = np.array([1., 2.], np.float64).reshape([2, 1])
        length_scale = np.array([.1, .2, .3], np.float64).reshape([1, 3])
        observation_noise_variance = np.array([1e-9], np.float64)

        observation_index_points = (np.random.uniform(
            -1., 1., (1, 1, 7, 2)).astype(np.float64))
        observations = np.random.uniform(-1., 1., (1, 1, 7)).astype(np.float64)

        index_points = np.random.uniform(-1., 1., (6, 2)).astype(np.float64)

        kernel = psd_kernels.ExponentiatedQuadratic(amplitude, length_scale)

        precomputed_stprm = tfd.StudentTProcessRegressionModel.precompute_regression_model(
            df=3.,
            kernel=kernel,
            index_points=index_points,
            observation_index_points=observation_index_points,
            observations=observations,
            observation_noise_variance=observation_noise_variance,
            validate_args=True)

        flat = tf.nest.flatten(precomputed_stprm, expand_composites=True)
        unflat = tf.nest.pack_sequence_as(precomputed_stprm,
                                          flat,
                                          expand_composites=True)
        self.assertIsInstance(unflat, tfd.StudentTProcessRegressionModel)
        # Check that we don't recompute the divisor matrix on flattening /
        # unflattening.
        self.assertIs(
            precomputed_stprm.kernel.schur_complement.
            _precomputed_divisor_matrix_cholesky,  # pylint:disable=line-too-long
            unflat.kernel.schur_complement._precomputed_divisor_matrix_cholesky
        )
  def testCustomMarginalFn(self):
    def test_marginal_fn(
        loc,
        covariance,
        validate_args=False,
        allow_nan_stats=False,
        name="custom_marginal"):
      return tfd.MultivariateNormalDiag(
          loc=loc,
          scale_diag=tf.math.sqrt(tf.linalg.diag_part(covariance)),
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats,
          name=name)

    index_points = np.expand_dims(np.random.uniform(-1., 1., 10), -1)

    gp = tfd.GaussianProcess(
        kernel=psd_kernels.ExponentiatedQuadratic(),
        index_points=index_points,
        marginal_fn=test_marginal_fn,
        validate_args=True)

    self.assertAllClose(
        np.eye(10),
        gp.get_marginal_distribution().covariance())
    def testMeanSameAsGPRM(self):
        df = np.float64(3.)
        index_points = np.linspace(-4., 4., 5, dtype=np.float64)
        index_points = np.stack(np.meshgrid(index_points, index_points),
                                axis=-1)
        index_points = np.reshape(index_points, [-1, 2])

        # Kernel with batch_shape [5, 3]
        amplitude = np.array([1., 2., 3., 4., 5.], np.float64).reshape([5, 1])
        length_scale = np.array([.1, .2, .3], np.float64).reshape([1, 3])
        observation_noise_variance = np.array([1e-5, 1e-6, 1e-9],
                                              np.float64).reshape([1, 3])

        observation_index_points = (np.random.uniform(
            -1., 1., (3, 7, 2)).astype(np.float64))
        observations = np.random.uniform(-1., 1., (3, 7)).astype(np.float64)

        kernel = psd_kernels.ExponentiatedQuadratic(amplitude, length_scale)
        stprm = tfd.StudentTProcessRegressionModel(
            df=df,
            kernel=kernel,
            index_points=index_points,
            observation_index_points=observation_index_points,
            observations=observations,
            observation_noise_variance=observation_noise_variance)
        gprm = tfd.GaussianProcessRegressionModel(
            kernel=kernel,
            index_points=index_points,
            observation_index_points=observation_index_points,
            observations=observations,
            observation_noise_variance=observation_noise_variance)

        self.assertAllClose(self.evaluate(stprm.mean()),
                            self.evaluate(gprm.mean()))
  def testAlwaysYieldMultivariateNormal(self):
    gp = tfd.GaussianProcess(
        kernel=psd_kernels.ExponentiatedQuadratic(),
        index_points=tf.ones([5, 1, 2]),
        always_yield_multivariate_normal=False,
    )
    self.assertAllEqual([5], self.evaluate(gp.batch_shape_tensor()))
    self.assertAllEqual([], self.evaluate(gp.event_shape_tensor()))

    gp = tfd.GaussianProcess(
        kernel=psd_kernels.ExponentiatedQuadratic(),
        index_points=tf.ones([5, 1, 2]),
        always_yield_multivariate_normal=True,
    )
    self.assertAllEqual([5], self.evaluate(gp.batch_shape_tensor()))
    self.assertAllEqual([1], self.evaluate(gp.event_shape_tensor()))
示例#7
0
    def testVarianceAndCovarianceMatrix(self):
        amp = np.float64(.5)
        len_scale = np.float64(.2)
        jitter = np.float64(1e-4)
        observation_noise_variance = np.float64(3e-3)

        kernel = psd_kernels.ExponentiatedQuadratic(amp, len_scale)

        index_points = np.expand_dims(np.random.uniform(-1., 1., 10), -1)

        gp = tfd.GaussianProcess(
            kernel,
            index_points,
            observation_noise_variance=observation_noise_variance,
            jitter=jitter,
            validate_args=True)

        def _kernel_fn(x, y):
            return amp**2 * np.exp(-.5 * (np.squeeze(
                (x - y)**2)) / (len_scale**2))

        expected_covariance = (_kernel_fn(np.expand_dims(index_points, 0),
                                          np.expand_dims(index_points, 1)) +
                               observation_noise_variance * np.eye(10))

        self.assertAllClose(expected_covariance,
                            self.evaluate(gp.covariance()))
        self.assertAllClose(np.diag(expected_covariance),
                            self.evaluate(gp.variance()))
  def testShapes(self):
    # 5x5 grid of index points in R^2 and flatten to 25x2
    index_points = np.linspace(-4., 4., 5, dtype=np.float32)
    index_points = np.stack(np.meshgrid(index_points, index_points), axis=-1)
    index_points = np.reshape(index_points, [-1, 2])
    # ==> shape = [25, 2]

    # Kernel with batch_shape [2, 4, 3, 1]
    amplitude = np.array([1., 2.], np.float32).reshape([2, 1, 1, 1])
    length_scale = np.array([1., 2., 3., 4.], np.float32).reshape([1, 4, 1, 1])
    observation_noise_variance = np.array(
        [1e-5, 1e-6, 1e-5], np.float32).reshape([1, 1, 3, 1])
    batched_index_points = np.stack([index_points]*6)
    # ==> shape = [6, 25, 2]
    if not self.is_static:
      amplitude = tf1.placeholder_with_default(amplitude, shape=None)
      length_scale = tf1.placeholder_with_default(length_scale, shape=None)
      batched_index_points = tf1.placeholder_with_default(
          batched_index_points, shape=None)
    kernel = psd_kernels.ExponentiatedQuadratic(amplitude, length_scale)
    gp = tfd.GaussianProcess(
        kernel,
        batched_index_points,
        observation_noise_variance=observation_noise_variance,
        jitter=1e-5,
        validate_args=True)

    batch_shape = [2, 4, 3, 6]
    event_shape = [25]
    sample_shape = [5, 3]

    samples = gp.sample(sample_shape, seed=test_util.test_seed())

    if self.is_static or tf.executing_eagerly():
      self.assertAllEqual(gp.batch_shape_tensor(), batch_shape)
      self.assertAllEqual(gp.event_shape_tensor(), event_shape)
      self.assertAllEqual(samples.shape,
                          sample_shape + batch_shape + event_shape)
      self.assertAllEqual(gp.batch_shape, batch_shape)
      self.assertAllEqual(gp.event_shape, event_shape)
      self.assertAllEqual(samples.shape,
                          sample_shape + batch_shape + event_shape)
      self.assertAllEqual(gp.mean().shape, batch_shape + event_shape)
      self.assertAllEqual(gp.variance().shape, batch_shape + event_shape)
    else:
      self.assertAllEqual(self.evaluate(gp.batch_shape_tensor()), batch_shape)
      self.assertAllEqual(self.evaluate(gp.event_shape_tensor()), event_shape)
      self.assertAllEqual(
          self.evaluate(samples).shape,
          sample_shape + batch_shape + event_shape)
      self.assertIsNone(tensorshape_util.rank(samples.shape))
      self.assertIsNone(tensorshape_util.rank(gp.batch_shape))
      self.assertEqual(tensorshape_util.rank(gp.event_shape), 1)
      self.assertIsNone(
          tf.compat.dimension_value(tensorshape_util.dims(gp.event_shape)[0]))
      self.assertAllEqual(
          self.evaluate(tf.shape(gp.mean())), batch_shape + event_shape)
      self.assertAllEqual(self.evaluate(
          tf.shape(gp.variance())), batch_shape + event_shape)
示例#9
0
 def testOneOfCholeskyAndMarginalFn(self):
     with self.assertRaises(ValueError):
         index_points = np.array([3., 4., 5.])[..., np.newaxis]
         tfd.GaussianProcess(kernel=psd_kernels.ExponentiatedQuadratic(),
                             index_points=index_points,
                             marginal_fn=lambda x: x,
                             cholesky_fn=lambda x: x,
                             validate_args=True)
    def testPrecomputedWithMasking(self):
        amplitude = np.array([1., 2.], np.float64)
        length_scale = np.array([[.1], [.2], [.3]], np.float64)
        observation_noise_variance = np.array([[1e-2], [1e-4], [1e-6]],
                                              np.float64)

        rng = test_util.test_np_rng()
        observations_is_missing = np.array([
            [False, True, False, True, False, True],
            [False, False, False, False, False, False],
            [True, True, False, False, True, True],
        ]).reshape((3, 1, 6))
        observation_index_points = np.where(
            observations_is_missing[..., np.newaxis], np.nan,
            rng.uniform(-1., 1., (3, 1, 6, 2)).astype(np.float64))
        observations = np.where(
            observations_is_missing, np.nan,
            rng.uniform(-1., 1., (3, 1, 6)).astype(np.float64))

        index_points = rng.uniform(-1., 1., (5, 2)).astype(np.float64)

        kernel = psd_kernels.ExponentiatedQuadratic(amplitude, length_scale)
        gprm = tfd.GaussianProcessRegressionModel.precompute_regression_model(
            kernel=kernel,
            index_points=index_points,
            observation_index_points=observation_index_points,
            observations=observations,
            observations_is_missing=observations_is_missing,
            observation_noise_variance=observation_noise_variance,
            validate_args=True)

        self.assertAllNotNan(gprm.mean())
        self.assertAllNotNan(gprm.variance())
        self.assertAllNotNan(gprm.covariance())

        # For each batch member of `gprm`, check that the distribution is the same
        # as a GaussianProcessRegressionModel with no masking but conditioned on
        # only the not-masked-out index points.
        x = gprm.sample(seed=test_util.test_seed())
        for i in range(3):
            observation_index_points_i = tf.gather(
                observation_index_points[i, 0],
                (~observations_is_missing[i, 0]).nonzero()[0])
            observations_i = tf.gather(
                observations[i,
                             0], (~observations_is_missing[i, 0]).nonzero()[0])
            gprm_i = tfd.GaussianProcessRegressionModel.precompute_regression_model(
                kernel=kernel[i],
                index_points=index_points,
                observation_index_points=observation_index_points_i,
                observations=observations_i,
                observation_noise_variance=observation_noise_variance[i, 0],
                validate_args=True)

            self.assertAllClose(gprm.mean()[i], gprm_i.mean())
            self.assertAllClose(gprm.variance()[i], gprm_i.variance())
            self.assertAllClose(gprm.covariance()[i], gprm_i.covariance())
            self.assertAllClose(gprm.log_prob(x)[i], gprm_i.log_prob(x[i]))
    def testCopy(self):
        # 5 random index points in R^2
        index_points_1 = np.random.uniform(-4., 4., (5, 2)).astype(np.float32)
        # 10 random index points in R^2
        index_points_2 = np.random.uniform(-4., 4., (10, 2)).astype(np.float32)

        # ==> shape = [6, 25, 2]
        if not self.is_static:
            index_points_1 = tf1.placeholder_with_default(index_points_1,
                                                          shape=None)
            index_points_2 = tf1.placeholder_with_default(index_points_2,
                                                          shape=None)

        mean_fn = lambda x: np.array([0.], np.float32)
        kernel_1 = psd_kernels.ExponentiatedQuadratic()
        kernel_2 = psd_kernels.ExpSinSquared()

        tp1 = tfd.StudentTProcess(df=3.,
                                  kernel=kernel_1,
                                  index_points=index_points_1,
                                  mean_fn=mean_fn,
                                  jitter=1e-5,
                                  validate_args=True)
        tp2 = tp1.copy(df=4., index_points=index_points_2, kernel=kernel_2)

        event_shape_1 = [5]
        event_shape_2 = [10]

        self.assertEqual(tp1.mean_fn, tp2.mean_fn)
        self.assertIsInstance(tp1.kernel, psd_kernels.ExponentiatedQuadratic)
        self.assertIsInstance(tp2.kernel, psd_kernels.ExpSinSquared)

        if self.is_static or tf.executing_eagerly():
            self.assertAllEqual(tp1.batch_shape, tp2.batch_shape)
            self.assertAllEqual(tp1.event_shape, event_shape_1)
            self.assertAllEqual(tp2.event_shape, event_shape_2)
            self.assertEqual(self.evaluate(tp1.df), 3.)
            self.assertEqual(self.evaluate(tp2.df), 4.)
            self.assertAllEqual(tp2.index_points, index_points_2)
            self.assertAllEqual(tp1.index_points, index_points_1)
            self.assertAllEqual(tp2.index_points, index_points_2)
            self.assertAllEqual(tf.get_static_value(tp1.jitter),
                                tf.get_static_value(tp2.jitter))
        else:
            self.assertAllEqual(self.evaluate(tp1.batch_shape_tensor()),
                                self.evaluate(tp2.batch_shape_tensor()))
            self.assertAllEqual(self.evaluate(tp1.event_shape_tensor()),
                                event_shape_1)
            self.assertAllEqual(self.evaluate(tp2.event_shape_tensor()),
                                event_shape_2)
            self.assertEqual(self.evaluate(tp1.jitter),
                             self.evaluate(tp2.jitter))
            self.assertEqual(self.evaluate(tp1.df), 3.)
            self.assertEqual(self.evaluate(tp2.df), 4.)
            self.assertAllEqual(self.evaluate(tp1.index_points),
                                index_points_1)
            self.assertAllEqual(self.evaluate(tp2.index_points),
                                index_points_2)
    def testInitParameterVariations(self, noise_kwargs, implied_values):
        num_test_points = 3
        num_obs_points = 4
        kernel = psd_kernels.ExponentiatedQuadratic()
        index_points = np.random.uniform(-1., 1., (num_test_points, 1))
        observation_index_points = np.random.uniform(-1., 1.,
                                                     (num_obs_points, 1))
        observations = np.random.uniform(-1., 1., num_obs_points)
        jitter = 1e-6

        gprm = tfd.GaussianProcessRegressionModel(
            kernel=kernel,
            index_points=index_points,
            observation_index_points=observation_index_points,
            observations=observations,
            jitter=jitter,
            validate_args=True,
            **noise_kwargs)

        # 'Property' means what was passed to CTOR. 'Parameter' is the effective
        # value by which the distribution is parameterized.
        implied_onv_param = implied_values[
            'observation_noise_variance_parameter']
        implied_pnv_param = implied_values[
            'predictive_noise_variance_parameter']

        # k_xx - k_xn @ (k_nn + ONV) @ k_nx + PNV
        k = lambda x, y: _np_kernel_matrix_fn(1., 1., x, y)
        k_tt_ = k(index_points, index_points)
        k_tx_ = k(index_points, observation_index_points)
        k_xx_plus_noise_ = (
            k(observation_index_points, observation_index_points) +
            (jitter + implied_onv_param) * np.eye(num_obs_points))

        expected_predictive_covariance = (
            k_tt_ - np.dot(k_tx_, np.linalg.solve(k_xx_plus_noise_, k_tx_.T)) +
            implied_pnv_param * np.eye(num_test_points))

        # Assertion 1: predictive covariance is correct.
        self.assertAllClose(self.evaluate(gprm.covariance()),
                            expected_predictive_covariance)

        # Assertion 2: predictive_noise_variance property is correct
        self.assertIsInstance(gprm.predictive_noise_variance, tf.Tensor)
        self.assertAllClose(self.evaluate(gprm.predictive_noise_variance),
                            implied_pnv_param)

        # Assertion 3: observation_noise_variance property is correct.
        self.assertIsInstance(gprm.observation_noise_variance, tf.Tensor)
        self.assertAllClose(
            self.evaluate(gprm.observation_noise_variance),
            # Note that this is, somewhat unintuitively, expceted to equal the
            # predictive_noise_variance. This is because of 1) the inheritance
            # structure of GPRM as a subclass of GaussianProcess and 2) the poor
            # choice of name of the GaussianProcess noise parameter. The latter
            # issue is being cleaned up in cl/256413439.
            implied_pnv_param)
 def testMean(self):
   mean_fn = lambda x: x[:, 0]**2
   kernel = psd_kernels.ExponentiatedQuadratic()
   index_points = np.expand_dims(np.random.uniform(-1., 1., 10), -1)
   gp = tfd.GaussianProcess(
       kernel, index_points, mean_fn=mean_fn, validate_args=True)
   expected_mean = mean_fn(index_points)
   self.assertAllClose(expected_mean,
                       self.evaluate(gp.mean()))
  def testOptimalVariationalShapes(self):
    # 5x5 grid of observation index points in R^2 and flatten to 25x2
    observation_index_points = np.linspace(-4., 4., 5, dtype=np.float64)
    observation_index_points = np.stack(
        np.meshgrid(
            observation_index_points, observation_index_points), axis=-1)
    observation_index_points = np.reshape(
        observation_index_points, [-1, 2])
    # ==> shape = [25, 2]
    observation_index_points = np.expand_dims(
        np.stack([observation_index_points]*6), -3)
    # ==> shape = [6, 1, 25, 2]
    observations = np.sin(observation_index_points[..., 0])
    # ==> shape = [6, 1, 25]

    # 9 inducing index points in R^2
    inducing_index_points = np.linspace(-4., 4., 3, dtype=np.float64)
    inducing_index_points = np.stack(np.meshgrid(inducing_index_points,
                                                 inducing_index_points),
                                     axis=-1)
    inducing_index_points = np.reshape(inducing_index_points, [-1, 2])
    # ==> shape = [9, 2]

    # Kernel with batch_shape [2, 4, 1, 1]
    amplitude = np.array([1., 2.], np.float64).reshape([2, 1, 1, 1])
    length_scale = np.array([.1, .2, .3, .4], np.float64).reshape([1, 4, 1, 1])

    jitter = np.float64(1e-6)
    observation_noise_variance = np.float64(1e-2)

    if not self.is_static:
      amplitude = tf1.placeholder_with_default(amplitude, shape=None)
      length_scale = tf1.placeholder_with_default(length_scale, shape=None)
      observation_index_points = tf1.placeholder_with_default(
          observation_index_points, shape=None)

      inducing_index_points = tf1.placeholder_with_default(
          inducing_index_points, shape=None)
    kernel = psd_kernels.ExponentiatedQuadratic(amplitude, length_scale)

    loc, scale = tfd.VariationalGaussianProcess.optimal_variational_posterior(
        kernel=kernel,
        inducing_index_points=inducing_index_points,
        observation_index_points=observation_index_points,
        observations=observations,
        observation_noise_variance=observation_noise_variance,
        jitter=jitter,
    )
    # We should expect that loc has shape [2, 4, 6, 1, 9]. This is because:
    # * [2, 4] comes from the batch shape of the kernel.
    # * [6, 1] comes from the batch shape of the observations / observation
    # index points.
    # * [9] comes from the number of inducing points.
    # Similar reasoning applies to scale.
    self.assertAllEqual([2, 4, 6, 1, 9], tf.shape(loc))
    self.assertAllEqual([2, 4, 6, 1, 9, 9], tf.shape(scale))
示例#15
0
    def testMarginalHasCorrectTypes(self):
        gp = tfd.GaussianProcess(kernel=psd_kernels.ExponentiatedQuadratic(),
                                 validate_args=True)

        self.assertIsInstance(
            gp.get_marginal_distribution(
                index_points=np.ones([1, 1], dtype=np.float32)), tfd.Normal)

        self.assertIsInstance(
            gp.get_marginal_distribution(
                index_points=np.ones([10, 1], dtype=np.float32)),
            tfd.MultivariateNormalLinearOperator)
    def testMarginalHasCorrectTypes(self):
        tp = tfd.StudentTProcess(df=3.,
                                 kernel=psd_kernels.ExponentiatedQuadratic(),
                                 validate_args=True)

        self.assertIsInstance(
            tp.get_marginal_distribution(
                index_points=np.ones([1, 1], dtype=np.float32)), tfd.StudentT)

        self.assertIsInstance(
            tp.get_marginal_distribution(
                index_points=np.ones([10, 1], dtype=np.float32)),
            tfd.MultivariateStudentTLinearOperator)
    def testInstantiate(self):
        df = np.float64(1.)
        # 5x5 grid of index points in R^2 and flatten to 25x2
        index_points = np.linspace(-4., 4., 5, dtype=np.float64)
        index_points = np.stack(np.meshgrid(index_points, index_points),
                                axis=-1)
        index_points = np.reshape(index_points, [-1, 2])
        # ==> shape = [25, 2]

        # Kernel with batch_shape [2, 4, 1, 3]
        amplitude = np.array([1., 2.], np.float64).reshape([2, 1, 1, 1])
        length_scale = np.array([.1, .2, .3, .4],
                                np.float64).reshape([1, 4, 1, 1])
        observation_noise_variance = np.array([1e-5, 1e-6, 1e-9],
                                              np.float64).reshape([1, 1, 1, 3])

        observation_index_points = (np.random.uniform(
            -1., 1., (3, 7, 2)).astype(np.float64))
        observations = np.random.uniform(-1., 1., (3, 7)).astype(np.float64)

        def cholesky_fn(x):
            return tf.linalg.cholesky(
                tf.linalg.set_diag(x,
                                   tf.linalg.diag_part(x) + 1.))

        kernel = psd_kernels.ExponentiatedQuadratic(amplitude, length_scale)
        stprm = tfd.StudentTProcessRegressionModel(
            df=df,
            kernel=kernel,
            index_points=index_points,
            observation_index_points=observation_index_points,
            observations=observations,
            observation_noise_variance=observation_noise_variance,
            cholesky_fn=cholesky_fn)
        batch_shape = [2, 4, 1, 3]
        event_shape = [25]
        sample_shape = [7, 2]

        print(stprm.batch_shape)
        print(stprm.kernel.batch_shape)
        print(stprm.kernel.schur_complement.batch_shape)
        print(stprm.kernel.schur_complement.base_kernel.batch_shape)

        self.assertIs(cholesky_fn, stprm.cholesky_fn)

        samples = stprm.sample(sample_shape, seed=test_util.test_seed())
        self.assertAllEqual(stprm.batch_shape_tensor(), batch_shape)
        self.assertAllEqual(stprm.event_shape_tensor(), event_shape)
        self.assertAllEqual(
            self.evaluate(samples).shape,
            sample_shape + batch_shape + event_shape)
  def testGPPosteriorPredictive(self):
    amplitude = np.float64(.5)
    length_scale = np.float64(2.)
    jitter = np.float64(1e-4)
    observation_noise_variance = np.float64(3e-3)
    kernel = psd_kernels.ExponentiatedQuadratic(amplitude, length_scale)

    index_points = np.random.uniform(-1., 1., 10)[..., np.newaxis]

    gp = tfd.GaussianProcess(
        kernel,
        index_points,
        observation_noise_variance=observation_noise_variance,
        jitter=jitter,
        validate_args=True)

    predictive_index_points = np.random.uniform(1., 2., 10)[..., np.newaxis]
    observations = np.linspace(1., 10., 10)

    expected_gprm = tfd.GaussianProcessRegressionModel(
        kernel=kernel,
        observation_index_points=index_points,
        observations=observations,
        observation_noise_variance=observation_noise_variance,
        jitter=jitter,
        index_points=predictive_index_points,
        validate_args=True)

    actual_gprm = gp.posterior_predictive(
        predictive_index_points=predictive_index_points,
        observations=observations)

    samples = self.evaluate(actual_gprm.sample(10, seed=test_util.test_seed()))

    self.assertAllClose(
        self.evaluate(expected_gprm.mean()),
        self.evaluate(actual_gprm.mean()))

    self.assertAllClose(
        self.evaluate(expected_gprm.covariance()),
        self.evaluate(actual_gprm.covariance()))

    self.assertAllClose(
        self.evaluate(expected_gprm.log_prob(samples)),
        self.evaluate(actual_gprm.log_prob(samples)))
示例#19
0
  def testEmptyDataMatchesGPPrior(self):
    amp = np.float64(.5)
    len_scale = np.float64(.2)
    jitter = np.float64(1e-4)
    index_points = np.random.uniform(-1., 1., (10, 1)).astype(np.float64)

    # k_xx - k_xn @ (k_nn + sigma^2) @ k_nx + sigma^2
    mean_fn = lambda x: x[:, 0]**2

    kernel = psd_kernels.ExponentiatedQuadratic(amp, len_scale)
    gp = tfd.GaussianProcess(
        kernel,
        index_points,
        mean_fn=mean_fn,
        jitter=jitter,
        validate_args=True)

    gprm_nones = tfd.GaussianProcessRegressionModel(
        kernel,
        index_points,
        mean_fn=mean_fn,
        jitter=jitter,
        validate_args=True)

    gprm_zero_shapes = tfd.GaussianProcessRegressionModel(
        kernel,
        index_points,
        observation_index_points=tf.ones([5, 0, 1], tf.float64),
        observations=tf.ones([5, 0], tf.float64),
        mean_fn=mean_fn,
        jitter=jitter,
        validate_args=True)

    for gprm in [gprm_nones, gprm_zero_shapes]:
      self.assertAllClose(self.evaluate(gp.mean()), self.evaluate(gprm.mean()))
      self.assertAllClose(self.evaluate(gp.covariance()),
                          self.evaluate(gprm.covariance()))
      self.assertAllClose(self.evaluate(gp.variance()),
                          self.evaluate(gprm.variance()))

      observations = np.random.uniform(-1., 1., 10).astype(np.float64)
      self.assertAllClose(self.evaluate(gp.log_prob(observations)),
                          self.evaluate(gprm.log_prob(observations)))
示例#20
0
    def testCustomCholeskyFn(self):
        def test_cholesky(x):
            return tf.linalg.cholesky(
                tf.linalg.set_diag(x,
                                   tf.linalg.diag_part(x) + 3.))

        # Make sure the points are far away so that this is roughly diagonal.
        index_points = np.array([-100., -50., 50., 100])[..., np.newaxis]

        gp = tfd.GaussianProcess(kernel=psd_kernels.ExponentiatedQuadratic(),
                                 index_points=index_points,
                                 cholesky_fn=test_cholesky,
                                 validate_args=True)

        # Roughly, the kernel matrix will look like the identity matrix.
        # When we add 3 to the diagonal, this leads to 2's on the diagonal
        # for the cholesky factor.
        self.assertAllClose(2 * np.ones([4], dtype=np.float64),
                            gp.get_marginal_distribution().stddev())
    def testEmptyDataMatchesStPPrior(self):
        df = np.float64(3.5)
        amp = np.float64(.5)
        len_scale = np.float64(.2)
        index_points = np.random.uniform(-1., 1., (10, 1)).astype(np.float64)

        # k_xx - k_xn @ (k_nn + sigma^2) @ k_nx + sigma^2
        mean_fn = lambda x: x[:, 0]**2

        kernel = psd_kernels.ExponentiatedQuadratic(amp, len_scale)
        stp = tfd.StudentTProcess(df,
                                  kernel,
                                  index_points,
                                  mean_fn=mean_fn,
                                  validate_args=True)

        stprm_nones = tfd.StudentTProcessRegressionModel(
            df,
            kernel=kernel,
            index_points=index_points,
            mean_fn=mean_fn,
            validate_args=True)

        stprm_zero_shapes = tfd.StudentTProcessRegressionModel(
            df,
            kernel=kernel,
            index_points=index_points,
            observation_index_points=tf.ones([0, 1], tf.float64),
            observations=tf.ones([0], tf.float64),
            mean_fn=mean_fn,
            validate_args=True)

        for stprm in [stprm_nones, stprm_zero_shapes]:
            self.assertAllClose(self.evaluate(stp.mean()),
                                self.evaluate(stprm.mean()))
            self.assertAllClose(self.evaluate(stp.covariance()),
                                self.evaluate(stprm.covariance()))
            self.assertAllClose(self.evaluate(stp.variance()),
                                self.evaluate(stprm.variance()))

            observations = np.random.uniform(-1., 1., 10).astype(np.float64)
            self.assertAllClose(self.evaluate(stp.log_prob(observations)),
                                self.evaluate(stprm.log_prob(observations)))
示例#22
0
    def testLateBindingIndexPoints(self):
        amp = np.float64(.5)
        len_scale = np.float64(.2)
        kernel = psd_kernels.ExponentiatedQuadratic(amp, len_scale)
        mean_fn = lambda x: x[:, 0]**2
        jitter = np.float64(1e-4)
        observation_noise_variance = np.float64(3e-3)

        gp = tfd.GaussianProcess(
            kernel=kernel,
            mean_fn=mean_fn,
            observation_noise_variance=observation_noise_variance,
            jitter=jitter,
            validate_args=True)

        index_points = np.random.uniform(-1., 1., [10, 1])

        expected_mean = mean_fn(index_points)
        self.assertAllClose(expected_mean,
                            self.evaluate(gp.mean(index_points=index_points)))

        def _kernel_fn(x, y):
            return amp**2 * np.exp(-.5 * (np.squeeze(
                (x - y)**2)) / (len_scale**2))

        expected_covariance = (_kernel_fn(np.expand_dims(index_points, -3),
                                          np.expand_dims(index_points, -2)) +
                               observation_noise_variance * np.eye(10))

        self.assertAllClose(
            expected_covariance,
            self.evaluate(gp.covariance(index_points=index_points)))
        self.assertAllClose(
            np.diag(expected_covariance),
            self.evaluate(gp.variance(index_points=index_points)))
        self.assertAllClose(
            np.sqrt(np.diag(expected_covariance)),
            self.evaluate(gp.stddev(index_points=index_points)))

        # Calling mean with no index_points should raise an Error
        with self.assertRaises(ValueError):
            gp.mean()
  def testCompositeTensor(self):
    index_points = np.random.uniform(-1., 1., 10)[..., np.newaxis]
    gp = tfd.GaussianProcess(
        kernel=psd_kernels.ExponentiatedQuadratic(),
        index_points=index_points)

    flat = tf.nest.flatten(gp, expand_composites=True)
    unflat = tf.nest.pack_sequence_as(
        gp, flat, expand_composites=True)
    self.assertIsInstance(unflat, tfd.GaussianProcess)

    x = self.evaluate(gp.sample(3, seed=test_util.test_seed()))
    actual = self.evaluate(gp.log_prob(x))

    self.assertAllClose(self.evaluate(unflat.log_prob(x)), actual)

    @tf.function
    def call_log_prob(d):
      return d.log_prob(x)
    self.assertAllClose(actual, call_log_prob(gp))
    self.assertAllClose(actual, call_log_prob(unflat))
    def testErrorCases(self):
        kernel = psd_kernels.ExponentiatedQuadratic()
        index_points = np.random.uniform(-1., 1., (10, 1)).astype(np.float64)
        observation_index_points = (np.random.uniform(-1., 1., (5, 1)).astype(
            np.float64))
        observations = np.random.uniform(-1., 1., 3).astype(np.float64)

        # Both or neither of `observation_index_points` and `observations` must be
        # specified.
        with self.assertRaises(ValueError):
            tfd.GaussianProcessRegressionModel(kernel,
                                               index_points,
                                               observation_index_points=None,
                                               observations=observations,
                                               validate_args=True)
        with self.assertRaises(ValueError):
            tfd.GaussianProcessRegressionModel(kernel,
                                               index_points,
                                               observation_index_points,
                                               observations=None,
                                               validate_args=True)

        # If specified, mean_fn must be a callable.
        with self.assertRaises(ValueError):
            tfd.GaussianProcessRegressionModel(kernel,
                                               index_points,
                                               mean_fn=0.,
                                               validate_args=True)

        # Observation index point and observation counts must be broadcastable.
        # Errors based on conditions of dynamic shape in graph mode cannot be
        # caught, so we only check this error case in static shape or eager mode.
        if self.is_static or tf.executing_eagerly():
            with self.assertRaises(ValueError):
                tfd.GaussianProcessRegressionModel(
                    kernel,
                    index_points,
                    observation_index_points=np.ones([2, 2, 2]),
                    observations=np.ones([5, 5]),
                    validate_args=True)
  def testUnivariateLogProbWithIsMissing(self):
    index_points = tf.convert_to_tensor([[[0.0, 0.0]], [[0.5, 1.0]]])
    amplitude = tf.convert_to_tensor(1.1)
    length_scale = tf.convert_to_tensor(0.9)

    gp = tfd.GaussianProcess(
        kernel=psd_kernels.ExponentiatedQuadratic(
            amplitude, length_scale),
        index_points=index_points,
        mean_fn=lambda x: tf.reduce_mean(x, axis=-1),
        observation_noise_variance=.05,
        jitter=0.0)

    x = gp.sample(3, seed=test_util.test_seed())
    lp = gp.log_prob(x)

    self.assertAllClose(lp, gp.log_prob(x, is_missing=[False, False]))
    self.assertAllClose(tf.convert_to_tensor([np.zeros((3, 2)), lp]),
                        gp.log_prob(x, is_missing=[[[True]], [[False]]]))
    self.assertAllClose(
        tf.convert_to_tensor([[lp[0, 0], 0.0], [0.0, 0.0], [0., lp[2, 1]]]),
        gp.log_prob(x, is_missing=[[False, True], [True, True], [True, False]]))
    def testLogProbNearGPRM(self):
        # For large df, the log_prob calculations should be the same.
        df = np.float64(1e6)
        index_points = np.linspace(-4., 4., 5, dtype=np.float64)
        index_points = np.stack(np.meshgrid(index_points, index_points),
                                axis=-1)
        index_points = np.reshape(index_points, [-1, 2])

        # Kernel with batch_shape [5, 3]
        amplitude = np.array([1., 2., 3., 4., 5.], np.float64).reshape([5, 1])
        length_scale = np.array([.1, .2, .3], np.float64).reshape([1, 3])
        observation_noise_variance = np.array([1e-5, 1e-6, 1e-9],
                                              np.float64).reshape([1, 3])

        observation_index_points = (np.random.uniform(
            -1., 1., (3, 7, 2)).astype(np.float64))
        observations = np.random.uniform(-1., 1., (3, 7)).astype(np.float64)

        kernel = psd_kernels.ExponentiatedQuadratic(amplitude, length_scale)
        stprm = tfd.StudentTProcessRegressionModel(
            df=df,
            kernel=kernel,
            index_points=index_points,
            observation_index_points=observation_index_points,
            observations=observations,
            observation_noise_variance=observation_noise_variance)
        gprm = tfd.GaussianProcessRegressionModel(
            kernel=kernel,
            index_points=index_points,
            observation_index_points=observation_index_points,
            observations=observations,
            observation_noise_variance=observation_noise_variance)

        x = np.linspace(-3., 3., 25)

        self.assertAllClose(self.evaluate(stprm.log_prob(x)),
                            self.evaluate(gprm.log_prob(x)),
                            rtol=2e-5)
    def testMeanVarianceAndCovariancePrecomputed(self):
        amplitude = np.array([1., 2.], np.float64).reshape([2, 1])
        length_scale = np.array([.1, .2, .3], np.float64).reshape([1, 3])
        observation_noise_variance = np.array([1e-9], np.float64)
        df = np.float64(3.)

        observation_index_points = (np.random.uniform(
            -1., 1., (1, 1, 7, 2)).astype(np.float64))
        observations = np.random.uniform(-1., 1., (1, 1, 7)).astype(np.float64)

        index_points = np.random.uniform(-1., 1., (6, 2)).astype(np.float64)

        kernel = psd_kernels.ExponentiatedQuadratic(amplitude, length_scale)
        stprm = tfd.StudentTProcessRegressionModel(
            df=df,
            kernel=kernel,
            index_points=index_points,
            observation_index_points=observation_index_points,
            observations=observations,
            observation_noise_variance=observation_noise_variance,
            validate_args=True)

        precomputed_stprm = tfd.StudentTProcessRegressionModel.precompute_regression_model(
            df=df,
            kernel=kernel,
            index_points=index_points,
            observation_index_points=observation_index_points,
            observations=observations,
            observation_noise_variance=observation_noise_variance,
            validate_args=True)

        self.assertAllClose(self.evaluate(precomputed_stprm.covariance()),
                            self.evaluate(stprm.covariance()))
        self.assertAllClose(self.evaluate(precomputed_stprm.variance()),
                            self.evaluate(stprm.variance()))
        self.assertAllClose(self.evaluate(precomputed_stprm.mean()),
                            self.evaluate(stprm.mean()))
    def testShapes(self):
        # We'll use a batch shape of [2, 3, 5, 7, 11]

        # 5x5 grid of index points in R^2 and flatten to 25x2
        index_points = np.linspace(-4., 4., 5, dtype=np.float64)
        index_points = np.stack(np.meshgrid(index_points, index_points),
                                axis=-1)
        index_points = np.reshape(index_points, [-1, 2])
        # ==> shape = [25, 2]
        batched_index_points = np.reshape(index_points, [1, 1, 25, 2])
        batched_index_points = np.stack([batched_index_points] * 5)
        # ==> shape = [5, 1, 1, 25, 2]

        # Kernel with batch_shape [2, 3, 1, 1, 1]
        amplitude = np.array([1., 2.], np.float64).reshape([2, 1, 1, 1, 1])
        length_scale = np.array([.1, .2, .3],
                                np.float64).reshape([1, 3, 1, 1, 1])
        observation_noise_variance = np.array([1e-9], np.float64).reshape(
            [1, 1, 1, 1, 1])

        jitter = np.float64(1e-6)
        observation_index_points = (np.random.uniform(
            -1., 1., (7, 1, 7, 2)).astype(np.float64))
        observations = np.random.uniform(-1., 1., (11, 7)).astype(np.float64)

        def cholesky_fn(x):
            return tf.linalg.cholesky(
                tf.linalg.set_diag(x,
                                   tf.linalg.diag_part(x) + 1.))

        if not self.is_static:
            amplitude = tf1.placeholder_with_default(amplitude, shape=None)
            length_scale = tf1.placeholder_with_default(length_scale,
                                                        shape=None)
            batched_index_points = tf1.placeholder_with_default(
                batched_index_points, shape=None)

            observation_index_points = tf1.placeholder_with_default(
                observation_index_points, shape=None)
            observations = tf1.placeholder_with_default(observations,
                                                        shape=None)

        kernel = psd_kernels.ExponentiatedQuadratic(amplitude, length_scale)

        gprm = tfd.GaussianProcessRegressionModel(kernel,
                                                  batched_index_points,
                                                  observation_index_points,
                                                  observations,
                                                  observation_noise_variance,
                                                  cholesky_fn=cholesky_fn,
                                                  jitter=jitter,
                                                  validate_args=True)

        batch_shape = [2, 3, 5, 7, 11]
        event_shape = [25]
        sample_shape = [9, 3]

        samples = gprm.sample(sample_shape, seed=test_util.test_seed())

        self.assertIs(cholesky_fn, gprm.cholesky_fn)

        if self.is_static or tf.executing_eagerly():
            self.assertAllEqual(gprm.batch_shape_tensor(), batch_shape)
            self.assertAllEqual(gprm.event_shape_tensor(), event_shape)
            self.assertAllEqual(samples.shape,
                                sample_shape + batch_shape + event_shape)
            self.assertAllEqual(gprm.batch_shape, batch_shape)
            self.assertAllEqual(gprm.event_shape, event_shape)
            self.assertAllEqual(samples.shape,
                                sample_shape + batch_shape + event_shape)
        else:
            self.assertAllEqual(self.evaluate(gprm.batch_shape_tensor()),
                                batch_shape)
            self.assertAllEqual(self.evaluate(gprm.event_shape_tensor()),
                                event_shape)
            self.assertAllEqual(
                self.evaluate(samples).shape,
                sample_shape + batch_shape + event_shape)
            self.assertIsNone(tensorshape_util.rank(samples.shape))
            self.assertIsNone(tensorshape_util.rank(gprm.batch_shape))
            self.assertEqual(tensorshape_util.rank(gprm.event_shape), 1)
            self.assertIsNone(
                tf.compat.dimension_value(
                    tensorshape_util.dims(gprm.event_shape)[0]))
    def testCopy(self):
        # 5 random index points in R^2
        index_points_1 = np.random.uniform(-4., 4., (5, 2)).astype(np.float32)
        # 10 random index points in R^2
        index_points_2 = np.random.uniform(-4., 4., (10, 2)).astype(np.float32)

        observation_index_points_1 = (np.random.uniform(
            -4., 4., (7, 2)).astype(np.float32))
        observation_index_points_2 = (np.random.uniform(
            -4., 4., (9, 2)).astype(np.float32))

        observations_1 = np.random.uniform(-1., 1., 7).astype(np.float32)
        observations_2 = np.random.uniform(-1., 1., 9).astype(np.float32)

        # ==> shape = [6, 25, 2]
        if not self.is_static:
            index_points_1 = tf1.placeholder_with_default(index_points_1,
                                                          shape=None)
            index_points_2 = tf1.placeholder_with_default(index_points_2,
                                                          shape=None)
            observation_index_points_1 = tf1.placeholder_with_default(
                observation_index_points_1, shape=None)
            observation_index_points_2 = tf1.placeholder_with_default(
                observation_index_points_2, shape=None)
            observations_1 = tf1.placeholder_with_default(observations_1,
                                                          shape=None)
            observations_2 = tf1.placeholder_with_default(observations_2,
                                                          shape=None)

        mean_fn = lambda x: np.array([0.], np.float32)
        kernel_1 = psd_kernels.ExponentiatedQuadratic()
        kernel_2 = psd_kernels.ExpSinSquared()

        gprm1 = tfd.GaussianProcessRegressionModel(
            kernel=kernel_1,
            index_points=index_points_1,
            observation_index_points=observation_index_points_1,
            observations=observations_1,
            mean_fn=mean_fn,
            jitter=1e-5,
            validate_args=True)
        gprm2 = gprm1.copy(kernel=kernel_2,
                           index_points=index_points_2,
                           observation_index_points=observation_index_points_2,
                           observations=observations_2)

        precomputed_gprm1 = (
            tfd.GaussianProcessRegressionModel.precompute_regression_model(
                kernel=kernel_1,
                index_points=index_points_1,
                observation_index_points=observation_index_points_1,
                observations=observations_1,
                mean_fn=mean_fn,
                jitter=1e-5,
                validate_args=True))
        precomputed_gprm2 = precomputed_gprm1.copy(index_points=index_points_2)
        self.assertIs(precomputed_gprm1.mean_fn, precomputed_gprm2.mean_fn)
        self.assertIs(precomputed_gprm1.kernel, precomputed_gprm2.kernel)

        event_shape_1 = [5]
        event_shape_2 = [10]

        self.assertIsInstance(gprm1.kernel.base_kernel,
                              psd_kernels.ExponentiatedQuadratic)
        self.assertIsInstance(gprm2.kernel.base_kernel,
                              psd_kernels.ExpSinSquared)

        if self.is_static or tf.executing_eagerly():
            self.assertAllEqual(gprm1.batch_shape, gprm2.batch_shape)
            self.assertAllEqual(gprm1.event_shape, event_shape_1)
            self.assertAllEqual(gprm2.event_shape, event_shape_2)
            self.assertAllEqual(gprm1.index_points, index_points_1)
            self.assertAllEqual(gprm2.index_points, index_points_2)
            self.assertAllEqual(tf.get_static_value(gprm1.jitter),
                                tf.get_static_value(gprm2.jitter))
        else:
            self.assertAllEqual(self.evaluate(gprm1.batch_shape_tensor()),
                                self.evaluate(gprm2.batch_shape_tensor()))
            self.assertAllEqual(self.evaluate(gprm1.event_shape_tensor()),
                                event_shape_1)
            self.assertAllEqual(self.evaluate(gprm2.event_shape_tensor()),
                                event_shape_2)
            self.assertEqual(self.evaluate(gprm1.jitter),
                             self.evaluate(gprm2.jitter))
            self.assertAllEqual(self.evaluate(gprm1.index_points),
                                index_points_1)
            self.assertAllEqual(self.evaluate(gprm2.index_points),
                                index_points_2)
    def testCopy(self):
        # 5 random index points in R^2
        index_points_1 = np.random.uniform(-4., 4., (5, 2)).astype(np.float32)
        # 10 random index points in R^2
        index_points_2 = np.random.uniform(-4., 4., (10, 2)).astype(np.float32)

        observation_index_points_1 = (np.random.uniform(
            -4., 4., (7, 2)).astype(np.float32))
        observation_index_points_2 = (np.random.uniform(
            -4., 4., (9, 2)).astype(np.float32))

        observations_1 = np.random.uniform(-1., 1., 7).astype(np.float32)
        observations_2 = np.random.uniform(-1., 1., 9).astype(np.float32)

        # ==> shape = [6, 25, 2]
        mean_fn = lambda x: np.array([0.], np.float32)
        kernel_1 = psd_kernels.ExponentiatedQuadratic()
        kernel_2 = psd_kernels.ExpSinSquared()

        stprm1 = tfd.StudentTProcessRegressionModel(
            df=5.,
            kernel=kernel_1,
            index_points=index_points_1,
            observation_index_points=observation_index_points_1,
            observations=observations_1,
            mean_fn=mean_fn,
            validate_args=True)
        stprm2 = stprm1.copy(
            kernel=kernel_2,
            index_points=index_points_2,
            observation_index_points=observation_index_points_2,
            observations=observations_2)

        precomputed_stprm1 = (
            tfd.StudentTProcessRegressionModel.precompute_regression_model(
                df=5.,
                kernel=kernel_1,
                index_points=index_points_1,
                observation_index_points=observation_index_points_1,
                observations=observations_1,
                mean_fn=mean_fn,
                validate_args=True))
        precomputed_stprm2 = precomputed_stprm1.copy(
            index_points=index_points_2)
        self.assertIs(precomputed_stprm1.mean_fn, precomputed_stprm2.mean_fn)
        self.assertIs(precomputed_stprm1.kernel, precomputed_stprm2.kernel)

        event_shape_1 = [5]
        event_shape_2 = [10]

        self.assertIsInstance(stprm1.kernel.schur_complement.base_kernel,
                              psd_kernels.ExponentiatedQuadratic)
        self.assertIsInstance(stprm2.kernel.schur_complement.base_kernel,
                              psd_kernels.ExpSinSquared)
        self.assertAllEqual(self.evaluate(stprm1.batch_shape_tensor()),
                            self.evaluate(stprm2.batch_shape_tensor()))
        self.assertAllEqual(self.evaluate(stprm1.event_shape_tensor()),
                            event_shape_1)
        self.assertAllEqual(self.evaluate(stprm2.event_shape_tensor()),
                            event_shape_2)
        self.assertAllEqual(self.evaluate(stprm1.index_points), index_points_1)
        self.assertAllEqual(self.evaluate(stprm2.index_points), index_points_2)