Ejemplo n.º 1
0
    def apply_init(self):
        self.medians = torch.ones(self.scale_shape).to(self.device)
        self.median_absolute_deviations = torch.ones(self.scale_shape).to(
            self.device)

        per_channel_history = get_per_channel_history(self.input_history,
                                                      self.scale_shape,
                                                      discard_zeros=True)
        per_channel_median = [
            np.median(channel_hist) for channel_hist in per_channel_history
        ]
        per_channel_mad = []
        for idx, median in enumerate(per_channel_median):
            # Constant factor depends on the distribution form - assuming normal
            per_channel_mad.append(
                1.4826 * np.median(abs(per_channel_history[idx] - median)))

        numpy_median = np.asarray(per_channel_median)
        numpy_mad = np.asarray(per_channel_mad)
        median_tensor = torch.from_numpy(numpy_median).to(self.device,
                                                          dtype=torch.float)
        mad_tensor = torch.from_numpy(numpy_mad).to(self.device,
                                                    dtype=torch.float)

        median_tensor = expand_like(median_tensor, self.scale_shape)
        mad_tensor = expand_like(mad_tensor, self.scale_shape)

        nncf_logger.debug("Statistics: median={} MAD={}".format(
            get_flat_tensor_contents_string(median_tensor),
            get_flat_tensor_contents_string(mad_tensor)))
        self.quantize_module.apply_minmax_init(median_tensor - 3 * mad_tensor,
                                               median_tensor + 3 * mad_tensor,
                                               self.log_module_name)
Ejemplo n.º 2
0
    def add_node(self, op_exec_context: OperationExecutionContext, inputs) -> NNCFNode:
        node_id = len(self._node_id_to_key_dict)

        name_parts = (str(op_exec_context.scope_in_model), op_exec_context.operator_name)
        node_key = '{idx} {uri}'.format(uri='/'.join(name_parts), idx=node_id)

        nncf_logger.debug("New node added to NNCF graph: {}".format(node_key))

        self._node_id_to_key_dict[node_id] = node_key
        attrs = {
            NNCFGraph.ID_NODE_ATTR: node_id,
            NNCFGraph.KEY_NODE_ATTR: node_key,
            NNCFGraph.OP_EXEC_CONTEXT_NODE_ATTR: op_exec_context,
        }
        self._nx_graph.add_node(node_key, **attrs)

        has_traced_inputs = False
        for i, info in enumerate(op_exec_context.tensor_metas):
            if info is None or info.creator_id is None:
                continue
            parent = self._node_id_to_key_dict[info.creator_id]
            self._nx_graph.add_edge(parent, node_key)
            has_traced_inputs = True
            self._nx_graph.edges[parent, node_key][NNCFGraph.ACTIVATION_SHAPE_EDGE_ATTR] = info.shape
            self._nx_graph.edges[parent, node_key][NNCFGraph.IN_PORT_NAME] = i

        if not has_traced_inputs:
            self._inputless_nx_nodes[node_key] = self._nx_graph.nodes[node_key]

        return self._nx_node_to_nncf_node(self._nx_graph.nodes[node_key])
Ejemplo n.º 3
0
    def apply_minmax_init(self,
                          min_values,
                          max_values,
                          log_module_name: str = None):
        if self.initialized:
            nncf_logger.debug(
                "Skipped initializing {} - loaded from checkpoint".format(
                    log_module_name))
            return
        if torch.any(torch.eq(min_values, np.inf)) or torch.any(
                torch.eq(max_values, -np.inf)):
            raise AttributeError(
                'Statistics is not collected for {}'.format(log_module_name))
        sign = torch.any(torch.lt(min_values, 0))
        if self.signedness_to_force is not None and sign != self.signedness_to_force:
            nncf_logger.warning("Forcing signed to {} for module {}".format(
                self.signedness_to_force, log_module_name))
            sign = self.signedness_to_force
        self.signed = int(sign)

        abs_max = torch.max(torch.abs(max_values), torch.abs(min_values))
        SCALE_LOWER_THRESHOLD = 0.1
        self.scale.fill_(SCALE_LOWER_THRESHOLD)
        self.scale.masked_scatter_(torch.gt(abs_max, SCALE_LOWER_THRESHOLD),
                                   abs_max)

        nncf_logger.info("Set sign: {} and scale: {} for {}".format(
            self.signed, get_flat_tensor_contents_string(self.scale),
            log_module_name))
