示例#1
0
    def trace_to_jaxpr_finalize(in_tracers,
                                out_tracers,
                                trace,
                                instantiate=True):
        # TODO: This is the final part of the partial_eval.trace_to_subjaxpr. Share.
        instantiate = [instantiate] * len(out_tracers)
        out_tracers = safe_map(trace.full_raise,
                               safe_map(core.full_lower, out_tracers))
        out_tracers = safe_map(partial(pe.instantiate_const_at, trace),
                               instantiate, out_tracers)
        jaxpr, consts, env = pe.tracers_to_jaxpr(in_tracers, out_tracers)
        out_pvals = [t.pval for t in out_tracers]
        # TODO: this is from partial_eval.trace_to_jaxpr. Share.
        assert not env

        # TODO: this is from the final part of lax_control_flow._initial_style_jaxpr
        out_avals = safe_map(abstract_arrays.raise_to_shaped,
                             unzip2(out_pvals)[0])
        const_avals = tuple(
            abstract_arrays.raise_to_shaped(core.get_aval(c)) for c in consts)

        in_pvals = [t.pval for t in in_tracers]
        in_avals = tuple(
            safe_map(abstract_arrays.raise_to_shaped,
                     unzip2(in_pvals)[0]))

        typed_jaxpr = core.TypedJaxpr(pe.convert_constvars_jaxpr(jaxpr), (),
                                      const_avals + in_avals, out_avals)
        return typed_jaxpr, consts
示例#2
0
 def trace_to_jaxpr_finalize(in_tracers, out_tracers, trace, instantiate=True):
   # TODO: This is the final part of the partial_eval.trace_to_subjaxpr. Share.
   instantiate = [instantiate] * len(out_tracers)
   out_tracers = safe_map(trace.full_raise, safe_map(core.full_lower, out_tracers))
   out_tracers = safe_map(partial(pe.instantiate_const_at, trace),
                          instantiate, out_tracers)
   jaxpr, consts, env = pe.tracers_to_jaxpr(in_tracers, out_tracers)
   assert not env  # TODO: this is from partial_eval.trace_to_jaxpr. Share.
   closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
   return closed_jaxpr, consts