Пример #1
0
  def test_default_event_space_bijector(self):
    # pylint: disable=bad-whitespace
    d = tfd.JointDistributionNamed(dict(
        e    =          tfd.Independent(tfd.Exponential(rate=[100, 120]), 1),
        scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]),
        s    =          tfd.HalfNormal(2.5),
        loc  =lambda s: tfd.Normal(loc=0, scale=s),
        df   =          tfd.Exponential(2),
        x    =          tfd.StudentT),
                                   validate_args=True)
    # pylint: enable=bad-whitespace
    with self.assertRaisesRegex(
        NotImplementedError, 'all elements of `model` are `tfp.distribution`s'):
      d._experimental_default_event_space_bijector()

    d = tfd.JointDistributionNamed(
        dict(e=tfd.Independent(tfd.Exponential(rate=[10, 12]), 1),
             x=tfd.Normal(loc=1, scale=1.),
             s=tfd.HalfNormal(2.5)),
        validate_args=True)
    for b in d._model_flatten(d._experimental_default_event_space_bijector()):
      self.assertIsInstance(b, tfb.Bijector)
    self.assertSetEqual(set(d.model.keys()),
                        set(d._experimental_default_event_space_bijector(
                            ).keys()))
Пример #2
0
 def test_cross_entropy(self):
   d0 = tfd.JointDistributionNamed(
       dict(e=tfd.Independent(tfd.Exponential(rate=[100, 120]), 1),
            x=tfd.Normal(loc=0, scale=2.)),
       validate_args=True)
   d1 = tfd.JointDistributionNamed(
       dict(e=tfd.Independent(tfd.Exponential(rate=[10, 12]), 1),
            x=tfd.Normal(loc=1, scale=1.)),
       validate_args=True)
   self.assertEqual(d0.model.keys(), d1.model.keys())
   expected_xent = sum(d0.model[k].cross_entropy(d1.model[k])
                       for k in d0.model.keys())
   actual_xent = d0.cross_entropy(d1)
   expected_xent_, actual_xent_ = self.evaluate([expected_xent, actual_xent])
   self.assertNear(actual_xent_, expected_xent_, err=1e-5)
Пример #3
0
  def test_default_event_space_bijector(self):
    # pylint: disable=bad-whitespace
    d = tfd.JointDistributionNamed(dict(
        e    =          tfd.Independent(tfd.Exponential(rate=[100, 120]), 1),
        scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]),
        s    =          tfd.HalfNormal(2.5),
        loc  =lambda s: tfd.Normal(loc=0, scale=s),
        df   =          tfd.Exponential(2),
        x    =          tfd.StudentT),
                                   validate_args=True)
    # pylint: enable=bad-whitespace

    # The event space bijector is inherited from `JointDistributionSequential`
    # and is tested more thoroughly in the tests for that class.
    b = d.experimental_default_event_space_bijector()
    y = self.evaluate(d.sample(seed=test_util.test_seed()))
    y_ = self.evaluate(b.forward(b.inverse(y)))
    self.assertAllClose(y, y_)

    # Verify that event shapes are passed through and flattened/unflattened
    # correctly.
    forward_event_shapes = b.forward_event_shape(d.event_shape)
    inverse_event_shapes = b.inverse_event_shape(d.event_shape)
    self.assertEqual(forward_event_shapes, d.event_shape)
    self.assertEqual(inverse_event_shapes, d.event_shape)

    # Verify that the outputs of other methods have the correct dict structure.
    forward_event_shape_tensors = b.forward_event_shape_tensor(
        d.event_shape_tensor())
    inverse_event_shape_tensors = b.inverse_event_shape_tensor(
        d.event_shape_tensor())
    for item in [forward_event_shape_tensors, inverse_event_shape_tensors]:
      self.assertSetEqual(set(self.evaluate(item).keys()), set(d.model.keys()))
