Ejemplo n.º 1
0
    def statistics(self,
                   quickly_collected_only: bool = False) -> NNCFStatistics:
        if not quickly_collected_only and is_debug():
            stats = PrunedModelTheoreticalBorderline(self._pruned_layers_num,
                                                     self._prunable_layers_num,
                                                     self._max_prunable_flops,
                                                     self._max_prunable_params,
                                                     self.full_flops,
                                                     self.full_params_num)

            nncf_logger.debug(stats.to_str())

        pruned_layers_summary = self._calculate_pruned_layers_summary()
        self._update_benchmark_statistics()
        model_statistics = PrunedModelStatistics(
            self.full_flops, self.current_flops, self.full_params_num,
            self.current_params_num, self.full_filters_num,
            self.current_filters_num, pruned_layers_summary)

        stats = FilterPruningStatistics(model_statistics,
                                        self.scheduler.current_pruning_level,
                                        self.scheduler.target_level,
                                        self.prune_flops)

        nncf_stats = NNCFStatistics()
        nncf_stats.register('filter_pruning', stats)
        return nncf_stats
Ejemplo n.º 2
0
    def __call__(self, saved_inputs: List[TensorMeta],
                 actual_inputs: List[TensorMeta],
                 tm_comparators: List[TensorMetaComparator]) -> bool:
        if saved_inputs is None and actual_inputs:
            return False

        matched_with_unexpected_tensors = False
        for saved_input, actual_input in zip(saved_inputs, actual_inputs):
            if saved_input is None and actual_input is None:
                continue
            if (saved_input is None) and (actual_input is not None):
                # torch.Tensor.size() seems to return ints when not tracing ONNX
                # and tensors when tracing ONNX. This breaks input-based node matching whenever
                # torch.Tensor.size() return value is passed into a NNCF-traced operation (such as `view`)
                # because at graph building time it expected to see ints as args and now it sees tensors.
                # To mitigate this, will only match inputs against the positions which had tensors during build-time
                # and disregard the rest of the argument positions.
                matched_with_unexpected_tensors = True
                continue
            if (saved_input is not None) and (actual_input is None):
                return False
            for tm_comparator in tm_comparators:
                if not tm_comparator(saved_input, actual_input):
                    return False
        if matched_with_unexpected_tensors:
            nncf_logger.debug(
                "Had to match a node to an op which has tensors at positions where there were no tensors "
                "at graph building time:\nNode input metas: {}, but op input metas: {}"
                .format(saved_inputs, actual_inputs))
        return True
Ejemplo n.º 3
0
    def prepare_for_export(self):
        """
        Applies pruning masks to layer weights before exporting the model to ONNX.
        """
        self._propagate_masks()

        pruned_layers_stats = self.get_stats_for_pruned_modules()
        nncf_logger.debug('Pruned layers statistics: \n%s',
                          pruned_layers_stats.draw())
Ejemplo n.º 4
0
    def _set_binary_masks_for_pruned_layers_globally(self,
                                                     pruning_level: float):
        """
        Sets the binary mask values for layer groups according to the global pruning level.
        Filter importance scores in each group are merged into a single global list and a
        threshold value separating the pruning_level proportion of the least important filters
        in the model is calculated. Filters are pruned globally according to the threshold value.
        """
        nncf_logger.debug(
            'Setting new binary masks for all pruned modules together.')
        filter_importances = {}
        wrapped_layers = collect_wrapped_layers(self._model)

        # 0. Remove masks at the elements of the NNCFGraph
        for node in self._original_graph.topological_sort():
            node.data.pop('output_mask', None)

        # 1. Calculate masks
        # a. Calculate importances for all groups of filters
        for group in self._pruned_layer_groups_info.get_all_clusters():
            cumulative_filters_importance = self._calculate_filters_importance_in_group(
                group)
            filter_importances[group.id] = cumulative_filters_importance

        # b. Calculate one threshold for all weights
        importances = tf.concat(list(filter_importances.values()), 0)
        threshold = sorted(importances)[int(pruning_level *
                                            importances.shape[0])]

        # c. Initialize masks
        for group in self._pruned_layer_groups_info.get_all_clusters():
            filter_mask = calculate_binary_mask(filter_importances[group.id],
                                                threshold)
            for node in group.elements:
                nncf_node = self._original_graph.get_node_by_id(
                    node.nncf_node_id)
                nncf_node.data['output_mask'] = TFNNCFTensor(filter_mask)

        # 2. Propagate masks across the graph
        mask_propagator = MaskPropagationAlgorithm(
            self._original_graph, TF_PRUNING_OPERATOR_METATYPES,
            TFNNCFPruningTensorProcessor)
        mask_propagator.mask_propagation()

        # 3. Apply masks to the model
        nncf_sorted_nodes = self._original_graph.topological_sort()
        for layer in wrapped_layers:
            nncf_node = [
                n for n in nncf_sorted_nodes if layer.name == n.layer_name
            ][0]
            if nncf_node.data['output_mask'] is not None:
                self._set_operation_masks([layer],
                                          nncf_node.data['output_mask'].tensor)

        # Calculate actual flops with new masks
        self._update_benchmark_statistics()
Ejemplo n.º 5
0
 def remove_unified_scale_from_point(self, qp_id: QuantizationPointId):
     gid = self.get_unified_scale_group_id(qp_id)
     if gid is None:
         nncf_logger.debug("Attempted to remove QP id {} from associated unified scale group, but the QP"
                           "is not in any unified scale group - ignoring.".format(qp_id))
         return
     self.unified_scale_groups[gid].discard(qp_id)
     if not self.unified_scale_groups[gid]:
         nncf_logger.debug("Removed last entry from a unified scale group {} - removing group itself".format(gid))
         self.unified_scale_groups.pop(gid)
Ejemplo n.º 6
0
 def _get_metatypes_for_hw_config_op(
         self, hw_config_op: HWConfigOpName) -> Set[Type[OperatorMetatype]]:
     metatypes = set()
     for op_meta in self._get_available_operator_metatypes_for_matching():
         if hw_config_op in op_meta.hw_config_names:
             metatypes.add(op_meta)
     if not metatypes:
         nncf_logger.debug(
             'Operation name {} in HW config is not registered in NNCF under any supported '
             'operation metatype - will be ignored'.format(hw_config_op))
     return metatypes
 def kd_loss_fn(teacher_output: torch.Tensor,
                student_output: torch.Tensor):
     mse = torch.nn.MSELoss()
     if len(teacher_output.shape) < 2:
         nncf_logger.debug(
             "Incompatible number of dimensions of the model output tensor for MSE KD "
             "(student - {}, teacher - {}, number of dims {} should be > 1)"
             " (most likely loss) - ignoring!".format(
                 student_output.shape, teacher_output.shape,
                 len(student_output.shape)))
         return torch.zeros([1]).to(student_output.device)
     return scale * mse(teacher_output, student_output)
