class NNInitializersTest(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_{}".format( rec.name, jtu.format_shape_dtype_string(shape, dtype)), "initializer": rec.initializer(), "shape": shape, "dtype": dtype} for rec in INITIALIZER_RECS for shape in rec.shapes for dtype in rec.dtypes)) def testInitializer(self, initializer, shape, dtype): rng = random.PRNGKey(0) val = initializer(rng, shape, dtype) self.assertEqual(shape, jnp.shape(val)) self.assertEqual(jax.dtypes.canonicalize_dtype(dtype), jnp.dtype(val)) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_{}".format( rec.name, jtu.format_shape_dtype_string(shape, dtype)), "initializer_provider": rec.initializer, "shape": shape, "dtype": dtype} for rec in INITIALIZER_RECS for shape in rec.shapes for dtype in rec.dtypes)) def testInitializerProvider(self, initializer_provider, shape, dtype): rng = random.PRNGKey(0) initializer = initializer_provider(dtype=dtype) val = initializer(rng, shape) self.assertEqual(shape, jnp.shape(val)) self.assertEqual(jax.dtypes.canonicalize_dtype(dtype), jnp.dtype(val)) def testVarianceScalingMultiAxis(self): rng = random.PRNGKey(0) shape = (2, 3, 4, 5) initializer = nn.initializers.variance_scaling( scale=1.0, mode='fan_avg', distribution='truncated_normal', in_axis=(0, 1), out_axis=(-2, -1)) val = initializer(rng, shape) self.assertEqual(shape, jnp.shape(val)) def testVarianceScalingBatchAxis(self): rng = random.PRNGKey(0) shape = (2, 3, 4, 5) initializer = nn.initializers.variance_scaling( scale=1.0, mode='fan_avg', distribution='truncated_normal', in_axis=0, out_axis=(2, 3), batch_axis=1) val = initializer(rng, shape) self.assertEqual(shape, jnp.shape(val))
class LaxBackedScipyFftTests(jtu.JaxTestCase): """Tests for LAX-backed scipy.fft implementations""" @parameterized.named_parameters( jtu.cases_from_list( dict( testcase_name= f"_shape={jtu.format_shape_dtype_string(shape, dtype)}_n={n}_axis={axis}_norm={norm}", shape=shape, dtype=dtype, n=n, axis=axis, norm=norm) for dtype in real_dtypes for shape in [(10, ), (2, 5)] for n in [None, 1, 7, 13, 20] for axis in [-1, 0] for norm in [None, 'ortho'])) @jtu.skip_on_devices("rocm") def testDct(self, shape, dtype, n, axis, norm): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype), ) jnp_fn = lambda a: jsp_fft.dct(a, n=n, axis=axis, norm=norm) np_fn = lambda a: osp_fft.dct(a, n=n, axis=axis, norm=norm) self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4) @parameterized.named_parameters( jtu.cases_from_list( dict( testcase_name= f"_shape={jtu.format_shape_dtype_string(shape, dtype)}_axes={axes}_s={s}_norm={norm}", shape=shape, dtype=dtype, s=s, axes=axes, norm=norm) for dtype in real_dtypes for shape in [(10, ), (10, 10), (9, ), (2, 3, 4), (2, 3, 4, 5)] for axes in _get_dctn_test_axes(shape) for s in _get_dctn_test_s(shape, axes) for norm in [None, 'ortho'])) @jtu.skip_on_devices("rocm") def testDctn(self, shape, dtype, s, axes, norm): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype), ) jnp_fn = lambda a: jsp_fft.dctn(a, s=s, axes=axes, norm=norm) np_fn = lambda a: osp_fft.dctn(a, shape=s, axes=axes, norm=norm) self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
class NNInitializersTest(jtu.JaxTestCase): def setUp(self): super().setUp() config.update("jax_numpy_rank_promotion", "raise") def tearDown(self): super().tearDown() config.update("jax_numpy_rank_promotion", "allow") @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}_{}".format(rec.name, jtu.format_shape_dtype_string(shape, dtype)), "initializer": rec.initializer(), "shape": shape, "dtype": dtype } for rec in INITIALIZER_RECS for shape in rec.shapes for dtype in rec.dtypes)) def testInitializer(self, initializer, shape, dtype): rng = random.PRNGKey(0) val = initializer(rng, shape, dtype) self.assertEqual(shape, jnp.shape(val)) self.assertEqual(jax.dtypes.canonicalize_dtype(dtype), jnp.dtype(val)) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}_{}".format(rec.name, jtu.format_shape_dtype_string(shape, dtype)), "initializer_provider": rec.initializer, "shape": shape, "dtype": dtype } for rec in INITIALIZER_RECS for shape in rec.shapes for dtype in rec.dtypes)) def testInitializerProvider(self, initializer_provider, shape, dtype): rng = random.PRNGKey(0) initializer = initializer_provider(dtype=dtype) val = initializer(rng, shape) self.assertEqual(shape, jnp.shape(val)) self.assertEqual(jax.dtypes.canonicalize_dtype(dtype), jnp.dtype(val))
def genNamedParametersNArgs(n): return parameterized.named_parameters( jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix("", shapes, dtypes), "shapes": shapes, "dtypes": dtypes} for shapes in itertools.combinations_with_replacement(all_shapes, n) for dtypes in itertools.combinations_with_replacement(jtu.dtypes.floating, n)))
class CustomErrorsTest(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": f"_{errorclass}", "errorclass": errorclass} for errorclass in dir(jax.errors) if errorclass.endswith('Error') and errorclass not in ['JaxIndexError', 'JAXTypeError'])) def testErrorsURL(self, errorclass): class FakeTracer(core.Tracer): aval = None ErrorClass = getattr(jax.errors, errorclass) err = ErrorClass(FakeTracer(None)) self.assertIn(f'https://jax.readthedocs.io/en/latest/errors.html#jax.errors.{errorclass}', str(err))
class LaxBackedScipyInterpolateTests(jtu.JaxTestCase): """Tests for LAX-backed scipy.interpolate implementations""" @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_spaces={spaces}_method={method}", "spaces": spaces, "method": method } for spaces in (((0., 10., 10), ), ((-15., 20., 12), (3., 4., 24))) for method in ("linear", "nearest"))) def testRegularGridInterpolator(self, spaces, method): rng = jtu.rand_default(self.rng()) scipy_fun = lambda init_args, call_args: sp_interp.RegularGridInterpolator( *init_args[:2], method, False, *init_args[2:])(*call_args) lax_fun = lambda init_args, call_args: jsp_interp.RegularGridInterpolator( *init_args[:2], method, False, *init_args[2:])(*call_args) def args_maker(): points = tuple(map(lambda x: np.linspace(*x), spaces)) values = rng(reduce(operator.add, tuple(map(np.shape, points))), float) fill_value = np.nan init_args = (points, values, fill_value) n_validation_points = 50 valid_points = tuple( map( lambda x: np.linspace(x[0] - 0.2 * (x[1] - x[0]), x[1] + 0.2 * (x[1] - x[0]), n_validation_points), spaces)) valid_points = np.squeeze(np.stack(valid_points, axis=1)) call_args = (valid_points, ) return init_args, call_args self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14})
class CudaArrayInterfaceTest(jtu.JaxTestCase): def setUp(self): super().setUp() if jtu.device_under_test() != "gpu": self.skipTest("__cuda_array_interface__ is only supported on GPU") @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}".format( jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype} for shape in all_shapes for dtype in dlpack_dtypes)) @unittest.skipIf(not cupy, "Test requires CuPy") def testJaxToCuPy(self, shape, dtype): rng = jtu.rand_default(self.rng()) x = rng(shape, dtype) y = jnp.array(x) z = cupy.asarray(y) self.assertEqual(y.__cuda_array_interface__["data"][0], z.__cuda_array_interface__["data"][0]) self.assertAllClose(x, cupy.asnumpy(z))
class LaxBackedScipyTests(jtu.JaxTestCase): """Tests for LAX-backed Scipy implementation.""" def _GetArgsMaker(self, rng, shapes, dtypes): return lambda: [ rng(shape, dtype) for shape, dtype in zip(shapes, dtypes) ] @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shapes={}_axis={}_keepdims={}_return_sign={}_use_b_{}".format( jtu.format_shape_dtype_string(shapes, dtype), axis, keepdims, return_sign, use_b), # TODO(b/133842870): re-enable when exp(nan) returns NaN on CPU. "shapes": shapes, "dtype": dtype, "axis": axis, "keepdims": keepdims, "return_sign": return_sign, "use_b": use_b } for shape_group in compatible_shapes for dtype in float_dtypes + complex_dtypes + int_dtypes for use_b in [False, True] for shapes in itertools.product( *((shape_group, shape_group) if use_b else (shape_group, ))) for axis in range( -max(len(shape) for shape in shapes), max(len(shape) for shape in shapes)) for keepdims in [False, True] for return_sign in [False, True])) @jtu.ignore_warning(category=RuntimeWarning, message="invalid value encountered in .*") @jax.numpy_rank_promotion( 'allow') # This test explicitly exercises implicit rank promotion. def testLogSumExp(self, shapes, dtype, axis, keepdims, return_sign, use_b): if jtu.device_under_test() != "cpu": rng = jtu.rand_some_inf_and_nan(self.rng()) else: rng = jtu.rand_default(self.rng()) # TODO(mattjj): test autodiff if use_b: def scipy_fun(array_to_reduce, scale_array): res = osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign, b=scale_array) if dtype == np.int32: res = tree_map(lambda x: x.astype('float32'), res) return res def lax_fun(array_to_reduce, scale_array): return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign, b=scale_array) args_maker = lambda: [rng(shapes[0], dtype), rng(shapes[1], dtype)] else: def scipy_fun(array_to_reduce): res = osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign) if dtype == np.int32: res = tree_map(lambda x: x.astype('float32'), res) return res def lax_fun(array_to_reduce): return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign) args_maker = lambda: [rng(shapes[0], dtype)] tol = {np.float32: 1E-6, np.float64: 1E-14} self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker) self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol) def testLogSumExpZeros(self): # Regression test for https://github.com/google/jax/issues/5370 scipy_fun = lambda a, b: osp_special.logsumexp(a, b=b) lax_fun = lambda a, b: lsp_special.logsumexp(a, b=b) args_maker = lambda: [np.array([-1000, -2]), np.array([1, 0])] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker) self._CompileAndCheck(lax_fun, args_maker) def testLogSumExpOnes(self): # Regression test for https://github.com/google/jax/issues/7390 args_maker = lambda: [np.ones(4, dtype='float32')] with jax.debug_infs(True): self._CheckAgainstNumpy(osp_special.logsumexp, lsp_special.logsumexp, args_maker) self._CompileAndCheck(lsp_special.logsumexp, args_maker) def testLogSumExpNans(self): # Regression test for https://github.com/google/jax/issues/7634 with jax.debug_nans(True): with jax.disable_jit(): result = lsp_special.logsumexp(1.0) self.assertEqual(result, 1.0) result = lsp_special.logsumexp(1.0, b=1.0) self.assertEqual(result, 1.0) @parameterized.named_parameters( itertools.chain.from_iterable( jtu.cases_from_list( { "testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes, dtypes), "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, "test_autodiff": rec.test_autodiff, "nondiff_argnums": rec.nondiff_argnums, "scipy_op": getattr(osp_special, rec.name), "lax_op": getattr(lsp_special, rec.name) } for shapes in itertools.combinations_with_replacement( all_shapes, rec.nargs) for dtypes in (itertools.combinations_with_replacement( rec.dtypes, rec.nargs) if isinstance(rec.dtypes, list) else itertools.product(*rec.dtypes))) for rec in JAX_SPECIAL_FUNCTION_RECORDS)) @jax.numpy_rank_promotion( 'allow') # This test explicitly exercises implicit rank promotion. @jax.numpy_dtype_promotion( 'standard') # This test explicitly exercises dtype promotion def testScipySpecialFun(self, scipy_op, lax_op, rng_factory, shapes, dtypes, test_autodiff, nondiff_argnums): if (jtu.device_under_test() == "cpu" and (lax_op is lsp_special.gammainc or lax_op is lsp_special.gammaincc)): # TODO(b/173608403): re-enable test when LLVM bug is fixed. raise unittest.SkipTest("Skipping test due to LLVM lowering bug") rng = rng_factory(self.rng()) args_maker = self._GetArgsMaker(rng, shapes, dtypes) args = args_maker() self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3, check_dtypes=False) self._CompileAndCheck(lax_op, args_maker, rtol=1e-4) if test_autodiff: def partial_lax_op(*vals): list_args = list(vals) for i in nondiff_argnums: list_args.insert(i, args[i]) return lax_op(*list_args) assert list(nondiff_argnums) == sorted(set(nondiff_argnums)) diff_args = [ x for i, x in enumerate(args) if i not in nondiff_argnums ] jtu.check_grads(partial_lax_op, diff_args, order=1, atol=jtu.if_device_under_test("tpu", .1, 1e-3), rtol=.1, eps=1e-3) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_inshape={}_d={}".format( jtu.format_shape_dtype_string(shape, dtype), d), "shape": shape, "dtype": dtype, "d": d } for shape in all_shapes for dtype in float_dtypes for d in [1, 2, 5])) @jax.numpy_rank_promotion('raise') def testMultigammaln(self, shape, dtype, d): def scipy_fun(a): return osp_special.multigammaln(a, d) def lax_fun(a): return lsp_special.multigammaln(a, d) rng = jtu.rand_positive(self.rng()) args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol={ np.float32: 1e-3, np.float64: 1e-14 }) self._CompileAndCheck(lax_fun, args_maker, rtol={ np.float32: 3e-07, np.float64: 4e-15 }) def testIssue980(self): x = np.full((4, ), -1e20, dtype=np.float32) self.assertAllClose(np.zeros((4, ), dtype=np.float32), lsp_special.expit(x)) @jax.numpy_rank_promotion('raise') def testIssue3758(self): x = np.array([1e5, 1e19, 1e10], dtype=np.float32) q = np.array([1., 40., 30.], dtype=np.float32) self.assertAllClose(np.array([1., 0., 0.], dtype=np.float32), lsp_special.zeta(x, q)) def testXlogyShouldReturnZero(self): self.assertAllClose(lsp_special.xlogy(0., 0.), 0., check_dtypes=False) def testGradOfXlogyAtZero(self): partial_xlogy = functools.partial(lsp_special.xlogy, 0.) self.assertAllClose(jax.grad(partial_xlogy)(0.), 0., check_dtypes=False) def testXlog1pyShouldReturnZero(self): self.assertAllClose(lsp_special.xlog1py(0., -1.), 0., check_dtypes=False) def testGradOfXlog1pyAtZero(self): partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.) self.assertAllClose(jax.grad(partial_xlog1py)(-1.), 0., check_dtypes=False) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}_lmax={}".format(jtu.format_shape_dtype_string(shape, dtype), l_max), "l_max": l_max, "shape": shape, "dtype": dtype } for l_max in [1, 2, 3, 6] for shape in [(5, ), (10, )] for dtype in float_dtypes)) def testLpmn(self, l_max, shape, dtype): rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9) args_maker = lambda: [rng(shape, dtype)] lax_fun = partial(lsp_special.lpmn, l_max, l_max) def scipy_fun(z, m=l_max, n=l_max): # scipy only supports scalar inputs for z, so we must loop here. vals, derivs = zip(*(osp_special.lpmn(m, n, zi) for zi in z)) return np.dstack(vals), np.dstack(derivs) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, rtol=1e-5, atol=3e-3, check_dtypes=False) self._CompileAndCheck(lax_fun, args_maker, rtol=1E-5, atol=3e-3) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}_lmax={}".format(jtu.format_shape_dtype_string(shape, dtype), l_max), "l_max": l_max, "shape": shape, "dtype": dtype } for l_max in [3, 4, 6, 32] for shape in [(2, ), (3, ), (4, ), (64, )] for dtype in float_dtypes)) def testNormalizedLpmnValues(self, l_max, shape, dtype): rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9) args_maker = lambda: [rng(shape, dtype)] # Note: we test only the normalized values, not the derivatives. lax_fun = partial(lsp_special.lpmn_values, l_max, l_max, is_normalized=True) def scipy_fun(z, m=l_max, n=l_max): # scipy only supports scalar inputs for z, so we must loop here. vals, _ = zip(*(osp_special.lpmn(m, n, zi) for zi in z)) a = np.dstack(vals) # apply the normalization num_m, num_l, _ = a.shape a_normalized = np.zeros_like(a) for m in range(num_m): for l in range(num_l): c0 = (2.0 * l + 1.0) * osp_special.factorial(l - m) c1 = (4.0 * np.pi) * osp_special.factorial(l + m) c2 = np.sqrt(c0 / c1) a_normalized[m, l] = c2 * a[m, l] return a_normalized self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, rtol=1e-5, atol=1e-5, check_dtypes=False) self._CompileAndCheck(lax_fun, args_maker, rtol=1E-6, atol=1E-6) @jax.numpy_dtype_promotion( 'standard') # This test explicitly exercises dtype promotion def testSphHarmAccuracy(self): m = jnp.arange(-3, 3)[:, None] n = jnp.arange(3, 6) n_max = 5 theta = 0.0 phi = jnp.pi expected = lsp_special.sph_harm(m, n, theta, phi, n_max) actual = osp_special.sph_harm(m, n, theta, phi) self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5) @jax.numpy_dtype_promotion( 'standard') # This test explicitly exercises dtype promotion def testSphHarmOrderZeroDegreeZero(self): """Tests the spherical harmonics of order zero and degree zero.""" theta = jnp.array([0.3]) phi = jnp.array([2.3]) n_max = 0 expected = jnp.array([1.0 / jnp.sqrt(4.0 * np.pi)]) actual = jnp.real( lsp_special.sph_harm(jnp.array([0]), jnp.array([0]), theta, phi, n_max)) self.assertAllClose(actual, expected, rtol=1.1e-7, atol=3e-8) @jtu.skip_on_devices("rocm") # rtol and atol needs to be adjusted for ROCm @jax.numpy_dtype_promotion( 'standard') # This test explicitly exercises dtype promotion def testSphHarmOrderZeroDegreeOne(self): """Tests the spherical harmonics of order one and degree zero.""" theta = jnp.array([2.0]) phi = jnp.array([3.1]) n_max = 1 expected = jnp.sqrt(3.0 / (4.0 * np.pi)) * jnp.cos(phi) actual = jnp.real( lsp_special.sph_harm(jnp.array([0]), jnp.array([1]), theta, phi, n_max)) self.assertAllClose(actual, expected, rtol=2e-7, atol=6e-8) @jax.numpy_dtype_promotion( 'standard') # This test explicitly exercises dtype promotion def testSphHarmOrderOneDegreeOne(self): """Tests the spherical harmonics of order one and degree one.""" theta = jnp.array([2.0]) phi = jnp.array([2.5]) n_max = 1 expected = (-1.0 / 2.0 * jnp.sqrt(3.0 / (2.0 * np.pi)) * jnp.sin(phi) * jnp.exp(1j * theta)) actual = lsp_special.sph_harm(jnp.array([1]), jnp.array([1]), theta, phi, n_max) self.assertAllClose(actual, expected, rtol=1e-8, atol=6e-8) @parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': f'_maxdegree={l_max}_inputsize={num_z}_dtype={dtype.__name__}', 'l_max': l_max, 'num_z': num_z, 'dtype': dtype } for l_max, num_z in zip([1, 3, 8, 10], [2, 6, 7, 8]) for dtype in jtu.dtypes.all_integer)) @jax.numpy_dtype_promotion( 'standard') # This test explicitly exercises dtype promotion def testSphHarmForJitAndAgainstNumpy(self, l_max, num_z, dtype): """Tests against JIT compatibility and Numpy.""" n_max = l_max shape = (num_z, ) rng = jtu.rand_int(self.rng(), -l_max, l_max + 1) lsp_special_fn = partial(lsp_special.sph_harm, n_max=n_max) def args_maker(): m = rng(shape, dtype) n = abs(m) theta = jnp.linspace(-4.0, 5.0, num_z) phi = jnp.linspace(-2.0, 1.0, num_z) return m, n, theta, phi with self.subTest('Test JIT compatibility'): self._CompileAndCheck(lsp_special_fn, args_maker) with self.subTest('Test against numpy.'): self._CheckAgainstNumpy(osp_special.sph_harm, lsp_special_fn, args_maker) @jax.numpy_dtype_promotion( 'standard') # This test explicitly exercises dtype promotion def testSphHarmCornerCaseWithWrongNmax(self): """Tests the corner case where `n_max` is not the maximum value of `n`.""" m = jnp.array([2]) n = jnp.array([10]) n_clipped = jnp.array([6]) n_max = 6 theta = jnp.array([0.9]) phi = jnp.array([0.2]) expected = lsp_special.sph_harm(m, n, theta, phi, n_max) actual = lsp_special.sph_harm(m, n_clipped, theta, phi, n_max) self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5) @parameterized.named_parameters( jtu.cases_from_list( { 'testcase_name': '_shape={}' '_n_zero_sv={}_degeneracy={}_geometric_spectrum={}' '_max_sv={}_method={}_side={}' '_nonzero_condition_number={}_seed={}'.format( jtu.format_shape_dtype_string( shape, jnp.dtype(dtype).name).replace(" ", ""), n_zero_sv, degeneracy, geometric_spectrum, max_sv, method, side, nonzero_condition_number, seed), 'n_zero_sv': n_zero_sv, 'degeneracy': degeneracy, 'geometric_spectrum': geometric_spectrum, 'max_sv': max_sv, 'shape': shape, 'method': method, 'side': side, 'nonzero_condition_number': nonzero_condition_number, 'dtype': dtype, 'seed': seed } for n_zero_sv in n_zero_svs for degeneracy in degeneracies for geometric_spectrum in geometric_spectra for max_sv in max_svs for shape in polar_shapes for method in methods for side in sides for nonzero_condition_number in nonzero_condition_numbers for dtype in jtu.dtypes.inexact for seed in seeds)) @jtu.skip_on_devices("gpu") # Fails on A100. def testPolar(self, n_zero_sv, degeneracy, geometric_spectrum, max_sv, shape, method, side, nonzero_condition_number, dtype, seed): """ Tests jax.scipy.linalg.polar.""" if jtu.device_under_test() != "cpu": if jnp.dtype(dtype).name in ("bfloat16", "float16"): raise unittest.SkipTest("Skip half precision off CPU.") m, n = shape if (method == "qdwh" and ((side == "left" and m >= n) or (side == "right" and m < n))): raise unittest.SkipTest("method=qdwh does not support these sizes") matrix, _ = _initialize_polar_test(self.rng(), shape, n_zero_sv, degeneracy, geometric_spectrum, max_sv, nonzero_condition_number, dtype) if jnp.dtype(dtype).name in ("bfloat16", "float16"): self.assertRaises(NotImplementedError, jsp.linalg.polar, matrix, method=method, side=side) return unitary, posdef = jsp.linalg.polar(matrix, method=method, side=side) if shape[0] >= shape[1]: should_be_eye = np.matmul(unitary.conj().T, unitary) else: should_be_eye = np.matmul(unitary, unitary.conj().T) tol = 500 * float(jnp.finfo(matrix.dtype).eps) eye_mat = np.eye(should_be_eye.shape[0], dtype=should_be_eye.dtype) with self.subTest('Test unitarity.'): self.assertAllClose(eye_mat, should_be_eye, atol=tol * min(shape)) with self.subTest('Test Hermiticity.'): self.assertAllClose(posdef, posdef.conj().T, atol=tol * jnp.linalg.norm(posdef)) ev, _ = np.linalg.eigh(posdef) ev = ev[np.abs(ev) > tol * np.linalg.norm(posdef)] negative_ev = jnp.sum(ev < 0.) with self.subTest('Test positive definiteness.'): self.assertEqual(negative_ev, 0) if side == "right": recon = jnp.matmul(unitary, posdef, precision=lax.Precision.HIGHEST) elif side == "left": recon = jnp.matmul(posdef, unitary, precision=lax.Precision.HIGHEST) with self.subTest('Test reconstruction.'): self.assertAllClose(matrix, recon, atol=tol * jnp.linalg.norm(matrix)) @parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': '_linear_size={}_dtype={}_termination_size={}'.format( linear_size, jnp.dtype(dtype).name, termination_size), 'linear_size': linear_size, 'dtype': dtype, 'termination_size': termination_size } for linear_size in linear_sizes for dtype in jtu.dtypes.floating + jtu.dtypes.complex for termination_size in [1, 19])) def test_spectral_dac_eigh(self, linear_size, dtype, termination_size): if jtu.device_under_test() != "tpu" and termination_size != 1: raise unittest.SkipTest( "Termination sizes greater than 1 only work on TPU") rng = self.rng() H = rng.randn(linear_size, linear_size) H = jnp.array(0.5 * (H + H.conj().T)).astype(dtype) if jnp.dtype(dtype).name in ("bfloat16", "float16"): self.assertRaises(NotImplementedError, jax._src.lax.eigh.eigh, H) return evs, V = jax._src.lax.eigh.eigh(H, termination_size=termination_size) ev_exp, eV_exp = jnp.linalg.eigh(H) HV = jnp.dot(H, V, precision=lax.Precision.HIGHEST) vV = evs.astype(V.dtype)[None, :] * V eps = jnp.finfo(H.dtype).eps atol = jnp.linalg.norm(H) * eps self.assertAllClose(ev_exp, jnp.sort(evs), atol=20 * atol) self.assertAllClose( HV, vV, atol=atol * (80 if jnp.issubdtype(dtype, jnp.complexfloating) else 30)) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_{jtu.format_shape_dtype_string((n_obs, n_codes, *n_feats), dtype)}", "n_obs": n_obs, "n_codes": n_codes, "n_feats": n_feats, "dtype": dtype } for n_obs in [1, 3, 5] for n_codes in [1, 2, 4] for n_feats in [()] + [(i, ) for i in range(1, 3)] for dtype in float_dtypes + int_dtypes) ) # scipy doesn't support complex def test_vq(self, n_obs, n_codes, n_feats, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: [ rng((n_obs, *n_feats), dtype), rng((n_codes, *n_feats), dtype) ] self._CheckAgainstNumpy(osp_cluster.vq.vq, lsp_cluster.vq.vq, args_maker, check_dtypes=False) self._CompileAndCheck(lsp_cluster.vq.vq, args_maker)
class SvdTest(jtu.JaxTestCase): @parameterized.named_parameters( jtu.cases_from_list( { # pylint:disable=g-complex-comprehension 'testcase_name': '_m={}_by_n={}_log_cond={}'.format( m, n, log_cond), 'm': m, 'n': n, 'log_cond': log_cond } for m, n in zip([2, 8, 10, 20], [4, 6, 10, 18]) for log_cond in np.linspace(1, _MAX_LOG_CONDITION_NUM, 4))) @jtu.skip_on_devices("rocm") # will be fixed on rocm-5.1 def testSvdWithRectangularInput(self, m, n, log_cond): """Tests SVD with rectangular input.""" with jax.default_matmul_precision('float32'): a = np.random.uniform(low=0.3, high=0.9, size=(m, n)).astype(_SVD_TEST_DTYPE) u, s, v = jnp.linalg.svd(a, full_matrices=False) cond = 10**log_cond s = jnp.linspace(cond, 1, min(m, n)) a = (u * s) @ v a = a + 1j * a osp_linalg_fn = functools.partial(osp_linalg.svd, full_matrices=False) actual_u, actual_s, actual_v = svd.svd(a) k = min(m, n) if m > n: unitary_u = jnp.abs(actual_u.T.conj() @ actual_u) unitary_v = jnp.abs(actual_v.T.conj() @ actual_v) else: unitary_u = jnp.abs(actual_u @ actual_u.T.conj()) unitary_v = jnp.abs(actual_v @ actual_v.T.conj()) _, expected_s, _ = osp_linalg_fn(a) args_maker = lambda: [a] with self.subTest('Test JIT compatibility'): self._CompileAndCheck(svd.svd, args_maker) with self.subTest('Test unitary u.'): self.assertAllClose(np.eye(k), unitary_u, rtol=_SVD_RTOL, atol=2E-3) with self.subTest('Test unitary v.'): self.assertAllClose(np.eye(k), unitary_v, rtol=_SVD_RTOL, atol=2E-3) with self.subTest('Test s.'): self.assertAllClose(expected_s, jnp.real(actual_s), rtol=_SVD_RTOL, atol=1E-6) @parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': '_m={}_by_n={}'.format(m, n), 'm': m, 'n': n } for m, n in zip([50, 6], [3, 60]))) def testSvdWithSkinnyTallInput(self, m, n): """Tests SVD with skinny and tall input.""" # Generates a skinny and tall input with jax.default_matmul_precision('float32'): np.random.seed(1235) a = np.random.randn(m, n).astype(_SVD_TEST_DTYPE) u, s, v = svd.svd(a, is_hermitian=False) relative_diff = np.linalg.norm(a - (u * s) @ v) / np.linalg.norm(a) np.testing.assert_almost_equal(relative_diff, 1E-6, decimal=6) @parameterized.named_parameters( jtu.cases_from_list( { # pylint:disable=g-complex-comprehension 'testcase_name': '_m={}_r={}_log_cond={}'.format( m, r, log_cond), 'm': m, 'r': r, 'log_cond': log_cond } for m, r in zip([8, 8, 8, 10], [3, 5, 7, 9]) for log_cond in np.linspace(1, 3, 3))) @jtu.skip_on_devices("rocm") # will be fixed on rocm-5.1 def testSvdWithOnRankDeficientInput(self, m, r, log_cond): """Tests SVD with rank-deficient input.""" with jax.default_matmul_precision('float32'): a = jnp.triu(jnp.ones((m, m))).astype(_SVD_TEST_DTYPE) # Generates a rank-deficient input. u, s, v = jnp.linalg.svd(a, full_matrices=False) cond = 10**log_cond s = jnp.linspace(cond, 1, m) s = s.at[r:m].set(jnp.zeros((m - r, ))) a = (u * s) @ v with jax.default_matmul_precision('float32'): u, s, v = svd.svd(a, is_hermitian=False) diff = np.linalg.norm(a - (u * s) @ v) np.testing.assert_almost_equal(diff, 1E-4, decimal=2)
class TestPromotionTables(jtu.JaxTestCase): @parameterized.named_parameters({ "testcase_name": f"_jaxtype={jaxtype}", "jaxtype": jaxtype } for jaxtype in dtypes._jax_types + dtypes._weak_types) def testJaxTypeFromType(self, jaxtype): self.assertIs(dtypes._jax_type(*dtypes._dtype_and_weaktype(jaxtype)), jaxtype) @parameterized.named_parameters({ "testcase_name": f"_jaxtype={jaxtype}", "jaxtype": jaxtype } for jaxtype in dtypes._jax_types + dtypes._weak_types) def testJaxTypeFromVal(self, jaxtype): try: val = jaxtype(0) except TypeError: val = jaxtype.type(0) self.assertIs(dtypes._jax_type(*dtypes._dtype_and_weaktype(val)), jaxtype) @parameterized.named_parameters({ "testcase_name": f"_dtype={dtype}", "dtype": dtype } for dtype in dtypes._jax_types) def testJaxTypeWeak(self, dtype): jax_type = dtypes._jax_type(dtype, weak_type=True) if dtypes.issubdtype(jax_type, np.complexfloating): self.assertIs(jax_type, complex) elif dtypes.issubdtype(jax_type, np.floating): self.assertIs(jax_type, float) elif dtypes.issubdtype(jax_type, np.integer): self.assertIs(jax_type, int) else: self.assertIs(jax_type, np.dtype(bool)) def testResultTypeNone(self): # This matches the behavior of np.result_type(None) => np.float_ self.assertEqual(dtypes.result_type(None), dtypes.canonicalize_dtype(dtypes.float_)) def testResultTypeWeakFlag(self): float_ = dtypes.canonicalize_dtype(dtypes.float_) x_weak = jnp.array(1.) x_strong = x_weak.astype(float_) self.assertEqual(dtypes.result_type(x_weak), float_) self.assertEqual( dtypes.result_type(x_weak, return_weak_type_flag=True), (float_, True)) self.assertEqual(dtypes.result_type(x_strong), float_) self.assertEqual( dtypes.result_type(x_strong, return_weak_type_flag=True), (float_, False)) @jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*") @jax.numpy_dtype_promotion('standard') def testObservedPromotionTable(self): """Test that the weak & strong dtype promotion table does not change over time.""" # Note: * here refers to weakly-typed values typecodes = \ ['b1','u1','u2','u4','u8','i1','i2','i4','i8','bf','f2','f4','f8','c4','c8','i*','f*','c*'] if config.x64_enabled: expected = [ [ 'b1', 'u1', 'u2', 'u4', 'u8', 'i1', 'i2', 'i4', 'i8', 'bf', 'f2', 'f4', 'f8', 'c4', 'c8', 'i*', 'f*', 'c*' ], [ 'u1', 'u1', 'u2', 'u4', 'u8', 'i2', 'i2', 'i4', 'i8', 'bf', 'f2', 'f4', 'f8', 'c4', 'c8', 'u1', 'f*', 'c*' ], [ 'u2', 'u2', 'u2', 'u4', 'u8', 'i4', 'i4', 'i4', 'i8', 'bf', 'f2', 'f4', 'f8', 'c4', 'c8', 'u2', 'f*', 'c*' ], [ 'u4', 'u4', 'u4', 'u4', 'u8', 'i8', 'i8', 'i8', 'i8', 'bf', 'f2', 'f4', 'f8', 'c4', 'c8', 'u4', 'f*', 'c*' ], [ 'u8', 'u8', 'u8', 'u8', 'u8', 'f*', 'f*', 'f*', 'f*', 'bf', 'f2', 'f4', 'f8', 'c4', 'c8', 'u8', 'f*', 'c*' ], [ 'i1', 'i2', 'i4', 'i8', 'f*', 'i1', 'i2', 'i4', 'i8', 'bf', 'f2', 'f4', 'f8', 'c4', 'c8', 'i1', 'f*', 'c*' ], [ 'i2', 'i2', 'i4', 'i8', 'f*', 'i2', 'i2', 'i4', 'i8', 'bf', 'f2', 'f4', 'f8', 'c4', 'c8', 'i2', 'f*', 'c*' ], [ 'i4', 'i4', 'i4', 'i8', 'f*', 'i4', 'i4', 'i4', 'i8', 'bf', 'f2', 'f4', 'f8', 'c4', 'c8', 'i4', 'f*', 'c*' ], [ 'i8', 'i8', 'i8', 'i8', 'f*', 'i8', 'i8', 'i8', 'i8', 'bf', 'f2', 'f4', 'f8', 'c4', 'c8', 'i8', 'f*', 'c*' ], [ 'bf', 'bf', 'bf', 'bf', 'bf', 'bf', 'bf', 'bf', 'bf', 'bf', 'f4', 'f4', 'f8', 'c4', 'c8', 'bf', 'bf', 'c4' ], [ 'f2', 'f2', 'f2', 'f2', 'f2', 'f2', 'f2', 'f2', 'f2', 'f4', 'f2', 'f4', 'f8', 'c4', 'c8', 'f2', 'f2', 'c4' ], [ 'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'f8', 'c4', 'c8', 'f4', 'f4', 'c4' ], [ 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'c8', 'c8', 'f8', 'f8', 'c8' ], [ 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c8', 'c4', 'c8', 'c4', 'c4', 'c4' ], [ 'c8', 'c8', 'c8', 'c8', 'c8', 'c8', 'c8', 'c8', 'c8', 'c8', 'c8', 'c8', 'c8', 'c8', 'c8', 'c8', 'c8', 'c8' ], [ 'i*', 'u1', 'u2', 'u4', 'u8', 'i1', 'i2', 'i4', 'i8', 'bf', 'f2', 'f4', 'f8', 'c4', 'c8', 'i*', 'f*', 'c*' ], [ 'f*', 'f*', 'f*', 'f*', 'f*', 'f*', 'f*', 'f*', 'f*', 'bf', 'f2', 'f4', 'f8', 'c4', 'c8', 'f*', 'f*', 'c*' ], [ 'c*', 'c*', 'c*', 'c*', 'c*', 'c*', 'c*', 'c*', 'c*', 'c4', 'c4', 'c4', 'c8', 'c4', 'c8', 'c*', 'c*', 'c*' ], ] else: expected = [ [ 'b1', 'u1', 'u2', 'u4', 'u4', 'i1', 'i2', 'i4', 'i4', 'bf', 'f2', 'f4', 'f4', 'c4', 'c4', 'i*', 'f*', 'c*' ], [ 'u1', 'u1', 'u2', 'u4', 'u4', 'i2', 'i2', 'i4', 'i4', 'bf', 'f2', 'f4', 'f4', 'c4', 'c4', 'u1', 'f*', 'c*' ], [ 'u2', 'u2', 'u2', 'u4', 'u4', 'i4', 'i4', 'i4', 'i4', 'bf', 'f2', 'f4', 'f4', 'c4', 'c4', 'u2', 'f*', 'c*' ], [ 'u4', 'u4', 'u4', 'u4', 'u4', 'i4', 'i4', 'i4', 'i4', 'bf', 'f2', 'f4', 'f4', 'c4', 'c4', 'u4', 'f*', 'c*' ], [ 'u4', 'u4', 'u4', 'u4', 'u4', 'i4', 'i4', 'i4', 'i4', 'bf', 'f2', 'f4', 'f4', 'c4', 'c4', 'u4', 'f*', 'c*' ], [ 'i1', 'i2', 'i4', 'i4', 'i4', 'i1', 'i2', 'i4', 'i4', 'bf', 'f2', 'f4', 'f4', 'c4', 'c4', 'i1', 'f*', 'c*' ], [ 'i2', 'i2', 'i4', 'i4', 'i4', 'i2', 'i2', 'i4', 'i4', 'bf', 'f2', 'f4', 'f4', 'c4', 'c4', 'i2', 'f*', 'c*' ], [ 'i4', 'i4', 'i4', 'i4', 'i4', 'i4', 'i4', 'i4', 'i4', 'bf', 'f2', 'f4', 'f4', 'c4', 'c4', 'i4', 'f*', 'c*' ], [ 'i4', 'i4', 'i4', 'i4', 'i4', 'i4', 'i4', 'i4', 'i4', 'bf', 'f2', 'f4', 'f4', 'c4', 'c4', 'i4', 'f*', 'c*' ], [ 'bf', 'bf', 'bf', 'bf', 'bf', 'bf', 'bf', 'bf', 'bf', 'bf', 'f4', 'f4', 'f4', 'c4', 'c4', 'bf', 'bf', 'c4' ], [ 'f2', 'f2', 'f2', 'f2', 'f2', 'f2', 'f2', 'f2', 'f2', 'f4', 'f2', 'f4', 'f4', 'c4', 'c4', 'f2', 'f2', 'c4' ], [ 'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'c4', 'c4', 'f4', 'f4', 'c4' ], [ 'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'c4', 'c4', 'f4', 'f4', 'c4' ], [ 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4' ], [ 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4' ], [ 'i*', 'u1', 'u2', 'u4', 'u4', 'i1', 'i2', 'i4', 'i4', 'bf', 'f2', 'f4', 'f4', 'c4', 'c4', 'i*', 'f*', 'c*' ], [ 'f*', 'f*', 'f*', 'f*', 'f*', 'f*', 'f*', 'f*', 'f*', 'bf', 'f2', 'f4', 'f4', 'c4', 'c4', 'f*', 'f*', 'c*' ], [ 'c*', 'c*', 'c*', 'c*', 'c*', 'c*', 'c*', 'c*', 'c*', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c*', 'c*', 'c*' ], ] typecode_to_dtype = { 'b1': jnp.bool_, 'u1': jnp.uint8, 'u2': jnp.uint16, 'u4': jnp.uint32, 'u8': jnp.uint64, 'i1': jnp.int8, 'i2': jnp.int16, 'i4': jnp.int32, 'i8': jnp.int64, 'bf': jnp.bfloat16, 'f2': jnp.float16, 'f4': jnp.float32, 'f8': jnp.float64, 'c4': jnp.complex64, 'c8': jnp.complex128, 'i*': jnp.int64, 'f*': jnp.float64, 'c*': jnp.complex128, } dtype_to_typecode = { jnp.dtype(v): k for k, v in typecode_to_dtype.items() if not k.endswith('*') } def typecode_to_val(typecode): weak_type = typecode.endswith('*') dtype = typecode_to_dtype[typecode] val = dtype(0) if weak_type: val = val.item() return val def val_to_typecode(val): dtype = dtypes.result_type(val) weak_type = dtypes.is_weakly_typed(val) typecode = dtype_to_typecode[dtype] if weak_type: typecode = typecode[:-1] + '*' return typecode vals = [typecode_to_val(t) for t in typecodes] table = [[val_to_typecode(v1 + v2) for v1 in vals] for v2 in vals] def show_differences(epected, actual): diffs = "" for i, t1 in enumerate(typecodes): for j, t2 in enumerate(typecodes): if expected[i][j] != actual[i][j]: diffs += f"\n{t1}, {t2} -> want {expected[i][j]}, got {actual[i][j]}" return diffs self.assertEqual(table, expected, show_differences(expected, table)) @parameterized.named_parameters({ "testcase_name": "_xtype={}_ytype={}_xfun={}_yfun={}".format(xtype.__name__, ytype.__name__, xfun.__name__, yfun.__name__), "xtype": xtype, "ytype": ytype, "xfun": xfun, "yfun": yfun } for xtype, ytype in itertools.product( [int, float, jnp.int16, jnp.int32, jnp.float16, jnp.float32], repeat=2) for xfun, yfun in itertools.product( [identity, abs, jnp.array], repeat=2)) @jax.numpy_dtype_promotion('standard') def testBinaryPromotionJitInvariance(self, xtype, ytype, xfun, yfun): """Test jit invariance of simple binary promotion rules with and without weak types.""" f = lambda x, y: xfun(x) + yfun(y) args_maker = lambda: [xtype(1), ytype(1)] self._CompileAndCheck(f, args_maker, check_dtypes=True) @parameterized.named_parameters({ "testcase_name": f"_dtype={dtype}_weak_type={weak_type}", "dtype": dtype, "weak_type": weak_type } for dtype in all_dtypes for weak_type in [True, False]) def testUnaryPromotion(self, dtype, weak_type): # Regression test for https://github.com/google/jax/issues/6051 x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type) if weak_type: expected = dtypes.canonicalize_dtype( dtypes._default_types['f' if x.dtype == 'bfloat16' else x.dtype.kind]) else: expected = x.dtype self.assertEqual(dtypes.result_type(x), expected) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_dtype={dtype}_weak_type={weak_type}_promotion={promotion}", "dtype": dtype, "weak_type": weak_type, "promotion": promotion } for dtype in all_dtypes for weak_type in [True, False] for promotion in ['standard', 'strict'])) def testBinaryNonPromotion(self, dtype, weak_type, promotion): # Regression test for https://github.com/google/jax/issues/6051 x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type) with jax.numpy_dtype_promotion(promotion): y = (x + x) if promotion == 'standard' or not weak_type or dtype == dtypes.bool_: expected_dtype = dtype elif dtypes.issubdtype(dtype, np.complexfloating): expected_dtype = dtypes.complex_ elif dtypes.issubdtype(dtype, np.floating): expected_dtype = dtypes.float_ else: expected_dtype = dtypes.int_ # No boolean weak types. expected_weak_type = weak_type and dtype != bool expected_dtype = dtypes.canonicalize_dtype(expected_dtype) self.assertEqual(y.dtype, expected_dtype) self.assertEqual(dtypes.is_weakly_typed(y), expected_weak_type) @parameterized.named_parameters({ "testcase_name": f"_dtype={dtype}_weak_type={weak_type}", "dtype": dtype, "weak_type": weak_type } for dtype in all_dtypes for weak_type in [True, False]) def testDeviceArrayRepr(self, dtype, weak_type): val = lax_internal._convert_element_type(0, dtype, weak_type=weak_type) rep = repr(val) self.assertStartsWith(rep, 'DeviceArray(') if weak_type: self.assertEndsWith(rep, f"dtype={val.dtype.name}, weak_type=True)") else: self.assertEndsWith(rep, f"dtype={val.dtype.name})")
class VectorizeTest(jtu.JaxTestCase): @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_leftshape={}_rightshape={}".format(left_shape, right_shape), "left_shape": left_shape, "right_shape": right_shape, "result_shape": result_shape } for left_shape, right_shape, result_shape in [ ((2, 3), (3, 4), (2, 4)), ((2, 3), (1, 3, 4), (1, 2, 4)), ((5, 2, 3), (1, 3, 4), (5, 2, 4)), ((6, 5, 2, 3), (3, 4), (6, 5, 2, 4)), ])) def test_matmat(self, left_shape, right_shape, result_shape): matmat = jnp.vectorize(jnp.dot, signature='(n,m),(m,k)->(n,k)') self.assertEqual( matmat(jnp.zeros(left_shape), jnp.zeros(right_shape)).shape, result_shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_leftshape={}_rightshape={}".format(left_shape, right_shape), "left_shape": left_shape, "right_shape": right_shape, "result_shape": result_shape } for left_shape, right_shape, result_shape in [ ((2, 3), (3, ), (2, )), ((2, 3), (1, 3), (1, 2)), ((4, 2, 3), (1, 3), (4, 2)), ((5, 4, 2, 3), (1, 3), (5, 4, 2)), ])) def test_matvec(self, left_shape, right_shape, result_shape): matvec = jnp.vectorize(jnp.dot, signature='(n,m),(m)->(n)') self.assertEqual( matvec(jnp.zeros(left_shape), jnp.zeros(right_shape)).shape, result_shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_leftshape={}_rightshape={}".format(left_shape, right_shape), "left_shape": left_shape, "right_shape": right_shape, "result_shape": result_shape } for left_shape, right_shape, result_shape in [ ((3, ), (3, ), ()), ((2, 3), (3, ), (2, )), ((4, 2, 3), (3, ), (4, 2)), ])) def test_vecmat(self, left_shape, right_shape, result_shape): vecvec = jnp.vectorize(jnp.dot, signature='(m),(m)->()') self.assertEqual( vecvec(jnp.zeros(left_shape), jnp.zeros(right_shape)).shape, result_shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(shape), "shape": shape, "result_shape": result_shape } for shape, result_shape in [ ((3, ), ()), (( 2, 3, ), (2, )), (( 1, 2, 3, ), (1, 2)), ])) def test_magnitude(self, shape, result_shape): size = 1 for x in shape: size *= x inputs = jnp.arange(size).reshape(shape) @partial(jnp.vectorize, signature='(n)->()') def magnitude(x): return jnp.dot(x, x) self.assertEqual(magnitude(inputs).shape, result_shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(shape), "shape": shape, "result_shape": result_shape } for shape, result_shape in [ ((3, ), ()), ((2, 3), (2, )), ((1, 2, 3, 4), (1, 2, 3)), ])) def test_mean(self, shape, result_shape): mean = jnp.vectorize(jnp.mean, signature='(n)->()') self.assertEqual(mean(jnp.zeros(shape)).shape, result_shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(shape), "shape": shape, "result_shape": result_shape } for shape, result_shape in [ ((), (2, )), ((3, ), ( 3, 2, )), ])) def test_stack_plus_minus(self, shape, result_shape): @partial(jnp.vectorize, signature='()->(n)') def stack_plus_minus(x): return jnp.stack([x, -x]) self.assertEqual( stack_plus_minus(jnp.zeros(shape)).shape, result_shape) def test_center(self): @partial(jnp.vectorize, signature='(n)->(),(n)') def center(array): bias = jnp.mean(array) debiased = array - bias return bias, debiased b, a = center(jnp.arange(3)) self.assertEqual(a.shape, (3, )) self.assertEqual(b.shape, ()) self.assertAllClose(1.0, b, check_dtypes=False) b, a = center(jnp.arange(6).reshape(2, 3)) self.assertEqual(a.shape, (2, 3)) self.assertEqual(b.shape, (2, )) self.assertAllClose(jnp.array([1.0, 4.0]), b, check_dtypes=False) def test_exclude_first(self): @partial(jnp.vectorize, excluded={0}) def f(x, y): assert x == 'foo' assert y.ndim == 0 return y x = jnp.arange(3) self.assertAllClose(x, f('foo', x)) self.assertAllClose(x, jax.jit(f, static_argnums=0)('foo', x)) def test_exclude_second(self): @partial(jnp.vectorize, excluded={1}) def f(x, y): assert x.ndim == 0 assert y == 'foo' return x x = jnp.arange(3) self.assertAllClose(x, f(x, 'foo')) self.assertAllClose(x, jax.jit(f, static_argnums=1)(x, 'foo')) def test_exclude_errors(self): with self.assertRaisesRegex(TypeError, "jax.numpy.vectorize can only exclude"): jnp.vectorize(lambda x: x, excluded={'foo'}) with self.assertRaisesRegex( ValueError, r"excluded=\{-1\} contains negative numbers"): jnp.vectorize(lambda x: x, excluded={-1}) f = jnp.vectorize(lambda x: x, excluded={1}) with self.assertRaisesRegex( ValueError, r"excluded=\{1\} is invalid for 1 argument\(s\)"): f(1.0) def test_bad_inputs(self): matmat = jnp.vectorize(jnp.dot, signature='(n,m),(m,k)->(n,k)') with self.assertRaisesRegex(TypeError, "wrong number of positional arguments"): matmat(jnp.zeros((3, 2))) with self.assertRaisesRegex( ValueError, r"input with shape \(2,\) does not have enough dimensions"): matmat(jnp.zeros((2, )), jnp.zeros((2, 2))) with self.assertRaisesRegex( ValueError, r"inconsistent size for core dimension 'm'"): matmat(jnp.zeros((2, 3)), jnp.zeros((4, 5))) def test_wrong_output_type(self): f = jnp.vectorize(jnp.dot, signature='(n,m),(m,k)->(n,k),()') with self.assertRaisesRegex(TypeError, "output must be a tuple"): f(jnp.zeros((2, 2)), jnp.zeros((2, 2))) def test_wrong_num_outputs(self): f = jnp.vectorize(lambda *args: args, signature='(),()->(),(),()') with self.assertRaisesRegex(TypeError, "wrong number of output arguments"): f(1, 2) def test_wrong_output_shape(self): f = jnp.vectorize(jnp.dot, signature='(n,m),(m,k)->(n)') with self.assertRaisesRegex(ValueError, r"output shape \(2, 2\) does not match"): f(jnp.zeros((2, 2)), jnp.zeros((2, 2))) def test_inconsistent_output_size(self): f = jnp.vectorize(jnp.dot, signature='(n,m),(m,k)->(n,n)') with self.assertRaisesRegex( ValueError, r"inconsistent size for core dimension 'n'"): f(jnp.zeros((2, 3)), jnp.zeros((3, 4)))
class MaskingTest(jtu.JaxTestCase): def test_sum(self): @partial(mask, in_shapes=['n'], out_shape='') def padded_sum(x): return jnp.sum(x) ans = padded_sum([jnp.array([3, 1, 4, 1, 5])], dict(n=3)) expected = 8 self.assertAllClose(ans, expected, check_dtypes=False) ans = padded_sum([jnp.array([3, 1, 4, 1, 5])], dict(n=4)) expected = 9 self.assertAllClose(ans, expected, check_dtypes=False) def test_sum_vmap(self): @partial(mask, in_shapes=['n'], out_shape='') def padded_sum(x): return jnp.sum(x) ans = vmap(padded_sum)([jnp.ones((5, 10))], dict(n=jnp.arange(5))) expected = np.array([0, 1, 2, 3, 4]) self.assertAllClose(ans, expected, check_dtypes=False) def check(self, fun, in_shapes, out_shape, logical_env, padded_in_shapes, dtypes, rng, rtol=None, atol=None): shapecheck(in_shapes, out_shape)(fun) masked_fun = mask(fun, in_shapes, out_shape) padded_args = [rng(shape, dtype) for shape, dtype in zip(padded_in_shapes, dtypes)] padded_outs, outs_tree = tree_flatten(masked_fun(padded_args, logical_env)) out_specs, _ = tree_flatten(out_shape) out_specs = map(parse_spec, out_specs) out_specs = map(finalize_spec, out_specs, map(np.shape, padded_outs)) logical_out_shapes = [eval_poly_shape(s, logical_env) for s in out_specs] logical_out_slices = [tuple(map(slice, s)) for s in logical_out_shapes] logical_outs = [o[s] for o, s in zip(padded_outs, logical_out_slices)] in_specs = map(parse_spec, in_shapes) in_specs = map(finalize_spec, in_specs, padded_in_shapes) logical_in_shapes = [eval_poly_shape(s, logical_env) for s in in_specs] logical_in_slices = [tuple(map(slice, s)) for s in logical_in_shapes] logical_args = [a[s] for a, s in zip(padded_args, logical_in_slices)] logical_outs_expected, logical_outs_tree = tree_flatten(fun(*logical_args)) assert outs_tree == logical_outs_tree self.assertAllClose(logical_outs, logical_outs_expected, check_dtypes=True, atol=atol, rtol=rtol) # Check that abstract evaluation works padded_outs_jit, _ = tree_flatten(jit(masked_fun)(padded_args, logical_env)) self.assertAllClose(padded_outs_jit, padded_outs, check_dtypes=True, atol=atol, rtol=rtol) def test_add(self): self.check(lax.add, ['n', ''], 'n', {'n': 3}, [(4,), ()], ['float_', 'float_'], jtu.rand_default(self.rng())) addvecs = mask(lax.add, in_shapes=['n', 'n'], out_shape='n') x = jnp.array([3, 1, 4, 1, 5, 9]) y = jnp.array([2, 6, 5, 3, 5, 8]) ans = addvecs([x, y], dict(n=3)) expected = np.array([5, 7, 9]) self.assertAllClose(ans[:3], expected, check_dtypes=False) thunk = lambda: addvecs([jnp.arange(5), jnp.arange(6)], dict(n=3)) self.assertRaisesRegex(ShapeError, "", thunk) def test_scan(self): @partial(mask, in_shapes=['n'], out_shape='') def cumsum(arr): out, _ = lax.scan(lambda c, x: (c + x, ()), 0, arr) return out n = np.uint8(3) # Test non-default integer type for dynamic length. ans = cumsum([jnp.array([5, 2, 9, 1, 4])], dict(n=n)) expected = 16 self.assertAllClose(ans, expected, check_dtypes=False) def test_scan_vmap(self): @partial(mask, in_shapes=['n'], out_shape='') def cumsum(arr): out, _ = lax.scan(lambda c, x: (c + x, ()), 0, arr) return out ans = vmap(cumsum)([jnp.arange(6).reshape(2, 3)], dict(n=jnp.array([1, 2]))) expected = np.array([0, 7]) self.assertAllClose(ans, expected, check_dtypes=False) def test_scan_jit(self): @partial(mask, in_shapes=['n'], out_shape='') def cumsum(arr): out, _ = lax.scan(lambda c, x: (c + x, ()), 0, arr) return out @jit def jit_cumsum(args, shape_env): assert python_should_be_executing return cumsum(args, shape_env) python_should_be_executing = True ans = jit_cumsum([jnp.array([5, 2, 9, 1, 4])], dict(n=3)) expected = 16 self.assertAllClose(ans, expected, check_dtypes=False) python_should_be_executing = False ans = jit_cumsum([jnp.array([5, 2, 9, 1, 4])], dict(n=4)) expected = 17 self.assertAllClose(ans, expected, check_dtypes=False) python_should_be_executing = False ans = jit_cumsum([jnp.array([5, 2, 9, 1, 4])], dict(n=1)) expected = 5 self.assertAllClose(ans, expected, check_dtypes=False) # TODO Shapecheck fails - shape_as_value can't deal with abstract eval yet @unittest.skip("Shapecheck fails") def test_mean(self): self.check(lambda x: jnp.sum(x) / shape_as_value(x.shape)[0], ['n'], '', {'n': 3}, [(4,)], ['float_'], jtu.rand_default(self.rng())) @unittest.skip("Failing after fixing Poly unsoundness #4878") def test_arithmetic(self): @partial(mask, in_shapes=['(n, m)', 'm'], out_shape='(n, m)') def times(x, y): return x * y # TODO(shoyer): enable this check when broadcast_in_dim supports masking with self.assertRaisesRegex( NotImplementedError, 'Masking rule for broadcast_in_dim not implemented yet.'): times([jnp.array([[1, 2], [3, 4], [5, 6]]), jnp.array([1, 2])], dict(n=4, m=5)) # expected = np.array([[1, 2, 3], [8, 10, 12]]) # self.assertAllClose(ans, expected, check_dtypes=False) def test_stack(self): @partial(mask, in_shapes=['n','n'], out_shape='(2, n)') def stack(x, y): return jnp.stack([x, y], 0) # TODO(shoyer): enable this check when broadcast_in_dim supports masking with self.assertRaisesRegex( NotImplementedError, 'Masking rule for broadcast_in_dim not implemented yet.'): stack([jnp.array([1, 2, 3]), jnp.array([4, 5, 6])], dict(n=10)) # expected = np.array([[1, 2, 3], [4, 5, 6]]) # self.assertAllClose(ans, expected, check_dtypes=False) def test_monomorphic(self): @partial(mask, in_shapes=['(_, n)'], out_shape='') def padded_sum(x): return jnp.sum(x) ans = padded_sum([jnp.array([[3, 4], [5, 6]])], dict(n=1)) expected = 8 self.assertAllClose(ans, expected, check_dtypes=False) def test_monomorphic2(self): @partial(mask, in_shapes=['(_, n)'], out_shape='n') def padded_sum(x): return jnp.sum(x, axis=0) ans = padded_sum([jnp.array([[3, 4], [5, 6]])], dict(n=2)) expected = jnp.array([8, 10]) self.assertAllClose(ans, expected, check_dtypes=False) def test_monomorphic3(self): @partial(mask, in_shapes=['(_, n)'], out_shape='_') def padded_sum(x): return jnp.sum(x, axis=1) ans = padded_sum([jnp.array([[3, 4], [5, 6]])], dict(n=1)) expected = jnp.array([3, 5]) self.assertAllClose(ans, expected, check_dtypes=False) @shapecheck(['(2*n, n)'], '_, n') def identity(x): return x def test_rnn(self): n = 3 @partial(mask, in_shapes=['(_, _)', '(t, _)'], out_shape='_') def rnn(W, xs): def step(h, x): new_h = jnp.dot(W, h) + jnp.dot(W, x) return new_h, () predicted, _ = lax.scan(step, jnp.zeros(n), xs) return predicted rng = self.rng() W = jnp.eye(n) xs = rng.randn(10, n).astype(jnp.float_) ans = rnn([W, xs], dict(t=4)) expected = xs[:4].sum(0) self.assertAllClose(ans, expected, check_dtypes=False) def test_rnn_grad(self): n = 3 @partial(mask, in_shapes=['(_, _)', '(t, _)', '_'], out_shape='') def rnn(W, xs, target): def step(h, x): new_h = jnp.tanh(jnp.dot(W, h) + jnp.dot(W, x)) return new_h, () predicted, _ = lax.scan(step, jnp.zeros(n), xs) return jnp.sum((predicted - target)**2) rng = self.rng() W = rng.randn(n, n).astype(jnp.float_) xs = rng.randn(10, n).astype(jnp.float_) y = rng.randn(n).astype(jnp.float_) ans = grad(lambda W: rnn([W, xs, y], dict(t=4)))(W) def rnn_reference(W, xs, target): h = jnp.zeros(n) for x in xs: h = jnp.tanh(jnp.dot(W, h) + jnp.dot(W, x)) predicted = h return jnp.sum((predicted - target)**2) expected = grad(lambda W: rnn_reference(W, xs[:4], y))(W) self.assertAllClose(ans, expected, check_dtypes=False, rtol={np.float64: 1e-14, np.float32: 1e-5}) def test_ragged_batched_rnn(self): n = 3 @partial(mask, in_shapes=('(_, _)', '(t, _)', '_'), out_shape='') def rnn(W, xs, target): def step(h, x): new_h = jnp.tanh(jnp.dot(W, h) + jnp.dot(W, x)) return new_h, () predicted, _ = lax.scan(step, jnp.zeros(n), xs) return jnp.sum((predicted - target)**2) rng = self.rng() W = rng.randn(n, n).astype(jnp.float_) seqs = rng.randn(3, 10, n).astype(jnp.float_) ts = jnp.array([2, 5, 4]) ys = rng.randn(3, n) ans = grad(lambda W: vmap(rnn, ((None, 0, 0), 0))((W, seqs, ys), dict(t=ts)).sum())(W) def rnn_reference(W, seqs, targets): total_loss = jnp.array(0.0) for xs, target in zip(seqs, targets): h = jnp.zeros(n) for x in xs: h = jnp.tanh(jnp.dot(W, h) + jnp.dot(W, x)) predicted = h total_loss = total_loss + jnp.sum((predicted - target)**2) return total_loss seqs_ = [xs[:t] for xs, t in zip(seqs, ts)] expected = grad(lambda W: rnn_reference(W, seqs_, ys).sum())(W) self.assertAllClose( ans, expected, check_dtypes=False, rtol=0.1 if jtu.device_under_test() == "tpu" else 1e-5) def test_concatenate(self): self.check(lambda x, y, z: lax.concatenate([x, y, z], 0), ['n', 'm', 'n'], 'm + 2 * n', {'n': 2, 'm': 3}, [(4,), (3,), (4,)], ['float_', 'float_', 'float_'], jtu.rand_default(self.rng())) def test_dot(self): self.check(lax.dot, ['(m, k)', '(k, n)'], '(m, n)', dict(m=2, k=3, n=4), [(4, 5), (5, 7)], ['float_', 'float_'], jtu.rand_default(self.rng())) self.check(lax.dot, ['(m, n)', 'n'], 'm', dict(m=2, n=3), [(4, 5), (5,)], ['float_', 'float_'], jtu.rand_default(self.rng())) # TODO(mattjj,j-towns): fix test failure and reenable. @jtu.skip_on_devices("tpu") def test_jit(self): @partial(mask, in_shapes=['n'], out_shape='2*n') @jit def duplicate(x): assert python_should_be_executing return lax.concatenate([x, x], 0) python_should_be_executing = True out = duplicate([jnp.arange(3)], dict(n=2)) assert np.all(np.array([0, 1, 0, 1]) == out[:4]) python_should_be_executing = False out = duplicate([jnp.arange(3)], dict(n=2)) assert np.all(np.array([0, 1, 0, 1]) == out[:4]) @unittest.skip("broken by omnistaging") # TODO(mattjj): update def test_jit2(self): # Trigger MaskTrace.post_process_call def fun(x): @jit def concat(y): return lax.concatenate([x, y], 0) return concat(jnp.array([1., 2., 3.], dtype='float32')) self.check(fun, ['n'], '(n+3,)', {'n': 2}, [(3,)], ['float32'], jtu.rand_default(self.rng())) @parameterized.named_parameters({ 'testcase_name': "padding_config={}_shapes={}".format(padding_config, shape), 'padding_config': padding_config, 'shape': shape} for padding_config, shape in ( (((1, 2, 0),), (2,)), (((1, 2, 0), (3, 4, 0)), (1, 2)), (((0, 0, 0), (0, 0, 0)), (1, 2)), (((1, 2, 3),), (2,)), (((1, 2, 1), (3, 4, 2)), (3, 2)), (((-1, 2, 0),), (2,)), (((-1, -2, 0), (1, 2, 0)), (4, 2)), (((-1, 2, 0), (1, 2, 2)), (4, 2)), (((-1, -2, 2),), (5,)), (((-1, -2, 1), (1, 2, 2)), (4, 2)))) @unittest.skip("Failing after fixing Poly unsoundness #4878") def test_pad(self, padding_config, shape): def pad(x): return lax.pad(x, jnp.array(1., x.dtype), padding_config) if len(shape) == 1: padding_config_, = padding_config linear_coeff = padding_config_[2] + 1 const_coeff = sum(padding_config_[:2]) - padding_config_[2] out_shape = str(linear_coeff) + ' * h + ' + str(const_coeff) self.check(pad, ['h'], out_shape, dict(h=shape[0]), [tuple(np.add(shape, 1))], ['float_'], jtu.rand_default(self.rng())) # TODO(mattjj,j-towns): fix test failure and reenable. @jtu.skip_on_devices("tpu") @unittest.skip("broken by omnistaging") # TODO(mattjj): update def test_numpy_pad(self): def numpy_pad(x): return jnp.pad(x, (0, 1), constant_values=5.) self.check(numpy_pad, ['n'], 'n + 1', dict(n=2), [(3,)], ['float_'], jtu.rand_default(self.rng())) @parameterized.named_parameters(jtu.cases_from_list( {'testcase_name': "padding={}_lhs_dilation={}_" "dimension_numbers={}_lhs_perm={}_rhs_perm={}_out_perm={}".format( padding, lhs_dilation, dimension_numbers, lhs_perm, rhs_perm, out_perm), 'padding': padding, 'lhs_dilation': lhs_dilation, 'dimension_numbers': dimension_numbers, 'lhs_perm': lhs_perm, 'rhs_perm': rhs_perm, 'out_perm': out_perm} for padding in ['SAME', 'VALID', ((0, 1), (2, 0))] for lhs_dilation in (None, (1, 2)) for dimension_numbers, (lhs_perm, rhs_perm, out_perm) in ( (("NCHW", "OIHW", "NCHW"), ((0, 1, 2, 3), (0, 1, 2, 3), (0, 1, 2, 3))), (("NHWC", "HWIO", "NHWC"), ((0, 2, 3, 1), (2, 3, 1, 0), (0, 2, 3, 1))), (("NCHW", "HWIO", "NHWC"), ((0, 1, 2, 3), (2, 3, 1, 0), (0, 2, 3, 1))) ) # String padding is not implemented for transposed convolution, see # conv_general_dilated implementation: if (lhs_dilation is None or not isinstance(padding, str)))) @unittest.skip("Failing after fixing Poly unsoundness #4878") def test_conv( self, padding, lhs_dilation, dimension_numbers, lhs_perm, rhs_perm, out_perm): def conv(lhs, rhs): return lax.conv_general_dilated( lhs, rhs, (1, 1), padding, lhs_dilation=lhs_dilation, dimension_numbers=dimension_numbers) template = '({}, {}, {}, {})' lhs_shape = template.format(*np.take(['n', 'c', 'h', 'w'], lhs_perm)) rhs_shape = template.format(*np.take(['o', 'c', '2', '3'], rhs_perm)) if padding == 'VALID': out_shape = template.format( *np.take(['n', 'o', 'h+-1', 'w+-2'], out_perm)) elif lhs_dilation: out_shape = template.format( *np.take(['n', 'o', 'h', '2*w+-1'], out_perm)) else: out_shape = template.format( *np.take(['n', 'o', 'h', 'w'], out_perm)) logical_env = dict(n=3, c=2, h=4, w=5, o=6) self.check(conv, [lhs_shape, rhs_shape], out_shape, logical_env, [tuple(np.take([4, 3, 6, 7], lhs_perm)), tuple(np.take([7, 3, 2, 3], rhs_perm))], ['float_', 'float_'], jtu.rand_default(self.rng()), rtol=1e-4, atol=1e-4) @parameterized.named_parameters(jtu.cases_from_list( {'testcase_name': "padding={}_lhs_dilation={}_" "dimension_numbers={}_lhs_perm={}_rhs_perm={}_out_perm={}".format( padding, lhs_dilation, dimension_numbers, lhs_perm, rhs_perm, out_perm), 'padding': padding, 'lhs_dilation': lhs_dilation, 'dimension_numbers': dimension_numbers, 'lhs_perm': lhs_perm, 'rhs_perm': rhs_perm, 'out_perm': out_perm} for padding in ['SAME', 'VALID', ((0, 1), (2, 0))] for lhs_dilation in (None, (1, 2)) for dimension_numbers, (lhs_perm, rhs_perm, out_perm) in ( (("NCHW", "OIHW", "NCHW"), ((0, 1, 2, 3), (0, 1, 2, 3), (0, 1, 2, 3))), (("NHWC", "HWIO", "NHWC"), ((0, 2, 3, 1), (2, 3, 1, 0), (0, 2, 3, 1))), (("NCHW", "HWIO", "NHWC"), ((0, 1, 2, 3), (2, 3, 1, 0), (0, 2, 3, 1))) ) # String padding is not implemented for transposed convolution, see # conv_general_dilated implementation: if (lhs_dilation is None or not isinstance(padding, str)))) @unittest.skip("Failing after fixing Poly unsoundness #4878") def test_conv_strided( self, padding, lhs_dilation, dimension_numbers, lhs_perm, rhs_perm, out_perm): def conv(lhs, rhs): return lax.conv_general_dilated( lhs, rhs, (2, 1), padding, lhs_dilation=lhs_dilation, dimension_numbers=dimension_numbers) template = '({}, {}, {}, {})' rhs_shape = template.format(*np.take(['o', 'c', '2', '3'], rhs_perm)) if padding == 'VALID': lhs_shape = template.format(*np.take(['n', 'c', '2*h+1', 'w'], lhs_perm)) lhs_shape_padded = tuple(np.take([4, 3, 5, 7], lhs_perm)) out_shape = template.format(*np.take(['n', 'o', 'h', 'w+-2'], out_perm)) elif lhs_dilation: lhs_shape = template.format(*np.take(['n', 'c', '2*h', 'w'], lhs_perm)) lhs_shape_padded = tuple(np.take([4, 3, 6, 7], lhs_perm)) out_shape = template.format(*np.take(['n', 'o', 'h', '2*w+-1'], out_perm)) else: lhs_shape = template.format(*np.take(['n', 'c', '2*h', 'w'], lhs_perm)) lhs_shape_padded = tuple(np.take([4, 3, 6, 7], lhs_perm)) out_shape = template.format(*np.take(['n', 'o', 'h', 'w'], out_perm)) logical_env = dict(n=3, c=2, h=4, w=5, o=6) self.check(conv, [lhs_shape, rhs_shape], out_shape, logical_env, [lhs_shape_padded, tuple(np.take([7, 3, 2, 3], rhs_perm))], ['float_', 'float_'], jtu.rand_default(self.rng()), rtol=1e-4, atol=1e-4) @unittest.skip("requires gather support") def test_indexing(self): self.check(lambda x: x[0], ['n'], '', {'n': 2}, [(3,)], ['float_'], jtu.rand_default(self.rng())) self.check(lambda x: x[-1], ['n'], '', {'n': 2}, [(3,)], ['float_'], jtu.rand_default(self.rng())) @unittest.skip("requires gather support") def test_slicing(self): self.check(lambda x: x[1:], ['n'], 'n+-1', {'n': 2}, [(3,)], ['float_']) self.check(lambda x: x[:-1], ['n'], 'n+-1', {'n': 2}, [(3,)], ['float_']) self.check(lambda x: x[..., -1], ['(n,3)'], 'n', {'n': 2}, [(3, 4)], ['float_']) def test_rev(self): @shapecheck(['n'], 'n') def rev1(x): return lax.rev(x, (0,)) @shapecheck(['(m, n)'], '(m, n)') def rev2(x): return lax.rev(x, (1,)) @unittest.skip("TODO") def test_rev_by_indexing(self): @shapecheck(['n'], 'n+-1') def rev1(x): return x[:0:-1] @shapecheck(['n'], 'n+-1') def rev2(x): return x[-2::-1] # TODO implement masking for rev_p: # self.check(lambda x: x[:0:-1], ['n'], dict(n=jnp.array([2, 3])), 'n+-1') # self.check(lambda x: x[-2::-1], ['n'], dict(n=jnp.array([2, 3])), 'n+-1') @unittest.skip("Failing after fixing Poly unsoundness #4878") def test_lax_slice(self): self.check(lambda x: lax.slice(x, (1,), (x.shape[0],)), ['n'], 'n+-1', {'n': 2}, [(3,)], ['float_'], jtu.rand_default(self.rng())) # TODO self.check(lambda x: lax.slice(x, (x.shape[0] // 2,), (x.shape[0],)), # ['2*n'], 'n', {'n': 2}, [(6,)], ['float_'], jtu.rand_default(self.rng())) self.check(lambda x: lax.slice(x, (0,), (x.shape[0],), (x.shape[0],)), ['n'], '1', {'n': 2}, [(5,)], ['float_'], jtu.rand_default(self.rng())) @unittest.skip("Failing after fixing Poly unsoundness #4878") def test_reshape(self): self.check(lambda x: jnp.reshape(x, (x.shape[1], 2, 4, 1)), ['1, n, 4, 2'], 'n, 2, 4, 1', dict(n=2), [(1, 3, 4, 2)], ['float_'], jtu.rand_default(self.rng())) self.check(lambda x: jnp.reshape(x, (x.shape[0] * 2,)), ['n, 2'], '2 * n', dict(n=2), [(3, 2)], ['float_'], jtu.rand_default(self.rng())) self.check(lambda x: jnp.reshape(x, (x.shape[0] // 2, 2)), ['2 * n'], 'n, 2', dict(n=2), [(6,)], ['float_'], jtu.rand_default(self.rng())) self.check(lambda x: jnp.reshape(x, (x.shape[0] * 4, 2)), ['n, 2, 4'], '4 * n, 2', dict(n=2), [(3, 2, 4)], ['float_'], jtu.rand_default(self.rng())) self.check(lambda x: jnp.reshape(x, ((x.shape[0] - 1) // 4 + 1, 2, 4)), ['4 * n + 4, 2'], 'n + 1, 2, 4', dict(n=2), [(12, 2)], ['float_'], jtu.rand_default(self.rng())) msg = "Reshape on padded dimensions causing fragmentation is not supported." with self.assertRaisesRegex(NotImplementedError, msg): self.check(lambda x: jnp.reshape(x, np.prod(x.shape)), ['a, b'], 'a*b', dict(a=2, b=3), [(3, 4)], ['float_'], jtu.rand_default(self.rng())) with self.assertRaisesRegex(NotImplementedError, msg): self.check(lambda x: jnp.reshape(x, (x.shape[1], x.shape[0])), ['a, b'], 'b, a', dict(a=2, b=3), [(3, 4)], ['float_'], jtu.rand_default(self.rng())) with self.assertRaisesRegex(NotImplementedError, msg): self.check(lambda x: jnp.reshape(x, (x.shape[1] * 2,)), ['2, n'], '2 * n', dict(n=2), [(2, 3)], ['float_'], jtu.rand_default(self.rng())) self.check(lambda x: jnp.reshape(x, (x.shape[0], -1)), ['n, 3, 1, 2'], 'n, 6', dict(n=1), [(2, 3, 1, 2)], ['float_'], jtu.rand_default(self.rng())) def test_transpose(self): self.check(lambda x: lax.transpose(x, (1, 0, 2)), ['(a, b, c)'], 'b, a, c', dict(a=2, b=3, c=4), [(3, 4, 5)], ['float_'], jtu.rand_default(self.rng())) def test_sum_2d(self): self.check(jnp.sum, ['(m, n)'], '', dict(m=2, n=3), [(3, 4)], ['float_'], jtu.rand_default(self.rng())) @unittest.skip("custom_jvp doesn't work with masking yet") def test_expit(self): self.check(expit, ['n'], 'n', dict(n=3), [(4,)], ['float_'], jtu.rand_default(self.rng())) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype).name} for dtype in [np.float32, np.float64])) @unittest.skip("not yet implemented") def test_uniform(self, dtype): # TODO needs fix for https://github.com/google/jax/issues/2155 pass @unittest.skip("not yet implemented") def test_broadcast_in_dim(self): pass def test_destructure(self): def d(key): key1, key2 = key return key1 self.check(d, ['2'], '', {}, [(2,)], ['int_'], jtu.rand_int(self.rng(), 0, 10)) # TODO(mattjj,j-towns): fix test failure and reenable. @jtu.skip_on_devices("tpu") def test_where(self): self.check(lambda x: jnp.where(x < 0, x, 0. * x), ['n'], 'n', {'n': 2}, [(3,)], ['float_'], jtu.rand_default(self.rng())) @unittest.skip("Failing after fixing Poly unsoundness #4878") def test_split(self): self.check(lambda x: jnp.split(x, 2), ['2*n'], ['n', 'n'], dict(n=4), [(8,)], ['float_'], jtu.rand_default(self.rng())) self.check(lambda x: jnp.split(x, [10]), ['n'], ['10', 'n+-10'], dict(n=12), [(12,)], ['float_'], jtu.rand_default(self.rng())) @parameterized.named_parameters(jtu.cases_from_list([{ 'testcase_name': "operator={}".format(operator.__name__), 'operator': operator} for operator in [jnp.sum, jnp.prod, jnp.max, jnp.min]])) def test_reduce(self, operator): self.check(operator, ['(m+1, n+1)'], '', {'m': 3, 'n': 4}, [(4, 5)], ['float_'], jtu.rand_default(self.rng())) def test_output_shape_error(self): def thunk(): shapecheck(['n'], 'n+-1')(lambda x: x) message = "Output shapes should be (n + -1,) but are (n,)." self.assertRaisesWithLiteralMatch(ShapeError, message, thunk) def thunk(): shapecheck(['n'], ['7*n', 'n'])(lambda x: (x, x)) message = "Output shapes should be [(7 n,), (n,)] but are ((n,), (n,))." self.assertRaisesWithLiteralMatch(ShapeError, message, thunk) def test_output_tree_error(self): def thunk(): shapecheck(['n'], ('n', 'n'))(lambda x: [x, x]) message = "Output shapes should be ((n,), (n,)) but are [(n,), (n,)]." self.assertRaisesWithLiteralMatch(ShapeError, message, thunk) def test_unsupported_op(self): p = core.Primitive('unsupported_op') p.def_abstract_eval(lambda x: x) p.def_impl(lambda x: x) def thunk(): mask(p.bind, ['n'], 'n')([np.arange(3)], {'n': 2}) message = "Masking rule for unsupported_op not implemented yet." self.assertRaisesWithLiteralMatch(NotImplementedError, message, thunk) @unittest.skip("not yet implemented") def test_nesting(self): @partial(mask, in_shapes=['n'], out_shape='') def padded_sum(x): return jnp.sum(x) batched_sum = vmap(padded_sum) @partial(mask, in_shapes=['(m, _)', 'm'], out_shape='') def fun(x, ns): return batched_sum([x], dict(n=ns)).sum() x = jnp.array([[3, 1, 4, 1], [5, 9, 2, 6], [5, 3, 5, 8]]) ns = jnp.array([2, 3, 2]) ans = fun([x, ns], dict(m=2)) expected = 3+1 + 5+9+2 self.assertAllClose(ans, expected, check_dtypes=False) def test_slice_oob_indexing(self): # https://github.com/google/jax/issues/2245 self.assertAllClose(jnp.ones(5), jnp.ones(5)[:10]) self.assertAllClose(jnp.ones(5), jnp.ones(5)[-10:]) def test_jaxpr_doesnt_include_trivial_operations(self): @partial(mask, in_shapes=['n'], out_shape='') def foo(x): return np.sum(x) padded_x = np.array([0, 1, 2, 3, 999, 999]) jaxpr = make_jaxpr(foo)([padded_x], dict(n=3)) self.assertNotIn('mul', str(jaxpr)) self.assertNotIn('add', str(jaxpr)) def test_return_shape_to_user(self): @partial(mask, in_shapes=['n']) def foo(x): return [x, np.sum(x)] out, out_shape = foo([np.arange(5)], dict(n=2)) self.assertIsInstance(out_shape, list) self.assertLen(out_shape, 2) a, b = out_shape self.assertEqual(a.shape, (2,)) self.assertEqual(b.shape, ())
class TestBFGS(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": f"_func={func_and_init[0].__name__}_maxiter={maxiter}", "maxiter": maxiter, "func_and_init": func_and_init} for maxiter in [None] for func_and_init in [(rosenbrock, np.zeros(2)), (himmelblau, np.zeros(2)), (matyas, np.ones(2) * 6.), (eggholder, np.ones(2) * 100.)])) def test_minimize(self, maxiter, func_and_init): # Note, cannot compare step for step with scipy BFGS because our line search is _slightly_ different. func, x0 = func_and_init @jit def min_op(x0): result = jax.scipy.optimize.minimize( func(jnp), x0, method='BFGS', options=dict(maxiter=maxiter, gtol=1e-6), ) return result.x jax_res = min_op(x0) scipy_res = scipy.optimize.minimize(func(np), x0, method='BFGS').x self.assertAllClose(scipy_res, jax_res, atol=2e-5, check_dtypes=False) def test_fixes4594(self): n = 2 A = jnp.eye(n) * 1e4 def f(x): return jnp.mean((A @ x) ** 2) results = jax.scipy.optimize.minimize(f, jnp.ones(n), method='BFGS') self.assertAllClose(results.x, jnp.zeros(n), atol=1e-6, rtol=1e-6) @jtu.skip_on_flag('jax_enable_x64', False) def test_zakharov(self): def zakharov_fn(x): ii = jnp.arange(1, len(x) + 1, step=1) answer = zakharovFromIndices(x=x, ii=ii) return answer x0 = jnp.array([600.0, 700.0, 200.0, 100.0, 90.0, 1e4]) eval_func = jax.jit(zakharov_fn) jax_res = jax.scipy.optimize.minimize(fun=eval_func, x0=x0, method='BFGS') self.assertLess(jax_res.fun, 1e-6) def test_minimize_bad_initial_values(self): # This test runs deliberately "bad" initial values to test that handling # of failed line search, etc. is the same across implementations initial_value = jnp.array([92, 0.001]) opt_fn = himmelblau(jnp) jax_res = jax.scipy.optimize.minimize( fun=opt_fn, x0=initial_value, method='BFGS', ).x scipy_res = scipy.optimize.minimize( fun=opt_fn, jac=jax.grad(opt_fn), method='BFGS', x0=initial_value ).x self.assertAllClose(scipy_res, jax_res, atol=2e-5, check_dtypes=False) def test_args_must_be_tuple(self): A = jnp.eye(2) * 1e4 def f(x): return jnp.mean((A @ x) ** 2) with self.assertRaisesRegex(TypeError, "args .* must be a tuple"): jax.scipy.optimize.minimize(f, jnp.ones(2), args=45, method='BFGS')
class SvdTest(jtu.JaxTestCase): @parameterized.named_parameters( jtu.cases_from_list( { # pylint:disable=g-complex-comprehension 'testcase_name': '_m={}_by_n={}_log_cond={}_full_matrices={}'.format( m, n, log_cond, full_matrices), 'm': m, 'n': n, 'log_cond': log_cond, 'full_matrices': full_matrices } for m, n in zip([2, 8, 10, 20], [4, 6, 10, 18]) for log_cond in np.linspace(1, _MAX_LOG_CONDITION_NUM, 4) for full_matrices in [True, False])) def testSvdWithRectangularInput(self, m, n, log_cond, full_matrices): """Tests SVD with rectangular input.""" with jax.default_matmul_precision('float32'): a = np.random.uniform(low=0.3, high=0.9, size=(m, n)).astype(_SVD_TEST_DTYPE) u, s, v = osp_linalg.svd(a, full_matrices=False) cond = 10**log_cond s = jnp.linspace(cond, 1, min(m, n)) a = (u * s) @ v a = a.astype(complex) * (1 + 1j) osp_linalg_fn = functools.partial(osp_linalg.svd, full_matrices=full_matrices) actual_u, actual_s, actual_v = svd.svd(a, full_matrices=full_matrices) k = min(m, n) if m > n: unitary_u = jnp.real(actual_u.T.conj() @ actual_u) unitary_v = jnp.real(actual_v.T.conj() @ actual_v) unitary_u_size = m if full_matrices else k unitary_v_size = k else: unitary_u = jnp.real(actual_u @ actual_u.T.conj()) unitary_v = jnp.real(actual_v @ actual_v.T.conj()) unitary_u_size = k unitary_v_size = n if full_matrices else k _, expected_s, _ = osp_linalg_fn(a) svd_fn = lambda a: svd.svd(a, full_matrices=full_matrices) args_maker = lambda: [a] with self.subTest('Test JIT compatibility'): self._CompileAndCheck(svd_fn, args_maker) with self.subTest('Test unitary u.'): self.assertAllClose(np.eye(unitary_u_size), unitary_u, rtol=_SVD_RTOL, atol=2E-3) with self.subTest('Test unitary v.'): self.assertAllClose(np.eye(unitary_v_size), unitary_v, rtol=_SVD_RTOL, atol=2E-3) with self.subTest('Test s.'): self.assertAllClose(expected_s, jnp.real(actual_s), rtol=_SVD_RTOL, atol=1E-6) @parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': f'_m={m}_by_n={n}', 'm': m, 'n': n } for m, n in zip([50, 6], [3, 60]))) def testSvdWithSkinnyTallInput(self, m, n): """Tests SVD with skinny and tall input.""" # Generates a skinny and tall input with jax.default_matmul_precision('float32'): np.random.seed(1235) a = np.random.randn(m, n).astype(_SVD_TEST_DTYPE) u, s, v = svd.svd(a, full_matrices=False, hermitian=False) relative_diff = np.linalg.norm(a - (u * s) @ v) / np.linalg.norm(a) np.testing.assert_almost_equal(relative_diff, 1E-6, decimal=6) @parameterized.named_parameters( jtu.cases_from_list({ # pylint:disable=g-complex-comprehension 'testcase_name': f'_m={m}_r={r}_log_cond={log_cond}', 'm': m, 'r': r, 'log_cond': log_cond } for m, r in zip([8, 8, 8, 10], [3, 5, 7, 9]) for log_cond in np.linspace(1, 3, 3))) def testSvdWithOnRankDeficientInput(self, m, r, log_cond): """Tests SVD with rank-deficient input.""" with jax.default_matmul_precision('float32'): a = jnp.triu(jnp.ones((m, m))).astype(_SVD_TEST_DTYPE) # Generates a rank-deficient input. u, s, v = jnp.linalg.svd(a, full_matrices=False) cond = 10**log_cond s = jnp.linspace(cond, 1, m) s = s.at[r:m].set(jnp.zeros((m - r, ))) a = (u * s) @ v with jax.default_matmul_precision('float32'): u, s, v = svd.svd(a, full_matrices=False, hermitian=False) diff = np.linalg.norm(a - (u * s) @ v) np.testing.assert_almost_equal(diff, 1E-4, decimal=2) @parameterized.named_parameters( jtu.cases_from_list( { # pylint:disable=g-complex-comprehension 'testcase_name': '_m={}_by_n={}_log_cond={}_full_matrices={}'.format( m, n, log_cond, full_matrices), 'm': m, 'n': n, 'log_cond': log_cond, 'full_matrices': full_matrices } for m, n in zip([2, 8, 10, 20], [4, 6, 10, 18]) for log_cond in np.linspace(1, _MAX_LOG_CONDITION_NUM, 4) for full_matrices in [True, False])) def testSingularValues(self, m, n, log_cond, full_matrices): """Tests singular values.""" with jax.default_matmul_precision('float32'): a = np.random.uniform(low=0.3, high=0.9, size=(m, n)).astype(_SVD_TEST_DTYPE) u, s, v = osp_linalg.svd(a, full_matrices=False) cond = 10**log_cond s = np.linspace(cond, 1, min(m, n)) a = (u * s) @ v a = a + 1j * a # Only computes singular values. compute_uv = False osp_linalg_fn = functools.partial(osp_linalg.svd, full_matrices=full_matrices, compute_uv=compute_uv) actual_s = svd.svd(a, full_matrices=full_matrices, compute_uv=compute_uv) expected_s = osp_linalg_fn(a) svd_fn = lambda a: svd.svd(a, full_matrices=full_matrices) args_maker = lambda: [a] with self.subTest('Test JIT compatibility'): self._CompileAndCheck(svd_fn, args_maker) with self.subTest('Test s.'): self.assertAllClose(expected_s, actual_s, rtol=_SVD_RTOL, atol=1E-6) with self.subTest('Test non-increasing order.'): # Computes `actual_diff[i] = s[i+1] - s[i]`. actual_diff = jnp.diff(actual_s, append=0) np.testing.assert_array_less(actual_diff, np.zeros_like(actual_diff)) @parameterized.named_parameters([ { 'testcase_name': f'_m={m}_by_n={n}_full_matrices={full_matrices}_' # pylint:disable=g-complex-comprehension f'compute_uv={compute_uv}_dtype={dtype}', 'm': m, 'n': n, 'full_matrices': full_matrices, # pylint:disable=undefined-variable 'compute_uv': compute_uv, 'dtype': dtype } # pylint:disable=undefined-variable for m, n in zip([2, 4, 8], [4, 4, 6]) for full_matrices in [True, False] for compute_uv in [True, False] for dtype in jtu.dtypes.floating + jtu.dtypes.complex ]) def testSvdOnZero(self, m, n, full_matrices, compute_uv, dtype): """Tests SVD on matrix of all zeros.""" osp_fun = functools.partial(osp_linalg.svd, full_matrices=full_matrices, compute_uv=compute_uv) lax_fun = functools.partial(svd.svd, full_matrices=full_matrices, compute_uv=compute_uv) args_maker_svd = lambda: [jnp.zeros((m, n), dtype=dtype)] self._CheckAgainstNumpy(osp_fun, lax_fun, args_maker_svd) self._CompileAndCheck(lax_fun, args_maker_svd) @parameterized.named_parameters([{ 'testcase_name': f'_m={m}_by_n={n}_r={r}_c={c}_dtype={dtype}', 'm': m, 'n': n, 'r': r, 'c': c, 'dtype': dtype } for m, n, r, c in zip([2, 4, 8], [4, 4, 6], [1, 0, 1], [1, 0, 1]) for dtype in jtu.dtypes.floating]) def testSvdOnTinyElement(self, m, n, r, c, dtype): """Tests SVD on matrix of zeros and close-to-zero entries.""" a = jnp.zeros((m, n), dtype=dtype) tiny_element = jnp.finfo(a).tiny a = a.at[r, c].set(tiny_element) @jax.jit def lax_fun(a): return svd.svd(a, full_matrices=False, compute_uv=False, hermitian=False) actual_s = lax_fun(a) k = min(m, n) expected_s = np.zeros((k, ), dtype=dtype) expected_s[0] = tiny_element self.assertAllClose(expected_s, jnp.real(actual_s), rtol=_SVD_RTOL, atol=1E-6)
class LaxBackedScipyStatsTests(jtu.JaxTestCase): """Tests for LAX-backed scipy.stats implementations""" @genNamedParametersNArgs(3) def testPoissonLogPmf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.poisson.logpmf lax_fun = lsp_stats.poisson.logpmf def args_maker(): k, mu, loc = map(rng, shapes, dtypes) k = np.floor(k) # clipping to ensure that rate parameter is strictly positive mu = np.clip(np.abs(mu), a_min=0.1, a_max=None) loc = np.floor(loc) return [k, mu, loc] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14}) @genNamedParametersNArgs(3) def testPoissonPmf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.poisson.pmf lax_fun = lsp_stats.poisson.pmf def args_maker(): k, mu, loc = map(rng, shapes, dtypes) k = np.floor(k) # clipping to ensure that rate parameter is strictly positive mu = np.clip(np.abs(mu), a_min=0.1, a_max=None) loc = np.floor(loc) return [k, mu, loc] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testPoissonCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.poisson.cdf lax_fun = lsp_stats.poisson.cdf def args_maker(): k, mu, loc = map(rng, shapes, dtypes) # clipping to ensure that rate parameter is strictly positive mu = np.clip(np.abs(mu), a_min=0.1, a_max=None) return [k, mu, loc] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testBernoulliLogPmf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.bernoulli.logpmf lax_fun = lsp_stats.bernoulli.logpmf def args_maker(): x, logit, loc = map(rng, shapes, dtypes) x = np.floor(x) p = expit(logit) loc = np.floor(loc) return [x, p, loc] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testGeomLogPmf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.geom.logpmf lax_fun = lsp_stats.geom.logpmf def args_maker(): x, logit, loc = map(rng, shapes, dtypes) x = np.floor(x) p = expit(logit) loc = np.floor(loc) return [x, p, loc] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(5) def testBetaLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.beta.logpdf lax_fun = lsp_stats.beta.logpdf def args_maker(): x, a, b, loc, scale = map(rng, shapes, dtypes) return [x, a, b, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker, rtol={ np.float32: 2e-3, np.float64: 1e-4 }) def testBetaLogPdfZero(self): # Regression test for https://github.com/google/jax/issues/7645 a = b = 1. x = np.array([0., 1.]) self.assertAllClose(osp_stats.beta.pdf(x, a, b), lsp_stats.beta.pdf(x, a, b), atol=1E-6) @genNamedParametersNArgs(3) def testCauchyLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.cauchy.logpdf lax_fun = lsp_stats.cauchy.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": jtu.format_test_name_suffix("", [x_shape, alpha_shape], dtypes), "shapes": [x_shape, alpha_shape], "dtypes": dtypes } for x_shape in one_and_two_dim_shapes for alpha_shape in [( x_shape[0], ), ( x_shape[0] + 1, )] for dtypes in itertools.combinations_with_replacement( jtu.dtypes.floating, 2))) def testDirichletLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) def _normalize(x, alpha): x_norm = x.sum(0) + (0.0 if x.shape[0] == alpha.shape[0] else 0.1) return (x / x_norm).astype(x.dtype), alpha def lax_fun(x, alpha): return lsp_stats.dirichlet.logpdf(*_normalize(x, alpha)) def scipy_fun(x, alpha): # scipy validates the x normalization using float64 arithmetic, so we must # cast x to float64 before normalization to ensure this passes. x, alpha = _normalize(x.astype('float64'), alpha) result = osp_stats.dirichlet.logpdf(x, alpha) # if x.shape is (N, 1), scipy flattens the output, while JAX returns arrays # of a consistent rank. This check ensures the results have the same shape. return result if x.ndim == 1 else np.atleast_1d(result) def args_maker(): # Don't normalize here, because we want normalization to happen at 64-bit # precision in the scipy version. x, alpha = map(rng, shapes, dtypes) return x, alpha tol = {np.float32: 1E-3, np.float64: 1e-5} self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=tol) self._CompileAndCheck(lax_fun, args_maker, atol=tol, rtol=tol) @genNamedParametersNArgs(3) def testExponLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.expon.logpdf lax_fun = lsp_stats.expon.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(4) def testGammaLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.gamma.logpdf lax_fun = lsp_stats.gamma.logpdf def args_maker(): x, a, loc, scale = map(rng, shapes, dtypes) return [x, a, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker) def testGammaLogPdfZero(self): # Regression test for https://github.com/google/jax/issues/7256 self.assertAllClose(osp_stats.gamma.pdf(0.0, 1.0), lsp_stats.gamma.pdf(0.0, 1.0), atol=1E-6) @genNamedParametersNArgs(4) def testNBinomLogPmf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.nbinom.logpmf lax_fun = lsp_stats.nbinom.logpmf def args_maker(): k, n, logit, loc = map(rng, shapes, dtypes) k = np.floor(np.abs(k)) n = np.ceil(np.abs(n)) p = expit(logit) loc = np.floor(loc) return [k, n, p, loc] tol = {np.float32: 1e-6, np.float64: 1e-8} self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol) @genNamedParametersNArgs(3) def testLaplaceLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.laplace.logpdf lax_fun = lsp_stats.laplace.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(scale, a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testLaplaceCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.laplace.cdf lax_fun = lsp_stats.laplace.cdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # ensure that scale is not too low scale = np.clip(scale, a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol={ np.float32: 1e-5, np.float64: 1e-6 }) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(1) def testLogisticCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.cdf lax_fun = lsp_stats.logistic.cdf def args_maker(): return list(map(rng, shapes, dtypes)) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-6) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(1) def testLogisticLogpdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.logpdf lax_fun = lsp_stats.logistic.logpdf def args_maker(): return list(map(rng, shapes, dtypes)) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(1) def testLogisticPpf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.ppf lax_fun = lsp_stats.logistic.ppf def args_maker(): return list(map(rng, shapes, dtypes)) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(1) def testLogisticSf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.sf lax_fun = lsp_stats.logistic.sf def args_maker(): return list(map(rng, shapes, dtypes)) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-6) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testNormLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.norm.logpdf lax_fun = lsp_stats.norm.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testNormLogCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.norm.logcdf lax_fun = lsp_stats.norm.logcdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testNormCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.norm.cdf lax_fun = lsp_stats.norm.cdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-6) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testNormPpf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.norm.ppf lax_fun = lsp_stats.norm.ppf def args_maker(): q, loc, scale = map(rng, shapes, dtypes) # ensure probability is between 0 and 1: q = np.clip(np.abs(q / 3), a_min=None, a_max=1) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None) return [q, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4) @genNamedParametersNArgs(4) def testParetoLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.pareto.logpdf lax_fun = lsp_stats.pareto.logpdf def args_maker(): x, b, loc, scale = map(rng, shapes, dtypes) return [x, b, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(4) def testTLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.t.logpdf lax_fun = lsp_stats.t.logpdf def args_maker(): x, df, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None) return [x, df, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14}, atol={np.float64: 1e-14}) @genNamedParametersNArgs(3) def testUniformLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.uniform.logpdf lax_fun = lsp_stats.uniform.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) return [x, loc, np.abs(scale)] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(4) def testChi2LogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.chi2.logpdf lax_fun = lsp_stats.chi2.logpdf def args_maker(): x, df, loc, scale = map(rng, shapes, dtypes) return [x, df, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(5) def testBetaBinomLogPmf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) lax_fun = lsp_stats.betabinom.logpmf def args_maker(): k, n, a, b, loc = map(rng, shapes, dtypes) k = np.floor(k) n = np.ceil(n) a = np.clip(a, a_min=0.1, a_max=None) b = np.clip(a, a_min=0.1, a_max=None) loc = np.floor(loc) return [k, n, a, b, loc] if scipy_version >= (1, 4): scipy_fun = osp_stats.betabinom.logpmf self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker, rtol=1e-5, atol=1e-5) def testIssue972(self): self.assertAllClose(np.ones((4, ), np.float32), lsp_stats.norm.cdf( np.full((4, ), np.inf, np.float32)), check_dtypes=False) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_x={}_mean={}_cov={}".format( jtu.format_shape_dtype_string(x_shape, x_dtype), jtu.format_shape_dtype_string(mean_shape, mean_dtype) if mean_shape is not None else None, jtu.format_shape_dtype_string(cov_shape, cov_dtype) if cov_shape is not None else None), "x_shape": x_shape, "x_dtype": x_dtype, "mean_shape": mean_shape, "mean_dtype": mean_dtype, "cov_shape": cov_shape, "cov_dtype": cov_dtype } for x_shape, mean_shape, cov_shape in [ # # These test cases cover default values for mean/cov, but we don't # # support those yet (and they seem not very valuable). # [(), None, None], # [(), (), None], # [(2,), None, None], # [(2,), (), None], # [(2,), (2,), None], # [(3, 2), (3, 2,), None], # [(5, 3, 2), (5, 3, 2,), None], [(), (), ()], [(3, ), (), ()], [(3, ), (3, ), ()], [(3, ), (3, ), (3, 3)], [(3, 4), (4, ), (4, 4)], [(2, 3, 4), (4, ), (4, 4)], ] for x_dtype, mean_dtype, cov_dtype in itertools.combinations_with_replacement( jtu.dtypes.floating, 3) if (mean_shape is not None or mean_dtype == np.float32) and (cov_shape is not None or cov_dtype == np.float32)) ) def testMultivariateNormalLogpdf(self, x_shape, x_dtype, mean_shape, mean_dtype, cov_shape, cov_dtype): rng = jtu.rand_default(self.rng()) def args_maker(): args = [rng(x_shape, x_dtype)] if mean_shape is not None: args.append(5 * rng(mean_shape, mean_dtype)) if cov_shape is not None: if cov_shape == (): args.append(0.1 + rng(cov_shape, cov_dtype)**2) else: factor_shape = (*cov_shape[:-1], 2 * cov_shape[-1]) factor = rng(factor_shape, cov_dtype) args.append(np.matmul(factor, np.swapaxes(factor, -1, -2))) return args self._CheckAgainstNumpy(osp_stats.multivariate_normal.logpdf, lsp_stats.multivariate_normal.logpdf, args_maker, tol=1e-3) self._CompileAndCheck(lsp_stats.multivariate_normal.logpdf, args_maker, rtol=1e-4, atol=1e-4) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_x={}_mean={}_cov={}".format( jtu.format_shape_dtype_string(x_shape, x_dtype), jtu.format_shape_dtype_string(mean_shape, mean_dtype) if mean_shape is not None else None, jtu.format_shape_dtype_string(cov_shape, cov_dtype) if cov_shape is not None else None), "x_shape": x_shape, "x_dtype": x_dtype, "mean_shape": mean_shape, "mean_dtype": mean_dtype, "cov_shape": cov_shape, "cov_dtype": cov_dtype } for x_shape, mean_shape, cov_shape in [ # These test cases are where scipy flattens things, which has # different batch semantics than some might expect, so we manually # vectorize scipy's outputs for the sake of testing. [(5, 3, 2), (5, 3, 2), (5, 3, 2, 2)], [(2, ), (5, 3, 2), (5, 3, 2, 2)], [(5, 3, 2), (2, ), (5, 3, 2, 2)], [(5, 3, 2), ( 5, 3, 2, ), (2, 2)], [(1, 3, 2), ( 3, 2, ), (5, 1, 2, 2)], [(5, 3, 2), ( 1, 2, ), (2, 2)], ] for x_dtype, mean_dtype, cov_dtype in itertools.combinations_with_replacement( jtu.dtypes.floating, 3) if (mean_shape is not None or mean_dtype == np.float32) and (cov_shape is not None or cov_dtype == np.float32)) ) def testMultivariateNormalLogpdfBroadcasted(self, x_shape, x_dtype, mean_shape, mean_dtype, cov_shape, cov_dtype): rng = jtu.rand_default(self.rng()) def args_maker(): args = [rng(x_shape, x_dtype)] if mean_shape is not None: args.append(5 * rng(mean_shape, mean_dtype)) if cov_shape is not None: if cov_shape == (): args.append(0.1 + rng(cov_shape, cov_dtype)**2) else: factor_shape = (*cov_shape[:-1], 2 * cov_shape[-1]) factor = rng(factor_shape, cov_dtype) args.append(np.matmul(factor, np.swapaxes(factor, -1, -2))) return args osp_fun = np.vectorize(osp_stats.multivariate_normal.logpdf, signature="(n),(n),(n,n)->()") self._CheckAgainstNumpy(osp_fun, lsp_stats.multivariate_normal.logpdf, args_maker, tol=1e-3) self._CompileAndCheck(lsp_stats.multivariate_normal.logpdf, args_maker, rtol=1e-4, atol=1e-4) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_ndim={}_nbatch={}_dtype={}".format(ndim, nbatch, dtype.__name__), "ndim": ndim, "nbatch": nbatch, "dtype": dtype } for ndim in [2, 3] for nbatch in [1, 3, 5] for dtype in jtu.dtypes.floating)) def testMultivariateNormalLogpdfBatch(self, ndim, nbatch, dtype): # Regression test for #5570 rng = jtu.rand_default(self.rng()) x = rng((nbatch, ndim), dtype) mean = 5 * rng((nbatch, ndim), dtype) factor = rng((nbatch, ndim, 2 * ndim), dtype) cov = factor @ factor.transpose(0, 2, 1) result1 = lsp_stats.multivariate_normal.logpdf(x, mean, cov) result2 = jax.vmap(lsp_stats.multivariate_normal.logpdf)(x, mean, cov) self.assertArraysEqual(result1, result2)
class CheckifyTransformTests(jtu.JaxTestCase): @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_jit={jit}", "jit": jit } for jit in [False, True])) @jtu.skip_on_devices("tpu") def test_jit_nan(self, jit): def f(x1, x2): y1 = jnp.sin(x1) y2 = jnp.sin(x2) return y1 + y2 f = jax.jit(f) if jit else f checked_f = checkify.checkify(f, errors=checkify.float_checks) err, _ = checked_f(3., 4.) self.assertIs(err.get(), None) err, _ = checked_f(3., jnp.inf) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "nan generated by primitive sin") @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_jit={jit}", "jit": jit } for jit in [False, True])) def test_jit_oob(self, jit): def f(x, i): y = jnp.sin(x) z = y[i] w = jnp.cos(z) return w f = jax.jit(f) if jit else f checked_f = checkify.checkify(f, errors=checkify.index_checks) err, _ = checked_f(jnp.arange(3), 2) self.assertIs(err.get(), None) err, _ = checked_f(jnp.arange(3), 5) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "out-of-bounds indexing") @parameterized.named_parameters( { "testcase_name": f"_updatefn={update_fn}", "update_fn": update_fn } for update_fn in ["set", "add", "multiply", "divide", "power", "min", "max", "get"]) def test_jit_oob_update(self, update_fn): def f(x, i): return getattr(x.at[i], update_fn)(1) f = jax.jit(f) checked_f = checkify.checkify(f, errors=checkify.index_checks) err, _ = checked_f(jnp.arange(3), 2) self.assertIs(err.get(), None) err, _ = checked_f(jnp.arange(3), 3) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "out-of-bounds indexing") @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_jit={jit}", "jit": jit } for jit in [False, True])) def test_jit_div_errors(self, jit): def f(x, y): return x / y f = jax.jit(f) if jit else f checked_f = checkify.checkify(f, errors=checkify.float_checks) err, _ = checked_f(jnp.ones((3, )), jnp.ones((3, ))) self.assertIs(err.get(), None) err, _ = checked_f(jnp.ones((3, )), jnp.array([1., 0., 1.])) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "divided by zero") err, _ = checked_f(jnp.array([1, jnp.inf, 1]), jnp.array([1, jnp.inf, 1])) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "nan generated by primitive div") @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_jit={jit}", "jit": jit } for jit in [False, True])) @jtu.skip_on_devices("tpu") def test_jit_multi(self, jit): def f(x, i): y = x[i] z = jnp.cos(y) return z f = jax.jit(f) if jit else f checked_f = checkify.checkify(f, errors=checkify.automatic_checks) # no error err, _ = checked_f(jnp.array([0., jnp.inf, 2.]), 2) self.assertIs(err.get(), None) # oob error err, _ = checked_f(jnp.array([0., 1., 2.]), 5) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "out-of-bounds indexing") # nan error err, _ = checked_f(jnp.array([0., 1., jnp.inf]), 2) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "nan generated by primitive cos") def test_numpy_indexing_oobs(self): def raises_oob(fn, idx, *expected_strs): err, _ = checkify.checkify(fn, errors=checkify.index_checks)(x, idx) error_txt = err.get() self.assertIsNotNone(error_txt) self.assertStartsWith(error_txt, "out-of-bounds indexing") for s in expected_strs: self.assertIn(s, error_txt) x = jnp.ones((2, 3, 7)) axis0_msg = "axis 0 with size 2" axis1_msg = "axis 1 with size 3" axis2_msg = "axis 2 with size 7" single_idx = lambda x, i: x[i] raises_oob(single_idx, 5, "index 5", axis0_msg) raises_oob(single_idx, -5, "index -3", axis0_msg) raises_oob(single_idx, (0, 100), "index 100", axis1_msg) raises_oob(single_idx, (0, 5, 100), "index 5", axis1_msg) raises_oob(single_idx, (0, 0, 100), "index 100", axis2_msg) raises_oob(single_idx, ((1, 20), (1, 4)), "index 20", axis0_msg) raises_oob(single_idx, ((1, 20), (3, 4)), "index 3", axis1_msg) raises_oob(single_idx, (((1, 1), (1, 20)), 3), "index 3", axis1_msg) raises_oob(single_idx, (((1, 1), (1, 20)), 0), "index 20", axis0_msg) multi_idx = lambda x, i: x[i[0], :, i[1]] raises_oob(multi_idx, (0, 9), "index 9", axis2_msg) # TODO(lenamartens): numpy reports index -5 here, need to normalize? raises_oob(multi_idx, (-5, 9), "index -3", axis0_msg) raises_oob(multi_idx, (5, -9), "index 5", axis0_msg) raises_oob(multi_idx, ((0, 9), 0), "index 9", axis0_msg) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_jit={jit}", "jit": jit } for jit in [False, True])) def test_jit_ordering(self, jit): def f(x, i): y = x[i] z = jnp.sin(x) return y * z f = jax.jit(f) if jit else f checked_f = checkify.checkify(f, errors=checkify.automatic_checks) # both oob and nan error, but oob happens first err, _ = checked_f(jnp.array([0., 1., jnp.inf]), 5) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "out-of-bounds indexing") @jtu.skip_on_devices("tpu") def test_pmap_basic(self): if len(jax.devices()) < 2: raise unittest.SkipTest("requires at least 2 devices") @jax.pmap def f(x1, x2): y1 = jnp.sin(x1) y2 = jnp.sin(x2) return y1 + y2 checked_f = checkify.checkify(f, errors=checkify.float_checks) xs = jnp.array([0., 2.]) err, _ = checked_f(xs, xs) self.assertIs(err.get(), None) ys = jnp.array([3., jnp.inf]) err, _ = checked_f(xs, ys) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "nan generated by primitive sin") @jtu.skip_on_devices("tpu") def test_cond_basic(self): @jax.jit def f(x): return lax.cond(x > 0, lambda: jnp.sin(x), lambda: x) checked_f = checkify.checkify(f, errors=checkify.float_checks) err, _ = checked_f(3.) self.assertIs(err.get(), None) err, _ = checked_f(jnp.inf) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "nan generated by primitive sin") err, _ = checked_f(-jnp.inf) self.assertIs(err.get(), None) @jtu.skip_on_devices("tpu") def test_scan_map(self): def scan_body(_, x): return None, jnp.sin(x) @jax.jit def f(xs): return lax.scan(scan_body, None, xs) checked_f = checkify.checkify(f, errors=checkify.float_checks) xs = jnp.array([0., 2.]) err, (_, ch_outs) = checked_f(xs) _, outs = f(xs) self.assertIs(err.get(), None) self.assertArraysEqual(ch_outs, outs) xs = jnp.array([3., jnp.inf]) err, (_, ch_outs) = checked_f(xs) _, outs = f(xs) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "nan generated by primitive sin") self.assertArraysEqual(ch_outs, outs) @jtu.skip_on_devices("tpu") def test_scan_carry(self): def scan_body(carry, x): carry = carry - 1. possible_nan = jnp.sin(1. / carry) return carry, x + possible_nan @jax.jit def f(carry, xs): return lax.scan(scan_body, carry, xs) checked_f = checkify.checkify(f, errors=checkify.float_checks) carry, xs = 3., jnp.ones((2, )) err, (ch_out_carry, ch_outs) = checked_f(carry, xs) out_carry, outs = f(carry, xs) self.assertIs(err.get(), None) self.assertArraysEqual(ch_outs, outs) self.assertArraysEqual(ch_out_carry, out_carry) # error happens on first iteration carry, xs = 1., jnp.ones((2, )) err, (ch_out_carry, ch_outs) = checked_f(carry, xs) out_carry, outs = f(carry, xs) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "divided by zero") self.assertArraysEqual(ch_outs, outs) self.assertArraysEqual(ch_out_carry, out_carry) # error happens on second iteration carry, xs = 2., jnp.ones((4, )) err, (ch_out_carry, ch_outs) = checked_f(carry, xs) out_carry, outs = f(carry, xs) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "divided by zero") self.assertArraysEqual(ch_outs, outs) self.assertArraysEqual(ch_out_carry, out_carry) @jtu.skip_on_devices("tpu") def test_while_loop_body_error(self): def while_cond(val): i, _ = val return i < 2 def while_body(val): i, x = val possible_nan = jnp.sin(1. / i) return i + 1., x + possible_nan @jax.jit def f(init_val): return lax.while_loop(while_cond, while_body, (init_val, 0.)) checked_f = checkify.checkify(f, errors=checkify.float_checks) init_val = 1. err, ch_out = checked_f(init_val) out = f(init_val) self.assertIs(err.get(), None) self.assertArraysEqual(ch_out, out) init_val = 0. err, ch_out = checked_f(init_val) out = f(init_val) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "divided by zero") self.assertArraysEqual(ch_out, out) @jtu.skip_on_devices("tpu") def test_while_loop_cond_error(self): def while_cond(val): _ = jnp.sin(1. / val) return val < 2. def while_body(val): return val + 1. @jax.jit def f(init_val): return lax.while_loop(while_cond, while_body, init_val) checked_f = checkify.checkify(f, errors=checkify.float_checks) init_val = 1. err, ch_out = checked_f(init_val) out = f(init_val) self.assertIs(err.get(), None) self.assertArraysEqual(ch_out, out) init_val = 0. err, ch_out = checked_f(init_val) out = f(init_val) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "divided by zero") self.assertArraysEqual(ch_out, out) @jtu.skip_on_devices("tpu") def test_while_loop_cond_error_and_false(self): # Tests if an error is generated when cond returns False. def while_cond(val): possible_nan = jnp.sin(1. / val) return jnp.logical_not(jnp.isnan(possible_nan)) @jax.jit def f(init_val): return lax.while_loop(while_cond, lambda val: val - 1, init_val) checked_f = checkify.checkify(f, errors=checkify.float_checks) # error on first cond init_val = 0. err, _ = checked_f(init_val) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "divided by zero") # error on second cond init_val = 1. err, _ = checked_f(init_val) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "divided by zero") @jtu.skip_on_devices("tpu") def test_while_loop_body_and_cond_error(self): def while_cond(val): i, cond_val, _ = val _ = jnp.sin(cond_val) return i < 2 def while_body(val): i, cond_val, body_val = val possible_nan = jnp.cos(body_val) return i + 1., cond_val, possible_nan @jax.jit def f(cond_val, body_val): return lax.while_loop(while_cond, while_body, (0., cond_val, body_val)) checked_f = checkify.checkify(f, errors=checkify.float_checks) cond_val = jnp.inf body_val = 1. err, _ = checked_f(cond_val, body_val) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "nan generated by primitive sin") cond_val = 1. body_val = jnp.inf err, _ = checked_f(cond_val, body_val) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "nan generated by primitive cos") cond_val = jnp.inf body_val = jnp.inf err, _ = checked_f(cond_val, body_val) self.assertIsNotNone(err.get()) # first error which occurs is in cond self.assertStartsWith(err.get(), "nan generated by primitive sin") def test_empty_enabled_errors(self): def multi_errors(x): x = x / 0 # DIV x = jnp.sin(x) # NAN x = x[500] # OOB checkify.check(x < 0, "must be negative!") # ASSERT return x x = jnp.ones((2, )) err, _ = checkify.checkify(multi_errors, errors=set())(x) self.assertIsNone(err.get()) @parameterized.named_parameters( ("assert", checkify.user_checks, "must be negative!"), ("div", {checkify.ErrorCategory.DIV}, "divided by zero"), ("nan", {checkify.ErrorCategory.NAN}, "nan generated"), ("oob", checkify.index_checks, "out-of-bounds indexing"), ("automatic_checks", checkify.automatic_checks, "divided by zero"), ) @jtu.skip_on_devices("tpu") def test_enabled_errors(self, error_set, expected_error): def multi_errors(x): checkify.check(jnp.all(x < 0), "must be negative!") # ASSERT x = x / 0 # DIV x = jnp.sin(x) # NAN x = x[500] # OOB return x x = jnp.ones((2, )) err, _ = checkify.checkify(multi_errors, errors=error_set)(x) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), expected_error) @jtu.skip_on_devices("tpu") def test_post_process_call(self): @partial(checkify.checkify, errors=checkify.float_checks) def g(x): @jax.jit def f(y): return jnp.sin(x * y) return f(jnp.inf) err, _ = g(2.) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "nan generated by primitive sin") @jtu.skip_on_devices("tpu") def test_post_process_map(self): @partial(checkify.checkify, errors=checkify.float_checks) def g(x): @jax.pmap def f(y): return jnp.sin(x * y) return f(jnp.array([jnp.inf]))[0] err, _ = g(2.) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), 'nan generated by primitive sin') @jtu.skip_on_devices("tpu") def test_custom_jvp(self): @jax.custom_jvp def sin(x): return jnp.sin(x) @sin.defjvp def sin_jvp(primals, tangents): (x, ), (xdot, ) = primals, tangents return sin(x), jnp.cos(x) * xdot f = checkify.checkify(sin, errors=checkify.float_checks) err, y = f(3.) self.assertIsNone(err.get()) err, y = f(jnp.inf) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), 'nan generated by primitive sin') # When we hit the custom jvp rule with jvp-of-checkify, no checks are added. (err, y), (errdot, ydot) = jax.jvp(f, (3., ), (1., )) # doesn't crash self.assertIsNone(err.get()) # no error self.assertEmpty(err.msgs) # and no checks were added! self.assertEmpty(errdot.msgs) y_expected, ydot_expected = jax.jvp(jnp.sin, (3., ), (1., )) self.assertAllClose(y, y_expected) self.assertAllClose(ydot, ydot_expected) # Grad-of-checkify doesn't crash either. x_bar = jax.grad(lambda x: f(x)[1])(3.) self.assertAllClose(x_bar, jnp.cos(3.)) # Checkify-of-jvp adds checks (unlike jvp-of-checkify above). g = checkify.checkify(lambda x, xdot: jax.jvp(sin, (x, ), (xdot, )), errors=checkify.float_checks) err, (y, ydot) = g(3., 1.) # doesn't crash self.assertIsNone(err.get()) # no error self.assertNotEmpty(err.msgs) # but checks were added! self.assertAllClose(y, jnp.sin(3.)) self.assertAllClose(ydot, jnp.cos(3.)) err, _ = g(jnp.inf, 1.) self.assertIsNotNone(err.get()) # yes error self.assertStartsWith(err.get(), 'nan generated by primitive sin') @jtu.skip_on_devices("tpu") def test_custom_vjp(self): @jax.custom_vjp def sin(x): return jnp.sin(x) def sin_fwd(x): return jnp.sin(x), 2. * x def sin_bwd(x2, g): return jnp.cos(x2 / 2.) * g, sin.defvjp(sin_fwd, sin_bwd) f = checkify.checkify(sin, errors=checkify.float_checks) # no differentiation, no error err, y = f(3.) self.assertIsNone(err.get()) # no differentiation, yes error err, y = f(jnp.inf) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), 'nan generated by primitive sin') # When we hit the custom vjp rule with vjp-of-checkify, no checks are added. (err, y), f_vjp = jax.vjp(f, 3.) self.assertIsNone(err.get()) # no error self.assertEmpty(err.msgs) # and no checks were added! # Checkify-of-vjp adds checks (unlike vjp-of-checkify above). err, y = checkify.checkify(jax.grad(sin), errors=checkify.float_checks)(3.) self.assertIsNone(err.get()) # no error self.assertNotEmpty(err.msgs) # but checks were added! err, y = checkify.checkify(jax.grad(sin), errors=checkify.float_checks)(jnp.inf) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "nan generated by primitive sin") def test_scan_consts(self): def f(xs): def scan_body(carry, _): # closes oves xs return carry + 1, xs[carry] return lax.scan(scan_body, 1, xs)[1] checked_f = checkify.checkify(f, errors=checkify.index_checks) err, _ = checked_f(jnp.ones((7, 3))) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "out-of-bounds indexing") def test_scan_consts2(self): def f(xs): def scan_body(carry, _): # add more consts! _ = xs[carry], xs[carry], jnp.sin(np.arange(11.)) return carry + 1, xs[carry] return lax.scan(scan_body, 1, xs)[1] checked_f = checkify.checkify(f, errors=checkify.index_checks) err, _ = checked_f(jnp.ones((7, 3))) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "out-of-bounds indexing") def test_while_consts(self): def f(xs): def while_cond(carry): i, _ = carry _ = xs[i], jnp.sin(np.arange(11.)) return i > -1 def while_body(carry): i, _ = carry x = xs[i] return i - 1, x / i return lax.while_loop(while_cond, while_body, (0, jnp.zeros_like(xs[0]))) checked_f = checkify.checkify(f, errors=checkify.float_checks) err, _ = checked_f(jnp.ones((7, 3))) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "divided by zero") def test_multiple_payloads(self): def f(x): _ = x[5] _ = x[6] err, _ = checkify.checkify(f, errors=checkify.index_checks)(jnp.ones( (2, ))) self.assertIsNotNone(err.get()) self.assertIn("index 5", err.get()) def test_nd_payloads(self): cf = checkify.checkify(lambda x, i: x[i], errors=checkify.index_checks) errs, _ = jax.vmap(cf)(jnp.ones((3, 2)), jnp.array([5, 0, 100])) self.assertIsNotNone(errs.get()) self.assertIn("index 5", errs.get()) self.assertIn("index 100", errs.get()) def test_mapped_error_one_payload(self): def f(x, i): x = x[i] return x / 0 cf = checkify.checkify(f, errors=checkify.automatic_checks) errs, _ = jax.vmap(cf)(jnp.ones((2, 1)), jnp.array([0, 100])) self.assertIsNotNone(errs.get()) self.assertIn("divided by zero", errs.get()) self.assertIn("index 100", errs.get())
class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase): # This test runs for all primitive harnesses. For each primitive "xxx" the # test will be called "test_prim_xxx_..." and the custom parameters for # the test are defined in the class method "jax2tf_limitations.Jax2TfLimitation.xxx". # See more details in the comment at top of file and in Jax2TfLimitation class. # If you want to run this test for only one harness, add parameter # `one_containing="foo"` to parameterized below. @primitive_harness.parameterized( primitive_harness.all_harnesses, include_jax_unimpl=False, #one_containing="cumprod_dtype_by_fun_shape=float16[8,9]_axis=0_reverse=False" ) @jtu.ignore_warning(category=UserWarning, message="Using reduced precision for gradient.*") def test_prim(self, harness: primitive_harness.Harness): limitations = Jax2TfLimitation.limitations_for_harness(harness) device = jtu.device_under_test() limitations = tuple( filter(lambda l: l.filter(device=device, dtype=harness.dtype), limitations)) func_jax = harness.dyn_fun args = harness.dyn_args_maker(self.rng()) enable_xla = harness.params.get("enable_xla", True) associative_scan_reductions = harness.params.get( "associative_scan_reductions", False) with jax.jax2tf_associative_scan_reductions( associative_scan_reductions): self.ConvertAndCompare(func_jax, *args, limitations=limitations, enable_xla=enable_xla) def test_primitive_coverage(self): """Fail if there are JAX primitives that are not implemented.""" # Harvest primitives from XLA translation tables all_primitives = (set(xla._translations) | set(xla._backend_specific_translations["cpu"]) | set(xla._backend_specific_translations["gpu"]) | set(xla._backend_specific_translations["tpu"]) | set(mlir._lowerings) | set(mlir._platform_specific_lowerings["cpu"]) | set(mlir._platform_specific_lowerings["gpu"]) | set(mlir._platform_specific_lowerings["tpu"])) tf_impl = set(jax.experimental.jax2tf.jax2tf.tf_impl) | set( jax.experimental.jax2tf.jax2tf.tf_impl_with_avals) tf_not_yet_impl = set(jax.experimental.jax2tf.jax2tf.tf_not_yet_impl) all_primitives = tuple(sorted(all_primitives, key=str)) for p in all_primitives: if p.name == "axis_index": continue # TODO: Remove once tensorflow is 2.10.0 everywhere. if p.name == "optimization_barrier": continue if p.name in tf_not_yet_impl: self.assertNotIn( p, tf_impl ) # Should not be in both tf_impl and tf_not_yet_impl else: self.assertIn(p, tf_impl) def test_generate_limitations_doc(self): """Generates primitives_with_limited_support.md. See the doc for instructions. """ harnesses = [ h for h in primitive_harness.all_harnesses if h.filter(h, include_jax_unimpl=True) ] print(f"Found {len(harnesses)} test harnesses that work in JAX") def unique_hash(h: primitive_harness.Harness, l: Jax2TfLimitation): return (h.group_name, l.description, l.devices, tuple(np.dtype(d).name for d in l.dtypes), l.modes) unique_limitations: Dict[Any, Tuple[primitive_harness.Harness, Jax2TfLimitation]] = {} for h in harnesses: for l in h.jax_unimplemented: if l.enabled: # Fake a Jax2TFLimitation from the Limitation tfl = Jax2TfLimitation( description="Not implemented in JAX: " + l.description, devices=l.devices, dtypes=l.dtypes, expect_tf_error=False, skip_tf_run=True) unique_limitations[hash(unique_hash(h, tfl))] = (h, tfl) for h in harnesses: for l in Jax2TfLimitation.limitations_for_harness(h): unique_limitations[hash(unique_hash(h, l))] = (h, l) print(f"Found {len(unique_limitations)} unique limitations") tf_error_table = [ """ | Affected primitive | Description of limitation | Affected dtypes | Affected devices | Affected compilation modes | | --- | --- | --- | --- | --- |""" ] tf_numerical_discrepancies_table = list(tf_error_table) # a copy for h, l in sorted(unique_limitations.values(), key=lambda pair: unique_hash(*pair)): devices = ", ".join(sorted(l.devices)) modes = ", ".join(sorted(l.modes)) description = l.description if l.skip_comparison: description = "Numeric comparison disabled: " + description if l.expect_tf_error: description = "TF error: " + description if l.skip_tf_run: description = "TF test skipped: " + description if l.skip_tf_run or l.expect_tf_error: to_table = tf_error_table elif l.skip_comparison or l.custom_assert: to_table = tf_numerical_discrepancies_table else: continue to_table.append( f"| {h.group_name} | {description} | " f"{primitive_harness.dtypes_to_str(l.dtypes, empty_means_all=True)} | {devices} | {modes} |" ) if not os.environ.get("JAX_OUTPUT_LIMITATIONS_DOC"): raise unittest.SkipTest( "Set JAX_OUTPUT_LIMITATIONS_DOC=1 to enable the generation of the documentation" ) # The CPU has more supported types, and harnesses self.assertEqual("cpu", jtu.device_under_test()) self.assertTrue( config.x64_enabled, "Documentation generation must be run with JAX_ENABLE_X64=1") with open( os.path.join( os.path.dirname(__file__), "../g3doc/primitives_with_limited_support.md.template") ) as f: template = f.read() output_file = os.path.join( os.path.dirname(__file__), "../g3doc/primitives_with_limited_support.md") with open(output_file, "w") as f: f.write(template.replace("{{generation_date}}", str(datetime.date.today())) \ .replace("{{tf_error_table}}", "\n".join(tf_error_table)) \ .replace("{{tf_numerical_discrepancies_table}}", "\n".join(tf_numerical_discrepancies_table)) \ ) # The rest of the test are checking special cases @parameterized.named_parameters( dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax) for f_jax in [ jnp.add, jnp.subtract, jnp.multiply, jnp.divide, jnp.less, jnp.less_equal, jnp.equal, jnp.greater, jnp.greater_equal, jnp.not_equal, jnp.maximum, jnp.minimum ]) def test_type_promotion(self, f_jax=jnp.add): # We only test a few types here, as tensorflow does not support many # types like uint* or bool in binary ops. types = [dtypes.bfloat16, np.int32, np.int64, np.float32] for x_dtype in types: for y_dtype in types: x = np.array([1, 2], dtype=x_dtype) y = np.array([3, 4], dtype=y_dtype) self.ConvertAndCompare(f_jax, x, y) def test_integer_div(self): x = jnp.array([-4, -3, -1, 0, 1, 3, 6]) y = np.int32(3) self.ConvertAndCompare(jnp.floor_divide, x, y) expected = jnp.floor_divide(x, y) # Try it with TF 1 as well (#5831) with tf.compat.v1.Session() as sess: tf1_res = sess.run(jax2tf.convert(jnp.floor_divide)(x, y)) self.assertAllClose(expected, tf1_res) def test_boolean_gather(self): values = np.array([[True, True], [False, True], [False, False]], dtype=np.bool_) indices = np.array([0, 1], dtype=np.int32) for axis in [0, 1]: f_jax = jax.jit(lambda v, i: jnp.take(v, i, axis=axis)) # pylint: disable=cell-var-from-loop self.ConvertAndCompare(f_jax, values, indices) def test_gather_rank_change(self): params = jnp.array([[1.0, 1.5, 2.0], [2.0, 2.5, 3.0], [3.0, 3.5, 4.0]]) indices = jnp.array([[1, 1, 2], [0, 1, 0]]) f_jax = jax.jit(lambda i: params[i]) self.ConvertAndCompare(f_jax, indices) @parameterized.named_parameters( jtu.cases_from_list( dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax) for f_jax in REDUCE)) def test_reduce_ops_with_numerical_input(self, f_jax): values = np.array([1, 2, 3], dtype=np.float32) self.ConvertAndCompare(f_jax, values) @parameterized.named_parameters( jtu.cases_from_list( dict(testcase_name=f"_{op}", op=op) for op in ("add", "max", "min", "multiply", "set"))) def test_scatter_static(self, op): values = np.ones((5, 6), dtype=np.float32) update = np.float32(6.) f_jax = jax.jit(lambda v, u: getattr(v.at[::2, 3:], op)(u)) self.ConvertAndCompare(f_jax, values, update) @parameterized.named_parameters( jtu.cases_from_list( dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax) for f_jax in REDUCE)) def test_reduce_ops_with_boolean_input(self, f_jax): values = np.array([True, False, True], dtype=np.bool_) self.ConvertAndCompare(f_jax, values)
class StaxTest(jtu.JaxTestCase): @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(shape), "shape": shape } for shape in [(2, 3), (5, )])) def testRandnInitShape(self, shape): key = random.PRNGKey(0) out = stax.randn()(key, shape) self.assertEqual(out.shape, shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(shape), "shape": shape } for shape in [(2, 3), (2, 3, 4)])) def testGlorotInitShape(self, shape): key = random.PRNGKey(0) out = stax.glorot()(key, shape) self.assertEqual(out.shape, shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}" .format(channels, filter_shape, padding, strides, input_shape), "channels": channels, "filter_shape": filter_shape, "padding": padding, "strides": strides, "input_shape": input_shape } for channels in [2, 3] for filter_shape in [(1, 1), (2, 3)] for padding in ["SAME", "VALID"] for strides in [None, (2, 1)] for input_shape in [(2, 10, 11, 1)])) def testConvShape(self, channels, filter_shape, padding, strides, input_shape): init_fun, apply_fun = stax.Conv(channels, filter_shape, strides=strides, padding=padding) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}" .format(channels, filter_shape, padding, strides, input_shape), "channels": channels, "filter_shape": filter_shape, "padding": padding, "strides": strides, "input_shape": input_shape } for channels in [2, 3] for filter_shape in [(1, 1), (2, 3), (3, 3)] for padding in ["SAME", "VALID"] for strides in [None, (2, 1), (2, 2)] for input_shape in [(2, 10, 11, 1)])) def testConvTransposeShape(self, channels, filter_shape, padding, strides, input_shape): init_fun, apply_fun = stax.ConvTranspose( channels, filter_shape, # 2D strides=strides, padding=padding) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}" .format(channels, filter_shape, padding, strides, input_shape), "channels": channels, "filter_shape": filter_shape, "padding": padding, "strides": strides, "input_shape": input_shape } for channels in [2, 3] for filter_shape in [(1, ), (2, ), (3, )] for padding in ["SAME", "VALID"] for strides in [None, (1, ), (2, )] for input_shape in [(2, 10, 1)])) def testConv1DTransposeShape(self, channels, filter_shape, padding, strides, input_shape): init_fun, apply_fun = stax.Conv1DTranspose(channels, filter_shape, strides=strides, padding=padding) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_out_dim={}_input_shape={}".format(out_dim, input_shape), "out_dim": out_dim, "input_shape": input_shape } for out_dim in [3, 4] for input_shape in [(2, 3), (3, 4)])) def testDenseShape(self, out_dim, input_shape): init_fun, apply_fun = stax.Dense(out_dim) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_input_shape={}_nonlinear={}".format(input_shape, nonlinear), "input_shape": input_shape, "nonlinear": nonlinear } for input_shape in [(2, 3), (2, 3, 4)] for nonlinear in ["Relu", "Sigmoid", "Elu", "LeakyRelu"])) def testNonlinearShape(self, input_shape, nonlinear): init_fun, apply_fun = getattr(stax, nonlinear) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_window_shape={}_padding={}_strides={}_input_shape={}" "_maxpool={}_spec={}".format(window_shape, padding, strides, input_shape, max_pool, spec), "window_shape": window_shape, "padding": padding, "strides": strides, "input_shape": input_shape, "max_pool": max_pool, "spec": spec } for window_shape in [(1, 1), (2, 3)] for padding in ["VALID"] for strides in [None, (2, 1)] for input_shape in [(2, 5, 6, 4)] for max_pool in [False, True] for spec in ["NHWC", "NCHW", "WHNC", "WHCN"])) def testPoolingShape(self, window_shape, padding, strides, input_shape, max_pool, spec): layer = stax.MaxPool if max_pool else stax.AvgPool init_fun, apply_fun = layer(window_shape, padding=padding, strides=strides, spec=spec) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(input_shape), "input_shape": input_shape } for input_shape in [(2, 3), (2, 3, 4)])) def testFlattenShape(self, input_shape): init_fun, apply_fun = stax.Flatten _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_input_shape={}_spec={}".format( input_shape, i), "input_shape": input_shape, "spec": spec } for input_shape in [(2, 5, 6, 1)] for i, spec in enumerate([[stax.Conv(3, ( 2, 2))], [stax.Conv(3, (2, 2)), stax.Flatten, stax.Dense(4)]]))) def testSerialComposeLayersShape(self, input_shape, spec): init_fun, apply_fun = stax.serial(*spec) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_input_shape={}".format(input_shape), "input_shape": input_shape } for input_shape in [(3, 4), (2, 5, 6, 1)])) def testDropoutShape(self, input_shape): init_fun, apply_fun = stax.Dropout(0.9) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_input_shape={}".format(input_shape), "input_shape": input_shape } for input_shape in [(3, 4), (2, 5, 6, 1)])) def testFanInSum(self, input_shape): init_fun, apply_fun = stax.FanInSum _CheckShapeAgreement(self, init_fun, apply_fun, [input_shape, input_shape]) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_inshapes={}_axis={}".format( input_shapes, axis), "input_shapes": input_shapes, "axis": axis } for input_shapes, axis in [ ([(2, 3), (2, 1)], 1), ([(2, 3), (2, 1)], -1), ([(1, 2, 4), (1, 1, 4)], 1), ])) def testFanInConcat(self, input_shapes, axis): init_fun, apply_fun = stax.FanInConcat(axis) _CheckShapeAgreement(self, init_fun, apply_fun, input_shapes) def testIssue182(self): key = random.PRNGKey(0) init_fun, apply_fun = stax.Softmax input_shape = (10, 3) inputs = np.arange(30.).astype("float32").reshape(input_shape) out_shape, params = init_fun(key, input_shape) out = apply_fun(params, inputs) assert out_shape == out.shape assert np.allclose(np.sum(np.asarray(out), -1), 1.) def testBatchNormNoScaleOrCenter(self): key = random.PRNGKey(0) axes = (0, 1, 2) init_fun, apply_fun = stax.BatchNorm(axis=axes, center=False, scale=False) input_shape = (4, 5, 6, 7) inputs = random_inputs(np.random.RandomState(0), input_shape) out_shape, params = init_fun(key, input_shape) out = apply_fun(params, inputs) means = np.mean(out, axis=(0, 1, 2)) std_devs = np.std(out, axis=(0, 1, 2)) assert np.allclose(means, np.zeros_like(means), atol=1e-4) assert np.allclose(std_devs, np.ones_like(std_devs), atol=1e-4) def testBatchNormShapeNHWC(self): key = random.PRNGKey(0) init_fun, apply_fun = stax.BatchNorm(axis=(0, 1, 2)) input_shape = (4, 5, 6, 7) inputs = random_inputs(np.random.RandomState(0), input_shape) out_shape, params = init_fun(key, input_shape) out = apply_fun(params, inputs) self.assertEqual(out_shape, input_shape) beta, gamma = params self.assertEqual(beta.shape, (7, )) self.assertEqual(gamma.shape, (7, )) self.assertEqual(out_shape, out.shape) def testBatchNormShapeNCHW(self): key = random.PRNGKey(0) # Regression test for https://github.com/google/jax/issues/461 init_fun, apply_fun = stax.BatchNorm(axis=(0, 2, 3)) input_shape = (4, 5, 6, 7) inputs = random_inputs(np.random.RandomState(0), input_shape) out_shape, params = init_fun(key, input_shape) out = apply_fun(params, inputs) self.assertEqual(out_shape, input_shape) beta, gamma = params self.assertEqual(beta.shape, (5, )) self.assertEqual(gamma.shape, (5, )) self.assertEqual(out_shape, out.shape)
class ImageTest(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_target={}_method={}_antialias={}".format( jtu.format_shape_dtype_string(image_shape, dtype), jtu.format_shape_dtype_string(target_shape, dtype), method, antialias), "dtype": dtype, "image_shape": image_shape, "target_shape": target_shape, "method": method, "antialias": antialias} for dtype in float_dtypes for target_shape, image_shape in itertools.combinations_with_replacement( [[2, 3, 2, 4], [2, 6, 4, 4], [2, 33, 17, 4], [2, 50, 38, 4]], 2) for method in ["nearest", "bilinear", "lanczos3", "lanczos5", "bicubic"] for antialias in [False, True])) @unittest.skipIf(not tf, "Test requires TensorFlow") def testResizeAgainstTensorFlow(self, dtype, image_shape, target_shape, method, antialias): # TODO(phawkins): debug this. There is a small mismatch between TF and JAX # for some cases of non-antialiased bicubic downscaling; we would expect # exact equality. if method == "bicubic" and any(x < y for x, y in zip(target_shape, image_shape)): raise unittest.SkipTest("non-antialiased bicubic downscaling mismatch") rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(image_shape, dtype),) def tf_fn(x): out = tf.image.resize( x.astype(np.float64), tf.constant(target_shape[1:-1]), method=method, antialias=antialias).numpy().astype(dtype) return out jax_fn = partial(image.resize, shape=target_shape, method=method, antialias=antialias) self._CheckAgainstNumpy(tf_fn, jax_fn, args_maker, check_dtypes=True, tol={np.float16: 2e-2, jnp.bfloat16: 1e-1, np.float32: 1e-4, np.float64: 1e-4}) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_target={}_method={}".format( jtu.format_shape_dtype_string(image_shape, dtype), jtu.format_shape_dtype_string(target_shape, dtype), method), "dtype": dtype, "image_shape": image_shape, "target_shape": target_shape, "method": method} for dtype in [np.float32] for target_shape, image_shape in itertools.combinations_with_replacement( [[3, 2], [6, 4], [33, 17], [50, 39]], 2) for method in ["nearest", "bilinear", "lanczos3", "bicubic"])) @unittest.skipIf(not PIL_Image, "Test requires PIL") def testResizeAgainstPIL(self, dtype, image_shape, target_shape, method): rng = jtu.rand_uniform(self.rng()) args_maker = lambda: (rng(image_shape, dtype),) def pil_fn(x): pil_methods = { "nearest": PIL_Image.NEAREST, "bilinear": PIL_Image.BILINEAR, "bicubic": PIL_Image.BICUBIC, "lanczos3": PIL_Image.LANCZOS, } img = PIL_Image.fromarray(x.astype(np.float32)) out = np.asarray(img.resize(target_shape[::-1], pil_methods[method]), dtype=dtype) return out jax_fn = partial(image.resize, shape=target_shape, method=method, antialias=True) self._CheckAgainstNumpy(pil_fn, jax_fn, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_target={}_method={}".format( jtu.format_shape_dtype_string(image_shape, dtype), jtu.format_shape_dtype_string(target_shape, dtype), method), "dtype": dtype, "image_shape": image_shape, "target_shape": target_shape, "method": method} for dtype in inexact_dtypes for image_shape, target_shape in [ ([3, 1, 2], [6, 1, 4]), ([1, 3, 2, 1], [1, 6, 4, 1]), ] for method in ["nearest", "linear", "lanczos3", "lanczos5", "cubic"])) def testResizeUp(self, dtype, image_shape, target_shape, method): data = [64, 32, 32, 64, 50, 100] expected_data = {} expected_data["nearest"] = [ 64.0, 64.0, 32.0, 32.0, 64.0, 64.0, 32.0, 32.0, 32.0, 32.0, 64.0, 64.0, 32.0, 32.0, 64.0, 64.0, 50.0, 50.0, 100.0, 100.0, 50.0, 50.0, 100.0, 100.0 ] expected_data["linear"] = [ 64.0, 56.0, 40.0, 32.0, 56.0, 52.0, 44.0, 40.0, 40.0, 44.0, 52.0, 56.0, 36.5, 45.625, 63.875, 73.0, 45.5, 56.875, 79.625, 91.0, 50.0, 62.5, 87.5, 100.0 ] expected_data["lanczos3"] = [ 75.8294, 59.6281, 38.4313, 22.23, 60.6851, 52.0037, 40.6454, 31.964, 35.8344, 41.0779, 47.9383, 53.1818, 24.6968, 43.0769, 67.1244, 85.5045, 35.7939, 56.4713, 83.5243, 104.2017, 44.8138, 65.1949, 91.8603, 112.2413 ] expected_data["lanczos5"] = [ 77.5699, 60.0223, 40.6694, 23.1219, 61.8253, 51.2369, 39.5593, 28.9709, 35.7438, 40.8875, 46.5604, 51.7041, 21.5942, 43.5299, 67.7223, 89.658, 32.1213, 56.784, 83.984, 108.6467, 44.5802, 66.183, 90.0082, 111.6109 ] expected_data["cubic"] = [ 70.1453, 59.0252, 36.9748, 25.8547, 59.3195, 53.3386, 41.4789, 35.4981, 36.383, 41.285, 51.0051, 55.9071, 30.2232, 42.151, 65.8032, 77.731, 41.6492, 55.823, 83.9288, 98.1026, 47.0363, 62.2744, 92.4903, 107.7284 ] x = np.array(data, dtype=dtype).reshape(image_shape) output = image.resize(x, target_shape, method) expected = np.array(expected_data[method], dtype=dtype).reshape(target_shape) self.assertAllClose(output, expected, atol=1e-04) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_target={}_method={}_antialias={}".format( jtu.format_shape_dtype_string(image_shape, dtype), jtu.format_shape_dtype_string(target_shape, dtype), method, antialias), "dtype": dtype, "image_shape": image_shape, "target_shape": target_shape, "method": method, "antialias": antialias} for dtype in [np.float32] for target_shape, image_shape in itertools.combinations_with_replacement( [[2, 3, 2, 4], [2, 6, 4, 4], [2, 33, 17, 4], [2, 50, 38, 4]], 2) for method in ["bilinear", "lanczos3", "lanczos5", "bicubic"] for antialias in [False, True])) def testResizeGradients(self, dtype, image_shape, target_shape, method, antialias): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(image_shape, dtype),) jax_fn = partial(image.resize, shape=target_shape, method=method, antialias=antialias) jtu.check_grads(jax_fn, args_maker(), order=2, rtol=1e-2, eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_target={}_method={}_antialias={}".format( jtu.format_shape_dtype_string(image_shape, dtype), jtu.format_shape_dtype_string(target_shape, dtype), method, antialias), "dtype": dtype, "image_shape": image_shape, "target_shape": target_shape, "method": method, "antialias": antialias} for dtype in [np.float32] for image_shape, target_shape in [ ([1], [0]), ([5, 5], [5, 0]), ([5, 5], [0, 1]), ([5, 5], [0, 0]) ] for method in ["nearest", "linear", "lanczos3", "lanczos5", "cubic"] for antialias in [False, True])) def testResizeEmpty(self, dtype, image_shape, target_shape, method, antialias): # Regression test for https://github.com/google/jax/issues/7586 image = np.ones(image_shape, dtype) out = jax.image.resize(image, shape=target_shape, method=method, antialias=antialias) self.assertArraysEqual(out, jnp.zeros(target_shape, dtype)) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_target={}_method={}".format( jtu.format_shape_dtype_string(image_shape, dtype), jtu.format_shape_dtype_string(target_shape, dtype), method), "dtype": dtype, "image_shape": image_shape, "target_shape": target_shape, "scale": scale, "translation": translation, "method": method} for dtype in inexact_dtypes for image_shape, target_shape, scale, translation in [ ([3, 1, 2], [6, 1, 4], [2.0, 1.0, 2.0], [1.0, 0.0, -1.0]), ([1, 3, 2, 1], [1, 6, 4, 1], [1.0, 2.0, 2.0, 1.0], [0.0, 1.0, -1.0, 0.0])] for method in ["linear", "lanczos3", "lanczos5", "cubic"])) def testScaleAndTranslateUp(self, dtype, image_shape, target_shape, scale, translation, method): data = [64, 32, 32, 64, 50, 100] # Note zeros occur in the output because the sampling location is outside # the boundaries of the input image. expected_data = {} expected_data["linear"] = [ 0.0, 0.0, 0.0, 0.0, 56.0, 40.0, 32.0, 0.0, 52.0, 44.0, 40.0, 0.0, 44.0, 52.0, 56.0, 0.0, 45.625, 63.875, 73.0, 0.0, 56.875, 79.625, 91.0, 0.0 ] expected_data["lanczos3"] = [ 0.0, 0.0, 0.0, 0.0, 59.6281, 38.4313, 22.23, 0.0, 52.0037, 40.6454, 31.964, 0.0, 41.0779, 47.9383, 53.1818, 0.0, 43.0769, 67.1244, 85.5045, 0.0, 56.4713, 83.5243, 104.2017, 0.0 ] expected_data["lanczos5"] = [ 0.0, 0.0, 0.0, 0.0, 60.0223, 40.6694, 23.1219, 0.0, 51.2369, 39.5593, 28.9709, 0.0, 40.8875, 46.5604, 51.7041, 0.0, 43.5299, 67.7223, 89.658, 0.0, 56.784, 83.984, 108.6467, 0.0 ] expected_data["cubic"] = [ 0.0, 0.0, 0.0, 0.0, 59.0252, 36.9748, 25.8547, 0.0, 53.3386, 41.4789, 35.4981, 0.0, 41.285, 51.0051, 55.9071, 0.0, 42.151, 65.8032, 77.731, 0.0, 55.823, 83.9288, 98.1026, 0.0 ] x = np.array(data, dtype=dtype).reshape(image_shape) # Should we test different float types here? scale_a = jnp.array(scale, dtype=jnp.float32) translation_a = jnp.array(translation, dtype=jnp.float32) output = image.scale_and_translate(x, target_shape, range(len(image_shape)), scale_a, translation_a, method) expected = np.array( expected_data[method], dtype=dtype).reshape(target_shape) self.assertAllClose(output, expected, atol=2e-03) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_dtype={}_method={}_antialias={}".format( jtu.dtype_str(dtype), method, antialias), "dtype": dtype, "method": method, "antialias": antialias} for dtype in inexact_dtypes for method in ["linear", "lanczos3", "lanczos5", "cubic"] for antialias in [True, False])) def testScaleAndTranslateDown(self, dtype, method, antialias): image_shape = [1, 6, 7, 1] target_shape = [1, 3, 3, 1] data = [ 51, 38, 32, 89, 41, 21, 97, 51, 33, 87, 89, 34, 21, 97, 43, 25, 25, 92, 41, 11, 84, 11, 55, 111, 23, 99, 50, 83, 13, 92, 52, 43, 90, 43, 14, 89, 71, 32, 23, 23, 35, 93 ] if antialias: expected_data = {} expected_data["linear"] = [ 43.5372, 59.3694, 53.6907, 49.3221, 56.8168, 55.4849, 0, 0, 0 ] expected_data["lanczos3"] = [ 43.2884, 57.9091, 54.6439, 48.5856, 58.2427, 53.7551, 0, 0, 0 ] expected_data["lanczos5"] = [ 43.9209, 57.6360, 54.9575, 48.9272, 58.1865, 53.1948, 0, 0, 0 ] expected_data["cubic"] = [ 42.9935, 59.1687, 54.2138, 48.2640, 58.2678, 54.4088, 0, 0, 0 ] else: expected_data = {} expected_data["linear"] = [ 43.6071, 89, 59, 37.1785, 27.2857, 58.3571, 0, 0, 0 ] expected_data["lanczos3"] = [ 44.1390, 87.8786, 63.3111, 25.1161, 20.8795, 53.6165, 0, 0, 0 ] expected_data["lanczos5"] = [ 44.8835, 85.5896, 66.7231, 16.9983, 19.8891, 47.1446, 0, 0, 0 ] expected_data["cubic"] = [ 43.6426, 88.8854, 60.6638, 31.4685, 22.1204, 58.3457, 0, 0, 0 ] x = np.array(data, dtype=dtype).reshape(image_shape) expected = np.array( expected_data[method], dtype=dtype).reshape(target_shape) scale_a = jnp.array([1.0, 0.35, 0.4, 1.0], dtype=jnp.float32) translation_a = jnp.array([0.0, 0.2, 0.1, 0.0], dtype=jnp.float32) output = image.scale_and_translate( x, target_shape, (0,1,2,3), scale_a, translation_a, method, antialias=antialias) self.assertAllClose(output, expected, atol=2e-03) # Tests that running with just a subset of dimensions that have non-trivial # scale and translation. output = image.scale_and_translate( x, target_shape, (1,2), scale_a[1:3], translation_a[1:3], method, antialias=antialias) self.assertAllClose(output, expected, atol=2e-03) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "antialias={}".format(antialias), "antialias": antialias} for antialias in [True, False])) def testScaleAndTranslateJITs(self, antialias): image_shape = [1, 6, 7, 1] target_shape = [1, 3, 3, 1] data = [ 51, 38, 32, 89, 41, 21, 97, 51, 33, 87, 89, 34, 21, 97, 43, 25, 25, 92, 41, 11, 84, 11, 55, 111, 23, 99, 50, 83, 13, 92, 52, 43, 90, 43, 14, 89, 71, 32, 23, 23, 35, 93 ] if antialias: expected_data = [ 43.5372, 59.3694, 53.6907, 49.3221, 56.8168, 55.4849, 0, 0, 0 ] else: expected_data = [43.6071, 89, 59, 37.1785, 27.2857, 58.3571, 0, 0, 0] x = jnp.array(data, dtype=jnp.float32).reshape(image_shape) expected = jnp.array(expected_data, dtype=jnp.float32).reshape(target_shape) scale_a = jnp.array([1.0, 0.35, 0.4, 1.0], dtype=jnp.float32) translation_a = jnp.array([0.0, 0.2, 0.1, 0.0], dtype=jnp.float32) def jit_fn(in_array, s, t): return jax.image.scale_and_translate( in_array, target_shape, (0, 1, 2, 3), s, t, "linear", antialias, precision=jax.lax.Precision.HIGHEST) output = jax.jit(jit_fn)(x, scale_a, translation_a) self.assertAllClose(output, expected, atol=2e-03) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "antialias={}".format(antialias), "antialias": antialias} for antialias in [True, False])) def testScaleAndTranslateGradFinite(self, antialias): image_shape = [1, 6, 7, 1] target_shape = [1, 3, 3, 1] data = [ 51, 38, 32, 89, 41, 21, 97, 51, 33, 87, 89, 34, 21, 97, 43, 25, 25, 92, 41, 11, 84, 11, 55, 111, 23, 99, 50, 83, 13, 92, 52, 43, 90, 43, 14, 89, 71, 32, 23, 23, 35, 93 ] x = jnp.array(data, dtype=jnp.float32).reshape(image_shape) scale_a = jnp.array([1.0, 0.35, 0.4, 1.0], dtype=jnp.float32) translation_a = jnp.array([0.0, 0.2, 0.1, 0.0], dtype=jnp.float32) def scale_fn(s): return jnp.sum(jax.image.scale_and_translate( x, target_shape, (0, 1, 2, 3), s, translation_a, "linear", antialias, precision=jax.lax.Precision.HIGHEST)) scale_out = jax.grad(scale_fn)(scale_a) self.assertTrue(jnp.all(jnp.isfinite(scale_out))) def translate_fn(t): return jnp.sum(jax.image.scale_and_translate( x, target_shape, (0, 1, 2, 3), scale_a, t, "linear", antialias, precision=jax.lax.Precision.HIGHEST)) translate_out = jax.grad(translate_fn)(translation_a) self.assertTrue(jnp.all(jnp.isfinite(translate_out))) def testResizeWithUnusualShapes(self): x = jnp.ones((3, 4)) # Array shapes are accepted self.assertEqual((10, 17), jax.image.resize(x, jnp.array((10, 17)), "nearest").shape) with self.assertRaises(TypeError): # Fractional shapes are disallowed jax.image.resize(x, [10.5, 17], "bicubic")
class SparsifyTest(jtu.JaxTestCase): @classmethod def sparsify(cls, f): return sparsify(f, use_tracer=False) def testTracerIsInstanceCheck(self): @self.sparsify def f(x): self.assertNotIsInstance(x, SparseTracer) f(jnp.arange(5)) def assertBcooIdentical(self, x, y): self.assertIsInstance(x, BCOO) self.assertIsInstance(y, BCOO) self.assertEqual(x.shape, y.shape) self.assertArraysEqual(x.data, y.data) self.assertArraysEqual(x.indices, y.indices) def testArgSpec(self): X = jnp.arange(5) X_BCOO = BCOO.fromdense(X) args = (X, X_BCOO, X_BCOO) # Independent index spenv = SparseEnv() argspecs = arrays_to_argspecs(spenv, args) self.assertEqual(len(argspecs), len(args)) self.assertEqual(spenv.size(), 5) self.assertEqual(argspecs, (ArgSpec( X.shape, 0, None), ArgSpec(X.shape, 1, 2), ArgSpec(X.shape, 3, 4))) args_out = argspecs_to_arrays(spenv, argspecs) self.assertEqual(len(args_out), len(args)) self.assertArraysEqual(args[0], args_out[0]) self.assertBcooIdentical(args[1], args_out[1]) self.assertBcooIdentical(args[2], args_out[2]) # Shared index argspecs = (ArgSpec(X.shape, 0, None), ArgSpec(X.shape, 1, 2), ArgSpec(X.shape, 3, 2)) spenv = SparseEnv([X, X_BCOO.data, X_BCOO.indices, X_BCOO.data]) args_out = argspecs_to_arrays(spenv, argspecs) self.assertEqual(len(args_out), len(args)) self.assertArraysEqual(args[0], args_out[0]) self.assertBcooIdentical(args[1], args_out[1]) self.assertBcooIdentical(args[2], args_out[2]) def testUnitHandling(self): x = BCOO.fromdense(jnp.arange(5)) f = jit(lambda x, y: x) result = self.sparsify(jit(f))(x, core.unit) self.assertBcooIdentical(result, x) def testDropvar(self): def inner(x): return x * 2, x * 3 def f(x): _, y = jit(inner)(x) return y * 4 x_dense = jnp.arange(5) x_sparse = BCOO.fromdense(x_dense) self.assertArraysEqual( self.sparsify(f)(x_sparse).todense(), f(x_dense)) def testPytreeInput(self): f = self.sparsify(lambda x: x) args = (jnp.arange(4), BCOO.fromdense(jnp.arange(4))) out = f(args) self.assertLen(out, 2) self.assertArraysEqual(args[0], out[0]) self.assertBcooIdentical(args[1], out[1]) def testSparsify(self): M_dense = jnp.arange(24).reshape(4, 6) M_sparse = BCOO.fromdense(M_dense) v = jnp.arange(M_dense.shape[0]) @self.sparsify def func(x, v): return -jnp.sin(jnp.pi * x).T @ (v + 1) result_dense = func(M_dense, v) result_sparse = func(M_sparse, v) self.assertAllClose(result_sparse, result_dense) def testSparsifyWithConsts(self): M_dense = jnp.arange(24).reshape(4, 6) M_sparse = BCOO.fromdense(M_dense) @self.sparsify def func(x): return jit(lambda x: jnp.sum(x, 1))(x) result_dense = func(M_dense) result_sparse = func(M_sparse) self.assertAllClose(result_sparse.todense(), result_dense) def testSparseMatmul(self): X = jnp.arange(16).reshape(4, 4) Xsp = BCOO.fromdense(X) Y = jnp.ones(4) Ysp = BCOO.fromdense(Y) # dot_general result_sparse = self.sparsify(operator.matmul)(Xsp, Y) result_dense = operator.matmul(X, Y) self.assertAllClose(result_sparse, result_dense) # rdot_general result_sparse = self.sparsify(operator.matmul)(Y, Xsp) result_dense = operator.matmul(Y, X) self.assertAllClose(result_sparse, result_dense) # spdot_general result_sparse = self.sparsify(operator.matmul)(Xsp, Ysp) result_dense = operator.matmul(X, Y) self.assertAllClose(result_sparse.todense(), result_dense) def testSparseAdd(self): x = BCOO.fromdense(jnp.arange(5)) y = BCOO.fromdense(2 * jnp.arange(5)) # Distinct indices out = self.sparsify(operator.add)(x, y) self.assertEqual(out.nse, 8) # uses concatenation. self.assertArraysEqual(out.todense(), 3 * jnp.arange(5)) # Shared indices – requires lower level call argspecs = [ArgSpec(x.shape, 1, 0), ArgSpec(y.shape, 2, 0)] spenv = SparseEnv([x.indices, x.data, y.data]) result = sparsify_raw(operator.add)(spenv, *argspecs) args_out, _ = result out, = argspecs_to_arrays(spenv, args_out) self.assertAllClose(out.todense(), x.todense() + y.todense()) def testSparseMul(self): x = BCOO.fromdense(jnp.arange(5)) y = BCOO.fromdense(2 * jnp.arange(5)) # Scalar multiplication out = self.sparsify(operator.mul)(x, 2.5) self.assertArraysEqual(out.todense(), x.todense() * 2.5) # Shared indices – requires lower level call argspecs = [ArgSpec(x.shape, 1, 0), ArgSpec(y.shape, 2, 0)] spenv = SparseEnv([x.indices, x.data, y.data]) result = sparsify_raw(operator.mul)(spenv, *argspecs) args_out, _ = result out, = argspecs_to_arrays(spenv, args_out) self.assertAllClose(out.todense(), x.todense() * y.todense()) def testSparseSubtract(self): x = BCOO.fromdense(3 * jnp.arange(5)) y = BCOO.fromdense(jnp.arange(5)) # Distinct indices out = self.sparsify(operator.sub)(x, y) self.assertEqual(out.nse, 8) # uses concatenation. self.assertArraysEqual(out.todense(), 2 * jnp.arange(5)) # Shared indices – requires lower level call argspecs = [ArgSpec(x.shape, 1, 0), ArgSpec(y.shape, 2, 0)] spenv = SparseEnv([x.indices, x.data, y.data]) result = sparsify_raw(operator.sub)(spenv, *argspecs) args_out, _ = result out, = argspecs_to_arrays(spenv, args_out) self.assertAllClose(out.todense(), x.todense() - y.todense()) def testSparseSum(self): x = jnp.arange(20).reshape(4, 5) xsp = BCOO.fromdense(x) def f(x): return x.sum(), x.sum(0), x.sum(1), x.sum((0, 1)) result_dense = f(x) result_sparse = self.sparsify(f)(xsp) assert len(result_dense) == len(result_sparse) for res_dense, res_sparse in zip(result_dense, result_sparse): if isinstance(res_sparse, BCOO): res_sparse = res_sparse.todense() self.assertArraysAllClose(res_dense, res_sparse) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_dimensions={}_nbatch={}, ndense={}".format( jtu.format_shape_dtype_string(shape, np.float32), dimensions, n_batch, n_dense), "shape": shape, "dimensions": dimensions, "n_batch": n_batch, "n_dense": n_dense } for shape, dimensions in [ [(1, ), (0, )], [(1, ), (-1, )], [(2, 1, 4), (1, )], [(2, 1, 3, 1), (1, )], [(2, 1, 3, 1), (1, 3)], [(2, 1, 3, 1), (3, )], ] for n_batch in range(len(shape) + 1) for n_dense in range(len(shape) - n_batch + 1))) def testSparseSqueeze(self, shape, dimensions, n_batch, n_dense): rng = jtu.rand_default(self.rng()) M_dense = rng(shape, np.float32) M_sparse = BCOO.fromdense(M_dense, n_batch=n_batch, n_dense=n_dense) func = self.sparsify(partial(lax.squeeze, dimensions=dimensions)) result_dense = func(M_dense) result_sparse = func(M_sparse).todense() self.assertAllClose(result_sparse, result_dense) def testSparseWhileLoop(self): def cond_fun(params): i, A = params return i < 5 def body_fun(params): i, A = params return i + 1, 2 * A def f(A): return lax.while_loop(cond_fun, body_fun, (0, A)) A = jnp.arange(4) out_dense = f(A) Asp = BCOO.fromdense(A) out_sparse = self.sparsify(f)(Asp) self.assertEqual(len(out_dense), 2) self.assertEqual(len(out_sparse), 2) self.assertArraysEqual(out_dense[0], out_dense[0]) self.assertArraysEqual(out_dense[1], out_sparse[1].todense()) def testSparseWhileLoopDuplicateIndices(self): def cond_fun(params): i, A, B = params return i < 5 def body_fun(params): i, A, B = params # TODO(jakevdp): track shared indices through while loop & use this # version of the test, which requires shared indices in order for # the nse of the result to remain the same. # return i + 1, A, A + B # This version is fine without shared indices, and tests that we're # flattening non-shared indices consistently. return i + 1, B, A def f(A): return lax.while_loop(cond_fun, body_fun, (0, A, A)) A = jnp.arange(4).reshape((2, 2)) out_dense = f(A) Asp = BCOO.fromdense(A) out_sparse = self.sparsify(f)(Asp) self.assertEqual(len(out_dense), 3) self.assertEqual(len(out_sparse), 3) self.assertArraysEqual(out_dense[0], out_dense[0]) self.assertArraysEqual(out_dense[1], out_sparse[1].todense()) self.assertArraysEqual(out_dense[2], out_sparse[2].todense()) def testSparsifyDenseXlaCall(self): # Test handling of dense xla_call within jaxpr interpreter. out = self.sparsify(jit(lambda x: x + 1))(0.0) self.assertEqual(out, 1.0) def testSparsifySparseXlaCall(self): # Test sparse lowering of XLA call def func(M): return 2 * M M = jnp.arange(6).reshape(2, 3) Msp = BCOO.fromdense(M) out_dense = func(M) out_sparse = self.sparsify(jit(func))(Msp) self.assertArraysEqual(out_dense, out_sparse.todense()) def testSparseForiLoop(self): def func(M, x): body_fun = lambda i, val: (M @ val) / M.shape[1] return lax.fori_loop(0, 2, body_fun, x) x = jnp.arange(5.0) M = jnp.arange(25).reshape(5, 5) M_bcoo = BCOO.fromdense(M) result_dense = func(M, x) result_sparse = self.sparsify(func)(M_bcoo, x) self.assertArraysAllClose(result_dense, result_sparse) def testSparseCondSimple(self): def func(x): return lax.cond(False, lambda x: x, lambda x: 2 * x, x) x = jnp.arange(5.0) result_dense = func(x) x_bcoo = BCOO.fromdense(x) result_sparse = self.sparsify(func)(x_bcoo) self.assertArraysAllClose(result_dense, result_sparse.todense()) def testSparseCondMismatchError(self): @self.sparsify def func(x, y): return lax.cond(False, lambda x: x[0], lambda x: x[1], (x, y)) x = jnp.arange(5.0) y = jnp.arange(5.0) x_bcoo = BCOO.fromdense(x) y_bcoo = BCOO.fromdense(y) func(x, y) # No error func(x_bcoo, y_bcoo) # No error with self.assertRaisesRegex( TypeError, "sparsified true_fun and false_fun output.*"): func(x_bcoo, y) def testToDense(self): M = jnp.arange(4) Msp = BCOO.fromdense(M) @self.sparsify def func(M): return todense(M) + 1 self.assertArraysEqual(func(M), M + 1) self.assertArraysEqual(func(Msp), M + 1) self.assertArraysEqual(jit(func)(M), M + 1) self.assertArraysEqual(jit(func)(Msp), M + 1) def testWeakTypes(self): # Regression test for https://github.com/google/jax/issues/8267 M = jnp.arange(12, dtype='int32').reshape(3, 4) Msp = BCOO.fromdense(M) self.assertArraysEqual( operator.mul(2, M), self.sparsify(operator.mul)(2, Msp).todense(), check_dtypes=True, )
class Jax2TfTest(tf_test_util.JaxToTfTestCase): def test_basics(self): f_jax = lambda x: jnp.sin(jnp.cos(x)) _, res_tf = self.ConvertAndCompare(f_jax, 0.7) def test_input_output_naming(self): @jax2tf.convert def f(xs, y): return [jnp.add(x, y) for x in xs] @tf.function(autograph=False) def u(xs, y): xs = tf.nest.map_structure(tf.convert_to_tensor, xs) with tf.GradientTape() as tape: tf.nest.map_structure(tape.watch, xs) y = f(xs, y) tape.gradient(y, xs) return y cf = u.get_concrete_function([1., 2., 3.], 4.) g = cf.graph g.get_operation_by_name("jax2tf_arg_0") g.get_operation_by_name("jax2tf_arg_1") g.get_operation_by_name("jax2tf_arg_2") g.get_operation_by_name("jax2tf_arg_3") g.get_operation_by_name("jax2tf_out") g.get_operation_by_name("jax2tf_out_1") g.get_operation_by_name("jax2tf_out_2") with self.assertRaises(KeyError): g.get_operation_by_name("jax2tf_arg_4") with self.assertRaises(KeyError): g.get_operation_by_name("jax2tf_out_3") g.get_operation_by_name("jax2tf_vjp/jax2tf_arg_0") g.get_operation_by_name("jax2tf_vjp/jax2tf_arg_1") g.get_operation_by_name("jax2tf_vjp/jax2tf_arg_2") g.get_operation_by_name("jax2tf_vjp/jax2tf_arg_3") g.get_operation_by_name("jax2tf_vjp/jax2tf_out") g.get_operation_by_name("jax2tf_vjp/jax2tf_out_1") g.get_operation_by_name("jax2tf_vjp/jax2tf_out_2") g.get_operation_by_name("jax2tf_vjp/jax2tf_out_3") def test_pytrees(self): # Take and return pytrees def f_jax(x: Tuple[float, Dict[str, float]]) -> Tuple[float, Dict[str, float]]: x_a, x_dict = x return x_a * 2., {k: v * 3. for k, v in x_dict.items()} x = (.7, {"a": .8, "b": .9}) self.ConvertAndCompare(f_jax, x) def test_variable_input(self): f_jax = lambda x: jnp.sin(jnp.cos(x)) f_tf = jax2tf.convert(f_jax) v = tf.Variable(0.7, dtype=jax2tf.dtype_of_val(0.7)) self.assertIsInstance(f_tf(v), tf.Tensor) self.assertAllClose(f_jax(0.7), f_tf(v)) def test_jit(self): f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x))) self.ConvertAndCompare(f_jax, 0.7) def test_nested_jit(self): f_jax = jax.jit(lambda x: jnp.sin(jax.jit(jnp.cos)(x))) f_tf = jax2tf.convert(f_jax) np.testing.assert_allclose(f_jax(0.7), f_tf(0.7)) def test_nested_jit_is_compiled(self): # Check that nested jax.jit are compiled with tf.function(jit_compile=True) # We do this by looking for the _XlaMustCompile attribute in the function graph def has_xla_must_compile(f_tf, x): f_conc = tf.function(f_tf, autograph=True).get_concrete_function(tf.convert_to_tensor(x)) for n in f_conc.graph._nodes_by_id.values(): try: n.get_attr("_XlaMustCompile") return True except ValueError: continue return False x = np.array(0.7) f_no_jit = lambda x: x self.assertFalse(has_xla_must_compile(jax2tf.convert(f_no_jit), x)) f_jit = lambda x: jax.jit(jnp.sin)(x) # TODO(b/207464757): TF compilation is disabled self.assertFalse(has_xla_must_compile(jax2tf.convert(f_jit), x)) def test_converts_jax_arrays(self): f_tf = tf.function(lambda x: x) self.assertEqual(f_tf(jnp.zeros([])).numpy(), 0.) self.assertEqual(f_tf(jnp.ones([])).numpy(), 1.) f_tf = tf.function(lambda x: x + x) self.assertEqual(f_tf(jnp.ones([])).numpy(), 2.) # Test with ShardedDeviceArray. n = jax.local_device_count() mk_sharded = lambda f: jax.pmap(lambda x: x)(f([n])) f_tf = tf.function(lambda x: x) self.assertAllClose(f_tf(mk_sharded(jnp.zeros)).numpy(), jnp.zeros([n])) self.assertAllClose(f_tf(mk_sharded(jnp.ones)).numpy(), jnp.ones([n])) @jtu.skip_on_devices("gpu") def test_bfloat16_passed_by_tf(self): f_jax = lambda a, b: a + b f_tf = tf.function(jax2tf.convert(f_jax), input_signature=[tf.TensorSpec([512, 512], tf.bfloat16), tf.TensorSpec([512, 512], tf.bfloat16)]) self.assertIsNotNone(f_tf.get_concrete_function()) @jtu.skip_on_devices("gpu") def test_bfloat16_returned_by_jax(self): f_jax = lambda a, b: (a + b).astype(jnp.bfloat16) f_tf = jax2tf.convert(f_jax) self.assertEqual(f_tf(1., 2.).dtype, tf.bfloat16) @jtu.skip_on_devices("gpu") def test_bfloat16_tf_grad(self): f_jax = lambda a, b: a + b def _tf_grad(a, b): with tf.GradientTape() as tape: tape.watch(a) result = jax2tf.convert(f_jax)(a, b) return result, tape.gradient(result, a) f_tf = tf.function(_tf_grad, input_signature=[tf.TensorSpec([512, 512], tf.bfloat16), tf.TensorSpec([512, 512], tf.bfloat16)]) self.assertIsNotNone(f_tf.get_concrete_function()) @parameterized.named_parameters(jtu.cases_from_list( dict(testcase_name=f"_dtype={dtype.__name__}_function={with_function}", dtype=dtype, with_function=with_function) for dtype in [np.int64, np.float64] for with_function in [True, False])) def test_converts_64bit(self, dtype=np.int64, with_function=False): if not config.jax_enable_x64: self.skipTest("requires x64 mode") big_const = np.full((5,), 2 ** 33, dtype=dtype) self.ConvertAndCompare(jnp.sin, big_const) f_conv = jax2tf.convert(jnp.sin) if with_function: f_conv = tf.function(f_conv) # We check also when we pass tf.Variable or tf.Tensor into the # converted function self.assertAllClose(jnp.sin(big_const), f_conv(tf.Variable(big_const))) self.assertAllClose(jnp.sin(big_const), f_conv(tf.constant(big_const))) def test_64bit_behavior_enable_x64(self): if not config.jax_enable_x64: self.skipTest("requires x64 mode") # JAX and TF have different default float types if JAX_ENABLE_X64=1 self.assertEqual(tf.math.sin(0.7).dtype, tf.float32) self.assertEqual(jnp.sin(0.7).dtype, jnp.float64) # jax2tf.convert has the same behavior as JAX self.assertEqual(jax2tf.convert(jnp.sin)(0.7).dtype, tf.float64) def test_64bit_behavior_not_enable_x64(self): if config.jax_enable_x64: self.skipTest("requires not x64 mode") # JAX and TF have same default float types if JAX_ENABLE_X64=1 self.assertEqual(tf.math.sin(0.7).dtype, tf.float32) self.assertEqual(jnp.sin(0.7).dtype, jnp.float32) # Except that JAX forces values to 32-bit self.assertEqual(jnp.sin(np.float64(0.7)).dtype, jnp.float32) # jax2tf.convert has the same behavior as JAX self.assertEqual(jax2tf.convert(jnp.sin)(0.7).dtype, tf.float32) self.assertEqual(jax2tf.convert(jnp.sin)(np.float64(0.7)).dtype, tf.float32) def test_function(self): f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x))) self.ConvertAndCompare(f_jax, 0.7) @parameterized.named_parameters(jtu.cases_from_list( dict(testcase_name=f"function={with_function}", with_function=with_function) for with_function in [False, True])) def test_gradients_disabled(self, with_function=False): f_tf = jax2tf.convert(jnp.tan, with_gradient=False) if with_function: f_tf = tf.function(f_tf, autograph=False) x = tf.ones([]) # With tf.function the error is raised when we evaluate f_tf(x), in # eager mode when we evaluate tape.gradient(y, x) with self.assertRaisesRegex(LookupError, "Gradient explicitly disabled.*The jax2tf-converted function does not support gradients"): with tf.GradientTape() as tape: tape.watch(x) y = f_tf(x) _ = tape.gradient(y, x) @parameterized.named_parameters(jtu.cases_from_list( dict(testcase_name=f"function={with_function}", with_function=with_function) for with_function in [False, True])) def test_gradients(self, with_function=True): def f(x, y): return x * x, x * y f_tf = jax2tf.convert(f, with_gradient=True) if with_function: f_tf = tf.function(f_tf, autograph=False) default_float_type = jax2tf.dtype_of_val(4.) x = tf.Variable(4., dtype=jax2tf.dtype_of_val(4.)) y = tf.Variable(5., dtype=default_float_type) with tf.GradientTape(persistent=True) as tape: u, v = f_tf(x, y) self.assertAllClose(2. * 4., tape.gradient(u, x)) self.assertAllClose(0., tape.gradient(u, y)) self.assertAllClose(5., tape.gradient(v, x)) self.assertAllClose(4., tape.gradient(v, y)) @parameterized.named_parameters(jtu.cases_from_list( dict(testcase_name=f"function={with_function}", with_function=with_function) for with_function in [False, True])) def test_gradients_pytree(self, with_function=True): def f(xy: Tuple[float, float]) -> Dict[str, float]: x, y = xy return dict(one=x * x, two=x * y) f_tf = jax2tf.convert(f, with_gradient=True) if with_function: f_tf = tf.function(f_tf, autograph=False) default_float_dtype = jax2tf.dtype_of_val(4.) x = tf.Variable(4., dtype=default_float_dtype) y = tf.Variable(5., dtype=default_float_dtype) with tf.GradientTape(persistent=True) as tape: uv = f_tf((x, y)) self.assertAllClose(2. * 4., tape.gradient(uv["one"], x)) self.assertAllClose(0., tape.gradient(uv["one"], y)) self.assertAllClose(5., tape.gradient(uv["two"], x)) self.assertAllClose(4., tape.gradient(uv["two"], y)) @parameterized.named_parameters(jtu.cases_from_list( dict(testcase_name=f"function={with_function}", with_function=with_function) for with_function in [False, True])) def test_gradients_with_ordered_dict_input(self, with_function=True): def f(inputs): out = 0.0 for v in inputs.values(): out += jnp.sum(v) return out f_tf = jax2tf.convert(f, with_gradient=True) if with_function: f_tf = tf.function(f_tf, autograph=False) default_float_type = jax2tf.dtype_of_val(4.) inputs = OrderedDict() x = tf.Variable([4.], dtype=default_float_type) y = tf.Variable([4., 5.], dtype=default_float_type) inputs = OrderedDict() inputs['r'] = x inputs['d'] = y with tf.GradientTape(persistent=True) as tape: u = f_tf(inputs) self.assertAllClose(np.array([1.]), tape.gradient(u, x).numpy()) self.assertAllClose(np.array([1., 1.]), tape.gradient(u, y).numpy()) @parameterized.named_parameters(jtu.cases_from_list( dict(testcase_name=f"function={with_function}", with_function=with_function) for with_function in [False, True])) def test_gradients_with_custom_jvp(self, with_function=True): """Check gradients, for a function with custom JVP.""" @jax.custom_jvp def f(x): return x * x @f.defjvp def f_jvp(primals, tangents): # 3 * x * x_t x, = primals x_dot, = tangents primal_out = f(x) tangent_out = 3. * x * x_dot return primal_out, tangent_out self.assertAllClose(4. * 4., f(4.)) self.assertAllClose(3. * 4., jax.grad(f)(4.)) f_tf = jax2tf.convert(f, with_gradient=True) if with_function: f_tf = tf.function(f_tf, autograph=False) self.assertAllClose(4. * 4., f_tf(4.)) x = tf.Variable(4., dtype=jax2tf.dtype_of_val(4.)) with tf.GradientTape() as tape: tape.watch(x) y = f_tf(x) self.assertAllClose(4. * 4., y) self.assertAllClose(3. * 4., tape.gradient(y, x)) @parameterized.named_parameters(jtu.cases_from_list( dict(testcase_name=f"function={with_function}", with_function=with_function) for with_function in [False, True])) def test_gradients_with_custom_vjp(self, with_function=True): """Check gradients, for a function with custom VJP.""" @jax.custom_vjp def f(x): return x * x # f_fwd: a -> (b, residual) def f_fwd(x): return f(x), 3. * x # f_bwd: (residual, CT b) -> [CT a] def f_bwd(residual, ct_b): return residual * ct_b, f.defvjp(f_fwd, f_bwd) self.assertAllClose(4. * 4., f(4.)) self.assertAllClose(3. * 4., jax.grad(f)(4.)) f_tf = jax2tf.convert(f, with_gradient=True) if with_function: f_tf = tf.function(f_tf, autograph=False) self.assertAllClose(4. * 4., f_tf(4.)) x = tf.Variable(4., dtype=jax2tf.dtype_of_val(4.)) with tf.GradientTape() as tape: tape.watch(x) y = f_tf(x) self.assertAllClose(4. * 4., y) self.assertAllClose(3. * 4., tape.gradient(y, x)) def test_gradient_with_float0_intermediate(self): # Gradient over integer-argument functions def f(x, y): # x is an int, y is a float return 2 * x + y def g(x): # x: f32 return 2. * f(3 * x.astype("int32"), x * 4.) x = 2. grad_g = jax.grad(g) self.ConvertAndCompare(grad_g, x) def test_gradient_with_float0_result(self): # Gradient over integer-argument functions, with float0 result def f(x, y): # x is an int, y is a float return 2 * x + y def g(x): # x: i32 return jnp.sum(2. * f(3 * x, 4. * jnp.array(x, jnp.dtype("float32")))) grad_g = jax.grad(g, allow_int=True) x = 2 d_dx_jax = grad_g(x) d_dx_tf = jax2tf.convert(grad_g)(x) self.assertEqual(d_dx_jax.dtype, dtypes.float0) self.assertAllClose(jnp.zeros(np.shape(d_dx_jax), np.int32), d_dx_tf.numpy()) shape = (3, 4) x = np.ones(shape, dtype=np.int32) d_dx_jax = grad_g(x) d_dx_tf = jax2tf.convert(grad_g)(x) self.assertEqual(d_dx_jax.dtype, dtypes.float0) self.assertAllClose(jnp.zeros(np.shape(d_dx_jax), np.int32), d_dx_tf.numpy()) @parameterized.named_parameters(jtu.cases_from_list( dict(testcase_name=f"function={with_function}", with_function=with_function) for with_function in [False, True])) def test_gradients_unused_argument_readme(self, with_function=True): # x2 and x3 are not used. x3 has integer type. def fn(x0, x1, x2, x3): return x0 * 0. + x2 * 2. xs = [tf.Variable(x) for x in [10., 11., 12., 13]] with tf.GradientTape(persistent=True) as tape: res = fn(*xs) g_tf_native = tape.gradient(res, xs) self.assertAllClose(g_tf_native[0].numpy(), np.float32(0.)) self.assertIsNone(g_tf_native[1]) self.assertAllClose(g_tf_native[2].numpy(), np.float32(2.)) self.assertIsNone(g_tf_native[3]) g_tf_native_0 = tape.gradient(res, xs, unconnected_gradients=tf.UnconnectedGradients.ZERO) self.assertAllClose(g_tf_native_0[0].numpy(), np.float32(0.)) self.assertAllClose(g_tf_native_0[1].numpy(), np.float32(0.)) self.assertAllClose(g_tf_native_0[2].numpy(), np.float32(2.)) self.assertAllClose(g_tf_native_0[3].numpy(), np.int32(0)) # Now with jax2tf.convert with tf.GradientTape(persistent=True) as tape: conv_fn = jax2tf.convert(fn, with_gradient=True) if with_function: conv_fn = tf.function(conv_fn, autograph=False) res = conv_fn(*xs) g_jax2tf = tape.gradient(res, xs) # Returns: 0., 0., 2., None # Note that the gradient for x1 is 0. self.assertAllClose(g_jax2tf[0].numpy(), np.float32(0.)) self.assertAllClose(g_jax2tf[1].numpy(), np.float32(0.)) self.assertAllClose(g_jax2tf[2].numpy(), np.float32(2.)) self.assertIsNone(g_jax2tf[3]) g_jax2tf = tape.gradient(res, xs, unconnected_gradients=tf.UnconnectedGradients.ZERO) self.assertAllClose(g_jax2tf[0].numpy(), np.float32(0.)) self.assertAllClose(g_jax2tf[1].numpy(), np.float32(0.)) self.assertAllClose(g_jax2tf[2].numpy(), np.float32(2.)) self.assertAllClose(g_jax2tf[3].numpy(), np.int32(0)) @parameterized.named_parameters(jtu.cases_from_list( dict(testcase_name=f"function={with_function}", with_function=with_function) for with_function in [False, True])) def test_gradients_int_argument(self, with_function=True): # https://github.com/google/jax/issues/6975 # Also issue #6975. # An expanded version of test_gradients_unused_argument state = dict( float_used=np.array([0.7, 0.9], dtype=np.float32), float_passthrough=np.float16(1.), float_unused=np.array([1.1, 2.2, 3.3], dtype=np.float32), int_used=np.int16(5), int_passthrough=np.int8(7), int_unused=np.array([1, 2, 3], dtype=np.uint32), bool_used=np.array([True, False, False, True], dtype=np.bool_), bool_passthrough=np.array([True, False, False, True, False], dtype=np.bool_), bool_unused=np.array([[True, False], [False, True]], dtype=np.bool_), ) def jax_f(state): res = dict(state, float_used=2. * state["float_used"], int_used=3 * state["int_used"], bool_used=(state["bool_used"] == state["bool_used"])) del res["float_unused"] del res["int_unused"] del res["bool_unused"] return res args = (state,) res_jax = jax_f(*args) # Native JAX AD vjp_jax_fun, args_vjp = tf_test_util.TransformJaxVJP(jax_f, args, res_jax) grad_jax, = vjp_jax_fun(*args_vjp) def compare_with_overrides(*, what, expected, **expected_overrides): what_keys = set(what.keys()) expected_keys = set(expected.keys()) self.assertEqual(what_keys, expected_keys) for k, w in what.items(): e = expected[k] if k in expected_overrides: if expected_overrides[k] == "ZERO": e = np.zeros_like(w) elif expected_overrides[k] == "ZERO_INT32": e = np.zeros(np.shape(w), dtype=np.int32) elif expected_overrides[k] == "ONE": e = np.ones_like(w) else: e = expected_overrides[k] if e is None: self.assertIsNone(w, msg=k) else: self.assertIsNotNone(w, msg=k) w = w.numpy() if isinstance(w, tf.Tensor) else e e = e.numpy() if isinstance(e, tf.Tensor) else e try: self.assertAllClose(e, w, err_msg=k) except: print(f"Failed at {k}") raise # compare_with_overrides(g_jax, {}, # bool_passthrough=np.zeros(state["bool_passthrough"].shape, dtype=dtypes.float0), # bool_unused=np.zeros(state["bool_unused"].shape, dtype=dtypes.float0), # bool_used=np.zeros(state["bool_used"].shape, dtype=dtypes.float0), # float_passthrough=np.ones_like(state["float_passthrough"]), # float_unused=np.zeros_like(state["float_unused"]), # float_used=np.ones_like(state["float_used"]) * np.array(2., dtype=state["float_used"].dtype), # int_passthrough=np.zeros(state["int_passthrough"].shape, dtype=dtypes.float0), # int_unused=np.zeros(state["int_unused"].shape, dtype=dtypes.float0), # int_used=np.zeros(state["int_used"].shape, dtype=dtypes.float0)) # Now native TF gradients, only to test how native TF AD works _, (grad_tf_0,) = tf_test_util.ComputeTfValueAndGrad( jax_f, args, unconnected_gradients=tf.UnconnectedGradients.ZERO) compare_with_overrides(what=grad_tf_0, expected=grad_jax, float_unused="ZERO", bool_used="ZERO", bool_passthrough="ONE", bool_unused="ZERO", int_used="ZERO", int_passthrough="ONE", int_unused="ZERO") _, (grad_tf_None,) = tf_test_util.ComputeTfValueAndGrad( jax_f, args, unconnected_gradients=tf.UnconnectedGradients.NONE) compare_with_overrides(what=grad_tf_None, expected=grad_tf_0, float_unused=None, int_used=None, int_unused=None, bool_used=None, bool_unused=None) f_tf_jax = jax2tf.convert(jax_f) if with_function: f_tf_jax = tf.function(f_tf_jax, autograph=False) _, (grad_tf_jax_0,) = tf_test_util.ComputeTfValueAndGrad(f_tf_jax, args) # Same results as TF native AD with tf.UnconnectedGradients.ZERO compare_with_overrides(what=grad_tf_jax_0, expected=grad_tf_0, int_passthrough="ZERO", bool_passthrough="ZERO") _, (grad_tf_jax_None,) = tf_test_util.ComputeTfValueAndGrad( f_tf_jax, args, unconnected_gradients=tf.UnconnectedGradients.NONE) compare_with_overrides(what=grad_tf_jax_None, expected=grad_tf_0, int_used=None, int_passthrough=None, int_unused=None, bool_unused=None, bool_used=None, bool_passthrough=None) # Not convert the JAX gradient function tf_vjp_jax_fun = jax2tf.convert(vjp_jax_fun) grad_tf_vjp_jax, = tf_vjp_jax_fun(*args_vjp) compare_with_overrides(what=grad_tf_vjp_jax, expected=grad_tf_0, bool_passthrough="ZERO_INT32", bool_unused="ZERO_INT32", bool_used="ZERO_INT32", int_passthrough="ZERO_INT32", int_unused="ZERO_INT32", int_used="ZERO_INT32") def test_readme_gradient_int(self): x = np.array(2, dtype=np.int16) def f_jax(x): # x: int16 return x.astype(np.float32) * 2. print(jax.grad(f_jax, allow_int=True)(x)) # returns a special `float0`: array((b'',), dtype=[('float0', 'V')]) print(jax2tf.convert(jax.grad(f_jax, allow_int=True))(x)) # returns a 0 with same shape as x, but with dtype int32 def f_tf(x): # x: int16 return tf.cast(x, tf.float32) * 2. xv = tf.Variable(x) with tf.GradientTape(persistent=True) as tape: print(tape.gradient(f_tf(xv), xv)) # returns None print(tape.gradient(f_tf(xv), xv, unconnected_gradients=tf.UnconnectedGradients.ZERO)) # returns 0 with the same shape and dtype as x def test_convert_argument_non_callable_error(self): with self.assertRaisesRegex(TypeError, "Expected a callable value"): jax2tf.convert(5.) def test_convert_argument_non_tensor_error(self): with self.assertRaisesRegex(TypeError, "Argument.*should be NumPy array"): jax2tf.convert(lambda x: x)(lambda y: y) def test_argument_eager_tensor(self): x = jax2tf.convert(jnp.sin)(1.) jax2tf.convert(jnp.cos)(x) # No error def test_checkpoint_wrapper_types(self): m = tf.Module() m.a = [tf.Module(), tf.Module()] m.b = (tf.Module(), tf.Module()) m.c = {'a': tf.Module(), 'b': tf.Module()} self.assertNotEqual(type(m.a), list) self.assertNotEqual(type(m.b), tuple) self.assertNotEqual(type(m.c), dict) self.assertLen(jax.tree_leaves(m.a), 2) self.assertLen(jax.tree_leaves(m.b), 2) self.assertLen(jax.tree_leaves(m.c), 2) def test_custom_jvp(self): """Conversion of function with custom JVP""" @jax.custom_jvp def f(x): return x * x @f.defjvp def f_jvp(primals, tangents): x, = primals x_dot, = tangents primal_out = f(x) tangent_out = 3. * x * x_dot return primal_out, tangent_out arg = 0.7 self.TransformConvertAndCompare(f, arg, None) self.TransformConvertAndCompare(f, arg, "jvp") self.TransformConvertAndCompare(f, arg, "vmap") self.TransformConvertAndCompare(f, arg, "jvp_vmap") self.TransformConvertAndCompare(f, arg, "grad") self.TransformConvertAndCompare(f, arg, "grad_vmap") def test_custom_vjp(self): """Conversion of function with custom VJP""" @jax.custom_vjp def f(x): return x * x # f_fwd: a -> (b, residual) def f_fwd(x): return f(x), 3. * x # f_bwd: (residual, CT b) -> [CT a] def f_bwd(residual, ct_b): return residual * ct_b, f.defvjp(f_fwd, f_bwd) arg = 0.7 self.TransformConvertAndCompare(f, arg, None) self.TransformConvertAndCompare(f, arg, "vmap") self.TransformConvertAndCompare(f, arg, "grad") self.TransformConvertAndCompare(f, arg, "grad_vmap") @parameterized.named_parameters(jtu.cases_from_list( dict(testcase_name=f"_{flavor}", flavor=flavor) for flavor in ["old", "new"])) def test_remat(self, flavor="old"): def f(x1): x2 = jnp.sin(x1) x3 = jnp.sin(x2) x4 = jnp.sin(x3) return x4 remat_f = jax.remat(f) if flavor == "old" else ad_checkpoint.checkpoint(f) # The computation of grad_f computes "sin" 5 times, 3 for the forward pass # and then to rematerialize "x2" and "x3" in the backward pass. arg = np.array(3.) # Check that we have a Sin under a conditional f_tf = tf.function(jax2tf.convert(jax.grad(remat_f)), autograph=False) f_tf_graph = f_tf.get_concrete_function(arg).graph.as_graph_def() if flavor == "old": raise unittest.SkipTest("TODO: CSE widget not yet implemented for old-style remat") if jax.config.jax_remat_opt_barrier: self.assertRegex( str(f_tf_graph), r"remat_checkpoint_/XlaOptimizationBarrier") elif config.jax_experimental_name_stack: self.assertRegex(str(f_tf_graph), r'transpose/jax2tf_f_/jvp/checkpoint/remat_checkpoint_/cond/branch_1_fun/Sin') else: self.assertRegex(str(f_tf_graph), r'remat_checkpoint_/switch_case/indexed_case/Sin') def test_remat_free_var(self): def f(x): y = 2 * x @ad_checkpoint.checkpoint def g(): return y return g() arg = 3. self.TransformConvertAndCompare(f, arg, None) self.TransformConvertAndCompare(f, arg, "grad") def test_convert_nullary_func(self): # Even nullary functions are converted to TF (as opposed to constant-folded # in JAX prior to conversion). def f_jax(): return jnp.sin(1.) f_tf = tf.function(jax2tf.convert(f_jax), autograph=False) f_tf_graph = f_tf.get_concrete_function().graph.as_graph_def() self.assertIn('op: "Sin"', str(f_tf_graph)) def test_convert_of_nested_independent_jit(self): def func(x): def inner1(y): return x + y # The JIT does not have data dependency return jax.jit(inner1)(1.) jax2tf.convert(func)(2.) def test_convert_of_nested_dependent_jit(self): def func(x): def inner1(y): return x + y # The JIT does have data dependency return jax.jit(inner1)(x) jax2tf.convert(func)(2.) # No error def test_nested_convert_error(self): def outer(y): return jax2tf.convert(jnp.sin)(y) # Inner convert takes tracer args with self.assertRaisesRegex( ValueError, "convert must be used outside all JAX transformations"): jax2tf.convert(outer)(np.ones((4, ))) def test_nested_convert_error_non_tracer(self): """The inner convert takes non-tracer arguments""" def outer(y): sin_1 = jax2tf.convert(jnp.sin)(1.) # Inner convert takes non-tracer arg return y + sin_1 with self.assertRaisesRegex( ValueError, "convert must be used outside all JAX transformations"): jax2tf.convert(outer)(2.) @parameterized.named_parameters(jtu.cases_from_list( dict(testcase_name=f"_{transform}", transform=transform) for transform in ["jit", "jvp", "grad", "vmap"])) def test_convert_under_transform_error(self, transform="vmap"): def outer(y): return jax2tf.convert(jnp.sin)(y) # Inner convert takes tracer args with self.assertRaisesRegex( ValueError, "convert must be used outside all JAX transformations"): self.TransformConvertAndCompare(outer, np.ones((4,)), transform) @parameterized.named_parameters(jtu.cases_from_list( dict(testcase_name=f"_{transform}", transform=transform) for transform in ["jit", "jvp", "grad", "vmap"])) def test_convert_under_transform_error_non_tracer(self, transform="vmap"): def outer(y): sin_1 = jax2tf.convert(jnp.sin)(1.) # Inner convert takes non-tracer arg return y + sin_1 with self.assertRaisesRegex( ValueError, "convert must be used outside all JAX transformations"): self.TransformConvertAndCompare(outer, np.ones((4,)), transform) def test_name_scope(self): @tf.function(autograph=False) def run(): @jax.named_call def my_test_function(x): return x * x def caller(x): return my_test_function(jnp.sin(x)) out = jax2tf.convert(caller, with_gradient=False)(2.) # When we use `with_gradient=False` the raw output of `caller` is passed # through a `tf.raw_ops.PreventGradient` and a `tf.identity`, clobbering # the name scope of the `mul` op. We need to get the grandparent of the # `out` tensor to see the name scope of the result of the `mul`. grandparent_op = out.op.inputs[0].op.inputs[0] self.assertIn("my_test_function", grandparent_op.name) return out run() def test_bfloat16_constant(self): # Re: https://github.com/google/jax/issues/3942 def jax_fn_scalar(x): x = x.astype(jnp.bfloat16) x *= 2. return x def jax_fn_array(x): x = x.astype(jnp.bfloat16) x *= np.array([1.5, 2.5, 3.5], jnp.bfloat16) return x tf_fn_scalar = jax2tf.convert(jax_fn_scalar) self.assertAllClose(tf_fn_scalar(1.375).numpy(), jnp.bfloat16(2.750)) tf_fn_array = jax2tf.convert(jax_fn_array) self.assertAllClose( tf_fn_array(np.array([3, 4, 5])), np.array([4.5, 10, 17.5], jnp.bfloat16)) def test_shared_constants(self): # Check that the constants are shared properly in converted functions # See https://github.com/google/jax/issues/7992. const = np.ones((16, 16)) def f(x): return x + const + const + const + const f_tf_nr_consts = self.CountLargeTfConstants(jax2tf.convert(f), const) self.assertEqual(f_tf_nr_consts, 1) def test_shared_constants_under_cond(self): # Check that the constants are shared properly in converted functions # See https://github.com/google/jax/issues/7992. const = np.arange(256, dtype=np.float32) x = np.ones((256,), dtype=np.float32) def f1(x): return lax.cond(x[0] >= 0., lambda x: x + const, lambda x: x * const, x) + const def f2(x): return f1(x) + const # The extra const should not cost anything f1_nr_consts = self.CountLargeTfConstants(jax2tf.convert(f1), x) f2_nr_consts = self.CountLargeTfConstants(jax2tf.convert(f2), x) self.assertEqual(f1_nr_consts, f2_nr_consts) def test_shared_constants_under_scan(self): # See https://github.com/google/jax/issues/7992. const = np.arange(256, dtype=np.float32) xs = np.ones((8, 256), dtype=np.float32) def f1(xs): res, _ = lax.scan(lambda carry, x: (carry + x + const, None), np.zeros((256,), dtype=np.float32), xs) return res def f2(xs): return f1(xs) + const # The extra const should not be saved f1_nr_consts = self.CountLargeTfConstants(jax2tf.convert(f1), xs) f2_nr_consts = self.CountLargeTfConstants(jax2tf.convert(f2), xs) self.assertEqual(f1_nr_consts, f2_nr_consts) def test_shared_constants_under_jit(self): # We do not share constants under jit. const = np.ones((16, 16)) @jax.jit def g_jit(x): return x * const def f(x): return g_jit(x) + const + const f_tf_graph_nr_consts = self.CountLargeTfConstants(jax2tf.convert(f), const) # TODO(b/207464757): TF compilation is disabled self.assertEqual(f_tf_graph_nr_consts, 1) def test_weak_types(self): mul = jax.jit(jnp.multiply) # The value `2` here should be weakly typed, and should not lead to # promotion. tf_fn = jax2tf.convert(lambda x: mul(x, 2.)) self.assertAllClose(tf_fn(tf.constant(1.375, tf.bfloat16)).numpy(), jnp.bfloat16(2.750)) @parameterized.named_parameters(jtu.cases_from_list( dict(testcase_name=f"function={with_function}", with_function=with_function) for with_function in [False, True])) def test_kwargs(self, with_function=True): # Re: https://github.com/google/jax/issues/6791 def f_jax(*, x): return jnp.sum(x) f_tf = jax2tf.convert(f_jax) if with_function: f_tf = tf.function(f_tf) self.assertAllClose( f_tf(x=np.zeros(3, dtype=np.float32)), # Call with kwargs. np.zeros((), dtype=np.float32)) @parameterized.named_parameters(jtu.cases_from_list( dict(testcase_name=f"function={with_function}", with_function=with_function) for with_function in [False, True])) def test_grad_kwargs(self, with_function=False): # Re: https://github.com/google/jax/issues/6791 x = (np.zeros(3, dtype=np.float32), np.zeros(4, dtype=np.float32)) def f_jax(*, x=(1., 2.)): return jnp.sum(x[0]) + 2. * jnp.sum(x[1]) f_tf = jax2tf.convert(f_jax) if with_function: f_tf = tf.function(f_tf) xv = tf.nest.map_structure(tf.Variable, x) with tf.GradientTape() as tape: res = f_tf(x=xv) grad_tf = tape.gradient(res, xv) self.assertAllClose((np.full_like(x[0], fill_value=1.), np.full_like(x[1], fill_value=2.)), (grad_tf[0].numpy(), grad_tf[1].numpy())) def test_enable_xla(self): # Tests that enable_xla flag is properly scoped to a conversion. def fun(x): # lax.reduce is unlikely to ever be convertible with enable_xla=False return lax.reduce(x, np.float32(0), lambda v, acc: v + acc, dimensions=(0, 1)) tf_fun_with_xla = jax2tf.convert(fun, enable_xla=True) tf_fun_without_xla = jax2tf.convert(fun, enable_xla=False) x = np.ones((2, 3), dtype=np.float32) self.assertAllClose(fun(x), tf_fun_with_xla(x)) with self.assertRaisesRegex(NotImplementedError, "Call to reduce cannot be converted with enable_xla=False"): tf_fun_without_xla(x) # Now in reverse order (we had bugs with the management of enable_xla global) tf_fun2_without_xla = jax2tf.convert(lambda x: fun(x), enable_xla=False) tf_fun2_with_xla = jax2tf.convert(lambda x: fun(x), enable_xla=True) with self.assertRaisesRegex(NotImplementedError, "Call to reduce cannot be converted with enable_xla=False"): tf_fun2_without_xla(x) self.assertAllClose(fun(x), tf_fun2_with_xla(x)) def test_device_array_arg(self): self.ConvertAndCompare(jnp.sin, jnp.zeros((2, 3), jnp.float32)) def test_randint(self): def randint(): return jax.random.randint( jax.random.PRNGKey(42), shape=(), minval=0, maxval=1) self.ConvertAndCompare(randint) def test_op_metadata_simple(self): self.skipTest("include_xla_op_metadata not yet enabled") # A simple example # The user_frame is used to compute line numbers for ops in the test. user_frame = source_info_util.user_frame(source_info_util.current()) def f_simple(x): return jnp.sin(x) x = np.ones((2, 3), np.float32) self.CheckOpMetadata( f_simple, x, [tf_test_util.OpMetadataGraph(tf_type="Sin", source_file=__file__, source_line=user_frame.line_num + 2, op_name="jax2tf(f_simple)/sin", op_type="sin") ] ) def test_op_metadata_sub_jit(self): self.skipTest("include_xla_op_metadata not yet enabled") # Calling a jitted-function # The user_frame is used to compute line numbers for ops in the test. user_frame = source_info_util.user_frame(source_info_util.current()) def f_callee(x): return jnp.cos(x) def f_caller(x): y = jnp.tanh(x) z = jax.jit(f_callee)(y) return jnp.sin(z) x = np.ones((2, 3), np.float32) self.CheckOpMetadata( f_caller, x, [tf_test_util.OpMetadataGraph(tf_type="Tanh", source_file=__file__, source_line=user_frame.line_num + 4, op_name="jax2tf(f_caller)/tanh", op_type="tanh"), tf_test_util.OpMetadataGraph(tf_type="Cos", source_file=__file__, source_line=user_frame.line_num + 2, op_name="jax2tf(f_caller)/jit(f_callee)/cos", op_type="cos"), tf_test_util.OpMetadataGraph(tf_type="Sin", source_file=__file__, source_line=user_frame.line_num + 6, op_name="jax2tf(f_caller)/sin", op_type="sin"), ] ) def test_op_metadata_named(self): self.skipTest("include_xla_op_metadata not yet enabled") # Calling a jax.named_call # The user_frame is used to compute line numbers for ops in the test. user_frame = source_info_util.user_frame(source_info_util.current()) def f_callee(x): return jnp.cos(x) def f_caller(x): y = jnp.tanh(x) z = jax.named_call(f_callee, name="callee")(y) return jnp.sin(z) x = np.ones((2, 3), np.float32) self.CheckOpMetadata( f_caller, x, [tf_test_util.OpMetadataGraph(tf_type="Tanh", source_file=__file__, source_line=user_frame.line_num + 4, op_name="jax2tf(f_caller)/tanh", op_type="tanh"), tf_test_util.OpMetadataGraph(tf_type="Cos", source_file=__file__, source_line=user_frame.line_num + 2, op_name="jax2tf(f_caller)/named(callee)/cos", op_type="cos"), tf_test_util.OpMetadataGraph(tf_type="Sin", source_file=__file__, source_line=user_frame.line_num + 6, op_name="jax2tf(f_caller)/sin", op_type="sin"), ] ) def test_op_metadata_while_and_cond(self): self.skipTest("include_xla_op_metadata not yet enabled") # An example with while and cond # The user_frame is used to compute line numbers for ops in the test. user_frame = source_info_util.user_frame(source_info_util.current()) def f_while_cond(x): def body_fun(i_acc): i, acc = i_acc return (i + 1, (jnp.cos(acc) + lax.cond(jnp.mod(i, 2) == 0, lambda acc: jnp.sin(acc), lambda acc: acc, acc))) _, acc = lax.while_loop( lambda i_acc: i_acc[0] <= 5, body_fun, (0, x)) return acc x = np.ones((2, 3), np.float32) self.CheckOpMetadata( f_while_cond, x, [tf_test_util.OpMetadataGraph(tf_type="Cos", source_file=__file__, source_line=user_frame.line_num + 5, op_name="jax2tf(f_while_cond)/while/body/cos", op_type="cos"), tf_test_util.OpMetadataGraph(tf_type="Sin", source_file=__file__, source_line=user_frame.line_num + 7, op_name="jax2tf(f_while_cond)/while/body/branch_1_fun/sin", op_type="sin"), tf_test_util.OpMetadataGraph(tf_type="FloorMod", source_file=__file__, source_line=user_frame.line_num + 6, op_name="jax2tf(f_while_cond)/while/body/rem", op_type="rem"), ] ) def test_op_metadata_batched_while(self): self.skipTest("include_xla_op_metadata not yet enabled") # An example with while and cond # The user_frame is used to compute line numbers for ops in the test. user_frame = source_info_util.user_frame(source_info_util.current()) @jax.vmap def f_while(x): def body_fun(carry): new_carry = jnp.sin(carry) # We look for "sin" in the graph return new_carry _, carry = lax.while_loop( lambda carry: jnp.all(carry <= x), # We look for "le" in the graph body_fun, x) return carry shape = (3, 2) x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) jax_comp = jax.xla_computation(f_while)(x) backend = jax._src.lib.xla_bridge.get_backend() modules = backend.compile(jax_comp).hlo_modules() jax_opt_hlo = modules[0].to_string() print(f"JAX OPT HLO = {jax_opt_hlo}") self.CheckOpMetadata( f_while, x, [tf_test_util.OpMetadataGraph(tf_type="Sin", source_file=__file__, source_line=user_frame.line_num + 4, op_name="jax2tf(f_while)/while/body/sin", op_type="sin"), tf_test_util.OpMetadataGraph(tf_type="LessEqual", source_file=__file__, source_line=user_frame.line_num + 8, op_name="jax2tf(f_while)/while/body_pred/le", op_type="le"), ] ) def test_op_metadata_disabled(self): self.skipTest("include_xla_op_metadata not yet enabled") def f_simple(x): return jnp.sin(x) x = np.ones((2, 3), np.float32) self.CheckOpMetadata( f_simple, x, [], include_xla_op_metadata=False )
class FftTest(jtu.JaxTestCase): def testNotImplemented(self): for name in jnp.fft._NOT_IMPLEMENTED: func = getattr(jnp.fft, name) with self.assertRaises(NotImplementedError): func() def testLaxFftAcceptsStringTypes(self): rng = jtu.rand_default(self.rng()) x = rng((10, ), np.complex64) self.assertAllClose( np.fft.fft(x).astype(np.complex64), lax.fft(x, "FFT", fft_lengths=(10, ))) @parameterized.parameters((np.float32, ), (np.float64, )) def testLaxIrfftDoesNotMutateInputs(self, dtype): if dtype == np.float64 and not config.x64_enabled: raise self.skipTest("float64 requires jax_enable_x64=true") x = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=dtype) * (1 + 1j) y = np.asarray(jnp.fft.irfft2(x)) z = np.asarray(jnp.fft.irfft2(x)) self.assertAllClose(y, z) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_inverse={}_real={}_shape={}_axes={}_s={}_norm={}".format( inverse, real, jtu.format_shape_dtype_string(shape, dtype), axes, s, norm), "axes": axes, "shape": shape, "dtype": dtype, "inverse": inverse, "real": real, "s": s, "norm": norm } for inverse in [False, True] for real in [False, True] for dtype in (real_dtypes if real and not inverse else all_dtypes) for shape in [(10, ), (10, 10), (9, ), (2, 3, 4), (2, 3, 4, 5)] for axes in _get_fftn_test_axes(shape) for s in _get_fftn_test_s(shape, axes) for norm in FFT_NORMS)) def testFftn(self, inverse, real, shape, dtype, axes, s, norm): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype), ) jnp_op = _get_fftn_func(jnp.fft, inverse, real) np_op = _get_fftn_func(np.fft, inverse, real) jnp_fn = lambda a: jnp_op(a, axes=axes, norm=norm) np_fn = lambda a: np_op(a, axes=axes, norm=norm ) if axes is None or axes else a # Numpy promotes to complex128 aggressively. self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(jnp_fn, args_maker) # Test gradient for differentiable types. if (config.x64_enabled and dtype in (float_dtypes if real and not inverse else inexact_dtypes)): # TODO(skye): can we be more precise? tol = 0.15 jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol) # check dtypes dtype = jnp_fn(rng(shape, dtype)).dtype expected_dtype = jnp.promote_types( float if inverse and real else complex, dtype) self.assertEqual(dtype, expected_dtype) def testIrfftTranspose(self): # regression test for https://github.com/google/jax/issues/6223 def build_matrix(linear_func, size): return jax.vmap(linear_func)(jnp.eye(size, size)) def func(x): return jnp.fft.irfft( jnp.concatenate([jnp.zeros(1), x[:2] + 1j * x[2:]])) def func_transpose(x): return jax.linear_transpose(func, x)(x)[0] matrix = build_matrix(func, 4) matrix2 = build_matrix(func_transpose, 4).T self.assertAllClose(matrix, matrix2) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_inverse={}_real={}".format(inverse, real), "inverse": inverse, "real": real } for inverse in [False, True] for real in [False, True])) def testFftnErrors(self, inverse, real): rng = jtu.rand_default(self.rng()) name = 'fftn' if real: name = 'r' + name if inverse: name = 'i' + name func = _get_fftn_func(jnp.fft, inverse, real) self.assertRaisesRegex( ValueError, "jax.numpy.fft.{} only supports 1D, 2D, and 3D FFTs. " "Got axes None with input rank 4.".format(name), lambda: func(rng([2, 3, 4, 5], dtype=np.float64), axes=None)) self.assertRaisesRegex( ValueError, "jax.numpy.fft.{} does not support repeated axes. Got axes \\[1, 1\\]." .format(name), lambda: func(rng([2, 3], dtype=np.float64), axes=[1, 1])) self.assertRaises( ValueError, lambda: func(rng([2, 3], dtype=np.float64), axes=[2])) self.assertRaises( ValueError, lambda: func(rng([2, 3], dtype=np.float64), axes=[-3])) def testFftEmpty(self): out = jnp.fft.fft(jnp.zeros((0, ), jnp.complex64)).block_until_ready() self.assertArraysEqual(jnp.zeros((0, ), jnp.complex64), out) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_inverse={}_real={}_hermitian={}_shape={}_n={}_axis={}". format(inverse, real, hermitian, jtu.format_shape_dtype_string(shape, dtype), n, axis), "axis": axis, "shape": shape, "dtype": dtype, "inverse": inverse, "real": real, "hermitian": hermitian, "n": n } for inverse in [False, True] for real in [False, True] for hermitian in [False, True] for dtype in (real_dtypes if (real and not inverse) or ( hermitian and inverse) else all_dtypes) for shape in [(10, )] for n in [None, 1, 7, 13, 20] for axis in [-1, 0])) def testFft(self, inverse, real, hermitian, shape, dtype, n, axis): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype), ) name = 'fft' if real: name = 'r' + name elif hermitian: name = 'h' + name if inverse: name = 'i' + name jnp_op = getattr(jnp.fft, name) np_op = getattr(np.fft, name) jnp_fn = lambda a: jnp_op(a, n=n, axis=axis) np_fn = lambda a: np_op(a, n=n, axis=axis) # Numpy promotes to complex128 aggressively. self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(jnp_op, args_maker) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_inverse={}_real={}_hermitian={}".format(inverse, real, hermitian), "inverse": inverse, "real": real, "hermitian": hermitian } for inverse in [False, True] for real in [False, True] for hermitian in [False, True])) def testFftErrors(self, inverse, real, hermitian): rng = jtu.rand_default(self.rng()) name = 'fft' if real: name = 'r' + name elif hermitian: name = 'h' + name if inverse: name = 'i' + name func = getattr(jnp.fft, name) self.assertRaisesRegex( ValueError, f"jax.numpy.fft.{name} does not support multiple axes. " f"Please use jax.numpy.fft.{name}n. Got axis = \\[1, 1\\].", lambda: func(rng([2, 3], dtype=np.float64), axis=[1, 1])) self.assertRaisesRegex( ValueError, f"jax.numpy.fft.{name} does not support multiple axes. " f"Please use jax.numpy.fft.{name}n. Got axis = \\(1, 1\\).", lambda: func(rng([2, 3], dtype=np.float64), axis=(1, 1))) self.assertRaises( ValueError, lambda: func(rng([2, 3], dtype=np.float64), axis=[2])) self.assertRaises( ValueError, lambda: func(rng([2, 3], dtype=np.float64), axis=[-3])) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_inverse={}_real={}_shape={}_axes={}_norm={}".format( inverse, real, jtu.format_shape_dtype_string(shape, dtype), axes, norm), "axes": axes, "shape": shape, "dtype": dtype, "inverse": inverse, "real": real, "norm": norm } for inverse in [False, True] for real in [False, True] for dtype in (real_dtypes if real and not inverse else all_dtypes) for shape in [(16, 8, 4, 8), (16, 8, 4, 8, 4)] for axes in [(-2, -1), (0, 1), (1, 3), (-1, 2)] for norm in FFT_NORMS)) def testFft2(self, inverse, real, shape, dtype, axes, norm): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype), ) name = 'fft2' if real: name = 'r' + name if inverse: name = 'i' + name jnp_op = getattr(jnp.fft, name) np_op = getattr(np.fft, name) jnp_fn = lambda a: jnp_op(a, axes=axes, norm=norm) np_fn = lambda a: np_op(a, axes=axes, norm=norm ) if axes is None or axes else a # Numpy promotes to complex128 aggressively. self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(jnp_op, args_maker) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_inverse={}_real={}".format(inverse, real), "inverse": inverse, "real": real } for inverse in [False, True] for real in [False, True])) def testFft2Errors(self, inverse, real): rng = jtu.rand_default(self.rng()) name = 'fft2' if real: name = 'r' + name if inverse: name = 'i' + name func = getattr(jnp.fft, name) self.assertRaisesRegex( ValueError, "jax.numpy.fft.{} only supports 2 axes. " "Got axes = \\[0\\].".format(name), lambda: func(rng([2, 3], dtype=np.float64), axes=[0])) self.assertRaisesRegex( ValueError, "jax.numpy.fft.{} only supports 2 axes. " "Got axes = \\(0, 1, 2\\).".format(name), lambda: func(rng([2, 3, 3], dtype=np.float64), axes=(0, 1, 2))) self.assertRaises( ValueError, lambda: func(rng([2, 3], dtype=np.float64), axes=[2, 3])) self.assertRaises( ValueError, lambda: func(rng([2, 3], dtype=np.float64), axes=[-3, -4])) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_size={}_d={}".format( jtu.format_shape_dtype_string([size], dtype), d), "dtype": dtype, "size": size, "d": d } for dtype in all_dtypes for size in [9, 10, 101, 102] for d in [0.1, 2.])) def testFftfreq(self, size, d, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng([size], dtype), ) jnp_op = jnp.fft.fftfreq np_op = np.fft.fftfreq jnp_fn = lambda a: jnp_op(size, d=d) np_fn = lambda a: np_op(size, d=d) # Numpy promotes to complex128 aggressively. self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(jnp_fn, args_maker) # Test gradient for differentiable types. if dtype in inexact_dtypes: tol = 0.15 # TODO(skye): can we be more precise? jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_n={}".format(n), "n": n } for n in [[0, 1, 2]])) def testFftfreqErrors(self, n): name = 'fftfreq' func = jnp.fft.fftfreq self.assertRaisesRegex( ValueError, "The n argument of jax.numpy.fft.{} only takes an int. " "Got n = \\[0, 1, 2\\].".format(name), lambda: func(n=n)) self.assertRaisesRegex( ValueError, "The d argument of jax.numpy.fft.{} only takes a single value. " "Got d = \\[0, 1, 2\\].".format(name), lambda: func(n=10, d=n)) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_size={}_d={}".format( jtu.format_shape_dtype_string([size], dtype), d), "dtype": dtype, "size": size, "d": d } for dtype in all_dtypes for size in [9, 10, 101, 102] for d in [0.1, 2.])) def testRfftfreq(self, size, d, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng([size], dtype), ) jnp_op = jnp.fft.rfftfreq np_op = np.fft.rfftfreq jnp_fn = lambda a: jnp_op(size, d=d) np_fn = lambda a: np_op(size, d=d) # Numpy promotes to complex128 aggressively. self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(jnp_fn, args_maker) # Test gradient for differentiable types. if dtype in inexact_dtypes: tol = 0.15 # TODO(skye): can we be more precise? jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_n={}".format(n), "n": n } for n in [[0, 1, 2]])) def testRfftfreqErrors(self, n): name = 'rfftfreq' func = jnp.fft.rfftfreq self.assertRaisesRegex( ValueError, "The n argument of jax.numpy.fft.{} only takes an int. " "Got n = \\[0, 1, 2\\].".format(name), lambda: func(n=n)) self.assertRaisesRegex( ValueError, "The d argument of jax.numpy.fft.{} only takes a single value. " "Got d = \\[0, 1, 2\\].".format(name), lambda: func(n=10, d=n)) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "dtype={}_axes={}".format( jtu.format_shape_dtype_string(shape, dtype), axes), "dtype": dtype, "shape": shape, "axes": axes } for dtype in all_dtypes for shape in [[9], [10], [101], [102], [3, 5], [3, 17], [5, 7, 11]] for axes in _get_fftn_test_axes(shape))) def testFftshift(self, shape, dtype, axes): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype), ) jnp_fn = lambda arg: jnp.fft.fftshift(arg, axes=axes) np_fn = lambda arg: np.fft.fftshift(arg, axes=axes) self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "dtype={}_axes={}".format( jtu.format_shape_dtype_string(shape, dtype), axes), "dtype": dtype, "shape": shape, "axes": axes } for dtype in all_dtypes for shape in [[9], [10], [101], [102], [3, 5], [3, 17], [5, 7, 11]] for axes in _get_fftn_test_axes(shape))) def testIfftshift(self, shape, dtype, axes): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype), ) jnp_fn = lambda arg: jnp.fft.ifftshift(arg, axes=axes) np_fn = lambda arg: np.fft.ifftshift(arg, axes=axes) self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker)
class HigherOrderPrimitiveTest(jtu.JaxTestCase): def test_core_call_primitive_inherits_effects(self): def f(x): @lu.wrap_init def f_(x): effect_p.bind(effect='foo') effect_p.bind(effect='bar') return [x] return core.call(f_, x)[0] with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'): jax.make_jaxpr(f)(2.) def test_xla_call_primitive_inherits_effects(self): @jax.jit def f(x): effect_p.bind(effect='foo') effect_p.bind(effect='bar') return x with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'): jax.make_jaxpr(f)(2.) @parameterized.named_parameters( jtu.cases_from_list( dict(testcase_name=f"_{flavor}", flavor=flavor) for flavor in ["old", "new"])) def test_remat_call_primitive_inherits_effects(self, flavor): remat = jax.remat if flavor == "old" else ad_checkpoint.checkpoint @remat def f(x): effect_p.bind(effect='foo') effect_p.bind(effect='bar') return x with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'): jax.make_jaxpr(f)(2.) def test_custom_jvp_primitive_inherits_effects(self): @jax.custom_jvp def f(x): effect_p.bind(effect='foo') effect_p.bind(effect='bar') return x f.defjvp(lambda x, t: (x, t)) with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'): jax.make_jaxpr(f)(2.) def test_custom_vjp_primitive_inherits_effects(self): @jax.custom_vjp def f(x): effect_p.bind(effect='foo') effect_p.bind(effect='bar') return x f.defvjp(fwd=lambda x: (x, ()), bwd=lambda _, g: g) with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'): jax.make_jaxpr(f)(2.) def test_pmap_inherits_effects(self): @jax.pmap def f(x): effect_p.bind(effect='foo') effect_p.bind(effect='bar') return x with self.assertRaisesRegex( ValueError, "Ordered effects not supported for map primitives: {'foo'}"): jax.make_jaxpr(f)(jnp.arange(jax.local_device_count())) def test_xmap_inherits_effects(self): def f(x): effect_p.bind(effect='foo') effect_p.bind(effect='bar') return x f = maps.xmap(f, in_axes=['a'], out_axes=['a']) jaxpr = jax.make_jaxpr(f)(jnp.arange(jax.local_device_count())) self.assertSetEqual(jaxpr.effects, {"foo", "bar"}) def test_pjit_inherits_effects(self): def f(x): effect_p.bind(effect='foo') effect_p.bind(effect='bar') return x f = pjit.pjit(f, in_axis_resources=pjit.PartitionSpec('x'), out_axis_resources=pjit.PartitionSpec('x')) with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'): with maps.Mesh(np.array(jax.devices()), ['x']): jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
class NdimageTest(jtu.JaxTestCase): @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_{}_coordinates={}_order={}_mode={}_cval={}_impl={}_round={}". format( jtu.format_shape_dtype_string(shape, dtype), jtu.format_shape_dtype_string(coords_shape, coords_dtype), order, mode, cval, impl, round_), "rng_factory": rng_factory, "shape": shape, "coords_shape": coords_shape, "dtype": dtype, "coords_dtype": coords_dtype, "order": order, "mode": mode, "cval": cval, "impl": impl, "round_": round_ } for shape in [(5, ), (3, 4), (3, 4, 5)] for coords_shape in [(7, ), (2, 3, 4)] for dtype in float_dtypes + int_dtypes for coords_dtype in float_dtypes for order in [0, 1] for mode in ['wrap', 'constant', 'nearest', 'mirror', 'reflect'] for cval in ([0, -1] if mode == 'constant' else [0]) for impl, rng_factory in [ ("original", partial(jtu.rand_uniform, low=0, high=1)), ("fixed", partial(jtu.rand_uniform, low=-0.75, high=1.75)), ] for round_ in [True, False])) def testMapCoordinates(self, shape, dtype, coords_shape, coords_dtype, order, mode, cval, impl, round_, rng_factory): def args_maker(): x = np.arange(prod(shape), dtype=dtype).reshape(shape) coords = [(size - 1) * rng(coords_shape, coords_dtype) for size in shape] if round_: coords = [c.round().astype(int) for c in coords] return x, coords rng = rng_factory(self.rng()) lsp_op = lambda x, c: lsp_ndimage.map_coordinates( x, c, order=order, mode=mode, cval=cval) impl_fun = (osp_ndimage.map_coordinates if impl == "original" else _fixed_ref_map_coordinates) osp_op = lambda x, c: impl_fun(x, c, order=order, mode=mode, cval=cval) with jtu.strict_promotion_if_dtypes_match( [dtype, int if round else coords_dtype]): if dtype in float_dtypes: epsilon = max( dtypes.finfo(dtypes.canonicalize_dtype(d)).eps for d in [dtype, coords_dtype]) self._CheckAgainstNumpy(osp_op, lsp_op, args_maker, tol=100 * epsilon) else: self._CheckAgainstNumpy(osp_op, lsp_op, args_maker, tol=0) def testMapCoordinatesErrors(self): x = np.arange(5.0) c = [np.linspace(0, 5, num=3)] with self.assertRaisesRegex(NotImplementedError, 'requires order<=1'): lsp_ndimage.map_coordinates(x, c, order=2) with self.assertRaisesRegex(NotImplementedError, 'does not yet support mode'): lsp_ndimage.map_coordinates(x, c, order=1, mode='grid-wrap') with self.assertRaisesRegex(ValueError, 'sequence of length'): lsp_ndimage.map_coordinates(x, [c, c], order=1) def testMapCoordinateDocstring(self): self.assertIn("Only nearest neighbor", lsp_ndimage.map_coordinates.__doc__) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_{np.dtype(dtype)}_order={order}", "dtype": dtype, "order": order } for dtype in float_dtypes + int_dtypes for order in [0, 1])) def testMapCoordinatesRoundHalf(self, dtype, order): x = np.arange(-3, 3, dtype=dtype) c = np.array([[.5, 1.5, 2.5, 3.5]]) def args_maker(): return x, c lsp_op = lambda x, c: lsp_ndimage.map_coordinates(x, c, order=order) osp_op = lambda x, c: osp_ndimage.map_coordinates(x, c, order=order) with jtu.strict_promotion_if_dtypes_match([dtype, c.dtype]): self._CheckAgainstNumpy(osp_op, lsp_op, args_maker) def testContinuousGradients(self): # regression test for https://github.com/google/jax/issues/3024 def loss(delta): x = np.arange(100.0) border = 10 indices = np.arange(x.size, dtype=x.dtype) + delta # linear interpolation of the linear function y=x should be exact shifted = lsp_ndimage.map_coordinates(x, [indices], order=1) return ((x - shifted)**2)[border:-border].mean() # analytical gradient of (x - (x - delta)) ** 2 is 2 * delta self.assertAllClose(grad(loss)(0.5), 1.0, check_dtypes=False) self.assertAllClose(grad(loss)(1.0), 2.0, check_dtypes=False)
class TestLBFGS(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": f"_func={func_and_init[0].__name__}_maxiter={maxiter}", "maxiter": maxiter, "func_and_init": func_and_init} for maxiter in [None] for func_and_init in [(rosenbrock, np.zeros(2)), (himmelblau, np.zeros(2)), (matyas, np.ones(2) * 6.), (eggholder, np.ones(2) * 100.)])) def test_minimize(self, maxiter, func_and_init): func, x0 = func_and_init @jit def min_op(x0): result = jax.scipy.optimize.minimize( func(jnp), x0, method='l-bfgs-experimental-do-not-rely-on-this', options=dict(maxiter=maxiter, gtol=1e-7), ) return result.x jax_res = min_op(x0) # Note that without bounds, L-BFGS-B is just L-BFGS with jtu.ignore_warning(category=DeprecationWarning, message=".*tostring.*is deprecated.*"): scipy_res = scipy.optimize.minimize(func(np), x0, method='L-BFGS-B').x if func.__name__ == 'matyas': # scipy performs badly for Matyas, compare to true minimum instead self.assertAllClose(jax_res, jnp.zeros_like(jax_res), atol=1e-7) return if func.__name__ == 'eggholder': # L-BFGS performs poorly for the eggholder function. # Neither scipy nor jax find the true minimum, so we can only loosely (with high atol) compare the false results self.assertAllClose(jax_res, scipy_res, atol=1e-3) return self.assertAllClose(jax_res, scipy_res, atol=2e-5, check_dtypes=False) def test_minimize_complex_sphere(self): z0 = jnp.array([1., 2. - 3.j, 4., -5.j]) def f(z): return jnp.real(jnp.dot(jnp.conj(z - z0), z - z0)) @jit def min_op(x0): result = jax.scipy.optimize.minimize( f, x0, method='l-bfgs-experimental-do-not-rely-on-this', options=dict(gtol=1e-6), ) return result.x jax_res = min_op(jnp.zeros_like(z0)) self.assertAllClose(jax_res, z0) def test_complex_rosenbrock(self): complex_dim = 5 f_re = rosenbrock(jnp) init_re = jnp.zeros((2 * complex_dim,)) expect_re = jnp.ones((2 * complex_dim,)) def f(z): x_re = jnp.concatenate([jnp.real(z), jnp.imag(z)]) return f_re(x_re) init = init_re[:complex_dim] + 1.j * init_re[complex_dim:] expect = expect_re[:complex_dim] + 1.j * expect_re[complex_dim:] @jit def min_op(z0): result = jax.scipy.optimize.minimize( f, z0, method='l-bfgs-experimental-do-not-rely-on-this', options=dict(gtol=1e-6), ) return result.x jax_res = min_op(init) self.assertAllClose(jax_res, expect, atol=2e-5)
class LaxBackedScipyStatsTests(jtu.JaxTestCase): """Tests for LAX-backed scipy.stats implementations""" @genNamedParametersNArgs(3) def testPoissonLogPmf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.poisson.logpmf lax_fun = lsp_stats.poisson.logpmf def args_maker(): k, mu, loc = map(rng, shapes, dtypes) k = np.floor(k) # clipping to ensure that rate parameter is strictly positive mu = np.clip(np.abs(mu), a_min=0.1, a_max=None).astype(mu.dtype) loc = np.floor(loc) return [k, mu, loc] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14}) @genNamedParametersNArgs(3) def testPoissonPmf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.poisson.pmf lax_fun = lsp_stats.poisson.pmf def args_maker(): k, mu, loc = map(rng, shapes, dtypes) k = np.floor(k) # clipping to ensure that rate parameter is strictly positive mu = np.clip(np.abs(mu), a_min=0.1, a_max=None).astype(mu.dtype) loc = np.floor(loc) return [k, mu, loc] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testPoissonCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.poisson.cdf lax_fun = lsp_stats.poisson.cdf def args_maker(): k, mu, loc = map(rng, shapes, dtypes) # clipping to ensure that rate parameter is strictly positive mu = np.clip(np.abs(mu), a_min=0.1, a_max=None).astype(mu.dtype) return [k, mu, loc] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testBernoulliLogPmf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.bernoulli.logpmf lax_fun = lsp_stats.bernoulli.logpmf def args_maker(): x, logit, loc = map(rng, shapes, dtypes) x = np.floor(x) p = expit(logit) loc = np.floor(loc) return [x, p, loc] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testGeomLogPmf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.geom.logpmf lax_fun = lsp_stats.geom.logpmf def args_maker(): x, logit, loc = map(rng, shapes, dtypes) x = np.floor(x) p = expit(logit) loc = np.floor(loc) return [x, p, loc] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(5) def testBetaLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.beta.logpdf lax_fun = lsp_stats.beta.logpdf def args_maker(): x, a, b, loc, scale = map(rng, shapes, dtypes) return [x, a, b, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker, rtol={np.float32: 2e-3, np.float64: 1e-4}) def testBetaLogPdfZero(self): # Regression test for https://github.com/google/jax/issues/7645 a = b = 1. x = np.array([0., 1.]) self.assertAllClose( osp_stats.beta.pdf(x, a, b), lsp_stats.beta.pdf(x, a, b), atol=1E-6) @genNamedParametersNArgs(3) def testCauchyLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.cauchy.logpdf lax_fun = lsp_stats.cauchy.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) return [x, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @parameterized.named_parameters( jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix("", [x_shape, alpha_shape], dtypes), "shapes": [x_shape, alpha_shape], "dtypes": dtypes} for x_shape in one_and_two_dim_shapes for alpha_shape in [(x_shape[0],), (x_shape[0] + 1,)] for dtypes in itertools.combinations_with_replacement(jtu.dtypes.floating, 2) )) def testDirichletLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) def _normalize(x, alpha): x_norm = x.sum(0) + (0.0 if x.shape[0] == alpha.shape[0] else 0.1) return (x / x_norm).astype(x.dtype), alpha def lax_fun(x, alpha): return lsp_stats.dirichlet.logpdf(*_normalize(x, alpha)) def scipy_fun(x, alpha): # scipy validates the x normalization using float64 arithmetic, so we must # cast x to float64 before normalization to ensure this passes. x, alpha = _normalize(x.astype('float64'), alpha) result = osp_stats.dirichlet.logpdf(x, alpha) # if x.shape is (N, 1), scipy flattens the output, while JAX returns arrays # of a consistent rank. This check ensures the results have the same shape. return result if x.ndim == 1 else np.atleast_1d(result) def args_maker(): # Don't normalize here, because we want normalization to happen at 64-bit # precision in the scipy version. x, alpha = map(rng, shapes, dtypes) return x, alpha tol = {np.float32: 1E-3, np.float64: 1e-5} with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=tol) self._CompileAndCheck(lax_fun, args_maker, atol=tol, rtol=tol) @genNamedParametersNArgs(3) def testExponLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.expon.logpdf lax_fun = lsp_stats.expon.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) return [x, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(4) def testGammaLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.gamma.logpdf lax_fun = lsp_stats.gamma.logpdf def args_maker(): x, a, loc, scale = map(rng, shapes, dtypes) return [x, a, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker) def testGammaLogPdfZero(self): # Regression test for https://github.com/google/jax/issues/7256 self.assertAllClose( osp_stats.gamma.pdf(0.0, 1.0), lsp_stats.gamma.pdf(0.0, 1.0), atol=1E-6) @genNamedParametersNArgs(2) def testGenNormLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.gennorm.logpdf lax_fun = lsp_stats.gennorm.logpdf def args_maker(): x, p = map(rng, shapes, dtypes) return [x, p] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4, rtol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(2) def testGenNormCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.gennorm.cdf lax_fun = lsp_stats.gennorm.cdf def args_maker(): x, p = map(rng, shapes, dtypes) return [x, p] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4, rtol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(4) def testNBinomLogPmf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.nbinom.logpmf lax_fun = lsp_stats.nbinom.logpmf def args_maker(): k, n, logit, loc = map(rng, shapes, dtypes) k = np.floor(np.abs(k)) n = np.ceil(np.abs(n)) p = expit(logit) loc = np.floor(loc) return [k, n, p, loc] tol = {np.float32: 1e-6, np.float64: 1e-8} with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol) @genNamedParametersNArgs(3) def testLaplaceLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.laplace.logpdf lax_fun = lsp_stats.laplace.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype) return [x, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testLaplaceCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.laplace.cdf lax_fun = lsp_stats.laplace.cdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # ensure that scale is not too low scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype) return [x, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol={np.float32: 1e-5, np.float64: 1e-6}) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(1) def testLogisticCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.cdf lax_fun = lsp_stats.logistic.cdf def args_maker(): return list(map(rng, shapes, dtypes)) with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-6) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(1) def testLogisticLogpdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.logpdf lax_fun = lsp_stats.logistic.logpdf def args_maker(): return list(map(rng, shapes, dtypes)) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) def testLogisticLogpdfOverflow(self): # Regression test for https://github.com/google/jax/issues/10219 self.assertAllClose( np.array([-100, -100], np.float32), lsp_stats.logistic.logpdf(np.array([-100, 100], np.float32)), check_dtypes=False) @genNamedParametersNArgs(1) def testLogisticPpf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.ppf lax_fun = lsp_stats.logistic.ppf def args_maker(): return list(map(rng, shapes, dtypes)) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(1) def testLogisticSf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.sf lax_fun = lsp_stats.logistic.sf def args_maker(): return list(map(rng, shapes, dtypes)) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-6) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testNormLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.norm.logpdf lax_fun = lsp_stats.norm.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) return [x, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testNormLogCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.norm.logcdf lax_fun = lsp_stats.norm.logcdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) return [x, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testNormCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.norm.cdf lax_fun = lsp_stats.norm.cdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) return [x, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-6) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testNormPpf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.norm.ppf lax_fun = lsp_stats.norm.ppf def args_maker(): q, loc, scale = map(rng, shapes, dtypes) # ensure probability is between 0 and 1: q = np.clip(np.abs(q / 3), a_min=None, a_max=1).astype(q.dtype) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) return [q, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4) @genNamedParametersNArgs(4) def testParetoLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.pareto.logpdf lax_fun = lsp_stats.pareto.logpdf def args_maker(): x, b, loc, scale = map(rng, shapes, dtypes) return [x, b, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(4) def testTLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.t.logpdf lax_fun = lsp_stats.t.logpdf def args_maker(): x, df, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) return [x, df, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14}, atol={np.float64: 1e-14}) @genNamedParametersNArgs(3) def testUniformLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.uniform.logpdf lax_fun = lsp_stats.uniform.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) return [x, loc, np.abs(scale)] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(4) def testChi2LogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.chi2.logpdf lax_fun = lsp_stats.chi2.logpdf def args_maker(): x, df, loc, scale = map(rng, shapes, dtypes) return [x, df, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(5) def testBetaBinomLogPmf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) lax_fun = lsp_stats.betabinom.logpmf def args_maker(): k, n, a, b, loc = map(rng, shapes, dtypes) k = np.floor(k) n = np.ceil(n) a = np.clip(a, a_min = 0.1, a_max=None).astype(a.dtype) b = np.clip(a, a_min = 0.1, a_max=None).astype(b.dtype) loc = np.floor(loc) return [k, n, a, b, loc] with jtu.strict_promotion_if_dtypes_match(dtypes): scipy_fun = osp_stats.betabinom.logpmf self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker, rtol=1e-5, atol=1e-5) def testIssue972(self): self.assertAllClose( np.ones((4,), np.float32), lsp_stats.norm.cdf(np.full((4,), np.inf, np.float32)), check_dtypes=False) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_x={}_mean={}_cov={}".format( jtu.format_shape_dtype_string(x_shape, x_dtype), jtu.format_shape_dtype_string(mean_shape, mean_dtype) if mean_shape is not None else None, jtu.format_shape_dtype_string(cov_shape, cov_dtype) if cov_shape is not None else None), "x_shape": x_shape, "x_dtype": x_dtype, "mean_shape": mean_shape, "mean_dtype": mean_dtype, "cov_shape": cov_shape, "cov_dtype": cov_dtype} for x_shape, mean_shape, cov_shape in [ # # These test cases cover default values for mean/cov, but we don't # # support those yet (and they seem not very valuable). # [(), None, None], # [(), (), None], # [(2,), None, None], # [(2,), (), None], # [(2,), (2,), None], # [(3, 2), (3, 2,), None], # [(5, 3, 2), (5, 3, 2,), None], [(), (), ()], [(3,), (), ()], [(3,), (3,), ()], [(3,), (3,), (3, 3)], [(3, 4), (4,), (4, 4)], [(2, 3, 4), (4,), (4, 4)], ] for x_dtype, mean_dtype, cov_dtype in itertools.combinations_with_replacement(jtu.dtypes.floating, 3) if (mean_shape is not None or mean_dtype == np.float32) and (cov_shape is not None or cov_dtype == np.float32))) def testMultivariateNormalLogpdf(self, x_shape, x_dtype, mean_shape, mean_dtype, cov_shape, cov_dtype): rng = jtu.rand_default(self.rng()) def args_maker(): args = [rng(x_shape, x_dtype)] if mean_shape is not None: args.append(5 * rng(mean_shape, mean_dtype)) if cov_shape is not None: if cov_shape == (): args.append(0.1 + rng(cov_shape, cov_dtype) ** 2) else: factor_shape = (*cov_shape[:-1], 2 * cov_shape[-1]) factor = rng(factor_shape, cov_dtype) args.append(np.matmul(factor, np.swapaxes(factor, -1, -2))) return [a.astype(x_dtype) for a in args] self._CheckAgainstNumpy(osp_stats.multivariate_normal.logpdf, lsp_stats.multivariate_normal.logpdf, args_maker, tol=1e-3, check_dtypes=False) self._CompileAndCheck(lsp_stats.multivariate_normal.logpdf, args_maker, rtol=1e-4, atol=1e-4) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_x={}_mean={}_cov={}".format( jtu.format_shape_dtype_string(x_shape, x_dtype), jtu.format_shape_dtype_string(mean_shape, mean_dtype) if mean_shape is not None else None, jtu.format_shape_dtype_string(cov_shape, cov_dtype) if cov_shape is not None else None), "x_shape": x_shape, "x_dtype": x_dtype, "mean_shape": mean_shape, "mean_dtype": mean_dtype, "cov_shape": cov_shape, "cov_dtype": cov_dtype} for x_shape, mean_shape, cov_shape in [ # These test cases are where scipy flattens things, which has # different batch semantics than some might expect, so we manually # vectorize scipy's outputs for the sake of testing. [(5, 3, 2), (5, 3, 2), (5, 3, 2, 2)], [(2,), (5, 3, 2), (5, 3, 2, 2)], [(5, 3, 2), (2,), (5, 3, 2, 2)], [(5, 3, 2), (5, 3, 2,), (2, 2)], [(1, 3, 2), (3, 2,), (5, 1, 2, 2)], [(5, 3, 2), (1, 2,), (2, 2)], ] for x_dtype, mean_dtype, cov_dtype in itertools.combinations_with_replacement(jtu.dtypes.floating, 3) if (mean_shape is not None or mean_dtype == np.float32) and (cov_shape is not None or cov_dtype == np.float32))) def testMultivariateNormalLogpdfBroadcasted(self, x_shape, x_dtype, mean_shape, mean_dtype, cov_shape, cov_dtype): rng = jtu.rand_default(self.rng()) def args_maker(): args = [rng(x_shape, x_dtype)] if mean_shape is not None: args.append(5 * rng(mean_shape, mean_dtype)) if cov_shape is not None: if cov_shape == (): args.append(0.1 + rng(cov_shape, cov_dtype) ** 2) else: factor_shape = (*cov_shape[:-1], 2 * cov_shape[-1]) factor = rng(factor_shape, cov_dtype) args.append(np.matmul(factor, np.swapaxes(factor, -1, -2))) return [a.astype(x_dtype) for a in args] osp_fun = np.vectorize(osp_stats.multivariate_normal.logpdf, signature="(n),(n),(n,n)->()") self._CheckAgainstNumpy(osp_fun, lsp_stats.multivariate_normal.logpdf, args_maker, tol=1e-3, check_dtypes=False) self._CompileAndCheck(lsp_stats.multivariate_normal.logpdf, args_maker, rtol=1e-4, atol=1e-4) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": f"_ndim={ndim}_nbatch={nbatch}_dtype={dtype.__name__}", "ndim": ndim, "nbatch": nbatch, "dtype": dtype} for ndim in [2, 3] for nbatch in [1, 3, 5] for dtype in jtu.dtypes.floating)) def testMultivariateNormalLogpdfBatch(self, ndim, nbatch, dtype): # Regression test for #5570 rng = jtu.rand_default(self.rng()) x = rng((nbatch, ndim), dtype) mean = 5 * rng((nbatch, ndim), dtype) factor = rng((nbatch, ndim, 2 * ndim), dtype) cov = factor @ factor.transpose(0, 2, 1) result1 = lsp_stats.multivariate_normal.logpdf(x, mean, cov) result2 = jax.vmap(lsp_stats.multivariate_normal.logpdf)(x, mean, cov) self.assertArraysEqual(result1, result2, check_dtypes=False) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_outsize={}_weights={}_method={}_func={}".format( jtu.format_shape_dtype_string(inshape, dtype), outsize, weights, method, func), "dtype": dtype, "inshape": inshape, "outsize": outsize, "weights": weights, "method": method, "func": func} for inshape in [(50,), (3, 50), (2, 12)] for dtype in jtu.dtypes.floating for outsize in [None, 10] for weights in [False, True] for method in [None, "scott", "silverman", 1.5, "callable"] for func in [None, "evaluate", "logpdf", "pdf"])) def testKde(self, inshape, dtype, outsize, weights, method, func): if method == "callable": method = lambda kde: jax.numpy.power(kde.neff, -1./(kde.d+4)) def scipy_fun(dataset, points, w): w = np.abs(w) if weights else None kde = osp_stats.gaussian_kde(dataset, bw_method=method, weights=w) if func is None: result = kde(points) else: result = getattr(kde, func)(points) # Note: the scipy implementation _always_ returns float64 return result.astype(dtype) def lax_fun(dataset, points, w): w = jax.numpy.abs(w) if weights else None kde = lsp_stats.gaussian_kde(dataset, bw_method=method, weights=w) if func is None: result = kde(points) else: result = getattr(kde, func)(points) return result if outsize is None: outshape = inshape else: outshape = inshape[:-1] + (outsize,) rng = jtu.rand_default(self.rng()) args_maker = lambda: [ rng(inshape, dtype), rng(outshape, dtype), rng(inshape[-1:], dtype)] self._CheckAgainstNumpy( scipy_fun, lax_fun, args_maker, tol={ np.float32: 1e-2 if jtu.device_under_test() == "tpu" else 1e-3, np.float64: 1e-14 }) self._CompileAndCheck( lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15}) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix("", [shape], [dtype]), "dtype": dtype, "shape": shape} for shape in [(15,), (3, 15), (1, 12)] for dtype in jtu.dtypes.floating)) def testKdeIntegrateGaussian(self, shape, dtype): def scipy_fun(dataset, weights): kde = osp_stats.gaussian_kde(dataset, weights=np.abs(weights)) # Note: the scipy implementation _always_ returns float64 return kde.integrate_gaussian(mean, covariance).astype(dtype) def lax_fun(dataset, weights): kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights)) return kde.integrate_gaussian(mean, covariance) # Construct a random mean and positive definite covariance matrix rng = jtu.rand_default(self.rng()) ndim = shape[0] if len(shape) > 1 else 1 mean = rng(ndim, dtype) L = rng((ndim, ndim), dtype) L[np.triu_indices(ndim, 1)] = 0.0 L[np.diag_indices(ndim)] = np.exp(np.diag(L)) + 0.01 covariance = L @ L.T args_maker = lambda: [ rng(shape, dtype), rng(shape[-1:], dtype)] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol={np.float32: 1e-3, np.float64: 1e-14}) self._CompileAndCheck( lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15}) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix("", [shape], [dtype]), "dtype": dtype, "shape": shape} for shape in [(15,), (12,)] for dtype in jtu.dtypes.floating)) def testKdeIntegrateBox1d(self, shape, dtype): def scipy_fun(dataset, weights): kde = osp_stats.gaussian_kde(dataset, weights=np.abs(weights)) # Note: the scipy implementation _always_ returns float64 return kde.integrate_box_1d(-0.5, 1.5).astype(dtype) def lax_fun(dataset, weights): kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights)) return kde.integrate_box_1d(-0.5, 1.5) rng = jtu.rand_default(self.rng()) args_maker = lambda: [ rng(shape, dtype), rng(shape[-1:], dtype)] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol={np.float32: 1e-3, np.float64: 1e-14}) self._CompileAndCheck( lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15}) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix("", [shape], [dtype]), "dtype": dtype, "shape": shape} for shape in [(15,), (3, 15), (1, 12)] for dtype in jtu.dtypes.floating)) def testKdeIntegrateKde(self, shape, dtype): def scipy_fun(dataset, weights): kde = osp_stats.gaussian_kde(dataset, weights=np.abs(weights)) other = osp_stats.gaussian_kde( dataset[..., :-3] + 0.1, weights=np.abs(weights[:-3])) # Note: the scipy implementation _always_ returns float64 return kde.integrate_kde(other).astype(dtype) def lax_fun(dataset, weights): kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights)) other = lsp_stats.gaussian_kde( dataset[..., :-3] + 0.1, weights=jax.numpy.abs(weights[:-3])) return kde.integrate_kde(other) rng = jtu.rand_default(self.rng()) args_maker = lambda: [ rng(shape, dtype), rng(shape[-1:], dtype)] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol={np.float32: 1e-3, np.float64: 1e-14}) self._CompileAndCheck( lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15}) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix("", [shape], [dtype]), "dtype": dtype, "shape": shape} for shape in [(15,), (3, 15), (1, 12)] for dtype in jtu.dtypes.floating)) def testKdeResampleShape(self, shape, dtype): def resample(key, dataset, weights, *, shape): kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights)) return kde.resample(key, shape=shape) rng = jtu.rand_default(self.rng()) args_maker = lambda: [ jax.random.PRNGKey(0), rng(shape, dtype), rng(shape[-1:], dtype)] ndim = shape[0] if len(shape) > 1 else 1 args = args_maker() func = partial(resample, shape=()) self._CompileAndCheck( func, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15}) result = func(*args) assert result.shape == (ndim,) func = partial(resample, shape=(4,)) self._CompileAndCheck( func, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15}) result = func(*args) assert result.shape == (ndim, 4) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix("", [shape], [dtype]), "dtype": dtype, "shape": shape} for shape in [(15,), (1, 12)] for dtype in jtu.dtypes.floating)) def testKdeResample1d(self, shape, dtype): rng = jtu.rand_default(self.rng()) dataset = rng(shape, dtype) weights = jax.numpy.abs(rng(shape[-1:], dtype)) kde = lsp_stats.gaussian_kde(dataset, weights=weights) samples = jax.numpy.squeeze(kde.resample(jax.random.PRNGKey(5), shape=(1000,))) def cdf(x): result = jax.vmap(partial(kde.integrate_box_1d, -np.inf))(x) # Manually casting to numpy in order to avoid type promotion error return np.array(result) self.assertGreater(osp_stats.kstest(samples, cdf).pvalue, 0.01) def testKdePyTree(self): @jax.jit def evaluate_kde(kde, x): return kde.evaluate(x) dtype = np.float32 rng = jtu.rand_default(self.rng()) dataset = rng((3, 15), dtype) x = rng((3, 12), dtype) kde = lsp_stats.gaussian_kde(dataset) leaves, treedef = tree_util.tree_flatten(kde) kde2 = tree_util.tree_unflatten(treedef, leaves) tree_util.tree_map(lambda a, b: self.assertAllClose(a, b), kde, kde2) self.assertAllClose(evaluate_kde(kde, x), kde.evaluate(x))
class LaxBackedScipySignalTests(jtu.JaxTestCase): """Tests for LAX-backed scipy.stats implementations""" @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_op={}_xshape={}_yshape={}_mode={}".format( op, jtu.format_shape_dtype_string(xshape, dtype), jtu.format_shape_dtype_string(yshape, dtype), mode), "xshape": xshape, "yshape": yshape, "dtype": dtype, "mode": mode, "jsp_op": getattr(jsp_signal, op), "osp_op": getattr(osp_signal, op) } for mode in ['full', 'same', 'valid'] for op in ['convolve', 'correlate'] for dtype in default_dtypes for shapeset in [onedim_shapes, twodim_shapes, threedim_shapes] for xshape in shapeset for yshape in shapeset)) def testConvolutions(self, xshape, yshape, dtype, mode, jsp_op, osp_op): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)] osp_fun = partial(osp_op, mode=mode) jsp_fun = partial(jsp_op, mode=mode, precision=lax.Precision.HIGHEST) tol = { np.float16: 1e-2, np.float32: 1e-2, np.float64: 1e-12, np.complex64: 1e-2, np.complex128: 1e-12 } self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False, tol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "op={}_xshape={}_yshape={}_mode={}".format( op, jtu.format_shape_dtype_string(xshape, dtype), jtu.format_shape_dtype_string(yshape, dtype), mode), "xshape": xshape, "yshape": yshape, "dtype": dtype, "mode": mode, "jsp_op": getattr(jsp_signal, op), "osp_op": getattr(osp_signal, op) } for mode in ['full', 'same', 'valid'] for op in ['convolve2d', 'correlate2d'] for dtype in default_dtypes for xshape in twodim_shapes for yshape in twodim_shapes)) def testConvolutions2D(self, xshape, yshape, dtype, mode, jsp_op, osp_op): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)] osp_fun = partial(osp_op, mode=mode) jsp_fun = partial(jsp_op, mode=mode, precision=lax.Precision.HIGHEST) tol = { np.float16: 1e-2, np.float32: 1e-2, np.float64: 1e-12, np.complex64: 1e-2, np.complex128: 1e-12 } self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False, tol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_axis={}_type={}_bp={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, type, bp), "shape": shape, "dtype": dtype, "axis": axis, "type": type, "bp": bp } for shape in [(5, ), (4, 5), (3, 4, 5)] for dtype in jtu.dtypes.floating + jtu.dtypes.integer for axis in [0, -1] for type in ['constant', 'linear'] for bp in [0, [0, 2]])) def testDetrend(self, shape, dtype, axis, type, bp): signal = np.random.normal(loc=2, size=shape) if type == 'constant': trend = np.ones_like(signal) elif type == 'linear': trend = np.linspace(-np.pi, np.pi, shape[0]) if len(shape) == 1: trend = np.broadcast_to(trend, shape) elif len(shape) == 2: trend = np.broadcast_to(trend[:, None], shape) elif len(shape) == 3: trend = np.broadcast_to(trend[:, None, None], shape) args_maker = lambda: [signal, trend] def osp_fun(signal, trend): return osp_signal.detrend( signal + trend, axis=axis, type=type, bp=bp) - trend def jsp_fun(signal, noise): return jsp_signal.detrend( signal + noise, axis=axis, type=type, bp=bp) - trend if jtu.device_under_test() == 'tpu': tol = {np.float32: 3e-2, np.float64: 1e-12} else: tol = {np.float32: 1e-5, np.float64: 1e-12} self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, tol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol) @parameterized.named_parameters( jtu. cases_from_list({ "testcase_name": f"_shape={jtu.format_shape_dtype_string(shape, dtype)}" f"_fs={fs}_window={window}_boundary={boundary}_detrend={detrend}" f"_padded={padded}_nperseg={nperseg}_noverlap={noverlap}" f"_axis={timeaxis}_nfft={nfft}", "shape": shape, "dtype": dtype, "fs": fs, "window": window, "nperseg": nperseg, "noverlap": noverlap, "nfft": nfft, "detrend": detrend, "boundary": boundary, "padded": padded, "timeaxis": timeaxis } for shape, nperseg, noverlap, timeaxis in stft_test_shapes for dtype in default_dtypes for fs in [1.0, 16000.0] for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann'] for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2] for detrend in ['constant', 'linear', False] for boundary in [None, 'even', 'odd', 'zeros'] for padded in [True, False])) def testStftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg, noverlap, nfft, detrend, boundary, padded, timeaxis): is_complex = np.dtype(dtype).kind == 'c' if is_complex and detrend is not None: return osp_fun = partial(osp_signal.stft, fs=fs, window=window, nfft=nfft, boundary=boundary, padded=padded, detrend=detrend, nperseg=nperseg, noverlap=noverlap, axis=timeaxis, return_onesided=not is_complex) jsp_fun = partial(jsp_signal.stft, fs=fs, window=window, nfft=nfft, boundary=boundary, padded=padded, detrend=detrend, nperseg=nperseg, noverlap=noverlap, axis=timeaxis, return_onesided=not is_complex) tol = { np.float32: 1e-5, np.float64: 1e-12, np.complex64: 1e-5, np.complex128: 1e-12 } if jtu.device_under_test() == 'tpu': tol = _TPU_FFT_TOL rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol) # Tests with `average == 'median'`` is excluded from `testCsd*` # due to the issue: # https://github.com/scipy/scipy/issues/15601 @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": f"_xshape={jtu.format_shape_dtype_string(xshape, dtype)}" f"_yshape={jtu.format_shape_dtype_string(yshape, dtype)}" f"_average={average}_scaling={scaling}_nfft={nfft}" f"_fs={fs}_window={window}_detrend={detrend}" f"_nperseg={nperseg}_noverlap={noverlap}" f"_axis={timeaxis}", "xshape": xshape, "yshape": yshape, "dtype": dtype, "fs": fs, "window": window, "nperseg": nperseg, "noverlap": noverlap, "nfft": nfft, "detrend": detrend, "scaling": scaling, "timeaxis": timeaxis, "average": average } for xshape, yshape, nperseg, noverlap, timeaxis in csd_test_shapes for dtype in default_dtypes for fs in [1.0, 16000.0] for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann'] for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2] for detrend in ['constant', 'linear', False] for scaling in ['density', 'spectrum'] for average in ['mean'])) def testCsdAgainstNumpy(self, *, xshape, yshape, dtype, fs, window, nperseg, noverlap, nfft, detrend, scaling, timeaxis, average): is_complex = np.dtype(dtype).kind == 'c' if is_complex and detrend is not None: raise unittest.SkipTest( "Complex signal is not supported in lax-backed `signal.detrend`." ) osp_fun = partial(osp_signal.csd, fs=fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=not is_complex, scaling=scaling, axis=timeaxis, average=average) jsp_fun = partial(jsp_signal.csd, fs=fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=not is_complex, scaling=scaling, axis=timeaxis, average=average) tol = { np.float32: 1e-5, np.float64: 1e-12, np.complex64: 1e-5, np.complex128: 1e-12 } if jtu.device_under_test() == 'tpu': tol = _TPU_FFT_TOL rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)] self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": f"_shape={jtu.format_shape_dtype_string(shape, dtype)}" f"_average={average}_scaling={scaling}_nfft={nfft}" f"_fs={fs}_window={window}_detrend={detrend}" f"_nperseg={nperseg}_noverlap={noverlap}" f"_axis={timeaxis}", "shape": shape, "dtype": dtype, "fs": fs, "window": window, "nperseg": nperseg, "noverlap": noverlap, "nfft": nfft, "detrend": detrend, "scaling": scaling, "timeaxis": timeaxis, "average": average } for shape, unused_yshape, nperseg, noverlap, timeaxis in csd_test_shapes for dtype in default_dtypes for fs in [1.0, 16000.0] for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann'] for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2] for detrend in ['constant', 'linear', False] for scaling in ['density', 'spectrum'] for average in ['mean'])) def testCsdWithSameParamAgainstNumpy(self, *, shape, dtype, fs, window, nperseg, noverlap, nfft, detrend, scaling, timeaxis, average): is_complex = np.dtype(dtype).kind == 'c' if is_complex and detrend is not None: raise unittest.SkipTest( "Complex signal is not supported in lax-backed `signal.detrend`." ) def osp_fun(x, y): # When the identical parameters are given, jsp-version follows # the behavior with copied parameters. freqs, Pxy = osp_signal.csd(x, y.copy(), fs=fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=not is_complex, scaling=scaling, axis=timeaxis, average=average) return freqs, Pxy jsp_fun = partial(jsp_signal.csd, fs=fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=not is_complex, scaling=scaling, axis=timeaxis, average=average) tol = { np.float32: 1e-5, np.float64: 1e-12, np.complex64: 1e-5, np.complex128: 1e-12 } if jtu.device_under_test() == 'tpu': tol = _TPU_FFT_TOL rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] * 2 self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": f"_shape={jtu.format_shape_dtype_string(shape, dtype)}" f"_fs={fs}_window={window}" f"_nperseg={nperseg}_noverlap={noverlap}_nfft={nfft}" f"_detrend={detrend}_return_onesided={return_onesided}" f"_scaling={scaling}_axis={timeaxis}_average={average}", "shape": shape, "dtype": dtype, "fs": fs, "window": window, "nperseg": nperseg, "noverlap": noverlap, "nfft": nfft, "detrend": detrend, "return_onesided": return_onesided, "scaling": scaling, "timeaxis": timeaxis, "average": average } for shape, nperseg, noverlap, timeaxis in welch_test_shapes for dtype in default_dtypes for fs in [1.0, 16000.0] for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann'] for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2] for detrend in ['constant', 'linear', False] for return_onesided in [True, False] for scaling in ['density', 'spectrum'] for average in ['mean', 'median'])) def testWelchAgainstNumpy(self, *, shape, dtype, fs, window, nperseg, noverlap, nfft, detrend, return_onesided, scaling, timeaxis, average): if np.dtype(dtype).kind == 'c': return_onesided = False if detrend is not None: raise unittest.SkipTest( "Complex signal is not supported in lax-backed `signal.detrend`." ) osp_fun = partial(osp_signal.welch, fs=fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=return_onesided, scaling=scaling, axis=timeaxis, average=average) jsp_fun = partial(jsp_signal.welch, fs=fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=detrend, return_onesided=return_onesided, scaling=scaling, axis=timeaxis, average=average) tol = { np.float32: 1e-5, np.float64: 1e-12, np.complex64: 1e-5, np.complex128: 1e-12 } if jtu.device_under_test() == 'tpu': tol = _TPU_FFT_TOL rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": f"_shape={jtu.format_shape_dtype_string(shape, dtype)}" f"_nperseg={nperseg}_noverlap={noverlap}" f"_use_nperseg={use_nperseg}_use_overlap={use_noverlap}" f"_axis={timeaxis}", "shape": shape, "dtype": dtype, "nperseg": nperseg, "noverlap": noverlap, "use_nperseg": use_nperseg, "use_noverlap": use_noverlap, "timeaxis": timeaxis } for shape, nperseg, noverlap, timeaxis in welch_test_shapes for use_nperseg in [False, True] for use_noverlap in [False, True] for dtype in jtu.dtypes.floating + jtu.dtypes.integer)) def testWelchWithDefaultStepArgsAgainstNumpy(self, *, shape, dtype, nperseg, noverlap, use_nperseg, use_noverlap, timeaxis): kwargs = {'axis': timeaxis} if use_nperseg: kwargs['nperseg'] = nperseg else: kwargs['window'] = osp_signal.get_window('hann', nperseg) if use_noverlap: kwargs['noverlap'] = noverlap osp_fun = partial(osp_signal.welch, **kwargs) jsp_fun = partial(jsp_signal.welch, **kwargs) tol = { np.float32: 1e-5, np.float64: 1e-12, np.complex64: 1e-5, np.complex128: 1e-12 } if jtu.device_under_test() == 'tpu': tol = _TPU_FFT_TOL rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": f"_shape={jtu.format_shape_dtype_string(shape, dtype)}" f"_fs={fs}_window={window}_boundary={boundary}" f"_nperseg={nperseg}_noverlap={noverlap}_onesided={onesided}" f"_timeaxis={timeaxis}_freqaxis{freqaxis}_nfft={nfft}", "shape": shape, "dtype": dtype, "fs": fs, "window": window, "nperseg": nperseg, "noverlap": noverlap, "nfft": nfft, "onesided": onesided, "boundary": boundary, "timeaxis": timeaxis, "freqaxis": freqaxis } for shape, nperseg, noverlap, timeaxis, freqaxis in istft_test_shapes for dtype in default_dtypes for fs in [1.0, 16000.0] for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann'] for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2] for onesided in [False, True] for boundary in [False, True])) def testIstftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg, noverlap, nfft, onesided, boundary, timeaxis, freqaxis): if not onesided: new_freq_len = (shape[freqaxis] - 1) * 2 shape = shape[:freqaxis] + (new_freq_len, ) + shape[freqaxis + 1:] def osp_fun(x, fs): # Ignore UserWarning in osp so we can also test over ill-posed cases. with warnings.catch_warnings(): warnings.simplefilter("ignore") result = osp_signal.istft(x, fs=fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, input_onesided=onesided, boundary=boundary, time_axis=timeaxis, freq_axis=freqaxis) return result jsp_fun = partial(jsp_signal.istft, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, input_onesided=onesided, boundary=boundary, time_axis=timeaxis, freq_axis=freqaxis) tol = { np.float32: 1e-4, np.float64: 1e-6, np.complex64: 1e-4, np.complex128: 1e-6 } if jtu.device_under_test() == 'tpu': tol = _TPU_FFT_TOL rng = jtu.rand_default(self.rng()) rng_fs = jtu.rand_uniform(self.rng(), 1.0, 16000.0) args_maker = lambda: [rng(shape, dtype), rng_fs((), np.float)] # Here, dtype of output signal is different depending on osp versions, # and so depending on the test environment. Thus, dtype check is disabled. self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol, check_dtypes=False) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
class TestLineSearch(jtu.JaxTestCase): # -- scalar functions; must have dphi(0.) < 0 def assert_wolfe(self, s, phi, derphi, c1=1e-4, c2=0.9, err_msg=""): """ Check that strong Wolfe conditions apply """ phi1 = phi(s) phi0 = phi(0) derphi0 = derphi(0) derphi1 = derphi(s) msg = "s = %s; phi(0) = %s; phi(s) = %s; phi'(0) = %s; phi'(s) = %s; %s" % ( s, phi0, phi1, derphi0, derphi1, err_msg) self.assertTrue(phi1 <= phi0 + c1 * s * derphi0, "Wolfe 1 failed: " + msg) self.assertTrue( abs(derphi1) <= abs(c2 * derphi0), "Wolfe 2 failed: " + msg) def assert_line_wolfe(self, x, p, s, f, fprime, **kw): self.assert_wolfe(s, phi=lambda sp: f(x + p * sp), derphi=lambda sp: jnp.dot(fprime(x + p * sp), p), **kw) def _scalar_func_1(self, s): p = -s - s**3 + s**4 dp = -1 - 3 * s**2 + 4 * s**3 return p, dp def _scalar_func_2(self, s): p = jnp.exp(-4 * s) + s**2 dp = -4 * jnp.exp(-4 * s) + 2 * s return p, dp def _scalar_func_3(self, s): p = -jnp.sin(10 * s) dp = -10 * jnp.cos(10 * s) return p, dp # -- n-d functions def _line_func_1(self, x): f = jnp.dot(x, x) df = 2 * x return f, df def _line_func_2(self, x): f = jnp.dot(x, jnp.dot(self.A, x)) + 1 df = jnp.dot(self.A + self.A.T, x) return f, df # -- Generic scalar searches @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_name={}".format(name), "name": name } for name in ['_scalar_func_1', '_scalar_func_2', '_scalar_func_3'])) def test_scalar_search_wolfe2(self, name): def bind_index(func, idx): # Remember Python's closure semantics! return lambda *a, **kw: func(*a, **kw)[idx] value = getattr(self, name) phi = bind_index(value, 0) derphi = bind_index(value, 1) for old_phi0 in self.rng().randn(3): res = line_search(phi, 0., 1.) s, phi1, derphi1 = res.a_k, res.f_k, res.g_k self.assertAllClose(phi1, phi(s), check_dtypes=False, atol=1e-6) if derphi1 is not None: self.assertAllClose(derphi1, derphi(s), check_dtypes=False, atol=1e-6) self.assert_wolfe(s, phi, derphi, err_msg="%s %g" % (name, old_phi0)) # -- Generic line searches @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_name={}".format(name), "name": name } for name in ['_line_func_1', '_line_func_2'])) def test_line_search_wolfe2(self, name): def bind_index(func, idx): # Remember Python's closure semantics! return lambda *a, **kw: func(*a, **kw)[idx] value = getattr(self, name) f = bind_index(value, 0) fprime = bind_index(value, 1) k = 0 N = 20 rng = self.rng() # sets A in one of the line funcs self.A = self.rng().randn(N, N) while k < 9: x = rng.randn(N) p = rng.randn(N) if jnp.dot(p, fprime(x)) >= 0: # always pick a descent pk continue k += 1 f0 = f(x) g0 = fprime(x) self.fcount = 0 res = line_search(f, x, p, old_fval=f0, gfk=g0) s = res.a_k fv = res.f_k gv = res.g_k self.assertAllClose(fv, f(x + s * p), check_dtypes=False, atol=1e-5) if gv is not None: self.assertAllClose(gv, fprime(x + s * p), check_dtypes=False, atol=1e-5) def test_line_search_wolfe2_bounds(self): # See gh-7475 # For this f and p, starting at a point on axis 0, the strong Wolfe # condition 2 is met if and only if the step length s satisfies # |x + s| <= c2 * |x| f = lambda x: jnp.dot(x, x) fp = lambda x: 2 * x p = jnp.array([1, 0]) # Smallest s satisfying strong Wolfe conditions for these arguments is 30 x = -60 * p c2 = 0.5 res = line_search(f, x, p, c2=c2) s = res.a_k # s, _, _, _, _, _ = ls.line_search_wolfe2(f, fp, x, p, amax=30, c2=c2) self.assert_line_wolfe(x, p, s, f, fp) self.assertTrue(s >= 30.) res = line_search(f, x, p, c2=c2, maxiter=5) self.assertTrue(res.failed) # s=30 will only be tried on the 6th iteration, so this won't converge def test_line_search(self): def f(x): return jnp.cos(jnp.sum(jnp.exp(-x))**2) # assert not line_search(jax.value_and_grad(f), np.ones(2), np.array([-0.5, -0.25])).failed xk = jnp.ones(2) pk = jnp.array([-0.5, -0.25]) res = line_search(f, xk, pk, maxiter=100) scipy_res = line_search_wolfe2(f, grad(f), xk, pk) self.assertAllClose(scipy_res[0], res.a_k, atol=1e-5, check_dtypes=False) self.assertAllClose(scipy_res[3], res.f_k, atol=1e-5, check_dtypes=False)
class CustomObjectTest(jtu.JaxTestCase): @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_compile={}_primitive={}".format(compile, primitive), "compile": compile, "primitive": primitive } for primitive in [True, False] for compile in [True, False])) def testSparseIdentity(self, compile, primitive): f = identity if primitive else (lambda x: x) f = jit(f) if compile else f rng = jtu.rand_default(self.rng()) M = make_sparse_array(rng, (10, ), jnp.float32) M2 = f(M) jaxpr = make_jaxpr(f)(M).jaxpr core.check_jaxpr(jaxpr) self.assertEqual(M.dtype, M2.dtype) self.assertEqual(M.index_dtype, M2.index_dtype) self.assertAllClose(M.data, M2.data) self.assertAllClose(M.indices, M2.indices) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_compile={}".format(compile), "compile": compile } for compile in [True, False])) def testSparseSplit(self, compile): f = jit(split) if compile else split rng = jtu.rand_default(self.rng()) M = make_sparse_array(rng, (10, ), jnp.float32) M2, M3 = f(M) jaxpr = make_jaxpr(f)(M).jaxpr core.check_jaxpr(jaxpr) for MM in M2, M3: self.assertEqual(M.dtype, MM.dtype) self.assertEqual(M.index_dtype, MM.index_dtype) self.assertArraysEqual(M.data, MM.data) self.assertArraysEqual(M.indices, MM.indices) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_compile={}_primitive={}".format(compile, primitive), "compile": compile, "primitive": primitive } for primitive in [True, False] for compile in [True, False])) def testSparseLaxLoop(self, compile, primitive): rng = jtu.rand_default(self.rng()) f = identity if primitive else (lambda x: x) f = jit(f) if compile else f body_fun = lambda _, A: f(A) M = make_sparse_array(rng, (10, ), jnp.float32) lax.fori_loop(0, 10, body_fun, M) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_attr={}".format(attr), "attr": attr } for attr in ["data", "indices"])) def testSparseAttrAccess(self, attr): rng = jtu.rand_default(self.rng()) args_maker = lambda: [make_sparse_array(rng, (10, ), jnp.float32)] f = lambda x: getattr(x, attr) self._CompileAndCheck(f, args_maker) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype } for shape in [(3, 3), (2, 6), (6, 2)] for dtype in jtu.dtypes.floating)) def testSparseMatvec(self, shape, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: [ make_sparse_array(rng, shape, dtype), rng(shape[-1:], dtype) ] self._CompileAndCheck(matvec, args_maker) def testLowerToNothing(self): empty = Empty(AbstractEmpty()) jaxpr = make_jaxpr(jit(lambda e: e))(empty).jaxpr core.check_jaxpr(jaxpr) # cannot return a unit, because CompileAndCheck assumes array output. testfunc = lambda e: None args_maker = lambda: [empty] self._CompileAndCheck(testfunc, args_maker) def testConstantHandler(self): def make_const_array(): data = np.arange(3.0) indices = np.arange(3)[:, None] shape = (5, ) aval = AbstractSparseArray(shape, data.dtype, indices.dtype, len(indices)) return SparseArray(aval, data, indices) out1 = make_const_array() out2 = jit(make_const_array)() self.assertArraysEqual(out1.data, out2.data) self.assertArraysEqual(out1.indices, out2.indices)
class SparsifyTest(jtu.JaxTestCase): @classmethod def sparsify(cls, f): return sparsify(f, use_tracer=False) def testNotImplementedMessages(self): x = BCOO.fromdense(jnp.arange(5.0)) # Test a densifying primitive with self.assertRaisesRegex( NotImplementedError, r"^sparse rule for cos is not implemented because it would result in dense output\." ): self.sparsify(lax.cos)(x) # Test a generic not implemented primitive. with self.assertRaisesRegex( NotImplementedError, r"^sparse rule for complex is not implemented\.$"): self.sparsify(lax.complex)(x, x) def testTracerIsInstanceCheck(self): @self.sparsify def f(x): self.assertNotIsInstance(x, SparseTracer) f(jnp.arange(5)) def assertBcooIdentical(self, x, y): self.assertIsInstance(x, BCOO) self.assertIsInstance(y, BCOO) self.assertEqual(x.shape, y.shape) self.assertArraysEqual(x.data, y.data) self.assertArraysEqual(x.indices, y.indices) def testSparsifyValue(self): X = jnp.arange(5) X_BCOO = BCOO.fromdense(X) args = (X, X_BCOO, X_BCOO) # Independent index spenv = SparsifyEnv() spvalues = arrays_to_spvalues(spenv, args) self.assertEqual(len(spvalues), len(args)) self.assertLen(spenv._buffers, 5) self.assertEqual( spvalues, (SparsifyValue(X.shape, 0, None), SparsifyValue( X.shape, 1, 2), SparsifyValue(X.shape, 3, 4))) args_out = spvalues_to_arrays(spenv, spvalues) self.assertEqual(len(args_out), len(args)) self.assertArraysEqual(args[0], args_out[0]) self.assertBcooIdentical(args[1], args_out[1]) self.assertBcooIdentical(args[2], args_out[2]) # Shared index spvalues = (SparsifyValue(X.shape, 0, None), SparsifyValue(X.shape, 1, 2), SparsifyValue(X.shape, 3, 2)) spenv = SparsifyEnv([X, X_BCOO.data, X_BCOO.indices, X_BCOO.data]) args_out = spvalues_to_arrays(spenv, spvalues) self.assertEqual(len(args_out), len(args)) self.assertArraysEqual(args[0], args_out[0]) self.assertBcooIdentical(args[1], args_out[1]) self.assertBcooIdentical(args[2], args_out[2]) def testDropvar(self): def inner(x): return x * 2, x * 3 def f(x): _, y = jit(inner)(x) return y * 4 x_dense = jnp.arange(5) x_sparse = BCOO.fromdense(x_dense) self.assertArraysEqual( self.sparsify(f)(x_sparse).todense(), f(x_dense)) def testPytreeInput(self): f = self.sparsify(lambda x: x) args = (jnp.arange(4), BCOO.fromdense(jnp.arange(4))) out = f(args) self.assertLen(out, 2) self.assertArraysEqual(args[0], out[0]) self.assertBcooIdentical(args[1], out[1]) def testSparsify(self): M_dense = jnp.arange(24).reshape(4, 6) M_sparse = BCOO.fromdense(M_dense) v = jnp.arange(M_dense.shape[0]) @self.sparsify def func(x, v): return -jnp.sin(jnp.pi * x).T @ (v + 1) with jtu.ignore_warning( category=CuSparseEfficiencyWarning, message= "bcoo_dot_general GPU lowering requires matrices with sorted indices*" ): result_sparse = func(M_sparse, v) result_dense = func(M_dense, v) self.assertAllClose(result_sparse, result_dense) def testSparsifyWithConsts(self): M_dense = jnp.arange(24).reshape(4, 6) M_sparse = BCOO.fromdense(M_dense) @self.sparsify def func(x): return jit(lambda x: jnp.sum(x, 1))(x) result_dense = func(M_dense) result_sparse = func(M_sparse) self.assertAllClose(result_sparse.todense(), result_dense) def testSparseMatmul(self): X = jnp.arange(16).reshape(4, 4) Xsp = BCOO.fromdense(X) Y = jnp.ones(4) Ysp = BCOO.fromdense(Y) func = self.sparsify(operator.matmul) # dot_general with jtu.ignore_warning( category=CuSparseEfficiencyWarning, message= "bcoo_dot_general GPU lowering requires matrices with sorted indices*" ): result_sparse = func(Xsp, Y) result_dense = operator.matmul(X, Y) self.assertAllClose(result_sparse, result_dense) # rdot_general with jtu.ignore_warning( category=CuSparseEfficiencyWarning, message= "bcoo_dot_general GPU lowering requires matrices with sorted indices*" ): result_sparse = func(Y, Xsp) result_dense = operator.matmul(Y, X) self.assertAllClose(result_sparse, result_dense) # spdot_general result_sparse = self.sparsify(operator.matmul)(Xsp, Ysp) result_dense = operator.matmul(X, Y) self.assertAllClose(result_sparse.todense(), result_dense) def testSparseAdd(self): x = BCOO.fromdense(jnp.arange(5)) y = BCOO.fromdense(2 * jnp.arange(5)) # Distinct indices out = self.sparsify(operator.add)(x, y) self.assertEqual(out.nse, 8) # uses concatenation. self.assertArraysEqual(out.todense(), 3 * jnp.arange(5)) # Shared indices – requires lower level call spenv = SparsifyEnv([x.indices, x.data, y.data]) spvalues = [ spenv.sparse(x.shape, data_ref=1, indices_ref=0), spenv.sparse(y.shape, data_ref=2, indices_ref=0) ] result = sparsify_raw(operator.add)(spenv, *spvalues) args_out, _ = result out, = spvalues_to_arrays(spenv, args_out) self.assertAllClose(out.todense(), x.todense() + y.todense()) def testSparseMul(self): x = BCOO.fromdense(jnp.arange(5)) y = BCOO.fromdense(2 * jnp.arange(5)) # Scalar multiplication out = self.sparsify(operator.mul)(x, 2.5) self.assertArraysEqual(out.todense(), x.todense() * 2.5) # Shared indices – requires lower level call spenv = SparsifyEnv([x.indices, x.data, y.data]) spvalues = [ spenv.sparse(x.shape, data_ref=1, indices_ref=0), spenv.sparse(y.shape, data_ref=2, indices_ref=0) ] result = sparsify_raw(operator.mul)(spenv, *spvalues) args_out, _ = result out, = spvalues_to_arrays(spenv, args_out) self.assertAllClose(out.todense(), x.todense() * y.todense()) def testSparseSubtract(self): x = BCOO.fromdense(3 * jnp.arange(5)) y = BCOO.fromdense(jnp.arange(5)) # Distinct indices out = self.sparsify(operator.sub)(x, y) self.assertEqual(out.nse, 8) # uses concatenation. self.assertArraysEqual(out.todense(), 2 * jnp.arange(5)) # Shared indices – requires lower level call spenv = SparsifyEnv([x.indices, x.data, y.data]) spvalues = [ spenv.sparse(x.shape, data_ref=1, indices_ref=0), spenv.sparse(y.shape, data_ref=2, indices_ref=0) ] result = sparsify_raw(operator.sub)(spenv, *spvalues) args_out, _ = result out, = spvalues_to_arrays(spenv, args_out) self.assertAllClose(out.todense(), x.todense() - y.todense()) def testSparseSum(self): x = jnp.arange(20).reshape(4, 5) xsp = BCOO.fromdense(x) def f(x): return x.sum(), x.sum(0), x.sum(1), x.sum((0, 1)) result_dense = f(x) result_sparse = self.sparsify(f)(xsp) assert len(result_dense) == len(result_sparse) for res_dense, res_sparse in zip(result_dense, result_sparse): if isinstance(res_sparse, BCOO): res_sparse = res_sparse.todense() self.assertArraysAllClose(res_dense, res_sparse) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_dimensions={}_nbatch={}_ndense={}".format( jtu.format_shape_dtype_string(shape, np.float32), dimensions, n_batch, n_dense), "shape": shape, "dimensions": dimensions, "n_batch": n_batch, "n_dense": n_dense } for shape, dimensions in [ [(1, ), (0, )], [(1, ), (-1, )], [(2, 1, 4), (1, )], [(2, 1, 3, 1), (1, )], [(2, 1, 3, 1), (1, 3)], [(2, 1, 3, 1), (3, )], ] for n_batch in range(len(shape) + 1) for n_dense in range(len(shape) - n_batch + 1))) def testSparseSqueeze(self, shape, dimensions, n_batch, n_dense): rng = jtu.rand_default(self.rng()) M_dense = rng(shape, np.float32) M_sparse = BCOO.fromdense(M_dense, n_batch=n_batch, n_dense=n_dense) func = self.sparsify(partial(lax.squeeze, dimensions=dimensions)) result_dense = func(M_dense) result_sparse = func(M_sparse).todense() self.assertAllClose(result_sparse, result_dense) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_shapes={shapes}_func={func}_nbatch={n_batch}", "shapes": shapes, "func": func, "n_batch": n_batch } for shapes, func, n_batch in [ ([(4, ), (4, )], "concatenate", 0), ([(4, ), (4, )], "stack", 0), ([(4, ), (4, )], "hstack", 0), ([(4, ), (4, )], "vstack", 0), ([(4, ), (4, )], "concatenate", 1), ([(4, ), (4, )], "stack", 1), ([(4, ), (4, )], "hstack", 1), ([(4, ), (4, )], "vstack", 1), ([(2, 4), (2, 4)], "stack", 0), ([(2, 4), (3, 4)], "vstack", 0), ([(2, 4), (2, 5)], "hstack", 0), ([(2, 4), (3, 4)], "vstack", 1), ([(2, 4), (2, 5)], "hstack", 1), ([(2, 4), (3, 4)], "vstack", 2), ([(2, 4), (2, 5)], "hstack", 2), ([(2, 4), (4, ), (3, 4)], "vstack", 0), ([(1, 4), (4, ), (1, 4)], "vstack", 0), ])) def testSparseConcatenate(self, shapes, func, n_batch): f = self.sparsify(getattr(jnp, func)) rng = jtu.rand_some_zero(self.rng()) arrs = [rng(shape, 'int32') for shape in shapes] sparrs = [BCOO.fromdense(arr, n_batch=n_batch) for arr in arrs] self.assertArraysEqual(f(arrs), f(sparrs).todense()) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_{shape}->{new_shape}_n_batch={n_batch}_n_dense={n_dense}", "shape": shape, "new_shape": new_shape, "n_batch": n_batch, "n_dense": n_dense } for shape, new_shape, n_batch, n_dense in [ [(6, ), (2, 3), 0, 0], [(1, 4), (2, 2), 0, 0], [(12, 2), (2, 3, 4), 0, 0], [(1, 3, 2), (2, 3), 0, 0], [(1, 6), (2, 3, 1), 0, 0], [(2, 3, 4), (3, 8), 0, 0], [(2, 3, 4), (1, 2, 12), 1, 0], [(2, 3, 4), (6, 2, 2), 2, 0], ])) def testSparseReshapeMethod(self, shape, new_shape, n_batch, n_dense): rng = jtu.rand_some_zero(self.rng()) arr = rng(shape, 'int32') arr_sparse = BCOO.fromdense(arr, n_batch=n_batch, n_dense=n_dense) arr2 = arr.reshape(new_shape) arr2_sparse = arr_sparse.reshape(new_shape) self.assertArraysEqual(arr2, arr2_sparse.todense()) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_{shape}->{new_shape}_n_batch={n_batch}_n_dense={n_dense}_dimensions={dimensions}", "shape": shape, "new_shape": new_shape, "n_batch": n_batch, "n_dense": n_dense, "dimensions": dimensions } for shape, new_shape, n_batch, n_dense, dimensions in [ [(2, 3, 4), (24, ), 0, 0, None], [(2, 3, 4), (24, ), 0, 0, (0, 1, 2)], [(2, 3, 4), (24, ), 0, 0, (0, 2, 1)], [(2, 3, 4), (24, ), 0, 0, (1, 0, 2)], [(2, 3, 4), (24, ), 0, 0, (1, 2, 0)], [(2, 3, 4), (24, ), 0, 0, (2, 0, 1)], [(2, 3, 4), (24, ), 0, 0, (2, 1, 0)], [(4, 2, 3), (2, 2, 6), 1, 0, (0, 1, 2)], [(4, 2, 3), (2, 2, 6), 1, 0, (0, 2, 1)], [(2, 3, 4), (6, 4), 2, 0, (0, 1, 2)], [(2, 3, 4), (6, 4), 2, 0, (1, 0, 2)], ])) def testSparseReshapeWithDimensions(self, shape, new_shape, n_batch, n_dense, dimensions): rng = jtu.rand_some_zero(self.rng()) arr = rng(shape, 'int32') arr_sparse = BCOO.fromdense(arr, n_batch=n_batch, n_dense=n_dense) f = self.sparsify( lambda x: lax.reshape(x, new_shape, dimensions=dimensions)) arr2 = f(arr) arr2_sparse = f(arr_sparse) self.assertArraysEqual(arr2, arr2_sparse.todense()) def testSparseWhileLoop(self): def cond_fun(params): i, A = params return i < 5 def body_fun(params): i, A = params return i + 1, 2 * A def f(A): return lax.while_loop(cond_fun, body_fun, (0, A)) A = jnp.arange(4) out_dense = f(A) Asp = BCOO.fromdense(A) out_sparse = self.sparsify(f)(Asp) self.assertEqual(len(out_dense), 2) self.assertEqual(len(out_sparse), 2) self.assertArraysEqual(out_dense[0], out_dense[0]) self.assertArraysEqual(out_dense[1], out_sparse[1].todense()) def testSparseWhileLoopDuplicateIndices(self): def cond_fun(params): i, A, B = params return i < 5 def body_fun(params): i, A, B = params # TODO(jakevdp): track shared indices through while loop & use this # version of the test, which requires shared indices in order for # the nse of the result to remain the same. # return i + 1, A, A + B # This version is fine without shared indices, and tests that we're # flattening non-shared indices consistently. return i + 1, B, A def f(A): return lax.while_loop(cond_fun, body_fun, (0, A, A)) A = jnp.arange(4).reshape((2, 2)) out_dense = f(A) Asp = BCOO.fromdense(A) out_sparse = self.sparsify(f)(Asp) self.assertEqual(len(out_dense), 3) self.assertEqual(len(out_sparse), 3) self.assertArraysEqual(out_dense[0], out_dense[0]) self.assertArraysEqual(out_dense[1], out_sparse[1].todense()) self.assertArraysEqual(out_dense[2], out_sparse[2].todense()) def testSparsifyDenseXlaCall(self): # Test handling of dense xla_call within jaxpr interpreter. out = self.sparsify(jit(lambda x: x + 1))(0.0) self.assertEqual(out, 1.0) def testSparsifySparseXlaCall(self): # Test sparse lowering of XLA call def func(M): return 2 * M M = jnp.arange(6).reshape(2, 3) Msp = BCOO.fromdense(M) out_dense = func(M) out_sparse = self.sparsify(jit(func))(Msp) self.assertArraysEqual(out_dense, out_sparse.todense()) def testSparseForiLoop(self): def func(M, x): body_fun = lambda i, val: (M @ val) / M.shape[1] return lax.fori_loop(0, 2, body_fun, x) x = jnp.arange(5.0) M = jnp.arange(25).reshape(5, 5) M_bcoo = BCOO.fromdense(M) result_dense = func(M, x) with jtu.ignore_warning( category=CuSparseEfficiencyWarning, message= "bcoo_dot_general GPU lowering requires matrices with sorted indices*" ): result_sparse = self.sparsify(func)(M_bcoo, x) self.assertArraysAllClose(result_dense, result_sparse) def testSparseCondSimple(self): def func(x): return lax.cond(False, lambda x: x, lambda x: 2 * x, x) x = jnp.arange(5.0) result_dense = func(x) x_bcoo = BCOO.fromdense(x) result_sparse = self.sparsify(func)(x_bcoo) self.assertArraysAllClose(result_dense, result_sparse.todense()) def testSparseCondMismatchError(self): @self.sparsify def func(x, y): return lax.cond(False, lambda x: x[0], lambda x: x[1], (x, y)) x = jnp.arange(5.0) y = jnp.arange(5.0) x_bcoo = BCOO.fromdense(x) y_bcoo = BCOO.fromdense(y) func(x, y) # No error func(x_bcoo, y_bcoo) # No error with self.assertRaisesRegex( TypeError, "sparsified true_fun and false_fun output.*"): func(x_bcoo, y) def testToDense(self): M = jnp.arange(4) Msp = BCOO.fromdense(M) @self.sparsify def func(M): return todense(M) + 1 self.assertArraysEqual(func(M), M + 1) self.assertArraysEqual(func(Msp), M + 1) self.assertArraysEqual(jit(func)(M), M + 1) self.assertArraysEqual(jit(func)(Msp), M + 1) def testWeakTypes(self): # Regression test for https://github.com/google/jax/issues/8267 M = jnp.arange(12, dtype='int32').reshape(3, 4) Msp = BCOO.fromdense(M) self.assertArraysEqual( operator.mul(2, M), self.sparsify(operator.mul)(2, Msp).todense(), check_dtypes=True, )