def testRematLambdaFunction(self): f = lambda a, b: tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) f_remat = extensions.remat(f) shape = [10] a = tf_np.random.randn(*shape) b = tf_np.random.randn(*shape) actual = extensions.grad(f_remat)(a, b) expected = extensions.grad(f)(a, b) self.assertAllClose(actual, expected)
def testRematJit(self): def f(a, b): return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) f_remat = extensions.remat(f) shape = [10] a = tf_np.random.randn(*shape) b = tf_np.random.randn(*shape) actual = extensions.jit(extensions.grad(f_remat))(a, b) expected = extensions.jit(extensions.grad(f))(a, b) self.assertAllClose(actual, expected)
def train_step(params, inputs, targets, learning_rate=0.1): grad_fn = extensions.grad(loss_fn) grads = grad_fn(params, inputs, targets) new_w = params[0] - (grads[0] * learning_rate) new_b = params[1] - (grads[1] * learning_rate) return new_w, new_b
def testGradNonArrayOutput(self): def f(_): return 1.0 g = extensions.grad(f) with self.assertRaisesWithPredicateMatch( ValueError, r"result .* must be an ndarray"): g(asarray(1.0))
def testScanGrad(self, jit_scan, jit_f): rng = np.random.RandomState(0) d = rng.randn(2) def f(c, a): assert a.shape == (3, ) assert c.shape == (4, ) b = (tf_np.sum(tf_np.sin(a)) + tf_np.sum(tf_np.sin(c)) + tf_np.sum(tf_np.sin(d))) c = tf_np.sin(c * b) assert b.shape == () # pylint: disable=g-explicit-bool-comparison return c, b if jit_f: f = extensions.jit(f) if jit_scan: scan = extensions.jit(extensions.scan, static_argnums=(0, )) else: scan = extensions.scan xs = tf_np.asarray(rng.randn(5, 3)) c = tf_np.asarray(rng.randn(4)) def losses(scan, c, xs): c, ys = scan(f, c, xs) return tf_np.concatenate( tf.nest.flatten( tf.nest.map_structure(lambda a: tf_np.reshape(a, [-1]), (c, ys)))) def loss(scan, c, xs): return tf_np.sum(losses(scan, c, xs)) ans = extensions.grad(functools.partial(loss, scan))(c, xs) expected = extensions.grad(functools.partial(loss, scan_reference))(c, xs) self.assertDTypesEqual(expected, ans) self.assertAllClose(expected, ans) theoretical, numerical = tf.test.compute_gradient( to_tf_fn(functools.partial(losses, scan)), (c, xs)) self.assertAllClose(theoretical, numerical, atol=1e-3, rtol=3e-4)
def testRematJitXla(self): def f(a, b): return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) f_remat = extensions.remat(f) shape = [10] a = tf_np.random.randn(*shape) b = tf_np.random.randn(*shape) actual = extensions.jit(extensions.grad(f_remat), xla_forced_compile=True)(a, b) expected = extensions.jit(extensions.grad(f), xla_forced_compile=True)(a, b) self.assertAllClose(actual, expected) actual = extensions.jit(extensions.grad(f_remat), experimental_compile=True)(a, b) expected = extensions.jit(extensions.grad(f), experimental_compile=True)(a, b) self.assertAllClose(actual, expected)
def testGrad(self): def f(a, b): return math.sum(math.sqrt(math.exp(a)) + b) g = extensions.grad(f) def compare(a, b): with tf.GradientTape() as tape: tape.watch(a.data) r = f(a, b) expected = tape.gradient(r.data, a.data) self.assertAllEqual(expected, g(a, b)) shape = [10] a = random.randn(*shape) b = random.randn(*shape) compare(a, b)
def testGradNonScalarOutput(self): def f(a): return a g = extensions.grad(f) with self.assertRaisesWithPredicateMatch( ValueError, r"result .* must be a scalar"): g(asarray([1.0, 2.0])) @extensions.jit def g_jitted(a): return extensions.grad(f)(a) g_jitted(asarray(1.0)) with self.assertRaisesWithPredicateMatch( ValueError, r"result .* must be a scalar"): g_jitted(asarray([1.0, 2.0]))
def _tf_grad(f, **kwargs): """Grad with support for argnums.""" argnums = kwargs.pop('argnums', 0) if argnums != 0: def g(*args, **kwargs): args = list(args) args[0], args[argnums] = args[argnums], args[0] return f(*args, **kwargs) else: g = f grad_g = tf_np_extensions.grad(g, **kwargs) if argnums == 0: return grad_g def grad_f(*args, **kwargs): args = list(args) args[0], args[argnums] = args[argnums], args[0] return grad_g(*args, **kwargs) return grad_f
def g_jitted(a): return extensions.grad(f)(a)
def grad_origin(c, xs): return extensions.grad(functools.partial(loss, scan))(c, xs)
def testScanGrad(self, jit_grad, jit_scan, jit_f): rng = np.random.RandomState(0) d = rng.randn(2) def f(c, a): assert a.shape == (3, ) assert c.shape == (4, ) b = (tf_np.sum(tf_np.sin(a)) + tf_np.sum(tf_np.sin(c)) + tf_np.sum(tf_np.sin(d))) c = tf_np.sin(c * b) assert b.shape == () # pylint: disable=g-explicit-bool-comparison return c, b if jit_f: f = extensions.jit(f) if jit_scan == "no_xla": scan = extensions.jit(extensions.scan, static_argnums=(0, )) elif jit_scan == "xla_forced_compile": # TODO(b/187107596): Remove `skipTest` self.skipTest( "Taking gradients of `jit(scan, experimental_compile=True)` triggers " "'Support for TensorList crossing the XLA/TF boundary is not " "implemented' error") # `xla_forced_compile=True` doesn't support gradients, so we use # `experimental_compile=True`. scan = extensions.jit(extensions.scan, static_argnums=(0, ), experimental_compile=True) else: scan = extensions.scan xs = tf_np.asarray(rng.randn(5, 3)) c = tf_np.asarray(rng.randn(4)) def losses(scan, c, xs): c, ys = scan(f, c, xs) return tf_np.concatenate( tf.nest.flatten( tf.nest.map_structure(lambda a: tf_np.reshape(a, [-1]), (c, ys)))) def loss(scan, c, xs): return tf_np.sum(losses(scan, c, xs)) def grad_origin(c, xs): return extensions.grad(functools.partial(loss, scan))(c, xs) if jit_grad == "no_xla": grad_jit = extensions.jit(grad_origin) elif jit_grad == "xla_forced_compile": grad_jit = extensions.jit(grad_origin, xla_forced_compile=True) else: grad_jit = grad_origin ans = grad_jit(c, xs) expected = extensions.grad(functools.partial(loss, scan_reference))(c, xs) self.assertDTypesEqual(expected, ans) self.assertAllClose(expected, ans) theoretical, numerical = tf.test.compute_gradient( to_tf_fn(functools.partial(losses, scan)), (c, xs)) self.assertAllClose(theoretical, numerical, atol=1e-3, rtol=3e-4)