Пример #4
0
  def test_can_call_log_prob_with_kwargs(self):

    d = tfd.JointDistributionNamed({
        'e': tfd.Normal(0., 1.),
        'a': tfd.Independent(
            tfd.Exponential(rate=[100, 120]),
            reinterpreted_batch_ndims=1),
        'x': lambda a: tfd.Gamma(concentration=a[..., 0], rate=a[..., 1])
    }, validate_args=True)

    sample = self.evaluate(d.sample([2, 3], seed=test_util.test_seed()))
    e, a, x = sample['e'], sample['a'], sample['x']

    lp_value_positional = self.evaluate(d.log_prob({'e': e, 'a': a, 'x': x}))
    lp_value_named = self.evaluate(d.log_prob(value={'e': e, 'a': a, 'x': x}))
    # Assert all close (rather than equal) because order is not defined for
    # dicts, and reordering the computation can give subtly different results.
    self.assertAllClose(lp_value_positional, lp_value_named)

    lp_kwargs = self.evaluate(d.log_prob(a=a, e=e, x=x))
    self.assertAllClose(lp_value_positional, lp_kwargs)

    with self.assertRaisesRegexp(ValueError,
                                 'Joint distribution with unordered variables '
                                 "can't take positional args"):
      lp_kwargs = d.log_prob(e, a, x)
Пример #5
0
def get_example_phylo_model(
    taxon_count: int,
    site_count: int,
    sampling_times: tf.Tensor,
    init_values: tp.Optional[tp.Dict[str, float]] = None,
    dtype: tf.DType = DEFAULT_FLOAT_DTYPE_TF,
    tree_name: str = DEFAULT_TREE_NAME,
    pattern_counts: tp.Optional[tf.Tensor] = None,
) -> tfd.JointDistribution:
    if init_values is None:
        init_values = {}

    constant = partial(tf.constant, dtype=dtype)
    rate = constant(init_values.get("rate", 1e-3))

    model_dict = dict(
        pop_size=tfd.Exponential(rate=constant(0.1), name="pop_size"),
        tree=lambda pop_size: ConstantCoalescent(
            taxon_count=taxon_count,
            pop_size=pop_size,
            sampling_times=sampling_times,
            tree_name=tree_name,
        ),
        alignment=strict_clock_fixed_alignment_func(rate,
                                                    site_count=site_count,
                                                    weights=pattern_counts),
    )
    return tfd.JointDistributionNamed(model_dict)
Пример #6
0
  def test_legacy_dists(self):

    class StatefulNormal(tfd.Normal):

      def _sample_n(self, n, seed=None):
        return self.loc + self.scale * tf.random.normal(
            tf.concat([[n], self.batch_shape_tensor()], axis=0),
            seed=seed)

    # pylint: disable=bad-whitespace
    d = tfd.JointDistributionNamed(dict(
        e    =          tfd.Independent(tfd.Exponential(rate=[100, 120]), 1),
        loc  =          StatefulNormal(loc=0, scale=2.),
        scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]),
        m    =          tfd.Normal,
        x    =lambda m: tfd.Sample(tfd.Bernoulli(logits=m), 12)),
                                   validate_args=True)
    # pylint: enable=bad-whitespace

    warnings.simplefilter('always')
    with warnings.catch_warnings(record=True) as w:
      d.sample(seed=test_util.test_seed())
    self.assertRegexpMatches(
        str(w[0].message),
        r'Falling back to stateful sampling for.*of type.*StatefulNormal.*'
        r'component name "loc" and `dist.name` "Normal"',
        msg=w)
Пример #7
0
  def test_sample_kwargs(self):
    joint = tfd.JointDistributionNamed(
        dict(
            a=tfd.Normal(0., 1.),
            b=lambda a: tfd.Normal(a, 1.),
            c=lambda a, b: tfd.Normal(a + b, 1.)))

    seed = test_util.test_seed()
    tf.random.set_seed(seed)
    samples = joint.sample(seed=seed, a=1.)
    # Check the first value is actually 1.
    self.assertEqual(1., self.evaluate(samples['a']))

    # Check the sample is reproducible using the `value` argument.
    tf.random.set_seed(seed)
    samples_named = joint.sample(seed=seed, value={'a': 1.})
    self.assertAllEqual(self.evaluate(samples), self.evaluate(samples_named))

    # Make sure to throw an exception if strange keywords are passed.
    expected_error = (
        'Found unexpected keyword arguments. Distribution names are\n'
        'a, b, c\n'
        'but received\n'
        'z\n'
        'These names were invalid:\n'
        'z')
    with self.assertRaisesRegex(ValueError, expected_error):
      joint.sample(seed=seed, z=2.)

    # Raise if value and keywords are passed.
    with self.assertRaisesRegex(
        ValueError, r'Supplied both `value` and keyword arguments .*'):
      joint.sample(seed=seed, a=1., value={'a': 1})
