Пример #1
0
def _set_seed(seed):
    """Helper which uses graph seed if using eager."""
    # TODO(b/68017812): Deprecate once eager correctly supports seed.
    if tf.executing_eagerly():
        return None
    return seed
Пример #2
0
 def _maybe_seed():
   if tf.executing_eagerly():
     tf1.set_random_seed(42)
     return None
   return 42
Пример #3
0
def variables_save(filename, variables):
    """Saves structure of `tf.Variable`s to `filename`."""
    if not tf.executing_eagerly():
        raise ValueError('Can only `save` while in eager mode.')
    np.savez_compressed(filename,
                        *[v.numpy() for v in tf.nest.flatten(variables)])
Пример #4
0
 def wrapped(*args, **kwargs):
     maybe_xla_best_effort = (tf.xla.experimental.jit_scope(
         compile_ops=True) if not tf.executing_eagerly()
                              and xla_best_effort else _dummy_context())
     with maybe_xla_best_effort:
         return f(*args, **kwargs)
Пример #5
0
    def testShapes(self):
        # We'll use a batch shape of [2, 3, 5, 7, 11]

        # 5x5 grid of index points in R^2 and flatten to 25x2
        index_points = np.linspace(-4., 4., 5, dtype=np.float64)
        index_points = np.stack(np.meshgrid(index_points, index_points),
                                axis=-1)
        index_points = np.reshape(index_points, [-1, 2])
        # ==> shape = [25, 2]
        batched_index_points = np.reshape(index_points, [1, 1, 25, 2])
        batched_index_points = np.stack([batched_index_points] * 5)
        # ==> shape = [5, 1, 1, 25, 2]

        # Kernel with batch_shape [2, 3, 1, 1, 1]
        amplitude = np.array([1., 2.], np.float64).reshape([2, 1, 1, 1, 1])
        length_scale = np.array([.1, .2, .3],
                                np.float64).reshape([1, 3, 1, 1, 1])
        observation_noise_variance = np.array([1e-9], np.float64).reshape(
            [1, 1, 1, 1, 1])

        jitter = np.float64(1e-6)
        observation_index_points = (np.random.uniform(
            -1., 1., (7, 1, 7, 2)).astype(np.float64))
        observations = np.random.uniform(-1., 1., (11, 7)).astype(np.float64)

        if not self.is_static:
            amplitude = tf1.placeholder_with_default(amplitude, shape=None)
            length_scale = tf1.placeholder_with_default(length_scale,
                                                        shape=None)
            batched_index_points = tf1.placeholder_with_default(
                batched_index_points, shape=None)

            observation_index_points = tf1.placeholder_with_default(
                observation_index_points, shape=None)
            observations = tf1.placeholder_with_default(observations,
                                                        shape=None)

        kernel = psd_kernels.ExponentiatedQuadratic(amplitude, length_scale)

        gprm = tfd.GaussianProcessRegressionModel(kernel,
                                                  batched_index_points,
                                                  observation_index_points,
                                                  observations,
                                                  observation_noise_variance,
                                                  jitter=jitter,
                                                  validate_args=True)

        batch_shape = [2, 3, 5, 7, 11]
        event_shape = [25]
        sample_shape = [9, 3]

        samples = gprm.sample(sample_shape, seed=test_util.test_seed())

        if self.is_static or tf.executing_eagerly():
            self.assertAllEqual(gprm.batch_shape_tensor(), batch_shape)
            self.assertAllEqual(gprm.event_shape_tensor(), event_shape)
            self.assertAllEqual(samples.shape,
                                sample_shape + batch_shape + event_shape)
            self.assertAllEqual(gprm.batch_shape, batch_shape)
            self.assertAllEqual(gprm.event_shape, event_shape)
            self.assertAllEqual(samples.shape,
                                sample_shape + batch_shape + event_shape)
        else:
            self.assertAllEqual(self.evaluate(gprm.batch_shape_tensor()),
                                batch_shape)
            self.assertAllEqual(self.evaluate(gprm.event_shape_tensor()),
                                event_shape)
            self.assertAllEqual(
                self.evaluate(samples).shape,
                sample_shape + batch_shape + event_shape)
            self.assertIsNone(tensorshape_util.rank(samples.shape))
            self.assertIsNone(tensorshape_util.rank(gprm.batch_shape))
            self.assertEqual(tensorshape_util.rank(gprm.event_shape), 1)
            self.assertIsNone(
                tf.compat.dimension_value(
                    tensorshape_util.dims(gprm.event_shape)[0]))
