Esempio n. 1
0
  def test_default_event_space_bijector(self):
    # pylint: disable=bad-whitespace
    d = tfd.JointDistributionNamed(dict(
        e    =          tfd.Independent(tfd.Exponential(rate=[100, 120]), 1),
        scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]),
        s    =          tfd.HalfNormal(2.5),
        loc  =lambda s: tfd.Normal(loc=0, scale=s),
        df   =          tfd.Exponential(2),
        x    =          tfd.StudentT),
                                   validate_args=True)
    # pylint: enable=bad-whitespace
    with self.assertRaisesRegex(
        NotImplementedError, 'all elements of `model` are `tfp.distribution`s'):
      d._experimental_default_event_space_bijector()

    d = tfd.JointDistributionNamed(
        dict(e=tfd.Independent(tfd.Exponential(rate=[10, 12]), 1),
             x=tfd.Normal(loc=1, scale=1.),
             s=tfd.HalfNormal(2.5)),
        validate_args=True)
    for b in d._model_flatten(d._experimental_default_event_space_bijector()):
      self.assertIsInstance(b, tfb.Bijector)
    self.assertSetEqual(set(d.model.keys()),
                        set(d._experimental_default_event_space_bijector(
                            ).keys()))
Esempio n. 2
0
 def test_sample_shape_propagation_nondefault_behavior(self):
     # pylint: disable=bad-whitespace
     d = tfd.JointDistributionNamed(dict(
         e=tfd.Independent(tfd.Exponential(rate=[100, 120]), 1),
         scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]),
         s=tfd.HalfNormal(2.5),
         loc=lambda s: tfd.Normal(loc=0, scale=s),
         df=tfd.Exponential(2),
         x=tfd.StudentT),
                                    validate_args=False)
     # pylint: enable=bad-whitespace
     # The following enables the nondefault sample shape behavior.
     d._always_use_specified_sample_shape = True
     sample_shape = (2, 3)
     x = d.sample(sample_shape, seed=test_util.test_seed())
     self.assertLen(x, 6)
     self.assertEqual(sample_shape + (2, ), x['e'].shape)
     self.assertEqual(sample_shape * 2, x['scale'].shape)  # Has 1 arg.
     self.assertEqual(sample_shape * 1, x['s'].shape)  # Has 0 args.
     self.assertEqual(sample_shape * 2, x['loc'].shape)  # Has 1 arg.
     self.assertEqual(sample_shape * 1, x['df'].shape)  # Has 0 args.
     # Has 3 args, one being scalar.
     self.assertEqual(sample_shape * 3, x['x'].shape)
     lp = d.log_prob(x)
     self.assertEqual(sample_shape * 3, lp.shape)
Esempio n. 3
0
  def test_default_event_space_bijector(self):
    # pylint: disable=bad-whitespace
    d = tfd.JointDistributionNamed(dict(
        e    =          tfd.Independent(tfd.Exponential(rate=[100, 120]), 1),
        scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]),
        s    =          tfd.HalfNormal(2.5),
        loc  =lambda s: tfd.Normal(loc=0, scale=s),
        df   =          tfd.Exponential(2),
        x    =          tfd.StudentT),
                                   validate_args=True)
    # pylint: enable=bad-whitespace

    # The event space bijector is inherited from `JointDistributionSequential`
    # and is tested more thoroughly in the tests for that class.
    b = d.experimental_default_event_space_bijector()
    y = self.evaluate(d.sample(seed=test_util.test_seed()))
    y_ = self.evaluate(b.forward(b.inverse(y)))
    self.assertAllClose(y, y_)

    # Verify that event shapes are passed through and flattened/unflattened
    # correctly.
    forward_event_shapes = b.forward_event_shape(d.event_shape)
    inverse_event_shapes = b.inverse_event_shape(d.event_shape)
    self.assertEqual(forward_event_shapes, d.event_shape)
    self.assertEqual(inverse_event_shapes, d.event_shape)

    # Verify that the outputs of other methods have the correct dict structure.
    forward_event_shape_tensors = b.forward_event_shape_tensor(
        d.event_shape_tensor())
    inverse_event_shape_tensors = b.inverse_event_shape_tensor(
        d.event_shape_tensor())
    for item in [forward_event_shape_tensors, inverse_event_shape_tensors]:
      self.assertSetEqual(set(self.evaluate(item).keys()), set(d.model.keys()))