Ejemplo n.º 4
0
    def apply_minmax_init(self,
                          min_values,
                          max_values,
                          log_module_name: str = None):
        if self.initialized:
            nncf_logger.debug(
                "Skipped initializing {} - loaded from checkpoint".format(
                    log_module_name))
            return
        if torch.any(torch.eq(min_values, np.inf)) or torch.any(
                torch.eq(max_values, -np.inf)):
            raise AttributeError(
                'Statistics is not collected for {}'.format(log_module_name))
        ranges = max_values - min_values
        max_range = torch.max(max_values - min_values)
        eps = 1e-2
        correction = (clamp(ranges, low=eps * max_range, high=max_range) -
                      ranges) * 0.5
        self.input_range.data = (ranges + 2 * correction).data
        self.input_low.data = (min_values - correction).data

        nncf_logger.info("Set input_low: {} and input_range: {} for {}".format(
            get_flat_tensor_contents_string(self.input_low),
            get_flat_tensor_contents_string(self.input_range),
            log_module_name))
Ejemplo n.º 5
0
    def _apply_masks(self):
        nncf_logger.debug("Applying pruning binary masks")

        def _apply_binary_mask_to_module_weight_and_bias(
                module, mask, module_scope):
            with torch.no_grad():
                dim = module.target_weight_dim_for_compression if isinstance(
                    module, _NNCFModuleMixin) else 0
                # Applying mask to weights
                inplace_apply_filter_binary_mask(mask, module.weight,
                                                 module_scope, dim)
                # Applying mask to bias too (if exists)
                if module.bias is not None:
                    inplace_apply_filter_binary_mask(mask, module.bias,
                                                     module_scope)

        for minfo in self.pruned_module_groups_info.get_all_nodes():
            _apply_binary_mask_to_module_weight_and_bias(
                minfo.module, minfo.operand.binary_filter_pruning_mask,
                minfo.module_scope)

            # Applying mask to the BatchNorm node
            related_modules = minfo.related_modules
            if minfo.related_modules is not None and PrunedModuleInfo.BN_MODULE_NAME in minfo.related_modules \
                    and related_modules[PrunedModuleInfo.BN_MODULE_NAME].module is not None:
                bn_module = related_modules[
                    PrunedModuleInfo.BN_MODULE_NAME].module
                _apply_binary_mask_to_module_weight_and_bias(
                    bn_module, minfo.operand.binary_filter_pruning_mask,
                    minfo.module_scope)
Ejemplo n.º 6
0
    def _apply_masks(self):
        nncf_logger.debug("Applying pruning binary masks")

        def _apply_binary_mask_to_module_weight_and_bias(
                module, mask, module_name=""):
            with torch.no_grad():
                # Applying mask to weights
                inplace_apply_filter_binary_mask(mask, module.weight,
                                                 module_name)
                # Applying mask to bias too (if exists)
                if module.bias is not None:
                    inplace_apply_filter_binary_mask(mask, module.bias,
                                                     module_name)

        for minfo in self.pruned_module_info:
            _apply_binary_mask_to_module_weight_and_bias(
                minfo.module, minfo.operand.binary_filter_pruning_mask,
                minfo.module_name)

            # Applying mask to the BatchNorm node
            related_modules = minfo.related_modules
            if minfo.related_modules is not None and PrunedModuleInfo.BN_MODULE_NAME in minfo.related_modules \
                    and related_modules[PrunedModuleInfo.BN_MODULE_NAME] is not None:
                bn_module = related_modules[PrunedModuleInfo.BN_MODULE_NAME]
                _apply_binary_mask_to_module_weight_and_bias(
                    bn_module, minfo.operand.binary_filter_pruning_mask)
