Ejemplo n.º 1
0
  def testAdvancedIndexingManually(self):
    x = onp.random.RandomState(0).randn(3, 4, 5)
    index_array = onp.array([0, 2, -1, 0])

    op = lambda x, index_array: x[..., index_array, :]
    cop = npe.jit(op)

    a1 = op(x, index_array)
    a2 = cop(x, index_array)

    self.assertAllClose(a1, a2, check_dtypes=True)

    op = lambda x, index_array: x[..., index_array, :, index_array, None]
    cop = npe.jit(op)

    a1 = op(x, index_array)
    a2 = cop(x, index_array)

    self.assertAllClose(a1, a2, check_dtypes=True)

    op = lambda x, index_array: x[index_array, ..., index_array[:, None], None]
    cop = npe.jit(op)

    a1 = op(x, index_array)
    a2 = cop(x, index_array)

    self.assertAllClose(a1, a2, check_dtypes=True)
Ejemplo n.º 2
0
  def testScanImpl(self, jit_scan, jit_f):
    rng = np.random.RandomState(0)

    d = rng.randn(2)
    def f(c, a):
      assert a.shape == (3,)
      assert c.shape == (4,)
      b = tf_np.cos(tf_np.sum(tf_np.sin(a)) + tf_np.sum(tf_np.cos(c)) +
                    tf_np.sum(tf_np.tan(d)))
      c = tf_np.sin(c * b)
      assert b.shape == ()  # pylint: disable=g-explicit-bool-comparison
      return c, b

    if jit_f:
      f = extensions.jit(f)

    if jit_scan:
      scan = extensions.jit(extensions.scan, (0,))
    else:
      scan = extensions.scan

    xs = rng.randn(5, 3)
    c = rng.randn(4)

    ans = scan(f, c, xs)
    expected = scan_reference(f, c, xs)
    self.assertDTypesEqual(expected, ans)
    self.assertAllClose(expected, ans)
Ejemplo n.º 3
0
    def testRematJit(self):
        def f(a, b):
            return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b)

        f_remat = extensions.remat(f)

        shape = [10]
        a = tf_np.random.randn(*shape)
        b = tf_np.random.randn(*shape)

        actual = extensions.jit(extensions.grad(f_remat))(a, b)
        expected = extensions.jit(extensions.grad(f))(a, b)
        self.assertAllClose(actual, expected)
Ejemplo n.º 4
0
 def testFloatIndexingError(self):
   error_regex = "only integers, slices.*are valid indices"
   # Verify onp behavior
   with self.assertRaisesRegex(IndexError, error_regex):
     _ = onp.zeros((2, 2))[(0, 0.)]
   # Test jnp
   with self.assertRaisesRegex(IndexError, error_regex):
     jnp.zeros(2)[0.]
   with self.assertRaisesRegex(IndexError, error_regex):
     jnp.zeros((2, 2))[(0, 0.)]
   # Test with jit
   with self.assertRaisesRegex(IndexError, error_regex):
     npe.jit(lambda idx: jnp.zeros((2, 2))[idx])((0, 0.))
Ejemplo n.º 5
0
 def testFloatIndexingError(self):
     BAD_INDEX_TYPE_ERROR = "Indexer must have integer or boolean type, got indexer with type"
     with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
         jnp.zeros(2)[0.]
     with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
         jnp.zeros((2, 2))[(0, 0.)]
     with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
         jnp.zeros((2, 2))[(0, 0.)]
     with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
         npe.jit(lambda idx: jnp.zeros((2, 2))[idx])((0, 0.))
     with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
         ops.index_add(jnp.zeros(2), 0., 1.)
     with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
         ops.index_update(jnp.zeros(2), 0., 1.)
