def test_trainable_bijectors(self, cls, batch_and_event_shape): init_fn, apply_fn = tfe_util.make_trainable_stateless( cls, batch_and_event_shape=batch_and_event_shape, validate_args=True) # Verify expected number of trainable variables. raw_parameters = init_fn(seed=test_util.test_seed()) bijector = apply_fn(raw_parameters) self.assertLen( raw_parameters, len([ k for k, p in bijector.parameter_properties().items() if p.is_tensor and p.is_preferred ])) # Verify gradients to all parameters. x = self.evaluate( samplers.normal(batch_and_event_shape, seed=test_util.test_seed())) y, grad = tfp.math.value_and_gradient( lambda params: apply_fn(params).forward(x), [raw_parameters]) self.assertAllNotNone(grad) # Verify that the round trip doesn't broadcast, i.e., that it preserves # batch_and_event_shape. self.assertAllCloseNested( x, bijector.inverse(tf.identity(y)), # Disable bijector cache. atol=1e-2)
def test_can_specify_initial_values(self): init_fn, apply_fn = tfe_util.make_trainable_stateless( tfd.Normal, initial_parameters={'scale': 1e-4}, batch_and_event_shape=[3], validate_args=True) raw_parameters = init_fn(seed=test_util.test_seed()) self.assertAllClose(apply_fn(raw_parameters).scale, [1e-4, 1e-4, 1e-4])
def test_dynamic_shape(self): batch_and_event_shape = tf1.placeholder_with_default([4, 3, 2], shape=None) init_fn, apply_fn = tfe_util.make_trainable_stateless( tfd.Normal, batch_and_event_shape=batch_and_event_shape, validate_args=True) distribution = apply_fn(init_fn(seed=test_util.test_seed())) x = self.evaluate(distribution.sample(seed=test_util.test_seed())) self.assertAllEqual(x.shape, batch_and_event_shape)
def test_can_specify_parameter_dtype(self): init_fn, apply_fn = tfe_util.make_trainable_stateless( tfd.Normal, initial_parameters={'loc': 17.}, parameter_dtype=tf.float64, validate_args=True) distribution = apply_fn(init_fn(seed=test_util.test_seed())) self.assertEqual(distribution.loc.dtype, tf.float64) self.assertEqual(distribution.scale.dtype, tf.float64) self.assertEqual( distribution.sample(seed=test_util.test_seed()).dtype, tf.float64)
def test_can_specify_fixed_values(self): init_fn, apply_fn = tfe_util.make_trainable_stateless( tfd.WishartTriL, batch_and_event_shape=[2, 2], validate_args=True, df=3) raw_parameters = init_fn(seed=test_util.test_seed()) self.assertLen(raw_parameters, 1) distribution = apply_fn(raw_parameters) self.assertAllClose(distribution.df, 3.) self.assertAllEqual( distribution.sample(seed=test_util.test_seed()).shape, [2, 2])
def test_docstring_example_normal(self): if not JAX_MODE: self.skipTest('Stateless minimization requires optax.') import optax # pylint: disable=g-import-not-at-top samples = [4.57, 6.37, 5.93, 7.98, 2.03, 3.59, 8.55, 3.45, 5.06, 6.44] init_fn, apply_fn = tfe_util.make_trainable_stateless(tfd.Normal) final_params, losses = tfp.math.minimize_stateless( lambda *params: -apply_fn(params).log_prob(samples), init=init_fn(seed=test_util.test_seed(sampler_type='stateless')), optimizer=optax.adam(0.1), num_steps=200) model = apply_fn(final_params) self.evaluate(losses) self.assertAllClose(tf.reduce_mean(samples), model.mean(), atol=2.0) self.assertAllClose(tf.math.reduce_std(samples), model.stddev(), atol=2.0)
def test_can_specify_callable_initializer(self): def uniform_initializer(_, shape, dtype, seed, constraining_bijector): return constraining_bijector.forward( tf.random.stateless_uniform( constraining_bijector.inverse_event_shape_tensor(shape), dtype=dtype, seed=seed)) init_fn, _ = tfe_util.make_trainable_stateless( tfd.Normal, initial_parameters=uniform_initializer, batch_and_event_shape=[3, 4], validate_args=True) raw_parameters = init_fn(test_util.test_seed()) for v in tf.nest.flatten(raw_parameters): # Unconstrained parameters should be uniform in [0, 1]. self.assertAllGreater(v, 0.) self.assertAllLess(v, 1.)
def test_trainable_distributions(self, cls, batch_and_event_shape): init_fn, apply_fn = tfe_util.make_trainable_stateless( cls, batch_and_event_shape=batch_and_event_shape, validate_args=True) raw_parameters = init_fn(seed=test_util.test_seed()) distribution = apply_fn(raw_parameters) x = self.evaluate(distribution.sample(seed=test_util.test_seed())) self.assertAllEqual(x.shape, batch_and_event_shape) # Verify expected number of trainable variables. self.assertLen( raw_parameters, len([ k for k, p in distribution.parameter_properties().items() if p.is_preferred ])) # Verify gradients to all parameters. _, grad = tfp.math.value_and_gradient( lambda params: apply_fn(params).log_prob(x), [raw_parameters]) self.assertAllNotNone(grad)
def test_initialization_is_deterministic_with_seed(self): seed = test_util.test_seed(sampler_type='stateless') init_fn, _ = tfe_util.make_trainable_stateless(tfd.Normal, validate_args=True) self.assertAllCloseNested(init_fn(seed=seed), init_fn(seed=seed))