Пример #8
0
    def test_can_call_ordereddict_log_prob_with_args_and_kwargs(
            self, model_fn):
        # With an OrderedDict, we can pass keyword and/or positional args.
        d = tfd.JointDistributionNamed(model_fn(), validate_args=True)

        # Destructure vector-valued Tensors into Python lists, to mimic the values
        # a user might type.
        sample = tf.nest.map_structure(
            lambda x: list(x) if isinstance(x, np.ndarray) else x,
            self.evaluate(d.sample(seed=test_util.test_seed())))
        sample_dict = dict(sample)

        lp_value_positional = self.evaluate(d.log_prob(sample_dict))
        lp_value_named = self.evaluate(d.log_prob(value=sample_dict))
        self.assertAllClose(lp_value_positional, lp_value_named)

        lp_args = self.evaluate(d.log_prob(*sample.values()))
        self.assertAllClose(lp_value_positional, lp_args)

        lp_kwargs = self.evaluate(d.log_prob(**sample_dict))
        self.assertAllClose(lp_value_positional, lp_kwargs)

        lp_args_then_kwargs = self.evaluate(
            d.log_prob(*list(sample.values())[:1],
                       **dict(list(sample.items())[1:])))
        self.assertAllClose(lp_value_positional, lp_args_then_kwargs)
Пример #9
0
  def test_works_with_structured_samples(self):

    # Check that we don't accidentally destroy the structure of `samples` when
    # it's a dict or other non-Tensor object from a joint distribution.
    p = tfd.JointDistributionNamed({
        'x': tfd.Normal(0., 1.),
        'y': tfd.Normal(0., 1.)})

    total_variance_with_reparam = tfp.monte_carlo.expectation(
        f=lambda d: d['x']**2 + d['y']**2,
        samples=p.sample(1000, seed=42),
        log_prob=p.log_prob,
        use_reparametrization=True)
    total_variance_without_reparam = tfp.monte_carlo.expectation(
        f=lambda d: d['x']**2 + d['y']**2,
        samples=p.sample(1000, seed=42),
        log_prob=p.log_prob,
        use_reparametrization=False)
    [
        total_variance_with_reparam_,
        total_variance_without_reparam_
    ] = self.evaluate([
        total_variance_with_reparam,
        total_variance_without_reparam])
    self.assertAllClose(total_variance_with_reparam_, 2., atol=0.2)
    self.assertAllClose(total_variance_without_reparam_, 2., atol=0.2)
Пример #10
0
    def test_can_call_namedtuple_log_prob_with_args_and_kwargs(self):
        # With an namedtuple, we can pass keyword and/or positional args.
        Model = collections.namedtuple('Model', ['e', 'a', 'x'])  # pylint: disable=invalid-name
        d = tfd.JointDistributionNamed(Model(
            e=tfd.Normal(0., 1.),
            a=tfd.Independent(tfd.Exponential(rate=[100, 120]),
                              reinterpreted_batch_ndims=1),
            x=lambda a: tfd.Gamma(concentration=a[..., 0], rate=a[..., 1])),
                                       validate_args=True)

        sample = self.evaluate(d.sample([2, 3], seed=test_util.test_seed()))
        e, a, x = sample.e, sample.a, sample.x

        lp_value_positional = self.evaluate(d.log_prob(Model(e=e, a=a, x=x)))
        lp_value_named = self.evaluate(d.log_prob(value=Model(e=e, a=a, x=x)))
        self.assertAllClose(lp_value_positional, lp_value_named)

        lp_kwargs = self.evaluate(d.log_prob(e=e, a=a, x=x))
        self.assertAllClose(lp_value_positional, lp_kwargs)

        lp_args = self.evaluate(d.log_prob(e, a, x))
        self.assertAllClose(lp_value_positional, lp_args)

        lp_args_then_kwargs = self.evaluate(d.log_prob(e, a=a, x=x))
        self.assertAllClose(lp_value_positional, lp_args_then_kwargs)
Пример #11
0
 def test_summary_statistic(self, attr):
     d = tfd.JointDistributionNamed(dict(logits=tfd.Normal(0., 1.),
                                         x=tfd.Bernoulli(logits=0.)),
                                    validate_args=True)
     expected = {k: getattr(d.model[k], attr)() for k in d.model.keys()}
     actual = getattr(d, attr)()
     self.assertAllEqual(*self.evaluate([expected, actual]))
