예제 #1
0
 def start_subtrace(self):
   """Starts a nested trace, returns the Trace object."""
   # TODO: This follows the __enter__ part of core.new_main.
   if config.omnistaging_enabled:
     level = core.thread_local_state.trace_state.trace_stack.next_level()
     main = core.MainTrace(level, pe.JaxprTrace)
     core.thread_local_state.trace_state.trace_stack.push(main)
     self._count_subtraces += 1
     return pe.JaxprTrace(main, core.cur_sublevel())
   else:
     level = core.thread_local_state.trace_state.trace_stack.next_level(False)
     main = core.MainTrace(level, pe.JaxprTrace)
     core.thread_local_state.trace_state.trace_stack.push(main, False)
     self._count_subtraces += 1
     return pe.JaxprTrace(main, core.cur_sublevel())
예제 #2
0
 def start_subtrace():
     """Starts a nested trace, returns the Trace object."""
     # TODO: This follows the __enter__ part of core.new_master. share
     level = core.trace_state.trace_stack.next_level(False)
     master = core.MasterTrace(level, pe.JaxprTrace)
     core.trace_state.trace_stack.push(master, False)
     return pe.JaxprTrace(master, core.cur_sublevel())
예제 #3
0
 def start_subtrace(self):
     """Starts a nested trace, returns the Trace object."""
     # TODO: This follows the __enter__ part of core.new_master.
     level = core.thread_local_state.trace_state.trace_stack.next_level(
         False)
     master = core.MasterTrace(level, pe.JaxprTrace)
     core.thread_local_state.trace_state.trace_stack.push(master, False)
     self._count_subtraces += 1
     return pe.JaxprTrace(master, core.cur_sublevel())
예제 #4
0
파일: loops.py 프로젝트: xueeinstein/jax
 def start_subtrace(self):
     """Starts a nested trace, returns the Trace object."""
     # TODO: This follows the __enter__ part of core.new_main.
     level = core.thread_local_state.trace_state.trace_stack.next_level()
     name_stack = source_info_util.current_name_stack()
     main = core.MainTrace(level, pe.JaxprTrace, name_stack=name_stack)
     core.thread_local_state.trace_state.trace_stack.push(main)
     self._count_subtraces += 1
     return pe.JaxprTrace(main, core.cur_sublevel(), name_stack=name_stack)