Exemplo n.º 1
0
 def test_dtype_is_consistent_with_tfp(self, dist_fn, bijector_fn):
     base = dist_fn()
     bijector = bijector_fn()
     dist = transformed.Transformed(base, bijector)
     tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base),
                                            bijector)
     assert dist.dtype == tfp_dist.dtype
Exemplo n.º 2
0
  def test_batched_bijector_shapes(self, batch_shape, sample_shape):
    base = tfd.MultivariateNormalDiag(jnp.zeros(3), jnp.ones(3))
    bijector = block.Block(tfb.Scale(jnp.ones(batch_shape + (3,))), 1)
    dist = transformed.Transformed(base, bijector)

    with self.subTest('batch_shape'):
      chex.assert_equal(dist.batch_shape, batch_shape)

    with self.subTest('sample.shape'):
      sample = dist.sample(seed=self.seed, sample_shape=sample_shape)
      chex.assert_equal(sample.shape, sample_shape + batch_shape + (3,))

    with self.subTest('sample_and_log_prob sample.shape'):
      sample, log_prob = dist.sample_and_log_prob(
          seed=self.seed, sample_shape=sample_shape)
      chex.assert_equal(sample.shape, sample_shape + batch_shape + (3,))

    with self.subTest('sample_and_log_prob log_prob.shape'):
      sample, log_prob = dist.sample_and_log_prob(
          seed=self.seed, sample_shape=sample_shape)
      chex.assert_equal(log_prob.shape, sample_shape + batch_shape)

    with self.subTest('sample_and_log_prob log_prob value'):
      sample, log_prob = dist.sample_and_log_prob(
          seed=self.seed, sample_shape=sample_shape)
      np.testing.assert_allclose(log_prob, dist.log_prob(sample))
Exemplo n.º 3
0
 def test_dtype_is_as_expected(self, dist_fn, bijector_fn, expected_dtype):
   base = dist_fn()
   bijector = bijector_fn()
   dist = transformed.Transformed(base, bijector)
   sample = self.variant(dist.sample)(seed=self.seed)
   assert dist.dtype == sample.dtype
   assert dist.dtype == expected_dtype
Exemplo n.º 4
0
    def test_sample_and_log_prob(self, mu, sigma, sample_shape, base_dist):
        base = base_dist(mu, sigma)
        bijector = chain.Chain([tfb.Scale(10), tfb.Tanh(), tfb.Scale(0.1)])
        dist = transformed.Transformed(base, bijector)

        def sample_and_log_prob_fn(seed, sample_shape):
            return dist.sample_and_log_prob(seed=seed,
                                            sample_shape=sample_shape)

        samples, log_prob = self.variant(sample_and_log_prob_fn,
                                         ignore_argnums=(1, ),
                                         static_argnums=(1, ))(self.seed,
                                                               sample_shape)
        expected_samples = bijector.forward(
            base.sample(seed=self.seed, sample_shape=sample_shape))

        tfp_bijector = tfb.Chain([tfb.Scale(10), tfb.Tanh(), tfb.Scale(0.1)])
        tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base),
                                               tfp_bijector)
        tfp_samples = tfp_dist.sample(seed=self.seed,
                                      sample_shape=sample_shape)
        tfp_log_prob = tfp_dist.log_prob(samples)

        chex.assert_equal_shape([samples, tfp_samples])
        np.testing.assert_allclose(log_prob, tfp_log_prob, rtol=RTOL)
        np.testing.assert_allclose(samples, expected_samples, rtol=RTOL)
Exemplo n.º 5
0
    def test_event_shape(self, mu, sigma, base_dist):
        base = base_dist(mu, sigma)
        bijector = tfb.Scale(2)
        dist = transformed.Transformed(base, bijector)
        tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base),
                                               bijector)

        assert dist.event_shape == tfp_dist.event_shape
Exemplo n.º 6
0
  def test_prob(self, mu, sigma, value, base_dist):
    base = base_dist(mu, sigma)
    bijector = tfb.Scale(2)
    dist = transformed.Transformed(base, bijector)
    actual = self.variant(dist.prob)(value)

    tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), bijector)
    expected = tfp_dist.prob(value)
    np.testing.assert_array_equal(actual, expected)