Ejemplo n.º 6
0
  def testIndexingEmptyDimension(self):
    # Issue 2671: XLA error when indexing into dimension of size 0
    x = jnp.ones((2, 0))
    # The following work, even on axis 1 of size 0
    _ = x[0, :] + x[0, None] + x[0, 1:] + x[0, 1:3:2]

    with self.assertRaisesRegex(IndexError,
                                "index .* is out of bounds for axis .* with size 0"):
      _ = onp.ones((2, 0))[0, 0]  # The numpy error
    with self.assertRaisesRegex(IndexError,
                                "index is out of bounds for axis .* with size 0"):
      _ = x[0, 0]  # JAX indexing
    with self.assertRaisesRegex(IndexError,
                                "index is out of bounds for axis .* with size 0"):
      npe.jit(lambda i: x[0, i])(0)  # JAX indexing under jit
Ejemplo n.º 7
0
        def transformer(f, **kwargs):
            def f_prime(*args):
                res = extensions.eval_on_shapes(f, **kwargs)(*args)
                return tf.nest.map_structure(
                    lambda x: tf_np.zeros(x.shape, x.dtype), res)

            return extensions.jit(f_prime, kwargs.get("static_argnums", ()))
Ejemplo n.º 8
0
  def testIssue187(self):
    x = jnp.ones((5, 5))
    x[[0, 2, 4], [0, 2, 4]]  # doesn't crash

    x = onp.arange(25).reshape((5, 5))
    ans = npe.jit(lambda x: x[[0, 2, 4], [0, 2, 4]])(x)
    expected = x[[0, 2, 4], [0, 2, 4]]
    self.assertAllClose(ans, expected, check_dtypes=False)
Ejemplo n.º 9
0
    def testScanGrad(self, jit_scan, jit_f):
        rng = np.random.RandomState(0)

        d = rng.randn(2)

        def f(c, a):
            assert a.shape == (3, )
            assert c.shape == (4, )
            b = (tf_np.sum(tf_np.sin(a)) + tf_np.sum(tf_np.sin(c)) +
                 tf_np.sum(tf_np.sin(d)))
            c = tf_np.sin(c * b)
            assert b.shape == ()  # pylint: disable=g-explicit-bool-comparison
            return c, b

        if jit_f:
            f = extensions.jit(f)

        if jit_scan:
            scan = extensions.jit(extensions.scan, static_argnums=(0, ))
        else:
            scan = extensions.scan

        xs = tf_np.asarray(rng.randn(5, 3))
        c = tf_np.asarray(rng.randn(4))

        def losses(scan, c, xs):
            c, ys = scan(f, c, xs)
            return tf_np.concatenate(
                tf.nest.flatten(
                    tf.nest.map_structure(lambda a: tf_np.reshape(a, [-1]),
                                          (c, ys))))

        def loss(scan, c, xs):
            return tf_np.sum(losses(scan, c, xs))

        ans = extensions.grad(functools.partial(loss, scan))(c, xs)
        expected = extensions.grad(functools.partial(loss, scan_reference))(c,
                                                                            xs)
        self.assertDTypesEqual(expected, ans)
        self.assertAllClose(expected, ans)

        theoretical, numerical = tf.test.compute_gradient(
            to_tf_fn(functools.partial(losses, scan)), (c, xs))
        self.assertAllClose(theoretical, numerical, atol=1e-3, rtol=3e-4)
Ejemplo n.º 10
0
 def testJit(self):
   def f(a, b):
     return math.sum(math.sqrt(math.exp(a)) + b)
   f_jitted = extensions.jit(f)
   shape = [10]
   a = random.randn(*shape)
   b = random.randn(*shape)
   self.assertAllClose(f(a, b), f_jitted(a, b))
   # Call again since the code path is different on second call
   self.assertAllClose(f(a, b), f_jitted(a, b))
Ejemplo n.º 11
0
    def testUnpacking(self):
        def foo(x):
            a, b, c = x
            return a + b + c

        cfoo = npe.jit(foo)

        a1 = foo(onp.arange(3))
        a2 = cfoo(onp.arange(3))

        self.assertAllClose(a1, a2, check_dtypes=True)
