def test_dtype_is_consistent_with_tfp(self, dist_fn, bijector_fn):
     base = dist_fn()
     bijector = bijector_fn()
     dist = transformed.Transformed(base, bijector)
     tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base),
                                            bijector)
     assert dist.dtype == tfp_dist.dtype
Exemple #2
0
 def test_on_distrax_bijector(self):
     bij = Tanh()
     wrapped_bij = conversion.to_tfp(bij)
     assert isinstance(wrapped_bij, Tanh)
     # Call the `forward` attribute of a wrapped Tanh.
     np.testing.assert_equal(wrapped_bij.forward(np.zeros(())),
                             bij.forward(np.zeros(())))
Exemple #3
0
    def test_sample_and_log_prob(self, mu, sigma, sample_shape, base_dist):
        base = base_dist(mu, sigma)
        bijector = inverse.Inverse(tfb.Scale(2))
        dist = transformed.Transformed(base, bijector)

        def sample_and_log_prob_fn(seed, sample_shape):
            return dist.sample_and_log_prob(seed=seed,
                                            sample_shape=sample_shape)

        samples, log_prob = self.variant(sample_and_log_prob_fn,
                                         ignore_argnums=(1, ),
                                         static_argnums=(1, ))(self.seed,
                                                               sample_shape)
        expected_samples = bijector.forward(
            base.sample(seed=self.seed, sample_shape=sample_shape))

        tfp_bijector = tfb.Invert(tfb.Scale(2))
        tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base),
                                               tfp_bijector)
        tfp_samples = tfp_dist.sample(seed=self.seed,
                                      sample_shape=sample_shape)
        tfp_log_prob = tfp_dist.log_prob(samples)

        chex.assert_equal_shape([samples, tfp_samples])
        np.testing.assert_allclose(log_prob, tfp_log_prob, rtol=RTOL)
        np.testing.assert_allclose(samples, expected_samples, rtol=RTOL)
Exemple #4
0
  def test_event_shape(self, mu, sigma, base_dist):
    base = base_dist(mu, sigma)
    bijector = tfb.Scale(2)
    dist = transformed.Transformed(base, bijector)
    tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), bijector)

    assert dist.event_shape == tfp_dist.event_shape
Exemple #5
0
  def test_method(self, function_string, mu, sigma, base_dist):
    base = base_dist(mu, sigma)
    bijector = tfb.Scale(2)
    dist = transformed.Transformed(base, bijector)
    tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), bijector)

    np.testing.assert_allclose(
        self.variant(getattr(dist, function_string))(),
        getattr(tfp_dist, function_string)())
Exemple #6
0
  def test_prob(self, mu, sigma, value, base_dist):
    base = base_dist(mu, sigma)
    bijector = tfb.Scale(2)
    dist = transformed.Transformed(base, bijector)
    actual = self.variant(dist.prob)(value)

    tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), bijector)
    expected = tfp_dist.prob(value)
    np.testing.assert_array_equal(actual, expected)
    def test_event_shape(self, mu, sigma, base_dist):
        base = base_dist(mu, sigma)
        bijector = lambda_bijector.Lambda(jnp.tanh)
        dist = transformed.Transformed(base, bijector)

        tfp_bijector = tfb.Tanh()
        tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base),
                                               tfp_bijector)

        assert dist.event_shape == tfp_dist.event_shape
Exemple #8
0
    def test_event_shape(self, mu, sigma, base_dist):
        base = base_dist(mu, sigma)
        bijector = chain.Chain([tfb.Scale(2), tfb.Shift(3), tfb.Tanh()])
        dist = transformed.Transformed(base, bijector)

        tfp_bijector = tfb.Chain([tfb.Scale(2), tfb.Shift(3), tfb.Tanh()])
        tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base),
                                               tfp_bijector)

        assert dist.event_shape == tfp_dist.event_shape
Exemple #9
0
  def test_prob(self, mu, sigma, value, base_dist):
    base = base_dist(mu, sigma)
    bijector = tanh.Tanh()
    dist = transformed.Transformed(base, bijector)
    actual = self.variant(dist.prob)(value)

    tfp_bijector = tfb.Tanh()
    tfp_dist = tfd.TransformedDistribution(
        conversion.to_tfp(base), tfp_bijector)
    expected = tfp_dist.prob(value)
    np.testing.assert_allclose(actual, expected, atol=1e-9)
    def test_method(self, function_string, mu, sigma, base_dist):
        base = base_dist(mu, sigma)
        bijector = lambda_bijector.Lambda(lambda x: x + 3)
        dist = transformed.Transformed(base, bijector)

        tfp_bijector = tfb.Shift(3)
        tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base),
                                               tfp_bijector)

        np.testing.assert_allclose(self.variant(getattr(dist,
                                                        function_string))(),
                                   getattr(tfp_dist, function_string)(),
                                   rtol=RTOL)