Пример #12
0
 def test_sample_shape_propagation_nondefault_behavior(self):
     # pylint: disable=bad-whitespace
     d = tfd.JointDistributionNamed(dict(
         e=tfd.Independent(tfd.Exponential(rate=[100, 120]), 1),
         scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]),
         s=tfd.HalfNormal(2.5),
         loc=lambda s: tfd.Normal(loc=0, scale=s),
         df=tfd.Exponential(2),
         x=tfd.StudentT),
                                    validate_args=False)
     # pylint: enable=bad-whitespace
     # The following enables the nondefault sample shape behavior.
     d._always_use_specified_sample_shape = True
     sample_shape = (2, 3)
     x = d.sample(sample_shape, seed=test_util.test_seed())
     self.assertLen(x, 6)
     self.assertEqual(sample_shape + (2, ), x['e'].shape)
     self.assertEqual(sample_shape * 2, x['scale'].shape)  # Has 1 arg.
     self.assertEqual(sample_shape * 1, x['s'].shape)  # Has 0 args.
     self.assertEqual(sample_shape * 2, x['loc'].shape)  # Has 1 arg.
     self.assertEqual(sample_shape * 1, x['df'].shape)  # Has 0 args.
     # Has 3 args, one being scalar.
     self.assertEqual(sample_shape * 3, x['x'].shape)
     lp = d.log_prob(x)
     self.assertEqual(sample_shape * 3, lp.shape)
 def test_copy(self):
     pgm = dict(logits=tfd.Normal(0., 1.), probs=tfd.Bernoulli(logits=0.5))
     d = tfd.JointDistributionNamed(pgm, validate_args=True)
     d_copy = d.copy()
     self.assertEqual(d_copy.parameters['model'], pgm)
     self.assertEqual(d_copy.parameters['validate_args'], True)
     self.assertEqual(d_copy.parameters['name'], 'JointDistributionNamed')
Пример #14
0
    def test_sample_complex_dependency(self):
        # pylint: disable=bad-whitespace
        d = tfd.JointDistributionNamed(dict(
            y=tfd.StudentT,
            x=tfd.StudentT,
            df=tfd.Exponential(2),
            loc=lambda s: tfd.Normal(loc=0, scale=s),
            s=tfd.HalfNormal(2.5),
            scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]),
            e=tfd.Independent(tfd.Exponential(rate=[100, 120]), 1)),
                                       validate_args=False)

        # pylint: enable=bad-whitespace

        self.assertEqual((
            ('e', ()),
            ('scale', ('e', )),
            ('s', ()),
            ('loc', ('s', )),
            ('df', ()),
            ('y', ('df', 'loc', 'scale')),
            ('x', ('df', 'loc', 'scale')),
        ), d.resolve_graph())

        x = d.sample()
        self.assertLen(x, 7)

        ds, s = d.sample_distributions()
        self.assertEqual(ds['x'].parameters['df'], s['df'])
        self.assertEqual(ds['x'].parameters['loc'], s['loc'])
        self.assertEqual(ds['x'].parameters['scale'], s['scale'])
        self.assertEqual(ds['y'].parameters['df'], s['df'])
        self.assertEqual(ds['y'].parameters['loc'], s['loc'])
        self.assertEqual(ds['y'].parameters['scale'], s['scale'])
