Example #1
0
 def f(a, b, reverse=False):
   res = tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b)
   res = (res, 10)
   if allow_static_outputs:
     res = res + (Thing(20),)
   if reverse:
     res = tuple(reversed(res))
   return res
Example #2
0
    def testRematLambdaFunction(self):
        f = lambda a, b: tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b)
        f_remat = extensions.remat(f)

        shape = [10]
        a = tf_np.random.randn(*shape)
        b = tf_np.random.randn(*shape)

        actual = extensions.grad(f_remat)(a, b)
        expected = extensions.grad(f)(a, b)
        self.assertAllClose(actual, expected)
Example #3
0
 def f(a, b):
   return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b)
Example #4
0
 def f(a, b):
   y = tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b)
   if has_aux:
     return y, tf_np.asarray(1)
   else:
     return y