Example #1
0
    def test_number_of_nodes_for_module_in_loop__not_input_node(self):
        num_iter = 5

        class LoopModule(nn.Module):
            class Inner(nn.Module):
                def forward(self, x):
                    s = F.sigmoid(x)
                    t = F.tanh(x)
                    result = F.sigmoid(x) * t + F.tanh(x) * s
                    return result

                @staticmethod
                def nodes_number():
                    return 7

            def __init__(self):
                super().__init__()
                self.inner = self.Inner()

            def forward(self, x):
                for _ in range(num_iter):
                    x = self.inner(F.relu(x))
                return x

            def nodes_number(self):
                return self.inner.nodes_number() + num_iter

        test_module = LoopModule()
        context = TracingContext()
        context.enable_trace_dynamic_graph()
        with context as ctx:
            _ = test_module(torch.zeros(1))
            assert ctx.graph.get_nodes_count() == test_module.nodes_number()
Example #2
0
    def test_number_of_nodes_for_module_with_nested_loops(self):
        num_iter = 5

        class TestIterModule(nn.Module):
            @ITERATION_MODULES.register()
            class TestIterModule_ResetPoint(nn.Module):
                def __init__(self, loop_module):
                    super().__init__()
                    self.loop_module = loop_module

                def forward(self, x):
                    return self.loop_module(F.relu(x))

            def __init__(self):
                super().__init__()
                self.loop_module = self.LoopModule2()
                self.reset_point = self.TestIterModule_ResetPoint(
                    self.loop_module)

            def forward(self, x):
                for _ in range(num_iter):
                    x = self.reset_point(x)
                return x

            class LoopModule2(nn.Module):
                @ITERATION_MODULES.register()
                class LoopModule2_ResetPoint(nn.Module):
                    def __init__(self, inner):
                        super().__init__()
                        self.inner = inner

                    def forward(self, x):
                        return self.inner(F.relu(x))

                def __init__(self):
                    super().__init__()
                    self.inner = self.Inner()
                    self.reset_helper = self.LoopModule2_ResetPoint(self.inner)

                def forward(self, x):
                    for _ in range(num_iter):
                        self.reset_helper(x)
                    return x

                class Inner(nn.Module):
                    def forward(self, x):
                        s = F.sigmoid(x)
                        t = F.tanh(x)
                        result = t + s
                        return result

        test_module = TestIterModule()
        context = TracingContext()
        context.enable_trace_dynamic_graph()
        with context as ctx:
            _ = test_module(torch.zeros(1))
            assert ctx.graph.get_nodes_count() == num_iter
def test_tensor_printing_does_not_inflate_graph():
    context_to_use = TracingContext()
    context_to_use.enable_trace_dynamic_graph()
    with context_to_use as _ctx:
        with torch.no_grad():
            tensor = torch.ones([1, 2])
            print(tensor)
            str(tensor)
            tensor.__repr__()
            tensor = TracedTensor.from_torch_tensor(
                tensor, TensorMeta(0, 0, tensor.shape))
            print(tensor)
            str(tensor)
            tensor.__repr__()
    assert _ctx.graph.get_nodes_count() == 0
Example #4
0
    def trace_graph(self,
                    model: torch.nn.Module,
                    context_to_use: Optional['TracingContext'] = None,
                    as_eval: bool = False) -> DynamicGraph:
        sd = deepcopy(model.state_dict())

        from nncf.torch.dynamic_graph.context import TracingContext
        if context_to_use is None:
            context_to_use = TracingContext()

        context_to_use.enable_trace_dynamic_graph()
        from nncf.torch.utils import training_mode_switcher
        context_to_use.base_module_thread_local_replica = model
        with context_to_use as _ctx:
            with torch.no_grad():
                if as_eval:
                    with training_mode_switcher(model, is_training=False):
                        self.custom_forward_fn(model)
                else:
                    self.custom_forward_fn(model)
        model.load_state_dict(sd)

        if isinstance(model, PostGraphBuildActing):
            model.post_build_graph_actions()
        context_to_use.disable_trace_dynamic_graph()
        return context_to_use.graph