Пример #15
0
  def test_namedtuple_sample_log_prob(self):
    Model = collections.namedtuple('Model', ['e', 'scale', 'loc', 'm', 'x'])  # pylint: disable=invalid-name
    # pylint: disable=bad-whitespace
    model = Model(
        e    =          tfd.Independent(tfd.Exponential(rate=[100, 120]), 1),
        scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]),
        loc  =          tfd.Normal(loc=0, scale=2.),
        m    =          tfd.Normal,
        x    =lambda m: tfd.Sample(tfd.Bernoulli(logits=m), 12))
    # pylint: enable=bad-whitespace
    d = tfd.JointDistributionNamed(model, validate_args=True)

    self.assertEqual(
        (
            ('e', ()),
            ('scale', ('e',)),
            ('loc', ()),
            ('m', ('loc', 'scale')),
            ('x', ('m',)),
        ),
        d.resolve_graph())

    xs = d.sample(seed=test_util.test_seed())
    self.assertLen(xs, 5)
    # We'll verify the shapes work as intended when we plumb these back into the
    # respective log_probs.

    ds, _ = d.sample_distributions(value=xs, seed=test_util.test_seed())
    self.assertLen(ds, 5)
    self.assertIsInstance(ds.e, tfd.Independent)
    self.assertIsInstance(ds.scale, tfd.Gamma)
    self.assertIsInstance(ds.loc, tfd.Normal)
    self.assertIsInstance(ds.m, tfd.Normal)
    self.assertIsInstance(ds.x, tfd.Sample)

    # Static properties.
    self.assertAllEqual(Model(e=tf.float32, scale=tf.float32, loc=tf.float32,
                              m=tf.float32, x=tf.int32),
                        d.dtype)

    batch_shape_tensor_, event_shape_tensor_ = self.evaluate([
        d.batch_shape_tensor(), d.event_shape_tensor()])

    expected_batch_shape = Model(e=[], scale=[], loc=[], m=[], x=[])
    for (expected, actual_tensorshape, actual_shape_tensor_) in zip(
        expected_batch_shape, d.batch_shape, batch_shape_tensor_):
      self.assertAllEqual(expected, actual_tensorshape)
      self.assertAllEqual(expected, actual_shape_tensor_)

    expected_event_shape = Model(e=[2], scale=[], loc=[], m=[], x=[12])
    for (expected, actual_tensorshape, actual_shape_tensor_) in zip(
        expected_event_shape, d.event_shape, event_shape_tensor_):
      self.assertAllEqual(expected, actual_tensorshape)
      self.assertAllEqual(expected, actual_shape_tensor_)

    expected_jlp = sum(d.log_prob(x) for d, x in zip(ds, xs))
    actual_jlp = d.log_prob(xs)
    self.assertAllClose(*self.evaluate([expected_jlp, actual_jlp]),
                        atol=0., rtol=1e-4)
Пример #16
0
 def test_notimplemented_quantile(self):
   d = tfd.JointDistributionNamed(dict(logits=tfd.Normal(0., 1.),
                                       x=tfd.Bernoulli(probs=0.5)),
                                  validate_args=True)
   with self.assertRaisesWithPredicateMatch(
       NotImplementedError,
       'quantile is not implemented: JointDistributionNamed'):
     d.quantile(0.5)
Пример #17
0
 def test_notimplemented_evaluative_statistic(self, attr):
     d = tfd.JointDistributionNamed(dict(logits=tfd.Normal(0., 1.),
                                         x=tfd.Bernoulli(probs=0.5)),
                                    validate_args=True)
     with self.assertRaisesWithPredicateMatch(
             NotImplementedError,
             attr + ' is not implemented: JointDistributionNamed'):
         getattr(d, attr)(dict(logits=0., x=0.5))
Пример #18
0
 def test_kl_divergence(self):
   d0 = tfd.JointDistributionNamed(
       dict(e=tfd.Independent(tfd.Exponential(rate=[100, 120]), 1),
            x=tfd.Normal(loc=0, scale=2.)),
       validate_args=True)
   d1 = tfd.JointDistributionNamed(
       dict(e=tfd.Independent(tfd.Exponential(rate=[10, 12]), 1),
            x=tfd.Normal(loc=1, scale=1.)),
       validate_args=True)
   self.assertEqual(d0.model.keys(), d1.model.keys())
   expected_kl = sum(tfd.kl_divergence(d0.model[k], d1.model[k])
                     for k in d0.model.keys())
   actual_kl = tfd.kl_divergence(d0, d1)
   other_actual_kl = d0.kl_divergence(d1)
   expected_kl_, actual_kl_, other_actual_kl_ = self.evaluate([
       expected_kl, actual_kl, other_actual_kl])
   self.assertNear(expected_kl_, actual_kl_, err=1e-5)
   self.assertNear(expected_kl_, other_actual_kl_, err=1e-5)
Пример #19
0
 def test_copy(self):
   pgm = dict(logits=tfd.Normal(0., 1.), probs=tfd.Bernoulli(logits=0.5))
   d = tfd.JointDistributionNamed(pgm, validate_args=True)
   d_copy = d.copy()
   self.assertAllEqual(
       {'model': pgm,
        'validate_args': True,
        'name': 'JointDistributionNamed'},
       d_copy.parameters)