Ejemplo n.º 12
0
    def testRematJitXla(self):
        def f(a, b):
            return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b)

        f_remat = extensions.remat(f)

        shape = [10]
        a = tf_np.random.randn(*shape)
        b = tf_np.random.randn(*shape)

        actual = extensions.jit(extensions.grad(f_remat),
                                xla_forced_compile=True)(a, b)
        expected = extensions.jit(extensions.grad(f),
                                  xla_forced_compile=True)(a, b)
        self.assertAllClose(actual, expected)

        actual = extensions.jit(extensions.grad(f_remat),
                                experimental_compile=True)(a, b)
        expected = extensions.jit(extensions.grad(f),
                                  experimental_compile=True)(a, b)
        self.assertAllClose(actual, expected)
Ejemplo n.º 13
0
    def testScanImpl(self, jit_scan, jit_f):
        rng = np.random.RandomState(0)

        d = rng.randn(2)

        def f(c, a):
            assert a.shape == (3, )
            assert c.shape == (4, )
            b = tf_np.cos(
                tf_np.sum(tf_np.sin(a)) + tf_np.sum(tf_np.cos(c)) +
                tf_np.sum(tf_np.tan(d)))
            c = tf_np.sin(c * b)
            assert b.shape == ()  # pylint: disable=g-explicit-bool-comparison
            return c, b

        if jit_f:
            f = extensions.jit(f)

        if jit_scan == "no_xla":
            scan = extensions.jit(extensions.scan, static_argnums=(0, ))
        elif jit_scan == "xla_forced_compile":
            scan = extensions.jit(extensions.scan,
                                  static_argnums=(0, ),
                                  xla_forced_compile=True)
        else:
            scan = extensions.scan

        xs = rng.randn(5, 3)
        c = rng.randn(4)

        ans = scan(f, c, xs)
        expected = scan_reference(f, c, xs)
        if jit_scan == "xla_forced_compile":
            # xla.compile doesn't preserve list-vs-tuple properly for the outputs, so
            # we canonicalize them to lists here.
            expected = list(expected)
            ans = list(ans)
        self.assertDTypesEqual(expected, ans)
        self.assertAllClose(expected, ans)
Ejemplo n.º 14
0
        def check_compile(**kwargs):
            # `wrapped_fun` and `python_should_be_executing` are used to check that
            # when the jitted function is called the second time, the original Python
            # function won't be executed.
            def wrapped_fun(*args):
                self.assertTrue(python_should_be_executing)
                return fun(*args)

            cfun = npe.jit(wrapped_fun,
                           static_argnums=static_argnums,
                           **kwargs)
            python_should_be_executing = True
            monitored_ans = cfun(*args)

            python_should_be_executing = False
            compiled_ans = cfun(*args)

            self.assertAllClose(python_ans, monitored_ans, check_dtypes, atol,
                                rtol)
            self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol,
                                rtol)

            # Run `cfun` with a different set of arguments to check that changing
            # arguments won't cause recompilation.

            new_args = args_maker()

            skip_retracing_test = False
            for old, new in zip(tf.nest.flatten(args),
                                tf.nest.flatten(new_args)):
                if npe.most_precise_int_dtype(
                        old) != npe.most_precise_int_dtype(new):
                    # If the old and new arguments result in different dtypes (because
                    # they fall into different value ranges), tf-numpy will retrace, so we
                    # skip the no-retrace test.
                    skip_retracing_test = True

            if not skip_retracing_test:
                python_should_be_executing = True
                new_python_ans = fun(*new_args)
                python_should_be_executing = False
                compiled_ans = cfun(*new_args)
                self.assertAllClose(new_python_ans, compiled_ans, check_dtypes,
                                    atol, rtol)