Ejemplo n.º 8
0
    def _cross_match_version_agnostic_names(normalized_keys_to_load: List[str],
                                            normalized_model_keys: List[str]) -> Dict[str, str]:
        """
        Handles the situation where the normalized_keys_to_load contain legacy version-agnostic names
        of operations, such as `RELU`.

        :param normalized_keys_to_load: A list of keys in the checkpoint, potentially with version-agnostic names
        :param normalized_model_keys: A list of keys in the model, without version-agnostic names.
        :return: A mapping of the checkpoint key to a model key that matches version-agnostic names with their
            torch-specific counterparts.
        """
        version_agnostic_to_specific_names = {"RELU": {"relu", "relu_"}}
        retval = {}
        processed_keys_to_load = normalized_keys_to_load

        for model_key in normalized_model_keys:
            for agnostic_op_name, specific_op_name_set in version_agnostic_to_specific_names.items():
                matches_for_curr_agnostic_op_name = []
                has_specific_op_name = False
                for specific_op_name in specific_op_name_set:
                    # Have to take care not to replace the matches to the class names
                    # The op names in existing checkpoint can only appear in external quantizers,
                    # i.e. external_quantizers.ResNet/ReLU[relu]/relu_0.signed_tensor, so composing a regex to match
                    # for that
                    slash_split_str = model_key.split('/')
                    last_portion = slash_split_str[-1]
                    if specific_op_name in last_portion:
                        last_portion = last_portion.replace(specific_op_name, agnostic_op_name, 1)
                        has_specific_op_name = True
                    slash_split_str[-1] = last_portion
                    agnostic_version_of_model_key = '/'.join(slash_split_str)
                    processed_agnostic_version_of_model_key = agnostic_version_of_model_key
                    if processed_agnostic_version_of_model_key in processed_keys_to_load:
                        idx = processed_keys_to_load.index(processed_agnostic_version_of_model_key)
                        matches_for_curr_agnostic_op_name.append(normalized_keys_to_load[idx])

                if not has_specific_op_name:
                    if matches_for_curr_agnostic_op_name:
                        checkpoint_matched_key = next(iter(matches_for_curr_agnostic_op_name))
                        retval[checkpoint_matched_key] = model_key
                elif len(matches_for_curr_agnostic_op_name) == 1:
                    checkpoint_matched_key = next(iter(matches_for_curr_agnostic_op_name))
                    retval[checkpoint_matched_key] = model_key
                elif len(matches_for_curr_agnostic_op_name) == 0:
                    nncf_logger.debug("Failed to match a version-specific key: {}".format(model_key))
                elif len(matches_for_curr_agnostic_op_name) > 1:
                    nncf_logger.debug("More than one match for the version specific key: {}\n"
                                      "Matches:\n"
                                      "{}".format(model_key, ', '.join(matches_for_curr_agnostic_op_name)))

        return retval
 def kd_loss_fn(teacher_output: torch.Tensor,
                student_output: torch.Tensor):
     if len(student_output.shape) != 2 or len(
             teacher_output.shape) != 2:
         nncf_logger.debug(
             "Incompatible number of dimensions of the model output tensor for softmax KD "
             "(student - {}, teacher - {}, number of dims {} should be == 2)"
             " - ignoring!".format(student_output.shape,
                                   teacher_output.shape,
                                   len(student_output.shape)))
         return torch.zeros([1]).to(student_output.device)
     return scale * -(nn.functional.log_softmax(student_output / temperature, dim=1) *
                      nn.functional.softmax(teacher_output / temperature, dim=1)).mean() \
            * (student_output.shape[1] * temperature * temperature)
Ejemplo n.º 10
0
    def save_first_iteration_node(self, inputs: 'OperatorInput',
                                  node: DynamicGraphNode):
        """
        It finds and saves "starting" points of iteration for further matching with them on next iteration,
        instead of adding new nodes for each iteration. "Starting" points of iteration are nodes
            * that have at least one input node, which is outside of iteration scope
            * or whose all inputs are not TracedTensor
        """
        op_exec_context = node.op_exec_context
        name = str(node)
        iter_scopes = op_exec_context.scope_in_model.get_iteration_scopes()
        if iter_scopes:
            for iter_scope in iter_scopes:
                if iter_scope not in self._first_iteration_nodes:
                    self._first_iteration_nodes[iter_scope] = {}
                first_nodes = self._first_iteration_nodes[iter_scope]
                has_input_outside_iteration = False
                untraced_tensor_inputs = []
                traced_tensor_inputs = []
                non_tensor_inputs = []
                for i in inputs:
                    input_obj = i.getter()
                    if isinstance(input_obj, Tensor):
                        if not isinstance(input_obj, TracedTensor):
                            untraced_tensor_inputs.append(input_obj)
                        else:
                            traced_tensor_inputs.append(input_obj)
                    else:
                        non_tensor_inputs.append(input_obj)

                for i in traced_tensor_inputs:
                    creator_id = i.tensor_meta.creator_id
                    creator_node = self.get_node_by_id(creator_id)
                    creator_node_op_exec_ctx = creator_node[
                        DynamicGraph.OP_EXEC_CONTEXT_NODE_ATTR]
                    within_scopes = creator_node_op_exec_ctx.scope_in_model.get_iteration_scopes(
                    )
                    if iter_scope not in within_scopes:
                        has_input_outside_iteration = True

                if len(untraced_tensor_inputs) == (len(inputs) -
                                                   len(non_tensor_inputs)):
                    has_input_outside_iteration = True
                if has_input_outside_iteration:
                    node_name = str(op_exec_context.op_address)
                    first_nodes[node_name] = node
                    nncf_logger.debug(
                        'Found first iteration node: {} in scope: {}'.format(
                            name, iter_scope))
