コード例 #1
0
    def wrapped(*args, **kwargs):
        ctx = get_current_context()
        if not ctx or getattr(ctx, 'in_operator', False) or not ctx.is_tracing:
            op1 = operator(*args, **kwargs)
            return op1

        ctx.in_operator = True

        if operator_info.custom_trace_fn is not None:
            result = operator_info.custom_trace_fn(operator, *args, **kwargs)
        else:
            ia_op_exec_context = ctx.get_caller_context(operator_info.name)
            ctx.register_operator_call(ia_op_exec_context.operator_name, ia_op_exec_context.scope_in_model)

            op_input = OperatorInput(list(args), kwargs)
            processed_input = ctx.execute_pre_hooks(ia_op_exec_context, op_input)
            args = tuple(processed_input.op_args)
            kwargs = processed_input.op_kwargs
            fargs = flatten_args(args, kwargs)

            node = ctx.find_operator_node(fargs, ia_op_exec_context)
            if is_debug():
                ctx.register_node_call(ctx.graph.get_node_key_by_id(node.node_id))

            result = operator(*args, **kwargs)

            result = trace_tensors(result, node)
            result = ctx.execute_post_hooks(ia_op_exec_context, result)

        ctx.in_operator = False
        return result
コード例 #2
0
def is_in_nncf_tracing():
    if not is_nncf_enabled():
        return False

    ctx = get_current_context()

    if ctx is None:
        return False
    return ctx.is_tracing
コード例 #3
0
ファイル: wrappers.py プロジェクト: xiaming9880/nncf
 def wrapped(self, *args, **kwargs):
     ctx = get_current_context()
     if not ctx or self.__class__ in _IGNORED_SCOPES:
         if isinstance(self, DataParallel):
             _warn_data_parallel()
         return module_call(self, *args, **kwargs)
     ctx.push_scope(self)
     retval = module_call(self, *args, **kwargs)
     if type(self).__name__ in ITERATION_MODULES.registry_dict.keys():
         ctx.reset_operator_call_count_in_scope(ctx.scope)
     ctx.pop_scope()
     return retval
コード例 #4
0
ファイル: wrappers.py プロジェクト: zbrnwpu/nncf
    def wrapped(*args, **kwargs):
        ctx = get_current_context()
        if not ctx or getattr(ctx, 'in_operator', False) or not ctx.is_tracing:
            op1 = operator(*args, **kwargs)
            return op1

        ctx.in_operator = True

        if operator_info.custom_trace_fn is not None:
            try:
                result = operator_info.custom_trace_fn(operator, *args, **kwargs)
            except:
                # Looks like the __repr__ call made during IDE debug to display tensor contents does not exit properly,
                # but instead throws an exception. This try...except block handles such a situation.
                # Otherwise the context is stuck in the "in_operator == True" state.
                ctx.in_operator = False
                raise
        else:
            ia_op_exec_context = ctx.get_caller_context(operator_info.name)
            ctx.register_operator_call(ia_op_exec_context.operator_name, ia_op_exec_context.scope_in_model)

            op_input = OperatorInput(list(args), kwargs)
            processed_input = ctx.execute_pre_hooks(ia_op_exec_context, op_input)
            args = tuple(processed_input.op_args)
            kwargs = processed_input.op_kwargs
            fargs = flatten_args(args, kwargs)

            node = ctx.find_operator_node(fargs, ia_op_exec_context)
            if is_debug():
                ctx.register_node_call(ctx.graph.get_node_key_by_id(node.node_id))

            result = operator(*args, **kwargs)

            result = trace_tensors(result, node)
            result = ctx.execute_post_hooks(ia_op_exec_context, result)

        ctx.in_operator = False
        return result