def testCopy(self):
        # 5 random index points in R^2
        index_points_1 = np.random.uniform(-4., 4., (5, 2)).astype(np.float32)
        # 10 random index points in R^2
        index_points_2 = np.random.uniform(-4., 4., (10, 2)).astype(np.float32)

        # ==> shape = [6, 25, 2]
        if not self.is_static:
            index_points_1 = tf1.placeholder_with_default(index_points_1,
                                                          shape=None)
            index_points_2 = tf1.placeholder_with_default(index_points_2,
                                                          shape=None)

        mean_fn = lambda x: np.array([0.], np.float32)
        kernel_1 = psd_kernels.ExponentiatedQuadratic()
        kernel_2 = psd_kernels.ExpSinSquared()

        tp1 = tfd.StudentTProcess(df=3.,
                                  kernel=kernel_1,
                                  index_points=index_points_1,
                                  mean_fn=mean_fn,
                                  jitter=1e-5,
                                  validate_args=True)
        tp2 = tp1.copy(df=4., index_points=index_points_2, kernel=kernel_2)

        event_shape_1 = [5]
        event_shape_2 = [10]

        self.assertEqual(tp1.mean_fn, tp2.mean_fn)
        self.assertIsInstance(tp1.kernel, psd_kernels.ExponentiatedQuadratic)
        self.assertIsInstance(tp2.kernel, psd_kernels.ExpSinSquared)

        if self.is_static or tf.executing_eagerly():
            self.assertAllEqual(tp1.batch_shape, tp2.batch_shape)
            self.assertAllEqual(tp1.event_shape, event_shape_1)
            self.assertAllEqual(tp2.event_shape, event_shape_2)
            self.assertEqual(self.evaluate(tp1.df), 3.)
            self.assertEqual(self.evaluate(tp2.df), 4.)
            self.assertAllEqual(tp2.index_points, index_points_2)
            self.assertAllEqual(tp1.index_points, index_points_1)
            self.assertAllEqual(tp2.index_points, index_points_2)
            self.assertAllEqual(tf.get_static_value(tp1.jitter),
                                tf.get_static_value(tp2.jitter))
        else:
            self.assertAllEqual(self.evaluate(tp1.batch_shape_tensor()),
                                self.evaluate(tp2.batch_shape_tensor()))
            self.assertAllEqual(self.evaluate(tp1.event_shape_tensor()),
                                event_shape_1)
            self.assertAllEqual(self.evaluate(tp2.event_shape_tensor()),
                                event_shape_2)
            self.assertEqual(self.evaluate(tp1.jitter),
                             self.evaluate(tp2.jitter))
            self.assertEqual(self.evaluate(tp1.df), 3.)
            self.assertEqual(self.evaluate(tp2.df), 4.)
            self.assertAllEqual(self.evaluate(tp1.index_points),
                                index_points_1)
            self.assertAllEqual(self.evaluate(tp2.index_points),
                                index_points_2)
    def testCopy(self):
        # 5 random index points in R^2
        index_points_1 = np.random.uniform(-4., 4., (5, 2)).astype(np.float32)
        # 10 random index points in R^2
        index_points_2 = np.random.uniform(-4., 4., (10, 2)).astype(np.float32)

        observation_index_points_1 = (np.random.uniform(
            -4., 4., (7, 2)).astype(np.float32))
        observation_index_points_2 = (np.random.uniform(
            -4., 4., (9, 2)).astype(np.float32))

        observations_1 = np.random.uniform(-1., 1., 7).astype(np.float32)
        observations_2 = np.random.uniform(-1., 1., 9).astype(np.float32)

        # ==> shape = [6, 25, 2]
        mean_fn = lambda x: np.array([0.], np.float32)
        kernel_1 = psd_kernels.ExponentiatedQuadratic()
        kernel_2 = psd_kernels.ExpSinSquared()

        stprm1 = tfd.StudentTProcessRegressionModel(
            df=5.,
            kernel=kernel_1,
            index_points=index_points_1,
            observation_index_points=observation_index_points_1,
            observations=observations_1,
            mean_fn=mean_fn,
            validate_args=True)
        stprm2 = stprm1.copy(
            kernel=kernel_2,
            index_points=index_points_2,
            observation_index_points=observation_index_points_2,
            observations=observations_2)

        precomputed_stprm1 = (
            tfd.StudentTProcessRegressionModel.precompute_regression_model(
                df=5.,
                kernel=kernel_1,
                index_points=index_points_1,
                observation_index_points=observation_index_points_1,
                observations=observations_1,
                mean_fn=mean_fn,
                validate_args=True))
        precomputed_stprm2 = precomputed_stprm1.copy(
            index_points=index_points_2)
        self.assertIs(precomputed_stprm1.mean_fn, precomputed_stprm2.mean_fn)
        self.assertIs(precomputed_stprm1.kernel, precomputed_stprm2.kernel)

        event_shape_1 = [5]
        event_shape_2 = [10]

        self.assertIsInstance(stprm1.kernel.schur_complement.base_kernel,
                              psd_kernels.ExponentiatedQuadratic)
        self.assertIsInstance(stprm2.kernel.schur_complement.base_kernel,
                              psd_kernels.ExpSinSquared)
        self.assertAllEqual(self.evaluate(stprm1.batch_shape_tensor()),
                            self.evaluate(stprm2.batch_shape_tensor()))
        self.assertAllEqual(self.evaluate(stprm1.event_shape_tensor()),
                            event_shape_1)
        self.assertAllEqual(self.evaluate(stprm2.event_shape_tensor()),
                            event_shape_2)
        self.assertAllEqual(self.evaluate(stprm1.index_points), index_points_1)
        self.assertAllEqual(self.evaluate(stprm2.index_points), index_points_2)
    def testCopy(self):
        # 5 random index points in R^2
        index_points_1 = np.random.uniform(-4., 4., (5, 2)).astype(np.float32)
        # 10 random index points in R^2
        index_points_2 = np.random.uniform(-4., 4., (10, 2)).astype(np.float32)

        observation_index_points_1 = (np.random.uniform(
            -4., 4., (7, 2)).astype(np.float32))
        observation_index_points_2 = (np.random.uniform(
            -4., 4., (9, 2)).astype(np.float32))

        observations_1 = np.random.uniform(-1., 1., 7).astype(np.float32)
        observations_2 = np.random.uniform(-1., 1., 9).astype(np.float32)

        # ==> shape = [6, 25, 2]
        if not self.is_static:
            index_points_1 = tf1.placeholder_with_default(index_points_1,
                                                          shape=None)
            index_points_2 = tf1.placeholder_with_default(index_points_2,
                                                          shape=None)
            observation_index_points_1 = tf1.placeholder_with_default(
                observation_index_points_1, shape=None)
            observation_index_points_2 = tf1.placeholder_with_default(
                observation_index_points_2, shape=None)
            observations_1 = tf1.placeholder_with_default(observations_1,
                                                          shape=None)
            observations_2 = tf1.placeholder_with_default(observations_2,
                                                          shape=None)

        mean_fn = lambda x: np.array([0.], np.float32)
        kernel_1 = psd_kernels.ExponentiatedQuadratic()
        kernel_2 = psd_kernels.ExpSinSquared()

        gprm1 = tfd.GaussianProcessRegressionModel(
            kernel=kernel_1,
            index_points=index_points_1,
            observation_index_points=observation_index_points_1,
            observations=observations_1,
            mean_fn=mean_fn,
            jitter=1e-5,
            validate_args=True)
        gprm2 = gprm1.copy(kernel=kernel_2,
                           index_points=index_points_2,
                           observation_index_points=observation_index_points_2,
                           observations=observations_2)

        precomputed_gprm1 = (
            tfd.GaussianProcessRegressionModel.precompute_regression_model(
                kernel=kernel_1,
                index_points=index_points_1,
                observation_index_points=observation_index_points_1,
                observations=observations_1,
                mean_fn=mean_fn,
                jitter=1e-5,
                validate_args=True))
        precomputed_gprm2 = precomputed_gprm1.copy(index_points=index_points_2)
        self.assertIs(precomputed_gprm1.mean_fn, precomputed_gprm2.mean_fn)
        self.assertIs(precomputed_gprm1.kernel, precomputed_gprm2.kernel)

        event_shape_1 = [5]
        event_shape_2 = [10]

        self.assertIsInstance(gprm1.kernel.base_kernel,
                              psd_kernels.ExponentiatedQuadratic)
        self.assertIsInstance(gprm2.kernel.base_kernel,
                              psd_kernels.ExpSinSquared)

        if self.is_static or tf.executing_eagerly():
            self.assertAllEqual(gprm1.batch_shape, gprm2.batch_shape)
            self.assertAllEqual(gprm1.event_shape, event_shape_1)
            self.assertAllEqual(gprm2.event_shape, event_shape_2)
            self.assertAllEqual(gprm1.index_points, index_points_1)
            self.assertAllEqual(gprm2.index_points, index_points_2)
            self.assertAllEqual(tf.get_static_value(gprm1.jitter),
                                tf.get_static_value(gprm2.jitter))
        else:
            self.assertAllEqual(self.evaluate(gprm1.batch_shape_tensor()),
                                self.evaluate(gprm2.batch_shape_tensor()))
            self.assertAllEqual(self.evaluate(gprm1.event_shape_tensor()),
                                event_shape_1)
            self.assertAllEqual(self.evaluate(gprm2.event_shape_tensor()),
                                event_shape_2)
            self.assertEqual(self.evaluate(gprm1.jitter),
                             self.evaluate(gprm2.jitter))
            self.assertAllEqual(self.evaluate(gprm1.index_points),
                                index_points_1)
            self.assertAllEqual(self.evaluate(gprm2.index_points),
                                index_points_2)