Ejemplo n.º 11
0
 def get_containing_module(self,
                           node_name: NNCFNodeName) -> torch.nn.Module:
     if self._compressed_graph is not None:
         try:
             scope = self._compressed_graph.get_scope_by_node_name(
                 node_name)
         except RuntimeError:
             nncf_logger.debug(
                 "Node {} not found in compressed graph when trying to determine containing module, "
                 "trying the original graph to see if the node was present there "
                 "during graph building")
             scope = self._original_graph.get_scope_by_node_name(node_name)
     else:
         scope = self._original_graph.get_scope_by_node_name(node_name)
     return self.get_module_by_scope(scope)
Ejemplo n.º 12
0
    def select_qconfigs(self, qp_id_vs_selected_qconfig_dict: Dict[QuantizationPointId, QuantizerConfig],
                        strict: bool =True) -> \
            SingleConfigQuantizerSetup:
        retval = SingleConfigQuantizerSetup()
        retval.unified_scale_groups = deepcopy(self.unified_scale_groups)
        retval.shared_input_operation_set_groups = deepcopy(self.shared_input_operation_set_groups)

        if Counter(qp_id_vs_selected_qconfig_dict.keys()) != Counter(self.quantization_points.keys()):
            raise ValueError("The set of quantization points for a selection is inconsistent with quantization"
                             "points in the quantizer setup!")
        for qp_id, qp in self.quantization_points.items():
            if strict:
                retval.quantization_points[qp_id] = qp.select_qconfig(
                    qp_id_vs_selected_qconfig_dict[qp_id]
                )
            else:
                multi_qp = qp
                qconfig = qp_id_vs_selected_qconfig_dict[qp_id]
                retval.quantization_points[qp_id] = SingleConfigQuantizationPoint(
                    multi_qp.insertion_point, qconfig,
                    multi_qp.directly_quantized_operator_node_names)

        # Segregate the unified scale groups into sub-groups based on what exact config was chosen.
        for us_group in self.unified_scale_groups.values():
            per_channel_qids = set()
            per_tensor_qids = set()
            for us_qid in us_group:
                final_qconfig = retval.quantization_points[us_qid].qconfig
                if final_qconfig.per_channel:
                    per_channel_qids.add(us_qid)
                else:
                    per_tensor_qids.add(us_qid)

            if per_tensor_qids:
                for qid in per_tensor_qids:
                    retval.remove_unified_scale_from_point(qid)

                retval.register_unified_scale_group(list(per_tensor_qids))

            for per_channel_qid in per_channel_qids:
                us_type = self._unified_scale_qpid_vs_type[per_channel_qid]
                if us_type is UnifiedScaleType.UNIFY_ONLY_PER_TENSOR:
                    nncf_logger.debug("Per-channel quantizer config selected in a MultiConfigQuantizerSetup for a "
                                      "unified scale point that only supports per-tensor scale unification, disabling "
                                      "unified scales for this point.")
                retval.remove_unified_scale_from_point(per_channel_qid)

        return retval
Ejemplo n.º 13
0
    def _set_binary_masks_for_pruned_layers_groupwise(self,
                                                      pruning_level: float):
        nncf_logger.debug('Setting new binary masks for pruned layers.')
        wrapped_layers = collect_wrapped_layers(self._model)

        # 0. Removing masks at the elements of the NNCFGraph
        for node in self._original_graph.topological_sort():
            node.data.pop('output_mask', None)

        # 1. Calculate masks
        for group in self._pruned_layer_groups_info.get_all_clusters():
            # a. Calculate the cumulative importance for all filters in the group
            cumulative_filters_importance = self._calculate_filters_importance_in_group(
                group)
            filters_num = len(cumulative_filters_importance)

            # b. Calculate threshold
            num_of_sparse_elems = get_rounded_pruned_element_number(
                cumulative_filters_importance.shape[0], pruning_level)
            threshold = sorted(cumulative_filters_importance)[min(
                num_of_sparse_elems, filters_num - 1)]

            # c. Initialize masks
            filter_mask = calculate_binary_mask(cumulative_filters_importance,
                                                threshold)
            for node in group.elements:
                nncf_node = self._original_graph.get_node_by_id(
                    node.nncf_node_id)
                nncf_node.data['output_mask'] = TFNNCFTensor(filter_mask)

        # 2. Propagating masks across the graph
        mask_propagator = MaskPropagationAlgorithm(
            self._original_graph, TF_PRUNING_OPERATOR_METATYPES,
            TFNNCFPruningTensorProcessor)
        mask_propagator.mask_propagation()

        # 3. Apply masks to the model
        nncf_sorted_nodes = self._original_graph.topological_sort()
        for layer in wrapped_layers:
            nncf_node = [
                n for n in nncf_sorted_nodes if layer.name == n.layer_name
            ][0]
            if nncf_node.data['output_mask'] is not None:
                self._set_operation_masks([layer],
                                          nncf_node.data['output_mask'].tensor)

        # Calculate actual flops and weights number with new masks
        self._update_benchmark_statistics()
Ejemplo n.º 14
0
    def _set_binary_masks_for_pruned_modules_globally(
            self, pruning_level: float) -> None:
        """
        Set the binary mask values for layer groups according to the global pruning level.
        Filter importance scores in each group are merged into a single global list and a
        threshold value separating the pruning_level proportion of the least important filters
        in the model is calculated. Filters are pruned globally according to the threshold value.
        """
        nncf_logger.debug(
            "Setting new binary masks for all pruned modules together.")
        filter_importances = []
        # 1. Calculate importances for all groups of  filters
        for group in self.pruned_module_groups_info.get_all_clusters():
            filters_num = torch.tensor(
                [get_filters_num(minfo.module) for minfo in group.elements])
            assert torch.all(filters_num == filters_num[0])
            device = group.elements[0].module.weight.device

            cumulative_filters_importance = torch.zeros(
                filters_num[0]).to(device)
            # Calculate cumulative importance for all filters in this group
            for minfo in group.elements:
                normalized_weight = self.weights_normalizer(
                    minfo.module.weight)
                filters_importance = self.filter_importance(
                    normalized_weight,
                    minfo.module.target_weight_dim_for_compression)
                cumulative_filters_importance += filters_importance

            filter_importances.append(cumulative_filters_importance)

        # 2. Calculate one threshold for all weights
        importances = torch.cat(filter_importances)
        threshold = sorted(importances)[int(pruning_level *
                                            importances.size(0))]

        # 3. Set binary masks for filters in groups
        for i, group in enumerate(
                self.pruned_module_groups_info.get_all_clusters()):
            mask = calculate_binary_mask(filter_importances[i], threshold)
            for minfo in group.elements:
                pruning_module = minfo.operand
                pruning_module.binary_filter_pruning_mask = mask

        # Calculate actual flops and weights number with new masks
        self._update_benchmark_statistics()