Ejemplo n.º 7
0
    def apply_init(self):
        self.min_values = torch.ones(self.scale_shape).to(self.device) * np.inf
        self.max_values = torch.ones(self.scale_shape).to(
            self.device) * (-np.inf)

        per_channel_history = get_per_channel_history(self.input_history,
                                                      self.scale_shape)
        per_channel_min_percentiles = [
            np.percentile(channel_hist, self.min_percentile)
            for channel_hist in per_channel_history
        ]
        per_channel_max_percentiles = [
            np.percentile(channel_hist, self.max_percentile)
            for channel_hist in per_channel_history
        ]

        numpy_mins = np.asarray(per_channel_min_percentiles)
        numpy_maxs = np.asarray(per_channel_max_percentiles)
        mins_tensor = torch.from_numpy(numpy_mins).to(self.device,
                                                      dtype=torch.float)
        maxs_tensor = torch.from_numpy(numpy_maxs).to(self.device,
                                                      dtype=torch.float)

        mins_tensor = expand_like(mins_tensor, self.scale_shape)
        maxs_tensor = expand_like(maxs_tensor, self.scale_shape)

        nncf_logger.debug("Statistics: Min ({}%th) percentile = {},"
                          " Max ({}%th) percentile = {}".format(
                              self.min_percentile,
                              get_flat_tensor_contents_string(mins_tensor),
                              self.max_percentile,
                              get_flat_tensor_contents_string(maxs_tensor)))
        self.quantize_module.apply_minmax_init(mins_tensor, maxs_tensor,
                                               self.log_module_name)
Ejemplo n.º 8
0
 def apply_init(self):
     nncf_logger.debug("Statistics: min={} max={}".format(
         get_flat_tensor_contents_string(self.min_values),
         get_flat_tensor_contents_string(self.max_values)))
     self.quantize_module.apply_minmax_init(self.min_values,
                                            self.max_values,
                                            self.log_module_name)
Ejemplo n.º 9
0
    def _set_binary_masks_for_filters(self, pruning_rate):
        nncf_logger.debug("Setting new binary masks for pruned modules.")

        with torch.no_grad():
            for group in self.pruned_module_groups_info.get_all_clusters():
                filters_num = torch.tensor(
                    [get_filters_num(minfo.module) for minfo in group.nodes])
                assert torch.all(filters_num == filters_num[0])
                device = group.nodes[0].module.weight.device

                cumulative_filters_importance = torch.zeros(
                    filters_num[0]).to(device)
                # 1. Calculate cumulative importance for all filters in group
                for minfo in group.nodes:
                    filters_importance = self.filter_importance(
                        minfo.module.weight,
                        minfo.module.target_weight_dim_for_compression)
                    cumulative_filters_importance += filters_importance

                # 2. Calculate threshold
                num_of_sparse_elems = get_rounded_pruned_element_number(
                    cumulative_filters_importance.size(0), pruning_rate)
                threshold = sorted(cumulative_filters_importance)[min(
                    num_of_sparse_elems, filters_num[0] - 1)]
                mask = calculate_binary_mask(cumulative_filters_importance,
                                             threshold)

                # 3. Set binary masks for filter
                for minfo in group.nodes:
                    pruning_module = minfo.operand
                    pruning_module.binary_filter_pruning_mask = mask

        # Calculate actual flops with new masks
        self.current_flops = self._calculate_flops_pruned_model_by_masks()
