def statistics(self, quickly_collected_only: bool = False) -> NNCFStatistics: if not quickly_collected_only and is_debug(): stats = PrunedModelTheoreticalBorderline(self._pruned_layers_num, self._prunable_layers_num, self._max_prunable_flops, self._max_prunable_params, self.full_flops, self.full_params_num) nncf_logger.debug(stats.to_str()) pruned_layers_summary = self._calculate_pruned_layers_summary() self._update_benchmark_statistics() model_statistics = PrunedModelStatistics( self.full_flops, self.current_flops, self.full_params_num, self.current_params_num, self.full_filters_num, self.current_filters_num, pruned_layers_summary) stats = FilterPruningStatistics(model_statistics, self.scheduler.current_pruning_level, self.scheduler.target_level, self.prune_flops) nncf_stats = NNCFStatistics() nncf_stats.register('filter_pruning', stats) return nncf_stats
def __call__(self, saved_inputs: List[TensorMeta], actual_inputs: List[TensorMeta], tm_comparators: List[TensorMetaComparator]) -> bool: if saved_inputs is None and actual_inputs: return False matched_with_unexpected_tensors = False for saved_input, actual_input in zip(saved_inputs, actual_inputs): if saved_input is None and actual_input is None: continue if (saved_input is None) and (actual_input is not None): # torch.Tensor.size() seems to return ints when not tracing ONNX # and tensors when tracing ONNX. This breaks input-based node matching whenever # torch.Tensor.size() return value is passed into a NNCF-traced operation (such as `view`) # because at graph building time it expected to see ints as args and now it sees tensors. # To mitigate this, will only match inputs against the positions which had tensors during build-time # and disregard the rest of the argument positions. matched_with_unexpected_tensors = True continue if (saved_input is not None) and (actual_input is None): return False for tm_comparator in tm_comparators: if not tm_comparator(saved_input, actual_input): return False if matched_with_unexpected_tensors: nncf_logger.debug( "Had to match a node to an op which has tensors at positions where there were no tensors " "at graph building time:\nNode input metas: {}, but op input metas: {}" .format(saved_inputs, actual_inputs)) return True
def prepare_for_export(self): """ Applies pruning masks to layer weights before exporting the model to ONNX. """ self._propagate_masks() pruned_layers_stats = self.get_stats_for_pruned_modules() nncf_logger.debug('Pruned layers statistics: \n%s', pruned_layers_stats.draw())
def _set_binary_masks_for_pruned_layers_globally(self, pruning_level: float): """ Sets the binary mask values for layer groups according to the global pruning level. Filter importance scores in each group are merged into a single global list and a threshold value separating the pruning_level proportion of the least important filters in the model is calculated. Filters are pruned globally according to the threshold value. """ nncf_logger.debug( 'Setting new binary masks for all pruned modules together.') filter_importances = {} wrapped_layers = collect_wrapped_layers(self._model) # 0. Remove masks at the elements of the NNCFGraph for node in self._original_graph.topological_sort(): node.data.pop('output_mask', None) # 1. Calculate masks # a. Calculate importances for all groups of filters for group in self._pruned_layer_groups_info.get_all_clusters(): cumulative_filters_importance = self._calculate_filters_importance_in_group( group) filter_importances[group.id] = cumulative_filters_importance # b. Calculate one threshold for all weights importances = tf.concat(list(filter_importances.values()), 0) threshold = sorted(importances)[int(pruning_level * importances.shape[0])] # c. Initialize masks for group in self._pruned_layer_groups_info.get_all_clusters(): filter_mask = calculate_binary_mask(filter_importances[group.id], threshold) for node in group.elements: nncf_node = self._original_graph.get_node_by_id( node.nncf_node_id) nncf_node.data['output_mask'] = TFNNCFTensor(filter_mask) # 2. Propagate masks across the graph mask_propagator = MaskPropagationAlgorithm( self._original_graph, TF_PRUNING_OPERATOR_METATYPES, TFNNCFPruningTensorProcessor) mask_propagator.mask_propagation() # 3. Apply masks to the model nncf_sorted_nodes = self._original_graph.topological_sort() for layer in wrapped_layers: nncf_node = [ n for n in nncf_sorted_nodes if layer.name == n.layer_name ][0] if nncf_node.data['output_mask'] is not None: self._set_operation_masks([layer], nncf_node.data['output_mask'].tensor) # Calculate actual flops with new masks self._update_benchmark_statistics()
def remove_unified_scale_from_point(self, qp_id: QuantizationPointId): gid = self.get_unified_scale_group_id(qp_id) if gid is None: nncf_logger.debug("Attempted to remove QP id {} from associated unified scale group, but the QP" "is not in any unified scale group - ignoring.".format(qp_id)) return self.unified_scale_groups[gid].discard(qp_id) if not self.unified_scale_groups[gid]: nncf_logger.debug("Removed last entry from a unified scale group {} - removing group itself".format(gid)) self.unified_scale_groups.pop(gid)
def _get_metatypes_for_hw_config_op( self, hw_config_op: HWConfigOpName) -> Set[Type[OperatorMetatype]]: metatypes = set() for op_meta in self._get_available_operator_metatypes_for_matching(): if hw_config_op in op_meta.hw_config_names: metatypes.add(op_meta) if not metatypes: nncf_logger.debug( 'Operation name {} in HW config is not registered in NNCF under any supported ' 'operation metatype - will be ignored'.format(hw_config_op)) return metatypes
def kd_loss_fn(teacher_output: torch.Tensor, student_output: torch.Tensor): mse = torch.nn.MSELoss() if len(teacher_output.shape) < 2: nncf_logger.debug( "Incompatible number of dimensions of the model output tensor for MSE KD " "(student - {}, teacher - {}, number of dims {} should be > 1)" " (most likely loss) - ignoring!".format( student_output.shape, teacher_output.shape, len(student_output.shape))) return torch.zeros([1]).to(student_output.device) return scale * mse(teacher_output, student_output)
def _cross_match_version_agnostic_names(normalized_keys_to_load: List[str], normalized_model_keys: List[str]) -> Dict[str, str]: """ Handles the situation where the normalized_keys_to_load contain legacy version-agnostic names of operations, such as `RELU`. :param normalized_keys_to_load: A list of keys in the checkpoint, potentially with version-agnostic names :param normalized_model_keys: A list of keys in the model, without version-agnostic names. :return: A mapping of the checkpoint key to a model key that matches version-agnostic names with their torch-specific counterparts. """ version_agnostic_to_specific_names = {"RELU": {"relu", "relu_"}} retval = {} processed_keys_to_load = normalized_keys_to_load for model_key in normalized_model_keys: for agnostic_op_name, specific_op_name_set in version_agnostic_to_specific_names.items(): matches_for_curr_agnostic_op_name = [] has_specific_op_name = False for specific_op_name in specific_op_name_set: # Have to take care not to replace the matches to the class names # The op names in existing checkpoint can only appear in external quantizers, # i.e. external_quantizers.ResNet/ReLU[relu]/relu_0.signed_tensor, so composing a regex to match # for that slash_split_str = model_key.split('/') last_portion = slash_split_str[-1] if specific_op_name in last_portion: last_portion = last_portion.replace(specific_op_name, agnostic_op_name, 1) has_specific_op_name = True slash_split_str[-1] = last_portion agnostic_version_of_model_key = '/'.join(slash_split_str) processed_agnostic_version_of_model_key = agnostic_version_of_model_key if processed_agnostic_version_of_model_key in processed_keys_to_load: idx = processed_keys_to_load.index(processed_agnostic_version_of_model_key) matches_for_curr_agnostic_op_name.append(normalized_keys_to_load[idx]) if not has_specific_op_name: if matches_for_curr_agnostic_op_name: checkpoint_matched_key = next(iter(matches_for_curr_agnostic_op_name)) retval[checkpoint_matched_key] = model_key elif len(matches_for_curr_agnostic_op_name) == 1: checkpoint_matched_key = next(iter(matches_for_curr_agnostic_op_name)) retval[checkpoint_matched_key] = model_key elif len(matches_for_curr_agnostic_op_name) == 0: nncf_logger.debug("Failed to match a version-specific key: {}".format(model_key)) elif len(matches_for_curr_agnostic_op_name) > 1: nncf_logger.debug("More than one match for the version specific key: {}\n" "Matches:\n" "{}".format(model_key, ', '.join(matches_for_curr_agnostic_op_name))) return retval
def kd_loss_fn(teacher_output: torch.Tensor, student_output: torch.Tensor): if len(student_output.shape) != 2 or len( teacher_output.shape) != 2: nncf_logger.debug( "Incompatible number of dimensions of the model output tensor for softmax KD " "(student - {}, teacher - {}, number of dims {} should be == 2)" " - ignoring!".format(student_output.shape, teacher_output.shape, len(student_output.shape))) return torch.zeros([1]).to(student_output.device) return scale * -(nn.functional.log_softmax(student_output / temperature, dim=1) * nn.functional.softmax(teacher_output / temperature, dim=1)).mean() \ * (student_output.shape[1] * temperature * temperature)
def save_first_iteration_node(self, inputs: 'OperatorInput', node: DynamicGraphNode): """ It finds and saves "starting" points of iteration for further matching with them on next iteration, instead of adding new nodes for each iteration. "Starting" points of iteration are nodes * that have at least one input node, which is outside of iteration scope * or whose all inputs are not TracedTensor """ op_exec_context = node.op_exec_context name = str(node) iter_scopes = op_exec_context.scope_in_model.get_iteration_scopes() if iter_scopes: for iter_scope in iter_scopes: if iter_scope not in self._first_iteration_nodes: self._first_iteration_nodes[iter_scope] = {} first_nodes = self._first_iteration_nodes[iter_scope] has_input_outside_iteration = False untraced_tensor_inputs = [] traced_tensor_inputs = [] non_tensor_inputs = [] for i in inputs: input_obj = i.getter() if isinstance(input_obj, Tensor): if not isinstance(input_obj, TracedTensor): untraced_tensor_inputs.append(input_obj) else: traced_tensor_inputs.append(input_obj) else: non_tensor_inputs.append(input_obj) for i in traced_tensor_inputs: creator_id = i.tensor_meta.creator_id creator_node = self.get_node_by_id(creator_id) creator_node_op_exec_ctx = creator_node[ DynamicGraph.OP_EXEC_CONTEXT_NODE_ATTR] within_scopes = creator_node_op_exec_ctx.scope_in_model.get_iteration_scopes( ) if iter_scope not in within_scopes: has_input_outside_iteration = True if len(untraced_tensor_inputs) == (len(inputs) - len(non_tensor_inputs)): has_input_outside_iteration = True if has_input_outside_iteration: node_name = str(op_exec_context.op_address) first_nodes[node_name] = node nncf_logger.debug( 'Found first iteration node: {} in scope: {}'.format( name, iter_scope))
def get_containing_module(self, node_name: NNCFNodeName) -> torch.nn.Module: if self._compressed_graph is not None: try: scope = self._compressed_graph.get_scope_by_node_name( node_name) except RuntimeError: nncf_logger.debug( "Node {} not found in compressed graph when trying to determine containing module, " "trying the original graph to see if the node was present there " "during graph building") scope = self._original_graph.get_scope_by_node_name(node_name) else: scope = self._original_graph.get_scope_by_node_name(node_name) return self.get_module_by_scope(scope)
def select_qconfigs(self, qp_id_vs_selected_qconfig_dict: Dict[QuantizationPointId, QuantizerConfig], strict: bool =True) -> \ SingleConfigQuantizerSetup: retval = SingleConfigQuantizerSetup() retval.unified_scale_groups = deepcopy(self.unified_scale_groups) retval.shared_input_operation_set_groups = deepcopy(self.shared_input_operation_set_groups) if Counter(qp_id_vs_selected_qconfig_dict.keys()) != Counter(self.quantization_points.keys()): raise ValueError("The set of quantization points for a selection is inconsistent with quantization" "points in the quantizer setup!") for qp_id, qp in self.quantization_points.items(): if strict: retval.quantization_points[qp_id] = qp.select_qconfig( qp_id_vs_selected_qconfig_dict[qp_id] ) else: multi_qp = qp qconfig = qp_id_vs_selected_qconfig_dict[qp_id] retval.quantization_points[qp_id] = SingleConfigQuantizationPoint( multi_qp.insertion_point, qconfig, multi_qp.directly_quantized_operator_node_names) # Segregate the unified scale groups into sub-groups based on what exact config was chosen. for us_group in self.unified_scale_groups.values(): per_channel_qids = set() per_tensor_qids = set() for us_qid in us_group: final_qconfig = retval.quantization_points[us_qid].qconfig if final_qconfig.per_channel: per_channel_qids.add(us_qid) else: per_tensor_qids.add(us_qid) if per_tensor_qids: for qid in per_tensor_qids: retval.remove_unified_scale_from_point(qid) retval.register_unified_scale_group(list(per_tensor_qids)) for per_channel_qid in per_channel_qids: us_type = self._unified_scale_qpid_vs_type[per_channel_qid] if us_type is UnifiedScaleType.UNIFY_ONLY_PER_TENSOR: nncf_logger.debug("Per-channel quantizer config selected in a MultiConfigQuantizerSetup for a " "unified scale point that only supports per-tensor scale unification, disabling " "unified scales for this point.") retval.remove_unified_scale_from_point(per_channel_qid) return retval
def _set_binary_masks_for_pruned_layers_groupwise(self, pruning_level: float): nncf_logger.debug('Setting new binary masks for pruned layers.') wrapped_layers = collect_wrapped_layers(self._model) # 0. Removing masks at the elements of the NNCFGraph for node in self._original_graph.topological_sort(): node.data.pop('output_mask', None) # 1. Calculate masks for group in self._pruned_layer_groups_info.get_all_clusters(): # a. Calculate the cumulative importance for all filters in the group cumulative_filters_importance = self._calculate_filters_importance_in_group( group) filters_num = len(cumulative_filters_importance) # b. Calculate threshold num_of_sparse_elems = get_rounded_pruned_element_number( cumulative_filters_importance.shape[0], pruning_level) threshold = sorted(cumulative_filters_importance)[min( num_of_sparse_elems, filters_num - 1)] # c. Initialize masks filter_mask = calculate_binary_mask(cumulative_filters_importance, threshold) for node in group.elements: nncf_node = self._original_graph.get_node_by_id( node.nncf_node_id) nncf_node.data['output_mask'] = TFNNCFTensor(filter_mask) # 2. Propagating masks across the graph mask_propagator = MaskPropagationAlgorithm( self._original_graph, TF_PRUNING_OPERATOR_METATYPES, TFNNCFPruningTensorProcessor) mask_propagator.mask_propagation() # 3. Apply masks to the model nncf_sorted_nodes = self._original_graph.topological_sort() for layer in wrapped_layers: nncf_node = [ n for n in nncf_sorted_nodes if layer.name == n.layer_name ][0] if nncf_node.data['output_mask'] is not None: self._set_operation_masks([layer], nncf_node.data['output_mask'].tensor) # Calculate actual flops and weights number with new masks self._update_benchmark_statistics()
def _set_binary_masks_for_pruned_modules_globally( self, pruning_level: float) -> None: """ Set the binary mask values for layer groups according to the global pruning level. Filter importance scores in each group are merged into a single global list and a threshold value separating the pruning_level proportion of the least important filters in the model is calculated. Filters are pruned globally according to the threshold value. """ nncf_logger.debug( "Setting new binary masks for all pruned modules together.") filter_importances = [] # 1. Calculate importances for all groups of filters for group in self.pruned_module_groups_info.get_all_clusters(): filters_num = torch.tensor( [get_filters_num(minfo.module) for minfo in group.elements]) assert torch.all(filters_num == filters_num[0]) device = group.elements[0].module.weight.device cumulative_filters_importance = torch.zeros( filters_num[0]).to(device) # Calculate cumulative importance for all filters in this group for minfo in group.elements: normalized_weight = self.weights_normalizer( minfo.module.weight) filters_importance = self.filter_importance( normalized_weight, minfo.module.target_weight_dim_for_compression) cumulative_filters_importance += filters_importance filter_importances.append(cumulative_filters_importance) # 2. Calculate one threshold for all weights importances = torch.cat(filter_importances) threshold = sorted(importances)[int(pruning_level * importances.size(0))] # 3. Set binary masks for filters in groups for i, group in enumerate( self.pruned_module_groups_info.get_all_clusters()): mask = calculate_binary_mask(filter_importances[i], threshold) for minfo in group.elements: pruning_module = minfo.operand pruning_module.binary_filter_pruning_mask = mask # Calculate actual flops and weights number with new masks self._update_benchmark_statistics()
def get_filters_num(layer: NNCFWrapper): layer_metatype = get_keras_layer_metatype(layer) if len(layer_metatype.weight_definitions) != 1: raise ValueError(f'Could not calculate the number of filters ' f'for the layer {layer.layer.name}.') weight_def = layer_metatype.weight_definitions[0] weight_attr = weight_def.weight_attr_name if layer_metatype is TFBatchNormalizationLayerMetatype and not layer.layer.scale: nncf_logger.debug( 'Fused gamma parameter encountered in BatchNormalization layer. ' 'Use beta parameter instead to calculate the number of filters.') weight_attr = 'beta' filter_axis = get_filter_axis(layer, weight_attr) filters_num = layer.layer_weights[weight_attr].shape[filter_axis] return filters_num
def apply_minmax_init(self, min_values, max_values, log_module_name: str = None): """min_values and max_values must have the same shape as specified in self.scale_shape""" if self.initialized: nncf_logger.debug( "Skipped initializing {} - loaded from checkpoint".format( log_module_name)) return if torch.any(torch.eq(min_values, np.inf)) or torch.any( torch.eq(max_values, -np.inf)): raise AttributeError( 'Statistics is not collected for {}'.format(log_module_name)) own_device = next(self.parameters()).device min_values = min_values.to(own_device) max_values = max_values.to(own_device) self._apply_minmax_init(min_values, max_values, log_module_name)
def _propagate_masks(self): nncf_logger.debug("Propagating pruning masks") # 1. Propagate masks for all modules graph = self.model.get_original_graph() init_output_masks_in_graph( graph, self.pruned_module_groups_info.get_all_nodes()) MaskPropagationAlgorithm( graph, PT_PRUNING_OPERATOR_METATYPES, PTNNCFPruningTensorProcessor).mask_propagation() # 2. Set the masks for Batch/Group Norms pruned_node_modules = [] for node, pruning_block, node_module in self._pruned_norms_operators: if node_module not in pruned_node_modules: # Setting masks for BN nodes pruning_block.binary_filter_pruning_mask = node.data[ 'output_mask'].tensor pruned_node_modules.append(node_module)
def _set_binary_masks_for_pruned_modules_groupwise( self, pruning_level: Union[float, Dict[int, float]]) -> None: """ Set the binary mask values according to groupwise pruning level. If pruning_level is a float, set the pruning level uniformly across groups. If pruning_level is a dict, set specific pruning levels corresponding to each group. """ nncf_logger.debug("Updating binary masks for pruned modules.") groupwise_pruning_levels_set = isinstance(pruning_level, dict) for group in self.pruned_module_groups_info.get_all_clusters(): group_pruning_level = pruning_level[group.id] if groupwise_pruning_levels_set \ else pruning_level filters_num = torch.tensor( [get_filters_num(minfo.module) for minfo in group.elements]) assert torch.all(filters_num == filters_num[0]) device = group.elements[0].module.weight.device cumulative_filters_importance = torch.zeros( filters_num[0]).to(device) # 1. Calculate cumulative importance for all filters in group for minfo in group.elements: filters_importance = self.filter_importance( minfo.module.weight, minfo.module.target_weight_dim_for_compression) cumulative_filters_importance += filters_importance # 2. Calculate threshold num_of_sparse_elems = get_rounded_pruned_element_number( cumulative_filters_importance.size(0), group_pruning_level) threshold = sorted(cumulative_filters_importance)[min( num_of_sparse_elems, filters_num[0] - 1)] mask = calculate_binary_mask(cumulative_filters_importance, threshold) # 3. Set binary masks for filter for minfo in group.elements: pruning_module = minfo.operand pruning_module.binary_filter_pruning_mask = mask # Calculate actual flops and weights number with new masks self._update_benchmark_statistics()
def _get_operations_with_attribute_values(self, attribute_name_vs_required_value: Dict[str, Any]) -> \ Set[Type[OperatorMetatype]]: result = set() for op_dict in self: if self.ATTRIBUTES_NAME not in op_dict: continue for attr_name, attr_value in attribute_name_vs_required_value.items( ): is_value_matched = op_dict[ self.ATTRIBUTES_NAME][attr_name] == attr_value is_attr_set = attr_name in op_dict[self.ATTRIBUTES_NAME] if is_value_matched and is_attr_set: hw_config_op_name = op_dict.type metatypes = self._get_metatypes_for_hw_config_op( hw_config_op_name) if not metatypes: nncf_logger.debug( 'Operation name {} in HW config is not registered in NNCF under any supported ' 'operation metatype - will be ignored'.format( hw_config_op_name)) result.update(metatypes) return result
def get_metatype_vs_quantizer_configs_map( self, for_weights=False ) -> Dict[Type[OperatorMetatype], Optional[List[QuantizerConfig]]]: # 'None' for ops unspecified in HW config, empty list for wildcard quantization ops retval = { k: None for k in self._get_available_operator_metatypes_for_matching() } config_key = 'weights' if for_weights else 'activations' for op_dict in self: hw_config_op_name = op_dict.type metatypes = self._get_metatypes_for_hw_config_op(hw_config_op_name) if not metatypes: nncf_logger.debug( 'Operation name {} in HW config is not registered in NNCF under any supported operation ' 'metatype - will be ignored'.format(hw_config_op_name)) if self.QUANTIZATION_ALGORITHM_NAME in op_dict: allowed_qconfs = op_dict[ self.QUANTIZATION_ALGORITHM_NAME][config_key] else: allowed_qconfs = [] qconf_list_with_possible_duplicates = [] for hw_config_qconf_dict in allowed_qconfs: qconf_list_with_possible_duplicates.append( self.get_qconf_from_hw_config_subdict( hw_config_qconf_dict)) qconf_list = list( OrderedDict.fromkeys(qconf_list_with_possible_duplicates)) for meta in metatypes: retval[meta] = qconf_list return retval
def statistics(self, quickly_collected_only: bool = False) -> NNCFStatistics: if not quickly_collected_only and is_debug(): stats = PrunedModelTheoreticalBorderline(self._pruned_layers_num, self._prunable_layers_num, self._max_prunable_flops, self._max_prunable_params, self.full_flops, self.full_params_num) nncf_logger.debug(stats.to_str()) pruned_layers_summary = {} for minfo in self.pruned_module_groups_info.get_all_nodes(): layer_name = str(minfo.module_scope) if layer_name not in pruned_layers_summary: pruned_layers_summary[layer_name] = \ PrunedLayerSummary(layer_name, list(minfo.module.weight.size()), list(self.mask_shape(minfo)), self.pruning_level_for_mask(minfo)) self._update_benchmark_statistics() model_statistics = PrunedModelStatistics( self.full_flops, self.current_flops, self.full_params_num, self.current_params_num, self.full_filters_num, self.current_filters_num, list(pruned_layers_summary.values())) stats = FilterPruningStatistics(model_statistics, self.scheduler.current_pruning_level, self.scheduler.target_level, self.prune_flops) nncf_stats = NNCFStatistics() nncf_stats.register('filter_pruning', stats) return nncf_stats
def get_transformation_layout(self, model: tf.keras.Model) -> TFTransformationLayout: """ Computes necessary model transformations (pruning mask insertions) to enable pruning. :param model: The original uncompressed model. :return: The instance of the `TransformationLayout` class containing a list of pruning mask insertions. """ converter = TFModelConverterFactory.create(model) self._graph = converter.convert() groups_of_nodes_to_prune = self._pruning_node_selector.create_pruning_groups(self._graph) transformations = TFTransformationLayout() shared_layers = set() self._pruned_layer_groups_info = Clusterization[PrunedLayerInfo](lambda x: x.layer_name) for i, group in enumerate(groups_of_nodes_to_prune.get_all_clusters()): group_minfos = [] for node in group.elements: layer_name = get_layer_identifier(node) layer = model.get_layer(layer_name) group_minfos.append(PrunedLayerInfo(node.node_name, layer_name, node.node_id, is_prunable_depthwise_conv(node))) # Add output_mask to elements to run mask_propagation # and detect spec_nodes that will be pruned. # It should be done for all elements of shared layer. node.data['output_mask'] = TFNNCFTensor(tf.ones(node.layer_attributes.out_channels)) if layer_name in shared_layers: continue if node.is_shared(): shared_layers.add(layer_name) # Check that we need to prune weights in this op assert self._is_pruned_layer(layer) nncf_logger.info('Adding Weight Pruner in: %s', layer_name) _, layer_info = converter.get_layer_info_for_node(node.node_name) for weight_def in node.metatype.weight_definitions: transformations.register( self._get_insertion_command_binary_mask( layer_info.layer_name, weight_def.weight_attr_name) ) if node.metatype.bias_attr_name is not None and \ getattr(layer, node.metatype.bias_attr_name) is not None: transformations.register( self._get_insertion_command_binary_mask( layer_info.layer_name, node.metatype.bias_attr_name) ) cluster = Cluster[PrunedLayerInfo](i, group_minfos, [n.node_id for n in group.elements]) self._pruned_layer_groups_info.add_cluster(cluster) # Propagating masks across the graph to detect spec_nodes that will be pruned mask_propagator = MaskPropagationAlgorithm(self._graph, TF_PRUNING_OPERATOR_METATYPES, TFNNCFPruningTensorProcessor) mask_propagator.mask_propagation() # Add masks for all spec modules, because prunable batchnorm layers can be determined # at the moment of mask propagation types_spec_layers = [TFBatchNormalizationLayerMetatype] \ if self._prune_batch_norms else [] spec_nodes = self._graph.get_nodes_by_metatypes(types_spec_layers) for spec_node in spec_nodes: layer_name = get_layer_identifier(spec_node) layer = model.get_layer(layer_name) if spec_node.data['output_mask'] is None: # Skip elements that will not be pruned continue if layer_name in shared_layers: continue if spec_node.is_shared(): shared_layers.add(layer_name) nncf_logger.info('Adding Weight Pruner in: %s', layer_name) _, layer_info = converter.get_layer_info_for_node(spec_node.node_name) for weight_def in spec_node.metatype.weight_definitions: if spec_node.metatype is TFBatchNormalizationLayerMetatype \ and not layer.scale and weight_def.weight_attr_name == 'gamma': nncf_logger.debug('Fused gamma parameter encountered in BatchNormalization layer. ' 'Do not add mask to it.') continue transformations.register( self._get_insertion_command_binary_mask( layer_info.layer_name, weight_def.weight_attr_name) ) transformations.register( self._get_insertion_command_binary_mask( layer_info.layer_name, spec_node.metatype.bias_attr_name) ) return transformations
def init_with_key_list(self, key_list: List): self.call_counts = {key: 0 for key in key_list} nncf_logger.debug("{} tracker: registered {} entries".format(self.name, len(self.call_counts)))
def _set_binary_masks_for_pruned_modules_globally_by_flops_target( self, target_flops_pruning_rate: float): """ Prunes least important filters one-by-one until target FLOPs pruning rate is achieved. Filters are sorted by filter importance score. """ nncf_logger.debug('Setting new binary masks for pruned layers.') target_flops = self.full_flops * (1 - target_flops_pruning_rate) wrapped_layers = collect_wrapped_layers(self._model) masks = [] nncf_sorted_nodes = self._original_graph.topological_sort() for layer in wrapped_layers: nncf_node = [ n for n in nncf_sorted_nodes if layer.layer.name == get_original_name(n.node_name) ][0] nncf_node.data['output_mask'] = tf.ones(get_filters_num(layer)) # 1. Calculate importances for all groups of filters. Initialize masks. filter_importances = [] group_indexes = [] filter_indexes = [] for group in self._pruned_layer_groups_info.get_all_clusters(): cumulative_filters_importance = \ self._calculate_filters_importance_in_group(group, wrapped_layers) filter_importances.extend(cumulative_filters_importance) filters_num = len(cumulative_filters_importance) group_indexes.extend([group.id] * filters_num) filter_indexes.extend(range(filters_num)) masks[group.id] = tf.ones(filters_num) # 2. tmp_in_channels = self._layers_in_channels.copy() tmp_out_channels = self._layers_out_channels.copy() sorted_importances = sorted(zip(filter_importances, group_indexes, filter_indexes), key=lambda x: x[0]) for _, group_id, filter_index in sorted_importances: if self._pruning_quotas[group_id] == 0: continue masks[group_id][filter_index] = 0 self._pruning_quotas[group_id] -= 1 # Update input/output shapes of pruned nodes group = self._pruned_layer_groups_info.get_cluster_by_id(group_id) for node in group.nodes: tmp_out_channels[node.layer_name] -= 1 for node_name in self._next_nodes[group_id]: tmp_in_channels[node_name] -= 1 flops = sum( count_flops_for_nodes(self._original_graph, self._layers_in_shapes, self._layers_out_shapes, input_channels=tmp_in_channels, output_channels=tmp_out_channels, conv_op_types=GENERAL_CONV_LAYERS, linear_op_types=LINEAR_LAYERS).values()) if flops <= target_flops: # 3. Add masks to the graph and propagate them for group in self._pruned_layer_groups_info.get_all_clusters(): for node in group.nodes: nncf_node = self._original_graph.get_node_by_id( node.nncf_node_id) nncf_node.data['output_mask'] = masks[group.id] mask_propagator = MaskPropagationAlgorithm( self._original_graph, TF_PRUNING_OPERATOR_METATYPES) mask_propagator.mask_propagation() # 4. Set binary masks to the model self.current_flops = flops nncf_sorted_nodes = self._original_graph.topological_sort() for layer in wrapped_layers: nncf_node = [ n for n in nncf_sorted_nodes if layer.layer.name == get_original_name(n.node_name) ][0] if nncf_node.data['output_mask'] is not None: self._set_operation_masks( [layer], nncf_node.data['output_mask']) return raise RuntimeError( f'Unable to prune model to required flops pruning rate:' f' {target_flops_pruning_rate}')
def wrapped(*args, **kwargs): ctx = get_current_context() if not ctx or getattr(ctx, 'in_operator', False) or not ctx.is_tracing: op1 = operator(*args, **kwargs) return op1 ctx.in_operator = True try: if operator_info.skip_trace: result = operator(*args, **kwargs) elif ctx.is_forwarding: from nncf.torch.dynamic_graph.trace_functions import forward_trace_only result = forward_trace_only(operator, *args, **kwargs) else: node = None op_name = operator_info.name op_address = ctx.get_caller_context(op_name) layer_attrs = None ignored_algos = [] # Collect module attributes, if required if ctx.trace_dynamic_graph: if op_name in OP_NAMES_REQUIRING_MODULE_ATTRS: curr_module = ctx.get_current_module() if curr_module is None: raise RuntimeError("Operation {} requires module attributes, " "but it was executed outside any module".format(op_name)) layer_attrs = _get_layer_attributes(curr_module, op_name) if isinstance(curr_module, _NNCFModuleMixin): ignored_algos = deepcopy(curr_module.ignored_algorithms) ctx.register_operator_call(op_address.operator_name, op_address.scope_in_model) op_input = OperatorInput(list(args), kwargs) processed_input = ctx.execute_pre_hooks(op_address, op_input) if ctx.trace_dynamic_graph: tensor_metas = make_tensor_metas(processed_input) node = ctx.find_operator_node(tensor_metas, op_address) args = tuple(processed_input.op_args) kwargs = processed_input.op_kwargs result = operator(*args, **kwargs) if isinstance(result, type(NotImplemented)): nncf_logger.debug("Operation {} returned NotImplemented".format(op_name)) elif ctx.trace_dynamic_graph and node is None: node = ctx.maybe_add_node(processed_input, tensor_metas, op_address, layer_attrs, ignored_algos) if is_debug() and ctx.trace_dynamic_graph and node is not None: ctx.register_node_call(node) result = trace_tensors(result, node) result = ctx.execute_post_hooks(op_address, result) except: # Looks like the __repr__ call made during IDE debug to display tensor contents does not exit properly, # but instead throws an exception. This try...except block handles such a situation. # Otherwise the context is stuck in the "in_operator == True" state. ctx.in_operator = False raise ctx.in_operator = False return result
def add_node(self, op_exec_context: OperationExecutionContext, inputs, layer_attrs: BaseLayerAttributes = None, ignored_algorithms: List[str] = None, is_in_iteration_scope: bool = False) -> DynamicGraphNode: node_id = len(self._node_id_to_key_dict) name_parts = (str(op_exec_context.scope_in_model), op_exec_context.operator_name) node_key = '{idx} {uri}'.format(uri='/'.join(name_parts), idx=node_id) nncf_logger.debug("New node added to NNCF graph: {}".format(node_key)) self._node_id_to_key_dict[node_id] = node_key attrs = { DynamicGraph.ID_NODE_ATTR: node_id, DynamicGraph.KEY_NODE_ATTR: node_key, DynamicGraph.OP_EXEC_CONTEXT_NODE_ATTR: op_exec_context, DynamicGraph.IS_IN_ITERATION_SCOPE_NODE_ATTR: is_in_iteration_scope } if layer_attrs is not None: attrs[DynamicGraph.LAYER_ATTRIBUTES] = layer_attrs if ignored_algorithms is not None: attrs[DynamicGraph.IGNORED_ALGOS_NODE_ATTR] = ignored_algorithms else: attrs[DynamicGraph.IGNORED_ALGOS_NODE_ATTR] = [] self._nx_graph.add_node(node_key, **attrs) has_traced_inputs = False for i, info in enumerate(op_exec_context.tensor_metas): if info is None or info.creator_id is None: continue parent = self._node_id_to_key_dict[info.creator_id] self._nx_graph.add_edge(parent, node_key) has_traced_inputs = True self._nx_graph.edges[parent, node_key][ DynamicGraph.ACTIVATION_SHAPE_EDGE_ATTR] = info.shape self._nx_graph.edges[parent, node_key][ DynamicGraph.INPUT_PORT_ID_EDGE_ATTR] = i self._nx_graph.edges[parent, node_key][ DynamicGraph.OUTPUT_PORT_ID_EDGE_ATTR] = info.index self._nx_graph.edges[parent, node_key][ DynamicGraph.ACTIVATION_DTYPE_EDGE_ATTR] = info.dtype nx_node_dict = self._nx_graph.nodes[node_key] node = DynamicGraphNode( node_id=nx_node_dict[DynamicGraph.ID_NODE_ATTR], node_key=nx_node_dict[DynamicGraph.KEY_NODE_ATTR], layer_attributes=nx_node_dict.get(DynamicGraph.LAYER_ATTRIBUTES), op_exec_context=nx_node_dict[ DynamicGraph.OP_EXEC_CONTEXT_NODE_ATTR], ignored_algorithms=nx_node_dict[ DynamicGraph.IGNORED_ALGOS_NODE_ATTR], is_in_iteration_scope=nx_node_dict[ DynamicGraph.IS_IN_ITERATION_SCOPE_NODE_ATTR]) if not has_traced_inputs: self._inputless_nodes[node_key] = node return node
def wrap_operator(operator, operator_info: 'PatchedOperatorInfo'): """ Wraps the input callable object (`operator`) with the functionality that allows the calls to this object to be tracked by the currently set global TracingContext. The wrapped functions can be then intercepted, their arguments and return values modified arbitrarily and, for functions that correspond to operations on tensors in a DNN, their general position and address in the DNN's model control flow graph can be established. :param: operator: A callable object to be wrapped. :param: operator_info (PatchedOperatorInfo): An informational struct containing the specifics of wrapping the `operator` in question. :return: The wrapped version of `operator` that, without a TracingContext, performs functionally the same as the unwrapped version, but within a TracingContext is able to be tracked and hooked. """ # do not wrap function twice _orig_op = getattr(operator, '_original_op', None) if _orig_op is not None: nncf_logger.debug("Operator: {} is already wrapped".format(_orig_op.__name__)) return operator def wrapped(*args, **kwargs): ctx = get_current_context() if not ctx or getattr(ctx, 'in_operator', False) or not ctx.is_tracing: op1 = operator(*args, **kwargs) return op1 ctx.in_operator = True try: if operator_info.skip_trace: result = operator(*args, **kwargs) elif ctx.is_forwarding: from nncf.torch.dynamic_graph.trace_functions import forward_trace_only result = forward_trace_only(operator, *args, **kwargs) else: node = None op_name = operator_info.name op_address = ctx.get_caller_context(op_name) layer_attrs = None ignored_algos = [] # Collect module attributes, if required if ctx.trace_dynamic_graph: if op_name in OP_NAMES_REQUIRING_MODULE_ATTRS: curr_module = ctx.get_current_module() if curr_module is None: raise RuntimeError("Operation {} requires module attributes, " "but it was executed outside any module".format(op_name)) layer_attrs = _get_layer_attributes(curr_module, op_name) if isinstance(curr_module, _NNCFModuleMixin): ignored_algos = deepcopy(curr_module.ignored_algorithms) ctx.register_operator_call(op_address.operator_name, op_address.scope_in_model) op_input = OperatorInput(list(args), kwargs) processed_input = ctx.execute_pre_hooks(op_address, op_input) if ctx.trace_dynamic_graph: tensor_metas = make_tensor_metas(processed_input) node = ctx.find_operator_node(tensor_metas, op_address) args = tuple(processed_input.op_args) kwargs = processed_input.op_kwargs result = operator(*args, **kwargs) if isinstance(result, type(NotImplemented)): nncf_logger.debug("Operation {} returned NotImplemented".format(op_name)) elif ctx.trace_dynamic_graph and node is None: node = ctx.maybe_add_node(processed_input, tensor_metas, op_address, layer_attrs, ignored_algos) if is_debug() and ctx.trace_dynamic_graph and node is not None: ctx.register_node_call(node) result = trace_tensors(result, node) result = ctx.execute_post_hooks(op_address, result) except: # Looks like the __repr__ call made during IDE debug to display tensor contents does not exit properly, # but instead throws an exception. This try...except block handles such a situation. # Otherwise the context is stuck in the "in_operator == True" state. ctx.in_operator = False raise ctx.in_operator = False return result # pylint: disable=protected-access wrapped._original_op = operator wrapped._operator_namespace = operator_info.operator_namespace return wrapped
def apply_init(self) -> SingleConfigQuantizerSetup: if not self._weight_quantizations_by_execution_order: return self._algo.get_quantizer_setup_for_current_state() original_device = next(self._model.parameters()).device self._model.to(self._init_device) traces_per_layer = self._calc_traces(self._criterion_fn, self._criterion, self._iter_number, self._tolerance) if not traces_per_layer: raise RuntimeError('Failed to calculate hessian traces!') traces_order = traces_per_layer.traces_order weight_qconfig_sequences_in_trace_order, covering_qconfig_sequences = \ self.get_qconfig_sequences_constrained_by_traces_order(traces_order) weight_quantizer_ids_in_execution_order = list( self._weight_quantizations_by_execution_order.keys()) if not weight_qconfig_sequences_in_trace_order: warnings.warn( 'All bitwidths configurations are incompatible with HW Config!', RuntimeWarning) return None weight_qconfig_sequences_in_trace_order = \ self._filter_qconfig_sequences_by_excessive_bitwidth(weight_qconfig_sequences_in_trace_order) if self._bitwidth_assignment_mode == BitwidthAssignmentMode.STRICT: weight_qconfig_sequences_in_trace_order = \ self._filter_qconfig_sequences_by_grouped_weight_quantizers(weight_qconfig_sequences_in_trace_order, weight_quantizer_ids_in_execution_order, self._groups_of_adjacent_quantizers, traces_order) if not weight_qconfig_sequences_in_trace_order: warnings.warn( 'No bitwidths configurations are left after removing inconsistent groups of weight quantizers' ' with adjacent activation quantizers!', RuntimeWarning) return self._algo.get_quantizer_setup_for_current_state() compression_ratio_per_qconfig = self.get_compression_ratio_per_qconfig_sequence( weight_qconfig_sequences_in_trace_order, traces_order) min_ratio = min(compression_ratio_per_qconfig) max_ratio = max(compression_ratio_per_qconfig) if not min_ratio <= self._compression_ratio <= max_ratio: raise AttributeError( 'Invalid compression ratio={}. Should be within range [{:.3f}, {:.3f}]' .format(self._compression_ratio, min_ratio, max_ratio)) perturbations, weight_observers = self.calc_quantization_noise( covering_qconfig_sequences, traces_order) metric_per_qconfig_sequence = self.calc_hawq_metric_per_qconfig_sequence( weight_qconfig_sequences_in_trace_order, perturbations, traces_per_layer, self._init_device) qconfig_sequence_index = self.choose_qconfig_sequence( metric_per_qconfig_sequence, compression_ratio_per_qconfig, self._compression_ratio) chosen_qconfig_sequence_in_traces_order = weight_qconfig_sequences_in_trace_order[ qconfig_sequence_index] chosen_qconfig_sequence_in_execution_order = traces_order.get_execution_order_configs( chosen_qconfig_sequence_in_traces_order) bitwidth_sequence = [ qconfig.num_bits for qconfig in chosen_qconfig_sequence_in_execution_order ] nncf_logger.info( 'Chosen HAWQ bitwidth sequence with ratio={:.2f}, bitwidth per weightable layer={}' .format(compression_ratio_per_qconfig[qconfig_sequence_index], bitwidth_sequence)) nncf_logger.debug( 'Order of the weightable layers in the HAWQ bitwidth sequence (in descending order of average' ' Hessian traces) ={}'.format(traces_order)) final_quantizer_setup = self.get_quantizer_setup_for_qconfig_sequence( chosen_qconfig_sequence_in_traces_order, traces_order) if is_debug() or self._dump_hawq_data: hawq_debugger = HAWQDebugger( weight_qconfig_sequences_in_trace_order, perturbations, weight_observers, traces_per_layer, self._bitwidths) hawq_debugger.dump_metric_MB(metric_per_qconfig_sequence) hawq_debugger.dump_metric_flops(metric_per_qconfig_sequence, compression_ratio_per_qconfig, qconfig_sequence_index) hawq_debugger.dump_avg_traces() hawq_debugger.dump_density_of_quantization_noise() hawq_debugger.dump_perturbations_ratio() new_ctrl, new_model = self._algo.apply_new_quantizer_setup( final_quantizer_setup) groups_of_adjacent_quantizers = new_ctrl.groups_of_adjacent_quantizers hawq_debugger.dump_bitwidth_graph(new_ctrl, new_model, groups_of_adjacent_quantizers) bitwidth_per_scope = self.get_bitwidth_per_scope(final_quantizer_setup) str_bw = [ str(element) for element in self.get_bitwidth_per_scope(final_quantizer_setup) ] nncf_logger.info('\n'.join( ['\n\"bitwidth_per_scope\": [', ',\n'.join(str_bw), ']'])) from nncf.common.utils.debug import DEBUG_LOG_DIR Path(DEBUG_LOG_DIR).mkdir(parents=True, exist_ok=True) with safe_open(Path(DEBUG_LOG_DIR) / 'bitwidth_per_scope.json', "w") as outfile: json.dump({'bitwidth_per_scope': bitwidth_per_scope}, outfile, indent=4, sort_keys=False) self._model.to(original_device) return final_quantizer_setup