def _asvi_surrogate_for_markov_chain(dist, build_nested_surrogate): """Builds a structured surrogate posterior for a Markov chain.""" surrogate_prior = yield from build_nested_surrogate( dist.initial_state_prior) (transition_all_steps_init_fn, _) = trainable_state_util.as_stateless_builder( lambda: build_nested_surrogate( # pylint: disable=g-long-lambda dist.transition_fn( tf.range(dist.num_steps - 1), dist.initial_state_prior.sample( dist.num_steps - 1, seed=samplers.zeros_seed()))))() transition_params = yield trainable_state_util.Parameter( transition_all_steps_init_fn, name='markov_chain_transition_params') build_transition_one_step = trainable_state_util.as_stateless_builder( lambda step, state: build_nested_surrogate( # pylint: disable=g-long-lambda dist.transition_fn(step, state))) def surrogate_transition_fn(step, state): _, one_step_apply_fn = build_transition_one_step(step, state) return one_step_apply_fn( tf.nest.map_structure( # Gather parameters for this specific step of the chain. lambda v: tf.gather(v, step, axis=0), transition_params)) return markov_chain.MarkovChain(initial_state_prior=surrogate_prior, transition_fn=surrogate_transition_fn, num_steps=dist.num_steps, validate_args=dist.validate_args)
def test_trivial_generator(self): init_fn, apply_fn = trainable_state_util.as_stateless_builder( yields_never)() params = init_fn(seed=test_util.test_seed()) self.assertEmpty(params) result = apply_fn(params) self.assertIsNone(result)
def test_raises_when_generator_is_not_a_generator(self): init_fn, apply_fn = trainable_state_util.as_stateless_builder( lambda: tfd.Normal(0., 1.))() error_msg = 'must contain at least one `yield` statement' with self.assertRaisesRegex(ValueError, error_msg): init_fn() with self.assertRaisesRegex(ValueError, error_msg): apply_fn([])
def test_basic_parameter_names(self): init_fn, _ = trainable_state_util.as_stateless_builder( normal_generator)([2]) params = init_fn(test_util.test_seed(sampler_type='stateless')) param_keys = params._asdict().keys() # Params is a namedtuple / structuple. self.assertLen(param_keys, 2) self.assertIn('loc', param_keys) self.assertIn('scale', param_keys)
def test_assigns_default_names(self): init_fn, _ = trainable_state_util.as_stateless_builder( seed_generator)() params = init_fn(test_util.test_seed(sampler_type='stateless')) param_keys = params._asdict().keys() # Params is a namedtuple / structuple. self.assertLen(param_keys, 5) self.assertIn('parameter', param_keys) for i in range(1, 5): self.assertIn('parameter_{:04d}'.format(i), param_keys)
def test_structured_parameters(self): init_fn, apply_fn = trainable_state_util.as_stateless_builder( yields_structured_parameter)() params = init_fn(test_util.test_seed(sampler_type='stateless')) self.assertAllEqualNested( params.dict_loc_scale, {'scale': tfb.Softplus().inverse(tf.ones([2])), 'loc': tf.zeros([2])}) dist = apply_fn(params) self.assertAllEqual(dist.loc, tf.zeros([2])) self.assertAllEqual(dist.scale, tf.ones([2]))
def test_rewrites_yield_to_return_in_docstring(self): wrapped = trainable_state_util.as_stateless_builder( generator_with_docstring) self.assertContainsExactSubsequence( generator_with_docstring.__doc__, 'Yields:') self.assertNotIn('Yields:', wrapped.__doc__) self.assertContainsExactSubsequence( wrapped.__doc__, 'Test generator with a docstring.') self.assertContainsExactSubsequence( wrapped.__doc__, trainable_state_util._STATELESS_RETURNS_DOCSTRING)
def test_assigns_unique_names(self): init_fn, _ = trainable_state_util.as_stateless_builder( joint_normal_nested_generator)([[1], [2], [3]]) params = init_fn(test_util.test_seed(sampler_type='stateless')) param_keys = params._asdict().keys() # Params is a namedtuple / structuple. self.assertLen(param_keys, 6) self.assertIn('loc', param_keys) self.assertIn('scale', param_keys) self.assertIn('loc_0001', param_keys) self.assertIn('scale_0001', param_keys) self.assertIn('loc_0002', param_keys) self.assertIn('scale_0002', param_keys)
def test_apply_raises_on_bad_parameters(self): init_fn, apply_fn = trainable_state_util.as_stateless_builder( normal_generator)(shape=[2]) good_params = init_fn(seed=test_util.test_seed(sampler_type='stateless')) # Check that both calling styles are supported. self.assertIsInstance(apply_fn(good_params), tfd.Normal) self.assertIsInstance(apply_fn(*good_params), tfd.Normal) with self.assertRaisesRegex(ValueError, 'Insufficient parameters'): apply_fn() with self.assertRaisesRegex(ValueError, 'Insufficient parameters'): apply_fn(None) apply_fn(list(good_params) + [np.array(2.)])
def test_init_supports_arg_or_kwarg_seed(self): seed = test_util.test_seed(sampler_type='stateless') init_fn, _ = trainable_state_util.as_stateless_builder( seed_generator)() self.assertLen(init_fn(seed=seed), 5) # Check that we can invoke init_fn with an arg or kwarg seed, # regardless of how the inner functions are parameterized. self.assertAllCloseNested(init_fn(seed), init_fn(seed=seed)) if not JAX_MODE: # Check that we can initialize with no seed. self.assertLen(init_fn(), 5)
def test_fitting_example(self): if not JAX_MODE: self.skipTest('Requires JAX with optax.') import optax # pylint: disable=g-import-not-at-top build_trainable_normal_stateless = ( trainable_state_util.as_stateless_builder( normal_generator)) init_fn, apply_fn = build_trainable_normal_stateless(shape=[]) # Find the maximum likelihood distribution given observed data. x_observed = [3., -2., 1.7] mle_parameters, _ = tfp.math.minimize_stateless( loss_fn=lambda *params: -apply_fn(*params).log_prob(x_observed), init=init_fn(seed=test_util.test_seed(sampler_type='stateless')), optimizer=optax.adam(1.0), num_steps=400) mle_dist = apply_fn(mle_parameters) self.assertAllClose(mle_dist.mean(), np.mean(x_observed), atol=0.1) self.assertAllClose(mle_dist.stddev(), np.std(x_observed), atol=0.1)
def test_distribution_init_apply(self, generator, expected_num_params, shape): # Test passing arguments to the wrapper. init_fn, apply_fn = trainable_state_util.as_stateless_builder( generator)(shape) seed = test_util.test_seed(sampler_type='stateless') params = init_fn(seed) self.assertLen(params, expected_num_params) # Check that the distribution's samples have the expected shape. dist = apply_fn(params) x = dist.sample(seed=seed) self.assertAllEqualNested(shape, tf.nest.map_structure(ps.shape, x)) # Check that gradients are defined. _, grad = tfp.math.value_and_gradient( lambda *params: apply_fn(*params).log_prob(x), params) self.assertLen(grad, expected_num_params) self.assertAllNotNone(grad)
surrogate_posterior = yield from _asvi_surrogate_for_distribution( dist=prior, base_distribution_surrogate_fn=functools.partial( _asvi_convex_update_for_base_distribution, mean_field=mean_field, initial_prior_weight=initial_prior_weight), prior_substitution_rules=prior_substitution_rules, surrogate_rules=surrogate_rules) return surrogate_posterior build_asvi_surrogate_posterior = trainable_state_util.as_stateful_builder( _build_asvi_surrogate_posterior) # TODO(davmre): replace stateful example code in the stateless docstring. build_asvi_surrogate_posterior_stateless = ( trainable_state_util.as_stateless_builder(_build_asvi_surrogate_posterior)) def _get_coroutine_parameters(jdc_model, seed): """Runs a coroutine and intercepts yielded dists, yielding params only.""" gen = jdc_model() to_send = None raw_parameters = [] try: while True: val = gen.send(to_send) if isinstance(val, trainable_state_util.Parameter): to_send = yield val raw_parameters.append(to_send) else: # Random variable. seed, local_seed = samplers.split_seed(seed, n=2)
def test_raises_when_non_callable_yielded(self, generator): init_fn, _ = trainable_state_util.as_stateless_builder( generator)() with self.assertRaisesRegex(ValueError, 'Expected generator to yield'): init_fn()
def test_respects_constraining_bijector(self): init_fn, apply_fn = trainable_state_util.as_stateless_builder( normal_generator)([50]) params = init_fn(test_util.test_seed(sampler_type='stateless')) dist = apply_fn(params) self.assertAllGreater(dist.scale, 0)
unconstrained_trainable_distribution = ( joint_distribution_util. independent_joint_distribution_from_structure( unconstrained_trainable_distributions, batch_ndims=ps.rank_from_shape(batch_shape), validate_args=validate_args)) if event_space_bijector is None: return unconstrained_trainable_distribution return transformed_distribution.TransformedDistribution( unconstrained_trainable_distribution, event_space_bijector) build_factored_surrogate_posterior = trainable_state_util.as_stateful_builder( _factored_surrogate_posterior) build_factored_surrogate_posterior_stateless = ( trainable_state_util.as_stateless_builder(_factored_surrogate_posterior)) def _affine_surrogate_posterior(event_shape, operators='diag', bijector=None, base_distribution=normal.Normal, dtype=tf.float32, batch_shape=(), validate_args=False, name=None): """Builds a joint variational posterior with a given `event_shape`. This function builds a surrogate posterior by applying a trainable transformation to a standard base distribution and constraining the samples with `bijector`. The surrogate posterior has event shape equal to
make_trainable = docstring_util.expand_docstring(minimize_example_code=""" ```python model = tfp.util.make_trainable(tfd.Normal) losses = tfp.math.minimize( lambda: -model.log_prob(samples), optimizer=tf.optimizers.Adam(0.1), num_steps=200) print('Fit Normal distribution with mean {} and stddev {}'.format( model.mean(), model.stddev())) ```""")(trainable_state_util.as_stateful_builder(_make_trainable)) make_trainable_stateless = docstring_util.expand_docstring( minimize_example_code=""" ```python init_fn, apply_fn = tfe_util.make_trainable_stateless(tfd.Normal) import optax # JAX only. mle_params, losses = tfp.math.minimize_stateless( lambda *params: -apply_fn(params).log_prob(samples), init=init_fn(), optimizer=optax.adam(0.1), num_steps=200) model = apply_fn(mle_params) print('Fit Normal distribution with mean {} and stddev {}'.format( model.mean(), model.stddev())) ```""")(trainable_state_util.as_stateless_builder(_make_trainable))