Пример #20
0
    def test_nested_partial_value(self, sample_fn):
        innermost = tfd.JointDistributionNamed({
            'a':
            tfd.Exponential(1.),
            'b':
            lambda a: tfd.Sample(tfd.LogNormal(a, a), [5]),
        })

        inner = tfd.JointDistributionNamed({
            'c': tfd.Exponential(1.),
            'd': innermost,
        })

        outer = tfd.JointDistributionNamed({
            'e': tfd.Exponential(1.),
            'f': inner,
        })

        seed = test_util.test_seed(sampler_type='stateless')
        true_xs = outer.sample(seed=seed)

        def _update(dict_, **kwargs):
            dict_.copy().update(**kwargs)
            return dict_

        # These asserts work because we advance the stateless seed inside the model
        # whether or not a sample is actually generated.
        partial_xs = _update(true_xs, f=None)
        xs = sample_fn(outer, value=partial_xs, seed=seed)
        self.assertAllCloseNested(true_xs, xs)

        partial_xs = _update(true_xs, e=None)
        xs = sample_fn(outer, value=partial_xs, seed=seed)
        self.assertAllCloseNested(true_xs, xs)

        partial_xs = _update(true_xs, f=_update(true_xs['f'], d=None))
        xs = sample_fn(outer, value=partial_xs, seed=seed)
        self.assertAllCloseNested(true_xs, xs)

        partial_xs = _update(true_xs,
                             f=_update(true_xs['f'],
                                       d=_update(true_xs['f']['d'], a=None)))
        xs = sample_fn(outer, value=partial_xs, seed=seed)
        self.assertAllCloseNested(true_xs, xs)
Пример #21
0
  def testPartsWithUnusedInternalStructure(self):
    dist = tfd.JointDistributionSequential([
        tfd.JointDistributionNamed({'a': tfd.Normal(0., 1.)}),
        tfd.JointDistributionNamed({'b': tfd.Normal(1000., 1.)}),
    ])
    x = dist.sample(  # Shape: [{'a': []}, {'b': []}]
        seed=test_util.test_seed(sampler_type='stateless'))

    # Test that we can swap the outer list entries, even though they contain
    # internal structure (i.e., are themselves dicts).
    swap_elements = tfb.Restructure(input_structure=[1, 0],
                                    output_structure=[0, 1])
    self.assertAllEqualNested(swap_elements(x), [x[1], x[0]], check_types=True)

    swapped_dist = swap_elements(dist)
    self.assertAllEqualNested(swapped_dist.event_shape,
                              [dist.event_shape[1], dist.event_shape[0]],
                              check_types=True)
    self.assertEqual(swapped_dist.dtype,
                     [dist.dtype[1], dist.dtype[0]])
  def test_transform_joint_to_joint(self, split_sizes):
    dist_batch_shape = tf.nest.pack_sequence_as(
        split_sizes,
        [tensorshape_util.constant_value_as_shape(s)
         for s in [[2, 3], [2, 1], [1, 3]]])
    bijector_batch_shape = [1, 3]

    # Build a joint distribution with parts of the specified sizes.
    seed = test_util.test_seed_stream()
    component_dists = tf.nest.map_structure(
        lambda size, batch_shape: tfd.MultivariateNormalDiag(  # pylint: disable=g-long-lambda
            loc=tf.random.normal(batch_shape + [size], seed=seed()),
            scale_diag=tf.random.uniform(
                minval=1., maxval=2.,
                shape=batch_shape + [size], seed=seed())),
        split_sizes, dist_batch_shape)
    if isinstance(split_sizes, dict):
      base_dist = tfd.JointDistributionNamed(component_dists)
    else:
      base_dist = tfd.JointDistributionSequential(component_dists)

    # Transform the distribution by applying a separate bijector to each part.
    bijectors = [tfb.Exp(),
                 tfb.Scale(
                     tf.random.uniform(
                         minval=1., maxval=2.,
                         shape=bijector_batch_shape, seed=seed())),
                 tfb.Reshape([2, 1])]
    bijector = tfb.JointMap(tf.nest.pack_sequence_as(split_sizes, bijectors),
                            validate_args=True)

    # Transform a joint distribution that has different batch shape components
    transformed_dist = tfd.TransformedDistribution(base_dist, bijector)

    self.assertRegex(
        str(transformed_dist),
        '{}.*batch_shape.*event_shape.*dtype'.format(transformed_dist.name))

    self.assertAllEqualNested(
        transformed_dist.event_shape,
        bijector.forward_event_shape(base_dist.event_shape))
    self.assertAllEqualNested(*self.evaluate((
        transformed_dist.event_shape_tensor(),
        bijector.forward_event_shape_tensor(base_dist.event_shape_tensor()))))

    # Test that the batch shape components of the input are the same as those of
    # the output.
    self.assertAllEqualNested(transformed_dist.batch_shape, dist_batch_shape)
    self.assertAllEqualNested(
        self.evaluate(transformed_dist.batch_shape_tensor()), dist_batch_shape)
    self.assertAllEqualNested(dist_batch_shape, base_dist.batch_shape)
