def np_fn(op, indexer, x, y): x = x.copy() x[indexer] = { UpdateOps.UPDATE: lambda: y, UpdateOps.ADD: lambda: x[indexer] + y, UpdateOps.MUL: lambda: x[indexer] * y, UpdateOps.DIV: jtu.ignore_warning(category=RuntimeWarning)( lambda: x[indexer] / y.astype(x.dtype)), UpdateOps.POW: jtu.ignore_warning(category=RuntimeWarning)( lambda: x[indexer] ** y.astype(x.dtype)), UpdateOps.MIN: lambda: np.minimum(x[indexer], y), UpdateOps.MAX: lambda: np.maximum(x[indexer], y), }[op]() return x
def testRngRandomBits(self): # Test specific outputs to ensure consistent random values between JAX versions. key = random.PRNGKey(1701) bits8 = jax._src.random._random_bits(key, 8, (3, )) expected8 = np.array([216, 115, 43], dtype=np.uint8) self.assertArraysEqual(bits8, expected8) bits16 = jax._src.random._random_bits(key, 16, (3, )) expected16 = np.array([41682, 1300, 55017], dtype=np.uint16) self.assertArraysEqual(bits16, expected16) bits32 = jax._src.random._random_bits(key, 32, (3, )) expected32 = np.array([56197195, 4200222568, 961309823], dtype=np.uint32) self.assertArraysEqual(bits32, expected32) with jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*"): bits64 = jax._src.random._random_bits(key, 64, (3, )) if config.x64_enabled: expected64 = np.array([ 3982329540505020460, 16822122385914693683, 7882654074788531506 ], dtype=np.uint64) else: expected64 = np.array([676898860, 3164047411, 4010691890], dtype=np.uint32) self.assertArraysEqual(bits64, expected64)
def test_minimize(self, maxiter, func_and_init): func, x0 = func_and_init @jit def min_op(x0): result = jax.scipy.optimize.minimize( func(jnp), x0, method='l-bfgs-experimental-do-not-rely-on-this', options=dict(maxiter=maxiter, gtol=1e-7), ) return result.x jax_res = min_op(x0) # Note that without bounds, L-BFGS-B is just L-BFGS with jtu.ignore_warning(category=DeprecationWarning, message=".*tostring.*is deprecated.*"): scipy_res = scipy.optimize.minimize(func(np), x0, method='L-BFGS-B').x if func.__name__ == 'matyas': # scipy performs badly for Matyas, compare to true minimum instead self.assertAllClose(jax_res, jnp.zeros_like(jax_res), atol=1e-7) return if func.__name__ == 'eggholder': # L-BFGS performs poorly for the eggholder function. # Neither scipy nor jax find the true minimum, so we can only loosely (with high atol) compare the false results self.assertAllClose(jax_res, scipy_res, atol=1e-3) return self.assertAllClose(jax_res, scipy_res, atol=2e-5, check_dtypes=False)
def testConvertElementTypeGrad(self, from_dtype, to_dtype, rng_factory): rng = rng_factory(self.rng()) tol = max(jtu.tolerance(to_dtype, jtu.default_gradient_tolerance), jtu.tolerance(from_dtype, jtu.default_gradient_tolerance)) args = (rng((2, 3), from_dtype),) convert_element_type = lambda x: lax.convert_element_type(x, to_dtype) convert_element_type = jtu.ignore_warning(category=onp.ComplexWarning)( convert_element_type) check_grads(convert_element_type, args, 2, ["fwd", "rev"], tol, tol, eps=1.)