def test_np_or_jnp(self): self.assertIs(jax_util.np_or_jnp(1.), np) self.assertIs(jax_util.np_or_jnp(np.array(1.)), np) self.assertIs(jax_util.np_or_jnp(jnp.array(1.)), jnp) def trace_check(x): self.assertIs(jax_util.np_or_jnp(x), jnp) jax.make_jaxpr(trace_check)(np.array(1.))
def trace_check(x): self.assertIs(jax_util.np_or_jnp(x), jnp)