def __init__(self, module, input_infos: List[ModelInputInfo] = None, dummy_forward_fn=None, scopes_without_shape_matching=None, ignored_scopes=None, target_scopes=None): super().__init__() self.set_nncf_wrapped_model(module) self.input_infos = input_infos self.ignored_scopes = ignored_scopes self.target_scopes = target_scopes self._dummy_forward_fn = dummy_forward_fn self._nncf_module_scopes = [] # type: List[Scope] self.scopes_without_shape_matching = scopes_without_shape_matching self.debug_interface = CombinedDebugInterface() if is_debug() else None self._extra_module_types = [] # type: List[CompressionModuleType] # pylint:disable=line-too-long self._insertions_into_original_graph = { } # type: Dict[InsertionPoint, List[Tuple[Callable, OperationPriority]]] device = next(module.parameters()).device # all modules should be replaced prior to graph building self._replace_modules_by_nncf_modules(device) _orig_context = TracingContext() _orig_graph_build_forward_fn = self._get_dummy_forward_fn_for_graph_building( with_input_tracing=True) self._graph_builder = GraphBuilder(_orig_graph_build_forward_fn) _orig_context.add_node_comparators([MODEL_INPUT_OP_NAME], ShapeIgnoringTensorMetaComparator()) if self.scopes_without_shape_matching: _orig_context.add_node_comparators( scopes_without_shape_matching, ShapeIgnoringTensorMetaComparator()) self._original_graph = self._graph_builder.build_graph( self.get_nncf_wrapped_model(), _orig_context) self._compressed_context = TracingContext() self._dummy_forward_fn = self._get_dummy_forward_fn_for_graph_building( with_input_tracing=False) self._compressed_context.add_node_comparators( [MODEL_INPUT_OP_NAME], ShapeIgnoringTensorMetaComparator()) if self.scopes_without_shape_matching: self._compressed_context.add_node_comparators( scopes_without_shape_matching, ShapeIgnoringTensorMetaComparator()) self._load_listener = None self._builders = [] # type: List['CompressionAlgorithmBuilder']
class NNCFNetwork(nn.Module, PostGraphBuildActing): def __init__(self, module, input_infos: List[ModelInputInfo] = None, dummy_forward_fn=None, wrap_inputs_fn=None, scopes_without_shape_matching=None, ignored_scopes=None, target_scopes=None): super().__init__() self._set_nncf_wrapped_model(module) self._forward_signature = inspect.signature(module.forward) self.input_infos = input_infos self.ignored_scopes = ignored_scopes self.target_scopes = target_scopes self._dummy_forward_fn = dummy_forward_fn device = next(module.parameters()).device if wrap_inputs_fn is not None: self._wrap_inputs_fn = wrap_inputs_fn else: self.__input_infos_based_input_wrapper = InputInfoWrapManager( self.input_infos, self._forward_signature, module_ref_for_device=self) self._wrap_inputs_fn = self.__input_infos_based_input_wrapper.wrap_inputs self._nncf_module_scopes = [] # type: List[Scope] self.scopes_without_shape_matching = scopes_without_shape_matching self.debug_interface = CombinedDebugInterface() if is_debug() else None self._extra_module_types = [] # type: List[CompressionModuleType] # pylint:disable=line-too-long self._insertions_into_original_graph = { } # type: Dict[InsertionPoint, List[Tuple[Callable, OperationPriority]]] # all modules should be replaced prior to graph building self._replace_modules_by_nncf_modules(device) _orig_context = TracingContext() _orig_graph_build_forward_fn = self._get_dummy_forward_fn_for_graph_building( with_input_tracing=True) self._graph_builder = GraphBuilder(_orig_graph_build_forward_fn) _orig_context.add_node_comparators([MODEL_INPUT_OP_NAME], ShapeIgnoringTensorMetaComparator()) if self.scopes_without_shape_matching: _orig_context.add_node_comparators( scopes_without_shape_matching, ShapeIgnoringTensorMetaComparator()) self._original_graph = self._graph_builder.build_graph( self.get_nncf_wrapped_model(), _orig_context) self._compressed_context = TracingContext() self._dummy_forward_fn = self._get_dummy_forward_fn_for_graph_building( with_input_tracing=False) self._compressed_context.add_node_comparators( [MODEL_INPUT_OP_NAME], ShapeIgnoringTensorMetaComparator()) if self.scopes_without_shape_matching: self._compressed_context.add_node_comparators( scopes_without_shape_matching, ShapeIgnoringTensorMetaComparator()) self._load_listener = None self._builders = [] # type: List['CompressionAlgorithmBuilder'] @debuggable_forward def forward(self, *args, **kwargs): with self._compressed_context as ctx: # type: TracingContext ctx.base_module_thread_local_replica = self args, kwargs = self._wrap_inputs_fn(args, kwargs) retval = self.get_nncf_wrapped_model()(*args, **kwargs) return retval def register_algorithm(self, builder: 'CompressionAlgorithmBuilder'): """Should be called during *builder*'s *apply_to* method, otherwise there will be no corresponding controller returned by the network on the *commit_compression_changes* stage""" self._builders.append(builder) # Cannnot use property syntax here, otherwise the wrapped module will end up # being twice in the same checkpoint with different prefixes def get_nncf_wrapped_model(self): return getattr(self, MODEL_WRAPPED_BY_NNCF_ATTR_NAME) def _set_nncf_wrapped_model(self, value): setattr(self, MODEL_WRAPPED_BY_NNCF_ATTR_NAME, value) def get_modules_in_nncf_modules_by_type(self, types) -> Dict['Scope', nn.Module]: nncf_modules = self.get_nncf_modules() retval = {} for nncf_module_scope, nncf_module in nncf_modules.items(): nncf_module_scope.pop() for relative_scope, target_module in get_all_modules_by_type( nncf_module, types).items(): retval[nncf_module_scope + relative_scope] = target_module return retval def register_insertion_command(self, command: InsertionCommand): point = command.insertion_point if point not in self._insertions_into_original_graph: self._insertions_into_original_graph[point] = [(command.fn, command.priority)] else: self._insertions_into_original_graph[point].append( (command.fn, command.priority)) def commit_compression_changes(self) -> 'CompressionAlgorithmController': for insertion_point, fn_list_with_priority in self._insertions_into_original_graph.items( ): fn_list_with_priority = sorted(fn_list_with_priority, key=lambda x: x[1]) self._insertions_into_original_graph[ insertion_point] = fn_list_with_priority self._insert_at_point(insertion_point, [x[0] for x in fn_list_with_priority]) if self.debug_interface is not None: self.debug_interface.init_actual(self) quantization_types = [ class_type.__name__ for class_type in QUANTIZATION_MODULES.registry_dict.values() ] all_quantizations = get_state_dict_names_with_modules( self, quantization_types) self._load_listener = LoadStateListener(self, all_quantizations) if not self._builders: from nncf.algo_selector import NoCompressionAlgorithmController return NoCompressionAlgorithmController(self) if len(self._builders) == 1: return self._builders[0].build_controller(self) from nncf.composite_compression import CompositeCompressionAlgorithmController composite_controller = CompositeCompressionAlgorithmController(self) for algo_builder in self._builders: composite_controller.add(algo_builder.build_controller(self)) return composite_controller def _insert_at_point(self, point: InsertionPoint, fn_list: List[Callable]): if point.insertion_type == InsertionType.OPERATOR_PRE_HOOK: self._compressed_context.register_pre_hooks( fn_list, point.ia_op_exec_context) elif point.insertion_type == InsertionType.OPERATOR_POST_HOOK: self._compressed_context.register_post_hooks( fn_list, point.ia_op_exec_context) else: norm_target_scope = self._normalize_variable_recurrent_scope( point.ia_op_exec_context.scope_in_model) norm_nncf_scopes = [ self._normalize_variable_recurrent_scope(x) for x in self._nncf_module_scopes ] assert norm_target_scope in norm_nncf_scopes # Required for proper Recurrent/VariableRecurrent addressing nncf_module = self.get_module_by_scope( point.ia_op_exec_context.scope_in_model) if point.insertion_type == InsertionType.NNCF_MODULE_PRE_OP: for fn in fn_list: nncf_module.register_pre_forward_operation(fn) elif point.insertion_type == InsertionType.NNCF_MODULE_POST_OP: for fn in fn_list: nncf_module.register_post_forward_operation(fn) def __getattr__(self, name): wrapped_module = super().__getattr__(MODEL_WRAPPED_BY_NNCF_ATTR_NAME) if hasattr(wrapped_module, name): return getattr(wrapped_module, name) return super().__getattr__(name) def get_graph(self) -> NNCFGraph: return self._compressed_context.graph def get_original_graph(self) -> NNCFGraph: return self._original_graph def get_tracing_context(self) -> TracingContext: return self._compressed_context def _get_dummy_forward_fn_for_graph_building(self, with_input_tracing): if self._dummy_forward_fn is None: return create_dummy_forward_fn( self.input_infos, with_input_tracing=with_input_tracing, wrap_inputs_fn=self._wrap_inputs_fn) return self._dummy_forward_fn def _replace_modules_by_nncf_modules(self, device): module, self._nncf_module_scopes = replace_modules_by_nncf_modules( self.get_nncf_wrapped_model(), ignored_scopes=self.ignored_scopes, target_scopes=self.target_scopes) self._set_nncf_wrapped_model(module.to(device)) def get_nncf_module_scopes(self) -> List['Scope']: return self._nncf_module_scopes def get_nncf_modules(self) -> Dict['Scope', torch.nn.Module]: nncf_module_names_list = NNCF_MODULES + [ x.__name__ for x in NNCF_WRAPPED_USER_MODULES_DICT.values() ] return get_all_modules_by_type(self.get_nncf_wrapped_model(), nncf_module_names_list) def rebuild_graph(self, *input_args): self._compressed_context.reset_graph() dummy_forward_fn = self._get_dummy_forward_fn_for_graph_building( with_input_tracing=False) builder = GraphBuilder(dummy_forward_fn) _ = builder.build_graph(self, self._compressed_context) def post_build_graph_actions(self): # Reset initialization flags (`initialized`) for all quantization modules # after dummy `load_state_dict` call. quantization_types = [ class_type.__name__ for class_type in QUANTIZATION_MODULES.registry_dict.values() ] all_quantizations = get_state_dict_names_with_modules( self, quantization_types) for module in all_quantizations.values(): module.initialized = False def get_post_pattern_insertion_points( self, pattern: 'NNCFNodeExpression', omit_nodes_in_nncf_modules=False) -> List[InsertionInfo]: io_infos = self._original_graph.get_matching_nncf_graph_pattern_io_list( pattern) insertion_infos = [] for io_info in io_infos: # The input/output is given in terms of edges, but the post-hooks are currently applied to # nodes. Multiple output edges in a pattern I/O info may originate from one and the same # node, and we have to ensure that these resolve into just one insertion point - thus the usage of "set". pattern_insertion_info_set = set() if len(io_info.output_edges) > 1: nncf_logger.debug( "WARNING: pattern has more than one activation output") for nncf_node in io_info.output_nodes: pattern_insertion_info_set.add( InsertionInfo(nncf_node.op_exec_context, is_output=True, shape_to_operate_on=None)) # TODO: determine output shapes for output nodes to enable per-channel quantization # Ignore input nodes in the pattern for now, rely on the _quantize_inputs functions. # TODO: handle input quantization here as well # Since this function is currently only used for activation quantization purposes via operator # post-hook mechanism, we may take any edge and it will point from the same node where we will have to # insert a quantizer later. However, in the future the output edges may refer to activation tensors # with different sizes, in which case we have to insert different per-channel quantizers to # accomodate different trainable params if there is a difference in the channel dimension. # Furthermore, currently there is no distinction for single tensor output to multiple nodes and # multiple tensor output to multiple nodes ("chunk" operation is an example of the latter). # The pattern may also have unexpected outputs from a node in the middle of the pattern (see # "densenet121.dot" for an example of this) - need to decide what to do with that in terms # of quantization. # TODO: address the issues above. for nncf_edge in io_info.output_edges: pattern_insertion_info_set.add( InsertionInfo(nncf_edge.from_node.op_exec_context, is_output=False, shape_to_operate_on=nncf_edge.tensor_shape)) insertion_infos += list(pattern_insertion_info_set) insertion_infos = list( set(insertion_infos) ) # Filter the overlapping insertion points from different matches (happens for GNMT) insertion_infos_filtered = [] for info in insertion_infos: if omit_nodes_in_nncf_modules and self.is_scope_in_nncf_module_scope( info.op_exec_context.scope_in_model): continue insertion_infos_filtered.append(info) return insertion_infos_filtered def is_scope_in_nncf_module_scope(self, scope: 'Scope'): # TODO: optimize norm_nncf_scopes = [ self._normalize_variable_recurrent_scope(x) for x in self._nncf_module_scopes ] norm_op_scope = self._normalize_variable_recurrent_scope(scope) for nncf_scope in norm_nncf_scopes: if norm_op_scope in nncf_scope: return True return False def register_compression_module_type( self, compression_module_type: CompressionModuleType): attr_name = self._compression_module_type_to_attr_name( compression_module_type) if compression_module_type in self._extra_module_types: raise RuntimeError("Module type {} is already registered".format( compression_module_type)) self.__setattr__(attr_name, nn.ModuleDict()) self._extra_module_types.append(compression_module_type) def add_compression_module(self, module_key: str, module: nn.Module, compression_module_type: CompressionModuleType): attr_name = self._compression_module_type_to_attr_name( compression_module_type) if compression_module_type not in self._extra_module_types: raise RuntimeError("Module type {} was not registered".format( compression_module_type)) self.__getattr__(attr_name)[module_key] = module def get_compression_modules_by_type( self, compression_module_type: CompressionModuleType) -> nn.ModuleDict: attr_name = self._compression_module_type_to_attr_name( compression_module_type) if compression_module_type not in self._extra_module_types: raise RuntimeError("Module type {} was not registered".format( compression_module_type)) return self.__getattr__(attr_name) @staticmethod def _compression_module_type_to_attr_name( compression_module_type: CompressionModuleType): """Required for backward compatibility with checkpoints that store function and activation quantizers directly under corresponding attributes of NNCFNetwork.""" if compression_module_type == CompressionModuleType.FUNCTION_QUANTIZER: return "function_quantizers" if compression_module_type == CompressionModuleType.ACTIVATION_QUANTIZER: return "activation_quantizers" raise RuntimeError("Unknown extra module type") def sort_compression_modules( self, compression_module_type: CompressionModuleType): attr_name = self._compression_module_type_to_attr_name( compression_module_type) if compression_module_type not in self._extra_module_types: raise RuntimeError("Module type {} was not registered".format( compression_module_type)) module_dict = self.__getattr__(attr_name) # pylint: disable=protected-access module_dict._modules = OrderedDict(sorted( module_dict._modules.items())) self.__setattr__(attr_name, module_dict) @staticmethod def _normalize_variable_recurrent_scope(scope: 'Scope'): """ Two scopes pointing to an NNCF module that only differ in a Recurrent/VariableRecurrent/VariableRecurrentReverse scope element actually point to one and the same module. """ ret_scope = scope.copy() for scope_element in ret_scope: if scope_element.calling_module_class_name in [ "Recurrent", "VariableRecurrent", "VariableRecurrentReverse" ]: scope_element.calling_module_class_name = "NormalizedName_Recurrent" return ret_scope def do_dummy_forward(self, force_eval=False): """Attention: If run with force_eval=False, this may spoil the batchnorm statistics, and an eval run of the model will perform much worse than the train run. """ if force_eval: train_mode = self.training self.eval() with torch.no_grad(): self._dummy_forward_fn(self) if force_eval: if train_mode: self.train() def get_insertion_point_graph(self) -> InsertionPointGraph: ip_graph = InsertionPointGraph( self._original_graph.get_nx_graph_copy()) # Mark IP graph operator nodes with associated op metatypes # Determining operator metatypes is more suited to occur at wrap_operator # stage, because it might be influenced by specific non-tensor function paramters, # but we have to inspect the containing module parameters as well, so the # TracingContext in wrap_operator would have to retain a reference to # the model that uses it. Since currently we do not need to inspect the # function arguments to determine the metatype, we can do this here, but # once we need to inspect the arguments, the code will have to be moved to # wrap_operator. for node_key in ip_graph.nodes: ip_graph_node = ip_graph.nodes[node_key] ip_graph_node_type = ip_graph_node[ InsertionPointGraph.NODE_TYPE_NODE_ATTR] if ip_graph_node_type == InsertionPointGraphNodeType.OPERATOR: nncf_graph_node_ref = ip_graph_node[ InsertionPointGraph.REGULAR_NODE_REF_NODE_ATTR] op_exec_context = nncf_graph_node_ref[ NNCFGraph.OP_EXEC_CONTEXT_NODE_ATTR] op_name = op_exec_context.operator_name scope = op_exec_context.scope_in_model op_arch = OPERATOR_METATYPES.get_operator_metatype_by_op_name( op_name) module = self.get_module_by_scope(scope) if module is not None: subtype = op_arch.determine_subtype( containing_module=module) if subtype is not None: op_arch = subtype ip_graph_node[ InsertionPointGraph.OPERATOR_METATYPE_NODE_ATTR] = op_arch return ip_graph def get_module_by_scope(self, scope: 'Scope') -> torch.nn.Module: curr_module = self.get_nncf_wrapped_model() for scope_element in scope[ 1:]: # omit first scope element which corresponds to base module if scope_element.calling_field_name is None: # The module used is being created in-place every time and never stored in the model, # happens for nn.Softmax in BERT implementations. return None # pylint: disable=protected-access next_module = curr_module._modules.get( scope_element.calling_field_name) if next_module is None: raise RuntimeError( "Could not find a {} module member in {} module of scope {} during node search" .format(scope_element.calling_field_name, scope_element.calling_module_class_name, str(scope))) curr_module = next_module return curr_module def get_parameters_count_in_model(self): """ Return total amount of model parameters. """ count = 0 for param in self.parameters(): count = count + param.numel() return count def get_flops_per_module(self): """ Calculates FLOPS count for modules. """ model = self flops_count_dict = {} def get_hook(name): def compute_MACs_hook(module, input_, output): if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)): ks = module.weight.data.shape mac_count = ks[0] * ks[1] * ks[2] * ks[3] * output.shape[ 3] * output.shape[2] elif isinstance(module, nn.Linear): mac_count = input_[0].shape[1] * output.shape[-1] elif isinstance(module, nn.BatchNorm2d): mac_count = np.prod(list(input_[0].shape)) else: return flops_count_dict[name] = 2 * mac_count return compute_MACs_hook hook_list = [ m.register_forward_hook(get_hook(n)) for n, m in model.named_modules() ] model.do_dummy_forward(force_eval=True) for h in hook_list: h.remove() return flops_count_dict def get_MACs_in_model(self): """ Calculates MAC units count for model. """ flops_count_dict = self.get_flops_per_module() total_MACs_count = sum(v // 2 for v in flops_count_dict.values()) return total_MACs_count def get_input_infos(self) -> List[ModelInputInfo]: return deepcopy(self.input_infos)
def __init__(self, module, input_infos: List[ModelInputInfo], dummy_forward_fn=None, wrap_inputs_fn=None, scopes_without_shape_matching=None, ignored_scopes=None, target_scopes=None, reset: bool = False): super().__init__() self._set_nncf_wrapped_model(module) self._forward_signature = inspect.signature(module.forward) self.input_infos = input_infos self.ignored_scopes = ignored_scopes self.target_scopes = target_scopes self._user_dummy_forward_fn = dummy_forward_fn device = next(module.parameters()).device if wrap_inputs_fn is not None: self._wrap_inputs_fn = wrap_inputs_fn else: self.__input_infos_based_input_wrapper = InputInfoWrapManager( self.input_infos, self._forward_signature, module_ref_for_device=self) self._wrap_inputs_fn = self.__input_infos_based_input_wrapper.wrap_inputs self._nncf_module_scopes = [] # type: List[Scope] self.scopes_without_shape_matching = scopes_without_shape_matching self.debug_interface = CombinedDebugInterface() if is_debug() else None self._extra_module_types = [] # type: List[ExtraCompressionModuleType] # pylint:disable=line-too-long self._insertions_into_original_graph = { } # type: Dict[InsertionPoint, List[Tuple[Callable, OperationPriority]]] _orig_graph_build_forward_fn = self._get_dummy_forward_fn_for_graph_building( with_input_tracing=True) self._graph_builder = GraphBuilder(_orig_graph_build_forward_fn) nncf_wrapped_model = self.get_nncf_wrapped_model() eval_only_ops_exec_ctx = self.collect_eval_only_ops_exec_context( nncf_wrapped_model, self._graph_builder) # all modules called in eval mode should be replaced prior to graph building self._replace_modules_by_nncf_modules(device, eval_only_ops_exec_ctx, reset) _orig_context = TracingContext() _orig_context.add_node_comparators([MODEL_INPUT_OP_NAME], ShapeIgnoringTensorMetaComparator()) if self.scopes_without_shape_matching: _orig_context.add_node_comparators( scopes_without_shape_matching, ShapeIgnoringTensorMetaComparator()) self._original_graph = self._graph_builder.build_graph( nncf_wrapped_model, _orig_context, as_eval=True) self._compressed_context = TracingContext() self._dummy_forward_fn = self._get_dummy_forward_fn_for_graph_building( with_input_tracing=False) self._compressed_context.add_node_comparators( [MODEL_INPUT_OP_NAME], ShapeIgnoringTensorMetaComparator()) if self.scopes_without_shape_matching: self._compressed_context.add_node_comparators( scopes_without_shape_matching, ShapeIgnoringTensorMetaComparator()) self._load_listener = None self._builders = [] # type: List['CompressionAlgorithmBuilder']