Ejemplo n.º 10
0
    def get_post_pattern_insertion_points(
            self,
            pattern: 'NNCFNodeExpression',
            omit_nodes_in_nncf_modules=False) -> List[InsertionInfo]:
        io_infos = self._original_graph.get_matching_nncf_graph_pattern_io_list(
            pattern)

        insertion_infos = []
        for io_info in io_infos:
            # The input/output is given in terms of edges, but the post-hooks are currently applied to
            # nodes. Multiple output edges in a pattern I/O info may originate from one and the same
            # node, and we have to ensure that these resolve into just one insertion point - thus the usage of "set".
            pattern_insertion_info_set = set()
            if len(io_info.output_edges) > 1:
                nncf_logger.debug(
                    "WARNING: pattern has more than one activation output")

            for nncf_node in io_info.output_nodes:
                pattern_insertion_info_set.add(
                    InsertionInfo(nncf_node.op_exec_context,
                                  is_output=True,
                                  shape_to_operate_on=None))
                # TODO: determine output shapes for output nodes to enable per-channel quantization

            # Ignore input nodes in the pattern for now, rely on the _quantize_inputs functions.
            # TODO: handle input quantization here as well

            # Since this function is currently only used for activation quantization purposes via operator
            # post-hook mechanism, we may take any edge and it will point from the same node where we will have to
            # insert a quantizer later. However, in the future the output edges may refer to activation tensors
            # with different sizes, in which case we have to insert different per-channel quantizers to
            # accomodate different trainable params if there is a difference in the channel dimension.
            # Furthermore, currently there is no distinction for single tensor output to multiple nodes and
            # multiple tensor output to multiple nodes ("chunk" operation is an example of the latter).
            # The pattern may also have unexpected outputs from a node in the middle of the pattern (see
            # "densenet121.dot" for an example of this) - need to decide what to do with that in terms
            # of quantization.
            # TODO: address the issues above.

            for nncf_edge in io_info.output_edges:
                pattern_insertion_info_set.add(
                    InsertionInfo(nncf_edge.from_node.op_exec_context,
                                  is_output=False,
                                  shape_to_operate_on=nncf_edge.tensor_shape))
            insertion_infos += list(pattern_insertion_info_set)

        insertion_infos = list(
            set(insertion_infos)
        )  # Filter the overlapping insertion points from different matches (happens for GNMT)
        insertion_infos_filtered = []

        for info in insertion_infos:
            if omit_nodes_in_nncf_modules and self.is_scope_in_nncf_module_scope(
                    info.op_exec_context.scope_in_model):
                continue
            insertion_infos_filtered.append(info)

        return insertion_infos_filtered
Ejemplo n.º 11
0
    def _generate_qid_nodekey_map(
            self, quantization_controller: QuantizationController,
            quantized_network: NNCFNetwork) -> Dict[QuantizerId, str]:
        """
        Create a lookup mapping for each QuantizerId to its corresponding quantize node in network graph
        :param quantization_controller:
        :param quantized_network:
        :return: dict with key of QuantizerId and value of node key string
        """
        # Map Non Weight Qid to its nodes in nxgraph
        weight_quantize_nodekeys = []
        non_weight_quantize_nodekeys = []
        qid_nodekey_map = OrderedDict()

        quantized_network.rebuild_graph()
        g = quantized_network.get_graph()

        for nodekey in g.get_all_node_keys():
            if 'symmetric_quantize' in nodekey and 'UpdateWeight' in nodekey:
                weight_quantize_nodekeys.append(nodekey)
            if 'symmetric_quantize' in nodekey and 'UpdateWeight' not in nodekey:
                non_weight_quantize_nodekeys.append(nodekey)

        # Find nodekey of Weight Quantizer
        for qid, _ in quantization_controller.weight_quantizers.items():
            quantize_nodekeys = []
            for nodekey in weight_quantize_nodekeys:
                if str(qid.scope) in nodekey:
                    quantize_nodekeys.append(nodekey)

            if len(quantize_nodekeys) == 1:
                qid_nodekey_map[qid] = quantize_nodekeys[0]
            else:
                raise ValueError(
                    "Quantize Node not found or More Nodes are found for WQid: {}"
                    .format(qid))

        # Find nodekey of Non-Weight Quantizer
        for qid, _ in quantization_controller.non_weight_quantizers.items():
            quantize_nodekeys = []
            for nodekey in non_weight_quantize_nodekeys:
                if str(qid.ia_op_exec_context.scope_in_model) in nodekey:
                    quantize_nodekeys.append(nodekey)

            if len(quantize_nodekeys) > 0:
                qid_nodekey_map[qid] = quantize_nodekeys[
                    qid.ia_op_exec_context.call_order]
            else:
                raise ValueError(
                    "Quantize Node not found for NWQid: {}".format(qid))

        if logger.level == logging.DEBUG:
            for qid, nodekey in qid_nodekey_map.items():
                logger.debug("QuantizerId: {}".format(qid))
                logger.debug("\tnodekey: {}".format(nodekey))

        return qid_nodekey_map