Esempio n. 4
0
    def test_sample_complex_dependency(self):
        # pylint: disable=bad-whitespace
        d = tfd.JointDistributionNamed(dict(
            y=tfd.StudentT,
            x=tfd.StudentT,
            df=tfd.Exponential(2),
            loc=lambda s: tfd.Normal(loc=0, scale=s),
            s=tfd.HalfNormal(2.5),
            scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]),
            e=tfd.Independent(tfd.Exponential(rate=[100, 120]), 1)),
                                       validate_args=False)

        # pylint: enable=bad-whitespace

        self.assertEqual((
            ('e', ()),
            ('scale', ('e', )),
            ('s', ()),
            ('loc', ('s', )),
            ('df', ()),
            ('y', ('df', 'loc', 'scale')),
            ('x', ('df', 'loc', 'scale')),
        ), d.resolve_graph())

        x = d.sample()
        self.assertLen(x, 7)

        ds, s = d.sample_distributions()
        self.assertEqual(ds['x'].parameters['df'], s['df'])
        self.assertEqual(ds['x'].parameters['loc'], s['loc'])
        self.assertEqual(ds['x'].parameters['scale'], s['scale'])
        self.assertEqual(ds['y'].parameters['df'], s['df'])
        self.assertEqual(ds['y'].parameters['loc'], s['loc'])
        self.assertEqual(ds['y'].parameters['scale'], s['scale'])
 def testTransformedKLDifferentBijectorFails(self):
     d1 = self._cls()(tfd.Exponential(rate=0.25),
                      bijector=tfb.Scale(scale=2.),
                      validate_args=True)
     d2 = self._cls()(tfd.Exponential(rate=0.25),
                      bijector=tfb.Scale(scale=3.),
                      validate_args=True)
     with self.assertRaisesRegex(NotImplementedError,
                                 r'their bijectors are not equal'):
         tfd.kl_divergence(d1, d2)
Esempio n. 6
0
 def test_graph_resolution(self):
     # pylint: disable=bad-whitespace
     d = tfd.JointDistributionNamed(dict(
         e=tfd.Independent(tfd.Exponential(rate=[100, 120]), 1),
         scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]),
         s=tfd.HalfNormal(2.5),
         loc=lambda s: tfd.Normal(loc=0, scale=s),
         df=tfd.Exponential(2),
         x=tfd.StudentT),
                                    validate_args=True)
     # pylint: enable=bad-whitespace
     self.assertEqual(
         (('e', ()), ('scale', ('e', )), ('s', ()), ('loc', ('s', )),
          ('df', ()), ('x', ('df', 'loc', 'scale'))), d.resolve_graph())
 def test_dist_fn_takes_kwargs(self):
   dist = tfd.JointDistributionNamed(
       {'positive': tfd.Exponential(rate=1.),
        'negative': tfb.Scale(-1.)(tfd.Exponential(rate=1.)),
        'b': lambda **kwargs: tfd.Normal(loc=kwargs['negative'],  # pylint: disable=g-long-lambda
                                         scale=kwargs['positive'],
                                         validate_args=True),
        'a': lambda **kwargs: tfb.Scale(kwargs['b'])(  # pylint: disable=g-long-lambda
            tfd.Gamma(concentration=-kwargs['negative'],
                      rate=kwargs['positive'],
                      validate_args=True))
        }, validate_args=True)
   lp = dist.log_prob(dist.sample(5, seed=test_util.test_seed()))
   self.assertAllEqual(lp.shape, [5])
Esempio n. 8
0
 def test_cross_entropy(self):
   d0 = tfd.JointDistributionNamed(
       dict(e=tfd.Independent(tfd.Exponential(rate=[100, 120]), 1),
            x=tfd.Normal(loc=0, scale=2.)),
       validate_args=True)
   d1 = tfd.JointDistributionNamed(
       dict(e=tfd.Independent(tfd.Exponential(rate=[10, 12]), 1),
            x=tfd.Normal(loc=1, scale=1.)),
       validate_args=True)
   self.assertEqual(d0.model.keys(), d1.model.keys())
   expected_xent = sum(d0.model[k].cross_entropy(d1.model[k])
                       for k in d0.model.keys())
   actual_xent = d0.cross_entropy(d1)
   expected_xent_, actual_xent_ = self.evaluate([expected_xent, actual_xent])
   self.assertNear(actual_xent_, expected_xent_, err=1e-5)
Esempio n. 9
0
    def testStrWorksCorrectlyScalar(self):
        normal = tfd.Normal(loc=np.float16(0), scale=1, validate_args=True)
        self.assertEqual(
            str(normal), 'tfp.distributions.Normal('
            '"Normal", '
            'batch_shape=[], '
            'event_shape=[], '
            'dtype=float16)')

        chi2 = tfd.Chi2(df=np.float32([1., 2.]),
                        name='silly',
                        validate_args=True)
        self.assertEqual(
            str(chi2),
            'tfp.distributions.Chi2('
            '"silly", '  # What a silly name that is!
            'batch_shape=[2], '
            'event_shape=[], '
            'dtype=float32)')

        # There's no notion of partially known shapes in eager mode, so exit
        # early.
        if tf.executing_eagerly():
            return

        exp = tfd.Exponential(rate=tf1.placeholder_with_default(1.,
                                                                shape=None),
                              validate_args=True)
        self.assertEqual(
            str(exp),
            'tfp.distributions.Exponential("Exponential", '
            # No batch shape.
            'event_shape=[], '
            'dtype=float32)')
