Ejemplo n.º 1
0
    def test_sum_of_nonzero_typechecks(self):
        def f(x):
            return reduce_sum(nonzero(x), tuple(range(len(x.shape))))

        x = jnp.array([1, 0, -2, 0, 3, 0])
        jaxpr, _, _ = djax.make_djaxpr(f, x)
        djax.typecheck_jaxpr(jaxpr)
Ejemplo n.º 2
0
    def test_nonzero_typechecks(self):
        def f(x):
            return nonzero(x)

        x = jnp.array([1, 0, -2, 0, 3, 0])
        jaxpr, _, _ = djax.make_djaxpr(f, x)
        djax.typecheck_jaxpr(jaxpr)
Ejemplo n.º 3
0
    def test_sin_typechecks(self):
        def f(x):
            return sin(x)

        x = bbarray((5, ), jnp.arange(3.))
        jaxpr, _, _ = djax.make_djaxpr(f, x)
        djax.typecheck_jaxpr(jaxpr)
Ejemplo n.º 4
0
    def test_identity_typechecks(self):
        def f(x):
            return x

        x = jnp.array([0, 1])
        jaxpr, _, _ = djax.make_djaxpr(f, x)
        djax.typecheck_jaxpr(jaxpr)
Ejemplo n.º 5
0
 def test_sin_and_add_typechecks(self):
   def f(x):
     y = sin(x)
     z = sin(y)
     return add(y, z)
   x = bbarray((5,), jnp.arange(3.))
   jaxpr, _, _ = djax.make_djaxpr(f, x)
   djax.typecheck_jaxpr(jaxpr)
Ejemplo n.º 6
0
 def test_iota_typechecks(self):
   def f():
     return iota(3)
   jaxpr, _, _ = djax.make_djaxpr(f)
   djax.typecheck_jaxpr(jaxpr)