Exemplo n.º 1
0
 def test_eval_shape_big_random_array(self):
   if not config.omnistaging_enabled:
     raise SkipTest("after deleting lazy constants, requires omnistaging")
   def f(x):
     return random.normal(random.PRNGKey(x), (int(1e12),))
   with core.skipping_checks():  # check_jaxpr will materialize array
     api.eval_shape(f, 0)  # doesn't error
Exemplo n.º 2
0
 def testEnumPromotion(self):
   class AnEnum(enum.IntEnum):
     A = 42
     B = 101
   onp.testing.assert_equal(onp.array(42), onp.array(AnEnum.A))
   with core.skipping_checks():
     # Passing AnEnum.A to np.array fails the type check in bind
     onp.testing.assert_equal(np.array(42), np.array(AnEnum.A))
   onp.testing.assert_equal(onp.int32(101), onp.int32(AnEnum.B))
   onp.testing.assert_equal(np.int32(101), np.int32(AnEnum.B))
Exemplo n.º 3
0
  def testStopGradient(self):
    def f(x):
      return lax.sin(x) * lax.cos(lax.stop_gradient(x))

    def f2(x, y):
      return lax.sin(x) * lax.cos(y)

    x = 3.14
    ans = api.grad(f)(x)
    expected = api.grad(f2)(x, x)
    self.assertAllClose(ans, expected)

    ans = api.grad(api.grad(f))(x)
    expected = api.grad(api.grad(f2))(x, x)
    self.assertAllClose(ans, expected)

    ans = api.grad(lambda x: lax.stop_gradient({'foo':x})['foo'])(3.)
    expected = onp.array(0.0)
    self.assertAllClose(ans, expected, check_dtypes=False)

    with core.skipping_checks():
      with self.assertRaises(TypeError):
        lax.stop_gradient(lambda x: x)
Exemplo n.º 4
0
 def testHardTanhMemory(self):
     # see https://github.com/google/jax/pull/1640
     with core.skipping_checks():  # With checks we materialize the array
         jax.make_jaxpr(lambda: nn.hard_tanh(jnp.ones(
             (10**12, ))))  # don't oom
Exemplo n.º 5
0
 def test_eval_shape_big_random_array(self):
   def f(x):
     return random.normal(random.PRNGKey(x), (int(1e12),))
   with core.skipping_checks():  # check_jaxpr will materialize array
     api.eval_shape(f, 0)  # doesn't error
Exemplo n.º 6
0
 def testEluMemory(self):
     # see https://github.com/google/jax/pull/1640
     with core.skipping_checks():  # With checks we materialize the array
         jax.make_jaxpr(nn.elu)(jnp.ones((10**12, )))  # don't oom