예제 #1
0
    def testKLVersusNormal(self):
        loc, scale = jnp.array([2.0]), jnp.array([2.0])
        log_scale = jnp.log(scale)
        lsn_prior = lsn.LogStddevNormal(jnp.array([0.0]), jnp.array([0.0]))
        n_prior = normal.Normal(jnp.array([0.0]), jnp.array([1.0]))
        lsn_dist = lsn.LogStddevNormal(loc, log_scale)
        n_dist = normal.Normal(loc, scale)

        kl1 = tfp.distributions.kl_divergence(lsn_dist, lsn_prior)
        kl2 = tfp.distributions.kl_divergence(n_dist, lsn_prior)
        kl3 = tfp.distributions.kl_divergence(n_dist, n_prior)
        np.testing.assert_allclose(kl2, kl1)
        np.testing.assert_allclose(kl3, kl2)
        np.testing.assert_allclose(kl1, kl3)
예제 #2
0
  def test_jittable(self):
    @jax.jit
    def f(x, d):
      return d.log_prob(x)

    base = normal.Normal(0, 1)
    bijector = scalar_affine.ScalarAffine(0, 1)
    dist = transformed.Transformed(base, bijector)
    x = np.zeros(())
    f(x, dist)
예제 #3
0
 def test_cdf(self, function_string, distr_params, value):
   distr_params = {
       k: np.asarray(v, dtype=np.float32) for k, v in distr_params.items()}
   value = np.asarray(value)
   dist = self.distrax_cls(**distr_params)
   result = self.variant(getattr(dist, function_string))(value)
   # The `cdf` is not implemented in TFP, so we test against a `Normal`.
   loc = 0. if 'loc' not in distr_params else distr_params['loc']
   univariate_normal = normal.Normal(loc, distr_params['scale_diag'])
   expected_result = getattr(univariate_normal, function_string)(value)
   if function_string == 'cdf':
     reduce_fn = lambda x: jnp.prod(x, axis=-1)
   elif function_string == 'log_cdf':
     reduce_fn = lambda x: jnp.sum(x, axis=-1)
   expected_result = reduce_fn(expected_result)
   self.assertion_fn(result, expected_result)
예제 #4
0
 def test_constructor_is_jittable_given_ndims(self, ndims):
   base = normal.Normal(loc=jnp.zeros((2, 3)), scale=jnp.ones((2, 3)))
   constructor = lambda d: independent.Independent(d, ndims)
   jax.jit(constructor)(base)
예제 #5
0
 def _make_distrax_base_distribution(self, loc, scale):
   return normal.Normal(loc=loc, scale=scale)