Ejemplo n.º 12
0
 def apply_init(self):
     min_values = torch.ones(self.scale_shape).to(self.device) * (-np.inf)
     max_values = torch.ones(self.scale_shape).to(self.device) * np.inf
     if self.all_min_values:
         stacked_min = torch.stack(self.all_min_values)
         min_values = stacked_min.mean(dim=0).view(self.scale_shape)
     if self.all_max_values:
         stacked_max = torch.stack(self.all_max_values)
         max_values = stacked_max.mean(dim=0).view(self.scale_shape)
     nncf_logger.debug("Statistics: min={} max={}".format(get_flat_tensor_contents_string(min_values),
                                                          get_flat_tensor_contents_string(max_values)))
     self.quantize_module.apply_minmax_init(min_values, max_values, self.log_module_name)
 def choose_configuration(self, configuration_metric: List[Tensor], bits_configurations: List[List[int]],
                          traces_order: List[int]) -> List[int]:
     num_weights = len(traces_order)
     ordered_config = [0] * num_weights
     median_metric = torch.Tensor(configuration_metric).to(self._device).median()
     configuration_index = configuration_metric.index(median_metric)
     bit_configuration = bits_configurations[configuration_index]
     for i, bitwidth in enumerate(bit_configuration):
         ordered_config[traces_order[i]] = bitwidth
     if is_main_process():
         nncf_logger.info('Chosen HAWQ configuration (bitwidth per weightable layer)={}'.format(ordered_config))
         nncf_logger.debug('Order of the weightable layers in the HAWQ configuration={}'.format(traces_order))
     return ordered_config
Ejemplo n.º 14
0
    def _set_binary_masks_for_filters(self):
        nncf_logger.debug("Setting new binary masks for pruned modules.")

        with torch.no_grad():
            for minfo in self.pruned_module_info:
                pruning_module = minfo.operand
                # 1. Calculate importance for all filters in all weights
                # 2. Calculate thresholds for every weight
                # 3. Set binary masks for filter
                filters_importance = self.filter_importance(minfo.module.weight)
                num_of_sparse_elems = get_rounded_pruned_element_number(filters_importance.size(0),
                                                                        self.pruning_rate)
                threshold = sorted(filters_importance)[num_of_sparse_elems]
                pruning_module.binary_filter_pruning_mask = calculate_binary_mask(filters_importance, threshold)
