def trace_to_jaxpr_finalize(in_tracers, out_tracers, trace, instantiate=True): # TODO: This is the final part of the partial_eval.trace_to_subjaxpr. Share. instantiate = [instantiate] * len(out_tracers) out_tracers = safe_map(trace.full_raise, safe_map(core.full_lower, out_tracers)) out_tracers = safe_map(partial(pe.instantiate_const_at, trace), instantiate, out_tracers) jaxpr, consts, env = pe.tracers_to_jaxpr(in_tracers, out_tracers) out_pvals = [t.pval for t in out_tracers] # TODO: this is from partial_eval.trace_to_jaxpr. Share. assert not env # TODO: this is from the final part of lax_control_flow._initial_style_jaxpr out_avals = safe_map(abstract_arrays.raise_to_shaped, unzip2(out_pvals)[0]) const_avals = tuple( abstract_arrays.raise_to_shaped(core.get_aval(c)) for c in consts) in_pvals = [t.pval for t in in_tracers] in_avals = tuple( safe_map(abstract_arrays.raise_to_shaped, unzip2(in_pvals)[0])) typed_jaxpr = core.TypedJaxpr(pe.closure_convert_jaxpr(jaxpr), (), const_avals + in_avals, out_avals) return typed_jaxpr, consts
def _initial_style_jaxpr(fun, in_tree, in_avals): in_pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals] fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True) out_avals = _map(raise_to_shaped, unzip2(out_pvals)[0]) const_avals = tuple(raise_to_shaped(core.get_aval(c)) for c in consts) typed_jaxpr = core.TypedJaxpr(pe.closure_convert_jaxpr(jaxpr), (), const_avals + in_avals, out_avals) return typed_jaxpr, consts, out_tree()
def _flat_initial_style_jaxpr(fun, in_avals): """lax_control_flow._initial_style_jaxpr, but for flat arguments and results.""" jaxpr, out_avals, consts = _instantiated_trace_to_jaxpr(fun, in_avals) return TypedJaxpr(closure_convert_jaxpr(jaxpr), (), in_avals=_abstractified(consts) + in_avals, out_avals=map(raise_to_shaped, out_avals)), consts