def testQdwhWithRandomMatrix(self, m, n, log_cond): """Tests qdwh with random input.""" rng = jtu.rand_uniform(self.rng(), low=0.3, high=0.9) a = rng((m, n), _QDWH_TEST_DTYPE) u, s, v = jnp.linalg.svd(a, full_matrices=False) cond = 10**log_cond s = jnp.linspace(cond, 1, min(m, n)) a = (u * s) @ v is_symmetric = _check_symmetry(a) max_iterations = 10 def lsp_linalg_fn(a): u, h, _, _ = qdwh.qdwh(a, is_symmetric=is_symmetric, max_iterations=max_iterations) return u, h args_maker = lambda: [a] # Sets the test tolerance. rtol = 1E6 * _QDWH_TEST_EPS with self.subTest('Test JIT compatibility'): self._CompileAndCheck(lsp_linalg_fn, args_maker) with self.subTest('Test against numpy.'): self._CheckAgainstNumpy(osp_linalg.polar, lsp_linalg_fn, args_maker, rtol=rtol, atol=1E-3)
def binary_check(self, fun, lims=[-2, 2], order=3, finite=True, dtype=None): dims = 2, 3 rng = np.random.RandomState(0) if isinstance(lims, tuple): x_lims, y_lims = lims else: x_lims, y_lims = lims, lims if dtype is None: primal_in = (transform(x_lims, rng.rand(*dims)), transform(y_lims, rng.rand(*dims))) series_in = ([rng.randn(*dims) for _ in range(order)], [rng.randn(*dims) for _ in range(order)]) else: rng = jtu.rand_uniform(rng, *lims) primal_in = (rng(dims, dtype), rng(dims, dtype)) series_in = ([rng(dims, dtype) for _ in range(order)], [rng(dims, dtype) for _ in range(order)]) if finite: self.check_jet(fun, primal_in, series_in, atol=1e-4, rtol=1e-4) else: self.check_jet_finite(fun, primal_in, series_in, atol=1e-4, rtol=1e-4)
def unary_check(self, fun, lims=[-2, 2], order=3, dtype=None): dims = 2, 3 rng = np.random.RandomState(0) if dtype is None: primal_in = transform(lims, rng.rand(*dims)) terms_in = [rng.randn(*dims) for _ in range(order)] else: rng = jtu.rand_uniform(rng, *lims) primal_in = rng(dims, dtype) terms_in = [rng(dims, dtype) for _ in range(order)] self.check_jet(fun, (primal_in,), (terms_in,), atol=1e-4, rtol=1e-4)
def testResizeAgainstPIL(self, dtype, image_shape, target_shape, method): rng = jtu.rand_uniform(self.rng()) args_maker = lambda: (rng(image_shape, dtype), ) def pil_fn(x): pil_methods = { "bilinear": PIL_Image.BILINEAR, "bicubic": PIL_Image.BICUBIC, "lanczos3": PIL_Image.LANCZOS, } img = PIL_Image.fromarray(x.astype(np.float32)) out = np.asarray(img.resize(target_shape[::-1], pil_methods[method]), dtype=dtype) return out jax_fn = partial(image.resize, shape=target_shape, method=method, antialias=True) self._CheckAgainstNumpy(pil_fn, jax_fn, args_maker, check_dtypes=True)
test_name = test_name or name return OpRecord(name, nargs, dtypes, rng, test_grad, test_name) JAX_SPECIAL_FUNCTION_RECORDS = [ # TODO: digamma has no JVP implemented. op_record("digamma", 1, float_dtypes, jtu.rand_positive(), False), op_record("erf", 1, float_dtypes, jtu.rand_small_positive(), True), op_record("erfc", 1, float_dtypes, jtu.rand_small_positive(), True), op_record("erfinv", 1, float_dtypes, jtu.rand_small_positive(), True), op_record("expit", 1, float_dtypes, jtu.rand_small_positive(), True), # TODO: gammaln has slightly high error. op_record("gammaln", 1, float_dtypes, jtu.rand_positive(), False), op_record("logit", 1, float_dtypes, jtu.rand_small_positive(), False), op_record("log_ndtr", 1, float_dtypes, jtu.rand_small(), True), op_record("ndtri", 1, float_dtypes, jtu.rand_uniform(0., 1.), True), op_record("ndtr", 1, float_dtypes, jtu.rand_default(), True), ] CombosWithReplacement = itertools.combinations_with_replacement class LaxBackedScipyTests(jtu.JaxTestCase): """Tests for LAX-backed Scipy implementation.""" def _GetArgsMaker(self, rng, shapes, dtypes): return lambda: [ rng(shape, dtype) for shape, dtype in zip(shapes, dtypes) ] @parameterized.named_parameters( jtu.cases_from_list({
op_record("less_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal(), []), op_record("log", 1, number_dtypes, all_shapes, jtu.rand_positive(), ["rev"]), op_record("logical_and", 2, all_dtypes, all_shapes, jtu.rand_bool(), []), op_record("logical_not", 1, all_dtypes, all_shapes, jtu.rand_bool(), []), op_record("logical_or", 2, all_dtypes, all_shapes, jtu.rand_bool(), []), op_record("logical_xor", 2, all_dtypes, all_shapes, jtu.rand_bool(), []), op_record("maximum", 2, number_dtypes, all_shapes, jtu.rand_some_inf(), []), op_record("minimum", 2, number_dtypes, all_shapes, jtu.rand_some_inf(), []), op_record("multiply", 2, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("negative", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("not_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal(), ["rev"]), op_record("power", 2, inexact_dtypes, all_shapes, jtu.rand_positive(), ["rev"]), op_record("subtract", 2, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("sin", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("cos", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("tan", 1, number_dtypes, all_shapes, jtu.rand_uniform(-1.5, 1.5), ["rev"]), op_record("sinh", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("cosh", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("tanh", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("arcsin", 1, float_dtypes, all_shapes, jtu.rand_small(), ["rev"]), op_record("arccos", 1, float_dtypes, all_shapes, jtu.rand_small(), ["rev"]), op_record("arctan", 1, float_dtypes, all_shapes, jtu.rand_small(), ["rev"]), op_record("arctan2", 2, float_dtypes, all_shapes, jtu.rand_small(), ["rev"]), op_record("arcsinh", 1, number_dtypes, all_shapes, jtu.rand_small(), ["rev"]), op_record("arccosh", 1, number_dtypes, all_shapes, jtu.rand_small(), ["rev"]), op_record("arctanh", 1, number_dtypes, all_shapes, jtu.rand_small(), ["rev"]), ] JAX_COMPOUND_OP_RECORDS = [ op_record("divide", 2, number_dtypes, all_shapes, jtu.rand_nonzero(), ["rev"]),