def test_trivial_noninvertible(self): def f(x): del x return 1. with self.assertRaises(ValueError): core.inverse(f)(1.)
def test_inverse_with_tuple_inputs(self): @custom_inverse.custom_inverse def dense(params, x): w, b = params return w * x + b def dense_ildj(invals, out, out_ildj): (w, b), _ = invals if w is None or b is None: raise custom_inverse.NonInvertibleError() in_ildj = core.ildj(lambda x: w * x + b)(out) return ((w, b), (out - b) / w), ((jnp.zeros_like(w), jnp.zeros_like(b)), out_ildj + in_ildj) dense_apply = functools.partial(dense, (2., 1.)) self.assertEqual(core.inverse(dense_apply)(5.), 2.) dense.def_inverse_and_ildj(dense_ildj) self.assertEqual(core.inverse(dense_apply)(5.), 2.) def dense_ildj2(invals, out, out_ildj): (w, b), _ = invals if w is None and b is None: raise custom_inverse.NonInvertibleError() in_ildj = jnp.ones_like(out) return ((w, b), w * out), ((jnp.zeros_like(w), jnp.zeros_like(b)), out_ildj + in_ildj) dense.def_inverse_and_ildj(dense_ildj2) self.assertEqual(core.inverse(dense_apply)(5.), 10.) self.assertEqual(core.ildj(dense_apply)(5.), 1.)
def test_pow_inverse(self): def f(x, y): return lax.pow(x, y) f_x_inv = core.inverse(lambda x: f(x, 2.)) onp.testing.assert_allclose(f_x_inv(2.), np.sqrt(2.)) f_y_inv = core.inverse(lambda y: f(2., y)) onp.testing.assert_allclose(f_y_inv(3.), np.log(3.) / np.log(2.))
def test_div_inverse(self): def f(x): return x / 2. f_inv = core.inverse(f) onp.testing.assert_allclose(f_inv(1.0), 2.) def f2(x): return 2. / x f2_inv = core.inverse(f2) onp.testing.assert_allclose(f2_inv(1.0), 2.)
def test_simple_inverse(self): def f(x): return np.exp(x) f_inv = core.inverse(f, 0.1) onp.testing.assert_allclose(f_inv(1.0), 0.) def f2(x): return np.exp(x) f2_inv = core.inverse(f2, np.zeros(2)) onp.testing.assert_allclose(f2_inv(np.ones(2)), np.zeros(2))
def test_trivial_inverse(self): def f(x): return x f_inv = core.inverse(f) onp.testing.assert_allclose(f_inv(1.0), 1.0) def f2(x, y): return x, y f2_inv = core.inverse(f2) onp.testing.assert_allclose(f2_inv(1.0, 2.0), (1.0, 2.0))
def test_mul_inverse(self): def f(x): return x * 2. f_inv = core.inverse(f) onp.testing.assert_allclose(f_inv(1.0), 0.5) def f2(x): return 2. * x f2_inv = core.inverse(f2) onp.testing.assert_allclose(f2_inv(1.0), 0.5)
def test_integer_pow_inverse(self): def f(x): return lax.integer_pow(x, 2) f_inv = core.inverse(f) onp.testing.assert_allclose(f_inv(2.), np.sqrt(2.)) def f2(x): return lax.integer_pow(x, 3) f2_inv = core.inverse(f2) onp.testing.assert_allclose(f2_inv(2.), onp.cbrt(2.))
def test_noninvertible_error_should_cause_unary_inverse_to_fail(self): @custom_inverse.custom_inverse def add_one(x): return x + 1. def add_one_inv(_): raise custom_inverse.NonInvertibleError() add_one.def_inverse_unary(add_one_inv) with self.assertRaises(ValueError): core.inverse(add_one)(1.)
def test_unary_inverse(self): @custom_inverse.custom_inverse def add_one(x): return x + 1. self.assertEqual(core.inverse(add_one)(1.), 0.)
def test_advanced_inverse_two(self): def f(x, y): return np.exp(x), x**2 + y f_inv = core.inverse(f, 0.1, 0.2) onp.testing.assert_allclose(f_inv(2.0, 2.0), (np.log(2.), 2 - np.log(2.)**2))
def test_advanced_inverse_three(self): def f(x, y, z): return np.exp(x), x ** 2 + y, np.exp(z + y) f_inv = core.inverse(f, 0., 0., 0.) onp.testing.assert_allclose(f_inv(2.0, 2.0, 2.0), (np.log(2.), 2 - np.log(2.) ** 2, np.log(2.0) - (2 - np.log(2.) ** 2)))
def test_logit_ildj(self): def naive_logit(x): # This is the default JAX implementation of logit. return np.log(x / (1. - x)) naive_inv = core.inverse(naive_logit) naive_ildj = core.ildj(naive_logit) with self.assertRaises(ValueError): naive_inv(-100.) with self.assertRaises(ValueError): naive_ildj(-100.) f_inv = core.inverse(jax.scipy.special.logit) f_ildj = core.ildj(jax.scipy.special.logit) onp.testing.assert_allclose(f_inv(-100.), jax.scipy.special.expit(-100.)) onp.testing.assert_allclose(f_ildj(-100.), tfb.Sigmoid().forward_log_det_jacobian( -100., 0))
def test_sigmoid_ildj(self): def naive_sigmoid(x): # This is the default JAX implementation of sigmoid. return 1. / (1 + np.exp(-x)) naive_inv = core.inverse(naive_sigmoid) naive_ildj = core.ildj(naive_sigmoid) with self.assertRaises(AssertionError): onp.testing.assert_allclose(naive_inv(0.9999), jax.scipy.special.logit(0.9999)) with self.assertRaises(AssertionError): onp.testing.assert_allclose(naive_ildj(0.9999), tfb.Sigmoid().inverse_log_det_jacobian( 0.9999, 0)) f_inv = core.inverse(jax.nn.sigmoid) f_ildj = core.ildj(jax.nn.sigmoid) onp.testing.assert_allclose(f_inv(0.9999), jax.scipy.special.logit(0.9999)) onp.testing.assert_allclose(f_ildj(0.9999), tfb.Sigmoid().inverse_log_det_jacobian( 0.9999, 0))
def test_noninvertible_error_should_cause_binary_inverse_to_fail(self): @custom_inverse.custom_inverse def add(x, y): return x + y def add_ildj(invals, z, z_ildj): x, y = invals if z is None: raise custom_inverse.NonInvertibleError() if x is not None and y is None: return (x, z - x), (jnp.zeros_like(z_ildj), z_ildj) elif x is None and y is not None: # Cannot invert if we don't know x raise custom_inverse.NonInvertibleError() return (None, None), (0., 0.) add.def_inverse_and_ildj(add_ildj) core.inverse(lambda x: add(1., x))(2.) with self.assertRaises(ValueError): core.inverse(lambda x: add(x, 1.))(2.)
def test_unary_ildj(self): @custom_inverse.custom_inverse def add_one(x): return x + 1. def add_one_ildj(y): del y return 4. add_one.def_inverse_unary(f_ildj=add_one_ildj) self.assertEqual(core.inverse(add_one)(2.), 1.) self.assertEqual(core.ildj(add_one)(2.), 4.)
def test_binary_inverse_and_ildj(self): @custom_inverse.custom_inverse def add(x, y): return x + y def add_ildj(invals, z, z_ildj): x, y = invals if z is None: raise custom_inverse.NonInvertibleError() if x is not None and y is None: return (x, z - x), (jnp.zeros_like(z_ildj), z_ildj) elif x is None and y is not None: return (z - y, y), (z_ildj, jnp.zeros_like(z_ildj)) return (None, None), (0., 0.) add_one_left = functools.partial(add, 1.) add_one_right = lambda x: add(x, 1.) self.assertEqual(core.inverse(add_one_left)(2.), 1.) add.def_inverse_and_ildj(add_ildj) self.assertEqual(core.inverse(add_one_left)(2.), 1.) self.assertEqual(core.inverse(add_one_right)(2.), 1.) def add_ildj2(invals, z, z_ildj): x, y = invals if z is None: raise custom_inverse.NonInvertibleError() if x is not None and y is None: return (x, z**2), (jnp.zeros_like(z_ildj), 3. + z_ildj) elif x is None and y is not None: return (z / 4, y), (2. + z_ildj, jnp.zeros_like(z_ildj)) return (None, None), (0., 0.) add.def_inverse_and_ildj(add_ildj2) self.assertEqual(core.inverse(add_one_left)(2.), 4.) self.assertEqual(core.ildj(add_one_left)(2.), 3.) self.assertEqual(core.inverse(add_one_right)(2.), 0.5) self.assertEqual(core.ildj(add_one_right)(2.), 2.)
def test_pow_ildj(self): def f(x, y): return lax.pow(x, y) f_x_ildj = core.ildj(lambda x: f(x, 2.)) onp.testing.assert_allclose(f_x_ildj(3.), tfb.Power(2.).inverse_log_det_jacobian(3.)) f_y_ildj = core.ildj(lambda y: f(2., y)) f_y_inv = core.inverse(lambda y: f(2., y)) y = f_y_inv(3.) onp.testing.assert_allclose( f_y_ildj(3.), -np.log(np.abs(jax.grad(lambda y: f(2., y))(y)))) onp.testing.assert_allclose(f_y_ildj(3.), np.log(np.abs(jax.grad(f_y_inv)(3.))))
def def_inverse_unary(self, f_inv=None, f_ildj=None): """Defines a unary inverse rule. Args: f_inv: An optional unary function that returns the inverse of this function. If not provided, we automatically invert the forward function. f_ildj: An optional unary function that computes the ILDJ of this function. If it is not provided, we automatically compute it from `f_inv`. """ f_inv_ = f_inv or core.inverse(self.func) f_ildj_ = f_ildj or core.ildj(core.inverse(f_inv_), reduce_ildj=False) def f_inverse_and_ildj(invals, outval, outildj): inval = invals[0] if outval is None: raise NonInvertibleError() if inval is None: new_inval = f_inv_(outval) new_inildj = f_ildj_(outval) + outildj return (new_inval, ), (new_inildj, ) raise NonInvertibleError() self.def_inverse_and_ildj(f_inverse_and_ildj)
def test_unary_inverse_and_ildj(self): @custom_inverse.custom_inverse def add_one(x): return x + 1. def add_one_inv(y): return jnp.exp(y) def add_one_ildj(y): del y return 4. add_one.def_inverse_unary(add_one_inv, add_one_ildj) self.assertEqual(core.inverse(add_one)(2.), jnp.exp(2.)) self.assertEqual(core.ildj(add_one)(2.), 4.)
def test_conditional_inverse(self): def f(x, y): return x + 1., np.exp(x + 1.) + y f_inv = core.inverse(f, 0., 2.) onp.testing.assert_allclose(f_inv(0., 2.), (-1., 1.))
def test_sqrt_inverse(self): def f(x): return np.sqrt(x) f_inv = core.inverse(f) onp.testing.assert_allclose(f_inv(2.), 4.)
def test_reciprocal_inverse(self): def f(x): return np.reciprocal(x) f_inv = core.inverse(f) onp.testing.assert_allclose(f_inv(2.), 0.5)
def test_sow_happens_in_forward_pass(self): def f(x, y): return x, harvest.sow(x, name='x', tag='foo') * y vals = harvest.reap(core.inverse(f), tag='foo')(1., 1.) self.assertDictEqual(vals, dict(x=1.))
def test_noninvertible(self): def f(x, y): return x + y, x + y with self.assertRaises(ValueError): core.inverse(f)(1., 2.)