示例#1
0
 def test_check_jaxpr_correct(self):
     jaxpr = make_jaxpr(lambda x: jnp.sin(x) + jnp.cos(x))(1.).jaxpr
     core.check_jaxpr(jaxpr)
示例#2
0
文件: core_test.py 项目: d3sm0/jax
 def test_check_jaxpr_cond_correct(self):
     jaxpr = make_jaxpr(lambda x: lax.switch(0, [jnp.sin, jnp.cos], x))(
         1.).jaxpr
     core.check_jaxpr(jaxpr)