コード例 #1
0
ファイル: nncf_network.py プロジェクト: zbrnwpu/nncf
    def commit_compression_changes(self) -> 'CompressionAlgorithmController':
        for insertion_point, fn_list_with_priority in self._insertions_into_original_graph.items(
        ):
            fn_list_with_priority = sorted(fn_list_with_priority,
                                           key=lambda x: x[1])
            self._insertions_into_original_graph[
                insertion_point] = fn_list_with_priority
            self._insert_at_point(insertion_point,
                                  [x[0] for x in fn_list_with_priority])

        if self.debug_interface is not None:
            self.debug_interface.init_actual(self)

        quantization_types = [
            class_type.__name__
            for class_type in QUANTIZATION_MODULES.registry_dict.values()
        ]
        all_quantizations = get_state_dict_names_with_modules(
            self, quantization_types)
        self._load_listener = LoadStateListener(self, all_quantizations)

        if not self._builders:
            from nncf.algo_selector import NoCompressionAlgorithmController
            return NoCompressionAlgorithmController(self)

        if len(self._builders) == 1:
            return self._builders[0].build_controller(self)

        from nncf.composite_compression import CompositeCompressionAlgorithmController
        composite_controller = CompositeCompressionAlgorithmController(self)
        for algo_builder in self._builders:
            composite_controller.add(algo_builder.build_controller(self))
        return composite_controller
コード例 #2
0
ファイル: nncf_network.py プロジェクト: zbrnwpu/nncf
 def post_build_graph_actions(self):
     # Reset initialization flags (`initialized`) for all quantization modules
     # after dummy `load_state_dict` call.
     quantization_types = [
         class_type.__name__
         for class_type in QUANTIZATION_MODULES.registry_dict.values()
     ]
     all_quantizations = get_state_dict_names_with_modules(
         self, quantization_types)
     for module in all_quantizations.values():
         module.initialized = False