def test_np_or_jnp(self):
        self.assertIs(jax_util.np_or_jnp(1.), np)
        self.assertIs(jax_util.np_or_jnp(np.array(1.)), np)
        self.assertIs(jax_util.np_or_jnp(jnp.array(1.)), jnp)

        def trace_check(x):
            self.assertIs(jax_util.np_or_jnp(x), jnp)

        jax.make_jaxpr(trace_check)(np.array(1.))
 def trace_check(x):
     self.assertIs(jax_util.np_or_jnp(x), jnp)