def testMultiBackend(self, backend): if backend not in ('cpu', jtu.device_under_test(), None): raise SkipTest("Backend is not CPU or the device under test") @partial(jax.jit, backend=backend) def fun(x, y): return jnp.matmul(x, y) x = npr.uniform(size=(10, 10)) y = npr.uniform(size=(10, 10)) z_host = np.matmul(x, y) z = fun(x, y) self.assertAllClose(z, z_host, rtol=1e-2) correct_platform = backend if backend else jtu.device_under_test() self.assertEqual(z.device_buffer.platform(), correct_platform)
def test_custom_linear_solve_cholesky(self): def positive_definite_solve(a, b): factors = jsp.linalg.cho_factor(a) def solve(matvec, x): return jsp.linalg.cho_solve(factors, x) matvec = partial(high_precision_dot, a) return lax.custom_linear_solve(matvec, b, solve, symmetric=True) rng = self.rng() a = rng.randn(2, 2) b = rng.randn(2) tol = {np.float32: 1E-3 if jtu.device_under_test() == "tpu" else 1E-5, np.float64: 1E-12} expected = jnp.linalg.solve(np.asarray(posify(a)), b) actual = positive_definite_solve(posify(a), b) self.assertAllClose(expected, actual, rtol=tol, atol=tol) actual = jax.jit(positive_definite_solve)(posify(a), b) self.assertAllClose(expected, actual, rtol=tol, atol=tol) # numerical gradients are only well defined if ``a`` is guaranteed to be # positive definite. jtu.check_grads( lambda x, y: positive_definite_solve(posify(x), y), (a, b), order=2, rtol=0.3)
def test_ragged_batched_rnn(self): n = 3 @partial(mask, in_shapes=('(_, _)', '(t, _)', '_'), out_shape='') def rnn(W, xs, target): def step(h, x): new_h = jnp.tanh(jnp.dot(W, h) + jnp.dot(W, x)) return new_h, () predicted, _ = lax.scan(step, jnp.zeros(n), xs) return jnp.sum((predicted - target)**2) rng = self.rng() W = rng.randn(n, n).astype(jnp.float_) seqs = rng.randn(3, 10, n).astype(jnp.float_) ts = jnp.array([2, 5, 4]) ys = rng.randn(3, n) ans = grad(lambda W: vmap(rnn, ((None, 0, 0), 0))((W, seqs, ys), dict(t=ts)).sum())(W) def rnn_reference(W, seqs, targets): total_loss = jnp.array(0.0) for xs, target in zip(seqs, targets): h = jnp.zeros(n) for x in xs: h = jnp.tanh(jnp.dot(W, h) + jnp.dot(W, x)) predicted = h total_loss = total_loss + jnp.sum((predicted - target)**2) return total_loss seqs_ = [xs[:t] for xs, t in zip(seqs, ts)] expected = grad(lambda W: rnn_reference(W, seqs_, ys).sum())(W) self.assertAllClose( ans, expected, check_dtypes=False, rtol=0.1 if jtu.device_under_test() == "tpu" else 1e-5)
def testLogSumExp(self, shapes, dtype, axis, keepdims, return_sign, use_b): if jtu.device_under_test() != "cpu": rng = jtu.rand_some_inf_and_nan(self.rng()) else: rng = jtu.rand_default(self.rng()) # TODO(mattjj): test autodiff if use_b: def scipy_fun(array_to_reduce, scale_array): return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign, b=scale_array) def lax_fun(array_to_reduce, scale_array): return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign, b=scale_array) args_maker = lambda: [rng(shapes[0], dtype), rng(shapes[1], dtype)] else: def scipy_fun(array_to_reduce): return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign) def lax_fun(array_to_reduce): return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign) args_maker = lambda: [rng(shapes[0], dtype)] tol = {np.float32: 1E-6, np.float64: 1E-14} self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker) self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol)
def test_function_dynamic_shape(self): if jtu.device_under_test() == "tpu": raise unittest.SkipTest("TODO: why does this fail on TPU?") # Call a function for which shape inference does not give an output # shape. x = np.array([-1, 0, 1], dtype=np.int32) def fun_tf(x): # x:i32[3] # The shape depends on the value of x res = tf.where(x >= 0) return res # Call in eager mode. Should work! res1 = jax2tf.call_tf(fun_tf)(x) expected = np.array([[1], [2]]) self.assertAllClose(expected, res1, check_dtypes=False) # Now under jit, should fail because the function is not compileable with self.assertRaisesRegex( ValueError, "Error compiling TensorFlow function. call_tf can used in a staged context" ): fun_jax = jax.jit(jax2tf.call_tf(fun_tf)) fun_jax(x) # TODO(necula): this should work in op-by-op mode, but it fails because # jax2tf.convert does abstract evaluation. with self.assertRaisesRegex( ValueError, "Error compiling TensorFlow function. call_tf can used in a staged context" ): fun_tf_rt = jax2tf.convert(jax2tf.call_tf(fun_tf)) fun_tf_rt(x)
def testScipySpecialFun(self, scipy_op, lax_op, rng_factory, shapes, dtypes, test_autodiff, nondiff_argnums): if (jtu.device_under_test() == "cpu" and (lax_op is lsp_special.gammainc or lax_op is lsp_special.gammaincc)): # TODO(b/173608403): re-enable test when LLVM bug is fixed. raise unittest.SkipTest("Skipping test due to LLVM lowering bug") rng = rng_factory(self.rng()) args_maker = self._GetArgsMaker(rng, shapes, dtypes) args = args_maker() self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3, check_dtypes=False) self._CompileAndCheck(lax_op, args_maker, rtol=1e-4) if test_autodiff: def partial_lax_op(*vals): list_args = list(vals) for i in nondiff_argnums: list_args.insert(i, args[i]) return lax_op(*list_args) assert list(nondiff_argnums) == sorted(set(nondiff_argnums)) diff_args = [ x for i, x in enumerate(args) if i not in nondiff_argnums ] jtu.check_grads(partial_lax_op, diff_args, order=1, atol=jtu.if_device_under_test("tpu", .1, 1e-3), rtol=.1, eps=1e-3)
def test_jax_implemented(self, harness: primitive_harness.Harness): """Runs all harnesses just with JAX to verify the jax_unimplemented field. """ jax_unimpl = [ l for l in harness.jax_unimplemented if l.filter(device=jtu.device_under_test(), dtype=harness.dtype) ] if any([lim.skip_run for lim in jax_unimpl]): logging.info( "Skipping run with expected JAX limitations: %s in harness %s", [u.description for u in jax_unimpl], harness.fullname) return try: harness.dyn_fun(*harness.dyn_args_maker(self.rng())) except Exception as e: if jax_unimpl: logging.info( "Found expected JAX error %s with expected JAX limitations: " "%s in harness %s", e, [u.description for u in jax_unimpl], harness.fullname) return else: raise e if jax_unimpl: logging.warning( "Found no JAX error but expected JAX limitations: %s in " "harness: %s", [u.description for u in jax_unimpl], harness.fullname)
def testDetrend(self, shape, dtype, axis, type, bp): signal = np.random.normal(loc=2, size=shape) if type == 'constant': trend = np.ones_like(signal) elif type == 'linear': trend = np.linspace(-np.pi, np.pi, shape[0]) if len(shape) == 1: trend = np.broadcast_to(trend, shape) elif len(shape) == 2: trend = np.broadcast_to(trend[:, None], shape) elif len(shape) == 3: trend = np.broadcast_to(trend[:, None, None], shape) args_maker = lambda: [signal, trend] def osp_fun(signal, trend): return osp_signal.detrend( signal + trend, axis=axis, type=type, bp=bp) - trend def jsp_fun(signal, noise): return jsp_signal.detrend( signal + noise, axis=axis, type=type, bp=bp) - trend if jtu.device_under_test() == 'tpu': tol = {np.float32: 3e-2, np.float64: 1e-12} else: tol = {np.float32: 1e-5, np.float64: 1e-12} self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, tol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
def testWelchWithDefaultStepArgsAgainstNumpy(self, *, shape, dtype, nperseg, noverlap, use_nperseg, use_noverlap, timeaxis): kwargs = {'axis': timeaxis} if use_nperseg: kwargs['nperseg'] = nperseg else: kwargs['window'] = osp_signal.get_window('hann', nperseg) if use_noverlap: kwargs['noverlap'] = noverlap osp_fun = partial(osp_signal.welch, **kwargs) jsp_fun = partial(jsp_signal.welch, **kwargs) tol = { np.float32: 1e-5, np.float64: 1e-12, np.complex64: 1e-5, np.complex128: 1e-12 } if jtu.device_under_test() == 'tpu': tol = _TPU_FFT_TOL rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
def setUp(self): super().setUp() if jtu.device_under_test() != "tpu": raise SkipTest("serialize executable only works on TPU") if jax._src.lib.xla_bridge.get_backend().runtime_type == "tfrt": raise SkipTest( "the new TFRT runtime does not support serialization")
def test_spectral_dac_svd(self, linear_size, seed, dtype): if jnp.dtype(dtype).name in ("bfloat16", "float16"): if jtu.device_under_test() != "cpu": raise unittest.SkipTest("Skip half precision off CPU.") rng = self.rng() A = rng.randn(linear_size, linear_size).astype(dtype) if jnp.dtype(dtype).name in ("bfloat16", "float16"): self.assertRaises(NotImplementedError, jax._src.scipy.eigh.svd, A) return S_expected = np.linalg.svd(A, compute_uv=False) U, S, V = jax._src.scipy.eigh.svd(A) recon = jnp.dot((U * jnp.expand_dims(S, 0)), V, precision=lax.Precision.HIGHEST) eps = jnp.finfo(dtype).eps eps = eps * jnp.linalg.norm(A) * 15 self.assertAllClose(np.sort(S), np.sort(S_expected), atol=eps) self.assertAllClose(A, recon, atol=eps) # U is unitary. u_unitary_delta = jnp.dot(U.conj().T, U, precision=lax.Precision.HIGHEST) u_eye = jnp.eye(u_unitary_delta.shape[0], dtype=dtype) self.assertAllClose(u_unitary_delta, u_eye, atol=eps) # V is unitary. v_unitary_delta = jnp.dot(V.conj().T, V, precision=lax.Precision.HIGHEST) v_eye = jnp.eye(v_unitary_delta.shape[0], dtype=dtype) self.assertAllClose(v_unitary_delta, v_eye, atol=eps)
def testTensorFlowToJax(self, shape, dtype): if not config.x64_enabled and dtype in [jnp.int64, jnp.uint64, jnp.float64]: raise self.skipTest("x64 types are disabled by jax_enable_x64") if (jtu.device_under_test() == "gpu" and not tf.config.list_physical_devices("GPU")): raise self.skipTest("TensorFlow not configured with GPU support") if jtu.device_under_test() == "gpu" and dtype == jnp.int32: raise self.skipTest("TensorFlow does not place int32 tensors on GPU") rng = jtu.rand_default(self.rng()) np = rng(shape, dtype) with tf.device("/GPU:0" if jtu.device_under_test() == "gpu" else "/CPU:0"): x = tf.identity(tf.constant(np)) dlpack = tf.experimental.dlpack.to_dlpack(x) y = jax.dlpack.from_dlpack(dlpack) self.assertAllClose(np, y)
def test_eval_numpy_no_copy(self): if jtu.device_under_test() != "cpu": raise unittest.SkipTest("no_copy test works only on CPU") # For ndarray, zero-copy only works for sufficiently-aligned arrays. x = np.ones((16, 16), dtype=np.float32) res = jax2tf.call_tf(lambda x: x)(x) self.assertAllClose(x, res) self.assertTrue(np.shares_memory(x, res))
def test_eval_devicearray_no_copy(self): if jtu.device_under_test() != "cpu": # TODO(necula): add tests for GPU and TPU raise unittest.SkipTest("no_copy test works only on CPU") # For DeviceArray zero-copy works even if not aligned x = jnp.ones((3, 3)) res = jax2tf.call_tf(lambda x: x)(x) self.assertAllClose(x, res) self.assertTrue(np.shares_memory(x, res))
def testPolar(self, n_zero_sv, degeneracy, geometric_spectrum, max_sv, shape, method, side, nonzero_condition_number, dtype, seed): """ Tests jax.scipy.linalg.polar.""" if jtu.device_under_test() != "cpu": if jnp.dtype(dtype).name in ("bfloat16", "float16"): raise unittest.SkipTest("Skip half precision off CPU.") m, n = shape if (method == "qdwh" and ((side == "left" and m >= n) or (side == "right" and m < n))): raise unittest.SkipTest("method=qdwh does not support these sizes") matrix, _ = _initialize_polar_test(self.rng(), shape, n_zero_sv, degeneracy, geometric_spectrum, max_sv, nonzero_condition_number, dtype) if jnp.dtype(dtype).name in ("bfloat16", "float16"): self.assertRaises(NotImplementedError, jsp.linalg.polar, matrix, method=method, side=side) return unitary, posdef = jsp.linalg.polar(matrix, method=method, side=side) if shape[0] >= shape[1]: should_be_eye = np.matmul(unitary.conj().T, unitary) else: should_be_eye = np.matmul(unitary, unitary.conj().T) tol = 500 * float(jnp.finfo(matrix.dtype).eps) eye_mat = np.eye(should_be_eye.shape[0], dtype=should_be_eye.dtype) with self.subTest('Test unitarity.'): self.assertAllClose(eye_mat, should_be_eye, atol=tol * min(shape)) with self.subTest('Test Hermiticity.'): self.assertAllClose(posdef, posdef.conj().T, atol=tol * jnp.linalg.norm(posdef)) ev, _ = np.linalg.eigh(posdef) ev = ev[np.abs(ev) > tol * np.linalg.norm(posdef)] negative_ev = jnp.sum(ev < 0.) with self.subTest('Test positive definiteness.'): self.assertEqual(negative_ev, 0) if side == "right": recon = jnp.matmul(unitary, posdef, precision=lax.Precision.HIGHEST) elif side == "left": recon = jnp.matmul(posdef, unitary, precision=lax.Precision.HIGHEST) with self.subTest('Test reconstruction.'): self.assertAllClose(matrix, recon, atol=tol * jnp.linalg.norm(matrix))
def testMultiBackendNestedJitConflict(self, ordering): outer, inner = ordering if outer not in ('cpu', jtu.device_under_test(), None): raise SkipTest("Backend is not CPU or the device under test") if inner not in ('cpu', jtu.device_under_test(), None): raise SkipTest("Backend is not CPU or the device under test") @partial(jax.jit, backend=outer) def fun(x, y): @partial(jax.jit, backend=inner) def infun(x, y): return jnp.matmul(x, y) return infun(x, y) + jnp.ones_like(x) x = npr.uniform(size=(10, 10)) y = npr.uniform(size=(10, 10)) self.assertRaises(ValueError, lambda: fun(x, y))
def testCsdWithSameParamAgainstNumpy(self, *, shape, dtype, fs, window, nperseg, noverlap, nfft, detrend, scaling, timeaxis, average): is_complex = np.dtype(dtype).kind == 'c' if is_complex and detrend is not None: raise unittest.SkipTest( "Complex signal is not supported in lax-backed `signal.detrend`." ) def osp_fun(x, y): # When the identical parameters are given, jsp-version follows # the behavior with copied parameters. freqs, Pxy = osp_signal.csd(x, y.copy(), fs=fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=not is_complex, scaling=scaling, axis=timeaxis, average=average) return freqs, Pxy jsp_fun = partial(jsp_signal.csd, fs=fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=not is_complex, scaling=scaling, axis=timeaxis, average=average) tol = { np.float32: 1e-5, np.float64: 1e-12, np.complex64: 1e-5, np.complex128: 1e-12 } if jtu.device_under_test() == 'tpu': tol = _TPU_FFT_TOL rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] * 2 self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
def testTorchToJax(self, shape, dtype): if not config.x64_enabled and dtype in [jnp.int64, jnp.float64]: self.skipTest("x64 types are disabled by jax_enable_x64") rng = jtu.rand_default(self.rng()) np = rng(shape, dtype) x = torch.from_numpy(np) x = x.cuda() if jtu.device_under_test() == "gpu" else x dlpack = torch.utils.dlpack.to_dlpack(x) y = jax.dlpack.from_dlpack(dlpack) self.assertAllClose(np, y)
def testIstftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg, noverlap, nfft, onesided, boundary, timeaxis, freqaxis): if not onesided: new_freq_len = (shape[freqaxis] - 1) * 2 shape = shape[:freqaxis] + (new_freq_len, ) + shape[freqaxis + 1:] def osp_fun(x, fs): # Ignore UserWarning in osp so we can also test over ill-posed cases. with warnings.catch_warnings(): warnings.simplefilter("ignore") result = osp_signal.istft(x, fs=fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, input_onesided=onesided, boundary=boundary, time_axis=timeaxis, freq_axis=freqaxis) return result jsp_fun = partial(jsp_signal.istft, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, input_onesided=onesided, boundary=boundary, time_axis=timeaxis, freq_axis=freqaxis) tol = { np.float32: 1e-4, np.float64: 1e-6, np.complex64: 1e-4, np.complex128: 1e-6 } if jtu.device_under_test() == 'tpu': tol = _TPU_FFT_TOL rng = jtu.rand_default(self.rng()) rng_fs = jtu.rand_uniform(self.rng(), 1.0, 16000.0) args_maker = lambda: [rng(shape, dtype), rng_fs((), np.float)] # Here, dtype of output signal is different depending on osp versions, # and so depending on the test environment. Thus, dtype check is disabled. self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol, check_dtypes=False) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
def testMultiBackendNestedJit(self, ordering): outer, inner = ordering if outer not in ('cpu', jtu.device_under_test(), None): raise SkipTest("Backend is not CPU or the device under test") @partial(jax.jit, backend=outer) def fun(x, y): @partial(jax.jit, backend=inner) def infun(x, y): return jnp.matmul(x, y) return infun(x, y) + jnp.ones_like(x) x = npr.uniform(size=(10, 10)) y = npr.uniform(size=(10, 10)) z_host = np.matmul(x, y) + np.ones_like(x) z = fun(x, y) self.assertAllClose(z, z_host, rtol=1e-2) correct_platform = outer if outer else jtu.device_under_test() self.assertEqual(z.device_buffer.platform(), correct_platform)
def test_jit_cache(self): if jtu.device_under_test() == "tpu": self.skipTest("64-bit random not available on TPU") f = partial(random.uniform, random.PRNGKey(0), (1, ), 'float64', -1, 1) with disable_x64(): for _ in range(2): f() with enable_x64(): for _ in range(2): f()
def test_prim(self, harness: primitive_harness.Harness): limitations = Jax2TfLimitation.limitations_for_harness(harness) device = jtu.device_under_test() limitations = tuple(filter(lambda l: l.filter(device=device, dtype=harness.dtype), limitations)) func_jax = harness.dyn_fun args = harness.dyn_args_maker(self.rng()) enable_xla = harness.params.get("enable_xla", True) associative_scan_reductions = harness.params.get("associative_scan_reductions", False) with jax.jax2tf_associative_scan_reductions(associative_scan_reductions): self.ConvertAndCompare(func_jax, *args, limitations=limitations, enable_xla=enable_xla)
def test_with_var_read_x64(self, with_jit=True): if jtu.device_under_test() == "gpu": raise unittest.SkipTest("Test fails on GPU") outer_var_array = np.array([3., 4.], dtype=np.float64) outer_var = tf.Variable(outer_var_array) def fun_tf(x): return x * tf.cast(outer_var, x.dtype) + 1. x = np.array([2., 5.,], dtype=np.float32) res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x) self.assertAllClose(x * outer_var_array + 1., res, check_dtypes=False)
def testWelchAgainstNumpy(self, *, shape, dtype, fs, window, nperseg, noverlap, nfft, detrend, return_onesided, scaling, timeaxis, average): if np.dtype(dtype).kind == 'c': return_onesided = False if detrend is not None: raise unittest.SkipTest( "Complex signal is not supported in lax-backed `signal.detrend`." ) osp_fun = partial(osp_signal.welch, fs=fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=return_onesided, scaling=scaling, axis=timeaxis, average=average) jsp_fun = partial(jsp_signal.welch, fs=fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=return_onesided, scaling=scaling, axis=timeaxis, average=average) tol = { np.float32: 1e-5, np.float64: 1e-12, np.complex64: 1e-5, np.complex128: 1e-12 } if jtu.device_under_test() == 'tpu': tol = _TPU_FFT_TOL rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
def test_with_multiple_capture(self, with_jit=True): if jtu.device_under_test() == "gpu": raise unittest.SkipTest("Test fails on GPU") v2 = tf.Variable(2., dtype=np.float32) v3 = tf.Variable(3., dtype=np.float32) t4 = tf.constant(4., dtype=np.float32) t5 = tf.constant(5., dtype=np.float32) def fun_tf(x): return (x * v3 + t4 + v2) * v3 + t5 x = np.float32(2.) res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x) self.assertAllClose((x * 3. + 4. + 2.) * 3. + 5., res, check_dtypes=False)
def testGpuMultiBackendOpByOpReturn(self, backend): if backend not in ('cpu', jtu.device_under_test()): raise SkipTest("Backend is not CPU or the device under test") @partial(jax.jit, backend=backend) def fun(x, y): return jnp.matmul(x, y) x = npr.uniform(size=(10, 10)) y = npr.uniform(size=(10, 10)) z = fun(x, y) w = jnp.sin(z) self.assertEqual(z.device_buffer.platform(), backend) self.assertEqual(w.device_buffer.platform(), backend)
def test_with_var_different_shape(self): # See https://github.com/google/jax/issues/6050 if jtu.device_under_test() == "gpu": raise unittest.SkipTest("Test fails on GPU") v = tf.Variable((4., 2.), dtype=tf.float32) def tf_func(x): return v + x x = np.float32(123.) tf_out = tf_func(x) jax_func = jax.jit(jax2tf.call_tf(tf_func)) jax_out = jax_func(x) self.assertAllClose(tf_out, jax_out, check_dtypes=False)
def test_spectral_dac_eigh(self, linear_size, seed, dtype, termination_size): if jnp.dtype(dtype).name in ("bfloat16", "float16"): if jtu.device_under_test() != "cpu": raise unittest.SkipTest("Skip half precision off CPU.") if jtu.device_under_test() != "tpu" and termination_size != 1: raise unittest.SkipTest( "Termination sizes greater than 1 only work on TPU") rng = self.rng() H = rng.randn(linear_size, linear_size) H = jnp.array(0.5 * (H + H.conj().T)).astype(dtype) if jnp.dtype(dtype).name in ("bfloat16", "float16"): self.assertRaises(NotImplementedError, jax._src.lax.eigh.eigh, H) return evs, V = jax._src.lax.eigh.eigh(H, termination_size=termination_size) ev_exp, eV_exp = jnp.linalg.eigh(H) HV = jnp.dot(H, V, precision=lax.Precision.HIGHEST) vV = evs[None, :] * V eps = jnp.finfo(H.dtype).eps atol = jnp.linalg.norm(H) * eps self.assertAllClose(ev_exp, jnp.sort(evs), atol=20 * atol) self.assertAllClose(HV, vV, atol=30 * atol)
def setUp(self): super().setUp() # Ensure that all TF ops are created on the proper device (TPU or GPU or CPU) tf_preferred_devices = (tf.config.list_logical_devices("TPU") + tf.config.list_logical_devices("GPU") + tf.config.list_logical_devices()) self.tf_default_device = tf_preferred_devices[0] logging.info("Running jax2tf converted code on %s.", self.tf_default_device) # We need --config=cuda build flag for TF to see the GPUs self.assertEqual(jtu.device_under_test().upper(), self.tf_default_device.device_type) with contextlib.ExitStack() as stack: stack.enter_context(tf.device(self.tf_default_device)) self.addCleanup(stack.pop_all().close)
def testProgrammaticProfilingContextManager(self): with tempfile.TemporaryDirectory() as tmpdir: with jax.profiler.trace(tmpdir): jax.pmap(lambda x: jax.lax.psum(x + 1, 'i'), axis_name='i')(jnp.ones(jax.local_device_count())) proto_path = glob.glob(os.path.join(tmpdir, "**/*.xplane.pb"), recursive=True) self.assertEqual(len(proto_path), 1) with open(proto_path[0], "rb") as f: proto = f.read() # Sanity check that serialized proto contains host and device traces # without deserializing. self.assertIn(b"/host:CPU", proto) if jtu.device_under_test() == "tpu": self.assertIn(b"/device:TPU", proto)