Ejemplo n.º 15
0
def get_filters_num(layer: NNCFWrapper):
    layer_metatype = get_keras_layer_metatype(layer)
    if len(layer_metatype.weight_definitions) != 1:
        raise ValueError(f'Could not calculate the number of filters '
                         f'for the layer {layer.layer.name}.')

    weight_def = layer_metatype.weight_definitions[0]
    weight_attr = weight_def.weight_attr_name

    if layer_metatype is TFBatchNormalizationLayerMetatype and not layer.layer.scale:
        nncf_logger.debug(
            'Fused gamma parameter encountered in BatchNormalization layer. '
            'Use beta parameter instead to calculate the number of filters.')
        weight_attr = 'beta'

    filter_axis = get_filter_axis(layer, weight_attr)
    filters_num = layer.layer_weights[weight_attr].shape[filter_axis]
    return filters_num
Ejemplo n.º 16
0
 def apply_minmax_init(self,
                       min_values,
                       max_values,
                       log_module_name: str = None):
     """min_values and max_values must have the same shape as specified in self.scale_shape"""
     if self.initialized:
         nncf_logger.debug(
             "Skipped initializing {} - loaded from checkpoint".format(
                 log_module_name))
         return
     if torch.any(torch.eq(min_values, np.inf)) or torch.any(
             torch.eq(max_values, -np.inf)):
         raise AttributeError(
             'Statistics is not collected for {}'.format(log_module_name))
     own_device = next(self.parameters()).device
     min_values = min_values.to(own_device)
     max_values = max_values.to(own_device)
     self._apply_minmax_init(min_values, max_values, log_module_name)
Ejemplo n.º 17
0
    def _propagate_masks(self):
        nncf_logger.debug("Propagating pruning masks")
        # 1. Propagate masks for all modules
        graph = self.model.get_original_graph()

        init_output_masks_in_graph(
            graph, self.pruned_module_groups_info.get_all_nodes())
        MaskPropagationAlgorithm(
            graph, PT_PRUNING_OPERATOR_METATYPES,
            PTNNCFPruningTensorProcessor).mask_propagation()

        # 2. Set the masks for Batch/Group Norms
        pruned_node_modules = []
        for node, pruning_block, node_module in self._pruned_norms_operators:
            if node_module not in pruned_node_modules:
                # Setting masks for BN nodes
                pruning_block.binary_filter_pruning_mask = node.data[
                    'output_mask'].tensor
                pruned_node_modules.append(node_module)
Ejemplo n.º 18
0
    def _set_binary_masks_for_pruned_modules_groupwise(
            self, pruning_level: Union[float, Dict[int, float]]) -> None:
        """
        Set the binary mask values according to groupwise pruning level.
        If pruning_level is a float, set the pruning level uniformly across groups.
        If pruning_level is a dict, set specific pruning levels corresponding to each group.
        """
        nncf_logger.debug("Updating binary masks for pruned modules.")
        groupwise_pruning_levels_set = isinstance(pruning_level, dict)

        for group in self.pruned_module_groups_info.get_all_clusters():
            group_pruning_level = pruning_level[group.id] if groupwise_pruning_levels_set \
                else pruning_level

            filters_num = torch.tensor(
                [get_filters_num(minfo.module) for minfo in group.elements])
            assert torch.all(filters_num == filters_num[0])
            device = group.elements[0].module.weight.device

            cumulative_filters_importance = torch.zeros(
                filters_num[0]).to(device)
            # 1. Calculate cumulative importance for all filters in group
            for minfo in group.elements:
                filters_importance = self.filter_importance(
                    minfo.module.weight,
                    minfo.module.target_weight_dim_for_compression)
                cumulative_filters_importance += filters_importance

            # 2. Calculate threshold
            num_of_sparse_elems = get_rounded_pruned_element_number(
                cumulative_filters_importance.size(0), group_pruning_level)
            threshold = sorted(cumulative_filters_importance)[min(
                num_of_sparse_elems, filters_num[0] - 1)]
            mask = calculate_binary_mask(cumulative_filters_importance,
                                         threshold)

            # 3. Set binary masks for filter
            for minfo in group.elements:
                pruning_module = minfo.operand
                pruning_module.binary_filter_pruning_mask = mask

        # Calculate actual flops and weights number with new masks
        self._update_benchmark_statistics()
Ejemplo n.º 19
0
 def _get_operations_with_attribute_values(self, attribute_name_vs_required_value: Dict[str, Any]) -> \
     Set[Type[OperatorMetatype]]:
     result = set()
     for op_dict in self:
         if self.ATTRIBUTES_NAME not in op_dict:
             continue
         for attr_name, attr_value in attribute_name_vs_required_value.items(
         ):
             is_value_matched = op_dict[
                 self.ATTRIBUTES_NAME][attr_name] == attr_value
             is_attr_set = attr_name in op_dict[self.ATTRIBUTES_NAME]
             if is_value_matched and is_attr_set:
                 hw_config_op_name = op_dict.type
                 metatypes = self._get_metatypes_for_hw_config_op(
                     hw_config_op_name)
                 if not metatypes:
                     nncf_logger.debug(
                         'Operation name {} in HW config is not registered in NNCF under any supported '
                         'operation metatype - will be ignored'.format(
                             hw_config_op_name))
                 result.update(metatypes)
     return result