Esempio n. 10
0
def get_example_phylo_model(
    taxon_count: int,
    site_count: int,
    sampling_times: tf.Tensor,
    init_values: tp.Optional[tp.Dict[str, float]] = None,
    dtype: tf.DType = DEFAULT_FLOAT_DTYPE_TF,
    tree_name: str = DEFAULT_TREE_NAME,
    pattern_counts: tp.Optional[tf.Tensor] = None,
) -> tfd.JointDistribution:
    if init_values is None:
        init_values = {}

    constant = partial(tf.constant, dtype=dtype)
    rate = constant(init_values.get("rate", 1e-3))

    model_dict = dict(
        pop_size=tfd.Exponential(rate=constant(0.1), name="pop_size"),
        tree=lambda pop_size: ConstantCoalescent(
            taxon_count=taxon_count,
            pop_size=pop_size,
            sampling_times=sampling_times,
            tree_name=tree_name,
        ),
        alignment=strict_clock_fixed_alignment_func(rate,
                                                    site_count=site_count,
                                                    weights=pattern_counts),
    )
    return tfd.JointDistributionNamed(model_dict)
Esempio n. 11
0
  def test_can_call_log_prob_with_kwargs(self):

    d = tfd.JointDistributionNamed({
        'e': tfd.Normal(0., 1.),
        'a': tfd.Independent(
            tfd.Exponential(rate=[100, 120]),
            reinterpreted_batch_ndims=1),
        'x': lambda a: tfd.Gamma(concentration=a[..., 0], rate=a[..., 1])
    }, validate_args=True)

    sample = self.evaluate(d.sample([2, 3], seed=test_util.test_seed()))
    e, a, x = sample['e'], sample['a'], sample['x']

    lp_value_positional = self.evaluate(d.log_prob({'e': e, 'a': a, 'x': x}))
    lp_value_named = self.evaluate(d.log_prob(value={'e': e, 'a': a, 'x': x}))
    # Assert all close (rather than equal) because order is not defined for
    # dicts, and reordering the computation can give subtly different results.
    self.assertAllClose(lp_value_positional, lp_value_named)

    lp_kwargs = self.evaluate(d.log_prob(a=a, e=e, x=x))
    self.assertAllClose(lp_value_positional, lp_kwargs)

    with self.assertRaisesRegexp(ValueError,
                                 'Joint distribution with unordered variables '
                                 "can't take positional args"):
      lp_kwargs = d.log_prob(e, a, x)
Esempio n. 12
0
    def test_can_call_namedtuple_log_prob_with_args_and_kwargs(self):
        # With an namedtuple, we can pass keyword and/or positional args.
        Model = collections.namedtuple('Model', ['e', 'a', 'x'])  # pylint: disable=invalid-name
        d = tfd.JointDistributionNamed(Model(
            e=tfd.Normal(0., 1.),
            a=tfd.Independent(tfd.Exponential(rate=[100, 120]),
                              reinterpreted_batch_ndims=1),
            x=lambda a: tfd.Gamma(concentration=a[..., 0], rate=a[..., 1])),
                                       validate_args=True)

        sample = self.evaluate(d.sample([2, 3], seed=test_util.test_seed()))
        e, a, x = sample.e, sample.a, sample.x

        lp_value_positional = self.evaluate(d.log_prob(Model(e=e, a=a, x=x)))
        lp_value_named = self.evaluate(d.log_prob(value=Model(e=e, a=a, x=x)))
        self.assertAllClose(lp_value_positional, lp_value_named)

        lp_kwargs = self.evaluate(d.log_prob(e=e, a=a, x=x))
        self.assertAllClose(lp_value_positional, lp_kwargs)

        lp_args = self.evaluate(d.log_prob(e, a, x))
        self.assertAllClose(lp_value_positional, lp_args)

        lp_args_then_kwargs = self.evaluate(d.log_prob(e, a=a, x=x))
        self.assertAllClose(lp_value_positional, lp_args_then_kwargs)
Esempio n. 13
0
    def testReprWorksCorrectlyScalar(self):
        normal = tfd.Normal(loc=np.float16(0), scale=np.float16(1))
        self.assertEqual(
            repr(normal), '<tfp.distributions.Normal'
            ' \'Normal\''
            ' batch_shape=[]'
            ' event_shape=[]'
            ' dtype=float16>')

        chi2 = tfd.Chi2(df=np.float32([1., 2.]), name='silly')
        self.assertEqual(
            repr(chi2),
            '<tfp.distributions.Chi2'
            ' \'silly\''  # What a silly name that is!
            ' batch_shape=[2]'
            ' event_shape=[]'
            ' dtype=float32>')

        # There's no notion of partially known shapes in eager mode, so exit
        # early.
        if tf.executing_eagerly():
            return

        exp = tfd.Exponential(
            rate=tf1.placeholder_with_default(1., shape=None))
        self.assertEqual(
            repr(exp), '<tfp.distributions.Exponential'
            ' \'Exponential\''
            ' batch_shape=?'
            ' event_shape=[]'
            ' dtype=float32>')
