Esempio n. 1
0
  def test_linearize(self):
    @djax.djit
    def f(x):
      y = sin(x)
      return reduce_sum(y, axes=(0,))
    x = bbarray((5,), jnp.arange(2.))
    with jax.enable_checks(False):  # TODO implement dxla_call abs eval rule
      z, f_lin = jax.linearize(f, x)
    z_dot = f_lin(ones_like(x))

    def g(x):
      return jnp.sin(x).sum()
    expected_z, expected_z_dot = jax.jvp(g, (np.arange(2.),), (np.ones(2),))

    self.assertAllClose(np.array(z), expected_z, check_dtypes=False)
    self.assertAllClose(np.array(z_dot), expected_z_dot, check_dtypes=False)
Esempio n. 2
0
  def testStopGradient(self):
    def f(x):
      return lax.sin(x) * lax.cos(lax.stop_gradient(x))

    def f2(x, y):
      return lax.sin(x) * lax.cos(y)

    x = 3.14
    ans = api.grad(f)(x)
    expected = api.grad(f2)(x, x)
    self.assertAllClose(ans, expected)

    ans = api.grad(api.grad(f))(x)
    expected = api.grad(api.grad(f2))(x, x)
    self.assertAllClose(ans, expected)

    ans = api.grad(lambda x: lax.stop_gradient({'foo':x})['foo'])(3.)
    expected = np.array(0.0)
    self.assertAllClose(ans, expected, check_dtypes=False)

    with jax.enable_checks(False):
      with self.assertRaises(TypeError):
        lax.stop_gradient(lambda x: x)
Esempio n. 3
0
File: nn_test.py Progetto: gtr8/jax
 def testHardTanhMemory(self):
   # see https://github.com/google/jax/pull/1640
   with jax.enable_checks(False):  # With checks we materialize the array
     jax.make_jaxpr(lambda: nn.hard_tanh(jnp.ones((10 ** 12,))))  # don't oom