def test_default_event_space_bijector(self): # pylint: disable=bad-whitespace d = tfd.JointDistributionNamed(dict( e = tfd.Independent(tfd.Exponential(rate=[100, 120]), 1), scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]), s = tfd.HalfNormal(2.5), loc =lambda s: tfd.Normal(loc=0, scale=s), df = tfd.Exponential(2), x = tfd.StudentT), validate_args=True) # pylint: enable=bad-whitespace with self.assertRaisesRegex( NotImplementedError, 'all elements of `model` are `tfp.distribution`s'): d._experimental_default_event_space_bijector() d = tfd.JointDistributionNamed( dict(e=tfd.Independent(tfd.Exponential(rate=[10, 12]), 1), x=tfd.Normal(loc=1, scale=1.), s=tfd.HalfNormal(2.5)), validate_args=True) for b in d._model_flatten(d._experimental_default_event_space_bijector()): self.assertIsInstance(b, tfb.Bijector) self.assertSetEqual(set(d.model.keys()), set(d._experimental_default_event_space_bijector( ).keys()))
def test_sample_shape_propagation_nondefault_behavior(self): # pylint: disable=bad-whitespace d = tfd.JointDistributionNamed(dict( e=tfd.Independent(tfd.Exponential(rate=[100, 120]), 1), scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]), s=tfd.HalfNormal(2.5), loc=lambda s: tfd.Normal(loc=0, scale=s), df=tfd.Exponential(2), x=tfd.StudentT), validate_args=False) # pylint: enable=bad-whitespace # The following enables the nondefault sample shape behavior. d._always_use_specified_sample_shape = True sample_shape = (2, 3) x = d.sample(sample_shape, seed=test_util.test_seed()) self.assertLen(x, 6) self.assertEqual(sample_shape + (2, ), x['e'].shape) self.assertEqual(sample_shape * 2, x['scale'].shape) # Has 1 arg. self.assertEqual(sample_shape * 1, x['s'].shape) # Has 0 args. self.assertEqual(sample_shape * 2, x['loc'].shape) # Has 1 arg. self.assertEqual(sample_shape * 1, x['df'].shape) # Has 0 args. # Has 3 args, one being scalar. self.assertEqual(sample_shape * 3, x['x'].shape) lp = d.log_prob(x) self.assertEqual(sample_shape * 3, lp.shape)
def test_default_event_space_bijector(self): # pylint: disable=bad-whitespace d = tfd.JointDistributionNamed(dict( e = tfd.Independent(tfd.Exponential(rate=[100, 120]), 1), scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]), s = tfd.HalfNormal(2.5), loc =lambda s: tfd.Normal(loc=0, scale=s), df = tfd.Exponential(2), x = tfd.StudentT), validate_args=True) # pylint: enable=bad-whitespace # The event space bijector is inherited from `JointDistributionSequential` # and is tested more thoroughly in the tests for that class. b = d.experimental_default_event_space_bijector() y = self.evaluate(d.sample(seed=test_util.test_seed())) y_ = self.evaluate(b.forward(b.inverse(y))) self.assertAllClose(y, y_) # Verify that event shapes are passed through and flattened/unflattened # correctly. forward_event_shapes = b.forward_event_shape(d.event_shape) inverse_event_shapes = b.inverse_event_shape(d.event_shape) self.assertEqual(forward_event_shapes, d.event_shape) self.assertEqual(inverse_event_shapes, d.event_shape) # Verify that the outputs of other methods have the correct dict structure. forward_event_shape_tensors = b.forward_event_shape_tensor( d.event_shape_tensor()) inverse_event_shape_tensors = b.inverse_event_shape_tensor( d.event_shape_tensor()) for item in [forward_event_shape_tensors, inverse_event_shape_tensors]: self.assertSetEqual(set(self.evaluate(item).keys()), set(d.model.keys()))
def test_sample_complex_dependency(self): # pylint: disable=bad-whitespace d = tfd.JointDistributionNamed(dict( y=tfd.StudentT, x=tfd.StudentT, df=tfd.Exponential(2), loc=lambda s: tfd.Normal(loc=0, scale=s), s=tfd.HalfNormal(2.5), scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]), e=tfd.Independent(tfd.Exponential(rate=[100, 120]), 1)), validate_args=False) # pylint: enable=bad-whitespace self.assertEqual(( ('e', ()), ('scale', ('e', )), ('s', ()), ('loc', ('s', )), ('df', ()), ('y', ('df', 'loc', 'scale')), ('x', ('df', 'loc', 'scale')), ), d.resolve_graph()) x = d.sample() self.assertLen(x, 7) ds, s = d.sample_distributions() self.assertEqual(ds['x'].parameters['df'], s['df']) self.assertEqual(ds['x'].parameters['loc'], s['loc']) self.assertEqual(ds['x'].parameters['scale'], s['scale']) self.assertEqual(ds['y'].parameters['df'], s['df']) self.assertEqual(ds['y'].parameters['loc'], s['loc']) self.assertEqual(ds['y'].parameters['scale'], s['scale'])
def testTransformedKLDifferentBijectorFails(self): d1 = self._cls()(tfd.Exponential(rate=0.25), bijector=tfb.Scale(scale=2.), validate_args=True) d2 = self._cls()(tfd.Exponential(rate=0.25), bijector=tfb.Scale(scale=3.), validate_args=True) with self.assertRaisesRegex(NotImplementedError, r'their bijectors are not equal'): tfd.kl_divergence(d1, d2)
def test_graph_resolution(self): # pylint: disable=bad-whitespace d = tfd.JointDistributionNamed(dict( e=tfd.Independent(tfd.Exponential(rate=[100, 120]), 1), scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]), s=tfd.HalfNormal(2.5), loc=lambda s: tfd.Normal(loc=0, scale=s), df=tfd.Exponential(2), x=tfd.StudentT), validate_args=True) # pylint: enable=bad-whitespace self.assertEqual( (('e', ()), ('scale', ('e', )), ('s', ()), ('loc', ('s', )), ('df', ()), ('x', ('df', 'loc', 'scale'))), d.resolve_graph())
def test_dist_fn_takes_kwargs(self): dist = tfd.JointDistributionNamed( {'positive': tfd.Exponential(rate=1.), 'negative': tfb.Scale(-1.)(tfd.Exponential(rate=1.)), 'b': lambda **kwargs: tfd.Normal(loc=kwargs['negative'], # pylint: disable=g-long-lambda scale=kwargs['positive'], validate_args=True), 'a': lambda **kwargs: tfb.Scale(kwargs['b'])( # pylint: disable=g-long-lambda tfd.Gamma(concentration=-kwargs['negative'], rate=kwargs['positive'], validate_args=True)) }, validate_args=True) lp = dist.log_prob(dist.sample(5, seed=test_util.test_seed())) self.assertAllEqual(lp.shape, [5])
def test_cross_entropy(self): d0 = tfd.JointDistributionNamed( dict(e=tfd.Independent(tfd.Exponential(rate=[100, 120]), 1), x=tfd.Normal(loc=0, scale=2.)), validate_args=True) d1 = tfd.JointDistributionNamed( dict(e=tfd.Independent(tfd.Exponential(rate=[10, 12]), 1), x=tfd.Normal(loc=1, scale=1.)), validate_args=True) self.assertEqual(d0.model.keys(), d1.model.keys()) expected_xent = sum(d0.model[k].cross_entropy(d1.model[k]) for k in d0.model.keys()) actual_xent = d0.cross_entropy(d1) expected_xent_, actual_xent_ = self.evaluate([expected_xent, actual_xent]) self.assertNear(actual_xent_, expected_xent_, err=1e-5)
def testStrWorksCorrectlyScalar(self): normal = tfd.Normal(loc=np.float16(0), scale=1, validate_args=True) self.assertEqual( str(normal), 'tfp.distributions.Normal(' '"Normal", ' 'batch_shape=[], ' 'event_shape=[], ' 'dtype=float16)') chi2 = tfd.Chi2(df=np.float32([1., 2.]), name='silly', validate_args=True) self.assertEqual( str(chi2), 'tfp.distributions.Chi2(' '"silly", ' # What a silly name that is! 'batch_shape=[2], ' 'event_shape=[], ' 'dtype=float32)') # There's no notion of partially known shapes in eager mode, so exit # early. if tf.executing_eagerly(): return exp = tfd.Exponential(rate=tf1.placeholder_with_default(1., shape=None), validate_args=True) self.assertEqual( str(exp), 'tfp.distributions.Exponential("Exponential", ' # No batch shape. 'event_shape=[], ' 'dtype=float32)')
def get_example_phylo_model( taxon_count: int, site_count: int, sampling_times: tf.Tensor, init_values: tp.Optional[tp.Dict[str, float]] = None, dtype: tf.DType = DEFAULT_FLOAT_DTYPE_TF, tree_name: str = DEFAULT_TREE_NAME, pattern_counts: tp.Optional[tf.Tensor] = None, ) -> tfd.JointDistribution: if init_values is None: init_values = {} constant = partial(tf.constant, dtype=dtype) rate = constant(init_values.get("rate", 1e-3)) model_dict = dict( pop_size=tfd.Exponential(rate=constant(0.1), name="pop_size"), tree=lambda pop_size: ConstantCoalescent( taxon_count=taxon_count, pop_size=pop_size, sampling_times=sampling_times, tree_name=tree_name, ), alignment=strict_clock_fixed_alignment_func(rate, site_count=site_count, weights=pattern_counts), ) return tfd.JointDistributionNamed(model_dict)
def test_can_call_log_prob_with_kwargs(self): d = tfd.JointDistributionNamed({ 'e': tfd.Normal(0., 1.), 'a': tfd.Independent( tfd.Exponential(rate=[100, 120]), reinterpreted_batch_ndims=1), 'x': lambda a: tfd.Gamma(concentration=a[..., 0], rate=a[..., 1]) }, validate_args=True) sample = self.evaluate(d.sample([2, 3], seed=test_util.test_seed())) e, a, x = sample['e'], sample['a'], sample['x'] lp_value_positional = self.evaluate(d.log_prob({'e': e, 'a': a, 'x': x})) lp_value_named = self.evaluate(d.log_prob(value={'e': e, 'a': a, 'x': x})) # Assert all close (rather than equal) because order is not defined for # dicts, and reordering the computation can give subtly different results. self.assertAllClose(lp_value_positional, lp_value_named) lp_kwargs = self.evaluate(d.log_prob(a=a, e=e, x=x)) self.assertAllClose(lp_value_positional, lp_kwargs) with self.assertRaisesRegexp(ValueError, 'Joint distribution with unordered variables ' "can't take positional args"): lp_kwargs = d.log_prob(e, a, x)
def test_can_call_namedtuple_log_prob_with_args_and_kwargs(self): # With an namedtuple, we can pass keyword and/or positional args. Model = collections.namedtuple('Model', ['e', 'a', 'x']) # pylint: disable=invalid-name d = tfd.JointDistributionNamed(Model( e=tfd.Normal(0., 1.), a=tfd.Independent(tfd.Exponential(rate=[100, 120]), reinterpreted_batch_ndims=1), x=lambda a: tfd.Gamma(concentration=a[..., 0], rate=a[..., 1])), validate_args=True) sample = self.evaluate(d.sample([2, 3], seed=test_util.test_seed())) e, a, x = sample.e, sample.a, sample.x lp_value_positional = self.evaluate(d.log_prob(Model(e=e, a=a, x=x))) lp_value_named = self.evaluate(d.log_prob(value=Model(e=e, a=a, x=x))) self.assertAllClose(lp_value_positional, lp_value_named) lp_kwargs = self.evaluate(d.log_prob(e=e, a=a, x=x)) self.assertAllClose(lp_value_positional, lp_kwargs) lp_args = self.evaluate(d.log_prob(e, a, x)) self.assertAllClose(lp_value_positional, lp_args) lp_args_then_kwargs = self.evaluate(d.log_prob(e, a=a, x=x)) self.assertAllClose(lp_value_positional, lp_args_then_kwargs)
def testReprWorksCorrectlyScalar(self): normal = tfd.Normal(loc=np.float16(0), scale=np.float16(1)) self.assertEqual( repr(normal), '<tfp.distributions.Normal' ' \'Normal\'' ' batch_shape=[]' ' event_shape=[]' ' dtype=float16>') chi2 = tfd.Chi2(df=np.float32([1., 2.]), name='silly') self.assertEqual( repr(chi2), '<tfp.distributions.Chi2' ' \'silly\'' # What a silly name that is! ' batch_shape=[2]' ' event_shape=[]' ' dtype=float32>') # There's no notion of partially known shapes in eager mode, so exit # early. if tf.executing_eagerly(): return exp = tfd.Exponential( rate=tf1.placeholder_with_default(1., shape=None)) self.assertEqual( repr(exp), '<tfp.distributions.Exponential' ' \'Exponential\'' ' batch_shape=?' ' event_shape=[]' ' dtype=float32>')
def test_scalar_distributions(self): self.dist1 = tfd.Normal( loc=self.maybe_static( tf.zeros(self.batch_dim_1, dtype=self.dtype), self.is_static), scale=self.maybe_static( tf.ones(self.batch_dim_1, dtype=self.dtype), self.is_static) ) self.dist2 = tfd.Logistic( loc=self.maybe_static( tf.zeros(self.batch_dim_2, dtype=self.dtype), self.is_static), scale=self.maybe_static( tf.ones(self.batch_dim_2, dtype=self.dtype), self.is_static) ) self.dist3 = tfd.Exponential( rate=self.maybe_static( tf.ones(self.batch_dim_3, dtype=self.dtype), self.is_static) ) concat_dist = batch_concat.BatchConcat( distributions=[self.dist1, self.dist2, self.dist3], axis=1, validate_args=False) self.assertAllEqual( self.evaluate(concat_dist.batch_shape_tensor()), [2, 6, 4]) seed = test_util.test_seed() samples = concat_dist.sample(seed=seed) self.assertAllEqual(self.evaluate(tf.shape(samples)), [2, 6, 4])
def test_legacy_dists(self): class StatefulNormal(tfd.Normal): def _sample_n(self, n, seed=None): return self.loc + self.scale * tf.random.normal( tf.concat([[n], self.batch_shape_tensor()], axis=0), seed=seed) # pylint: disable=bad-whitespace d = tfd.JointDistributionNamed(dict( e = tfd.Independent(tfd.Exponential(rate=[100, 120]), 1), loc = StatefulNormal(loc=0, scale=2.), scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]), m = tfd.Normal, x =lambda m: tfd.Sample(tfd.Bernoulli(logits=m), 12)), validate_args=True) # pylint: enable=bad-whitespace warnings.simplefilter('always') with warnings.catch_warnings(record=True) as w: d.sample(seed=test_util.test_seed()) self.assertRegexpMatches( str(w[0].message), r'Falling back to stateful sampling for.*of type.*StatefulNormal.*' r'component name "loc" and `dist.name` "Normal"', msg=w)
def test_namedtuple_sample_log_prob(self): Model = collections.namedtuple('Model', ['e', 'scale', 'loc', 'm', 'x']) # pylint: disable=invalid-name # pylint: disable=bad-whitespace model = Model( e = tfd.Independent(tfd.Exponential(rate=[100, 120]), 1), scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]), loc = tfd.Normal(loc=0, scale=2.), m = tfd.Normal, x =lambda m: tfd.Sample(tfd.Bernoulli(logits=m), 12)) # pylint: enable=bad-whitespace d = tfd.JointDistributionNamed(model, validate_args=True) self.assertEqual( ( ('e', ()), ('scale', ('e',)), ('loc', ()), ('m', ('loc', 'scale')), ('x', ('m',)), ), d.resolve_graph()) xs = d.sample(seed=test_util.test_seed()) self.assertLen(xs, 5) # We'll verify the shapes work as intended when we plumb these back into the # respective log_probs. ds, _ = d.sample_distributions(value=xs, seed=test_util.test_seed()) self.assertLen(ds, 5) self.assertIsInstance(ds.e, tfd.Independent) self.assertIsInstance(ds.scale, tfd.Gamma) self.assertIsInstance(ds.loc, tfd.Normal) self.assertIsInstance(ds.m, tfd.Normal) self.assertIsInstance(ds.x, tfd.Sample) # Static properties. self.assertAllEqual(Model(e=tf.float32, scale=tf.float32, loc=tf.float32, m=tf.float32, x=tf.int32), d.dtype) batch_shape_tensor_, event_shape_tensor_ = self.evaluate([ d.batch_shape_tensor(), d.event_shape_tensor()]) expected_batch_shape = Model(e=[], scale=[], loc=[], m=[], x=[]) for (expected, actual_tensorshape, actual_shape_tensor_) in zip( expected_batch_shape, d.batch_shape, batch_shape_tensor_): self.assertAllEqual(expected, actual_tensorshape) self.assertAllEqual(expected, actual_shape_tensor_) expected_event_shape = Model(e=[2], scale=[], loc=[], m=[], x=[12]) for (expected, actual_tensorshape, actual_shape_tensor_) in zip( expected_event_shape, d.event_shape, event_shape_tensor_): self.assertAllEqual(expected, actual_tensorshape) self.assertAllEqual(expected, actual_shape_tensor_) expected_jlp = sum(d.log_prob(x) for d, x in zip(ds, xs)) actual_jlp = d.log_prob(xs) self.assertAllClose(*self.evaluate([expected_jlp, actual_jlp]), atol=0., rtol=1e-4)
def basic_ordered_model_fn(): return collections.OrderedDict(( ('a', tfd.Normal(0., 1.)), ('e', tfd.Independent(tfd.Exponential(rate=[100, 120]), reinterpreted_batch_ndims=1)), ('x', lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1])), ))
def testScalarBatchScalarEventIdentityScale(self): exp2 = self._cls()( tfd.Exponential(rate=0.25), bijector=tfb.AffineScalar(scale=2.)) log_prob = exp2.log_prob(1.) log_prob_ = self.evaluate(log_prob) base_log_prob = -0.5 * 0.25 + np.log(0.25) ildj = np.log(2.) self.assertAllClose(base_log_prob - ildj, log_prob_, rtol=1e-6, atol=0.)
def test_kl_divergence(self): d0 = tfd.JointDistributionNamed( dict(e=tfd.Independent(tfd.Exponential(rate=[100, 120]), 1), x=tfd.Normal(loc=0, scale=2.)), validate_args=True) d1 = tfd.JointDistributionNamed( dict(e=tfd.Independent(tfd.Exponential(rate=[10, 12]), 1), x=tfd.Normal(loc=1, scale=1.)), validate_args=True) self.assertEqual(d0.model.keys(), d1.model.keys()) expected_kl = sum(tfd.kl_divergence(d0.model[k], d1.model[k]) for k in d0.model.keys()) actual_kl = tfd.kl_divergence(d0, d1) other_actual_kl = d0.kl_divergence(d1) expected_kl_, actual_kl_, other_actual_kl_ = self.evaluate([ expected_kl, actual_kl, other_actual_kl]) self.assertNear(expected_kl_, actual_kl_, err=1e-5) self.assertNear(expected_kl_, other_actual_kl_, err=1e-5)
def test_nested_partial_value(self, sample_fn): innermost = tfd.JointDistributionNamed({ 'a': tfd.Exponential(1.), 'b': lambda a: tfd.Sample(tfd.LogNormal(a, a), [5]), }) inner = tfd.JointDistributionNamed({ 'c': tfd.Exponential(1.), 'd': innermost, }) outer = tfd.JointDistributionNamed({ 'e': tfd.Exponential(1.), 'f': inner, }) seed = test_util.test_seed(sampler_type='stateless') true_xs = outer.sample(seed=seed) def _update(dict_, **kwargs): dict_.copy().update(**kwargs) return dict_ # These asserts work because we advance the stateless seed inside the model # whether or not a sample is actually generated. partial_xs = _update(true_xs, f=None) xs = sample_fn(outer, value=partial_xs, seed=seed) self.assertAllCloseNested(true_xs, xs) partial_xs = _update(true_xs, e=None) xs = sample_fn(outer, value=partial_xs, seed=seed) self.assertAllCloseNested(true_xs, xs) partial_xs = _update(true_xs, f=_update(true_xs['f'], d=None)) xs = sample_fn(outer, value=partial_xs, seed=seed) self.assertAllCloseNested(true_xs, xs) partial_xs = _update(true_xs, f=_update(true_xs['f'], d=_update(true_xs['f']['d'], a=None))) xs = sample_fn(outer, value=partial_xs, seed=seed) self.assertAllCloseNested(true_xs, xs)
def testScalarBatchScalarEventIdentityScale(self): exp2 = tfd.TransformedDistribution( tfd.Exponential(rate=0.25), bijector=tfb.Scale(scale=2.), validate_args=True) log_prob = exp2.log_prob(1.) log_prob_ = self.evaluate(log_prob) base_log_prob = -0.5 * 0.25 + np.log(0.25) ildj = np.log(2.) self.assertAllClose(base_log_prob - ildj, log_prob_, rtol=1e-6, atol=0.)
def nested_lists_model_fn(): return collections.OrderedDict(( ('abc', tfd.JointDistributionSequential([ tfd.MultivariateNormalDiag([0., 0.], [1., 1.]), tfd.JointDistributionSequential( [tfd.StudentT(3., -2., 5.), tfd.Exponential(4.)])])), ('de', lambda abc: tfd.JointDistributionSequential([ # pylint: disable=g-long-lambda tfd.Normal(abc[0] * abc[1][0], abc[1][1]), tfd.Normal(abc[0] + abc[1][0], abc[1][1])]))))
def test_sample_shape_propagation_default_behavior(self): # pylint: disable=bad-whitespace d = tfd.JointDistributionNamed(dict( e=tfd.Independent(tfd.Exponential(rate=[100, 120]), 1), scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]), s=tfd.HalfNormal(2.5), loc=lambda s: tfd.Normal(loc=0, scale=s), df=tfd.Exponential(2), x=tfd.StudentT), validate_args=False) # pylint: enable=bad-whitespace x = d.sample([2, 3], seed=test_util.test_seed()) self.assertLen(x, 6) self.assertEqual((2, 3, 2), x['e'].shape) self.assertEqual((2, 3), x['scale'].shape) self.assertEqual((2, 3), x['s'].shape) self.assertEqual((2, 3), x['loc'].shape) self.assertEqual((2, 3), x['df'].shape) self.assertEqual((2, 3), x['x'].shape) lp = d.log_prob(x) self.assertEqual((2, 3), lp.shape)
def test_legacy_dists_stateless_seed_raises(self): class StatefulNormal(tfd.Normal): def _sample_n(self, n, seed=None): return self.loc + self.scale * tf.random.normal(tf.concat( [[n], self.batch_shape_tensor()], axis=0), seed=seed) # pylint: disable=bad-whitespace d = tfd.JointDistributionNamed(dict( e=tfd.Independent(tfd.Exponential(rate=[100, 120]), 1), loc=StatefulNormal(loc=0, scale=2.), scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]), m=tfd.Normal, x=lambda m: tfd.Sample(tfd.Bernoulli(logits=m), 12)), validate_args=True) # pylint: enable=bad-whitespace with self.assertRaisesRegexp(TypeError, r'Expected int for argument'): d.sample(seed=samplers.zeros_seed())
def test_batch_slicing(self): # pylint: disable=bad-whitespace d = tfd.JointDistributionNamed(dict( s=tfd.Exponential(rate=[10, 12, 14]), n=lambda s: tfd.Normal(loc=0, scale=s), x=lambda: tfd.Beta(concentration0=[3, 2, 1], concentration1=1)), validate_args=True) # pylint: enable=bad-whitespace d0, d1 = d[:1], d[1:] x0 = d0.sample(seed=test_util.test_seed()) x1 = d1.sample(seed=test_util.test_seed()) self.assertLen(x0, 3) self.assertEqual([1], x0['s'].shape) self.assertEqual([1], x0['n'].shape) self.assertEqual([1], x0['x'].shape) self.assertLen(x1, 3) self.assertEqual([2], x1['s'].shape) self.assertEqual([2], x1['n'].shape) self.assertEqual([2], x1['x'].shape)
def test_custom_weights_prior(self): batch_shape = [4, 3] num_timesteps = 10 num_features = 2 design_matrix = self._build_placeholder( np.random.randn(*(batch_shape + [num_timesteps, num_features]))) # Build a model with scalar Exponential(1.) prior. linear_regression = LinearRegression( design_matrix=design_matrix, weights_prior=tfd.Exponential( rate=self._build_placeholder(np.ones(batch_shape)))) # Check that the prior is broadcast to match the shape of the weights. weights = linear_regression.parameters[0] self.assertAllEqual([num_features], self.evaluate(weights.prior.event_shape_tensor())) self.assertAllEqual(batch_shape, self.evaluate(weights.prior.batch_shape_tensor())) prior_sampled_weights = weights.prior.sample() ssm = linear_regression.make_state_space_model( num_timesteps=num_timesteps, param_vals={"weights": prior_sampled_weights}) lp = ssm.log_prob(ssm.sample()) self.assertAllEqual(batch_shape, self.evaluate(lp).shape) # Verify that the bijector enforces the prior constraint that # weights must be nonnegative. self.assertAllFinite( self.evaluate( weights.prior.log_prob( weights.bijector( tf.random.normal(tf.shape(weights.prior.sample(64)), seed=test_util.test_seed(), dtype=self.dtype)))))
def test_dict_sample_log_prob(self): # pylint: disable=bad-whitespace d = tfd.JointDistributionNamed(dict( e=tfd.Independent(tfd.Exponential(rate=[100, 120]), 1), scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]), loc=tfd.Normal(loc=0, scale=2.), m=tfd.Normal, x=lambda m: tfd.Sample(tfd.Bernoulli(logits=m), 12)), validate_args=True) # pylint: enable=bad-whitespace self.assertEqual(( ('e', ()), ('scale', ('e', )), ('loc', ()), ('m', ('loc', 'scale')), ('x', ('m', )), ), d.resolve_graph()) xs = d.sample(seed=test_util.test_seed()) self.assertLen(xs, 5) # We'll verify the shapes work as intended when we plumb these back into the # respective log_probs. ds, _ = d.sample_distributions(value=xs) self.assertLen(ds, 5) self.assertIsInstance(ds['e'], tfd.Independent) self.assertIsInstance(ds['scale'], tfd.Gamma) self.assertIsInstance(ds['loc'], tfd.Normal) self.assertIsInstance(ds['m'], tfd.Normal) self.assertIsInstance(ds['x'], tfd.Sample) # Static properties. self.assertAllEqual( { 'e': tf.float32, 'scale': tf.float32, 'loc': tf.float32, 'm': tf.float32, 'x': tf.int32 }, d.dtype) batch_shape_tensor_, event_shape_tensor_ = self.evaluate( [d.batch_shape_tensor(), d.event_shape_tensor()]) expected_batch_shape = { 'e': [], 'scale': [], 'loc': [], 'm': [], 'x': [] } batch_tensorshape = d.batch_shape for k in expected_batch_shape: self.assertAllEqual(expected_batch_shape[k], batch_tensorshape[k]) self.assertAllEqual(expected_batch_shape[k], batch_shape_tensor_[k]) expected_event_shape = { 'e': [2], 'scale': [], 'loc': [], 'm': [], 'x': [12] } event_tensorshape = d.event_shape for k in expected_event_shape: self.assertAllEqual(expected_event_shape[k], event_tensorshape[k]) self.assertAllEqual(expected_event_shape[k], event_shape_tensor_[k]) expected_jlp = sum(ds[k].log_prob(xs[k]) for k in ds.keys()) actual_jlp = d.log_prob(xs) self.assertAllClose(*self.evaluate([expected_jlp, actual_jlp]), atol=0., rtol=1e-4)
tfd.Normal(loc=-1., scale=0.1), tfd.Normal(loc=-0.5, scale=0.1), tfd.Normal(loc=0., scale=0.1), tfd.Normal(loc=0.5, scale=0.1), tfd.Normal(loc=1., scale=0.1) ]) AsymmetricDoubleClaw = tfd.Mixture(cat=tfd.Categorical(probs=[ 46. / 100., 46. / 100, 1. / 300., 1. / 300., 1. / 300., 7. / 300., 7. / 300., 7. / 300. ]), components=[ tfd.Normal(loc=-1., scale=2. / 3.), tfd.Normal(loc=1., scale=2. / 3.), tfd.Normal(loc=-1. / 2., scale=0.01), tfd.Normal(loc=-1., scale=0.01), tfd.Normal(loc=-3. / 2., scale=0.01), tfd.Normal(loc=1. / 2., scale=0.07), tfd.Normal(loc=1., scale=0.07), tfd.Normal(loc=3. / 2., scale=0.07) ]) Mix3gauss1exp1uni = tfd.Mixture( cat=tfd.Categorical(probs=[0.1, 0.2, 0.1, 0.4, 0.2]), components=[ tfd.Normal(loc=-1., scale=0.4), tfd.Normal(loc=+1., scale=0.5), tfd.Normal(loc=+1., scale=0.3), tfd.Exponential(rate=2), tfd.Uniform(low=-5, high=5) ])
def test_slicing_does_not_modify_the_sliced_distribution(self): dist = tfd.Exponential(tf.ones((5, 2, 3))) sliced = dist[:4, :, 2] self.assertAllEqual([2], sliced[-1].batch_shape_tensor()) self.assertAllEqual([3], sliced[:-1, 1].batch_shape_tensor())
class MarkovChainBijectorTest(test_util.TestCase): # pylint: disable=g-long-lambda @parameterized.named_parameters( dict(testcase_name='deterministic_prior', prior_fn=lambda: tfd.Deterministic([-100., 0., 100.]), transition_fn=lambda _, x: tfd.Normal(loc=x, scale=1.)), dict(testcase_name='deterministic_transition', prior_fn=lambda: tfd.Normal(loc=[-100., 0., 100.], scale=1.), transition_fn=lambda _, x: tfd.Deterministic(x)), dict(testcase_name='fully_deterministic', prior_fn=lambda: tfd.Deterministic([-100., 0., 100.]), transition_fn=lambda _, x: tfd.Deterministic(x)), dict(testcase_name='mvn_diag', prior_fn=(lambda: tfd.MultivariateNormalDiag(loc=[[2.], [2.]], scale_diag=[1.])), transition_fn=lambda _, x: tfd.VectorDeterministic(x)), dict(testcase_name='docstring_dirichlet', prior_fn=lambda: tfd.JointDistributionNamedAutoBatched( {'probs': tfd.Dirichlet([1., 1.])}), transition_fn=lambda _, x: tfd.JointDistributionNamedAutoBatched( { 'probs': tfd.MultivariateNormalDiag(loc=x['probs'], scale_diag=[0.1, 0.1]) }, batch_ndims=ps.rank(x['probs']))), dict(testcase_name='uniform_step', prior_fn=lambda: tfd.Exponential(tf.ones([4, 1])), transition_fn=lambda _, x: tfd.Uniform(low=x, high=x + 1.)), dict(testcase_name='joint_distribution', prior_fn=lambda: tfd.JointDistributionNamedAutoBatched( batch_ndims=2, model={ 'a': tfd.Gamma(tf.zeros([5]), 1.), 'b': lambda a: (tfb.Reshape(event_shape_in=[4, 3], event_shape_out=[2, 3, 2]) (tfd.Independent(tfd.Normal( loc=tf.zeros([5, 4, 3]), scale=a[..., tf.newaxis, tf.newaxis]), reinterpreted_batch_ndims=2))) }), transition_fn=lambda _, x: tfd.JointDistributionNamedAutoBatched( batch_ndims=ps.rank_from_shape(x['a'].shape), model={ 'a': tfd.Normal(loc=x['a'], scale=1.), 'b': lambda a: tfd.Deterministic(x['b'] + a[ ..., tf.newaxis, tf.newaxis, tf.newaxis]) })), dict(testcase_name='nested_chain', prior_fn=lambda: tfd. MarkovChain(initial_state_prior=tfb.Split(2) (tfd.MultivariateNormalDiag(0., [1., 2.])), transition_fn=lambda _, x: tfb.Split(2) (tfd.MultivariateNormalDiag(x[0], [1., 2.])), num_steps=6), transition_fn=( lambda _, x: tfd.JointDistributionSequentialAutoBatched( [ tfd.MultivariateNormalDiag(x[0], [1.]), tfd.MultivariateNormalDiag(x[1], [1.]) ], batch_ndims=ps.rank(x[0]))))) # pylint: enable=g-long-lambda def test_default_bijector(self, prior_fn, transition_fn): chain = tfd.MarkovChain(initial_state_prior=prior_fn(), transition_fn=transition_fn, num_steps=7) y = self.evaluate(chain.sample(seed=test_util.test_seed())) bijector = chain.experimental_default_event_space_bijector() self.assertAllEqual(chain.batch_shape_tensor(), bijector.experimental_batch_shape_tensor()) x = bijector.inverse(y) yy = bijector.forward(tf.nest.map_structure( tf.identity, x)) # Bypass bijector cache. self.assertAllCloseNested(y, yy) chain_event_ndims = tf.nest.map_structure(ps.rank_from_shape, chain.event_shape_tensor()) self.assertAllEqualNested(bijector.inverse_min_event_ndims, chain_event_ndims) ildj = bijector.inverse_log_det_jacobian( tf.nest.map_structure(tf.identity, y), # Bypass bijector cache. event_ndims=chain_event_ndims) if not bijector.is_constant_jacobian: self.assertAllEqual(ildj.shape, chain.batch_shape) fldj = bijector.forward_log_det_jacobian( tf.nest.map_structure(tf.identity, x), # Bypass bijector cache. event_ndims=bijector.inverse_event_ndims(chain_event_ndims)) self.assertAllClose(ildj, -fldj) # Verify that event shapes are passed through and flattened/unflattened # correctly. inverse_event_shapes = bijector.inverse_event_shape(chain.event_shape) x_event_shapes = tf.nest.map_structure( lambda t, nd: t.shape[ps.rank(t) - nd:], x, bijector.forward_min_event_ndims) self.assertAllEqualNested(inverse_event_shapes, x_event_shapes) forward_event_shapes = bijector.forward_event_shape( inverse_event_shapes) self.assertAllEqualNested(forward_event_shapes, chain.event_shape) # Verify that the outputs of other methods have the correct structure. inverse_event_shape_tensors = bijector.inverse_event_shape_tensor( chain.event_shape_tensor()) self.assertAllEqualNested(inverse_event_shape_tensors, x_event_shapes) forward_event_shape_tensors = bijector.forward_event_shape_tensor( inverse_event_shape_tensors) self.assertAllEqualNested(forward_event_shape_tensors, chain.event_shape_tensor())