class TestPolynomial(jtu.JaxTestCase): def testNotImplemented(self): for name in jnp.polynomial._NOT_IMPLEMENTED: func = getattr(jnp.polynomial, name) with self.assertRaises(NotImplementedError): func() @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_dtype={}_leading={}_trailing={}".format( jtu.format_shape_dtype_string((length+leading+trailing,), dtype), leading, trailing), "dtype": dtype, "rng_factory": rng_factory, "length": length, "leading": leading, "trailing": trailing} for dtype in all_dtypes for rng_factory in [jtu.rand_default] for length in [0, 3, 9, 10, 17] for leading in [0, 1, 2, 3, 5, 7, 10] for trailing in [0, 1, 2, 3, 5, 7, 10])) def testRoots(self, dtype, rng_factory, length, leading, trailing): rng = rng_factory(np.random.RandomState(0)) def args_maker(): p = rng((length,), dtype) return jnp.concatenate( [jnp.zeros(leading, p.dtype), p, jnp.zeros(trailing, p.dtype)]), jnp_fn = lambda arg: jnp.sort(jnp.roots(arg)) np_fn = lambda arg: np.sort(np.roots(arg)) self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=3e-6) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_dtype={}_trailing={}".format( jtu.format_shape_dtype_string((length+trailing,), dtype), trailing), "dtype": dtype, "rng_factory": rng_factory, "length": length, "trailing": trailing} for dtype in all_dtypes for rng_factory in [jtu.rand_default] for length in [0, 1, 3, 10] for trailing in [0, 1, 3, 7])) def testRootsNostrip(self, length, dtype, rng_factory, trailing): rng = rng_factory(np.random.RandomState(0)) def args_maker(): p = rng((length,), dtype) if length != 0: return jnp.concatenate([p, jnp.zeros(trailing, p.dtype)]), else: # adding trailing would make input invalid (start with zeros) return p, jnp_fn = lambda arg: jnp.sort(jnp.roots(arg, strip_zeros=False)) np_fn = lambda arg: np.sort(np.roots(arg)) self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=1e-6) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_dtype={}_trailing={}".format( jtu.format_shape_dtype_string((length + trailing,), dtype), trailing), "dtype": dtype, "rng_factory": rng_factory, "length": length, "trailing": trailing} for dtype in all_dtypes for rng_factory in [jtu.rand_default] for length in [0, 1, 3, 10] for trailing in [0, 1, 3, 7])) # TODO: enable when there is an eigendecomposition implementation # for GPU/TPU. @jtu.skip_on_devices("gpu", "tpu") def testRootsJit(self, length, dtype, rng_factory, trailing): rng = rng_factory(np.random.RandomState(0)) def args_maker(): p = rng((length,), dtype) if length != 0: return jnp.concatenate([p, jnp.zeros(trailing, p.dtype)]), else: # adding trailing would make input invalid (start with zeros) return p, roots_compiled = jit(partial(jnp.roots, strip_zeros=False)) jnp_fn = lambda arg: jnp.sort(roots_compiled(arg)) np_fn = lambda arg: np.sort(np.roots(arg)) # Using strip_zeros=False makes the algorithm less efficient # and leads to slightly different values compared ot numpy self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=1e-6) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_dtype={}_zeros={}_nonzeros={}".format( jtu.format_shape_dtype_string((zeros+nonzeros,), dtype), zeros, nonzeros), "zeros": zeros, "nonzeros": nonzeros, "dtype": dtype, "rng_factory": rng_factory} for dtype in all_dtypes for rng_factory in [jtu.rand_default] for zeros in [1, 2, 5] for nonzeros in [0, 3])) @jtu.skip_on_devices("gpu") def testRootsInvalid(self, zeros, nonzeros, dtype, rng_factory): rng = rng_factory(np.random.RandomState(0)) # The polynomial coefficients here start with zero and would have to # be stripped before computing eigenvalues of the companion matrix. # Setting strip_zeros=False skips this check, # allowing jit transformation but yielding nan's for these inputs. p = jnp.concatenate([jnp.zeros(zeros, dtype), rng((nonzeros,), dtype)]) if p.size == 1: # polynomial = const has no roots self.assertTrue(jnp.roots(p, strip_zeros=False).size == 0) else: self.assertTrue(jnp.any(jnp.isnan(jnp.roots(p, strip_zeros=False))))
class NumpyLinalgTest(jtu.JaxTestCase): @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng } for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (1000, 0, 0)] for dtype in float_types + complex_types for rng in [jtu.rand_default()])) def testCholesky(self, shape, dtype, rng): _skip_if_unsupported_type(dtype) def args_maker(): factor_shape = shape[:-1] + (2 * shape[-1], ) a = rng(factor_shape, dtype) return [onp.matmul(a, np.conj(T(a)))] if np.issubdtype(dtype, np.complexfloating) and ( len(shape) > 2 or jtu.device_under_test() != "cpu"): self.skipTest( "Unimplemented case for complex Cholesky decomposition.") self._CheckAgainstNumpy(onp.linalg.cholesky, np.linalg.cholesky, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(np.linalg.cholesky, args_maker, check_dtypes=True) if onp.finfo(dtype).bits == 64: jtu.check_grads(np.linalg.cholesky, args_maker(), order=2) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_n={}".format(jtu.format_shape_dtype_string((n, n), dtype)), "n": n, "dtype": dtype, "rng": rng } for n in [0, 4, 5, 25 ] # TODO(mattjj): complex64 unstable on large sizes? for dtype in float_types + complex_types for rng in [jtu.rand_default()])) def testDet(self, n, dtype, rng): _skip_if_unsupported_type(dtype) args_maker = lambda: [rng((n, n), dtype)] self._CheckAgainstNumpy(onp.linalg.det, np.linalg.det, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(np.linalg.det, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_n={}".format(jtu.format_shape_dtype_string((n, n), dtype)), "n": n, "dtype": dtype, "rng": rng } for n in [0, 4, 10, 200] for dtype in float_types + complex_types for rng in [jtu.rand_default()])) def testSlogdet(self, n, dtype, rng): _skip_if_unsupported_type(dtype) args_maker = lambda: [rng((n, n), dtype)] self._CheckAgainstNumpy(onp.linalg.slogdet, np.linalg.slogdet, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(np.linalg.slogdet, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng } for shape in [(0, 0), (4, 4), (5, 5), (50, 50), (2, 6, 6)] for dtype in float_types + complex_types for rng in [jtu.rand_default()])) # TODO(phawkins): enable when there is an eigendecomposition implementation # for GPU/TPU. @jtu.skip_on_devices("gpu", "tpu") def testEig(self, shape, dtype, rng): _skip_if_unsupported_type(dtype) n = shape[-1] args_maker = lambda: [rng(shape, dtype)] # Norm, adjusted for dimension and type. def norm(x): norm = onp.linalg.norm(x, axis=(-2, -1)) return norm / ((n + 1) * onp.finfo(dtype).eps) a, = args_maker() w, v = np.linalg.eig(a) self.assertTrue( onp.all(norm(onp.matmul(a, v) - w[..., None, :] * v) < 100)) self._CompileAndCheck(partial(np.linalg.eig), args_maker, check_dtypes=True, rtol=1e-3) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng } for shape in [(1, 1), (4, 4), (5, 5)] for dtype in float_types + complex_types for rng in [jtu.rand_default()])) @jtu.skip_on_devices("gpu", "tpu") def testEigBatching(self, shape, dtype, rng): _skip_if_unsupported_type(dtype) shape = (10, ) + shape args = rng(shape, dtype) ws, vs = vmap(np.linalg.eig)(args) self.assertTrue( onp.all( onp.linalg.norm(onp.matmul(args, vs) - ws[..., None, :] * vs) < 1e-3)) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_n={}_lower={}".format( jtu.format_shape_dtype_string((n, n), dtype), lower), "n": n, "dtype": dtype, "lower": lower, "rng": rng } for n in [0, 4, 5, 50] for dtype in float_types + complex_types for lower in [False, True] for rng in [jtu.rand_default()])) # TODO(phawkins): enable when there is an eigendecomposition implementation # for TPU. @jtu.skip_on_devices("tpu") def testEigh(self, n, dtype, lower, rng): _skip_if_unsupported_type(dtype) args_maker = lambda: [rng((n, n), dtype)] uplo = "L" if lower else "U" # Norm, adjusted for dimension and type. def norm(x): norm = onp.linalg.norm(x, axis=(-2, -1)) return norm / ((n + 1) * onp.finfo(dtype).eps) a, = args_maker() a = (a + onp.conj(a.T)) / 2 w, v = np.linalg.eigh(onp.tril(a) if lower else onp.triu(a), UPLO=uplo, symmetrize_input=False) self.assertTrue(norm(onp.eye(n) - onp.matmul(onp.conj(T(v)), v)) < 5) self.assertTrue(norm(onp.matmul(a, v) - w * v) < 30) self._CompileAndCheck(partial(np.linalg.eigh, UPLO=uplo), args_maker, check_dtypes=True, rtol=1e-3) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_lower={}".format( jtu.format_shape_dtype_string(shape, dtype), lower), "shape": shape, "dtype": dtype, "rng": rng, "lower": lower } for shape in [(1, 1), (4, 4), (5, 5), (50, 50)] for dtype in float_types + complex_types for rng in [jtu.rand_default()] for lower in [True, False])) # TODO(phawkins): enable when there is an eigendecomposition implementation # for TPU. @jtu.skip_on_devices("tpu") def testEighGrad(self, shape, dtype, rng, lower): self.skipTest("Test fails with numeric errors.") uplo = "L" if lower else "U" a = rng(shape, dtype) a = (a + onp.conj(a.T)) / 2 a = onp.tril(a) if lower else onp.triu(a) # Gradient checks will fail without symmetrization as the eigh jvp rule # is only correct for tangents in the symmetric subspace, whereas the # checker checks against unconstrained (co)tangents. if dtype not in complex_types: f = partial(np.linalg.eigh, UPLO=uplo, symmetrize_input=True) else: # only check eigenvalue grads for complex matrices f = lambda a: partial( np.linalg.eigh, UPLO=uplo, symmetrize_input=True)(a)[0] jtu.check_grads(f, (a, ), 2, rtol=1e-1) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_lower={}".format( jtu.format_shape_dtype_string(shape, dtype), lower), "shape": shape, "dtype": dtype, "rng": rng, "lower": lower, "eps": eps } for shape in [(1, 1), (4, 4), (5, 5), (50, 50)] for dtype in complex_types for rng in [jtu.rand_default()] for lower in [True, False] for eps in [1e-4])) # TODO(phawkins): enable when there is an eigendecomposition implementation # for TPU. @jtu.skip_on_devices("tpu") def testEighGradVectorComplex(self, shape, dtype, rng, lower, eps): _skip_if_unsupported_type(dtype) # Special case to test for complex eigenvector grad correctness. # Exact eigenvector coordinate gradients are hard to test numerically for complex # eigensystem solvers given the extra degrees of per-eigenvector phase freedom. # Instead, we numerically verify the eigensystem properties on the perturbed # eigenvectors. You only ever want to optimize eigenvector directions, not coordinates! uplo = "L" if lower else "U" a = rng(shape, dtype) a = (a + onp.conj(a.T)) / 2 a = onp.tril(a) if lower else onp.triu(a) a_dot = eps * rng(shape, dtype) a_dot = (a_dot + onp.conj(a_dot.T)) / 2 a_dot = onp.tril(a_dot) if lower else onp.triu(a_dot) # evaluate eigenvector gradient and groundtruth eigensystem for perturbed input matrix f = partial(np.linalg.eigh, UPLO=uplo) (w, v), (dw, dv) = jvp(f, primals=(a, ), tangents=(a_dot, )) new_a = a + a_dot new_w, new_v = f(new_a) new_a = (new_a + onp.conj(new_a.T)) / 2 # Assert rtol eigenvalue delta between perturbed eigenvectors vs new true eigenvalues. RTOL = 1e-2 assert onp.max( onp.abs((onp.diag( onp.dot(onp.conj( (v + dv).T), onp.dot(new_a, (v + dv)))) - new_w) / new_w)) < RTOL # Redundant to above, but also assert rtol for eigenvector property with new true eigenvalues. assert onp.max( onp.linalg.norm( onp.abs(new_w * (v + dv) - onp.dot(new_a, (v + dv))), axis=0) / onp.linalg.norm(onp.abs(new_w * (v + dv)), axis=0)) < RTOL @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng } for shape in [(1, 1), (4, 4), (5, 5)] for dtype in float_types + complex_types for rng in [jtu.rand_default()])) @jtu.skip_on_devices("tpu") def testEighBatching(self, shape, dtype, rng): _skip_if_unsupported_type(dtype) shape = (10, ) + shape args = rng(shape, dtype) args = (args + onp.conj(T(args))) / 2 ws, vs = vmap(jsp.linalg.eigh)(args) self.assertTrue( onp.all( onp.linalg.norm(onp.matmul(args, vs) - ws[..., None, :] * vs) < 1e-3)) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_ord={}_axis={}_keepdims={}".format( jtu.format_shape_dtype_string(shape, dtype), ord, axis, keepdims), "shape": shape, "dtype": dtype, "axis": axis, "keepdims": keepdims, "ord": ord, "rng": rng } for axis, shape in [(None, (1, )), (None, (7, )), (None, ( 5, 8)), (0, (9, )), (0, (4, 5)), ((1, ), ( 10, 7, 3)), ((-2, ), (4, 8)), (-1, (6, 3)), ((0, 2), (3, 4, 5)), ((2, 0), (7, 8, 9)), (None, (7, 8, 11))] for keepdims in [False, True] for ord in ([None] if axis is None and len(shape) > 2 else [None, 0, 1, 2, 3, -1, -2, -3, np.inf, -np.inf] if (axis is None and len(shape) == 1 ) or isinstance(axis, int) or ( isinstance(axis, tuple) and len(axis) == 1 ) else [ None, 'fro', 1, 2, -1, -2, np. inf, -np.inf, 'nuc' ]) for dtype in float_types + complex_types for rng in [jtu.rand_default()])) def testNorm(self, shape, dtype, ord, axis, keepdims, rng): _skip_if_unsupported_type(dtype) if (ord in ('nuc', 2, -2) and (jtu.device_under_test() != "cpu" or (isinstance(axis, tuple) and len(axis) == 2))): raise unittest.SkipTest("No adequate SVD implementation available") args_maker = lambda: [rng(shape, dtype)] onp_fn = partial(onp.linalg.norm, ord=ord, axis=axis, keepdims=keepdims) np_fn = partial(np.linalg.norm, ord=ord, axis=axis, keepdims=keepdims) # Older numpy versions promote to float64 unnecessarily.. check_dtypes = numpy_version >= (1, 15) self._CheckAgainstNumpy(onp_fn, np_fn, args_maker, check_dtypes=check_dtypes, tol=1e-3) self._CompileAndCheck(np_fn, args_maker, check_dtypes=check_dtypes) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_n={}_full_matrices={}_compute_uv={}".format( jtu.format_shape_dtype_string(( m, n), dtype), full_matrices, compute_uv), "m": m, "n": n, "dtype": dtype, "full_matrices": full_matrices, "compute_uv": compute_uv, "rng": rng } for m in [2, 7, 29, 53] for n in [2, 7, 29, 53] for dtype in float_types + complex_types for full_matrices in [False, True] for compute_uv in [False, True] for rng in [jtu.rand_default()])) @jtu.skip_on_devices("tpu") def testSVD(self, m, n, dtype, full_matrices, compute_uv, rng): _skip_if_unsupported_type(dtype) args_maker = lambda: [rng((m, n), dtype)] # Norm, adjusted for dimension and type. def norm(x): norm = onp.linalg.norm(x, axis=(-2, -1)) return norm / (max(m, n) * onp.finfo(dtype).eps) a, = args_maker() out = np.linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv) if compute_uv: # Check the reconstructed matrices if full_matrices: k = min(m, n) if m < n: self.assertTrue( onp.all( norm(a - onp.matmul(out[1] * out[0], out[2][:k, :])) < 50)) else: self.assertTrue( onp.all( norm(a - onp.matmul(out[1] * out[0][:, :k], out[2])) < 50)) else: self.assertTrue( onp.all( norm(a - onp.matmul(out[1] * out[0], out[2])) < 50)) # Check the unitary properties of the singular vector matrices. self.assertTrue( onp.all( norm( onp.eye(out[0].shape[1]) - onp.matmul(onp.conj(T(out[0])), out[0])) < 10)) if m >= n: self.assertTrue( onp.all( norm( onp.eye(out[2].shape[1]) - onp.matmul(onp.conj(T(out[2])), out[2])) < 10)) else: self.assertTrue( onp.all( norm( onp.eye(out[2].shape[0]) - onp.matmul(out[2], onp.conj(T(out[2])))) < 20)) else: self.assertTrue( onp.allclose(onp.linalg.svd(a, compute_uv=False), onp.asarray(out), atol=1e-4, rtol=1e-4)) self._CompileAndCheck(partial(np.linalg.svd, full_matrices=full_matrices, compute_uv=compute_uv), args_maker, check_dtypes=True) if not full_matrices: svd = partial(np.linalg.svd, full_matrices=False) jtu.check_jvp(svd, partial(jvp, svd), (a, ), atol=1e-1 if FLAGS.jax_enable_x64 else jtu.ATOL) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_fullmatrices={}".format( jtu.format_shape_dtype_string(shape, dtype), full_matrices), "shape": shape, "dtype": dtype, "full_matrices": full_matrices, "rng": rng } for shape in [(1, 1), (3, 3), (3, 4), (2, 10, 5), (2, 200, 100)] for dtype in float_types + complex_types for full_matrices in [False, True] for rng in [jtu.rand_default()])) def testQr(self, shape, dtype, full_matrices, rng): _skip_if_unsupported_type(dtype) if (onp.issubdtype(dtype, onp.complexfloating) and (jtu.device_under_test() == "tpu" or jax.lib.version <= (0, 1, 27))): raise unittest.SkipTest("No complex QR implementation") m, n = shape[-2:] if full_matrices: mode, k = "complete", m else: mode, k = "reduced", min(m, n) a = rng(shape, dtype) lq, lr = np.linalg.qr(a, mode=mode) # onp.linalg.qr doesn't support batch dimensions. But it seems like an # inevitable extension so we support it in our version. nq = onp.zeros(shape[:-2] + (m, k), dtype) nr = onp.zeros(shape[:-2] + (k, n), dtype) for index in onp.ndindex(*shape[:-2]): nq[index], nr[index] = onp.linalg.qr(a[index], mode=mode) max_rank = max(m, n) # Norm, adjusted for dimension and type. def norm(x): n = onp.linalg.norm(x, axis=(-2, -1)) return n / (max_rank * onp.finfo(dtype).eps) def compare_orthogonal(q1, q2): # Q is unique up to sign, so normalize the sign first. sum_of_ratios = onp.sum(onp.divide(q1, q2), axis=-2, keepdims=True) phases = onp.divide(sum_of_ratios, onp.abs(sum_of_ratios)) q1 *= phases self.assertTrue(onp.all(norm(q1 - q2) < 30)) # Check a ~= qr self.assertTrue(onp.all(norm(a - onp.matmul(lq, lr)) < 30)) # Compare the first 'k' vectors of Q; the remainder form an arbitrary # orthonormal basis for the null space. compare_orthogonal(nq[..., :k], lq[..., :k]) # Check that q is close to unitary. self.assertTrue( onp.all(norm(onp.eye(k) - onp.matmul(onp.conj(T(lq)), lq)) < 5)) if not full_matrices and m >= n: jtu.check_jvp(np.linalg.qr, partial(jvp, np.linalg.qr), (a, ), atol=1e-3) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng } for shape in [(10, 4, 5), (5, 3, 3), (7, 6, 4)] for dtype in float_types + complex_types for rng in [jtu.rand_default()])) def testQrBatching(self, shape, dtype, rng): args = rng(shape, np.float32) qs, rs = vmap(jsp.linalg.qr)(args) self.assertTrue( onp.all(onp.linalg.norm(args - onp.matmul(qs, rs)) < 1e-3)) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_lhs={}_rhs={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype)), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "rng": rng } for lhs_shape, rhs_shape in [ ((1, 1), (1, 1)), ((4, 4), (4, )), ((8, 8), (8, 4)), ((1, 2, 2), (3, 2)), ((2, 1, 3, 3), (2, 4, 3, 4)), ] for dtype in float_types + complex_types for rng in [jtu.rand_default()])) def testSolve(self, lhs_shape, rhs_shape, dtype, rng): _skip_if_unsupported_type(dtype) args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] self._CheckAgainstNumpy(onp.linalg.solve, np.linalg.solve, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(np.linalg.solve, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng } for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (5, 5, 5)] for dtype in float_types for rng in [jtu.rand_default()])) def testInv(self, shape, dtype, rng): _skip_if_unsupported_type(dtype) if jtu.device_under_test() == "gpu" and shape == (200, 200): raise unittest.SkipTest("Test is flaky on GPU") def args_maker(): invertible = False while not invertible: a = rng(shape, dtype) try: onp.linalg.inv(a) invertible = True except onp.linalg.LinAlgError: pass return [a] self._CheckAgainstNumpy(onp.linalg.inv, np.linalg.inv, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(np.linalg.inv, args_maker, check_dtypes=True) # Regression test for incorrect type for eigenvalues of a complex matrix. @jtu.skip_on_devices("tpu" ) # TODO(phawkins): No eigh implementation on TPU. def testIssue669(self): def test(x): val, vec = np.linalg.eigh(x) return np.real(np.sum(val)) grad_test_jc = jit(grad(jit(test))) xc = onp.eye(3, dtype=onp.complex) self.assertAllClose(xc, grad_test_jc(xc), check_dtypes=True) def testIssue1151(self): A = np.array(onp.random.randn(100, 3, 3), dtype=np.float32) b = np.array(onp.random.randn(100, 3), dtype=np.float32) x = np.linalg.solve(A, b) self.assertAllClose(vmap(np.dot)(A, x), b, atol=1e-3, rtol=1e-3, check_dtypes=True) jac0 = jax.jacobian(np.linalg.solve, argnums=0)(A, b) jac1 = jax.jacobian(np.linalg.solve, argnums=1)(A, b) jac0 = jax.jacobian(np.linalg.solve, argnums=0)(A[0], b[0]) jac1 = jax.jacobian(np.linalg.solve, argnums=1)(A[0], b[0])
class LaxRandomTest(jtu.JaxTestCase): def _CheckCollisions(self, samples, nbits): fail_prob = 0.01 # conservative bound on statistical fail prob by Chebyshev nitems = len(samples) nbins = 2**nbits nexpected = nbins * (1 - ((nbins - 1) / nbins)**nitems) ncollisions = len(np.unique(samples)) sq_percent_deviation = ((ncollisions - nexpected) / nexpected)**2 self.assertLess(sq_percent_deviation, 1 / np.sqrt(nexpected * fail_prob)) def _CheckKolmogorovSmirnovCDF(self, samples, cdf): fail_prob = 0.01 # conservative bound on statistical fail prob by Kolmo CDF self.assertGreater(scipy.stats.kstest(samples, cdf).pvalue, fail_prob) def _CheckChiSquared(self, samples, pmf): alpha = 0.01 # significance level, threshold for p-value values, actual_freq = np.unique(samples, return_counts=True) expected_freq = pmf(values) * samples.size # per scipy: "A typical rule is that all of the observed and expected # frequencies should be at least 5." valid = (actual_freq > 5) & (expected_freq > 5) self.assertGreater( valid.sum(), 1, msg='not enough valid frequencies for chi-squared test') _, p_value = scipy.stats.chisquare(actual_freq[valid], expected_freq[valid]) self.assertGreater(p_value, alpha, msg=f'Failed chi-squared test with p={p_value}.\n' 'Expected vs. actual frequencies:\n' f'{expected_freq[valid]}\n{actual_freq[valid]}') @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype).name } for dtype in [np.float32, np.float64])) def testNumpyAndXLAAgreeOnFloatEndianness(self, dtype): if not FLAGS.jax_enable_x64 and jnp.issubdtype(dtype, np.float64): raise SkipTest("can't test float64 agreement") bits_dtype = np.uint32 if jnp.finfo(dtype).bits == 32 else np.uint64 numpy_bits = np.array(1., dtype).view(bits_dtype) xla_bits = api.jit(lambda: lax.bitcast_convert_type( np.array(1., dtype), bits_dtype))() self.assertEqual(numpy_bits, xla_bits) def testThreefry2x32(self): # We test the hash by comparing to known values provided in the test code of # the original reference implementation of Threefry. For the values, see # https://github.com/DEShawResearch/Random123-Boost/blob/65e3d874b67aa7b3e02d5ad8306462f52d2079c0/libs/random/test/test_threefry.cpp#L30-L32 def result_to_hex(result): return tuple([hex(x.copy()).rstrip("L") for x in result]) expected = ("0x6b200159", "0x99ba4efe") result = random.threefry_2x32(np.uint32([0, 0]), np.uint32([0, 0])) self.assertEqual(expected, result_to_hex(result)) expected = ("0x1cb996fc", "0xbb002be7") result = random.threefry_2x32(np.uint32([-1, -1]), np.uint32([-1, -1])) self.assertEqual(expected, result_to_hex(result)) expected = ("0xc4923a9c", "0x483df7a0") result = random.threefry_2x32(np.uint32([0x13198a2e, 0x03707344]), np.uint32([0x243f6a88, 0x85a308d3])) self.assertEqual(expected, result_to_hex(result)) def testThreefry2x32Large(self): n = 10000000 result = random.threefry_2x32( (np.uint32(0x13198a2e), np.uint32(0x03707344)), jnp.concatenate([ jnp.full((n, ), 0x243f6a88, jnp.uint32), jnp.full((n, ), 0x85a308d3, jnp.uint32) ])) np.testing.assert_equal(result[:n], np.full((n, ), 0xc4923a9c, dtype=np.uint32)) np.testing.assert_equal(result[n:], np.full((n, ), 0x483df7a0, dtype=np.uint32)) def testRngRandomBitsViewProperty(self): # TODO: add 64-bit if it ever supports this property. # TODO: will this property hold across endian-ness? N = 10 key = random.PRNGKey(1701) nbits = [8, 16, 32] if jtu.device_under_test() == "tpu": # U8 and U16 are not supported on TPU. nbits = [32] rand_bits = [ random._random_bits(key, n, (N * 64 // n, )) for n in nbits ] rand_bits_32 = np.array( [np.array(r).view(np.uint32) for r in rand_bits]) print(rand_bits_32) assert np.all(rand_bits_32 == rand_bits_32[0]) def testRngRandomBits(self): # Test specific outputs to ensure consistent random values between JAX versions. key = random.PRNGKey(1701) # U8 and U16 are not supported on TPU. if jtu.device_under_test() != "tpu": bits8 = random._random_bits(key, 8, (3, )) expected8 = np.array([216, 115, 43], dtype=np.uint8) self.assertArraysEqual(bits8, expected8, check_dtypes=True) bits16 = random._random_bits(key, 16, (3, )) expected16 = np.array([41682, 1300, 55017], dtype=np.uint16) self.assertArraysEqual(bits16, expected16, check_dtypes=True) bits32 = random._random_bits(key, 32, (3, )) expected32 = np.array([56197195, 4200222568, 961309823], dtype=np.uint32) self.assertArraysEqual(bits32, expected32, check_dtypes=True) bits64 = random._random_bits(key, 64, (3, )) if FLAGS.jax_enable_x64: expected64 = np.array([ 3982329540505020460, 16822122385914693683, 7882654074788531506 ], dtype=np.uint64) else: expected64 = np.array([676898860, 3164047411, 4010691890], dtype=np.uint32) self.assertArraysEqual(bits64, expected64, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype).name } for dtype in [np.float32, np.float64])) def testRngUniform(self, dtype): key = random.PRNGKey(0) rand = lambda key: random.uniform(key, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self._CheckCollisions(samples, jnp.finfo(dtype).nmant) self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.uniform().cdf) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype).name } for dtype in [np.int32, np.int64])) def testRngRandint(self, dtype): lo = 5 hi = 10 key = random.PRNGKey(0) rand = lambda key: random.randint(key, (10000, ), lo, hi, dtype) crand = api.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self.assertTrue(np.all(lo <= samples)) self.assertTrue(np.all(samples < hi)) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype).name } for dtype in [np.float32, np.float64])) def testNormal(self, dtype): key = random.PRNGKey(0) rand = lambda key: random.normal(key, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.norm().cdf) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype).name } for dtype in [np.float32, np.float64, np.int32, np.int64])) def testShuffle(self, dtype): key = random.PRNGKey(0) x = np.arange(100).astype(dtype) rand = lambda key: random.shuffle(key, x) crand = api.jit(rand) with self.assertWarns(FutureWarning): perm1 = rand(key) with self.assertWarns(FutureWarning): perm2 = crand(key) self.assertAllClose(perm1, perm2, check_dtypes=True) self.assertFalse(np.all(perm1 == x)) # seems unlikely! self.assertAllClose(np.sort(perm1), x, check_dtypes=False) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)), "dtype": np.dtype(dtype).name, "shape": shape } for dtype in [np.float32, np.float64, np.int32, np.int64] for shape in [100, (10, 10), (10, 5, 2)])) def testPermutationArray(self, dtype, shape): key = random.PRNGKey(0) x = jnp.arange(jnp.prod(shape)).reshape(shape).astype(dtype) rand = lambda key: random.permutation(key, x) crand = api.jit(rand) perm1 = rand(key) perm2 = crand(key) self.assertAllClose(perm1, perm2, check_dtypes=True) self.assertFalse(np.all(perm1 == x)) # seems unlikely! self.assertAllClose(np.sort(perm1.ravel()), x.ravel(), check_dtypes=False) self.assertArraysAllClose( x, jnp.arange(jnp.prod(shape)).reshape(shape).astype(dtype), check_dtypes=True) def testPermutationInteger(self): key = random.PRNGKey(0) x = 100 rand = lambda key: random.permutation(key, x) crand = api.jit(rand) perm1 = rand(key) perm2 = crand(key) self.assertAllClose(perm1, perm2, check_dtypes=True) self.assertEqual(perm1.dtype, perm2.dtype) self.assertFalse(np.all(perm1 == np.arange(100))) # seems unlikely! self.assertAllClose(np.sort(perm1), np.arange(100), check_dtypes=False) def testPermutationErrors(self): key = random.PRNGKey(0) with self.assertRaises(TypeError): random.permutation(key, 10.) with self.assertRaises(core.ConcretizationTypeError): api.jit(random.permutation)(key, 10) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_p={}_{}".format(p, dtype), "p": p, "dtype": np.dtype(dtype).name } for p in [0.1, 0.5, 0.9] for dtype in [np.float32, np.float64])) def testBernoulli(self, p, dtype): key = random.PRNGKey(0) p = np.array(p, dtype=dtype) rand = lambda key, p: random.bernoulli(key, p, (10000, )) crand = api.jit(rand) uncompiled_samples = rand(key, p) compiled_samples = crand(key, p) for samples in [uncompiled_samples, compiled_samples]: self._CheckChiSquared(samples, scipy.stats.bernoulli(p).pmf) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_p={}_{}_{}".format(p, dtype, sample_shape), "p": p, "axis": axis, "dtype": np.dtype(dtype).name, 'sample_shape': sample_shape } for (p, axis) in [ ([.25] * 4, -1), ([.1, .2, .3, .4], -1), ([[.5, .5], [.1, .9]], 1), ([[.5, .1], [.5, .9]], 0), ] for sample_shape in [(10000, ), (5000, 2)] for dtype in [np.float32, np.float64])) def testCategorical(self, p, axis, dtype, sample_shape): key = random.PRNGKey(0) p = np.array(p, dtype=dtype) logits = np.log(p) - 42 # test unnormalized out_shape = tuple(np.delete(logits.shape, axis)) shape = sample_shape + out_shape rand = lambda key, p: random.categorical( key, logits, shape=shape, axis=axis) crand = api.jit(rand) uncompiled_samples = rand(key, p) compiled_samples = crand(key, p) if axis < 0: axis += len(logits.shape) for samples in [uncompiled_samples, compiled_samples]: assert samples.shape == shape samples = jnp.reshape(samples, (10000, ) + out_shape) if len(p.shape[:-1]) > 0: ps = np.transpose(p, (1, 0)) if axis == 0 else p for cat_samples, cat_p in zip(samples.transpose(), ps): self._CheckChiSquared(cat_samples, pmf=lambda x: cat_p[x]) else: self._CheckChiSquared(samples, pmf=lambda x: p[x]) def testBernoulliShape(self): key = random.PRNGKey(0) x = random.bernoulli(key, np.array([0.2, 0.3]), shape=(3, 2)) assert x.shape == (3, 2) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_a={}_b={}_{}".format(a, b, dtype), "a": a, "b": b, "dtype": np.dtype(dtype).name } for a in [0.2, 5.] for b in [0.2, 5.] for dtype in [np.float64])) # NOTE: KS test fails with float32 def testBeta(self, a, b, dtype): if not FLAGS.jax_enable_x64: raise SkipTest("skip test except on X64") key = random.PRNGKey(0) rand = lambda key, a, b: random.beta(key, a, b, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key, a, b) compiled_samples = crand(key, a, b) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.beta(a, b).cdf) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype).name } for dtype in [np.float32, np.float64])) def testCauchy(self, dtype): key = random.PRNGKey(0) rand = lambda key: random.cauchy(key, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.cauchy().cdf) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_alpha={}_{}".format(alpha, dtype), "alpha": alpha, "dtype": np.dtype(dtype).name } for alpha in [ np.array([0.2, 1., 5.]), ] for dtype in [np.float32, np.float64])) def testDirichlet(self, alpha, dtype): key = random.PRNGKey(0) rand = lambda key, alpha: random.dirichlet(key, alpha, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key, alpha) compiled_samples = crand(key, alpha) for samples in [uncompiled_samples, compiled_samples]: self.assertAllClose(samples.sum(-1), np.ones(10000, dtype=dtype), check_dtypes=True) alpha_sum = sum(alpha) for i, a in enumerate(alpha): self._CheckKolmogorovSmirnovCDF( samples[..., i], scipy.stats.beta(a, alpha_sum - a).cdf) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype).name } for dtype in [np.float32, np.float64])) def testExponential(self, dtype): key = random.PRNGKey(0) rand = lambda key: random.exponential(key, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.expon().cdf) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_a={}_{}".format(a, dtype), "a": a, "dtype": np.dtype(dtype).name } for a in [0.1, 1., 10.] for dtype in [np.float32, np.float64])) def testGamma(self, a, dtype): key = random.PRNGKey(0) rand = lambda key, a: random.gamma(key, a, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key, a) compiled_samples = crand(key, a) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gamma(a).cdf) def testGammaShape(self): key = random.PRNGKey(0) x = random.gamma(key, np.array([0.2, 0.3]), shape=(3, 2)) assert x.shape == (3, 2) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_a={}".format(alpha), "alpha": alpha } for alpha in [1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4])) def testGammaGrad(self, alpha): rng = random.PRNGKey(0) alphas = np.full((100, ), alpha) z = random.gamma(rng, alphas) actual_grad = api.grad(lambda x: random.gamma(rng, x).sum())(alphas) eps = 0.01 * alpha / (1.0 + np.sqrt(alpha)) cdf_dot = (scipy.stats.gamma.cdf(z, alpha + eps) - scipy.stats.gamma.cdf(z, alpha - eps)) / (2 * eps) pdf = scipy.stats.gamma.pdf(z, alpha) expected_grad = -cdf_dot / pdf self.assertAllClose( actual_grad, expected_grad, check_dtypes=True, rtol=2e-2 if jtu.device_under_test() == "tpu" else 5e-4) def testGammaGradType(self): # Regression test for https://github.com/google/jax/issues/2130 key = random.PRNGKey(0) a = jnp.array(1., dtype=jnp.float32) b = jnp.array(3., dtype=jnp.float32) f = lambda x, y: random.gamma(key=key, a=x, dtype=jnp.float32) / y # Should not crash with a type error. api.vjp(f, a, b) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_lam={}_{}".format(lam, dtype), "lam": lam, "dtype": np.dtype(dtype).name } for lam in [0.5, 3, 9, 11, 50, 500] for dtype in [np.int32, np.int64])) def testPoisson(self, lam, dtype): key = random.PRNGKey(0) rand = lambda key, lam: random.poisson(key, lam, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key, lam) compiled_samples = crand(key, lam) for samples in [uncompiled_samples, compiled_samples]: self._CheckChiSquared(samples, scipy.stats.poisson(lam).pmf) # TODO(shoyer): determine error bounds for moments more rigorously (e.g., # based on the central limit theorem). self.assertAllClose(samples.mean(), lam, rtol=0.01, check_dtypes=False) self.assertAllClose(samples.var(), lam, rtol=0.03, check_dtypes=False) def testPoissonBatched(self): key = random.PRNGKey(0) lam = jnp.concatenate([2 * jnp.ones(10000), 20 * jnp.ones(10000)]) samples = random.poisson(key, lam, shape=(20000, )) self._CheckChiSquared(samples[:10000], scipy.stats.poisson(2.0).pmf) self._CheckChiSquared(samples[10000:], scipy.stats.poisson(20.0).pmf) def testPoissonShape(self): key = random.PRNGKey(0) x = random.poisson(key, np.array([2.0, 20.0]), shape=(3, 2)) assert x.shape == (3, 2) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype).name } for dtype in [np.float32, np.float64])) def testGumbel(self, dtype): key = random.PRNGKey(0) rand = lambda key: random.gumbel(key, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gumbel_r().cdf) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype).name } for dtype in [np.float32, np.float64])) def testLaplace(self, dtype): key = random.PRNGKey(0) rand = lambda key: random.laplace(key, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.laplace().cdf) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype).name } for dtype in [np.float32, np.float64])) def testLogistic(self, dtype): key = random.PRNGKey(0) rand = lambda key: random.logistic(key, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.logistic().cdf) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_b={}_{}".format(b, dtype), "b": b, "dtype": np.dtype(dtype).name } for b in [0.1, 1., 10.] for dtype in [np.float32, np.float64])) def testPareto(self, b, dtype): key = random.PRNGKey(0) rand = lambda key, b: random.pareto(key, b, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key, b) compiled_samples = crand(key, b) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.pareto(b).cdf) def testParetoShape(self): key = random.PRNGKey(0) x = random.pareto(key, np.array([0.2, 0.3]), shape=(3, 2)) assert x.shape == (3, 2) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_df={}_{}".format(df, dtype), "df": df, "dtype": np.dtype(dtype).name } for df in [0.1, 1., 10.] for dtype in [np.float32, np.float64])) @jtu.skip_on_devices("cpu", "tpu") # TODO(phawkins): slow compilation times def testT(self, df, dtype): key = random.PRNGKey(0) rand = lambda key, df: random.t(key, df, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key, df) compiled_samples = crand(key, df) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.t(df).cdf) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_{}D_{}".format(dim, np.dtype(dtype).name), "dim": dim, "dtype": dtype } for dim in [1, 3, 5] for dtype in [np.float32, np.float64])) def testMultivariateNormal(self, dim, dtype): r = np.random.RandomState(dim) mean = r.randn(dim) cov_factor = r.randn(dim, dim) cov = np.dot(cov_factor, cov_factor.T) + dim * np.eye(dim) key = random.PRNGKey(0) rand = partial(random.multivariate_normal, mean=mean, cov=cov, shape=(10000, )) crand = api.jit(rand) uncompiled_samples = np.asarray(rand(key), np.float64) compiled_samples = np.asarray(crand(key), np.float64) inv_scale = scipy.linalg.lapack.dtrtri(np.linalg.cholesky(cov), lower=True)[0] for samples in [uncompiled_samples, compiled_samples]: centered = samples - mean whitened = np.einsum('nj,ij->ni', centered, inv_scale) # This is a quick-and-dirty multivariate normality check that tests that a # uniform mixture of the marginals along the covariance matrix's # eigenvectors follow a standard normal distribution. self._CheckKolmogorovSmirnovCDF(whitened.ravel(), scipy.stats.norm().cdf) def testMultivariateNormalCovariance(self): # test code based on https://github.com/google/jax/issues/1869 N = 100000 cov = jnp.array([[0.19, 0.00, -0.13, 0.00], [0.00, 0.29, 0.00, -0.23], [-0.13, 0.00, 0.39, 0.00], [0.00, -0.23, 0.00, 0.49]]) mean = jnp.zeros(4) out_np = np.random.RandomState(0).multivariate_normal(mean, cov, N) key = random.PRNGKey(0) out_jnp = random.multivariate_normal(key, mean=mean, cov=cov, shape=(N, )) var_np = out_np.var(axis=0) var_jnp = out_jnp.var(axis=0) self.assertAllClose(var_np, var_jnp, rtol=1e-2, atol=1e-2, check_dtypes=False) var_np = np.cov(out_np, rowvar=False) var_jnp = np.cov(out_jnp, rowvar=False) self.assertAllClose(var_np, var_jnp, rtol=1e-2, atol=1e-2, check_dtypes=False) def testIssue222(self): x = random.randint(random.PRNGKey(10003), (), 0, 0) assert x == 0 def testFoldIn(self): key = random.PRNGKey(0) keys = [random.fold_in(key, i) for i in range(10)] assert np.unique(np.ravel(keys)).shape == (20, ) def testStaticShapeErrors(self): if config.read("jax_disable_jit"): raise SkipTest("test only relevant when jit enabled") @api.jit def feature_map(n, d, sigma=1.0, seed=123): key = random.PRNGKey(seed) W = random.normal(key, (d, n)) / sigma w = random.normal(key, (d, )) / sigma b = 2 * jnp.pi * random.uniform(key, (d, )) phi = lambda x, t: jnp.sqrt(2.0 / d) * jnp.cos( jnp.matmul(W, x) + w * t + b) return phi self.assertRaisesRegex(TypeError, 'Shapes must be 1D.*', lambda: feature_map(5, 3)) def testIssue756(self): key = random.PRNGKey(0) w = random.normal(key, ()) if FLAGS.jax_enable_x64: self.assertEqual(np.result_type(w), np.float64) else: self.assertEqual(np.result_type(w), np.float32) def testIssue1789(self): def f(x): return random.gamma(random.PRNGKey(0), x) grad(lambda x: jnp.sum(vmap(f)(x)))(jnp.ones(2)) def testNoOpByOpUnderHash(self): def fail(*args, **kwargs): assert False apply_primitive, xla.apply_primitive = xla.apply_primitive, fail try: out = random.threefry_2x32(np.zeros(2, np.uint32), np.arange(10, dtype=np.uint32)) finally: xla.apply_primitive = apply_primitive def testPRNGValues(self): # Test to ensure consistent random values between JAX versions k = random.PRNGKey(0) randints = random.randint(k, (3, 3), 0, 8) if FLAGS.jax_enable_x64: self.assertAllClose(random.randint(k, (3, 3), 0, 8), np.array([[7, 2, 6], [2, 1, 0], [6, 7, 7]], dtype='int64'), check_dtypes=True) else: self.assertAllClose(random.randint(k, (3, 3), 0, 8), np.array([[2, 1, 3], [6, 1, 5], [6, 3, 4]], dtype='int32'), check_dtypes=True) self.assertAllClose( random.split(k, 4), np.array([[2285895361, 1501764800], [1518642379, 4090693311], [433833334, 4221794875], [839183663, 3740430601]], dtype='uint32'), check_dtypes=True) self.assertAllClose(random.fold_in(k, 4), np.array([2285895361, 433833334], dtype='uint32'), check_dtypes=True)
class LaxBackedScipyTests(jtu.JaxTestCase): def _fetch_preconditioner(self, preconditioner, A, rng=None, return_function=False): """ Returns one of various preconditioning matrices depending on the identifier `preconditioner' and the input matrix A whose inverse it supposedly approximates. """ if preconditioner == 'identity': M = np.eye(A.shape[0], dtype=A.dtype) elif preconditioner == 'random': if rng is None: rng = jtu.rand_default(self.rng()) M = np.linalg.inv(rand_sym_pos_def(rng, A.shape, A.dtype)) elif preconditioner == 'exact': M = np.linalg.inv(A) else: M = None if M is None or not return_function: return M else: return lambda x: jnp.dot(M, x, precision=lax.Precision.HIGHEST) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_preconditioner={}".format( jtu.format_shape_dtype_string(shape, dtype), preconditioner), "shape": shape, "dtype": dtype, "preconditioner": preconditioner} for shape in [(4, 4), (7, 7)] for dtype in [np.float64, np.complex128] for preconditioner in [None, 'identity', 'exact', 'random'])) def test_cg_against_scipy(self, shape, dtype, preconditioner): if not config.FLAGS.jax_enable_x64: raise unittest.SkipTest("requires x64 mode") rng = jtu.rand_default(self.rng()) A = rand_sym_pos_def(rng, shape, dtype) b = rng(shape[:1], dtype) M = self._fetch_preconditioner(preconditioner, A, rng=rng) def args_maker(): return A, b self._CheckAgainstNumpy( partial(scipy_cg, M=M, maxiter=1), partial(lax_cg, M=M, maxiter=1), args_maker, tol=1e-12) self._CheckAgainstNumpy( partial(scipy_cg, M=M, maxiter=3), partial(lax_cg, M=M, maxiter=3), args_maker, tol=1e-12) self._CheckAgainstNumpy( np.linalg.solve, partial(lax_cg, M=M, atol=1e-10), args_maker, tol=1e-6) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype} for shape in [(2, 2)] for dtype in float_types + complex_types)) def test_cg_as_solve(self, shape, dtype): rng = jtu.rand_default(self.rng()) a = rng(shape, dtype) b = rng(shape[:1], dtype) expected = np.linalg.solve(posify(a), b) actual = lax_cg(posify(a), b) self.assertAllClose(expected, actual) actual = jit(lax_cg)(posify(a), b) self.assertAllClose(expected, actual) # numerical gradients are only well defined if ``a`` is guaranteed to be # positive definite. jtu.check_grads( lambda x, y: lax_cg(posify(x), y), (a, b), order=2, rtol=1e-2) def test_cg_ndarray(self): A = lambda x: 2 * x b = jnp.arange(9.0).reshape((3, 3)) expected = b / 2 actual, _ = jax.scipy.sparse.linalg.cg(A, b) self.assertAllClose(expected, actual) def test_cg_pytree(self): A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]} b = {"a": 1.0, "b": -4.0} expected = {"a": 4.0, "b": -6.0} actual, _ = jax.scipy.sparse.linalg.cg(A, b) self.assertEqual(expected.keys(), actual.keys()) self.assertAlmostEqual(expected["a"], actual["a"], places=6) self.assertAlmostEqual(expected["b"], actual["b"], places=6) def test_cg_errors(self): A = lambda x: x b = jnp.zeros((2,)) with self.assertRaisesRegex( ValueError, "x0 and b must have matching tree structure"): jax.scipy.sparse.linalg.cg(A, {'x': b}, {'y': b}) with self.assertRaisesRegex( ValueError, "x0 and b must have matching shape"): jax.scipy.sparse.linalg.cg(A, b, b[:, np.newaxis]) def test_cg_without_pytree_equality(self): @register_pytree_node_class class MinimalPytree: def __init__(self, value): self.value = value def tree_flatten(self): return [self.value], None @classmethod def tree_unflatten(cls, aux_data, children): return cls(*children) A = lambda x: MinimalPytree(2 * x.value) b = MinimalPytree(jnp.arange(5.0)) expected = b.value / 2 actual, _ = jax.scipy.sparse.linalg.cg(A, b) self.assertAllClose(expected, actual.value) # GMRES @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_preconditioner={}_solve_method={}".format( jtu.format_shape_dtype_string(shape, dtype), preconditioner, solve_method), "shape": shape, "dtype": dtype, "preconditioner": preconditioner, "solve_method": solve_method} for shape in [(3, 3)] for dtype in [np.float64, np.complex128] for preconditioner in [None, 'identity', 'exact', 'random'] for solve_method in ['incremental', 'batched'])) def test_gmres_against_scipy( self, shape, dtype, preconditioner, solve_method): if not config.FLAGS.jax_enable_x64: raise unittest.SkipTest("requires x64 mode") rng = jtu.rand_default(self.rng()) A = rng(shape, dtype) b = rng(shape[:1], dtype) M = self._fetch_preconditioner(preconditioner, A, rng=rng) def args_maker(): return A, b self._CheckAgainstNumpy( partial(scipy_gmres, M=M, restart=1, maxiter=1), partial(lax_gmres, M=M, restart=1, maxiter=1, solve_method=solve_method), args_maker, tol=1e-10) self._CheckAgainstNumpy( partial(scipy_gmres, M=M, restart=1, maxiter=2), partial(lax_gmres, M=M, restart=1, maxiter=2, solve_method=solve_method), args_maker, tol=1e-10) self._CheckAgainstNumpy( partial(scipy_gmres, M=M, restart=2, maxiter=1), partial(lax_gmres, M=M, restart=2, maxiter=1, solve_method=solve_method), args_maker, tol=1e-10) self._CheckAgainstNumpy( np.linalg.solve, partial(lax_gmres, M=M, atol=1e-6, solve_method=solve_method), args_maker, tol=1e-10) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_preconditioner={}_solve_method={}".format( jtu.format_shape_dtype_string(shape, dtype), preconditioner, solve_method), "shape": shape, "dtype": dtype, "preconditioner": preconditioner, "solve_method": solve_method} for shape in [(2, 2), (7, 7)] for dtype in float_types + complex_types for preconditioner in [None, 'identity', 'exact'] for solve_method in ['batched', 'incremental'] )) def test_gmres_on_identity_system(self, shape, dtype, preconditioner, solve_method): A = jnp.eye(shape[1], dtype=dtype) solution = jnp.ones(shape[1], dtype=dtype) @jax.tree_util.Partial def A_mv(x): return matmul_high_precision(A, x) rng = jtu.rand_default(self.rng()) M = self._fetch_preconditioner(preconditioner, A, rng=rng, return_function=True) b = A_mv(solution) restart = shape[-1] tol = shape[0] * jnp.finfo(dtype).eps x, info = jax.scipy.sparse.linalg.gmres(A_mv, b, tol=tol, atol=tol, restart=restart, M=M, solve_method=solve_method) using_x64 = solution.dtype.kind in {np.float64, np.complex128} solution_tol = 1e-8 if using_x64 else 1e-4 self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_preconditioner={}_solve_method={}".format( jtu.format_shape_dtype_string(shape, dtype), preconditioner, solve_method), "shape": shape, "dtype": dtype, "preconditioner": preconditioner, "solve_method": solve_method} for shape in [(2, 2), (4, 4)] for dtype in float_types + complex_types for preconditioner in [None, 'identity', 'exact'] for solve_method in ['incremental', 'batched'] )) def test_gmres_on_random_system(self, shape, dtype, preconditioner, solve_method): rng = jtu.rand_default(self.rng()) A = rng(shape, dtype) solution = rng(shape[1:], dtype) @jax.tree_util.Partial def A_mv(x): return matmul_high_precision(A, x) M = self._fetch_preconditioner(preconditioner, A, rng=rng, return_function=True) b = A_mv(solution) restart = shape[-1] tol = shape[0] * jnp.finfo(A.dtype).eps x, info = jax.scipy.sparse.linalg.gmres(A_mv, b, tol=tol, atol=tol, restart=restart, M=M, solve_method=solve_method) using_x64 = solution.dtype.kind in {np.float64, np.complex128} solution_tol = 1e-8 if using_x64 else 1e-4 self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol) def test_gmres_pytree(self): A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]} b = {"a": 1.0, "b": -4.0} expected = {"a": 4.0, "b": -6.0} actual, _ = jax.scipy.sparse.linalg.gmres(A, b) self.assertEqual(expected.keys(), actual.keys()) self.assertAlmostEqual(expected["a"], actual["a"], places=5) self.assertAlmostEqual(expected["b"], actual["b"], places=5) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_preconditioner={}".format( jtu.format_shape_dtype_string(shape, dtype), preconditioner), "shape": shape, "dtype": dtype, "preconditioner": preconditioner} for shape in [(2, 2), (3, 3)] for dtype in float_types + complex_types for preconditioner in [None, 'identity'])) def test_gmres_arnoldi_step(self, shape, dtype, preconditioner): """ The Arnoldi decomposition within GMRES is correct. """ if not config.FLAGS.jax_enable_x64: raise unittest.SkipTest("requires x64 mode") rng = jtu.rand_default(self.rng()) A = rng(shape, dtype) if preconditioner is None: M = lambda x: x else: M = self._fetch_preconditioner(preconditioner, A, rng=rng, return_function=True) n = shape[0] x0 = rng(shape[:1], dtype) Q = np.zeros((n, n + 1), dtype=dtype) Q[:, 0] = x0/jnp.linalg.norm(x0) Q = jnp.array(Q) H = jnp.eye(n, n + 1, dtype=dtype) @jax.tree_util.Partial def A_mv(x): return matmul_high_precision(A, x) for k in range(n): Q, H, _ = jax._src.scipy.sparse.linalg._kth_arnoldi_iteration( k, A_mv, M, Q, H) QA = matmul_high_precision(Q[:, :n].conj().T, A) QAQ = matmul_high_precision(QA, Q[:, :n]) self.assertAllClose(QAQ, H.T[:n, :], rtol=1e-5, atol=1e-5)
class BatchingTest(jtu.JaxTestCase): def testConstantFunction(self): ans = vmap(lambda x: 3)(np.ones(4)) expected = 3 * np.ones(4) self.assertAllClose(ans, expected, check_dtypes=False) def testNestedBatchingMatMat(self): matvec = vmap(jnp.vdot, in_axes=(0, None)) matmat = vmap(matvec, in_axes=(None, 1), out_axes=1) R = np.random.RandomState(0).randn A = R(4, 3) B = R(3, 2) ans = matmat(A, B) expected = np.dot(A, B) self.assertAllClose( ans, expected, check_dtypes=False, rtol={np.float32:1e-2} if jtu.device_under_test() == "tpu" else None) jaxpr = make_jaxpr(matmat)(A, B) self.assertEqual(len(jaxpr.jaxpr.eqns), 1) def testPerExampleGradients(self): def predict(params, inputs): for W, b in params: outputs = jnp.dot(W, inputs) + b inputs = jnp.tanh(outputs) return outputs def loss(params, data): inputs, targets = data predictions = predict(params, inputs) return jnp.sum((predictions - targets)**2) batch_size = 5 layer_sizes = [3, 2, 4] R = np.random.RandomState(0).randn params = [(R(m, n), R(m)) for m, n in zip(layer_sizes[1:], layer_sizes[:-1])] input_batch = R(5, 3) target_batch = R(5, 4) batch = (input_batch, target_batch) ans = vmap(partial(grad(loss), params))(batch) for ans_pair, param_pair in zip(ans, params): dW, db = ans_pair W, b = param_pair self.assertEqual(dW.shape, (batch_size,) + W.shape) self.assertEqual(db.shape, (batch_size,) + b.shape) def testJacobians(self): def jacbwd(f, x): y, pullback = vjp(f, x) std_basis = np.eye(np.size(y)).reshape((-1,) + np.shape(y)) jac_flat, = vmap(pullback, out_axes=np.ndim(y))(std_basis) return jac_flat.reshape(np.shape(y) + np.shape(x)) def jacfwd(f, x): pushfwd = lambda v: jvp(f, (x,), (v,)) std_basis = np.eye(np.size(x)).reshape((-1,) + np.shape(x)) y, jac_flat = vmap(pushfwd, out_axes=(None, 0))(std_basis) return jac_flat.reshape(np.shape(y) + np.shape(x)) R = np.random.RandomState(0).randn A = R(4, 3) b = R(4) f = lambda x: jnp.tanh(jnp.dot(A, x) + b) x = R(3) self.assertAllClose(jacfwd(f, x), jacbwd(f, x), check_dtypes=False) def testBatchOfCompile(self): side = [] @jit def f(x): side.append(None) return x + x g = jit(vmap(f)) self.assertAllClose(g(np.ones(2)), 2 * np.ones(2), check_dtypes=False) self.assertEqual(len(side), 1) self.assertAllClose(g(2 * np.ones(2)), 4 * np.ones(2), check_dtypes=False) self.assertEqual(len(side), 1) def testSliceLax(self): fun = lambda x: lax.slice(x, (2,), (4,)) R = np.random.RandomState(0).randn x = R(5, 10) ans = vmap(fun)(x) expected_ans = x[:, 2:4] self.assertAllClose(ans, expected_ans, check_dtypes=False) def testSliceNumpy(self): fun = lambda x: x[:, 2] R = np.random.RandomState(0).randn x = R(10, 5, 3, 7) ans = vmap(fun)(x) expected_ans = x[:, :, 2] self.assertAllClose(ans, expected_ans, check_dtypes=False) def testRevLax(self): fun = lambda x: lax.rev(x, [0]) R = np.random.RandomState(0).randn x = R(2, 3) ans = vmap(fun)(x) expected_ans = x[:, ::-1] self.assertAllClose(ans, expected_ans, check_dtypes=False) ans = vmap(fun, (1,), 1)(x) expected_ans = x[::-1, :] self.assertAllClose(ans, expected_ans, check_dtypes=False) def testRevNumpy(self): fun = lambda x: x[:, ::-1] R = np.random.RandomState(0).randn x = R(3, 2, 4) ans = vmap(fun)(x) expected_ans = x[:, :, ::-1] self.assertAllClose(ans, expected_ans, check_dtypes=False) ans = vmap(fun, (1,), 1)(x) expected_ans = x[:, :, ::-1] self.assertAllClose(ans, expected_ans, check_dtypes=False) ans = vmap(fun, (2,), 2)(x) expected_ans = x[:, ::-1, :] self.assertAllClose(ans, expected_ans, check_dtypes=False) def testNpMaximum(self): fun = lambda x: jnp.maximum(x, 0.0) R = np.random.RandomState(0).randn x = R(10, 5, 3, 7) ans = vmap(fun)(x) expected_ans = np.maximum(x, 0.0) self.assertAllClose(ans, expected_ans, check_dtypes=False) def testNpGtrThan(self): R = np.random.RandomState(0).randn x = R(10, 5, 3, 7) ans = vmap(lambda x: x > 1.0)(x) expected_ans = x > 1.0 self.assertAllClose(ans, expected_ans) def testNpMaximumPerExampleGrad(self): R = np.random.RandomState(0).randn x = R(10, 5) W = R(5, 5) fun = lambda W, x: jnp.sum(jnp.maximum(jnp.dot(x, W), 0.0) ** 2) ans = vmap(partial(grad(fun), W))(x) W_t = jnp.transpose(W) for i in range(10): x_ex = x[i:i + 1] expected_ans = 2.0 * jnp.dot( jnp.maximum(jnp.dot(W_t, jnp.transpose(x_ex)), 0.0), x_ex) expected_ans = jnp.transpose(expected_ans) self.assertAllClose( ans[i], expected_ans, check_dtypes=False, atol={np.float32:5e-2} if jtu.device_under_test() == "tpu" else None) def testDotGeneral(self): R = np.random.RandomState(0).randn x = R(10, 3, 4, 5) y = R(10, 3, 5, 6) fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))]) ans = vmap(fun)(x, y) expected = lax.dot_general(x, y, [((3,), (2,)), ((0, 1), (0, 1))]) self.assertAllClose(ans, expected) x = R(3, 4, 10, 5) y = R(3, 10, 5, 6) fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))]) ans = vmap(fun, in_axes=(2, 1))(x, y) expected = np.stack([fun(x[..., i, :], y[:, i, ...]) for i in range(10)]) self.assertAllClose(ans, expected) x = R(3, 4, 5, 10) y = R(3, 5, 6) fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))]) ans = vmap(fun, in_axes=(3, None))(x, y) expected = np.stack([fun(x[..., i], y) for i in range(10)]) self.assertAllClose(ans, expected) x = R(3, 4, 5) y = R(3, 5, 10, 6) fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))]) ans = vmap(fun, in_axes=(None, 2))(x, y) expected = np.stack([fun(x, y[..., i, :]) for i in range(10)]) self.assertAllClose(ans, expected) x = R(4) y = R(4, 10) fun = lambda x, y: lax.dot_general(x, y, [((0,), (0,)), ((), ())]) ans = vmap(fun, in_axes=(None, 1))(x, y) expected = np.stack([fun(x, y[..., i]) for i in range(10)]) self.assertAllClose(ans, expected) def testDot(self): # these tests are based on @shoyer's notebook studying gufuncs def vecvec(a, b): dot = jnp.dot for ndim in range(1, max(a.ndim, b.ndim)): a_ax = 0 if a.ndim > ndim else None b_ax = 0 if b.ndim > ndim else None dot = vmap(dot, in_axes=(a_ax, b_ax)) return dot(a, b) assert vecvec(jnp.zeros((3,)), jnp.zeros((3,))).shape == () assert vecvec(jnp.zeros((2, 3)), jnp.zeros((3,))).shape == (2,) assert vecvec(jnp.zeros((4, 2, 3)), jnp.zeros((3,))).shape == (4, 2) def testDot2(self): R = np.random.RandomState(0).randn xs = R(10, 3) ys = R(10, 3) ans = vmap(jnp.dot)(xs, ys) expected = np.einsum('ni,ni->n', xs, ys) self.assertAllClose(ans, expected, check_dtypes=False) def testDot3(self): R = np.random.RandomState(0).randn xs = R(5, 8, 10) ys = R(10, 1) ans = vmap(jnp.dot, in_axes=(1, None))(xs, ys) expected = np.einsum('inj,jk->nik', xs, ys) self.assertAllClose(ans, expected, check_dtypes=False) def testDot4(self): R = np.random.RandomState(0).randn xs = R(3, 2) ys = R(3) ans = vmap(jnp.dot, in_axes=(1, None))(xs, ys) expected = np.einsum('ij,i->j', xs, ys) self.assertAllClose(ans, expected, check_dtypes=False) def testPad(self): R = np.random.RandomState(0).randn fun = lambda x: lax.pad(x, np.float32(0), [(1, 2, 1)]) x = R(5, 10).astype(np.float32) ans = vmap(fun)(x) expected_ans = jnp.stack(list(map(fun, x))) self.assertAllClose(ans, expected_ans, check_dtypes=False) fun = lambda x: lax.pad(x, np.float32(0), [(1, 2, 1), (0, 1, 0)]) x = R(5, 10, 3).astype(np.float32) ans = vmap(fun)(x) expected_ans = jnp.stack(list(map(fun, x))) self.assertAllClose(ans, expected_ans, check_dtypes=False) def testConcatenate(self): R = lambda *shape: np.random.RandomState(0).randn(*shape).astype(np.float32) fun = lambda *args: lax.concatenate(args, dimension=0) x, y, z = R(10, 2, 3), R(1, 10, 3), R(4, 3) ans = vmap(fun, in_axes=(0, 1, None))(x, y, z) expected_ans = np.concatenate([x, np.swapaxes(y, 0, 1), np.broadcast_to(z, (10, 4, 3))], 1) self.assertAllClose(ans, expected_ans, check_dtypes=False) fun = lambda *args: lax.concatenate(args, dimension=1) x, y, z = R(10, 2, 1), R(2, 3), R(2, 4, 10) ans = vmap(fun, in_axes=(0, None, 2))(x, y, z) expected_ans = np.concatenate([x, np.broadcast_to(y, (10, 2, 3)), np.moveaxis(z, 2, 0)], 2) self.assertAllClose(ans, expected_ans, check_dtypes=False) def testJacobianIssue54(self): # test modeling the code in https://github.com/google/jax/issues/54 def func(xs): return jnp.array(list(xs)) xs = jnp.ones((5, 1)) jacrev(func)(xs) # don't crash jacfwd(func)(xs) # don't crash def testAny(self): # test modeling the code in https://github.com/google/jax/issues/108 ans = vmap(jnp.any)(jnp.array([[True, False], [False, False]])) expected = jnp.array([True, False]) self.assertAllClose(ans, expected) @jtu.skip_on_devices("tpu") def testHessian(self): # test based on code from sindhwani@google def fun(x, t): return jnp.sum(jnp.power(jnp.maximum(x, 0.0), 2)) + t x = np.array([-1., -0.5, 0., 0.5, 1.0]) ans = hessian(lambda x: fun(x, 0.0))(x) expected = np.array([[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0.,0.5, 0., 0.], [0., 0., 0., 2., 0.], [0., 0., 0., 0., 2.]]) self.assertAllClose(ans, expected, check_dtypes=False) def testDynamicSlice(self): # test dynamic_slice via numpy indexing syntax # see https://github.com/google/jax/issues/1613 for an explanation of why we # need to use np rather than np to create x and idx x = jnp.arange(30).reshape((10, 3)) ans = vmap(lambda x, i: x[i], in_axes=(0, None))(x, 1) expected = x[:, 1] self.assertAllClose(ans, expected, check_dtypes=False) idx = jnp.array([0, 1, 2, 1, 0] * 2) ans = vmap(lambda x, i: x[i], in_axes=(0, 0))(x, idx) expected = x[np.arange(10), idx] self.assertAllClose(ans, expected, check_dtypes=False) x = jnp.arange(3) idx = jnp.array([0, 1, 2, 1, 0] * 2) ans = vmap(lambda x, i: x[i], in_axes=(None, 0))(x, idx) expected = x[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testDynamicUpdateSlice(self): x = np.random.randn(10, 3) y = np.random.randn(10) ans = vmap(lambda x, y, i: lax.dynamic_update_index_in_dim(x, y, i, axis=0), in_axes=(0, 0, None))(x, y, 1) expected = x.copy() expected[:, 1] = y self.assertAllClose(ans, expected, check_dtypes=False) x = np.random.randn(3) idx = np.array([0, 1, 2, 1, 0] * 2) ans = vmap(lambda x, y, i: lax.dynamic_update_index_in_dim(x, y, i, axis=0), in_axes=(None, 0, 0))(x, y, idx) expected = np.broadcast_to(x, (10, 3)).copy() expected[np.arange(10), idx] = y self.assertAllClose(ans, expected, check_dtypes=False) def testRandom(self): seeds = vmap(random.PRNGKey)(np.arange(10)) ans = vmap(partial(random.normal, shape=(3, 2)))(seeds) expected = np.stack([random.normal(random.PRNGKey(seed), (3, 2)) for seed in np.arange(10)]) self.assertAllClose(ans, expected, check_dtypes=False) assert len(np.unique(ans)) == 10 * 3 * 2 def testSort(self): v = np.arange(12)[::-1].reshape(3, 4) sv = vmap(partial(lax.sort, dimension=0), (0,))(v) self.assertAllClose(sv, v[:, ::-1]) sv = vmap(partial(lax.sort, dimension=-1), (0,))(v) self.assertAllClose(sv, v[:, ::-1]) sv = vmap(partial(lax.sort, dimension=0), (1,))(v) self.assertAllClose(sv, v[::-1, :].T) sv = vmap(partial(lax.sort, dimension=0), (1,), 1)(v) self.assertAllClose(sv, v[::-1, :]) def testSortKeyVal(self): k = np.arange(12)[::-1].reshape(3, 4) v = np.random.RandomState(0).permutation(12).reshape(3, 4) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 0))(k, v) self.assertAllClose(sk, k[:, ::-1]) self.assertAllClose(sv, v[:, ::-1]) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 1), 1)(k, v) self.assertAllClose(sk, k[::-1, :]) self.assertAllClose(sv, v[::-1, :]) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 1))(k, v.T) self.assertAllClose(sk, k[:, ::-1]) self.assertAllClose(sv, v[:, ::-1]) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 0))(k.T, v) self.assertAllClose(sk, k[:, ::-1]) self.assertAllClose(sv, v[:, ::-1]) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (None, 0))(k[0], v) self.assertAllClose(sk, np.broadcast_to(k[0, ::-1], (3, 4))) self.assertAllClose(sv, v[:, ::-1]) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, None))(k.T, v[0]) self.assertAllClose(sk, k[:, ::-1]) self.assertAllClose(sv, np.broadcast_to(v[0, ::-1], (3, 4))) def testConvGeneralDilated(self): W = jnp.array(np.random.randn(3, 3, 1, 5), dtype=np.float32) X = jnp.array(np.random.randn(10, 5, 5, 1), dtype=np.float32) def f(params, x): one = (1, 1) dimension_numbers = ('NHWC', 'HWIO', 'NHWC') y = lax.conv_general_dilated( x, params, one, 'SAME', one, one, dimension_numbers) return y grad_loss = grad(lambda params, x: jnp.mean(f(params, x) ** 2)) # Test forward prop. per_example = vmap(partial(f, W))(jnp.reshape(X, (10, 1, 5, 5, 1))) per_example = jnp.reshape(per_example, (10, 5, 5, 5)) per_example_direct = f(W, X) self.assertAllClose(per_example, per_example_direct) # Test gradients. per_example = vmap(partial(grad_loss, W))(jnp.reshape(X, (10, 1, 5, 5, 1))) per_example_direct = [] for i in range(10): g = grad_loss(W, jnp.reshape(X[i], (1, 5, 5, 1))) per_example_direct += [ jnp.reshape(g, (1,) + g.shape)] per_example_direct = jnp.concatenate(per_example_direct, axis=0) self.assertAllClose(per_example, per_example_direct, rtol=2e-2, atol=2e-3) def testConvGeneralDilatedBatchNotMajor(self): W = jnp.array(np.random.randn(3, 3, 1, 4), dtype=np.float32) x = jnp.array(np.random.randn(3, 5, 7, 5, 1), dtype=np.float32) def f(params, x): one = (1, 1) dimension_numbers = ('HNWC', 'HWIO', 'HWNC') y = lax.conv_general_dilated( x, params, one, 'SAME', one, one, dimension_numbers) return y per_example = vmap(partial(f, W))(x) per_example = jnp.reshape(jnp.transpose(per_example, (1, 2, 0, 3, 4)), (5, 5, 21, 4)) per_example_direct = f(W, jnp.reshape(jnp.transpose(x, (1, 0, 2, 3, 4)), (5, 21, 5, 1))) self.assertAllClose(per_example, per_example_direct) @parameterized.named_parameters( {"testcase_name": "_op={}".format(name), "op": op, "unit": unit} for name, op, unit in [("max", lax.max, -jnp.inf), ("min", lax.min, jnp.inf)]) def testMinMaxPool(self, op, unit): W = jnp.array(np.random.randn(3, 3, 1, 5), dtype=np.float32) X = jnp.array(np.random.randn(10, 5, 5, 1), dtype=np.float32) def f(params, x): one = (1, 1) dimension_numbers = ('NHWC', 'HWIO', 'NHWC') y = lax.conv_general_dilated( x, params, one, 'SAME', one, one, dimension_numbers) y = lax.reduce_window( y, unit, op, (1, 2, 2, 1), (1, 1, 1, 1), 'SAME') return y grad_loss = grad(lambda params, x: jnp.mean(f(params, x) ** 2)) # Test forward prop. per_example = vmap(partial(f, W))(jnp.reshape(X, (10, 1, 5, 5, 1))) per_example = jnp.reshape(per_example, (10, 5, 5, 5)) per_example_direct = f(W, X) self.assertAllClose(per_example, per_example_direct) # Test gradients. per_example = vmap(partial(grad_loss, W))(jnp.reshape(X, (10, 1, 5, 5, 1))) per_example_direct = [] for i in range(10): g = grad_loss(W, jnp.reshape(X[i], (1, 5, 5, 1))) per_example_direct += [ jnp.reshape(g, (1,) + g.shape)] per_example_direct = jnp.concatenate(per_example_direct, axis=0) self.assertAllClose(per_example, per_example_direct, rtol=5e-2, atol=1e-3) def testSumPool(self): W = jnp.array(np.random.randn(3, 3, 1, 5), dtype=np.float32) X = jnp.array(np.random.randn(10, 5, 5, 1), dtype=np.float32) def f(params, x): one = (1, 1) dimension_numbers = ('NHWC', 'HWIO', 'NHWC') y = lax.conv_general_dilated( x, params, one, 'SAME', one, one, dimension_numbers) y = lax.reduce_window( y, 0.0, lax.add, (1, 2, 2, 1), (1, 1, 1, 1), 'SAME') return y grad_loss = grad(lambda params, x: jnp.mean(f(params, x) ** 2)) # Test forward prop. per_example = vmap(partial(f, W))(jnp.reshape(X, (10, 1, 5, 5, 1))) per_example = jnp.reshape(per_example, (10, 5, 5, 5)) per_example_direct = f(W, X) self.assertAllClose(per_example, per_example_direct) # Test gradients. per_example = vmap(partial(grad_loss, W))(jnp.reshape(X, (10, 1, 5, 5, 1))) per_example_direct = [] for i in range(10): g = grad_loss(W, jnp.reshape(X[i], (1, 5, 5, 1))) per_example_direct += [ jnp.reshape(g, (1,) + g.shape)] per_example_direct = jnp.concatenate(per_example_direct, axis=0) self.assertAllClose(per_example, per_example_direct, rtol=3e-2, atol=1e-3) def testCumProd(self): x = jnp.arange(9).reshape(3, 3) + 1 y = vmap(lambda x: jnp.cumprod(x, axis=-1))(x) self.assertAllClose(np.cumprod(x, axis=1, dtype=jnp.int_), y) def testSelect(self): pred = np.array([True, False]) on_true = np.array([0, 1]) on_false = np.array([2, 3]) ans = vmap(lax.select)(pred, on_true, on_false) expected = np.array([0, 3]) self.assertAllClose(ans, expected) pred = np.array([False, True]) on_true = np.array([0, 1]) on_false = np.array([2, 3]) ans = vmap(lax.select, (0, None, None))(pred, on_true, on_false) expected = np.array([[2, 3], [0, 1]]) self.assertAllClose(ans, expected) pred = True on_true = np.array([0, 1], np.float32) on_false = np.array(3, np.float32) ans = vmap(lax.select, (None, 0, None))(pred, on_true, on_false) expected = np.array([0, 1], np.float32) self.assertAllClose(ans, expected) pred = np.array([False, True]) on_true = np.array([0, 1], np.float32) on_false = np.array(3, np.float32) ans = vmap(lax.select, (0, 0, None))(pred, on_true, on_false) expected = np.array([3, 1], np.float32) self.assertAllClose(ans, expected) pred = np.array([False, True]) on_true = np.array([2], np.float32) on_false = np.array([[3, 4]], np.float32) ans = vmap(lax.select, (0, None, 1), 1)(pred, on_true, on_false) expected = np.array([[3, 2]], np.float32) self.assertAllClose(ans, expected) def testLaxLinalgCholesky(self): a = np.random.RandomState(0).randn(10, 5, 5).astype(np.float32) a = np.matmul(a, np.conj(np.swapaxes(a, -1, -2))) ans = vmap(lax.linalg.cholesky)(a) expected = np.linalg.cholesky(a) self.assertAllClose(ans, expected, check_dtypes=False, rtol=1e-4) b = np.random.RandomState(0).randn(10, 5, 5).astype(np.float32) b = np.matmul(b, np.conj(np.swapaxes(b, -1, -2))) b_trans = np.swapaxes(b, 0, 1) # shape is (5, 10, 5) ans = vmap(lax.linalg.cholesky, in_axes=1, out_axes=0)(b_trans) expected = np.linalg.cholesky(b) self.assertAllClose(ans, expected, check_dtypes=False, rtol=1e-4) def testLaxLinalgTriangularSolve(self): a = np.random.RandomState(0).randn(4, 10, 4).astype(np.float32) a += np.eye(4, dtype=jnp.float32)[:, None, :] b = np.random.RandomState(0).randn(5, 4, 10).astype(np.float32) ans = vmap(lax.linalg.triangular_solve, in_axes=(1, 2))(a, b) expected = np.stack( [lax.linalg.triangular_solve(a[:, i], b[..., i]) for i in range(10)]) self.assertAllClose(ans, expected) ans = vmap(lax.linalg.triangular_solve, in_axes=(None, 2))(a[:, 0], b) expected = np.stack( [lax.linalg.triangular_solve(a[:, 0], b[..., i]) for i in range(10)]) self.assertAllClose(ans, expected) ans = vmap(lax.linalg.triangular_solve, in_axes=(1, None))(a, b[..., 0]) expected = np.stack( [lax.linalg.triangular_solve(a[:, i], b[..., 0]) for i in range(10)]) self.assertAllClose(ans, expected) @parameterized.named_parameters( {"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums, slice_sizes), "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "rng_factory": rng_factory} for dtype in [np.float32, np.int32] for axis, shape, idxs, dnums, slice_sizes in [ (0, (3, 5), np.array([[0], [2]]), lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), (1,)), (1, (10, 3), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), (2,)), (1, (10, 3, 5), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3)), (2, (10, 5, 3), np.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)), ] for rng_factory in [jtu.rand_default]) def testGatherBatchedOperand(self, axis, shape, dtype, idxs, dnums, slice_sizes, rng_factory): rng = rng_factory(self.rng()) fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) operand = rng(shape, dtype) ans = vmap(fun, (axis, None))(operand, idxs) expected = np.stack([fun(operand[(slice(None),) * axis + (i,)], idxs) for i in range(operand.shape[axis])]) self.assertAllClose(ans, expected, check_dtypes=False) @parameterized.named_parameters( {"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums, slice_sizes), "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "rng_factory": rng_factory} for dtype in [np.float32, np.float64] for axis, shape, idxs, dnums, slice_sizes in [ (0, (3, 5), np.array([[0], [2]]), lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), (1,)), (1, (10, 3), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), (2,)), (1, (10, 3, 5), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3)), (2, (10, 5, 3), np.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)), ] for rng_factory in [jtu.rand_default]) def testGatherGradBatchedOperand(self, axis, shape, dtype, idxs, dnums, slice_sizes, rng_factory): rng = rng_factory(self.rng()) fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) gfun = grad(lambda x, idx: jnp.sum(jnp.sin(fun(x, idx)))) operand = rng(shape, dtype) ans = vmap(gfun, (axis, None))(operand, idxs) expected = np.stack([gfun(operand[(slice(None),) * axis + (i,)], idxs) for i in range(operand.shape[axis])]) self.assertAllClose(ans, expected, check_dtypes=False) @parameterized.named_parameters( {"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums, slice_sizes), "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "rng_factory": rng_factory} for dtype in [np.float32, np.int32] for axis, shape, idxs, dnums, slice_sizes in [ (0, (5,), np.array([[[0], [2]], [[1], [3]]]), lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), (1,)), (1, (10,), np.array([[0, 0, 0], [0, 2, 1]]).T[..., None], lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), (2,)), (1, (10, 5), np.array([[0, 2, 1], [0, 3, 3]]).T[..., None], lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3)), (0, (10, 5), np.array([[[0, 1], [2, 0]], [[1, 0], [2, 3]]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)), ] for rng_factory in [jtu.rand_default]) def testGatherBatchedIndices(self, axis, shape, dtype, idxs, dnums, slice_sizes, rng_factory): rng = rng_factory(self.rng()) fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) operand = rng(shape, dtype) ans = vmap(fun, (None, axis))(operand, idxs) expected = np.stack([fun(operand, idxs[(slice(None),) * axis + (i,)]) for i in range(idxs.shape[axis])]) self.assertAllClose(ans, expected, check_dtypes=False) @parameterized.named_parameters( {"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums, slice_sizes), "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "rng_factory": rng_factory} for dtype in [np.float32, np.float64] for axis, shape, idxs, dnums, slice_sizes in [ (0, (5,), np.array([[[0], [2]], [[1], [3]]]), lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), (1,)), (1, (10,), np.array([[0, 0, 0], [0, 2, 1]]).T[..., None], lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), (2,)), (1, (10, 5), np.array([[0, 2, 1], [0, 3, 3]]).T[..., None], lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3)), (0, (10, 5), np.array([[[0, 1], [2, 0]], [[1, 0], [2, 3]]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)), ] for rng_factory in [jtu.rand_default]) def testGatherGradBatchedIndices(self, axis, shape, dtype, idxs, dnums, slice_sizes, rng_factory): rng = rng_factory(self.rng()) fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) gfun = grad(lambda x, idx: jnp.sum(jnp.sin(fun(x, idx)))) operand = rng(shape, dtype) ans = vmap(gfun, (None, axis))(operand, idxs) expected = np.stack([gfun(operand, idxs[(slice(None),) * axis + (i,)]) for i in range(idxs.shape[axis])]) self.assertAllClose(ans, expected, check_dtypes=False) @parameterized.named_parameters( {"testcase_name": "_shape={}_op_axis={}_idxs_axis={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), op_axis, idxs_axis, idxs, dnums, slice_sizes), "op_axis": op_axis, "idxs_axis": idxs_axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "rng_factory": rng_factory} for dtype in [np.float32, np.int32] for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [ (0, 0, (2, 5), np.array([[[0], [2]], [[1], [3]]]), lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), (1,)), (1, 1, (10, 2), np.array([[0, 0, 0], [0, 2, 1]]).T[..., None], lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), (2,)), (0, 1, (2, 10, 5,), np.array([[[0, 2, 1], [0, 3, 3]]]).T, lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3)), (2, 0, (10, 5, 2), np.array([[[0, 2], [1, 0]], [[1, 0], [2, 0]]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)), ] for rng_factory in [jtu.rand_default]) def testGatherBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs, dnums, slice_sizes, rng_factory): rng = rng_factory(self.rng()) fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) operand = rng(shape, dtype) assert operand.shape[op_axis] == idxs.shape[idxs_axis] ans = vmap(fun, (op_axis, idxs_axis))(operand, idxs) expected = np.stack([fun(operand[(slice(None),) * op_axis + (i,)], idxs[(slice(None),) * idxs_axis + (i,)]) for i in range(idxs.shape[idxs_axis])]) self.assertAllClose(ans, expected, check_dtypes=False) @parameterized.named_parameters( {"testcase_name": "_shape={}_op_axis={}_idxs_axis={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), op_axis, idxs_axis, idxs, dnums, slice_sizes), "op_axis": op_axis, "idxs_axis": idxs_axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "rng_factory": rng_factory} for dtype in [np.float32] for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [ (0, 0, (2, 5), np.array([[[0], [2]], [[1], [3]]]), lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), (1,)), (1, 1, (10, 2), np.array([[0, 0, 0], [0, 2, 1]]).T[..., None], lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), (2,)), (0, 1, (2, 10, 5,), np.array([[[0, 2, 1], [0, 3, 3]]]).T, lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3)), (2, 0, (10, 5, 2), np.array([[[0, 2], [1, 0]], [[1, 0], [2, 0]]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)), ] for rng_factory in [jtu.rand_default]) def testGatherGradBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs, dnums, slice_sizes, rng_factory): rng = rng_factory(self.rng()) fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) gfun = grad(lambda x, idx: jnp.sum(jnp.sin(fun(x, idx)))) operand = rng(shape, dtype) assert operand.shape[op_axis] == idxs.shape[idxs_axis] ans = vmap(gfun, (op_axis, idxs_axis))(operand, idxs) expected = np.stack([gfun(operand[(slice(None),) * op_axis + (i,)], idxs[(slice(None),) * idxs_axis + (i,)]) for i in range(idxs.shape[idxs_axis])]) self.assertAllClose(ans, expected, check_dtypes=False) def testNumpyIndexing1(self): a = jnp.arange(2 * 3 * 4).reshape((2, 3, 4)) ind = np.array([[0, 1], [2, 0]]) def f(a, ind): return a[:, ind] expected = np.stack([f(a, ind[i, :]) for i in range(ind.shape[0])]) ans = vmap(f, (None, 0))(a, ind) assert np.all(ans == expected) def testNumpyIndexing2(self): a = jnp.arange(2 * 3 * 4).reshape((2, 3, 4)) def f(a): inds = jnp.array([0, 2]) return a[:, inds] ans = vmap(f)(a) expected = np.stack([f(a[:, i, :]) for i in range(a.shape[1])], axis=1) assert np.all(ans == expected) def testTranspose(self): x = np.arange(4 * 3 * 3).reshape((4, 3, 3)) ans = vmap(lambda x: x + x.T)(x) expected = x + np.swapaxes(x, -1, -2) self.assertAllClose(ans, expected, check_dtypes=False) def testTransposePermutation(self): x = np.arange(6 * 3 * 4 * 5).reshape((6, 3, 4, 5)) ans = vmap(lambda x: jnp.transpose(x, (1, 0, 2)))(x) expected = np.transpose(x, (0, 2, 1, 3)) self.assertAllClose(ans, expected, check_dtypes=False) x = np.arange(6 * 3 * 4 * 5).reshape((6, 3, 4, 5)) ans = vmap(lambda x: jnp.transpose(x, (1, 2, 0)))(x) expected = np.transpose(x, (0, 2, 3, 1)) self.assertAllClose(ans, expected, check_dtypes=False) x = np.arange(6 * 3 * 4 * 5).reshape((3, 4, 6, 5)) ans = vmap(lambda x: jnp.transpose(x, (1, 2, 0)), in_axes=2)(x) expected = np.transpose(x, (2, 1, 3, 0)) self.assertAllClose(ans, expected, check_dtypes=False) def testIssue354(self): psd_mat = np.random.randn(20, 10) psd_mat = psd_mat.T.dot(psd_mat) vec = np.random.randn(10) def f(scale): scaled_mat = scale * psd_mat chol = jnp.linalg.cholesky(scaled_mat) return -0.5 * jnp.sum((jnp.einsum('ij,j->i', chol, vec))**2) vmapped_f = vmap(f) vmapped_f_grad = grad(lambda x: jnp.sum(vmapped_f(x))) scales = np.array([[0.1], [0.2], [0.3], [0.4], [0.5]]) ans = vmapped_f_grad(scales) # don't crash! expected = np.stack([grad(f)(scale) for scale in scales]) self.assertAllClose(ans, expected, check_dtypes=False, rtol=jtu.default_gradient_tolerance) def testIssue387(self): # https://github.com/google/jax/issues/387 R = np.random.RandomState(0).rand(100, 2) def dist_sq(R): dR = R[:, jnp.newaxis, :] - R[jnp.newaxis, :, :] zero = jnp.zeros_like(dR) dR = dR - jnp.where(jnp.abs(dR) < 0.5, zero, 0.5 * jnp.sign(dR)) return jnp.sum(dR ** 2, axis=2) @jit def f(R): _ = dist_sq(R) return jnp.sum(R ** 2) _ = hessian(f)(R) # don't crash on UnshapedArray def testIssue489(self): def f(key): def body_fn(uk): key = uk[1] u = random.uniform(key, (), dtype=jnp.float64) key, _ = random.split(key) return u, key u, _ = lax.while_loop(lambda uk: uk[0] > 0.5, body_fn, (jnp.float64(1.), key)) return u print(vmap(f)(random.split(random.PRNGKey(0), 2))) # no crash def testEmptyTuples(self): # Ensure there is no crash when a vectorized input contains empty tuples. result = vmap(lambda x, _: x + 1)(np.array([0, 1]), ()) self.assertAllClose(result, np.array([1, 2]), check_dtypes=False) # Ensure there is no crash when a vectorized output contains empty tuples. result, empty_tuple = vmap(lambda x: (x + 1, ()))(np.array([0, 1])) self.assertAllClose(result, np.array([1, 2]), check_dtypes=False) self.assertEqual((), empty_tuple) def testIndexAddBatchedIndexesOnly(self): f = lambda x, idx, y: jax.ops.index_add(x, jax.ops.index[idx], y) result = vmap(f, (None, 0, None))(np.zeros((10,)), np.arange(10,), 1.) self.assertAllClose(result, np.eye(10), check_dtypes=False) def testIssue1170(self): def f(index1, index2): return jnp.arange(36).reshape(6, 6)[index1, index2] g = jax.jit(jax.pmap(f)) ans = g(index1=np.asarray([1]), index2=np.asarray([2])) expected = g(np.asarray([1]), np.asarray([2])) self.assertAllClose(ans, expected) def testIssue3883(self): def scalar_f(x): return lax.dynamic_slice(x, [], []) xs = jnp.array([1, 2, 3, 4]) ans = vmap(scalar_f)(xs) expected = jnp.array([scalar_f(x) for x in xs]) self.assertAllClose(ans, expected) def scalar_f2(x): return lax.dynamic_update_slice(x, 7, []) xs = jnp.array([1, 2, 3, 4]) ans = vmap(scalar_f2)(xs) expected = jnp.array([scalar_f2(x) for x in xs]) self.assertAllClose(ans, expected) @parameterized.named_parameters( {"testcase_name": "_collective={}".format(seq.__name__).replace(" ", ""), "collective": collective, "seq": seq} for collective, seq in [(lax.psum, jnp.sum), (lax.pmean, jnp.mean), (lambda x, n: lax.pmax(x, n), jnp.max), (lambda x, n: lax.pmin(x, n), jnp.min)]) @skipIf(not jax.config.omnistaging_enabled, "vmap collectives only supported when omnistaging is enabled") def testCollective(self, collective, seq): x = jnp.arange(64).reshape((4, 4, 4)) self.assertAllClose( vmap(lambda x: x - collective(x, 'i'), axis_name='i')(x), x - seq(x, axis=0)) self.assertAllClose( vmap(vmap(lambda x: x - collective(x, ('j', 'i')), axis_name='i'), axis_name='j')(x), x - seq(x, axis=(0, 1))) self.assertAllClose( vmap(vmap(lambda x: x - collective(x, ('i', 'j')), axis_name='i'), axis_name='j')(x), x - seq(x, axis=(1, 0))) @skipIf(not jax.config.omnistaging_enabled, "vmap collectives only supported when omnistaging is enabled") def testPPermute(self): nelem = 10 ntests = 10 x = np.arange(nelem) rng = np.random.RandomState(1) for i in range(ntests): perm = np.arange(nelem) rng.shuffle(perm) perm_pairs = np.stack([np.arange(nelem), perm], axis=-1) rng.shuffle(perm_pairs) self.assertAllClose( vmap(lambda x: x - lax.ppermute(x, 'i', perm_pairs), axis_name='i')(x), x - x[perm]) @parameterized.named_parameters( {"testcase_name": f"_split={split_axis}_concat={concat_axis}_vmap={vmap_axis}", "split_axis": split_axis, "concat_axis": concat_axis, "vmap_axis": vmap_axis} for split_axis, concat_axis, vmap_axis in it.product(range(3), range(3), range(4))) @skipIf(not jax.config.omnistaging_enabled, "vmap collectives only supported when omnistaging is enabled") def testAllToAll(self, vmap_axis, split_axis, concat_axis): d = vmap_axis def shape_fun(x, out_d): shape = list(x.shape) vmap_dim_id = shape.pop(d) split_dim_id = shape.pop(split_axis) shape.insert(concat_axis, vmap_dim_id) shape.insert(out_d, split_dim_id) return tuple(shape) shape = (2, 3, 4, 5) x = np.arange(np.prod(shape)).reshape(shape) rule = batching.collective_rules[lax.all_to_all_p] (y,), (out_d,) = rule((x,), (d,), None, None, split_axis, concat_axis) exp_shape = shape_fun(x, out_d) self.assertEqual(y.shape, exp_shape) @parameterized.named_parameters( {"testcase_name": f"_split={split_axis}_concat={concat_axis}_vmap={vmap_axis}", "split_axis": split_axis, "concat_axis": concat_axis, "vmap_axis": vmap_axis} for split_axis, concat_axis, vmap_axis in it.product(range(2), range(2), range(3))) @skipIf(not jax.config.omnistaging_enabled, "vmap collectives only supported when omnistaging is enabled") def testAllToAllSplitAxis(self, vmap_axis, split_axis, concat_axis): shape = (4, 4, 4) x = np.arange(np.prod(shape)).reshape(shape) @partial(vmap, in_axes=vmap_axis, axis_name='i') @partial(vmap, in_axes=vmap_axis, axis_name='j') def f(x): return lax.all_to_all(x, ('i', 'j'), split_axis, concat_axis) unroll_shape = (2, 2, *shape[1:]) unroll_shape = list(shape) unroll_shape[vmap_axis:vmap_axis+1] = (2, 2) x_unroll = x.reshape(unroll_shape) y_unrolled = f(x_unroll) y = y_unrolled.reshape(shape) if vmap_axis <= split_axis: split_axis += 1 ref = jnp.moveaxis(x, (vmap_axis, split_axis), (concat_axis + 1, 0)) self.assertAllClose(y, ref) def testNegativeAxes(self): x = np.arange(3*4*5).reshape(3, 4, 5) self.assertAllClose(jax.vmap(jnp.sum, in_axes=-3)(x), jnp.sum(x, axis=(1, 2))) self.assertAllClose(jax.vmap(jnp.sum, in_axes=-2)(x), jnp.sum(x, axis=(0, 2))) self.assertAllClose(jax.vmap(jnp.sum, in_axes=-1)(x), jnp.sum(x, axis=(0, 1))) with self.assertRaisesRegex(ValueError, "vmap got arg 0 of rank 3 but axis to be mapped -4"): jax.vmap(jnp.sum, in_axes=-4)(x) id = lambda y: y self.assertAllClose(x, jax.vmap(id, in_axes=0, out_axes=-3)(x)) self.assertAllClose(x.transpose(1, 0, 2), jax.vmap(id, in_axes=0, out_axes=-2)(x)) self.assertAllClose(x.transpose(1, 2, 0), jax.vmap(id, in_axes=0, out_axes=-1)(x)) with self.assertRaisesRegex(ValueError, "axis -4 is out of bounds.*"): jax.vmap(id, in_axes=0, out_axes=-4)(x) self.assertAllClose( np.full((5,), 7), jax.vmap(lambda *xs: xs, in_axes=(0, None), out_axes=(0, -1))( np.arange(5), 7)[1]) with self.assertRaisesRegex(ValueError, "axis -2 is out of bounds.*"): jax.vmap(lambda *xs: xs, in_axes=(0, None), out_axes=(0, -2))( np.arange(5), 7) @skipIf(not jax.config.omnistaging_enabled, "vmap collectives only supported when omnistaging is enabled") def testAxisIndex(self): x = np.arange(10) self.assertAllClose( vmap(lambda x: x - lax.axis_index('i'), axis_name='i')(x), x - np.arange(x.shape[0]))
class IndexedUpdateTest(jtu.JaxTestCase): @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer, jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer, "update_shape": update_shape, "update_dtype": update_dtype, "op": op } for name, index_specs in STATIC_INDEXING_TESTS for shape, indexer in index_specs for op in UpdateOps for dtype in ( all_dtypes if op == UpdateOps.UPDATE else default_dtypes) for update_shape in _broadcastable_shapes( _update_shape(shape, indexer)) for update_dtype in ( [dtype] if op == UpdateOps.ADD else all_dtypes) for rng_factory in [jtu.rand_default])) def testStaticIndexing(self, shape, dtype, update_shape, update_dtype, rng_factory, indexer, op): rng = rng_factory() args_maker = lambda: [ rng(shape, dtype), rng(update_shape, update_dtype) ] onp_fn = lambda x, y: UpdateOps.onp_fn(op, indexer, x, y) jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y) self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True) self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer, jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer, "update_shape": update_shape, "update_dtype": update_dtype, "op": op } for name, index_specs in ADVANCED_INDEXING_TESTS_NO_REPEATS for shape, indexer in index_specs for op in UpdateOps for dtype in ( all_dtypes if op == UpdateOps.UPDATE else default_dtypes) for update_shape in _broadcastable_shapes( _update_shape(shape, indexer)) for update_dtype in ( [dtype] if op == UpdateOps.ADD else all_dtypes) for rng_factory in [jtu.rand_default])) def testAdvancedIndexing(self, shape, dtype, update_shape, update_dtype, rng_factory, indexer, op): rng = rng_factory() args_maker = lambda: [ rng(shape, dtype), rng(update_shape, update_dtype) ] onp_fn = lambda x, y: UpdateOps.onp_fn(op, indexer, x, y) jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y) self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True) self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer, jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer, "update_shape": update_shape, "update_dtype": update_dtype, "op": op } for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS for shape, indexer in index_specs for op in UpdateOps for dtype in ( all_dtypes if op == UpdateOps.UPDATE else default_dtypes) for update_shape in _broadcastable_shapes( _update_shape(shape, indexer)) for update_dtype in ( [dtype] if op == UpdateOps.ADD else all_dtypes) for rng_factory in [jtu.rand_default])) def testMixedAdvancedIndexing(self, shape, dtype, update_shape, update_dtype, rng_factory, indexer, op): rng = rng_factory() args_maker = lambda: [ rng(shape, dtype), rng(update_shape, update_dtype) ] onp_fn = lambda x, y: UpdateOps.onp_fn(op, indexer, x, y) jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y) self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True) self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer, jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer, "update_shape": update_shape, "update_dtype": update_dtype, "op": op } for name, index_specs in STATIC_INDEXING_TESTS for shape, indexer in index_specs for op in UpdateOps for dtype in float_dtypes for update_shape in _broadcastable_shapes( _update_shape(shape, indexer)) for update_dtype in ( [dtype] if op == UpdateOps.ADD else float_dtypes) for rng_factory in [jtu.rand_default])) @jtu.skip_on_devices("tpu") # TODO(mattjj,phawkins): tpu issues def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype, rng_factory, indexer, op): rng = rng_factory() jax_op = ops.index_update if op == UpdateOps.UPDATE else ops.index_add jax_fn = lambda x, y: jax_op(x, indexer, y) x = rng(shape, dtype) y = rng(update_shape, update_dtype) check_grads(jax_fn, (x, y), 2, rtol=1e-3, atol=1e-3, eps=1.) def testSegmentSumBehavior(self): # testAdvancedIndexing compares against NumPy, and as a result doesn't check # repeated indices. This test is just a simple manual check, based on # https://www.tensorflow.org/api_docs/python/tf/math/segment_sum data = onp.array([5, 1, 7, 2, 3, 4, 1, 3]) segment_ids = onp.array([0, 0, 0, 1, 2, 2, 3, 3]) ans = ops.index_add(onp.zeros(onp.max(segment_ids) + 1), segment_ids, data) expected = onp.array([13, 2, 7, 4]) self.assertAllClose(ans, expected, check_dtypes=False) def testSegmentSum(self): data = onp.array([5, 1, 7, 2, 3, 4, 1, 3]) segment_ids = onp.array([0, 0, 0, 1, 2, 2, 3, 3]) # test with explicit num_segments ans = ops.segment_sum(data, segment_ids, num_segments=4) expected = onp.array([13, 2, 7, 4]) self.assertAllClose(ans, expected, check_dtypes=False) # test without explicit num_segments ans = ops.segment_sum(data, segment_ids) expected = onp.array([13, 2, 7, 4]) self.assertAllClose(ans, expected, check_dtypes=False)
class NumpyLinalgTest(jtu.JaxTestCase): @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng } for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (1000, 0, 0)] for dtype in float_types() for rng in [jtu.rand_default()])) def testCholesky(self, shape, dtype, rng): def args_maker(): a = rng(shape, dtype) return [onp.matmul(a, np.conj(T(a)))] self._CheckAgainstNumpy(onp.linalg.cholesky, np.linalg.cholesky, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(np.linalg.cholesky, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_n={}".format(jtu.format_shape_dtype_string((n, n), dtype)), "n": n, "dtype": dtype, "rng": rng } for n in [0, 4, 5, 50] for dtype in float_types() | complex_types() for rng in [jtu.rand_default()])) # TODO(phawkins): enable when there is an LU implementation for GPU/TPU. @jtu.skip_on_devices("gpu", "tpu") def testDet(self, n, dtype, rng): if not hasattr(lapack, "jax_getrf"): self.skipTest("No LU implementation available") args_maker = lambda: [rng((n, n), dtype)] self._CheckAgainstNumpy(onp.linalg.det, np.linalg.det, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(np.linalg.det, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_n={}".format(jtu.format_shape_dtype_string((n, n), dtype)), "n": n, "dtype": dtype, "rng": rng } for n in [0, 4, 10, 200] for dtype in float_types() | complex_types() for rng in [jtu.rand_default()])) @jtu.skip_on_devices("gpu", "tpu") def testSlogdet(self, n, dtype, rng): if not hasattr(lapack, "jax_getrf"): self.skipTest("No LU implementation available") args_maker = lambda: [rng((n, n), dtype)] self._CheckAgainstNumpy(onp.linalg.slogdet, np.linalg.slogdet, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(np.linalg.slogdet, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_n={}_lower={}".format( jtu.format_shape_dtype_string((n, n), dtype), lower), "n": n, "dtype": dtype, "lower": lower, "rng": rng } for n in [0, 4, 5, 50] for dtype in float_types() | complex_types() for lower in [False, True] for rng in [jtu.rand_default()])) # TODO(phawkins): enable when there is an eigendecomposition implementation # for GPU/TPU. @jtu.skip_on_devices("gpu", "tpu") def testEigh(self, n, dtype, lower, rng): if not hasattr(lapack, "jax_syevd"): self.skipTest( "No symmetric eigendecomposition implementation available") args_maker = lambda: [rng((n, n), dtype)] uplo = "L" if lower else "U" # Norm, adjusted for dimension and type. def norm(x): norm = onp.linalg.norm(x, axis=(-2, -1)) return norm / ((n + 1) * onp.finfo(dtype).eps) a, = args_maker() a = (a + onp.conj(a.T)) / 2 w, v = np.linalg.eigh(onp.tril(a) if lower else onp.triu(a), UPLO=uplo) self.assertTrue(norm(onp.eye(n) - onp.matmul(onp.conj(T(v)), v)) < 5) self.assertTrue(norm(onp.matmul(a, v) - w * v) < 30) self._CompileAndCheck(partial(np.linalg.eigh, UPLO=uplo), args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_n={}_full_matrices={}_compute_uv={}".format( jtu.format_shape_dtype_string(( m, n), dtype), full_matrices, compute_uv), "m": m, "n": n, "dtype": dtype, "full_matrices": full_matrices, "compute_uv": compute_uv, "rng": rng } for m in [2, 7, 29, 53] for n in [2, 7, 29, 53] for dtype in float_types() | complex_types() for full_matrices in [False, True] for compute_uv in [False, True] for rng in [jtu.rand_default()])) @jtu.skip_on_devices("gpu", "tpu") def testSVD(self, m, n, dtype, full_matrices, compute_uv, rng): if not hasattr(lapack, "jax_gesdd"): self.skipTest( "No singular value decomposition implementation available") args_maker = lambda: [rng((m, n), dtype)] # Norm, adjusted for dimension and type. def norm(x): norm = onp.linalg.norm(x, axis=(-2, -1)) return norm / (max(m, n) * onp.finfo(dtype).eps) a, = args_maker() out = np.linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv) if compute_uv: # Check the reconstructed matrices if full_matrices: k = min(m, n) if m < n: self.assertTrue( onp.all( norm(a - onp.matmul(out[1] * out[0], out[2][:k, :])) < 50)) else: self.assertTrue( onp.all( norm(a - onp.matmul(out[1] * out[0][:, :k], out[2])) < 50)) else: self.assertTrue( onp.all( norm(a - onp.matmul(out[1] * out[0], out[2])) < 50)) # Check the unitary properties of the singular vector matrices. self.assertTrue( onp.all( norm( onp.eye(out[0].shape[1]) - onp.matmul(onp.conj(T(out[0])), out[0])) < 10)) if m >= n: self.assertTrue( onp.all( norm( onp.eye(out[2].shape[1]) - onp.matmul(onp.conj(T(out[2])), out[2])) < 10)) else: self.assertTrue( onp.all( norm( onp.eye(out[2].shape[0]) - onp.matmul(out[2], onp.conj(T(out[2])))) < 20)) else: self.assertTrue( onp.allclose(onp.linalg.svd(a, compute_uv=False), onp.asarray(out), atol=1e-4, rtol=1e-4)) self._CompileAndCheck(partial(np.linalg.svd, full_matrices=full_matrices, compute_uv=compute_uv), args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_fullmatrices={}".format( jtu.format_shape_dtype_string(shape, dtype), full_matrices), "shape": shape, "dtype": dtype, "full_matrices": full_matrices, "rng": rng } for shape in [(1, 1), (3, 4), (2, 10, 5), (2, 200, 100)] for dtype in float_types() for full_matrices in [False, True] for rng in [jtu.rand_default()])) @jtu.skip_on_devices("cpu") def testQr(self, shape, dtype, full_matrices, rng): m, n = shape[-2:] if full_matrices: mode, k = "complete", m else: mode, k = "reduced", min(m, n) a = rng(shape, dtype) lq, lr = np.linalg.qr(a, mode=mode) # onp.linalg.qr doesn't support broadcasting. But it seems like an # inevitable extension so we support it in our version. nq = onp.zeros(shape[:-2] + (m, k), dtype) nr = onp.zeros(shape[:-2] + (k, n), dtype) for index in onp.ndindex(*shape[:-2]): nq[index], nr[index] = onp.linalg.qr(a[index], mode=mode) max_rank = max(m, n) # Norm, adjusted for dimension and type. def norm(x): n = onp.linalg.norm(x, axis=(-2, -1)) return n / (max_rank * onp.finfo(dtype).eps) def compare_orthogonal(q1, q2): # Q is unique up to sign, so normalize the sign first. sum_of_ratios = onp.sum(onp.divide(q1, q2), axis=-2, keepdims=True) phases = onp.divide(sum_of_ratios, onp.abs(sum_of_ratios)) q1 *= phases self.assertTrue(onp.all(norm(q1 - q2) < 30)) # Check a ~= qr self.assertTrue(onp.all(norm(a - onp.matmul(lq, lr)) < 30)) # Compare the first 'k' vectors of Q; the remainder form an arbitrary # orthonormal basis for the null space. compare_orthogonal(nq[..., :k], lq[..., :k]) # Check that q is close to unitary. self.assertTrue(onp.all(norm(onp.eye(k) - onp.matmul(T(lq), lq)) < 5)) if not full_matrices and m >= n: jtu.check_jvp(np.linalg.qr, partial(jvp, np.linalg.qr), (a, )) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_lhs={}_rhs={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype)), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "rng": rng } for lhs_shape, rhs_shape in [ ((1, 1), (1, 1)), ((4, 4), (4, )), ((8, 8), (8, 4)), ] for dtype in float_types() | complex_types() for rng in [jtu.rand_default()])) # TODO(phawkins): enable when there is an LU implementation for GPU/TPU. @jtu.skip_on_devices("gpu", "tpu") def testSolve(self, lhs_shape, rhs_shape, dtype, rng): if not hasattr(lapack, "jax_getrf"): self.skipTest("No LU implementation available") args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] self._CheckAgainstNumpy(onp.linalg.solve, np.linalg.solve, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(np.linalg.solve, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng } for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (5, 5, 5)] for dtype in float_types() for rng in [jtu.rand_default()])) def testInv(self, shape, dtype, rng): def args_maker(): invertible = False while not invertible: a = rng(shape, dtype) try: onp.linalg.inv(a) invertible = True except onp.linalg.LinAlgError: pass return [a] self._CheckAgainstNumpy(onp.linalg.inv, np.linalg.inv, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(np.linalg.inv, args_maker, check_dtypes=True)
class SparseObjectTest(jtu.JaxTestCase): @parameterized.named_parameters({ "testcase_name": "_{}".format(Obj.__name__), "Obj": Obj } for Obj in [sparse_ops.CSR, sparse_ops.CSC, sparse_ops.COO]) def test_attrs(self, Obj, shape=(5, 8), dtype=np.float16): rng = rand_sparse(self.rng(), post=Obj.fromdense) M = rng(shape, dtype) assert isinstance(M, Obj) assert M.shape == shape assert M.dtype == dtype assert M.nnz == (M.todense() != 0).sum() assert M.data.dtype == dtype if isinstance(M, sparse_ops.CSR): assert len(M.data) == len(M.indices) assert len(M.indptr) == M.shape[0] + 1 elif isinstance(M, sparse_ops.CSC): assert len(M.data) == len(M.indices) assert len(M.indptr) == M.shape[1] + 1 elif isinstance(M, sparse_ops.COO): assert len(M.data) == len(M.row) == len(M.col) else: raise ValueError("Obj={Obj} not expected.") @parameterized.named_parameters( itertools.chain.from_iterable( jtu.cases_from_list({ "testcase_name": "_{}_Obj={}".format( jtu.format_shape_dtype_string(shape, dtype), Obj.__name__), "shape": shape, "dtype": dtype, "Obj": Obj } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for dtype in jtu.dtypes.floating + jtu.dtypes.complex) for Obj in [sparse_ops.CSR, sparse_ops.CSC, sparse_ops.COO])) def test_dense_round_trip(self, shape, dtype, Obj): rng = rand_sparse(self.rng()) M = rng(shape, dtype) Msparse = Obj.fromdense(M) self.assertArraysEqual(M, Msparse.todense()) @parameterized.named_parameters( itertools.chain.from_iterable( jtu.cases_from_list({ "testcase_name": "_{}_Obj={}".format( jtu.format_shape_dtype_string(shape, dtype), Obj.__name__), "shape": shape, "dtype": dtype, "Obj": Obj } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for dtype in jtu.dtypes.floating + jtu.dtypes.complex) for Obj in [sparse_ops.CSR, sparse_ops.CSC, sparse_ops.COO])) def test_transpose(self, shape, dtype, Obj): rng = rand_sparse(self.rng()) M = rng(shape, dtype) Msparse = Obj.fromdense(M) self.assertArraysEqual(M.T, Msparse.T.todense()) @unittest.skipIf(jtu.device_under_test() == "tpu", "TPU has insufficient precision") @parameterized.named_parameters( itertools.chain.from_iterable( jtu.cases_from_list( { "testcase_name": "_{}_Obj={}_bshape={}".format( jtu.format_shape_dtype_string(shape, dtype), Obj.__name__, bshape), "shape": shape, "dtype": dtype, "Obj": Obj, "bshape": bshape } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for bshape in [shape[-1:] + s for s in [(), (3, ), (4, )]] for dtype in jtu.dtypes.floating + jtu.dtypes.complex) for Obj in [sparse_ops.CSR, sparse_ops.CSC, sparse_ops.COO])) def test_matmul(self, shape, dtype, Obj, bshape): rng = rand_sparse(self.rng(), post=jnp.array) rng_b = jtu.rand_default(self.rng()) M = rng(shape, dtype) Msp = Obj.fromdense(M) x = rng_b(bshape, dtype) x = jnp.asarray(x) self.assertAllClose(M @ x, Msp @ x, rtol=MATMUL_TOL)
class cuSparseTest(jtu.JaxTestCase): @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) def test_csr_todense(self, shape, dtype): rng = rand_sparse(self.rng(), post=sparse.csr_matrix) M = rng(shape, dtype) args = (M.data, M.indices, M.indptr) todense = lambda *args: sparse_ops.csr_todense(*args, shape=M.shape) self.assertArraysEqual(M.toarray(), todense(*args)) self.assertArraysEqual(M.toarray(), jit(todense)(*args)) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) def test_csr_fromdense(self, shape, dtype): rng = rand_sparse(self.rng()) M = rng(shape, dtype) M_csr = sparse.csr_matrix(M) nnz = M_csr.nnz index_dtype = jnp.int32 fromdense = lambda M: sparse_ops.csr_fromdense( M, nnz=nnz, index_dtype=jnp.int32) data, indices, indptr = fromdense(M) self.assertArraysEqual(data, M_csr.data.astype(dtype)) self.assertArraysEqual(indices, M_csr.indices.astype(index_dtype)) self.assertArraysEqual(indptr, M_csr.indptr.astype(index_dtype)) data, indices, indptr = jit(fromdense)(M) self.assertArraysEqual(data, M_csr.data.astype(dtype)) self.assertArraysEqual(indices, M_csr.indices.astype(index_dtype)) self.assertArraysEqual(indptr, M_csr.indptr.astype(index_dtype)) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose), "shape": shape, "dtype": dtype, "transpose": transpose } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for dtype in jtu.dtypes.floating + jtu.dtypes.complex for transpose in [True, False])) def test_csr_matvec(self, shape, dtype, transpose): op = lambda M: M.T if transpose else M v_rng = jtu.rand_default(self.rng()) rng = rand_sparse(self.rng(), post=sparse.csr_matrix) M = rng(shape, dtype) v = v_rng(op(M).shape[1], dtype) args = (M.data, M.indices, M.indptr, v) matvec = lambda *args: sparse_ops.csr_matvec( *args, shape=M.shape, transpose=transpose) self.assertAllClose(op(M) @ v, matvec(*args), rtol=MATMUL_TOL) self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose), "shape": shape, "dtype": dtype, "transpose": transpose } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for dtype in jtu.dtypes.floating + jtu.dtypes.complex for transpose in [True, False])) def test_csr_matmat(self, shape, dtype, transpose): op = lambda M: M.T if transpose else M B_rng = jtu.rand_default(self.rng()) rng = rand_sparse(self.rng(), post=sparse.csr_matrix) M = rng(shape, dtype) B = B_rng((op(M).shape[1], 4), dtype) args = (M.data, M.indices, M.indptr, B) matmat = lambda *args: sparse_ops.csr_matmat( *args, shape=shape, transpose=transpose) self.assertAllClose(op(M) @ B, matmat(*args), rtol=MATMUL_TOL) self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) def test_coo_todense(self, shape, dtype): rng = rand_sparse(self.rng(), post=sparse.coo_matrix) M = rng(shape, dtype) args = (M.data, M.row, M.col) todense = lambda *args: sparse_ops.coo_todense(*args, shape=M.shape) self.assertArraysEqual(M.toarray(), todense(*args)) self.assertArraysEqual(M.toarray(), jit(todense)(*args)) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) def test_coo_fromdense(self, shape, dtype): rng = rand_sparse(self.rng()) M = rng(shape, dtype) M_coo = sparse.coo_matrix(M) nnz = M_coo.nnz index_dtype = jnp.int32 fromdense = lambda M: sparse_ops.coo_fromdense( M, nnz=nnz, index_dtype=jnp.int32) data, row, col = fromdense(M) self.assertArraysEqual(data, M_coo.data.astype(dtype)) self.assertArraysEqual(row, M_coo.row.astype(index_dtype)) self.assertArraysEqual(col, M_coo.col.astype(index_dtype)) data, indices, indptr = jit(fromdense)(M) self.assertArraysEqual(data, M_coo.data.astype(dtype)) self.assertArraysEqual(row, M_coo.row.astype(index_dtype)) self.assertArraysEqual(col, M_coo.col.astype(index_dtype)) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose), "shape": shape, "dtype": dtype, "transpose": transpose } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for dtype in jtu.dtypes.floating + jtu.dtypes.complex for transpose in [True, False])) def test_coo_matvec(self, shape, dtype, transpose): op = lambda M: M.T if transpose else M v_rng = jtu.rand_default(self.rng()) rng = rand_sparse(self.rng(), post=sparse.coo_matrix) M = rng(shape, dtype) v = v_rng(op(M).shape[1], dtype) args = (M.data, M.row, M.col, v) matvec = lambda *args: sparse_ops.coo_matvec( *args, shape=M.shape, transpose=transpose) self.assertAllClose(op(M) @ v, matvec(*args), rtol=MATMUL_TOL) self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose), "shape": shape, "dtype": dtype, "transpose": transpose } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for dtype in jtu.dtypes.floating + jtu.dtypes.complex for transpose in [True, False])) def test_coo_matmat(self, shape, dtype, transpose): op = lambda M: M.T if transpose else M B_rng = jtu.rand_default(self.rng()) rng = rand_sparse(self.rng(), post=sparse.coo_matrix) M = rng(shape, dtype) B = B_rng((op(M).shape[1], 4), dtype) args = (M.data, M.row, M.col, B) matmat = lambda *args: sparse_ops.coo_matmat( *args, shape=shape, transpose=transpose) self.assertAllClose(op(M) @ B, matmat(*args), rtol=MATMUL_TOL) self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL) @unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU") def test_gpu_translation_rule(self): version = xla_bridge.get_backend().platform_version cuda_version = None if version == "<unknown>" else int( version.split()[-1]) if cuda_version is None or cuda_version < 11000: self.assertFalse(cusparse and cusparse.is_supported) self.assertNotIn(sparse_ops.csr_todense_p, xla.backend_specific_translations["gpu"]) else: self.assertTrue(cusparse and cusparse.is_supported) self.assertIn(sparse_ops.csr_todense_p, xla.backend_specific_translations["gpu"]) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}_{}".format(jtu.format_shape_dtype_string(shape, dtype), mat_type), "shape": shape, "dtype": dtype, "mat_type": mat_type } for shape in [(5, 8), (8, 5), (5, 5), (8, 8)] for dtype in jtu.dtypes.floating + jtu.dtypes.complex for mat_type in ['csr', 'coo'])) def test_extra_nnz(self, shape, dtype, mat_type): rng = rand_sparse(self.rng()) M = rng(shape, dtype) nnz = (M != 0).sum() + 5 fromdense = getattr(sparse_ops, f"{mat_type}_fromdense") todense = getattr(sparse_ops, f"{mat_type}_todense") args = fromdense(M, nnz=nnz, index_dtype=jnp.int32) M_out = todense(*args, shape=M.shape) self.assertArraysEqual(M, M_out)
class LaxBackedScipyStatsTests(jtu.JaxTestCase): """Tests for LAX-backed scipy.stats implementations""" @genNamedParametersNArgs(3) def testPoissonLogPmf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.poisson.logpmf lax_fun = lsp_stats.poisson.logpmf def args_maker(): k, mu, loc = map(rng, shapes, dtypes) k = np.floor(k) # clipping to ensure that rate parameter is strictly positive mu = np.clip(np.abs(mu), a_min=0.1, a_max=None) loc = np.floor(loc) return [k, mu, loc] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14}) @genNamedParametersNArgs(3) def testPoissonPmf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.poisson.pmf lax_fun = lsp_stats.poisson.pmf def args_maker(): k, mu, loc = map(rng, shapes, dtypes) k = np.floor(k) # clipping to ensure that rate parameter is strictly positive mu = np.clip(np.abs(mu), a_min=0.1, a_max=None) loc = np.floor(loc) return [k, mu, loc] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testPoissonCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.poisson.cdf lax_fun = lsp_stats.poisson.cdf def args_maker(): k, mu, loc = map(rng, shapes, dtypes) # clipping to ensure that rate parameter is strictly positive mu = np.clip(np.abs(mu), a_min=0.1, a_max=None) return [k, mu, loc] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testBernoulliLogPmf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.bernoulli.logpmf lax_fun = lsp_stats.bernoulli.logpmf def args_maker(): x, logit, loc = map(rng, shapes, dtypes) x = np.floor(x) p = expit(logit) loc = np.floor(loc) return [x, p, loc] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testGeomLogPmf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.geom.logpmf lax_fun = lsp_stats.geom.logpmf def args_maker(): x, logit, loc = map(rng, shapes, dtypes) x = np.floor(x) p = expit(logit) loc = np.floor(loc) return [x, p, loc] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(5) def testBetaLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.beta.logpdf lax_fun = lsp_stats.beta.logpdf def args_maker(): x, a, b, loc, scale = map(rng, shapes, dtypes) return [x, a, b, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker, rtol={ np.float32: 2e-3, np.float64: 1e-4 }) @genNamedParametersNArgs(3) def testCauchyLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.cauchy.logpdf lax_fun = lsp_stats.cauchy.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(2) def testDirichletLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.cauchy.logpdf lax_fun = lsp_stats.cauchy.logpdf dim = 4 shapes = (shapes[0] + (dim, ), shapes[1] + (dim, )) def args_maker(): x, alpha = map(rng, shapes, dtypes) x = x / np.sum(x, axis=-1, keepdims=True) return [x, alpha] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testExponLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.expon.logpdf lax_fun = lsp_stats.expon.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(4) def testGammaLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.gamma.logpdf lax_fun = lsp_stats.gamma.logpdf def args_maker(): x, a, loc, scale = map(rng, shapes, dtypes) return [x, a, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testLaplaceLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.laplace.logpdf lax_fun = lsp_stats.laplace.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(scale, a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testLaplaceCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.laplace.cdf lax_fun = lsp_stats.laplace.cdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # ensure that scale is not too low scale = np.clip(scale, a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol={ np.float32: 1e-5, np.float64: 1e-6 }) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(1) def testLogisticCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.cdf lax_fun = lsp_stats.logistic.cdf def args_maker(): return list(map(rng, shapes, dtypes)) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-6) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(1) def testLogisticLogpdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.logpdf lax_fun = lsp_stats.logistic.logpdf def args_maker(): return list(map(rng, shapes, dtypes)) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(1) def testLogisticPpf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.ppf lax_fun = lsp_stats.logistic.ppf def args_maker(): return list(map(rng, shapes, dtypes)) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(1) def testLogisticSf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.sf lax_fun = lsp_stats.logistic.sf def args_maker(): return list(map(rng, shapes, dtypes)) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-6) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testNormLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.norm.logpdf lax_fun = lsp_stats.norm.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testNormLogCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.norm.logcdf lax_fun = lsp_stats.norm.logcdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testNormCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.norm.cdf lax_fun = lsp_stats.norm.cdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-6) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testNormPpf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.norm.ppf lax_fun = lsp_stats.norm.ppf def args_maker(): q, loc, scale = map(rng, shapes, dtypes) # ensure probability is between 0 and 1: q = np.clip(np.abs(q / 3), a_min=None, a_max=1) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None) return [q, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4) @genNamedParametersNArgs(4) def testParetoLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.pareto.logpdf lax_fun = lsp_stats.pareto.logpdf def args_maker(): x, b, loc, scale = map(rng, shapes, dtypes) return [x, b, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(4) def testTLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.t.logpdf lax_fun = lsp_stats.t.logpdf def args_maker(): x, df, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None) return [x, df, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14}, atol={np.float64: 1e-14}) @genNamedParametersNArgs(3) def testUniformLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.uniform.logpdf lax_fun = lsp_stats.uniform.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) return [x, loc, np.abs(scale)] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(4) def testChi2LogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.chi2.logpdf lax_fun = lsp_stats.chi2.logpdf def args_maker(): x, df, loc, scale = map(rng, shapes, dtypes) return [x, df, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker) def testIssue972(self): self.assertAllClose(np.ones((4, ), np.float32), lsp_stats.norm.cdf( np.full((4, ), np.inf, np.float32)), check_dtypes=False) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_x={}_mean={}_cov={}".format( jtu.format_shape_dtype_string(x_shape, x_dtype), jtu.format_shape_dtype_string(mean_shape, mean_dtype) if mean_shape is not None else None, jtu.format_shape_dtype_string(cov_shape, cov_dtype) if cov_shape is not None else None), "x_shape": x_shape, "x_dtype": x_dtype, "mean_shape": mean_shape, "mean_dtype": mean_dtype, "cov_shape": cov_shape, "cov_dtype": cov_dtype } for x_shape, mean_shape, cov_shape in [ # # These test cases cover default values for mean/cov, but we don't # # support those yet (and they seem not very valuable). # [(), None, None], # [(), (), None], # [(2,), None, None], # [(2,), (), None], # [(2,), (2,), None], # [(3, 2), (3, 2,), None], # [(5, 3, 2), (5, 3, 2,), None], [(), (), ()], [(3, ), (), ()], [(3, ), (3, ), ()], [(3, ), (3, ), (3, 3)], [(3, 4), (4, ), (4, 4)], # # These test cases are where scipy flattens things, which has # # different batch semantics than some might expect # [(5, 3, 2), (5, 3, 2,), ()], # [(5, 3, 2), (5, 3, 2,), (5, 3, 2, 2)], # [(5, 3, 2), (3, 2,), (5, 3, 2, 2)], # [(5, 3, 2), (3, 2,), (2, 2)], ] for x_dtype, mean_dtype, cov_dtype in itertools.combinations_with_replacement( jtu.dtypes.floating, 3) if (mean_shape is not None or mean_dtype == np.float32) and (cov_shape is not None or cov_dtype == np.float32)) ) def testMultivariateNormalLogpdf(self, x_shape, x_dtype, mean_shape, mean_dtype, cov_shape, cov_dtype): rng = jtu.rand_default(self.rng()) def args_maker(): args = [rng(x_shape, x_dtype)] if mean_shape is not None: args.append(5 * rng(mean_shape, mean_dtype)) if cov_shape is not None: if cov_shape == (): args.append(0.1 + rng(cov_shape, cov_dtype)**2) else: factor_shape = (*cov_shape[:-1], 2 * cov_shape[-1]) factor = rng(factor_shape, cov_dtype) args.append(np.matmul(factor, np.swapaxes(factor, -1, -2))) return args self._CheckAgainstNumpy(osp_stats.multivariate_normal.logpdf, lsp_stats.multivariate_normal.logpdf, args_maker, tol=1e-3) self._CompileAndCheck(lsp_stats.multivariate_normal.logpdf, args_maker, rtol=1e-4, atol=1e-4) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_ndim={}_nbatch={}_dtype={}".format(ndim, nbatch, dtype.__name__), "ndim": ndim, "nbatch": nbatch, "dtype": dtype } for ndim in [2, 3] for nbatch in [1, 3, 5] for dtype in jtu.dtypes.floating)) def testMultivariateNormalLogpdfBatch(self, ndim, nbatch, dtype): # Regression test for #5570 rng = jtu.rand_default(self.rng()) x = rng((nbatch, ndim), dtype) mean = 5 * rng((nbatch, ndim), dtype) factor = rng((nbatch, ndim, 2 * ndim), dtype) cov = factor @ factor.transpose(0, 2, 1) result1 = lsp_stats.multivariate_normal.logpdf(x, mean, cov) result2 = api.vmap(lsp_stats.multivariate_normal.logpdf)(x, mean, cov) self.assertArraysEqual(result1, result2)
class LaxBackedScipyTests(jtu.JaxTestCase): """Tests for LAX-backed Scipy implementation.""" def _GetArgsMaker(self, rng, shapes, dtypes): return lambda: [ rng(shape, dtype) for shape, dtype in zip(shapes, dtypes) ] @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shapes={}_axis={}_keepdims={}_return_sign={}_use_b_{}".format( jtu.format_shape_dtype_string(shapes, dtype), axis, keepdims, return_sign, use_b), # TODO(b/133842870): re-enable when exp(nan) returns NaN on CPU. "shapes": shapes, "dtype": dtype, "axis": axis, "keepdims": keepdims, "return_sign": return_sign, "use_b": use_b } for shape_group in compatible_shapes for dtype in float_dtypes + complex_dtypes + int_dtypes for use_b in [False, True] for shapes in itertools.product( *((shape_group, shape_group) if use_b else (shape_group, ))) for axis in range( -max(len(shape) for shape in shapes), max(len(shape) for shape in shapes)) for keepdims in [False, True] for return_sign in [False, True])) @jtu.ignore_warning(category=RuntimeWarning, message="invalid value encountered in .*") def testLogSumExp(self, shapes, dtype, axis, keepdims, return_sign, use_b): if jtu.device_under_test() != "cpu": rng = jtu.rand_some_inf_and_nan(self.rng()) else: rng = jtu.rand_default(self.rng()) # TODO(mattjj): test autodiff if use_b: def scipy_fun(array_to_reduce, scale_array): return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign, b=scale_array) def lax_fun(array_to_reduce, scale_array): return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign, b=scale_array) args_maker = lambda: [rng(shapes[0], dtype), rng(shapes[1], dtype)] else: def scipy_fun(array_to_reduce): return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign) def lax_fun(array_to_reduce): return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign) args_maker = lambda: [rng(shapes[0], dtype)] tol = {np.float32: 1E-6, np.float64: 1E-14} self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker) self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol) def testLogSumExpZeros(self): # Regression test for https://github.com/google/jax/issues/5370 scipy_fun = lambda a, b: osp_special.logsumexp(a, b=b) lax_fun = lambda a, b: lsp_special.logsumexp(a, b=b) args_maker = lambda: [np.array([-1000, -2]), np.array([1, 0])] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker) self._CompileAndCheck(lax_fun, args_maker) @parameterized.named_parameters( itertools.chain.from_iterable( jtu.cases_from_list( { "testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes, dtypes), "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, "test_autodiff": rec.test_autodiff, "nondiff_argnums": rec.nondiff_argnums, "scipy_op": getattr(osp_special, rec.name), "lax_op": getattr(lsp_special, rec.name) } for shapes in itertools.combinations_with_replacement( all_shapes, rec.nargs) for dtypes in (itertools.combinations_with_replacement( rec.dtypes, rec.nargs) if isinstance(rec.dtypes, list) else itertools.product(*rec.dtypes))) for rec in JAX_SPECIAL_FUNCTION_RECORDS)) def testScipySpecialFun(self, scipy_op, lax_op, rng_factory, shapes, dtypes, test_autodiff, nondiff_argnums): if (jtu.device_under_test() == "cpu" and (lax_op is lsp_special.gammainc or lax_op is lsp_special.gammaincc)): # TODO(b/173608403): re-enable test when LLVM bug is fixed. raise unittest.SkipTest("Skipping test due to LLVM lowering bug") rng = rng_factory(self.rng()) args_maker = self._GetArgsMaker(rng, shapes, dtypes) args = args_maker() self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3, check_dtypes=False) self._CompileAndCheck(lax_op, args_maker, rtol=1e-4) if test_autodiff: def partial_lax_op(*vals): list_args = list(vals) for i in nondiff_argnums: list_args.insert(i, args[i]) return lax_op(*list_args) assert list(nondiff_argnums) == sorted(set(nondiff_argnums)) diff_args = [ x for i, x in enumerate(args) if i not in nondiff_argnums ] jtu.check_grads(partial_lax_op, diff_args, order=1, atol=jtu.if_device_under_test("tpu", .1, 1e-3), rtol=.1, eps=1e-3) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_inshape={}_d={}".format( jtu.format_shape_dtype_string(shape, dtype), d), "shape": shape, "dtype": dtype, "d": d } for shape in all_shapes for dtype in float_dtypes for d in [1, 2, 5])) def testMultigammaln(self, shape, dtype, d): def scipy_fun(a): return osp_special.multigammaln(a, d) def lax_fun(a): return lsp_special.multigammaln(a, d) rng = jtu.rand_positive(self.rng()) args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol={ np.float32: 1e-3, np.float64: 1e-14 }) self._CompileAndCheck(lax_fun, args_maker) def testIssue980(self): x = np.full((4, ), -1e20, dtype=np.float32) self.assertAllClose(np.zeros((4, ), dtype=np.float32), lsp_special.expit(x)) def testIssue3758(self): x = np.array([1e5, 1e19, 1e10], dtype=np.float32) q = np.array([1., 40., 30.], dtype=np.float32) self.assertAllClose(np.array([1., 0., 0.], dtype=np.float32), lsp_special.zeta(x, q)) def testXlogyShouldReturnZero(self): self.assertAllClose(lsp_special.xlogy(0., 0.), 0., check_dtypes=False) def testGradOfXlogyAtZero(self): partial_xlogy = functools.partial(lsp_special.xlogy, 0.) self.assertAllClose(api.grad(partial_xlogy)(0.), 0., check_dtypes=False) def testXlog1pyShouldReturnZero(self): self.assertAllClose(lsp_special.xlog1py(0., -1.), 0., check_dtypes=False) def testGradOfXlog1pyAtZero(self): partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.) self.assertAllClose(api.grad(partial_xlog1py)(-1.), 0., check_dtypes=False) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_maxdegree={}_inputsize={}".format( l_max, num_z), "l_max": l_max, "num_z": num_z } for l_max, num_z in zip([1, 2, 3], [6, 7, 8]))) def testLpmn(self, l_max, num_z): # Points on which the associated Legendre functions areevaluated. z = np.linspace(-0.2, 0.9, num_z) actual_p_vals, actual_p_derivatives = lsp_special.lpmn(m=l_max, n=l_max, z=z) # The expected results are obtained from scipy. expected_p_vals = np.zeros((l_max + 1, l_max + 1, num_z)) expected_p_derivatives = np.zeros((l_max + 1, l_max + 1, num_z)) for i in range(num_z): val, derivative = osp_special.lpmn(l_max, l_max, z[i]) expected_p_vals[:, :, i] = val expected_p_derivatives[:, :, i] = derivative with self.subTest('Test values.'): self.assertAllClose(actual_p_vals, expected_p_vals, rtol=1e-6, atol=3.2e-6) with self.subTest('Test derivatives.'): self.assertAllClose(actual_p_derivatives, expected_p_derivatives, rtol=1e-6, atol=8.4e-4) with self.subTest('Test JIT compatibility'): args_maker = lambda: [z] lsp_special_fn = lambda z: lsp_special.lpmn(l_max, l_max, z) self._CompileAndCheck(lsp_special_fn, args_maker) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_maxdegree={}_inputsize={}".format( l_max, num_z), "l_max": l_max, "num_z": num_z } for l_max, num_z in zip([3, 4, 6, 32], [2, 3, 4, 64]))) def testNormalizedLpmnValues(self, l_max, num_z): # Points on which the associated Legendre functions areevaluated. z = np.linspace(-0.2, 0.9, num_z) is_normalized = True actual_p_vals = lsp_special.lpmn_values(l_max, l_max, z, is_normalized) # The expected results are obtained from scipy. expected_p_vals = np.zeros((l_max + 1, l_max + 1, num_z)) for i in range(num_z): expected_p_vals[:, :, i] = osp_special.lpmn(l_max, l_max, z[i])[0] def apply_normalization(a): """Applies normalization to the associated Legendre functions.""" num_m, num_l, _ = a.shape a_normalized = np.zeros_like(a) for m in range(num_m): for l in range(num_l): c0 = (2.0 * l + 1.0) * osp_special.factorial(l - m) c1 = (4.0 * np.pi) * osp_special.factorial(l + m) c2 = np.sqrt(c0 / c1) a_normalized[m, l] = c2 * a[m, l] return a_normalized # The results from scipy are not normalized and the comparison requires # normalizing the results. expected_p_vals_normalized = apply_normalization(expected_p_vals) with self.subTest('Test accuracy.'): self.assertAllClose(actual_p_vals, expected_p_vals_normalized, rtol=1e-6, atol=3.2e-6) with self.subTest('Test JIT compatibility'): args_maker = lambda: [z] lsp_special_fn = lambda z: lsp_special.lpmn_values( l_max, l_max, z, is_normalized) self._CompileAndCheck(lsp_special_fn, args_maker) def testSphHarmAccuracy(self): m = jnp.arange(-3, 3)[:, None] n = jnp.arange(3, 6) n_max = 5 theta = 0.0 phi = jnp.pi expected = lsp_special.sph_harm(m, n, theta, phi, n_max) actual = osp_special.sph_harm(m, n, theta, phi) self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5) def testSphHarmOrderZeroDegreeZero(self): """Tests the spherical harmonics of order zero and degree zero.""" theta = jnp.array([0.3]) phi = jnp.array([2.3]) n_max = 0 expected = jnp.array([1.0 / jnp.sqrt(4.0 * np.pi)]) actual = jnp.real( lsp_special.sph_harm(jnp.array([0]), jnp.array([0]), theta, phi, n_max)) self.assertAllClose(actual, expected, rtol=1.1e-7, atol=3e-8) def testSphHarmOrderZeroDegreeOne(self): """Tests the spherical harmonics of order one and degree zero.""" theta = jnp.array([2.0]) phi = jnp.array([3.1]) n_max = 1 expected = jnp.sqrt(3.0 / (4.0 * np.pi)) * jnp.cos(phi) actual = jnp.real( lsp_special.sph_harm(jnp.array([0]), jnp.array([1]), theta, phi, n_max)) self.assertAllClose(actual, expected, rtol=7e-8, atol=1.5e-8) def testSphHarmOrderOneDegreeOne(self): """Tests the spherical harmonics of order one and degree one.""" theta = jnp.array([2.0]) phi = jnp.array([2.5]) n_max = 1 expected = (-1.0 / 2.0 * jnp.sqrt(3.0 / (2.0 * np.pi)) * jnp.sin(phi) * jnp.exp(1j * theta)) actual = lsp_special.sph_harm(jnp.array([1]), jnp.array([1]), theta, phi, n_max) self.assertAllClose(actual, expected, rtol=1e-8, atol=6e-8) @parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': '_maxdegree={}_inputsize={}_dtype={}'.format(l_max, num_z, dtype), 'l_max': l_max, 'num_z': num_z, 'dtype': dtype } for l_max, num_z in zip([1, 3, 8, 10], [2, 6, 7, 8]) for dtype in jtu.dtypes.all_integer)) def testSphHarmForJitAndAgainstNumpy(self, l_max, num_z, dtype): """Tests against JIT compatibility and Numpy.""" n_max = l_max shape = (num_z, ) rng = jtu.rand_int(self.rng(), -l_max, l_max + 1) lsp_special_fn = partial(lsp_special.sph_harm, n_max=n_max) def args_maker(): m = rng(shape, dtype) n = abs(m) theta = jnp.linspace(-4.0, 5.0, num_z) phi = jnp.linspace(-2.0, 1.0, num_z) return m, n, theta, phi with self.subTest('Test JIT compatibility'): self._CompileAndCheck(lsp_special_fn, args_maker) with self.subTest('Test against numpy.'): self._CheckAgainstNumpy(osp_special.sph_harm, lsp_special_fn, args_maker) def testSphHarmCornerCaseWithWrongNmax(self): """Tests the corner case where `n_max` is not the maximum value of `n`.""" m = jnp.array([2]) n = jnp.array([10]) n_clipped = jnp.array([6]) n_max = 6 theta = jnp.array([0.9]) phi = jnp.array([0.2]) expected = lsp_special.sph_harm(m, n, theta, phi, n_max) actual = lsp_special.sph_harm(m, n_clipped, theta, phi, n_max) self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5) @parameterized.named_parameters( jtu.cases_from_list( { 'testcase_name': '_shape={}' '_n_zero_sv={}_degeneracy={}_geometric_spectrum={}' '_max_sv={}_method={}_side={}' '_nonzero_condition_number={}_seed={}'.format( jtu.format_shape_dtype_string( shape, jnp.dtype(dtype).name).replace(" ", ""), n_zero_sv, degeneracy, geometric_spectrum, max_sv, method, side, nonzero_condition_number, seed), 'n_zero_sv': n_zero_sv, 'degeneracy': degeneracy, 'geometric_spectrum': geometric_spectrum, 'max_sv': max_sv, 'shape': shape, 'method': method, 'side': side, 'nonzero_condition_number': nonzero_condition_number, 'dtype': dtype, 'seed': seed } for n_zero_sv in n_zero_svs for degeneracy in degeneracies for geometric_spectrum in geometric_spectra for max_sv in max_svs for shape in polar_shapes for method in methods for side in sides for nonzero_condition_number in nonzero_condition_numbers for dtype in jtu.dtypes.floating for seed in seeds)) def testPolar(self, n_zero_sv, degeneracy, geometric_spectrum, max_sv, shape, method, side, nonzero_condition_number, dtype, seed): """ Tests jax.scipy.linalg.polar.""" if jtu.device_under_test() != "cpu": if jnp.dtype(dtype).name in ("bfloat16", "float16"): raise unittest.SkipTest("Skip half precision off CPU.") if method == "svd": raise unittest.SkipTest("Can't use SVD mode on TPU/GPU.") np.random.seed(seed) matrix, _ = _initialize_polar_test(shape, n_zero_sv, degeneracy, geometric_spectrum, max_sv, nonzero_condition_number, dtype) if jnp.dtype(dtype).name in ("bfloat16", "float16"): self.assertRaises(NotImplementedError, jsp.linalg.polar, matrix, method=method, side=side) return unitary, posdef = jsp.linalg.polar(matrix, method=method, side=side) if shape[0] >= shape[1]: should_be_eye = np.matmul(unitary.conj().T, unitary) else: should_be_eye = np.matmul(unitary, unitary.conj().T) tol = 10 * jnp.finfo(matrix.dtype).eps eye_mat = np.eye(should_be_eye.shape[0], dtype=should_be_eye.dtype) with self.subTest('Test unitarity.'): self.assertAllClose(eye_mat, should_be_eye, atol=tol * min(shape)) with self.subTest('Test Hermiticity.'): self.assertAllClose(posdef, posdef.conj().T, atol=tol * jnp.linalg.norm(posdef)) ev, _ = np.linalg.eigh(posdef) ev = ev[np.abs(ev) > tol * np.linalg.norm(posdef)] negative_ev = jnp.sum(ev < 0.) with self.subTest('Test positive definiteness.'): assert negative_ev == 0. if side == "right": recon = jnp.matmul(unitary, posdef, precision=lax.Precision.HIGHEST) elif side == "left": recon = jnp.matmul(posdef, unitary, precision=lax.Precision.HIGHEST) with self.subTest('Test reconstruction.'): self.assertAllClose(matrix, recon, atol=tol * jnp.linalg.norm(matrix)) @parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': '_linear_size_={}_seed={}_dtype={}'.format(linear_size, seed, jnp.dtype(dtype).name), 'linear_size': linear_size, 'seed': seed, 'dtype': dtype } for linear_size in linear_sizes for seed in seeds for dtype in jtu.dtypes.floating)) def test_spectral_dac_eigh(self, linear_size, seed, dtype): if jtu.device_under_test != "cpu": raise unittest.SkipTest("Skip eigh off CPU for now.") if jnp.dtype(dtype).name in ("bfloat16", "float16"): if jtu.device_under_test() != "cpu": raise unittest.SkipTest("Skip half precision off CPU.") np.random.seed(seed) H = np.random.randn(linear_size, linear_size) H = jnp.array(0.5 * (H + H.conj().T)).astype(dtype) if jnp.dtype(dtype).name in ("bfloat16", "float16"): self.assertRaises(NotImplementedError, jax._src.scipy.eigh.eigh, H) return evs, V = jax._src.scipy.eigh.eigh(H) ev_exp, eV_exp = jnp.linalg.eigh(H) HV = jnp.dot(H, V, precision=lax.Precision.HIGHEST) vV = evs * V eps = jnp.finfo(H.dtype).eps atol = jnp.linalg.norm(H) * eps self.assertAllClose(ev_exp, jnp.sort(evs), atol=20 * atol) self.assertAllClose(HV, vV, atol=30 * atol) @parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': '_linear_size_={}_seed={}_dtype={}'.format(linear_size, seed, jnp.dtype(dtype).name), 'linear_size': linear_size, 'seed': seed, 'dtype': dtype } for linear_size in linear_sizes for seed in seeds for dtype in jtu.dtypes.floating)) def test_spectral_dac_svd(self, linear_size, seed, dtype): if jnp.dtype(dtype).name in ("bfloat16", "float16"): if jtu.device_under_test() != "cpu": raise unittest.SkipTest("Skip half precision off CPU.") np.random.seed(seed) A = np.random.randn(linear_size, linear_size).astype(dtype) if jnp.dtype(dtype).name in ("bfloat16", "float16"): self.assertRaises(NotImplementedError, jax._src.scipy.eigh.svd, A) return S_expected = np.linalg.svd(A, compute_uv=False) U, S, V = jax._src.scipy.eigh.svd(A) recon = jnp.dot((U * S), V, precision=lax.Precision.HIGHEST) eps = jnp.finfo(dtype).eps eps = eps * jnp.linalg.norm(A) * 10 self.assertAllClose(np.sort(S), np.sort(S_expected), atol=eps) self.assertAllClose(A, recon, atol=eps) # U is unitary. u_unitary_delta = jnp.dot(U.conj().T, U, precision=lax.Precision.HIGHEST) u_eye = jnp.eye(u_unitary_delta.shape[0], dtype=dtype) self.assertAllClose(u_unitary_delta, u_eye, atol=eps) # V is unitary. v_unitary_delta = jnp.dot(V.conj().T, V, precision=lax.Precision.HIGHEST) v_eye = jnp.eye(v_unitary_delta.shape[0], dtype=dtype) self.assertAllClose(v_unitary_delta, v_eye, atol=eps)
class DLPackTest(jtu.JaxTestCase): def setUp(self): super(DLPackTest, self).setUp() if jtu.device_under_test() == "tpu": self.skipTest("DLPack not supported on TPU") @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}_take_ownership={}".format( jtu.format_shape_dtype_string(shape, dtype), take_ownership), "shape": shape, "dtype": dtype, "take_ownership": take_ownership } for shape in all_shapes for dtype in dlpack_dtypes for take_ownership in [False, True])) def testJaxRoundTrip(self, shape, dtype, take_ownership): rng = jtu.rand_default(self.rng()) np = rng(shape, dtype) x = jnp.array(np) dlpack = jax.dlpack.to_dlpack(x, take_ownership=take_ownership) self.assertEqual(take_ownership, x.device_buffer.is_deleted()) y = jax.dlpack.from_dlpack(dlpack) self.assertAllClose(np.astype(x.dtype), y) self.assertRaisesRegex(RuntimeError, "DLPack tensor may be consumed at most once", lambda: jax.dlpack.from_dlpack(dlpack)) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype } for shape in all_shapes for dtype in dlpack_dtypes)) @unittest.skipIf(not tf, "Test requires TensorFlow") def testTensorFlowToJax(self, shape, dtype): if not config.x64_enabled and dtype in [ jnp.int64, jnp.uint64, jnp.float64 ]: raise self.skipTest("x64 types are disabled by jax_enable_x64") if (jtu.device_under_test() == "gpu" and not tf.config.list_physical_devices("GPU")): raise self.skipTest("TensorFlow not configured with GPU support") rng = jtu.rand_default(self.rng()) np = rng(shape, dtype) with tf.device("/GPU:0" if jtu.device_under_test() == "gpu" else "/CPU:0"): x = tf.constant(np) dlpack = tf.experimental.dlpack.to_dlpack(x) y = jax.dlpack.from_dlpack(dlpack) self.assertAllClose(np, y) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype } for shape in all_shapes for dtype in dlpack_dtypes)) @unittest.skipIf(not tf, "Test requires TensorFlow") def testJaxToTensorFlow(self, shape, dtype): if not config.x64_enabled and dtype in [ jnp.int64, jnp.uint64, jnp.float64 ]: self.skipTest("x64 types are disabled by jax_enable_x64") if (jtu.device_under_test() == "gpu" and not tf.config.list_physical_devices("GPU")): raise self.skipTest("TensorFlow not configured with GPU support") rng = jtu.rand_default(self.rng()) np = rng(shape, dtype) x = jnp.array(np) # TODO(b/171320191): this line works around a missing context initialization # bug in TensorFlow. _ = tf.add(1, 1) dlpack = jax.dlpack.to_dlpack(x) y = tf.experimental.dlpack.from_dlpack(dlpack) self.assertAllClose(np, y.numpy()) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype } for shape in all_shapes for dtype in torch_dtypes)) @unittest.skipIf(not torch, "Test requires PyTorch") def testTorchToJax(self, shape, dtype): if not config.x64_enabled and dtype in [jnp.int64, jnp.float64]: self.skipTest("x64 types are disabled by jax_enable_x64") rng = jtu.rand_default(self.rng()) np = rng(shape, dtype) x = torch.from_numpy(np) x = x.cuda() if jtu.device_under_test() == "gpu" else x dlpack = torch.utils.dlpack.to_dlpack(x) y = jax.dlpack.from_dlpack(dlpack) self.assertAllClose(np, y) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype } for shape in all_shapes for dtype in torch_dtypes)) @unittest.skipIf(not torch, "Test requires PyTorch") def testJaxToTorch(self, shape, dtype): if not config.x64_enabled and dtype in [jnp.int64, jnp.float64]: self.skipTest("x64 types are disabled by jax_enable_x64") rng = jtu.rand_default(self.rng()) np = rng(shape, dtype) x = jnp.array(np) dlpack = jax.dlpack.to_dlpack(x) y = torch.utils.dlpack.from_dlpack(dlpack) self.assertAllClose(np, y.cpu().numpy())
class LaxBackedNumpyTests(jtu.JaxTestCase): """Tests for LAX-backed Numpy implementation.""" def _GetArgsMaker(self, rng, shapes, dtypes): return lambda: [ rng(shape, dtype) for shape, dtype in zip(shapes, dtypes) ] @parameterized.named_parameters( itertools.chain.from_iterable( jtu.cases_from_list( { "testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes, dtypes), "rng": rec.rng, "shapes": shapes, "dtypes": dtypes, "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name) } for shapes in filter( _shapes_are_broadcast_compatible, CombosWithReplacement(rec.shapes, rec.nargs)) for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs)) for rec in itertools.chain(JAX_ONE_TO_ONE_OP_RECORDS, JAX_COMPOUND_OP_RECORDS))) def testOp(self, onp_op, lnp_op, rng, shapes, dtypes): args_maker = self._GetArgsMaker(rng, shapes, dtypes) self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True) @parameterized.named_parameters( itertools.chain.from_iterable( jtu.cases_from_list( { "testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes, dtypes), "rng": rec.rng, "shapes": shapes, "dtypes": dtypes, "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name) } for shapes in filter( _shapes_are_broadcast_compatible, CombosWithReplacement(rec.shapes, rec.nargs)) for dtypes in filter( _dtypes_are_compatible_for_bitwise_ops, CombosWithReplacement(rec.dtypes, rec.nargs))) for rec in JAX_BITWISE_OP_RECORDS)) def testBitwiseOp(self, onp_op, lnp_op, rng, shapes, dtypes): if not FLAGS.jax_enable_x64 and any( onp.iinfo(dtype).bits == 64 for dtype in dtypes): self.skipTest("x64 types are disabled by jax_enable_x64") args_maker = self._GetArgsMaker(rng, shapes, dtypes) self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "{}_inshape={}_axis={}_dtype={}_keepdims={}".format( rec.test_name.capitalize(), jtu.format_shape_dtype_string(shape, dtype), axis, "None" if out_dtype is None else onp.dtype(out_dtype).name, keepdims), "rng": rec.rng, "shape": shape, "dtype": dtype, "out_dtype": out_dtype, "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name), "axis": axis, "keepdims": keepdims } for rec in JAX_REDUCER_RECORDS for shape in rec.shapes for dtype in rec.dtypes for out_dtype in [None] + rec.dtypes for axis in range(-len(shape), len(shape)) for keepdims in [False, True])) def testReducer(self, onp_op, lnp_op, rng, shape, dtype, out_dtype, axis, keepdims): onp_fun = lambda x: onp_op(x, axis, dtype=out_dtype, keepdims=keepdims) lnp_fun = lambda x: lnp_op(x, axis, dtype=out_dtype, keepdims=keepdims) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "{}_inshape={}_axis={}_keepdims={}".format( rec.test_name.capitalize(), jtu.format_shape_dtype_string(shape, dtype), axis, keepdims), "rng": rec.rng, "shape": shape, "dtype": dtype, "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name), "axis": axis, "keepdims": keepdims } for rec in JAX_REDUCER_NO_DTYPE_RECORDS for shape in rec.shapes for dtype in rec.dtypes for axis in range(-len(shape), len(shape)) for keepdims in [False, True])) def testReducerNoDtype(self, onp_op, lnp_op, rng, shape, dtype, axis, keepdims): onp_fun = lambda x: onp_op(x, axis, keepdims=keepdims) lnp_fun = lambda x: lnp_op(x, axis, keepdims=keepdims) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "{}_inshape={}_axis={}".format( rec.test_name.capitalize(), jtu.format_shape_dtype_string(shape, dtype), axis), "rng": rec.rng, "shape": shape, "dtype": dtype, "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name), "axis": axis } for rec in JAX_ARGMINMAX_RECORDS for shape in rec.shapes for dtype in rec.dtypes for axis in range(-len(shape), len(shape)))) def testArgMinMax(self, onp_op, lnp_op, rng, shape, dtype, axis): def onp_fun(array_to_reduce): return onp_op(array_to_reduce, axis) def lnp_fun(array_to_reduce): return lnp_op(array_to_reduce, axis) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}_{}_{}".format( name, jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)), "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype, "rng": rng } for rng in [jtu.rand_default()] for name, lhs_shape, rhs_shape in [( "matrix-scalar", (3, 3), ()), ("scalar-matrix", (), (3, 3)), ("matrix-vector", (4, 5), ( 5, )), ("vector-matrix", (6, ), ( 6, 4)), ("matrix-matrix", (3, 4), (4, 5)), ("tensor-vector", (4, 3, 2), (2, )), ( "vector-tensor", (2, ), (3, 2, 4) ), ("tensor-matrix", (4, 3, 2), (2, 5)), ( "matrix-tensor", (5, 2), (3, 2, 4) ), ("tensor-tensor", (2, 3, 4), (5, 4, 1))] for lhs_dtype, rhs_dtype in CombosWithReplacement( float_dtypes, 2))) def testDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng): args_maker = lambda: [ rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype) ] self._CheckAgainstNumpy(onp.dot, lnp.dot, args_maker, check_dtypes=True) self._CompileAndCheck(lnp.dot, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}_{}_{}".format( name, jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)), "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype, "rng": rng } for rng in [jtu.rand_default()] for name, lhs_shape, rhs_shape in [ ("vector-vector", (3, ), (3, )), ("matrix-vector", (3, 3), (3, )), ("vector-matrix", (3, ), (3, 3)), ("matrix-matrix", (3, 3), (3, 3)), ("vector-tensor", (3, ), (5, 3, 2)), ("tensor-vector", (5, 3, 2), (2, )), ("matrix-tensor", (5, 2), (3, 2, 4)), ("tensor-matrix", (5, 2, 3), (3, 2)), ("tensor-tensor", (5, 3, 4), (5, 4, 1)), ("tensor-tensor-broadcast", (3, 1, 3, 4), (5, 4, 1)) ] for lhs_dtype, rhs_dtype in CombosWithReplacement(float_dtypes, 2))) def testMatmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng): args_maker = lambda: [ rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype) ] self._CheckAgainstNumpy(onp.matmul, lnp.matmul, args_maker, check_dtypes=True) self._CompileAndCheck(lnp.matmul, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}_amin={}_amax={}".format( jtu.format_shape_dtype_string(shape, dtype), a_min, a_max), "shape": shape, "dtype": dtype, "a_min": a_min, "a_max": a_max, "rng": jtu.rand_default() } for shape in all_shapes for dtype in float_dtypes for a_min, a_max in [(-1, None), (None, 1), (-1, 1)])) def testClipStaticBounds(self, shape, dtype, a_min, a_max, rng): onp_fun = lambda x: onp.clip(x, a_min=a_min, a_max=a_max) lnp_fun = lambda x: lnp.clip(x, a_min=a_min, a_max=a_max) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}_decimals={}".format( jtu.format_shape_dtype_string(shape, dtype), decimals), "shape": shape, "dtype": dtype, "decimals": decimals, "rng": jtu.rand_default() } for shape in all_shapes for dtype in float_dtypes for decimals in [0, 1, -2])) def testRoundStaticDecimals(self, shape, dtype, decimals, rng): onp_fun = lambda x: onp.round(x, decimals=decimals) lnp_fun = lambda x: lnp.round(x, decimals=decimals) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format( axis, ",".join(str(d) for d in base_shape), ",".join( onp.dtype(dtype).name for dtype in dtypes)), "axis": axis, "base_shape": base_shape, "dtypes": dtypes, "rng": jtu.rand_default() } for num_arrs in [3] for dtypes in CombosWithReplacement(default_dtypes, num_arrs) for base_shape in [(4, ), (3, 4), (2, 3, 4)] for axis in range(-len(base_shape) + 1, len(base_shape)))) def testConcatenate(self, axis, base_shape, dtypes, rng): wrapped_axis = axis % len(base_shape) shapes = [ base_shape[:wrapped_axis] + (size, ) + base_shape[wrapped_axis + 1:] for size, _ in zip(itertools.cycle([3, 1, 4]), dtypes) ] onp_fun = lambda *args: onp.concatenate(args, axis=axis) lnp_fun = lambda *args: lnp.concatenate(args, axis=axis) def args_maker(): return [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape=[{}]_axis={}_repeats={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, repeats), "axis": axis, "shape": shape, "dtype": dtype, "repeats": repeats, "rng": jtu.rand_default() } for repeats in [0, 1, 2] for dtype in default_dtypes for shape in all_shapes for axis in [None] + list(range(-len(shape), len(shape))))) def testRepeat(self, axis, shape, dtype, repeats, rng): onp_fun = lambda arg: onp.repeat(arg, repeats=repeats, axis=axis) lnp_fun = lambda arg: lnp.repeat(arg, repeats=repeats, axis=axis) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_dtype={}_m={}_n={}_k={}".format(onp.dtype(dtype).name, m, n, k), "m": m, "n": n, "k": k, "dtype": dtype, "rng": jtu.rand_default() } for dtype in default_dtypes for n in [0, 4] for m in [None, 0, 1, 3, 4] for k in list(range(-4, 4)))) def testTri(self, m, n, k, dtype, rng): onp_fun = lambda: onp.tri(n, M=m, k=k, dtype=dtype) lnp_fun = lambda: lnp.tri(n, M=m, k=k, dtype=dtype) args_maker = lambda: [] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_op={}_shape={}_k={}".format( op, jtu.format_shape_dtype_string(shape, dtype), k), "dtype": dtype, "shape": shape, "op": op, "k": k, "rng": jtu.rand_default() } for dtype in default_dtypes for shape in [shape for shape in all_shapes if len(shape) >= 1] for op in ["tril", "triu"] for k in list(range(-3, 3)))) def testTriLU(self, dtype, shape, op, k, rng): onp_fun = lambda arg: getattr(onp, op)(arg, k=k) lnp_fun = lambda arg: getattr(lnp, op)(arg, k=k) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_shape={}_k={}".format( jtu.format_shape_dtype_string(shape, dtype), k), "dtype": dtype, "shape": shape, "k": k, "rng": jtu.rand_default() } for dtype in default_dtypes for shape in [shape for shape in all_shapes if len(shape) in (1, 2)] for k in list(range(-4, 4)))) def testDiag(self, shape, dtype, k, rng): onp_fun = lambda arg: onp.diag(arg, k) lnp_fun = lambda arg: lnp.diag(arg, k) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_shape={}_offset={}_axis1={}_axis2={}".format( jtu.format_shape_dtype_string(shape, dtype), offset, axis1, axis2), "dtype": dtype, "shape": shape, "offset": offset, "axis1": axis1, "axis2": axis2, "rng": jtu.rand_default() } for dtype in default_dtypes for shape in [shape for shape in all_shapes if len(shape) >= 2] for (axis1, axis2) in itertools.combinations(range(len(shape)), 2) for offset in list(range(-4, 4)))) def testDiagonal(self, shape, dtype, offset, axis1, axis2, rng): onp_fun = lambda arg: onp.diagonal(arg, offset, axis1, axis2) lnp_fun = lambda arg: lnp.diagonal(arg, offset, axis1, axis2) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}".format( jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes)), "shape": shape, "dtypes": dtypes, "rng": rng } for dtypes in [ [onp.float32], [onp.float32, onp.float32], [onp.float32, onp.int32, onp.float32], [onp.float32, onp.int64, onp.float32], [onp.float32, onp.int32, onp.float64], ] for shape in [(), (2, ), (3, 4), (1, 100)] for rng in [jtu.rand_default()])) def testStack(self, shape, dtypes, rng): args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] self._CheckAgainstNumpy(lnp.stack, onp.stack, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_inshape={}_outdtype={}".format( jtu.format_shape_dtype_string(shape, fill_value_dtype), onp.dtype(out_dtype).name), "shape": shape, "fill_value_dtype": fill_value_dtype, "out_dtype": out_dtype, "rng": jtu.rand_default() } for shape in array_shapes for fill_value_dtype in default_dtypes for out_dtype in default_dtypes)) def testFull(self, shape, fill_value_dtype, out_dtype, rng): onp_fun = lambda fill_value: onp.full( shape, fill_value, dtype=out_dtype) lnp_fun = lambda fill_value: lnp.full( shape, fill_value, dtype=out_dtype) args_maker = lambda: [rng((), fill_value_dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}_axis={}_{}sections".format( jtu.format_shape_dtype_string(shape, dtype), axis, num_sections), "shape": shape, "num_sections": num_sections, "axis": axis, "dtype": dtype, "rng": jtu.rand_default() } for shape, axis, num_sections in [((3, ), 0, 3), (( 12, ), 0, 3), ((12, 4), 0, 4), ((12, 4), 1, 2), ((2, 3, 4), -1, 2), ((2, 3, 4), -2, 3)] for dtype in default_dtypes)) def testSplitStaticInt(self, shape, num_sections, axis, dtype, rng): onp_fun = lambda x: onp.split(x, num_sections, axis=axis) lnp_fun = lambda x: lnp.split(x, num_sections, axis=axis) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_inshape={}_outshape={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), jtu.format_shape_dtype_string(out_shape, dtype)), "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype, "rng": jtu.rand_default() } for dtype in default_dtypes for arg_shape, out_shape in [( jtu.NUMPY_SCALAR_SHAPE, (1, 1, 1)), ((), (1, 1, 1)), ((7, 0), (0, 42, 101)), (( 3, 4), 12), ((3, 4), (12, )), ((3, 4), -1), ((2, 1, 4), (-1, )), ((2, 2, 4), (2, 8))])) def testReshape(self, arg_shape, out_shape, dtype, rng): onp_fun = lambda x: onp.reshape(x, out_shape) lnp_fun = lambda x: lnp.reshape(x, out_shape) args_maker = lambda: [rng(arg_shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_inshape={}_expanddim={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), dim), "arg_shape": arg_shape, "dtype": dtype, "dim": dim, "rng": jtu.rand_default() } for arg_shape in [(), (3, ), (3, 4)] for dtype in default_dtypes for dim in range(-len(arg_shape) + 1, len(arg_shape)))) def testExpandDimsStaticDim(self, arg_shape, dtype, dim, rng): onp_fun = lambda x: onp.expand_dims(x, dim) lnp_fun = lambda x: lnp.expand_dims(x, dim) args_maker = lambda: [rng(arg_shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_inshape={}_axes=({},{})".format( jtu.format_shape_dtype_string(arg_shape, dtype), ax1, ax2), "arg_shape": arg_shape, "dtype": dtype, "ax1": ax1, "ax2": ax2, "rng": jtu.rand_default() } for arg_shape, ax1, ax2 in [((3, 4), 0, 1), ((3, 4), 1, 0), ((3, 4, 5), 1, 2), ((3, 4, 5), -1, -2), ((3, 4, 5), 0, 1)] for dtype in default_dtypes)) def testSwapAxesStaticAxes(self, arg_shape, dtype, ax1, ax2, rng): onp_fun = lambda x: onp.swapaxes(x, ax1, ax2) lnp_fun = lambda x: lnp.swapaxes(x, ax1, ax2) args_maker = lambda: [rng(arg_shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_inshape={}_axis={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), ax), "arg_shape": arg_shape, "dtype": dtype, "ax": ax, "rng": jtu.rand_default() } for arg_shape, ax in [((3, 1), None), ((3, 1), 1), ((1, 3, 1), (0, 2)), ((1, 4, 1), (0, ))] for dtype in default_dtypes)) def testSqueeze(self, arg_shape, dtype, ax, rng): onp_fun = lambda x: onp.squeeze(x, ax) lnp_fun = lambda x: lnp.squeeze(x, ax) args_maker = lambda: [rng(arg_shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_arg{}".format(i), "arg": arg } for i, arg in enumerate([ [1, 2, 3], [1., 2., 3.], [[1, 2], [3, 4], [5, 6]], [[1, 2.], [3, 4], [5, 6]], [[3, onp.array(2), 1], onp.arange(3.)], ]))) def testArray(self, arg): args_maker = lambda: [arg] self._CheckAgainstNumpy(onp.array, lnp.array, args_maker, check_dtypes=True) self._CompileAndCheck(lnp.array, args_maker, check_dtypes=True) def testArrayAsarrayMethod(self): class arraylike(object): def __asarray__(self, dtype=None): return 3. a = arraylike() ans = lnp.array(a) assert ans == 3. def testAllClose(self): rng = onp.random.RandomState(0) x = rng.randn(2, 2) y = rng.randn(2) def same(list1, list2): allclose = functools.partial(lnp.allclose, atol=1e-3, rtol=1e-3) elements_close = list(map(allclose, list1, list2)) return lnp.all(lnp.array(elements_close)) csame = api.jit(same) a1 = same((x, y), (x, y)) a2 = csame((x, y), (x, y)) a3 = csame((x, y), (x, 2 * y)) self.assertTrue(a1) self.assertTrue(a2) self.assertFalse(a3) @jtu.skip_on_devices("tpu") # TODO(mattjj): investigate this failure def DISABLED_testOnesBroadcastingConstantHandler(self): # TODO(mattjj): update this test for jax3 def fun(x): ones = lnp.ones((3, 4)) assert isinstance(ones, onp.ndarray) and ones.strides == (0, 0) # To check that the constant handler generates a Broadcast for stride-zero # arrays, we monkey-patch the client instance. # TODO(mattjj): once we have better HLO dumping and inspecting facilities, # we can check the HLO more directly. c = x._node.c Broadcast = c.Broadcast # pylint: disable=invalid-name was_called = [] c.Broadcast = lambda *args: was_called.append(True) or Broadcast( *args) out = x + ones # the ndarray constant handler should call Broadcast here assert was_called, "Broadcast was not called." return out fun = api.jit(fun) out_val = fun(lnp.ones(4)) self.assertAllClose(out_val, onp.full((3, 4), 2.), check_dtypes=False) def testZeroStridesConstantHandler(self): raw_const = onp.random.RandomState(0).randn(1, 2, 1, 1, 5, 1) const = onp.broadcast_to(raw_const, (3, 2, 3, 4, 5, 6)) def fun(x): return x * const fun = api.jit(fun) out_val = fun(3.) self.assertAllClose(out_val, 3. * const, check_dtypes=False) def testIsInstanceNdarrayDuringTracing(self): arr = onp.ones(3) @api.jit def f(x): self.assertIsInstance(x, lnp.ndarray) return lnp.sum(x) f(arr) def testNonArrayErrorMessage(self): x = [1., 2.] y = onp.array([3., 4.]) def g(x, y): return lnp.add(x, y) def f(x, y): return lnp.dot(x, y) self.assertRaises(TypeError, lambda: g(x, y)) self.assertRaises(TypeError, lambda: f(x, y)) self.assertRaises(TypeError, lambda: api.jit(g)(x, y)) self.assertRaises(TypeError, lambda: api.jit(f)(x, y)) def testAbstractionErrorMessage(self): @api.jit def f(x, n): for _ in range(n): x = x * x return x self.assertRaises(TypeError, lambda: f(3., 3)) @api.jit def g(x): if x > 0.: return x * 2 else: return x + 2 self.assertRaises(TypeError, lambda: g(3.)) def DISABLED_testTracingPrimitiveWithNoTranslationErrorMessage(self): # TODO(mattjj): update this for jax3 foo = lnp._not_implemented(lambda x: x) # No error if there's no tracing. foo(onp.arange(3)) cfoo = api.jit(foo) self.assertRaises(NotImplementedError, lambda: cfoo(onp.arange(3))) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}_axis={}".format(jtu.format_shape_dtype_string(shape, dtype), axis), "rng": rng, "shape": shape, "dtype": dtype, "axis": axis } for shape in [(3, ), (2, 3)] for dtype in default_dtypes for axis in range(len(shape)) for rng in [jtu.rand_default()])) def testFlip(self, shape, dtype, axis, rng): args_maker = self._GetArgsMaker(rng, [shape], [dtype]) lnp_op = lambda x: lnp.flip(x, axis) onp_op = lambda x: onp.flip(x, axis) self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}_k={}_axes={}".format( jtu.format_shape_dtype_string(shape, dtype), k, axes), "rng": rng, "shape": shape, "dtype": dtype, "k": k, "axes": axes } for shape, axes in [ [(2, 3), (0, 1)], [(2, 3), (1, 0)], [(4, 3, 2), (0, 2)], [(4, 3, 2), (2, 1)], ] for k in range(-3, 4) for dtype in default_dtypes for rng in [jtu.rand_default()])) def testRot90(self, shape, dtype, k, axes, rng): args_maker = self._GetArgsMaker(rng, [shape], [dtype]) lnp_op = lambda x: lnp.rot90(x, k, axes) onp_op = lambda x: onp.rot90(x, k, axes) self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True) # TODO(mattjj): test infix operator overrides def DISABLED_testRavel(self): # TODO(mattjj): support this method-based syntax? rng = onp.random.RandomState(0) args_maker = lambda: [rng.randn(3, 4).astype("float32")] self._CompileAndCheck(lambda x: x.ravel(), args_maker, check_dtypes=True)
class LaxBackedScipyTests(jtu.JaxTestCase): """Tests for LAX-backed Scipy implementation.""" def _GetArgsMaker(self, rng, shapes, dtypes): return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shapes={}_axis={}_keepdims={}_return_sign={}_use_b_{}".format( jtu.format_shape_dtype_string(shapes, dtype), axis, keepdims, return_sign, use_b), # TODO(b/133842870): re-enable when exp(nan) returns NaN on CPU. "shapes": shapes, "dtype": dtype, "axis": axis, "keepdims": keepdims, "return_sign": return_sign, "use_b": use_b} for shape_group in compatible_shapes for dtype in float_dtypes for use_b in [False, True] for shapes in itertools.product(*( (shape_group, shape_group) if use_b else (shape_group,))) for axis in range(-max(len(shape) for shape in shapes), max(len(shape) for shape in shapes)) for keepdims in [False, True] for return_sign in [False, True])) @jtu.ignore_warning(category=RuntimeWarning, message="invalid value encountered in .*") def testLogSumExp(self, shapes, dtype, axis, keepdims, return_sign, use_b): if jtu.device_under_test() != "cpu": rng = jtu.rand_some_inf_and_nan(self.rng()) else: rng = jtu.rand_default(self.rng()) # TODO(mattjj): test autodiff if use_b: def scipy_fun(array_to_reduce, scale_array): return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign, b=scale_array) def lax_fun(array_to_reduce, scale_array): return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign, b=scale_array) args_maker = lambda: [rng(shapes[0], dtype), rng(shapes[1], dtype)] else: def scipy_fun(array_to_reduce): return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign) def lax_fun(array_to_reduce): return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign) args_maker = lambda: [rng(shapes[0], dtype)] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker) self._CompileAndCheck(lax_fun, args_maker) @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix( rec.test_name, shapes, dtypes), "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, "test_autodiff": rec.test_autodiff, "nondiff_argnums": rec.nondiff_argnums, "scipy_op": getattr(osp_special, rec.name), "lax_op": getattr(lsp_special, rec.name)} for shapes in itertools.combinations_with_replacement(all_shapes, rec.nargs) for dtypes in (itertools.combinations_with_replacement(rec.dtypes, rec.nargs) if isinstance(rec.dtypes, list) else itertools.product(*rec.dtypes))) for rec in JAX_SPECIAL_FUNCTION_RECORDS)) def testScipySpecialFun(self, scipy_op, lax_op, rng_factory, shapes, dtypes, test_autodiff, nondiff_argnums): if (jtu.device_under_test() == "cpu" and (lax_op is lsp_special.gammainc or lax_op is lsp_special.gammaincc)): # TODO(b/173608403): re-enable test when LLVM bug is fixed. raise unittest.SkipTest("Skipping test due to LLVM lowering bug") rng = rng_factory(self.rng()) args_maker = self._GetArgsMaker(rng, shapes, dtypes) args = args_maker() self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3, check_dtypes=False) self._CompileAndCheck(lax_op, args_maker, rtol=1e-4) if test_autodiff: def partial_lax_op(*vals): list_args = list(vals) for i in nondiff_argnums: list_args.insert(i, args[i]) return lax_op(*list_args) assert list(nondiff_argnums) == sorted(set(nondiff_argnums)) diff_args = [x for i, x in enumerate(args) if i not in nondiff_argnums] jtu.check_grads(partial_lax_op, diff_args, order=1, atol=jtu.if_device_under_test("tpu", .1, 1e-3), rtol=.1, eps=1e-3) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_d={}".format( jtu.format_shape_dtype_string(shape, dtype), d), "rng_factory": jtu.rand_positive, "shape": shape, "dtype": dtype, "d": d} for shape in all_shapes for dtype in float_dtypes for d in [1, 2, 5])) def testMultigammaln(self, rng_factory, shape, dtype, d): def scipy_fun(a): return osp_special.multigammaln(a, d) def lax_fun(a): return lsp_special.multigammaln(a, d) rng = rng_factory(self.rng()) args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol={np.float32: 1e-3, np.float64: 1e-14}) self._CompileAndCheck(lax_fun, args_maker) def testIssue980(self): x = np.full((4,), -1e20, dtype=np.float32) self.assertAllClose(np.zeros((4,), dtype=np.float32), lsp_special.expit(x)) def testIssue3758(self): x = np.array([1e5, 1e19, 1e10], dtype=np.float32) q = np.array([1., 40., 30.], dtype=np.float32) self.assertAllClose(np.array([1., 0., 0.], dtype=np.float32), lsp_special.zeta(x, q)) def testXlogyShouldReturnZero(self): self.assertAllClose(lsp_special.xlogy(0., 0.), 0., check_dtypes=False) def testGradOfXlogyAtZero(self): partial_xlogy = functools.partial(lsp_special.xlogy, 0.) self.assertAllClose(api.grad(partial_xlogy)(0.), 0., check_dtypes=False) def testXlog1pyShouldReturnZero(self): self.assertAllClose(lsp_special.xlog1py(0., -1.), 0., check_dtypes=False) def testGradOfXlog1pyAtZero(self): partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.) self.assertAllClose(api.grad(partial_xlog1py)(-1.), 0., check_dtypes=False)
class LaxBackedScipyTests(jtu.JaxTestCase): """Tests for LAX-backed Scipy implementation.""" def _GetArgsMaker(self, rng, shapes, dtypes): return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_axis={}_keepdims={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, keepdims), # TODO(b/133842870): re-enable when exp(nan) returns NaN on CPU. "rng_factory": jtu.rand_some_inf_and_nan if jtu.device_under_test() != "cpu" else jtu.rand_default, "shape": shape, "dtype": dtype, "axis": axis, "keepdims": keepdims} for shape in all_shapes for dtype in float_dtypes for axis in range(-len(shape), len(shape)) for keepdims in [False, True])) @jtu.skip_on_flag("jax_xla_backend", "xrt") def testLogSumExp(self, rng_factory, shape, dtype, axis, keepdims): rng = rng_factory() # TODO(mattjj): test autodiff def scipy_fun(array_to_reduce): return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims) def lax_fun(array_to_reduce): return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix( rec.test_name, shapes, dtypes), "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, "test_autodiff": rec.test_autodiff, "scipy_op": getattr(osp_special, rec.name), "lax_op": getattr(lsp_special, rec.name)} for shapes in CombosWithReplacement(all_shapes, rec.nargs) for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs)) for rec in JAX_SPECIAL_FUNCTION_RECORDS)) def testScipySpecialFun(self, scipy_op, lax_op, rng_factory, shapes, dtypes, test_autodiff): rng = rng_factory() args_maker = self._GetArgsMaker(rng, shapes, dtypes) args = args_maker() self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3, check_dtypes=False) self._CompileAndCheck(lax_op, args_maker, check_dtypes=True, rtol=1e-5) if test_autodiff: jtu.check_grads(lax_op, args, order=1, atol=jtu.if_device_under_test("tpu", 2e-3, 1e-3), rtol=3e-3, eps=1e-3) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_d={}".format( jtu.format_shape_dtype_string(shape, dtype), d), "rng_factory": jtu.rand_positive, "shape": shape, "dtype": dtype, "d": d} for shape in all_shapes for dtype in float_dtypes for d in [1, 2, 5])) def testMultigammaln(self, rng_factory, shape, dtype, d): def scipy_fun(a): return osp_special.multigammaln(a, d) def lax_fun(a): return lsp_special.multigammaln(a, d) rng = rng_factory() args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True, tol={onp.float32: 1e-3, onp.float64: 1e-14}) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) def testIssue980(self): x = onp.full((4,), -1e20, dtype=onp.float32) self.assertAllClose(onp.zeros((4,), dtype=onp.float32), lsp_special.expit(x), check_dtypes=True) def testXlogyShouldReturnZero(self): self.assertAllClose(lsp_special.xlogy(0., 0.), 0., check_dtypes=False) def testGradOfXlogyAtZero(self): partial_xlogy = functools.partial(lsp_special.xlogy, 0.) self.assertAllClose(api.grad(partial_xlogy)(0.), 0., check_dtypes=False) def testXlog1pyShouldReturnZero(self): self.assertAllClose(lsp_special.xlog1py(0., -1.), 0., check_dtypes=False) def testGradOfXlog1pyAtZero(self): partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.) self.assertAllClose(api.grad(partial_xlog1py)(-1.), 0., check_dtypes=False)
class LaxVmapTest(jtu.JaxTestCase): def _CheckBatching(self, op, bdim_size, bdims, shapes, dtypes, rng, rtol=None, atol=None): batched_shapes = map(partial(add_bdim, bdim_size), bdims, shapes) args = [rng(shape, dtype) for shape, dtype in zip(batched_shapes, dtypes)] args_slice = args_slicer(args, bdims) ans = api.vmap(op, bdims)(*args) expected = np.stack([op(*args_slice(i)) for i in range(bdim_size)]) self.assertAllClose(ans, expected, rtol=rtol, atol=atol) @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": "{}_bdims={}".format( jtu.format_test_name_suffix(rec.op, shapes, itertools.repeat(dtype)), bdims), "op_name": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes, "dtype": dtype, "bdims": bdims, "tol": rec.tol} for shape_group in compatible_shapes for shapes in itertools.combinations_with_replacement(shape_group, rec.nargs) for bdims in all_bdims(*shapes) for dtype in rec.dtypes) for rec in LAX_OPS)) def testOp(self, op_name, rng_factory, shapes, dtype, bdims, tol): rng = rng_factory(self.rng()) op = getattr(lax, op_name) self._CheckBatching(op, 10, bdims, shapes, [dtype] * len(shapes), rng, atol=tol, rtol=tol) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_" "rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}" "_lhs_bdim={}_rhs_bdim={}" .format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums), feature_group_count, batch_group_count, lhs_bdim, rhs_bdim), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "strides": strides, "padding": padding, "lhs_dil": lhs_dil, "rhs_dil": rhs_dil, "rng_factory": rng_factory, "dimension_numbers": dim_nums, "perms": perms, "lhs_bdim": lhs_bdim, "rhs_bdim": rhs_bdim, "feature_group_count": feature_group_count, "batch_group_count": batch_group_count, } for batch_group_count, feature_group_count in ([(1, 1), (2, 1), (1, 2)]) for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in [ ((b * batch_group_count, i * feature_group_count, 6, 7), # lhs_shape (j * batch_group_count * feature_group_count, i, 1, 2), # rhs_shape [(1, 1), (1, 2), (2, 1)], # strides [((0, 0), (0, 0)), ((1, 0), (0, 1)), ((0, -1), (0, 0))], # pads [(1, 1), (2, 1)], # lhs_dils [(1, 1), (2, 2)]) # rhs_dils for b, i, j in itertools.product([1, 2], repeat=3)] for strides in all_strides for rhs_dil in rhs_dils for lhs_dil in lhs_dils for dtype in [np.float32] for padding in all_pads for dim_nums, perms in [ (("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])), (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])), (("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))] for lhs_bdim in itertools.chain([cast(Optional[int], None)], range(len(lhs_shape) + 1)) for rhs_bdim in itertools.chain([cast(Optional[int], None)], range(len(rhs_shape) + 1)) if (lhs_bdim, rhs_bdim) != (None, None) for rng_factory in [jtu.rand_default] )) def testConvGeneralDilatedBatching( self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dil, rhs_dil, dimension_numbers, perms, feature_group_count, batch_group_count, lhs_bdim, rhs_bdim, rng_factory): rng = rng_factory(self.rng()) tol = 1e-1 if dtypes.finfo(dtype).bits <= 32 else 1e-3 # permute shapes to match dim_spec, scale by feature_group_count lhs_perm, rhs_perm = perms lhs_shape = list(np.take(lhs_shape, lhs_perm)) rhs_shape = list(np.take(rhs_shape, rhs_perm)) conv = partial(lax.conv_general_dilated, window_strides=strides, padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil, dimension_numbers=dimension_numbers, feature_group_count=feature_group_count, batch_group_count=batch_group_count, precision=lax.Precision.HIGHEST) self._CheckBatching(conv, 5, (lhs_bdim, rhs_bdim), (lhs_shape, rhs_shape), (dtype, dtype), rng, rtol=tol, atol=tol) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_from_dtype={}_to_dtype={}_bdims={}".format( shape, from_dtype, to_dtype, bdims), "shape": shape, "from_dtype": from_dtype, "to_dtype": to_dtype, "bdims": bdims, "rng_factory": rng_factory} for from_dtype, to_dtype in itertools.product( [np.float32, np.int32, "float32", "int32"], repeat=2) for shape in [(2, 3)] for bdims in all_bdims(shape) for rng_factory in [jtu.rand_default])) def testConvertElementType(self, shape, from_dtype, to_dtype, bdims, rng_factory): rng = rng_factory(self.rng()) op = lambda x: lax.convert_element_type(x, to_dtype) self._CheckBatching(op, 10, bdims, (shape,), (from_dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_from_dtype={}_to_dtype={}_bdims={}".format( shape, from_dtype, to_dtype, bdims), "shape": shape, "from_dtype": from_dtype, "to_dtype": to_dtype, "bdims": bdims, "rng_factory": rng_factory} for from_dtype, to_dtype in itertools.product( [np.float32, np.int32, "float32", "int32"], repeat=2) for shape in [(2, 3)] for bdims in all_bdims(shape) for rng_factory in [jtu.rand_default])) def testBitcastElementType(self, shape, from_dtype, to_dtype, bdims, rng_factory): rng = rng_factory(self.rng()) op = lambda x: lax.bitcast_convert_type(x, to_dtype) self._CheckBatching(op, 10, bdims, (shape,), (from_dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_min_shape={}_operand_shape={}_max_shape={}_bdims={}" .format(jtu.format_shape_dtype_string(min_shape, dtype), jtu.format_shape_dtype_string(operand_shape, dtype), jtu.format_shape_dtype_string(max_shape, dtype), bdims), "min_shape": min_shape, "operand_shape": operand_shape, "max_shape": max_shape, "dtype": dtype, "bdims": bdims, "rng_factory": rng_factory} for min_shape, operand_shape, max_shape in [ [(), (2, 3), ()], [(2, 3), (2, 3), ()], [(), (2, 3), (2, 3)], [(2, 3), (2, 3), (2, 3)], ] for dtype in default_dtypes for bdims in all_bdims(min_shape, operand_shape, max_shape) for rng_factory in [jtu.rand_default])) def testClamp(self, min_shape, operand_shape, max_shape, dtype, bdims, rng_factory): rng = rng_factory(self.rng()) raise SkipTest("batching rule for clamp not implemented") # TODO(mattj) shapes = [min_shape, operand_shape, max_shape] self._CheckBatching(lax.clamp, 10, bdims, shapes, [dtype] * 3, rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}_bdims={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), bdims), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "bdims": bdims, "rng_factory": rng_factory} for lhs_shape in [(3,), (4, 3)] for rhs_shape in [(3,), (3, 6)] for bdims in all_bdims(lhs_shape, rhs_shape) for dtype in default_dtypes for rng_factory in [jtu.rand_default])) def testDot(self, lhs_shape, rhs_shape, dtype, bdims, rng_factory): rng = rng_factory(self.rng()) op = partial(lax.dot, precision=lax.Precision.HIGHEST) self._CheckBatching(op, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype), rng, rtol={np.float16: 5e-2, np.float64: 5e-14}) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}_lhs_contracting={}_rhs_contracting={}_bdims={}" .format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), lhs_contracting, rhs_contracting, bdims), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "lhs_contracting": lhs_contracting, "rhs_contracting": rhs_contracting, "bdims": bdims, "rng_factory": rng_factory} for lhs_shape, rhs_shape, lhs_contracting, rhs_contracting in [ [(5,), (5,), [0], [0]], [(5, 7), (5,), [0], [0]], [(7, 5), (5,), [1], [0]], [(3, 5), (2, 5), [1], [1]], [(5, 3), (5, 2), [0], [0]], [(5, 3, 2), (5, 2, 4), [0], [0]], [(5, 3, 2), (5, 2, 4), [0,2], [0,1]], [(5, 3, 2), (3, 5, 2, 4), [0,2], [1,2]], [(1, 2, 2, 3), (1, 2, 3, 1), [1], [1]], [(3, 2), (2, 4), [1], [0]], ] for bdims in all_bdims(lhs_shape, rhs_shape) for dtype in default_dtypes for rng_factory in [jtu.rand_small])) def testDotGeneralContractOnly(self, lhs_shape, rhs_shape, dtype, lhs_contracting, rhs_contracting, bdims, rng_factory): rng = rng_factory(self.rng()) dimension_numbers = ((lhs_contracting, rhs_contracting), ([], [])) dot = partial(lax.dot_general, dimension_numbers=dimension_numbers) self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}_dimension_numbers={}_bdims={}" .format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), dimension_numbers, bdims), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "dimension_numbers": dimension_numbers, "bdims": bdims, "rng_factory": rng_factory} for lhs_shape, rhs_shape, dimension_numbers in [ ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))), ((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1]))), ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))), ] for bdims in all_bdims(lhs_shape, rhs_shape) for dtype in default_dtypes for rng_factory in [jtu.rand_small])) def testDotGeneralContractAndBatch(self, lhs_shape, rhs_shape, dtype, dimension_numbers, bdims, rng_factory): rng = rng_factory(self.rng()) dot = partial(lax.dot_general, dimension_numbers=dimension_numbers) self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype), rng) # Checks that batching didn't introduce any transposes or broadcasts. jaxpr = api.make_jaxpr(dot)(np.zeros(lhs_shape, dtype), np.zeros(rhs_shape, dtype)) for eqn in jtu.iter_eqns(jaxpr.jaxpr): self.assertFalse(eqn.primitive in ["transpose", "broadcast"]) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}_bdims={}".format( shape, np.dtype(dtype).name, broadcast_sizes, bdims), "shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes, "bdims": bdims, "rng_factory": rng_factory} for shape in [(), (2, 3)] for dtype in default_dtypes for broadcast_sizes in [(), (2,), (1, 2)] for bdims in all_bdims(shape) for rng_factory in [jtu.rand_default])) def testBroadcast(self, shape, dtype, broadcast_sizes, bdims, rng_factory): rng = rng_factory(self.rng()) op = lambda x: lax.broadcast(x, broadcast_sizes) self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_outshape={}_bcdims={}_bdims={}".format( jtu.format_shape_dtype_string(inshape, dtype), outshape, broadcast_dimensions, bdims), "inshape": inshape, "dtype": dtype, "outshape": outshape, "dimensions": broadcast_dimensions, "bdims": bdims, "rng_factory": rng_factory} for inshape, outshape, broadcast_dimensions in [ ([2], [2, 2], [0]), ([2], [2, 2], [1]), ([2], [2, 3], [0]), ([], [2, 3], []), ] for dtype in default_dtypes for bdims in all_bdims(inshape) for rng_factory in [jtu.rand_default])) def testBroadcastInDim(self, inshape, dtype, outshape, dimensions, bdims, rng_factory): rng = rng_factory(self.rng()) raise SkipTest("this test has failures in some cases") # TODO(mattjj) op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions) self._CheckBatching(op, 5, bdims, (inshape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_dimensions={}_bdims={}".format( jtu.format_shape_dtype_string(arg_shape, np.float32), dimensions, bdims), "arg_shape": arg_shape, "dimensions": dimensions, "bdims": bdims, "rng_factory": rng_factory} for arg_shape, dimensions in [ [(1,), (0,)], [(1,), (-1,)], [(2, 1, 4), (1,)], [(2, 1, 4), (-2,)], [(2, 1, 3, 1), (1,)], [(2, 1, 3, 1), (1, 3)], [(2, 1, 3, 1), (3,)], [(2, 1, 3, 1), (1, -1)], ] for bdims in all_bdims(arg_shape) for rng_factory in [jtu.rand_default])) def testSqueeze(self, arg_shape, dimensions, bdims, rng_factory): dtype = np.float32 rng = rng_factory(self.rng()) op = lambda x: lax.squeeze(x, dimensions) self._CheckBatching(op, 10, bdims, (arg_shape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_outshape={}_dims={}_bdims={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), jtu.format_shape_dtype_string(out_shape, dtype), dimensions, bdims), "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype, "dimensions": dimensions, "bdims": bdims, "rng_factory": rng_factory} for dtype in default_dtypes for arg_shape, dimensions, out_shape in [ [(3, 4), None, (12,)], [(2, 1, 4), None, (8,)], [(2, 2, 4), None, (2, 8)], [(2, 2, 4), (0, 1, 2), (2, 8)], [(2, 2, 4), (1, 0, 2), (8, 2)], [(2, 2, 4), (2, 1, 0), (4, 2, 2)] ] for bdims in all_bdims(arg_shape) for rng_factory in [jtu.rand_default])) def testReshape(self, arg_shape, out_shape, dtype, dimensions, bdims, rng_factory): rng = rng_factory(self.rng()) op = lambda x: lax.reshape(x, out_shape, dimensions=dimensions) self._CheckBatching(op, 10, bdims, (arg_shape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_pads={}_bdims={}" .format(jtu.format_shape_dtype_string(shape, dtype), pads, bdims), "shape": shape, "dtype": dtype, "pads": pads, "rng_factory": jtu.rand_small, "bdims": bdims} for shape in [(2, 3)] for bdims in all_bdims(shape) for dtype in default_dtypes for pads in [[(1, 2, 1), (0, 1, 0)]])) def testPad(self, shape, dtype, pads, bdims, rng_factory): rng = rng_factory(self.rng()) fun = lambda operand: lax.pad(operand, np.array(0, dtype), pads) self._CheckBatching(fun, 5, bdims, (shape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_predshape={}_argshapes={}_bdims={}".format( jtu.format_shape_dtype_string(pred_shape, np.bool_), jtu.format_shape_dtype_string(arg_shape, arg_dtype), bdims), "pred_shape": pred_shape, "arg_shape": arg_shape, "arg_dtype": arg_dtype, "bdims": bdims, "rng_factory": rng_factory} for arg_shape in [(), (3,), (2, 3)] for pred_shape in ([(), arg_shape] if arg_shape else [()]) for bdims in all_bdims(pred_shape, arg_shape, arg_shape) for arg_dtype in default_dtypes for rng_factory in [jtu.rand_default])) def testSelect(self, pred_shape, arg_shape, arg_dtype, bdims, rng_factory): rng = rng_factory(self.rng()) op = lambda c, x, y: lax.select(c < 0, x, y) self._CheckBatching(op, 5, bdims, (pred_shape, arg_shape, arg_shape,), (np.bool_, arg_dtype, arg_dtype), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_start_indices={}_limit_indices={}_strides={}_bdims={}".format( jtu.format_shape_dtype_string(shape, dtype), start_indices, limit_indices, strides, bdims), "shape": shape, "dtype": dtype, "starts": start_indices, "limits": limit_indices, "strides": strides, "bdims": bdims, "rng_factory": rng_factory} for shape, start_indices, limit_indices, strides in [ [(3,), (1,), (2,), None], [(7,), (4,), (7,), None], [(5,), (1,), (5,), (2,)], [(8,), (1,), (6,), (2,)], [(5, 3), (1, 1), (3, 2), None], [(5, 3), (1, 1), (3, 1), None], [(7, 5, 3), (4, 0, 1), (7, 1, 3), None], [(5, 3), (1, 1), (2, 1), (1, 1)], [(5, 3), (1, 1), (5, 3), (2, 1)], ] for bdims in all_bdims(shape) for dtype in default_dtypes for rng_factory in [jtu.rand_default])) def testSlice(self, shape, dtype, starts, limits, strides, bdims, rng_factory): rng = rng_factory(self.rng()) op = lambda x: lax.slice(x, starts, limits, strides) self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_perm={}_bdims={}".format( jtu.format_shape_dtype_string(shape, dtype), perm, bdims), "shape": shape, "dtype": dtype, "perm": perm, "bdims": bdims, "rng_factory": rng_factory} for shape, perm in [ [(3, 4), (1, 0)], [(3, 4), (0, 1)], [(3, 4, 5), (2, 1, 0)], [(3, 4, 5), (1, 0, 2)], ] for bdims in all_bdims(shape) for dtype in default_dtypes for rng_factory in [jtu.rand_default])) def testTranspose(self, shape, dtype, perm, bdims, rng_factory): rng = rng_factory(self.rng()) op = lambda x: lax.transpose(x, perm) self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_op={}_inshape={}_reducedims={}_initval={}_bdims={}" .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims, init_val, bdims), "op": op, "init_val": init_val, "shape": shape, "dtype": dtype, "dims": dims, "bdims": bdims, "rng_factory": rng_factory} for init_val, op, dtypes in [ (0, lax.add, default_dtypes), (1, lax.mul, default_dtypes), (0, lax.max, all_dtypes), # non-monoidal (-np.inf, lax.max, float_dtypes), (dtypes.iinfo(np.int32).min, lax.max, [np.int32]), (dtypes.iinfo(np.int64).min, lax.max, [np.int64]), (dtypes.iinfo(np.uint32).min, lax.max, [np.uint32]), (dtypes.iinfo(np.uint64).min, lax.max, [np.uint64]), (np.inf, lax.min, float_dtypes), (dtypes.iinfo(np.int32).max, lax.min, [np.int32]), (dtypes.iinfo(np.int64).max, lax.min, [np.int64]), (dtypes.iinfo(np.uint32).max, lax.min, [np.uint32]), (dtypes.iinfo(np.uint64).max, lax.min, [np.uint64]), ] for dtype in dtypes for shape, dims in [ [(3, 4, 5), (0,)], [(3, 4, 5), (1, 2)], [(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)] ] for bdims in all_bdims(shape) for rng_factory in [jtu.rand_small])) def testReduce(self, op, init_val, shape, dtype, dims, bdims, rng_factory): rng = rng_factory(self.rng()) init_val = np.asarray(init_val, dtype=dtype) fun = lambda operand: lax.reduce(operand, init_val, op, dims) self._CheckBatching(fun, 5, bdims, (shape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_op={}_inshape={}_reducedims={}_bdims={}" .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dim, bdims), "op": op, "shape": shape, "dtype": dtype, "dim": dim, "bdims": bdims} for op in [lax.argmin, lax.argmax] for dtype in default_dtypes for shape in [(3, 4, 5)] for dim in range(len(shape)) for bdims in all_bdims(shape))) def testArgminmax(self, op, shape, dtype, dim, bdims): rng = jtu.rand_default(self.rng()) fun = lambda operand: op(operand, dim, np.int32) self._CheckBatching(fun, 5, bdims, (shape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": ("_op={}_shape={}_dims={}_strides={}_padding={}" "_basedilation={}_windowdilation={}") .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims, strides, padding, base_dilation, window_dilation), "op": op, "init_val": init_val, "dtype": dtype, "shape": shape, "dims": dims, "strides": strides, "padding": padding, "base_dilation": base_dilation, "window_dilation": window_dilation} for init_val, op, dtypes in [ (0, lax.add, [np.float32]), (-np.inf, lax.max, [np.float32]), (np.inf, lax.min, [np.float32]), ] for shape, dims, strides, padding, base_dilation, window_dilation in ( itertools.chain( itertools.product( [(4, 6)], [(2, 1), (1, 2)], [(1, 1), (2, 1), (1, 2)], ["VALID", "SAME", [(0, 3), (1, 2)]], [(1, 1), (2, 3)], [(1, 1), (1, 2)]), itertools.product( [(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)], [(1, 2, 2, 1), (1, 1, 1, 1)], ["VALID", "SAME", [(0, 1), (1, 0), (2, 3), (0, 2)]], [(1, 1, 1, 1), (2, 1, 3, 2)], [(1, 1, 1, 1), (1, 2, 2, 1)]))) for dtype in dtypes)) def testReduceWindow(self, op, init_val, dtype, shape, dims, strides, padding, base_dilation, window_dilation): rng = jtu.rand_small(self.rng()) init_val = np.asarray(init_val, dtype=dtype) def fun(operand): return lax.reduce_window(operand, init_val, op, dims, strides, padding, base_dilation, window_dilation) for bdims in all_bdims(shape): self._CheckBatching(fun, 3, bdims, (shape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_op={}_shape={}_axis={}_bdims={}_reverse={}" .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis, bdims, reverse), "op": op, "shape": shape, "dtype": dtype, "bdims": bdims, "axis": axis, "reverse": reverse} for op, types in [ (lax.cumsum, [np.float32, np.float64]), (lax.cumprod, [np.float32, np.float64]), ] for dtype in types for shape in [[10], [3, 4, 5]] for axis in range(len(shape)) for bdims in all_bdims(shape) for reverse in [False, True])) def testCumulativeReduce(self, op, shape, dtype, axis, bdims, reverse): rng_factory = (jtu.rand_default if dtypes.issubdtype(dtype, np.integer) else jtu.rand_small) rng = rng_factory(self.rng()) self._CheckBatching(partial(op, axis=axis, reverse=reverse), 7, bdims, (shape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_dtype={}_padding={}".format(np.dtype(dtype).name, padding), "dtype": dtype, "padding": padding, "rng_factory": rng_factory} for dtype in float_dtypes for padding in ["VALID", "SAME"] for rng_factory in [jtu.rand_small])) @jtu.skip_on_flag("jax_skip_slow_tests", True) @jtu.ignore_warning(message="Using reduced precision for gradient.*") def testSelectAndGatherAdd(self, dtype, padding, rng_factory): if jtu.device_under_test() == "tpu" and dtype == dtypes.bfloat16: raise SkipTest("bfloat16 _select_and_gather_add doesn't work on tpu") rng = rng_factory(self.rng()) all_configs = itertools.chain( itertools.product( [(4, 6)], [(2, 1), (1, 2)], [(1, 1), (2, 1), (1, 2)]), itertools.product( [(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)], [(1, 2, 2, 1), (1, 1, 1, 1)])) def fun(operand, tangents): pads = lax.padtype_to_pads(operand.shape, dims, strides, padding) ones = (1,) * len(operand.shape) return lax._select_and_gather_add(operand, tangents, lax.ge_p, dims, strides, pads, ones, ones) for shape, dims, strides in all_configs: for bdims in all_bdims(shape, shape): self._CheckBatching(fun, 3, bdims, (shape, shape), (dtype, dtype), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": f"_dtype={jtu.format_shape_dtype_string(shape, dtype)}" f"_padding={padding}_dims={dims}_strides={strides}", "dtype": dtype, "padding": padding, "shape": shape, "dims": dims, "strides": strides} for dtype in float_dtypes for padding in ["VALID", "SAME"] for shape in [(3, 2, 4, 6)] for dims in [(1, 1, 2, 1)] for strides in [(1, 2, 2, 1), (1, 1, 1, 1)])) def testSelectAndScatterAdd(self, dtype, padding, shape, dims, strides): rng = jtu.rand_small(self.rng()) pads = lax.padtype_to_pads(shape, dims, strides, padding) def fun(operand, cotangents): return lax._select_and_scatter_add(operand, cotangents, lax.ge_p, dims, strides, pads) ones = (1,) * len(shape) cotangent_shape = api.eval_shape( lambda x: lax._select_and_gather_add(x, x, lax.ge_p, dims, strides, pads, ones, ones), np.ones(shape, dtype)).shape for bdims in all_bdims(cotangent_shape, shape): self._CheckBatching(fun, 3, bdims, (cotangent_shape, shape), (dtype, dtype), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_bdims={}_fft_ndims={}" .format(shape, bdims, fft_ndims), "shape": shape, "bdims": bdims, "fft_ndims": fft_ndims, "rng_factory": rng_factory} for shape in [(5,), (3, 4, 5), (2, 3, 4, 5)] for bdims in all_bdims(shape) for fft_ndims in range(0, min(3, len(shape)) + 1) for rng_factory in [jtu.rand_default])) @jtu.skip_on_devices("tpu") # TODO(b/137993701): unimplemented cases. def testFft(self, fft_ndims, shape, bdims, rng_factory): rng = rng_factory(self.rng()) ndims = len(shape) axes = range(ndims - fft_ndims, ndims) fft_lengths = [shape[axis] for axis in axes] op = lambda x: lax.fft(x, xla_client.FftType.FFT, fft_lengths) self._CheckBatching(op, 5, bdims, [shape], [np.complex64], rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}_bdims={}" .format(jtu.format_shape_dtype_string(shape, dtype), idxs, dnums, slice_sizes, bdims), "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "bdims": bdims} for dtype in all_dtypes for shape, idxs, dnums, slice_sizes in [ ((5,), np.array([[0], [2]]), lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), (1,)), ((10,), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), (2,)), ((10, 5,), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3)), ((10, 5), np.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)), ] for bdims in all_bdims(shape, idxs.shape))) def testGather(self, shape, dtype, idxs, dnums, slice_sizes, bdims): fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) self._CheckBatching(fun, 5, bdims, [shape, idxs.shape], [dtype, idxs.dtype], jtu.rand_default(self.rng())) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}_bdims={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), idxs, update_shape, dnums, bdims), "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, "update_shape": update_shape, "dnums": dnums, "bdims": bdims} for dtype in float_dtypes for arg_shape, idxs, update_shape, dnums in [ ((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))), ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), ] for bdims in all_bdims(arg_shape, idxs.shape, update_shape))) def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums, bdims): fun = partial(lax.scatter_add, dimension_numbers=dnums) self._CheckBatching(fun, 5, bdims, [arg_shape, idxs.shape, update_shape], [dtype, idxs.dtype, dtype], jtu.rand_default(self.rng()), rtol={np.float16: 5e-3}) def testShapeUsesBuiltinInt(self): x = lax.iota(np.int32, 3) + 1 self.assertIsInstance(x.shape[0], int) # not np.int64 def testBroadcastShapesReturnsPythonInts(self): shape1, shape2 = (1, 2, 3), (2, 3) out_shape = lax.broadcast_shapes(shape1, shape2) self.assertTrue(all(type(s) is int for s in out_shape)) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_k={}_bdims={}".format( jtu.format_shape_dtype_string(shape, dtype), k, bdims), "shape": shape, "dtype": dtype, "k": k, "bdims": bdims, "rng_factory": rng_factory} for shape in [(4,), (3, 5, 3)] for k in [1, 3] for bdims in all_bdims(shape) # TODO(b/155170120): test with repeats once the XLA:CPU stable top_k bug is fixed: # The top_k indices for integer arrays with identical entries won't match between # vmap'd version and manual reference, so only test unique integer arrays for int_dtypes. # Note also that we chose 3 * 5 * 3 * 5 such that it fits in the range of # values a bfloat16 can represent exactly to avoid ties. for dtype, rng_factory in itertools.chain( unsafe_zip(default_dtypes, itertools.repeat(jtu.rand_unique_int))))) def testTopK(self, shape, dtype, k, bdims, rng_factory): rng = rng_factory(self.rng()) # _CheckBatching doesn't work with tuple outputs, so test outputs separately. op1 = lambda x: lax.top_k(x, k=k)[0] self._CheckBatching(op1, 5, bdims, (shape,), (dtype,), rng) op2 = lambda x: lax.top_k(x, k=k)[1] self._CheckBatching(op2, 5, bdims, (shape,), (dtype,), rng) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_dimension={}_arity={}_bdims={}_isstable={}" .format(jtu.format_shape_dtype_string(shape, np.float32), dimension, arity, bdims, is_stable), "shape": shape, "dimension": dimension, "arity": arity, "bdims": bdims, "is_stable": is_stable} for shape in [(2, 3)] for dimension in [0, 1] for arity in range(3) for bdims in all_bdims(*((shape,) * arity)) for is_stable in [False, True])) def testSort(self, shape, dimension, arity, bdims, is_stable): rng = jtu.rand_default(self.rng()) if arity == 1: fun = partial(lax.sort, dimension=dimension) self._CheckBatching(fun, 5, bdims, (shape,) * arity, (np.float32,) * arity, rng) else: for i in range(arity): fun = lambda *args, i=i: lax.sort(args, dimension=dimension, is_stable=is_stable)[i] self._CheckBatching(fun, 5, bdims, (shape,) * arity, (np.float32,) * arity, rng)
class IndexingTest(jtu.JaxTestCase): """Tests for Numpy indexing translation rules.""" @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer } for name, index_specs in STATIC_INDEXING_TESTS for shape, indexer in index_specs for dtype in all_dtypes for rng_factory in [jtu.rand_default])) def testStaticIndexing(self, shape, dtype, rng_factory, indexer): rng = rng_factory() args_maker = lambda: [rng(shape, dtype)] fun = lambda x: x[indexer] self._CompileAndCheck(fun, args_maker, check_dtypes=True) @parameterized.named_parameters({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer } for name, index_specs in STATIC_INDEXING_GRAD_TESTS for shape, indexer in index_specs for dtype in float_dtypes for rng_factory in [jtu.rand_default]) def testStaticIndexingGrads(self, shape, dtype, rng_factory, indexer): rng = rng_factory() tol = 1e-2 if lnp.finfo(dtype).bits == 32 else None arg = rng(shape, dtype) fun = lambda x: x[indexer]**2 check_grads(fun, (arg, ), 2, tol, tol, tol) def _ReplaceSlicesWithTuples(self, idx): """Helper method to replace slices with tuples for dynamic indexing args.""" if isinstance(idx, slice): triple = idx.start, idx.stop, idx.step isnone = [i for i, elt in enumerate(triple) if elt is None] zeros = itertools.repeat(0) nones = itertools.repeat(None) out = util.subvals(triple, zip(isnone, zeros)) return out, lambda out: slice(*util.subvals( out, zip(isnone, nones))) elif isinstance(idx, (tuple, list)) and idx: t = type(idx) elts, packs = zip(*map(self._ReplaceSlicesWithTuples, idx)) return elts, lambda elts: t( (pack(i) for pack, i in zip(packs, elts))) else: return idx, lambda x: x @parameterized.named_parameters({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer } for name, index_specs in [ ("OneSliceIndex", [ IndexSpec(shape=(5, ), indexer=slice(1, 3)), IndexSpec(shape=(5, 4), indexer=slice(1, 3)) ]), ("TwoSliceIndices", [ IndexSpec(shape=(5, 4), indexer=(slice(1, 3), slice(0, 2))), IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, 2))) ]), ("NonUnitStrides", [ IndexSpec(shape=(3, ), indexer=slice(None, None, -1)), IndexSpec(shape=(3, 3), indexer=slice(0, 3, -2)), IndexSpec(shape=(3, 4, 5), indexer=slice(0, 4, 2)) ]), ("OnlyStartOrStopDynamic", [ IndexSpec(shape=(5, 4), indexer=(slice(None, 3), slice(0, 2))), IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, None))) ]), ] for shape, indexer in index_specs for dtype in all_dtypes for rng_factory in [jtu.rand_default]) def testDynamicIndexingWithSlicesErrors(self, shape, dtype, rng_factory, indexer): rng = rng_factory() unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) @api.jit def fun(x, unpacked_indexer): indexer = pack_indexer(unpacked_indexer) return x[indexer] args_maker = lambda: [rng(shape, dtype), unpacked_indexer] self.assertRaises(IndexError, lambda: fun(*args_maker())) @parameterized.named_parameters({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer } for name, index_specs in [ ("OneIntIndex", [ IndexSpec(shape=(3, ), indexer=1), IndexSpec(shape=(3, 3), indexer=0), IndexSpec(shape=(3, 4, 5), indexer=2), IndexSpec(shape=(3, ), indexer=-1), IndexSpec(shape=(3, ), indexer=-2) ]), ("TwoIntIndices", [ IndexSpec(shape=(3, 3), indexer=(2, 1)), IndexSpec(shape=(3, 4, 5), indexer=(1, 2)), IndexSpec(shape=(3, 4, 5), indexer=(-1, 2)) ]), ("ThreeIntIndices", [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]), ] for shape, indexer in index_specs for dtype in all_dtypes for rng_factory in [jtu.rand_default]) def testDynamicIndexingWithIntegers(self, shape, dtype, rng_factory, indexer): rng = rng_factory() unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) def fun(x, unpacked_indexer): indexer = pack_indexer(unpacked_indexer) return x[indexer] args_maker = lambda: [rng(shape, dtype), unpacked_indexer] self._CompileAndCheck(fun, args_maker, check_dtypes=True) @parameterized.named_parameters({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer } for name, index_specs in [ ("OneIntIndex", [ IndexSpec(shape=(3, ), indexer=1), IndexSpec(shape=(3, 3), indexer=0), IndexSpec(shape=(3, 4, 5), indexer=2), IndexSpec(shape=(3, ), indexer=-1), IndexSpec(shape=(3, ), indexer=-2), ]), ("TwoIntIndices", [ IndexSpec(shape=(3, 3), indexer=(2, 1)), IndexSpec(shape=(3, 4, 5), indexer=(1, 2)), IndexSpec(shape=(3, 4, 5), indexer=(-1, 2)), ]), ("ThreeIntIndices", [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]), ] for shape, indexer in index_specs for dtype in float_dtypes for rng_factory in [jtu.rand_default]) def testDynamicIndexingWithIntegersGrads(self, shape, dtype, rng_factory, indexer): rng = rng_factory() tol = 1e-2 if lnp.finfo(dtype).bits == 32 else None unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) @api.jit def fun(unpacked_indexer, x): indexer = pack_indexer(unpacked_indexer) return x[indexer] arr = rng(shape, dtype) check_grads(partial(fun, unpacked_indexer), (arr, ), 2, tol, tol, tol) @parameterized.named_parameters({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer } for name, index_specs in ADVANCED_INDEXING_TESTS for shape, indexer in index_specs for dtype in all_dtypes for rng_factory in [jtu.rand_default]) def testAdvancedIntegerIndexing(self, shape, dtype, rng_factory, indexer): rng = rng_factory() args_maker = lambda: [rng(shape, dtype), indexer] fun = lambda x, idx: x[idx] self._CompileAndCheck(fun, args_maker, check_dtypes=True) @parameterized.named_parameters({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer } for name, index_specs in [ ("One1DIntArrayIndex", [ IndexSpec(shape=(3, ), indexer=onp.array([0, 1])), IndexSpec(shape=(3, 3), indexer=onp.array([1, 2, 1])), IndexSpec(shape=(3, 4, 5), indexer=onp.array([0, 2, 0, 1])), IndexSpec(shape=(3, ), indexer=onp.array([-1, 1])), IndexSpec(shape=(3, ), indexer=onp.array([-2, -1])), ]), ("One2DIntArrayIndex", [ IndexSpec(shape=(3, ), indexer=onp.array([[0, 0]])), IndexSpec(shape=(3, 3), indexer=onp.array([[1, 2, 1], [0, 1, -1]])), IndexSpec(shape=(3, 4, 5), indexer=onp.array([[0, 2, 0, 1], [-1, -2, 1, 0]])), ]), ("Two1DIntArrayIndicesNoBroadcasting", [ IndexSpec(shape=(3, 3), indexer=[onp.array([0, 1]), onp.array([1, 2])]), IndexSpec( shape=(3, 4, 5), indexer=[onp.array([0, 2, 0, 1]), onp.array([-1, 0, -1, 2])]), ]), ("Two1DIntArrayIndicesWithBroadcasting", [ IndexSpec(shape=(3, 3), indexer=[onp.array([[0, 1]]), onp.array([1, 2])]), IndexSpec(shape=(3, 4, 5), indexer=[ onp.array([[0, 2, 0, 1]]), onp.array([-1, 0, -1, 2]) ]), ]), ("ListOfPythonInts", [ IndexSpec(shape=(3, ), indexer=[0, 1, 0]), IndexSpec(shape=(3, 4, 5), indexer=[0, -1]), ]), ("ListOfListsOfPythonInts", [ IndexSpec(shape=(3, 4, 5), indexer=[[0, 1]]), IndexSpec(shape=(3, 4, 5), indexer=[[[0], [-1]], [[2, 3, 0, 3]]]), ]), ("ListOfPythonIntsAndIntArrays", [ IndexSpec(shape=(3, 4, 5), indexer=[0, onp.array([0, 1])]), IndexSpec(shape=(3, 4, 5), indexer=[0, 1, onp.array([[2, 3, 0, 3]])]), ]), ("ListOfListsOfPythonIntsAndIntArrays", [ IndexSpec(shape=(3, 4, 5), indexer=[[0, 1], onp.array([0])]), IndexSpec(shape=(3, 4, 5), indexer=[[[0], [-1]], onp.array([[2, 3, 0, 3]])]), ]), ] for shape, indexer in index_specs for dtype in float_dtypes for rng_factory in [jtu.rand_default]) def testAdvancedIntegerIndexingGrads(self, shape, dtype, rng_factory, indexer): rng = rng_factory() tol = 1e-2 if lnp.finfo(dtype).bits == 32 else None arg = rng(shape, dtype) fun = lambda x: x[indexer]**2 check_grads(fun, (arg, ), 2, tol, tol, tol) @parameterized.named_parameters({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer } for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS for shape, indexer in index_specs for dtype in all_dtypes for rng_factory in [jtu.rand_default]) def testMixedAdvancedIntegerIndexing(self, shape, dtype, rng_factory, indexer): rng = rng_factory() indexer_with_dummies = [ e if isinstance(e, onp.ndarray) else () for e in indexer ] substitutes = [(i, e) for i, e in enumerate(indexer) if not isinstance(e, onp.ndarray)] args_maker = lambda: [rng(shape, dtype), indexer_with_dummies] def fun(x, indexer_with_dummies): idx = type(indexer)(util.subvals(indexer_with_dummies, substitutes)) return x[idx] self._CompileAndCheck(fun, args_maker, check_dtypes=True) def testAdvancedIndexingManually(self): x = onp.random.RandomState(0).randn(3, 4, 5) index_array = onp.array([0, 2, -1, 0]) op = lambda x, index_array: x[..., index_array, :] cop = api.jit(op) a1 = op(x, index_array) a2 = cop(x, index_array) self.assertAllClose(a1, a2, check_dtypes=True) op = lambda x, index_array: x[..., index_array, :, index_array, None] cop = api.jit(op) a1 = op(x, index_array) a2 = cop(x, index_array) self.assertAllClose(a1, a2, check_dtypes=True) op = lambda x, index_array: x[index_array, ..., index_array[:, None], None] cop = api.jit(op) a1 = op(x, index_array) a2 = cop(x, index_array) self.assertAllClose(a1, a2, check_dtypes=True) def testUnpacking(self): def foo(x): a, b, c = x return a + b + c cfoo = api.jit(foo) a1 = foo(onp.arange(3)) a2 = cfoo(onp.arange(3)) self.assertAllClose(a1, a2, check_dtypes=True) def testBooleanIndexingArray1D(self): idx = onp.array([True, True, False]) x = api.device_put(onp.arange(3)) ans = x[idx] expected = onp.arange(3)[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testBooleanIndexingList1D(self): idx = [True, True, False] x = api.device_put(onp.arange(3)) ans = x[idx] expected = onp.arange(3)[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testBooleanIndexingArray2DBroadcast(self): idx = onp.array([True, True, False, True]) x = onp.arange(8).reshape(4, 2) ans = api.device_put(x)[idx] expected = x[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testBooleanIndexingList2DBroadcast(self): idx = [True, True, False, True] x = onp.arange(8).reshape(4, 2) ans = api.device_put(x)[idx] expected = x[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testBooleanIndexingArray2D(self): idx = onp.array([[True, False], [False, True], [False, False], [True, True]]) x = onp.arange(8).reshape(4, 2) ans = api.device_put(x)[idx] expected = x[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testBooleanIndexingDynamicShapeError(self): x = onp.zeros(3) i = onp.array([True, True, False]) self.assertRaises(IndexError, lambda: api.jit(lambda x, i: x[i])(x, i)) def testIssue187(self): x = lnp.ones((5, 5)) x[[0, 2, 4], [0, 2, 4]] # doesn't crash x = onp.arange(25).reshape((5, 5)) ans = api.jit(lambda x: x[[0, 2, 4], [0, 2, 4]])(x) expected = x[[0, 2, 4], [0, 2, 4]] self.assertAllClose(ans, expected, check_dtypes=False) def testJVPOfGradOfIndexing(self): # Should return a value, even though we didn't pass a symbolic zero as the # index tangent. x = lnp.ones((3, 4), lnp.float32) i = lnp.ones((3, ), lnp.int32) f = lambda x, i: lnp.sum(x[i]) primals, tangents = api.jvp(api.grad(f), (x, i), (x, onp.zeros_like(i))) expected = onp.broadcast_to( onp.array([0, 3, 0], dtype=onp.float32)[:, None], (3, 4)) self.assertAllClose(expected, primals, check_dtypes=True) self.assertAllClose(onp.zeros_like(x), tangents, check_dtypes=True) def testTrivialGatherIsntGenerated(self): # https://github.com/google/jax/issues/1621 jaxpr = api.make_jaxpr(lambda x: x[:, None])(onp.arange(4)) self.assertEqual(len(jaxpr.jaxpr.eqns), 1) self.assertNotIn('gather', str(jaxpr)) def testBooleanIndexingWithEmptyResult(self): # based on a TensorFlow Probability test that started failing after #1622 x = lnp.array([-1]) mask = lnp.array([False]) ans = x[mask] # doesn't crash expected = onp.array([-1])[onp.array([False])] self.assertAllClose(ans, expected, check_dtypes=False) def testFloatIndexingError(self): x = lnp.array([1, 2, 3]) self.assertRaises(TypeError, lambda: x[3.5])
class CustomObjectTest(jtu.JaxTestCase): @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_compile={}_primitive={}".format(compile, primitive), "compile": compile, "primitive": primitive } for primitive in [True, False] for compile in [True, False])) def testSparseIdentity(self, compile, primitive): f = identity if primitive else (lambda x: x) f = jit(f) if compile else f rng = jtu.rand_default(self.rng()) M = make_sparse_array(rng, (10, ), jnp.float32) M2 = f(M) jaxpr = make_jaxpr(f)(M).jaxpr core.check_jaxpr(jaxpr) self.assertEqual(M.dtype, M2.dtype) self.assertEqual(M.index_dtype, M2.index_dtype) self.assertAllClose(M.data, M2.data) self.assertAllClose(M.indices, M2.indices) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_compile={}".format(compile), "compile": compile } for compile in [True, False])) def testSparseSplit(self, compile): f = jit(split) if compile else split rng = jtu.rand_default(self.rng()) M = make_sparse_array(rng, (10, ), jnp.float32) M2, M3 = f(M) jaxpr = make_jaxpr(f)(M).jaxpr core.check_jaxpr(jaxpr) for MM in M2, M3: self.assertEqual(M.dtype, MM.dtype) self.assertEqual(M.index_dtype, MM.index_dtype) self.assertArraysEqual(M.data, MM.data) self.assertArraysEqual(M.indices, MM.indices) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_compile={}_primitive={}".format(compile, primitive), "compile": compile, "primitive": primitive } for primitive in [True, False] for compile in [True, False])) def testSparseLaxLoop(self, compile, primitive): rng = jtu.rand_default(self.rng()) f = identity if primitive else (lambda x: x) f = jit(f) if compile else f body_fun = lambda _, A: f(A) M = make_sparse_array(rng, (10, ), jnp.float32) lax.fori_loop(0, 10, body_fun, M) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_attr={}".format(attr), "attr": attr } for attr in ["data", "indices"])) def testSparseAttrAccess(self, attr): rng = jtu.rand_default(self.rng()) args_maker = lambda: [make_sparse_array(rng, (10, ), jnp.float32)] f = lambda x: getattr(x, attr) self._CompileAndCheck(f, args_maker) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype } for shape in [(3, 3), (2, 6), (6, 2)] for dtype in jtu.dtypes.floating)) def testSparseMatvec(self, shape, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: [ make_sparse_array(rng, shape, dtype), rng(shape[-1:], dtype) ] self._CompileAndCheck(matvec, args_maker) def testLowerToNothing(self): empty = Empty(AbstractEmpty()) jaxpr = make_jaxpr(jit(lambda e: e))(empty).jaxpr core.check_jaxpr(jaxpr) # cannot return a unit, because CompileAndCheck assumes array output. testfunc = lambda e: None args_maker = lambda: [empty] self._CompileAndCheck(testfunc, args_maker)
class LaxBackedScipyStatsTests(jtu.JaxTestCase): """Tests for LAX-backed scipy.stats implementations""" @genNamedParametersNArgs(3, jtu.rand_default) def testPoissonLogPmf(self, rng_factory, shapes, dtypes): rng = rng_factory() scipy_fun = osp_stats.poisson.logpmf lax_fun = lsp_stats.poisson.logpmf def args_maker(): k, mu, loc = map(rng, shapes, dtypes) k = onp.floor(k) # clipping to ensure that rate parameter is strictly positive mu = onp.clip(onp.abs(mu), a_min=0.1, a_max=None) loc = onp.floor(loc) return [k, mu, loc] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(3, jtu.rand_default) def testPoissonPmf(self, rng_factory, shapes, dtypes): rng = rng_factory() scipy_fun = osp_stats.poisson.pmf lax_fun = lsp_stats.poisson.pmf def args_maker(): k, mu, loc = map(rng, shapes, dtypes) k = onp.floor(k) # clipping to ensure that rate parameter is strictly positive mu = onp.clip(onp.abs(mu), a_min=0.1, a_max=None) loc = onp.floor(loc) return [k, mu, loc] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(3, jtu.rand_default) def testBernoulliLogPmf(self, rng_factory, shapes, dtypes): rng = rng_factory() scipy_fun = osp_stats.bernoulli.logpmf lax_fun = lsp_stats.bernoulli.logpmf def args_maker(): x, logit, loc = map(rng, shapes, dtypes) x = onp.floor(x) p = expit(logit) loc = onp.floor(loc) return [x, p, loc] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(5, jtu.rand_positive) def testBetaLogPdf(self, rng_factory, shapes, dtypes): rng = rng_factory() scipy_fun = osp_stats.beta.logpdf lax_fun = lsp_stats.beta.logpdf def args_maker(): x, a, b, loc, scale = map(rng, shapes, dtypes) return [x, a, b, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True, rtol=1e-4) @genNamedParametersNArgs(3, jtu.rand_default) def testCauchyLogPdf(self, rng_factory, shapes, dtypes): rng = rng_factory() scipy_fun = osp_stats.cauchy.logpdf lax_fun = lsp_stats.cauchy.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(2, jtu.rand_positive) def testDirichletLogPdf(self, rng_factory, shapes, dtypes): rng = rng_factory() scipy_fun = osp_stats.cauchy.logpdf lax_fun = lsp_stats.cauchy.logpdf dim = 4 shapes = (shapes[0] + (dim,), shapes[1] + (dim,)) def args_maker(): x, alpha = map(rng, shapes, dtypes) x = x / onp.sum(x, axis=-1, keepdims=True) return [x, alpha] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(3, jtu.rand_positive) def testExponLogPdf(self, rng_factory, shapes, dtypes): rng = rng_factory() scipy_fun = osp_stats.expon.logpdf lax_fun = lsp_stats.expon.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(4, jtu.rand_positive) def testGammaLogPdf(self, rng_factory, shapes, dtypes): rng = rng_factory() scipy_fun = osp_stats.gamma.logpdf lax_fun = lsp_stats.gamma.logpdf def args_maker(): x, a, loc, scale = map(rng, shapes, dtypes) return [x, a, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(3, jtu.rand_positive) def testLaplaceLogPdf(self, rng_factory, shapes, dtypes): rng = rng_factory() scipy_fun = osp_stats.laplace.logpdf lax_fun = lsp_stats.laplace.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = onp.clip(scale, a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(3, jtu.rand_default) def testLaplaceCdf(self, rng_factory, shapes, dtypes): rng = rng_factory() scipy_fun = osp_stats.laplace.cdf lax_fun = lsp_stats.laplace.cdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # ensure that scale is not too low scale = onp.clip(scale, a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-6) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(1, jtu.rand_default) def testLogisticCdf(self, rng_factory, shapes, dtypes): rng = rng_factory() scipy_fun = osp_stats.logistic.cdf lax_fun = lsp_stats.logistic.cdf def args_maker(): return list(map(rng, shapes, dtypes)) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-6) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(1, jtu.rand_default) def testLogisticLogpdf(self, rng_factory, shapes, dtypes): rng = rng_factory() scipy_fun = osp_stats.logistic.logpdf lax_fun = lsp_stats.logistic.logpdf def args_maker(): return list(map(rng, shapes, dtypes)) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(1, jtu.rand_default) def testLogisticPpf(self, rng_factory, shapes, dtypes): rng = rng_factory() scipy_fun = osp_stats.logistic.ppf lax_fun = lsp_stats.logistic.ppf def args_maker(): return list(map(rng, shapes, dtypes)) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(1, jtu.rand_default) def testLogisticSf(self, rng_factory, shapes, dtypes): rng = rng_factory() scipy_fun = osp_stats.logistic.sf lax_fun = lsp_stats.logistic.sf def args_maker(): return list(map(rng, shapes, dtypes)) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-6) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(3, jtu.rand_default) def testNormLogPdf(self, rng_factory, shapes, dtypes): rng = rng_factory() scipy_fun = osp_stats.norm.logpdf lax_fun = lsp_stats.norm.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(3, jtu.rand_default) def testNormLogCdf(self, rng_factory, shapes, dtypes): rng = rng_factory() scipy_fun = osp_stats.norm.logcdf lax_fun = lsp_stats.norm.logcdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(3, jtu.rand_default) def testNormCdf(self, rng_factory, shapes, dtypes): rng = rng_factory() scipy_fun = osp_stats.norm.cdf lax_fun = lsp_stats.norm.cdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-6) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(3, jtu.rand_default) def testNormPpf(self, rng_factory, shapes, dtypes): rng = rng_factory() scipy_fun = osp_stats.norm.ppf lax_fun = lsp_stats.norm.ppf def args_maker(): q, loc, scale = map(rng, shapes, dtypes) # ensure probability is between 0 and 1: q = onp.clip(onp.abs(q / 3), a_min=None, a_max=1) # clipping to ensure that scale is not too low scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None) return [q, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True, rtol=1e-5) @genNamedParametersNArgs(4, jtu.rand_positive) def testParetoLogPdf(self, rng_factory, shapes, dtypes): rng = rng_factory() scipy_fun = osp_stats.pareto.logpdf lax_fun = lsp_stats.pareto.logpdf def args_maker(): x, b, loc, scale = map(rng, shapes, dtypes) return [x, b, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(4, jtu.rand_default) def testTLogPdf(self, rng_factory, shapes, dtypes): rng = rng_factory() scipy_fun = osp_stats.t.logpdf lax_fun = lsp_stats.t.logpdf def args_maker(): x, df, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None) return [x, df, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @genNamedParametersNArgs(3, jtu.rand_default) def testUniformLogPdf(self, rng_factory, shapes, dtypes): rng = rng_factory() scipy_fun = osp_stats.uniform.logpdf lax_fun = lsp_stats.uniform.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) return [x, loc, onp.abs(scale)] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) def testIssue972(self): self.assertAllClose( onp.ones((4,), onp.float32), lsp_stats.norm.cdf(onp.full((4,), onp.inf, onp.float32)), check_dtypes=False) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_x={}_mean={}_cov={}".format( jtu.format_shape_dtype_string(x_shape, x_dtype), jtu.format_shape_dtype_string(mean_shape, mean_dtype) if mean_shape is not None else None, jtu.format_shape_dtype_string(cov_shape, cov_dtype) if cov_shape is not None else None), "x_shape": x_shape, "x_dtype": x_dtype, "mean_shape": mean_shape, "mean_dtype": mean_dtype, "cov_shape": cov_shape, "cov_dtype": cov_dtype, "rng_factory": rng_factory} for x_shape, mean_shape, cov_shape in [ # # These test cases cover default values for mean/cov, but we don't # # support those yet (and they seem not very valuable). # [(), None, None], # [(), (), None], # [(2,), None, None], # [(2,), (), None], # [(2,), (2,), None], # [(3, 2), (3, 2,), None], # [(5, 3, 2), (5, 3, 2,), None], [(), (), ()], [(3,), (), ()], [(3,), (3,), ()], [(3,), (3,), (3, 3)], [(3, 4), (4,), (4, 4)], # # These test cases are where scipy flattens things, which has # # different batch semantics than some might expect # [(5, 3, 2), (5, 3, 2,), ()], # [(5, 3, 2), (5, 3, 2,), (5, 3, 2, 2)], # [(5, 3, 2), (3, 2,), (5, 3, 2, 2)], # [(5, 3, 2), (3, 2,), (2, 2)], ] for x_dtype, mean_dtype, cov_dtype in CombosWithReplacement(float_dtypes, 3) if (mean_shape is not None or mean_dtype == onp.float32) and (cov_shape is not None or cov_dtype == onp.float32) for rng_factory in [jtu.rand_default])) def testMultivariateNormalLogpdf(self, x_shape, x_dtype, mean_shape, mean_dtype, cov_shape, cov_dtype, rng_factory): rng = rng_factory() def args_maker(): args = [rng(x_shape, x_dtype)] if mean_shape is not None: args.append(5 * rng(mean_shape, mean_dtype)) if cov_shape is not None: if cov_shape == (): args.append(0.1 + rng(cov_shape, cov_dtype) ** 2) else: factor_shape = (*cov_shape[:-1], 2 * cov_shape[-1]) factor = rng(factor_shape, cov_dtype) args.append(onp.matmul(factor, onp.swapaxes(factor, -1, -2))) return args self._CheckAgainstNumpy(osp_stats.multivariate_normal.logpdf, lsp_stats.multivariate_normal.logpdf, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(lsp_stats.multivariate_normal.logpdf, args_maker, check_dtypes=True, rtol=1e-4, atol=1e-4)
class LaxBackedScipyTests(jtu.JaxTestCase): """Tests for LAX-backed Scipy implementation.""" def _GetArgsMaker(self, rng, shapes, dtypes): return lambda: [ rng(shape, dtype) for shape, dtype in zip(shapes, dtypes) ] @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_inshape={}_axis={}_keepdims={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, keepdims), "rng": jtu.rand_default(), "shape": shape, "dtype": dtype, "axis": axis, "keepdims": keepdims } for shape in all_shapes for dtype in float_dtypes for axis in range(-len(shape), len(shape)) for keepdims in [False, True])) @jtu.skip_on_flag("jax_xla_backend", "xrt") def testLogSumExp(self, rng, shape, dtype, axis, keepdims): # TODO(mattjj): test autodiff def scipy_fun(array_to_reduce): return osp_misc.logsumexp(array_to_reduce, axis, keepdims=keepdims) def lax_fun(array_to_reduce): return lsp_misc.logsumexp(array_to_reduce, axis, keepdims=keepdims) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes, dtypes), "rng": rec.rng, "shapes": shapes, "dtypes": dtypes, "modes": rec.diff_modes, "scipy_op": getattr(osp_special, rec.name), "lax_op": getattr(lsp_special, rec.name) } for rec in JAX_SPECIAL_FUNCTION_RECORDS for shapes in CombosWithReplacement(all_shapes, rec.nargs) for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs))) def testScipySpecialFun(self, scipy_op, lax_op, rng, shapes, dtypes, modes): # TODO(mattjj): unskip this test combination when real() on tpu is improved # TODO(mattjj): test autodiff if (FLAGS.jax_test_dut and FLAGS.jax_test_dut.startswith("tpu") and not shapes[0]): return absltest.unittest.skip( "real() on scalar not supported on tpu") args_maker = self._GetArgsMaker(rng, shapes, dtypes) args = args_maker() self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3, check_dtypes=False) self._CompileAndCheck(lax_op, args_maker, check_dtypes=True)
class LaxAutodiffTest(jtu.JaxTestCase): @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix( rec.name, shapes, itertools.repeat(dtype)), "op": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes, "dtype": dtype, "order": rec.order, "tol": rec.tol} for shape_group in compatible_shapes for shapes in itertools.combinations_with_replacement(shape_group, rec.nargs) for dtype in rec.dtypes) for rec in LAX_GRAD_OPS)) def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol): rng = rng_factory(self.rng()) if jtu.device_under_test() == "tpu" and op is lax.pow: raise SkipTest("pow grad imprecise on tpu") tol = jtu.join_tolerance(1e-1, tol) if jtu.num_float_bits(dtype) == 32 else tol args = tuple(rng(shape, dtype) for shape in shapes) check_grads(op, args, order, ["fwd", "rev"], tol, tol) @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": "_{}_{}".format(rec.op.__name__, special_value), "op": rec.op, "special_value": special_value, "tol": rec.tol} for special_value in rec.values) for rec in LAX_GRAD_SPECIAL_VALUE_TESTS)) def testOpGradSpecialValue(self, op, special_value, tol): check_grads(op, (special_value,), 2, ["fwd", "rev"], rtol=tol, atol=tol) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_from_dtype={}_to_dtype={}".format( jtu.dtype_str(from_dtype), jtu.dtype_str(to_dtype)), "from_dtype": from_dtype, "to_dtype": to_dtype, "rng_factory": rng_factory} for from_dtype, to_dtype in itertools.product(inexact_dtypes, repeat=2) for rng_factory in [jtu.rand_default])) def testConvertElementTypeGrad(self, from_dtype, to_dtype, rng_factory): rng = rng_factory(self.rng()) tol = max(jtu.tolerance(to_dtype, jtu.default_gradient_tolerance), jtu.tolerance(from_dtype, jtu.default_gradient_tolerance)) args = (rng((2, 3), from_dtype),) convert_element_type = lambda x: lax.convert_element_type(x, to_dtype) convert_element_type = jtu.ignore_warning(category=np.ComplexWarning)( convert_element_type) check_grads(convert_element_type, args, 2, ["fwd", "rev"], tol, tol, eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}".format( jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng_factory": rng_factory} for shape in [(), (2, 3)] for dtype in grad_float_dtypes for rng_factory in [jtu.rand_default])) def testClampGrad(self, shape, dtype, rng_factory): rng = rng_factory(self.rng()) operand = rng(shape, dtype) low = operand - dtype(10) high = operand + dtype(10) # Avoids points near the boundary where the gradient may be inaccurate. check_grads(lax.clamp, (operand, low, high), 2, ["fwd", "rev"], eps=1e-2) check_grads(lax.clamp, (low, operand, high), 2, ["fwd", "rev"], eps=1e-2) check_grads(lax.clamp, (low, high, operand), 2, ["fwd", "rev"], eps=1e-2) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_dim={}_baseshape=[{}]_dtype={}_narrs={}".format( dim, ",".join(str(d) for d in base_shape), np.dtype(dtype).name, num_arrs), "dim": dim, "base_shape": base_shape, "dtype": dtype, "num_arrs": num_arrs, "rng_factory": rng_factory} for num_arrs in [3] for dtype in float_dtypes for base_shape in [(4,), (3, 4), (2, 3, 4)] for dim in range(len(base_shape)) for rng_factory in [jtu.rand_default])) def testConcatenateGrad(self, dim, base_shape, dtype, num_arrs, rng_factory): rng = rng_factory(self.rng()) shapes = [base_shape[:dim] + (size,) + base_shape[dim+1:] for size, _ in zip(itertools.cycle([3, 1, 4]), range(num_arrs))] operands = tuple(rng(shape, dtype) for shape in shapes) concatenate = lambda *args: lax.concatenate(args, dim) check_grads(concatenate, operands, 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}" .format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "strides": strides, "padding": padding, "rng_factory": rng_factory,} for lhs_shape, rhs_shape, all_strides in itertools.chain( [((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)]) for b, i, j in itertools.product([2, 3], repeat=3)], [((4, 2, 1), (3, 2, 1), [(1,)])]) for strides in all_strides for dtype in float_dtypes for padding in ["VALID", "SAME"] for rng_factory in [jtu.rand_small])) def testConvGrad(self, lhs_shape, rhs_shape, dtype, strides, padding, rng_factory): rng = rng_factory(self.rng()) lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) conv = partial(lax.conv, window_strides=strides, padding=padding, precision=lax.Precision.HIGHEST) check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"], atol=1e-2, rtol=1e-2) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_" "rhs_dilation={}" .format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding, lhs_dil, rhs_dil), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "strides": strides, "padding": padding, "lhs_dil": lhs_dil, "rhs_dil": rhs_dil, "rng_factory": rng_factory} for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in itertools.chain( [((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)], [((0, 0), (0, 0)), ((-1, 0), (0, -1)), ((1, 0), (0, 1))], [(1, 1), (2, 1)], [(1, 1)]) for b, i, j in itertools.product([2, 3], repeat=3)], [((4, 2, 1), (3, 2, 1), [(1,)], [((1, 1),), ((0, 0),)], [(1,), (2,)], [(1,), (2,)])]) for strides in all_strides for rhs_dil in rhs_dils for lhs_dil in lhs_dils for dtype in float_dtypes for padding in all_pads for rng_factory in [jtu.rand_small])) def testConvWithGeneralPaddingGrad(self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dil, rhs_dil, rng_factory): rng = rng_factory(self.rng()) lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) conv = partial(lax.conv_with_general_padding, window_strides=strides, padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil, precision=lax.Precision.HIGHEST) check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"], atol=1e-2, rtol=1e-2) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_" "rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}" .format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums), feature_group_count, batch_group_count), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "strides": strides, "padding": padding, "lhs_dil": lhs_dil, "rhs_dil": rhs_dil, "rng_factory": rng_factory, "dimension_numbers": dim_nums, "perms": perms, "feature_group_count": feature_group_count, "batch_group_count": batch_group_count} for batch_group_count, feature_group_count in ([(1, 1), (2, 1), (1, 2)]) for lhs_shapes, rhs_shape, all_strides, lhs_dils, rhs_dils in [ ([(b * batch_group_count, i * feature_group_count, 6, 7), (b * batch_group_count, i * feature_group_count, 0, 4)], # lhs_shape (j * batch_group_count * feature_group_count, i, 1, 2), # rhs_shape [(1, 1), (1, 2), (2, 1)], # strides [(1, 1), (2, 1)], # lhs_dils [(1, 1), (2, 2)]) # rhs_dils for b, i, j in itertools.product([1, 2], repeat=3)] for lhs_shape in lhs_shapes for strides in all_strides for rhs_dil in rhs_dils for lhs_dil in lhs_dils for dtype in grad_inexact_dtypes for padding in ([((0, 0), (0, 0)), ((1, 0), (0, 1))] + ([((0, -1), (0, 0))] if lhs_shape[2] != 0 else [])) for dim_nums, perms in [ (("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])), (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])), (("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))] for rng_factory in [jtu.rand_default] )) def testConvGeneralDilatedGrad(self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dil, rhs_dil, dimension_numbers, perms, feature_group_count, batch_group_count, rng_factory): if dtype == np.float16: raise SkipTest("float16 numerical issues") # TODO(mattjj): resolve rng = rng_factory(self.rng()) tol = {dtypes.bfloat16: 1e-0, np.float16: 5e-1, np.float32: 1e-3} # permute shapes to match dim_spec, scale by feature_group_count lhs_perm, rhs_perm = perms lhs_shape = list(np.take(lhs_shape, lhs_perm)) rhs_shape = list(np.take(rhs_shape, rhs_perm)) lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) conv = partial(lax.conv_general_dilated, window_strides=strides, padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil, dimension_numbers=dimension_numbers, feature_group_count=feature_group_count, batch_group_count=batch_group_count, precision=lax.Precision.HIGHEST) check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"], atol=tol, rtol=tol) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype)), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "rng_factory": jtu.rand_default} for lhs_shape in [(2,), (3, 2)] for rhs_shape in [(2,), (2, 4)] for dtype in float_dtypes)) def testDotGrad(self, lhs_shape, rhs_shape, dtype, rng_factory): rng = rng_factory(self.rng()) tol = {np.float16: 1e-1, np.float32: 1e-4} lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) dot = partial(lax.dot, precision=lax.Precision.HIGHEST) check_grads_bilinear(dot, (lhs, rhs), order=2, modes=["fwd", "rev"], atol=tol, rtol=tol) # check that precision config is preserved result, pullback = api.vjp(dot, lhs, rhs) gresult = lax.zeros_like_array(result) s = str(api.make_jaxpr(pullback)(gresult)) assert "precision=HIGHEST" in s @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}_dimension_numbers={}" .format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), dimension_numbers), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "dimension_numbers": dimension_numbers, "rng_factory": jtu.rand_small} for lhs_shape, rhs_shape, dimension_numbers in [ ((3, 2), (2, 4), (([1], [0]), ([], []))), ((3, 5), (2, 5), (([1], [1]), ([], []))), ((5, 3), (5, 2), (([0], [0]), ([], []))), ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))), ((3, 5, 2), (2, 4, 5), (([2], [0]), ([1], [2]))), ((7, 3, 5, 2), (2, 2, 4, 5), (([3], [0]), ([2], [3]))), ] for dtype in float_dtypes)) def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype, dimension_numbers, rng_factory): rng = rng_factory(self.rng()) lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) dot_general = partial(lax.dot_general, dimension_numbers=dimension_numbers, precision=lax.Precision.HIGHEST) check_grads_bilinear(dot_general, (lhs, rhs), order=2, modes=["fwd", "rev"]) # check that precision config is preserved result, pullback = api.vjp(dot_general, lhs, rhs) gresult = lax.zeros_like_array(result) s = str(api.make_jaxpr(pullback)(gresult)) assert "precision=HIGHEST" in s @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}".format( shape, np.dtype(dtype).name, broadcast_sizes), "shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes, "rng_factory": rng_factory} for shape in [(), (2, 3)] for dtype in float_dtypes for broadcast_sizes in [(), (2,), (1, 2)] for rng_factory in [jtu.rand_default])) def testBroadcastGrad(self, shape, dtype, broadcast_sizes, rng_factory): rng = rng_factory(self.rng()) args = (rng(shape, dtype),) broadcast = lambda x: lax.broadcast(x, broadcast_sizes) check_grads(broadcast, args, 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_outshape={}_bcdims={}".format( jtu.format_shape_dtype_string(inshape, dtype), outshape, broadcast_dimensions), "inshape": inshape, "dtype": dtype, "outshape": outshape, "dimensions": broadcast_dimensions, "rng_factory": rng_factory} for inshape, outshape, broadcast_dimensions in [ ([2], [2, 2], [0]), ([2], [2, 2], [1]), ([2], [2, 3], [0]), ([], [2, 3], []), ] for dtype in float_dtypes for rng_factory in [jtu.rand_default])) def testBroadcastInDimGrad(self, inshape, dtype, outshape, dimensions, rng_factory): rng = rng_factory(self.rng()) operand = rng(inshape, dtype) broadcast_in_dim = lambda x: lax.broadcast_in_dim(x, outshape, dimensions) check_grads(broadcast_in_dim, (operand,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_outshape={}_perm={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), jtu.format_shape_dtype_string(out_shape, dtype), permutation), "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype, "rng_factory": rng_factory, "permutation": permutation} for dtype in float_dtypes for arg_shape, out_shape, permutation in [ [(3, 4), (12,), None], [(2, 1, 4), (8,), None], [(2, 2, 4), (2, 8), None], [(3, 4), (12,), (0, 1)], [(3, 4), (12,), (1, 0)], [(2, 1, 4), (8,), (0, 2, 1)], [(2, 1, 4), (8,), (2, 0, 1)], [(2, 2, 4), (2, 8), (0, 2, 1)], [(2, 2, 4), (2, 8), (2, 0, 1)], ] for rng_factory in [jtu.rand_default])) def testReshapeGrad(self, arg_shape, out_shape, permutation, dtype, rng_factory): rng = rng_factory(self.rng()) operand = rng(arg_shape, dtype) reshape = lambda x: lax.reshape(x, out_shape, permutation) check_grads(reshape, (operand,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_pads={}" .format(jtu.format_shape_dtype_string(shape, dtype), pads), "shape": shape, "dtype": dtype, "pads": pads, "rng_factory": jtu.rand_small} for shape in [(2, 3)] for dtype in float_dtypes for pads in [[(1, 2, 1), (0, 1, 0)], [(-1, 0, 0), (-1, 0, 2)]])) def testPadGrad(self, shape, dtype, pads, rng_factory): rng = rng_factory(self.rng()) operand = rng(shape, dtype) pad = lambda operand: lax.pad(operand, np.array(0, dtype), pads) check_grads(pad, (operand,), 2, ["fwd", "rev"], eps=1.) operand = rng(shape, dtype) padding_value = np.array(0., dtype) pad = lambda operand, padding_value: lax.pad(operand, padding_value, pads) check_grads(pad, (operand, padding_value), 2, ["fwd", "rev"], eps=1.) def testReverseGrad(self): rev = lambda operand: lax.rev(operand, dimensions) dimensions = [0] check_grads(rev, (np.array([3., 2., 1.]),), 2) dimensions = [0, 1] check_grads(rev, (np.array([[6., 5., 4.], [3., 2., 1.]]),), 2, rtol={np.float32: 3e-3}) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_predshape={}_argshapes={}".format( jtu.format_shape_dtype_string(pred_shape, np.bool_), jtu.format_shape_dtype_string(arg_shape, dtype)), "pred_shape": pred_shape, "arg_shape": arg_shape, "dtype": dtype, "rng_factory": rng_factory} for arg_shape in [(), (3,), (2, 3)] for pred_shape in ([(), arg_shape] if arg_shape else [()]) for dtype in float_dtypes for rng_factory in [jtu.rand_default])) def testSelectGrad(self, pred_shape, arg_shape, dtype, rng_factory): rng = rng_factory(self.rng()) pred = rng(pred_shape, np.bool_) on_true = rng(arg_shape, dtype) on_false = rng(arg_shape, dtype) select = lambda on_true, on_false: lax.select(pred, on_true, on_false) check_grads(select, (on_true, on_false), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_start_indices={}_limit_indices={}_strides={}".format( jtu.format_shape_dtype_string(shape, dtype), start_indices, limit_indices, strides), "shape": shape, "dtype": dtype, "starts": start_indices, "limits": limit_indices, "strides": strides, "rng_factory": rng_factory} for shape, start_indices, limit_indices, strides in [ [(3,), (1,), (2,), None], [(7,), (4,), (7,), None], [(5,), (1,), (5,), (2,)], [(8,), (1,), (6,), (2,)], [(5, 3), (1, 1), (3, 2), None], [(5, 3), (1, 1), (3, 1), None], [(7, 5, 3), (4, 0, 1), (7, 1, 3), None], [(5, 3), (1, 1), (2, 1), (1, 1)], [(5, 3), (1, 1), (5, 3), (2, 1)], ] for dtype in float_dtypes for rng_factory in [jtu.rand_default])) def testSliceGrad(self, shape, dtype, starts, limits, strides, rng_factory): rng = rng_factory(self.rng()) operand = rng(shape, dtype) slice = lambda x: lax.slice(x, starts, limits, strides) check_grads(slice, (operand,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_start_indices={}_size_indices={}".format( jtu.format_shape_dtype_string(shape, dtype), start_indices, size_indices), "shape": shape, "dtype": dtype, "start_indices": start_indices, "size_indices": size_indices, "rng_factory": rng_factory} for shape, start_indices, size_indices in [ [(3,), (1,), (1,)], [(5, 3), (1, 1), (3, 1)], [(7, 5, 3), (4, 1, 0), (2, 0, 1)], ] for dtype in float_dtypes for rng_factory in [jtu.rand_default])) def testDynamicSliceGrad(self, shape, dtype, start_indices, size_indices, rng_factory): rng = rng_factory(self.rng()) operand = rng(shape, dtype) dynamic_slice = lambda x: lax.dynamic_slice(x, start_indices, size_indices) check_grads(dynamic_slice, (operand,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_start_indices={}_update_shape={}".format( jtu.format_shape_dtype_string(shape, dtype), start_indices, update_shape), "shape": shape, "dtype": dtype, "start_indices": start_indices, "update_shape": update_shape, "rng_factory": rng_factory} for shape, start_indices, update_shape in [ [(3,), (1,), (1,)], [(5, 3), (1, 1), (3, 1)], [(7, 5, 3), (4, 1, 0), (2, 0, 1)], ] for dtype in float_dtypes for rng_factory in [jtu.rand_default])) def testDynamicUpdateSliceGrad(self, shape, dtype, start_indices, update_shape, rng_factory): rng = rng_factory(self.rng()) operand = rng(shape, dtype) update = rng(update_shape, dtype) start_indices = np.array(start_indices) dus = lambda x, y: lax.dynamic_update_slice(x, y, start_indices) check_grads(dus, (operand, update), 2, ["fwd", "rev"], eps=1.) dus = lambda x: lax.dynamic_update_slice(x, update, start_indices) check_grads(dus, (operand,), 2, ["fwd", "rev"], eps=1.) dus = lambda y: lax.dynamic_update_slice(operand, y, start_indices) check_grads(dus, (update,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_perm={}".format( jtu.format_shape_dtype_string(shape, dtype), perm), "shape": shape, "dtype": dtype, "perm": perm, "rng_factory": rng_factory} for shape, perm in [ [(3, 4), (1, 0)], [(3, 4), (0, 1)], [(3, 4, 5), (2, 1, 0)], [(3, 4, 5), (1, 0, 2)], ] for dtype in float_dtypes for rng_factory in [jtu.rand_default])) def testTransposeGrad(self, shape, dtype, perm, rng_factory): rng = rng_factory(self.rng()) operand = rng(shape, dtype) transpose = lambda x: lax.transpose(x, perm) check_grads(transpose, (operand,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_op={}_inshape={}_reducedims={}" .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims), "op": op, "init_val": init_val, "shape": shape, "dtype": dtype, "dims": dims, "rng_factory": rng_factory} for init_val, op, dtypes, rng_factory in [ (0, lax.add, float_dtypes + jtu.dtypes.complex, jtu.rand_default), (-np.inf, lax.max, grad_inexact_dtypes, jtu.rand_unique_int), (np.inf, lax.min, grad_inexact_dtypes, jtu.rand_unique_int), (1, lax.mul, grad_float_dtypes, partial(jtu.rand_default, scale=1)), ] for dtype in dtypes for shape, dims in [ [(3, 4, 5), ()], [(3, 4, 5), (0,)], [(3, 4, 5), (1, 2)], [(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)], [(3, 1), (1,)], [(3, 0, 5), (1,)], ])) def testReduceGrad(self, op, init_val, shape, dtype, dims, rng_factory): rng = rng_factory(self.rng()) if jtu.device_under_test() == "tpu" and op is lax.mul: raise SkipTest("unimplemented case") tol = {dtypes.bfloat16: 2e-1, np.float16: 1e-1, np.float32: 1e-1, np.float64: 1e-3, np.complex64: 1e-1} operand = rng(shape, dtype) init_val = np.asarray(init_val, dtype=dtype) reduce = lambda operand: lax.reduce(operand, init_val, op, dims) eps = (1.0 if dtypes.finfo(dtype).bits == 16 and op is lax.add else 1e-1 if dtype == dtypes.bfloat16 else 1e-2 if dtypes.finfo(dtype).bits == 32 else None) if op not in (lax.max, lax.min) or all(d > 0 for d in shape): check_grads(reduce, (operand,), 2, ["fwd", "rev"], tol, tol, eps) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_op={}_dtype={}_padding={}" .format(op.__name__, np.dtype(dtype).name, padding), "op": op, "init_val": init_val, "dtype": dtype, "padding": padding, "rng_factory": rng_factory} for init_val, op, dtypes, rng_factory in [ (0, lax.add, grad_float_dtypes, jtu.rand_small), (-np.inf, lax.max, grad_float_dtypes, jtu.rand_unique_int), (np.inf, lax.min, grad_float_dtypes, jtu.rand_unique_int), ] for dtype in dtypes for padding in ["VALID", "SAME"])) @jtu.skip_on_flag("jax_skip_slow_tests", True) @jtu.ignore_warning(category=UserWarning, message="Using reduced precision for gradient.*") def testReduceWindowGrad(self, op, init_val, dtype, padding, rng_factory): rng = rng_factory(self.rng()) init_val = np.asarray(init_val, dtype=dtype) # We need this conditional and the corresponding loop logic to be in the # test method, rather than at the parameterized test level, because it # depends on FLAGS for the device under test. # TODO(b/31565929): enable when fixed. if jtu.device_under_test() == "tpu" and op is not lax.add: all_configs = [((6, 5, 4, 3), (2, 2, 1, 1), (1, 2, 1, 1))] # TODO(b/73062247): need variadic reduce-window for better precision. gradient_order = 1 else: all_configs = itertools.chain( itertools.product( [(4, 6)], # shapes [(2, 1), (1, 2)], # window_dimensions [(1, 1), (2, 1), (1, 2)] # strides ), itertools.product( [(3, 2, 4, 6)], # shapes [(1, 1, 2, 1), (2, 1, 2, 1)], # window_dimensions [(1, 2, 2, 1), (1, 1, 1, 1)]), # strides ) gradient_order = 3 def fun(operand): return lax.reduce_window(operand, init_val, op, dims, strides, padding) for shape, dims, strides in all_configs: operand = rng(shape, dtype) if op is lax.add: eps = 1. tol = None else: # this test can fail if there are duplicates in operand self.assertEqual(np.unique(operand).size, operand.size, msg="test requires operand elements to be unique.") eps = 1e-2 tol = {np.float16: 1e-1, np.float32: 6e-2, np.float64: 6e-2} check_grads(fun, (operand,), gradient_order, ["fwd", "rev"], tol, tol, eps) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_op={}_shape={}_axis={}" .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis), "op": op, "shape": shape, "dtype": dtype, "axis": axis, "rng_factory": rng_factory} for op, types in [ (lax.cumsum, [np.float32, np.float64]), (lax.cumprod, [np.float32, np.float64]), ] for dtype in types for shape in [[10], [3, 4, 5]] for axis in range(len(shape)) for rng_factory in [ jtu.rand_default if dtypes.issubdtype(dtype, np.integer) else jtu.rand_small])) def testCumulativeReduceGrad(self, op, shape, dtype, axis, rng_factory): rng = rng_factory(self.rng()) check_grads(partial(op, axis=axis), (rng(shape, dtype),), order=2) # TODO(b/205052657): enable more tests when supported @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_axis={}_isstable={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, is_stable), "rng_factory": rng_factory, "shape": shape, "dtype": dtype, "axis": axis, "is_stable": is_stable} for dtype in [np.float32] for shape in [(5,), (5, 7)] for axis in [len(shape) - 1] for is_stable in [False, True] for rng_factory in [jtu.rand_default])) def testSortGrad(self, shape, dtype, axis, is_stable, rng_factory): rng = rng_factory(self.rng()) operand = rng(shape, dtype) sort = lambda x: lax.sort(x, dimension=axis, is_stable=is_stable) check_grads(sort, (operand,), 2, ["fwd", "rev"], eps=1e-2) # TODO(b/205052657): enable more tests when supported @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_keyshape={}_valshape={}_axis={}_isstable={}".format( jtu.format_shape_dtype_string(shape, key_dtype), jtu.format_shape_dtype_string(shape, val_dtype), axis, is_stable), "rng_factory": rng_factory, "shape": shape, "key_dtype": key_dtype, "val_dtype": val_dtype, "axis": axis, "is_stable": is_stable} for key_dtype in [np.float32] for val_dtype in [np.float32] for shape in [(3,), (5, 3)] for axis in [len(shape) - 1] for is_stable in [False, True] for rng_factory in [jtu.rand_default])) def testSortKeyValGrad(self, shape, key_dtype, val_dtype, axis, is_stable, rng_factory): rng = rng_factory(self.rng()) # This test relies on the property that wherever keys are tied, values are # too, since we don't guarantee the same ordering of values with equal keys. # To avoid that case, we generate unique keys (globally in the key array). def args_maker(): flat_keys = np.arange(np.prod(shape, dtype=int), dtype=key_dtype) keys = self.rng().permutation(flat_keys).reshape(shape) values = rng(shape, val_dtype) return keys, values keys, values = args_maker() fun = lambda keys, values: lax.sort_key_val(keys, values, axis, is_stable) check_grads(fun, (keys, values), 2, ["fwd", "rev"], 1e-2, 1e-2, 1e-2) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_k={}".format( jtu.format_shape_dtype_string(shape, dtype), k), "rng_factory": rng_factory, "shape": shape, "dtype": dtype, "k": k} for dtype in [np.float32,] for shape in [(4,), (5, 5), (2, 1, 4)] for k in [1, 3] for rng_factory in [jtu.rand_default])) def testTopKGrad(self, shape, dtype, k, rng_factory): flat_values = np.arange(np.prod(shape, dtype=int), dtype=dtype) values = self.rng().permutation(flat_values).reshape(shape) fun = lambda vs: lax.top_k(vs, k=k)[0] check_grads(fun, (values,), 2, ["fwd", "rev"], eps=1e-2) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_axes={}".format( jtu.format_shape_dtype_string(shape, dtype), idxs, axes), "shape": shape, "dtype": dtype, "idxs": idxs, "axes": axes, "rng_factory": rng_factory} for dtype in float_dtypes for shape, idxs, axes in [ [(3, 4, 5), (np.array([0, 2, 1]),), (0,)], [(3, 4, 5), (np.array([-1, -2]),), (0,)], [(3, 4, 5), (np.array([0, 2]), np.array([1, 3])), (0, 1)], [(3, 4, 5), (np.array([0, 2]), np.array([1, 3])), (0, 2)], ] for rng_factory in [jtu.rand_default])) def testIndexTakeGrad(self, shape, dtype, idxs, axes, rng_factory): rng = rng_factory(self.rng()) src = rng(shape, dtype) index_take = lambda src: lax.index_take(src, idxs, axes) check_grads(index_take, (src,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), idxs, dnums, slice_sizes), "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory} for dtype in grad_float_dtypes for shape, idxs, dnums, slice_sizes, max_idx in [ ((5,), np.array([[0], [2]]), lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), (1,), 5), ((10,), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), (2,), 9), ((10, 5,), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3), 3), ] for rng_idx_factory in [partial(jtu.rand_int, high=max_idx)] for rng_factory in [jtu.rand_default])) def testGatherGrad(self, shape, dtype, idxs, dnums, slice_sizes, rng_factory, rng_idx_factory): rng = rng_factory(self.rng()) rng_idx = rng_idx_factory(self.rng()) idxs = rng_idx(idxs.shape, idxs.dtype) gather = lambda x: lax.gather(x, idxs, dimension_numbers=dnums, slice_sizes=slice_sizes) x = rng(shape, dtype) check_grads(gather, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), idxs, update_shape, dnums), "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, "update_shape": update_shape, "dnums": dnums, "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory} for dtype in grad_float_dtypes for arg_shape, idxs, update_shape, dnums, max_idx in [ ((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)), 4), ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,)), 9), ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)), 3), ] for rng_idx_factory in [partial(jtu.rand_int, high=max_idx)] for rng_factory in [jtu.rand_default])) def testScatterAddGrad(self, arg_shape, dtype, idxs, update_shape, dnums, rng_factory, rng_idx_factory): rng = rng_factory(self.rng()) rng_idx = rng_idx_factory(self.rng()) idxs = rng_idx(idxs.shape, idxs.dtype) scatter_add = lambda x, y: lax.scatter_add(x, idxs, y, dimension_numbers=dnums) x = rng(arg_shape, dtype) y = rng(update_shape, dtype) check_grads(scatter_add, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), idxs, update_shape, dnums), "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, "update_shape": update_shape, "dnums": dnums, "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory} for dtype in grad_float_dtypes for arg_shape, idxs, update_shape, dnums, max_idx in [ ((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)), 4), ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,)), 9), ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)), 3), ] # Scatters with conflicting indices are not deterministic on GPU, so we # use indices that do not collide. for rng_idx_factory in [partial(jtu.rand_unique_int, high=max_idx)] for rng_factory in [jtu.rand_default])) def testScatterGrad(self, arg_shape, dtype, idxs, update_shape, dnums, rng_factory, rng_idx_factory): rng = rng_factory(self.rng()) rng_idx = rng_idx_factory(self.rng()) idxs = rng_idx(idxs.shape, idxs.dtype) scatter = lambda x, y: lax.scatter(x, idxs, y, dimension_numbers=dnums) x = rng(arg_shape, dtype) y = rng(update_shape, dtype) check_grads(scatter, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.) def testScatterGradSymbolicZeroUpdate(self): # https://github.com/google/jax/issues/1901 def f(x): n = x.shape[0] y = np.arange(n, dtype=x.dtype) return jax.ops.index_update(x, np.diag_indices(n), y) rng = jtu.rand_default(self.rng()) check_grads(f, (rng((5, 5), np.float32),), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), idxs, update_shape, dnums), "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, "update_shape": update_shape, "dnums": dnums, "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory} for dtype in grad_float_dtypes for arg_shape, idxs, update_shape, dnums in [ ((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))), ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), ] for rng_idx_factory in [partial(jtu.rand_int, high=max(arg_shape))] for rng_factory in [jtu.rand_default])) def testScatterMax(self, arg_shape, dtype, idxs, update_shape, dnums, rng_factory, rng_idx_factory): rng = rng_factory(self.rng()) rng_idx = rng_idx_factory(self.rng()) idxs = rng_idx(idxs.shape, idxs.dtype) scatter_max = lambda x, y: lax.scatter_max(x, idxs, y, dnums) x = rng(arg_shape, dtype) y = rng(update_shape, dtype) check_grads(scatter_max, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), idxs, update_shape, dnums), "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, "update_shape": update_shape, "dnums": dnums, "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory} for dtype in grad_float_dtypes for arg_shape, idxs, update_shape, dnums in [ ((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))), ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), ] for rng_idx_factory in [partial(jtu.rand_int, high=max(arg_shape))] for rng_factory in [jtu.rand_default])) def testScatterMin(self, arg_shape, dtype, idxs, update_shape, dnums, rng_factory, rng_idx_factory): rng = rng_factory(self.rng()) rng_idx = rng_idx_factory(self.rng()) idxs = rng_idx(idxs.shape, idxs.dtype) scatter_min = lambda x, y: lax.scatter_min(x, idxs, y, dnums) x = rng(arg_shape, dtype) y = rng(update_shape, dtype) check_grads(scatter_min, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2) def testStopGradient(self): def f(x): return lax.sin(x) * lax.cos(lax.stop_gradient(x)) def f2(x, y): return lax.sin(x) * lax.cos(y) x = 3.14 ans = api.grad(f)(x) expected = api.grad(f2)(x, x) self.assertAllClose(ans, expected) ans = api.grad(api.grad(f))(x) expected = api.grad(api.grad(f2))(x, x) self.assertAllClose(ans, expected) ans = api.grad(lambda x: lax.stop_gradient({'foo':x})['foo'])(3.) expected = np.array(0.0) self.assertAllClose(ans, expected, check_dtypes=False) with core.skipping_checks(): with self.assertRaises(TypeError): lax.stop_gradient(lambda x: x) # TODO(mattjj): make this a more systematic test def testRemainder(self): rng = np.random.RandomState(0) x = rng.uniform(-0.9, 9, size=(3, 4)) y = rng.uniform(0.7, 1.9, size=(3, 1)) assert not set(np.unique(x)) & set(np.unique(y)) tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3 check_grads(lax.rem, (x, y), 2, ["fwd", "rev"], tol, tol) rng = np.random.RandomState(0) x = rng.uniform(-0.9, 9, size=(1, 4)) y = rng.uniform(0.7, 1.9, size=(3, 4)) assert not set(np.unique(x)) & set(np.unique(y)) tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3 check_grads(lax.rem, (x, y), 2, ["fwd", "rev"], tol, tol) def testHigherOrderGradientOfReciprocal(self): # Regression test for https://github.com/google/jax/issues/3136 def inv(x): # N.B.: intentionally written as 1/x, not x ** -1 or reciprocal(x) return 1 / x grad_fn = jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(inv)))))) self.assertAllClose(np.float32(0.0439453125), grad_fn(np.float32(4.)))
class DLPackTest(jtu.JaxTestCase): def setUp(self): if jtu.device_under_test() == "tpu": self.skipTest("DLPack not supported on TPU") @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype } for shape in all_shapes for dtype in dlpack_dtypes)) def testJaxRoundTrip(self, shape, dtype): rng = jtu.rand_default() np = rng(shape, dtype) x = jnp.array(np) dlpack = jax.dlpack.to_dlpack(x) y = jax.dlpack.from_dlpack(dlpack) self.assertAllClose(np.astype(x.dtype), y, check_dtypes=True) self.assertRaisesRegex(RuntimeError, "DLPack tensor may be consumed at most once", lambda: jax.dlpack.from_dlpack(dlpack)) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype } for shape in all_shapes for dtype in torch_dtypes)) @unittest.skipIf(not torch, "Test requires PyTorch") def testTorchToJax(self, shape, dtype): rng = jtu.rand_default() np = rng(shape, dtype) x = torch.from_numpy(np) x = x.cuda() if jtu.device_under_test() == "gpu" else x dlpack = torch.utils.dlpack.to_dlpack(x) y = jax.dlpack.from_dlpack(dlpack) self.assertAllClose(np, y, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype } for shape in all_shapes for dtype in torch_dtypes)) @unittest.skipIf(not torch, "Test requires PyTorch") def testJaxToTorch(self, shape, dtype): rng = jtu.rand_default() np = rng(shape, dtype) x = jnp.array(np) dlpack = jax.dlpack.to_dlpack(x) y = torch.utils.dlpack.from_dlpack(dlpack) self.assertAllClose(np, y.numpy(), check_dtypes=True)
class DoubleDoubleTest(jtu.JaxTestCase): @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}_{}".format(op.__name__, jtu.format_shape_dtype_string(shape, dtype)), "dtype": dtype, "shape": shape, "op": op } for dtype in (jnp.float16, jnp.float32, jnp.float64) for shape in ((), (5, ), (2, 3), (2, 3, 4)) for op in (abs, operator.neg, operator.pos, jnp.sqrt))) def testUnaryOp(self, dtype, shape, op): rng = jtu.rand_default(self.rng()) op_doubled = doubledouble(op) args = (rng(shape, dtype), ) self.assertAllClose(op(*args), op_doubled(*args)) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}_{}".format(op.__name__, jtu.format_shape_dtype_string(shape, dtype)), "dtype": dtype, "shape": shape, "op": op } for dtype in (jnp.float16, jnp.float32, jnp.float64) for shape in ((), (5, ), (2, 3), (2, 3, 4)) for op in (operator.add, operator.sub, operator.mul, operator.truediv, operator.gt, operator.ge, operator.lt, operator.le, operator.eq, operator.ne))) def testBinaryOp(self, dtype, shape, op): rng = jtu.rand_default(self.rng()) op_doubled = doubledouble(op) args = rng(shape, dtype), rng(shape, dtype) self.assertAllClose(op(*args), op_doubled(*args)) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_{}_{}".format(jtu.format_shape_dtype_string(shape, dtype), label), "shape": shape, "dtype": dtype, "op1": op1, "op2": op2 } for dtype in (jnp.float32, jnp.float64) for shape in ((), (5, ), (2, 3), (2, 3, 4)) for label, op1, op2 in [ ('add_sub', lambda x, y: x + y - x, lambda x, y: y), ("add_neg_add", lambda x, y: -(x + y) + x, lambda x, y: -y), ("add_mul_sub", lambda x, y: 2 * (x + y) - 2 * x, lambda x, y: 2 * y), ("add_div_sub", lambda x, y: (x + y) / 2 - x / 2, lambda x, y: y / 2), ])) def testDoubledPrecision(self, shape, dtype, op1, op2): """Test operations that would lose precision without doubling.""" rng = jtu.rand_default(self.rng()) double_op1 = doubledouble(op1) args = 1E20 * rng(shape, dtype), rng(shape, dtype) check_dtypes = not FLAGS.jax_enable_x64 self.assertAllClose(double_op1(*args), op2(*args), check_dtypes=check_dtypes) # Sanity check: make sure test fails for regular precision. with self.assertRaisesRegex(AssertionError, "Not equal to tolerance"): self.assertAllClose(op1(*args), op2(*args), check_dtypes=check_dtypes) def testTypeConversion(self): x = jnp.arange(10, dtype='float16') f = lambda x, y: (x + y).astype('float32') g = doubledouble(f) self.assertAllClose(f(1E2 * x, 1E-2 * x), 1E2 * x.astype('float32')) self.assertAllClose(g(1E2 * x, 1E-2 * x), 100.01 * x.astype('float32')) def testRepeatedDoubling(self): def f(x, y, z): return x + y + z - x - y f2 = doubledouble(f) f4 = doubledouble(f2) dtype = jnp.float32 x, y, z = dtype(1E20), dtype(1.0), dtype(1E-20) self.assertEqual(f(x, y, z), -y) self.assertEqual(f2(x, y, z), 0) self.assertEqual(f4(x, y, z), z) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_{}_{}".format(dtype, val), "dtype": dtype, "val": val } for dtype in ["float16", "float32", "float64"] for val in ["6.0221409e23", "3.14159265358", "0", 123456789])) def testClassInstantiation(self, dtype, val): dtype = jnp.dtype(dtype).type self.assertEqual(dtype(val), _DoubleDouble(val, dtype).to_array()) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}_{}".format(jtu.format_shape_dtype_string(shape, dtype), op.__name__), "shape": shape, "dtype": dtype, "op": op } for dtype in (jnp.float32, jnp.float64) for shape in ((), (5, ), (2, 3), (2, 3, 4)) for op in (operator.neg, operator.abs))) def testClassUnaryOp(self, dtype, shape, op): rng = jtu.rand_default(self.rng()) args = (rng(shape, dtype), ) class_op = lambda x: op(_DoubleDouble(x)).to_array() self.assertAllClose(op(*args), class_op(*args)) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}_{}".format(jtu.format_shape_dtype_string(shape, dtype), op.__name__), "shape": shape, "dtype": dtype, "op": op } for dtype in (jnp.float32, jnp.float64) for shape in ((), (5, ), (2, 3), (2, 3, 4)) for op in (operator.add, operator.sub, operator.mul, operator.truediv, operator.gt, operator.ge, operator.lt, operator.le, operator.eq, operator.ne))) def testClassBinaryOp(self, dtype, shape, op): rng = jtu.rand_default(self.rng()) args = rng(shape, dtype), rng(shape, dtype) def class_op(x, y): result = op(_DoubleDouble(x), _DoubleDouble(y)) if isinstance(result, _DoubleDouble): result = result.to_array() return result self.assertAllClose(op(*args), class_op(*args))
class IndexingTest(jtu.JaxTestCase): """Tests for Numpy indexing translation rules.""" @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer } for name, index_specs in STATIC_INDEXING_TESTS for shape, indexer in index_specs for dtype in all_dtypes for rng in [jtu.rand_default()])) def testStaticIndexing(self, shape, dtype, rng, indexer): args_maker = lambda: [rng(shape, dtype)] fun = lambda x: x[indexer] self._CompileAndCheck(fun, args_maker, check_dtypes=True) @parameterized.named_parameters({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer } for name, index_specs in STATIC_INDEXING_GRAD_TESTS for shape, indexer in index_specs for dtype in float_dtypes for rng in [jtu.rand_default()]) def testStaticIndexingGrads(self, shape, dtype, rng, indexer): tol = 1e-2 if onp.finfo(dtype).bits == 32 else None arg = rng(shape, dtype) fun = lambda x: x[indexer]**2 check_grads(fun, (arg, ), 2, tol, tol, tol) def _ReplaceSlicesWithTuples(self, idx): """Helper method to replace slices with tuples for dynamic indexing args.""" if isinstance(idx, slice): triple = idx.start, idx.stop, idx.step isnone = [i for i, elt in enumerate(triple) if elt is None] zeros = itertools.repeat(0) nones = itertools.repeat(None) out = lax.subvals(triple, zip(isnone, zeros)) return out, lambda out: slice(*lax.subvals(out, zip(isnone, nones)) ) elif isinstance(idx, (tuple, list)) and idx: t = type(idx) elts, packs = zip(*map(self._ReplaceSlicesWithTuples, idx)) return elts, lambda elts: t( (pack(i) for pack, i in zip(packs, elts))) else: return idx, lambda x: x @parameterized.named_parameters({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer } for name, index_specs in [ ("OneSliceIndex", [ IndexSpec(shape=(5, ), indexer=slice(1, 3)), IndexSpec(shape=(5, 4), indexer=slice(1, 3)) ]), ("TwoSliceIndices", [ IndexSpec(shape=(5, 4), indexer=(slice(1, 3), slice(0, 2))), IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, 2))) ]), ("NonUnitStrides", [ IndexSpec(shape=(3, ), indexer=slice(None, None, -1)), IndexSpec(shape=(3, 3), indexer=slice(0, 3, -2)), IndexSpec(shape=(3, 4, 5), indexer=slice(0, 4, 2)) ]), ("OnlyStartOrStopDynamic", [ IndexSpec(shape=(5, 4), indexer=(slice(None, 3), slice(0, 2))), IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, None))) ]), ] for shape, indexer in index_specs for dtype in all_dtypes for rng in [jtu.rand_default()]) def testDynamicIndexingWithSlicesErrors(self, shape, dtype, rng, indexer): unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) @api.jit def fun(x, unpacked_indexer): indexer = pack_indexer(unpacked_indexer) return x[indexer] args_maker = lambda: [rng(shape, dtype), unpacked_indexer] self.assertRaises(IndexError, lambda: fun(*args_maker())) @parameterized.named_parameters({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer } for name, index_specs in [ ("OneIntIndex", [ IndexSpec(shape=(3, ), indexer=1), IndexSpec(shape=(3, 3), indexer=0), IndexSpec(shape=(3, 4, 5), indexer=2), IndexSpec(shape=(3, ), indexer=-1), IndexSpec(shape=(3, ), indexer=-2) ]), ("TwoIntIndices", [ IndexSpec(shape=(3, 3), indexer=(2, 1)), IndexSpec(shape=(3, 4, 5), indexer=(1, 2)), IndexSpec(shape=(3, 4, 5), indexer=(-1, 2)) ]), ("ThreeIntIndices", [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]), ] for shape, indexer in index_specs for dtype in all_dtypes for rng in [jtu.rand_default()]) def testDynamicIndexingWithIntegers(self, shape, dtype, rng, indexer): unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) def fun(x, unpacked_indexer): indexer = pack_indexer(unpacked_indexer) return x[indexer] args_maker = lambda: [rng(shape, dtype), unpacked_indexer] self._CompileAndCheck(fun, args_maker, check_dtypes=True) @unittest.skip @parameterized.named_parameters({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer } for name, index_specs in [ ("OneIntIndex", [ IndexSpec(shape=(3, ), indexer=1), IndexSpec(shape=(3, 3), indexer=0), IndexSpec(shape=(3, 4, 5), indexer=2), IndexSpec(shape=(3, ), indexer=-1), IndexSpec(shape=(3, ), indexer=-2), ]), ("TwoIntIndices", [ IndexSpec(shape=(3, 3), indexer=(2, 1)), IndexSpec(shape=(3, 4, 5), indexer=(1, 2)), IndexSpec(shape=(3, 4, 5), indexer=(-1, 2)), ]), ("ThreeIntIndices", [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]), ] for shape, indexer in index_specs for dtype in float_dtypes for rng in [jtu.rand_default()]) def DISABLED_testDynamicIndexingWithIntegersGrads(self, shape, dtype, rng, indexer): # TODO(mattjj): re-enable (test works but for grad-of-compile, in flux) tol = 1e-2 if onp.finfo(dtype).bits == 32 else None unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) @api.jit def fun(unpacked_indexer, x): indexer = pack_indexer(unpacked_indexer) return x[indexer] arr = rng(shape, dtype) check_grads(partial(fun, unpacked_indexer), (arr, ), 2, tol, tol, tol) @parameterized.named_parameters({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer } for name, index_specs in ADVANCED_INDEXING_TESTS for shape, indexer in index_specs for dtype in all_dtypes for rng in [jtu.rand_default()]) def testAdvancedIntegerIndexing(self, shape, dtype, rng, indexer): args_maker = lambda: [rng(shape, dtype), indexer] fun = lambda x, idx: x[idx] self._CompileAndCheck(fun, args_maker, check_dtypes=True) @parameterized.named_parameters({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer } for name, index_specs in [ ("One1DIntArrayIndex", [ IndexSpec(shape=(3, ), indexer=onp.array([0, 1])), IndexSpec(shape=(3, 3), indexer=onp.array([1, 2, 1])), IndexSpec(shape=(3, 4, 5), indexer=onp.array([0, 2, 0, 1])), IndexSpec(shape=(3, ), indexer=onp.array([-1, 1])), IndexSpec(shape=(3, ), indexer=onp.array([-2, -1])), ]), ("One2DIntArrayIndex", [ IndexSpec(shape=(3, ), indexer=onp.array([[0, 0]])), IndexSpec(shape=(3, 3), indexer=onp.array([[1, 2, 1], [0, 1, -1]])), IndexSpec(shape=(3, 4, 5), indexer=onp.array([[0, 2, 0, 1], [-1, -2, 1, 0]])), ]), ("Two1DIntArrayIndicesNoBroadcasting", [ IndexSpec(shape=(3, 3), indexer=[onp.array([0, 1]), onp.array([1, 2])]), IndexSpec( shape=(3, 4, 5), indexer=[onp.array([0, 2, 0, 1]), onp.array([-1, 0, -1, 2])]), ]), ("Two1DIntArrayIndicesWithBroadcasting", [ IndexSpec(shape=(3, 3), indexer=[onp.array([[0, 1]]), onp.array([1, 2])]), IndexSpec(shape=(3, 4, 5), indexer=[ onp.array([[0, 2, 0, 1]]), onp.array([-1, 0, -1, 2]) ]), ]), ("ListOfPythonInts", [ IndexSpec(shape=(3, ), indexer=[0, 1, 0]), IndexSpec(shape=(3, 4, 5), indexer=[0, -1]), ]), ("ListOfListsOfPythonInts", [ IndexSpec(shape=(3, 4, 5), indexer=[[0, 1]]), IndexSpec(shape=(3, 4, 5), indexer=[[[0], [-1]], [[2, 3, 0, 3]]]), ]), ("ListOfPythonIntsAndIntArrays", [ IndexSpec(shape=(3, 4, 5), indexer=[0, onp.array([0, 1])]), IndexSpec(shape=(3, 4, 5), indexer=[0, 1, onp.array([[2, 3, 0, 3]])]), ]), ("ListOfListsOfPythonIntsAndIntArrays", [ IndexSpec(shape=(3, 4, 5), indexer=[[0, 1], onp.array([0])]), IndexSpec(shape=(3, 4, 5), indexer=[[[0], [-1]], onp.array([[2, 3, 0, 3]])]), ]), ] for shape, indexer in index_specs for dtype in float_dtypes for rng in [jtu.rand_default()]) def testAdvancedIntegerIndexingGrads(self, shape, dtype, rng, indexer): tol = 1e-2 if onp.finfo(dtype).bits == 32 else None arg = rng(shape, dtype) fun = lambda x: x[indexer]**2 check_grads(fun, (arg, ), 2, tol, tol, tol) @parameterized.named_parameters({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer } for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS for shape, indexer in index_specs for dtype in all_dtypes for rng in [jtu.rand_default()]) def testMixedAdvancedIntegerIndexing(self, shape, dtype, rng, indexer): indexer_with_dummies = [ e if isinstance(e, onp.ndarray) else () for e in indexer ] substitutes = [(i, e) for i, e in enumerate(indexer) if not isinstance(e, onp.ndarray)] args_maker = lambda: [rng(shape, dtype), indexer_with_dummies] def fun(x, indexer_with_dummies): idx = type(indexer)(lax.subvals(indexer_with_dummies, substitutes)) return x[idx] self._CompileAndCheck(fun, args_maker, check_dtypes=True) def testAdvancedIndexingManually(self): x = onp.random.RandomState(0).randn(3, 4, 5) index_array = onp.array([0, 2, -1, 0]) op = lambda x, index_array: x[..., index_array, :] cop = api.jit(op) a1 = op(x, index_array) a2 = cop(x, index_array) self.assertAllClose(a1, a2, check_dtypes=True) op = lambda x, index_array: x[..., index_array, :, index_array, None] cop = api.jit(op) a1 = op(x, index_array) a2 = cop(x, index_array) self.assertAllClose(a1, a2, check_dtypes=True) op = lambda x, index_array: x[index_array, ..., index_array[:, None], None] cop = api.jit(op) a1 = op(x, index_array) a2 = cop(x, index_array) self.assertAllClose(a1, a2, check_dtypes=True) def testUnpacking(self): def foo(x): a, b, c = x return a + b + c cfoo = api.jit(foo) a1 = foo(onp.arange(3)) a2 = cfoo(onp.arange(3)) self.assertAllClose(a1, a2, check_dtypes=True) def testBooleanIndexingArray1D(self): idx = onp.array([True, True, False]) x = api.device_put(onp.arange(3)) ans = x[idx] expected = onp.arange(3)[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testBooleanIndexingList1D(self): idx = [True, True, False] x = api.device_put(onp.arange(3)) ans = x[idx] expected = onp.arange(3)[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testBooleanIndexingArray2DBroadcast(self): idx = onp.array([True, True, False, True]) x = onp.arange(8).reshape(4, 2) ans = api.device_put(x)[idx] expected = x[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testBooleanIndexingList2DBroadcast(self): idx = [True, True, False, True] x = onp.arange(8).reshape(4, 2) ans = api.device_put(x)[idx] expected = x[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testBooleanIndexingArray2D(self): idx = onp.array([[True, False], [False, True], [False, False], [True, True]]) x = onp.arange(8).reshape(4, 2) ans = api.device_put(x)[idx] expected = x[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testBooleanIndexingDynamicShapeError(self): x = onp.zeros(3) i = onp.array([True, True, False]) self.assertRaises(IndexError, lambda: api.jit(lambda x, i: x[i])(x, i)) def testIssue187(self): x = lnp.ones((5, 5)) x[[0, 2, 4], [0, 2, 4]] # doesn't crash x = onp.arange(25).reshape((5, 5)) ans = api.jit(lambda x: x[[0, 2, 4], [0, 2, 4]])(x) expected = x[[0, 2, 4], [0, 2, 4]] self.assertAllClose(ans, expected, check_dtypes=False)
class ScipyLinalgTest(jtu.JaxTestCase): # TODO(phawkins): enable when there is an LU implementation for GPU/TPU. @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng} for shape in [(1, 1), (4, 5), (10, 5), (50, 50)] for dtype in float_types() + complex_types() for rng in [jtu.rand_default()])) @jtu.skip_on_devices("gpu", "tpu") def testLu(self, shape, dtype, rng): args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(jsp.linalg.lu, osp.linalg.lu, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(jsp.linalg.lu, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng} for shape in [(1, 1), (4, 5), (10, 5), (10, 10)] for dtype in float_types() + complex_types() for rng in [jtu.rand_default()])) # TODO(phawkins): enable when there is an LU implementation for GPU/TPU. @jtu.skip_on_devices("gpu", "tpu") def testLuGrad(self, shape, dtype, rng): a = rng(shape, dtype) jtu.check_grads(jsp.linalg.lu, (a,), 2, rtol=1e-1) @jtu.skip_on_devices("gpu", "tpu") def testLuBatching(self): self.skipTest("Test disabled until Jaxlib 0.1.14 is released") shape = (4, 5) dtype = np.float32 rng = jtu.rand_default() args = [rng(shape, np.float32) for _ in range(10)] expected = list(osp.linalg.lu(x) for x in args) ps = onp.stack([out[0] for out in expected]) ls = onp.stack([out[1] for out in expected]) us = onp.stack([out[2] for out in expected]) actual_ps, actual_ls, actual_us = vmap(jsp.linalg.lu)(np.stack(args)) self.assertAllClose(ps, actual_ps, check_dtypes=True) self.assertAllClose(ls, actual_ls, check_dtypes=True) self.assertAllClose(us, actual_us, check_dtypes=True) # TODO(phawkins): enable when there is an LU implementation for GPU/TPU. @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_n={}".format(jtu.format_shape_dtype_string((n,n), dtype)), "n": n, "dtype": dtype, "rng": rng} for n in [1, 4, 5, 200] for dtype in float_types() + complex_types() for rng in [jtu.rand_default()])) @jtu.skip_on_devices("gpu", "tpu") def testLuFactor(self, n, dtype, rng): args_maker = lambda: [rng((n, n), dtype)] self._CheckAgainstNumpy(jsp.linalg.lu_factor, osp.linalg.lu_factor, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(jsp.linalg.lu_factor, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs={}_rhs={}_sym_pos={}_lower={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), sym_pos, lower), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "sym_pos": sym_pos, "lower": lower, "rng": rng} for lhs_shape, rhs_shape in [ ((1, 1), (1, 1)), ((4, 4), (4,)), ((8, 8), (8, 4)), ] for sym_pos, lower in [ (False, False), (True, False), (True, True), ] for dtype in float_types() + complex_types() for rng in [jtu.rand_default()])) # TODO(phawkins): enable when there is an LU implementation for GPU/TPU. @jtu.skip_on_devices("gpu", "tpu") def testSolve(self, lhs_shape, rhs_shape, dtype, sym_pos, lower, rng): osp_fun = lambda lhs, rhs: osp.linalg.solve(lhs, rhs, sym_pos=sym_pos, lower=lower) jsp_fun = lambda lhs, rhs: jsp.linalg.solve(lhs, rhs, sym_pos=sym_pos, lower=lower) def args_maker(): a = rng(lhs_shape, dtype) if sym_pos: a = onp.matmul(a, onp.conj(T(a))) a = onp.tril(a) if lower else onp.triu(a) return [a, rng(rhs_shape, dtype)] self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs={}_rhs={}_lower={}_transposea={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), lower, transpose_a), "lower": lower, "transpose_a": transpose_a, "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "rng": rng} for lower, transpose_a in itertools.product([False, True], repeat=2) for lhs_shape, rhs_shape in [ ((4, 4), (4,)), ((4, 4), (4, 3)), ((2, 8, 8), (2, 8, 10)), ] for dtype in float_types() for rng in [jtu.rand_default()])) def testSolveTriangular(self, lower, transpose_a, lhs_shape, rhs_shape, dtype, rng): k = rng(lhs_shape, dtype) l = onp.linalg.cholesky(onp.matmul(k, T(k)) + lhs_shape[-1] * onp.eye(lhs_shape[-1])) l = l.astype(k.dtype) b = rng(rhs_shape, dtype) a = l if lower else T(l) inv = onp.linalg.inv(T(a) if transpose_a else a).astype(a.dtype) if len(lhs_shape) == len(rhs_shape): onp_ans = onp.matmul(inv, b) else: onp_ans = onp.einsum("...ij,...j->...i", inv, b) # The standard scipy.linalg.solve_triangular doesn't support broadcasting. # But it seems like an inevitable extension so we support it. ans = jsp.linalg.solve_triangular( l if lower else T(l), b, trans=1 if transpose_a else 0, lower=lower) self.assertAllClose(onp_ans, ans, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs={}_rhs={}_lower={}_transposea={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), lower, transpose_a), "lower": lower, "transpose_a": transpose_a, "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "rng": rng} for lower, transpose_a in itertools.product([False, True], repeat=2) for lhs_shape, rhs_shape in [ ((4, 4), (4,)), ((4, 4), (4, 3)), ((2, 8, 8), (2, 8, 10)), ] for dtype in float_types() for rng in [jtu.rand_default()])) def testSolveTriangularGrad(self, lower, transpose_a, lhs_shape, rhs_shape, dtype, rng): # TODO(frostig): change ensemble to support a bigger rtol self.skipTest("rtol does not cover all devices and precision modes") A = np.tril(rng(lhs_shape, dtype) + 5 * onp.eye(lhs_shape[-1], dtype=dtype)) A = A if lower else T(A) B = rng(rhs_shape, dtype) f = partial(jsp.linalg.solve_triangular, lower=lower, trans=1 if transpose_a else 0) jtu.check_grads(f, (A, B), 2, rtol=1e-3)
class ScipyLinalgTest(jtu.JaxTestCase): # TODO(phawkins): enable when there is an LU implementation for GPU/TPU. @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng } for shape in [(1, 1), (4, 5), (10, 5), (50, 50)] for dtype in float_types() | complex_types() for rng in [jtu.rand_default()])) @jtu.skip_on_devices("gpu", "tpu") def testLu(self, shape, dtype, rng): # TODO(phawkins): remove this after a jaxlib release. if not hasattr(lapack, "jax_getrf"): self.skipTest("No LU implementation available") args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(jsp.linalg.lu, osp.linalg.lu, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(jsp.linalg.lu, args_maker, check_dtypes=True) # TODO(phawkins): enable when there is an LU implementation for GPU/TPU. @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_n={}".format(jtu.format_shape_dtype_string((n, n), dtype)), "n": n, "dtype": dtype, "rng": rng } for n in [1, 4, 5, 200] for dtype in float_types() | complex_types() for rng in [jtu.rand_default()])) @jtu.skip_on_devices("gpu", "tpu") def testLuFactor(self, n, dtype, rng): # TODO(phawkins): remove this after a jaxlib release. if not hasattr(lapack, "jax_getrf"): self.skipTest("No LU implementation available") args_maker = lambda: [rng((n, n), dtype)] self._CheckAgainstNumpy(jsp.linalg.lu_factor, osp.linalg.lu_factor, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(jsp.linalg.lu_factor, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_lhs={}_rhs={}_lower={}_transposea={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), lower, transpose_a), "lower": lower, "transpose_a": transpose_a, "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "rng": rng } for lower, transpose_a in itertools.product([False, True], repeat=2) for lhs_shape, rhs_shape in [ ((4, 4), (4, )), ((4, 4), (4, 3)), ((2, 8, 8), (2, 8, 10)), ] for dtype in float_types() for rng in [jtu.rand_default()])) def testSolveTriangular(self, lower, transpose_a, lhs_shape, rhs_shape, dtype, rng): k = rng(lhs_shape, dtype) l = onp.linalg.cholesky( onp.matmul(k, T(k)) + lhs_shape[-1] * onp.eye(lhs_shape[-1])) l = l.astype(k.dtype) b = rng(rhs_shape, dtype) a = l if lower else T(l) inv = onp.linalg.inv(T(a) if transpose_a else a).astype(a.dtype) if len(lhs_shape) == len(rhs_shape): onp_ans = onp.matmul(inv, b) else: onp_ans = onp.einsum("...ij,...j->...i", inv, b) # The standard scipy.linalg.solve_triangular doesn't support broadcasting. # But it seems like an inevitable extension so we support it. ans = jsp.linalg.solve_triangular(l if lower else T(l), b, trans=1 if transpose_a else 0, lower=lower) self.assertAllClose(onp_ans, ans, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_lhs={}_rhs={}_lower={}_transposea={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), lower, transpose_a), "lower": lower, "transpose_a": transpose_a, "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "rng": rng } for lower, transpose_a in itertools.product([False, True], repeat=2) for lhs_shape, rhs_shape in [ ((4, 4), (4, )), ((4, 4), (4, 3)), ((2, 8, 8), (2, 8, 10)), ] for dtype in float_types() for rng in [jtu.rand_default()])) def testSolveTriangularGrad(self, lower, transpose_a, lhs_shape, rhs_shape, dtype, rng): # TODO(frostig): change ensemble to support a bigger rtol A = np.tril( rng(lhs_shape, dtype) + 5 * onp.eye(lhs_shape[-1], dtype=dtype)) A = A if lower else T(A) B = rng(rhs_shape, dtype) f = partial(jsp.linalg.solve_triangular, lower=lower, trans=1 if transpose_a else 0) jtu.check_grads(f, (A, B), 2, rtol=1e-3)
class ScipyLinalgTest(jtu.JaxTestCase): @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng } for shape in [(1, 1), (4, 5), (10, 5), (50, 50)] for dtype in float_types + complex_types for rng in [jtu.rand_default()])) def testLu(self, shape, dtype, rng): _skip_if_unsupported_type(dtype) args_maker = lambda: [rng(shape, dtype)] x, = args_maker() p, l, u = jsp.linalg.lu(x) self.assertAllClose(x, onp.matmul(p, onp.matmul(l, u)), check_dtypes=True) self._CompileAndCheck(jsp.linalg.lu, args_maker, check_dtypes=True) # TODO(phawkins): figure out why this test fails on Travis and reenable. @unittest.skip("Test fails on travis") def testLuOfSingularMatrixReturnsNans(self): xs = np.array([[-1., 3. / 2], [2. / 3, -1.]]) lu, _ = jsp.linalg.lu_factor(xs) self.assertTrue(onp.all(onp.isnan(lu))) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng } for shape in [(1, 1), (4, 5), (10, 5), (10, 10), (6, 7, 7)] for dtype in float_types + complex_types for rng in [jtu.rand_default()])) @jtu.skip_on_devices("tpu") # TODO(phawkins): precision problems on TPU. def testLuGrad(self, shape, dtype, rng): _skip_if_unsupported_type(dtype) a = rng(shape, dtype) lu = vmap(jsp.linalg.lu) if len(shape) > 2 else jsp.linalg.lu jtu.check_grads(lu, (a, ), 2, atol=5e-2, rtol=1e-1) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng } for shape in [(4, 5), (6, 5)] for dtype in [np.float32] for rng in [jtu.rand_default()])) def testLuBatching(self, shape, dtype, rng): _skip_if_unsupported_type(dtype) args = [rng(shape, np.float32) for _ in range(10)] expected = list(osp.linalg.lu(x) for x in args) ps = onp.stack([out[0] for out in expected]) ls = onp.stack([out[1] for out in expected]) us = onp.stack([out[2] for out in expected]) actual_ps, actual_ls, actual_us = vmap(jsp.linalg.lu)(np.stack(args)) self.assertAllClose(ps, actual_ps, check_dtypes=True) self.assertAllClose(ls, actual_ls, check_dtypes=True) self.assertAllClose(us, actual_us, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_n={}".format(jtu.format_shape_dtype_string((n, n), dtype)), "n": n, "dtype": dtype, "rng": rng } for n in [1, 4, 5, 200] for dtype in float_types + complex_types for rng in [jtu.rand_default()])) def testLuFactor(self, n, dtype, rng): _skip_if_unsupported_type(dtype) args_maker = lambda: [rng((n, n), dtype)] x, = args_maker() lu, piv = jsp.linalg.lu_factor(x) l = onp.tril(lu, -1) + onp.eye(n, dtype=dtype) u = onp.triu(lu) for i in range(n): x[[i, piv[i]], ] = x[[piv[i], i], ] self.assertAllClose(x, onp.matmul(l, u), check_dtypes=True, rtol=1e-3) self._CompileAndCheck(jsp.linalg.lu_factor, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_lhs={}_rhs={}_sym_pos={}_lower={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), sym_pos, lower), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "sym_pos": sym_pos, "lower": lower, "rng": rng } for lhs_shape, rhs_shape in [ ((1, 1), (1, 1)), ((4, 4), (4, )), ((8, 8), (8, 4)), ] for sym_pos, lower in [ (False, False), (True, False), (True, True), ] for dtype in float_types + complex_types for rng in [jtu.rand_default()])) def testSolve(self, lhs_shape, rhs_shape, dtype, sym_pos, lower, rng): _skip_if_unsupported_type(dtype) if (sym_pos and onp.issubdtype(dtype, onp.complexfloating) and jtu.device_under_test() == "tpu"): raise unittest.SkipTest( "Complex Cholesky decomposition not implemented on TPU") osp_fun = lambda lhs, rhs: osp.linalg.solve( lhs, rhs, sym_pos=sym_pos, lower=lower) jsp_fun = lambda lhs, rhs: jsp.linalg.solve( lhs, rhs, sym_pos=sym_pos, lower=lower) def args_maker(): a = rng(lhs_shape, dtype) if sym_pos: a = onp.matmul(a, onp.conj(T(a))) a = onp.tril(a) if lower else onp.triu(a) return [a, rng(rhs_shape, dtype)] self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_lhs={}_rhs={}_lower={}_transposea={}_unit_diagonal={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), lower, transpose_a, unit_diagonal), "lower": lower, "transpose_a": transpose_a, "unit_diagonal": unit_diagonal, "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "rng": rng } for lower in [False, True] for transpose_a in [False, True] for unit_diagonal in [False, True] for lhs_shape, rhs_shape in [ ((4, 4), (4, )), ((4, 4), (4, 3)), ((2, 8, 8), (2, 8, 10)), ] for dtype in float_types for rng in [jtu.rand_default()])) def testSolveTriangular(self, lower, transpose_a, unit_diagonal, lhs_shape, rhs_shape, dtype, rng): _skip_if_unsupported_type(dtype) k = rng(lhs_shape, dtype) l = onp.linalg.cholesky( onp.matmul(k, T(k)) + lhs_shape[-1] * onp.eye(lhs_shape[-1])) l = l.astype(k.dtype) b = rng(rhs_shape, dtype) if unit_diagonal: a = onp.tril(l, -1) + onp.eye(lhs_shape[-1], dtype=dtype) else: a = l a = a if lower else T(a) inv = onp.linalg.inv(T(a) if transpose_a else a).astype(a.dtype) if len(lhs_shape) == len(rhs_shape): onp_ans = onp.matmul(inv, b) else: onp_ans = onp.einsum("...ij,...j->...i", inv, b) # The standard scipy.linalg.solve_triangular doesn't support broadcasting. # But it seems like an inevitable extension so we support it. ans = jsp.linalg.solve_triangular(l if lower else T(l), b, trans=1 if transpose_a else 0, lower=lower, unit_diagonal=unit_diagonal) self.assertAllClose(onp_ans, ans, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_lhs={}_rhs={}_lower={}_transposea={}_unit_diagonal={}". format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), lower, transpose_a, unit_diagonal), "lower": lower, "transpose_a": transpose_a, "unit_diagonal": unit_diagonal, "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "rng": rng } for lower in [False, True] for unit_diagonal in [False, True] for dtype in float_types + complex_types for transpose_a in ( [0, 1] if onp.issubdtype(dtype, np.floating) else [0, 1, 2]) for lhs_shape, rhs_shape in [ ((4, 4), (4, )), ((4, 4), (4, 3)), ((2, 8, 8), (2, 8, 10)), ] for rng in [jtu.rand_default()])) @jtu.skip_on_devices("tpu") # TODO(phawkins): Test fails on TPU. def testSolveTriangularGrad(self, lower, transpose_a, unit_diagonal, lhs_shape, rhs_shape, dtype, rng): _skip_if_unsupported_type(dtype) A = np.tril( rng(lhs_shape, dtype) + 5 * onp.eye(lhs_shape[-1], dtype=dtype)) A = A if lower else T(A) B = rng(rhs_shape, dtype) f = partial(jsp.linalg.solve_triangular, lower=lower, trans=transpose_a, unit_diagonal=unit_diagonal) jtu.check_grads(f, (A, B), 2, rtol=2e-2, eps=1e-3)
class NumpyLinalgTest(jtu.JaxTestCase): @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng } for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (1000, 0, 0)] for dtype in float_types() for rng in [jtu.rand_default()])) def testCholesky(self, shape, dtype, rng): def args_maker(): a = rng(shape, dtype) return [onp.matmul(a, T(a))] self._CheckAgainstNumpy(onp.linalg.cholesky, np.linalg.cholesky, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(np.linalg.cholesky, args_maker, check_dtypes=True) # TODO(phawkins): enable when there is an LU implementation for GPU/TPU. @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_n={}".format(jtu.format_shape_dtype_string((n, n), dtype)), "n": n, "dtype": dtype, "rng": rng } for n in [0, 4, 5, 50] for dtype in float_types() | complex_types() for rng in [jtu.rand_default()])) @jtu.skip_on_devices("gpu", "tpu") def testDet(self, n, dtype, rng): if not hasattr(lapack, "jax_getrf"): self.skipTest("No LU implementation available") args_maker = lambda: [rng((n, n), dtype)] self._CheckAgainstNumpy(onp.linalg.det, np.linalg.det, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(np.linalg.det, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_n={}".format(jtu.format_shape_dtype_string((n, n), dtype)), "n": n, "dtype": dtype, "rng": rng } for n in [0, 4, 10, 200] for dtype in float_types() | complex_types() for rng in [jtu.rand_default()])) @jtu.skip_on_devices("gpu", "tpu") def testSlogdet(self, n, dtype, rng): if not hasattr(lapack, "jax_getrf"): self.skipTest("No LU implementation available") args_maker = lambda: [rng((n, n), dtype)] self._CheckAgainstNumpy(onp.linalg.slogdet, np.linalg.slogdet, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(np.linalg.slogdet, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_fullmatrices={}".format( jtu.format_shape_dtype_string(shape, dtype), full_matrices), "shape": shape, "dtype": dtype, "full_matrices": full_matrices, "rng": rng } for shape in [(1, 1), (3, 4), (2, 10, 5), (2, 200, 100)] for dtype in float_types() for full_matrices in [False, True] for rng in [jtu.rand_default()])) def testQr(self, shape, dtype, full_matrices, rng): m, n = shape[-2:] if full_matrices: mode, k = "complete", m else: mode, k = "reduced", min(m, n) a = rng(shape, dtype) lq, lr = np.linalg.qr(a, mode=mode) # onp.linalg.qr doesn't support broadcasting. But it seems like an # inevitable extension so we support it in our version. nq = onp.zeros(shape[:-2] + (m, k), dtype) nr = onp.zeros(shape[:-2] + (k, n), dtype) for index in onp.ndindex(*shape[:-2]): nq[index], nr[index] = onp.linalg.qr(a[index], mode=mode) max_rank = max(m, n) # Norm, adjusted for dimension and type. def norm(x): n = onp.linalg.norm(x, axis=(-2, -1)) return n / (max_rank * onp.finfo(dtype).eps) def compare_orthogonal(q1, q2): # Q is unique up to sign, so normalize the sign first. sum_of_ratios = onp.sum(onp.divide(q1, q2), axis=-2, keepdims=True) phases = onp.divide(sum_of_ratios, onp.abs(sum_of_ratios)) q1 *= phases self.assertTrue(onp.all(norm(q1 - q2) < 30)) # Check a ~= qr self.assertTrue(onp.all(norm(a - onp.matmul(lq, lr)) < 30)) # Compare the first 'k' vectors of Q; the remainder form an arbitrary # orthonormal basis for the null space. compare_orthogonal(nq[..., :k], lq[..., :k]) # Check that q is close to unitary. self.assertTrue(onp.all(norm(onp.eye(k) - onp.matmul(T(lq), lq)) < 5)) if not full_matrices and m >= n: jtu.check_jvp(np.linalg.qr, partial(jvp, np.linalg.qr), (a, )) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng } for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (5, 5, 5)] for dtype in float_types() for rng in [jtu.rand_default()])) def testInv(self, shape, dtype, rng): def args_maker(): invertible = False while not invertible: a = rng(shape, dtype) try: onp.linalg.inv(a) invertible = True except onp.linalg.LinAlgError: pass return [a] self._CheckAgainstNumpy(onp.linalg.inv, np.linalg.inv, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(np.linalg.inv, args_maker, check_dtypes=True)
class LaxBackedScipyTests(jtu.JaxTestCase): @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_preconditioner={}".format( jtu.format_shape_dtype_string(shape, dtype), preconditioner), "shape": shape, "dtype": dtype, "preconditioner": preconditioner } for shape in [(4, 4), (7, 7), (32, 32)] for dtype in float_types + complex_types for preconditioner in [None, 'identity', 'exact'])) # TODO(#2951): reenable 'random' preconditioner. def test_cg_against_scipy(self, shape, dtype, preconditioner): rng = jtu.rand_default(self.rng()) A = rand_sym_pos_def(rng, shape, dtype) b = rng(shape[:1], dtype) if preconditioner == 'identity': M = np.eye(shape[0], dtype=dtype) elif preconditioner == 'random': M = np.linalg.inv(rand_sym_pos_def(rng, shape, dtype)) elif preconditioner == 'exact': M = np.linalg.inv(A) else: M = None def args_maker(): return A, b self._CheckAgainstNumpy(partial(scipy_cg, M=M, maxiter=1), partial(lax_cg, M=M, maxiter=1), args_maker, tol=1e-3) # TODO(shoyer,mattjj): I had to loosen the tolerance for complex64[7,7] # with preconditioner=random self._CheckAgainstNumpy(partial(scipy_cg, M=M, maxiter=3), partial(lax_cg, M=M, maxiter=3), args_maker, tol=3e-3) self._CheckAgainstNumpy(np.linalg.solve, partial(lax_cg, M=M, atol=1e-6), args_maker, tol=2e-2) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype } for shape in [(2, 2)] for dtype in float_types + complex_types)) def test_cg_as_solve(self, shape, dtype): rng = jtu.rand_default(self.rng()) a = rng(shape, dtype) b = rng(shape[:1], dtype) expected = np.linalg.solve(posify(a), b) actual = lax_cg(posify(a), b) self.assertAllClose(expected, actual) actual = jit(lax_cg)(posify(a), b) self.assertAllClose(expected, actual) # numerical gradients are only well defined if ``a`` is guaranteed to be # positive definite. jtu.check_grads(lambda x, y: lax_cg(posify(x), y), (a, b), order=2, rtol=1e-2) def test_cg_ndarray(self): A = lambda x: 2 * x b = jnp.arange(9.0).reshape((3, 3)) expected = b / 2 actual, _ = jax.scipy.sparse.linalg.cg(A, b) self.assertAllClose(expected, actual) def test_cg_pytree(self): A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]} b = {"a": 1.0, "b": -4.0} expected = {"a": 4.0, "b": -6.0} actual, _ = jax.scipy.sparse.linalg.cg(A, b) self.assertEqual(expected.keys(), actual.keys()) self.assertAlmostEqual(expected["a"], actual["a"], places=6) self.assertAlmostEqual(expected["b"], actual["b"], places=6) def test_cg_errors(self): A = lambda x: x b = jnp.zeros((2, 1)) x0 = jnp.zeros((2, )) with self.assertRaisesRegex(ValueError, "x0 and b must have matching shape"): jax.scipy.sparse.linalg.cg(A, b, x0)
class NdimageTest(jtu.JaxTestCase): @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_{}_coordinates={}_order={}_mode={}_cval={}_impl={}_round={}". format( jtu.format_shape_dtype_string(shape, dtype), jtu.format_shape_dtype_string(coords_shape, coords_dtype), order, mode, cval, impl, round_), "rng_factory": rng_factory, "shape": shape, "coords_shape": coords_shape, "dtype": dtype, "coords_dtype": coords_dtype, "order": order, "mode": mode, "cval": cval, "impl": impl, "round_": round_ } for shape in [(5, ), (3, 4), (3, 4, 5)] for coords_shape in [(7, ), (2, 3, 4)] for dtype in float_dtypes + int_dtypes for coords_dtype in float_dtypes for order in [0, 1] for mode in ['wrap', 'constant', 'nearest'] for cval in ([0, -1] if mode == 'constant' else [0]) for impl, rng_factory in [ ("original", partial(jtu.rand_uniform, low=0, high=1)), ("fixed", partial(jtu.rand_uniform, low=-0.75, high=1.75)), ] for round_ in [True, False])) def testMapCoordinates(self, shape, dtype, coords_shape, coords_dtype, order, mode, cval, impl, round_, rng_factory): def args_maker(): x = onp.arange(onp.prod(shape), dtype=dtype).reshape(shape) coords = [(size - 1) * rng(coords_shape, coords_dtype) for size in shape] if round_: coords = [c.round().astype(int) for c in coords] return x, coords rng = rng_factory(self.rng()) lsp_op = lambda x, c: lsp_ndimage.map_coordinates( x, c, order=order, mode=mode, cval=cval) impl_fun = (osp_ndimage.map_coordinates if impl == "original" else _fixed_ref_map_coordinates) osp_op = lambda x, c: impl_fun(x, c, order=order, mode=mode, cval=cval) if dtype in float_dtypes: epsilon = max([ dtypes.finfo(dtypes.canonicalize_dtype(d)).eps for d in [dtype, coords_dtype] ]) self._CheckAgainstNumpy(lsp_op, osp_op, args_maker, tol=100 * epsilon) else: self._CheckAgainstNumpy(lsp_op, osp_op, args_maker, tol=0) def testMapCoordinatesErrors(self): x = onp.arange(5.0) c = [onp.linspace(0, 5, num=3)] with self.assertRaisesRegex(NotImplementedError, 'requires order<=1'): lsp_ndimage.map_coordinates(x, c, order=2) with self.assertRaisesRegex(NotImplementedError, 'does not yet support mode'): lsp_ndimage.map_coordinates(x, c, order=1, mode='reflect') with self.assertRaisesRegex(ValueError, 'sequence of length'): lsp_ndimage.map_coordinates(x, [c, c], order=1) def testMapCoordinateDocstring(self): self.assertIn("Only linear interpolation", lsp_ndimage.map_coordinates.__doc__) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_{}_order={}".format(onp.dtype(dtype), order), "dtype": dtype, "order": order } for dtype in float_dtypes + int_dtypes for order in [0, 1])) def testMapCoordinatesRoundHalf(self, dtype, order): x = onp.arange(-3, 3, dtype=dtype) c = onp.array([[.5, 1.5, 2.5, 3.5]]) def args_maker(): return x, c lsp_op = lambda x, c: lsp_ndimage.map_coordinates(x, c, order=order) osp_op = lambda x, c: osp_ndimage.map_coordinates(x, c, order=order) self._CheckAgainstNumpy(lsp_op, osp_op, args_maker) def testContinuousGradients(self): # regression test for https://github.com/google/jax/issues/3024 def loss(delta): x = onp.arange(100.0) border = 10 indices = onp.arange(x.size) + delta # linear interpolation of the linear function y=x should be exact shifted = lsp_ndimage.map_coordinates(x, [indices], order=1) return ((x - shifted)**2)[border:-border].mean() # analytical gradient of (x - (x - delta)) ** 2 is 2 * delta self.assertAllClose(grad(loss)(0.5), 1.0, check_dtypes=False) self.assertAllClose(grad(loss)(1.0), 2.0, check_dtypes=False)