Exemplo n.º 7
0
  def test_method(self, function_string, mu, sigma, base_dist):
    base = base_dist(mu, sigma)
    bijector = tfb.Scale(2)
    dist = transformed.Transformed(base, bijector)
    tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), bijector)

    np.testing.assert_allclose(
        self.variant(getattr(dist, function_string))(),
        getattr(tfp_dist, function_string)())
Exemplo n.º 8
0
  def test_jittable(self):
    @jax.jit
    def f(x, d):
      return d.log_prob(x)

    base = normal.Normal(0, 1)
    bijector = scalar_affine.ScalarAffine(0, 1)
    dist = transformed.Transformed(base, bijector)
    x = np.zeros(())
    f(x, dist)
Exemplo n.º 9
0
    def test_event_shape(self, mu, sigma, base_dist):
        base = base_dist(mu, sigma)
        bijector = lambda_bijector.Lambda(jnp.tanh)
        dist = transformed.Transformed(base, bijector)

        tfp_bijector = tfb.Tanh()
        tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base),
                                               tfp_bijector)

        assert dist.event_shape == tfp_dist.event_shape
Exemplo n.º 10
0
    def test_event_shape(self, mu, sigma, base_dist):
        base = base_dist(mu, sigma)
        bijector = chain.Chain([tfb.Scale(2), tfb.Shift(3), tfb.Tanh()])
        dist = transformed.Transformed(base, bijector)

        tfp_bijector = tfb.Chain([tfb.Scale(2), tfb.Shift(3), tfb.Tanh()])
        tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base),
                                               tfp_bijector)

        assert dist.event_shape == tfp_dist.event_shape
Exemplo n.º 11
0
  def test_prob(self, mu, sigma, value, base_dist):
    base = base_dist(mu, sigma)
    bijector = tanh.Tanh()
    dist = transformed.Transformed(base, bijector)
    actual = self.variant(dist.prob)(value)

    tfp_bijector = tfb.Tanh()
    tfp_dist = tfd.TransformedDistribution(
        conversion.to_tfp(base), tfp_bijector)
    expected = tfp_dist.prob(value)
    np.testing.assert_allclose(actual, expected, atol=1e-9)
Exemplo n.º 12
0
  def test_sample_shape(self, mu, sigma, sample_shape, base_dist):
    base = base_dist(mu, sigma)
    bijector = tfb.Scale(2)
    dist = transformed.Transformed(base, bijector)
    def sample_fn(seed, sample_shape):
      return dist.sample(seed=seed, sample_shape=sample_shape)
    samples = self.variant(sample_fn, ignore_argnums=(1,), static_argnums=1)(
        self.seed, sample_shape)

    tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), bijector)
    tfp_samples = tfp_dist.sample(sample_shape=sample_shape, seed=self.seed)

    chex.assert_equal_shape([samples, tfp_samples])
Exemplo n.º 13
0
    def test_method(self, function_string, mu, sigma, base_dist):
        base = base_dist(mu, sigma)
        bijector = lambda_bijector.Lambda(lambda x: x + 3)
        dist = transformed.Transformed(base, bijector)

        tfp_bijector = tfb.Shift(3)
        tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base),
                                               tfp_bijector)

        np.testing.assert_allclose(self.variant(getattr(dist,
                                                        function_string))(),
                                   getattr(tfp_dist, function_string)(),
                                   rtol=RTOL)
Exemplo n.º 14
0
  def test_integer_inputs(self, inputs, base_dist):
    base = base_dist(jnp.zeros_like(inputs, dtype=jnp.float32),
                     jnp.ones_like(inputs, dtype=jnp.float32))
    bijector = scalar_affine.ScalarAffine(shift=0.0)
    dist = transformed.Transformed(base, bijector)

    log_prob = self.variant(dist.log_prob)(inputs)

    standard_normal_log_prob_of_zero = -0.9189385
    expected_log_prob = jnp.full_like(
        inputs, standard_normal_log_prob_of_zero, dtype=jnp.float32)

    np.testing.assert_array_equal(log_prob, expected_log_prob)
