Esempio n. 1
0
 def test_forward_methods_are_correct(self):
     key = jax.random.PRNGKey(42)
     x = jax.random.normal(key, (2, 3, 4, 5))
     bij_no_scale = scalar_affine.ScalarAffine(shift=3.)
     bij_with_scale = scalar_affine.ScalarAffine(shift=3., scale=1.)
     bij_with_log_scale = scalar_affine.ScalarAffine(shift=3., log_scale=0.)
     for bij in [bij_no_scale, bij_with_scale, bij_with_log_scale]:
         y, logdet = self.variant(bij.forward_and_log_det)(x)
         np.testing.assert_allclose(y, x + 3., atol=1e-8)
         np.testing.assert_allclose(logdet, 0., atol=1e-8)
Esempio n. 2
0
 def test_inverse_methods_are_correct(self):
     k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(42), 4)
     x = jax.random.normal(k1, (2, 3, 4, 5))
     shift = jax.random.normal(k2, (4, 5))
     scale = jax.random.uniform(k3, (3, 4, 5)) + 0.1
     log_scale = jax.random.normal(k4, (3, 4, 5))
     bij_no_scale = scalar_affine.ScalarAffine(shift)
     bij_with_scale = scalar_affine.ScalarAffine(shift, scale=scale)
     bij_with_log_scale = scalar_affine.ScalarAffine(shift,
                                                     log_scale=log_scale)
     for bij in [bij_no_scale, bij_with_scale, bij_with_log_scale]:
         y, logdet_fwd = self.variant(bij.forward_and_log_det)(x)
         x_rec, logdet_inv = self.variant(bij.inverse_and_log_det)(y)
         np.testing.assert_allclose(x_rec, x, atol=1e-5)
         np.testing.assert_allclose(logdet_fwd, -logdet_inv, atol=3e-6)
Esempio n. 3
0
 def test_properties(self):
     bij = scalar_affine.ScalarAffine(shift=0., scale=1.)
     self.assertTrue(bij.is_constant_jacobian)
     self.assertTrue(bij.is_constant_log_det)
     np.testing.assert_allclose(bij.shift, 0.)
     np.testing.assert_allclose(bij.scale, 1.)
     np.testing.assert_allclose(bij.log_scale, 0.)
Esempio n. 4
0
    def test_jittable(self):
        @jax.jit
        def f(x, b):
            return b.forward(x)

        bijector = block_bijector.Block(scalar_affine.ScalarAffine(0), 1)
        x = np.zeros((2, 3))
        f(x, bijector)
Esempio n. 5
0
    def test_jittable(self):
        @jax.jit
        def f(x, b):
            return b.forward(x)

        bijector = scalar_affine.ScalarAffine(0, 1)
        x = np.zeros(())
        f(x, bijector)
Esempio n. 6
0
    def test_integer_inputs(self, inputs):
        bijector = inverse.Inverse(scalar_affine.ScalarAffine(shift=1.0))
        output, log_det = self.variant(bijector.forward_and_log_det)(inputs)

        expected_out = jnp.array(inputs, dtype=jnp.float32) - 1.0
        expected_log_det = jnp.zeros_like(inputs, dtype=jnp.float32)

        np.testing.assert_array_equal(output, expected_out)
        np.testing.assert_array_equal(log_det, expected_log_det)
Esempio n. 7
0
 def test_shapes_are_correct(self):
     k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(42), 4)
     x = jax.random.normal(k1, (2, 3, 4, 5))
     shift = jax.random.normal(k2, (4, 5))
     scale = jax.random.uniform(k3, (3, 4, 5)) + 0.1
     log_scale = jax.random.normal(k4, (3, 4, 5))
     bij_no_scale = scalar_affine.ScalarAffine(shift)
     bij_with_scale = scalar_affine.ScalarAffine(shift, scale=scale)
     bij_with_log_scale = scalar_affine.ScalarAffine(shift,
                                                     log_scale=log_scale)
     for bij in [bij_no_scale, bij_with_scale, bij_with_log_scale]:
         # Forward methods.
         y, logdet = self.variant(bij.forward_and_log_det)(x)
         self.assertEqual(y.shape, (2, 3, 4, 5))
         self.assertEqual(logdet.shape, (2, 3, 4, 5))
         # Inverse methods.
         x, logdet = self.variant(bij.inverse_and_log_det)(y)
         self.assertEqual(x.shape, (2, 3, 4, 5))
         self.assertEqual(logdet.shape, (2, 3, 4, 5))
