Example #1
0
  def test_batched_bijector_shapes(self, batch_shape, sample_shape):
    base = tfd.MultivariateNormalDiag(jnp.zeros(3), jnp.ones(3))
    bijector = block.Block(tfb.Scale(jnp.ones(batch_shape + (3,))), 1)
    dist = transformed.Transformed(base, bijector)

    with self.subTest('batch_shape'):
      chex.assert_equal(dist.batch_shape, batch_shape)

    with self.subTest('sample.shape'):
      sample = dist.sample(seed=self.seed, sample_shape=sample_shape)
      chex.assert_equal(sample.shape, sample_shape + batch_shape + (3,))

    with self.subTest('sample_and_log_prob sample.shape'):
      sample, log_prob = dist.sample_and_log_prob(
          seed=self.seed, sample_shape=sample_shape)
      chex.assert_equal(sample.shape, sample_shape + batch_shape + (3,))

    with self.subTest('sample_and_log_prob log_prob.shape'):
      sample, log_prob = dist.sample_and_log_prob(
          seed=self.seed, sample_shape=sample_shape)
      chex.assert_equal(log_prob.shape, sample_shape + batch_shape)

    with self.subTest('sample_and_log_prob log_prob value'):
      sample, log_prob = dist.sample_and_log_prob(
          seed=self.seed, sample_shape=sample_shape)
      np.testing.assert_allclose(log_prob, dist.log_prob(sample))
Example #2
0
    def test_jittable(self):
        @jax.jit
        def f(x, b):
            return b.forward(x)

        bijector = block_bijector.Block(scalar_affine.ScalarAffine(0), 1)
        x = np.zeros((2, 3))
        f(x, bijector)
Example #3
0
 def test_raises_on_invalid_input_shape(self):
     bij = block_bijector.Block(lambda x: x, 1)
     for fn in [
             bij.forward, bij.inverse, bij.forward_log_det_jacobian,
             bij.inverse_log_det_jacobian, bij.forward_and_log_det,
             bij.inverse_and_log_det
     ]:
         with self.assertRaises(ValueError):
             fn(jnp.array(0))
Example #4
0
 def test_kl_divergence_raises_on_event_shape(self):
   base_dist1 = tfd.MultivariateNormalDiag([0.1, 0.5, 0.9], [0.1, 1.1, 2.5])
   base_dist2 = tfd.Normal(-0.1, 1.5)
   bij1 = block.Block(lambda_bijector.Lambda(lambda x: x), ndims=1)
   bij2 = lambda_bijector.Lambda(lambda x: x)
   distrax_dist1 = transformed.Transformed(base_dist1, bij1)
   distrax_dist2 = transformed.Transformed(base_dist2, bij2)
   with self.assertRaises(ValueError):
     distrax_dist1.kl_divergence(distrax_dist2)
Example #5
0
 def test_raises_if_inner_bijector_is_not_scalar(self):
     key = jax.random.PRNGKey(101)
     event_shape = (2, 3)
     bijector = masked_coupling.MaskedCoupling(
         mask=jax.random.choice(key, jnp.array([True, False]), event_shape),
         conditioner=lambda x: x,
         bijector=lambda _: block.Block(lambda x: x, 1))
     with self.assertRaises(ValueError):
         bijector.forward_and_log_det(jnp.zeros(event_shape))
     with self.assertRaises(ValueError):
         bijector.inverse_and_log_det(jnp.zeros(event_shape))
Example #6
0
 def test_against_tfp_semantics(self, tfp_bijector_fn, ndims):
     tfp_bijector = tfp_bijector_fn()
     x = jax.random.normal(self.seed, [2, 3, 4, 5, 6])
     y = tfp_bijector(x)
     fwd_event_ndims = ndims + tfp_bijector.forward_min_event_ndims
     inv_event_ndims = ndims + tfp_bijector.inverse_min_event_ndims
     block = block_bijector.Block(tfp_bijector, ndims)
     np.testing.assert_allclose(
         tfp_bijector.forward_log_det_jacobian(x, fwd_event_ndims),
         self.variant(block.forward_log_det_jacobian)(x),
         atol=2e-5)
     np.testing.assert_allclose(
         tfp_bijector.inverse_log_det_jacobian(y, inv_event_ndims),
         self.variant(block.inverse_log_det_jacobian)(y),
         atol=2e-5)
Example #7
0
 def test_raises_on_invalid_inner_bijector(self):
     event_shape = (2, 3)
     bij = split_coupling.SplitCoupling(
         split_index=event_shape[-1] // 2,
         event_ndims=len(event_shape),
         conditioner=lambda x: x,
         bijector=lambda _: block.Block(lambda x: x,
                                        len(event_shape) + 1))
     for fn in [
             bij.forward, bij.inverse, bij.forward_log_det_jacobian,
             bij.inverse_log_det_jacobian, bij.forward_and_log_det,
             bij.inverse_and_log_det
     ]:
         with self.assertRaises(ValueError):
             fn(jnp.zeros(event_shape))
Example #8
0
 def test_forward_inverse_work_as_expected(self, bijector_fn, ndims):
     bijct = conversion.as_bijector(bijector_fn())
     x = jax.random.normal(self.seed, [2, 3])
     block = block_bijector.Block(bijct, ndims)
     np.testing.assert_array_equal(
         self.variant(bijct.forward)(x),
         self.variant(block.forward)(x))
     np.testing.assert_array_equal(
         self.variant(bijct.inverse)(x),
         self.variant(block.inverse)(x))
     np.testing.assert_array_equal(
         self.variant(bijct.forward_and_log_det)(x)[0],
         self.variant(block.forward_and_log_det)(x)[0])
     np.testing.assert_array_equal(
         self.variant(bijct.inverse_and_log_det)(x)[0],
         self.variant(block.inverse_and_log_det)(x)[0])