Ejemplo n.º 15
0
    def apply_init(self):
        self.medians = torch.ones(self.scale_shape).to(self.device)
        self.median_absolute_deviations = torch.ones(self.scale_shape).to(
            self.device)

        channel_count, _ = get_channel_count_and_dim_idx(self.scale_shape)
        per_channel_history = [None for i in range(channel_count)]
        while not self.input_history.empty():
            entry = self.input_history.get()
            split = split_into_channels(entry, self.scale_shape)
            for i in range(channel_count):
                flat_channel_split = split[i].flatten()

                # For post-RELU quantizers exact zeros may prevail and lead to
                # zero mean and MAD - discard them
                flat_channel_split = flat_channel_split[
                    flat_channel_split != 0]

                if per_channel_history[i] is None:
                    per_channel_history[i] = flat_channel_split
                else:
                    per_channel_history[i] = np.concatenate(
                        [per_channel_history[i], flat_channel_split])
        per_channel_median = [
            np.median(channel_hist) for channel_hist in per_channel_history
        ]
        per_channel_mad = []
        for idx, median in enumerate(per_channel_median):
            # Constant factor depends on the distribution form - assuming normal
            per_channel_mad.append(
                1.4826 * np.median(abs(per_channel_history[idx] - median)))

        numpy_median = np.asarray(per_channel_median)
        numpy_mad = np.asarray(per_channel_mad)
        median_tensor = torch.from_numpy(numpy_median).to(self.device,
                                                          dtype=torch.float)
        mad_tensor = torch.from_numpy(numpy_mad).to(self.device,
                                                    dtype=torch.float)

        nncf_logger.debug("Statistics: median={} MAD={}".format(
            get_flat_tensor_contents_string(median_tensor),
            get_flat_tensor_contents_string(mad_tensor)))
        self.quantize_module.apply_minmax_init(median_tensor - 3 * mad_tensor,
                                               median_tensor + 3 * mad_tensor,
                                               self.is_distributed,
                                               self.log_module_name)
Ejemplo n.º 16
0
 def apply_minmax_init(self,
                       min_values,
                       max_values,
                       log_module_name: str = None):
     """min_values and max_values must have the same shape as specified in self.scale_shape"""
     if self.initialized:
         nncf_logger.debug(
             "Skipped initializing {} - loaded from checkpoint".format(
                 log_module_name))
         return
     if torch.any(torch.eq(min_values, np.inf)) or torch.any(
             torch.eq(max_values, -np.inf)):
         raise AttributeError(
             'Statistics is not collected for {}'.format(log_module_name))
     own_device = next(self.parameters()).device
     min_values = min_values.to(own_device)
     max_values = max_values.to(own_device)
     self._apply_minmax_init(min_values, max_values, log_module_name)
Ejemplo n.º 17
0
 def get_minmax_values(self, tensor_statistics: Dict[InsertionPoint, Dict[ReductionShape, TensorStatistic]]) -> \
         Dict[QuantizationPointId, MinMaxTensorStatistic]:
     retval = {}
     for qp_id, qp in self.quantization_points.items():
         ip = qp.insertion_point
         if ip not in tensor_statistics:
             nncf_logger.debug("IP {} not found in tensor statistics".format(ip))
             retval[qp_id] = None
         else:
             input_shape = self.quantization_points[qp_id].qconfig.input_shape
             scale_shape = tuple(get_scale_shape(input_shape,
                                                 qp.is_weight_quantization_point(),
                                                 qp.qconfig.per_channel))
             if scale_shape not in tensor_statistics[ip]:
                 nncf_logger.debug("Did not collect tensor statistics at {} for shape {}".format(ip, scale_shape))
                 retval[qp_id] = None
             minmax_stat = MinMaxTensorStatistic.from_stat(tensor_statistics[ip][scale_shape])
             retval[qp_id] = minmax_stat
     return retval
Ejemplo n.º 18
0
    def _set_binary_masks_for_all_pruned_modules(self):
        nncf_logger.debug("Setting new binary masks for all pruned modules together.")

        normalized_weights = []
        filter_importances = []
        for minfo in self.pruned_module_info:
            pruning_module = minfo.operand
            # 1. Calculate importance for all filters in all weights
            # 2. Calculate thresholds for every weight
            # 3. Set binary masks for filter
            normalized_weight = self.weights_normalizer(minfo.module.weight)
            normalized_weights.append(normalized_weight)

            filter_importances.append(self.filter_importance(normalized_weight))
        importances = torch.cat(filter_importances)
        threshold = sorted(importances)[int(self.pruning_rate * importances.size(0))]

        for i, minfo in enumerate(self.pruned_module_info):
            pruning_module = minfo.operand
            pruning_module.binary_filter_pruning_mask = calculate_binary_mask(filter_importances[i], threshold)