Ejemplo n.º 20
0
    def get_metatype_vs_quantizer_configs_map(
        self,
        for_weights=False
    ) -> Dict[Type[OperatorMetatype], Optional[List[QuantizerConfig]]]:
        # 'None' for ops unspecified in HW config, empty list for wildcard quantization ops
        retval = {
            k: None
            for k in self._get_available_operator_metatypes_for_matching()
        }
        config_key = 'weights' if for_weights else 'activations'
        for op_dict in self:
            hw_config_op_name = op_dict.type

            metatypes = self._get_metatypes_for_hw_config_op(hw_config_op_name)
            if not metatypes:
                nncf_logger.debug(
                    'Operation name {} in HW config is not registered in NNCF under any supported operation '
                    'metatype - will be ignored'.format(hw_config_op_name))

            if self.QUANTIZATION_ALGORITHM_NAME in op_dict:
                allowed_qconfs = op_dict[
                    self.QUANTIZATION_ALGORITHM_NAME][config_key]
            else:
                allowed_qconfs = []

            qconf_list_with_possible_duplicates = []
            for hw_config_qconf_dict in allowed_qconfs:
                qconf_list_with_possible_duplicates.append(
                    self.get_qconf_from_hw_config_subdict(
                        hw_config_qconf_dict))

            qconf_list = list(
                OrderedDict.fromkeys(qconf_list_with_possible_duplicates))

            for meta in metatypes:
                retval[meta] = qconf_list

        return retval
Ejemplo n.º 21
0
    def statistics(self,
                   quickly_collected_only: bool = False) -> NNCFStatistics:
        if not quickly_collected_only and is_debug():
            stats = PrunedModelTheoreticalBorderline(self._pruned_layers_num,
                                                     self._prunable_layers_num,
                                                     self._max_prunable_flops,
                                                     self._max_prunable_params,
                                                     self.full_flops,
                                                     self.full_params_num)

            nncf_logger.debug(stats.to_str())

        pruned_layers_summary = {}
        for minfo in self.pruned_module_groups_info.get_all_nodes():
            layer_name = str(minfo.module_scope)
            if layer_name not in pruned_layers_summary:
                pruned_layers_summary[layer_name] = \
                    PrunedLayerSummary(layer_name,
                                       list(minfo.module.weight.size()),
                                       list(self.mask_shape(minfo)),
                                       self.pruning_level_for_mask(minfo))

        self._update_benchmark_statistics()
        model_statistics = PrunedModelStatistics(
            self.full_flops, self.current_flops, self.full_params_num,
            self.current_params_num, self.full_filters_num,
            self.current_filters_num, list(pruned_layers_summary.values()))

        stats = FilterPruningStatistics(model_statistics,
                                        self.scheduler.current_pruning_level,
                                        self.scheduler.target_level,
                                        self.prune_flops)

        nncf_stats = NNCFStatistics()
        nncf_stats.register('filter_pruning', stats)
        return nncf_stats
Ejemplo n.º 22
0
    def get_transformation_layout(self, model: tf.keras.Model) -> TFTransformationLayout:
        """
        Computes necessary model transformations (pruning mask insertions) to enable pruning.

        :param model: The original uncompressed model.
        :return: The instance of the `TransformationLayout` class containing
            a list of pruning mask insertions.
        """
        converter = TFModelConverterFactory.create(model)
        self._graph = converter.convert()
        groups_of_nodes_to_prune = self._pruning_node_selector.create_pruning_groups(self._graph)

        transformations = TFTransformationLayout()
        shared_layers = set()

        self._pruned_layer_groups_info = Clusterization[PrunedLayerInfo](lambda x: x.layer_name)

        for i, group in enumerate(groups_of_nodes_to_prune.get_all_clusters()):
            group_minfos = []
            for node in group.elements:
                layer_name = get_layer_identifier(node)
                layer = model.get_layer(layer_name)
                group_minfos.append(PrunedLayerInfo(node.node_name, layer_name, node.node_id,
                                                    is_prunable_depthwise_conv(node)))

                # Add output_mask to elements to run mask_propagation
                # and detect spec_nodes that will be pruned.
                # It should be done for all elements of shared layer.
                node.data['output_mask'] = TFNNCFTensor(tf.ones(node.layer_attributes.out_channels))
                if layer_name in shared_layers:
                    continue
                if node.is_shared():
                    shared_layers.add(layer_name)
                # Check that we need to prune weights in this op
                assert self._is_pruned_layer(layer)
                nncf_logger.info('Adding Weight Pruner in: %s', layer_name)

                _, layer_info = converter.get_layer_info_for_node(node.node_name)
                for weight_def in node.metatype.weight_definitions:
                    transformations.register(
                        self._get_insertion_command_binary_mask(
                            layer_info.layer_name, weight_def.weight_attr_name)
                    )
                if node.metatype.bias_attr_name is not None and \
                        getattr(layer, node.metatype.bias_attr_name) is not None:
                    transformations.register(
                        self._get_insertion_command_binary_mask(
                            layer_info.layer_name, node.metatype.bias_attr_name)
                    )

            cluster = Cluster[PrunedLayerInfo](i, group_minfos, [n.node_id for n in group.elements])
            self._pruned_layer_groups_info.add_cluster(cluster)

        # Propagating masks across the graph to detect spec_nodes that will be pruned
        mask_propagator = MaskPropagationAlgorithm(self._graph, TF_PRUNING_OPERATOR_METATYPES,
                                                   TFNNCFPruningTensorProcessor)
        mask_propagator.mask_propagation()

        # Add masks for all spec modules, because prunable batchnorm layers can be determined
        # at the moment of mask propagation
        types_spec_layers = [TFBatchNormalizationLayerMetatype] \
            if self._prune_batch_norms else []

        spec_nodes = self._graph.get_nodes_by_metatypes(types_spec_layers)
        for spec_node in spec_nodes:
            layer_name = get_layer_identifier(spec_node)
            layer = model.get_layer(layer_name)
            if spec_node.data['output_mask'] is None:
                # Skip elements that will not be pruned
                continue
            if layer_name in shared_layers:
                continue
            if spec_node.is_shared():
                shared_layers.add(layer_name)
            nncf_logger.info('Adding Weight Pruner in: %s', layer_name)

            _, layer_info = converter.get_layer_info_for_node(spec_node.node_name)
            for weight_def in spec_node.metatype.weight_definitions:
                if spec_node.metatype is TFBatchNormalizationLayerMetatype \
                        and not layer.scale and weight_def.weight_attr_name == 'gamma':
                    nncf_logger.debug('Fused gamma parameter encountered in BatchNormalization layer. '
                                      'Do not add mask to it.')
                    continue

                transformations.register(
                    self._get_insertion_command_binary_mask(
                        layer_info.layer_name, weight_def.weight_attr_name)
                )
            transformations.register(
                self._get_insertion_command_binary_mask(
                    layer_info.layer_name, spec_node.metatype.bias_attr_name)
            )
        return transformations
