Example #1
    def test_negative_axes(self):
        x = np.arange(3 * 4 * 5).reshape(3, 4, 5)
            extensions.vmap(tf_np.sum, in_axes=-3)(x), tf_np.sum(x,
                                                                 axis=(1, 2)))
            extensions.vmap(tf_np.sum, in_axes=-2)(x), tf_np.sum(x,
                                                                 axis=(0, 2)))
            extensions.vmap(tf_np.sum, in_axes=-1)(x), tf_np.sum(x,
                                                                 axis=(0, 1)))

        identity = lambda y: y
            extensions.vmap(identity, in_axes=0, out_axes=-3)(x))
            x.transpose(1, 0, 2),
            extensions.vmap(identity, in_axes=0, out_axes=-2)(x))
            x.transpose(1, 2, 0),
            extensions.vmap(identity, in_axes=0, out_axes=-1)(x))

            np.full((5, ), 7),
            extensions.vmap(lambda *xs: xs,
                            in_axes=(0, None),
                            out_axes=(0, -1))(np.arange(5), 7)[1])
Example #2
 def f(c, a):
   assert a.shape == (3,)
   assert c.shape == (4,)
   b = (tf_np.sum(tf_np.sin(a)) + tf_np.sum(tf_np.sin(c)) +
   c = tf_np.sin(c * b)
   assert b.shape == ()  # pylint: disable=g-explicit-bool-comparison
   return c, b
Example #3
 def f(c_g, a_e):
   c, g = c_g
   a, e = a_e
   assert a.shape == (3,)
   assert e.shape == ()  # pylint: disable=g-explicit-bool-comparison
   assert c.shape == (4,)
   assert g.shape == (2,)
   b = tf_np.cos(tf_np.sum(tf_np.sin(a)) + tf_np.sum(tf_np.cos(c)) +
   f = tf_np.cos(a)
   c = tf_np.sin(c * b)
   g = tf_np.sin(g * b)
   assert b.shape == ()  # pylint: disable=g-explicit-bool-comparison
   assert f.shape == (3,)
   return [c, g], (b, f)
Example #4
 def f(a, b, reverse=False):
   res = tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b)
   res = (res, 10)
   if allow_static_outputs:
     res = res + (Thing(20),)
   if reverse:
     res = tuple(reversed(res))
   return res
Example #5
 def f(c_g_i, a_e_h):
   c_g, i = c_g_i
   c, g = c_g
   a, e_h = a_e_h
   e, h = e_h
   assert a.shape == (3,)
   assert e.shape == ()  # pylint: disable=g-explicit-bool-comparison
   assert c.shape == (4,)
   assert g.shape == (2,)
   assert i is None
   assert h is None
   b = tf_np.cos(tf_np.sum(tf_np.sin(a)) + tf_np.sum(tf_np.cos(c)) +
   f = tf_np.cos(a)
   c = tf_np.sin(c * b)
   g = tf_np.sin(g * b)
   assert b.shape == ()  # pylint: disable=g-explicit-bool-comparison
   assert f.shape == (3,)
   return [(c, g), i], (b, [f, h])
Example #6
    def testRematLambdaFunction(self):
        f = lambda a, b: tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b)
        f_remat = extensions.remat(f)

        shape = [10]
        a = tf_np.random.randn(*shape)
        b = tf_np.random.randn(*shape)

        actual = extensions.grad(f_remat)(a, b)
        expected = extensions.grad(f)(a, b)
        self.assertAllClose(actual, expected)
Example #7
 def testJVPOfGradOfIndexing(self):
     # Should return a value, even though we didn't pass a symbolic zero as the
     # index tangent.
     x = jnp.ones((3, 4), jnp.float32)
     i = jnp.ones((3, ), jnp.int32)
     f = lambda x, i: jnp.sum(x[i])
     primals, tangents = api.jvp(api.grad(f), (x, i),
                                 (x, onp.zeros_like(i)))
     expected = onp.broadcast_to(
         onp.array([0, 3, 0], dtype=onp.float32)[:, None], (3, 4))
     self.assertAllClose(expected, primals, check_dtypes=True)
     self.assertAllClose(onp.zeros_like(x), tangents, check_dtypes=True)
    def evaluate(self, x, y):
        """Returns the number of correct predictions.

      x: 2-d array of size batch_size x image_size.
      y: 2-d array of size batch_size x num_classes.

      A scalar, the number of correct predictions.
        y_actual = np.argmax(y, axis=1)
        y_predicted = np.argmax(self.forward(x), axis=1)
        return int(
            np.sum(np.array(y_actual == y_predicted, copy=False, dtype=int32)))
Example #9
 def loss(scan, c, xs):
   return tf_np.sum(losses(scan, c, xs))
Example #10
 def f(a, b):
   return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b)
Example #11
 def f(a, b):
   y = tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b)
   if has_aux:
     return y, tf_np.asarray(1)
     return y
Example #12
class ExtensionsTest(tf.test.TestCase, parameterized.TestCase):

  def __init__(self, methodName="runTest"):  # pylint: disable=invalid-name
    physical_devices = tf.config.experimental.list_physical_devices("CPU")
        physical_devices[0], [
    if extensions.tpu_devices():
      resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")

  def _hasGPU(self):
    physical_devices = tf.config.experimental.list_physical_devices("GPU")
    return physical_devices

  def testCustomGrad(self):
    """Test for custom_grad."""
    x_shape = (tf.TensorShape([10]), tf.TensorShape([1, 10]))
    y_shape = (tf.TensorShape([]))
    dtype = np.float32
    scale1 = 5.0
    scale2 = 6.0

    def fwd(a, b):
      return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b)

    def f(a, b):
      y = fwd(a, b)

      def vjp(dy):
        return dy * scale1 * a, dy * scale2 * b

      return y, vjp

    rng = tf.random.Generator.from_seed(1234)
    x, dy = tf.nest.map_structure(lambda shape: uniform(rng, shape, dtype),
                                  [x_shape, y_shape])
    expected_y = fwd(*x)
    expected_dx = (dy * scale1 * x[0], dy * scale2 * x[1])
    y, vjp = extensions.vjp(f, *x)
    dx = vjp(dy)
    self.assertAllClose(to_tf(expected_y), to_tf(y))
    self.assertAllClose(to_tf(expected_dx), to_tf(dx))

      (  # pylint: disable=g-complex-comprehension
          ("_%s_%s_%s" % (decorator_id, x_struct, y_struct)).replace(
              " ", "").replace("None", ""), decorator, x_struct, y_struct)
      for y_struct in [[None, ()], (None, (), [], (None, ((), None)))]
      for x_struct in [(None, ()), (((), ()), [None, None], [], (None, ()))]
      for decorator_id, decorator in enumerate([lambda f: f, extensions.jit])
  def testCustomGradStructure(self, decorator, x_struct, y_struct):
    """Tests that custom_grad can handle structured inputs/outputs."""

    def zeros(x):
      return tf.nest.map_structure(lambda _: tf_np.zeros([], np.float32), x)

    def get_struct(x):
      return tf.nest.map_structure(lambda _: None, x)

    def f(*x):
      del x

      def vjp(dy):
        self.assertEqual(y_struct, get_struct(dy))
        return zeros(x_struct)

      return zeros(y_struct), vjp

    x, dy = zeros([x_struct, y_struct])

    def run(x, dy):
      y, vjp = extensions.vjp(f, *x)
      dx = vjp(dy)
      return dx, y

    dx, y = run(x, dy)
    self.assertEqual(x_struct, get_struct(dx))
    self.assertEqual(y_struct, get_struct(y))

      ("_%s" % has_aux, has_aux) for has_aux in [True, False]
  def testVjp(self, has_aux):
    x_shape = (tf.TensorShape([10]), tf.TensorShape([1, 10]))
    y_shape = (tf.TensorShape([]))
    dtype = np.float32

    def f(a, b):
      y = tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b)
      if has_aux:
        return y, tf_np.asarray(1)
        return y

    rng = tf.random.Generator.from_seed(1234)
    x, dy_list = tf.nest.map_structure(lambda shape: uniform(rng, shape, dtype),
                                       [x_shape, [y_shape] * 2])
    tf_x = to_tf(x)
    outputs = extensions.vjp(f, *x, has_aux=has_aux)
    if has_aux:
      y, vjp, aux = outputs
      y, vjp = outputs
    with tf.GradientTape(persistent=True) as tape:
      outputs = f(*x)
      if has_aux:
        expected_y, expected_aux = outputs
        self.assertAllClose(to_tf(expected_aux), to_tf(aux))
        expected_y = outputs
    self.assertAllClose(to_tf(expected_y), to_tf(y))
    for dy in dy_list:
      expected_dx = tape.gradient(
          to_tf(expected_y), tf_x, output_gradients=to_tf(dy))
      self.assertAllClose(expected_dx, to_tf(vjp(dy)))

  def testGrad(self):

    def f(a, b):
      return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b)

    g = extensions.grad(f)

    def compare(a, b):
      with tf.GradientTape() as tape:
        r = f(a, b)
      expected = tape.gradient(r.data, a.data)
      self.assertAllEqual(expected, g(a, b))

    shape = [10]
    a = tf_np.random.randn(*shape)
    b = tf_np.random.randn(*shape)
    compare(a, b)

  def testGradNonArrayOutput(self):

    def f(_):
      return 1.0

    g = extensions.grad(f)
    with self.assertRaisesWithPredicateMatch(ValueError,
                                             r"result .* must be an ndarray"):

  def testGradNonScalarOutput(self):

    def f(a):
      return a

    g = extensions.grad(f)
    with self.assertRaisesWithPredicateMatch(ValueError,
                                             r"result .* must be a scalar"):
      g(tf_np.asarray([1.0, 2.0]))

    def g_jitted(a):
      return extensions.grad(f)(a)

    with self.assertRaisesWithPredicateMatch(ValueError,
                                             r"result .* must be a scalar"):
      g_jitted(tf_np.asarray([1.0, 2.0]))

  def testJit(self):

    def f(a, b):
      return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b)

    f_jitted = extensions.jit(f)
    shape = [10]
    a = tf_np.random.randn(*shape)
    b = tf_np.random.randn(*shape)
    self.assertAllClose(f(a, b), f_jitted(a, b))
    # Call again since the code path is different on second call
    self.assertAllClose(f(a, b), f_jitted(a, b))

  def testJitNoUnnecessaryTracing(self):

    def num_traces(f):
      return len(f.tf_function._list_all_concrete_functions_for_serialization())

    def check_trace_only_once(arg1, arg2):

      def f(a):
        return a + 1

      self.assertAllEqual(0, num_traces(f))
      self.assertAllEqual(1, num_traces(f))
      self.assertAllEqual(1, num_traces(f))

    check_trace_only_once(1, 2)
    check_trace_only_once(1.1, 2.1)
    check_trace_only_once(tf_np.asarray(1), tf_np.asarray(2))
        tf.convert_to_tensor(value=1), tf.convert_to_tensor(value=2))

  def _testEvalOnShapes(self, transformer, allow_static_outputs):

    # A class that's not convertable to tensor
    class Thing:

      def __init__(self, value):
        self.value = value

    def f(a, b, reverse=False):
      res = tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b)
      res = (res, 10)
      if allow_static_outputs:
        res = res + (Thing(20),)
      if reverse:
        res = tuple(reversed(res))
      return res

    f_prime = transformer(
        f, static_argnums=(2,), allow_static_outputs=allow_static_outputs)
    shape = [10]
    dtype = np.float16
    a = tf_np.zeros(shape=shape, dtype=dtype)
    b = tf_np.zeros(shape=shape, dtype=dtype)
    expected, *_ = f(a, b)
    got = f_prime(a, b)
    def check(got):
      self.assertIsInstance(got[0], (tf.TensorSpec, tf_np.ndarray))
      self.assertAllEqual(expected.shape, got[0].shape)
      self.assertAllEqual(expected.dtype, got[0].dtype)
      if allow_static_outputs:
        self.assertIsInstance(got[1], int)
        self.assertEqual(10, got[1])
        self.assertIsInstance(got[2], Thing)
        self.assertEqual(20, got[2].value)
        self.assertIsInstance(got[1], (tf.TensorSpec, tf_np.ndarray))
        self.assertAllEqual((), got[1].shape)
    # Call again since the code path is different on second call
    got = f_prime(a, b)
    # Retrace and check again
    got = f_prime(a, b, True)
    got = f_prime(a, b, True)

  @parameterized.named_parameters(("_%s" % b, b) for b in [False, True])
  def testEvalOnShapes(self, allow_static_outputs):
    self._testEvalOnShapes(extensions.eval_on_shapes, allow_static_outputs)

  def testEvalOnShapesNested(self):
    transformer = functools.partial(extensions.eval_on_shapes,
    def outer():
      def inner():
        return 1
      return inner() + 2
    r = outer()
    self.assertIsInstance(r, int)
    self.assertEqual(3, r)

  def testJitOfEvalOnShapes(self):
    """Tests that eval_on_shapes can be called within jit."""

    def transformer(f, **kwargs):
      def f_prime(*args):
        res = extensions.eval_on_shapes(f, **kwargs)(*args)
        return tf.nest.map_structure(
            lambda x: tf_np.zeros(x.shape, x.dtype), res)
      return extensions.jit(f_prime, kwargs.get("static_argnums", ()))

    self._testEvalOnShapes(transformer, False)

  def testEvalOnShapesNoUnnecessaryTracing(self):

    def num_traces(f):
      return len(

    def check_trace_only_once(arg1, arg2):

      def f(a):
        return a + 1

      self.assertAllEqual(0, num_traces(f))
      self.assertAllEqual(1, num_traces(f))
      self.assertAllEqual(1, num_traces(f))

    check_trace_only_once(1, 2)
    check_trace_only_once(1.1, 2.1)
    check_trace_only_once(tf_np.asarray(1), tf_np.asarray(2))
        tf.convert_to_tensor(value=1), tf.convert_to_tensor(value=2))

          "lhs_np": np.ones((5, 3)),
          "rhs_np": np.ones((3, 2)),
          "dims": (((1,), (0,)), ((), ()))
          "lhs_np": np.ones((5, 3)),
          "rhs_np": np.ones((5, 3)),
          "dims": (((0, 1), (0, 1)), ((), ()))
          "lhs_np": np.ones((5, 3, 2)),
          "rhs_np": np.ones((2, 3, 2)),
          "dims": (((1, 2), (1, 0)), ((), ()))
          "lhs_np": np.ones((6, 5, 3)),
          "rhs_np": np.ones((6, 3, 2)),
          "dims": (((2,), (1,)), ((0,), (0,)))
          "lhs_np": np.ones((6, 3, 5)),
          "rhs_np": np.ones((6, 3, 2)),
          "dims": (((1,), (1,)), ((0,), (0,)))
          "lhs_np": np.ones((5, 3, 2, 2)),
          "rhs_np": np.ones((5, 2, 2, 6)),
          "dims": (((2, 3), (1, 2)), ((0,), (0,)))
          "lhs_np": np.ones((2, 2, 5, 3)),
          "rhs_np": np.ones((2, 2, 3, 2)),
          "dims": (((3,), (2,)), ((0, 1), (0, 1)))
          "lhs_np": np.ones((2, 2, 5, 2)),
          "rhs_np": np.ones((2, 2, 3, 2)),
          "dims": (((3,), (1,)), ((0,), (0,)))
          "lhs_np": np.ones((2, 2, 5, 3, 3)),
          "rhs_np": np.ones((2, 3, 2, 3, 2)),
          "dims": (((4,), (1,)), ((0,), (0,)))
  def test_tf_dot_general(self, lhs_np, rhs_np, dims):
    ans = lax.dot_general(lhs_np, rhs_np, dims)
    result = extensions.tf_dot_general(lhs_np, rhs_np, dims)
    self.assertAllClose(result, np.array(ans))

      ("_lhs_shape={}_rhs_shape={}_strides={}_padding={}"  # pylint: disable=g-complex-comprehension
       "_perms={}".format(lhs_shape, rhs_shape,
                          strides, padding, lhs_dilation, rhs_dilation,
                          feature_group_count, batch_group_count, ",".join(
                              dimension_numbers), perms),
       lhs_shape, rhs_shape, strides, padding, lhs_dilation, rhs_dilation,
       feature_group_count, batch_group_count, dimension_numbers, perms)
      for batch_group_count, feature_group_count in [(1, 1)]
      for lhs_shape, rhs_shape in [
          ((b * batch_group_count, i * feature_group_count, 9, w),
           (j * feature_group_count * batch_group_count, i, 4, 5))
          for w in [0, 10]
          for b, i, j in itertools.product([2, 3], repeat=3)]
      for strides in [(1, 1), (2, 1)]
      for padding in ["SAME"]
      for lhs_dilation, rhs_dilation in [
          (None, (1, 1))
      for dimension_numbers, perms in [
          (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0]))
  def testConvGeneralDilated(self, lhs_shape, rhs_shape, strides,
                             padding, lhs_dilation, rhs_dilation,
                             feature_group_count, batch_group_count,
                             dimension_numbers, perms):
    lhs_perm, rhs_perm = perms  # permute to compatible shapes

    lhs = np.transpose(np.ones(lhs_shape), lhs_perm)
    rhs = np.transpose(np.ones(rhs_shape), rhs_perm)

    jax_conv = lax.conv_general_dilated(lhs, rhs, strides, padding,
                                        lhs_dilation, rhs_dilation,

    tf_conv = extensions.tf_conv_general_dilated(lhs, rhs, strides,
                                                 padding, None,
                                                 lhs_dilation, rhs_dilation,

    self.assertAllEqual(tf_conv, tf_np.asarray(jax_conv))

  def testConv(self):
    y = extensions.conv(
        np.ones([5, 320, 480, 3], dtype=np.float32),
        np.ones([3, 4, 3, 11], dtype=np.float32), [1, 1], "SAME",
        ("NHWC", "HWIO", "NHWC"))
    self.assertAllClose(y.shape, [5, 320, 480, 11])
            input=tf.ones([5, 320, 480, 3], dtype=tf.float32),
            filters=tf.ones([3, 4, 3, 11], dtype=tf.float32),

  def testAvgPool(self):
    y = extensions.avg_pool(np.ones([5, 320, 480, 3]), [3, 5], [2, 3], "VALID")
            input=tf.ones([5, 320, 480, 3]),
            window_shape=[3, 5],
            strides=[2, 3],

  def testMaxPool(self):
    y = extensions.max_pool(np.ones([5, 320, 480, 3]), [3, 5], [2, 3], "VALID")
            input=tf.ones([5, 320, 480, 3]),
            window_shape=[3, 5],
            strides=[2, 3],

  def assertDTypesEqual(self, a, b):
    get_dtype = lambda t: t.dtype
    self.assertEqual(tf.nest.map_structure(get_dtype, a),
                     tf.nest.map_structure(get_dtype, b))

      (f"_{jit_scan}_{jit_f}", jit_scan, jit_f)  # pylint: disable=g-complex-comprehension
      for jit_f in [False, True]
      for jit_scan in [False, True])
  def testScanImpl(self, jit_scan, jit_f):
    rng = np.random.RandomState(0)

    d = rng.randn(2)
    def f(c, a):
      assert a.shape == (3,)
      assert c.shape == (4,)
      b = tf_np.cos(tf_np.sum(tf_np.sin(a)) + tf_np.sum(tf_np.cos(c)) +
      c = tf_np.sin(c * b)
      assert b.shape == ()  # pylint: disable=g-explicit-bool-comparison
      return c, b

    if jit_f:
      f = extensions.jit(f)

    if jit_scan:
      scan = extensions.jit(extensions.scan, (0,))
      scan = extensions.scan

    xs = rng.randn(5, 3)
    c = rng.randn(4)

    ans = scan(f, c, xs)
    expected = scan_reference(f, c, xs)
    self.assertDTypesEqual(expected, ans)
    self.assertAllClose(expected, ans)

  def testScanStruct(self):
    rng = np.random.RandomState(0)

    d = rng.randn(2)
    def f(c_g_i, a_e_h):
      c_g, i = c_g_i
      c, g = c_g
      a, e_h = a_e_h
      e, h = e_h
      assert a.shape == (3,)
      assert e.shape == ()  # pylint: disable=g-explicit-bool-comparison
      assert c.shape == (4,)
      assert g.shape == (2,)
      assert i is None
      assert h is None
      b = tf_np.cos(tf_np.sum(tf_np.sin(a)) + tf_np.sum(tf_np.cos(c)) +
      f = tf_np.cos(a)
      c = tf_np.sin(c * b)
      g = tf_np.sin(g * b)
      assert b.shape == ()  # pylint: disable=g-explicit-bool-comparison
      assert f.shape == (3,)
      return [(c, g), i], (b, [f, h])

    xs = (rng.randn(5, 3), [rng.randn(5), None])
    init = [(rng.randn(4), rng.randn(2)), None]

    c_g_i, b_f_h = extensions.scan(f, init, xs)
    self.assertIsInstance(c_g_i, list)
    self.assertIsInstance(b_f_h, tuple)
    c_g, i = c_g_i
    c, g = c_g
    self.assertIsInstance(c_g, tuple)
    self.assertEqual((4,), c.shape)
    self.assertEqual((2,), g.shape)
    b, f_h = b_f_h
    f, h = f_h
    self.assertIsInstance(f_h, list)
    self.assertEqual((5,), b.shape)
    self.assertEqual((5, 3), f.shape)

      (f"_{jit_scan}_{jit_f}", jit_scan, jit_f)  # pylint: disable=g-complex-comprehension
      for jit_f in [False, True]
      for jit_scan in [False, True])
  def testScanGrad(self, jit_scan, jit_f):
    rng = np.random.RandomState(0)

    d = rng.randn(2)
    def f(c, a):
      assert a.shape == (3,)
      assert c.shape == (4,)
      b = (tf_np.sum(tf_np.sin(a)) + tf_np.sum(tf_np.sin(c)) +
      c = tf_np.sin(c * b)
      assert b.shape == ()  # pylint: disable=g-explicit-bool-comparison
      return c, b

    if jit_f:
      f = extensions.jit(f)

    if jit_scan:
      scan = extensions.jit(extensions.scan, static_argnums=(0,))
      scan = extensions.scan

    xs = tf_np.asarray(rng.randn(5, 3))
    c = tf_np.asarray(rng.randn(4))

    def losses(scan, c, xs):
      c, ys = scan(f, c, xs)
      return tf_np.concatenate(tf.nest.flatten(tf.nest.map_structure(
          lambda a: tf_np.reshape(a, [-1]), (c, ys))))
    def loss(scan, c, xs):
      return tf_np.sum(losses(scan, c, xs))

    ans = extensions.grad(functools.partial(loss, scan))(c, xs)
    expected = extensions.grad(functools.partial(loss, scan_reference))(c, xs)
    self.assertDTypesEqual(expected, ans)
    self.assertAllClose(expected, ans)

    theoretical, numerical = tf.test.compute_gradient(
        to_tf_fn(functools.partial(losses, scan)), (c, xs))
    self.assertAllClose(theoretical, numerical, atol=1e-3, rtol=3e-4)

      (f"_{i}", *args)  # pylint: disable=g-complex-comprehension
      for i, args in enumerate([
          (lambda c, x: (c + 1, tf_np.sum(c + x, 0)),
           [spec(2), spec(4, 3, 2)], [spec(2), spec(4, 2)]),
          (lambda c, x: (c + 1, tf_np.sum(c + x, 0)),
           [spec(2), spec(0, 3, 2), 0], [spec(2), spec(0, 2)]),
  def testScanShape(self, f, inputs, expected_outputs):
    outputs = extensions.eval_on_shapes(
        functools.partial(extensions.scan, f), static_argnums=(2,))(*inputs)
    self.assertAllEqual(expected_outputs, outputs)

  def testPrng(self):
    self.assertAllEqual(tf_np.asarray(123, np.int64), extensions.prng(123))

  def testUniform(self):
    minval = 0.43
    maxval = 3.10
    shape = [13, 34, 29]
    atol = 0.1
    outputs = extensions.uniform(123, shape, minval=minval, maxval=maxval)
    self.assertAllClose((minval + maxval) / 2.0, np.mean(outputs), atol=atol)

  def testNormal(self):
    shape = [13, 34, 29]
    atol = 0.1
    outputs = extensions.normal(123, shape)
    self.assertAllClose(0, np.mean(outputs), atol=atol)
    self.assertAllClose(1, np.std(outputs), atol=atol)

  def testBernoulli(self):
    mean = 0.23
    shape = [13, 34, 29]
    atol = 0.1
    outputs = extensions.bernoulli(123, mean, shape)
    self.assertAllClose(mean, np.mean(outputs), atol=atol)

  def testBernoulliWrongShape(self):
    mean = [0.1, 0.2]
    shape = [3]
    with self.assertRaisesWithPredicateMatch(tf.errors.InvalidArgumentError,
                                             r"Incompatible shapes"):
      extensions.bernoulli(123, mean, shape)

  def testDatasetAsNumpy(self):
    arrs = extensions.dataset_as_numpy(
        [tf.constant([1, 2]), tf.constant([3, 4])])
    for a in arrs:
      self.assertIsInstance(a, tf_np.ndarray)
    with self.assertRaisesWithPredicateMatch(
        r"dataset_as_numpy must be run in eager mode outside tf.function"):

      def f():
        return extensions.dataset_as_numpy([tf.constant([1, 2])])


  def _get_two_devices(self, require_same_type=False):
    tpus = extensions.tpu_devices()
    if FLAGS.requires_tpu:
      if len(tpus) == 2:
        res = tpus
        raise ValueError("This test requires 2 TPU cores but %s are found" %
      if len(tpus) == 2:
        res = tpus
      elif self._hasGPU() and not require_same_type:
        res = ("CPU:0", "GPU:0")
        res = ("CPU:0", "CPU:1")
    return res

  def testPmap(self):
    devices = self._get_two_devices()

    @functools.partial(extensions.pmap, devices=devices)
    def return_three(f):
      return f, f + 1.0, f + 2.0

    result = return_three(tf.ones((2, 20)))
    # The function returned 3 items, so we got 3 items back.
    self.assertLen(result, 3)

    # Each of the items should be a ShardedNdarray that when converted to tensor
    # should produce a tensor of shape (2, 20)
    converted = tf.nest.map_structure(tf.convert_to_tensor, result)

    self.assertLen(result, 3)

    self.assertAllEqual(converted[0].shape, converted[1].shape)
    self.assertAllEqual(converted[0].shape, converted[2].shape)

    self.assertAllEqual(converted[0], tf.ones((2, 20)))
    self.assertAllEqual(converted[1], 1 + tf.ones((2, 20)))
    self.assertAllEqual(converted[2], 2 + tf.ones((2, 20)))

    @functools.partial(extensions.pmap, devices=devices)
    def return_one(f):
      return f + 2.0

    result = return_one(tf.ones((2, 20)))

    # Only a single item is returned, so we can convert it directly.
    converted = tf.convert_to_tensor(value=result)
    self.assertAllEqual(converted, 2 + tf.ones((2, 20)))

    @functools.partial(extensions.pmap, devices=devices)
    def return_list(f):
      return [f + 2.0]

    result = return_list(tf.ones((2, 20)))

    # A singleton list is returned.
    self.assertLen(result, 1)
    converted = tf.convert_to_tensor(value=result[0])
    self.assertAllEqual(converted, 2 + tf.ones((2, 20)))

  def testGradSimpleModel(self):
    params, params_true, inputs, targets = generate_params_inputs_targets()

    for _ in range(50):
      params = train_step(params, inputs, targets)

    # This is not trained super well, but it usually gets "close".
    self.assertAllClose(params[0], params_true[0], atol=1e-1)
    self.assertAllClose(params[1], params_true[1], atol=1e-1)

  # NOTE: Compare to testGradSimpleModel to see the differences when pmapping.
  def testPmapSimpleModel(self):
    devices = self._get_two_devices(require_same_type=True)
    n_devices = len(devices)

    params, params_true, inputs, targets = generate_params_inputs_targets()

    def _train_and_reduce(params, inputs, targets, learning_rate=0.1):
      new_w, new_b = train_step(params, inputs, targets, learning_rate)

      return (extensions.psum(new_w) / n_devices,
              extensions.psum(new_b) / n_devices)

    train_step_pmapped = extensions.pmap(_train_and_reduce, devices=devices)

    def replicate(x, num_devices=2):
      return tf_np.broadcast_to(x, (num_devices,) + x.shape)

    params = tf.nest.map_structure(replicate, params)

    def reshape(x, num_devices=2):
      x_shape = list(x.shape)
      batch_size = x_shape[0]
      batch_size_per_device = batch_size // num_devices

      # New shape.
      new_shape_prefix = [num_devices, batch_size_per_device]
      return tf_np.reshape(x, new_shape_prefix + x_shape[1:])

    inputs = tf.nest.map_structure(reshape, inputs)
    targets = tf.nest.map_structure(reshape, targets)

    for _ in range(50):
      params = train_step_pmapped(params, inputs, targets)

    # PMAP returns sharded tensors.

    # Since the inputs are identical, the returned tensors should be identical
    self.assertAllClose(params[0][0], params[0][1])
    self.assertAllClose(params[1][0], params[1][1])

    # This is not trained super well, but it usually gets "close".
    self.assertAllClose(params[0][0], params_true[0], atol=1e-1)
    self.assertAllClose(params[1][0], params_true[1], atol=1e-1)

  def testPsum(self):
    devices = self._get_two_devices(require_same_type=True)

    def reduce_sum(f):
      return extensions.psum(f)

    data = tf_np.asarray(tf.convert_to_tensor(value=[1, 3]))
    pmapped = extensions.pmap(reduce_sum, devices=devices)
    result = pmapped(data)

    self.assertAllClose(result[0], 4)
    self.assertAllClose(result[1], 4)

  def testPsumStruct(self):
    devices = self._get_two_devices(require_same_type=True)

    def reduce_sum(a):
      a = extensions.psum(a)
          lambda x: self.assertIsInstance(x, tf_np.ndarray), a)
      return a

    data = [tf_np.asarray([1, 3]), tf_np.asarray([2, 4], np.int64)]
    pmapped = extensions.pmap(reduce_sum, devices=devices)
    result = pmapped(data)

    self.assertIsInstance(result[0][0], tf_np.ndarray)
    self.assertIsInstance(result[0][1], tf_np.ndarray)
    self.assertIsInstance(result[1][0], tf_np.ndarray)
    self.assertIsInstance(result[1][1], tf_np.ndarray)
    self.assertAllClose(result[0][0], 4)
    self.assertAllClose(result[0][1], 4)
    self.assertAllClose(result[1][0], 6)
    self.assertAllClose(result[1][1], 6)

  def testPmean(self):
    if extensions.tpu_devices():
      self.skipTest("pmean for TPU is not supported yet")
    devices = self._get_two_devices(require_same_type=True)

    def reduce_mean(f):
      return extensions.pmean(f)

    data = tf_np.asarray(tf.convert_to_tensor(value=[1, 3]))
    pmapped = extensions.pmap(reduce_mean, devices=devices)
    result = pmapped(data)

    self.assertAllClose(result[0], 2)
    self.assertAllClose(result[1], 2)

  def testAxisName(self):
    devices = self._get_two_devices(require_same_type=True)

    def reduce_sum(f):
      return extensions.psum(f, axis_name="foo")

    data = tf_np.asarray(tf.convert_to_tensor(value=[1, 3]))
    pmapped = extensions.pmap(reduce_sum, axis_name="foo", devices=devices)

  def testWrongAxisName(self):
    devices = self._get_two_devices(require_same_type=True)

    def reduce_sum(f):
      return extensions.psum(f, axis_name="bar")

    data = tf_np.asarray(tf.convert_to_tensor(value=[1, 3]))
    with self.assertRaisesWithPredicateMatch(
        ValueError, r"axis_name (.*) is not equal to that of the surrounding"):
      pmapped = extensions.pmap(reduce_sum, axis_name="foo", devices=devices)

  def testNoNestedPmap(self):
    devices = self._get_two_devices(require_same_type=True)

    def f(x):
      return x + 1.0

    data = tf_np.asarray(tf.convert_to_tensor(value=[1, 3]))
    with self.assertRaisesWithPredicateMatch(ValueError,
                                             r"Nested pmap is not supported"):
      f = extensions.pmap(f, devices=devices)
      f = extensions.pmap(f, devices=devices)

  def testVmap(self):
    fn1 = extensions.vmap(lambda z: z * z)

    x = tf_np.arange(10)
    self.assertAllClose(x * x, fn1(x))

    y = tf.range(10)
    np_y = tf_np.asarray(y)
    output = fn1(y)
    self.assertIsInstance(output, tf_np.ndarray)
    self.assertAllClose(np_y * np_y, output)

    fn2 = extensions.vmap(lambda x, y: x + y)
    x = tf_np.random.randn(10, 3)
    y = tf_np.random.randn(10, 2, 3)
    self.assertAllClose(tf_np.expand_dims(x, 1) + y, fn2(x, y))
 def mean_squared_error(x, y):
     diff = x - y
     return np.sum(diff * diff) / len(x)
Example #14
def _fold_in(rng, d):
    """Equivalent of jax.random.fold_in."""
    # TODO(lukaszkaiser): verify that this function has good randomness
    # properties or switch to an implementation equivalent to JAX.
    _, rng = tf_np_extensions.split(rng + tf_np.sum(d).astype(tf_np.int64), 2)
    return rng