def _fn(dtype, shape, name, trainable, add_variable_fn, initializer=tf.random_normal_initializer(stddev=0.1), regularizer=None, constraint=None, **kwargs): loc_scale_fn = tensor_loc_scale_fn(loc_initializer=initializer, loc_regularizer=regularizer, loc_constraint=constraint, **kwargs) loc, scale = loc_scale_fn(dtype, shape, name, trainable, add_variable_fn) prec = tfp.util.DeferredTensor(scale, precision_from_scale, name='precision') if scale is None: dist = tfd.Deterministic(loc=loc) else: loc_reparametrized, scale_reparametrized = \ reparametrize_loc_scale(loc, prec, loc_ratio, prec_ratio) dist = tfd.Normal(loc=loc_reparametrized, scale=scale_reparametrized) batch_ndims = tf.size(dist.batch_shape_tensor()) return tfd.Independent(dist, reinterpreted_batch_ndims=batch_ndims)
def new(params, event_shape=(), log_prob=None, reinterpreted_batch_ndims=None, validate_args=False, name='DeterministicLayer'): """Create the distribution instance from a `params` vector.""" params = tf.convert_to_tensor(value=params, name='params') event_shape = dist_util.expand_to_vector( tf.convert_to_tensor(value=event_shape, name='event_shape', dtype=tf.int32), tensor_name='event_shape', ) output_shape = tf.concat( [tf.shape(input=params)[:-1], event_shape], axis=0, ) dist = tfd.Deterministic(loc=tf.reshape(params, output_shape), validate_args=validate_args, name=name) # override the log-prob function if log_prob is not None and callable(log_prob): dist.log_prob = types.MethodType(log_prob, dist) # independent if reinterpreted_batch_ndims is not None and reinterpreted_batch_ndims > 0: dist = tfd.Independent( dist, reinterpreted_batch_ndims=int(reinterpreted_batch_ndims)) return dist
def model(): i = yield Root(tfd.Categorical(probs=initial_prob, dtype=tf.int32)) for t in range(n_steps - 1): i = yield tfd.Categorical(probs=tf.gather( transition_matrix, i), dtype=tf.int32) yield tfd.Deterministic(i)
def model(): weights = yield from nest_util.map_structure_coroutine( _horseshoe, scale={ 'a': tf.ones([5]) * 100., 'b': tf.ones([2]) * 1e-2 }, _with_tuple_paths=True) yield tfd.Deterministic( tf.sqrt(tf.norm(weights['a'])**2 + tf.norm(weights['b'])**2), name='weights_norm')
def transition_fn(_, previous_state): return tfd.JointDistributionNamed( { # The autoregressive coefficients and the `log_scale` each follow # an independent slow-moving random walk. 'coefs': tfd.Independent( tfd.Normal(loc=previous_state['coefs'], scale=0.01), reinterpreted_batch_ndims=1), 'log_scale': tfd.Normal(loc=previous_state['log_scale'], scale=0.01), # The level is a linear combination of the previous *two* levels, # with additional noise of scale `exp(log_scale)`. 'level': lambda coefs, log_scale: tfd.Normal( # pylint: disable=g-long-lambda loc=(coefs[..., 0] * previous_state['level'] + coefs[..., 1] * previous_state['previous_level']), scale=tf.exp(log_scale)), # Store the previous level to access at the next step. 'previous_level': tfd.Deterministic(previous_state['level'])})
def transition_fn(_, previous_state): return tfd.JointDistributionNamedAutoBatched( # The previous state may include batch dimensions. Since the log scale # is a scalar quantity, its shape is the batch shape. batch_ndims=ps.rank(previous_state['log_scale']), model={ # The autoregressive coefficients and the `log_scale` each follow # an independent slow-moving random walk. 'coefs': tfd.Normal(loc=previous_state['coefs'], scale=0.01), 'log_scale': tfd.Normal(loc=previous_state['log_scale'], scale=0.01), # The level is a linear combination of the previous *two* levels, # with additional noise of scale `exp(log_scale)`. 'level': lambda coefs, log_scale: tfd.Normal( # pylint: disable=g-long-lambda loc=(coefs[..., 0] * previous_state['level'] + coefs[..., 1] * previous_state['previous_level']), scale=tf.exp(log_scale)), # Store the previous level to access at the next step. 'previous_level': tfd.Deterministic(previous_state['level'])})
def _fn(dtype, shape, name, trainable, add_variable_fn): """Creates multivariate `Deterministic` or `Normal` distribution. Args: dtype: Type of parameter's event. shape: Python `list`-like representing the parameter's event shape. name: Python `str` name prepended to any created (or existing) `tf.Variable`s. trainable: Python `bool` indicating all created `tf.Variable`s should be added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. add_variable_fn: `tf.get_variable`-like `callable` used to create (or access existing) `tf.Variable`s. Returns: Multivariate `Deterministic` or `Normal` distribution. """ loc, scale = loc_scale_fn(dtype, shape, name, trainable, add_variable_fn) if scale is None: dist = tfd.Deterministic(loc=loc) else: dist = tfd.Normal(loc=loc, scale=scale) batch_ndims = tf.size(input=dist.batch_shape_tensor()) return tfd.Independent(dist, reinterpreted_batch_ndims=batch_ndims)
class MarkovChainBijectorTest(test_util.TestCase): # pylint: disable=g-long-lambda @parameterized.named_parameters( dict(testcase_name='deterministic_prior', prior_fn=lambda: tfd.Deterministic([-100., 0., 100.]), transition_fn=lambda _, x: tfd.Normal(loc=x, scale=1.)), dict(testcase_name='deterministic_transition', prior_fn=lambda: tfd.Normal(loc=[-100., 0., 100.], scale=1.), transition_fn=lambda _, x: tfd.Deterministic(x)), dict(testcase_name='fully_deterministic', prior_fn=lambda: tfd.Deterministic([-100., 0., 100.]), transition_fn=lambda _, x: tfd.Deterministic(x)), dict(testcase_name='mvn_diag', prior_fn=(lambda: tfd.MultivariateNormalDiag(loc=[[2.], [2.]], scale_diag=[1.])), transition_fn=lambda _, x: tfd.VectorDeterministic(x)), dict(testcase_name='docstring_dirichlet', prior_fn=lambda: tfd.JointDistributionNamedAutoBatched( {'probs': tfd.Dirichlet([1., 1.])}), transition_fn=lambda _, x: tfd.JointDistributionNamedAutoBatched( { 'probs': tfd.MultivariateNormalDiag(loc=x['probs'], scale_diag=[0.1, 0.1]) }, batch_ndims=ps.rank(x['probs']))), dict(testcase_name='uniform_step', prior_fn=lambda: tfd.Exponential(tf.ones([4, 1])), transition_fn=lambda _, x: tfd.Uniform(low=x, high=x + 1.)), dict(testcase_name='joint_distribution', prior_fn=lambda: tfd.JointDistributionNamedAutoBatched( batch_ndims=2, model={ 'a': tfd.Gamma(tf.zeros([5]), 1.), 'b': lambda a: (tfb.Reshape(event_shape_in=[4, 3], event_shape_out=[2, 3, 2]) (tfd.Independent(tfd.Normal( loc=tf.zeros([5, 4, 3]), scale=a[..., tf.newaxis, tf.newaxis]), reinterpreted_batch_ndims=2))) }), transition_fn=lambda _, x: tfd.JointDistributionNamedAutoBatched( batch_ndims=ps.rank_from_shape(x['a'].shape), model={ 'a': tfd.Normal(loc=x['a'], scale=1.), 'b': lambda a: tfd.Deterministic(x['b'] + a[ ..., tf.newaxis, tf.newaxis, tf.newaxis]) })), dict(testcase_name='nested_chain', prior_fn=lambda: tfd. MarkovChain(initial_state_prior=tfb.Split(2) (tfd.MultivariateNormalDiag(0., [1., 2.])), transition_fn=lambda _, x: tfb.Split(2) (tfd.MultivariateNormalDiag(x[0], [1., 2.])), num_steps=6), transition_fn=( lambda _, x: tfd.JointDistributionSequentialAutoBatched( [ tfd.MultivariateNormalDiag(x[0], [1.]), tfd.MultivariateNormalDiag(x[1], [1.]) ], batch_ndims=ps.rank(x[0]))))) # pylint: enable=g-long-lambda def test_default_bijector(self, prior_fn, transition_fn): chain = tfd.MarkovChain(initial_state_prior=prior_fn(), transition_fn=transition_fn, num_steps=7) y = self.evaluate(chain.sample(seed=test_util.test_seed())) bijector = chain.experimental_default_event_space_bijector() self.assertAllEqual(chain.batch_shape_tensor(), bijector.experimental_batch_shape_tensor()) x = bijector.inverse(y) yy = bijector.forward(tf.nest.map_structure( tf.identity, x)) # Bypass bijector cache. self.assertAllCloseNested(y, yy) chain_event_ndims = tf.nest.map_structure(ps.rank_from_shape, chain.event_shape_tensor()) self.assertAllEqualNested(bijector.inverse_min_event_ndims, chain_event_ndims) ildj = bijector.inverse_log_det_jacobian( tf.nest.map_structure(tf.identity, y), # Bypass bijector cache. event_ndims=chain_event_ndims) if not bijector.is_constant_jacobian: self.assertAllEqual(ildj.shape, chain.batch_shape) fldj = bijector.forward_log_det_jacobian( tf.nest.map_structure(tf.identity, x), # Bypass bijector cache. event_ndims=bijector.inverse_event_ndims(chain_event_ndims)) self.assertAllClose(ildj, -fldj) # Verify that event shapes are passed through and flattened/unflattened # correctly. inverse_event_shapes = bijector.inverse_event_shape(chain.event_shape) x_event_shapes = tf.nest.map_structure( lambda t, nd: t.shape[ps.rank(t) - nd:], x, bijector.forward_min_event_ndims) self.assertAllEqualNested(inverse_event_shapes, x_event_shapes) forward_event_shapes = bijector.forward_event_shape( inverse_event_shapes) self.assertAllEqualNested(forward_event_shapes, chain.event_shape) # Verify that the outputs of other methods have the correct structure. inverse_event_shape_tensors = bijector.inverse_event_shape_tensor( chain.event_shape_tensor()) self.assertAllEqualNested(inverse_event_shape_tensors, x_event_shapes) forward_event_shape_tensors = bijector.forward_event_shape_tensor( inverse_event_shape_tensors) self.assertAllEqualNested(forward_event_shape_tensors, chain.event_shape_tensor())