Exemple #11
0
  def test_sample_shape(self, mu, sigma, sample_shape, base_dist):
    base = base_dist(mu, sigma)
    bijector = tfb.Scale(2)
    dist = transformed.Transformed(base, bijector)
    def sample_fn(seed, sample_shape):
      return dist.sample(seed=seed, sample_shape=sample_shape)
    samples = self.variant(sample_fn, ignore_argnums=(1,), static_argnums=1)(
        self.seed, sample_shape)

    tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), bijector)
    tfp_samples = tfp_dist.sample(sample_shape=sample_shape, seed=self.seed)

    chex.assert_equal_shape([samples, tfp_samples])
    def test_auto_lambda(self, forward_fn, tfp_bijector, base_dist):
        mu = np.array([-1.0, 0.0, 1.0], dtype=np.float32)
        sigma = np.array([0.5, 1.0, 2.5], dtype=np.float32)
        base = base_dist(mu, sigma)

        dist = transformed.Transformed(base, forward_fn)
        tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base),
                                               tfp_bijector())

        y = np.array([0.05, 0.3, 0.95], dtype=np.float32)

        lp_y = tfp_dist.log_prob(y)
        lp_y_ = self.variant(dist.log_prob)(y)
        np.testing.assert_allclose(lp_y_, lp_y, rtol=RTOL)

        p_y = tfp_dist.prob(y)
        p_y_ = self.variant(dist.prob)(y)
        np.testing.assert_allclose(p_y_, p_y, rtol=RTOL)
Exemple #13
0
  def test_batched_bijector_against_tfp(
      self, bijector_fn, block_ndims, bijector_shape, params_shape):

    base = tfd.MultivariateNormalDiag(
        jnp.zeros(params_shape), jnp.ones(params_shape))

    tfp_bijector = bijector_fn(jnp.ones(bijector_shape))
    dx_bijector = block.Block(tfp_bijector, block_ndims)

    dx_dist = transformed.Transformed(base, dx_bijector)
    tfp_dist = tfd.TransformedDistribution(
        conversion.to_tfp(base), tfp_bijector)

    with self.subTest('event_shape property matches TFP'):
      np.testing.assert_equal(dx_dist.event_shape, tfp_dist.event_shape)

    with self.subTest('sample shape matches TFP'):
      dx_sample = self.variant(dx_dist.sample)(seed=self.seed)
      tfp_sample = self.variant(tfp_dist.sample)(seed=self.seed)
      chex.assert_equal_shape([dx_sample, tfp_sample])

    with self.subTest('log_prob(dx_sample) matches TFP'):
      dx_logp_dx = self.variant(dx_dist.log_prob)(dx_sample)
      tfp_logp_dx = self.variant(tfp_dist.log_prob)(dx_sample)
      np.testing.assert_allclose(dx_logp_dx, tfp_logp_dx, rtol=RTOL)

    with self.subTest('log_prob(tfp_sample) matches TFP'):
      dx_logp_tfp = self.variant(dx_dist.log_prob)(tfp_sample)
      tfp_logp_tfp = self.variant(tfp_dist.log_prob)(tfp_sample)
      np.testing.assert_allclose(dx_logp_tfp, tfp_logp_tfp, rtol=RTOL)

    with self.subTest('sample/lp shape is self-consistent'):
      second_sample, log_prob = self.variant(dx_dist.sample_and_log_prob)(
          seed=self.seed)
      chex.assert_equal_shape([dx_sample, second_sample])
      chex.assert_equal_shape([dx_logp_dx, log_prob])
Exemple #14
0
 def test_on_tfp_bijector(self):
     bij = tfb.Exp()
     wrapped_bij = conversion.to_tfp(bij)
     self.assertIs(wrapped_bij, bij)
Exemple #15
0
 def test_on_tfp_distribution(self):
     dist = tfd.Normal(0., 1.)
     wrapped_dist = conversion.to_tfp(dist)
     self.assertIs(wrapped_dist, dist)
Exemple #16
0
 def test_on_distrax_distribution(self):
     dist = Normal(loc=0., scale=1.)
     wrapped_dist = conversion.to_tfp(dist)
     assert isinstance(wrapped_dist, Normal)
     # Access the `loc` attribute of a wrapped Normal.
     np.testing.assert_almost_equal(wrapped_dist.loc, 0.)