Exemplo n.º 1
0
    def test_trainable_bijectors(self, cls, batch_and_event_shape):
        init_fn, apply_fn = tfe_util.make_trainable_stateless(
            cls,
            batch_and_event_shape=batch_and_event_shape,
            validate_args=True)

        # Verify expected number of trainable variables.
        raw_parameters = init_fn(seed=test_util.test_seed())
        bijector = apply_fn(raw_parameters)
        self.assertLen(
            raw_parameters,
            len([
                k for k, p in bijector.parameter_properties().items()
                if p.is_tensor and p.is_preferred
            ]))
        # Verify gradients to all parameters.
        x = self.evaluate(
            samplers.normal(batch_and_event_shape, seed=test_util.test_seed()))
        y, grad = tfp.math.value_and_gradient(
            lambda params: apply_fn(params).forward(x), [raw_parameters])
        self.assertAllNotNone(grad)

        # Verify that the round trip doesn't broadcast, i.e., that it preserves
        # batch_and_event_shape.
        self.assertAllCloseNested(
            x,
            bijector.inverse(tf.identity(y)),  # Disable bijector cache.
            atol=1e-2)
Exemplo n.º 2
0
 def test_can_specify_initial_values(self):
     init_fn, apply_fn = tfe_util.make_trainable_stateless(
         tfd.Normal,
         initial_parameters={'scale': 1e-4},
         batch_and_event_shape=[3],
         validate_args=True)
     raw_parameters = init_fn(seed=test_util.test_seed())
     self.assertAllClose(apply_fn(raw_parameters).scale, [1e-4, 1e-4, 1e-4])
Exemplo n.º 3
0
 def test_dynamic_shape(self):
     batch_and_event_shape = tf1.placeholder_with_default([4, 3, 2],
                                                          shape=None)
     init_fn, apply_fn = tfe_util.make_trainable_stateless(
         tfd.Normal,
         batch_and_event_shape=batch_and_event_shape,
         validate_args=True)
     distribution = apply_fn(init_fn(seed=test_util.test_seed()))
     x = self.evaluate(distribution.sample(seed=test_util.test_seed()))
     self.assertAllEqual(x.shape, batch_and_event_shape)
Exemplo n.º 4
0
 def test_can_specify_parameter_dtype(self):
     init_fn, apply_fn = tfe_util.make_trainable_stateless(
         tfd.Normal,
         initial_parameters={'loc': 17.},
         parameter_dtype=tf.float64,
         validate_args=True)
     distribution = apply_fn(init_fn(seed=test_util.test_seed()))
     self.assertEqual(distribution.loc.dtype, tf.float64)
     self.assertEqual(distribution.scale.dtype, tf.float64)
     self.assertEqual(
         distribution.sample(seed=test_util.test_seed()).dtype, tf.float64)
Exemplo n.º 5
0
 def test_can_specify_fixed_values(self):
     init_fn, apply_fn = tfe_util.make_trainable_stateless(
         tfd.WishartTriL,
         batch_and_event_shape=[2, 2],
         validate_args=True,
         df=3)
     raw_parameters = init_fn(seed=test_util.test_seed())
     self.assertLen(raw_parameters, 1)
     distribution = apply_fn(raw_parameters)
     self.assertAllClose(distribution.df, 3.)
     self.assertAllEqual(
         distribution.sample(seed=test_util.test_seed()).shape, [2, 2])
Exemplo n.º 6
0
    def test_docstring_example_normal(self):
        if not JAX_MODE:
            self.skipTest('Stateless minimization requires optax.')
        import optax  # pylint: disable=g-import-not-at-top

        samples = [4.57, 6.37, 5.93, 7.98, 2.03, 3.59, 8.55, 3.45, 5.06, 6.44]
        init_fn, apply_fn = tfe_util.make_trainable_stateless(tfd.Normal)
        final_params, losses = tfp.math.minimize_stateless(
            lambda *params: -apply_fn(params).log_prob(samples),
            init=init_fn(seed=test_util.test_seed(sampler_type='stateless')),
            optimizer=optax.adam(0.1),
            num_steps=200)
        model = apply_fn(final_params)
        self.evaluate(losses)
        self.assertAllClose(tf.reduce_mean(samples), model.mean(), atol=2.0)
        self.assertAllClose(tf.math.reduce_std(samples),
                            model.stddev(),
                            atol=2.0)
Exemplo n.º 7
0
    def test_can_specify_callable_initializer(self):
        def uniform_initializer(_, shape, dtype, seed, constraining_bijector):
            return constraining_bijector.forward(
                tf.random.stateless_uniform(
                    constraining_bijector.inverse_event_shape_tensor(shape),
                    dtype=dtype,
                    seed=seed))

        init_fn, _ = tfe_util.make_trainable_stateless(
            tfd.Normal,
            initial_parameters=uniform_initializer,
            batch_and_event_shape=[3, 4],
            validate_args=True)
        raw_parameters = init_fn(test_util.test_seed())
        for v in tf.nest.flatten(raw_parameters):
            # Unconstrained parameters should be uniform in [0, 1].
            self.assertAllGreater(v, 0.)
            self.assertAllLess(v, 1.)
Exemplo n.º 8
0
    def test_trainable_distributions(self, cls, batch_and_event_shape):
        init_fn, apply_fn = tfe_util.make_trainable_stateless(
            cls,
            batch_and_event_shape=batch_and_event_shape,
            validate_args=True)

        raw_parameters = init_fn(seed=test_util.test_seed())
        distribution = apply_fn(raw_parameters)
        x = self.evaluate(distribution.sample(seed=test_util.test_seed()))
        self.assertAllEqual(x.shape, batch_and_event_shape)

        # Verify expected number of trainable variables.
        self.assertLen(
            raw_parameters,
            len([
                k for k, p in distribution.parameter_properties().items()
                if p.is_preferred
            ]))
        # Verify gradients to all parameters.
        _, grad = tfp.math.value_and_gradient(
            lambda params: apply_fn(params).log_prob(x), [raw_parameters])
        self.assertAllNotNone(grad)
Exemplo n.º 9
0
 def test_initialization_is_deterministic_with_seed(self):
     seed = test_util.test_seed(sampler_type='stateless')
     init_fn, _ = tfe_util.make_trainable_stateless(tfd.Normal,
                                                    validate_args=True)
     self.assertAllCloseNested(init_fn(seed=seed), init_fn(seed=seed))