Example #5
0
    def test_number_of_nodes_for_module_in_loop(self):
        num_iter = 5

        class LoopModule(nn.Module):
            @ITERATION_MODULES.register('Inner')
            class Inner(nn.Module):
                def __init__(self):
                    super().__init__()
                    self.operator1 = torch.sigmoid
                    self.operator2 = torch.tanh

                def forward(self, x):
                    s = self.operator1(x)
                    t = self.operator2(x)
                    result = t + s
                    return result

                @staticmethod
                def nodes_number():
                    return 3

            def __init__(self):
                super().__init__()
                self.inner = self.Inner()

            def forward(self, x):
                for _ in range(num_iter):
                    x = self.inner(x)
                return x

            def nodes_number(self):
                return self.inner.nodes_number()

        test_module = LoopModule()
        context = TracingContext()
        context.enable_trace_dynamic_graph()
        with context as ctx:
            _ = test_module(torch.zeros(1))
            assert ctx.graph.get_nodes_count() == test_module.nodes_number()
Example #6
0
    def test_number_of_nodes_for_repeated_module(self):
        class LoopModule(nn.Module):
            def __init__(self):
                super().__init__()
                self.operator = F.relu
                self.layers = nn.ModuleList(
                    [nn.Conv2d(1, 1, 1),
                     nn.Conv2d(1, 1, 1)])

            def forward(self, x):
                for layer in self.layers:
                    x = F.relu(layer(x))
                return x

        test_module = LoopModule()
        context = TracingContext()
        context.enable_trace_dynamic_graph()
        with context as ctx:
            x = test_module(torch.zeros(1, 1, 1, 1))
            assert ctx.graph.get_nodes_count(
            ) == 4  # NB: may always fail in debug due to superfluous 'cat' nodes
            _ = test_module(x)
            assert ctx.graph.get_nodes_count(
            ) == 8  # NB: may always fail in debug due to superfluous 'cat' nodes
Example #7
0
    def __init__(self,
                 module,
                 input_infos: List[ModelInputInfo],
                 dummy_forward_fn=None,
                 wrap_inputs_fn=None,
                 scopes_without_shape_matching=None,
                 ignored_scopes=None,
                 target_scopes=None,
                 reset: bool = False,
                 wrap_outputs_fn=None,
                 original_model_accuracy=None):
        super().__init__()
        self._set_nncf_wrapped_model(module)
        self._forward_signature = inspect.signature(module.forward)
        self.input_infos = input_infos

        self._original_model_accuracy = original_model_accuracy

        self.ignored_scopes = ignored_scopes
        self.target_scopes = target_scopes
        self._user_dummy_forward_fn = dummy_forward_fn
        self._kd_loss_handler = None

        try:
            device = next(module.parameters()).device
        except StopIteration:
            # Param-less model, assume CPU
            device = 'cpu'

        if wrap_inputs_fn is not None:
            self._wrap_inputs_fn = wrap_inputs_fn
        else:
            self.__input_infos_based_input_wrapper = InputInfoWrapManager(
                self.input_infos,
                self._forward_signature,
                module_ref_for_device=self)
            self._wrap_inputs_fn = self.__input_infos_based_input_wrapper.wrap_inputs

        if wrap_outputs_fn is not None:
            self._wrap_outputs_fn = wrap_outputs_fn
        else:
            self._wrap_outputs_fn = wrap_nncf_model_outputs_with_objwalk

        self._nncf_module_scopes = []  # type: List[Scope]
        self.scopes_without_shape_matching = scopes_without_shape_matching
        self.debug_interface = CombinedDebugInterface() if is_debug() else None
        self._extra_module_types = []  # type: List[ExtraCompressionModuleType]
        # pylint:disable=line-too-long
        self._insertions_into_original_graph = {
        }  # type: Dict[PTTargetPoint, List[Tuple[Callable, TransformationPriority]]]

        _orig_graph_build_forward_fn = self._get_dummy_forward_fn_for_graph_building(
            with_input_tracing=True, with_output_tracing=True)

        nncf_wrapped_model = self.get_nncf_wrapped_model()
        eval_only_op_scopes = self._collect_eval_only_op_scopes(
            nncf_wrapped_model, _orig_graph_build_forward_fn)

        # all modules called in eval mode should be replaced prior to graph building
        self._replace_modules_by_nncf_modules(device, eval_only_op_scopes,
                                              reset)

        _orig_context = TracingContext()

        _orig_context.add_node_comparators([MODEL_INPUT_OP_NAME],
                                           ShapeIgnoringTensorMetaComparator())
        _orig_context.add_node_comparators([MODEL_OUTPUT_OP_NAME],
                                           ShapeIgnoringTensorMetaComparator())
        if self.scopes_without_shape_matching:
            _orig_context.add_node_comparators(
                scopes_without_shape_matching,
                ShapeIgnoringTensorMetaComparator())

        self._original_dynamic_graph = GraphTracer(
            _orig_graph_build_forward_fn).trace_graph(nncf_wrapped_model,
                                                      _orig_context,
                                                      as_eval=True)
        self._original_graph = GraphConverter.convert(
            self._original_dynamic_graph, input_infos=self.input_infos)
        self._compressed_graph = None  # type: PTNNCFGraph

        self._compressed_context = TracingContext()

        self._dummy_forward_fn = self._get_dummy_forward_fn_for_graph_building(
            with_input_tracing=False, with_output_tracing=False)
        self._in_user_dummy_forward = False

        self._compressed_context.add_node_comparators(
            [MODEL_INPUT_OP_NAME], ShapeIgnoringTensorMetaComparator())
        self._compressed_context.add_node_comparators(
            [MODEL_OUTPUT_OP_NAME], ShapeIgnoringTensorMetaComparator())
        if self.scopes_without_shape_matching:
            self._compressed_context.add_node_comparators(
                scopes_without_shape_matching,
                ShapeIgnoringTensorMetaComparator())
        self._load_listener = None
