def test_structured_parameters(self):
   make_trainable_normal = trainable_state_util.as_stateful_builder(
       yields_structured_parameter)
   trainable_normal = make_trainable_normal()
   self.assertLen(trainable_normal.trainable_variables, 2)
   self.evaluate(tf1.global_variables_initializer())
   self.assertAllEqual(trainable_normal.loc, tf.zeros([2]))
   self.assertAllEqual(trainable_normal.scale, tf.ones([2]))
 def test_rewrites_yield_to_return_in_docstring(self):
   wrapped = trainable_state_util.as_stateful_builder(
       generator_with_docstring)
   self.assertContainsExactSubsequence(
       generator_with_docstring.__doc__, 'Yields:')
   self.assertNotIn('Yields:', wrapped.__doc__)
   self.assertContainsExactSubsequence(
       wrapped.__doc__,
       'Test generator with a docstring.')
   self.assertContainsExactSubsequence(
       wrapped.__doc__,
       trainable_state_util._STATEFUL_RETURNS_DOCSTRING)
  def test_initialization_is_deterministic_with_seed(self):
    seed = test_util.test_seed(sampler_type='stateless')
    make_trainable_jd = trainable_state_util.as_stateful_builder(
        seed_generator)

    trainable_jd1 = make_trainable_jd(seed=seed)
    variables1 = trainable_jd1.trainable_variables
    self.assertLen(variables1, 5)

    trainable_jd2 = make_trainable_jd(seed=seed)
    variables2 = trainable_jd2.trainable_variables
    self.evaluate([v.initializer for v in variables1 + variables2])
    vals1, vals2 = self.evaluate((variables1, variables2))
    self.assertAllCloseNested(vals1, vals2)
 def test_fitting_example(self):
   build_trainable_normal = trainable_state_util.as_stateful_builder(
       normal_generator)
   trainable_dist = build_trainable_normal(
       shape=[],
       seed=test_util.test_seed(sampler_type='stateless'))
   optimizer = tf.optimizers.Adam(1.0)
   # Find the maximum likelihood distribution given observed data.
   x_observed = [3., -2., 1.7]
   losses = tfp.math.minimize(
       loss_fn=lambda: -trainable_dist.log_prob(x_observed),
       optimizer=optimizer,
       num_steps=300)
   self.evaluate(tf1.global_variables_initializer())
   losses = self.evaluate(losses)
   self.assertAllClose(trainable_dist.mean(),
                       np.mean(x_observed), atol=0.1)
   self.assertAllClose(trainable_dist.stddev(),
                       np.std(x_observed), atol=0.1)
Example #5
0
                parameter_dtype=nest_util.broadcast_structure(
                    event_shape, dtype),
                _up_to=event_shape))
        unconstrained_trainable_distribution = (
            joint_distribution_util.
            independent_joint_distribution_from_structure(
                unconstrained_trainable_distributions,
                batch_ndims=ps.rank_from_shape(batch_shape),
                validate_args=validate_args))
        if event_space_bijector is None:
            return unconstrained_trainable_distribution
        return transformed_distribution.TransformedDistribution(
            unconstrained_trainable_distribution, event_space_bijector)


build_factored_surrogate_posterior = trainable_state_util.as_stateful_builder(
    _factored_surrogate_posterior)
build_factored_surrogate_posterior_stateless = (
    trainable_state_util.as_stateless_builder(_factored_surrogate_posterior))


def _affine_surrogate_posterior(event_shape,
                                operators='diag',
                                bijector=None,
                                base_distribution=normal.Normal,
                                dtype=tf.float32,
                                batch_shape=(),
                                validate_args=False,
                                name=None):
    """Builds a joint variational posterior with a given `event_shape`.

  This function builds a surrogate posterior by applying a trainable
        return tf.math.abs(x) + dtype_util.eps(dtype)

    def _inverse(self, y):
        dtype = dtype_util.base_dtype(y.dtype)
        return tf.math.abs(y) - np.finfo(dtype.as_numpy_dtype).eps

    def _inverse_log_det_jacobian(self, y):
        return tf.zeros([], dtype_util.base_dtype(y.dtype))


_OPERATOR_COROUTINES = {
    tf.linalg.LinearOperatorLowerTriangular: _trainable_linear_operator_tril,
    tf.linalg.LinearOperatorDiag: _trainable_linear_operator_diag,
    tf.linalg.LinearOperatorFullMatrix: _trainable_linear_operator_full_matrix,
    tf.linalg.LinearOperatorZeros: _linear_operator_zeros,
    None: _linear_operator_zeros,
}

# TODO(davmre): also expose stateless builders.
build_trainable_linear_operator_block = (
    trainable_state_util.as_stateful_builder(_trainable_linear_operator_block))
build_trainable_linear_operator_tril = (
    trainable_state_util.as_stateful_builder(_trainable_linear_operator_tril))
build_trainable_linear_operator_diag = (
    trainable_state_util.as_stateful_builder(_trainable_linear_operator_diag))
build_trainable_linear_operator_full_matrix = (
    trainable_state_util.as_stateful_builder(
        _trainable_linear_operator_full_matrix))
build_linear_operator_zeros = (
    trainable_state_util.as_stateful_builder(_linear_operator_zeros))
Example #7
0
        https://arxiv.org/abs/2002.00643

  """
    with tf.name_scope(name or 'build_asvi_surrogate_posterior'):
        surrogate_posterior = yield from _asvi_surrogate_for_distribution(
            dist=prior,
            base_distribution_surrogate_fn=functools.partial(
                _asvi_convex_update_for_base_distribution,
                mean_field=mean_field,
                initial_prior_weight=initial_prior_weight),
            prior_substitution_rules=prior_substitution_rules,
            surrogate_rules=surrogate_rules)
        return surrogate_posterior


build_asvi_surrogate_posterior = trainable_state_util.as_stateful_builder(
    _build_asvi_surrogate_posterior)
# TODO(davmre): replace stateful example code in the stateless docstring.
build_asvi_surrogate_posterior_stateless = (
    trainable_state_util.as_stateless_builder(_build_asvi_surrogate_posterior))


def _get_coroutine_parameters(jdc_model, seed):
    """Runs a coroutine and intercepts yielded dists, yielding params only."""
    gen = jdc_model()
    to_send = None
    raw_parameters = []
    try:
        while True:
            val = gen.send(to_send)
            if isinstance(val, trainable_state_util.Parameter):
                to_send = yield val
Example #8
0
                name=parameter_name)

    return cls(**init_kwargs)


make_trainable = docstring_util.expand_docstring(minimize_example_code="""
```python
model = tfp.util.make_trainable(tfd.Normal)
losses = tfp.math.minimize(
  lambda: -model.log_prob(samples),
  optimizer=tf.optimizers.Adam(0.1),
  num_steps=200)
print('Fit Normal distribution with mean {} and stddev {}'.format(
  model.mean(),
  model.stddev()))
```""")(trainable_state_util.as_stateful_builder(_make_trainable))

make_trainable_stateless = docstring_util.expand_docstring(
    minimize_example_code="""
```python
init_fn, apply_fn = tfe_util.make_trainable_stateless(tfd.Normal)

import optax  # JAX only.
mle_params, losses = tfp.math.minimize_stateless(
  lambda *params: -apply_fn(params).log_prob(samples),
  init=init_fn(),
  optimizer=optax.adam(0.1),
  num_steps=200)
model = apply_fn(mle_params)
print('Fit Normal distribution with mean {} and stddev {}'.format(
  model.mean(),