예제 #1
0
 def test_check_jaxpr_correct(self):
     jaxpr = make_jaxpr(lambda x: jnp.sin(x) + jnp.cos(x))(1.).jaxpr
     core.check_jaxpr(jaxpr)
예제 #2
0
파일: core_test.py 프로젝트: d3sm0/jax
 def test_check_jaxpr_cond_correct(self):
     jaxpr = make_jaxpr(lambda x: lax.switch(0, [jnp.sin, jnp.cos], x))(
         1.).jaxpr
     core.check_jaxpr(jaxpr)