Esempio n. 14
0
  def test_scalar_distributions(self):
    self.dist1 = tfd.Normal(
        loc=self.maybe_static(
            tf.zeros(self.batch_dim_1, dtype=self.dtype),
            self.is_static),
        scale=self.maybe_static(
            tf.ones(self.batch_dim_1, dtype=self.dtype),
            self.is_static)
    )
    self.dist2 = tfd.Logistic(
        loc=self.maybe_static(
            tf.zeros(self.batch_dim_2, dtype=self.dtype),
            self.is_static),
        scale=self.maybe_static(
            tf.ones(self.batch_dim_2, dtype=self.dtype),
            self.is_static)
    )
    self.dist3 = tfd.Exponential(
        rate=self.maybe_static(
            tf.ones(self.batch_dim_3, dtype=self.dtype),
            self.is_static)
    )
    concat_dist = batch_concat.BatchConcat(
        distributions=[self.dist1, self.dist2, self.dist3], axis=1,
        validate_args=False)
    self.assertAllEqual(
        self.evaluate(concat_dist.batch_shape_tensor()),
        [2, 6, 4])

    seed = test_util.test_seed()
    samples = concat_dist.sample(seed=seed)
    self.assertAllEqual(self.evaluate(tf.shape(samples)), [2, 6, 4])
Esempio n. 15
0
  def test_legacy_dists(self):

    class StatefulNormal(tfd.Normal):

      def _sample_n(self, n, seed=None):
        return self.loc + self.scale * tf.random.normal(
            tf.concat([[n], self.batch_shape_tensor()], axis=0),
            seed=seed)

    # pylint: disable=bad-whitespace
    d = tfd.JointDistributionNamed(dict(
        e    =          tfd.Independent(tfd.Exponential(rate=[100, 120]), 1),
        loc  =          StatefulNormal(loc=0, scale=2.),
        scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]),
        m    =          tfd.Normal,
        x    =lambda m: tfd.Sample(tfd.Bernoulli(logits=m), 12)),
                                   validate_args=True)
    # pylint: enable=bad-whitespace

    warnings.simplefilter('always')
    with warnings.catch_warnings(record=True) as w:
      d.sample(seed=test_util.test_seed())
    self.assertRegexpMatches(
        str(w[0].message),
        r'Falling back to stateful sampling for.*of type.*StatefulNormal.*'
        r'component name "loc" and `dist.name` "Normal"',
        msg=w)
Esempio n. 16
0
  def test_namedtuple_sample_log_prob(self):
    Model = collections.namedtuple('Model', ['e', 'scale', 'loc', 'm', 'x'])  # pylint: disable=invalid-name
    # pylint: disable=bad-whitespace
    model = Model(
        e    =          tfd.Independent(tfd.Exponential(rate=[100, 120]), 1),
        scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]),
        loc  =          tfd.Normal(loc=0, scale=2.),
        m    =          tfd.Normal,
        x    =lambda m: tfd.Sample(tfd.Bernoulli(logits=m), 12))
    # pylint: enable=bad-whitespace
    d = tfd.JointDistributionNamed(model, validate_args=True)

    self.assertEqual(
        (
            ('e', ()),
            ('scale', ('e',)),
            ('loc', ()),
            ('m', ('loc', 'scale')),
            ('x', ('m',)),
        ),
        d.resolve_graph())

    xs = d.sample(seed=test_util.test_seed())
    self.assertLen(xs, 5)
    # We'll verify the shapes work as intended when we plumb these back into the
    # respective log_probs.

    ds, _ = d.sample_distributions(value=xs, seed=test_util.test_seed())
    self.assertLen(ds, 5)
    self.assertIsInstance(ds.e, tfd.Independent)
    self.assertIsInstance(ds.scale, tfd.Gamma)
    self.assertIsInstance(ds.loc, tfd.Normal)
    self.assertIsInstance(ds.m, tfd.Normal)
    self.assertIsInstance(ds.x, tfd.Sample)

    # Static properties.
    self.assertAllEqual(Model(e=tf.float32, scale=tf.float32, loc=tf.float32,
                              m=tf.float32, x=tf.int32),
                        d.dtype)

    batch_shape_tensor_, event_shape_tensor_ = self.evaluate([
        d.batch_shape_tensor(), d.event_shape_tensor()])

    expected_batch_shape = Model(e=[], scale=[], loc=[], m=[], x=[])
    for (expected, actual_tensorshape, actual_shape_tensor_) in zip(
        expected_batch_shape, d.batch_shape, batch_shape_tensor_):
      self.assertAllEqual(expected, actual_tensorshape)
      self.assertAllEqual(expected, actual_shape_tensor_)

    expected_event_shape = Model(e=[2], scale=[], loc=[], m=[], x=[12])
    for (expected, actual_tensorshape, actual_shape_tensor_) in zip(
        expected_event_shape, d.event_shape, event_shape_tensor_):
      self.assertAllEqual(expected, actual_tensorshape)
      self.assertAllEqual(expected, actual_shape_tensor_)

    expected_jlp = sum(d.log_prob(x) for d, x in zip(ds, xs))
    actual_jlp = d.log_prob(xs)
    self.assertAllClose(*self.evaluate([expected_jlp, actual_jlp]),
                        atol=0., rtol=1e-4)