Ejemplo n.º 23
0
 def init_with_key_list(self, key_list: List):
     self.call_counts = {key: 0 for key in key_list}
     nncf_logger.debug("{} tracker: registered {} entries".format(self.name, len(self.call_counts)))
Ejemplo n.º 24
0
    def _set_binary_masks_for_pruned_modules_globally_by_flops_target(
            self, target_flops_pruning_rate: float):
        """
        Prunes least important filters one-by-one until target FLOPs pruning rate is achieved.
        Filters are sorted by filter importance score.
        """
        nncf_logger.debug('Setting new binary masks for pruned layers.')
        target_flops = self.full_flops * (1 - target_flops_pruning_rate)
        wrapped_layers = collect_wrapped_layers(self._model)
        masks = []

        nncf_sorted_nodes = self._original_graph.topological_sort()
        for layer in wrapped_layers:
            nncf_node = [
                n for n in nncf_sorted_nodes
                if layer.layer.name == get_original_name(n.node_name)
            ][0]
            nncf_node.data['output_mask'] = tf.ones(get_filters_num(layer))

        # 1. Calculate importances for all groups of filters. Initialize masks.
        filter_importances = []
        group_indexes = []
        filter_indexes = []
        for group in self._pruned_layer_groups_info.get_all_clusters():
            cumulative_filters_importance = \
                self._calculate_filters_importance_in_group(group, wrapped_layers)

            filter_importances.extend(cumulative_filters_importance)
            filters_num = len(cumulative_filters_importance)
            group_indexes.extend([group.id] * filters_num)
            filter_indexes.extend(range(filters_num))
            masks[group.id] = tf.ones(filters_num)

        # 2.
        tmp_in_channels = self._layers_in_channels.copy()
        tmp_out_channels = self._layers_out_channels.copy()
        sorted_importances = sorted(zip(filter_importances, group_indexes,
                                        filter_indexes),
                                    key=lambda x: x[0])
        for _, group_id, filter_index in sorted_importances:
            if self._pruning_quotas[group_id] == 0:
                continue
            masks[group_id][filter_index] = 0
            self._pruning_quotas[group_id] -= 1

            # Update input/output shapes of pruned nodes
            group = self._pruned_layer_groups_info.get_cluster_by_id(group_id)
            for node in group.nodes:
                tmp_out_channels[node.layer_name] -= 1
            for node_name in self._next_nodes[group_id]:
                tmp_in_channels[node_name] -= 1

            flops = sum(
                count_flops_for_nodes(self._original_graph,
                                      self._layers_in_shapes,
                                      self._layers_out_shapes,
                                      input_channels=tmp_in_channels,
                                      output_channels=tmp_out_channels,
                                      conv_op_types=GENERAL_CONV_LAYERS,
                                      linear_op_types=LINEAR_LAYERS).values())
            if flops <= target_flops:
                # 3. Add masks to the graph and propagate them
                for group in self._pruned_layer_groups_info.get_all_clusters():
                    for node in group.nodes:
                        nncf_node = self._original_graph.get_node_by_id(
                            node.nncf_node_id)
                        nncf_node.data['output_mask'] = masks[group.id]

                mask_propagator = MaskPropagationAlgorithm(
                    self._original_graph, TF_PRUNING_OPERATOR_METATYPES)
                mask_propagator.mask_propagation()

                # 4. Set binary masks to the model
                self.current_flops = flops
                nncf_sorted_nodes = self._original_graph.topological_sort()
                for layer in wrapped_layers:
                    nncf_node = [
                        n for n in nncf_sorted_nodes
                        if layer.layer.name == get_original_name(n.node_name)
                    ][0]
                    if nncf_node.data['output_mask'] is not None:
                        self._set_operation_masks(
                            [layer], nncf_node.data['output_mask'])
                return
        raise RuntimeError(
            f'Unable to prune model to required flops pruning rate:'
            f' {target_flops_pruning_rate}')
Ejemplo n.º 25
0
    def wrapped(*args, **kwargs):
        ctx = get_current_context()
        if not ctx or getattr(ctx, 'in_operator', False) or not ctx.is_tracing:
            op1 = operator(*args, **kwargs)
            return op1

        ctx.in_operator = True

        try:
            if operator_info.skip_trace:
                result = operator(*args, **kwargs)
            elif ctx.is_forwarding:
                from nncf.torch.dynamic_graph.trace_functions import forward_trace_only
                result = forward_trace_only(operator, *args, **kwargs)
            else:
                node = None
                op_name = operator_info.name
                op_address = ctx.get_caller_context(op_name)

                layer_attrs = None
                ignored_algos = []
                # Collect module attributes, if required
                if ctx.trace_dynamic_graph:
                    if op_name in OP_NAMES_REQUIRING_MODULE_ATTRS:
                        curr_module = ctx.get_current_module()
                        if curr_module is None:
                            raise RuntimeError("Operation {} requires module attributes, "
                                               "but it was executed outside any module".format(op_name))
                        layer_attrs = _get_layer_attributes(curr_module, op_name)
                        if isinstance(curr_module, _NNCFModuleMixin):
                            ignored_algos = deepcopy(curr_module.ignored_algorithms)

                ctx.register_operator_call(op_address.operator_name, op_address.scope_in_model)
                op_input = OperatorInput(list(args), kwargs)
                processed_input = ctx.execute_pre_hooks(op_address, op_input)

                if ctx.trace_dynamic_graph:
                    tensor_metas = make_tensor_metas(processed_input)
                    node = ctx.find_operator_node(tensor_metas, op_address)

                args = tuple(processed_input.op_args)
                kwargs = processed_input.op_kwargs
                result = operator(*args, **kwargs)

                if isinstance(result, type(NotImplemented)):
                    nncf_logger.debug("Operation {} returned NotImplemented".format(op_name))
                elif ctx.trace_dynamic_graph and node is None:
                    node = ctx.maybe_add_node(processed_input, tensor_metas, op_address, layer_attrs, ignored_algos)

                if is_debug() and ctx.trace_dynamic_graph and node is not None:
                    ctx.register_node_call(node)
                result = trace_tensors(result, node)
                result = ctx.execute_post_hooks(op_address, result)
        except:
            # Looks like the __repr__ call made during IDE debug to display tensor contents does not exit properly,
            # but instead throws an exception. This try...except block handles such a situation.
            # Otherwise the context is stuck in the "in_operator == True" state.
            ctx.in_operator = False
            raise

        ctx.in_operator = False
        return result
