Esempio n. 1
0
def _asvi_surrogate_for_markov_chain(dist, build_nested_surrogate):
    """Builds a structured surrogate posterior for a Markov chain."""
    surrogate_prior = yield from build_nested_surrogate(
        dist.initial_state_prior)

    (transition_all_steps_init_fn,
     _) = trainable_state_util.as_stateless_builder(
         lambda: build_nested_surrogate(  # pylint: disable=g-long-lambda
             dist.transition_fn(
                 tf.range(dist.num_steps - 1),
                 dist.initial_state_prior.sample(
                     dist.num_steps - 1, seed=samplers.zeros_seed()))))()
    transition_params = yield trainable_state_util.Parameter(
        transition_all_steps_init_fn, name='markov_chain_transition_params')

    build_transition_one_step = trainable_state_util.as_stateless_builder(
        lambda step, state: build_nested_surrogate(  # pylint: disable=g-long-lambda
            dist.transition_fn(step, state)))

    def surrogate_transition_fn(step, state):
        _, one_step_apply_fn = build_transition_one_step(step, state)
        return one_step_apply_fn(
            tf.nest.map_structure(
                # Gather parameters for this specific step of the chain.
                lambda v: tf.gather(v, step, axis=0),
                transition_params))

    return markov_chain.MarkovChain(initial_state_prior=surrogate_prior,
                                    transition_fn=surrogate_transition_fn,
                                    num_steps=dist.num_steps,
                                    validate_args=dist.validate_args)
 def test_trivial_generator(self):
   init_fn, apply_fn = trainable_state_util.as_stateless_builder(
       yields_never)()
   params = init_fn(seed=test_util.test_seed())
   self.assertEmpty(params)
   result = apply_fn(params)
   self.assertIsNone(result)
 def test_raises_when_generator_is_not_a_generator(self):
   init_fn, apply_fn = trainable_state_util.as_stateless_builder(
       lambda: tfd.Normal(0., 1.))()
   error_msg = 'must contain at least one `yield` statement'
   with self.assertRaisesRegex(ValueError, error_msg):
     init_fn()
   with self.assertRaisesRegex(ValueError, error_msg):
     apply_fn([])
 def test_basic_parameter_names(self):
   init_fn, _ = trainable_state_util.as_stateless_builder(
       normal_generator)([2])
   params = init_fn(test_util.test_seed(sampler_type='stateless'))
   param_keys = params._asdict().keys()  # Params is a namedtuple / structuple.
   self.assertLen(param_keys, 2)
   self.assertIn('loc', param_keys)
   self.assertIn('scale', param_keys)
 def test_assigns_default_names(self):
   init_fn, _ = trainable_state_util.as_stateless_builder(
       seed_generator)()
   params = init_fn(test_util.test_seed(sampler_type='stateless'))
   param_keys = params._asdict().keys()  # Params is a namedtuple / structuple.
   self.assertLen(param_keys, 5)
   self.assertIn('parameter', param_keys)
   for i in range(1, 5):
     self.assertIn('parameter_{:04d}'.format(i), param_keys)
 def test_structured_parameters(self):
   init_fn, apply_fn = trainable_state_util.as_stateless_builder(
       yields_structured_parameter)()
   params = init_fn(test_util.test_seed(sampler_type='stateless'))
   self.assertAllEqualNested(
       params.dict_loc_scale,
       {'scale': tfb.Softplus().inverse(tf.ones([2])),
        'loc': tf.zeros([2])})
   dist = apply_fn(params)
   self.assertAllEqual(dist.loc, tf.zeros([2]))
   self.assertAllEqual(dist.scale, tf.ones([2]))
 def test_rewrites_yield_to_return_in_docstring(self):
   wrapped = trainable_state_util.as_stateless_builder(
       generator_with_docstring)
   self.assertContainsExactSubsequence(
       generator_with_docstring.__doc__, 'Yields:')
   self.assertNotIn('Yields:', wrapped.__doc__)
   self.assertContainsExactSubsequence(
       wrapped.__doc__,
       'Test generator with a docstring.')
   self.assertContainsExactSubsequence(
       wrapped.__doc__,
       trainable_state_util._STATELESS_RETURNS_DOCSTRING)
 def test_assigns_unique_names(self):
   init_fn, _ = trainable_state_util.as_stateless_builder(
       joint_normal_nested_generator)([[1], [2], [3]])
   params = init_fn(test_util.test_seed(sampler_type='stateless'))
   param_keys = params._asdict().keys()  # Params is a namedtuple / structuple.
   self.assertLen(param_keys, 6)
   self.assertIn('loc', param_keys)
   self.assertIn('scale', param_keys)
   self.assertIn('loc_0001', param_keys)
   self.assertIn('scale_0001', param_keys)
   self.assertIn('loc_0002', param_keys)
   self.assertIn('scale_0002', param_keys)
  def test_apply_raises_on_bad_parameters(self):
    init_fn, apply_fn = trainable_state_util.as_stateless_builder(
        normal_generator)(shape=[2])
    good_params = init_fn(seed=test_util.test_seed(sampler_type='stateless'))
    # Check that both calling styles are supported.
    self.assertIsInstance(apply_fn(good_params), tfd.Normal)
    self.assertIsInstance(apply_fn(*good_params), tfd.Normal)

    with self.assertRaisesRegex(ValueError, 'Insufficient parameters'):
      apply_fn()
    with self.assertRaisesRegex(ValueError, 'Insufficient parameters'):
      apply_fn(None)
    apply_fn(list(good_params) + [np.array(2.)])
  def test_init_supports_arg_or_kwarg_seed(self):

    seed = test_util.test_seed(sampler_type='stateless')
    init_fn, _ = trainable_state_util.as_stateless_builder(
        seed_generator)()
    self.assertLen(init_fn(seed=seed), 5)
    # Check that we can invoke init_fn with an arg or kwarg seed,
    # regardless of how the inner functions are parameterized.
    self.assertAllCloseNested(init_fn(seed), init_fn(seed=seed))

    if not JAX_MODE:
      # Check that we can initialize with no seed.
      self.assertLen(init_fn(), 5)
  def test_fitting_example(self):
    if not JAX_MODE:
      self.skipTest('Requires JAX with optax.')
    import optax  # pylint: disable=g-import-not-at-top
    build_trainable_normal_stateless = (
        trainable_state_util.as_stateless_builder(
            normal_generator))
    init_fn, apply_fn = build_trainable_normal_stateless(shape=[])

    # Find the maximum likelihood distribution given observed data.
    x_observed = [3., -2., 1.7]
    mle_parameters, _ = tfp.math.minimize_stateless(
        loss_fn=lambda *params: -apply_fn(*params).log_prob(x_observed),
        init=init_fn(seed=test_util.test_seed(sampler_type='stateless')),
        optimizer=optax.adam(1.0),
        num_steps=400)
    mle_dist = apply_fn(mle_parameters)
    self.assertAllClose(mle_dist.mean(), np.mean(x_observed), atol=0.1)
    self.assertAllClose(mle_dist.stddev(), np.std(x_observed), atol=0.1)
  def test_distribution_init_apply(self, generator, expected_num_params, shape):
    # Test passing arguments to the wrapper.
    init_fn, apply_fn = trainable_state_util.as_stateless_builder(
        generator)(shape)

    seed = test_util.test_seed(sampler_type='stateless')
    params = init_fn(seed)
    self.assertLen(params, expected_num_params)

    # Check that the distribution's samples have the expected shape.
    dist = apply_fn(params)
    x = dist.sample(seed=seed)
    self.assertAllEqualNested(shape, tf.nest.map_structure(ps.shape, x))

    # Check that gradients are defined.
    _, grad = tfp.math.value_and_gradient(
        lambda *params: apply_fn(*params).log_prob(x), params)
    self.assertLen(grad, expected_num_params)
    self.assertAllNotNone(grad)