Ejemplo n.º 19
0
 def save_first_iteration_node(self, inputs, node: NNCFNode):
     """
     It finds and saves "starting" points of iteration for further matching with them on next iteration,
     instead of adding new nodes for each iteration. "Starting" points of iteration are nodes
         * that have at least one input node, which is outside of iteration scope
         * or whose all inputs are not TracedTensor
     """
     op_exec_context = node.op_exec_context
     name = node
     iter_scopes = self._get_iteration_scopes(
         op_exec_context.scope_in_model)
     if iter_scopes:
         for iter_scope in iter_scopes:
             if iter_scope not in self._first_iteration_nodes:
                 self._first_iteration_nodes[iter_scope] = {}
             first_nodes = self._first_iteration_nodes[iter_scope]
             has_input_outside_iteration = False
             not_traced_count = 0
             for i in inputs:
                 if isinstance(i, Tensor):
                     has_input_outside_iteration = True
                     break
                 if not isinstance(i, TracedTensor):
                     not_traced_count += 1
                     continue
                 creator_id = i.tensor_meta.creator_id
                 creator_node = self.get_node_by_id(creator_id)
                 creator_node_op_exec_ctx = creator_node[
                     NNCFGraph.OP_EXEC_CONTEXT_NODE_ATTR]
                 within_scopes = self._get_iteration_scopes(
                     creator_node_op_exec_ctx.scope_in_model)
                 if iter_scope not in within_scopes:
                     has_input_outside_iteration = True
             if not_traced_count == len(inputs):
                 has_input_outside_iteration = True
             if has_input_outside_iteration:
                 node_name = str(op_exec_context.input_agnostic)
                 first_nodes[node_name] = node
                 nncf_logger.debug(
                     'Found first iteration node: {} in scope: {}'.format(
                         name, iter_scope))
Ejemplo n.º 20
0
    def _set_binary_masks_for_all_pruned_modules(self, pruning_rate):
        nncf_logger.debug(
            "Setting new binary masks for all pruned modules together.")
        filter_importances = []
        # 1. Calculate importances for all groups of  filters
        for group in self.pruned_module_groups_info.get_all_clusters():
            filters_num = torch.tensor(
                [get_filters_num(minfo.module) for minfo in group.nodes])
            assert torch.all(filters_num == filters_num[0])
            device = group.nodes[0].module.weight.device

            cumulative_filters_importance = torch.zeros(
                filters_num[0]).to(device)
            # Calculate cumulative importance for all filters in this group
            for minfo in group.nodes:
                normalized_weight = self.weights_normalizer(
                    minfo.module.weight)
                filters_importance = self.filter_importance(
                    normalized_weight,
                    minfo.module.target_weight_dim_for_compression)
                cumulative_filters_importance += filters_importance

            filter_importances.append(cumulative_filters_importance)

        # 2. Calculate one threshold for all weights
        importances = torch.cat(filter_importances)
        threshold = sorted(importances)[int(pruning_rate *
                                            importances.size(0))]

        # 3. Set binary masks for filters in grops
        for i, group in enumerate(
                self.pruned_module_groups_info.get_all_clusters()):
            mask = calculate_binary_mask(filter_importances[i], threshold)
            for minfo in group.nodes:
                pruning_module = minfo.operand
                pruning_module.binary_filter_pruning_mask = mask

        # Calculate actual flops with new masks
        self.current_flops = self._calculate_flops_pruned_model_by_masks()
Ejemplo n.º 21
0
 def init_with_key_list(self, key_list: List):
     self.call_counts = {key: 0 for key in key_list}
     nncf_logger.debug("{} tracker: registered {} entries".format(
         self.name, len(self.call_counts)))
