Exemple #1
0
    def wrapped(*args, **kwargs):
        ctx = get_current_context()
        if not ctx or getattr(ctx, 'in_operator', False) or not ctx.is_tracing:
            op1 = operator(*args, **kwargs)
            return op1

        ctx.in_operator = True

        if operator_info.custom_trace_fn is not None:
            result = operator_info.custom_trace_fn(operator, *args, **kwargs)
        else:
            ia_op_exec_context = ctx.get_caller_context(operator_info.name)
            ctx.register_operator_call(ia_op_exec_context.operator_name, ia_op_exec_context.scope_in_model)

            op_input = OperatorInput(list(args), kwargs)
            processed_input = ctx.execute_pre_hooks(ia_op_exec_context, op_input)
            args = tuple(processed_input.op_args)
            kwargs = processed_input.op_kwargs
            fargs = flatten_args(args, kwargs)

            node = ctx.find_operator_node(fargs, ia_op_exec_context)
            if is_debug():
                ctx.register_node_call(ctx.graph.get_node_key_by_id(node.node_id))

            result = operator(*args, **kwargs)

            result = trace_tensors(result, node)
            result = ctx.execute_post_hooks(ia_op_exec_context, result)

        ctx.in_operator = False
        return result
Exemple #2
0
    def __call__(self, operator: Callable, *args, **kwargs):
        """ This wrapper override will result in the operator not being added to graph,
        but the result will still have TracedTensors with parent IDs left the same as in input.
        Useful for operators which are not likely to be present in patterns considered for
        compression, but still have to be accounted for so that the NNCF internal graph representation
        does not become disjoint. """

        result = operator(*args, **kwargs)

        fargs = flatten_args(args, kwargs)
        input_traced_tensor_indices = [
            i for i in range(len(fargs)) if isinstance(fargs[i], TracedTensor)
        ]

        if isinstance(result, (list, tuple)):
            output_tensors_to_be_traced_indices = [
                i for i in range(len(result)) if isinstance(result[i], Tensor)
            ]

            was_tuple = isinstance(result, tuple)
            result = list(result)
            if len(input_traced_tensor_indices) == 1:
                # Broadcast one and the same creator ID of input to all outputs
                for out_idx in output_tensors_to_be_traced_indices:
                    forwarded_meta = deepcopy(
                        fargs[input_traced_tensor_indices[0]].tensor_meta)
                    forwarded_meta.shape = tuple(result[out_idx].shape)
                    result[out_idx] = TracedTensor.from_torch_tensor(
                        result[out_idx], forwarded_meta)
            elif len(input_traced_tensor_indices) != len(
                    output_tensors_to_be_traced_indices):
                raise RuntimeError(
                    "Unable to forward trace through operator {} - "
                    "input and output tensor count mismatch!".format(
                        operator.__name__))
            else:
                # Assume that output tensor order corresponds to input tensor order
                for in_idx, out_idx in zip(
                        input_traced_tensor_indices,
                        output_tensors_to_be_traced_indices):
                    forwarded_meta = deepcopy(fargs[in_idx].tensor_meta)
                    forwarded_meta.shape = tuple(result[out_idx].shape)
                    result[out_idx] = TracedTensor.from_torch_tensor(
                        result[out_idx], forwarded_meta)
            if was_tuple:
                result = tuple(result)
        elif len(input_traced_tensor_indices) > 1:
            raise RuntimeError(
                "Unable to forward trace through operator {} - "
                "input and output tensor count mismatch!".format(
                    operator.__name__))
        elif input_traced_tensor_indices:
            forwarded_meta = deepcopy(
                fargs[input_traced_tensor_indices[0]].tensor_meta)
            forwarded_meta.shape = tuple(result.shape)
            return TracedTensor.from_torch_tensor(result, forwarded_meta)
        # No traced tensors in input, return a usual torch.Tensor as well
        return result
Exemple #3
0
    def wrapped(*args, **kwargs):
        ctx = get_current_context()
        if not ctx or getattr(ctx, 'in_operator', False) or not ctx.is_tracing:
            op1 = operator(*args, **kwargs)
            return op1

        ctx.in_operator = True

        if operator_info.custom_trace_fn is not None:
            try:
                result = operator_info.custom_trace_fn(operator, *args, **kwargs)
            except:
                # Looks like the __repr__ call made during IDE debug to display tensor contents does not exit properly,
                # but instead throws an exception. This try...except block handles such a situation.
                # Otherwise the context is stuck in the "in_operator == True" state.
                ctx.in_operator = False
                raise
        else:
            ia_op_exec_context = ctx.get_caller_context(operator_info.name)
            ctx.register_operator_call(ia_op_exec_context.operator_name, ia_op_exec_context.scope_in_model)

            op_input = OperatorInput(list(args), kwargs)
            processed_input = ctx.execute_pre_hooks(ia_op_exec_context, op_input)
            args = tuple(processed_input.op_args)
            kwargs = processed_input.op_kwargs
            fargs = flatten_args(args, kwargs)

            node = ctx.find_operator_node(fargs, ia_op_exec_context)
            if is_debug():
                ctx.register_node_call(ctx.graph.get_node_key_by_id(node.node_id))

            result = operator(*args, **kwargs)

            result = trace_tensors(result, node)
            result = ctx.execute_post_hooks(ia_op_exec_context, result)

        ctx.in_operator = False
        return result