def test_sum_of_nonzero_typechecks(self): def f(x): return reduce_sum(nonzero(x), tuple(range(len(x.shape)))) x = jnp.array([1, 0, -2, 0, 3, 0]) jaxpr, _, _ = djax.make_djaxpr(f, x) djax.typecheck_jaxpr(jaxpr)
def test_nonzero_typechecks(self): def f(x): return nonzero(x) x = jnp.array([1, 0, -2, 0, 3, 0]) jaxpr, _, _ = djax.make_djaxpr(f, x) djax.typecheck_jaxpr(jaxpr)
def test_sin_typechecks(self): def f(x): return sin(x) x = bbarray((5, ), jnp.arange(3.)) jaxpr, _, _ = djax.make_djaxpr(f, x) djax.typecheck_jaxpr(jaxpr)
def test_identity_typechecks(self): def f(x): return x x = jnp.array([0, 1]) jaxpr, _, _ = djax.make_djaxpr(f, x) djax.typecheck_jaxpr(jaxpr)
def test_sin_and_add_typechecks(self): def f(x): y = sin(x) z = sin(y) return add(y, z) x = bbarray((5,), jnp.arange(3.)) jaxpr, _, _ = djax.make_djaxpr(f, x) djax.typecheck_jaxpr(jaxpr)
def test_iota_typechecks(self): def f(): return iota(3) jaxpr, _, _ = djax.make_djaxpr(f) djax.typecheck_jaxpr(jaxpr)