Esempio n. 8
0
  def test_jittable(self):
    @jax.jit
    def f(x, d):
      return d.log_prob(x)

    base = normal.Normal(0, 1)
    bijector = scalar_affine.ScalarAffine(0, 1)
    dist = transformed.Transformed(base, bijector)
    x = np.zeros(())
    f(x, dist)
Esempio n. 9
0
    def test_batched_parameters(self, scale_batch_shape, shift_batch_shape,
                                input_batch_shape):
        k1, k2, k3 = jax.random.split(jax.random.PRNGKey(42), 3)
        log_scale = jax.random.normal(k1, scale_batch_shape)
        shift = jax.random.normal(k2, shift_batch_shape)
        bijector = scalar_affine.ScalarAffine(shift, log_scale=log_scale)

        x = jax.random.normal(k3, input_batch_shape)
        y, logdet_fwd = self.variant(bijector.forward_and_log_det)(x)
        z, logdet_inv = self.variant(bijector.inverse_and_log_det)(x)

        output_batch_shape = jnp.broadcast_arrays(log_scale, shift, x)[0].shape

        self.assertEqual(y.shape, output_batch_shape)
        self.assertEqual(z.shape, output_batch_shape)
        self.assertEqual(logdet_fwd.shape, output_batch_shape)
        self.assertEqual(logdet_inv.shape, output_batch_shape)

        log_scale = jnp.broadcast_to(log_scale, output_batch_shape).flatten()
        shift = jnp.broadcast_to(shift, output_batch_shape).flatten()
        x = jnp.broadcast_to(x, output_batch_shape).flatten()
        y = y.flatten()
        z = z.flatten()
        logdet_fwd = logdet_fwd.flatten()
        logdet_inv = logdet_inv.flatten()

        for i in range(np.prod(output_batch_shape)):
            bijector = scalar_affine.ScalarAffine(shift[i],
                                                  jnp.exp(log_scale[i]))
            this_y, this_logdet_fwd = self.variant(
                bijector.forward_and_log_det)(x[i])
            this_z, this_logdet_inv = self.variant(
                bijector.inverse_and_log_det)(x[i])
            np.testing.assert_allclose(this_y, y[i], atol=1e-7)
            np.testing.assert_allclose(this_z, z[i], atol=1e-5)
            np.testing.assert_allclose(this_logdet_fwd,
                                       logdet_fwd[i],
                                       atol=1e-4)
            np.testing.assert_allclose(this_logdet_inv,
                                       logdet_inv[i],
                                       atol=1e-4)
Esempio n. 10
0
  def test_integer_inputs(self, inputs, base_dist):
    base = base_dist(jnp.zeros_like(inputs, dtype=jnp.float32),
                     jnp.ones_like(inputs, dtype=jnp.float32))
    bijector = scalar_affine.ScalarAffine(shift=0.0)
    dist = transformed.Transformed(base, bijector)

    log_prob = self.variant(dist.log_prob)(inputs)

    standard_normal_log_prob_of_zero = -0.9189385
    expected_log_prob = jnp.full_like(
        inputs, standard_normal_log_prob_of_zero, dtype=jnp.float32)

    np.testing.assert_array_equal(log_prob, expected_log_prob)
Esempio n. 11
0
 def test_composite_methods_are_consistent(self):
     k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(42), 4)
     bij = scalar_affine.ScalarAffine(shift=jax.random.normal(k1, (4, 5)),
                                      log_scale=jax.random.normal(
                                          k2, (4, 5)))
     # Forward methods.
     x = jax.random.normal(k3, (2, 3, 4, 5))
     y1 = self.variant(bij.forward)(x)
     logdet1 = self.variant(bij.forward_log_det_jacobian)(x)
     y2, logdet2 = self.variant(bij.forward_and_log_det)(x)
     np.testing.assert_allclose(y1, y2, atol=1e-12)
     np.testing.assert_allclose(logdet1, logdet2, atol=1e-12)
     # Inverse methods.
     y = jax.random.normal(k4, (2, 3, 4, 5))
     x1 = self.variant(bij.inverse)(y)
     logdet1 = self.variant(bij.inverse_log_det_jacobian)(y)
     x2, logdet2 = self.variant(bij.inverse_and_log_det)(y)
     np.testing.assert_allclose(x1, x2, atol=1e-12)
     np.testing.assert_allclose(logdet1, logdet2, atol=1e-12)
Esempio n. 12
0
 def test_raises_if_both_scale_and_log_scale_are_specified(self):
     with self.assertRaises(ValueError):
         scalar_affine.ScalarAffine(shift=0., scale=1., log_scale=0.)