Esempio n. 17
0
def basic_ordered_model_fn():
    return collections.OrderedDict((
        ('a', tfd.Normal(0., 1.)),
        ('e',
         tfd.Independent(tfd.Exponential(rate=[100, 120]),
                         reinterpreted_batch_ndims=1)),
        ('x', lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1])),
    ))
Esempio n. 18
0
 def testScalarBatchScalarEventIdentityScale(self):
   exp2 = self._cls()(
       tfd.Exponential(rate=0.25), bijector=tfb.AffineScalar(scale=2.))
   log_prob = exp2.log_prob(1.)
   log_prob_ = self.evaluate(log_prob)
   base_log_prob = -0.5 * 0.25 + np.log(0.25)
   ildj = np.log(2.)
   self.assertAllClose(base_log_prob - ildj, log_prob_, rtol=1e-6, atol=0.)
Esempio n. 19
0
 def test_kl_divergence(self):
   d0 = tfd.JointDistributionNamed(
       dict(e=tfd.Independent(tfd.Exponential(rate=[100, 120]), 1),
            x=tfd.Normal(loc=0, scale=2.)),
       validate_args=True)
   d1 = tfd.JointDistributionNamed(
       dict(e=tfd.Independent(tfd.Exponential(rate=[10, 12]), 1),
            x=tfd.Normal(loc=1, scale=1.)),
       validate_args=True)
   self.assertEqual(d0.model.keys(), d1.model.keys())
   expected_kl = sum(tfd.kl_divergence(d0.model[k], d1.model[k])
                     for k in d0.model.keys())
   actual_kl = tfd.kl_divergence(d0, d1)
   other_actual_kl = d0.kl_divergence(d1)
   expected_kl_, actual_kl_, other_actual_kl_ = self.evaluate([
       expected_kl, actual_kl, other_actual_kl])
   self.assertNear(expected_kl_, actual_kl_, err=1e-5)
   self.assertNear(expected_kl_, other_actual_kl_, err=1e-5)
Esempio n. 20
0
    def test_nested_partial_value(self, sample_fn):
        innermost = tfd.JointDistributionNamed({
            'a':
            tfd.Exponential(1.),
            'b':
            lambda a: tfd.Sample(tfd.LogNormal(a, a), [5]),
        })

        inner = tfd.JointDistributionNamed({
            'c': tfd.Exponential(1.),
            'd': innermost,
        })

        outer = tfd.JointDistributionNamed({
            'e': tfd.Exponential(1.),
            'f': inner,
        })

        seed = test_util.test_seed(sampler_type='stateless')
        true_xs = outer.sample(seed=seed)

        def _update(dict_, **kwargs):
            dict_.copy().update(**kwargs)
            return dict_

        # These asserts work because we advance the stateless seed inside the model
        # whether or not a sample is actually generated.
        partial_xs = _update(true_xs, f=None)
        xs = sample_fn(outer, value=partial_xs, seed=seed)
        self.assertAllCloseNested(true_xs, xs)

        partial_xs = _update(true_xs, e=None)
        xs = sample_fn(outer, value=partial_xs, seed=seed)
        self.assertAllCloseNested(true_xs, xs)

        partial_xs = _update(true_xs, f=_update(true_xs['f'], d=None))
        xs = sample_fn(outer, value=partial_xs, seed=seed)
        self.assertAllCloseNested(true_xs, xs)

        partial_xs = _update(true_xs,
                             f=_update(true_xs['f'],
                                       d=_update(true_xs['f']['d'], a=None)))
        xs = sample_fn(outer, value=partial_xs, seed=seed)
        self.assertAllCloseNested(true_xs, xs)
Esempio n. 21
0
 def testScalarBatchScalarEventIdentityScale(self):
   exp2 = tfd.TransformedDistribution(
       tfd.Exponential(rate=0.25),
       bijector=tfb.Scale(scale=2.),
       validate_args=True)
   log_prob = exp2.log_prob(1.)
   log_prob_ = self.evaluate(log_prob)
   base_log_prob = -0.5 * 0.25 + np.log(0.25)
   ildj = np.log(2.)
   self.assertAllClose(base_log_prob - ildj, log_prob_, rtol=1e-6, atol=0.)
Esempio n. 22
0
def nested_lists_model_fn():
  return collections.OrderedDict((
      ('abc', tfd.JointDistributionSequential([
          tfd.MultivariateNormalDiag([0., 0.], [1., 1.]),
          tfd.JointDistributionSequential(
              [tfd.StudentT(3., -2., 5.),
               tfd.Exponential(4.)])])),
      ('de', lambda abc: tfd.JointDistributionSequential([  # pylint: disable=g-long-lambda
          tfd.Normal(abc[0] * abc[1][0], abc[1][1]),
          tfd.Normal(abc[0] + abc[1][0], abc[1][1])]))))