Пример #23
0
 def test_graph_resolution(self):
     # pylint: disable=bad-whitespace
     d = tfd.JointDistributionNamed(dict(
         e=tfd.Independent(tfd.Exponential(rate=[100, 120]), 1),
         scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]),
         s=tfd.HalfNormal(2.5),
         loc=lambda s: tfd.Normal(loc=0, scale=s),
         df=tfd.Exponential(2),
         x=tfd.StudentT),
                                    validate_args=True)
     # pylint: enable=bad-whitespace
     self.assertEqual(
         (('e', ()), ('scale', ('e', )), ('s', ()), ('loc', ('s', )),
          ('df', ()), ('x', ('df', 'loc', 'scale'))), d.resolve_graph())
 def test_dist_fn_takes_kwargs(self):
   dist = tfd.JointDistributionNamed(
       {'positive': tfd.Exponential(rate=1.),
        'negative': tfb.Scale(-1.)(tfd.Exponential(rate=1.)),
        'b': lambda **kwargs: tfd.Normal(loc=kwargs['negative'],  # pylint: disable=g-long-lambda
                                         scale=kwargs['positive'],
                                         validate_args=True),
        'a': lambda **kwargs: tfb.Scale(kwargs['b'])(  # pylint: disable=g-long-lambda
            tfd.Gamma(concentration=-kwargs['negative'],
                      rate=kwargs['positive'],
                      validate_args=True))
        }, validate_args=True)
   lp = dist.log_prob(dist.sample(5, seed=test_util.test_seed()))
   self.assertAllEqual(lp.shape, [5])
Пример #25
0
def test_get_mean_field_approximation_tree(
    flat_tree_test_data: TreeTestData, with_init_loc: bool
):
    test_tree = data_to_tensor_tree(flat_tree_test_data)
    taxon_count = test_tree.taxon_count
    tree_name = "tree_dist_name"

    init_loc: tp.Optional[tp.Dict[str, object]]
    if with_init_loc:
        init_loc = dict(tree=test_tree)
    else:
        init_loc = None

    model = tfd.JointDistributionNamed(
        dict(
            pop_size=tfd.LogNormal(_constant(0.0), _constant(1.0)),
            tree=lambda pop_size: ConstantCoalescent(
                taxon_count, pop_size, test_tree.sampling_times, tree_name=tree_name
            ),
            obs=lambda tree: tfd.Normal(
                _constant(0.0), tf.reduce_sum(tree.branch_lengths)
            ),
        )
    )
    obs = _constant([10.0])
    pinned = model.experimental_pin(obs=obs)
    approximation = get_fixed_topology_mean_field_approximation(
        pinned,
        dtype=DEFAULT_FLOAT_DTYPE_TF,
        topology_pins={tree_name: test_tree.topology},
        init_loc=init_loc,
    )

    sample = approximation.sample()
    assert (
        tf.reduce_all(
            sample["tree"].topology.parent_indices == test_tree.topology.parent_indices
        )
        .numpy()
        .item()
    )
    assert_allclose(
        sample["tree"].sampling_times.numpy(), test_tree.sampling_times.numpy()
    )
    model_log_prob = pinned.unnormalized_log_prob(sample)
    approx_log_prob = approximation.log_prob(sample)
    assert np.isfinite(model_log_prob.numpy())
    assert np.isfinite(approx_log_prob.numpy())
Пример #26
0
 def transition_fn(_, previous_state):
   return tfd.JointDistributionNamed(
       {
           # The autoregressive coefficients and the `log_scale` each follow
           # an independent slow-moving random walk.
           'coefs': tfd.Independent(
               tfd.Normal(loc=previous_state['coefs'], scale=0.01),
               reinterpreted_batch_ndims=1),
           'log_scale': tfd.Normal(loc=previous_state['log_scale'],
                                   scale=0.01),
           # The level is a linear combination of the previous *two* levels,
           # with additional noise of scale `exp(log_scale)`.
           'level': lambda coefs, log_scale: tfd.Normal(  # pylint: disable=g-long-lambda
               loc=(coefs[..., 0] * previous_state['level'] +
                    coefs[..., 1] * previous_state['previous_level']),
               scale=tf.exp(log_scale)),
           # Store the previous level to access at the next step.
           'previous_level': tfd.Deterministic(previous_state['level'])})
