예제 #1
0
파일: djax_test.py 프로젝트: yashk2810/jax
  def test_jvp(self):
    @djax.djit
    def f(x):
      y = sin(x)
      return reduce_sum(y, axes=(0,))
    x = bbarray((5,), jnp.arange(2.))
    z, z_dot = jax.jvp(f, (x,), (ones_like(x),))

    def g(x):
      return jnp.sin(x).sum()
    expected_z, expected_z_dot = jax.jvp(g, (np.arange(2.),), (np.ones(2),))

    self.assertAllClose(np.array(z), expected_z, check_dtypes=False)
    self.assertAllClose(np.array(z_dot), expected_z_dot, check_dtypes=False)
예제 #2
0
파일: djax_test.py 프로젝트: yashk2810/jax
  def test_linearize(self):
    @djax.djit
    def f(x):
      y = sin(x)
      return reduce_sum(y, axes=(0,))
    x = bbarray((5,), jnp.arange(2.))
    with jax.enable_checks(False):  # TODO implement dxla_call abs eval rule
      z, f_lin = jax.linearize(f, x)
    z_dot = f_lin(ones_like(x))

    def g(x):
      return jnp.sin(x).sum()
    expected_z, expected_z_dot = jax.jvp(g, (np.arange(2.),), (np.ones(2),))

    self.assertAllClose(np.array(z), expected_z, check_dtypes=False)
    self.assertAllClose(np.array(z_dot), expected_z_dot, check_dtypes=False)