Esempio n. 23
0
 def test_sample_shape_propagation_default_behavior(self):
     # pylint: disable=bad-whitespace
     d = tfd.JointDistributionNamed(dict(
         e=tfd.Independent(tfd.Exponential(rate=[100, 120]), 1),
         scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]),
         s=tfd.HalfNormal(2.5),
         loc=lambda s: tfd.Normal(loc=0, scale=s),
         df=tfd.Exponential(2),
         x=tfd.StudentT),
                                    validate_args=False)
     # pylint: enable=bad-whitespace
     x = d.sample([2, 3], seed=test_util.test_seed())
     self.assertLen(x, 6)
     self.assertEqual((2, 3, 2), x['e'].shape)
     self.assertEqual((2, 3), x['scale'].shape)
     self.assertEqual((2, 3), x['s'].shape)
     self.assertEqual((2, 3), x['loc'].shape)
     self.assertEqual((2, 3), x['df'].shape)
     self.assertEqual((2, 3), x['x'].shape)
     lp = d.log_prob(x)
     self.assertEqual((2, 3), lp.shape)
Esempio n. 24
0
    def test_legacy_dists_stateless_seed_raises(self):
        class StatefulNormal(tfd.Normal):
            def _sample_n(self, n, seed=None):
                return self.loc + self.scale * tf.random.normal(tf.concat(
                    [[n], self.batch_shape_tensor()], axis=0),
                                                                seed=seed)

        # pylint: disable=bad-whitespace
        d = tfd.JointDistributionNamed(dict(
            e=tfd.Independent(tfd.Exponential(rate=[100, 120]), 1),
            loc=StatefulNormal(loc=0, scale=2.),
            scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]),
            m=tfd.Normal,
            x=lambda m: tfd.Sample(tfd.Bernoulli(logits=m), 12)),
                                       validate_args=True)
        # pylint: enable=bad-whitespace

        with self.assertRaisesRegexp(TypeError, r'Expected int for argument'):
            d.sample(seed=samplers.zeros_seed())
Esempio n. 25
0
    def test_batch_slicing(self):
        # pylint: disable=bad-whitespace
        d = tfd.JointDistributionNamed(dict(
            s=tfd.Exponential(rate=[10, 12, 14]),
            n=lambda s: tfd.Normal(loc=0, scale=s),
            x=lambda: tfd.Beta(concentration0=[3, 2, 1], concentration1=1)),
                                       validate_args=True)
        # pylint: enable=bad-whitespace

        d0, d1 = d[:1], d[1:]
        x0 = d0.sample(seed=test_util.test_seed())
        x1 = d1.sample(seed=test_util.test_seed())

        self.assertLen(x0, 3)
        self.assertEqual([1], x0['s'].shape)
        self.assertEqual([1], x0['n'].shape)
        self.assertEqual([1], x0['x'].shape)

        self.assertLen(x1, 3)
        self.assertEqual([2], x1['s'].shape)
        self.assertEqual([2], x1['n'].shape)
        self.assertEqual([2], x1['x'].shape)
Esempio n. 26
0
  def test_custom_weights_prior(self):

    batch_shape = [4, 3]
    num_timesteps = 10
    num_features = 2
    design_matrix = self._build_placeholder(
        np.random.randn(*(batch_shape + [num_timesteps, num_features])))

    # Build a model with scalar Exponential(1.) prior.
    linear_regression = LinearRegression(
        design_matrix=design_matrix,
        weights_prior=tfd.Exponential(
            rate=self._build_placeholder(np.ones(batch_shape))))

    # Check that the prior is broadcast to match the shape of the weights.
    weights = linear_regression.parameters[0]
    self.assertAllEqual([num_features],
                        self.evaluate(weights.prior.event_shape_tensor()))
    self.assertAllEqual(batch_shape,
                        self.evaluate(weights.prior.batch_shape_tensor()))

    prior_sampled_weights = weights.prior.sample()
    ssm = linear_regression.make_state_space_model(
        num_timesteps=num_timesteps,
        param_vals={"weights": prior_sampled_weights})

    lp = ssm.log_prob(ssm.sample())
    self.assertAllEqual(batch_shape,
                        self.evaluate(lp).shape)

    # Verify that the bijector enforces the prior constraint that
    # weights must be nonnegative.
    self.assertAllFinite(
        self.evaluate(
            weights.prior.log_prob(
                weights.bijector(
                    tf.random.normal(tf.shape(weights.prior.sample(64)),
                                     seed=test_util.test_seed(),
                                     dtype=self.dtype)))))
