def _sparsify_jaxpr(spenv, jaxpr, *argspecs): # TODO(jakevdp): currently this approach discards all information about # shared data & indices when generating the sparsified jaxpr. The # current approach produces valid sparsified while loops, but they # don't work in corner cases (see associated TODO in sparsify_test.py) out_tree = None @lu.wrap_init def wrapped(*args_flat): nonlocal out_tree args = tree_unflatten(in_tree, args_flat) argspecs = arrays_to_argspecs(spenv, args) result = eval_sparse(jaxpr.jaxpr, jaxpr.consts, argspecs, spenv) out = argspecs_to_arrays(spenv, result) out_flat, out_tree = tree_flatten(out) return out_flat args = argspecs_to_arrays(spenv, argspecs) args_flat, in_tree = tree_flatten(args) avals_flat = [ core.raise_to_shaped(core.get_aval(arg)) for arg in args_flat ] sp_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped, avals_flat) sp_jaxpr = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(sp_jaxpr), consts) return sp_jaxpr, out_tree
def fwd(*args, **kwargs): ans, rule = fun(*args, **kwargs) ans_flat, out_tree = tree_flatten((ans, )) rule, in_tree = flatten_fun_nokwargs(lu.wrap_init(rule), out_tree) ans_avals = [core.get_aval(x).at_least_vspace() for x in ans_flat] jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(rule, ans_avals) return ans, Residuals(jaxpr, in_tree(), out_tree, consts)
def _trace_to_jaxpr_with_refs( f, state_tree: PyTreeDef, state_avals: Sequence[core.AbstractValue] ) -> Tuple[core.Jaxpr, List[Any], PyTreeDef]: f, out_tree_thunk = flatten_fun_nokwargs( lu.wrap_init(f), treedef_tuple((tree_structure(0), state_tree))) jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(f, state_avals) return jaxpr, consts, out_tree_thunk()
def _sparsify_jaxpr(spenv, jaxpr, *spvalues): # TODO(jakevdp): currently this approach discards all information about # shared data & indices when generating the sparsified jaxpr. The # current approach produces valid sparsified while loops, but they # don't work in corner cases (see associated TODO in sparsify_test.py) out_tree = None @lu.wrap_init def wrapped(*args_flat): # TODO(frostig,jakevdp): This closes over `spenv`, which can bring # in buffers from the "outer scope" as constants. Is this a # problem for primitives like cond and while_loop, which always # convert constvars to invars when staging out their subjaxprs? nonlocal out_tree args = tree_unflatten(in_tree, args_flat) spvalues = arrays_to_spvalues(spenv, args) result = eval_sparse(jaxpr.jaxpr, jaxpr.consts, spvalues, spenv) out = spvalues_to_arrays(spenv, result) out_flat, out_tree = tree_flatten(out) return out_flat args = spvalues_to_arrays(spenv, spvalues) args_flat, in_tree = tree_flatten(args) avals_flat = [core.raise_to_shaped(core.get_aval(arg)) for arg in args_flat] sp_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped, avals_flat) sp_jaxpr = pe.ClosedJaxpr(sp_jaxpr, consts) return sp_jaxpr, out_tree
def remat_transpose(reduce_axes, out_cts, *in_primals, jaxpr, **params): assert not jaxpr.constvars cell = lambda: None @lu.wrap_init def transposed(*args): in_primals, out_cts = tree_unflatten(treedef, args) in_pvals = [pe.PartialVal.unknown(x.aval) if ad.is_undefined_primal(x) else pe.PartialVal.known(x) for x in in_primals] primal_fun = lu.wrap_init(partial(core.eval_jaxpr, jaxpr, ())) t_jaxpr, _, consts = pe.trace_to_jaxpr_nounits(primal_fun, in_pvals, False) dummy_args = [ad.UndefinedPrimal(v.aval) for v in t_jaxpr.invars] in_cts = ad.backward_pass(t_jaxpr, reduce_axes, False, consts, dummy_args, out_cts) in_cts_ = iter(in_cts) in_cts = [next(in_cts_) if ad.is_undefined_primal(x) else ad_util.Zero(x.aval) for x in in_primals] assert next(in_cts_, None) is None in_cts, cell.treedef = tree_flatten(in_cts) return in_cts args, treedef = tree_flatten((in_primals, out_cts)) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args] transposed_jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(transposed, in_avals) transposed_jaxpr = pe.convert_constvars_jaxpr(transposed_jaxpr_) in_cts = remat_p.bind(*consts, *args, jaxpr=transposed_jaxpr, **params) return tree_unflatten(cell.treedef, in_cts) # type: ignore
def test_staging_nested(self): n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False) a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False) b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False) @lu.wrap_init def f(x, y): @jax.jit def g(x, y, z, w): return (x, w) return g(x, y, x, y) jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b], keep_inputs=[False, True, True]) self.assertLen(jaxpr.invars, 1 + 2) # one axis size var, two other inputs self.assertEqual((jaxpr.invars[0],), jaxpr.invars[1].aval.shape) self.assertEqual((jaxpr.invars[0],), jaxpr.invars[2].aval.shape) self.assertLen(jaxpr.outvars, 2) self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[0].aval.shape) self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[1].aval.shape) self.assertLen(jaxpr.eqns, 1) eqn = jaxpr.eqns[0] self.assertIsInstance(eqn.primitive, core.CallPrimitive) inner_jaxpr = eqn.params['call_jaxpr'] self.assertIsInstance(inner_jaxpr, core.Jaxpr) self.assertLen(inner_jaxpr.invars, 1 + 4) # one axis size var self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[1].aval.shape) self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[2].aval.shape) self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[3].aval.shape) self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[4].aval.shape)
def test_staging_nested_including_shape_arg(self): # This test covers the _get_tracers_only_in_shapes logic in partial_eval.py. n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False) a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False) b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False) @lu.wrap_init def f(x, y): @jax.jit def g(_, x, y, z, w): return (x, w) return g(x.shape[0], x, y, x, y) jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b], keep_inputs=[False, True, True]) self.assertLen(jaxpr.eqns, 1) eqn = jaxpr.eqns[0] self.assertIsInstance(eqn.primitive, core.CallPrimitive) inner_jaxpr = eqn.params['call_jaxpr'] self.assertIsInstance(inner_jaxpr, core.Jaxpr) self.assertLen(inner_jaxpr.invars, 1 + 4) # one axis size var self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[1].aval.shape) self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[2].aval.shape) self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[3].aval.shape) self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[4].aval.shape)
def test_staging_primitive_applications(self): n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False) a = core.DShapedArray((pe.DBIdx(0), ), jnp.dtype('float32'), weak_type=False) b = core.DShapedArray((pe.DBIdx(0), ), jnp.dtype('float32'), weak_type=False) @lu.wrap_init def f(x, y): z = lax.mul(x, y) w = lax.sin(z) u = lax_internal._reduce_sum(w, [0]) return (u, ) jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( f, [n, a, b], keep_inputs=[False, True, True]) self.assertLen(jaxpr.invars, 1 + 2) # one axis size var, two other inputs self.assertLen(jaxpr.eqns, 3) self.assertLen(jaxpr.eqns[0].outvars, 1) self.assertEqual(jaxpr.eqns[0].outvars[0].aval.shape, jaxpr.invars[1].aval.shape) self.assertLen(jaxpr.outvars, 1) self.assertEqual(jaxpr.outvars[0].aval.shape, ())
def closure_convert(fun, in_tree, in_avals): if config.omnistaging_enabled: wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic( wrapped_fun, in_avals) else: in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals] wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) with core.initial_style_staging(): # type: ignore jaxpr, out_pvals, consts = pe.trace_to_jaxpr( wrapped_fun, in_pvals, instantiate=True, stage_out=False) # type: ignore out_tree = out_tree() # We only want to closure convert for constants with respect to which we're # differentiating. As a proxy for that, we hoist consts with float dtype. # TODO(mattjj): revise this approach is_float = lambda c: dtypes.issubdtype(dtypes.dtype(c), jnp.inexact) (closure_consts, hoisted_consts), merge = partition_list(is_float, consts) num_consts = len(hoisted_consts) def converted_fun(y, t, *hconsts_args): hoisted_consts, args = split_list(hconsts_args, [num_consts]) consts = merge(closure_consts, hoisted_consts) all_args, in_tree2 = tree_flatten((y, t, *args)) assert in_tree == in_tree2 out_flat = core.eval_jaxpr(jaxpr, consts, *all_args) return tree_unflatten(out_tree, out_flat) return converted_fun, hoisted_consts
def checkify_fun_to_jaxpr(f, error, enabled_errors, in_avals): f, msgs = checkify_subtrace(f) f = checkify_traceable(f, tuple(error.msgs.items()), enabled_errors) err_aval = core.raise_to_shaped(core.get_aval(error.err)) code_aval = core.raise_to_shaped(core.get_aval(error.code)) avals_in = [err_aval, code_aval, *in_avals] jaxpr_out, _, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in) return core.ClosedJaxpr(jaxpr_out, literals_out), msgs()
def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals, primitive_name: Optional[str] = None): wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) debug = pe.debug_info(fun, in_tree, False, primitive_name or "<unknown>") jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug) return jaxpr, consts, out_tree()
def _jvp_jaxpr(jaxpr, nonzeros, instantiate): assert len(jaxpr.in_avals) == len(nonzeros) f = lu.wrap_init(core.jaxpr_as_fun(jaxpr)) f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate, transform_stack=False), nonzeros) tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz] avals_in = list(it.chain(jaxpr.in_avals, tangent_avals)) jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in) return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros()
def _batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest, axis_name, main_type): f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr)) f, out_batched = _batch_jaxpr_inner(f, axis_size, out_axes_dest) f = _batch_jaxpr_outer(f, axis_name, axis_size, in_axes, main_type) avals_in = [core.unmapped_aval(axis_size, axis_name, b, aval) if b is not not_mapped else aval for aval, b in zip(closed_jaxpr.in_avals, in_axes)] jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in) return core.ClosedJaxpr(jaxpr_out, consts), out_batched()
def checkify_jaxpr(jaxpr, error): f = lu.wrap_init(core.jaxpr_as_fun(jaxpr)) f, msgs = check_errors_subtrace(f) f = check_errors_traceable(f, tuple(error.msgs.items())) err_aval = core.raise_to_shaped(core.get_aval(error.err)) code_aval = core.raise_to_shaped(core.get_aval(error.code)) avals_in = [err_aval, code_aval, *jaxpr.in_avals] jaxpr_out, _, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in) return core.ClosedJaxpr(jaxpr_out, literals_out), msgs()
def test_typecheck_staging_nested(self): n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False) m = core.ShapedArray((), jnp.dtype('int32'), weak_type=False) a = core.DShapedArray((DBIdx(0), ), jnp.dtype('float32'), weak_type=False) b = core.DShapedArray((DBIdx(1), ), jnp.dtype('float32'), weak_type=False) @lu.wrap_init def f(a, b): @jax.jit def g(x): return x return g(a), jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( f, [n, m, a, b], keep_inputs=[False, False, True, True]) # { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let # e:f32[a] = xla_call[ # call_jaxpr={ lambda ; f:i32[] g:f32[f]. let in (g,) } # name=g # ] a c # in (e,) } core.check_jaxpr(jaxpr) # no problems here... # Let's introduce a type error by applying the called jaxpr to arguments # with types which aren't consistent with its input binders: _, _, c, d = jaxpr.invars jaxpr.eqns[0].invars[1] = d # { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let # e:f32[a] = xla_call[ # call_jaxpr={ lambda ; f:i32[] g:f32[f]. let in (g,) } # name=g # ] a d !!! type error here !!! # in (e,) } with self.assertRaisesRegex(TypeError, "passes operand"): core.check_jaxpr(jaxpr) # Restore the original jaxpr: jaxpr.eqns[0].invars[1] = c core.check_jaxpr(jaxpr) # no problems here... # Let's introduce another type error by setting the call result let binders # to have the wrong type: jaxpr.eqns[0].outvars[0] = core.Var(0, '', d.aval) # { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let # e:f32[b] = xla_call[ !!! type error here !!! # call_jaxpr={ lambda ; f:i32[] g:f32[f]. let in (g,) } # name=g # ] a c # in (h,) } with self.assertRaisesRegex(TypeError, "inconsistently typed as"): core.check_jaxpr(jaxpr)
def wrapped(spenv: SparsifyEnv, *spvalues: SparsifyValue, **params: Any) -> Tuple[Sequence[SparsifyValue], bool]: spvalues_flat, in_tree = tree_flatten(spvalues, is_leaf=_is_spvalue) in_avals_flat = spvalues_to_avals(spenv, spvalues_flat) wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(f, params), in_tree) jaxpr, out_avals_flat, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat) result = eval_sparse(jaxpr, consts, spvalues_flat, spenv) if len(out_avals_flat) != len(result): raise Exception("Internal: eval_sparse does not return expected number of arguments. " "Got {result} for avals {out_avals_flat}") return result, out_tree()
def wrapped(spenv: SparseEnv, *argspecs: ArgSpec, **params: Any) -> Tuple[Sequence[ArgSpec], bool]: in_avals = argspecs_to_avals(spenv, argspecs) in_avals_flat, in_tree = tree_flatten(in_avals) wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(f), in_tree) jaxpr, out_avals_flat, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat) result = eval_sparse(jaxpr, consts, argspecs, spenv) if len(out_avals_flat) != len(result): raise Exception("Internal: eval_sparse does not return expected number of arguments. " "Got {result} for avals {out_avals_flat}") return result, out_tree()
def fun_remat(*args, **kwargs): args_flat, in_tree = tree_flatten((args, kwargs)) flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] debug = pe.debug_info(fun, in_tree, False, "checkpoint") jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) out_flat = remat_p.bind( *consts, *args_flat, jaxpr=pe.convert_constvars_jaxpr(jaxpr), prevent_cse=prevent_cse, differentiated=False, policy=policy) return tree_unflatten(out_tree(), out_flat)
def make_jaxpr(fun: Callable, in_tree: PyTreeDef, in_avals: typing.Tuple[core.AbstractValue], # with DBIdx in them keep_inputs: typing.Tuple[bool] ) -> typing.Tuple[core.Jaxpr, List[Any], PyTreeDef]: flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree) debug = pe.debug_info(fun, in_tree, False, "dex_jit") jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug, keep_inputs=keep_inputs) jaxpr = pe.convert_constvars_jaxpr(jaxpr_) consts = [_canonicalize_arg(c) for c in consts] return jaxpr, consts, out_tree()
def f_new(ctx: TranslationContext, avals_in: Sequence[core.AbstractValue], avals_out: Optional[Sequence[core.AbstractValue]], *xla_args: xc.XlaOp, **params) -> Sequence[xc.XlaOp]: wrapped_fun = lu.wrap_init(fun, params) if not multiple_results: wrapped_fun = _tuple_output(wrapped_fun) with core.extend_axis_env_nd(zip(ctx.axis_env.names, ctx.axis_env.sizes)): jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals_in) return jaxpr_subcomp(ctx, jaxpr, _xla_consts(ctx.builder, consts), *xla_args)
def f_lowered(ctx, avals_in, avals_out, *args, **params): if multiple_results: f = fun else: f = lambda *args, **kw: (fun(*args, **kw), ) wrapped_fun = lu.wrap_init(f, params) with core.extend_axis_env_nd( zip(ctx.axis_env.names, ctx.axis_env.sizes)): jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals_in) return jaxpr_subcomp(ctx, jaxpr, _ir_consts(consts), *map(wrap_singleton_ir_values, args))
def discharge_state(jaxpr: core.Jaxpr, consts: Sequence[Any]) -> Tuple[core.Jaxpr, List[Any]]: """Converts a jaxpr that takes in `Ref`s into one that doesn't.""" in_avals = [ core.ShapedArray(v.aval.shape, v.aval.dtype) if type(v.aval) is ShapedArrayRef else v.aval for v in jaxpr.invars ] eval_jaxpr = lu.wrap_init( partial(_eval_jaxpr_discharge_state, jaxpr, consts)) new_jaxpr, _, new_consts = pe.trace_to_jaxpr_dynamic(eval_jaxpr, in_avals) return new_jaxpr, new_consts
def __call__(self, *args, **kwargs): assert not kwargs args_flat, in_tree = tree_flatten(args) flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] debug = pe.debug_info(self.fun, in_tree, False, "custom_vmap") jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) assert not len(consts) closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) out_flat = custom_vmap_p.bind(*consts, *args_flat, call=closed_call, rule=self.vmap_rule, in_tree=in_tree) return tree_unflatten(out_tree(), out_flat)
def f_with_avals(c, avals, xla_args, params): # parallelism is only supported via the new-style API. axis_env = AxisEnv(1, (), ()) wrapped_fun = lu.wrap_init(fun, params) if not multiple_results: wrapped_fun = _tuple_output(wrapped_fun) with core.extend_axis_env_nd(zip(axis_env.names, axis_env.sizes)): jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals) ctx = TranslationContext(c, backend, axis_env, new_name_stack()) outs = jaxpr_subcomp(ctx, jaxpr, _xla_consts(c, consts), *xla_args) if (multiple_results or any( len(aval_to_xla_shapes(v.aval)) > 1 for v in jaxpr.outvars)): return xops.Tuple(c, outs) else: assert len(outs) == 1, outs return outs[0]
def wrapped(*args, **kwargs): fun = lu.wrap_init(f, kwargs) flat_args, in_tree = tree_util.tree_flatten(args) flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) flat_avals = safe_map(get_shaped_aval, flat_args) if dynamic: jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( flat_fun, flat_avals) else: pvals = [pe.PartialVal.unknown(aval) for aval in flat_avals] jaxpr, _, consts = pe.trace_to_jaxpr( flat_fun, pvals, instantiate=True) typed_jaxpr = jax_core.ClosedJaxpr(jaxpr, consts) return typed_jaxpr, (in_tree, out_tree())
def test_staging_nested_including_shape_arg(self): n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False) a = core.DShapedArray((pe.DBIdx(0), ), jnp.dtype('float32'), weak_type=False) b = core.DShapedArray((pe.DBIdx(0), ), jnp.dtype('float32'), weak_type=False) @lu.wrap_init def f(x, y): @jax.jit def g(_, x, y, z, w): return (x, w) return g(x.shape[0], x, y, x, y) jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( f, [n, a, b], keep_inputs=[False, True, True]) print(jaxpr) # { lambda ; a:i32[] b:f32[a] c:f32[a]. let # d:f32[a] e:f32[a] = xla_call[ # call_jaxpr={ lambda ; f:i32[] g:i32[] h:f32[f] i:f32[f] j:f32[f] k:f32[f]. let # # in (h, k) } # name=g # ] a a b c b c # in (d, e) } self.assertLen(jaxpr.eqns, 1) eqn = jaxpr.eqns[0] self.assertIsInstance(eqn.primitive, core.CallPrimitive) inner_jaxpr = eqn.params['call_jaxpr'] self.assertIsInstance(inner_jaxpr, core.Jaxpr) self.assertLen(inner_jaxpr.invars, 1 + 4) # one axis size var self.assertEqual((inner_jaxpr.invars[0], ), inner_jaxpr.invars[1].aval.shape) self.assertEqual((inner_jaxpr.invars[0], ), inner_jaxpr.invars[2].aval.shape) self.assertEqual((inner_jaxpr.invars[0], ), inner_jaxpr.invars[3].aval.shape) self.assertEqual((inner_jaxpr.invars[0], ), inner_jaxpr.invars[4].aval.shape)
def wrapped(*args, **kwargs): fun = lu.wrap_init(f, kwargs) flat_args, in_tree = tree_util.tree_flatten(args) flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) flat_avals = safe_map(get_shaped_aval, flat_args) if not jax.config.omnistaging_enabled: raise ValueError('Oryx must be used with JAX omnistaging enabled.') if dynamic: jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) else: pvals = [ pe.PartialVal((aval, jax_core.unit)) for aval in flat_avals ] jaxpr, _, consts = pe.trace_to_jaxpr(flat_fun, pvals, instantiate=True) typed_jaxpr = jax_core.ClosedJaxpr(jaxpr, consts) return typed_jaxpr, (in_tree, out_tree())
def _closure_convert_for_avals(fun, in_tree, in_avals): wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals) out_tree = out_tree() (closure_consts, hoisted_consts), merge = partition_list(_maybe_perturbed, consts) num_consts = len(hoisted_consts) def converted_fun(*args_hconsts): num_args = len(args_hconsts) - num_consts args, hoisted_consts = split_list(args_hconsts, [num_args]) consts = merge(closure_consts, hoisted_consts) all_args, in_tree2 = tree_flatten(tuple(args)) assert in_tree == in_tree2 out_flat = core.eval_jaxpr(jaxpr, consts, *all_args) return tree_unflatten(out_tree, out_flat) return converted_fun, hoisted_consts
def wrapped(*args, **kwargs): fun = lu.wrap_init(f, kwargs) flat_args, in_tree = tree_util.tree_flatten(args) flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) flat_avals = safe_map(get_shaped_aval, flat_args) if dynamic and jax.config.omnistaging_enabled: jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic( flat_fun, flat_avals) else: pvals = [ pe.PartialVal((aval, jax_core.unit)) for aval in flat_avals ] jaxpr, out_pvals, consts = pe.trace_to_jaxpr(flat_fun, pvals, instantiate=True, stage_out=True) out_avals = [pval.get_aval() for pval in out_pvals] typed_jaxpr = jax_core.TypedJaxpr(jaxpr, consts, flat_avals, out_avals) return typed_jaxpr, (in_tree, out_tree())
def test_staging_basic(self): n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False) a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False) b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False) @lu.wrap_init def f(x, y): return x, y jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b], keep_inputs=[False, True, True]) self.assertLen(jaxpr.invars, 3) self.assertEqual((jaxpr.invars[0],), jaxpr.invars[1].aval.shape) self.assertEqual((jaxpr.invars[0],), jaxpr.invars[2].aval.shape) self.assertLen(jaxpr.outvars, 2) self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[0].aval.shape) self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[1].aval.shape)