Esempio n. 13
0
        surrogate_posterior = yield from _asvi_surrogate_for_distribution(
            dist=prior,
            base_distribution_surrogate_fn=functools.partial(
                _asvi_convex_update_for_base_distribution,
                mean_field=mean_field,
                initial_prior_weight=initial_prior_weight),
            prior_substitution_rules=prior_substitution_rules,
            surrogate_rules=surrogate_rules)
        return surrogate_posterior


build_asvi_surrogate_posterior = trainable_state_util.as_stateful_builder(
    _build_asvi_surrogate_posterior)
# TODO(davmre): replace stateful example code in the stateless docstring.
build_asvi_surrogate_posterior_stateless = (
    trainable_state_util.as_stateless_builder(_build_asvi_surrogate_posterior))


def _get_coroutine_parameters(jdc_model, seed):
    """Runs a coroutine and intercepts yielded dists, yielding params only."""
    gen = jdc_model()
    to_send = None
    raw_parameters = []
    try:
        while True:
            val = gen.send(to_send)
            if isinstance(val, trainable_state_util.Parameter):
                to_send = yield val
                raw_parameters.append(to_send)
            else:  # Random variable.
                seed, local_seed = samplers.split_seed(seed, n=2)
 def test_raises_when_non_callable_yielded(self, generator):
   init_fn, _ = trainable_state_util.as_stateless_builder(
       generator)()
   with self.assertRaisesRegex(ValueError, 'Expected generator to yield'):
     init_fn()
 def test_respects_constraining_bijector(self):
   init_fn, apply_fn = trainable_state_util.as_stateless_builder(
       normal_generator)([50])
   params = init_fn(test_util.test_seed(sampler_type='stateless'))
   dist = apply_fn(params)
   self.assertAllGreater(dist.scale, 0)