Esempio n. 27
0
    def test_dict_sample_log_prob(self):
        # pylint: disable=bad-whitespace
        d = tfd.JointDistributionNamed(dict(
            e=tfd.Independent(tfd.Exponential(rate=[100, 120]), 1),
            scale=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]),
            loc=tfd.Normal(loc=0, scale=2.),
            m=tfd.Normal,
            x=lambda m: tfd.Sample(tfd.Bernoulli(logits=m), 12)),
                                       validate_args=True)
        # pylint: enable=bad-whitespace

        self.assertEqual((
            ('e', ()),
            ('scale', ('e', )),
            ('loc', ()),
            ('m', ('loc', 'scale')),
            ('x', ('m', )),
        ), d.resolve_graph())

        xs = d.sample(seed=test_util.test_seed())
        self.assertLen(xs, 5)
        # We'll verify the shapes work as intended when we plumb these back into the
        # respective log_probs.

        ds, _ = d.sample_distributions(value=xs)
        self.assertLen(ds, 5)
        self.assertIsInstance(ds['e'], tfd.Independent)
        self.assertIsInstance(ds['scale'], tfd.Gamma)
        self.assertIsInstance(ds['loc'], tfd.Normal)
        self.assertIsInstance(ds['m'], tfd.Normal)
        self.assertIsInstance(ds['x'], tfd.Sample)

        # Static properties.
        self.assertAllEqual(
            {
                'e': tf.float32,
                'scale': tf.float32,
                'loc': tf.float32,
                'm': tf.float32,
                'x': tf.int32
            }, d.dtype)

        batch_shape_tensor_, event_shape_tensor_ = self.evaluate(
            [d.batch_shape_tensor(),
             d.event_shape_tensor()])

        expected_batch_shape = {
            'e': [],
            'scale': [],
            'loc': [],
            'm': [],
            'x': []
        }
        batch_tensorshape = d.batch_shape
        for k in expected_batch_shape:
            self.assertAllEqual(expected_batch_shape[k], batch_tensorshape[k])
            self.assertAllEqual(expected_batch_shape[k],
                                batch_shape_tensor_[k])

        expected_event_shape = {
            'e': [2],
            'scale': [],
            'loc': [],
            'm': [],
            'x': [12]
        }
        event_tensorshape = d.event_shape
        for k in expected_event_shape:
            self.assertAllEqual(expected_event_shape[k], event_tensorshape[k])
            self.assertAllEqual(expected_event_shape[k],
                                event_shape_tensor_[k])

        expected_jlp = sum(ds[k].log_prob(xs[k]) for k in ds.keys())
        actual_jlp = d.log_prob(xs)
        self.assertAllClose(*self.evaluate([expected_jlp, actual_jlp]),
                            atol=0.,
                            rtol=1e-4)
Esempio n. 28
0
                       tfd.Normal(loc=-1., scale=0.1),
                       tfd.Normal(loc=-0.5, scale=0.1),
                       tfd.Normal(loc=0., scale=0.1),
                       tfd.Normal(loc=0.5, scale=0.1),
                       tfd.Normal(loc=1., scale=0.1)
                   ])

AsymmetricDoubleClaw = tfd.Mixture(cat=tfd.Categorical(probs=[
    46. / 100., 46. / 100, 1. / 300., 1. / 300., 1. / 300., 7. / 300., 7. /
    300., 7. / 300.
]),
                                   components=[
                                       tfd.Normal(loc=-1., scale=2. / 3.),
                                       tfd.Normal(loc=1., scale=2. / 3.),
                                       tfd.Normal(loc=-1. / 2., scale=0.01),
                                       tfd.Normal(loc=-1., scale=0.01),
                                       tfd.Normal(loc=-3. / 2., scale=0.01),
                                       tfd.Normal(loc=1. / 2., scale=0.07),
                                       tfd.Normal(loc=1., scale=0.07),
                                       tfd.Normal(loc=3. / 2., scale=0.07)
                                   ])

Mix3gauss1exp1uni = tfd.Mixture(
    cat=tfd.Categorical(probs=[0.1, 0.2, 0.1, 0.4, 0.2]),
    components=[
        tfd.Normal(loc=-1., scale=0.4),
        tfd.Normal(loc=+1., scale=0.5),
        tfd.Normal(loc=+1., scale=0.3),
        tfd.Exponential(rate=2),
        tfd.Uniform(low=-5, high=5)
    ])
Esempio n. 29
0
 def test_slicing_does_not_modify_the_sliced_distribution(self):
   dist = tfd.Exponential(tf.ones((5, 2, 3)))
   sliced = dist[:4, :, 2]
   self.assertAllEqual([2], sliced[-1].batch_shape_tensor())
   self.assertAllEqual([3], sliced[:-1, 1].batch_shape_tensor())