Ejemplo n.º 26
0
    def add_node(self,
                 op_exec_context: OperationExecutionContext,
                 inputs,
                 layer_attrs: BaseLayerAttributes = None,
                 ignored_algorithms: List[str] = None,
                 is_in_iteration_scope: bool = False) -> DynamicGraphNode:
        node_id = len(self._node_id_to_key_dict)

        name_parts = (str(op_exec_context.scope_in_model),
                      op_exec_context.operator_name)
        node_key = '{idx} {uri}'.format(uri='/'.join(name_parts), idx=node_id)

        nncf_logger.debug("New node added to NNCF graph: {}".format(node_key))

        self._node_id_to_key_dict[node_id] = node_key
        attrs = {
            DynamicGraph.ID_NODE_ATTR: node_id,
            DynamicGraph.KEY_NODE_ATTR: node_key,
            DynamicGraph.OP_EXEC_CONTEXT_NODE_ATTR: op_exec_context,
            DynamicGraph.IS_IN_ITERATION_SCOPE_NODE_ATTR: is_in_iteration_scope
        }
        if layer_attrs is not None:
            attrs[DynamicGraph.LAYER_ATTRIBUTES] = layer_attrs

        if ignored_algorithms is not None:
            attrs[DynamicGraph.IGNORED_ALGOS_NODE_ATTR] = ignored_algorithms
        else:
            attrs[DynamicGraph.IGNORED_ALGOS_NODE_ATTR] = []

        self._nx_graph.add_node(node_key, **attrs)

        has_traced_inputs = False
        for i, info in enumerate(op_exec_context.tensor_metas):
            if info is None or info.creator_id is None:
                continue
            parent = self._node_id_to_key_dict[info.creator_id]
            self._nx_graph.add_edge(parent, node_key)
            has_traced_inputs = True
            self._nx_graph.edges[parent, node_key][
                DynamicGraph.ACTIVATION_SHAPE_EDGE_ATTR] = info.shape
            self._nx_graph.edges[parent, node_key][
                DynamicGraph.INPUT_PORT_ID_EDGE_ATTR] = i
            self._nx_graph.edges[parent, node_key][
                DynamicGraph.OUTPUT_PORT_ID_EDGE_ATTR] = info.index
            self._nx_graph.edges[parent, node_key][
                DynamicGraph.ACTIVATION_DTYPE_EDGE_ATTR] = info.dtype

        nx_node_dict = self._nx_graph.nodes[node_key]
        node = DynamicGraphNode(
            node_id=nx_node_dict[DynamicGraph.ID_NODE_ATTR],
            node_key=nx_node_dict[DynamicGraph.KEY_NODE_ATTR],
            layer_attributes=nx_node_dict.get(DynamicGraph.LAYER_ATTRIBUTES),
            op_exec_context=nx_node_dict[
                DynamicGraph.OP_EXEC_CONTEXT_NODE_ATTR],
            ignored_algorithms=nx_node_dict[
                DynamicGraph.IGNORED_ALGOS_NODE_ATTR],
            is_in_iteration_scope=nx_node_dict[
                DynamicGraph.IS_IN_ITERATION_SCOPE_NODE_ATTR])

        if not has_traced_inputs:
            self._inputless_nodes[node_key] = node

        return node
Ejemplo n.º 27
0
def wrap_operator(operator, operator_info: 'PatchedOperatorInfo'):
    """
    Wraps the input callable object (`operator`) with the functionality that allows the calls to this object
    to be tracked by the currently set global TracingContext. The wrapped functions can be then intercepted,
    their arguments and return values modified arbitrarily and, for functions that correspond to operations on
    tensors in a DNN,  their general position and address in the DNN's model control flow graph can be established.

    :param: operator: A callable object to be wrapped.
    :param: operator_info (PatchedOperatorInfo): An informational struct containing the specifics of wrapping
            the `operator` in question.

    :return: The wrapped version of `operator` that, without a TracingContext, performs functionally the same as
             the unwrapped version, but within a TracingContext is able to be tracked and hooked.
    """
    # do not wrap function twice
    _orig_op = getattr(operator, '_original_op', None)
    if _orig_op is not None:
        nncf_logger.debug("Operator: {} is already wrapped".format(_orig_op.__name__))
        return operator

    def wrapped(*args, **kwargs):
        ctx = get_current_context()
        if not ctx or getattr(ctx, 'in_operator', False) or not ctx.is_tracing:
            op1 = operator(*args, **kwargs)
            return op1

        ctx.in_operator = True

        try:
            if operator_info.skip_trace:
                result = operator(*args, **kwargs)
            elif ctx.is_forwarding:
                from nncf.torch.dynamic_graph.trace_functions import forward_trace_only
                result = forward_trace_only(operator, *args, **kwargs)
            else:
                node = None
                op_name = operator_info.name
                op_address = ctx.get_caller_context(op_name)

                layer_attrs = None
                ignored_algos = []
                # Collect module attributes, if required
                if ctx.trace_dynamic_graph:
                    if op_name in OP_NAMES_REQUIRING_MODULE_ATTRS:
                        curr_module = ctx.get_current_module()
                        if curr_module is None:
                            raise RuntimeError("Operation {} requires module attributes, "
                                               "but it was executed outside any module".format(op_name))
                        layer_attrs = _get_layer_attributes(curr_module, op_name)
                        if isinstance(curr_module, _NNCFModuleMixin):
                            ignored_algos = deepcopy(curr_module.ignored_algorithms)

                ctx.register_operator_call(op_address.operator_name, op_address.scope_in_model)
                op_input = OperatorInput(list(args), kwargs)
                processed_input = ctx.execute_pre_hooks(op_address, op_input)

                if ctx.trace_dynamic_graph:
                    tensor_metas = make_tensor_metas(processed_input)
                    node = ctx.find_operator_node(tensor_metas, op_address)

                args = tuple(processed_input.op_args)
                kwargs = processed_input.op_kwargs
                result = operator(*args, **kwargs)

                if isinstance(result, type(NotImplemented)):
                    nncf_logger.debug("Operation {} returned NotImplemented".format(op_name))
                elif ctx.trace_dynamic_graph and node is None:
                    node = ctx.maybe_add_node(processed_input, tensor_metas, op_address, layer_attrs, ignored_algos)

                if is_debug() and ctx.trace_dynamic_graph and node is not None:
                    ctx.register_node_call(node)
                result = trace_tensors(result, node)
                result = ctx.execute_post_hooks(op_address, result)
        except:
            # Looks like the __repr__ call made during IDE debug to display tensor contents does not exit properly,
            # but instead throws an exception. This try...except block handles such a situation.
            # Otherwise the context is stuck in the "in_operator == True" state.
            ctx.in_operator = False
            raise

        ctx.in_operator = False
        return result

    # pylint: disable=protected-access
    wrapped._original_op = operator
    wrapped._operator_namespace = operator_info.operator_namespace
    return wrapped
