def test_default_event_space_bijector(self): # pylint: disable=bad-whitespace d = tfd.JointDistributionNamed(dict( e = tfd.Independent(tfd.Exponential(rate=[100, 120]), 1), scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]), s = tfd.HalfNormal(2.5), loc =lambda s: tfd.Normal(loc=0, scale=s), df = tfd.Exponential(2), x = tfd.StudentT), validate_args=True) # pylint: enable=bad-whitespace with self.assertRaisesRegex( NotImplementedError, 'all elements of `model` are `tfp.distribution`s'): d._experimental_default_event_space_bijector() d = tfd.JointDistributionNamed( dict(e=tfd.Independent(tfd.Exponential(rate=[10, 12]), 1), x=tfd.Normal(loc=1, scale=1.), s=tfd.HalfNormal(2.5)), validate_args=True) for b in d._model_flatten(d._experimental_default_event_space_bijector()): self.assertIsInstance(b, tfb.Bijector) self.assertSetEqual(set(d.model.keys()), set(d._experimental_default_event_space_bijector( ).keys()))
def test_cross_entropy(self): d0 = tfd.JointDistributionNamed( dict(e=tfd.Independent(tfd.Exponential(rate=[100, 120]), 1), x=tfd.Normal(loc=0, scale=2.)), validate_args=True) d1 = tfd.JointDistributionNamed( dict(e=tfd.Independent(tfd.Exponential(rate=[10, 12]), 1), x=tfd.Normal(loc=1, scale=1.)), validate_args=True) self.assertEqual(d0.model.keys(), d1.model.keys()) expected_xent = sum(d0.model[k].cross_entropy(d1.model[k]) for k in d0.model.keys()) actual_xent = d0.cross_entropy(d1) expected_xent_, actual_xent_ = self.evaluate([expected_xent, actual_xent]) self.assertNear(actual_xent_, expected_xent_, err=1e-5)
def test_default_event_space_bijector(self): # pylint: disable=bad-whitespace d = tfd.JointDistributionNamed(dict( e = tfd.Independent(tfd.Exponential(rate=[100, 120]), 1), scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]), s = tfd.HalfNormal(2.5), loc =lambda s: tfd.Normal(loc=0, scale=s), df = tfd.Exponential(2), x = tfd.StudentT), validate_args=True) # pylint: enable=bad-whitespace # The event space bijector is inherited from `JointDistributionSequential` # and is tested more thoroughly in the tests for that class. b = d.experimental_default_event_space_bijector() y = self.evaluate(d.sample(seed=test_util.test_seed())) y_ = self.evaluate(b.forward(b.inverse(y))) self.assertAllClose(y, y_) # Verify that event shapes are passed through and flattened/unflattened # correctly. forward_event_shapes = b.forward_event_shape(d.event_shape) inverse_event_shapes = b.inverse_event_shape(d.event_shape) self.assertEqual(forward_event_shapes, d.event_shape) self.assertEqual(inverse_event_shapes, d.event_shape) # Verify that the outputs of other methods have the correct dict structure. forward_event_shape_tensors = b.forward_event_shape_tensor( d.event_shape_tensor()) inverse_event_shape_tensors = b.inverse_event_shape_tensor( d.event_shape_tensor()) for item in [forward_event_shape_tensors, inverse_event_shape_tensors]: self.assertSetEqual(set(self.evaluate(item).keys()), set(d.model.keys()))
def test_can_call_log_prob_with_kwargs(self): d = tfd.JointDistributionNamed({ 'e': tfd.Normal(0., 1.), 'a': tfd.Independent( tfd.Exponential(rate=[100, 120]), reinterpreted_batch_ndims=1), 'x': lambda a: tfd.Gamma(concentration=a[..., 0], rate=a[..., 1]) }, validate_args=True) sample = self.evaluate(d.sample([2, 3], seed=test_util.test_seed())) e, a, x = sample['e'], sample['a'], sample['x'] lp_value_positional = self.evaluate(d.log_prob({'e': e, 'a': a, 'x': x})) lp_value_named = self.evaluate(d.log_prob(value={'e': e, 'a': a, 'x': x})) # Assert all close (rather than equal) because order is not defined for # dicts, and reordering the computation can give subtly different results. self.assertAllClose(lp_value_positional, lp_value_named) lp_kwargs = self.evaluate(d.log_prob(a=a, e=e, x=x)) self.assertAllClose(lp_value_positional, lp_kwargs) with self.assertRaisesRegexp(ValueError, 'Joint distribution with unordered variables ' "can't take positional args"): lp_kwargs = d.log_prob(e, a, x)
def get_example_phylo_model( taxon_count: int, site_count: int, sampling_times: tf.Tensor, init_values: tp.Optional[tp.Dict[str, float]] = None, dtype: tf.DType = DEFAULT_FLOAT_DTYPE_TF, tree_name: str = DEFAULT_TREE_NAME, pattern_counts: tp.Optional[tf.Tensor] = None, ) -> tfd.JointDistribution: if init_values is None: init_values = {} constant = partial(tf.constant, dtype=dtype) rate = constant(init_values.get("rate", 1e-3)) model_dict = dict( pop_size=tfd.Exponential(rate=constant(0.1), name="pop_size"), tree=lambda pop_size: ConstantCoalescent( taxon_count=taxon_count, pop_size=pop_size, sampling_times=sampling_times, tree_name=tree_name, ), alignment=strict_clock_fixed_alignment_func(rate, site_count=site_count, weights=pattern_counts), ) return tfd.JointDistributionNamed(model_dict)
def test_legacy_dists(self): class StatefulNormal(tfd.Normal): def _sample_n(self, n, seed=None): return self.loc + self.scale * tf.random.normal( tf.concat([[n], self.batch_shape_tensor()], axis=0), seed=seed) # pylint: disable=bad-whitespace d = tfd.JointDistributionNamed(dict( e = tfd.Independent(tfd.Exponential(rate=[100, 120]), 1), loc = StatefulNormal(loc=0, scale=2.), scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]), m = tfd.Normal, x =lambda m: tfd.Sample(tfd.Bernoulli(logits=m), 12)), validate_args=True) # pylint: enable=bad-whitespace warnings.simplefilter('always') with warnings.catch_warnings(record=True) as w: d.sample(seed=test_util.test_seed()) self.assertRegexpMatches( str(w[0].message), r'Falling back to stateful sampling for.*of type.*StatefulNormal.*' r'component name "loc" and `dist.name` "Normal"', msg=w)
def test_sample_kwargs(self): joint = tfd.JointDistributionNamed( dict( a=tfd.Normal(0., 1.), b=lambda a: tfd.Normal(a, 1.), c=lambda a, b: tfd.Normal(a + b, 1.))) seed = test_util.test_seed() tf.random.set_seed(seed) samples = joint.sample(seed=seed, a=1.) # Check the first value is actually 1. self.assertEqual(1., self.evaluate(samples['a'])) # Check the sample is reproducible using the `value` argument. tf.random.set_seed(seed) samples_named = joint.sample(seed=seed, value={'a': 1.}) self.assertAllEqual(self.evaluate(samples), self.evaluate(samples_named)) # Make sure to throw an exception if strange keywords are passed. expected_error = ( 'Found unexpected keyword arguments. Distribution names are\n' 'a, b, c\n' 'but received\n' 'z\n' 'These names were invalid:\n' 'z') with self.assertRaisesRegex(ValueError, expected_error): joint.sample(seed=seed, z=2.) # Raise if value and keywords are passed. with self.assertRaisesRegex( ValueError, r'Supplied both `value` and keyword arguments .*'): joint.sample(seed=seed, a=1., value={'a': 1})
def test_can_call_ordereddict_log_prob_with_args_and_kwargs( self, model_fn): # With an OrderedDict, we can pass keyword and/or positional args. d = tfd.JointDistributionNamed(model_fn(), validate_args=True) # Destructure vector-valued Tensors into Python lists, to mimic the values # a user might type. sample = tf.nest.map_structure( lambda x: list(x) if isinstance(x, np.ndarray) else x, self.evaluate(d.sample(seed=test_util.test_seed()))) sample_dict = dict(sample) lp_value_positional = self.evaluate(d.log_prob(sample_dict)) lp_value_named = self.evaluate(d.log_prob(value=sample_dict)) self.assertAllClose(lp_value_positional, lp_value_named) lp_args = self.evaluate(d.log_prob(*sample.values())) self.assertAllClose(lp_value_positional, lp_args) lp_kwargs = self.evaluate(d.log_prob(**sample_dict)) self.assertAllClose(lp_value_positional, lp_kwargs) lp_args_then_kwargs = self.evaluate( d.log_prob(*list(sample.values())[:1], **dict(list(sample.items())[1:]))) self.assertAllClose(lp_value_positional, lp_args_then_kwargs)
def test_works_with_structured_samples(self): # Check that we don't accidentally destroy the structure of `samples` when # it's a dict or other non-Tensor object from a joint distribution. p = tfd.JointDistributionNamed({ 'x': tfd.Normal(0., 1.), 'y': tfd.Normal(0., 1.)}) total_variance_with_reparam = tfp.monte_carlo.expectation( f=lambda d: d['x']**2 + d['y']**2, samples=p.sample(1000, seed=42), log_prob=p.log_prob, use_reparametrization=True) total_variance_without_reparam = tfp.monte_carlo.expectation( f=lambda d: d['x']**2 + d['y']**2, samples=p.sample(1000, seed=42), log_prob=p.log_prob, use_reparametrization=False) [ total_variance_with_reparam_, total_variance_without_reparam_ ] = self.evaluate([ total_variance_with_reparam, total_variance_without_reparam]) self.assertAllClose(total_variance_with_reparam_, 2., atol=0.2) self.assertAllClose(total_variance_without_reparam_, 2., atol=0.2)
def test_can_call_namedtuple_log_prob_with_args_and_kwargs(self): # With an namedtuple, we can pass keyword and/or positional args. Model = collections.namedtuple('Model', ['e', 'a', 'x']) # pylint: disable=invalid-name d = tfd.JointDistributionNamed(Model( e=tfd.Normal(0., 1.), a=tfd.Independent(tfd.Exponential(rate=[100, 120]), reinterpreted_batch_ndims=1), x=lambda a: tfd.Gamma(concentration=a[..., 0], rate=a[..., 1])), validate_args=True) sample = self.evaluate(d.sample([2, 3], seed=test_util.test_seed())) e, a, x = sample.e, sample.a, sample.x lp_value_positional = self.evaluate(d.log_prob(Model(e=e, a=a, x=x))) lp_value_named = self.evaluate(d.log_prob(value=Model(e=e, a=a, x=x))) self.assertAllClose(lp_value_positional, lp_value_named) lp_kwargs = self.evaluate(d.log_prob(e=e, a=a, x=x)) self.assertAllClose(lp_value_positional, lp_kwargs) lp_args = self.evaluate(d.log_prob(e, a, x)) self.assertAllClose(lp_value_positional, lp_args) lp_args_then_kwargs = self.evaluate(d.log_prob(e, a=a, x=x)) self.assertAllClose(lp_value_positional, lp_args_then_kwargs)
def test_summary_statistic(self, attr): d = tfd.JointDistributionNamed(dict(logits=tfd.Normal(0., 1.), x=tfd.Bernoulli(logits=0.)), validate_args=True) expected = {k: getattr(d.model[k], attr)() for k in d.model.keys()} actual = getattr(d, attr)() self.assertAllEqual(*self.evaluate([expected, actual]))
def test_sample_shape_propagation_nondefault_behavior(self): # pylint: disable=bad-whitespace d = tfd.JointDistributionNamed(dict( e=tfd.Independent(tfd.Exponential(rate=[100, 120]), 1), scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]), s=tfd.HalfNormal(2.5), loc=lambda s: tfd.Normal(loc=0, scale=s), df=tfd.Exponential(2), x=tfd.StudentT), validate_args=False) # pylint: enable=bad-whitespace # The following enables the nondefault sample shape behavior. d._always_use_specified_sample_shape = True sample_shape = (2, 3) x = d.sample(sample_shape, seed=test_util.test_seed()) self.assertLen(x, 6) self.assertEqual(sample_shape + (2, ), x['e'].shape) self.assertEqual(sample_shape * 2, x['scale'].shape) # Has 1 arg. self.assertEqual(sample_shape * 1, x['s'].shape) # Has 0 args. self.assertEqual(sample_shape * 2, x['loc'].shape) # Has 1 arg. self.assertEqual(sample_shape * 1, x['df'].shape) # Has 0 args. # Has 3 args, one being scalar. self.assertEqual(sample_shape * 3, x['x'].shape) lp = d.log_prob(x) self.assertEqual(sample_shape * 3, lp.shape)
def test_copy(self): pgm = dict(logits=tfd.Normal(0., 1.), probs=tfd.Bernoulli(logits=0.5)) d = tfd.JointDistributionNamed(pgm, validate_args=True) d_copy = d.copy() self.assertEqual(d_copy.parameters['model'], pgm) self.assertEqual(d_copy.parameters['validate_args'], True) self.assertEqual(d_copy.parameters['name'], 'JointDistributionNamed')
def test_sample_complex_dependency(self): # pylint: disable=bad-whitespace d = tfd.JointDistributionNamed(dict( y=tfd.StudentT, x=tfd.StudentT, df=tfd.Exponential(2), loc=lambda s: tfd.Normal(loc=0, scale=s), s=tfd.HalfNormal(2.5), scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]), e=tfd.Independent(tfd.Exponential(rate=[100, 120]), 1)), validate_args=False) # pylint: enable=bad-whitespace self.assertEqual(( ('e', ()), ('scale', ('e', )), ('s', ()), ('loc', ('s', )), ('df', ()), ('y', ('df', 'loc', 'scale')), ('x', ('df', 'loc', 'scale')), ), d.resolve_graph()) x = d.sample() self.assertLen(x, 7) ds, s = d.sample_distributions() self.assertEqual(ds['x'].parameters['df'], s['df']) self.assertEqual(ds['x'].parameters['loc'], s['loc']) self.assertEqual(ds['x'].parameters['scale'], s['scale']) self.assertEqual(ds['y'].parameters['df'], s['df']) self.assertEqual(ds['y'].parameters['loc'], s['loc']) self.assertEqual(ds['y'].parameters['scale'], s['scale'])
def test_namedtuple_sample_log_prob(self): Model = collections.namedtuple('Model', ['e', 'scale', 'loc', 'm', 'x']) # pylint: disable=invalid-name # pylint: disable=bad-whitespace model = Model( e = tfd.Independent(tfd.Exponential(rate=[100, 120]), 1), scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]), loc = tfd.Normal(loc=0, scale=2.), m = tfd.Normal, x =lambda m: tfd.Sample(tfd.Bernoulli(logits=m), 12)) # pylint: enable=bad-whitespace d = tfd.JointDistributionNamed(model, validate_args=True) self.assertEqual( ( ('e', ()), ('scale', ('e',)), ('loc', ()), ('m', ('loc', 'scale')), ('x', ('m',)), ), d.resolve_graph()) xs = d.sample(seed=test_util.test_seed()) self.assertLen(xs, 5) # We'll verify the shapes work as intended when we plumb these back into the # respective log_probs. ds, _ = d.sample_distributions(value=xs, seed=test_util.test_seed()) self.assertLen(ds, 5) self.assertIsInstance(ds.e, tfd.Independent) self.assertIsInstance(ds.scale, tfd.Gamma) self.assertIsInstance(ds.loc, tfd.Normal) self.assertIsInstance(ds.m, tfd.Normal) self.assertIsInstance(ds.x, tfd.Sample) # Static properties. self.assertAllEqual(Model(e=tf.float32, scale=tf.float32, loc=tf.float32, m=tf.float32, x=tf.int32), d.dtype) batch_shape_tensor_, event_shape_tensor_ = self.evaluate([ d.batch_shape_tensor(), d.event_shape_tensor()]) expected_batch_shape = Model(e=[], scale=[], loc=[], m=[], x=[]) for (expected, actual_tensorshape, actual_shape_tensor_) in zip( expected_batch_shape, d.batch_shape, batch_shape_tensor_): self.assertAllEqual(expected, actual_tensorshape) self.assertAllEqual(expected, actual_shape_tensor_) expected_event_shape = Model(e=[2], scale=[], loc=[], m=[], x=[12]) for (expected, actual_tensorshape, actual_shape_tensor_) in zip( expected_event_shape, d.event_shape, event_shape_tensor_): self.assertAllEqual(expected, actual_tensorshape) self.assertAllEqual(expected, actual_shape_tensor_) expected_jlp = sum(d.log_prob(x) for d, x in zip(ds, xs)) actual_jlp = d.log_prob(xs) self.assertAllClose(*self.evaluate([expected_jlp, actual_jlp]), atol=0., rtol=1e-4)
def test_notimplemented_quantile(self): d = tfd.JointDistributionNamed(dict(logits=tfd.Normal(0., 1.), x=tfd.Bernoulli(probs=0.5)), validate_args=True) with self.assertRaisesWithPredicateMatch( NotImplementedError, 'quantile is not implemented: JointDistributionNamed'): d.quantile(0.5)
def test_notimplemented_evaluative_statistic(self, attr): d = tfd.JointDistributionNamed(dict(logits=tfd.Normal(0., 1.), x=tfd.Bernoulli(probs=0.5)), validate_args=True) with self.assertRaisesWithPredicateMatch( NotImplementedError, attr + ' is not implemented: JointDistributionNamed'): getattr(d, attr)(dict(logits=0., x=0.5))
def test_kl_divergence(self): d0 = tfd.JointDistributionNamed( dict(e=tfd.Independent(tfd.Exponential(rate=[100, 120]), 1), x=tfd.Normal(loc=0, scale=2.)), validate_args=True) d1 = tfd.JointDistributionNamed( dict(e=tfd.Independent(tfd.Exponential(rate=[10, 12]), 1), x=tfd.Normal(loc=1, scale=1.)), validate_args=True) self.assertEqual(d0.model.keys(), d1.model.keys()) expected_kl = sum(tfd.kl_divergence(d0.model[k], d1.model[k]) for k in d0.model.keys()) actual_kl = tfd.kl_divergence(d0, d1) other_actual_kl = d0.kl_divergence(d1) expected_kl_, actual_kl_, other_actual_kl_ = self.evaluate([ expected_kl, actual_kl, other_actual_kl]) self.assertNear(expected_kl_, actual_kl_, err=1e-5) self.assertNear(expected_kl_, other_actual_kl_, err=1e-5)
def test_copy(self): pgm = dict(logits=tfd.Normal(0., 1.), probs=tfd.Bernoulli(logits=0.5)) d = tfd.JointDistributionNamed(pgm, validate_args=True) d_copy = d.copy() self.assertAllEqual( {'model': pgm, 'validate_args': True, 'name': 'JointDistributionNamed'}, d_copy.parameters)
def test_nested_partial_value(self, sample_fn): innermost = tfd.JointDistributionNamed({ 'a': tfd.Exponential(1.), 'b': lambda a: tfd.Sample(tfd.LogNormal(a, a), [5]), }) inner = tfd.JointDistributionNamed({ 'c': tfd.Exponential(1.), 'd': innermost, }) outer = tfd.JointDistributionNamed({ 'e': tfd.Exponential(1.), 'f': inner, }) seed = test_util.test_seed(sampler_type='stateless') true_xs = outer.sample(seed=seed) def _update(dict_, **kwargs): dict_.copy().update(**kwargs) return dict_ # These asserts work because we advance the stateless seed inside the model # whether or not a sample is actually generated. partial_xs = _update(true_xs, f=None) xs = sample_fn(outer, value=partial_xs, seed=seed) self.assertAllCloseNested(true_xs, xs) partial_xs = _update(true_xs, e=None) xs = sample_fn(outer, value=partial_xs, seed=seed) self.assertAllCloseNested(true_xs, xs) partial_xs = _update(true_xs, f=_update(true_xs['f'], d=None)) xs = sample_fn(outer, value=partial_xs, seed=seed) self.assertAllCloseNested(true_xs, xs) partial_xs = _update(true_xs, f=_update(true_xs['f'], d=_update(true_xs['f']['d'], a=None))) xs = sample_fn(outer, value=partial_xs, seed=seed) self.assertAllCloseNested(true_xs, xs)
def testPartsWithUnusedInternalStructure(self): dist = tfd.JointDistributionSequential([ tfd.JointDistributionNamed({'a': tfd.Normal(0., 1.)}), tfd.JointDistributionNamed({'b': tfd.Normal(1000., 1.)}), ]) x = dist.sample( # Shape: [{'a': []}, {'b': []}] seed=test_util.test_seed(sampler_type='stateless')) # Test that we can swap the outer list entries, even though they contain # internal structure (i.e., are themselves dicts). swap_elements = tfb.Restructure(input_structure=[1, 0], output_structure=[0, 1]) self.assertAllEqualNested(swap_elements(x), [x[1], x[0]], check_types=True) swapped_dist = swap_elements(dist) self.assertAllEqualNested(swapped_dist.event_shape, [dist.event_shape[1], dist.event_shape[0]], check_types=True) self.assertEqual(swapped_dist.dtype, [dist.dtype[1], dist.dtype[0]])
def test_transform_joint_to_joint(self, split_sizes): dist_batch_shape = tf.nest.pack_sequence_as( split_sizes, [tensorshape_util.constant_value_as_shape(s) for s in [[2, 3], [2, 1], [1, 3]]]) bijector_batch_shape = [1, 3] # Build a joint distribution with parts of the specified sizes. seed = test_util.test_seed_stream() component_dists = tf.nest.map_structure( lambda size, batch_shape: tfd.MultivariateNormalDiag( # pylint: disable=g-long-lambda loc=tf.random.normal(batch_shape + [size], seed=seed()), scale_diag=tf.random.uniform( minval=1., maxval=2., shape=batch_shape + [size], seed=seed())), split_sizes, dist_batch_shape) if isinstance(split_sizes, dict): base_dist = tfd.JointDistributionNamed(component_dists) else: base_dist = tfd.JointDistributionSequential(component_dists) # Transform the distribution by applying a separate bijector to each part. bijectors = [tfb.Exp(), tfb.Scale( tf.random.uniform( minval=1., maxval=2., shape=bijector_batch_shape, seed=seed())), tfb.Reshape([2, 1])] bijector = tfb.JointMap(tf.nest.pack_sequence_as(split_sizes, bijectors), validate_args=True) # Transform a joint distribution that has different batch shape components transformed_dist = tfd.TransformedDistribution(base_dist, bijector) self.assertRegex( str(transformed_dist), '{}.*batch_shape.*event_shape.*dtype'.format(transformed_dist.name)) self.assertAllEqualNested( transformed_dist.event_shape, bijector.forward_event_shape(base_dist.event_shape)) self.assertAllEqualNested(*self.evaluate(( transformed_dist.event_shape_tensor(), bijector.forward_event_shape_tensor(base_dist.event_shape_tensor())))) # Test that the batch shape components of the input are the same as those of # the output. self.assertAllEqualNested(transformed_dist.batch_shape, dist_batch_shape) self.assertAllEqualNested( self.evaluate(transformed_dist.batch_shape_tensor()), dist_batch_shape) self.assertAllEqualNested(dist_batch_shape, base_dist.batch_shape)
def test_graph_resolution(self): # pylint: disable=bad-whitespace d = tfd.JointDistributionNamed(dict( e=tfd.Independent(tfd.Exponential(rate=[100, 120]), 1), scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]), s=tfd.HalfNormal(2.5), loc=lambda s: tfd.Normal(loc=0, scale=s), df=tfd.Exponential(2), x=tfd.StudentT), validate_args=True) # pylint: enable=bad-whitespace self.assertEqual( (('e', ()), ('scale', ('e', )), ('s', ()), ('loc', ('s', )), ('df', ()), ('x', ('df', 'loc', 'scale'))), d.resolve_graph())
def test_dist_fn_takes_kwargs(self): dist = tfd.JointDistributionNamed( {'positive': tfd.Exponential(rate=1.), 'negative': tfb.Scale(-1.)(tfd.Exponential(rate=1.)), 'b': lambda **kwargs: tfd.Normal(loc=kwargs['negative'], # pylint: disable=g-long-lambda scale=kwargs['positive'], validate_args=True), 'a': lambda **kwargs: tfb.Scale(kwargs['b'])( # pylint: disable=g-long-lambda tfd.Gamma(concentration=-kwargs['negative'], rate=kwargs['positive'], validate_args=True)) }, validate_args=True) lp = dist.log_prob(dist.sample(5, seed=test_util.test_seed())) self.assertAllEqual(lp.shape, [5])
def test_get_mean_field_approximation_tree( flat_tree_test_data: TreeTestData, with_init_loc: bool ): test_tree = data_to_tensor_tree(flat_tree_test_data) taxon_count = test_tree.taxon_count tree_name = "tree_dist_name" init_loc: tp.Optional[tp.Dict[str, object]] if with_init_loc: init_loc = dict(tree=test_tree) else: init_loc = None model = tfd.JointDistributionNamed( dict( pop_size=tfd.LogNormal(_constant(0.0), _constant(1.0)), tree=lambda pop_size: ConstantCoalescent( taxon_count, pop_size, test_tree.sampling_times, tree_name=tree_name ), obs=lambda tree: tfd.Normal( _constant(0.0), tf.reduce_sum(tree.branch_lengths) ), ) ) obs = _constant([10.0]) pinned = model.experimental_pin(obs=obs) approximation = get_fixed_topology_mean_field_approximation( pinned, dtype=DEFAULT_FLOAT_DTYPE_TF, topology_pins={tree_name: test_tree.topology}, init_loc=init_loc, ) sample = approximation.sample() assert ( tf.reduce_all( sample["tree"].topology.parent_indices == test_tree.topology.parent_indices ) .numpy() .item() ) assert_allclose( sample["tree"].sampling_times.numpy(), test_tree.sampling_times.numpy() ) model_log_prob = pinned.unnormalized_log_prob(sample) approx_log_prob = approximation.log_prob(sample) assert np.isfinite(model_log_prob.numpy()) assert np.isfinite(approx_log_prob.numpy())
def transition_fn(_, previous_state): return tfd.JointDistributionNamed( { # The autoregressive coefficients and the `log_scale` each follow # an independent slow-moving random walk. 'coefs': tfd.Independent( tfd.Normal(loc=previous_state['coefs'], scale=0.01), reinterpreted_batch_ndims=1), 'log_scale': tfd.Normal(loc=previous_state['log_scale'], scale=0.01), # The level is a linear combination of the previous *two* levels, # with additional noise of scale `exp(log_scale)`. 'level': lambda coefs, log_scale: tfd.Normal( # pylint: disable=g-long-lambda loc=(coefs[..., 0] * previous_state['level'] + coefs[..., 1] * previous_state['previous_level']), scale=tf.exp(log_scale)), # Store the previous level to access at the next step. 'previous_level': tfd.Deterministic(previous_state['level'])})
def test_legacy_dists_stateless_seed_raises(self): class StatefulNormal(tfd.Normal): def _sample_n(self, n, seed=None): return self.loc + self.scale * tf.random.normal(tf.concat( [[n], self.batch_shape_tensor()], axis=0), seed=seed) # pylint: disable=bad-whitespace d = tfd.JointDistributionNamed(dict( e=tfd.Independent(tfd.Exponential(rate=[100, 120]), 1), loc=StatefulNormal(loc=0, scale=2.), scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]), m=tfd.Normal, x=lambda m: tfd.Sample(tfd.Bernoulli(logits=m), 12)), validate_args=True) # pylint: enable=bad-whitespace with self.assertRaisesRegexp(TypeError, r'Expected int for argument'): d.sample(seed=samplers.zeros_seed())
def test_get_mean_field_approximation(): sample_size = 3 model = tfd.JointDistributionNamed( dict( a=tfd.Normal(_constant(0.0), _constant(1.0)), b=lambda a: tfd.Sample(tfd.LogNormal(a, _constant(1.0)), sample_size), obs=lambda b: tfd.Independent( tfd.Normal(b, _constant(1.0)), reinterpreted_batch_ndims=1 ), ) ) obs = _constant([-1.1, 2.1, 0.1]) pinned = model.experimental_pin(obs=obs) approximation = get_mean_field_approximation( pinned, init_loc=dict(a=_constant(0.1)), dtype=DEFAULT_FLOAT_DTYPE_TF ) sample = approximation.sample() model_log_prob = pinned.unnormalized_log_prob(sample) approx_log_prob = approximation.log_prob(sample) assert np.isfinite(model_log_prob.numpy()) assert np.isfinite(approx_log_prob.numpy())
def test_sample_shape_propagation_default_behavior(self): # pylint: disable=bad-whitespace d = tfd.JointDistributionNamed(dict( e=tfd.Independent(tfd.Exponential(rate=[100, 120]), 1), scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]), s=tfd.HalfNormal(2.5), loc=lambda s: tfd.Normal(loc=0, scale=s), df=tfd.Exponential(2), x=tfd.StudentT), validate_args=False) # pylint: enable=bad-whitespace x = d.sample([2, 3], seed=test_util.test_seed()) self.assertLen(x, 6) self.assertEqual((2, 3, 2), x['e'].shape) self.assertEqual((2, 3), x['scale'].shape) self.assertEqual((2, 3), x['s'].shape) self.assertEqual((2, 3), x['loc'].shape) self.assertEqual((2, 3), x['df'].shape) self.assertEqual((2, 3), x['x'].shape) lp = d.log_prob(x) self.assertEqual((2, 3), lp.shape)
def test_batch_slicing(self): # pylint: disable=bad-whitespace d = tfd.JointDistributionNamed(dict( s=tfd.Exponential(rate=[10, 12, 14]), n=lambda s: tfd.Normal(loc=0, scale=s), x=lambda: tfd.Beta(concentration0=[3, 2, 1], concentration1=1)), validate_args=True) # pylint: enable=bad-whitespace d0, d1 = d[:1], d[1:] x0 = d0.sample(seed=test_util.test_seed()) x1 = d1.sample(seed=test_util.test_seed()) self.assertLen(x0, 3) self.assertEqual([1], x0['s'].shape) self.assertEqual([1], x0['n'].shape) self.assertEqual([1], x0['x'].shape) self.assertLen(x1, 3) self.assertEqual([2], x1['s'].shape) self.assertEqual([2], x1['n'].shape) self.assertEqual([2], x1['x'].shape)