Example #1
0
    def test_sample_and_log_prob(self, mu, sigma, sample_shape, base_dist):
        base = base_dist(mu, sigma)
        bijector = chain.Chain([tfb.Scale(10), tfb.Tanh(), tfb.Scale(0.1)])
        dist = transformed.Transformed(base, bijector)

        def sample_and_log_prob_fn(seed, sample_shape):
            return dist.sample_and_log_prob(seed=seed,
                                            sample_shape=sample_shape)

        samples, log_prob = self.variant(sample_and_log_prob_fn,
                                         ignore_argnums=(1, ),
                                         static_argnums=(1, ))(self.seed,
                                                               sample_shape)
        expected_samples = bijector.forward(
            base.sample(seed=self.seed, sample_shape=sample_shape))

        tfp_bijector = tfb.Chain([tfb.Scale(10), tfb.Tanh(), tfb.Scale(0.1)])
        tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base),
                                               tfp_bijector)
        tfp_samples = tfp_dist.sample(seed=self.seed,
                                      sample_shape=sample_shape)
        tfp_log_prob = tfp_dist.log_prob(samples)

        chex.assert_equal_shape([samples, tfp_samples])
        np.testing.assert_allclose(log_prob, tfp_log_prob, rtol=RTOL)
        np.testing.assert_allclose(samples, expected_samples, rtol=RTOL)
Example #2
0
    def test_jittable(self):
        @jax.jit
        def f(x, b):
            return b.forward(x)

        bijector = chain.Chain([scalar_affine.ScalarAffine(0, 1)])
        x = np.zeros(())
        f(x, bijector)
Example #3
0
    def test_integer_inputs(self, inputs):
        bijector = chain.Chain([scalar_affine.ScalarAffine(shift=1.0)])
        output, log_det = self.variant(bijector.forward_and_log_det)(inputs)

        expected_out = jnp.array(inputs, dtype=jnp.float32) + 1.0
        expected_log_det = jnp.zeros_like(inputs, dtype=jnp.float32)

        np.testing.assert_array_equal(output, expected_out)
        np.testing.assert_array_equal(log_det, expected_log_det)
Example #4
0
    def test_event_shape(self, mu, sigma, base_dist):
        base = base_dist(mu, sigma)
        bijector = chain.Chain([tfb.Scale(2), tfb.Shift(3), tfb.Tanh()])
        dist = transformed.Transformed(base, bijector)

        tfp_bijector = tfb.Chain([tfb.Scale(2), tfb.Shift(3), tfb.Tanh()])
        tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base),
                                               tfp_bijector)

        assert dist.event_shape == tfp_dist.event_shape
Example #5
0
    def test_prob(self, mu, sigma, value, base_dist):
        base = base_dist(mu, sigma)
        bijector = chain.Chain([tfb.Scale(2), tfb.Shift(3), tfb.Tanh()])
        dist = transformed.Transformed(base, bijector)
        actual = self.variant(dist.prob)(value)

        tfp_bijector = tfb.Chain([tfb.Scale(2), tfb.Shift(3), tfb.Tanh()])
        tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base),
                                               tfp_bijector)
        expected = tfp_dist.prob(value)
        np.testing.assert_array_equal(actual, expected)
Example #6
0
    def test_method(self, function_string, mu, sigma, base_dist):
        base = base_dist(mu, sigma)
        bijector = chain.Chain([tfb.Identity(), tfb.Scale(2), tfb.Shift(3)])
        dist = transformed.Transformed(base, bijector)

        tfp_bijector = tfb.Chain([tfb.Identity(), tfb.Scale(2), tfb.Shift(3)])
        tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base),
                                               tfp_bijector)

        np.testing.assert_allclose(
            self.variant(getattr(dist, function_string))(),
            getattr(tfp_dist, function_string)())
Example #7
0
    def test_batched_bijectors(self, outer_shape, inner_shape, input_shape):
        outer_bij = tfb.Shift(5. * jnp.ones(outer_shape))
        inner_bij = tfb.Shift(7. * jnp.ones(inner_shape))

        dx_bij = chain.Chain([outer_bij, inner_bij])
        tfb_bij = tfb.Chain([outer_bij, inner_bij])

        x = jnp.zeros(input_shape)
        dx_y = self.variant(dx_bij.forward)(x)
        tfb_y = self.variant(tfb_bij.forward)(x)

        chex.assert_equal_shape([dx_y, tfb_y])
        np.testing.assert_allclose(dx_y, tfb_y, rtol=RTOL)
Example #8
0
    def test_sample_shape(self, mu, sigma, sample_shape, base_dist):
        base = base_dist(mu, sigma)
        bijector = chain.Chain([tfb.Scale(2), tfb.Shift(3), tfb.Tanh()])
        dist = transformed.Transformed(base, bijector)

        def sample_fn(seed, sample_shape):
            return dist.sample(seed=seed, sample_shape=sample_shape)

        samples = self.variant(sample_fn,
                               ignore_argnums=(1, ),
                               static_argnums=1)(self.seed, sample_shape)

        tfp_bijector = tfb.Chain([tfb.Scale(2), tfb.Shift(3), tfb.Tanh()])
        tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base),
                                               tfp_bijector)
        tfp_samples = tfp_dist.sample(sample_shape=sample_shape,
                                      seed=self.seed)

        chex.assert_equal_shape([samples, tfp_samples])
Example #9
0
 def test_properties(self):
     bijector = chain.Chain([tfb.Scale(2), tfb.Shift(3), jnp.tanh])
     for bij in bijector.bijectors:
         assert isinstance(bij, base_bijector.Bijector)
Example #10
0
 def test_raises_on_incompatible_dimensions(self):
     with self.assertRaises(ValueError):
         chain.Chain([jnp.log, block.Block(jnp.exp, 1)])
Example #11
0
 def test_raises_on_empty_list(self):
     with self.assertRaises(ValueError):
         chain.Chain([])