Ejemplo n.º 28
0
    def apply_init(self) -> SingleConfigQuantizerSetup:
        if not self._weight_quantizations_by_execution_order:
            return self._algo.get_quantizer_setup_for_current_state()

        original_device = next(self._model.parameters()).device
        self._model.to(self._init_device)

        traces_per_layer = self._calc_traces(self._criterion_fn,
                                             self._criterion,
                                             self._iter_number,
                                             self._tolerance)
        if not traces_per_layer:
            raise RuntimeError('Failed to calculate hessian traces!')

        traces_order = traces_per_layer.traces_order
        weight_qconfig_sequences_in_trace_order, covering_qconfig_sequences = \
            self.get_qconfig_sequences_constrained_by_traces_order(traces_order)

        weight_quantizer_ids_in_execution_order = list(
            self._weight_quantizations_by_execution_order.keys())

        if not weight_qconfig_sequences_in_trace_order:
            warnings.warn(
                'All bitwidths configurations are incompatible with HW Config!',
                RuntimeWarning)
            return None

        weight_qconfig_sequences_in_trace_order = \
            self._filter_qconfig_sequences_by_excessive_bitwidth(weight_qconfig_sequences_in_trace_order)

        if self._bitwidth_assignment_mode == BitwidthAssignmentMode.STRICT:
            weight_qconfig_sequences_in_trace_order = \
                self._filter_qconfig_sequences_by_grouped_weight_quantizers(weight_qconfig_sequences_in_trace_order,
                                                                            weight_quantizer_ids_in_execution_order,
                                                                            self._groups_of_adjacent_quantizers,
                                                                            traces_order)
        if not weight_qconfig_sequences_in_trace_order:
            warnings.warn(
                'No bitwidths configurations are left after removing inconsistent groups of weight quantizers'
                ' with adjacent activation quantizers!', RuntimeWarning)
            return self._algo.get_quantizer_setup_for_current_state()

        compression_ratio_per_qconfig = self.get_compression_ratio_per_qconfig_sequence(
            weight_qconfig_sequences_in_trace_order, traces_order)
        min_ratio = min(compression_ratio_per_qconfig)
        max_ratio = max(compression_ratio_per_qconfig)
        if not min_ratio <= self._compression_ratio <= max_ratio:
            raise AttributeError(
                'Invalid compression ratio={}. Should be within range [{:.3f}, {:.3f}]'
                .format(self._compression_ratio, min_ratio, max_ratio))

        perturbations, weight_observers = self.calc_quantization_noise(
            covering_qconfig_sequences, traces_order)

        metric_per_qconfig_sequence = self.calc_hawq_metric_per_qconfig_sequence(
            weight_qconfig_sequences_in_trace_order, perturbations,
            traces_per_layer, self._init_device)

        qconfig_sequence_index = self.choose_qconfig_sequence(
            metric_per_qconfig_sequence, compression_ratio_per_qconfig,
            self._compression_ratio)
        chosen_qconfig_sequence_in_traces_order = weight_qconfig_sequences_in_trace_order[
            qconfig_sequence_index]
        chosen_qconfig_sequence_in_execution_order = traces_order.get_execution_order_configs(
            chosen_qconfig_sequence_in_traces_order)
        bitwidth_sequence = [
            qconfig.num_bits
            for qconfig in chosen_qconfig_sequence_in_execution_order
        ]
        nncf_logger.info(
            'Chosen HAWQ bitwidth sequence with ratio={:.2f}, bitwidth per weightable layer={}'
            .format(compression_ratio_per_qconfig[qconfig_sequence_index],
                    bitwidth_sequence))
        nncf_logger.debug(
            'Order of the weightable layers in the HAWQ bitwidth sequence (in descending order of average'
            ' Hessian traces) ={}'.format(traces_order))

        final_quantizer_setup = self.get_quantizer_setup_for_qconfig_sequence(
            chosen_qconfig_sequence_in_traces_order, traces_order)
        if is_debug() or self._dump_hawq_data:
            hawq_debugger = HAWQDebugger(
                weight_qconfig_sequences_in_trace_order, perturbations,
                weight_observers, traces_per_layer, self._bitwidths)
            hawq_debugger.dump_metric_MB(metric_per_qconfig_sequence)
            hawq_debugger.dump_metric_flops(metric_per_qconfig_sequence,
                                            compression_ratio_per_qconfig,
                                            qconfig_sequence_index)
            hawq_debugger.dump_avg_traces()
            hawq_debugger.dump_density_of_quantization_noise()
            hawq_debugger.dump_perturbations_ratio()
            new_ctrl, new_model = self._algo.apply_new_quantizer_setup(
                final_quantizer_setup)
            groups_of_adjacent_quantizers = new_ctrl.groups_of_adjacent_quantizers
            hawq_debugger.dump_bitwidth_graph(new_ctrl, new_model,
                                              groups_of_adjacent_quantizers)
        bitwidth_per_scope = self.get_bitwidth_per_scope(final_quantizer_setup)
        str_bw = [
            str(element)
            for element in self.get_bitwidth_per_scope(final_quantizer_setup)
        ]
        nncf_logger.info('\n'.join(
            ['\n\"bitwidth_per_scope\": [', ',\n'.join(str_bw), ']']))
        from nncf.common.utils.debug import DEBUG_LOG_DIR
        Path(DEBUG_LOG_DIR).mkdir(parents=True, exist_ok=True)
        with safe_open(Path(DEBUG_LOG_DIR) / 'bitwidth_per_scope.json',
                       "w") as outfile:
            json.dump({'bitwidth_per_scope': bitwidth_per_scope},
                      outfile,
                      indent=4,
                      sort_keys=False)
        self._model.to(original_device)
        return final_quantizer_setup