def testRematLambdaFunction(self): f = lambda a, b: tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) f_remat = extensions.remat(f) shape = [10] a = tf_np.random.randn(*shape) b = tf_np.random.randn(*shape) actual = extensions.grad(f_remat)(a, b) expected = extensions.grad(f)(a, b) self.assertAllClose(actual, expected)
def testRematJit(self): def f(a, b): return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) f_remat = extensions.remat(f) shape = [10] a = tf_np.random.randn(*shape) b = tf_np.random.randn(*shape) actual = extensions.jit(extensions.grad(f_remat))(a, b) expected = extensions.jit(extensions.grad(f))(a, b) self.assertAllClose(actual, expected)
def testRematJitXla(self): def f(a, b): return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) f_remat = extensions.remat(f) shape = [10] a = tf_np.random.randn(*shape) b = tf_np.random.randn(*shape) actual = extensions.jit(extensions.grad(f_remat), xla_forced_compile=True)(a, b) expected = extensions.jit(extensions.grad(f), xla_forced_compile=True)(a, b) self.assertAllClose(actual, expected) actual = extensions.jit(extensions.grad(f_remat), experimental_compile=True)(a, b) expected = extensions.jit(extensions.grad(f), experimental_compile=True)(a, b) self.assertAllClose(actual, expected)