示例#1
0
文件: jax2tf.py 项目: ekelsen/jax
def _interpret_fun(fun: lu.WrappedFun,
                   in_vals: Sequence[TfValOrUnit]) -> Sequence[TfValOrUnit]:
    with core.new_master(TensorFlowTrace) as master:
        fun = _interpret_subtrace(fun, master)
        out_vals: Sequence[TfValOrUnit] = fun.call_wrapped(*in_vals)
        del master
    return out_vals
示例#2
0
 def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
     if self.master.strip_calls:  # type: ignore
         return f.call_wrapped(*tracers)
     vals_in = [t.val for t in tracers]
     f = callback_subtrace(f, self.master)
     vals_out = call_primitive.bind(f, *vals_in, **params)
     return [CallbackTracer(self, val) for val in vals_out]
示例#3
0
def checkify_flat(fun: lu.WrappedFun,
                  enabled_errors: FrozenSet['ErrorCategory'], *args):
    fun, msgs = checkify_subtrace(fun)
    fun = checkify_traceable(fun, tuple(init_error.msgs.items()),
                             enabled_errors)
    err, code, *outvals = fun.call_wrapped(init_error.err, init_error.code,
                                           *args)
    return (err, code, outvals), msgs()
示例#4
0
def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
                              in_avals: Sequence[core.AbstractValue]):
  frame = DJaxprStackFrame()
  with pe.extend_jaxpr_stack(main, frame):
    trace = DJaxprTrace(main, core.cur_sublevel())
    in_dim_tracers, in_avals = _place_in_dim_tracers_in_shapes(trace, in_avals)
    in_tracers = map(trace.new_arg, in_avals)
    ans = fun.call_wrapped(*in_tracers)
    out_tracers = map(trace.full_raise, ans)
  out_dim_tracers = _extract_out_dim_tracers_from_shapes(main, in_dim_tracers, out_tracers)
  return frame.to_jaxpr(in_dim_tracers, in_tracers, out_dim_tracers, out_tracers)
示例#5
0
def vmap_unrestricted(f: lu.WrappedFun, *args, in_axes, axis_name, axis_size):
    f, out_axes = batching.batch_subtrace(f)
    f = batching._batch_outer(f, axis_name, axis_size, in_axes,
                              batching.BatchTrace)
    outs = f.call_wrapped(*args)
    return outs, out_axes()
示例#6
0
def callback_fun(fun: lu.WrappedFun, in_vals, callback, strip_calls):
    fun = callback_subtrace(fun)
    fun = _callback_fun(fun, callback, strip_calls)
    return fun.call_wrapped(*in_vals)
示例#7
0
def check_errors_flat(fun: lu.WrappedFun, *args):
    fun, msgs = check_errors_subtrace(fun)
    fun = check_errors_toplevel(fun)
    err, code, *out_vals = fun.call_wrapped(*args)
    return (err, code, out_vals), msgs()