Exemplo n.º 15
0
    def test_auto_lambda(self, forward_fn, tfp_bijector, base_dist):
        mu = np.array([-1.0, 0.0, 1.0], dtype=np.float32)
        sigma = np.array([0.5, 1.0, 2.5], dtype=np.float32)
        base = base_dist(mu, sigma)

        dist = transformed.Transformed(base, forward_fn)
        tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base),
                                               tfp_bijector())

        y = np.array([0.05, 0.3, 0.95], dtype=np.float32)

        lp_y = tfp_dist.log_prob(y)
        lp_y_ = self.variant(dist.log_prob)(y)
        np.testing.assert_allclose(lp_y_, lp_y, rtol=RTOL)

        p_y = tfp_dist.prob(y)
        p_y_ = self.variant(dist.prob)(y)
        np.testing.assert_allclose(p_y_, p_y, rtol=RTOL)
Exemplo n.º 16
0
  def test_bijector_that_assumes_batch_dimensions(self):
    # Create a Haiku conditioner that assumes a single batch dimension.
    def forward(x):
      network = hk.Sequential([hk.Flatten(preserve_dims=1), hk.Linear(3)])
      return network(x)
    init, apply = hk.transform(forward)
    params = init(self.seed, jnp.ones((2, 3)))
    conditioner = functools.partial(apply, params, self.seed)

    bijector = masked_coupling.MaskedCoupling(
        jnp.ones(3) > 0, conditioner, tfb.Scale)

    base = tfd.MultivariateNormalDiag(jnp.zeros((2, 3)), jnp.ones((2, 3)))
    dist = transformed.Transformed(base, bijector)
    # Exercise the trace-based functions
    assert dist.batch_shape == (2,)
    assert dist.event_shape == (3,)
    assert dist.dtype == jnp.float32
    sample = self.variant(dist.sample)(seed=self.seed)
    assert sample.dtype == dist.dtype
    self.variant(dist.log_prob)(sample)
Exemplo n.º 17
0
  def test_batched_bijector_against_tfp(
      self, bijector_fn, block_ndims, bijector_shape, params_shape):

    base = tfd.MultivariateNormalDiag(
        jnp.zeros(params_shape), jnp.ones(params_shape))

    tfp_bijector = bijector_fn(jnp.ones(bijector_shape))
    dx_bijector = block.Block(tfp_bijector, block_ndims)

    dx_dist = transformed.Transformed(base, dx_bijector)
    tfp_dist = tfd.TransformedDistribution(
        conversion.to_tfp(base), tfp_bijector)

    with self.subTest('event_shape property matches TFP'):
      np.testing.assert_equal(dx_dist.event_shape, tfp_dist.event_shape)

    with self.subTest('sample shape matches TFP'):
      dx_sample = self.variant(dx_dist.sample)(seed=self.seed)
      tfp_sample = self.variant(tfp_dist.sample)(seed=self.seed)
      chex.assert_equal_shape([dx_sample, tfp_sample])

    with self.subTest('log_prob(dx_sample) matches TFP'):
      dx_logp_dx = self.variant(dx_dist.log_prob)(dx_sample)
      tfp_logp_dx = self.variant(tfp_dist.log_prob)(dx_sample)
      np.testing.assert_allclose(dx_logp_dx, tfp_logp_dx, rtol=RTOL)

    with self.subTest('log_prob(tfp_sample) matches TFP'):
      dx_logp_tfp = self.variant(dx_dist.log_prob)(tfp_sample)
      tfp_logp_tfp = self.variant(tfp_dist.log_prob)(tfp_sample)
      np.testing.assert_allclose(dx_logp_tfp, tfp_logp_tfp, rtol=RTOL)

    with self.subTest('sample/lp shape is self-consistent'):
      second_sample, log_prob = self.variant(dx_dist.sample_and_log_prob)(
          seed=self.seed)
      chex.assert_equal_shape([dx_sample, second_sample])
      chex.assert_equal_shape([dx_logp_dx, log_prob])
Exemplo n.º 18
0
 def test_raises_on_incorrect_shape(self, block_dims):
   base = tfd.MultivariateNormalDiag(jnp.zeros((2, 3)), jnp.ones((2, 3)))
   scalar_bijector = tfb.Scale(jnp.ones((1, 2, 3)))
   block_bijector = block.Block(scalar_bijector, block_dims)
   with self.assertRaises(ValueError):
     transformed.Transformed(base, block_bijector)