Пример #6
0
    def testCopy(self):
        # 5 random index points in R^2
        index_points_1 = np.random.uniform(-4., 4., (5, 2)).astype(np.float32)
        # 10 random index points in R^2
        index_points_2 = np.random.uniform(-4., 4., (10, 2)).astype(np.float32)

        observation_index_points_1 = (np.random.uniform(
            -4., 4., (7, 2)).astype(np.float32))
        observation_index_points_2 = (np.random.uniform(
            -4., 4., (9, 2)).astype(np.float32))

        observations_1 = np.random.uniform(-1., 1., 7).astype(np.float32)
        observations_2 = np.random.uniform(-1., 1., 9).astype(np.float32)

        # ==> shape = [6, 25, 2]
        if not self.is_static:
            index_points_1 = tf1.placeholder_with_default(index_points_1,
                                                          shape=None)
            index_points_2 = tf1.placeholder_with_default(index_points_2,
                                                          shape=None)
            observation_index_points_1 = tf1.placeholder_with_default(
                observation_index_points_1, shape=None)
            observation_index_points_2 = tf1.placeholder_with_default(
                observation_index_points_2, shape=None)
            observations_1 = tf1.placeholder_with_default(observations_1,
                                                          shape=None)
            observations_2 = tf1.placeholder_with_default(observations_2,
                                                          shape=None)

        mean_fn = lambda x: np.array([0.], np.float32)
        kernel_1 = psd_kernels.ExponentiatedQuadratic()
        kernel_2 = psd_kernels.ExpSinSquared()

        gprm1 = tfd.GaussianProcessRegressionModel(
            kernel=kernel_1,
            index_points=index_points_1,
            observation_index_points=observation_index_points_1,
            observations=observations_1,
            mean_fn=mean_fn,
            jitter=1e-5,
            validate_args=True)
        gprm2 = gprm1.copy(kernel=kernel_2,
                           index_points=index_points_2,
                           observation_index_points=observation_index_points_2,
                           observations=observations_2)

        event_shape_1 = [5]
        event_shape_2 = [10]

        self.assertIsInstance(gprm1.kernel.base_kernel,
                              psd_kernels.ExponentiatedQuadratic)
        self.assertIsInstance(gprm2.kernel.base_kernel,
                              psd_kernels.ExpSinSquared)

        if self.is_static or tf.executing_eagerly():
            self.assertAllEqual(gprm1.batch_shape, gprm2.batch_shape)
            self.assertAllEqual(gprm1.event_shape, event_shape_1)
            self.assertAllEqual(gprm2.event_shape, event_shape_2)
            self.assertAllEqual(gprm1.index_points, index_points_1)
            self.assertAllEqual(gprm2.index_points, index_points_2)
            self.assertAllEqual(tf.get_static_value(gprm1.jitter),
                                tf.get_static_value(gprm2.jitter))
        else:
            self.assertAllEqual(self.evaluate(gprm1.batch_shape_tensor()),
                                self.evaluate(gprm2.batch_shape_tensor()))
            self.assertAllEqual(self.evaluate(gprm1.event_shape_tensor()),
                                event_shape_1)
            self.assertAllEqual(self.evaluate(gprm2.event_shape_tensor()),
                                event_shape_2)
            self.assertEqual(self.evaluate(gprm1.jitter),
                             self.evaluate(gprm2.jitter))
            self.assertAllEqual(self.evaluate(gprm1.index_points),
                                index_points_1)
            self.assertAllEqual(self.evaluate(gprm2.index_points),
                                index_points_2)