Exemple #1
0
    def testRematLambdaFunction(self):
        f = lambda a, b: tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b)
        f_remat = extensions.remat(f)

        shape = [10]
        a = tf_np.random.randn(*shape)
        b = tf_np.random.randn(*shape)

        actual = extensions.grad(f_remat)(a, b)
        expected = extensions.grad(f)(a, b)
        self.assertAllClose(actual, expected)
Exemple #2
0
    def testRematJit(self):
        def f(a, b):
            return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b)

        f_remat = extensions.remat(f)

        shape = [10]
        a = tf_np.random.randn(*shape)
        b = tf_np.random.randn(*shape)

        actual = extensions.jit(extensions.grad(f_remat))(a, b)
        expected = extensions.jit(extensions.grad(f))(a, b)
        self.assertAllClose(actual, expected)
Exemple #3
0
def train_step(params, inputs, targets, learning_rate=0.1):
  grad_fn = extensions.grad(loss_fn)
  grads = grad_fn(params, inputs, targets)
  new_w = params[0] - (grads[0] * learning_rate)
  new_b = params[1] - (grads[1] * learning_rate)

  return new_w, new_b
Exemple #4
0
 def testGradNonArrayOutput(self):
   def f(_):
     return 1.0
   g = extensions.grad(f)
   with self.assertRaisesWithPredicateMatch(
       ValueError, r"result .* must be an ndarray"):
     g(asarray(1.0))
Exemple #5
0
    def testScanGrad(self, jit_scan, jit_f):
        rng = np.random.RandomState(0)

        d = rng.randn(2)

        def f(c, a):
            assert a.shape == (3, )
            assert c.shape == (4, )
            b = (tf_np.sum(tf_np.sin(a)) + tf_np.sum(tf_np.sin(c)) +
                 tf_np.sum(tf_np.sin(d)))
            c = tf_np.sin(c * b)
            assert b.shape == ()  # pylint: disable=g-explicit-bool-comparison
            return c, b

        if jit_f:
            f = extensions.jit(f)

        if jit_scan:
            scan = extensions.jit(extensions.scan, static_argnums=(0, ))
        else:
            scan = extensions.scan

        xs = tf_np.asarray(rng.randn(5, 3))
        c = tf_np.asarray(rng.randn(4))

        def losses(scan, c, xs):
            c, ys = scan(f, c, xs)
            return tf_np.concatenate(
                tf.nest.flatten(
                    tf.nest.map_structure(lambda a: tf_np.reshape(a, [-1]),
                                          (c, ys))))

        def loss(scan, c, xs):
            return tf_np.sum(losses(scan, c, xs))

        ans = extensions.grad(functools.partial(loss, scan))(c, xs)
        expected = extensions.grad(functools.partial(loss, scan_reference))(c,
                                                                            xs)
        self.assertDTypesEqual(expected, ans)
        self.assertAllClose(expected, ans)

        theoretical, numerical = tf.test.compute_gradient(
            to_tf_fn(functools.partial(losses, scan)), (c, xs))
        self.assertAllClose(theoretical, numerical, atol=1e-3, rtol=3e-4)
Exemple #6
0
    def testRematJitXla(self):
        def f(a, b):
            return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b)

        f_remat = extensions.remat(f)

        shape = [10]
        a = tf_np.random.randn(*shape)
        b = tf_np.random.randn(*shape)

        actual = extensions.jit(extensions.grad(f_remat),
                                xla_forced_compile=True)(a, b)
        expected = extensions.jit(extensions.grad(f),
                                  xla_forced_compile=True)(a, b)
        self.assertAllClose(actual, expected)

        actual = extensions.jit(extensions.grad(f_remat),
                                experimental_compile=True)(a, b)
        expected = extensions.jit(extensions.grad(f),
                                  experimental_compile=True)(a, b)
        self.assertAllClose(actual, expected)
Exemple #7
0
 def testGrad(self):
   def f(a, b):
     return math.sum(math.sqrt(math.exp(a)) + b)
   g = extensions.grad(f)
   def compare(a, b):
     with tf.GradientTape() as tape:
       tape.watch(a.data)
       r = f(a, b)
     expected = tape.gradient(r.data, a.data)
     self.assertAllEqual(expected, g(a, b))
   shape = [10]
   a = random.randn(*shape)
   b = random.randn(*shape)
   compare(a, b)
