Example #1
0
    def test_sin_typechecks(self):
        def f(x):
            return sin(x)

        x = bbarray((5, ), jnp.arange(3.))
        jaxpr, _, _ = djax.make_djaxpr(f, x)
        djax.typecheck_jaxpr(jaxpr)
Example #2
0
 def test_sin_and_add_typechecks(self):
   def f(x):
     y = sin(x)
     z = sin(y)
     return add(y, z)
   x = bbarray((5,), jnp.arange(3.))
   jaxpr, _, _ = djax.make_djaxpr(f, x)
   djax.typecheck_jaxpr(jaxpr)
Example #3
0
  def test_jvp(self):
    @djax.djit
    def f(x):
      y = sin(x)
      return reduce_sum(y, axes=(0,))
    x = bbarray((5,), jnp.arange(2.))
    z, z_dot = jax.jvp(f, (x,), (ones_like(x),))

    def g(x):
      return jnp.sin(x).sum()
    expected_z, expected_z_dot = jax.jvp(g, (np.arange(2.),), (np.ones(2),))

    self.assertAllClose(np.array(z), expected_z, check_dtypes=False)
    self.assertAllClose(np.array(z_dot), expected_z_dot, check_dtypes=False)
Example #4
0
  def test_linearize(self):
    @djax.djit
    def f(x):
      y = sin(x)
      return reduce_sum(y, axes=(0,))
    x = bbarray((5,), jnp.arange(2.))
    with jax.enable_checks(False):  # TODO implement dxla_call abs eval rule
      z, f_lin = jax.linearize(f, x)
    z_dot = f_lin(ones_like(x))

    def g(x):
      return jnp.sin(x).sum()
    expected_z, expected_z_dot = jax.jvp(g, (np.arange(2.),), (np.ones(2),))

    self.assertAllClose(np.array(z), expected_z, check_dtypes=False)
    self.assertAllClose(np.array(z_dot), expected_z_dot, check_dtypes=False)