def test_staging_nested_including_shape_arg(self): # This test covers the _get_tracers_only_in_shapes logic in partial_eval.py. n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False) a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False) b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False) @lu.wrap_init def f(x, y): @jax.jit def g(_, x, y, z, w): return (x, w) return g(x.shape[0], x, y, x, y) jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b], keep_inputs=[False, True, True]) self.assertLen(jaxpr.eqns, 1) eqn = jaxpr.eqns[0] self.assertIsInstance(eqn.primitive, core.CallPrimitive) inner_jaxpr = eqn.params['call_jaxpr'] self.assertIsInstance(inner_jaxpr, core.Jaxpr) self.assertLen(inner_jaxpr.invars, 1 + 4) # one axis size var self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[1].aval.shape) self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[2].aval.shape) self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[3].aval.shape) self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[4].aval.shape)
def test_staging_primitive_applications(self): n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False) a = core.DShapedArray((pe.DBIdx(0), ), jnp.dtype('float32'), weak_type=False) b = core.DShapedArray((pe.DBIdx(0), ), jnp.dtype('float32'), weak_type=False) @lu.wrap_init def f(x, y): z = lax.mul(x, y) w = lax.sin(z) u = lax_internal._reduce_sum(w, [0]) return (u, ) jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( f, [n, a, b], keep_inputs=[False, True, True]) self.assertLen(jaxpr.invars, 1 + 2) # one axis size var, two other inputs self.assertLen(jaxpr.eqns, 3) self.assertLen(jaxpr.eqns[0].outvars, 1) self.assertEqual(jaxpr.eqns[0].outvars[0].aval.shape, jaxpr.invars[1].aval.shape) self.assertLen(jaxpr.outvars, 1) self.assertEqual(jaxpr.outvars[0].aval.shape, ())
def test_staging_nested(self): n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False) a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False) b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False) @lu.wrap_init def f(x, y): @jax.jit def g(x, y, z, w): return (x, w) return g(x, y, x, y) jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b], keep_inputs=[False, True, True]) self.assertLen(jaxpr.invars, 1 + 2) # one axis size var, two other inputs self.assertEqual((jaxpr.invars[0],), jaxpr.invars[1].aval.shape) self.assertEqual((jaxpr.invars[0],), jaxpr.invars[2].aval.shape) self.assertLen(jaxpr.outvars, 2) self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[0].aval.shape) self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[1].aval.shape) self.assertLen(jaxpr.eqns, 1) eqn = jaxpr.eqns[0] self.assertIsInstance(eqn.primitive, core.CallPrimitive) inner_jaxpr = eqn.params['call_jaxpr'] self.assertIsInstance(inner_jaxpr, core.Jaxpr) self.assertLen(inner_jaxpr.invars, 1 + 4) # one axis size var self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[1].aval.shape) self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[2].aval.shape) self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[3].aval.shape) self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[4].aval.shape)
def test_typecheck_staging_nested(self): n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False) m = core.ShapedArray((), jnp.dtype('int32'), weak_type=False) a = core.DShapedArray((DBIdx(0), ), jnp.dtype('float32'), weak_type=False) b = core.DShapedArray((DBIdx(1), ), jnp.dtype('float32'), weak_type=False) @lu.wrap_init def f(a, b): @jax.jit def g(x): return x return g(a), jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( f, [n, m, a, b], keep_inputs=[False, False, True, True]) # { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let # e:f32[a] = xla_call[ # call_jaxpr={ lambda ; f:i32[] g:f32[f]. let in (g,) } # name=g # ] a c # in (e,) } core.check_jaxpr(jaxpr) # no problems here... # Let's introduce a type error by applying the called jaxpr to arguments # with types which aren't consistent with its input binders: _, _, c, d = jaxpr.invars jaxpr.eqns[0].invars[1] = d # { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let # e:f32[a] = xla_call[ # call_jaxpr={ lambda ; f:i32[] g:f32[f]. let in (g,) } # name=g # ] a d !!! type error here !!! # in (e,) } with self.assertRaisesRegex(TypeError, "passes operand"): core.check_jaxpr(jaxpr) # Restore the original jaxpr: jaxpr.eqns[0].invars[1] = c core.check_jaxpr(jaxpr) # no problems here... # Let's introduce another type error by setting the call result let binders # to have the wrong type: jaxpr.eqns[0].outvars[0] = core.Var(0, '', d.aval) # { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let # e:f32[b] = xla_call[ !!! type error here !!! # call_jaxpr={ lambda ; f:i32[] g:f32[f]. let in (g,) } # name=g # ] a c # in (h,) } with self.assertRaisesRegex(TypeError, "inconsistently typed as"): core.check_jaxpr(jaxpr)
def make_aval(arg, spec): if not spec: return shaped_abstractify(arg) assert all(arg.shape[i] == sizes.setdefault(name, arg.shape[i]) for i, name in spec.items()) shape = [env[spec[i]] if i in spec else d for i, d in enumerate(arg.shape)] return core.DShapedArray(tuple(shape), arg.dtype, False)
def test_staging_nested_including_shape_arg(self): n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False) a = core.DShapedArray((pe.DBIdx(0), ), jnp.dtype('float32'), weak_type=False) b = core.DShapedArray((pe.DBIdx(0), ), jnp.dtype('float32'), weak_type=False) @lu.wrap_init def f(x, y): @jax.jit def g(_, x, y, z, w): return (x, w) return g(x.shape[0], x, y, x, y) jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( f, [n, a, b], keep_inputs=[False, True, True]) print(jaxpr) # { lambda ; a:i32[] b:f32[a] c:f32[a]. let # d:f32[a] e:f32[a] = xla_call[ # call_jaxpr={ lambda ; f:i32[] g:i32[] h:f32[f] i:f32[f] j:f32[f] k:f32[f]. let # # in (h, k) } # name=g # ] a a b c b c # in (d, e) } self.assertLen(jaxpr.eqns, 1) eqn = jaxpr.eqns[0] self.assertIsInstance(eqn.primitive, core.CallPrimitive) inner_jaxpr = eqn.params['call_jaxpr'] self.assertIsInstance(inner_jaxpr, core.Jaxpr) self.assertLen(inner_jaxpr.invars, 1 + 4) # one axis size var self.assertEqual((inner_jaxpr.invars[0], ), inner_jaxpr.invars[1].aval.shape) self.assertEqual((inner_jaxpr.invars[0], ), inner_jaxpr.invars[2].aval.shape) self.assertEqual((inner_jaxpr.invars[0], ), inner_jaxpr.invars[3].aval.shape) self.assertEqual((inner_jaxpr.invars[0], ), inner_jaxpr.invars[4].aval.shape)
def test_staging_basic(self): n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False) a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False) b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False) @lu.wrap_init def f(x, y): return x, y jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b], keep_inputs=[False, True, True]) self.assertLen(jaxpr.invars, 3) self.assertEqual((jaxpr.invars[0],), jaxpr.invars[1].aval.shape) self.assertEqual((jaxpr.invars[0],), jaxpr.invars[2].aval.shape) self.assertLen(jaxpr.outvars, 2) self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[0].aval.shape) self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[1].aval.shape)
def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, named_shape_rule, *avals, **kwargs): assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals assert not prim.multiple_results weak_type = weak_type_rule(*avals, **kwargs) least_specialized = _max( map(type, avals), key=operator.attrgetter('array_abstraction_level')) if least_specialized is core.ConcreteArray: out = prim.impl(*[x.val for x in avals], **kwargs) return core.ConcreteArray(out.dtype, out, weak_type=weak_type) elif least_specialized is core.ShapedArray: return core.ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs), weak_type=weak_type, named_shape=named_shape_rule(*avals, **kwargs)) elif least_specialized is core.DShapedArray: return core.DShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs), weak_type) elif least_specialized is core.UnshapedArray: return core.UnshapedArray(dtype_rule(*avals, **kwargs), weak_type=weak_type) else: raise TypeError(avals, least_specialized)