def testAdvancedIndexingManually(self): x = np.random.RandomState(0).randn(3, 4, 5) index_array = np.array([0, 2, -1, 0]) op = lambda x, index_array: x[..., index_array, :] cop = api.jit(op) a1 = op(x, index_array) a2 = cop(x, index_array) self.assertAllClose(a1, a2) op = lambda x, index_array: x[..., index_array, :, index_array, None] cop = api.jit(op) a1 = op(x, index_array) a2 = cop(x, index_array) self.assertAllClose(a1, a2) op = lambda x, index_array: x[index_array, ..., index_array[:, None], None] cop = api.jit(op) a1 = op(x, index_array) a2 = cop(x, index_array) self.assertAllClose(a1, a2)
def test_jit_on_nondefault_backend(self): cpus = api.devices("cpu") self.assertNotEmpty(cpus) # Since we are not on CPU, some other backend will be the default default_dev = api.devices()[0] self.assertNotEqual(default_dev.platform, "cpu") data_on_cpu = api.device_put(1, device=cpus[0]) self.assertEqual(data_on_cpu.device_buffer.device(), cpus[0]) def my_sin(x): return jnp.sin(x) # jit without any device spec follows the data result1 = api.jit(my_sin)(2) self.assertEqual(result1.device_buffer.device(), default_dev) result2 = api.jit(my_sin)(data_on_cpu) self.assertEqual(result2.device_buffer.device(), cpus[0]) # jit with `device` spec places the data on the specified device result3 = api.jit(my_sin, device=cpus[0])(2) self.assertEqual(result3.device_buffer.device(), cpus[0]) # jit with `backend` spec places the data on the specified backend result4 = api.jit(my_sin, backend="cpu")(2) self.assertEqual(result4.device_buffer.device(), cpus[0])
def test_closed_over_values_device_placement(self): # see https://github.com/google/jax/issues/1431 def f(): return jnp.add(3., 4.) self.assertNotEqual( api.jit(f)().device_buffer.device(), api.devices('cpu')[0]) self.assertEqual( api.jit(f, backend='cpu')().device_buffer.device(), api.devices('cpu')[0])
def testFloatIndexingError(self): BAD_INDEX_TYPE_ERROR = "Indexer must have integer or boolean type, got indexer with type" with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): jnp.zeros(2)[0.] with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): jnp.zeros((2, 2))[(0, 0.)] with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): jnp.zeros((2, 2))[(0, 0.)] with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): api.jit(lambda idx: jnp.zeros((2, 2))[idx])((0, 0.)) with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): ops.index_add(jnp.zeros(2), 0., 1.) with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): ops.index_update(jnp.zeros(2), 0., 1.)
def testIndexingEmptyDimension(self): # Issue 2671: XLA error when indexing into dimension of size 0 x = jnp.ones((2, 0)) # The following work, even on axis 1 of size 0 _ = x[0, :] + x[0, None] + x[0, 1:] + x[0, 1:3:2] with self.assertRaisesRegex(IndexError, "index .* is out of bounds for axis .* with size 0"): _ = np.ones((2, 0))[0, 0] # The numpy error with self.assertRaisesRegex(IndexError, "index is out of bounds for axis .* with size 0"): _ = x[0, 0] # JAX indexing with self.assertRaisesRegex(IndexError, "index is out of bounds for axis .* with size 0"): api.jit(lambda i: x[0, i])(0) # JAX indexing under jit
def testCategorical(self, p, axis, dtype, sample_shape): key = random.PRNGKey(0) p = np.array(p, dtype=dtype) logits = np.log(p) - 42 # test unnormalized out_shape = tuple(np.delete(logits.shape, axis)) shape = sample_shape + out_shape rand = partial(random.categorical, shape=shape, axis=axis) crand = api.jit(rand) uncompiled_samples = rand(key, logits) compiled_samples = crand(key, logits) if axis < 0: axis += len(logits.shape) for samples in [uncompiled_samples, compiled_samples]: assert samples.shape == shape samples = jnp.reshape(samples, (10000,) + out_shape) if len(p.shape[:-1]) > 0: ps = np.transpose(p, (1, 0)) if axis == 0 else p for cat_samples, cat_p in zip(samples.transpose(), ps): pmf = lambda x: np.where(x < len(cat_p), cat_p[np.minimum(len(cat_p) - 1, x)], 0.0) self._CheckChiSquared(cat_samples, pmf=pmf) else: pmf = lambda x: np.where(x < len(p), p[np.minimum(len(p) - 1, x)], 0.0) self._CheckChiSquared(samples, pmf=pmf)
def testMultivariateNormal(self, dim, dtype, method): r = np.random.RandomState(dim) mean = r.randn(dim) cov_factor = r.randn(dim, dim) cov = np.dot(cov_factor, cov_factor.T) + dim * np.eye(dim) key = random.PRNGKey(0) rand = partial(random.multivariate_normal, mean=mean, cov=cov, shape=(10000, ), method=method) crand = api.jit(rand) uncompiled_samples = np.asarray(rand(key), np.float64) compiled_samples = np.asarray(crand(key), np.float64) inv_scale = scipy.linalg.lapack.dtrtri(np.linalg.cholesky(cov), lower=True)[0] for samples in [uncompiled_samples, compiled_samples]: centered = samples - mean whitened = np.einsum('nj,ij->ni', centered, inv_scale) # This is a quick-and-dirty multivariate normality check that tests that a # uniform mixture of the marginals along the covariance matrix's # eigenvectors follow a standard normal distribution. self._CheckKolmogorovSmirnovCDF(whitened.ravel(), scipy.stats.norm().cdf)
def testIssue187(self): x = jnp.ones((5, 5)) x[[0, 2, 4], [0, 2, 4]] # doesn't crash x = np.arange(25).reshape((5, 5)) ans = api.jit(lambda x: x[[0, 2, 4], [0, 2, 4]])(x) expected = x[[0, 2, 4], [0, 2, 4]] self.assertAllClose(ans, expected, check_dtypes=False)
def testGamma(self, a, dtype): key = random.PRNGKey(0) rand = lambda key, a: random.gamma(key, a, (10000,), dtype) crand = api.jit(rand) uncompiled_samples = rand(key, a) compiled_samples = crand(key, a) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gamma(a).cdf)
def testLogistic(self, dtype): key = random.PRNGKey(0) rand = lambda key: random.logistic(key, (10000,), dtype) crand = api.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.logistic().cdf)
def _maybe_bool_binop(numpy_fn, lax_fn, bool_lax_fn, lax_doc=False): def fn(x1, x2): x1, x2 = _promote_args(numpy_fn.__name__, x1, x2) return lax_fn(x1, x2) if x1.dtype != np.bool_ else bool_lax_fn(x1, x2) fn = jit(fn, inline=True) if lax_doc: doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() return _wraps(numpy_fn, lax_description=doc)(fn) else: return _wraps(numpy_fn)(fn)
def testPareto(self, b, dtype): key = random.PRNGKey(0) rand = lambda key, b: random.pareto(key, b, (10000,), dtype) crand = api.jit(rand) uncompiled_samples = rand(key, b) compiled_samples = crand(key, b) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.pareto(b).cdf)
def testT(self, df, dtype): key = random.PRNGKey(0) rand = lambda key, df: random.t(key, df, (10000,), dtype) crand = api.jit(rand) uncompiled_samples = rand(key, df) compiled_samples = crand(key, df) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.t(df).cdf)
def testRngUniform(self, dtype): key = random.PRNGKey(0) rand = lambda key: random.uniform(key, (10000,), dtype) crand = api.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self._CheckCollisions(samples, jnp.finfo(dtype).nmant) self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.uniform().cdf)
def _one_to_one_binop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False): if promote_to_inexact: fn = lambda x1, x2: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x1, x2)) else: fn = lambda x1, x2: lax_fn(*_promote_args(numpy_fn.__name__, x1, x2)) fn = jit(fn, inline=True) if lax_doc: doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() return _wraps(numpy_fn, lax_description=doc)(fn) else: return _wraps(numpy_fn)(fn)
def test_convert_element_type(self): # Regression test for part of https://github.com/google/jax/issues/5982 with enable_x64(): x = jnp.int64(1) self.assertEqual(x.dtype, jnp.int64) y = x.astype(jnp.int32) self.assertEqual(y.dtype, jnp.int32) z = api.jit(lambda x: x.astype(jnp.int32))(x) self.assertEqual(z.dtype, jnp.int32)
def testBernoulli(self, p, dtype): key = random.PRNGKey(0) p = np.array(p, dtype=dtype) rand = lambda key, p: random.bernoulli(key, p, (10000,)) crand = api.jit(rand) uncompiled_samples = rand(key, p) compiled_samples = crand(key, p) for samples in [uncompiled_samples, compiled_samples]: self._CheckChiSquared(samples, scipy.stats.bernoulli(p).pmf)
def test_prng_seeds_and_keys(self, seed, type, jit, key): if (jit and type is int and not config.x64_enabled and (seed < np.iinfo('int32').min or seed > np.iinfo('int32').max)): self.skipTest("Expected failure: integer out of range for jit.") seed = type(seed) if jit: actual = api.jit(random.PRNGKey)(seed) else: actual = random.PRNGKey(seed) expected = jnp.array(key, dtype=jnp.uint32) self.assertArraysEqual(actual, expected)
def testNormalComplex(self, dtype): key = random.PRNGKey(0) rand = lambda key: random.normal(key, (10000,), dtype) crand = api.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(jnp.real(samples), scipy.stats.norm(scale=1/np.sqrt(2)).cdf) self._CheckKolmogorovSmirnovCDF(jnp.imag(samples), scipy.stats.norm(scale=1/np.sqrt(2)).cdf) self.assertEqual(dtype, samples.dtype)
def testUnpacking(self): def foo(x): a, b, c = x return a + b + c cfoo = api.jit(foo) a1 = foo(np.arange(3)) a2 = cfoo(np.arange(3)) self.assertAllClose(a1, a2)
def testBeta(self, a, b, dtype): if not config.x64_enabled: raise SkipTest("skip test except on X64") key = random.PRNGKey(0) rand = lambda key, a, b: random.beta(key, a, b, (10000,), dtype) crand = api.jit(rand) uncompiled_samples = rand(key, a, b) compiled_samples = crand(key, a, b) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.beta(a, b).cdf)
def test_impl(self, ad="simple"): call_tf = CALL_TF_IMPLEMENTATIONS[ad] def f_jax(x): return jnp.sin(x) def f_outside(x): return call_tf(tf.math.sin, x, result_shape=x) res = f_outside(3.) self.assertAllClose(f_jax(3.), res) self.assertAllClose(f_jax(3.), api.jit(f_outside)(3.))
def testDirichlet(self, alpha, dtype): key = random.PRNGKey(0) rand = lambda key, alpha: random.dirichlet(key, alpha, (10000,), dtype) crand = api.jit(rand) uncompiled_samples = rand(key, alpha) compiled_samples = crand(key, alpha) for samples in [uncompiled_samples, compiled_samples]: self.assertAllClose(samples.sum(-1), np.ones(10000, dtype=dtype)) alpha_sum = sum(alpha) for i, a in enumerate(alpha): self._CheckKolmogorovSmirnovCDF(samples[..., i], scipy.stats.beta(a, alpha_sum - a).cdf)
def testPermutationInteger(self): key = random.PRNGKey(0) x = 100 rand = lambda key: random.permutation(key, x) crand = api.jit(rand) perm1 = rand(key) perm2 = crand(key) self.assertAllClose(perm1, perm2) self.assertEqual(perm1.dtype, perm2.dtype) self.assertFalse(np.all(perm1 == np.arange(100))) # seems unlikely! self.assertAllClose(np.sort(perm1), np.arange(100), check_dtypes=False)
def testPoisson(self, lam, dtype): key = random.PRNGKey(0) rand = lambda key, lam: random.poisson(key, lam, (10000,), dtype) crand = api.jit(rand) uncompiled_samples = rand(key, lam) compiled_samples = crand(key, lam) for samples in [uncompiled_samples, compiled_samples]: self._CheckChiSquared(samples, scipy.stats.poisson(lam).pmf) # TODO(shoyer): determine error bounds for moments more rigorously (e.g., # based on the central limit theorem). self.assertAllClose(samples.mean(), lam, rtol=0.01, check_dtypes=False) self.assertAllClose(samples.var(), lam, rtol=0.03, check_dtypes=False)
def testRngRandint(self, dtype): lo = 5 hi = 10 key = random.PRNGKey(0) rand = lambda key: random.randint(key, (10000,), lo, hi, dtype) crand = api.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self.assertTrue(np.all(lo <= samples)) self.assertTrue(np.all(samples < hi))
def testShuffle(self, dtype): key = random.PRNGKey(0) x = np.arange(100).astype(dtype) rand = lambda key: random.shuffle(key, x) crand = api.jit(rand) with self.assertWarns(FutureWarning): perm1 = rand(key) with self.assertWarns(FutureWarning): perm2 = crand(key) self.assertAllClose(perm1, perm2) self.assertFalse(np.all(perm1 == x)) # seems unlikely! self.assertAllClose(np.sort(perm1), x, check_dtypes=False)
def testTruncatedNormal(self, dtype): key = random.PRNGKey(0) rand = lambda key: random.truncated_normal(key, -0.3, 0.3, (10000,), dtype) crand = api.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) min_val = np.min(uncompiled_samples) max_val = np.max(uncompiled_samples) self.assertTrue(min_val > -0.3) self.assertTrue(max_val < 0.3) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.truncnorm(-0.3, 0.3).cdf)
def testPermutationArray(self, dtype, shape): key = random.PRNGKey(0) x = jnp.arange(np.prod(shape)).reshape(shape).astype(dtype) rand = lambda key: random.permutation(key, x) crand = api.jit(rand) perm1 = rand(key) perm2 = crand(key) self.assertAllClose(perm1, perm2) if x.shape[0] > 1: self.assertFalse(np.all(perm1 == x)) # seems unlikely! self.assertAllClose(np.sort(perm1.ravel()), x.ravel(), check_dtypes=False) self.assertArraysAllClose( x, jnp.arange(np.prod(shape)).reshape(shape).astype(dtype))
def _CompileAndCheck(self, fun, args_maker, *, check_dtypes=True, rtol=None, atol=None, check_cache_misses=True): """Helper method for running JAX compilation and allclose assertions.""" args = args_maker() def wrapped_fun(*args): self.assertTrue(python_should_be_executing) return fun(*args) python_should_be_executing = True python_ans = fun(*args) python_shapes = tree_map(lambda x: np.shape(x), python_ans) np_shapes = tree_map(lambda x: np.shape(np.asarray(x)), python_ans) self.assertEqual(python_shapes, np_shapes) cache_misses = dispatch.xla_primitive_callable.cache_info().misses python_ans = fun(*args) if check_cache_misses: self.assertEqual( cache_misses, dispatch.xla_primitive_callable.cache_info().misses, "Compilation detected during second call of {} in op-by-op " "mode.".format(fun)) cfun = api.jit(wrapped_fun) python_should_be_executing = True monitored_ans = cfun(*args) python_should_be_executing = False compiled_ans = cfun(*args) self.assertAllClose(python_ans, monitored_ans, check_dtypes=check_dtypes, atol=atol, rtol=rtol) self.assertAllClose(python_ans, compiled_ans, check_dtypes=check_dtypes, atol=atol, rtol=rtol) args = args_maker() python_should_be_executing = True python_ans = fun(*args) python_should_be_executing = False compiled_ans = cfun(*args) self.assertAllClose(python_ans, compiled_ans, check_dtypes=check_dtypes, atol=atol, rtol=rtol)