def testAdvancedIndexingManually(self): x = onp.random.RandomState(0).randn(3, 4, 5) index_array = onp.array([0, 2, -1, 0]) op = lambda x, index_array: x[..., index_array, :] cop = npe.jit(op) a1 = op(x, index_array) a2 = cop(x, index_array) self.assertAllClose(a1, a2, check_dtypes=True) op = lambda x, index_array: x[..., index_array, :, index_array, None] cop = npe.jit(op) a1 = op(x, index_array) a2 = cop(x, index_array) self.assertAllClose(a1, a2, check_dtypes=True) op = lambda x, index_array: x[index_array, ..., index_array[:, None], None] cop = npe.jit(op) a1 = op(x, index_array) a2 = cop(x, index_array) self.assertAllClose(a1, a2, check_dtypes=True)
def testScanImpl(self, jit_scan, jit_f): rng = np.random.RandomState(0) d = rng.randn(2) def f(c, a): assert a.shape == (3,) assert c.shape == (4,) b = tf_np.cos(tf_np.sum(tf_np.sin(a)) + tf_np.sum(tf_np.cos(c)) + tf_np.sum(tf_np.tan(d))) c = tf_np.sin(c * b) assert b.shape == () # pylint: disable=g-explicit-bool-comparison return c, b if jit_f: f = extensions.jit(f) if jit_scan: scan = extensions.jit(extensions.scan, (0,)) else: scan = extensions.scan xs = rng.randn(5, 3) c = rng.randn(4) ans = scan(f, c, xs) expected = scan_reference(f, c, xs) self.assertDTypesEqual(expected, ans) self.assertAllClose(expected, ans)
def testRematJit(self): def f(a, b): return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) f_remat = extensions.remat(f) shape = [10] a = tf_np.random.randn(*shape) b = tf_np.random.randn(*shape) actual = extensions.jit(extensions.grad(f_remat))(a, b) expected = extensions.jit(extensions.grad(f))(a, b) self.assertAllClose(actual, expected)
def testFloatIndexingError(self): error_regex = "only integers, slices.*are valid indices" # Verify onp behavior with self.assertRaisesRegex(IndexError, error_regex): _ = onp.zeros((2, 2))[(0, 0.)] # Test jnp with self.assertRaisesRegex(IndexError, error_regex): jnp.zeros(2)[0.] with self.assertRaisesRegex(IndexError, error_regex): jnp.zeros((2, 2))[(0, 0.)] # Test with jit with self.assertRaisesRegex(IndexError, error_regex): npe.jit(lambda idx: jnp.zeros((2, 2))[idx])((0, 0.))
def testFloatIndexingError(self): BAD_INDEX_TYPE_ERROR = "Indexer must have integer or boolean type, got indexer with type" with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): jnp.zeros(2)[0.] with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): jnp.zeros((2, 2))[(0, 0.)] with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): jnp.zeros((2, 2))[(0, 0.)] with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): npe.jit(lambda idx: jnp.zeros((2, 2))[idx])((0, 0.)) with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): ops.index_add(jnp.zeros(2), 0., 1.) with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): ops.index_update(jnp.zeros(2), 0., 1.)
def testIndexingEmptyDimension(self): # Issue 2671: XLA error when indexing into dimension of size 0 x = jnp.ones((2, 0)) # The following work, even on axis 1 of size 0 _ = x[0, :] + x[0, None] + x[0, 1:] + x[0, 1:3:2] with self.assertRaisesRegex(IndexError, "index .* is out of bounds for axis .* with size 0"): _ = onp.ones((2, 0))[0, 0] # The numpy error with self.assertRaisesRegex(IndexError, "index is out of bounds for axis .* with size 0"): _ = x[0, 0] # JAX indexing with self.assertRaisesRegex(IndexError, "index is out of bounds for axis .* with size 0"): npe.jit(lambda i: x[0, i])(0) # JAX indexing under jit
def transformer(f, **kwargs): def f_prime(*args): res = extensions.eval_on_shapes(f, **kwargs)(*args) return tf.nest.map_structure( lambda x: tf_np.zeros(x.shape, x.dtype), res) return extensions.jit(f_prime, kwargs.get("static_argnums", ()))
def testIssue187(self): x = jnp.ones((5, 5)) x[[0, 2, 4], [0, 2, 4]] # doesn't crash x = onp.arange(25).reshape((5, 5)) ans = npe.jit(lambda x: x[[0, 2, 4], [0, 2, 4]])(x) expected = x[[0, 2, 4], [0, 2, 4]] self.assertAllClose(ans, expected, check_dtypes=False)
def testScanGrad(self, jit_scan, jit_f): rng = np.random.RandomState(0) d = rng.randn(2) def f(c, a): assert a.shape == (3, ) assert c.shape == (4, ) b = (tf_np.sum(tf_np.sin(a)) + tf_np.sum(tf_np.sin(c)) + tf_np.sum(tf_np.sin(d))) c = tf_np.sin(c * b) assert b.shape == () # pylint: disable=g-explicit-bool-comparison return c, b if jit_f: f = extensions.jit(f) if jit_scan: scan = extensions.jit(extensions.scan, static_argnums=(0, )) else: scan = extensions.scan xs = tf_np.asarray(rng.randn(5, 3)) c = tf_np.asarray(rng.randn(4)) def losses(scan, c, xs): c, ys = scan(f, c, xs) return tf_np.concatenate( tf.nest.flatten( tf.nest.map_structure(lambda a: tf_np.reshape(a, [-1]), (c, ys)))) def loss(scan, c, xs): return tf_np.sum(losses(scan, c, xs)) ans = extensions.grad(functools.partial(loss, scan))(c, xs) expected = extensions.grad(functools.partial(loss, scan_reference))(c, xs) self.assertDTypesEqual(expected, ans) self.assertAllClose(expected, ans) theoretical, numerical = tf.test.compute_gradient( to_tf_fn(functools.partial(losses, scan)), (c, xs)) self.assertAllClose(theoretical, numerical, atol=1e-3, rtol=3e-4)
def testJit(self): def f(a, b): return math.sum(math.sqrt(math.exp(a)) + b) f_jitted = extensions.jit(f) shape = [10] a = random.randn(*shape) b = random.randn(*shape) self.assertAllClose(f(a, b), f_jitted(a, b)) # Call again since the code path is different on second call self.assertAllClose(f(a, b), f_jitted(a, b))
def testUnpacking(self): def foo(x): a, b, c = x return a + b + c cfoo = npe.jit(foo) a1 = foo(onp.arange(3)) a2 = cfoo(onp.arange(3)) self.assertAllClose(a1, a2, check_dtypes=True)
def testRematJitXla(self): def f(a, b): return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) f_remat = extensions.remat(f) shape = [10] a = tf_np.random.randn(*shape) b = tf_np.random.randn(*shape) actual = extensions.jit(extensions.grad(f_remat), xla_forced_compile=True)(a, b) expected = extensions.jit(extensions.grad(f), xla_forced_compile=True)(a, b) self.assertAllClose(actual, expected) actual = extensions.jit(extensions.grad(f_remat), experimental_compile=True)(a, b) expected = extensions.jit(extensions.grad(f), experimental_compile=True)(a, b) self.assertAllClose(actual, expected)
def testScanImpl(self, jit_scan, jit_f): rng = np.random.RandomState(0) d = rng.randn(2) def f(c, a): assert a.shape == (3, ) assert c.shape == (4, ) b = tf_np.cos( tf_np.sum(tf_np.sin(a)) + tf_np.sum(tf_np.cos(c)) + tf_np.sum(tf_np.tan(d))) c = tf_np.sin(c * b) assert b.shape == () # pylint: disable=g-explicit-bool-comparison return c, b if jit_f: f = extensions.jit(f) if jit_scan == "no_xla": scan = extensions.jit(extensions.scan, static_argnums=(0, )) elif jit_scan == "xla_forced_compile": scan = extensions.jit(extensions.scan, static_argnums=(0, ), xla_forced_compile=True) else: scan = extensions.scan xs = rng.randn(5, 3) c = rng.randn(4) ans = scan(f, c, xs) expected = scan_reference(f, c, xs) if jit_scan == "xla_forced_compile": # xla.compile doesn't preserve list-vs-tuple properly for the outputs, so # we canonicalize them to lists here. expected = list(expected) ans = list(ans) self.assertDTypesEqual(expected, ans) self.assertAllClose(expected, ans)
def check_compile(**kwargs): # `wrapped_fun` and `python_should_be_executing` are used to check that # when the jitted function is called the second time, the original Python # function won't be executed. def wrapped_fun(*args): self.assertTrue(python_should_be_executing) return fun(*args) cfun = npe.jit(wrapped_fun, static_argnums=static_argnums, **kwargs) python_should_be_executing = True monitored_ans = cfun(*args) python_should_be_executing = False compiled_ans = cfun(*args) self.assertAllClose(python_ans, monitored_ans, check_dtypes, atol, rtol) self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol, rtol) # Run `cfun` with a different set of arguments to check that changing # arguments won't cause recompilation. new_args = args_maker() skip_retracing_test = False for old, new in zip(tf.nest.flatten(args), tf.nest.flatten(new_args)): if npe.most_precise_int_dtype( old) != npe.most_precise_int_dtype(new): # If the old and new arguments result in different dtypes (because # they fall into different value ranges), tf-numpy will retrace, so we # skip the no-retrace test. skip_retracing_test = True if not skip_retracing_test: python_should_be_executing = True new_python_ans = fun(*new_args) python_should_be_executing = False compiled_ans = cfun(*new_args) self.assertAllClose(new_python_ans, compiled_ans, check_dtypes, atol, rtol)
def _CompileAndCheck(self, fun, args_maker, check_dtypes, rtol=None, atol=None, check_eval_on_shapes=True, check_incomplete_shape=False, check_unknown_rank=True, static_argnums=()): """Compiles the function and checks the results. Args: fun: the function to be checked. args_maker: a callable that returns a tuple which will be used as the positional arguments. check_dtypes: whether to check that the result dtypes from non-compiled and compiled runs agree. rtol: relative tolerance for allclose assertions. atol: absolute tolerance for allclose assertions. check_eval_on_shapes: whether to run `eval_on_shapes` on the function and check that the result shapes and dtypes are correct. check_incomplete_shape: whether to check that the function can handle incomplete shapes (including those with and without a known rank). check_unknown_rank: (only has effect when check_incomplete_shape is True) whether to check that the function can handle unknown ranks. static_argnums: indices of arguments to be treated as static arguments for `jit` and `eval_on_shapes`. """ args = args_maker() for x in args: if not hasattr(x, 'dtype'): # If there is a input that doesn't have dtype info, jit and # eval_on_shapes may pick a different dtype for it than numpy, so we # skip the dtype check. check_dtypes = False # `wrapped_fun` and `python_should_be_executing` are used to check that when # the jitted function is called the second time, the original Python # function won't be executed. def wrapped_fun(*args): self.assertTrue(python_should_be_executing) return fun(*args) python_ans = fun(*args) python_shapes = tf.nest.map_structure(lambda x: onp.shape(x), python_ans) onp_shapes = tf.nest.map_structure(lambda x: onp.shape(onp.asarray(x)), python_ans) self.assertEqual(python_shapes, onp_shapes) cfun = npe.jit(wrapped_fun, static_argnums=static_argnums) python_should_be_executing = True monitored_ans = cfun(*args) python_should_be_executing = False compiled_ans = cfun(*args) self.assertAllClose(python_ans, monitored_ans, check_dtypes, atol, rtol) self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol, rtol) # Run `cfun` with a different set of arguments to check that changing # arguments won't cause recompilation. new_args = args_maker() skip_retracing_test = False for old, new in zip(args, new_args): if npe.most_precise_int_dtype(old) != npe.most_precise_int_dtype(new): # If the old and new arguments result in different dtypes (because they # fall into different value ranges), tf-numpy will retrace, so we skip # the no-retrace test. skip_retracing_test = True if not skip_retracing_test: python_should_be_executing = True new_python_ans = fun(*new_args) python_should_be_executing = False compiled_ans = cfun(*new_args) self.assertAllClose(new_python_ans, compiled_ans, check_dtypes, atol, rtol) if check_eval_on_shapes: # Check that npe.eval_on_shapes can get complete output shapes given # complete input shapes. cfun = npe.eval_on_shapes(fun, static_argnums=static_argnums) compiled_ans = cfun(*args) flat_python_ans = tf.nest.flatten(python_ans) flat_compiled_ans = tf.nest.flatten(compiled_ans) self.assertEqual(len(flat_python_ans), len(flat_compiled_ans)) for a, b in zip(flat_python_ans, flat_compiled_ans): if hasattr(a, 'shape'): self.assertEqual(a.shape, b.shape) if check_dtypes and hasattr(a, 'dtype'): self.assertEqual(tf.as_dtype(a.dtype), b.dtype) # If some argument doesn't have a `dtype` attr (e.g. a Python scalar), we # skip incomplete-shape checks, since shape specs need dtype. It's OK to # skip since the same incomplete-shape checks will run for []-shaped arrays. if check_incomplete_shape and all(hasattr(x, 'dtype') for x in args): # Check partial shapes with known ranks. # Numpy scalars (created by e.g. np.int32(5)) have `dtype` but not # `shape`. if all(hasattr(x, 'shape') for x in args): specs = [tf.TensorSpec([None] * len(x.shape), x.dtype) for x in args] cfun = npe.jit( fun, static_argnums=static_argnums, input_signature=specs) compiled_ans = cfun(*args) self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol, rtol) if check_unknown_rank: # Check unknown ranks. specs = [tf.TensorSpec(None, x.dtype) for x in args] cfun = npe.jit( fun, static_argnums=static_argnums, input_signature=specs) compiled_ans = cfun(*args) self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol, rtol)
def _tf_jit(*args, **kwargs): kwargs['xla_forced_compile'] = tf_xla_forced_compile_enabled() return tf_np_extensions.jit(*args, **kwargs)
def _tf_jit(*args, **kwargs): kwargs['xla_forced_compile'] = tf_xla_forced_compile_enabled() kwargs.pop('donate_argnums', None) # donate_argnums not used in TF return tf_np_extensions.jit(*args, **kwargs)
def testBooleanIndexingDynamicShapeError(self): x = onp.zeros(3) i = onp.array([True, True, False]) self.assertRaises(IndexError, lambda: npe.jit(lambda x, i: x[i])(x, i))
def testScanGrad(self, jit_grad, jit_scan, jit_f): rng = np.random.RandomState(0) d = rng.randn(2) def f(c, a): assert a.shape == (3, ) assert c.shape == (4, ) b = (tf_np.sum(tf_np.sin(a)) + tf_np.sum(tf_np.sin(c)) + tf_np.sum(tf_np.sin(d))) c = tf_np.sin(c * b) assert b.shape == () # pylint: disable=g-explicit-bool-comparison return c, b if jit_f: f = extensions.jit(f) if jit_scan == "no_xla": scan = extensions.jit(extensions.scan, static_argnums=(0, )) elif jit_scan == "xla_forced_compile": # TODO(b/187107596): Remove `skipTest` self.skipTest( "Taking gradients of `jit(scan, experimental_compile=True)` triggers " "'Support for TensorList crossing the XLA/TF boundary is not " "implemented' error") # `xla_forced_compile=True` doesn't support gradients, so we use # `experimental_compile=True`. scan = extensions.jit(extensions.scan, static_argnums=(0, ), experimental_compile=True) else: scan = extensions.scan xs = tf_np.asarray(rng.randn(5, 3)) c = tf_np.asarray(rng.randn(4)) def losses(scan, c, xs): c, ys = scan(f, c, xs) return tf_np.concatenate( tf.nest.flatten( tf.nest.map_structure(lambda a: tf_np.reshape(a, [-1]), (c, ys)))) def loss(scan, c, xs): return tf_np.sum(losses(scan, c, xs)) def grad_origin(c, xs): return extensions.grad(functools.partial(loss, scan))(c, xs) if jit_grad == "no_xla": grad_jit = extensions.jit(grad_origin) elif jit_grad == "xla_forced_compile": grad_jit = extensions.jit(grad_origin, xla_forced_compile=True) else: grad_jit = grad_origin ans = grad_jit(c, xs) expected = extensions.grad(functools.partial(loss, scan_reference))(c, xs) self.assertDTypesEqual(expected, ans) self.assertAllClose(expected, ans) theoretical, numerical = tf.test.compute_gradient( to_tf_fn(functools.partial(losses, scan)), (c, xs)) self.assertAllClose(theoretical, numerical, atol=1e-3, rtol=3e-4)