def test_check_jaxpr_eqn_mismatch(self): def f(x): return jnp.sin(x) + jnp.cos(x) def new_jaxpr(): return make_jaxpr(f)(1.).jaxpr # jaxpr is: # # { lambda ; a. # let b = sin a # c = cos a # d = add b c # in (d,) } # # NB: eqns[0].outvars[0] and eqns[2].invars[0] are both 'b' jaxpr = new_jaxpr() jaxpr.eqns[0].outvars[0].aval = make_shaped_array(2) # int, not float! jtu.check_raises_regexp( lambda: core.check_jaxpr(jaxpr), TypeError, ("Jaxpr equation LHS .* is ShapedArray(.*), " "RHS is inferred as ShapedArray(.*), in '.* = sin .*'")) jaxpr = new_jaxpr() jaxpr.eqns[0].outvars[0].aval = make_shaped_array(np.ones((2, 3))) jtu.check_raises_regexp( lambda: core.check_jaxpr(jaxpr), TypeError, ("Jaxpr equation LHS .* is ShapedArray(.*), " "RHS is inferred as ShapedArray(.*), in '.* = sin .*'"))
def test_check_jaxpr_eqn_mismatch(self): def f(x): return jnp.sin(x) + jnp.cos(x) def new_jaxpr(): return make_jaxpr(f)(1.).jaxpr # jaxpr is: # # { lambda ; a. # let b = sin a # c = cos a # d = add b c # in (d,) } # # NB: eqns[0].outvars[0] and eqns[2].invars[0] are both 'b' jaxpr = new_jaxpr() jaxpr.eqns[0].outvars[0].aval = make_shaped_array(2) # int, not float! self.assertRaisesRegex( core.JaxprTypeError, r"Variable '.' inconsistently typed as ShapedArray(.*), " r"bound as ShapedArray(.*)\n\nin equation:\n\n . = sin .", lambda: core.check_jaxpr(jaxpr)) jaxpr = new_jaxpr() jaxpr.eqns[0].outvars[0].aval = make_shaped_array(np.ones((2, 3))) self.assertRaisesRegex( core.JaxprTypeError, r"Variable '.' inconsistently typed as ShapedArray(.*), " r"bound as ShapedArray(.*)\n\nin equation:\n\n . = sin .", lambda: core.check_jaxpr(jaxpr))
def test_check_jaxpr_eqn_mismatch(self): def f(x): return jnp.sin(x) + jnp.cos(x) def new_jaxpr(): return make_jaxpr(f)(1.).jaxpr # jaxpr is: # # { lambda ; a. # let b = sin a # c = cos a # d = add b c # in (d,) } # # NB: eqns[0].outvars[0] and eqns[2].invars[0] are both 'b' jaxpr = new_jaxpr() jaxpr.eqns[0].outvars[0].aval = make_shaped_array(2) # int, not float! jtu.check_raises_regexp( lambda: core.check_jaxpr(jaxpr), TypeError, (r"Variable '.' inconsistently typed as ShapedArray(.*), " r"bound as ShapedArray(.*) in '. = sin .'")) jaxpr = new_jaxpr() jaxpr.eqns[0].outvars[0].aval = make_shaped_array(np.ones((2, 3))) jtu.check_raises_regexp( lambda: core.check_jaxpr(jaxpr), TypeError, (r"Variable '.' inconsistently typed as ShapedArray(.*), " r"bound as ShapedArray(.*) in '. = sin .'"))