Example #9
0
 def _inner_bijector(self, params: BijectorParams) -> base.Bijector:
     """Returns an inner bijector for the passed params."""
     bijector = conversion.as_bijector(self._bijector(params))
     if bijector.event_ndims_in != bijector.event_ndims_out:
         raise ValueError(
             f'The inner bijector must have `event_ndims_in==event_ndims_out`. '
             f'Instead, it has `event_ndims_in=={bijector.event_ndims_in}` and '
             f'`event_ndims_out=={bijector.event_ndims_out}`.')
     extra_ndims = self.event_ndims_in - bijector.event_ndims_in
     if extra_ndims < 0:
         raise ValueError(
             f'The inner bijector can\'t have more event dimensions than the '
             f'coupling bijector. Got {bijector.event_ndims_in} for the inner '
             f'bijector and {self.event_ndims_in} for the coupling bijector.'
         )
     elif extra_ndims > 0:
         bijector = block.Block(bijector, extra_ndims)
     return bijector
Example #10
0
 def test_log_det_jacobian_works_as_expected(self, bijector_fn, ndims):
     bijct = conversion.as_bijector(bijector_fn())
     x = jax.random.normal(self.seed, [2, 3])
     block = block_bijector.Block(bijct, ndims)
     axes = tuple(range(-ndims, 0))
     np.testing.assert_allclose(
         self.variant(bijct.forward_log_det_jacobian)(x).sum(axes),
         self.variant(block.forward_log_det_jacobian)(x),
         rtol=RTOL)
     np.testing.assert_allclose(
         self.variant(bijct.inverse_log_det_jacobian)(x).sum(axes),
         self.variant(block.inverse_log_det_jacobian)(x),
         rtol=RTOL)
     np.testing.assert_allclose(
         self.variant(bijct.forward_and_log_det)(x)[1].sum(axes),
         self.variant(block.forward_and_log_det)(x)[1],
         rtol=RTOL)
     np.testing.assert_allclose(
         self.variant(bijct.inverse_and_log_det)(x)[1].sum(axes),
         self.variant(block.inverse_and_log_det)(x)[1],
         rtol=RTOL)
Example #11
0
  def test_batched_bijector_against_tfp(
      self, bijector_fn, block_ndims, bijector_shape, params_shape):

    base = tfd.MultivariateNormalDiag(
        jnp.zeros(params_shape), jnp.ones(params_shape))

    tfp_bijector = bijector_fn(jnp.ones(bijector_shape))
    dx_bijector = block.Block(tfp_bijector, block_ndims)

    dx_dist = transformed.Transformed(base, dx_bijector)
    tfp_dist = tfd.TransformedDistribution(
        conversion.to_tfp(base), tfp_bijector)

    with self.subTest('event_shape property matches TFP'):
      np.testing.assert_equal(dx_dist.event_shape, tfp_dist.event_shape)

    with self.subTest('sample shape matches TFP'):
      dx_sample = self.variant(dx_dist.sample)(seed=self.seed)
      tfp_sample = self.variant(tfp_dist.sample)(seed=self.seed)
      chex.assert_equal_shape([dx_sample, tfp_sample])

    with self.subTest('log_prob(dx_sample) matches TFP'):
      dx_logp_dx = self.variant(dx_dist.log_prob)(dx_sample)
      tfp_logp_dx = self.variant(tfp_dist.log_prob)(dx_sample)
      np.testing.assert_allclose(dx_logp_dx, tfp_logp_dx, rtol=RTOL)

    with self.subTest('log_prob(tfp_sample) matches TFP'):
      dx_logp_tfp = self.variant(dx_dist.log_prob)(tfp_sample)
      tfp_logp_tfp = self.variant(tfp_dist.log_prob)(tfp_sample)
      np.testing.assert_allclose(dx_logp_tfp, tfp_logp_tfp, rtol=RTOL)

    with self.subTest('sample/lp shape is self-consistent'):
      second_sample, log_prob = self.variant(dx_dist.sample_and_log_prob)(
          seed=self.seed)
      chex.assert_equal_shape([dx_sample, second_sample])
      chex.assert_equal_shape([dx_logp_dx, log_prob])
Example #12
0
 def test_invalid_properties(self):
     bijct = conversion.as_bijector(jnp.tanh)
     with self.assertRaises(ValueError):
         block_bijector.Block(bijct, -1)
Example #13
0
 def test_properties(self):
     bijct = conversion.as_bijector(jnp.tanh)
     block = block_bijector.Block(bijct, 1)
     assert block.ndims == 1
     assert isinstance(block.bijector, base.Bijector)
Example #14
0
 def test_raises_on_incorrect_shape(self, block_dims):
   base = tfd.MultivariateNormalDiag(jnp.zeros((2, 3)), jnp.ones((2, 3)))
   scalar_bijector = tfb.Scale(jnp.ones((1, 2, 3)))
   block_bijector = block.Block(scalar_bijector, block_dims)
   with self.assertRaises(ValueError):
     transformed.Transformed(base, block_bijector)
Example #15
0
 def test_raises_on_incompatible_dimensions(self):
     with self.assertRaises(ValueError):
         chain.Chain([jnp.log, block.Block(jnp.exp, 1)])