Ejemplo n.º 22
0
    def apply_init(self):
        if not self._quantizers_handler.get_weight_quantizers_in_execution_order_per_id(
        ):
            return None
        original_device = next(self._model.parameters()).device
        self._model.to(self._init_device)

        traces_per_layer = self._calc_traces(self._criterion_fn,
                                             self._criterion,
                                             self._iter_number,
                                             self._tolerance)
        if not traces_per_layer:
            raise RuntimeError('Failed to calculate hessian traces!')

        traces_order = traces_per_layer.traces_order
        num_weights = len(self._weight_quantizations_by_execution_order)
        bits_configurations = self.get_configs_constrained_by_traces_order(
            self._bits, num_weights)

        weight_quantizer_ids_in_execution_order = list(
            self._weight_quantizations_by_execution_order.keys())

        if self._bitwidth_assignment_mode == BitwidthAssignmentMode.STRICT:
            self._merge_constraints_for_adjacent_quantizers(
                self._groups_of_adjacent_quantizers,
                self._hw_precision_constraints)

        bits_configurations = self._filter_configs_by_precision_constraints(
            bits_configurations, self._hw_precision_constraints,
            weight_quantizer_ids_in_execution_order, traces_order)
        if not bits_configurations:
            warnings.warn(
                'All bits configurations are incompatible with HW Config!',
                RuntimeWarning)
            return None

        if self._bitwidth_assignment_mode == BitwidthAssignmentMode.STRICT:
            bits_configurations = \
                self._filter_configs_by_grouped_weight_quantizers(bits_configurations,
                                                                  weight_quantizer_ids_in_execution_order,
                                                                  self._groups_of_adjacent_quantizers,
                                                                  traces_order)
        if not bits_configurations:
            warnings.warn(
                'No bits configurations are left after removing inconsistent groups of weight quantizers'
                ' with adjacent activation quantizers!', RuntimeWarning)
            return None

        flops_bits_per_config = self.get_flops_bits_per_config(
            bits_configurations, traces_order)
        min_ratio = min(flops_bits_per_config)
        max_ratio = max(flops_bits_per_config)
        if not min_ratio <= self._compression_ratio <= max_ratio:
            raise AttributeError(
                'Invalid compression ratio={}. Should be within range [{:.3f}, {:.3f}]'
                .format(self._compression_ratio, min_ratio, max_ratio))

        perturbations, weight_observers = self.calc_quantization_noise()

        configuration_metric = self.calc_hawq_metric_per_configuration(
            bits_configurations, perturbations, traces_per_layer,
            self._init_device)

        config_index = self.choose_configuration(configuration_metric,
                                                 flops_bits_per_config)
        chosen_config_in_traces_order = bits_configurations[config_index]
        chosen_config_in_execution_order = traces_order.get_execution_order_config(
            chosen_config_in_traces_order)
        nncf_logger.info(
            'Chosen HAWQ configuration with ratio={:.2f}, bitwidth per weightable layer={}'
            .format(flops_bits_per_config[config_index],
                    chosen_config_in_execution_order))
        nncf_logger.debug(
            'Order of the weightable layers in the HAWQ configuration (in descending order of average '
            'Hessian traces) ={}'.format(traces_order))

        self.set_chosen_config(chosen_config_in_execution_order)
        self._model.rebuild_graph()
        if is_debug() or self._dump_hawq_data:
            hawq_debugger = HAWQDebugger(bits_configurations, perturbations,
                                         weight_observers, traces_per_layer,
                                         self._bits)
            hawq_debugger.dump_metric_MB(configuration_metric)
            hawq_debugger.dump_metric_flops(configuration_metric,
                                            flops_bits_per_config,
                                            config_index)
            hawq_debugger.dump_avg_traces()
            hawq_debugger.dump_density_of_quantization_noise()
            hawq_debugger.dump_perturbations_ratio()
            hawq_debugger.dump_bitwidth_graph(
                self._algo, self._model, self._groups_of_adjacent_quantizers)
        str_bw = [str(element) for element in self.get_bitwidth_per_scope()]
        nncf_logger.info('\n'.join(
            ['\n\"bitwidth_per_scope\": [', ',\n'.join(str_bw), ']']))

        self._model.to(original_device)

        ordered_metric_per_layer = self.get_metric_per_layer(
            chosen_config_in_execution_order, perturbations, traces_per_layer)
        return ordered_metric_per_layer