def test_structured_parameters(self): make_trainable_normal = trainable_state_util.as_stateful_builder( yields_structured_parameter) trainable_normal = make_trainable_normal() self.assertLen(trainable_normal.trainable_variables, 2) self.evaluate(tf1.global_variables_initializer()) self.assertAllEqual(trainable_normal.loc, tf.zeros([2])) self.assertAllEqual(trainable_normal.scale, tf.ones([2]))
def test_rewrites_yield_to_return_in_docstring(self): wrapped = trainable_state_util.as_stateful_builder( generator_with_docstring) self.assertContainsExactSubsequence( generator_with_docstring.__doc__, 'Yields:') self.assertNotIn('Yields:', wrapped.__doc__) self.assertContainsExactSubsequence( wrapped.__doc__, 'Test generator with a docstring.') self.assertContainsExactSubsequence( wrapped.__doc__, trainable_state_util._STATEFUL_RETURNS_DOCSTRING)
def test_initialization_is_deterministic_with_seed(self): seed = test_util.test_seed(sampler_type='stateless') make_trainable_jd = trainable_state_util.as_stateful_builder( seed_generator) trainable_jd1 = make_trainable_jd(seed=seed) variables1 = trainable_jd1.trainable_variables self.assertLen(variables1, 5) trainable_jd2 = make_trainable_jd(seed=seed) variables2 = trainable_jd2.trainable_variables self.evaluate([v.initializer for v in variables1 + variables2]) vals1, vals2 = self.evaluate((variables1, variables2)) self.assertAllCloseNested(vals1, vals2)
def test_fitting_example(self): build_trainable_normal = trainable_state_util.as_stateful_builder( normal_generator) trainable_dist = build_trainable_normal( shape=[], seed=test_util.test_seed(sampler_type='stateless')) optimizer = tf.optimizers.Adam(1.0) # Find the maximum likelihood distribution given observed data. x_observed = [3., -2., 1.7] losses = tfp.math.minimize( loss_fn=lambda: -trainable_dist.log_prob(x_observed), optimizer=optimizer, num_steps=300) self.evaluate(tf1.global_variables_initializer()) losses = self.evaluate(losses) self.assertAllClose(trainable_dist.mean(), np.mean(x_observed), atol=0.1) self.assertAllClose(trainable_dist.stddev(), np.std(x_observed), atol=0.1)
parameter_dtype=nest_util.broadcast_structure( event_shape, dtype), _up_to=event_shape)) unconstrained_trainable_distribution = ( joint_distribution_util. independent_joint_distribution_from_structure( unconstrained_trainable_distributions, batch_ndims=ps.rank_from_shape(batch_shape), validate_args=validate_args)) if event_space_bijector is None: return unconstrained_trainable_distribution return transformed_distribution.TransformedDistribution( unconstrained_trainable_distribution, event_space_bijector) build_factored_surrogate_posterior = trainable_state_util.as_stateful_builder( _factored_surrogate_posterior) build_factored_surrogate_posterior_stateless = ( trainable_state_util.as_stateless_builder(_factored_surrogate_posterior)) def _affine_surrogate_posterior(event_shape, operators='diag', bijector=None, base_distribution=normal.Normal, dtype=tf.float32, batch_shape=(), validate_args=False, name=None): """Builds a joint variational posterior with a given `event_shape`. This function builds a surrogate posterior by applying a trainable
return tf.math.abs(x) + dtype_util.eps(dtype) def _inverse(self, y): dtype = dtype_util.base_dtype(y.dtype) return tf.math.abs(y) - np.finfo(dtype.as_numpy_dtype).eps def _inverse_log_det_jacobian(self, y): return tf.zeros([], dtype_util.base_dtype(y.dtype)) _OPERATOR_COROUTINES = { tf.linalg.LinearOperatorLowerTriangular: _trainable_linear_operator_tril, tf.linalg.LinearOperatorDiag: _trainable_linear_operator_diag, tf.linalg.LinearOperatorFullMatrix: _trainable_linear_operator_full_matrix, tf.linalg.LinearOperatorZeros: _linear_operator_zeros, None: _linear_operator_zeros, } # TODO(davmre): also expose stateless builders. build_trainable_linear_operator_block = ( trainable_state_util.as_stateful_builder(_trainable_linear_operator_block)) build_trainable_linear_operator_tril = ( trainable_state_util.as_stateful_builder(_trainable_linear_operator_tril)) build_trainable_linear_operator_diag = ( trainable_state_util.as_stateful_builder(_trainable_linear_operator_diag)) build_trainable_linear_operator_full_matrix = ( trainable_state_util.as_stateful_builder( _trainable_linear_operator_full_matrix)) build_linear_operator_zeros = ( trainable_state_util.as_stateful_builder(_linear_operator_zeros))
https://arxiv.org/abs/2002.00643 """ with tf.name_scope(name or 'build_asvi_surrogate_posterior'): surrogate_posterior = yield from _asvi_surrogate_for_distribution( dist=prior, base_distribution_surrogate_fn=functools.partial( _asvi_convex_update_for_base_distribution, mean_field=mean_field, initial_prior_weight=initial_prior_weight), prior_substitution_rules=prior_substitution_rules, surrogate_rules=surrogate_rules) return surrogate_posterior build_asvi_surrogate_posterior = trainable_state_util.as_stateful_builder( _build_asvi_surrogate_posterior) # TODO(davmre): replace stateful example code in the stateless docstring. build_asvi_surrogate_posterior_stateless = ( trainable_state_util.as_stateless_builder(_build_asvi_surrogate_posterior)) def _get_coroutine_parameters(jdc_model, seed): """Runs a coroutine and intercepts yielded dists, yielding params only.""" gen = jdc_model() to_send = None raw_parameters = [] try: while True: val = gen.send(to_send) if isinstance(val, trainable_state_util.Parameter): to_send = yield val
name=parameter_name) return cls(**init_kwargs) make_trainable = docstring_util.expand_docstring(minimize_example_code=""" ```python model = tfp.util.make_trainable(tfd.Normal) losses = tfp.math.minimize( lambda: -model.log_prob(samples), optimizer=tf.optimizers.Adam(0.1), num_steps=200) print('Fit Normal distribution with mean {} and stddev {}'.format( model.mean(), model.stddev())) ```""")(trainable_state_util.as_stateful_builder(_make_trainable)) make_trainable_stateless = docstring_util.expand_docstring( minimize_example_code=""" ```python init_fn, apply_fn = tfe_util.make_trainable_stateless(tfd.Normal) import optax # JAX only. mle_params, losses = tfp.math.minimize_stateless( lambda *params: -apply_fn(params).log_prob(samples), init=init_fn(), optimizer=optax.adam(0.1), num_steps=200) model = apply_fn(mle_params) print('Fit Normal distribution with mean {} and stddev {}'.format( model.mean(),