示例#1
0
 def start_subtrace(self):
   """Starts a nested trace, returns the Trace object."""
   # TODO: This follows the __enter__ part of core.new_main.
   if config.omnistaging_enabled:
     level = core.thread_local_state.trace_state.trace_stack.next_level()
     main = core.MainTrace(level, pe.JaxprTrace)
     core.thread_local_state.trace_state.trace_stack.push(main)
     self._count_subtraces += 1
     return pe.JaxprTrace(main, core.cur_sublevel())
   else:
     level = core.thread_local_state.trace_state.trace_stack.next_level(False)
     main = core.MainTrace(level, pe.JaxprTrace)
     core.thread_local_state.trace_state.trace_stack.push(main, False)
     self._count_subtraces += 1
     return pe.JaxprTrace(main, core.cur_sublevel())
示例#2
0
文件: loops.py 项目: xueeinstein/jax
 def start_subtrace(self):
     """Starts a nested trace, returns the Trace object."""
     # TODO: This follows the __enter__ part of core.new_main.
     level = core.thread_local_state.trace_state.trace_stack.next_level()
     name_stack = source_info_util.current_name_stack()
     main = core.MainTrace(level, pe.JaxprTrace, name_stack=name_stack)
     core.thread_local_state.trace_state.trace_stack.push(main)
     self._count_subtraces += 1
     return pe.JaxprTrace(main, core.cur_sublevel(), name_stack=name_stack)
示例#3
0
def _copy_main_traces(x):
    if isinstance(x, core.MainTrace):
        return core.MainTrace(x.level, x.trace_type, **x.payload)
    else:
        return x