def test_typecheck_staging_nested(self): n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False) m = core.ShapedArray((), jnp.dtype('int32'), weak_type=False) a = core.DShapedArray((DBIdx(0), ), jnp.dtype('float32'), weak_type=False) b = core.DShapedArray((DBIdx(1), ), jnp.dtype('float32'), weak_type=False) @lu.wrap_init def f(a, b): @jax.jit def g(x): return x return g(a), jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( f, [n, m, a, b], keep_inputs=[False, False, True, True]) # { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let # e:f32[a] = xla_call[ # call_jaxpr={ lambda ; f:i32[] g:f32[f]. let in (g,) } # name=g # ] a c # in (e,) } core.check_jaxpr(jaxpr) # no problems here... # Let's introduce a type error by applying the called jaxpr to arguments # with types which aren't consistent with its input binders: _, _, c, d = jaxpr.invars jaxpr.eqns[0].invars[1] = d # { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let # e:f32[a] = xla_call[ # call_jaxpr={ lambda ; f:i32[] g:f32[f]. let in (g,) } # name=g # ] a d !!! type error here !!! # in (e,) } with self.assertRaisesRegex(TypeError, "passes operand"): core.check_jaxpr(jaxpr) # Restore the original jaxpr: jaxpr.eqns[0].invars[1] = c core.check_jaxpr(jaxpr) # no problems here... # Let's introduce another type error by setting the call result let binders # to have the wrong type: jaxpr.eqns[0].outvars[0] = core.Var(0, '', d.aval) # { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let # e:f32[b] = xla_call[ !!! type error here !!! # call_jaxpr={ lambda ; f:i32[] g:f32[f]. let in (g,) } # name=g # ] a c # in (h,) } with self.assertRaisesRegex(TypeError, "inconsistently typed as"): core.check_jaxpr(jaxpr)
def cond_bind(*args, branches, linear): if config.jax_enable_checks: avals = _map(core.get_aval, args) in_atoms = [core.Var(0, '', a) for a in avals] # dummies _cond_typecheck(*in_atoms, branches=branches, linear=linear) for jaxpr in branches: core.check_jaxpr(jaxpr.jaxpr) return core.AxisPrimitive.bind(cond_p, *args, branches=branches, linear=linear)
def mk_new_var(aval: core.AbstractValue) -> core.Var: return core.Var(next(mk_new_id), '', aval)