Exemple #8
0
 def testGradNonScalarOutput(self):
   def f(a):
     return a
   g = extensions.grad(f)
   with self.assertRaisesWithPredicateMatch(
       ValueError, r"result .* must be a scalar"):
     g(asarray([1.0, 2.0]))
   @extensions.jit
   def g_jitted(a):
     return extensions.grad(f)(a)
   g_jitted(asarray(1.0))
   with self.assertRaisesWithPredicateMatch(
       ValueError, r"result .* must be a scalar"):
     g_jitted(asarray([1.0, 2.0]))
Exemple #9
0
def _tf_grad(f, **kwargs):
  """Grad with support for argnums."""
  argnums = kwargs.pop('argnums', 0)
  if argnums != 0:
    def g(*args, **kwargs):
      args = list(args)
      args[0], args[argnums] = args[argnums], args[0]
      return f(*args, **kwargs)
  else:
    g = f
  grad_g = tf_np_extensions.grad(g, **kwargs)
  if argnums == 0:
    return grad_g
  def grad_f(*args, **kwargs):
    args = list(args)
    args[0], args[argnums] = args[argnums], args[0]
    return grad_g(*args, **kwargs)
  return grad_f
Exemple #10
0
 def g_jitted(a):
   return extensions.grad(f)(a)
Exemple #11
0
 def grad_origin(c, xs):
     return extensions.grad(functools.partial(loss, scan))(c, xs)
Exemple #12
0
    def testScanGrad(self, jit_grad, jit_scan, jit_f):
        rng = np.random.RandomState(0)

        d = rng.randn(2)

        def f(c, a):
            assert a.shape == (3, )
            assert c.shape == (4, )
            b = (tf_np.sum(tf_np.sin(a)) + tf_np.sum(tf_np.sin(c)) +
                 tf_np.sum(tf_np.sin(d)))
            c = tf_np.sin(c * b)
            assert b.shape == ()  # pylint: disable=g-explicit-bool-comparison
            return c, b

        if jit_f:
            f = extensions.jit(f)

        if jit_scan == "no_xla":
            scan = extensions.jit(extensions.scan, static_argnums=(0, ))
        elif jit_scan == "xla_forced_compile":
            # TODO(b/187107596): Remove `skipTest`
            self.skipTest(
                "Taking gradients of `jit(scan, experimental_compile=True)` triggers "
                "'Support for TensorList crossing the XLA/TF boundary is not "
                "implemented' error")
            # `xla_forced_compile=True` doesn't support gradients, so we use
            # `experimental_compile=True`.
            scan = extensions.jit(extensions.scan,
                                  static_argnums=(0, ),
                                  experimental_compile=True)
        else:
            scan = extensions.scan

        xs = tf_np.asarray(rng.randn(5, 3))
        c = tf_np.asarray(rng.randn(4))

        def losses(scan, c, xs):
            c, ys = scan(f, c, xs)
            return tf_np.concatenate(
                tf.nest.flatten(
                    tf.nest.map_structure(lambda a: tf_np.reshape(a, [-1]),
                                          (c, ys))))

        def loss(scan, c, xs):
            return tf_np.sum(losses(scan, c, xs))

        def grad_origin(c, xs):
            return extensions.grad(functools.partial(loss, scan))(c, xs)

        if jit_grad == "no_xla":
            grad_jit = extensions.jit(grad_origin)
        elif jit_grad == "xla_forced_compile":
            grad_jit = extensions.jit(grad_origin, xla_forced_compile=True)
        else:
            grad_jit = grad_origin

        ans = grad_jit(c, xs)
        expected = extensions.grad(functools.partial(loss, scan_reference))(c,
                                                                            xs)
        self.assertDTypesEqual(expected, ans)
        self.assertAllClose(expected, ans)

        theoretical, numerical = tf.test.compute_gradient(
            to_tf_fn(functools.partial(losses, scan)), (c, xs))
        self.assertAllClose(theoretical, numerical, atol=1e-3, rtol=3e-4)