def testForiLoopTupleState(self): def sum_first_n(arr, num): def body_fun(i, state): arr, total = state arr_i = lax.dynamic_index_in_dim(arr, i, 0, False) return (arr, lax.add(total, arr_i)) init_val = (arr, 0.) _, total = lax.fori_loop(0, lax.min(arr.shape[0], num), body_fun, init_val) return total cfun = api.jit(sum_first_n) x = npr.RandomState(0).randn(10) for num in [0, 5, 10, 15]: self.assertAllClose(sum_first_n(x, num), onp.sum(x[:num]), check_dtypes=False) self.assertAllClose(cfun(x, num), onp.sum(x[:num]), check_dtypes=False) self.assertAllClose(cfun(x, num), onp.sum(x[:num]), check_dtypes=False)
def test_custom_root_vector_with_solve_closure(self): def vector_solve(f, y): return np.linalg.solve(api.jacobian(f)(y), y) def linear_solve(a, b): f = lambda y: high_precision_dot(a, y) - b x0 = np.zeros_like(b) solution = np.linalg.solve(a, b) oracle = lambda func, x0: solution return lax.custom_root(f, x0, oracle, vector_solve) rng = onp.random.RandomState(0) a = rng.randn(2, 2) b = rng.randn(2) jtu.check_grads(linear_solve, (a, b), order=2) actual = api.jit(linear_solve)(a, b) expected = np.linalg.solve(a, b) self.assertAllClose(expected, actual, check_dtypes=True)
def test_jit_interleaving(self): # Several jit's without data dependencies; they may interfere count = 0 # Count tap invocations nr_arrays = 5 def tap_func(arg, **kwargs): nonlocal count assert len(arg) == nr_arrays count += 1 # This is the function that we'll run multiple times def func(x, count): for i in range(count): x = hcb.id_tap(tap_func, [x + i for i in range(nr_arrays)], i=i)[-1] return x with hcb.outfeed_receiver(receiver_name=self._testMethodName): x = jnp.array(1, dtype=np.int32) res = 0 for i in range(10): # No dependencies between the jit invocations res += api.jit(lambda x: func(x, 10))(x) self.assertEqual(100, count)
def testPoisson(self, lam, dtype): key = random.PRNGKey(0) rand = lambda key, lam: random.poisson(key, lam, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key, lam) compiled_samples = crand(key, lam) for samples in [uncompiled_samples, compiled_samples]: self._CheckChiSquared(samples, scipy.stats.poisson(lam).pmf) # TODO(shoyer): determine error bounds for moments more rigorously (e.g., # based on the central limit theorem). self.assertAllClose(samples.mean(), lam, rtol=0.01, check_dtypes=False) self.assertAllClose(samples.var(), lam, rtol=0.03, check_dtypes=False)
def testWeibullSample(self, concentration, scale): num_samples = 10**5 rng = random.PRNGKey(0) rand = lambda x: random.weibull_min(x, scale, concentration, (num_samples,)) crand = api.jit(rand) loc = scipy.stats.weibull_min.mean(c=concentration, scale=scale) std = scipy.stats.weibull_min.std(c=concentration, scale=scale) uncompiled_samples = rand(rng) compiled_samples = crand(rng) for samples in [uncompiled_samples, compiled_samples]: # Check first and second moments. self.assertEqual((num_samples,), samples.shape) self.assertAllClose(np.mean(samples), loc, atol=0., rtol=0.1) self.assertAllClose(np.std(samples), std, atol=0., rtol=0.1) self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.weibull_min( c=concentration, scale=scale).cdf)
def test_grad_of_jit_compilation_caching(self): if not hasattr(self, "assertLogs"): raise unittest.SkipTest("test requires assertLogs (python 3)") lax.add(1, 2) # make sure some initial warnings are already printed sin = api.jit(np.sin) prev_level = logging.get_verbosity() try: logging.set_verbosity('DEBUG') with self.assertLogs(level=logging.DEBUG) as l: ans1 = api.grad(sin)(2.) ans2 = api.grad(sin)(3.) finally: logging.set_verbosity(prev_level) self.assertLen(l.output, 2) self.assertAllClose(ans1, onp.cos(2.), check_dtypes=False) self.assertAllClose(ans2, onp.cos(3.), check_dtypes=False)
def test_jit_unknown_tap(self): # Simulate an unknown tap function def func(x): x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream) x2 = hcb.id_tap(hcb._unknown_testing_consumer, x1 + 1, what="err") x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream) return x3 with self.assertRaises(hcb.TapFunctionException): with hcb.outfeed_receiver(receiver_name=self._testMethodName): res = api.jit(func)(0) # Even though the receiver thread raised, the main thread should still # return 3. self.assertEqual(3, res) # We should have received all others assertMultiLineStrippedEqual(self, """ what: x1 1 what: x3 3""", testing_stream.output) testing_stream.reset()
def test_loop_1(self): """One loop with one state var, with transforms.""" def f_op(inc): with loops.Scope() as s: s.out = 10. for _ in s.range(5): s.out += inc return s.out def f_expected(inc): return 10 + 5 * inc self.assertAllClose(f_expected(2.), f_op(2.)) self.assertAllClose(f_expected(2.), api.jit(f_op)(2.)) self.assertAllClose(5., api.grad(f_op)(2.)) self.assertAllClose(5., api.grad(f_op)(2.)) inc_batch = np.arange(5, dtype=jnp.float_) self.assertAllClose( jnp.array([f_expected(inc) for inc in inc_batch], dtype=jnp.float_), api.vmap(f_op)(inc_batch))
def test_while(self): def f_op(init): with loops.Scope() as s: s.out = init for _ in s.while_range(lambda: s.out < 5.): s.out += 2. s.out += 1. return s.out def f_expected(init): out = init while out < 5.: out += 2. out += 1. return out self.assertAllClose(f_expected(2.), f_op(2.), check_dtypes=True) self.assertAllClose(f_expected(2.), api.jit(f_op)(2.), check_dtypes=True) self.assertAllClose(f_expected(1.), f_op(1.), check_dtypes=True) init_batch = np.array([1., 2., 3.]) self.assertAllClose(np.array([f_expected(init) for init in init_batch]), api.vmap(f_op)(init_batch), check_dtypes=True)
def test_root_scalar(self): def scalar_solve(f, y): return y / f(1.0) def binary_search(func, x0, low=0.0, high=100.0, tolerance=1e-6): del x0 # unused def cond(state): low, high = state return high - low > tolerance def body(state): low, high = state midpoint = 0.5 * (low + high) update_upper = func(midpoint) > 0 low = np.where(update_upper, low, midpoint) high = np.where(update_upper, midpoint, high) return (low, high) solution, _ = lax.while_loop(cond, body, (low, high)) return solution def sqrt_cubed(x, tangent_solve=scalar_solve): f = lambda y: y ** 2 - x ** 3 return lax.root(f, 0.0, binary_search, tangent_solve) value, grad = api.value_and_grad(sqrt_cubed)(5.0) self.assertAllClose(value, 5 ** 1.5, check_dtypes=False) self.assertAllClose(grad, api.grad(pow)(5.0, 1.5), check_dtypes=False) jtu.check_grads(sqrt_cubed, (5.0,), order=2, rtol=1e-3) # TODO(shoyer): reenable when batching works # inputs = np.array([4.0, 5.0]) # results = api.vmap(sqrt_cubed)(inputs) # self.assertAllClose(results, inputs ** 1.5, check_dtypes=False) results = api.jit(sqrt_cubed)(5.0) self.assertAllClose(results, 5.0 ** 1.5, check_dtypes=False)
def testCategorical(self, p, axis, dtype, sample_shape): key = random.PRNGKey(0) p = onp.array(p, dtype=dtype) logits = onp.log(p) - 42 # test unnormalized shape = sample_shape + tuple(onp.delete(logits.shape, axis)) rand = lambda key, p: random.categorical(key, logits, shape=shape, axis=axis) crand = api.jit(rand) uncompiled_samples = rand(key, p) compiled_samples = crand(key, p) for samples in [uncompiled_samples, compiled_samples]: if axis < 0: axis += len(logits.shape) assert samples.shape == shape if len(p.shape[:-1]) > 0: for cat_index, p_ in enumerate(p): self._CheckChiSquared(samples[:, cat_index], pmf=lambda x: p_[x]) else: self._CheckChiSquared(samples, pmf=lambda x: p[x])
def testWhileWithTuple(self): limit = 10 def loop_cond(state): pos, _ = state return lax.lt(pos, limit) def loop_body(state): pos, count = state return (lax.add(pos, 1), lax.add(count, 1)) def loop(init): result = lax.while_loop(loop_cond, loop_body, (init, 0)) _, count = result return count cloop = api.jit(loop) self.assertEqual(loop(2), limit - 2) self.assertEqual(cloop(2), limit - 2) self.assertEqual(cloop(2), limit - 2) self.assertEqual(cloop(3), limit - 3)
def testRadamacher(self): rng = random.PRNGKey(0) num_samples = 10**5 rand = lambda x: random.rademacher(x, (num_samples, )) crand = api.jit(rand) uncompiled_samples = rand(rng) compiled_samples = crand(rng) for samples in [uncompiled_samples, compiled_samples]: unique_values, counts = np.unique(samples, return_counts=True) assert len(unique_values) == 2 assert len(counts) == 2 self.assertAllClose(counts[0] / num_samples, 0.5, rtol=1e-02, atol=1e-02) self.assertAllClose(counts[1] / num_samples, 0.5, rtol=1e-02, atol=1e-02)
def testMultivariateNormal(self, mean, cov, dtype): key = random.PRNGKey(0) rand = lambda key, mean, cov: random.multivariate_normal( key, mean, cov, (1000, ), dtype) crand = api.jit(rand) if hasattr(cov, "shape") and cov.ndim > 2 or hasattr( mean, "shape") and mean.ndim > 1: self.assertRaises(ValueError, lambda: rand(key, mean, cov)) self.assertRaises(ValueError, lambda: crand(key, mean, cov)) return uncompiled_samples = rand(key, mean, cov) compiled_samples = crand(key, mean, cov) if hasattr(cov, "shape") and cov.ndim == 2: inv_scale = scipy.linalg.lapack.dtrtri(onp.linalg.cholesky(cov), lower=True)[0] rescale = lambda x: onp.tensordot(x, inv_scale, axes=(-1, 1)) else: rescale = lambda x: x / np.sqrt(cov) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF( rescale(samples - mean).reshape(-1), scipy.stats.norm().cdf)
def testPoisson(self, lam, dtype): if jtu.device_under_test() == "tpu" and jnp.dtype(dtype).itemsize < 3: raise SkipTest( "random.poisson() not supported on TPU for 16-bit types.") key = random.PRNGKey(0) rand = lambda key, lam: random.poisson(key, lam, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key, lam) compiled_samples = crand(key, lam) for samples in [uncompiled_samples, compiled_samples]: self._CheckChiSquared(samples, scipy.stats.poisson(lam).pmf) # TODO(shoyer): determine error bounds for moments more rigorously (e.g., # based on the central limit theorem). self.assertAllClose(samples.mean(), lam, rtol=0.01, check_dtypes=False) self.assertAllClose(samples.var(), lam, rtol=0.03, check_dtypes=False)
def testLoopWithConjunctionCondition(self): def sum_first_n(arr, num): # pylint: disable=missing-docstring def cond_fun(state): arr, num, i, _ = state return lax.bitwise_and(lax.lt(i, num), lax.lt(i, arr.shape[0])) def body_fun(state): arr, num, i, total = state arr_i = lax.dynamic_index_in_dim(arr, i, 0, False) return (arr, num, lax.add(i, 1), lax.add(total, arr_i)) init_val = (arr, num, 0, 0.) _, _, _, total = lax.while_loop(cond_fun, body_fun, init_val) return total cfun = api.jit(sum_first_n) x = npr.RandomState(0).randn(10) for num in [0, 5, 10, 15]: self.assertAllClose(sum_first_n(x, num), onp.sum(x[:num]), check_dtypes=False) self.assertAllClose(cfun(x, num), onp.sum(x[:num]), check_dtypes=False) self.assertAllClose(cfun(x, num), onp.sum(x[:num]), check_dtypes=False)
def DISABLED_testOnesBroadcastingConstantHandler(self): # TODO(mattjj): update this test for jax3 def fun(x): ones = lnp.ones((3, 4)) assert isinstance(ones, onp.ndarray) and ones.strides == (0, 0) # To check that the constant handler generates a Broadcast for stride-zero # arrays, we monkey-patch the client instance. # TODO(mattjj): once we have better HLO dumping and inspecting facilities, # we can check the HLO more directly. c = x._node.c Broadcast = c.Broadcast # pylint: disable=invalid-name was_called = [] c.Broadcast = lambda *args: was_called.append(True) or Broadcast(*args) out = x + ones # the ndarray constant handler should call Broadcast here assert was_called, "Broadcast was not called." return out fun = api.jit(fun) out_val = fun(lnp.ones(4)) self.assertAllClose(out_val, onp.full((3, 4), 2.), check_dtypes=False)
def test_custom_linear_solve(self, symmetric): def explicit_jacobian_solve(matvec, b): return lax.stop_gradient(np.linalg.solve(api.jacobian(matvec)(b), b)) def matrix_free_solve(matvec, b): return lax.custom_linear_solve( matvec, b, explicit_jacobian_solve, explicit_jacobian_solve, symmetric=symmetric) def linear_solve(a, b): return matrix_free_solve(partial(high_precision_dot, a), b) rng = onp.random.RandomState(0) a = rng.randn(3, 3) if symmetric: a = a + a.T b = rng.randn(3) jtu.check_grads(linear_solve, (a, b), order=2, rtol=2e-3) expected = np.linalg.solve(a, b) actual = api.jit(linear_solve)(a, b) self.assertAllClose(expected, actual, check_dtypes=True)
def testMultivariateNormal(self, dim, dtype, method): r = np.random.RandomState(dim) mean = r.randn(dim) cov_factor = r.randn(dim, dim) cov = np.dot(cov_factor, cov_factor.T) + dim * np.eye(dim) key = random.PRNGKey(0) rand = partial(random.multivariate_normal, mean=mean, cov=cov, shape=(10000,), method=method) crand = api.jit(rand) uncompiled_samples = np.asarray(rand(key), np.float64) compiled_samples = np.asarray(crand(key), np.float64) inv_scale = scipy.linalg.lapack.dtrtri(np.linalg.cholesky(cov), lower=True)[0] for samples in [uncompiled_samples, compiled_samples]: centered = samples - mean whitened = np.einsum('nj,ij->ni', centered, inv_scale) # This is a quick-and-dirty multivariate normality check that tests that a # uniform mixture of the marginals along the covariance matrix's # eigenvectors follow a standard normal distribution. self._CheckKolmogorovSmirnovCDF(whitened.ravel(), scipy.stats.norm().cdf)
def testDoublesidedMaxwellSample(self, loc, scale): num_samples = 10**5 rng = random.PRNGKey(0) rand = lambda key: random.double_sided_maxwell( rng, loc, scale, (num_samples,)) crand = api.jit(rand) mean = loc std = np.sqrt(3.) * scale uncompiled_samples = rand(rng) compiled_samples = crand(rng) # Compute the double sided maxwell CDF through the one sided maxwell cdf. # This is done as follows: # P(DSM <= x) = P (loc + scale * radamacher_sample * one_sided_sample <=x) = # P (radamacher_sample * one_sided_sample <= (x - loc) / scale) = # 1/2 P(one_sided_sample <= (x - loc) / scale) # + 1/2 P( - one_sided_sample <= (x - loc) / scale) = # 1/2 P(one_sided_sample <= (x - loc) / scale) # + 1/2 P(one_sided_sample >= - (x - loc) / scale) = # 1/2 CDF_one_maxwell((x - loc) / scale)) # + 1/2 (1 - CDF_one_maxwell(- (x - loc) / scale))) def double_sided_maxwell_cdf(x, loc, scale): pos = scipy.stats.maxwell().cdf((x - loc)/ scale) neg = (1 - scipy.stats.maxwell().cdf((-x + loc)/ scale)) return (pos + neg) / 2 for samples in [uncompiled_samples, compiled_samples]: # Check first and second moments. self.assertEqual((num_samples,), samples.shape) self.assertAllClose(np.mean(samples), mean, atol=0., rtol=0.1) self.assertAllClose(np.std(samples), std, atol=0., rtol=0.1) self._CheckKolmogorovSmirnovCDF( samples, lambda x: double_sided_maxwell_cdf(x, loc, scale))
def testCategorical(self, p, axis, dtype, sample_shape): key = random.PRNGKey(0) p = np.array(p, dtype=dtype) logits = np.log(p) - 42 # test unnormalized out_shape = tuple(np.delete(logits.shape, axis)) shape = sample_shape + out_shape rand = partial(random.categorical, shape=shape, axis=axis) crand = api.jit(rand) uncompiled_samples = rand(key, logits) compiled_samples = crand(key, logits) if axis < 0: axis += len(logits.shape) for samples in [uncompiled_samples, compiled_samples]: assert samples.shape == shape samples = jnp.reshape(samples, (10000,) + out_shape) if len(p.shape[:-1]) > 0: ps = np.transpose(p, (1, 0)) if axis == 0 else p for cat_samples, cat_p in zip(samples.transpose(), ps): self._CheckChiSquared(cat_samples, pmf=lambda x: cat_p[x]) else: self._CheckChiSquared(samples, pmf=lambda x: p[x])
def test_custom_linear_solve_cholesky(self): def positive_definive_solve(a, b): factors = jsp.linalg.cho_factor(a) def solve(matvec, x): return jsp.linalg.cho_solve(factors, x) return lax.custom_linear_solve( partial(np.dot, a), b, solve, symmetric=True) rng = onp.random.RandomState(0) a = rng.randn(2, 2) b = rng.randn(2) expected = np.linalg.solve(np.dot(a, a.T), b) actual = positive_definive_solve(np.dot(a, a.T), b) self.assertAllClose(expected, actual, check_dtypes=True) actual = api.jit(positive_definive_solve)(np.dot(a, a.T), b) self.assertAllClose(expected, actual, check_dtypes=True) # numerical gradients are only well defined if ``a`` is guaranteed to be # positive definite. jtu.check_grads(lambda x, y: positive_definive_solve(np.dot(x, x.T), y), (a, b), order=2)
def test_custom_linear_solve_lu(self): def linear_solve(a, b): a_factors = jsp.linalg.lu_factor(a) at_factors = jsp.linalg.lu_factor(a.T) def solve(matvec, x): return jsp.linalg.lu_solve(a_factors, x) def transpose_solve(vecmat, x): return jsp.linalg.lu_solve(at_factors, x) return lax.custom_linear_solve( partial(np.dot, a), b, solve, transpose_solve) rng = onp.random.RandomState(0) a = rng.randn(3, 3) b = rng.randn(3) expected = np.linalg.solve(a, b) actual = linear_solve(a, b) self.assertAllClose(expected, actual, check_dtypes=True) jtu.check_grads(linear_solve, (a, b), order=2) # regression test for https://github.com/google/jax/issues/1536 jtu.check_grads(api.jit(linear_solve), (a, b), order=2)
def mc_sampling(count=10): empirical_mean = 0. key = random.PRNGKey(100) init_fn, f, _ = _build_network(train_shape[1:], network, out_logits) _kernel_fn = empirical.empirical_kernel_fn(f) kernel_fn = jit( lambda x1, x2, params: _kernel_fn(x1, x2, params, 'ntk')) for _ in range(count): key, split = random.split(key) _, params = init_fn(split, train_shape) g_dd = kernel_fn(x_train, None, params) g_td = kernel_fn(x_test, x_train, params) predictor = predict.gradient_descent_mse(g_dd, y_train, g_td) fx_initial_train = f(params, x_train) fx_initial_test = f(params, x_test) _, fx_pred_test = predictor(1.0e8, fx_initial_train, fx_initial_test) empirical_mean += fx_pred_test return empirical_mean / count
def f_pmapped(x, *args, **kwargs): args_np, args_np_idxs = [], [] args_other = {} # TODO(romann): treat `np.ndarray`s in `kwargs` when JAX allows it. # https://github.com/google/jax/issues/912 # Filter out `np.ndarray`s from other arguments. for i, arg in enumerate(args): if _is_np_ndarray(arg): args_np.append(arg) args_np_idxs.append(i) else: args_other[i] = arg # Check cache before jitting. _key = key + tuple(args_other.items()) + tuple(kwargs.items()) if _key in cache: _f = cache[_key] else: # Define a `np.ndarray`-only function as a closure over other arguments. def _f(_x, *_args_np): # Merge args. _args_np = { i: _arg_np for i, _arg_np in zip(args_np_idxs, _args_np) } _args = _merge_dicts(_args_np, args_other) _args = tuple(v for k, v in sorted(_args.items())) return f(_x, *_args, **kwargs) _f = jit(_f) if device_count == 0 else pmap(_f) cache[_key] = _f # Broadcast `np.ndarray` arguments and apply the new function to them. args_np = tree_map(broadcast, args_np) return _f(x, *args_np)
def loop_body(state): effect[0] = True pos, count = state f = lambda pos, inc: (lax.add(pos, 1), lax.add(count, inc)) return api.jit(f)(pos, inc)
def testPermutationErrors(self): key = random.PRNGKey(0) with self.assertRaises(TypeError): random.permutation(key, 10.) with self.assertRaises(core.ConcretizationTypeError): api.jit(random.permutation)(key, 10)
def test_jit_error_no_consumer(self): # Check for errors if starting jit without a consumer active with self.assertRaisesRegex(ValueError, "outfeed_receiver is not started"): api.jit(lambda x: hcb.id_print(x))(0)
def test_jit_several_together(self): arg = jnp.arange(50, dtype=jnp.int32).reshape((10, 5)) with hcb.outfeed_receiver(receiver_name=self._testMethodName): api.jit(lambda x, y: hcb.id_print((x, y, x * 2.)))( arg, jnp.ones(100, dtype=jnp.int32))
def test_jit_large(self): arg = jnp.arange(10000, dtype=jnp.int32).reshape((10, 10, 5, -1)) with hcb.outfeed_receiver(receiver_name=self._testMethodName): api.jit(hcb.id_print)(arg)