def test_number_of_nodes_for_module_in_loop__not_input_node(self):
        num_iter = 5

        class LoopModule(nn.Module):
            class Inner(nn.Module):
                def forward(self, x):
                    s = F.sigmoid(x)
                    t = F.tanh(x)
                    result = F.sigmoid(x) * t + F.tanh(x) * s
                    return result

                @staticmethod
                def nodes_number():
                    return 7

            def __init__(self):
                super().__init__()
                self.inner = self.Inner()

            def forward(self, x):
                for _ in range(num_iter):
                    x = self.inner(F.relu(x))
                return x

            def nodes_number(self):
                return self.inner.nodes_number() + num_iter

        test_module = LoopModule()
        context = TracingContext()
        with context as ctx:
            _ = test_module(torch.zeros(1))
            assert ctx.graph.get_nodes_count() == test_module.nodes_number()
Exemplo n.º 2
0
    def __init__(self,
                 module,
                 input_infos: List[ModelInputInfo] = None,
                 dummy_forward_fn=None,
                 scopes_without_shape_matching=None,
                 ignored_scopes=None,
                 target_scopes=None):
        super().__init__()
        self.set_nncf_wrapped_model(module)
        self.input_infos = input_infos
        self.ignored_scopes = ignored_scopes
        self.target_scopes = target_scopes
        self._dummy_forward_fn = dummy_forward_fn
        self._nncf_module_scopes = []  # type: List[Scope]
        self.scopes_without_shape_matching = scopes_without_shape_matching
        self.debug_interface = CombinedDebugInterface() if is_debug() else None
        self._extra_module_types = []  # type: List[CompressionModuleType]
        # pylint:disable=line-too-long
        self._insertions_into_original_graph = {
        }  # type: Dict[InsertionPoint, List[Tuple[Callable, OperationPriority]]]

        device = next(module.parameters()).device

        # all modules should be replaced prior to graph building
        self._replace_modules_by_nncf_modules(device)

        _orig_context = TracingContext()
        _orig_graph_build_forward_fn = self._get_dummy_forward_fn_for_graph_building(
            with_input_tracing=True)

        self._graph_builder = GraphBuilder(_orig_graph_build_forward_fn)

        _orig_context.add_node_comparators([MODEL_INPUT_OP_NAME],
                                           ShapeIgnoringTensorMetaComparator())
        if self.scopes_without_shape_matching:
            _orig_context.add_node_comparators(
                scopes_without_shape_matching,
                ShapeIgnoringTensorMetaComparator())

        self._original_graph = self._graph_builder.build_graph(
            self.get_nncf_wrapped_model(), _orig_context)

        self._compressed_context = TracingContext()

        self._dummy_forward_fn = self._get_dummy_forward_fn_for_graph_building(
            with_input_tracing=False)

        self._compressed_context.add_node_comparators(
            [MODEL_INPUT_OP_NAME], ShapeIgnoringTensorMetaComparator())
        if self.scopes_without_shape_matching:
            self._compressed_context.add_node_comparators(
                scopes_without_shape_matching,
                ShapeIgnoringTensorMetaComparator())
        self._load_listener = None

        self._builders = []  # type: List['CompressionAlgorithmBuilder']
    def test_number_of_nodes_for_module_with_nested_loops(self):
        num_iter = 5

        class TestIterModule(nn.Module):
            @ITERATION_MODULES.register()
            class TestIterModule_ResetPoint(nn.Module):
                def __init__(self, loop_module):
                    super().__init__()
                    self.loop_module = loop_module

                def forward(self, x):
                    return self.loop_module(F.relu(x))

            def __init__(self):
                super().__init__()
                self.loop_module = self.LoopModule2()
                self.reset_point = self.TestIterModule_ResetPoint(
                    self.loop_module)

            def forward(self, x):
                for _ in range(num_iter):
                    x = self.reset_point(x)
                return x

            class LoopModule2(nn.Module):
                @ITERATION_MODULES.register()
                class LoopModule2_ResetPoint(nn.Module):
                    def __init__(self, inner):
                        super().__init__()
                        self.inner = inner

                    def forward(self, x):
                        return self.inner(F.relu(x))

                def __init__(self):
                    super().__init__()
                    self.inner = self.Inner()
                    self.reset_helper = self.LoopModule2_ResetPoint(self.inner)

                def forward(self, x):
                    for _ in range(num_iter):
                        self.reset_helper(x)
                    return x

                class Inner(nn.Module):
                    def forward(self, x):
                        s = F.sigmoid(x)
                        t = F.tanh(x)
                        result = t + s
                        return result

        test_module = TestIterModule()
        context = TracingContext()
        with context as ctx:
            _ = test_module(torch.zeros(1))
            assert ctx.graph.get_nodes_count() == num_iter
    def build_graph(self, model: torch.nn.Module, context_to_use: Optional['TracingContext'] = None) -> 'NNCFGraph':
        sd = deepcopy(model.state_dict())

        from nncf.dynamic_graph.context import TracingContext
        if context_to_use is None:
            context_to_use = TracingContext()

        with context_to_use as _ctx:
            self.custom_forward_fn(model)
        model.load_state_dict(sd)

        if isinstance(model, PostGraphBuildActing):
            model.post_build_graph_actions()
        return context_to_use.graph
