Beispiel #1
0
 def test_check_jaxpr_correct(self):
     jaxpr = make_jaxpr(lambda x: jnp.sin(x) + jnp.cos(x))(1.).jaxpr
     core.check_jaxpr(jaxpr)
Beispiel #2
0
 def test_check_jaxpr_cond_correct(self):
     jaxpr = make_jaxpr(lambda x: lax.switch(0, [jnp.sin, jnp.cos], x))(
         1.).jaxpr
     core.check_jaxpr(jaxpr)