Пример #27
0
    def test_legacy_dists_stateless_seed_raises(self):
        class StatefulNormal(tfd.Normal):
            def _sample_n(self, n, seed=None):
                return self.loc + self.scale * tf.random.normal(tf.concat(
                    [[n], self.batch_shape_tensor()], axis=0),
                                                                seed=seed)

        # pylint: disable=bad-whitespace
        d = tfd.JointDistributionNamed(dict(
            e=tfd.Independent(tfd.Exponential(rate=[100, 120]), 1),
            loc=StatefulNormal(loc=0, scale=2.),
            scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]),
            m=tfd.Normal,
            x=lambda m: tfd.Sample(tfd.Bernoulli(logits=m), 12)),
                                       validate_args=True)
        # pylint: enable=bad-whitespace

        with self.assertRaisesRegexp(TypeError, r'Expected int for argument'):
            d.sample(seed=samplers.zeros_seed())
Пример #28
0
def test_get_mean_field_approximation():
    sample_size = 3
    model = tfd.JointDistributionNamed(
        dict(
            a=tfd.Normal(_constant(0.0), _constant(1.0)),
            b=lambda a: tfd.Sample(tfd.LogNormal(a, _constant(1.0)), sample_size),
            obs=lambda b: tfd.Independent(
                tfd.Normal(b, _constant(1.0)), reinterpreted_batch_ndims=1
            ),
        )
    )
    obs = _constant([-1.1, 2.1, 0.1])
    pinned = model.experimental_pin(obs=obs)
    approximation = get_mean_field_approximation(
        pinned, init_loc=dict(a=_constant(0.1)), dtype=DEFAULT_FLOAT_DTYPE_TF
    )
    sample = approximation.sample()
    model_log_prob = pinned.unnormalized_log_prob(sample)
    approx_log_prob = approximation.log_prob(sample)
    assert np.isfinite(model_log_prob.numpy())
    assert np.isfinite(approx_log_prob.numpy())
Пример #29
0
 def test_sample_shape_propagation_default_behavior(self):
     # pylint: disable=bad-whitespace
     d = tfd.JointDistributionNamed(dict(
         e=tfd.Independent(tfd.Exponential(rate=[100, 120]), 1),
         scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]),
         s=tfd.HalfNormal(2.5),
         loc=lambda s: tfd.Normal(loc=0, scale=s),
         df=tfd.Exponential(2),
         x=tfd.StudentT),
                                    validate_args=False)
     # pylint: enable=bad-whitespace
     x = d.sample([2, 3], seed=test_util.test_seed())
     self.assertLen(x, 6)
     self.assertEqual((2, 3, 2), x['e'].shape)
     self.assertEqual((2, 3), x['scale'].shape)
     self.assertEqual((2, 3), x['s'].shape)
     self.assertEqual((2, 3), x['loc'].shape)
     self.assertEqual((2, 3), x['df'].shape)
     self.assertEqual((2, 3), x['x'].shape)
     lp = d.log_prob(x)
     self.assertEqual((2, 3), lp.shape)
Пример #30
0
    def test_batch_slicing(self):
        # pylint: disable=bad-whitespace
        d = tfd.JointDistributionNamed(dict(
            s=tfd.Exponential(rate=[10, 12, 14]),
            n=lambda s: tfd.Normal(loc=0, scale=s),
            x=lambda: tfd.Beta(concentration0=[3, 2, 1], concentration1=1)),
                                       validate_args=True)
        # pylint: enable=bad-whitespace

        d0, d1 = d[:1], d[1:]
        x0 = d0.sample(seed=test_util.test_seed())
        x1 = d1.sample(seed=test_util.test_seed())

        self.assertLen(x0, 3)
        self.assertEqual([1], x0['s'].shape)
        self.assertEqual([1], x0['n'].shape)
        self.assertEqual([1], x0['x'].shape)

        self.assertLen(x1, 3)
        self.assertEqual([2], x1['s'].shape)
        self.assertEqual([2], x1['n'].shape)
        self.assertEqual([2], x1['x'].shape)