Esempio n. 30
0
class MarkovChainBijectorTest(test_util.TestCase):

    # pylint: disable=g-long-lambda
    @parameterized.named_parameters(
        dict(testcase_name='deterministic_prior',
             prior_fn=lambda: tfd.Deterministic([-100., 0., 100.]),
             transition_fn=lambda _, x: tfd.Normal(loc=x, scale=1.)),
        dict(testcase_name='deterministic_transition',
             prior_fn=lambda: tfd.Normal(loc=[-100., 0., 100.], scale=1.),
             transition_fn=lambda _, x: tfd.Deterministic(x)),
        dict(testcase_name='fully_deterministic',
             prior_fn=lambda: tfd.Deterministic([-100., 0., 100.]),
             transition_fn=lambda _, x: tfd.Deterministic(x)),
        dict(testcase_name='mvn_diag',
             prior_fn=(lambda: tfd.MultivariateNormalDiag(loc=[[2.], [2.]],
                                                          scale_diag=[1.])),
             transition_fn=lambda _, x: tfd.VectorDeterministic(x)),
        dict(testcase_name='docstring_dirichlet',
             prior_fn=lambda: tfd.JointDistributionNamedAutoBatched(
                 {'probs': tfd.Dirichlet([1., 1.])}),
             transition_fn=lambda _, x: tfd.JointDistributionNamedAutoBatched(
                 {
                     'probs':
                     tfd.MultivariateNormalDiag(loc=x['probs'],
                                                scale_diag=[0.1, 0.1])
                 },
                 batch_ndims=ps.rank(x['probs']))),
        dict(testcase_name='uniform_step',
             prior_fn=lambda: tfd.Exponential(tf.ones([4, 1])),
             transition_fn=lambda _, x: tfd.Uniform(low=x, high=x + 1.)),
        dict(testcase_name='joint_distribution',
             prior_fn=lambda: tfd.JointDistributionNamedAutoBatched(
                 batch_ndims=2,
                 model={
                     'a':
                     tfd.Gamma(tf.zeros([5]), 1.),
                     'b':
                     lambda a: (tfb.Reshape(event_shape_in=[4, 3],
                                            event_shape_out=[2, 3, 2])
                                (tfd.Independent(tfd.Normal(
                                    loc=tf.zeros([5, 4, 3]),
                                    scale=a[..., tf.newaxis, tf.newaxis]),
                                                 reinterpreted_batch_ndims=2)))
                 }),
             transition_fn=lambda _, x: tfd.JointDistributionNamedAutoBatched(
                 batch_ndims=ps.rank_from_shape(x['a'].shape),
                 model={
                     'a':
                     tfd.Normal(loc=x['a'], scale=1.),
                     'b':
                     lambda a: tfd.Deterministic(x['b'] + a[
                         ..., tf.newaxis, tf.newaxis, tf.newaxis])
                 })),
        dict(testcase_name='nested_chain',
             prior_fn=lambda: tfd.
             MarkovChain(initial_state_prior=tfb.Split(2)
                         (tfd.MultivariateNormalDiag(0., [1., 2.])),
                         transition_fn=lambda _, x: tfb.Split(2)
                         (tfd.MultivariateNormalDiag(x[0], [1., 2.])),
                         num_steps=6),
             transition_fn=(
                 lambda _, x: tfd.JointDistributionSequentialAutoBatched(
                     [
                         tfd.MultivariateNormalDiag(x[0], [1.]),
                         tfd.MultivariateNormalDiag(x[1], [1.])
                     ],
                     batch_ndims=ps.rank(x[0])))))
    # pylint: enable=g-long-lambda
    def test_default_bijector(self, prior_fn, transition_fn):
        chain = tfd.MarkovChain(initial_state_prior=prior_fn(),
                                transition_fn=transition_fn,
                                num_steps=7)

        y = self.evaluate(chain.sample(seed=test_util.test_seed()))
        bijector = chain.experimental_default_event_space_bijector()

        self.assertAllEqual(chain.batch_shape_tensor(),
                            bijector.experimental_batch_shape_tensor())

        x = bijector.inverse(y)
        yy = bijector.forward(tf.nest.map_structure(
            tf.identity, x))  # Bypass bijector cache.
        self.assertAllCloseNested(y, yy)

        chain_event_ndims = tf.nest.map_structure(ps.rank_from_shape,
                                                  chain.event_shape_tensor())
        self.assertAllEqualNested(bijector.inverse_min_event_ndims,
                                  chain_event_ndims)

        ildj = bijector.inverse_log_det_jacobian(
            tf.nest.map_structure(tf.identity, y),  # Bypass bijector cache.
            event_ndims=chain_event_ndims)
        if not bijector.is_constant_jacobian:
            self.assertAllEqual(ildj.shape, chain.batch_shape)
        fldj = bijector.forward_log_det_jacobian(
            tf.nest.map_structure(tf.identity, x),  # Bypass bijector cache.
            event_ndims=bijector.inverse_event_ndims(chain_event_ndims))
        self.assertAllClose(ildj, -fldj)

        # Verify that event shapes are passed through and flattened/unflattened
        # correctly.
        inverse_event_shapes = bijector.inverse_event_shape(chain.event_shape)
        x_event_shapes = tf.nest.map_structure(
            lambda t, nd: t.shape[ps.rank(t) - nd:], x,
            bijector.forward_min_event_ndims)
        self.assertAllEqualNested(inverse_event_shapes, x_event_shapes)
        forward_event_shapes = bijector.forward_event_shape(
            inverse_event_shapes)
        self.assertAllEqualNested(forward_event_shapes, chain.event_shape)

        # Verify that the outputs of other methods have the correct structure.
        inverse_event_shape_tensors = bijector.inverse_event_shape_tensor(
            chain.event_shape_tensor())
        self.assertAllEqualNested(inverse_event_shape_tensors, x_event_shapes)
        forward_event_shape_tensors = bijector.forward_event_shape_tensor(
            inverse_event_shape_tensors)
        self.assertAllEqualNested(forward_event_shape_tensors,
                                  chain.event_shape_tensor())