def test_check_jaxpr_eqn_mismatch(self): def f(x): return jnp.sin(x) + jnp.cos(x) def new_jaxpr(): return make_jaxpr(f)(jnp.float32(1.)).jaxpr # jaxpr is: # # { lambda ; a. # let b = sin a # c = cos a # d = add b c # in (d,) } # # NB: eqns[0].outvars[0] and eqns[2].invars[0] are both 'b' jaxpr = new_jaxpr() # int, not float! jaxpr.eqns[0].outvars[0].aval = make_shaped_array(jnp.int32(2)) self.assertRaisesRegex( core.JaxprTypeError, r"Variable 'b' inconsistently typed as f32\[\], " r"bound as i32\[\]\n\nin equation:\n\nb:i32\[\] = sin a", lambda: core.check_jaxpr(jaxpr)) jaxpr = new_jaxpr() jaxpr.eqns[0].outvars[0].aval = make_shaped_array( np.ones((2, 3), dtype=jnp.float32)) self.assertRaisesRegex( core.JaxprTypeError, r"Variable 'b' inconsistently typed as f32\[\], " r"bound as f32\[2,3\]\n\nin equation:\n\nb:f32\[2,3\] = sin a", lambda: core.check_jaxpr(jaxpr))
def test_check_jaxpr_eqn_mismatch(self): def f(x): return jnp.sin(x) + jnp.cos(x) def new_jaxpr(): return make_jaxpr(f)(1.).jaxpr # jaxpr is: # # { lambda ; a. # let b = sin a # c = cos a # d = add b c # in (d,) } # # NB: eqns[0].outvars[0] and eqns[2].invars[0] are both 'b' jaxpr = new_jaxpr() jaxpr.eqns[0].outvars[0].aval = make_shaped_array(2) # int, not float! self.assertRaisesRegex( core.JaxprTypeError, r"Variable '.' inconsistently typed as ShapedArray(.*), " r"bound as ShapedArray(.*)\n\nin equation:\n\n . = sin .", lambda: core.check_jaxpr(jaxpr)) jaxpr = new_jaxpr() jaxpr.eqns[0].outvars[0].aval = make_shaped_array(np.ones((2, 3))) self.assertRaisesRegex( core.JaxprTypeError, r"Variable '.' inconsistently typed as ShapedArray(.*), " r"bound as ShapedArray(.*)\n\nin equation:\n\n . = sin .", lambda: core.check_jaxpr(jaxpr))