Esempio n. 16
0
        unconstrained_trainable_distribution = (
            joint_distribution_util.
            independent_joint_distribution_from_structure(
                unconstrained_trainable_distributions,
                batch_ndims=ps.rank_from_shape(batch_shape),
                validate_args=validate_args))
        if event_space_bijector is None:
            return unconstrained_trainable_distribution
        return transformed_distribution.TransformedDistribution(
            unconstrained_trainable_distribution, event_space_bijector)


build_factored_surrogate_posterior = trainable_state_util.as_stateful_builder(
    _factored_surrogate_posterior)
build_factored_surrogate_posterior_stateless = (
    trainable_state_util.as_stateless_builder(_factored_surrogate_posterior))


def _affine_surrogate_posterior(event_shape,
                                operators='diag',
                                bijector=None,
                                base_distribution=normal.Normal,
                                dtype=tf.float32,
                                batch_shape=(),
                                validate_args=False,
                                name=None):
    """Builds a joint variational posterior with a given `event_shape`.

  This function builds a surrogate posterior by applying a trainable
  transformation to a standard base distribution and constraining the samples
  with `bijector`. The surrogate posterior has event shape equal to
Esempio n. 17
0

make_trainable = docstring_util.expand_docstring(minimize_example_code="""
```python
model = tfp.util.make_trainable(tfd.Normal)
losses = tfp.math.minimize(
  lambda: -model.log_prob(samples),
  optimizer=tf.optimizers.Adam(0.1),
  num_steps=200)
print('Fit Normal distribution with mean {} and stddev {}'.format(
  model.mean(),
  model.stddev()))
```""")(trainable_state_util.as_stateful_builder(_make_trainable))

make_trainable_stateless = docstring_util.expand_docstring(
    minimize_example_code="""
```python
init_fn, apply_fn = tfe_util.make_trainable_stateless(tfd.Normal)

import optax  # JAX only.
mle_params, losses = tfp.math.minimize_stateless(
  lambda *params: -apply_fn(params).log_prob(samples),
  init=init_fn(),
  optimizer=optax.adam(0.1),
  num_steps=200)
model = apply_fn(mle_params)
print('Fit Normal distribution with mean {} and stddev {}'.format(
  model.mean(),
  model.stddev()))
```""")(trainable_state_util.as_stateless_builder(_make_trainable))