def check_correct_nncf_modules_replacement(model: NNCFNetwork, compressed_model: NNCFNetwork) \ -> Tuple[Dict[Scope, Module], Dict[Scope, Module]]: """ Check that all convolutions in model was replaced by NNCF convolution. :param model: original model :param compressed_model: compressed model :return: list of all convolutions in original model and list of all NNCF convolutions from compressed model """ NNCF_MODULES_REVERSED_MAP = { value: key for key, value in NNCF_MODULES_MAP.items() } original_modules = get_all_modules_by_type(model, list(NNCF_MODULES_MAP.values())) nncf_modules = get_all_modules_by_type( compressed_model.get_nncf_wrapped_model(), list(NNCF_MODULES_MAP.keys())) assert len(original_modules) == len(nncf_modules) print(original_modules, nncf_modules) for scope in original_modules.keys(): sparse_scope = deepcopy(scope) elt = sparse_scope.pop() # type: ScopeElement elt.calling_module_class_name = NNCF_MODULES_REVERSED_MAP[ elt.calling_module_class_name] sparse_scope.push(elt) print(sparse_scope, nncf_modules) assert sparse_scope in nncf_modules return original_modules, nncf_modules
def __init__(self, model: NNCFNetwork, weight_quantizers: Dict[WeightQuantizerId, WeightQuantizerInfo], constraints: HardwareQuantizationConstraints): self._wq_affected_module_node_name_vs_qid_dict = { k.target_node_name: k for k in weight_quantizers.keys() } self._quantizer_module_scope_vs_qid_dict = { } # type: Dict[Scope, WeightQuantizerId] self._skipped_quantized_weight_node_names = [] self._skipped_weight_quantizers = { } # type: Dict[WeightQuantizerId, BaseQuantizer] self._weight_quantizers_in_execution_order_per_scope = OrderedDict( ) # type: Dict[Scope, BaseQuantizer] self._weight_quantizers_in_execution_order = OrderedDict( ) # type: Dict[WeightQuantizerId, BaseQuantizer] quantization_types = [ class_type.__name__ for class_type in QUANTIZATION_MODULES.registry_dict.values() ] weight_module_dict = model.get_nncf_wrapped_model() quantizers_in_execution_order_per_scope = get_all_modules_by_type( weight_module_dict, quantization_types) for scope, quantizer in quantizers_in_execution_order_per_scope.items( ): if self.is_wq_scope(scope): affected_module_scope = self.get_owning_module_scope_from_wq_scope( scope) affected_module_node = model.get_original_graph( ).get_op_nodes_in_scope(affected_module_scope)[0] if affected_module_node.node_name in self._wq_affected_module_node_name_vs_qid_dict: qid = self._wq_affected_module_node_name_vs_qid_dict[ affected_module_node.node_name] if len(constraints.get_all_unique_bitwidths(qid)) != 1: self._weight_quantizers_in_execution_order_per_scope[ scope] = quantizer self._weight_quantizers_in_execution_order[ qid] = quantizer else: self._skipped_quantized_weight_node_names.append( affected_module_node.node_name) self._skipped_weight_quantizers[qid] = quantizer