Example #8
0
class NNCFNetwork(nn.Module, PostGraphBuildActing):
    MODEL_STATE_VERSION_ATTR = '_nncf_model_state_version'
    MODEL_STATE_VERSION = 1

    def __init__(self,
                 module,
                 input_infos: List[ModelInputInfo],
                 dummy_forward_fn=None,
                 wrap_inputs_fn=None,
                 scopes_without_shape_matching=None,
                 ignored_scopes=None,
                 target_scopes=None,
                 reset: bool = False,
                 wrap_outputs_fn=None,
                 original_model_accuracy=None):
        super().__init__()
        self._set_nncf_wrapped_model(module)
        self._forward_signature = inspect.signature(module.forward)
        self.input_infos = input_infos

        self._original_model_accuracy = original_model_accuracy

        self.ignored_scopes = ignored_scopes
        self.target_scopes = target_scopes
        self._user_dummy_forward_fn = dummy_forward_fn
        self._kd_loss_handler = None

        try:
            device = next(module.parameters()).device
        except StopIteration:
            # Param-less model, assume CPU
            device = 'cpu'

        if wrap_inputs_fn is not None:
            self._wrap_inputs_fn = wrap_inputs_fn
        else:
            self.__input_infos_based_input_wrapper = InputInfoWrapManager(
                self.input_infos,
                self._forward_signature,
                module_ref_for_device=self)
            self._wrap_inputs_fn = self.__input_infos_based_input_wrapper.wrap_inputs

        if wrap_outputs_fn is not None:
            self._wrap_outputs_fn = wrap_outputs_fn
        else:
            self._wrap_outputs_fn = wrap_nncf_model_outputs_with_objwalk

        self._nncf_module_scopes = []  # type: List[Scope]
        self.scopes_without_shape_matching = scopes_without_shape_matching
        self.debug_interface = CombinedDebugInterface() if is_debug() else None
        self._extra_module_types = []  # type: List[ExtraCompressionModuleType]
        # pylint:disable=line-too-long
        self._insertions_into_original_graph = {
        }  # type: Dict[PTTargetPoint, List[Tuple[Callable, TransformationPriority]]]

        _orig_graph_build_forward_fn = self._get_dummy_forward_fn_for_graph_building(
            with_input_tracing=True, with_output_tracing=True)

        nncf_wrapped_model = self.get_nncf_wrapped_model()
        eval_only_op_scopes = self._collect_eval_only_op_scopes(
            nncf_wrapped_model, _orig_graph_build_forward_fn)

        # all modules called in eval mode should be replaced prior to graph building
        self._replace_modules_by_nncf_modules(device, eval_only_op_scopes,
                                              reset)

        _orig_context = TracingContext()

        _orig_context.add_node_comparators([MODEL_INPUT_OP_NAME],
                                           ShapeIgnoringTensorMetaComparator())
        _orig_context.add_node_comparators([MODEL_OUTPUT_OP_NAME],
                                           ShapeIgnoringTensorMetaComparator())
        if self.scopes_without_shape_matching:
            _orig_context.add_node_comparators(
                scopes_without_shape_matching,
                ShapeIgnoringTensorMetaComparator())

        self._original_dynamic_graph = GraphTracer(
            _orig_graph_build_forward_fn).trace_graph(nncf_wrapped_model,
                                                      _orig_context,
                                                      as_eval=True)
        self._original_graph = GraphConverter.convert(
            self._original_dynamic_graph, input_infos=self.input_infos)
        self._compressed_graph = None  # type: PTNNCFGraph

        self._compressed_context = TracingContext()

        self._dummy_forward_fn = self._get_dummy_forward_fn_for_graph_building(
            with_input_tracing=False, with_output_tracing=False)
        self._in_user_dummy_forward = False

        self._compressed_context.add_node_comparators(
            [MODEL_INPUT_OP_NAME], ShapeIgnoringTensorMetaComparator())
        self._compressed_context.add_node_comparators(
            [MODEL_OUTPUT_OP_NAME], ShapeIgnoringTensorMetaComparator())
        if self.scopes_without_shape_matching:
            self._compressed_context.add_node_comparators(
                scopes_without_shape_matching,
                ShapeIgnoringTensorMetaComparator())
        self._load_listener = None

    @debuggable_forward
    def forward(self, *args, **kwargs):
        with self._compressed_context as ctx:  # type: TracingContext
            ctx.base_module_thread_local_replica = self
            args, kwargs = replicate_same_tensors((args, kwargs))
            if not self._in_user_dummy_forward:
                # If a user supplies own dummy forward, he is responsible for
                # correctly wrapping inputs inside it as well.
                args, kwargs = self._strip_traced_tensors(args, kwargs)
                args, kwargs = self._wrap_inputs_fn(args, kwargs)
            retval = self.get_nncf_wrapped_model()(*args, **kwargs)
            retval = replicate_same_tensors(retval)
            if not self._in_user_dummy_forward:
                retval = self._wrap_outputs_fn(retval)

        if self._kd_loss_handler is not None and self.get_nncf_wrapped_model(
        ).training:
            self._kd_loss_handler(retval, *args, **kwargs)
        return retval

    def _strip_traced_tensors(self, args: Tuple,
                              kwargs: Dict) -> Tuple[Tuple, Dict]:
        """
            Required to guard against new forward calls on tensors that have already passed
            through NNCF's forward once and got turned into TracedTensors by reference access.
        """
        is_traced_tensor_predicate = lambda x: isinstance(x, TracedTensor)

        def strip_fn(tensor: TracedTensor) -> torch.Tensor:
            if hasattr(torch.Tensor, 'as_subclass'):
                return torch.Tensor.as_subclass(tensor, torch.Tensor)
            # Torch < 1.7.0 fallback
            return torch.tensor(tensor,
                                device=tensor.device,
                                requires_grad=tensor.requires_grad)

        args = objwalk(args, is_traced_tensor_predicate, strip_fn)
        kwargs = objwalk(kwargs, is_traced_tensor_predicate, strip_fn)
        return args, kwargs

    def create_knowledge_distillation_loss_handler(self, kd_original_model: nn.Module, calculate_fn)\
            -> KnowledgeDistillationLossHandler:
        """
        Creates KnowledgeDistillationLossHandler instance for enabling Knowledge Distillation feature.
            Also returns created KnowledgeDistillationLossHandler for control over Knowledge Distillation logic.

        :param kd_original_model: original non compressed model used for distillation
        :param calculate_fn: function used to parse model outputs and calculate knowledge distillation loss
        :return: KnowledgeDistillationLossHandler instance
        """
        device = next(self.get_nncf_wrapped_model().parameters()).device
        self._kd_loss_handler = KnowledgeDistillationLossHandler(
            self._compressed_context, kd_original_model, calculate_fn, device)
        return self._kd_loss_handler

    # Cannnot use property syntax here, otherwise the wrapped module will end up
    # being twice in the same checkpoint with different prefixes
    def get_nncf_wrapped_model(self):
        return getattr(self, MODEL_WRAPPED_BY_NNCF_ATTR_NAME)

    def _set_nncf_wrapped_model(self, value):
        setattr(self, MODEL_WRAPPED_BY_NNCF_ATTR_NAME, value)

    def get_clean_shallow_copy(self) -> 'NNCFNetwork':
        # WARNING: Will reset pre- and post-ops of the underlying model. Use save_nncf_module_additions
        # and load_nncf_module_additions to preserve these, or temporary_clean_view().
        from nncf.torch.utils import save_module_state, load_module_state
        saved_state = save_module_state(self)
        model_copy = NNCFNetwork(self.get_nncf_wrapped_model(),
                                 self.input_infos,
                                 self._user_dummy_forward_fn,
                                 self._wrap_inputs_fn,
                                 self.scopes_without_shape_matching,
                                 self.ignored_scopes,
                                 self.target_scopes,
                                 reset=True)
        load_module_state(model_copy, saved_state)
        return model_copy

    def get_modules_in_nncf_modules_by_type(self,
                                            types) -> Dict[Scope, nn.Module]:
        nncf_modules = self.get_nncf_modules()
        retval = {}
        for nncf_module_scope, nncf_module in nncf_modules.items():
            nncf_module_scope.pop()
            for relative_scope, target_module in get_all_modules_by_type(
                    nncf_module, types).items():
                retval[nncf_module_scope + relative_scope] = target_module
        return retval

    def insert_at_point(self, point: PTInsertionPoint,
                        fn_list: List[Callable]):
        if point.insertion_type == PTInsertionType.OPERATOR_PRE_HOOK:
            self._compressed_context.register_pre_hooks(
                fn_list, point.op_address, point.input_port_id)
        elif point.insertion_type == PTInsertionType.OPERATOR_POST_HOOK:
            self._compressed_context.register_post_hooks(
                fn_list, point.op_address)
        elif point.insertion_type in [
                PTInsertionType.NNCF_MODULE_PRE_OP,
                PTInsertionType.NNCF_MODULE_POST_OP
        ]:
            norm_target_scope = self._normalize_variable_recurrent_scope(
                point.module_scope)
            norm_nncf_scopes = [
                self._normalize_variable_recurrent_scope(x)
                for x in self._nncf_module_scopes
            ]
            assert norm_target_scope in norm_nncf_scopes  # Required for proper Recurrent/VariableRecurrent addressing
            nncf_module = self.get_module_by_scope(point.module_scope)
            if point.insertion_type == PTInsertionType.NNCF_MODULE_PRE_OP:
                for fn in fn_list:
                    nncf_module.register_pre_forward_operation(fn)
            elif point.insertion_type == PTInsertionType.NNCF_MODULE_POST_OP:
                for fn in fn_list:
                    nncf_module.register_post_forward_operation(fn)
        else:
            raise RuntimeError("Unsupported insertion type: {}".format(
                point.insertion_type))

    def __getattr__(self, name):
        class NotFound:
            pass

        def get_nncf_network_attr(self, name):
            if name in self.__dict__:
                return self.__dict__[name]
            return NotFound

        def get_nncf_module_attr(self, name):
            if hasattr(
                    self.__dict__['_modules'][MODEL_WRAPPED_BY_NNCF_ATTR_NAME],
                    name):
                attr = getattr(
                    self.__dict__['_modules'][MODEL_WRAPPED_BY_NNCF_ATTR_NAME],
                    name)
                if hasattr(attr, '__self__'):  # If it is a bound function
                    from functools import partial
                    attr = partial(attr.__func__, self)
                    return attr
                # If it is not a bound function
                return attr
            return NotFound

        def get_nn_module_attr(self, name):
            return super().__getattr__(name)

        attr = get_nncf_network_attr(self, name)
        if attr != NotFound:
            return attr
        attr = get_nncf_module_attr(self, name)
        if attr != NotFound:
            return attr
        return get_nn_module_attr(self, name)

    def get_graph(self) -> PTNNCFGraph:
        if self._compressed_context.graph.get_nodes_count(
        ) == 0 or self._compressed_graph is None:
            self.rebuild_graph()
        return self._compressed_graph

    def get_dynamic_graph(self) -> DynamicGraph:
        return self._compressed_context.graph

    def get_original_graph(self) -> PTNNCFGraph:
        return self._original_graph

    def get_tracing_context(self) -> TracingContext:
        return self._compressed_context

    def enable_dynamic_graph_building(self):
        self._compressed_context.enable_node_additions()

    def disable_dynamic_graph_building(self):
        self._compressed_context.disable_node_additions()

    def _get_dummy_forward_fn_for_graph_building(self, with_input_tracing,
                                                 with_output_tracing):
        if self._user_dummy_forward_fn is None:
            return create_dummy_forward_fn(
                self.input_infos,
                with_input_tracing=with_input_tracing,
                wrap_inputs_fn=self._wrap_inputs_fn,
                wrap_outputs_fn=self._wrap_outputs_fn,
                with_output_tracing=with_output_tracing)

        def wrapped_user_dummy_forward_fn(*args, **kwargs):
            self._in_user_dummy_forward = True
            retval = self._user_dummy_forward_fn(*args, **kwargs)
            self._in_user_dummy_forward = False
            return retval

        return wrapped_user_dummy_forward_fn

    def _replace_modules_by_nncf_modules(
            self,
            device,
            eval_only_op_scopes: List[Scope] = None,
            reset: bool = False):
        module, self._nncf_module_scopes = replace_modules_by_nncf_modules(
            self.get_nncf_wrapped_model(),
            ignored_scopes=self.ignored_scopes,
            target_scopes=self.target_scopes,
            eval_op_scopes=eval_only_op_scopes,
            reset=reset)
        self._set_nncf_wrapped_model(module.to(device))

    def get_nncf_module_scopes(self) -> List[Scope]:
        return self._nncf_module_scopes

    def get_nncf_modules(self) -> Dict[Scope, torch.nn.Module]:
        nncf_module_names_list = NNCF_MODULES + [
            x.__name__ for x in NNCF_WRAPPED_USER_MODULES_DICT.values()
        ]
        return get_all_modules_by_type(self.get_nncf_wrapped_model(),
                                       nncf_module_names_list)

    def get_weighted_original_graph_nodes(self,
                                          nncf_module_names: List[str] = None
                                          ) -> List[NNCFNode]:
        retval = []
        for nncf_module_scope in self._nncf_module_scopes:
            if nncf_module_names is not None:
                module_name = nncf_module_scope[-1].calling_module_class_name
                if module_name not in nncf_module_names:
                    continue
            nodes_in_scope = self._original_graph.get_op_nodes_in_scope(
                nncf_module_scope)
            for node in nodes_in_scope:
                if node.layer_attributes is not None:  # TODO(vshampor): implement more explicit filtering
                    retval.append(node)
        return retval

    def get_nncf_modules_by_module_names(
            self, nncf_module_names_list: List[str]
    ) -> Dict["Scope", torch.nn.Module]:
        return get_all_modules_by_type(self.get_nncf_wrapped_model(),
                                       nncf_module_names_list)

    def rebuild_graph(self, *input_args):
        self._compressed_context.reset_graph()
        dummy_forward_fn = self._get_dummy_forward_fn_for_graph_building(
            with_input_tracing=False, with_output_tracing=False)
        builder = GraphBuilder(dummy_forward_fn)
        self._compressed_graph = builder.build_graph(
            self, self._compressed_context, input_infos=self.input_infos)

    def post_build_graph_actions(self):
        # Reset initialization flags (`initialized`) for all quantization modules
        # after dummy `load_state_dict` call.
        quantization_types = [
            class_type.__name__
            for class_type in QUANTIZATION_MODULES.registry_dict.values()
        ]
        all_quantizations = get_state_dict_names_with_modules(
            self, quantization_types)
        for module in all_quantizations.values():
            module.initialized = False

    def is_scope_in_nncf_module_scope(self, scope: Scope):
        # TODO: optimize
        norm_nncf_scopes = [
            self._normalize_variable_recurrent_scope(x)
            for x in self._nncf_module_scopes
        ]
        norm_op_scope = self._normalize_variable_recurrent_scope(scope)
        for nncf_scope in norm_nncf_scopes:
            if norm_op_scope in nncf_scope:
                return True
        return False

    def register_compression_module_type(
            self, compression_module_type: ExtraCompressionModuleType):
        attr_name = self._compression_module_type_to_attr_name(
            compression_module_type)
        if compression_module_type in self._extra_module_types:
            raise RuntimeError("Module type {} is already registered".format(
                compression_module_type))
        self.__setattr__(attr_name, nn.ModuleDict())
        self._extra_module_types.append(compression_module_type)

    def add_compression_module(
            self, module_key: str, module: nn.Module,
            compression_module_type: ExtraCompressionModuleType):
        attr_name = self._compression_module_type_to_attr_name(
            compression_module_type)
        if compression_module_type not in self._extra_module_types:
            raise RuntimeError("Module type {} was not registered".format(
                compression_module_type))
        storage = self.__getattr__(attr_name)
        if module_key in storage:
            raise RuntimeError(
                "Module {} is already registered under {}".format(
                    module_key, attr_name))
        storage[module_key] = module

    def get_compression_modules_by_type(
            self, compression_module_type: ExtraCompressionModuleType
    ) -> nn.ModuleDict:
        attr_name = self._compression_module_type_to_attr_name(
            compression_module_type)
        if compression_module_type not in self._extra_module_types:
            raise RuntimeError("Module type {} was not registered".format(
                compression_module_type))
        return self.__getattr__(attr_name)

    @staticmethod
    def _compression_module_type_to_attr_name(
            compression_module_type: ExtraCompressionModuleType):
        """
        Required for backward compatibility with checkpoints that store function and activation
        quantizers directly under corresponding attributes of NNCFNetwork.
        """
        if compression_module_type == ExtraCompressionModuleType.EXTERNAL_QUANTIZER:
            return EXTERNAL_QUANTIZERS_STORAGE_NAME
        raise RuntimeError("Unknown extra module type")

    def sort_compression_modules(
            self, compression_module_type: ExtraCompressionModuleType):
        attr_name = self._compression_module_type_to_attr_name(
            compression_module_type)
        if compression_module_type not in self._extra_module_types:
            raise RuntimeError("Module type {} was not registered".format(
                compression_module_type))
        module_dict = self.__getattr__(attr_name)
        # pylint: disable=protected-access
        module_dict._modules = OrderedDict(sorted(
            module_dict._modules.items()))
        self.__setattr__(attr_name, module_dict)

    @staticmethod
    def _normalize_variable_recurrent_scope(scope: Scope):
        """
        Two scopes pointing to an NNCF module that only differ in a Recurrent/VariableRecurrent/VariableRecurrentReverse
        scope node actually point to one and the same module.
        """
        ret_scope = scope.copy()
        for scope_element in ret_scope:
            if scope_element.calling_module_class_name in [
                    "Recurrent", "VariableRecurrent",
                    "VariableRecurrentReverse"
            ]:
                scope_element.calling_module_class_name = "NormalizedName_Recurrent"
        return ret_scope

    def do_dummy_forward(self, force_eval=False):
        """
        Attention: If run with force_eval=False, this may spoil the batchnorm statistics,
        and an eval run of the model will perform much worse than the train run.
        """
        if force_eval:
            train_mode = self.training
            self.eval()
        with torch.no_grad():
            with self._compressed_context as ctx:
                ctx.base_module_thread_local_replica = self
                self._dummy_forward_fn(self)
        if force_eval:
            if train_mode:
                self.train()

    def get_insertion_point_graph(self) -> InsertionPointGraph:
        # Set up a pre- and post-hooks on almost every op in PyTorch
        nncf_graph = self.get_original_graph()
        pre_hooks = []  # type: List[PreHookInsertionPoint]
        post_hooks = []  # type: List[PostHookInsertionPoint]
        for node in nncf_graph.get_all_nodes():
            # Pre-hook insertion point nodes
            # Will insert a pre-hook IP for each input edge. The input edge must be marked with
            # a port ID attribute.
            in_edges = nncf_graph.get_input_edges(node)
            for edge in in_edges:
                port_id = edge.input_port_id
                pre_hook_ip = PreHookInsertionPoint(
                    target_node_name=node.node_name, input_port_id=port_id)
                pre_hooks.append(pre_hook_ip)

            if issubclass(node.metatype, PTSplitMetatype):
                # chunk returns a tuple of tensors, which can only be handled in NNCF
                # once post-hook ports are enabled. Work around it for now by disallowing post-hook
                # insertion for chunks
                # TODO: enable post-hook ports and remove this
                continue

            # Post-hook insertion point nodes
            post_hook_ip = PostHookInsertionPoint(node.node_name)
            post_hooks.append(post_hook_ip)

        weighted_nodes = self.get_weighted_original_graph_nodes()
        weighted_node_names = [
            weighted_node.node_name for weighted_node in weighted_nodes
        ]

        ip_graph = InsertionPointGraph(
            self._original_graph,
            weight_modifiable_node_names=weighted_node_names,
            allowed_pre_hook_insertion_points=pre_hooks,
            allowed_post_hook_insertion_points=post_hooks)
        return ip_graph

    def get_module_by_scope(self, scope: Scope) -> Optional[torch.nn.Module]:
        curr_module = self.get_nncf_wrapped_model()
        for scope_element in scope[
                1:]:  # omit first scope element which corresponds to base module
            if scope_element.calling_field_name is None:
                # The module used is being created in-place every time and never stored in the model,
                # happens for nn.Softmax in BERT implementations.
                return None
            # pylint: disable=protected-access
            next_module = curr_module._modules.get(
                scope_element.calling_field_name)
            if next_module is None:
                raise RuntimeError(
                    "Could not find a {} module member in {} module of scope {} during node search"
                    .format(scope_element.calling_field_name,
                            scope_element.calling_module_class_name,
                            str(scope)))
            curr_module = next_module
        return curr_module

    def get_containing_module(self,
                              node_name: NNCFNodeName) -> torch.nn.Module:
        if self._compressed_graph is not None:
            try:
                scope = self._compressed_graph.get_scope_by_node_name(
                    node_name)
            except RuntimeError:
                nncf_logger.debug(
                    "Node {} not found in compressed graph when trying to determine containing module, "
                    "trying the original graph to see if the node was present there "
                    "during graph building")
                scope = self._original_graph.get_scope_by_node_name(node_name)
        else:
            scope = self._original_graph.get_scope_by_node_name(node_name)
        return self.get_module_by_scope(scope)

    def get_parameters_count_in_model(self):
        """
        Return total amount of model parameters.
        """
        count = 0
        for param in self.parameters():
            count = count + param.numel()
        return count

    def get_flops_per_module(self) -> Dict[NNCFNodeName, int]:
        """
        Calculates FLOPS count for modules.
        """
        model = self
        flops_count_dict = {}

        def get_hook(name):
            return functools.partial(compute_FLOPs_hook,
                                     dict_to_save=flops_count_dict,
                                     module_node_name=name)

        hook_list = []
        for nncf_node in self._original_graph.get_all_nodes():
            node_module = self.get_containing_module(nncf_node.node_name)
            hook_list.append(
                node_module.register_forward_hook(get_hook(
                    nncf_node.node_name)))
        model.do_dummy_forward(force_eval=True)

        for h in hook_list:
            h.remove()
        return flops_count_dict

    def get_MACs_in_model(self):
        """
            Calculates MAC units count for model.
        """
        flops_count_dict = self.get_flops_per_module()
        total_MACs_count = sum(v // 2 for v in flops_count_dict.values())
        return total_MACs_count

    def get_input_infos(self) -> List[ModelInputInfo]:
        return deepcopy(self.input_infos)

    def save_nncf_module_additions(
            self
    ) -> Dict[Scope, Tuple[torch.nn.ModuleDict, torch.nn.ModuleDict]]:
        retval = {}
        for module_scope, nncf_module in self.get_nncf_modules().items():
            retval[module_scope] = (deepcopy(nncf_module.pre_ops),
                                    deepcopy(nncf_module.post_ops))
        return retval

    def load_nncf_module_additions(
        self, scope_vs_pre_post_ops_dict: Dict[Scope,
                                               Tuple[torch.nn.ModuleDict,
                                                     torch.nn.ModuleDict]]):
        for module_scope, nncf_module in self.get_nncf_modules().items():
            nncf_module.pre_ops = scope_vs_pre_post_ops_dict[module_scope][0]
            nncf_module.post_ops = scope_vs_pre_post_ops_dict[module_scope][1]

    def temporary_clean_view(self):
        class Mgr:
            def __init__(self, model: NNCFNetwork):
                self.model = model
                self.storage_dict = {}

            def __enter__(self):
                self.storage_dict = self.model.save_nncf_module_additions()
                clean_model = self.model.get_clean_shallow_copy()
                return clean_model

            def __exit__(self, exc_type, exc_val, exc_tb):
                self.model.load_nncf_module_additions(self.storage_dict)

        return Mgr(self)

    def _collect_eval_only_op_scopes(
            self, model: nn.Module, dummy_forward_fn: Callable) -> List[Scope]:
        """
        Returns scopes of the modules which are executed in evaluation mode only.
        """

        tracer = GraphTracer(dummy_forward_fn)
        result = []
        eval_graph = tracer.trace_graph(model, as_eval=True)
        for dyn_graph_node in eval_graph.get_all_nodes():
            result.append(dyn_graph_node.op_exec_context.scope_in_model)
        return result

    @property
    def original_model_accuracy(self):
        return self._original_model_accuracy

    def get_node_to_op_address_mapping(
            self) -> Dict[NNCFNodeName, OperationAddress]:
        # The IDs of corresponding nodes of the original dynamic graph and original NNCF graph
        # must be equal for this to work.
        retval = {}
        for node in self._original_dynamic_graph.get_all_nodes():
            node_id = node.node_id
            op_address = node.op_exec_context.op_address
            nncf_node = self._original_graph.get_node_by_id(node_id)
            retval[nncf_node.node_name] = op_address
        return retval