def new(params, event_shape='auto', pre_softplus=False, clip_for_stable=True, validate_args=False, name=None): """Create the distribution instance from a `params` vector.""" event_shape = _preprocess_eventshape(params, event_shape) with tf.compat.v1.name_scope(name, 'Dirichlet', [params, event_shape]): params = tf.convert_to_tensor(value=params, name='params') event_shape = dist_util.expand_to_vector(tf.convert_to_tensor( value=event_shape, name='event_shape', dtype=tf.int32), tensor_name='event_shape') output_shape = tf.concat([ tf.shape(input=params)[:-1], event_shape, ], axis=0) # Clips the Dirichlet parameters to the numerically stable KL region if pre_softplus: params = tf.nn.softplus(params) if clip_for_stable: params = tf.clip_by_value(params, 1e-3, 1e3) return tfd.Independent( tfd.Dirichlet(concentration=tf.reshape(params, output_shape), validate_args=validate_args), reinterpreted_batch_ndims=tf.size(input=event_shape), validate_args=validate_args)
def testGradientsThroughParams(self): logits = tf.Variable(np.zeros((3, 5, 2)), dtype=tf.float32, shape=tf.TensorShape([None, None, 2])) concentration = tf.Variable(np.ones((3, 5, 4)), dtype=tf.float32, shape=tf.TensorShape(None)) loc = tf.Variable(np.zeros((3, 5, 4)), dtype=tf.float32, shape=tf.TensorShape(None)) scale = tf.Variable(1., dtype=tf.float32, shape=tf.TensorShape(None)) dist = tfd.Mixture(tfd.Categorical(logits=logits), components=[ tfd.Dirichlet(concentration), tfd.MultivariateNormalDiag( loc=loc, scale_identity_multiplier=scale) ], use_static_graph=self.use_static_graph, validate_args=True) with tf.GradientTape() as tape: loss = tf.reduce_sum(dist.log_prob(tf.ones((3, 5, 4)) / 4.)) grad = tape.gradient(loss, dist.trainable_variables) self.assertLen(grad, 4) self.assertAllNotNone(grad)
def test_bug170030378(self): n_item = 50 n_rater = 7 stream = test_util.test_seed_stream() weight = self.evaluate( tfd.Sample(tfd.Dirichlet([0.25, 0.25]), n_item).sample(seed=stream())) mixture_dist = tfd.Categorical(probs=weight) # batch_shape=[50] rater_sensitivity = self.evaluate( tfd.Sample(tfd.Beta(5., 1.), n_rater).sample(seed=stream())) rater_specificity = self.evaluate( tfd.Sample(tfd.Beta(2., 5.), n_rater).sample(seed=stream())) probs = tf.stack([rater_sensitivity, rater_specificity])[None, ...] components_dist = tfd.BatchBroadcast( # batch_shape=[50, 2] tfd.Independent(tfd.Bernoulli(probs=probs), reinterpreted_batch_ndims=1), [50, 2]) obs_dist = tfd.MixtureSameFamily(mixture_dist, components_dist) observed = self.evaluate(obs_dist.sample(seed=stream())) mixture_logp = obs_dist.log_prob(observed) expected_logp = tf.math.reduce_logsumexp( tf.math.log(weight) + components_dist.distribution.log_prob( observed[:, None, ...]), axis=-1) self.assertAllClose(expected_logp, mixture_logp)
def testExcessiveConcretizationOfParams(self): logits = tfp_hps.defer_and_count_usage( tf.Variable(np.zeros((3, 5, 2)), dtype=tf.float32, shape=tf.TensorShape([None, None, 2]), name='logits')) concentration = tfp_hps.defer_and_count_usage( tf.Variable(np.ones((3, 5, 4)), dtype=tf.float32, shape=tf.TensorShape(None), name='concentration')) loc = tfp_hps.defer_and_count_usage( tf.Variable(np.zeros((3, 5, 4)), dtype=tf.float32, shape=tf.TensorShape(None), name='loc')) scale = tfp_hps.defer_and_count_usage( tf.Variable(1., dtype=tf.float32, shape=tf.TensorShape(None), name='scale')) dist = tfd.Mixture(tfd.Categorical(logits=logits), components=[ tfd.Dirichlet(concentration), tfd.Independent(tfd.Normal(loc=loc, scale=scale), reinterpreted_batch_ndims=1) ], use_static_graph=self.use_static_graph, validate_args=True) for method in ('batch_shape_tensor', 'event_shape_tensor', 'entropy_lower_bound'): with tfp_hps.assert_no_excessive_var_usage(method, max_permissible=2): getattr(dist, method)() with tfp_hps.assert_no_excessive_var_usage('sample', max_permissible=2): dist.sample(seed=test_util.test_seed()) for method in ('prob', 'log_prob'): with tfp_hps.assert_no_excessive_var_usage('method', max_permissible=2): getattr(dist, method)(tf.ones((3, 5, 4)) / 4.) # TODO(b/140579567): The `stddev()` and `variance()` methods require # calling both: # - `self.components[i].mean()` # - `self.components[i].stddev()` # Thus, these methods incur an additional concretization (or two if # `validate_args=True` for `self.components[i]`). for method in ('stddev', 'variance'): with tfp_hps.assert_no_excessive_var_usage(method, max_permissible=3): getattr(dist, method)()
def new(params, event_shape=(), concentration_activation=softplus1, concentration_clip=True, validate_args=False, name="DirichletLayer"): r"""Create the distribution instance from a `params` vector.""" params = tf.convert_to_tensor(value=params, name='params') # Clips the Dirichlet parameters to the numerically stable KL region concentration_activation = parse_activation(concentration_activation, 'tf') params = concentration_activation(params) if concentration_clip: params = tf.clip_by_value(params, 1e-3, 1e3) return tfd.Dirichlet(concentration=params, validate_args=validate_args, name=name)
def get_distributions(self, validate_args=False): self.dist1 = tfd.MultivariateNormalDiag(loc=self.maybe_static( tf.zeros(self.batch_dim_1 + self.event_dim_1, dtype=self.dtype), self.is_static), scale_diag=self.maybe_static( tf.ones(self.batch_dim_1 + self.event_dim_1, dtype=self.dtype), self.is_static)) self.dist2 = tfd.OneHotCategorical(logits=self.maybe_static( tf.zeros(self.batch_dim_2 + self.event_dim_2), self.is_static), dtype=self.dtype) self.dist3 = tfd.Dirichlet( self.maybe_static( tf.zeros(self.batch_dim_3 + self.event_dim_3, dtype=self.dtype), self.is_static)) return batch_concat.BatchConcat( distributions=[self.dist1, self.dist2, self.dist3], axis=self.axis, validate_args=validate_args)
class MarkovChainBijectorTest(test_util.TestCase): # pylint: disable=g-long-lambda @parameterized.named_parameters( dict(testcase_name='deterministic_prior', prior_fn=lambda: tfd.Deterministic([-100., 0., 100.]), transition_fn=lambda _, x: tfd.Normal(loc=x, scale=1.)), dict(testcase_name='deterministic_transition', prior_fn=lambda: tfd.Normal(loc=[-100., 0., 100.], scale=1.), transition_fn=lambda _, x: tfd.Deterministic(x)), dict(testcase_name='fully_deterministic', prior_fn=lambda: tfd.Deterministic([-100., 0., 100.]), transition_fn=lambda _, x: tfd.Deterministic(x)), dict(testcase_name='mvn_diag', prior_fn=(lambda: tfd.MultivariateNormalDiag(loc=[[2.], [2.]], scale_diag=[1.])), transition_fn=lambda _, x: tfd.VectorDeterministic(x)), dict(testcase_name='docstring_dirichlet', prior_fn=lambda: tfd.JointDistributionNamedAutoBatched( {'probs': tfd.Dirichlet([1., 1.])}), transition_fn=lambda _, x: tfd.JointDistributionNamedAutoBatched( { 'probs': tfd.MultivariateNormalDiag(loc=x['probs'], scale_diag=[0.1, 0.1]) }, batch_ndims=ps.rank(x['probs']))), dict(testcase_name='uniform_step', prior_fn=lambda: tfd.Exponential(tf.ones([4, 1])), transition_fn=lambda _, x: tfd.Uniform(low=x, high=x + 1.)), dict(testcase_name='joint_distribution', prior_fn=lambda: tfd.JointDistributionNamedAutoBatched( batch_ndims=2, model={ 'a': tfd.Gamma(tf.zeros([5]), 1.), 'b': lambda a: (tfb.Reshape(event_shape_in=[4, 3], event_shape_out=[2, 3, 2]) (tfd.Independent(tfd.Normal( loc=tf.zeros([5, 4, 3]), scale=a[..., tf.newaxis, tf.newaxis]), reinterpreted_batch_ndims=2))) }), transition_fn=lambda _, x: tfd.JointDistributionNamedAutoBatched( batch_ndims=ps.rank_from_shape(x['a'].shape), model={ 'a': tfd.Normal(loc=x['a'], scale=1.), 'b': lambda a: tfd.Deterministic(x['b'] + a[ ..., tf.newaxis, tf.newaxis, tf.newaxis]) })), dict(testcase_name='nested_chain', prior_fn=lambda: tfd. MarkovChain(initial_state_prior=tfb.Split(2) (tfd.MultivariateNormalDiag(0., [1., 2.])), transition_fn=lambda _, x: tfb.Split(2) (tfd.MultivariateNormalDiag(x[0], [1., 2.])), num_steps=6), transition_fn=( lambda _, x: tfd.JointDistributionSequentialAutoBatched( [ tfd.MultivariateNormalDiag(x[0], [1.]), tfd.MultivariateNormalDiag(x[1], [1.]) ], batch_ndims=ps.rank(x[0]))))) # pylint: enable=g-long-lambda def test_default_bijector(self, prior_fn, transition_fn): chain = tfd.MarkovChain(initial_state_prior=prior_fn(), transition_fn=transition_fn, num_steps=7) y = self.evaluate(chain.sample(seed=test_util.test_seed())) bijector = chain.experimental_default_event_space_bijector() self.assertAllEqual(chain.batch_shape_tensor(), bijector.experimental_batch_shape_tensor()) x = bijector.inverse(y) yy = bijector.forward(tf.nest.map_structure( tf.identity, x)) # Bypass bijector cache. self.assertAllCloseNested(y, yy) chain_event_ndims = tf.nest.map_structure(ps.rank_from_shape, chain.event_shape_tensor()) self.assertAllEqualNested(bijector.inverse_min_event_ndims, chain_event_ndims) ildj = bijector.inverse_log_det_jacobian( tf.nest.map_structure(tf.identity, y), # Bypass bijector cache. event_ndims=chain_event_ndims) if not bijector.is_constant_jacobian: self.assertAllEqual(ildj.shape, chain.batch_shape) fldj = bijector.forward_log_det_jacobian( tf.nest.map_structure(tf.identity, x), # Bypass bijector cache. event_ndims=bijector.inverse_event_ndims(chain_event_ndims)) self.assertAllClose(ildj, -fldj) # Verify that event shapes are passed through and flattened/unflattened # correctly. inverse_event_shapes = bijector.inverse_event_shape(chain.event_shape) x_event_shapes = tf.nest.map_structure( lambda t, nd: t.shape[ps.rank(t) - nd:], x, bijector.forward_min_event_ndims) self.assertAllEqualNested(inverse_event_shapes, x_event_shapes) forward_event_shapes = bijector.forward_event_shape( inverse_event_shapes) self.assertAllEqualNested(forward_event_shapes, chain.event_shape) # Verify that the outputs of other methods have the correct structure. inverse_event_shape_tensors = bijector.inverse_event_shape_tensor( chain.event_shape_tensor()) self.assertAllEqualNested(inverse_event_shape_tensors, x_event_shapes) forward_event_shape_tensors = bijector.forward_event_shape_tensor( inverse_event_shape_tensors) self.assertAllEqualNested(forward_event_shape_tensors, chain.event_shape_tensor())