def start_subtrace(self): """Starts a nested trace, returns the Trace object.""" # TODO: This follows the __enter__ part of core.new_main. if config.omnistaging_enabled: level = core.thread_local_state.trace_state.trace_stack.next_level() main = core.MainTrace(level, pe.JaxprTrace) core.thread_local_state.trace_state.trace_stack.push(main) self._count_subtraces += 1 return pe.JaxprTrace(main, core.cur_sublevel()) else: level = core.thread_local_state.trace_state.trace_stack.next_level(False) main = core.MainTrace(level, pe.JaxprTrace) core.thread_local_state.trace_state.trace_stack.push(main, False) self._count_subtraces += 1 return pe.JaxprTrace(main, core.cur_sublevel())
def start_subtrace(self): """Starts a nested trace, returns the Trace object.""" # TODO: This follows the __enter__ part of core.new_main. level = core.thread_local_state.trace_state.trace_stack.next_level() name_stack = source_info_util.current_name_stack() main = core.MainTrace(level, pe.JaxprTrace, name_stack=name_stack) core.thread_local_state.trace_state.trace_stack.push(main) self._count_subtraces += 1 return pe.JaxprTrace(main, core.cur_sublevel(), name_stack=name_stack)
def _copy_main_traces(x): if isinstance(x, core.MainTrace): return core.MainTrace(x.level, x.trace_type, **x.payload) else: return x