def commit_compression_changes(self) -> 'CompressionAlgorithmController': for insertion_point, fn_list_with_priority in self._insertions_into_original_graph.items( ): fn_list_with_priority = sorted(fn_list_with_priority, key=lambda x: x[1]) self._insertions_into_original_graph[ insertion_point] = fn_list_with_priority self._insert_at_point(insertion_point, [x[0] for x in fn_list_with_priority]) if self.debug_interface is not None: self.debug_interface.init_actual(self) quantization_types = [ class_type.__name__ for class_type in QUANTIZATION_MODULES.registry_dict.values() ] all_quantizations = get_state_dict_names_with_modules( self, quantization_types) self._load_listener = LoadStateListener(self, all_quantizations) if not self._builders: from nncf.algo_selector import NoCompressionAlgorithmController return NoCompressionAlgorithmController(self) if len(self._builders) == 1: return self._builders[0].build_controller(self) from nncf.composite_compression import CompositeCompressionAlgorithmController composite_controller = CompositeCompressionAlgorithmController(self) for algo_builder in self._builders: composite_controller.add(algo_builder.build_controller(self)) return composite_controller
def post_build_graph_actions(self): # Reset initialization flags (`initialized`) for all quantization modules # after dummy `load_state_dict` call. quantization_types = [ class_type.__name__ for class_type in QUANTIZATION_MODULES.registry_dict.values() ] all_quantizations = get_state_dict_names_with_modules( self, quantization_types) for module in all_quantizations.values(): module.initialized = False