Exemplo n.º 5
0
def test_iterate_module_list():
    class Net(nn.Module):
        def __init__(self):
            super().__init__()
            self.ml = nn.ModuleList([nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1)])

        def forward(self, x):
            return [self.ml[0](x), self.ml[1](x)]

    net = Net()

    context = TracingContext()
    with context:
        _ = net(torch.zeros(1, 1, 1, 1))

    check_graph(context.graph, 'case_iterate_module_list.dot', 'original')
    def test_number_of_nodes_for_module_in_loop(self):
        num_iter = 5

        class LoopModule(nn.Module):
            @ITERATION_MODULES.register('Inner')
            class Inner(nn.Module):
                def __init__(self):
                    super().__init__()
                    self.operator1 = torch.sigmoid
                    self.operator2 = torch.tanh

                def forward(self, x):
                    s = self.operator1(x)
                    t = self.operator2(x)
                    result = t + s
                    return result

                @staticmethod
                def nodes_number():
                    return 3

            def __init__(self):
                super().__init__()
                self.inner = self.Inner()

            def forward(self, x):
                for _ in range(num_iter):
                    x = self.inner(x)
                return x

            def nodes_number(self):
                return self.inner.nodes_number()

        test_module = LoopModule()
        context = TracingContext()
        with context as ctx:
            _ = test_module(torch.zeros(1))
            assert ctx.graph.get_nodes_count() == test_module.nodes_number()
    def test_number_of_nodes_for_repeated_module(self):
        class LoopModule(nn.Module):
            def __init__(self):
                super().__init__()
                self.operator = F.relu
                self.layers = nn.ModuleList(
                    [nn.Conv2d(1, 1, 1),
                     nn.Conv2d(1, 1, 1)])

            def forward(self, x):
                for layer in self.layers:
                    x = F.relu(layer(x))
                return x

        test_module = LoopModule()
        context = TracingContext()
        with context as ctx:
            x = test_module(torch.zeros(1, 1, 1, 1))
            assert ctx.graph.get_nodes_count(
            ) == 4  # NB: may always fail in debug due to superfluous 'cat' nodes
            _ = test_module(x)
            assert ctx.graph.get_nodes_count(
            ) == 8  # NB: may always fail in debug due to superfluous 'cat' nodes
