def test_logit_ildj(self): def naive_logit(x): # This is the default JAX implementation of logit. return np.log(x / (1. - x)) naive_inv = core.inverse(naive_logit) naive_ildj = core.ildj(naive_logit) with self.assertRaises(ValueError): naive_inv(-100.) with self.assertRaises(ValueError): naive_ildj(-100.) f_inv = core.inverse(jax.scipy.special.logit) f_ildj = core.ildj(jax.scipy.special.logit) onp.testing.assert_allclose(f_inv(-100.), jax.scipy.special.expit(-100.)) onp.testing.assert_allclose( f_ildj(-100.), tfb.Sigmoid().forward_log_det_jacobian(-100., 0))
def test_sigmoid_ildj(self): def naive_sigmoid(x): # This is the default JAX implementation of sigmoid. return 1. / (1 + np.exp(-x)) naive_inv = core.inverse(naive_sigmoid) naive_ildj = core.ildj(naive_sigmoid) with self.assertRaises(AssertionError): onp.testing.assert_allclose(naive_inv(0.9999), jax.scipy.special.logit(0.9999)) with self.assertRaises(AssertionError): onp.testing.assert_allclose(naive_ildj(0.9999), tfb.Sigmoid().inverse_log_det_jacobian( 0.9999, 0)) f_inv = core.inverse(jax.nn.sigmoid) f_ildj = core.ildj(jax.nn.sigmoid) onp.testing.assert_allclose(f_inv(0.9999), jax.scipy.special.logit(0.9999)) onp.testing.assert_allclose(f_ildj(0.9999), tfb.Sigmoid().inverse_log_det_jacobian( 0.9999, 0))
def test_binary_inverse_and_ildj(self): @custom_inverse.custom_inverse def add(x, y): return x + y def add_ildj(invals, z, z_ildj): x, y = invals if z is None: raise custom_inverse.NonInvertibleError() if x is not None and y is None: return (x, z - x), (jnp.zeros_like(z_ildj), z_ildj) elif x is None and y is not None: return (z - y, y), (z_ildj, jnp.zeros_like(z_ildj)) return (None, None), (0., 0.) add_one_left = functools.partial(add, 1.) add_one_right = lambda x: add(x, 1.) self.assertEqual(core.inverse(add_one_left)(2.), 1.) add.def_inverse_and_ildj(add_ildj) self.assertEqual(core.inverse(add_one_left)(2.), 1.) self.assertEqual(core.inverse(add_one_right)(2.), 1.) def add_ildj2(invals, z, z_ildj): x, y = invals if z is None: raise custom_inverse.NonInvertibleError() if x is not None and y is None: return (x, z**2), (jnp.zeros_like(z_ildj), 3. + z_ildj) elif x is None and y is not None: return (z / 4, y), (2. + z_ildj, jnp.zeros_like(z_ildj)) return (None, None), (0., 0.) add.def_inverse_and_ildj(add_ildj2) self.assertEqual(core.inverse(add_one_left)(2.), 4.) self.assertEqual(core.ildj(add_one_left)(2.), 3.) self.assertEqual(core.inverse(add_one_right)(2.), 0.5) self.assertEqual(core.ildj(add_one_right)(2.), 2.)
def test_unary_ildj(self): @custom_inverse.custom_inverse def add_one(x): return x + 1. def add_one_ildj(y): del y return 4. add_one.def_inverse_unary(f_ildj=add_one_ildj) self.assertEqual(core.inverse(add_one)(2.), 1.) self.assertEqual(core.ildj(add_one)(2.), 4.)
def test_unary_inverse_and_ildj(self): @custom_inverse.custom_inverse def add_one(x): return x + 1. def add_one_inv(y): return jnp.exp(y) def add_one_ildj(y): del y return 4. add_one.def_inverse_unary(add_one_inv, add_one_ildj) self.assertEqual(core.inverse(add_one)(2.), jnp.exp(2.)) self.assertEqual(core.ildj(add_one)(2.), 4.)
def def_inverse_unary(self, f_inv=None, f_ildj=None): """Defines a unary inverse rule. Args: f_inv: An optional unary function that returns the inverse of this function. If not provided, we automatically invert the forward function. f_ildj: An optional unary function that computes the ILDJ of this function. If it is not provided, we automatically compute it from `f_inv`. """ f_inv_ = f_inv or core.inverse(self.func) f_ildj_ = f_ildj or core.ildj(core.inverse(f_inv_), reduce_ildj=False) def f_inverse_and_ildj(invals, outval, outildj): inval = invals[0] if outval is None: raise NonInvertibleError() if inval is None: new_inval = f_inv_(outval) new_inildj = f_ildj_(outval) + outildj return (new_inval, ), (new_inildj, ) raise NonInvertibleError() self.def_inverse_and_ildj(f_inverse_and_ildj)
def wrapped(*args): mapped_args = mapping_fn(*args) ildjs = inverse.ildj(mapping_fn, *args)(mapped_args) return target_log_prob(mapped_args) - np.sum( np.array(tree_util.tree_leaves(ildjs)))
def test_sqrt_ildj(self): def f(x): return np.sqrt(x) f_ildj = core.ildj(f) onp.testing.assert_allclose(f_ildj(3.), np.log(2.) + np.log(3.))
def test_reciprocal_ildj(self): def f(x): return np.reciprocal(x) f_ildj = core.ildj(f) onp.testing.assert_allclose(f_ildj(2.), onp.log(1 / 4.))