Example #1
0
    def test_sample_and_log_prob(self, mu, sigma, sample_shape, base_dist):
        base = base_dist(mu, sigma)
        bijector = inverse.Inverse(tfb.Scale(2))
        dist = transformed.Transformed(base, bijector)

        def sample_and_log_prob_fn(seed, sample_shape):
            return dist.sample_and_log_prob(seed=seed,
                                            sample_shape=sample_shape)

        samples, log_prob = self.variant(sample_and_log_prob_fn,
                                         ignore_argnums=(1, ),
                                         static_argnums=(1, ))(self.seed,
                                                               sample_shape)
        expected_samples = bijector.forward(
            base.sample(seed=self.seed, sample_shape=sample_shape))

        tfp_bijector = tfb.Invert(tfb.Scale(2))
        tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base),
                                               tfp_bijector)
        tfp_samples = tfp_dist.sample(seed=self.seed,
                                      sample_shape=sample_shape)
        tfp_log_prob = tfp_dist.log_prob(samples)

        chex.assert_equal_shape([samples, tfp_samples])
        np.testing.assert_allclose(log_prob, tfp_log_prob, rtol=RTOL)
        np.testing.assert_allclose(samples, expected_samples, rtol=RTOL)
Example #2
0
    def test_jittable(self):
        @jax.jit
        def f(x, b):
            return b.forward(x)

        bijector = inverse.Inverse(scalar_affine.ScalarAffine(0, 1))
        x = np.zeros(())
        f(x, bijector)
Example #3
0
    def test_integer_inputs(self, inputs):
        bijector = inverse.Inverse(scalar_affine.ScalarAffine(shift=1.0))
        output, log_det = self.variant(bijector.forward_and_log_det)(inputs)

        expected_out = jnp.array(inputs, dtype=jnp.float32) - 1.0
        expected_log_det = jnp.zeros_like(inputs, dtype=jnp.float32)

        np.testing.assert_array_equal(output, expected_out)
        np.testing.assert_array_equal(log_det, expected_log_det)
Example #4
0
    def test_event_shape(self, mu, sigma, base_dist):
        base = base_dist(mu, sigma)
        bijector = inverse.Inverse(tfb.Scale(2))
        dist = transformed.Transformed(base, bijector)

        tfp_bijector = tfb.Invert(tfb.Scale(2))
        tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base),
                                               tfp_bijector)

        assert dist.event_shape == tfp_dist.event_shape
Example #5
0
    def test_prob(self, mu, sigma, value, base_dist):
        base = base_dist(mu, sigma)
        bijector = inverse.Inverse(tfb.Scale(2))
        dist = transformed.Transformed(base, bijector)
        actual = self.variant(dist.prob)(value)

        tfp_bijector = tfb.Invert(tfb.Scale(2))
        tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base),
                                               tfp_bijector)
        expected = tfp_dist.prob(value)
        np.testing.assert_allclose(actual, expected, atol=1e-9)
Example #6
0
    def test_method(self, function_string, mu, sigma, base_dist):
        base = base_dist(mu, sigma)
        bijector = inverse.Inverse(tfb.Scale(2))
        dist = transformed.Transformed(base, bijector)

        tfp_bijector = tfb.Invert(tfb.Scale(2))
        tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base),
                                               tfp_bijector)

        np.testing.assert_allclose(
            self.variant(getattr(dist, function_string))(),
            getattr(tfp_dist, function_string)())
Example #7
0
    def test_sample_shape(self, mu, sigma, sample_shape, base_dist):
        base = base_dist(mu, sigma)
        bijector = inverse.Inverse(tfb.Scale(2))
        dist = transformed.Transformed(base, bijector)

        def sample_fn(seed, sample_shape):
            return dist.sample(seed=seed, sample_shape=sample_shape)

        samples = self.variant(sample_fn,
                               ignore_argnums=(1, ),
                               static_argnums=1)(self.seed, sample_shape)

        tfp_bijector = tfb.Invert(tfb.Scale(2))
        tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base),
                                               tfp_bijector)
        tfp_samples = tfp_dist.sample(sample_shape=sample_shape,
                                      seed=self.seed)

        chex.assert_equal_shape([samples, tfp_samples])
Example #8
0
 def test_properties(self):
     bijector = inverse.Inverse(tfb.Scale(2))
     assert isinstance(bijector.bijector, base_bijector.Bijector)