示例#1
0
    def testRematLambdaFunction(self):
        f = lambda a, b: tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b)
        f_remat = extensions.remat(f)

        shape = [10]
        a = tf_np.random.randn(*shape)
        b = tf_np.random.randn(*shape)

        actual = extensions.grad(f_remat)(a, b)
        expected = extensions.grad(f)(a, b)
        self.assertAllClose(actual, expected)
示例#2
0
    def testRematJit(self):
        def f(a, b):
            return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b)

        f_remat = extensions.remat(f)

        shape = [10]
        a = tf_np.random.randn(*shape)
        b = tf_np.random.randn(*shape)

        actual = extensions.jit(extensions.grad(f_remat))(a, b)
        expected = extensions.jit(extensions.grad(f))(a, b)
        self.assertAllClose(actual, expected)
示例#3
0
    def testRematJitXla(self):
        def f(a, b):
            return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b)

        f_remat = extensions.remat(f)

        shape = [10]
        a = tf_np.random.randn(*shape)
        b = tf_np.random.randn(*shape)

        actual = extensions.jit(extensions.grad(f_remat),
                                xla_forced_compile=True)(a, b)
        expected = extensions.jit(extensions.grad(f),
                                  xla_forced_compile=True)(a, b)
        self.assertAllClose(actual, expected)

        actual = extensions.jit(extensions.grad(f_remat),
                                experimental_compile=True)(a, b)
        expected = extensions.jit(extensions.grad(f),
                                  experimental_compile=True)(a, b)
        self.assertAllClose(actual, expected)