Exemplo n.º 8
0
    def build_graph(self,
                    model: torch.nn.Module,
                    context_to_use: Optional['TracingContext'] = None,
                    as_eval: bool = False) -> 'NNCFGraph':
        sd = deepcopy(model.state_dict())

        from nncf.dynamic_graph.context import TracingContext
        if context_to_use is None:
            context_to_use = TracingContext()

        from nncf.utils import training_mode_switcher
        context_to_use.base_module_thread_local_replica = model
        with context_to_use as _ctx:
            with torch.no_grad():
                if as_eval:
                    with training_mode_switcher(model, is_training=False):
                        self.custom_forward_fn(model)
                else:
                    self.custom_forward_fn(model)
        model.load_state_dict(sd)

        if isinstance(model, PostGraphBuildActing):
            model.post_build_graph_actions()
        return context_to_use.graph
Exemplo n.º 9
0
class NNCFNetwork(nn.Module, PostGraphBuildActing):
    def __init__(self,
                 module,
                 input_infos: List[ModelInputInfo] = None,
                 dummy_forward_fn=None,
                 wrap_inputs_fn=None,
                 scopes_without_shape_matching=None,
                 ignored_scopes=None,
                 target_scopes=None):
        super().__init__()
        self._set_nncf_wrapped_model(module)
        self._forward_signature = inspect.signature(module.forward)
        self.input_infos = input_infos

        self.ignored_scopes = ignored_scopes
        self.target_scopes = target_scopes
        self._dummy_forward_fn = dummy_forward_fn

        device = next(module.parameters()).device

        if wrap_inputs_fn is not None:
            self._wrap_inputs_fn = wrap_inputs_fn
        else:
            self.__input_infos_based_input_wrapper = InputInfoWrapManager(
                self.input_infos,
                self._forward_signature,
                module_ref_for_device=self)
            self._wrap_inputs_fn = self.__input_infos_based_input_wrapper.wrap_inputs

        self._nncf_module_scopes = []  # type: List[Scope]
        self.scopes_without_shape_matching = scopes_without_shape_matching
        self.debug_interface = CombinedDebugInterface() if is_debug() else None
        self._extra_module_types = []  # type: List[CompressionModuleType]
        # pylint:disable=line-too-long
        self._insertions_into_original_graph = {
        }  # type: Dict[InsertionPoint, List[Tuple[Callable, OperationPriority]]]

        # all modules should be replaced prior to graph building
        self._replace_modules_by_nncf_modules(device)

        _orig_context = TracingContext()
        _orig_graph_build_forward_fn = self._get_dummy_forward_fn_for_graph_building(
            with_input_tracing=True)

        self._graph_builder = GraphBuilder(_orig_graph_build_forward_fn)

        _orig_context.add_node_comparators([MODEL_INPUT_OP_NAME],
                                           ShapeIgnoringTensorMetaComparator())
        if self.scopes_without_shape_matching:
            _orig_context.add_node_comparators(
                scopes_without_shape_matching,
                ShapeIgnoringTensorMetaComparator())

        self._original_graph = self._graph_builder.build_graph(
            self.get_nncf_wrapped_model(), _orig_context)

        self._compressed_context = TracingContext()

        self._dummy_forward_fn = self._get_dummy_forward_fn_for_graph_building(
            with_input_tracing=False)

        self._compressed_context.add_node_comparators(
            [MODEL_INPUT_OP_NAME], ShapeIgnoringTensorMetaComparator())
        if self.scopes_without_shape_matching:
            self._compressed_context.add_node_comparators(
                scopes_without_shape_matching,
                ShapeIgnoringTensorMetaComparator())
        self._load_listener = None

        self._builders = []  # type: List['CompressionAlgorithmBuilder']

    @debuggable_forward
    def forward(self, *args, **kwargs):
        with self._compressed_context as ctx:  # type: TracingContext
            ctx.base_module_thread_local_replica = self
            args, kwargs = self._wrap_inputs_fn(args, kwargs)
            retval = self.get_nncf_wrapped_model()(*args, **kwargs)
        return retval

    def register_algorithm(self, builder: 'CompressionAlgorithmBuilder'):
        """Should be called during *builder*'s *apply_to* method, otherwise there will be no corresponding
        controller returned by the network on the *commit_compression_changes* stage"""
        self._builders.append(builder)

    # Cannnot use property syntax here, otherwise the wrapped module will end up
    # being twice in the same checkpoint with different prefixes
    def get_nncf_wrapped_model(self):
        return getattr(self, MODEL_WRAPPED_BY_NNCF_ATTR_NAME)

    def _set_nncf_wrapped_model(self, value):
        setattr(self, MODEL_WRAPPED_BY_NNCF_ATTR_NAME, value)

    def get_modules_in_nncf_modules_by_type(self,
                                            types) -> Dict['Scope', nn.Module]:
        nncf_modules = self.get_nncf_modules()
        retval = {}
        for nncf_module_scope, nncf_module in nncf_modules.items():
            nncf_module_scope.pop()
            for relative_scope, target_module in get_all_modules_by_type(
                    nncf_module, types).items():
                retval[nncf_module_scope + relative_scope] = target_module
        return retval

    def register_insertion_command(self, command: InsertionCommand):
        point = command.insertion_point
        if point not in self._insertions_into_original_graph:
            self._insertions_into_original_graph[point] = [(command.fn,
                                                            command.priority)]
        else:
            self._insertions_into_original_graph[point].append(
                (command.fn, command.priority))

    def commit_compression_changes(self) -> 'CompressionAlgorithmController':
        for insertion_point, fn_list_with_priority in self._insertions_into_original_graph.items(
        ):
            fn_list_with_priority = sorted(fn_list_with_priority,
                                           key=lambda x: x[1])
            self._insertions_into_original_graph[
                insertion_point] = fn_list_with_priority
            self._insert_at_point(insertion_point,
                                  [x[0] for x in fn_list_with_priority])

        if self.debug_interface is not None:
            self.debug_interface.init_actual(self)

        quantization_types = [
            class_type.__name__
            for class_type in QUANTIZATION_MODULES.registry_dict.values()
        ]
        all_quantizations = get_state_dict_names_with_modules(
            self, quantization_types)
        self._load_listener = LoadStateListener(self, all_quantizations)

        if not self._builders:
            from nncf.algo_selector import NoCompressionAlgorithmController
            return NoCompressionAlgorithmController(self)

        if len(self._builders) == 1:
            return self._builders[0].build_controller(self)

        from nncf.composite_compression import CompositeCompressionAlgorithmController
        composite_controller = CompositeCompressionAlgorithmController(self)
        for algo_builder in self._builders:
            composite_controller.add(algo_builder.build_controller(self))
        return composite_controller

    def _insert_at_point(self, point: InsertionPoint, fn_list: List[Callable]):
        if point.insertion_type == InsertionType.OPERATOR_PRE_HOOK:
            self._compressed_context.register_pre_hooks(
                fn_list, point.ia_op_exec_context)
        elif point.insertion_type == InsertionType.OPERATOR_POST_HOOK:
            self._compressed_context.register_post_hooks(
                fn_list, point.ia_op_exec_context)
        else:
            norm_target_scope = self._normalize_variable_recurrent_scope(
                point.ia_op_exec_context.scope_in_model)
            norm_nncf_scopes = [
                self._normalize_variable_recurrent_scope(x)
                for x in self._nncf_module_scopes
            ]
            assert norm_target_scope in norm_nncf_scopes  # Required for proper Recurrent/VariableRecurrent addressing
            nncf_module = self.get_module_by_scope(
                point.ia_op_exec_context.scope_in_model)
            if point.insertion_type == InsertionType.NNCF_MODULE_PRE_OP:
                for fn in fn_list:
                    nncf_module.register_pre_forward_operation(fn)
            elif point.insertion_type == InsertionType.NNCF_MODULE_POST_OP:
                for fn in fn_list:
                    nncf_module.register_post_forward_operation(fn)

    def __getattr__(self, name):
        wrapped_module = super().__getattr__(MODEL_WRAPPED_BY_NNCF_ATTR_NAME)
        if hasattr(wrapped_module, name):
            return getattr(wrapped_module, name)
        return super().__getattr__(name)

    def get_graph(self) -> NNCFGraph:
        return self._compressed_context.graph

    def get_original_graph(self) -> NNCFGraph:
        return self._original_graph

    def get_tracing_context(self) -> TracingContext:
        return self._compressed_context

    def _get_dummy_forward_fn_for_graph_building(self, with_input_tracing):
        if self._dummy_forward_fn is None:
            return create_dummy_forward_fn(
                self.input_infos,
                with_input_tracing=with_input_tracing,
                wrap_inputs_fn=self._wrap_inputs_fn)
        return self._dummy_forward_fn

    def _replace_modules_by_nncf_modules(self, device):
        module, self._nncf_module_scopes = replace_modules_by_nncf_modules(
            self.get_nncf_wrapped_model(),
            ignored_scopes=self.ignored_scopes,
            target_scopes=self.target_scopes)
        self._set_nncf_wrapped_model(module.to(device))

    def get_nncf_module_scopes(self) -> List['Scope']:
        return self._nncf_module_scopes

    def get_nncf_modules(self) -> Dict['Scope', torch.nn.Module]:
        nncf_module_names_list = NNCF_MODULES + [
            x.__name__ for x in NNCF_WRAPPED_USER_MODULES_DICT.values()
        ]
        return get_all_modules_by_type(self.get_nncf_wrapped_model(),
                                       nncf_module_names_list)

    def rebuild_graph(self, *input_args):
        self._compressed_context.reset_graph()
        dummy_forward_fn = self._get_dummy_forward_fn_for_graph_building(
            with_input_tracing=False)
        builder = GraphBuilder(dummy_forward_fn)
        _ = builder.build_graph(self, self._compressed_context)

    def post_build_graph_actions(self):
        # Reset initialization flags (`initialized`) for all quantization modules
        # after dummy `load_state_dict` call.
        quantization_types = [
            class_type.__name__
            for class_type in QUANTIZATION_MODULES.registry_dict.values()
        ]
        all_quantizations = get_state_dict_names_with_modules(
            self, quantization_types)
        for module in all_quantizations.values():
            module.initialized = False

    def get_post_pattern_insertion_points(
            self,
            pattern: 'NNCFNodeExpression',
            omit_nodes_in_nncf_modules=False) -> List[InsertionInfo]:
        io_infos = self._original_graph.get_matching_nncf_graph_pattern_io_list(
            pattern)

        insertion_infos = []
        for io_info in io_infos:
            # The input/output is given in terms of edges, but the post-hooks are currently applied to
            # nodes. Multiple output edges in a pattern I/O info may originate from one and the same
            # node, and we have to ensure that these resolve into just one insertion point - thus the usage of "set".
            pattern_insertion_info_set = set()
            if len(io_info.output_edges) > 1:
                nncf_logger.debug(
                    "WARNING: pattern has more than one activation output")

            for nncf_node in io_info.output_nodes:
                pattern_insertion_info_set.add(
                    InsertionInfo(nncf_node.op_exec_context,
                                  is_output=True,
                                  shape_to_operate_on=None))
                # TODO: determine output shapes for output nodes to enable per-channel quantization

            # Ignore input nodes in the pattern for now, rely on the _quantize_inputs functions.
            # TODO: handle input quantization here as well

            # Since this function is currently only used for activation quantization purposes via operator
            # post-hook mechanism, we may take any edge and it will point from the same node where we will have to
            # insert a quantizer later. However, in the future the output edges may refer to activation tensors
            # with different sizes, in which case we have to insert different per-channel quantizers to
            # accomodate different trainable params if there is a difference in the channel dimension.
            # Furthermore, currently there is no distinction for single tensor output to multiple nodes and
            # multiple tensor output to multiple nodes ("chunk" operation is an example of the latter).
            # The pattern may also have unexpected outputs from a node in the middle of the pattern (see
            # "densenet121.dot" for an example of this) - need to decide what to do with that in terms
            # of quantization.
            # TODO: address the issues above.

            for nncf_edge in io_info.output_edges:
                pattern_insertion_info_set.add(
                    InsertionInfo(nncf_edge.from_node.op_exec_context,
                                  is_output=False,
                                  shape_to_operate_on=nncf_edge.tensor_shape))
            insertion_infos += list(pattern_insertion_info_set)

        insertion_infos = list(
            set(insertion_infos)
        )  # Filter the overlapping insertion points from different matches (happens for GNMT)
        insertion_infos_filtered = []

        for info in insertion_infos:
            if omit_nodes_in_nncf_modules and self.is_scope_in_nncf_module_scope(
                    info.op_exec_context.scope_in_model):
                continue
            insertion_infos_filtered.append(info)

        return insertion_infos_filtered

    def is_scope_in_nncf_module_scope(self, scope: 'Scope'):
        # TODO: optimize
        norm_nncf_scopes = [
            self._normalize_variable_recurrent_scope(x)
            for x in self._nncf_module_scopes
        ]
        norm_op_scope = self._normalize_variable_recurrent_scope(scope)
        for nncf_scope in norm_nncf_scopes:
            if norm_op_scope in nncf_scope:
                return True
        return False

    def register_compression_module_type(
            self, compression_module_type: CompressionModuleType):
        attr_name = self._compression_module_type_to_attr_name(
            compression_module_type)
        if compression_module_type in self._extra_module_types:
            raise RuntimeError("Module type {} is already registered".format(
                compression_module_type))
        self.__setattr__(attr_name, nn.ModuleDict())
        self._extra_module_types.append(compression_module_type)

    def add_compression_module(self, module_key: str, module: nn.Module,
                               compression_module_type: CompressionModuleType):
        attr_name = self._compression_module_type_to_attr_name(
            compression_module_type)
        if compression_module_type not in self._extra_module_types:
            raise RuntimeError("Module type {} was not registered".format(
                compression_module_type))
        self.__getattr__(attr_name)[module_key] = module

    def get_compression_modules_by_type(
            self,
            compression_module_type: CompressionModuleType) -> nn.ModuleDict:
        attr_name = self._compression_module_type_to_attr_name(
            compression_module_type)
        if compression_module_type not in self._extra_module_types:
            raise RuntimeError("Module type {} was not registered".format(
                compression_module_type))
        return self.__getattr__(attr_name)

    @staticmethod
    def _compression_module_type_to_attr_name(
            compression_module_type: CompressionModuleType):
        """Required for backward compatibility with checkpoints that store function and activation
        quantizers directly under corresponding attributes of NNCFNetwork."""
        if compression_module_type == CompressionModuleType.FUNCTION_QUANTIZER:
            return "function_quantizers"
        if compression_module_type == CompressionModuleType.ACTIVATION_QUANTIZER:
            return "activation_quantizers"
        raise RuntimeError("Unknown extra module type")

    def sort_compression_modules(
            self, compression_module_type: CompressionModuleType):
        attr_name = self._compression_module_type_to_attr_name(
            compression_module_type)
        if compression_module_type not in self._extra_module_types:
            raise RuntimeError("Module type {} was not registered".format(
                compression_module_type))
        module_dict = self.__getattr__(attr_name)
        # pylint: disable=protected-access
        module_dict._modules = OrderedDict(sorted(
            module_dict._modules.items()))
        self.__setattr__(attr_name, module_dict)

    @staticmethod
    def _normalize_variable_recurrent_scope(scope: 'Scope'):
        """
        Two scopes pointing to an NNCF module that only differ in a Recurrent/VariableRecurrent/VariableRecurrentReverse
        scope element actually point to one and the same module.
        """
        ret_scope = scope.copy()
        for scope_element in ret_scope:
            if scope_element.calling_module_class_name in [
                    "Recurrent", "VariableRecurrent",
                    "VariableRecurrentReverse"
            ]:
                scope_element.calling_module_class_name = "NormalizedName_Recurrent"
        return ret_scope

    def do_dummy_forward(self, force_eval=False):
        """Attention: If run with force_eval=False, this may spoil the batchnorm statistics,
        and an eval run of the model will perform much worse than the train run. """
        if force_eval:
            train_mode = self.training
            self.eval()
        with torch.no_grad():
            self._dummy_forward_fn(self)
        if force_eval:
            if train_mode:
                self.train()

    def get_insertion_point_graph(self) -> InsertionPointGraph:
        ip_graph = InsertionPointGraph(
            self._original_graph.get_nx_graph_copy())

        # Mark IP graph operator nodes with associated op metatypes
        # Determining operator metatypes is more suited to occur at wrap_operator
        # stage, because it might be influenced by specific non-tensor function paramters,
        # but we have to inspect the containing module parameters as well, so the
        # TracingContext in wrap_operator would have to retain a reference to
        # the model that uses it. Since currently we do not need to inspect the
        # function arguments to determine the metatype, we can do this here, but
        # once we need to inspect the arguments, the code will have to be moved to
        # wrap_operator.

        for node_key in ip_graph.nodes:
            ip_graph_node = ip_graph.nodes[node_key]
            ip_graph_node_type = ip_graph_node[
                InsertionPointGraph.NODE_TYPE_NODE_ATTR]
            if ip_graph_node_type == InsertionPointGraphNodeType.OPERATOR:
                nncf_graph_node_ref = ip_graph_node[
                    InsertionPointGraph.REGULAR_NODE_REF_NODE_ATTR]
                op_exec_context = nncf_graph_node_ref[
                    NNCFGraph.OP_EXEC_CONTEXT_NODE_ATTR]
                op_name = op_exec_context.operator_name
                scope = op_exec_context.scope_in_model
                op_arch = OPERATOR_METATYPES.get_operator_metatype_by_op_name(
                    op_name)
                module = self.get_module_by_scope(scope)
                if module is not None:
                    subtype = op_arch.determine_subtype(
                        containing_module=module)
                    if subtype is not None:
                        op_arch = subtype
                ip_graph_node[
                    InsertionPointGraph.OPERATOR_METATYPE_NODE_ATTR] = op_arch
        return ip_graph

    def get_module_by_scope(self, scope: 'Scope') -> torch.nn.Module:
        curr_module = self.get_nncf_wrapped_model()
        for scope_element in scope[
                1:]:  # omit first scope element which corresponds to base module
            if scope_element.calling_field_name is None:
                # The module used is being created in-place every time and never stored in the model,
                # happens for nn.Softmax in BERT implementations.
                return None
            # pylint: disable=protected-access
            next_module = curr_module._modules.get(
                scope_element.calling_field_name)
            if next_module is None:
                raise RuntimeError(
                    "Could not find a {} module member in {} module of scope {} during node search"
                    .format(scope_element.calling_field_name,
                            scope_element.calling_module_class_name,
                            str(scope)))
            curr_module = next_module
        return curr_module

    def get_parameters_count_in_model(self):
        """
        Return total amount of model parameters.
        """
        count = 0
        for param in self.parameters():
            count = count + param.numel()
        return count

    def get_flops_per_module(self):
        """
        Calculates FLOPS count for modules.
        """
        model = self
        flops_count_dict = {}

        def get_hook(name):
            def compute_MACs_hook(module, input_, output):
                if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
                    ks = module.weight.data.shape
                    mac_count = ks[0] * ks[1] * ks[2] * ks[3] * output.shape[
                        3] * output.shape[2]
                elif isinstance(module, nn.Linear):
                    mac_count = input_[0].shape[1] * output.shape[-1]
                elif isinstance(module, nn.BatchNorm2d):
                    mac_count = np.prod(list(input_[0].shape))
                else:
                    return
                flops_count_dict[name] = 2 * mac_count

            return compute_MACs_hook

        hook_list = [
            m.register_forward_hook(get_hook(n))
            for n, m in model.named_modules()
        ]

        model.do_dummy_forward(force_eval=True)

        for h in hook_list:
            h.remove()
        return flops_count_dict

    def get_MACs_in_model(self):
        """
            Calculates MAC units count for model.
        """
        flops_count_dict = self.get_flops_per_module()
        total_MACs_count = sum(v // 2 for v in flops_count_dict.values())
        return total_MACs_count

    def get_input_infos(self) -> List[ModelInputInfo]:
        return deepcopy(self.input_infos)
Exemplo n.º 10
0
    def __init__(self,
                 module,
                 input_infos: List[ModelInputInfo],
                 dummy_forward_fn=None,
                 wrap_inputs_fn=None,
                 scopes_without_shape_matching=None,
                 ignored_scopes=None,
                 target_scopes=None,
                 reset: bool = False):
        super().__init__()
        self._set_nncf_wrapped_model(module)
        self._forward_signature = inspect.signature(module.forward)
        self.input_infos = input_infos

        self.ignored_scopes = ignored_scopes
        self.target_scopes = target_scopes
        self._user_dummy_forward_fn = dummy_forward_fn

        device = next(module.parameters()).device

        if wrap_inputs_fn is not None:
            self._wrap_inputs_fn = wrap_inputs_fn
        else:
            self.__input_infos_based_input_wrapper = InputInfoWrapManager(
                self.input_infos,
                self._forward_signature,
                module_ref_for_device=self)
            self._wrap_inputs_fn = self.__input_infos_based_input_wrapper.wrap_inputs

        self._nncf_module_scopes = []  # type: List[Scope]
        self.scopes_without_shape_matching = scopes_without_shape_matching
        self.debug_interface = CombinedDebugInterface() if is_debug() else None
        self._extra_module_types = []  # type: List[ExtraCompressionModuleType]
        # pylint:disable=line-too-long
        self._insertions_into_original_graph = {
        }  # type: Dict[InsertionPoint, List[Tuple[Callable, OperationPriority]]]

        _orig_graph_build_forward_fn = self._get_dummy_forward_fn_for_graph_building(
            with_input_tracing=True)
        self._graph_builder = GraphBuilder(_orig_graph_build_forward_fn)

        nncf_wrapped_model = self.get_nncf_wrapped_model()
        eval_only_ops_exec_ctx = self.collect_eval_only_ops_exec_context(
            nncf_wrapped_model, self._graph_builder)

        # all modules called in eval mode should be replaced prior to graph building
        self._replace_modules_by_nncf_modules(device, eval_only_ops_exec_ctx,
                                              reset)

        _orig_context = TracingContext()

        _orig_context.add_node_comparators([MODEL_INPUT_OP_NAME],
                                           ShapeIgnoringTensorMetaComparator())
        if self.scopes_without_shape_matching:
            _orig_context.add_node_comparators(
                scopes_without_shape_matching,
                ShapeIgnoringTensorMetaComparator())

        self._original_graph = self._graph_builder.build_graph(
            nncf_wrapped_model, _orig_context, as_eval=True)

        self._compressed_context = TracingContext()

        self._dummy_forward_fn = self._get_dummy_forward_fn_for_graph_building(
            with_input_tracing=False)

        self._compressed_context.add_node_comparators(
            [MODEL_INPUT_OP_NAME], ShapeIgnoringTensorMetaComparator())
        if self.scopes_without_shape_matching:
            self._compressed_context.add_node_comparators(
                scopes_without_shape_matching,
                ShapeIgnoringTensorMetaComparator())
        self._load_listener = None

        self._builders = []  # type: List['CompressionAlgorithmBuilder']