Ejemplo n.º 1
0
 def np_fn(op, indexer, x, y):
   x = x.copy()
   x[indexer] = {
     UpdateOps.UPDATE: lambda: y,
     UpdateOps.ADD: lambda: x[indexer] + y,
     UpdateOps.MUL: lambda: x[indexer] * y,
     UpdateOps.DIV: jtu.ignore_warning(category=RuntimeWarning)(
       lambda: x[indexer] / y.astype(x.dtype)),
     UpdateOps.POW: jtu.ignore_warning(category=RuntimeWarning)(
       lambda: x[indexer] ** y.astype(x.dtype)),
     UpdateOps.MIN: lambda: np.minimum(x[indexer], y),
     UpdateOps.MAX: lambda: np.maximum(x[indexer], y),
   }[op]()
   return x
Ejemplo n.º 2
0
    def testRngRandomBits(self):
        # Test specific outputs to ensure consistent random values between JAX versions.
        key = random.PRNGKey(1701)

        bits8 = jax._src.random._random_bits(key, 8, (3, ))
        expected8 = np.array([216, 115, 43], dtype=np.uint8)
        self.assertArraysEqual(bits8, expected8)

        bits16 = jax._src.random._random_bits(key, 16, (3, ))
        expected16 = np.array([41682, 1300, 55017], dtype=np.uint16)
        self.assertArraysEqual(bits16, expected16)

        bits32 = jax._src.random._random_bits(key, 32, (3, ))
        expected32 = np.array([56197195, 4200222568, 961309823],
                              dtype=np.uint32)
        self.assertArraysEqual(bits32, expected32)

        with jtu.ignore_warning(category=UserWarning,
                                message="Explicitly requested dtype.*"):
            bits64 = jax._src.random._random_bits(key, 64, (3, ))
        if config.x64_enabled:
            expected64 = np.array([
                3982329540505020460, 16822122385914693683, 7882654074788531506
            ],
                                  dtype=np.uint64)
        else:
            expected64 = np.array([676898860, 3164047411, 4010691890],
                                  dtype=np.uint32)
        self.assertArraysEqual(bits64, expected64)
Ejemplo n.º 3
0
  def test_minimize(self, maxiter, func_and_init):

    func, x0 = func_and_init

    @jit
    def min_op(x0):
      result = jax.scipy.optimize.minimize(
          func(jnp),
          x0,
          method='l-bfgs-experimental-do-not-rely-on-this',
          options=dict(maxiter=maxiter, gtol=1e-7),
      )
      return result.x

    jax_res = min_op(x0)

    # Note that without bounds, L-BFGS-B is just L-BFGS
    with jtu.ignore_warning(category=DeprecationWarning,
                            message=".*tostring.*is deprecated.*"):
      scipy_res = scipy.optimize.minimize(func(np), x0, method='L-BFGS-B').x

    if func.__name__ == 'matyas':
      # scipy performs badly for Matyas, compare to true minimum instead
      self.assertAllClose(jax_res, jnp.zeros_like(jax_res), atol=1e-7)
      return

    if func.__name__ == 'eggholder':
      # L-BFGS performs poorly for the eggholder function.
      # Neither scipy nor jax find the true minimum, so we can only loosely (with high atol) compare the false results
      self.assertAllClose(jax_res, scipy_res, atol=1e-3)
      return

    self.assertAllClose(jax_res, scipy_res, atol=2e-5, check_dtypes=False)
Ejemplo n.º 4
0
 def testConvertElementTypeGrad(self, from_dtype, to_dtype, rng_factory):
   rng = rng_factory(self.rng())
   tol = max(jtu.tolerance(to_dtype, jtu.default_gradient_tolerance),
             jtu.tolerance(from_dtype, jtu.default_gradient_tolerance))
   args = (rng((2, 3), from_dtype),)
   convert_element_type = lambda x: lax.convert_element_type(x, to_dtype)
   convert_element_type = jtu.ignore_warning(category=onp.ComplexWarning)(
     convert_element_type)
   check_grads(convert_element_type, args, 2, ["fwd", "rev"], tol, tol, eps=1.)