class LaxBackedScipyTests(jtu.JaxTestCase): def _fetch_preconditioner(self, preconditioner, A, rng=None): """ Returns one of various preconditioning matrices depending on the identifier `preconditioner' and the input matrix A whose inverse it supposedly approximates. """ if preconditioner == 'identity': M = np.eye(A.shape[0], dtype=A.dtype) elif preconditioner == 'random': if rng is None: rng = jtu.rand_default(self.rng()) M = np.linalg.inv(rand_sym_pos_def(rng, A.shape, A.dtype)) elif preconditioner == 'exact': M = np.linalg.inv(A) else: M = None return M @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_preconditioner={}".format( jtu.format_shape_dtype_string(shape, dtype), preconditioner), "shape": shape, "dtype": dtype, "preconditioner": preconditioner} for shape in [(4, 4), (7, 7)] for dtype in [np.float64, np.complex128] for preconditioner in [None, 'identity', 'exact', 'random'])) def test_cg_against_scipy(self, shape, dtype, preconditioner): if not config.x64_enabled: raise unittest.SkipTest("requires x64 mode") rng = jtu.rand_default(self.rng()) A = rand_sym_pos_def(rng, shape, dtype) b = rng(shape[:1], dtype) M = self._fetch_preconditioner(preconditioner, A, rng=rng) def args_maker(): return A, b self._CheckAgainstNumpy( partial(scipy_cg, M=M, maxiter=1), partial(lax_cg, M=M, maxiter=1), args_maker, tol=1e-12) self._CheckAgainstNumpy( partial(scipy_cg, M=M, maxiter=3), partial(lax_cg, M=M, maxiter=3), args_maker, tol=1e-12) self._CheckAgainstNumpy( np.linalg.solve, partial(lax_cg, M=M, atol=1e-10), args_maker, tol=1e-6) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype} for shape in [(2, 2)] for dtype in float_types + complex_types)) def test_cg_as_solve(self, shape, dtype): rng = jtu.rand_default(self.rng()) a = rng(shape, dtype) b = rng(shape[:1], dtype) expected = np.linalg.solve(posify(a), b) actual = lax_cg(posify(a), b) self.assertAllClose(expected, actual, atol=1e-5, rtol=1e-5) actual = jit(lax_cg)(posify(a), b) self.assertAllClose(expected, actual, atol=1e-5, rtol=1e-5) # numerical gradients are only well defined if ``a`` is guaranteed to be # positive definite. jtu.check_grads( lambda x, y: lax_cg(posify(x), y), (a, b), order=2, rtol=2e-1) def test_cg_ndarray(self): A = lambda x: 2 * x b = jnp.arange(9.0).reshape((3, 3)) expected = b / 2 actual, _ = jax.scipy.sparse.linalg.cg(A, b) self.assertAllClose(expected, actual) def test_cg_pytree(self): A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]} b = {"a": 1.0, "b": -4.0} expected = {"a": 4.0, "b": -6.0} actual, _ = jax.scipy.sparse.linalg.cg(A, b) self.assertEqual(expected.keys(), actual.keys()) self.assertAlmostEqual(expected["a"], actual["a"], places=6) self.assertAlmostEqual(expected["b"], actual["b"], places=6) def test_cg_errors(self): A = lambda x: x b = jnp.zeros((2,)) with self.assertRaisesRegex( ValueError, "x0 and b must have matching tree structure"): jax.scipy.sparse.linalg.cg(A, {'x': b}, {'y': b}) with self.assertRaisesRegex( ValueError, "x0 and b must have matching shape"): jax.scipy.sparse.linalg.cg(A, b, b[:, np.newaxis]) with self.assertRaisesRegex(ValueError, "must be a square matrix"): jax.scipy.sparse.linalg.cg(jnp.zeros((3, 2)), jnp.zeros((2,))) with self.assertRaisesRegex( TypeError, "linear operator must be either a function or ndarray"): jax.scipy.sparse.linalg.cg([[1]], jnp.zeros((1,))) def test_cg_without_pytree_equality(self): @register_pytree_node_class class MinimalPytree: def __init__(self, value): self.value = value def tree_flatten(self): return [self.value], None @classmethod def tree_unflatten(cls, aux_data, children): return cls(*children) A = lambda x: MinimalPytree(2 * x.value) b = MinimalPytree(jnp.arange(5.0)) expected = b.value / 2 actual, _ = jax.scipy.sparse.linalg.cg(A, b) self.assertAllClose(expected, actual.value) def test_cg_weak_types(self): x, _ = jax.scipy.sparse.linalg.bicgstab(lambda x: x, 1.0) self.assertTrue(dtypes.is_weakly_typed(x)) # BICGSTAB @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_preconditioner={}".format( jtu.format_shape_dtype_string(shape, dtype), preconditioner), "shape": shape, "dtype": dtype, "preconditioner": preconditioner} for shape in [(5, 5)] for dtype in [np.float64, np.complex128] for preconditioner in [None, 'identity', 'exact', 'random'] )) def test_bicgstab_against_scipy( self, shape, dtype, preconditioner): if not config.jax_enable_x64: raise unittest.SkipTest("requires x64 mode") rng = jtu.rand_default(self.rng()) A = rng(shape, dtype) b = rng(shape[:1], dtype) M = self._fetch_preconditioner(preconditioner, A, rng=rng) def args_maker(): return A, b self._CheckAgainstNumpy( partial(scipy_bicgstab, M=M, maxiter=1), partial(lax_bicgstab, M=M, maxiter=1), args_maker, tol=1e-5) self._CheckAgainstNumpy( partial(scipy_bicgstab, M=M, maxiter=2), partial(lax_bicgstab, M=M, maxiter=2), args_maker, tol=1e-4) self._CheckAgainstNumpy( partial(scipy_bicgstab, M=M, maxiter=1), partial(lax_bicgstab, M=M, maxiter=1), args_maker, tol=1e-4) self._CheckAgainstNumpy( np.linalg.solve, partial(lax_bicgstab, M=M, atol=1e-6), args_maker, tol=1e-4) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_preconditioner={}".format( jtu.format_shape_dtype_string(shape, dtype), preconditioner), "shape": shape, "dtype": dtype, "preconditioner": preconditioner} for shape in [(2, 2), (7, 7)] for dtype in float_types + complex_types for preconditioner in [None, 'identity', 'exact'] )) @jtu.skip_on_devices("gpu") def test_bicgstab_on_identity_system(self, shape, dtype, preconditioner): A = jnp.eye(shape[1], dtype=dtype) solution = jnp.ones(shape[1], dtype=dtype) rng = jtu.rand_default(self.rng()) M = self._fetch_preconditioner(preconditioner, A, rng=rng) b = matmul_high_precision(A, solution) tol = shape[0] * jnp.finfo(dtype).eps x, info = jax.scipy.sparse.linalg.bicgstab(A, b, tol=tol, atol=tol, M=M) using_x64 = solution.dtype.kind in {np.float64, np.complex128} solution_tol = 1e-8 if using_x64 else 1e-4 self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_preconditioner={}".format( jtu.format_shape_dtype_string(shape, dtype), preconditioner), "shape": shape, "dtype": dtype, "preconditioner": preconditioner } for shape in [(2, 2), (4, 4)] for dtype in float_types + complex_types for preconditioner in [None, 'identity', 'exact'] )) @jtu.skip_on_devices("gpu") def test_bicgstab_on_random_system(self, shape, dtype, preconditioner): rng = jtu.rand_default(self.rng()) A = rng(shape, dtype) solution = rng(shape[1:], dtype) M = self._fetch_preconditioner(preconditioner, A, rng=rng) b = matmul_high_precision(A, solution) tol = shape[0] * jnp.finfo(A.dtype).eps x, info = jax.scipy.sparse.linalg.bicgstab(A, b, tol=tol, atol=tol, M=M) using_x64 = solution.dtype.kind in {np.float64, np.complex128} solution_tol = 1e-8 if using_x64 else 1e-4 self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol) # solve = lambda A, b: jax.scipy.sparse.linalg.bicgstab(A, b)[0] # jtu.check_grads(solve, (A, b), order=1, rtol=3e-1) def test_bicgstab_pytree(self): A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]} b = {"a": 1.0, "b": -4.0} expected = {"a": 4.0, "b": -6.0} actual, _ = jax.scipy.sparse.linalg.bicgstab(A, b) self.assertEqual(expected.keys(), actual.keys()) self.assertAlmostEqual(expected["a"], actual["a"], places=5) self.assertAlmostEqual(expected["b"], actual["b"], places=5) def test_bicgstab_weak_types(self): x, _ = jax.scipy.sparse.linalg.bicgstab(lambda x: x, 1.0) self.assertTrue(dtypes.is_weakly_typed(x)) # GMRES @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_preconditioner={}_solve_method={}".format( jtu.format_shape_dtype_string(shape, dtype), preconditioner, solve_method), "shape": shape, "dtype": dtype, "preconditioner": preconditioner, "solve_method": solve_method} for shape in [(3, 3)] for dtype in [np.float64, np.complex128] for preconditioner in [None, 'identity', 'exact', 'random'] for solve_method in ['incremental', 'batched'])) def test_gmres_against_scipy( self, shape, dtype, preconditioner, solve_method): if not config.x64_enabled: raise unittest.SkipTest("requires x64 mode") rng = jtu.rand_default(self.rng()) A = rng(shape, dtype) b = rng(shape[:1], dtype) M = self._fetch_preconditioner(preconditioner, A, rng=rng) def args_maker(): return A, b self._CheckAgainstNumpy( partial(scipy_gmres, M=M, restart=1, maxiter=1), partial(lax_gmres, M=M, restart=1, maxiter=1, solve_method=solve_method), args_maker, tol=1e-10) self._CheckAgainstNumpy( partial(scipy_gmres, M=M, restart=1, maxiter=2), partial(lax_gmres, M=M, restart=1, maxiter=2, solve_method=solve_method), args_maker, tol=1e-10) self._CheckAgainstNumpy( partial(scipy_gmres, M=M, restart=2, maxiter=1), partial(lax_gmres, M=M, restart=2, maxiter=1, solve_method=solve_method), args_maker, tol=1e-10) self._CheckAgainstNumpy( np.linalg.solve, partial(lax_gmres, M=M, atol=1e-6, solve_method=solve_method), args_maker, tol=1e-10) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_preconditioner={}_solve_method={}".format( jtu.format_shape_dtype_string(shape, dtype), preconditioner, solve_method), "shape": shape, "dtype": dtype, "preconditioner": preconditioner, "solve_method": solve_method} for shape in [(2, 2), (7, 7)] for dtype in float_types + complex_types for preconditioner in [None, 'identity', 'exact'] for solve_method in ['batched', 'incremental'] )) @jtu.skip_on_devices("gpu") def test_gmres_on_identity_system(self, shape, dtype, preconditioner, solve_method): A = jnp.eye(shape[1], dtype=dtype) solution = jnp.ones(shape[1], dtype=dtype) rng = jtu.rand_default(self.rng()) M = self._fetch_preconditioner(preconditioner, A, rng=rng) b = matmul_high_precision(A, solution) restart = shape[-1] tol = shape[0] * jnp.finfo(dtype).eps x, info = jax.scipy.sparse.linalg.gmres(A, b, tol=tol, atol=tol, restart=restart, M=M, solve_method=solve_method) using_x64 = solution.dtype.kind in {np.float64, np.complex128} solution_tol = 1e-8 if using_x64 else 1e-4 self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_preconditioner={}_solve_method={}".format( jtu.format_shape_dtype_string(shape, dtype), preconditioner, solve_method), "shape": shape, "dtype": dtype, "preconditioner": preconditioner, "solve_method": solve_method} for shape in [(2, 2), (4, 4)] for dtype in float_types + complex_types for preconditioner in [None, 'identity', 'exact'] for solve_method in ['incremental', 'batched'] )) @jtu.skip_on_devices("gpu") def test_gmres_on_random_system(self, shape, dtype, preconditioner, solve_method): rng = jtu.rand_default(self.rng()) A = rng(shape, dtype) solution = rng(shape[1:], dtype) M = self._fetch_preconditioner(preconditioner, A, rng=rng) b = matmul_high_precision(A, solution) restart = shape[-1] tol = shape[0] * jnp.finfo(A.dtype).eps x, info = jax.scipy.sparse.linalg.gmres(A, b, tol=tol, atol=tol, restart=restart, M=M, solve_method=solve_method) using_x64 = solution.dtype.kind in {np.float64, np.complex128} solution_tol = 1e-8 if using_x64 else 1e-4 self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol) # solve = lambda A, b: jax.scipy.sparse.linalg.gmres(A, b)[0] # jtu.check_grads(solve, (A, b), order=1, rtol=2e-1) def test_gmres_pytree(self): A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]} b = {"a": 1.0, "b": -4.0} expected = {"a": 4.0, "b": -6.0} actual, _ = jax.scipy.sparse.linalg.gmres(A, b) self.assertEqual(expected.keys(), actual.keys()) self.assertAlmostEqual(expected["a"], actual["a"], places=5) self.assertAlmostEqual(expected["b"], actual["b"], places=5) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_preconditioner={}".format( jtu.format_shape_dtype_string(shape, dtype), preconditioner), "shape": shape, "dtype": dtype, "preconditioner": preconditioner} for shape in [(2, 2), (3, 3)] for dtype in float_types + complex_types for preconditioner in [None, 'identity'])) def test_gmres_arnoldi_step(self, shape, dtype, preconditioner): """ The Arnoldi decomposition within GMRES is correct. """ if not config.x64_enabled: raise unittest.SkipTest("requires x64 mode") rng = jtu.rand_default(self.rng()) A = rng(shape, dtype) M = self._fetch_preconditioner(preconditioner, A, rng=rng) if preconditioner is None: M = lambda x: x else: M = partial(matmul_high_precision, M) n = shape[0] x0 = rng(shape[:1], dtype) Q = np.zeros((n, n + 1), dtype=dtype) Q[:, 0] = x0/jnp.linalg.norm(x0) Q = jnp.array(Q) H = jnp.eye(n, n + 1, dtype=dtype) @jax.tree_util.Partial def A_mv(x): return matmul_high_precision(A, x) for k in range(n): Q, H, _ = jax._src.scipy.sparse.linalg._kth_arnoldi_iteration( k, A_mv, M, Q, H) QA = matmul_high_precision(Q[:, :n].conj().T, A) QAQ = matmul_high_precision(QA, Q[:, :n]) self.assertAllClose(QAQ, H.T[:n, :], rtol=1e-5, atol=1e-5) def test_gmres_weak_types(self): x, _ = jax.scipy.sparse.linalg.gmres(lambda x: x, 1.0) self.assertTrue(dtypes.is_weakly_typed(x))
class CustomObjectTest(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_compile={}_primitive={}".format(compile, primitive), "compile": compile, "primitive": primitive} for primitive in [True, False] for compile in [True, False])) def testSparseIdentity(self, compile, primitive): f = identity if primitive else (lambda x: x) f = jit(f) if compile else f rng = jtu.rand_default(self.rng()) M = make_sparse_array(rng, (10,), jnp.float32) M2 = f(M) jaxpr = make_jaxpr(f)(M).jaxpr core.check_jaxpr(jaxpr) self.assertEqual(M.dtype, M2.dtype) self.assertEqual(M.index_dtype, M2.index_dtype) self.assertAllClose(M.data, M2.data) self.assertAllClose(M.indices, M2.indices) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_compile={}".format(compile), "compile": compile} for compile in [True, False])) def testSparseSplit(self, compile): f = jit(split) if compile else split rng = jtu.rand_default(self.rng()) M = make_sparse_array(rng, (10,), jnp.float32) M2, M3 = f(M) jaxpr = make_jaxpr(f)(M).jaxpr core.check_jaxpr(jaxpr) for MM in M2, M3: self.assertEqual(M.dtype, MM.dtype) self.assertEqual(M.index_dtype, MM.index_dtype) self.assertArraysEqual(M.data, MM.data) self.assertArraysEqual(M.indices, MM.indices) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_compile={}_primitive={}".format(compile, primitive), "compile": compile, "primitive": primitive} for primitive in [True, False] for compile in [True, False])) def testSparseLaxLoop(self, compile, primitive): rng = jtu.rand_default(self.rng()) f = identity if primitive else (lambda x: x) f = jit(f) if compile else f body_fun = lambda _, A: f(A) M = make_sparse_array(rng, (10,), jnp.float32) lax.fori_loop(0, 10, body_fun, M) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_attr={}".format(attr), "attr": attr} for attr in ["data", "indices"])) def testSparseAttrAccess(self, attr): rng = jtu.rand_default(self.rng()) args_maker = lambda: [make_sparse_array(rng, (10,), jnp.float32)] f = lambda x: getattr(x, attr) self._CompileAndCheck(f, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}".format( jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype} for shape in [(3, 3), (2, 6), (6, 2)] for dtype in jtu.dtypes.floating)) def testSparseMatvec(self, shape, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: [make_sparse_array(rng, shape, dtype), rng(shape[-1:], dtype)] self._CompileAndCheck(matvec, args_maker) def testLowerToNothing(self): empty = Empty(AbstractEmpty()) jaxpr = make_jaxpr(jit(lambda e: e))(empty).jaxpr core.check_jaxpr(jaxpr) # cannot return a unit, because CompileAndCheck assumes array output. testfunc = lambda e: None args_maker = lambda: [empty] self._CompileAndCheck(testfunc, args_maker) def testConstantHandler(self): def make_const_array(): data = np.arange(3.0) indices = np.arange(3)[:, None] shape = (5,) aval = AbstractSparseArray(shape, data.dtype, indices.dtype, len(indices)) return SparseArray(aval, data, indices) out1 = make_const_array() out2 = jit(make_const_array)() self.assertArraysEqual(out1.data, out2.data) self.assertArraysEqual(out1.indices, out2.indices)
class TestPolynomial(jtu.JaxTestCase): def assertSetsAllClose(self, x, y, rtol=None, atol=None, check_dtypes=True): """Assert that x and y contain permutations of the same approximate set of values. For non-complex inputs, this is accomplished by comparing the sorted inputs. For complex, such an approach can be confounded by numerical errors. In this case, we compute the structural rank of the pairwise comparison matrix: if the structural rank is full, it implies that the matrix can be permuted so that the diagonal is non-zero, which implies a one-to-one approximate match between the permuted sets. """ x = np.asarray(x).ravel() y = np.asarray(y).ravel() atol = max(jtu.tolerance(x.dtype, atol), jtu.tolerance(y.dtype, atol)) rtol = max(jtu.tolerance(x.dtype, rtol), jtu.tolerance(y.dtype, rtol)) if not (np.issubdtype(x.dtype, np.complexfloating) or np.issubdtype(y.dtype, np.complexfloating)): return self.assertAllClose(np.sort(x), np.sort(y), atol=atol, rtol=rtol, check_dtypes=check_dtypes) if check_dtypes: self.assertEqual(x.dtype, y.dtype) self.assertEqual(x.size, y.size) pairwise = np.isclose(x[:, None], x[None, :], atol=atol, rtol=rtol, equal_nan=True) rank = csgraph.structural_rank(csr_matrix(pairwise)) self.assertEqual(rank, x.size) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_dtype={}_leading={}_trailing={}".format( jtu.format_shape_dtype_string(( length + leading + trailing, ), dtype), leading, trailing), "dtype": dtype, "length": length, "leading": leading, "trailing": trailing } for dtype in all_dtypes for length in [0, 3, 5] for leading in [0, 2] for trailing in [0, 2])) # TODO(phawkins): no nonsymmetric eigendecomposition implementation on GPU. @jtu.skip_on_devices("gpu", "tpu") def testRoots(self, dtype, length, leading, trailing): rng = jtu.rand_some_zero(self.rng()) def args_maker(): p = rng((length, ), dtype) return [ jnp.concatenate([ jnp.zeros(leading, p.dtype), p, jnp.zeros(trailing, p.dtype) ]) ] jnp_fun = jnp.roots def np_fun(arg): return np.roots(arg).astype(dtypes._to_complex_dtype(arg.dtype)) # Note: outputs have no defined order, so we need to use a special comparator. args = args_maker() np_roots = np_fun(*args) jnp_roots = jnp_fun(*args) self.assertSetsAllClose(np_roots, jnp_roots) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_dtype={}_leading={}_trailing={}".format( jtu.format_shape_dtype_string(( length + leading + trailing, ), dtype), leading, trailing), "dtype": dtype, "length": length, "leading": leading, "trailing": trailing } for dtype in all_dtypes for length in [0, 3, 5] for leading in [0, 2] for trailing in [0, 2])) # TODO(phawkins): no nonsymmetric eigendecomposition implementation on GPU. @jtu.skip_on_devices("gpu", "tpu") def testRootsNoStrip(self, dtype, length, leading, trailing): rng = jtu.rand_some_zero(self.rng()) def args_maker(): p = rng((length, ), dtype) return [ jnp.concatenate([ jnp.zeros(leading, p.dtype), p, jnp.zeros(trailing, p.dtype) ]) ] jnp_fun = partial(jnp.roots, strip_zeros=False) def np_fun(arg): roots = np.roots(arg).astype(dtypes._to_complex_dtype(arg.dtype)) if len(roots) < len(arg) - 1: roots = np.pad(roots, (0, len(arg) - len(roots) - 1), constant_values=complex(np.nan, np.nan)) return roots # Note: outputs have no defined order, so we need to use a special comparator. args = args_maker() np_roots = np_fun(*args) jnp_roots = jnp_fun(*args) self.assertSetsAllClose(np_roots, jnp_roots) self._CompileAndCheck(jnp_fun, args_maker)
class LaxVmapTest(jtu.JaxTestCase): def _CheckBatching(self, op, bdim_size, bdims, shapes, dtypes, rng, rtol=None, atol=None, multiple_results=False): batched_shapes = map(partial(add_bdim, bdim_size), bdims, shapes) args = [rng(shape, dtype) for shape, dtype in zip(batched_shapes, dtypes)] args_slice = args_slicer(args, bdims) ans = jax.vmap(op, bdims)(*args) if bdim_size == 0: args = [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] out = op(*args) if not multiple_results: expected = np.zeros((0,) + out.shape, out.dtype) else: expected = [np.zeros((0,) + o.shape, o.dtype) for o in out] else: outs = [op(*args_slice(i)) for i in range(bdim_size)] if not multiple_results: expected = np.stack(outs) else: expected = [np.stack(xs) for xs in zip(*outs)] self.assertAllClose(ans, expected, rtol=rtol, atol=atol) @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": "{}_bdims={}".format( jtu.format_test_name_suffix(rec.op, shapes, itertools.repeat(dtype)), bdims), "op_name": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes, "dtype": dtype, "bdims": bdims, "tol": rec.tol} for shape_group in compatible_shapes for shapes in itertools.combinations_with_replacement(shape_group, rec.nargs) for bdims in all_bdims(*shapes) for dtype in rec.dtypes) for rec in LAX_OPS)) def testOp(self, op_name, rng_factory, shapes, dtype, bdims, tol): rng = rng_factory(self.rng()) op = getattr(lax, op_name) self._CheckBatching(op, 10, bdims, shapes, [dtype] * len(shapes), rng, atol=tol, rtol=tol) @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({ "testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_" "rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}" "_lhs_bdim={}_rhs_bdim={}" .format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums), feature_group_count, batch_group_count, lhs_bdim, rhs_bdim), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "strides": strides, "padding": padding, "lhs_dil": lhs_dil, "rhs_dil": rhs_dil, "dimension_numbers": dim_nums, "perms": perms, "lhs_bdim": lhs_bdim, "rhs_bdim": rhs_bdim, "feature_group_count": feature_group_count, "batch_group_count": batch_group_count, } for batch_group_count, feature_group_count in s([(1, 1), (2, 1), (1, 2)]) for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in s([ ((b * batch_group_count, i * feature_group_count, 6, 7), # lhs_shape (j * batch_group_count * feature_group_count, i, 1, 2), # rhs_shape [(1, 1), (1, 2), (2, 1)], # strides [((0, 0), (0, 0)), ((1, 0), (0, 1)), ((0, -1), (0, 0))], # pads [(1, 1), (2, 1)], # lhs_dils [(1, 1), (2, 2)]) # rhs_dils for b, i, j in itertools.product([1, 2], repeat=3)]) for strides in s(all_strides) for rhs_dil in s(rhs_dils) for lhs_dil in s(lhs_dils) for dtype in s([np.float32]) for padding in s(all_pads) for dim_nums, perms in s([ (("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])), (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])), (("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))]) for lhs_bdim in s(itertools.chain([cast(Optional[int], None)], range(len(lhs_shape) + 1))) for rhs_bdim in s(itertools.chain([cast(Optional[int], None)], range(len(rhs_shape) + 1))) if (lhs_bdim, rhs_bdim) != (None, None) ))) def testConvGeneralDilatedBatching( self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dil, rhs_dil, dimension_numbers, perms, feature_group_count, batch_group_count, lhs_bdim, rhs_bdim): rng = jtu.rand_default(self.rng()) tol = 1e-1 if dtypes.finfo(dtype).bits <= 32 else 1e-3 # permute shapes to match dim_spec, scale by feature_group_count lhs_perm, rhs_perm = perms lhs_shape = list(np.take(lhs_shape, lhs_perm)) rhs_shape = list(np.take(rhs_shape, rhs_perm)) conv = partial(lax.conv_general_dilated, window_strides=strides, padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil, dimension_numbers=dimension_numbers, feature_group_count=feature_group_count, batch_group_count=batch_group_count, precision=lax.Precision.HIGHEST) self._CheckBatching(conv, 5, (lhs_bdim, rhs_bdim), (lhs_shape, rhs_shape), (dtype, dtype), rng, rtol=tol, atol=tol) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_from_dtype={}_to_dtype={}_bdims={}".format( shape, from_dtype, to_dtype, bdims), "shape": shape, "from_dtype": from_dtype, "to_dtype": to_dtype, "bdims": bdims} for from_dtype, to_dtype in itertools.product( [np.float32, np.int32, "float32", "int32"], repeat=2) for shape in [(2, 3)] for bdims in all_bdims(shape))) def testConvertElementType(self, shape, from_dtype, to_dtype, bdims): rng = jtu.rand_default(self.rng()) op = lambda x: lax.convert_element_type(x, to_dtype) self._CheckBatching(op, 10, bdims, (shape,), (from_dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_nmant={}_nexp={}_bdims={}".format( jtu.format_shape_dtype_string(shape, dtype), nmant, nexp, bdims), "shape": shape, "dtype": dtype, "nmant": nmant, "nexp": nexp, "bdims": bdims} for dtype in float_dtypes for shape in [(2, 4)] for nexp in [1, 3, 5] for nmant in [0, 2, 4] for bdims in all_bdims(shape))) def testReducePrecision(self, shape, dtype, nmant, nexp, bdims): rng = jtu.rand_default(self.rng()) op = lambda x: lax.reduce_precision(x, exponent_bits=nexp, mantissa_bits=nmant) self._CheckBatching(op, 10, bdims, (shape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_from_dtype={}_to_dtype={}_bdims={}".format( shape, from_dtype, to_dtype, bdims), "shape": shape, "from_dtype": from_dtype, "to_dtype": to_dtype, "bdims": bdims} for from_dtype, to_dtype in itertools.product( [np.float32, np.int32, "float32", "int32"], repeat=2) for shape in [(2, 3)] for bdims in all_bdims(shape))) def testBitcastElementType(self, shape, from_dtype, to_dtype, bdims,): rng = jtu.rand_default(self.rng()) op = lambda x: lax.bitcast_convert_type(x, to_dtype) self._CheckBatching(op, 10, bdims, (shape,), (from_dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_min_shape={}_operand_shape={}_max_shape={}_bdims={}" .format(jtu.format_shape_dtype_string(min_shape, dtype), jtu.format_shape_dtype_string(operand_shape, dtype), jtu.format_shape_dtype_string(max_shape, dtype), bdims), "min_shape": min_shape, "operand_shape": operand_shape, "max_shape": max_shape, "dtype": dtype, "bdims": bdims} for min_shape, operand_shape, max_shape in [ [(), (2, 3), ()], [(2, 3), (2, 3), ()], [(), (2, 3), (2, 3)], [(2, 3), (2, 3), (2, 3)], ] for dtype in default_dtypes for bdims in all_bdims(min_shape, operand_shape, max_shape))) def testClamp(self, min_shape, operand_shape, max_shape, dtype, bdims): rng = jtu.rand_default(self.rng()) shapes = [min_shape, operand_shape, max_shape] self._CheckBatching(lax.clamp, 10, bdims, shapes, [dtype] * 3, rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}_bdims={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), bdims), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "bdims": bdims} for lhs_shape in [(3,), (4, 3)] for rhs_shape in [(3,), (3, 6)] for bdims in all_bdims(lhs_shape, rhs_shape) for dtype in default_dtypes)) def testDot(self, lhs_shape, rhs_shape, dtype, bdims): rng = jtu.rand_default(self.rng()) op = partial(lax.dot, precision=lax.Precision.HIGHEST) self._CheckBatching(op, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype), rng, rtol={np.float16: 5e-2, np.float64: 5e-14}) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}_lhs_contracting={}_rhs_contracting={}_bdims={}" .format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), lhs_contracting, rhs_contracting, bdims), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "lhs_contracting": lhs_contracting, "rhs_contracting": rhs_contracting, "bdims": bdims} for lhs_shape, rhs_shape, lhs_contracting, rhs_contracting in [ [(5,), (5,), [0], [0]], [(5, 7), (5,), [0], [0]], [(7, 5), (5,), [1], [0]], [(3, 5), (2, 5), [1], [1]], [(5, 3), (5, 2), [0], [0]], [(5, 3, 2), (5, 2, 4), [0], [0]], [(5, 3, 2), (5, 2, 4), [0,2], [0,1]], [(5, 3, 2), (3, 5, 2, 4), [0,2], [1,2]], [(1, 2, 2, 3), (1, 2, 3, 1), [1], [1]], [(3, 2), (2, 4), [1], [0]], ] for bdims in all_bdims(lhs_shape, rhs_shape) for dtype in default_dtypes)) def testDotGeneralContractOnly(self, lhs_shape, rhs_shape, dtype, lhs_contracting, rhs_contracting, bdims): rng = jtu.rand_small(self.rng()) dimension_numbers = ((lhs_contracting, rhs_contracting), ([], [])) dot = partial(lax.dot_general, dimension_numbers=dimension_numbers) self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}_dimension_numbers={}_bdims={}" .format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), dimension_numbers, bdims), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "dimension_numbers": dimension_numbers, "bdims": bdims} for lhs_shape, rhs_shape, dimension_numbers in [ ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))), ((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1]))), ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))), ] for bdims in all_bdims(lhs_shape, rhs_shape) for dtype in default_dtypes)) def testDotGeneralContractAndBatch(self, lhs_shape, rhs_shape, dtype, dimension_numbers, bdims): rng = jtu.rand_small(self.rng()) dot = partial(lax.dot_general, dimension_numbers=dimension_numbers) self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype), rng) # Checks that batching didn't introduce any transposes or broadcasts. jaxpr = jax.make_jaxpr(dot)(np.zeros(lhs_shape, dtype), np.zeros(rhs_shape, dtype)) for eqn in jtu.iter_eqns(jaxpr.jaxpr): self.assertFalse(eqn.primitive in ["transpose", "broadcast"]) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}_bdims={}".format( shape, np.dtype(dtype).name, broadcast_sizes, bdims), "shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes, "bdims": bdims} for shape in [(), (2, 3)] for dtype in default_dtypes for broadcast_sizes in [(), (2,), (1, 2)] for bdims in all_bdims(shape))) def testBroadcast(self, shape, dtype, broadcast_sizes, bdims): rng = jtu.rand_default(self.rng()) op = lambda x: lax.broadcast(x, broadcast_sizes) self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_outshape={}_bcdims={}_bdims={}".format( jtu.format_shape_dtype_string(inshape, dtype), outshape, broadcast_dimensions, bdims), "inshape": inshape, "dtype": dtype, "outshape": outshape, "dimensions": broadcast_dimensions, "bdims": bdims} for inshape, outshape, broadcast_dimensions in [ ([2], [2, 2], [0]), ([2], [2, 2], [1]), ([2], [2, 3], [0]), ([], [2, 3], []), ] for dtype in default_dtypes for bdims in all_bdims(inshape))) @unittest.skip("this test has failures in some cases") # TODO(mattjj) def testBroadcastInDim(self, inshape, dtype, outshape, dimensions, bdims): rng = jtu.rand_default(self.rng()) op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions) self._CheckBatching(op, 5, bdims, (inshape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_dimensions={}_bdims={}".format( jtu.format_shape_dtype_string(arg_shape, np.float32), dimensions, bdims), "arg_shape": arg_shape, "dimensions": dimensions, "bdims": bdims} for arg_shape, dimensions in [ [(1,), (0,)], [(1,), (-1,)], [(2, 1, 4), (1,)], [(2, 1, 4), (-2,)], [(2, 1, 3, 1), (1,)], [(2, 1, 3, 1), (1, 3)], [(2, 1, 3, 1), (3,)], [(2, 1, 3, 1), (1, -1)], ] for bdims in all_bdims(arg_shape))) def testSqueeze(self, arg_shape, dimensions, bdims): dtype = np.float32 rng = jtu.rand_default(self.rng()) op = lambda x: lax.squeeze(x, dimensions) self._CheckBatching(op, 10, bdims, (arg_shape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_outshape={}_dims={}_bdims={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), jtu.format_shape_dtype_string(out_shape, dtype), dimensions, bdims), "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype, "dimensions": dimensions, "bdims": bdims} for dtype in default_dtypes for arg_shape, dimensions, out_shape in [ [(3, 4), None, (12,)], [(2, 1, 4), None, (8,)], [(2, 2, 4), None, (2, 8)], [(2, 2, 4), (0, 1, 2), (2, 8)], [(2, 2, 4), (1, 0, 2), (8, 2)], [(2, 2, 4), (2, 1, 0), (4, 2, 2)] ] for bdims in all_bdims(arg_shape))) def testReshape(self, arg_shape, out_shape, dtype, dimensions, bdims): rng = jtu.rand_default(self.rng()) op = lambda x: lax.reshape(x, out_shape, dimensions=dimensions) self._CheckBatching(op, 10, bdims, (arg_shape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_pads={}_bdims={}" .format(jtu.format_shape_dtype_string(shape, dtype), pads, bdims), "shape": shape, "dtype": dtype, "pads": pads, "bdims": bdims} for shape in [(2, 3)] for bdims in all_bdims(shape, ()) for dtype in default_dtypes for pads in [[(1, 2, 1), (0, 1, 0)]])) def testPad(self, shape, dtype, pads, bdims): rng = jtu.rand_small(self.rng()) fun = lambda operand, padding: lax.pad(operand, padding, pads) self._CheckBatching(fun, 5, bdims, (shape, ()), (dtype, dtype), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_predshape={}_argshapes={}_bdims={}".format( jtu.format_shape_dtype_string(pred_shape, np.bool_), jtu.format_shape_dtype_string(arg_shape, arg_dtype), bdims), "pred_shape": pred_shape, "arg_shape": arg_shape, "arg_dtype": arg_dtype, "bdims": bdims} for arg_shape in [(), (3,), (2, 3)] for pred_shape in ([(), arg_shape] if arg_shape else [()]) for bdims in all_bdims(pred_shape, arg_shape, arg_shape) for arg_dtype in default_dtypes)) def testSelect(self, pred_shape, arg_shape, arg_dtype, bdims): rng = jtu.rand_default(self.rng()) op = lambda c, x, y: lax.select(c < 0, x, y) self._CheckBatching(op, 5, bdims, (pred_shape, arg_shape, arg_shape,), (np.bool_, arg_dtype, arg_dtype), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_start_indices={}_limit_indices={}_strides={}_bdims={}".format( jtu.format_shape_dtype_string(shape, dtype), start_indices, limit_indices, strides, bdims), "shape": shape, "dtype": dtype, "starts": start_indices, "limits": limit_indices, "strides": strides, "bdims": bdims} for shape, start_indices, limit_indices, strides in [ [(3,), (1,), (2,), None], [(7,), (4,), (7,), None], [(5,), (1,), (5,), (2,)], [(8,), (1,), (6,), (2,)], [(5, 3), (1, 1), (3, 2), None], [(5, 3), (1, 1), (3, 1), None], [(7, 5, 3), (4, 0, 1), (7, 1, 3), None], [(5, 3), (1, 1), (2, 1), (1, 1)], [(5, 3), (1, 1), (5, 3), (2, 1)], ] for bdims in all_bdims(shape) for dtype in default_dtypes)) def testSlice(self, shape, dtype, starts, limits, strides, bdims): rng = jtu.rand_default(self.rng()) op = lambda x: lax.slice(x, starts, limits, strides) self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_perm={}_bdims={}".format( jtu.format_shape_dtype_string(shape, dtype), perm, bdims), "shape": shape, "dtype": dtype, "perm": perm, "bdims": bdims} for shape, perm in [ [(3, 4), (1, 0)], [(3, 4), (0, 1)], [(3, 4, 5), (2, 1, 0)], [(3, 4, 5), (1, 0, 2)], ] for bdims in all_bdims(shape) for dtype in default_dtypes)) def testTranspose(self, shape, dtype, perm, bdims): rng = jtu.rand_default(self.rng()) op = lambda x: lax.transpose(x, perm) self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_op={}_inshape={}_reducedims={}_initval={}_bdims={}" .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims, init_val, bdims), "op": op, "init_val": init_val, "shape": shape, "dtype": dtype, "dims": dims, "bdims": bdims} for init_val, op, dtypes in [ (0, lax.add, default_dtypes), (1, lax.mul, default_dtypes), (0, lax.max, all_dtypes), # non-monoidal (-np.inf, lax.max, float_dtypes), (dtypes.iinfo(np.int32).min, lax.max, [np.int32]), (dtypes.iinfo(np.int64).min, lax.max, [np.int64]), (dtypes.iinfo(np.uint32).min, lax.max, [np.uint32]), (dtypes.iinfo(np.uint64).min, lax.max, [np.uint64]), (np.inf, lax.min, float_dtypes), (dtypes.iinfo(np.int32).max, lax.min, [np.int32]), (dtypes.iinfo(np.int64).max, lax.min, [np.int64]), (dtypes.iinfo(np.uint32).max, lax.min, [np.uint32]), (dtypes.iinfo(np.uint64).max, lax.min, [np.uint64]), ] for dtype in dtypes for shape, dims in [ [(3, 4, 5), (0,)], [(3, 4, 5), (1, 2)], [(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)] ] for bdims in all_bdims(shape))) def testReduce(self, op, init_val, shape, dtype, dims, bdims): rng = jtu.rand_small(self.rng()) init_val = np.asarray(init_val, dtype=dtype) fun = lambda operand: lax.reduce(operand, init_val, op, dims) self._CheckBatching(fun, 5, bdims, (shape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_reducedims={}_bdims={}" .format(jtu.format_shape_dtype_string(shape, dtype), dims, bdims), "shape": shape, "dtype": dtype, "dims": dims, "bdims": bdims} for dtype in default_dtypes for shape, dims in [ [(3, 4, 5), (0,)], [(3, 4, 5), (1, 2)], [(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)] ] for bdims in all_bdims(shape, shape))) def testVariadicReduce(self, shape, dtype, dims, bdims): def op(a, b): x1, y1 = a x2, y2 = b return x1 + x2, y1 * y2 rng = jtu.rand_small(self.rng()) init_val = tuple(np.asarray([0, 1], dtype=dtype)) fun = lambda x, y: lax.reduce((x, y), init_val, op, dims) self._CheckBatching(fun, 5, bdims, (shape, shape), (dtype, dtype), rng, multiple_results=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_op={}_inshape={}_reducedims={}_bdims={}" .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dim, bdims), "op": op, "shape": shape, "dtype": dtype, "dim": dim, "bdims": bdims} for op in [lax.argmin, lax.argmax] for dtype in default_dtypes for shape in [(3, 4, 5)] for dim in range(len(shape)) for bdims in all_bdims(shape))) def testArgminmax(self, op, shape, dtype, dim, bdims): rng = jtu.rand_default(self.rng()) fun = lambda operand: op(operand, dim, np.int32) self._CheckBatching(fun, 5, bdims, (shape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": ("_op={}_shape={}_dims={}_strides={}_padding={}" "_basedilation={}_windowdilation={}") .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims, strides, padding, base_dilation, window_dilation), "op": op, "init_val": init_val, "dtype": dtype, "shape": shape, "dims": dims, "strides": strides, "padding": padding, "base_dilation": base_dilation, "window_dilation": window_dilation} for init_val, op, dtypes in [ (0, lax.add, [np.float32]), (-np.inf, lax.max, [np.float32]), (np.inf, lax.min, [np.float32]), ] for shape, dims, strides, padding, base_dilation, window_dilation in ( itertools.chain( itertools.product( [(4, 6)], [(2, 1), (1, 2)], [(1, 1), (2, 1), (1, 2)], ["VALID", "SAME", [(0, 3), (1, 2)]], [(1, 1), (2, 3)], [(1, 1), (1, 2)]), itertools.product( [(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)], [(1, 2, 2, 1), (1, 1, 1, 1)], ["VALID", "SAME", [(0, 1), (1, 0), (2, 3), (0, 2)]], [(1, 1, 1, 1), (2, 1, 3, 2)], [(1, 1, 1, 1), (1, 2, 2, 1)]))) for dtype in dtypes)) def testReduceWindow(self, op, init_val, dtype, shape, dims, strides, padding, base_dilation, window_dilation): rng = jtu.rand_small(self.rng()) init_val = np.asarray(init_val, dtype=dtype) def fun(operand): return lax.reduce_window(operand, init_val, op, dims, strides, padding, base_dilation, window_dilation) for bdims in all_bdims(shape): self._CheckBatching(fun, 3, bdims, (shape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_op={}_shape={}_axis={}_bdims={}_reverse={}" .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis, bdims, reverse), "op": op, "shape": shape, "dtype": dtype, "bdims": bdims, "axis": axis, "reverse": reverse} for op, types in [ (lax.cumsum, [np.float32, np.float64]), (lax.cumprod, [np.float32, np.float64]), ] for dtype in types for shape in [[10], [3, 4, 5]] for axis in range(len(shape)) for bdims in all_bdims(shape) for reverse in [False, True])) def testCumulativeReduce(self, op, shape, dtype, axis, bdims, reverse): rng_factory = (jtu.rand_default if dtypes.issubdtype(dtype, np.integer) else jtu.rand_small) rng = rng_factory(self.rng()) self._CheckBatching(partial(op, axis=axis, reverse=reverse), 7, bdims, (shape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_dtype={}_padding={}".format(np.dtype(dtype).name, padding), "dtype": dtype, "padding": padding} for dtype in float_dtypes for padding in ["VALID", "SAME"])) @jtu.skip_on_flag("jax_skip_slow_tests", True) @jtu.ignore_warning(message="Using reduced precision for gradient.*") def testSelectAndGatherAdd(self, dtype, padding): rng = jtu.rand_small(self.rng()) all_configs = itertools.chain( itertools.product( [(4, 6)], [(2, 1), (1, 2)], [(1, 1), (2, 1), (1, 2)]), itertools.product( [(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)], [(1, 2, 2, 1), (1, 1, 1, 1)])) def fun(operand, tangents): pads = lax.padtype_to_pads(operand.shape, dims, strides, padding) ones = (1,) * len(operand.shape) return lax._select_and_gather_add(operand, tangents, lax.ge_p, dims, strides, pads, ones, ones) for shape, dims, strides in all_configs: for bdims in all_bdims(shape, shape): self._CheckBatching(fun, 3, bdims, (shape, shape), (dtype, dtype), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": f"_dtype={jtu.format_shape_dtype_string(shape, dtype)}" f"_padding={padding}_dims={dims}_strides={strides}", "dtype": dtype, "padding": padding, "shape": shape, "dims": dims, "strides": strides} for dtype in float_dtypes for padding in ["VALID", "SAME"] for shape in [(3, 2, 4, 6)] for dims in [(1, 1, 2, 1)] for strides in [(1, 2, 2, 1), (1, 1, 1, 1)])) def testSelectAndScatterAdd(self, dtype, padding, shape, dims, strides): rng = jtu.rand_small(self.rng()) pads = lax.padtype_to_pads(shape, dims, strides, padding) def fun(operand, cotangents): return lax._select_and_scatter_add(operand, cotangents, lax.ge_p, dims, strides, pads) ones = (1,) * len(shape) cotangent_shape = jax.eval_shape( lambda x: lax._select_and_gather_add(x, x, lax.ge_p, dims, strides, pads, ones, ones), np.ones(shape, dtype)).shape for bdims in all_bdims(cotangent_shape, shape): self._CheckBatching(fun, 3, bdims, (cotangent_shape, shape), (dtype, dtype), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_bdims={}_fft_ndims={}" .format(shape, bdims, fft_ndims), "shape": shape, "bdims": bdims, "fft_ndims": fft_ndims} for shape in [(5,), (3, 4, 5), (2, 3, 4, 5)] for bdims in all_bdims(shape) for fft_ndims in range(0, min(3, len(shape)) + 1))) def testFft(self, fft_ndims, shape, bdims): rng = jtu.rand_default(self.rng()) ndims = len(shape) axes = range(ndims - fft_ndims, ndims) fft_lengths = tuple(shape[axis] for axis in axes) op = lambda x: lax.fft(x, xla_client.FftType.FFT, fft_lengths) self._CheckBatching(op, 5, bdims, [shape], [np.complex64], rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}_bdims={}" .format(jtu.format_shape_dtype_string(shape, dtype), idxs, dnums, slice_sizes, bdims), "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "bdims": bdims} for dtype in all_dtypes for shape, idxs, dnums, slice_sizes in [ ((5,), np.array([[0], [2]]), lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), (1,)), ((10,), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), (2,)), ((10, 5,), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3)), ((10, 5), np.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)), ] for bdims in all_bdims(shape, idxs.shape))) def testGather(self, shape, dtype, idxs, dnums, slice_sizes, bdims): fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) self._CheckBatching(fun, 0, bdims, [shape, idxs.shape], [dtype, idxs.dtype], jtu.rand_default(self.rng())) self._CheckBatching(fun, 5, bdims, [shape, idxs.shape], [dtype, idxs.dtype], jtu.rand_default(self.rng())) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}_bdims={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), idxs, update_shape, dnums, bdims), "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, "update_shape": update_shape, "dnums": dnums, "bdims": bdims} for dtype in float_dtypes for arg_shape, idxs, update_shape, dnums in [ ((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))), ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), ] for bdims in all_bdims(arg_shape, idxs.shape, update_shape))) def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums, bdims): fun = partial(lax.scatter_add, dimension_numbers=dnums) self._CheckBatching(fun, 5, bdims, [arg_shape, idxs.shape, update_shape], [dtype, idxs.dtype, dtype], jtu.rand_default(self.rng()), rtol={np.float16: 5e-3, dtypes.bfloat16: 3e-2}) def testShapeUsesBuiltinInt(self): x = lax.iota(np.int32, 3) + 1 self.assertIsInstance(x.shape[0], int) # not np.int64 def testBroadcastShapesReturnsPythonInts(self): shape1, shape2 = (1, 2, 3), (2, 3) out_shape = lax.broadcast_shapes(shape1, shape2) self.assertTrue(all(type(s) is int for s in out_shape)) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_k={}_bdims={}".format( jtu.format_shape_dtype_string(shape, dtype), k, bdims), "shape": shape, "dtype": dtype, "k": k, "bdims": bdims, "rng_factory": rng_factory} for shape in [(4,), (3, 5, 3)] for k in [1, 3] for bdims in all_bdims(shape) # TODO(b/155170120): test with repeats once the XLA:CPU stable top_k bug is fixed: # The top_k indices for integer arrays with identical entries won't match between # vmap'd version and manual reference, so only test unique integer arrays for int_dtypes. # Note also that we chose 3 * 5 * 3 * 5 such that it fits in the range of # values a bfloat16 can represent exactly to avoid ties. for dtype, rng_factory in itertools.chain( unsafe_zip(default_dtypes, itertools.repeat(jtu.rand_unique_int))))) def testTopK(self, shape, dtype, k, bdims, rng_factory): rng = rng_factory(self.rng()) # _CheckBatching doesn't work with tuple outputs, so test outputs separately. op1 = lambda x: lax.top_k(x, k=k)[0] self._CheckBatching(op1, 5, bdims, (shape,), (dtype,), rng) op2 = lambda x: lax.top_k(x, k=k)[1] self._CheckBatching(op2, 5, bdims, (shape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_dimension={}_arity={}_bdims={}_isstable={}" .format(jtu.format_shape_dtype_string(shape, np.float32), dimension, arity, bdims, is_stable), "shape": shape, "dimension": dimension, "arity": arity, "bdims": bdims, "is_stable": is_stable} for shape in [(2, 3)] for dimension in [0, 1] for arity in range(3) for bdims in all_bdims(*((shape,) * arity)) for is_stable in [False, True])) def testSort(self, shape, dimension, arity, bdims, is_stable): rng = jtu.rand_default(self.rng()) if arity == 1: fun = partial(lax.sort, dimension=dimension) self._CheckBatching(fun, 5, bdims, (shape,) * arity, (np.float32,) * arity, rng) else: for i in range(arity): fun = lambda *args, i=i: lax.sort(args, dimension=dimension, is_stable=is_stable)[i] self._CheckBatching(fun, 5, bdims, (shape,) * arity, (np.float32,) * arity, rng)
class DLPackTest(jtu.JaxTestCase): def setUp(self): super().setUp() if jtu.device_under_test() == "tpu": self.skipTest("DLPack not supported on TPU") @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_take_ownership={}_gpu={}".format( jtu.format_shape_dtype_string(shape, dtype), take_ownership, gpu), "shape": shape, "dtype": dtype, "take_ownership": take_ownership, "gpu": gpu} for shape in all_shapes for dtype in dlpack_dtypes for take_ownership in [False, True] for gpu in [False, True])) @jtu.skip_on_devices("rocm") # TODO(sharadmv,phawkins): see GH issue #10973 def testJaxRoundTrip(self, shape, dtype, take_ownership, gpu): rng = jtu.rand_default(self.rng()) np = rng(shape, dtype) if gpu and jax.default_backend() == "cpu": raise unittest.SkipTest("Skipping GPU test case on CPU") device = jax.devices("gpu" if gpu else "cpu")[0] x = jax.device_put(np, device) dlpack = jax.dlpack.to_dlpack(x, take_ownership=take_ownership) self.assertEqual(take_ownership, x.device_buffer.is_deleted()) y = jax.dlpack.from_dlpack(dlpack) self.assertEqual(y.device(), device) self.assertAllClose(np.astype(x.dtype), y) self.assertRaisesRegex(RuntimeError, "DLPack tensor may be consumed at most once", lambda: jax.dlpack.from_dlpack(dlpack)) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}".format( jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype} for shape in all_shapes for dtype in dlpack_dtypes)) @unittest.skipIf(not tf, "Test requires TensorFlow") @jtu.skip_on_devices("rocm") # TODO(sharadmv,phawkins): see GH issue #10973 def testTensorFlowToJax(self, shape, dtype): if not config.x64_enabled and dtype in [jnp.int64, jnp.uint64, jnp.float64]: raise self.skipTest("x64 types are disabled by jax_enable_x64") if (jtu.device_under_test() == "gpu" and not tf.config.list_physical_devices("GPU")): raise self.skipTest("TensorFlow not configured with GPU support") if jtu.device_under_test() == "gpu" and dtype == jnp.int32: raise self.skipTest("TensorFlow does not place int32 tensors on GPU") rng = jtu.rand_default(self.rng()) np = rng(shape, dtype) with tf.device("/GPU:0" if jtu.device_under_test() == "gpu" else "/CPU:0"): x = tf.identity(tf.constant(np)) dlpack = tf.experimental.dlpack.to_dlpack(x) y = jax.dlpack.from_dlpack(dlpack) self.assertAllClose(np, y) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}".format( jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype} for shape in all_shapes for dtype in dlpack_dtypes)) @unittest.skipIf(not tf, "Test requires TensorFlow") def testJaxToTensorFlow(self, shape, dtype): if not config.x64_enabled and dtype in [jnp.int64, jnp.uint64, jnp.float64]: self.skipTest("x64 types are disabled by jax_enable_x64") if (jtu.device_under_test() == "gpu" and not tf.config.list_physical_devices("GPU")): raise self.skipTest("TensorFlow not configured with GPU support") rng = jtu.rand_default(self.rng()) np = rng(shape, dtype) x = jnp.array(np) # TODO(b/171320191): this line works around a missing context initialization # bug in TensorFlow. _ = tf.add(1, 1) dlpack = jax.dlpack.to_dlpack(x) y = tf.experimental.dlpack.from_dlpack(dlpack) self.assertAllClose(np, y.numpy()) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}".format( jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype} for shape in all_shapes for dtype in torch_dtypes)) @unittest.skipIf(not torch, "Test requires PyTorch") def testTorchToJax(self, shape, dtype): if not config.x64_enabled and dtype in [jnp.int64, jnp.float64]: self.skipTest("x64 types are disabled by jax_enable_x64") rng = jtu.rand_default(self.rng()) np = rng(shape, dtype) x = torch.from_numpy(np) x = x.cuda() if jtu.device_under_test() == "gpu" else x dlpack = torch.utils.dlpack.to_dlpack(x) y = jax.dlpack.from_dlpack(dlpack) self.assertAllClose(np, y) @unittest.skipIf(not torch, "Test requires PyTorch") def testTorchToJaxFailure(self): x = torch.arange(6).reshape((2, 3)) y = torch.utils.dlpack.to_dlpack(x[:, :2]) backend = xla_bridge.get_backend() client = getattr(backend, "client", backend) regex_str = (r'UNIMPLEMENTED: Only DLPack tensors with trivial \(compact\) ' r'striding are supported') with self.assertRaisesRegex(RuntimeError, regex_str): xla_client._xla.dlpack_managed_tensor_to_buffer( y, client) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}".format( jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype} for shape in all_shapes for dtype in torch_dtypes)) @unittest.skipIf(not torch, "Test requires PyTorch") def testJaxToTorch(self, shape, dtype): if not config.x64_enabled and dtype in [jnp.int64, jnp.float64]: self.skipTest("x64 types are disabled by jax_enable_x64") rng = jtu.rand_default(self.rng()) np = rng(shape, dtype) x = jnp.array(np) dlpack = jax.dlpack.to_dlpack(x) y = torch.utils.dlpack.from_dlpack(dlpack) self.assertAllClose(np, y.cpu().numpy())
class LaxBackedScipySignalTests(jtu.JaxTestCase): """Tests for LAX-backed scipy.stats implementations""" @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_op={}_xshape={}_yshape={}_mode={}".format( op, jtu.format_shape_dtype_string(xshape, dtype), jtu.format_shape_dtype_string(yshape, dtype), mode), "xshape": xshape, "yshape": yshape, "dtype": dtype, "mode": mode, "jsp_op": getattr(jsp_signal, op), "osp_op": getattr(osp_signal, op) } for mode in ['full', 'same', 'valid'] for op in ['convolve', 'correlate'] for dtype in default_dtypes for shapeset in [onedim_shapes, twodim_shapes, threedim_shapes] for xshape in shapeset for yshape in shapeset)) def testConvolutions(self, xshape, yshape, dtype, mode, jsp_op, osp_op): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)] osp_fun = partial(osp_op, mode=mode) jsp_fun = partial(jsp_op, mode=mode, precision=lax.Precision.HIGHEST) tol = { np.float16: 1e-2, np.float32: 1e-2, np.float64: 1e-12, np.complex64: 1e-2, np.complex128: 1e-12 } self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False, tol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "op={}_xshape={}_yshape={}_mode={}".format( op, jtu.format_shape_dtype_string(xshape, dtype), jtu.format_shape_dtype_string(yshape, dtype), mode), "xshape": xshape, "yshape": yshape, "dtype": dtype, "mode": mode, "jsp_op": getattr(jsp_signal, op), "osp_op": getattr(osp_signal, op) } for mode in ['full', 'same', 'valid'] for op in ['convolve2d', 'correlate2d'] for dtype in default_dtypes for xshape in twodim_shapes for yshape in twodim_shapes)) def testConvolutions2D(self, xshape, yshape, dtype, mode, jsp_op, osp_op): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)] osp_fun = partial(osp_op, mode=mode) jsp_fun = partial(jsp_op, mode=mode, precision=lax.Precision.HIGHEST) tol = { np.float16: 1e-2, np.float32: 1e-2, np.float64: 1e-12, np.complex64: 1e-2, np.complex128: 1e-12 } self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False, tol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_axis={}_type={}_bp={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, type, bp), "shape": shape, "dtype": dtype, "axis": axis, "type": type, "bp": bp } for shape in [(5, ), (4, 5), (3, 4, 5)] for dtype in jtu.dtypes.floating + jtu.dtypes.integer for axis in [0, -1] for type in ['constant', 'linear'] for bp in [0, [0, 2]])) @jtu.skip_on_devices("rocm") # will be fixed in rocm-5.1 def testDetrend(self, shape, dtype, axis, type, bp): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] osp_fun = partial(osp_signal.detrend, axis=axis, type=type, bp=bp) jsp_fun = partial(jsp_signal.detrend, axis=axis, type=type, bp=bp) tol = {np.float32: 1e-5, np.float64: 1e-12} self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, tol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol) @parameterized.named_parameters( jtu. cases_from_list({ "testcase_name": f"_shape={jtu.format_shape_dtype_string(shape, dtype)}" f"_fs={fs}_window={window}_boundary={boundary}_detrend={detrend}" f"_padded={padded}_nperseg={nperseg}_noverlap={noverlap}" f"_axis={timeaxis}_nfft={nfft}", "shape": shape, "dtype": dtype, "fs": fs, "window": window, "nperseg": nperseg, "noverlap": noverlap, "nfft": nfft, "detrend": detrend, "boundary": boundary, "padded": padded, "timeaxis": timeaxis } for shape, nperseg, noverlap, timeaxis in stft_test_shapes for dtype in default_dtypes for fs in [1.0, 16000.0] for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann'] for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2] for detrend in ['constant', 'linear', False] for boundary in [None, 'even', 'odd', 'zeros'] for padded in [True, False])) @jtu.skip_on_devices("rocm") # will be fixed in ROCm 5.1 def testStftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg, noverlap, nfft, detrend, boundary, padded, timeaxis): is_complex = np.dtype(dtype).kind == 'c' if is_complex and detrend is not None: return osp_fun = partial(osp_signal.stft, fs=fs, window=window, nfft=nfft, boundary=boundary, padded=padded, detrend=detrend, nperseg=nperseg, noverlap=noverlap, axis=timeaxis, return_onesided=not is_complex) jsp_fun = partial(jsp_signal.stft, fs=fs, window=window, nfft=nfft, boundary=boundary, padded=padded, detrend=detrend, nperseg=nperseg, noverlap=noverlap, axis=timeaxis, return_onesided=not is_complex) tol = { np.float32: 1e-5, np.float64: 1e-12, np.complex64: 1e-5, np.complex128: 1e-12 } if jtu.device_under_test() == 'tpu': tol = _TPU_FFT_TOL rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol) # Tests with `average == 'median'`` is excluded from `testCsd*` # due to the issue: # https://github.com/scipy/scipy/issues/15601 @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": f"_xshape={jtu.format_shape_dtype_string(xshape, dtype)}" f"_yshape={jtu.format_shape_dtype_string(yshape, dtype)}" f"_average={average}_scaling={scaling}_nfft={nfft}" f"_fs={fs}_window={window}_detrend={detrend}" f"_nperseg={nperseg}_noverlap={noverlap}" f"_axis={timeaxis}", "xshape": xshape, "yshape": yshape, "dtype": dtype, "fs": fs, "window": window, "nperseg": nperseg, "noverlap": noverlap, "nfft": nfft, "detrend": detrend, "scaling": scaling, "timeaxis": timeaxis, "average": average } for xshape, yshape, nperseg, noverlap, timeaxis in csd_test_shapes for dtype in default_dtypes for fs in [1.0, 16000.0] for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann'] for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2] for detrend in ['constant', 'linear', False] for scaling in ['density', 'spectrum'] for average in ['mean'])) @jtu.skip_on_devices("rocm") # will be fixed in next ROCm version def testCsdAgainstNumpy(self, *, xshape, yshape, dtype, fs, window, nperseg, noverlap, nfft, detrend, scaling, timeaxis, average): is_complex = np.dtype(dtype).kind == 'c' if is_complex and detrend is not None: raise unittest.SkipTest( "Complex signal is not supported in lax-backed `signal.detrend`." ) osp_fun = partial(osp_signal.csd, fs=fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=not is_complex, scaling=scaling, axis=timeaxis, average=average) jsp_fun = partial(jsp_signal.csd, fs=fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=not is_complex, scaling=scaling, axis=timeaxis, average=average) tol = { np.float32: 1e-5, np.float64: 1e-12, np.complex64: 1e-5, np.complex128: 1e-12 } if jtu.device_under_test() == 'tpu': tol = _TPU_FFT_TOL rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)] self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": f"_shape={jtu.format_shape_dtype_string(shape, dtype)}" f"_average={average}_scaling={scaling}_nfft={nfft}" f"_fs={fs}_window={window}_detrend={detrend}" f"_nperseg={nperseg}_noverlap={noverlap}" f"_axis={timeaxis}", "shape": shape, "dtype": dtype, "fs": fs, "window": window, "nperseg": nperseg, "noverlap": noverlap, "nfft": nfft, "detrend": detrend, "scaling": scaling, "timeaxis": timeaxis, "average": average } for shape, unused_yshape, nperseg, noverlap, timeaxis in csd_test_shapes for dtype in default_dtypes for fs in [1.0, 16000.0] for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann'] for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2] for detrend in ['constant', 'linear', False] for scaling in ['density', 'spectrum'] for average in ['mean'])) @jtu.skip_on_devices("rocm") # will be fixed in next rocm release def testCsdWithSameParamAgainstNumpy(self, *, shape, dtype, fs, window, nperseg, noverlap, nfft, detrend, scaling, timeaxis, average): is_complex = np.dtype(dtype).kind == 'c' if is_complex and detrend is not None: raise unittest.SkipTest( "Complex signal is not supported in lax-backed `signal.detrend`." ) def osp_fun(x, y): # When the identical parameters are given, jsp-version follows # the behavior with copied parameters. freqs, Pxy = osp_signal.csd(x, y.copy(), fs=fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=not is_complex, scaling=scaling, axis=timeaxis, average=average) return freqs, Pxy jsp_fun = partial(jsp_signal.csd, fs=fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=not is_complex, scaling=scaling, axis=timeaxis, average=average) tol = { np.float32: 1e-5, np.float64: 1e-12, np.complex64: 1e-5, np.complex128: 1e-12 } if jtu.device_under_test() == 'tpu': tol = _TPU_FFT_TOL rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] * 2 self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": f"_shape={jtu.format_shape_dtype_string(shape, dtype)}" f"_fs={fs}_window={window}" f"_nperseg={nperseg}_noverlap={noverlap}_nfft={nfft}" f"_detrend={detrend}_return_onesided={return_onesided}" f"_scaling={scaling}_axis={timeaxis}_average={average}", "shape": shape, "dtype": dtype, "fs": fs, "window": window, "nperseg": nperseg, "noverlap": noverlap, "nfft": nfft, "detrend": detrend, "return_onesided": return_onesided, "scaling": scaling, "timeaxis": timeaxis, "average": average } for shape, nperseg, noverlap, timeaxis in welch_test_shapes for dtype in default_dtypes for fs in [1.0, 16000.0] for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann'] for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2] for detrend in ['constant', 'linear', False] for return_onesided in [True, False] for scaling in ['density', 'spectrum'] for average in ['mean', 'median'])) @jtu.skip_on_devices("rocm") # will be fixed in next ROCm release def testWelchAgainstNumpy(self, *, shape, dtype, fs, window, nperseg, noverlap, nfft, detrend, return_onesided, scaling, timeaxis, average): if np.dtype(dtype).kind == 'c': return_onesided = False if detrend is not None: raise unittest.SkipTest( "Complex signal is not supported in lax-backed `signal.detrend`." ) osp_fun = partial(osp_signal.welch, fs=fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=return_onesided, scaling=scaling, axis=timeaxis, average=average) jsp_fun = partial(jsp_signal.welch, fs=fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=return_onesided, scaling=scaling, axis=timeaxis, average=average) tol = { np.float32: 1e-5, np.float64: 1e-12, np.complex64: 1e-5, np.complex128: 1e-12 } if jtu.device_under_test() == 'tpu': tol = _TPU_FFT_TOL rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": f"_shape={jtu.format_shape_dtype_string(shape, dtype)}" f"_nperseg={nperseg}_noverlap={noverlap}" f"_use_nperseg={use_nperseg}_use_overlap={use_noverlap}" f"_axis={timeaxis}", "shape": shape, "dtype": dtype, "nperseg": nperseg, "noverlap": noverlap, "use_nperseg": use_nperseg, "use_noverlap": use_noverlap, "timeaxis": timeaxis } for shape, nperseg, noverlap, timeaxis in welch_test_shapes for use_nperseg in [False, True] for use_noverlap in [False, True] for dtype in jtu.dtypes.floating + jtu.dtypes.integer)) def testWelchWithDefaultStepArgsAgainstNumpy(self, *, shape, dtype, nperseg, noverlap, use_nperseg, use_noverlap, timeaxis): kwargs = {'axis': timeaxis} if use_nperseg: kwargs['nperseg'] = nperseg else: kwargs['window'] = osp_signal.get_window('hann', nperseg) if use_noverlap: kwargs['noverlap'] = noverlap osp_fun = partial(osp_signal.welch, **kwargs) jsp_fun = partial(jsp_signal.welch, **kwargs) tol = { np.float32: 1e-5, np.float64: 1e-12, np.complex64: 1e-5, np.complex128: 1e-12 } if jtu.device_under_test() == 'tpu': tol = _TPU_FFT_TOL rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
class NNInitializersTest(jtu.JaxTestCase): def setUp(self): super().setUp() config.update("jax_numpy_rank_promotion", "raise") def tearDown(self): super().tearDown() config.update("jax_numpy_rank_promotion", "allow") @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}_{}".format(rec.name, jtu.format_shape_dtype_string(shape, dtype)), "initializer": rec.initializer(), "shape": shape, "dtype": dtype } for rec in INITIALIZER_RECS for shape in rec.shapes for dtype in rec.dtypes)) def testInitializer(self, initializer, shape, dtype): rng = random.PRNGKey(0) val = initializer(rng, shape, dtype) self.assertEqual(shape, jnp.shape(val)) self.assertEqual(jax.dtypes.canonicalize_dtype(dtype), jnp.dtype(val)) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}_{}".format(rec.name, jtu.format_shape_dtype_string(shape, dtype)), "initializer_provider": rec.initializer, "shape": shape, "dtype": dtype } for rec in INITIALIZER_RECS for shape in rec.shapes for dtype in rec.dtypes)) def testInitializerProvider(self, initializer_provider, shape, dtype): rng = random.PRNGKey(0) initializer = initializer_provider(dtype=dtype) val = initializer(rng, shape) self.assertEqual(shape, jnp.shape(val)) self.assertEqual(jax.dtypes.canonicalize_dtype(dtype), jnp.dtype(val)) def testVarianceScalingMultiAxis(self): rng = random.PRNGKey(0) shape = (2, 3, 4, 5) initializer = nn.initializers.variance_scaling( scale=1.0, mode='fan_avg', distribution='truncated_normal', in_axis=(0, 1), out_axis=(-2, -1)) val = initializer(rng, shape) self.assertEqual(shape, jnp.shape(val))
class LaxBackedScipySignalTests(jtu.JaxTestCase): """Tests for LAX-backed scipy.stats implementations""" @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_op={}_xshape={}_yshape={}_mode={}".format( op, jtu.format_shape_dtype_string(xshape, dtype), jtu.format_shape_dtype_string(yshape, dtype), mode), "xshape": xshape, "yshape": yshape, "dtype": dtype, "mode": mode, "jsp_op": getattr(jsp_signal, op), "osp_op": getattr(osp_signal, op) } for mode in ['full', 'same', 'valid'] for op in ['convolve', 'correlate'] for dtype in default_dtypes for shapeset in [onedim_shapes, twodim_shapes, threedim_shapes] for xshape in shapeset for yshape in shapeset)) def testConvolutions(self, xshape, yshape, dtype, mode, jsp_op, osp_op): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)] osp_fun = partial(osp_op, mode=mode) jsp_fun = partial(jsp_op, mode=mode, precision=lax.Precision.HIGHEST) tol = { np.float16: 1e-2, np.float32: 1e-2, np.float64: 1e-12, np.complex64: 1e-2, np.complex128: 1e-12 } self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False, tol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "op={}_xshape={}_yshape={}_mode={}".format( op, jtu.format_shape_dtype_string(xshape, dtype), jtu.format_shape_dtype_string(yshape, dtype), mode), "xshape": xshape, "yshape": yshape, "dtype": dtype, "mode": mode, "jsp_op": getattr(jsp_signal, op), "osp_op": getattr(osp_signal, op) } for mode in ['full', 'same', 'valid'] for op in ['convolve2d', 'correlate2d'] for dtype in default_dtypes for xshape in twodim_shapes for yshape in twodim_shapes)) def testConvolutions2D(self, xshape, yshape, dtype, mode, jsp_op, osp_op): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)] osp_fun = partial(osp_op, mode=mode) jsp_fun = partial(jsp_op, mode=mode, precision=lax.Precision.HIGHEST) tol = { np.float16: 1e-2, np.float32: 1e-2, np.float64: 1e-12, np.complex64: 1e-2, np.complex128: 1e-12 } self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False, tol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_axis={}_type={}_bp={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, type, bp), "shape": shape, "dtype": dtype, "axis": axis, "type": type, "bp": bp } for shape in [(5, ), (4, 5), (3, 4, 5)] for dtype in jtu.dtypes.floating + jtu.dtypes.integer for axis in [0, -1] for type in ['constant', 'linear'] for bp in [0, [0, 2]])) def testDetrend(self, shape, dtype, axis, type, bp): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] kwds = dict(axis=axis, type=type, bp=bp) def osp_fun(x): return osp_signal.detrend(x, **kwds).astype( dtypes._to_inexact_dtype(x.dtype)) jsp_fun = partial(jsp_signal.detrend, **kwds) if jtu.device_under_test() == 'tpu': tol = {np.float32: 3e-2, np.float64: 1e-12} else: tol = {np.float32: 1e-5, np.float64: 1e-12} self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, tol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol) @parameterized.named_parameters( jtu. cases_from_list({ "testcase_name": f"_shape={jtu.format_shape_dtype_string(shape, dtype)}" f"_fs={fs}_window={window}_boundary={boundary}_detrend={detrend}" f"_padded={padded}_nperseg={nperseg}_noverlap={noverlap}" f"_axis={timeaxis}_nfft={nfft}", "shape": shape, "dtype": dtype, "fs": fs, "window": window, "nperseg": nperseg, "noverlap": noverlap, "nfft": nfft, "detrend": detrend, "boundary": boundary, "padded": padded, "timeaxis": timeaxis } for shape, nperseg, noverlap, timeaxis in stft_test_shapes for dtype in default_dtypes for fs in [1.0, 16000.0] for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann'] for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2] for detrend in ['constant', 'linear', False] for boundary in [None, 'even', 'odd', 'zeros'] for padded in [True, False])) def testStftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg, noverlap, nfft, detrend, boundary, padded, timeaxis): is_complex = dtypes.issubdtype(dtype, np.complexfloating) if is_complex and detrend is not None: self.skipTest( "Complex signal is not supported in lax-backed `signal.detrend`." ) kwds = dict(fs=fs, window=window, nfft=nfft, boundary=boundary, padded=padded, detrend=detrend, nperseg=nperseg, noverlap=noverlap, axis=timeaxis, return_onesided=not is_complex) def osp_fun(x): freqs, time, Pxx = osp_signal.stft(x, **kwds) return freqs.astype(_real_dtype(dtype)), time.astype( _real_dtype(dtype)), Pxx.astype(_complex_dtype(dtype)) jsp_fun = partial(jsp_signal.stft, **kwds) tol = { np.float32: 1e-5, np.float64: 1e-12, np.complex64: 1e-5, np.complex128: 1e-12 } if jtu.device_under_test() == 'tpu': tol = _TPU_FFT_TOL rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol) # Tests with `average == 'median'`` is excluded from `testCsd*` # due to the issue: # https://github.com/scipy/scipy/issues/15601 @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": f"_xshape={jtu.format_shape_dtype_string(xshape, dtype)}" f"_yshape={jtu.format_shape_dtype_string(yshape, dtype)}" f"_average={average}_scaling={scaling}_nfft={nfft}" f"_fs={fs}_window={window}_detrend={detrend}" f"_nperseg={nperseg}_noverlap={noverlap}" f"_axis={timeaxis}", "xshape": xshape, "yshape": yshape, "dtype": dtype, "fs": fs, "window": window, "nperseg": nperseg, "noverlap": noverlap, "nfft": nfft, "detrend": detrend, "scaling": scaling, "timeaxis": timeaxis, "average": average } for xshape, yshape, nperseg, noverlap, timeaxis in csd_test_shapes for dtype in default_dtypes for fs in [1.0, 16000.0] for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann'] for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2] for detrend in ['constant', 'linear', False] for scaling in ['density', 'spectrum'] for average in ['mean'])) def testCsdAgainstNumpy(self, *, xshape, yshape, dtype, fs, window, nperseg, noverlap, nfft, detrend, scaling, timeaxis, average): is_complex = dtypes.issubdtype(dtype, np.complexfloating) if is_complex and detrend is not None: self.skipTest( "Complex signal is not supported in lax-backed `signal.detrend`." ) kwds = dict(fs=fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=not is_complex, scaling=scaling, axis=timeaxis, average=average) def osp_fun(x, y): freqs, Pxy = osp_signal.csd(x, y, **kwds) # Make type-casting the same as JAX. return freqs.astype(_real_dtype(dtype)), Pxy.astype( _complex_dtype(dtype)) jsp_fun = partial(jsp_signal.csd, **kwds) tol = { np.float32: 1e-5, np.float64: 1e-12, np.complex64: 1e-5, np.complex128: 1e-12 } if jtu.device_under_test() == 'tpu': tol = _TPU_FFT_TOL rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)] self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": f"_shape={jtu.format_shape_dtype_string(shape, dtype)}" f"_average={average}_scaling={scaling}_nfft={nfft}" f"_fs={fs}_window={window}_detrend={detrend}" f"_nperseg={nperseg}_noverlap={noverlap}" f"_axis={timeaxis}", "shape": shape, "dtype": dtype, "fs": fs, "window": window, "nperseg": nperseg, "noverlap": noverlap, "nfft": nfft, "detrend": detrend, "scaling": scaling, "timeaxis": timeaxis, "average": average } for shape, unused_yshape, nperseg, noverlap, timeaxis in csd_test_shapes for dtype in default_dtypes for fs in [1.0, 16000.0] for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann'] for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2] for detrend in ['constant', 'linear', False] for scaling in ['density', 'spectrum'] for average in ['mean'])) def testCsdWithSameParamAgainstNumpy(self, *, shape, dtype, fs, window, nperseg, noverlap, nfft, detrend, scaling, timeaxis, average): is_complex = dtypes.issubdtype(dtype, np.complexfloating) if is_complex and detrend is not None: self.skipTest( "Complex signal is not supported in lax-backed `signal.detrend`." ) kwds = dict(fs=fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=not is_complex, scaling=scaling, axis=timeaxis, average=average) def osp_fun(x, y): # When the identical parameters are given, jsp-version follows # the behavior with copied parameters. freqs, Pxy = osp_signal.csd(x, y.copy(), **kwds) # Make type-casting the same as JAX. return freqs.astype(_real_dtype(dtype)), Pxy.astype( _complex_dtype(dtype)) jsp_fun = partial(jsp_signal.csd, **kwds) tol = { np.float32: 1e-5, np.float64: 1e-12, np.complex64: 1e-5, np.complex128: 1e-12 } if jtu.device_under_test() == 'tpu': tol = _TPU_FFT_TOL rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] * 2 self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": f"_shape={jtu.format_shape_dtype_string(shape, dtype)}" f"_fs={fs}_window={window}" f"_nperseg={nperseg}_noverlap={noverlap}_nfft={nfft}" f"_detrend={detrend}_return_onesided={return_onesided}" f"_scaling={scaling}_axis={timeaxis}_average={average}", "shape": shape, "dtype": dtype, "fs": fs, "window": window, "nperseg": nperseg, "noverlap": noverlap, "nfft": nfft, "detrend": detrend, "return_onesided": return_onesided, "scaling": scaling, "timeaxis": timeaxis, "average": average } for shape, nperseg, noverlap, timeaxis in welch_test_shapes for dtype in default_dtypes for fs in [1.0, 16000.0] for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann'] for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2] for detrend in ['constant', 'linear', False] for return_onesided in [True, False] for scaling in ['density', 'spectrum'] for average in ['mean', 'median'])) def testWelchAgainstNumpy(self, *, shape, dtype, fs, window, nperseg, noverlap, nfft, detrend, return_onesided, scaling, timeaxis, average): if np.dtype(dtype).kind == 'c': return_onesided = False if detrend is not None: raise unittest.SkipTest( "Complex signal is not supported in lax-backed `signal.detrend`." ) kwds = dict(fs=fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=return_onesided, scaling=scaling, axis=timeaxis, average=average) def osp_fun(x): freqs, Pxx = osp_signal.welch(x, **kwds) return freqs.astype(_real_dtype(dtype)), Pxx.astype( _real_dtype(dtype)) jsp_fun = partial(jsp_signal.welch, **kwds) tol = { np.float32: 1e-5, np.float64: 1e-12, np.complex64: 1e-5, np.complex128: 1e-12 } if jtu.device_under_test() == 'tpu': tol = _TPU_FFT_TOL rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": f"_shape={jtu.format_shape_dtype_string(shape, dtype)}" f"_nperseg={nperseg}_noverlap={noverlap}" f"_use_nperseg={use_nperseg}_use_overlap={use_noverlap}" f"_axis={timeaxis}", "shape": shape, "dtype": dtype, "nperseg": nperseg, "noverlap": noverlap, "use_nperseg": use_nperseg, "use_noverlap": use_noverlap, "timeaxis": timeaxis } for shape, nperseg, noverlap, timeaxis in welch_test_shapes for use_nperseg in [False, True] for use_noverlap in [False, True] for dtype in jtu.dtypes.floating + jtu.dtypes.integer)) def testWelchWithDefaultStepArgsAgainstNumpy(self, *, shape, dtype, nperseg, noverlap, use_nperseg, use_noverlap, timeaxis): kwargs = {'axis': timeaxis} if use_nperseg: kwargs['nperseg'] = nperseg else: kwargs['window'] = jnp.array(osp_signal.get_window( 'hann', nperseg), dtype=dtypes._to_complex_dtype(dtype)) if use_noverlap: kwargs['noverlap'] = noverlap def osp_fun(x): freqs, Pxx = osp_signal.welch(x, **kwargs) return freqs.astype(_real_dtype(dtype)), Pxx.astype( _real_dtype(dtype)) jsp_fun = partial(jsp_signal.welch, **kwargs) tol = { np.float32: 1e-5, np.float64: 1e-12, np.complex64: 1e-5, np.complex128: 1e-12 } if jtu.device_under_test() == 'tpu': tol = _TPU_FFT_TOL rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": f"_shape={jtu.format_shape_dtype_string(shape, dtype)}" f"_fs={fs}_window={window}_boundary={boundary}" f"_nperseg={nperseg}_noverlap={noverlap}_onesided={onesided}" f"_timeaxis={timeaxis}_freqaxis{freqaxis}_nfft={nfft}", "shape": shape, "dtype": dtype, "fs": fs, "window": window, "nperseg": nperseg, "noverlap": noverlap, "nfft": nfft, "onesided": onesided, "boundary": boundary, "timeaxis": timeaxis, "freqaxis": freqaxis } for shape, nperseg, noverlap, timeaxis, freqaxis in istft_test_shapes for dtype in default_dtypes for fs in [1.0, 16000.0] for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann'] for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2] for onesided in [False, True] for boundary in [False, True])) def testIstftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg, noverlap, nfft, onesided, boundary, timeaxis, freqaxis): if not onesided: new_freq_len = (shape[freqaxis] - 1) * 2 shape = shape[:freqaxis] + (new_freq_len, ) + shape[freqaxis + 1:] kwds = dict(fs=fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, input_onesided=onesided, boundary=boundary, time_axis=timeaxis, freq_axis=freqaxis) osp_fun = partial(osp_signal.istft, **kwds) osp_fun = jtu.ignore_warning( message="NOLA condition failed, STFT may not be invertible")( osp_fun) jsp_fun = partial(jsp_signal.istft, **kwds) tol = { np.float32: 1e-4, np.float64: 1e-6, np.complex64: 1e-4, np.complex128: 1e-6 } if jtu.device_under_test() == 'tpu': tol = _TPU_FFT_TOL rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] # Here, dtype of output signal is different depending on osp versions, # and so depending on the test environment. Thus, dtype check is disabled. self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol, check_dtypes=False) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
class AnnTest(jtu.JaxTestCase): @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_qy={}_db={}_k={}_recall={}".format( jtu.format_shape_dtype_string(qy_shape, dtype), jtu.format_shape_dtype_string(db_shape, dtype), k, recall), "qy_shape": qy_shape, "db_shape": db_shape, "dtype": dtype, "k": k, "recall": recall } for qy_shape in [(200, 128), (128, 128)] for db_shape in [(128, 500), (128, 3000)] for dtype in jtu.dtypes.all_floating for k in [1, 10, 50] for recall in [0.9, 0.95])) def test_approx_max_k(self, qy_shape, db_shape, dtype, k, recall): rng = jtu.rand_default(self.rng()) qy = rng(qy_shape, dtype) db = rng(db_shape, dtype) scores = lax.dot(qy, db) _, gt_args = lax.top_k(scores, k) _, ann_args = ann.approx_max_k(scores, k, recall_target=recall) self.assertEqual(k, len(ann_args[0])) gt_args_sets = [set(np.asarray(x)) for x in gt_args] hits = sum( len(list(x for x in ann_args_per_q if x.item() in gt_args_sets[q])) for q, ann_args_per_q in enumerate(ann_args)) self.assertGreater(hits / (qy_shape[0] * k), recall) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_qy={}_db={}_k={}_recall={}".format( jtu.format_shape_dtype_string(qy_shape, dtype), jtu.format_shape_dtype_string(db_shape, dtype), k, recall), "qy_shape": qy_shape, "db_shape": db_shape, "dtype": dtype, "k": k, "recall": recall } for qy_shape in [(200, 128), (128, 128)] for db_shape in [(128, 500), (128, 3000)] for dtype in jtu.dtypes.all_floating for k in [1, 10, 50] for recall in [0.9, 0.95])) def test_approx_min_k(self, qy_shape, db_shape, dtype, k, recall): rng = jtu.rand_default(self.rng()) qy = rng(qy_shape, dtype) db = rng(db_shape, dtype) scores = lax.dot(qy, db) _, gt_args = lax.top_k(-scores, k) _, ann_args = ann.approx_min_k(scores, k, recall_target=recall) self.assertEqual(k, len(ann_args[0])) gt_args_sets = [set(np.asarray(x)) for x in gt_args] hits = sum( len(list(x for x in ann_args_per_q if x.item() in gt_args_sets[q])) for q, ann_args_per_q in enumerate(ann_args)) self.assertGreater(hits / (qy_shape[0] * k), recall) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_k={}_max_k={}".format( jtu.format_shape_dtype_string(shape, dtype), k, is_max_k), "shape": shape, "dtype": dtype, "k": k, "is_max_k": is_max_k } for dtype in [np.float32] for shape in [(4,), (5, 5), (2, 1, 4)] for k in [1, 3] for is_max_k in [True, False])) def test_autodiff(self, shape, dtype, k, is_max_k): vals = np.arange(prod(shape), dtype=dtype) vals = self.rng().permutation(vals).reshape(shape) if is_max_k: fn = lambda vs: ann.approx_max_k(vs, k=k)[0] else: fn = lambda vs: ann.approx_min_k(vs, k=k)[0] jtu.check_grads(fn, (vals,), 2, ["fwd", "rev"], eps=1e-2)
class SparsifyTest(jtu.JaxTestCase): @classmethod def sparsify(cls, f): return sparsify(f, use_tracer=False) def testNotImplementedMessages(self): x = BCOO.fromdense(jnp.arange(5.0)) # Test a densifying primitive with self.assertRaisesRegex( NotImplementedError, r"^sparse rule for cos is not implemented because it would result in dense output\." ): self.sparsify(lax.cos)(x) # Test a generic not implemented primitive. with self.assertRaisesRegex( NotImplementedError, r"^sparse rule for complex is not implemented\.$"): self.sparsify(lax.complex)(x, x) def testTracerIsInstanceCheck(self): @self.sparsify def f(x): self.assertNotIsInstance(x, SparseTracer) f(jnp.arange(5)) def assertBcooIdentical(self, x, y): self.assertIsInstance(x, BCOO) self.assertIsInstance(y, BCOO) self.assertEqual(x.shape, y.shape) self.assertArraysEqual(x.data, y.data) self.assertArraysEqual(x.indices, y.indices) def testSparsifyValue(self): X = jnp.arange(5) X_BCOO = BCOO.fromdense(X) args = (X, X_BCOO, X_BCOO) # Independent index spenv = SparsifyEnv() spvalues = arrays_to_spvalues(spenv, args) self.assertEqual(len(spvalues), len(args)) self.assertLen(spenv._buffers, 5) self.assertEqual( spvalues, (SparsifyValue( X.shape, 0, None, indices_sorted=False, unique_indices=False), SparsifyValue( X.shape, 1, 2, indices_sorted=True, unique_indices=True), SparsifyValue( X.shape, 3, 4, indices_sorted=True, unique_indices=True))) args_out = spvalues_to_arrays(spenv, spvalues) self.assertEqual(len(args_out), len(args)) self.assertArraysEqual(args[0], args_out[0]) self.assertBcooIdentical(args[1], args_out[1]) self.assertBcooIdentical(args[2], args_out[2]) # Shared index spvalues = (SparsifyValue(X.shape, 0, None), SparsifyValue(X.shape, 1, 2), SparsifyValue(X.shape, 3, 2)) spenv = SparsifyEnv([X, X_BCOO.data, X_BCOO.indices, X_BCOO.data]) args_out = spvalues_to_arrays(spenv, spvalues) self.assertEqual(len(args_out), len(args)) self.assertArraysEqual(args[0], args_out[0]) self.assertBcooIdentical(args[1], args_out[1]) self.assertBcooIdentical(args[2], args_out[2]) def testDropvar(self): def inner(x): return x * 2, x * 3 def f(x): _, y = jit(inner)(x) return y * 4 x_dense = jnp.arange(5) x_sparse = BCOO.fromdense(x_dense) self.assertArraysEqual( self.sparsify(f)(x_sparse).todense(), f(x_dense)) def testPytreeInput(self): f = self.sparsify(lambda x: x) args = (jnp.arange(4), BCOO.fromdense(jnp.arange(4))) out = f(args) self.assertLen(out, 2) self.assertArraysEqual(args[0], out[0]) self.assertBcooIdentical(args[1], out[1]) @jax.numpy_dtype_promotion( 'standard') # explicitly exercises implicit dtype promotion. def testSparsify(self): M_dense = jnp.arange(24).reshape(4, 6) M_sparse = BCOO.fromdense(M_dense) v = jnp.arange(M_dense.shape[0]) @self.sparsify def func(x, v): return -jnp.sin(jnp.pi * x).T @ (v + 1) with jtu.ignore_warning( category=CuSparseEfficiencyWarning, message= "bcoo_dot_general GPU lowering requires matrices with sorted indices*" ): result_sparse = func(M_sparse, v) result_dense = func(M_dense, v) self.assertAllClose(result_sparse, result_dense) def testSparsifyWithConsts(self): M_dense = jnp.arange(24).reshape(4, 6) M_sparse = BCOO.fromdense(M_dense) @self.sparsify def func(x): return jit(lambda x: jnp.sum(x, 1))(x) result_dense = func(M_dense) result_sparse = func(M_sparse) self.assertAllClose(result_sparse.todense(), result_dense) def testSparseMatmul(self): X = jnp.arange(16.0).reshape(4, 4) Xsp = BCOO.fromdense(X) Y = jnp.ones(4) Ysp = BCOO.fromdense(Y) func = self.sparsify(operator.matmul) # dot_general result_sparse = func(Xsp, Y) result_dense = operator.matmul(X, Y) self.assertAllClose(result_sparse, result_dense) # rdot_general result_sparse = func(Y, Xsp) result_dense = operator.matmul(Y, X) self.assertAllClose(result_sparse, result_dense) # spdot_general result_sparse = self.sparsify(operator.matmul)(Xsp, Ysp) result_dense = operator.matmul(X, Y) self.assertAllClose(result_sparse.todense(), result_dense) def testSparseAdd(self): x = BCOO.fromdense(jnp.arange(5)) y = BCOO.fromdense(2 * jnp.arange(5)) # Distinct indices out = self.sparsify(operator.add)(x, y) self.assertEqual(out.nse, 8) # uses concatenation. self.assertArraysEqual(out.todense(), 3 * jnp.arange(5)) # Shared indices – requires lower level call spenv = SparsifyEnv([x.indices, x.data, y.data]) spvalues = [ spenv.sparse(x.shape, data_ref=1, indices_ref=0), spenv.sparse(y.shape, data_ref=2, indices_ref=0) ] result = sparsify_raw(operator.add)(spenv, *spvalues) args_out, _ = result out, = spvalues_to_arrays(spenv, args_out) self.assertAllClose(out.todense(), x.todense() + y.todense()) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_{}_nbatch={}_ndense={}_unique_indices={}".format( jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense, unique_indices), "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense, "unique_indices": unique_indices } for shape in [(5, ), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] for dtype in (jtu.dtypes.integer + jtu.dtypes.floating + jtu.dtypes.complex) for n_batch in range(len(shape) + 1) for n_dense in range(len(shape) + 1 - n_batch) for unique_indices in [True, False])) def testSparseMul(self, shape, dtype, n_batch, n_dense, unique_indices): rng_sparse = rand_sparse(self.rng(), rand_method=jtu.rand_some_zero) x = BCOO.fromdense(rng_sparse(shape, dtype), n_batch=n_batch, n_dense=n_dense) # Scalar multiplication scalar = 2 y = self.sparsify(operator.mul)(x, scalar) self.assertArraysEqual(x.todense() * scalar, y.todense()) # Shared indices – requires lower level call spenv = SparsifyEnv([x.indices, x.data, y.data]) spvalues = [ spenv.sparse(x.shape, data_ref=1, indices_ref=0, unique_indices=unique_indices), spenv.sparse(y.shape, data_ref=2, indices_ref=0, unique_indices=unique_indices) ] result = sparsify_raw(operator.mul)(spenv, *spvalues) args_out, _ = result out, = spvalues_to_arrays(spenv, args_out) self.assertAllClose(out.todense(), x.todense() * y.todense()) def testSparseSubtract(self): x = BCOO.fromdense(3 * jnp.arange(5)) y = BCOO.fromdense(jnp.arange(5)) # Distinct indices out = self.sparsify(operator.sub)(x, y) self.assertEqual(out.nse, 8) # uses concatenation. self.assertArraysEqual(out.todense(), 2 * jnp.arange(5)) # Shared indices – requires lower level call spenv = SparsifyEnv([x.indices, x.data, y.data]) spvalues = [ spenv.sparse(x.shape, data_ref=1, indices_ref=0), spenv.sparse(y.shape, data_ref=2, indices_ref=0) ] result = sparsify_raw(operator.sub)(spenv, *spvalues) args_out, _ = result out, = spvalues_to_arrays(spenv, args_out) self.assertAllClose(out.todense(), x.todense() - y.todense()) def testSparseSum(self): x = jnp.arange(20).reshape(4, 5) xsp = BCOO.fromdense(x) def f(x): return x.sum(), x.sum(0), x.sum(1), x.sum((0, 1)) result_dense = f(x) result_sparse = self.sparsify(f)(xsp) assert len(result_dense) == len(result_sparse) for res_dense, res_sparse in zip(result_dense, result_sparse): if isinstance(res_sparse, BCOO): res_sparse = res_sparse.todense() self.assertArraysAllClose(res_dense, res_sparse) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_dimensions={}_nbatch={}_ndense={}".format( jtu.format_shape_dtype_string(shape, np.float32), dimensions, n_batch, n_dense), "shape": shape, "dimensions": dimensions, "n_batch": n_batch, "n_dense": n_dense } for shape, dimensions in [ [(1, ), (0, )], [(1, ), (-1, )], [(2, 1, 4), (1, )], [(2, 1, 3, 1), (1, )], [(2, 1, 3, 1), (1, 3)], [(2, 1, 3, 1), (3, )], ] for n_batch in range(len(shape) + 1) for n_dense in range(len(shape) - n_batch + 1))) def testSparseSqueeze(self, shape, dimensions, n_batch, n_dense): rng = jtu.rand_default(self.rng()) M_dense = rng(shape, np.float32) M_sparse = BCOO.fromdense(M_dense, n_batch=n_batch, n_dense=n_dense) func = self.sparsify(partial(lax.squeeze, dimensions=dimensions)) result_dense = func(M_dense) result_sparse = func(M_sparse).todense() self.assertAllClose(result_sparse, result_dense) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_shapes={shapes}_func={func}_nbatch={n_batch}", "shapes": shapes, "func": func, "n_batch": n_batch } for shapes, func, n_batch in [ ([(4, ), (4, )], "concatenate", 0), ([(4, ), (4, )], "stack", 0), ([(4, ), (4, )], "hstack", 0), ([(4, ), (4, )], "vstack", 0), ([(4, ), (4, )], "concatenate", 1), ([(4, ), (4, )], "stack", 1), ([(4, ), (4, )], "hstack", 1), ([(4, ), (4, )], "vstack", 1), ([(2, 4), (2, 4)], "stack", 0), ([(2, 4), (3, 4)], "vstack", 0), ([(2, 4), (2, 5)], "hstack", 0), ([(2, 4), (3, 4)], "vstack", 1), ([(2, 4), (2, 5)], "hstack", 1), ([(2, 4), (3, 4)], "vstack", 2), ([(2, 4), (2, 5)], "hstack", 2), ([(2, 4), (4, ), (3, 4)], "vstack", 0), ([(1, 4), (4, ), (1, 4)], "vstack", 0), ])) def testSparseConcatenate(self, shapes, func, n_batch): f = self.sparsify(getattr(jnp, func)) rng = jtu.rand_some_zero(self.rng()) arrs = [rng(shape, 'int32') for shape in shapes] sparrs = [BCOO.fromdense(arr, n_batch=n_batch) for arr in arrs] self.assertArraysEqual(f(arrs), f(sparrs).todense()) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_{shape}->{new_shape}_n_batch={n_batch}_n_dense={n_dense}", "shape": shape, "new_shape": new_shape, "n_batch": n_batch, "n_dense": n_dense } for shape, new_shape, n_batch, n_dense in [ [(6, ), (2, 3), 0, 0], [(1, 4), (2, 2), 0, 0], [(12, 2), (2, 3, 4), 0, 0], [(1, 3, 2), (2, 3), 0, 0], [(1, 6), (2, 3, 1), 0, 0], [(2, 3, 4), (3, 8), 0, 0], [(2, 3, 4), (1, 2, 12), 1, 0], [(2, 3, 4), (6, 2, 2), 2, 0], ])) def testSparseReshapeMethod(self, shape, new_shape, n_batch, n_dense): rng = jtu.rand_some_zero(self.rng()) arr = rng(shape, 'int32') arr_sparse = BCOO.fromdense(arr, n_batch=n_batch, n_dense=n_dense) arr2 = arr.reshape(new_shape) arr2_sparse = arr_sparse.reshape(new_shape) self.assertArraysEqual(arr2, arr2_sparse.todense()) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_{shape}->{new_shape}_n_batch={n_batch}_n_dense={n_dense}_dimensions={dimensions}", "shape": shape, "new_shape": new_shape, "n_batch": n_batch, "n_dense": n_dense, "dimensions": dimensions } for shape, new_shape, n_batch, n_dense, dimensions in [ [(2, 3, 4), (24, ), 0, 0, None], [(2, 3, 4), (24, ), 0, 0, (0, 1, 2)], [(2, 3, 4), (24, ), 0, 0, (0, 2, 1)], [(2, 3, 4), (24, ), 0, 0, (1, 0, 2)], [(2, 3, 4), (24, ), 0, 0, (1, 2, 0)], [(2, 3, 4), (24, ), 0, 0, (2, 0, 1)], [(2, 3, 4), (24, ), 0, 0, (2, 1, 0)], [(4, 2, 3), (2, 2, 6), 1, 0, (0, 1, 2)], [(4, 2, 3), (2, 2, 6), 1, 0, (0, 2, 1)], [(2, 3, 4), (6, 4), 2, 0, (0, 1, 2)], [(2, 3, 4), (6, 4), 2, 0, (1, 0, 2)], ])) def testSparseReshapeWithDimensions(self, shape, new_shape, n_batch, n_dense, dimensions): rng = jtu.rand_some_zero(self.rng()) arr = rng(shape, 'int32') arr_sparse = BCOO.fromdense(arr, n_batch=n_batch, n_dense=n_dense) f = self.sparsify( lambda x: lax.reshape(x, new_shape, dimensions=dimensions)) arr2 = f(arr) arr2_sparse = f(arr_sparse) self.assertArraysEqual(arr2, arr2_sparse.todense()) def testSparseWhileLoop(self): def cond_fun(params): i, A = params return i < 5 def body_fun(params): i, A = params return i + 1, 2 * A def f(A): return lax.while_loop(cond_fun, body_fun, (0, A)) A = jnp.arange(4) out_dense = f(A) Asp = BCOO.fromdense(A) out_sparse = self.sparsify(f)(Asp) self.assertEqual(len(out_dense), 2) self.assertEqual(len(out_sparse), 2) self.assertArraysEqual(out_dense[0], out_dense[0]) self.assertArraysEqual(out_dense[1], out_sparse[1].todense()) def testSparseWhileLoopDuplicateIndices(self): def cond_fun(params): i, A, B = params return i < 5 def body_fun(params): i, A, B = params # TODO(jakevdp): track shared indices through while loop & use this # version of the test, which requires shared indices in order for # the nse of the result to remain the same. # return i + 1, A, A + B # This version is fine without shared indices, and tests that we're # flattening non-shared indices consistently. return i + 1, B, A def f(A): return lax.while_loop(cond_fun, body_fun, (0, A, A)) A = jnp.arange(4).reshape((2, 2)) out_dense = f(A) Asp = BCOO.fromdense(A) out_sparse = self.sparsify(f)(Asp) self.assertEqual(len(out_dense), 3) self.assertEqual(len(out_sparse), 3) self.assertArraysEqual(out_dense[0], out_dense[0]) self.assertArraysEqual(out_dense[1], out_sparse[1].todense()) self.assertArraysEqual(out_dense[2], out_sparse[2].todense()) def testSparsifyDenseXlaCall(self): # Test handling of dense xla_call within jaxpr interpreter. out = self.sparsify(jit(lambda x: x + 1))(0.0) self.assertEqual(out, 1.0) def testSparsifySparseXlaCall(self): # Test sparse lowering of XLA call def func(M): return 2 * M M = jnp.arange(6).reshape(2, 3) Msp = BCOO.fromdense(M) out_dense = func(M) out_sparse = self.sparsify(jit(func))(Msp) self.assertArraysEqual(out_dense, out_sparse.todense()) def testSparseForiLoop(self): def func(M, x): body_fun = lambda i, val: (M @ val) / M.shape[1] return lax.fori_loop(0, 2, body_fun, x) x = jnp.arange(5.0) M = jnp.arange(25).reshape(5, 5) M_bcoo = BCOO.fromdense(M) with jax.numpy_dtype_promotion('standard'): result_dense = func(M, x) result_sparse = self.sparsify(func)(M_bcoo, x) self.assertArraysAllClose(result_dense, result_sparse) def testSparseCondSimple(self): def func(x): return lax.cond(False, lambda x: x, lambda x: 2 * x, x) x = jnp.arange(5.0) result_dense = func(x) x_bcoo = BCOO.fromdense(x) result_sparse = self.sparsify(func)(x_bcoo) self.assertArraysAllClose(result_dense, result_sparse.todense()) def testSparseCondMismatchError(self): @self.sparsify def func(x, y): return lax.cond(False, lambda x: x[0], lambda x: x[1], (x, y)) x = jnp.arange(5.0) y = jnp.arange(5.0) x_bcoo = BCOO.fromdense(x) y_bcoo = BCOO.fromdense(y) func(x, y) # No error func(x_bcoo, y_bcoo) # No error with self.assertRaisesRegex( TypeError, "sparsified true_fun and false_fun output.*"): func(x_bcoo, y) def testToDense(self): M = jnp.arange(4) Msp = BCOO.fromdense(M) @self.sparsify def func(M): return todense(M) + 1 self.assertArraysEqual(func(M), M + 1) self.assertArraysEqual(func(Msp), M + 1) self.assertArraysEqual(jit(func)(M), M + 1) self.assertArraysEqual(jit(func)(Msp), M + 1) def testWeakTypes(self): # Regression test for https://github.com/google/jax/issues/8267 M = jnp.arange(12, dtype='int32').reshape(3, 4) Msp = BCOO.fromdense(M) self.assertArraysEqual( operator.mul(2, M), self.sparsify(operator.mul)(2, Msp).todense(), check_dtypes=True, )
class ImageTest(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_target={}_method={}_antialias={}".format( jtu.format_shape_dtype_string(image_shape, dtype), jtu.format_shape_dtype_string(target_shape, dtype), method, antialias), "dtype": dtype, "image_shape": image_shape, "target_shape": target_shape, "method": method, "antialias": antialias} for dtype in float_dtypes for target_shape, image_shape in itertools.combinations_with_replacement( [[2, 3, 2, 4], [2, 6, 4, 4], [2, 33, 17, 4], [2, 50, 38, 4]], 2) for method in ["nearest", "bilinear", "lanczos3", "lanczos5", "bicubic"] for antialias in [False, True])) @unittest.skipIf(not tf, "Test requires TensorFlow") def testResizeAgainstTensorFlow(self, dtype, image_shape, target_shape, method, antialias): # TODO(phawkins): debug this. There is a small mismatch between TF and JAX # for some cases of non-antialiased bicubic downscaling; we would expect # exact equality. if method == "bicubic" and any(x < y for x, y in zip(target_shape, image_shape)): raise unittest.SkipTest("non-antialiased bicubic downscaling mismatch") rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(image_shape, dtype),) def tf_fn(x): out = tf.image.resize( x.astype(np.float64), tf.constant(target_shape[1:-1]), method=method, antialias=antialias).numpy().astype(dtype) return out jax_fn = partial(image.resize, shape=target_shape, method=method, antialias=antialias) self._CheckAgainstNumpy(tf_fn, jax_fn, args_maker, check_dtypes=True, tol={np.float16: 2e-2, jnp.bfloat16: 1e-1, np.float32: 1e-4, np.float64: 1e-4}) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_target={}_method={}".format( jtu.format_shape_dtype_string(image_shape, dtype), jtu.format_shape_dtype_string(target_shape, dtype), method), "dtype": dtype, "image_shape": image_shape, "target_shape": target_shape, "method": method} for dtype in [np.float32] for target_shape, image_shape in itertools.combinations_with_replacement( [[3, 2], [6, 4], [33, 17], [50, 39]], 2) for method in ["nearest", "bilinear", "lanczos3", "bicubic"])) @unittest.skipIf(not PIL_Image, "Test requires PIL") def testResizeAgainstPIL(self, dtype, image_shape, target_shape, method): rng = jtu.rand_uniform(self.rng()) args_maker = lambda: (rng(image_shape, dtype),) def pil_fn(x): pil_methods = { "nearest": PIL_Image.NEAREST, "bilinear": PIL_Image.BILINEAR, "bicubic": PIL_Image.BICUBIC, "lanczos3": PIL_Image.LANCZOS, } img = PIL_Image.fromarray(x.astype(np.float32)) out = np.asarray(img.resize(target_shape[::-1], pil_methods[method]), dtype=dtype) return out jax_fn = partial(image.resize, shape=target_shape, method=method, antialias=True) self._CheckAgainstNumpy(pil_fn, jax_fn, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_target={}_method={}".format( jtu.format_shape_dtype_string(image_shape, dtype), jtu.format_shape_dtype_string(target_shape, dtype), method), "dtype": dtype, "image_shape": image_shape, "target_shape": target_shape, "method": method} for dtype in inexact_dtypes for image_shape, target_shape in [ ([3, 1, 2], [6, 1, 4]), ([1, 3, 2, 1], [1, 6, 4, 1]), ] for method in ["nearest", "linear", "lanczos3", "lanczos5", "cubic"])) def testResizeUp(self, dtype, image_shape, target_shape, method): data = [64, 32, 32, 64, 50, 100] expected_data = {} expected_data["nearest"] = [ 64.0, 64.0, 32.0, 32.0, 64.0, 64.0, 32.0, 32.0, 32.0, 32.0, 64.0, 64.0, 32.0, 32.0, 64.0, 64.0, 50.0, 50.0, 100.0, 100.0, 50.0, 50.0, 100.0, 100.0 ] expected_data["linear"] = [ 64.0, 56.0, 40.0, 32.0, 56.0, 52.0, 44.0, 40.0, 40.0, 44.0, 52.0, 56.0, 36.5, 45.625, 63.875, 73.0, 45.5, 56.875, 79.625, 91.0, 50.0, 62.5, 87.5, 100.0 ] expected_data["lanczos3"] = [ 75.8294, 59.6281, 38.4313, 22.23, 60.6851, 52.0037, 40.6454, 31.964, 35.8344, 41.0779, 47.9383, 53.1818, 24.6968, 43.0769, 67.1244, 85.5045, 35.7939, 56.4713, 83.5243, 104.2017, 44.8138, 65.1949, 91.8603, 112.2413 ] expected_data["lanczos5"] = [ 77.5699, 60.0223, 40.6694, 23.1219, 61.8253, 51.2369, 39.5593, 28.9709, 35.7438, 40.8875, 46.5604, 51.7041, 21.5942, 43.5299, 67.7223, 89.658, 32.1213, 56.784, 83.984, 108.6467, 44.5802, 66.183, 90.0082, 111.6109 ] expected_data["cubic"] = [ 70.1453, 59.0252, 36.9748, 25.8547, 59.3195, 53.3386, 41.4789, 35.4981, 36.383, 41.285, 51.0051, 55.9071, 30.2232, 42.151, 65.8032, 77.731, 41.6492, 55.823, 83.9288, 98.1026, 47.0363, 62.2744, 92.4903, 107.7284 ] x = np.array(data, dtype=dtype).reshape(image_shape) output = image.resize(x, target_shape, method) expected = np.array(expected_data[method], dtype=dtype).reshape(target_shape) self.assertAllClose(output, expected, atol=1e-04) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_target={}_method={}_antialias={}".format( jtu.format_shape_dtype_string(image_shape, dtype), jtu.format_shape_dtype_string(target_shape, dtype), method, antialias), "dtype": dtype, "image_shape": image_shape, "target_shape": target_shape, "method": method, "antialias": antialias} for dtype in [np.float32] for target_shape, image_shape in itertools.combinations_with_replacement( [[2, 3, 2, 4], [2, 6, 4, 4], [2, 33, 17, 4], [2, 50, 38, 4]], 2) for method in ["bilinear", "lanczos3", "lanczos5", "bicubic"] for antialias in [False, True])) def testResizeGradients(self, dtype, image_shape, target_shape, method, antialias): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(image_shape, dtype),) jax_fn = partial(image.resize, shape=target_shape, method=method, antialias=antialias) jtu.check_grads(jax_fn, args_maker(), order=2, rtol=1e-2, eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_target={}_method={}_antialias={}".format( jtu.format_shape_dtype_string(image_shape, dtype), jtu.format_shape_dtype_string(target_shape, dtype), method, antialias), "dtype": dtype, "image_shape": image_shape, "target_shape": target_shape, "method": method, "antialias": antialias} for dtype in [np.float32] for image_shape, target_shape in [ ([1], [0]), ([5, 5], [5, 0]), ([5, 5], [0, 1]), ([5, 5], [0, 0]) ] for method in ["nearest", "linear", "lanczos3", "lanczos5", "cubic"] for antialias in [False, True])) def testResizeEmpty(self, dtype, image_shape, target_shape, method, antialias): # Regression test for https://github.com/google/jax/issues/7586 image = np.ones(image_shape, dtype) out = jax.image.resize(image, shape=target_shape, method=method, antialias=antialias) self.assertArraysEqual(out, jnp.zeros(target_shape, dtype)) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_target={}_method={}".format( jtu.format_shape_dtype_string(image_shape, dtype), jtu.format_shape_dtype_string(target_shape, dtype), method), "dtype": dtype, "image_shape": image_shape, "target_shape": target_shape, "scale": scale, "translation": translation, "method": method} for dtype in inexact_dtypes for image_shape, target_shape, scale, translation in [ ([3, 1, 2], [6, 1, 4], [2.0, 1.0, 2.0], [1.0, 0.0, -1.0]), ([1, 3, 2, 1], [1, 6, 4, 1], [1.0, 2.0, 2.0, 1.0], [0.0, 1.0, -1.0, 0.0])] for method in ["linear", "lanczos3", "lanczos5", "cubic"])) def testScaleAndTranslateUp(self, dtype, image_shape, target_shape, scale, translation, method): data = [64, 32, 32, 64, 50, 100] # Note zeros occur in the output because the sampling location is outside # the boundaries of the input image. expected_data = {} expected_data["linear"] = [ 0.0, 0.0, 0.0, 0.0, 56.0, 40.0, 32.0, 0.0, 52.0, 44.0, 40.0, 0.0, 44.0, 52.0, 56.0, 0.0, 45.625, 63.875, 73.0, 0.0, 56.875, 79.625, 91.0, 0.0 ] expected_data["lanczos3"] = [ 0.0, 0.0, 0.0, 0.0, 59.6281, 38.4313, 22.23, 0.0, 52.0037, 40.6454, 31.964, 0.0, 41.0779, 47.9383, 53.1818, 0.0, 43.0769, 67.1244, 85.5045, 0.0, 56.4713, 83.5243, 104.2017, 0.0 ] expected_data["lanczos5"] = [ 0.0, 0.0, 0.0, 0.0, 60.0223, 40.6694, 23.1219, 0.0, 51.2369, 39.5593, 28.9709, 0.0, 40.8875, 46.5604, 51.7041, 0.0, 43.5299, 67.7223, 89.658, 0.0, 56.784, 83.984, 108.6467, 0.0 ] expected_data["cubic"] = [ 0.0, 0.0, 0.0, 0.0, 59.0252, 36.9748, 25.8547, 0.0, 53.3386, 41.4789, 35.4981, 0.0, 41.285, 51.0051, 55.9071, 0.0, 42.151, 65.8032, 77.731, 0.0, 55.823, 83.9288, 98.1026, 0.0 ] x = np.array(data, dtype=dtype).reshape(image_shape) # Should we test different float types here? scale_a = jnp.array(scale, dtype=jnp.float32) translation_a = jnp.array(translation, dtype=jnp.float32) output = image.scale_and_translate(x, target_shape, range(len(image_shape)), scale_a, translation_a, method) expected = np.array( expected_data[method], dtype=dtype).reshape(target_shape) self.assertAllClose(output, expected, atol=2e-03) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_dtype={}_method={}_antialias={}".format( jtu.dtype_str(dtype), method, antialias), "dtype": dtype, "method": method, "antialias": antialias} for dtype in inexact_dtypes for method in ["linear", "lanczos3", "lanczos5", "cubic"] for antialias in [True, False])) def testScaleAndTranslateDown(self, dtype, method, antialias): image_shape = [1, 6, 7, 1] target_shape = [1, 3, 3, 1] data = [ 51, 38, 32, 89, 41, 21, 97, 51, 33, 87, 89, 34, 21, 97, 43, 25, 25, 92, 41, 11, 84, 11, 55, 111, 23, 99, 50, 83, 13, 92, 52, 43, 90, 43, 14, 89, 71, 32, 23, 23, 35, 93 ] if antialias: expected_data = {} expected_data["linear"] = [ 43.5372, 59.3694, 53.6907, 49.3221, 56.8168, 55.4849, 0, 0, 0 ] expected_data["lanczos3"] = [ 43.2884, 57.9091, 54.6439, 48.5856, 58.2427, 53.7551, 0, 0, 0 ] expected_data["lanczos5"] = [ 43.9209, 57.6360, 54.9575, 48.9272, 58.1865, 53.1948, 0, 0, 0 ] expected_data["cubic"] = [ 42.9935, 59.1687, 54.2138, 48.2640, 58.2678, 54.4088, 0, 0, 0 ] else: expected_data = {} expected_data["linear"] = [ 43.6071, 89, 59, 37.1785, 27.2857, 58.3571, 0, 0, 0 ] expected_data["lanczos3"] = [ 44.1390, 87.8786, 63.3111, 25.1161, 20.8795, 53.6165, 0, 0, 0 ] expected_data["lanczos5"] = [ 44.8835, 85.5896, 66.7231, 16.9983, 19.8891, 47.1446, 0, 0, 0 ] expected_data["cubic"] = [ 43.6426, 88.8854, 60.6638, 31.4685, 22.1204, 58.3457, 0, 0, 0 ] x = np.array(data, dtype=dtype).reshape(image_shape) expected = np.array( expected_data[method], dtype=dtype).reshape(target_shape) scale_a = jnp.array([1.0, 0.35, 0.4, 1.0], dtype=jnp.float32) translation_a = jnp.array([0.0, 0.2, 0.1, 0.0], dtype=jnp.float32) output = image.scale_and_translate( x, target_shape, (0,1,2,3), scale_a, translation_a, method, antialias=antialias) self.assertAllClose(output, expected, atol=2e-03) # Tests that running with just a subset of dimensions that have non-trivial # scale and translation. output = image.scale_and_translate( x, target_shape, (1,2), scale_a[1:3], translation_a[1:3], method, antialias=antialias) self.assertAllClose(output, expected, atol=2e-03) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "antialias={}".format(antialias), "antialias": antialias} for antialias in [True, False])) def testScaleAndTranslateJITs(self, antialias): image_shape = [1, 6, 7, 1] target_shape = [1, 3, 3, 1] data = [ 51, 38, 32, 89, 41, 21, 97, 51, 33, 87, 89, 34, 21, 97, 43, 25, 25, 92, 41, 11, 84, 11, 55, 111, 23, 99, 50, 83, 13, 92, 52, 43, 90, 43, 14, 89, 71, 32, 23, 23, 35, 93 ] if antialias: expected_data = [ 43.5372, 59.3694, 53.6907, 49.3221, 56.8168, 55.4849, 0, 0, 0 ] else: expected_data = [43.6071, 89, 59, 37.1785, 27.2857, 58.3571, 0, 0, 0] x = jnp.array(data, dtype=jnp.float32).reshape(image_shape) expected = jnp.array(expected_data, dtype=jnp.float32).reshape(target_shape) scale_a = jnp.array([1.0, 0.35, 0.4, 1.0], dtype=jnp.float32) translation_a = jnp.array([0.0, 0.2, 0.1, 0.0], dtype=jnp.float32) def jit_fn(in_array, s, t): return jax.image.scale_and_translate( in_array, target_shape, (0, 1, 2, 3), s, t, "linear", antialias, precision=jax.lax.Precision.HIGHEST) output = jax.jit(jit_fn)(x, scale_a, translation_a) self.assertAllClose(output, expected, atol=2e-03) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "antialias={}".format(antialias), "antialias": antialias} for antialias in [True, False])) def testScaleAndTranslateGradFinite(self, antialias): image_shape = [1, 6, 7, 1] target_shape = [1, 3, 3, 1] data = [ 51, 38, 32, 89, 41, 21, 97, 51, 33, 87, 89, 34, 21, 97, 43, 25, 25, 92, 41, 11, 84, 11, 55, 111, 23, 99, 50, 83, 13, 92, 52, 43, 90, 43, 14, 89, 71, 32, 23, 23, 35, 93 ] x = jnp.array(data, dtype=jnp.float32).reshape(image_shape) scale_a = jnp.array([1.0, 0.35, 0.4, 1.0], dtype=jnp.float32) translation_a = jnp.array([0.0, 0.2, 0.1, 0.0], dtype=jnp.float32) def scale_fn(s): return jnp.sum(jax.image.scale_and_translate( x, target_shape, (0, 1, 2, 3), s, translation_a, "linear", antialias, precision=jax.lax.Precision.HIGHEST)) scale_out = jax.grad(scale_fn)(scale_a) self.assertTrue(jnp.all(jnp.isfinite(scale_out))) def translate_fn(t): return jnp.sum(jax.image.scale_and_translate( x, target_shape, (0, 1, 2, 3), scale_a, t, "linear", antialias, precision=jax.lax.Precision.HIGHEST)) translate_out = jax.grad(translate_fn)(translation_a) self.assertTrue(jnp.all(jnp.isfinite(translate_out))) def testScaleAndTranslateNegativeDims(self): data = jnp.full((3, 3), 0.5) actual = jax.image.scale_and_translate( data, (5, 5), (-2, -1), jnp.ones(2), jnp.zeros(2), "linear") expected = jax.image.scale_and_translate( data, (5, 5), (0, 1), jnp.ones(2), jnp.zeros(2), "linear") self.assertAllClose(actual, expected) def testResizeWithUnusualShapes(self): x = jnp.ones((3, 4)) # Array shapes are accepted self.assertEqual((10, 17), jax.image.resize(x, jnp.array((10, 17)), "nearest").shape) with self.assertRaises(TypeError): # Fractional shapes are disallowed jax.image.resize(x, [10.5, 17], "bicubic")
class AnnTest(jtu.JaxTestCase): @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_qy={}_db={}_k={}_recall={}".format( jtu.format_shape_dtype_string(qy_shape, dtype), jtu.format_shape_dtype_string(db_shape, dtype), k, recall), "qy_shape": qy_shape, "db_shape": db_shape, "dtype": dtype, "k": k, "recall": recall } for qy_shape in [(200, 128), (128, 128)] for db_shape in [(128, 500), (128, 3000)] for dtype in jtu.dtypes.all_floating for k in [1, 10, 50] for recall in [0.9, 0.95])) def test_approx_max_k(self, qy_shape, db_shape, dtype, k, recall): rng = jtu.rand_default(self.rng()) qy = rng(qy_shape, dtype) db = rng(db_shape, dtype) scores = lax.dot(qy, db) _, gt_args = lax.top_k(scores, k) _, ann_args = lax.approx_max_k(scores, k, recall_target=recall) self.assertEqual(k, len(ann_args[0])) ann_recall = compute_recall(np.asarray(ann_args), np.asarray(gt_args)) self.assertGreater(ann_recall, recall) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_qy={}_db={}_k={}_recall={}".format( jtu.format_shape_dtype_string(qy_shape, dtype), jtu.format_shape_dtype_string(db_shape, dtype), k, recall), "qy_shape": qy_shape, "db_shape": db_shape, "dtype": dtype, "k": k, "recall": recall } for qy_shape in [(200, 128), (128, 128)] for db_shape in [(128, 500), (128, 3000)] for dtype in jtu.dtypes.all_floating for k in [1, 10, 50] for recall in [0.9, 0.95])) def test_approx_min_k(self, qy_shape, db_shape, dtype, k, recall): rng = jtu.rand_default(self.rng()) qy = rng(qy_shape, dtype) db = rng(db_shape, dtype) scores = lax.dot(qy, db) _, gt_args = lax.top_k(-scores, k) _, ann_args = lax.approx_min_k(scores, k, recall_target=recall) self.assertEqual(k, len(ann_args[0])) ann_recall = compute_recall(np.asarray(ann_args), np.asarray(gt_args)) self.assertGreater(ann_recall, recall) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_k={}_max_k={}".format( jtu.format_shape_dtype_string(shape, dtype), k, is_max_k), "shape": shape, "dtype": dtype, "k": k, "is_max_k": is_max_k } for dtype in [np.float32] for shape in [(4, ), (5, 5), (2, 1, 4)] for k in [1, 3] for is_max_k in [True, False])) def test_autodiff(self, shape, dtype, k, is_max_k): vals = np.arange(prod(shape), dtype=dtype) vals = self.rng().permutation(vals).reshape(shape) if is_max_k: fn = lambda vs: lax.approx_max_k(vs, k=k)[0] else: fn = lambda vs: lax.approx_min_k(vs, k=k)[0] jtu.check_grads(fn, (vals, ), 2, ["fwd", "rev"], eps=1e-2) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_qy={}_db={}_k={}_recall={}".format( jtu.format_shape_dtype_string(qy_shape, dtype), jtu.format_shape_dtype_string(db_shape, dtype), k, recall), "qy_shape": qy_shape, "db_shape": db_shape, "dtype": dtype, "k": k, "recall": recall } for qy_shape in [(200, 128), (128, 128)] for db_shape in [(2048, 128)] for dtype in jtu.dtypes.all_floating for k in [1, 10] for recall in [0.9, 0.95])) def test_pmap(self, qy_shape, db_shape, dtype, k, recall): num_devices = jax.device_count() rng = jtu.rand_default(self.rng()) qy = rng(qy_shape, dtype) db = rng(db_shape, dtype) db_size = db.shape[0] gt_scores = lax.dot_general(qy, db, (([1], [1]), ([], []))) _, gt_args = lax.top_k(-gt_scores, k) # negate the score to get min-k db_per_device = db_size // num_devices sharded_db = db.reshape(num_devices, db_per_device, 128) db_offsets = np.arange(num_devices, dtype=np.int32) * db_per_device def parallel_topk(qy, db, db_offset): scores = lax.dot_general(qy, db, (([1], [1]), ([], []))) ann_vals, ann_args = lax.approx_min_k( scores, k=k, reduction_dimension=1, recall_target=recall, reduction_input_size_override=db_size, aggregate_to_topk=False) return (ann_vals, ann_args + db_offset) # shape = qy_size, num_devices, approx_dp ann_vals, ann_args = jax.pmap(parallel_topk, in_axes=(None, 0, 0), out_axes=(1, 1))(qy, sharded_db, db_offsets) # collapse num_devices and approx_dp ann_vals = lax.collapse(ann_vals, 1, 3) ann_args = lax.collapse(ann_args, 1, 3) ann_vals, ann_args = lax.sort_key_val(ann_vals, ann_args, dimension=1) ann_args = lax.slice_in_dim(ann_args, start_index=0, limit_index=k, axis=1) ann_recall = compute_recall(np.asarray(ann_args), np.asarray(gt_args)) self.assertGreater(ann_recall, recall)
class NdimageTest(jtu.JaxTestCase): @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_{}_coordinates={}_order={}_mode={}_cval={}_impl={}_round={}". format( jtu.format_shape_dtype_string(shape, dtype), jtu.format_shape_dtype_string(coords_shape, coords_dtype), order, mode, cval, impl, round_), "rng_factory": rng_factory, "shape": shape, "coords_shape": coords_shape, "dtype": dtype, "coords_dtype": coords_dtype, "order": order, "mode": mode, "cval": cval, "impl": impl, "round_": round_ } for shape in [(5, ), (3, 4), (3, 4, 5)] for coords_shape in [(7, ), (2, 3, 4)] for dtype in float_dtypes + int_dtypes for coords_dtype in float_dtypes for order in [0, 1] for mode in ['wrap', 'constant', 'nearest', 'mirror', 'reflect'] for cval in ([0, -1] if mode == 'constant' else [0]) for impl, rng_factory in [ ("original", partial(jtu.rand_uniform, low=0, high=1)), ("fixed", partial(jtu.rand_uniform, low=-0.75, high=1.75)), ] for round_ in [True, False])) def testMapCoordinates(self, shape, dtype, coords_shape, coords_dtype, order, mode, cval, impl, round_, rng_factory): def args_maker(): x = np.arange(prod(shape), dtype=dtype).reshape(shape) coords = [(size - 1) * rng(coords_shape, coords_dtype) for size in shape] if round_: coords = [c.round().astype(int) for c in coords] return x, coords rng = rng_factory(self.rng()) lsp_op = lambda x, c: lsp_ndimage.map_coordinates( x, c, order=order, mode=mode, cval=cval) impl_fun = (osp_ndimage.map_coordinates if impl == "original" else _fixed_ref_map_coordinates) osp_op = lambda x, c: impl_fun(x, c, order=order, mode=mode, cval=cval) if dtype in float_dtypes: epsilon = max( dtypes.finfo(dtypes.canonicalize_dtype(d)).eps for d in [dtype, coords_dtype]) self._CheckAgainstNumpy(osp_op, lsp_op, args_maker, tol=100 * epsilon) else: self._CheckAgainstNumpy(osp_op, lsp_op, args_maker, tol=0) def testMapCoordinatesErrors(self): x = np.arange(5.0) c = [np.linspace(0, 5, num=3)] with self.assertRaisesRegex(NotImplementedError, 'requires order<=1'): lsp_ndimage.map_coordinates(x, c, order=2) with self.assertRaisesRegex(NotImplementedError, 'does not yet support mode'): lsp_ndimage.map_coordinates(x, c, order=1, mode='grid-wrap') with self.assertRaisesRegex(ValueError, 'sequence of length'): lsp_ndimage.map_coordinates(x, [c, c], order=1) def testMapCoordinateDocstring(self): self.assertIn("Only nearest neighbor", lsp_ndimage.map_coordinates.__doc__) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_{np.dtype(dtype)}_order={order}", "dtype": dtype, "order": order } for dtype in float_dtypes + int_dtypes for order in [0, 1])) def testMapCoordinatesRoundHalf(self, dtype, order): x = np.arange(-3, 3, dtype=dtype) c = np.array([[.5, 1.5, 2.5, 3.5]]) def args_maker(): return x, c lsp_op = lambda x, c: lsp_ndimage.map_coordinates(x, c, order=order) osp_op = lambda x, c: osp_ndimage.map_coordinates(x, c, order=order) self._CheckAgainstNumpy(osp_op, lsp_op, args_maker) def testContinuousGradients(self): # regression test for https://github.com/google/jax/issues/3024 def loss(delta): x = np.arange(100.0) border = 10 indices = np.arange(x.size) + delta # linear interpolation of the linear function y=x should be exact shifted = lsp_ndimage.map_coordinates(x, [indices], order=1) return ((x - shifted)**2)[border:-border].mean() # analytical gradient of (x - (x - delta)) ** 2 is 2 * delta self.assertAllClose(grad(loss)(0.5), 1.0, check_dtypes=False) self.assertAllClose(grad(loss)(1.0), 2.0, check_dtypes=False)
class LaxBackedScipyStatsTests(jtu.JaxTestCase): """Tests for LAX-backed scipy.stats implementations""" @genNamedParametersNArgs(3) def testPoissonLogPmf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.poisson.logpmf lax_fun = lsp_stats.poisson.logpmf def args_maker(): k, mu, loc = map(rng, shapes, dtypes) k = np.floor(k) # clipping to ensure that rate parameter is strictly positive mu = np.clip(np.abs(mu), a_min=0.1, a_max=None) loc = np.floor(loc) return [k, mu, loc] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14}) @genNamedParametersNArgs(3) def testPoissonPmf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.poisson.pmf lax_fun = lsp_stats.poisson.pmf def args_maker(): k, mu, loc = map(rng, shapes, dtypes) k = np.floor(k) # clipping to ensure that rate parameter is strictly positive mu = np.clip(np.abs(mu), a_min=0.1, a_max=None) loc = np.floor(loc) return [k, mu, loc] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testPoissonCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.poisson.cdf lax_fun = lsp_stats.poisson.cdf def args_maker(): k, mu, loc = map(rng, shapes, dtypes) # clipping to ensure that rate parameter is strictly positive mu = np.clip(np.abs(mu), a_min=0.1, a_max=None) return [k, mu, loc] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testBernoulliLogPmf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.bernoulli.logpmf lax_fun = lsp_stats.bernoulli.logpmf def args_maker(): x, logit, loc = map(rng, shapes, dtypes) x = np.floor(x) p = expit(logit) loc = np.floor(loc) return [x, p, loc] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testGeomLogPmf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.geom.logpmf lax_fun = lsp_stats.geom.logpmf def args_maker(): x, logit, loc = map(rng, shapes, dtypes) x = np.floor(x) p = expit(logit) loc = np.floor(loc) return [x, p, loc] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(5) def testBetaLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.beta.logpdf lax_fun = lsp_stats.beta.logpdf def args_maker(): x, a, b, loc, scale = map(rng, shapes, dtypes) return [x, a, b, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker, rtol={ np.float32: 2e-3, np.float64: 1e-4 }) def testBetaLogPdfZero(self): # Regression test for https://github.com/google/jax/issues/7645 a = b = 1. x = np.array([0., 1.]) self.assertAllClose(osp_stats.beta.pdf(x, a, b), lsp_stats.beta.pdf(x, a, b), atol=1E-6) @genNamedParametersNArgs(3) def testCauchyLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.cauchy.logpdf lax_fun = lsp_stats.cauchy.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": jtu.format_test_name_suffix("", [x_shape, alpha_shape], dtypes), "shapes": [x_shape, alpha_shape], "dtypes": dtypes } for x_shape in one_and_two_dim_shapes for alpha_shape in [( x_shape[0], ), ( x_shape[0] + 1, )] for dtypes in itertools.combinations_with_replacement( jtu.dtypes.floating, 2))) def testDirichletLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) def _normalize(x, alpha): x_norm = x.sum(0) + (0.0 if x.shape[0] == alpha.shape[0] else 0.1) return (x / x_norm).astype(x.dtype), alpha def lax_fun(x, alpha): return lsp_stats.dirichlet.logpdf(*_normalize(x, alpha)) def scipy_fun(x, alpha): # scipy validates the x normalization using float64 arithmetic, so we must # cast x to float64 before normalization to ensure this passes. x, alpha = _normalize(x.astype('float64'), alpha) result = osp_stats.dirichlet.logpdf(x, alpha) # if x.shape is (N, 1), scipy flattens the output, while JAX returns arrays # of a consistent rank. This check ensures the results have the same shape. return result if x.ndim == 1 else np.atleast_1d(result) def args_maker(): # Don't normalize here, because we want normalization to happen at 64-bit # precision in the scipy version. x, alpha = map(rng, shapes, dtypes) return x, alpha tol = {np.float32: 1E-3, np.float64: 1e-5} self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=tol) self._CompileAndCheck(lax_fun, args_maker, atol=tol, rtol=tol) @genNamedParametersNArgs(3) def testExponLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.expon.logpdf lax_fun = lsp_stats.expon.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(4) def testGammaLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.gamma.logpdf lax_fun = lsp_stats.gamma.logpdf def args_maker(): x, a, loc, scale = map(rng, shapes, dtypes) return [x, a, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker) def testGammaLogPdfZero(self): # Regression test for https://github.com/google/jax/issues/7256 self.assertAllClose(osp_stats.gamma.pdf(0.0, 1.0), lsp_stats.gamma.pdf(0.0, 1.0), atol=1E-6) @genNamedParametersNArgs(4) def testNBinomLogPmf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.nbinom.logpmf lax_fun = lsp_stats.nbinom.logpmf def args_maker(): k, n, logit, loc = map(rng, shapes, dtypes) k = np.floor(np.abs(k)) n = np.ceil(np.abs(n)) p = expit(logit) loc = np.floor(loc) return [k, n, p, loc] tol = {np.float32: 1e-6, np.float64: 1e-8} self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol) @genNamedParametersNArgs(3) def testLaplaceLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.laplace.logpdf lax_fun = lsp_stats.laplace.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(scale, a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testLaplaceCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.laplace.cdf lax_fun = lsp_stats.laplace.cdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # ensure that scale is not too low scale = np.clip(scale, a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol={ np.float32: 1e-5, np.float64: 1e-6 }) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(1) def testLogisticCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.cdf lax_fun = lsp_stats.logistic.cdf def args_maker(): return list(map(rng, shapes, dtypes)) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-6) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(1) def testLogisticLogpdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.logpdf lax_fun = lsp_stats.logistic.logpdf def args_maker(): return list(map(rng, shapes, dtypes)) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(1) def testLogisticPpf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.ppf lax_fun = lsp_stats.logistic.ppf def args_maker(): return list(map(rng, shapes, dtypes)) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(1) def testLogisticSf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.sf lax_fun = lsp_stats.logistic.sf def args_maker(): return list(map(rng, shapes, dtypes)) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-6) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testNormLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.norm.logpdf lax_fun = lsp_stats.norm.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testNormLogCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.norm.logcdf lax_fun = lsp_stats.norm.logcdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testNormCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.norm.cdf lax_fun = lsp_stats.norm.cdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-6) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testNormPpf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.norm.ppf lax_fun = lsp_stats.norm.ppf def args_maker(): q, loc, scale = map(rng, shapes, dtypes) # ensure probability is between 0 and 1: q = np.clip(np.abs(q / 3), a_min=None, a_max=1) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None) return [q, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4) @genNamedParametersNArgs(4) def testParetoLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.pareto.logpdf lax_fun = lsp_stats.pareto.logpdf def args_maker(): x, b, loc, scale = map(rng, shapes, dtypes) return [x, b, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(4) def testTLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.t.logpdf lax_fun = lsp_stats.t.logpdf def args_maker(): x, df, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None) return [x, df, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14}, atol={np.float64: 1e-14}) @genNamedParametersNArgs(3) def testUniformLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.uniform.logpdf lax_fun = lsp_stats.uniform.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) return [x, loc, np.abs(scale)] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(4) def testChi2LogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.chi2.logpdf lax_fun = lsp_stats.chi2.logpdf def args_maker(): x, df, loc, scale = map(rng, shapes, dtypes) return [x, df, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(5) def testBetaBinomLogPmf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) lax_fun = lsp_stats.betabinom.logpmf def args_maker(): k, n, a, b, loc = map(rng, shapes, dtypes) k = np.floor(k) n = np.ceil(n) a = np.clip(a, a_min=0.1, a_max=None) b = np.clip(a, a_min=0.1, a_max=None) loc = np.floor(loc) return [k, n, a, b, loc] if scipy_version >= (1, 4): scipy_fun = osp_stats.betabinom.logpmf self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker, rtol=1e-5, atol=1e-5) def testIssue972(self): self.assertAllClose(np.ones((4, ), np.float32), lsp_stats.norm.cdf( np.full((4, ), np.inf, np.float32)), check_dtypes=False) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_x={}_mean={}_cov={}".format( jtu.format_shape_dtype_string(x_shape, x_dtype), jtu.format_shape_dtype_string(mean_shape, mean_dtype) if mean_shape is not None else None, jtu.format_shape_dtype_string(cov_shape, cov_dtype) if cov_shape is not None else None), "x_shape": x_shape, "x_dtype": x_dtype, "mean_shape": mean_shape, "mean_dtype": mean_dtype, "cov_shape": cov_shape, "cov_dtype": cov_dtype } for x_shape, mean_shape, cov_shape in [ # # These test cases cover default values for mean/cov, but we don't # # support those yet (and they seem not very valuable). # [(), None, None], # [(), (), None], # [(2,), None, None], # [(2,), (), None], # [(2,), (2,), None], # [(3, 2), (3, 2,), None], # [(5, 3, 2), (5, 3, 2,), None], [(), (), ()], [(3, ), (), ()], [(3, ), (3, ), ()], [(3, ), (3, ), (3, 3)], [(3, 4), (4, ), (4, 4)], # # These test cases are where scipy flattens things, which has # # different batch semantics than some might expect # [(5, 3, 2), (5, 3, 2,), ()], # [(5, 3, 2), (5, 3, 2,), (5, 3, 2, 2)], # [(5, 3, 2), (3, 2,), (5, 3, 2, 2)], # [(5, 3, 2), (3, 2,), (2, 2)], ] for x_dtype, mean_dtype, cov_dtype in itertools.combinations_with_replacement( jtu.dtypes.floating, 3) if (mean_shape is not None or mean_dtype == np.float32) and (cov_shape is not None or cov_dtype == np.float32)) ) def testMultivariateNormalLogpdf(self, x_shape, x_dtype, mean_shape, mean_dtype, cov_shape, cov_dtype): rng = jtu.rand_default(self.rng()) def args_maker(): args = [rng(x_shape, x_dtype)] if mean_shape is not None: args.append(5 * rng(mean_shape, mean_dtype)) if cov_shape is not None: if cov_shape == (): args.append(0.1 + rng(cov_shape, cov_dtype)**2) else: factor_shape = (*cov_shape[:-1], 2 * cov_shape[-1]) factor = rng(factor_shape, cov_dtype) args.append(np.matmul(factor, np.swapaxes(factor, -1, -2))) return args self._CheckAgainstNumpy(osp_stats.multivariate_normal.logpdf, lsp_stats.multivariate_normal.logpdf, args_maker, tol=1e-3) self._CompileAndCheck(lsp_stats.multivariate_normal.logpdf, args_maker, rtol=1e-4, atol=1e-4) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_ndim={}_nbatch={}_dtype={}".format(ndim, nbatch, dtype.__name__), "ndim": ndim, "nbatch": nbatch, "dtype": dtype } for ndim in [2, 3] for nbatch in [1, 3, 5] for dtype in jtu.dtypes.floating)) def testMultivariateNormalLogpdfBatch(self, ndim, nbatch, dtype): # Regression test for #5570 rng = jtu.rand_default(self.rng()) x = rng((nbatch, ndim), dtype) mean = 5 * rng((nbatch, ndim), dtype) factor = rng((nbatch, ndim, 2 * ndim), dtype) cov = factor @ factor.transpose(0, 2, 1) result1 = lsp_stats.multivariate_normal.logpdf(x, mean, cov) result2 = jax.vmap(lsp_stats.multivariate_normal.logpdf)(x, mean, cov) self.assertArraysEqual(result1, result2)
class LaxBackedScipyTests(jtu.JaxTestCase): """Tests for LAX-backed Scipy implementation.""" def _GetArgsMaker(self, rng, shapes, dtypes): return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shapes={}_axis={}_keepdims={}_return_sign={}_use_b_{}".format( jtu.format_shape_dtype_string(shapes, dtype), axis, keepdims, return_sign, use_b), # TODO(b/133842870): re-enable when exp(nan) returns NaN on CPU. "shapes": shapes, "dtype": dtype, "axis": axis, "keepdims": keepdims, "return_sign": return_sign, "use_b": use_b} for shape_group in compatible_shapes for dtype in float_dtypes + complex_dtypes + int_dtypes for use_b in [False, True] for shapes in itertools.product(*( (shape_group, shape_group) if use_b else (shape_group,))) for axis in range(-max(len(shape) for shape in shapes), max(len(shape) for shape in shapes)) for keepdims in [False, True] for return_sign in [False, True])) @jtu.ignore_warning(category=RuntimeWarning, message="invalid value encountered in .*") @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. def testLogSumExp(self, shapes, dtype, axis, keepdims, return_sign, use_b): if jtu.device_under_test() != "cpu": rng = jtu.rand_some_inf_and_nan(self.rng()) else: rng = jtu.rand_default(self.rng()) # TODO(mattjj): test autodiff if use_b: def scipy_fun(array_to_reduce, scale_array): return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign, b=scale_array) def lax_fun(array_to_reduce, scale_array): return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign, b=scale_array) args_maker = lambda: [rng(shapes[0], dtype), rng(shapes[1], dtype)] else: def scipy_fun(array_to_reduce): return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign) def lax_fun(array_to_reduce): return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign) args_maker = lambda: [rng(shapes[0], dtype)] tol = {np.float32: 1E-6, np.float64: 1E-14} self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker) self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol) def testLogSumExpZeros(self): # Regression test for https://github.com/google/jax/issues/5370 scipy_fun = lambda a, b: osp_special.logsumexp(a, b=b) lax_fun = lambda a, b: lsp_special.logsumexp(a, b=b) args_maker = lambda: [np.array([-1000, -2]), np.array([1, 0])] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker) self._CompileAndCheck(lax_fun, args_maker) def testLogSumExpOnes(self): # Regression test for https://github.com/google/jax/issues/7390 args_maker = lambda: [np.ones(4, dtype='float32')] with jax.debug_infs(True): self._CheckAgainstNumpy(osp_special.logsumexp, lsp_special.logsumexp, args_maker) self._CompileAndCheck(lsp_special.logsumexp, args_maker) def testLogSumExpNans(self): # Regression test for https://github.com/google/jax/issues/7634 with jax.debug_nans(True): with jax.disable_jit(): result = lsp_special.logsumexp(1.0) self.assertEqual(result, 1.0) result = lsp_special.logsumexp(1.0, b=1.0) self.assertEqual(result, 1.0) @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix( rec.test_name, shapes, dtypes), "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, "test_autodiff": rec.test_autodiff, "nondiff_argnums": rec.nondiff_argnums, "scipy_op": getattr(osp_special, rec.name), "lax_op": getattr(lsp_special, rec.name)} for shapes in itertools.combinations_with_replacement(all_shapes, rec.nargs) for dtypes in (itertools.combinations_with_replacement(rec.dtypes, rec.nargs) if isinstance(rec.dtypes, list) else itertools.product(*rec.dtypes))) for rec in JAX_SPECIAL_FUNCTION_RECORDS)) @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. def testScipySpecialFun(self, scipy_op, lax_op, rng_factory, shapes, dtypes, test_autodiff, nondiff_argnums): if (jtu.device_under_test() == "cpu" and (lax_op is lsp_special.gammainc or lax_op is lsp_special.gammaincc)): # TODO(b/173608403): re-enable test when LLVM bug is fixed. raise unittest.SkipTest("Skipping test due to LLVM lowering bug") rng = rng_factory(self.rng()) args_maker = self._GetArgsMaker(rng, shapes, dtypes) args = args_maker() self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3, check_dtypes=False) self._CompileAndCheck(lax_op, args_maker, rtol=1e-4) if test_autodiff: def partial_lax_op(*vals): list_args = list(vals) for i in nondiff_argnums: list_args.insert(i, args[i]) return lax_op(*list_args) assert list(nondiff_argnums) == sorted(set(nondiff_argnums)) diff_args = [x for i, x in enumerate(args) if i not in nondiff_argnums] jtu.check_grads(partial_lax_op, diff_args, order=1, atol=jtu.if_device_under_test("tpu", .1, 1e-3), rtol=.1, eps=1e-3) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_d={}".format( jtu.format_shape_dtype_string(shape, dtype), d), "shape": shape, "dtype": dtype, "d": d} for shape in all_shapes for dtype in float_dtypes for d in [1, 2, 5])) @jax.numpy_rank_promotion('raise') def testMultigammaln(self, shape, dtype, d): def scipy_fun(a): return osp_special.multigammaln(a, d) def lax_fun(a): return lsp_special.multigammaln(a, d) rng = jtu.rand_positive(self.rng()) args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol={np.float32: 1e-3, np.float64: 1e-14}) self._CompileAndCheck( lax_fun, args_maker, rtol={ np.float32: 3e-07, np.float64: 4e-15 }) def testIssue980(self): x = np.full((4,), -1e20, dtype=np.float32) self.assertAllClose(np.zeros((4,), dtype=np.float32), lsp_special.expit(x)) @jax.numpy_rank_promotion('raise') def testIssue3758(self): x = np.array([1e5, 1e19, 1e10], dtype=np.float32) q = np.array([1., 40., 30.], dtype=np.float32) self.assertAllClose(np.array([1., 0., 0.], dtype=np.float32), lsp_special.zeta(x, q)) def testXlogyShouldReturnZero(self): self.assertAllClose(lsp_special.xlogy(0., 0.), 0., check_dtypes=False) def testGradOfXlogyAtZero(self): partial_xlogy = functools.partial(lsp_special.xlogy, 0.) self.assertAllClose(jax.grad(partial_xlogy)(0.), 0., check_dtypes=False) def testXlog1pyShouldReturnZero(self): self.assertAllClose(lsp_special.xlog1py(0., -1.), 0., check_dtypes=False) def testGradOfXlog1pyAtZero(self): partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.) self.assertAllClose(jax.grad(partial_xlog1py)(-1.), 0., check_dtypes=False) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_lmax={}".format( jtu.format_shape_dtype_string(shape, dtype), l_max), "l_max": l_max, "shape": shape, "dtype": dtype} for l_max in [1, 2, 3] for shape in [(5,), (10,)] for dtype in float_dtypes)) def testLpmn(self, l_max, shape, dtype): rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9) args_maker = lambda: [rng(shape, dtype)] lax_fun = partial(lsp_special.lpmn, l_max, l_max) def scipy_fun(z, m=l_max, n=l_max): # scipy only supports scalar inputs for z, so we must loop here. vals, derivs = zip(*(osp_special.lpmn(m, n, zi) for zi in z)) return np.dstack(vals), np.dstack(derivs) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, rtol=1e-6, atol=1e-6) self._CompileAndCheck(lax_fun, args_maker, rtol=1E-6, atol=1E-6) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_lmax={}".format( jtu.format_shape_dtype_string(shape, dtype), l_max), "l_max": l_max, "shape": shape, "dtype": dtype} for l_max in [3, 4, 6, 32] for shape in [(2,), (3,), (4,), (64,)] for dtype in float_dtypes)) def testNormalizedLpmnValues(self, l_max, shape, dtype): rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9) args_maker = lambda: [rng(shape, dtype)] # Note: we test only the normalized values, not the derivatives. lax_fun = partial(lsp_special.lpmn_values, l_max, l_max, is_normalized=True) def scipy_fun(z, m=l_max, n=l_max): # scipy only supports scalar inputs for z, so we must loop here. vals, _ = zip(*(osp_special.lpmn(m, n, zi) for zi in z)) a = np.dstack(vals) # apply the normalization num_m, num_l, _ = a.shape a_normalized = np.zeros_like(a) for m in range(num_m): for l in range(num_l): c0 = (2.0 * l + 1.0) * osp_special.factorial(l - m) c1 = (4.0 * np.pi) * osp_special.factorial(l + m) c2 = np.sqrt(c0 / c1) a_normalized[m, l] = c2 * a[m, l] return a_normalized self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, rtol=1e-5, atol=1e-5) self._CompileAndCheck(lax_fun, args_maker, rtol=1E-6, atol=1E-6) def testSphHarmAccuracy(self): m = jnp.arange(-3, 3)[:, None] n = jnp.arange(3, 6) n_max = 5 theta = 0.0 phi = jnp.pi expected = lsp_special.sph_harm(m, n, theta, phi, n_max) actual = osp_special.sph_harm(m, n, theta, phi) self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5) def testSphHarmOrderZeroDegreeZero(self): """Tests the spherical harmonics of order zero and degree zero.""" theta = jnp.array([0.3]) phi = jnp.array([2.3]) n_max = 0 expected = jnp.array([1.0 / jnp.sqrt(4.0 * np.pi)]) actual = jnp.real( lsp_special.sph_harm(jnp.array([0]), jnp.array([0]), theta, phi, n_max)) self.assertAllClose(actual, expected, rtol=1.1e-7, atol=3e-8) def testSphHarmOrderZeroDegreeOne(self): """Tests the spherical harmonics of order one and degree zero.""" theta = jnp.array([2.0]) phi = jnp.array([3.1]) n_max = 1 expected = jnp.sqrt(3.0 / (4.0 * np.pi)) * jnp.cos(phi) actual = jnp.real( lsp_special.sph_harm(jnp.array([0]), jnp.array([1]), theta, phi, n_max)) self.assertAllClose(actual, expected, rtol=2e-7, atol=6e-8) def testSphHarmOrderOneDegreeOne(self): """Tests the spherical harmonics of order one and degree one.""" theta = jnp.array([2.0]) phi = jnp.array([2.5]) n_max = 1 expected = (-1.0 / 2.0 * jnp.sqrt(3.0 / (2.0 * np.pi)) * jnp.sin(phi) * jnp.exp(1j * theta)) actual = lsp_special.sph_harm( jnp.array([1]), jnp.array([1]), theta, phi, n_max) self.assertAllClose(actual, expected, rtol=1e-8, atol=6e-8) @parameterized.named_parameters(jtu.cases_from_list( {'testcase_name': '_maxdegree={}_inputsize={}_dtype={}'.format( l_max, num_z, dtype), 'l_max': l_max, 'num_z': num_z, 'dtype': dtype} for l_max, num_z in zip([1, 3, 8, 10], [2, 6, 7, 8]) for dtype in jtu.dtypes.all_integer)) def testSphHarmForJitAndAgainstNumpy(self, l_max, num_z, dtype): """Tests against JIT compatibility and Numpy.""" n_max = l_max shape = (num_z,) rng = jtu.rand_int(self.rng(), -l_max, l_max + 1) lsp_special_fn = partial(lsp_special.sph_harm, n_max=n_max) def args_maker(): m = rng(shape, dtype) n = abs(m) theta = jnp.linspace(-4.0, 5.0, num_z) phi = jnp.linspace(-2.0, 1.0, num_z) return m, n, theta, phi with self.subTest('Test JIT compatibility'): self._CompileAndCheck(lsp_special_fn, args_maker) with self.subTest('Test against numpy.'): self._CheckAgainstNumpy(osp_special.sph_harm, lsp_special_fn, args_maker) def testSphHarmCornerCaseWithWrongNmax(self): """Tests the corner case where `n_max` is not the maximum value of `n`.""" m = jnp.array([2]) n = jnp.array([10]) n_clipped = jnp.array([6]) n_max = 6 theta = jnp.array([0.9]) phi = jnp.array([0.2]) expected = lsp_special.sph_harm(m, n, theta, phi, n_max) actual = lsp_special.sph_harm(m, n_clipped, theta, phi, n_max) self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5) @parameterized.named_parameters(jtu.cases_from_list( {'testcase_name': '_shape={}' '_n_zero_sv={}_degeneracy={}_geometric_spectrum={}' '_max_sv={}_method={}_side={}' '_nonzero_condition_number={}_seed={}'.format( jtu.format_shape_dtype_string( shape, jnp.dtype(dtype).name).replace(" ", ""), n_zero_sv, degeneracy, geometric_spectrum, max_sv, method, side, nonzero_condition_number, seed ), 'n_zero_sv': n_zero_sv, 'degeneracy': degeneracy, 'geometric_spectrum': geometric_spectrum, 'max_sv': max_sv, 'shape': shape, 'method': method, 'side': side, 'nonzero_condition_number': nonzero_condition_number, 'dtype': dtype, 'seed': seed} for n_zero_sv in n_zero_svs for degeneracy in degeneracies for geometric_spectrum in geometric_spectra for max_sv in max_svs for shape in polar_shapes for method in methods for side in sides for nonzero_condition_number in nonzero_condition_numbers for dtype in jtu.dtypes.floating for seed in seeds)) @jtu.skip_on_devices("gpu") # Fails on A100. def testPolar( self, n_zero_sv, degeneracy, geometric_spectrum, max_sv, shape, method, side, nonzero_condition_number, dtype, seed): """ Tests jax.scipy.linalg.polar.""" if jtu.device_under_test() != "cpu": if jnp.dtype(dtype).name in ("bfloat16", "float16"): raise unittest.SkipTest("Skip half precision off CPU.") if method == "svd": raise unittest.SkipTest("Can't use SVD mode on TPU/GPU.") matrix, _ = _initialize_polar_test(self.rng(), shape, n_zero_sv, degeneracy, geometric_spectrum, max_sv, nonzero_condition_number, dtype) if jnp.dtype(dtype).name in ("bfloat16", "float16"): self.assertRaises( NotImplementedError, jsp.linalg.polar, matrix, method=method, side=side) return unitary, posdef = jsp.linalg.polar(matrix, method=method, side=side) if shape[0] >= shape[1]: should_be_eye = np.matmul(unitary.conj().T, unitary) else: should_be_eye = np.matmul(unitary, unitary.conj().T) tol = 10 * jnp.finfo(matrix.dtype).eps eye_mat = np.eye(should_be_eye.shape[0], dtype=should_be_eye.dtype) with self.subTest('Test unitarity.'): self.assertAllClose( eye_mat, should_be_eye, atol=tol * min(shape)) with self.subTest('Test Hermiticity.'): self.assertAllClose( posdef, posdef.conj().T, atol=tol * jnp.linalg.norm(posdef)) ev, _ = np.linalg.eigh(posdef) ev = ev[np.abs(ev) > tol * np.linalg.norm(posdef)] negative_ev = jnp.sum(ev < 0.) with self.subTest('Test positive definiteness.'): assert negative_ev == 0. if side == "right": recon = jnp.matmul(unitary, posdef, precision=lax.Precision.HIGHEST) elif side == "left": recon = jnp.matmul(posdef, unitary, precision=lax.Precision.HIGHEST) with self.subTest('Test reconstruction.'): self.assertAllClose( matrix, recon, atol=tol * jnp.linalg.norm(matrix)) @parameterized.named_parameters(jtu.cases_from_list( {'testcase_name': '_linear_size_={}_seed={}_dtype={}'.format( linear_size, seed, jnp.dtype(dtype).name ), 'linear_size': linear_size, 'seed': seed, 'dtype': dtype} for linear_size in linear_sizes for seed in seeds for dtype in jtu.dtypes.floating)) def test_spectral_dac_eigh(self, linear_size, seed, dtype): if jtu.device_under_test != "cpu": raise unittest.SkipTest("Skip eigh off CPU for now.") if jnp.dtype(dtype).name in ("bfloat16", "float16"): if jtu.device_under_test() != "cpu": raise unittest.SkipTest("Skip half precision off CPU.") rng = self.rng() H = rng.randn(linear_size, linear_size) H = jnp.array(0.5 * (H + H.conj().T)).astype(dtype) if jnp.dtype(dtype).name in ("bfloat16", "float16"): self.assertRaises( NotImplementedError, jax._src.scipy.eigh.eigh, H) return evs, V = jax._src.scipy.eigh.eigh(H) ev_exp, eV_exp = jnp.linalg.eigh(H) HV = jnp.dot(H, V, precision=lax.Precision.HIGHEST) vV = evs * V eps = jnp.finfo(H.dtype).eps atol = jnp.linalg.norm(H) * eps self.assertAllClose(ev_exp, jnp.sort(evs), atol=20 * atol) self.assertAllClose(HV, vV, atol=30 * atol) @parameterized.named_parameters(jtu.cases_from_list( {'testcase_name': '_linear_size_={}_seed={}_dtype={}'.format( linear_size, seed, jnp.dtype(dtype).name ), 'linear_size': linear_size, 'seed': seed, 'dtype': dtype} for linear_size in linear_sizes for seed in seeds for dtype in jtu.dtypes.floating)) @jtu.skip_on_devices("gpu") # Fails on A100. def test_spectral_dac_svd(self, linear_size, seed, dtype): if jnp.dtype(dtype).name in ("bfloat16", "float16"): if jtu.device_under_test() != "cpu": raise unittest.SkipTest("Skip half precision off CPU.") rng = self.rng() A = rng.randn(linear_size, linear_size).astype(dtype) if jnp.dtype(dtype).name in ("bfloat16", "float16"): self.assertRaises( NotImplementedError, jax._src.scipy.eigh.svd, A) return S_expected = np.linalg.svd(A, compute_uv=False) U, S, V = jax._src.scipy.eigh.svd(A) recon = jnp.dot((U * jnp.expand_dims(S, 0)), V, precision=lax.Precision.HIGHEST) eps = jnp.finfo(dtype).eps eps = eps * jnp.linalg.norm(A) * 15 self.assertAllClose(np.sort(S), np.sort(S_expected), atol=eps) self.assertAllClose(A, recon, atol=eps) # U is unitary. u_unitary_delta = jnp.dot(U.conj().T, U, precision=lax.Precision.HIGHEST) u_eye = jnp.eye(u_unitary_delta.shape[0], dtype=dtype) self.assertAllClose(u_unitary_delta, u_eye, atol=eps) # V is unitary. v_unitary_delta = jnp.dot(V.conj().T, V, precision=lax.Precision.HIGHEST) v_eye = jnp.eye(v_unitary_delta.shape[0], dtype=dtype) self.assertAllClose(v_unitary_delta, v_eye, atol=eps)
class SparsifyTest(jtu.JaxTestCase): def assertBcooIdentical(self, x, y): self.assertIsInstance(x, BCOO) self.assertIsInstance(y, BCOO) self.assertEqual(x.shape, y.shape) self.assertArraysEqual(x.data, y.data) self.assertArraysEqual(x.indices, y.indices) def testArgSpec(self): X = jnp.arange(5) X_BCOO = BCOO.fromdense(X) args = (X, X_BCOO, X_BCOO) # Independent index spenv = SparseEnv() argspecs = arrays_to_argspecs(spenv, args) self.assertEqual(len(argspecs), len(args)) self.assertEqual(spenv.size(), 5) self.assertEqual(argspecs, (ArgSpec(X.shape, 0, None), ArgSpec(X.shape, 1, 2), ArgSpec(X.shape, 3, 4))) args_out = argspecs_to_arrays(spenv, argspecs) self.assertEqual(len(args_out), len(args)) self.assertArraysEqual(args[0], args_out[0]) self.assertBcooIdentical(args[1], args_out[1]) self.assertBcooIdentical(args[2], args_out[2]) # Shared index argspecs = (ArgSpec(X.shape, 0, None), ArgSpec(X.shape, 1, 2), ArgSpec(X.shape, 3, 2)) spenv = SparseEnv([X, X_BCOO.data, X_BCOO.indices, X_BCOO.data]) args_out = argspecs_to_arrays(spenv, argspecs) self.assertEqual(len(args_out), len(args)) self.assertArraysEqual(args[0], args_out[0]) self.assertBcooIdentical(args[1], args_out[1]) self.assertBcooIdentical(args[2], args_out[2]) def testUnitHandling(self): x = BCOO.fromdense(jnp.arange(5)) f = jit(lambda x, y: x) result = sparsify(jit(f))(x, core.unit) self.assertBcooIdentical(result, x) def testDropvar(self): def inner(x): return x * 2, x * 3 def f(x): _, y = jit(inner)(x) return y * 4 x_dense = jnp.arange(5) x_sparse = BCOO.fromdense(x_dense) self.assertArraysEqual(sparsify(f)(x_sparse).todense(), f(x_dense)) def testPytreeInput(self): f = sparsify(lambda x: x) args = (jnp.arange(4), BCOO.fromdense(jnp.arange(4))) out = f(args) self.assertLen(out, 2) self.assertArraysEqual(args[0], out[0]) self.assertBcooIdentical(args[1], out[1]) def testSparsify(self): M_dense = jnp.arange(24).reshape(4, 6) M_sparse = BCOO.fromdense(M_dense) v = jnp.arange(M_dense.shape[0]) @sparsify def func(x, v): return -jnp.sin(jnp.pi * x).T @ (v + 1) result_dense = func(M_dense, v) result_sparse = func(M_sparse, v) self.assertAllClose(result_sparse, result_dense) def testSparsifyWithConsts(self): M_dense = jnp.arange(24).reshape(4, 6) M_sparse = BCOO.fromdense(M_dense) @sparsify def func(x): return jit(lambda x: jnp.sum(x, 1))(x) result_dense = func(M_dense) result_sparse = func(M_sparse) self.assertAllClose(result_sparse.todense(), result_dense) def testSparseMatmul(self): X = jnp.arange(16).reshape(4, 4) Xsp = BCOO.fromdense(X) Y = jnp.ones(4) Ysp = BCOO.fromdense(Y) # dot_general result_sparse = sparsify(operator.matmul)(Xsp, Y) result_dense = operator.matmul(X, Y) self.assertAllClose(result_sparse, result_dense) # rdot_general result_sparse = sparsify(operator.matmul)(Y, Xsp) result_dense = operator.matmul(Y, X) self.assertAllClose(result_sparse, result_dense) # spdot_general result_sparse = sparsify(operator.matmul)(Xsp, Ysp) result_dense = operator.matmul(X, Y) self.assertAllClose(result_sparse.todense(), result_dense) def testSparseAdd(self): x = BCOO.fromdense(jnp.arange(5)) y = BCOO.fromdense(2 * jnp.arange(5)) # Distinct indices out = sparsify(operator.add)(x, y) self.assertEqual(out.nse, 8) # uses concatenation. self.assertArraysEqual(out.todense(), 3 * jnp.arange(5)) # Shared indices – requires lower level call argspecs = [ ArgSpec(x.shape, 1, 0), ArgSpec(y.shape, 2, 0) ] spenv = SparseEnv([x.indices, x.data, y.data]) result = sparsify_raw(operator.add)(spenv, *argspecs) args_out, _ = result out, = argspecs_to_arrays(spenv, args_out) self.assertAllClose(out.todense(), x.todense() + y.todense()) def testSparseMul(self): x = BCOO.fromdense(jnp.arange(5)) y = BCOO.fromdense(2 * jnp.arange(5)) # Scalar multiplication out = sparsify(operator.mul)(x, 2.5) self.assertArraysEqual(out.todense(), x.todense() * 2.5) # Shared indices – requires lower level call argspecs = [ ArgSpec(x.shape, 1, 0), ArgSpec(y.shape, 2, 0) ] spenv = SparseEnv([x.indices, x.data, y.data]) result = sparsify_raw(operator.mul)(spenv, *argspecs) args_out, _ = result out, = argspecs_to_arrays(spenv, args_out) self.assertAllClose(out.todense(), x.todense() * y.todense()) def testSparseSum(self): x = jnp.arange(20).reshape(4, 5) xsp = BCOO.fromdense(x) def f(x): return x.sum(), x.sum(0), x.sum(1), x.sum((0, 1)) result_dense = f(x) result_sparse = sparsify(f)(xsp) assert len(result_dense) == len(result_sparse) for res_dense, res_sparse in zip(result_dense, result_sparse): if isinstance(res_sparse, BCOO): res_sparse = res_sparse.todense() self.assertArraysAllClose(res_dense, res_sparse) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_dimensions={}_nbatch={}, ndense={}".format( jtu.format_shape_dtype_string(shape, np.float32), dimensions, n_batch, n_dense), "shape": shape, "dimensions": dimensions, "n_batch": n_batch, "n_dense": n_dense} for shape, dimensions in [ [(1,), (0,)], [(1,), (-1,)], [(2, 1, 4), (1,)], [(2, 1, 3, 1), (1,)], [(2, 1, 3, 1), (1, 3)], [(2, 1, 3, 1), (3,)], ] for n_batch in range(len(shape) + 1) for n_dense in range(len(shape) - n_batch + 1))) def testSparseSqueeze(self, shape, dimensions, n_batch, n_dense): rng = jtu.rand_default(self.rng()) M_dense = rng(shape, np.float32) M_sparse = BCOO.fromdense(M_dense, n_batch=n_batch, n_dense=n_dense) func = sparsify(partial(lax.squeeze, dimensions=dimensions)) result_dense = func(M_dense) result_sparse = func(M_sparse).todense() self.assertAllClose(result_sparse, result_dense) def testSparseWhileLoop(self): def cond_fun(params): i, A = params return i < 5 def body_fun(params): i, A = params return i + 1, 2 * A def f(A): return lax.while_loop(cond_fun, body_fun, (0, A)) A = jnp.arange(4) out_dense = f(A) Asp = BCOO.fromdense(A) out_sparse = sparsify(f)(Asp) self.assertEqual(len(out_dense), 2) self.assertEqual(len(out_sparse), 2) self.assertArraysEqual(out_dense[0], out_dense[0]) self.assertArraysEqual(out_dense[1], out_sparse[1].todense()) def testSparseWhileLoopDuplicateIndices(self): def cond_fun(params): i, A, B = params return i < 5 def body_fun(params): i, A, B = params # TODO(jakevdp): track shared indices through while loop & use this # version of the test, which requires shared indices in order for # the nse of the result to remain the same. # return i + 1, A, A + B # This version is fine without shared indices, and tests that we're # flattening non-shared indices consistently. return i + 1, B, A def f(A): return lax.while_loop(cond_fun, body_fun, (0, A, A)) A = jnp.arange(4).reshape((2, 2)) out_dense = f(A) Asp = BCOO.fromdense(A) out_sparse = sparsify(f)(Asp) self.assertEqual(len(out_dense), 3) self.assertEqual(len(out_sparse), 3) self.assertArraysEqual(out_dense[0], out_dense[0]) self.assertArraysEqual(out_dense[1], out_sparse[1].todense()) self.assertArraysEqual(out_dense[2], out_sparse[2].todense()) def testSparsifyDenseXlaCall(self): # Test handling of dense xla_call within jaxpr interpreter. out = sparsify(jit(lambda x: x + 1))(0.0) self.assertEqual(out, 1.0) def testSparsifySparseXlaCall(self): # Test sparse lowering of XLA call def func(M): return 2 * M M = jnp.arange(6).reshape(2, 3) Msp = BCOO.fromdense(M) out_dense = func(M) out_sparse = sparsify(jit(func))(Msp) self.assertArraysEqual(out_dense, out_sparse.todense()) def testSparseForiLoop(self): def func(M, x): body_fun = lambda i, val: (M @ val) / M.shape[1] return lax.fori_loop(0, 2, body_fun, x) x = jnp.arange(5.0) M = jnp.arange(25).reshape(5, 5) M_bcoo = BCOO.fromdense(M) result_dense = func(M, x) result_sparse = sparsify(func)(M_bcoo, x) self.assertArraysAllClose(result_dense, result_sparse) def testSparseCondSimple(self): def func(x): return lax.cond(False, lambda x: x, lambda x: 2 * x, x) x = jnp.arange(5.0) result_dense = func(x) x_bcoo = BCOO.fromdense(x) result_sparse = sparsify(func)(x_bcoo) self.assertArraysAllClose(result_dense, result_sparse.todense()) def testSparseCondMismatchError(self): @sparsify def func(x, y): return lax.cond(False, lambda x: x[0], lambda x: x[1], (x, y)) x = jnp.arange(5.0) y = jnp.arange(5.0) x_bcoo = BCOO.fromdense(x) y_bcoo = BCOO.fromdense(y) func(x, y) # No error func(x_bcoo, y_bcoo) # No error with self.assertRaisesRegex(TypeError, "sparsified true_fun and false_fun output.*"): func(x_bcoo, y)
class FftTest(jtu.JaxTestCase): def testNotImplemented(self): for name in jnp.fft._NOT_IMPLEMENTED: func = getattr(jnp.fft, name) with self.assertRaises(NotImplementedError): func() def testLaxFftAcceptsStringTypes(self): rng = jtu.rand_default(self.rng()) x = rng((10, ), np.complex64) self.assertAllClose( np.fft.fft(x).astype(np.complex64), lax.fft(x, "FFT", fft_lengths=(10, ))) @parameterized.parameters((np.float32, ), (np.float64, )) def testLaxIrfftDoesNotMutateInputs(self, dtype): if dtype == np.float64 and not config.x64_enabled: raise self.skipTest("float64 requires jax_enable_x64=true") x = (1 + 1j) * jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=dtypes._to_complex_dtype(dtype)) y = np.asarray(jnp.fft.irfft2(x)) z = np.asarray(jnp.fft.irfft2(x)) self.assertAllClose(y, z) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_inverse={}_real={}_shape={}_axes={}_s={}_norm={}".format( inverse, real, jtu.format_shape_dtype_string(shape, dtype), axes, s, norm), "axes": axes, "shape": shape, "dtype": dtype, "inverse": inverse, "real": real, "s": s, "norm": norm } for inverse in [False, True] for real in [False, True] for dtype in (real_dtypes if real and not inverse else all_dtypes) for shape in [(10, ), (10, 10), (9, ), (2, 3, 4), (2, 3, 4, 5)] for axes in _get_fftn_test_axes(shape) for s in _get_fftn_test_s(shape, axes) for norm in FFT_NORMS)) def testFftn(self, inverse, real, shape, dtype, axes, s, norm): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype), ) jnp_op = _get_fftn_func(jnp.fft, inverse, real) np_op = _get_fftn_func(np.fft, inverse, real) jnp_fn = lambda a: jnp_op(a, axes=axes, norm=norm) np_fn = lambda a: np_op(a, axes=axes, norm=norm ) if axes is None or axes else a # Numpy promotes to complex128 aggressively. self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(jnp_fn, args_maker) # Test gradient for differentiable types. if (config.x64_enabled and dtype in (float_dtypes if real and not inverse else inexact_dtypes)): # TODO(skye): can we be more precise? tol = 0.15 jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol) # check dtypes dtype = jnp_fn(rng(shape, dtype)).dtype expected_dtype = jnp.promote_types( float if inverse and real else complex, dtype) self.assertEqual(dtype, expected_dtype) def testIrfftTranspose(self): # regression test for https://github.com/google/jax/issues/6223 def build_matrix(linear_func, size): return jax.vmap(linear_func)(jnp.eye(size, size)) def func(x): x, = _promote_dtypes_complex(x) return jnp.fft.irfft( jnp.concatenate( [jnp.zeros_like(x, shape=1), x[:2] + 1j * x[2:]])) def func_transpose(x): return jax.linear_transpose(func, x)(x)[0] matrix = build_matrix(func, 4) matrix2 = build_matrix(func_transpose, 4).T self.assertAllClose(matrix, matrix2) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_inverse={inverse}_real={real}", "inverse": inverse, "real": real } for inverse in [False, True] for real in [False, True])) def testFftnErrors(self, inverse, real): rng = jtu.rand_default(self.rng()) name = 'fftn' if real: name = 'r' + name if inverse: name = 'i' + name func = _get_fftn_func(jnp.fft, inverse, real) self.assertRaisesRegex( ValueError, "jax.numpy.fft.{} only supports 1D, 2D, and 3D FFTs. " "Got axes None with input rank 4.".format(name), lambda: func(rng([2, 3, 4, 5], dtype=np.float64), axes=None)) self.assertRaisesRegex( ValueError, f"jax.numpy.fft.{name} does not support repeated axes. Got axes \\[1, 1\\].", lambda: func(rng([2, 3], dtype=np.float64), axes=[1, 1])) self.assertRaises( ValueError, lambda: func(rng([2, 3], dtype=np.float64), axes=[2])) self.assertRaises( ValueError, lambda: func(rng([2, 3], dtype=np.float64), axes=[-3])) def testFftEmpty(self): out = jnp.fft.fft(jnp.zeros((0, ), jnp.complex64)).block_until_ready() self.assertArraysEqual(jnp.zeros((0, ), jnp.complex64), out) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_inverse={}_real={}_hermitian={}_shape={}_n={}_axis={}". format(inverse, real, hermitian, jtu.format_shape_dtype_string(shape, dtype), n, axis), "axis": axis, "shape": shape, "dtype": dtype, "inverse": inverse, "real": real, "hermitian": hermitian, "n": n } for inverse in [False, True] for real in [False, True] for hermitian in [False, True] for dtype in (real_dtypes if (real and not inverse) or ( hermitian and inverse) else all_dtypes) for shape in [(10, )] for n in [None, 1, 7, 13, 20] for axis in [-1, 0])) def testFft(self, inverse, real, hermitian, shape, dtype, n, axis): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype), ) name = 'fft' if real: name = 'r' + name elif hermitian: name = 'h' + name if inverse: name = 'i' + name jnp_op = getattr(jnp.fft, name) np_op = getattr(np.fft, name) jnp_fn = lambda a: jnp_op(a, n=n, axis=axis) np_fn = lambda a: np_op(a, n=n, axis=axis) # Numpy promotes to complex128 aggressively. self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(jnp_op, args_maker) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_inverse={inverse}_real={real}_hermitian={hermitian}", "inverse": inverse, "real": real, "hermitian": hermitian } for inverse in [False, True] for real in [False, True] for hermitian in [False, True])) def testFftErrors(self, inverse, real, hermitian): rng = jtu.rand_default(self.rng()) name = 'fft' if real: name = 'r' + name elif hermitian: name = 'h' + name if inverse: name = 'i' + name func = getattr(jnp.fft, name) self.assertRaisesRegex( ValueError, f"jax.numpy.fft.{name} does not support multiple axes. " f"Please use jax.numpy.fft.{name}n. Got axis = \\[1, 1\\].", lambda: func(rng([2, 3], dtype=np.float64), axis=[1, 1])) self.assertRaisesRegex( ValueError, f"jax.numpy.fft.{name} does not support multiple axes. " f"Please use jax.numpy.fft.{name}n. Got axis = \\(1, 1\\).", lambda: func(rng([2, 3], dtype=np.float64), axis=(1, 1))) self.assertRaises( ValueError, lambda: func(rng([2, 3], dtype=np.float64), axis=[2])) self.assertRaises( ValueError, lambda: func(rng([2, 3], dtype=np.float64), axis=[-3])) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_inverse={}_real={}_shape={}_axes={}_norm={}".format( inverse, real, jtu.format_shape_dtype_string(shape, dtype), axes, norm), "axes": axes, "shape": shape, "dtype": dtype, "inverse": inverse, "real": real, "norm": norm } for inverse in [False, True] for real in [False, True] for dtype in (real_dtypes if real and not inverse else all_dtypes) for shape in [(16, 8, 4, 8), (16, 8, 4, 8, 4)] for axes in [(-2, -1), (0, 1), (1, 3), (-1, 2)] for norm in FFT_NORMS)) def testFft2(self, inverse, real, shape, dtype, axes, norm): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype), ) name = 'fft2' if real: name = 'r' + name if inverse: name = 'i' + name jnp_op = getattr(jnp.fft, name) np_op = getattr(np.fft, name) jnp_fn = lambda a: jnp_op(a, axes=axes, norm=norm) np_fn = lambda a: np_op(a, axes=axes, norm=norm ) if axes is None or axes else a # Numpy promotes to complex128 aggressively. self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(jnp_op, args_maker) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_inverse={inverse}_real={real}", "inverse": inverse, "real": real } for inverse in [False, True] for real in [False, True])) def testFft2Errors(self, inverse, real): rng = jtu.rand_default(self.rng()) name = 'fft2' if real: name = 'r' + name if inverse: name = 'i' + name func = getattr(jnp.fft, name) self.assertRaisesRegex( ValueError, "jax.numpy.fft.{} only supports 2 axes. " "Got axes = \\[0\\].".format(name), lambda: func(rng([2, 3], dtype=np.float64), axes=[0])) self.assertRaisesRegex( ValueError, "jax.numpy.fft.{} only supports 2 axes. " "Got axes = \\(0, 1, 2\\).".format(name), lambda: func(rng([2, 3, 3], dtype=np.float64), axes=(0, 1, 2))) self.assertRaises( ValueError, lambda: func(rng([2, 3], dtype=np.float64), axes=[2, 3])) self.assertRaises( ValueError, lambda: func(rng([2, 3], dtype=np.float64), axes=[-3, -4])) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_size={}_d={}".format( jtu.format_shape_dtype_string([size], dtype), d), "dtype": dtype, "size": size, "d": d } for dtype in all_dtypes for size in [9, 10, 101, 102] for d in [0.1, 2.])) def testFftfreq(self, size, d, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng([size], dtype), ) jnp_op = jnp.fft.fftfreq np_op = np.fft.fftfreq jnp_fn = lambda a: jnp_op(size, d=d) np_fn = lambda a: np_op(size, d=d) # Numpy promotes to complex128 aggressively. self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(jnp_fn, args_maker) # Test gradient for differentiable types. if dtype in inexact_dtypes: tol = 0.15 # TODO(skye): can we be more precise? jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_n={n}", "n": n } for n in [[0, 1, 2]])) def testFftfreqErrors(self, n): name = 'fftfreq' func = jnp.fft.fftfreq self.assertRaisesRegex( ValueError, "The n argument of jax.numpy.fft.{} only takes an int. " "Got n = \\[0, 1, 2\\].".format(name), lambda: func(n=n)) self.assertRaisesRegex( ValueError, "The d argument of jax.numpy.fft.{} only takes a single value. " "Got d = \\[0, 1, 2\\].".format(name), lambda: func(n=10, d=n)) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_size={}_d={}".format( jtu.format_shape_dtype_string([size], dtype), d), "dtype": dtype, "size": size, "d": d } for dtype in all_dtypes for size in [9, 10, 101, 102] for d in [0.1, 2.])) def testRfftfreq(self, size, d, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng([size], dtype), ) jnp_op = jnp.fft.rfftfreq np_op = np.fft.rfftfreq jnp_fn = lambda a: jnp_op(size, d=d) np_fn = lambda a: np_op(size, d=d) # Numpy promotes to complex128 aggressively. self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(jnp_fn, args_maker) # Test gradient for differentiable types. if dtype in inexact_dtypes: tol = 0.15 # TODO(skye): can we be more precise? jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_n={n}", "n": n } for n in [[0, 1, 2]])) def testRfftfreqErrors(self, n): name = 'rfftfreq' func = jnp.fft.rfftfreq self.assertRaisesRegex( ValueError, "The n argument of jax.numpy.fft.{} only takes an int. " "Got n = \\[0, 1, 2\\].".format(name), lambda: func(n=n)) self.assertRaisesRegex( ValueError, "The d argument of jax.numpy.fft.{} only takes a single value. " "Got d = \\[0, 1, 2\\].".format(name), lambda: func(n=10, d=n)) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "dtype={}_axes={}".format( jtu.format_shape_dtype_string(shape, dtype), axes), "dtype": dtype, "shape": shape, "axes": axes } for dtype in all_dtypes for shape in [[9], [10], [101], [102], [3, 5], [3, 17], [5, 7, 11]] for axes in _get_fftn_test_axes(shape))) def testFftshift(self, shape, dtype, axes): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype), ) jnp_fn = lambda arg: jnp.fft.fftshift(arg, axes=axes) np_fn = lambda arg: np.fft.fftshift(arg, axes=axes) self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "dtype={}_axes={}".format( jtu.format_shape_dtype_string(shape, dtype), axes), "dtype": dtype, "shape": shape, "axes": axes } for dtype in all_dtypes for shape in [[9], [10], [101], [102], [3, 5], [3, 17], [5, 7, 11]] for axes in _get_fftn_test_axes(shape))) def testIfftshift(self, shape, dtype, axes): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype), ) jnp_fn = lambda arg: jnp.fft.ifftshift(arg, axes=axes) np_fn = lambda arg: np.fft.ifftshift(arg, axes=axes) self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker)
class TestPolynomial(jtu.JaxTestCase): @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_dtype={}_leading={}_trailing={}".format( jtu.format_shape_dtype_string(( length + leading + trailing, ), dtype), leading, trailing), "dtype": dtype, "length": length, "leading": leading, "trailing": trailing } for dtype in all_dtypes for length in [0, 3, 9, 10, 17] for leading in [0, 1, 2, 3, 5, 7, 10] for trailing in [0, 1, 2, 3, 5, 7, 10])) # TODO(phawkins): no nonsymmetric eigendecomposition implementation on GPU. @jtu.skip_on_devices("gpu", "tpu") def testRoots(self, dtype, length, leading, trailing): rng = jtu.rand_default(np.random.RandomState(0)) def args_maker(): p = rng((length, ), dtype) return jnp.concatenate( [jnp.zeros(leading, p.dtype), p, jnp.zeros(trailing, p.dtype)]), jnp_fn = lambda arg: jnp.sort(jnp.roots(arg)) np_fn = lambda arg: np.sort(np.roots(arg)) self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=3e-6) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_dtype={}_trailing={}".format( jtu.format_shape_dtype_string((length + trailing, ), dtype), trailing), "dtype": dtype, "length": length, "trailing": trailing } for dtype in all_dtypes for length in [0, 1, 3, 10] for trailing in [0, 1, 3, 7])) # TODO(phawkins): no nonsymmetric eigendecomposition implementation on GPU. @jtu.skip_on_devices("gpu", "tpu") def testRootsNostrip(self, length, dtype, trailing): rng = jtu.rand_default(np.random.RandomState(0)) def args_maker(): p = rng((length, ), dtype) if length != 0: return jnp.concatenate([p, jnp.zeros(trailing, p.dtype)]), else: # adding trailing would make input invalid (start with zeros) return p, jnp_fn = lambda arg: jnp.sort(jnp.roots(arg, strip_zeros=False)) np_fn = lambda arg: np.sort(np.roots(arg)) self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=1e-6) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_dtype={}_trailing={}".format( jtu.format_shape_dtype_string((length + trailing, ), dtype), trailing), "dtype": dtype, "length": length, "trailing": trailing } for dtype in all_dtypes for length in [0, 1, 3, 10] for trailing in [0, 1, 3, 7])) # TODO: enable when there is an eigendecomposition implementation # for GPU/TPU. @jtu.skip_on_devices("gpu", "tpu") def testRootsJit(self, length, dtype, trailing): rng = jtu.rand_default(np.random.RandomState(0)) def args_maker(): p = rng((length, ), dtype) if length != 0: return jnp.concatenate([p, jnp.zeros(trailing, p.dtype)]), else: # adding trailing would make input invalid (start with zeros) return p, roots_compiled = jit(partial(jnp.roots, strip_zeros=False)) jnp_fn = lambda arg: jnp.sort(roots_compiled(arg)) np_fn = lambda arg: np.sort(np.roots(arg)) # Using strip_zeros=False makes the algorithm less efficient # and leads to slightly different values compared ot numpy self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=1e-6) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_dtype={}_zeros={}_nonzeros={}".format( jtu.format_shape_dtype_string(( zeros + nonzeros, ), dtype), zeros, nonzeros), "zeros": zeros, "nonzeros": nonzeros, "dtype": dtype } for dtype in all_dtypes for zeros in [1, 2, 5] for nonzeros in [0, 3])) @jtu.skip_on_devices("gpu") @unittest.skip("getting segfaults on MKL") # TODO(#3711) def testRootsInvalid(self, zeros, nonzeros, dtype): rng = jtu.rand_default(np.random.RandomState(0)) # The polynomial coefficients here start with zero and would have to # be stripped before computing eigenvalues of the companion matrix. # Setting strip_zeros=False skips this check, # allowing jit transformation but yielding nan's for these inputs. p = jnp.concatenate( [jnp.zeros(zeros, dtype), rng((nonzeros, ), dtype)]) if p.size == 1: # polynomial = const has no roots self.assertTrue(jnp.roots(p, strip_zeros=False).size == 0) else: self.assertTrue(jnp.any(jnp.isnan(jnp.roots(p, strip_zeros=False))))