def test_eval_shape_big_random_array(self): if not config.omnistaging_enabled: raise SkipTest("after deleting lazy constants, requires omnistaging") def f(x): return random.normal(random.PRNGKey(x), (int(1e12),)) with core.skipping_checks(): # check_jaxpr will materialize array api.eval_shape(f, 0) # doesn't error
def testEnumPromotion(self): class AnEnum(enum.IntEnum): A = 42 B = 101 onp.testing.assert_equal(onp.array(42), onp.array(AnEnum.A)) with core.skipping_checks(): # Passing AnEnum.A to np.array fails the type check in bind onp.testing.assert_equal(np.array(42), np.array(AnEnum.A)) onp.testing.assert_equal(onp.int32(101), onp.int32(AnEnum.B)) onp.testing.assert_equal(np.int32(101), np.int32(AnEnum.B))
def testStopGradient(self): def f(x): return lax.sin(x) * lax.cos(lax.stop_gradient(x)) def f2(x, y): return lax.sin(x) * lax.cos(y) x = 3.14 ans = api.grad(f)(x) expected = api.grad(f2)(x, x) self.assertAllClose(ans, expected) ans = api.grad(api.grad(f))(x) expected = api.grad(api.grad(f2))(x, x) self.assertAllClose(ans, expected) ans = api.grad(lambda x: lax.stop_gradient({'foo':x})['foo'])(3.) expected = onp.array(0.0) self.assertAllClose(ans, expected, check_dtypes=False) with core.skipping_checks(): with self.assertRaises(TypeError): lax.stop_gradient(lambda x: x)
def testHardTanhMemory(self): # see https://github.com/google/jax/pull/1640 with core.skipping_checks(): # With checks we materialize the array jax.make_jaxpr(lambda: nn.hard_tanh(jnp.ones( (10**12, )))) # don't oom
def test_eval_shape_big_random_array(self): def f(x): return random.normal(random.PRNGKey(x), (int(1e12),)) with core.skipping_checks(): # check_jaxpr will materialize array api.eval_shape(f, 0) # doesn't error
def testEluMemory(self): # see https://github.com/google/jax/pull/1640 with core.skipping_checks(): # With checks we materialize the array jax.make_jaxpr(nn.elu)(jnp.ones((10**12, ))) # don't oom