Ejemplo n.º 15
0
  def _CompileAndCheck(self,
                       fun,
                       args_maker,
                       check_dtypes,
                       rtol=None,
                       atol=None,
                       check_eval_on_shapes=True,
                       check_incomplete_shape=False,
                       check_unknown_rank=True,
                       static_argnums=()):
    """Compiles the function and checks the results.

    Args:
      fun: the function to be checked.
      args_maker: a callable that returns a tuple which will be used as the
        positional arguments.
      check_dtypes: whether to check that the result dtypes from non-compiled
        and compiled runs agree.
      rtol: relative tolerance for allclose assertions.
      atol: absolute tolerance for allclose assertions.
      check_eval_on_shapes: whether to run `eval_on_shapes` on the function and
        check that the result shapes and dtypes are correct.
      check_incomplete_shape: whether to check that the function can handle
        incomplete shapes (including those with and without a known rank).
      check_unknown_rank: (only has effect when check_incomplete_shape is True)
        whether to check that the function can handle unknown ranks.
      static_argnums: indices of arguments to be treated as static arguments for
        `jit` and `eval_on_shapes`.
    """
    args = args_maker()

    for x in args:
      if not hasattr(x, 'dtype'):
        # If there is a input that doesn't have dtype info, jit and
        # eval_on_shapes may pick a different dtype for it than numpy, so we
        # skip the dtype check.
        check_dtypes = False

    # `wrapped_fun` and `python_should_be_executing` are used to check that when
    # the jitted function is called the second time, the original Python
    # function won't be executed.
    def wrapped_fun(*args):
      self.assertTrue(python_should_be_executing)
      return fun(*args)

    python_ans = fun(*args)

    python_shapes = tf.nest.map_structure(lambda x: onp.shape(x), python_ans)
    onp_shapes = tf.nest.map_structure(lambda x: onp.shape(onp.asarray(x)),
                                       python_ans)
    self.assertEqual(python_shapes, onp_shapes)

    cfun = npe.jit(wrapped_fun, static_argnums=static_argnums)
    python_should_be_executing = True
    monitored_ans = cfun(*args)

    python_should_be_executing = False
    compiled_ans = cfun(*args)

    self.assertAllClose(python_ans, monitored_ans, check_dtypes, atol, rtol)
    self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol, rtol)

    # Run `cfun` with a different set of arguments to check that changing
    # arguments won't cause recompilation.

    new_args = args_maker()

    skip_retracing_test = False
    for old, new in zip(args, new_args):
      if npe.most_precise_int_dtype(old) != npe.most_precise_int_dtype(new):
        # If the old and new arguments result in different dtypes (because they
        # fall into different value ranges), tf-numpy will retrace, so we skip
        # the no-retrace test.
        skip_retracing_test = True

    if not skip_retracing_test:
      python_should_be_executing = True
      new_python_ans = fun(*new_args)
      python_should_be_executing = False
      compiled_ans = cfun(*new_args)
      self.assertAllClose(new_python_ans, compiled_ans, check_dtypes, atol,
                          rtol)

    if check_eval_on_shapes:
      # Check that npe.eval_on_shapes can get complete output shapes given
      # complete input shapes.
      cfun = npe.eval_on_shapes(fun, static_argnums=static_argnums)
      compiled_ans = cfun(*args)
      flat_python_ans = tf.nest.flatten(python_ans)
      flat_compiled_ans = tf.nest.flatten(compiled_ans)
      self.assertEqual(len(flat_python_ans), len(flat_compiled_ans))
      for a, b in zip(flat_python_ans, flat_compiled_ans):
        if hasattr(a, 'shape'):
          self.assertEqual(a.shape, b.shape)
        if check_dtypes and hasattr(a, 'dtype'):
          self.assertEqual(tf.as_dtype(a.dtype), b.dtype)

    # If some argument doesn't have a `dtype` attr (e.g. a Python scalar), we
    # skip incomplete-shape checks, since shape specs need dtype. It's OK to
    # skip since the same incomplete-shape checks will run for []-shaped arrays.
    if check_incomplete_shape and all(hasattr(x, 'dtype') for x in args):
      # Check partial shapes with known ranks.
      # Numpy scalars (created by e.g. np.int32(5)) have `dtype` but not
      # `shape`.
      if all(hasattr(x, 'shape') for x in args):
        specs = [tf.TensorSpec([None] * len(x.shape), x.dtype) for x in args]
        cfun = npe.jit(
            fun, static_argnums=static_argnums, input_signature=specs)
        compiled_ans = cfun(*args)
        self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol, rtol)

      if check_unknown_rank:
        # Check unknown ranks.
        specs = [tf.TensorSpec(None, x.dtype) for x in args]
        cfun = npe.jit(
            fun, static_argnums=static_argnums, input_signature=specs)
        compiled_ans = cfun(*args)
        self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol, rtol)
