def wrapped(*args, **kwargs): ctx = get_current_context() if not ctx or getattr(ctx, 'in_operator', False) or not ctx.is_tracing: op1 = operator(*args, **kwargs) return op1 ctx.in_operator = True if operator_info.custom_trace_fn is not None: result = operator_info.custom_trace_fn(operator, *args, **kwargs) else: ia_op_exec_context = ctx.get_caller_context(operator_info.name) ctx.register_operator_call(ia_op_exec_context.operator_name, ia_op_exec_context.scope_in_model) op_input = OperatorInput(list(args), kwargs) processed_input = ctx.execute_pre_hooks(ia_op_exec_context, op_input) args = tuple(processed_input.op_args) kwargs = processed_input.op_kwargs fargs = flatten_args(args, kwargs) node = ctx.find_operator_node(fargs, ia_op_exec_context) if is_debug(): ctx.register_node_call(ctx.graph.get_node_key_by_id(node.node_id)) result = operator(*args, **kwargs) result = trace_tensors(result, node) result = ctx.execute_post_hooks(ia_op_exec_context, result) ctx.in_operator = False return result
def is_in_nncf_tracing(): if not is_nncf_enabled(): return False ctx = get_current_context() if ctx is None: return False return ctx.is_tracing
def wrapped(self, *args, **kwargs): ctx = get_current_context() if not ctx or self.__class__ in _IGNORED_SCOPES: if isinstance(self, DataParallel): _warn_data_parallel() return module_call(self, *args, **kwargs) ctx.push_scope(self) retval = module_call(self, *args, **kwargs) if type(self).__name__ in ITERATION_MODULES.registry_dict.keys(): ctx.reset_operator_call_count_in_scope(ctx.scope) ctx.pop_scope() return retval
def wrapped(*args, **kwargs): ctx = get_current_context() if not ctx or getattr(ctx, 'in_operator', False) or not ctx.is_tracing: op1 = operator(*args, **kwargs) return op1 ctx.in_operator = True if operator_info.custom_trace_fn is not None: try: result = operator_info.custom_trace_fn(operator, *args, **kwargs) except: # Looks like the __repr__ call made during IDE debug to display tensor contents does not exit properly, # but instead throws an exception. This try...except block handles such a situation. # Otherwise the context is stuck in the "in_operator == True" state. ctx.in_operator = False raise else: ia_op_exec_context = ctx.get_caller_context(operator_info.name) ctx.register_operator_call(ia_op_exec_context.operator_name, ia_op_exec_context.scope_in_model) op_input = OperatorInput(list(args), kwargs) processed_input = ctx.execute_pre_hooks(ia_op_exec_context, op_input) args = tuple(processed_input.op_args) kwargs = processed_input.op_kwargs fargs = flatten_args(args, kwargs) node = ctx.find_operator_node(fargs, ia_op_exec_context) if is_debug(): ctx.register_node_call(ctx.graph.get_node_key_by_id(node.node_id)) result = operator(*args, **kwargs) result = trace_tensors(result, node) result = ctx.execute_post_hooks(ia_op_exec_context, result) ctx.in_operator = False return result