Ejemplo n.º 16
0
Archivo: math.py Proyecto: zzszmyf/trax
def _tf_jit(*args, **kwargs):
  kwargs['xla_forced_compile'] = tf_xla_forced_compile_enabled()
  return tf_np_extensions.jit(*args, **kwargs)
Ejemplo n.º 17
0
def _tf_jit(*args, **kwargs):
    kwargs['xla_forced_compile'] = tf_xla_forced_compile_enabled()
    kwargs.pop('donate_argnums', None)  # donate_argnums not used in TF
    return tf_np_extensions.jit(*args, **kwargs)
Ejemplo n.º 18
0
 def testBooleanIndexingDynamicShapeError(self):
     x = onp.zeros(3)
     i = onp.array([True, True, False])
     self.assertRaises(IndexError, lambda: npe.jit(lambda x, i: x[i])(x, i))
Ejemplo n.º 19
0
    def testScanGrad(self, jit_grad, jit_scan, jit_f):
        rng = np.random.RandomState(0)

        d = rng.randn(2)

        def f(c, a):
            assert a.shape == (3, )
            assert c.shape == (4, )
            b = (tf_np.sum(tf_np.sin(a)) + tf_np.sum(tf_np.sin(c)) +
                 tf_np.sum(tf_np.sin(d)))
            c = tf_np.sin(c * b)
            assert b.shape == ()  # pylint: disable=g-explicit-bool-comparison
            return c, b

        if jit_f:
            f = extensions.jit(f)

        if jit_scan == "no_xla":
            scan = extensions.jit(extensions.scan, static_argnums=(0, ))
        elif jit_scan == "xla_forced_compile":
            # TODO(b/187107596): Remove `skipTest`
            self.skipTest(
                "Taking gradients of `jit(scan, experimental_compile=True)` triggers "
                "'Support for TensorList crossing the XLA/TF boundary is not "
                "implemented' error")
            # `xla_forced_compile=True` doesn't support gradients, so we use
            # `experimental_compile=True`.
            scan = extensions.jit(extensions.scan,
                                  static_argnums=(0, ),
                                  experimental_compile=True)
        else:
            scan = extensions.scan

        xs = tf_np.asarray(rng.randn(5, 3))
        c = tf_np.asarray(rng.randn(4))

        def losses(scan, c, xs):
            c, ys = scan(f, c, xs)
            return tf_np.concatenate(
                tf.nest.flatten(
                    tf.nest.map_structure(lambda a: tf_np.reshape(a, [-1]),
                                          (c, ys))))

        def loss(scan, c, xs):
            return tf_np.sum(losses(scan, c, xs))

        def grad_origin(c, xs):
            return extensions.grad(functools.partial(loss, scan))(c, xs)

        if jit_grad == "no_xla":
            grad_jit = extensions.jit(grad_origin)
        elif jit_grad == "xla_forced_compile":
            grad_jit = extensions.jit(grad_origin, xla_forced_compile=True)
        else:
            grad_jit = grad_origin

        ans = grad_jit(c, xs)
        expected = extensions.grad(functools.partial(loss, scan_reference))(c,
                                                                            xs)
        self.assertDTypesEqual(expected, ans)
        self.assertAllClose(expected, ans)

        theoretical, numerical = tf.test.compute_gradient(
            to_tf_fn(functools.partial(losses, scan)), (c, xs))
        self.assertAllClose(theoretical, numerical, atol=1e-3, rtol=3e-4)